有些人可能天生对蛇之类的图片感到敏感,又或者避免小孩子上网浏览到不健康的内容,这个时候我们可能需要对网页的图片建立一个前置过滤系统。

在以前做到这个可能很麻烦很费资源,当然,这一切在有了tf.js后要做到是很方便的。

MobileNet

Google提出的移动端模型MobileNet,其核心是采用了可分解的depthwise separable convolution,其不仅可以降低模型计算复杂度,而且可以大大降低模型大小。

根据tfjs的demo,我们下载一个训练好的模型。index.js

1
2
3
const MOBILENET_MODEL_PATH =
// tslint:disable-next-line:max-line-length
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';

demo里用的是0.25大小的模型,准确性有点问题。

https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.75_224/model.json

我没找到这个在哪可以下载完整的,路径猜测发现0.75的存在,于是我写了个脚本把里面的分片文件一个个下载下来了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
var fs = require('fs');
var fetch = require('node-fetch');
var path = require('path');

var config = fs.readFileSync('mobilenet-0.75/model.json', 'utf-8');
config = JSON.parse(config);

// var model = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.75_224/model.json';
(async () => {

// var configResp = await fetch(model);
// var config = await configResp.json();
var gPath = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.75_224/';
// console.log(gPath);
// return;
for (let index = 0; index < config.weightsManifest.length; index++) {
const weightsManifest = config.weightsManifest[index];
for (let indexc = 0; indexc < weightsManifest.paths.length; indexc++) {
const path = weightsManifest.paths[indexc];
var realFile = gPath+path;
var resp = await fetch(realFile);
var data = await resp.buffer();
fs.writeFileSync('./mobilenet-0.75/'+path, data);
}
}
})();
background.js

准备一个分类的接口,接收一个图片地址,返回改图片的分类列表

加载模型
1
2
3
4
5
6
7
8
9
const MOBILENET_MODEL_PATH = // tslint:disable-next-line:max-line-length https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json
'mobilenet-0.75/model.json';
const IMAGE_SIZE = 224;
const TOPK_PREDICTIONS = 10;
let mobilenet;
(async () => {
mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH); // Warmup the model
mobilenet.predict(tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 3])).dispose();
})();
分类接口

接收一个Image对象,返回潜在Top分类,详见

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
async function predict(imgElement) {
status('Predicting...'); // The first start time includes the time it takes to extract the image
// from the HTML and preprocess it, in additon to the predict() call.

const startTime1 = performance.now(); // The second start time excludes the extraction and preprocessing and
// includes only the predict() call.

let startTime2;
const logits = tf.tidy(() => {
// tf.browser.fromPixels() returns a Tensor from an image element.
const img = tf.browser.fromPixels(imgElement).toFloat();
const offset = tf.scalar(127.5); // Normalize the image from [0, 255] to [-1, 1].

const normalized = img.sub(offset).div(offset); // Reshape to a single-element batch so we can pass it to predict.

const batched = normalized.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]);
startTime2 = performance.now(); // Make a prediction through mobilenet.

return mobilenet.predict(batched);
}); // Convert logits to probabilities and class names.

const classes = await getTopKClasses(logits, TOPK_PREDICTIONS);
const totalTime1 = performance.now() - startTime1;
const totalTime2 = performance.now() - startTime2;
// status("Done in ".concat(Math.floor(totalTime1), " ms ") + "(not including preprocessing: ".concat(Math.floor(totalTime2), " ms)")); // Show the classes in the DOM.
// showResults(imgElement, classes);
return classes;
}
/**
* Computes the probabilities of the topK classes given logits by computing
* softmax to get probabilities and then sorting the probabilities.
* @param logits Tensor representing the logits from MobileNet.
* @param topK The number of top predictions to show.
*/


async function getTopKClasses(logits, topK) {
const values = await logits.data();
const valuesAndIndices = [];

for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({
value: values[i],
index: i
});
}

valuesAndIndices.sort((a, b) => {
return b.value - a.value;
});
const topkValues = new Float32Array(topK);
const topkIndices = new Int32Array(topK);

for (let i = 0; i < topK; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}

const topClassesAndProbs = [];

for (let i = 0; i < topkIndices.length; i++) {
topClassesAndProbs.push({
className: _imagenet_classes.IMAGENET_CLASSES[topkIndices[i]],
probability: topkValues[i]
});
}

return topClassesAndProbs;
}

图片URL转base64,插件manifest.json需要 "permissions": ["http://*/*", "https://*/*" ],

1
2
3
4
5
6
7
8
9
10
11
12
13
function getDataUrl(srcUrl, cb) {
var canvas = document.createElement('canvas'),
context;
var tmpImage = new Image;
tmpImage.src = srcUrl;
tmpImage.onload = function() {
canvas.width = tmpImage.width;
canvas.height = tmpImage.height;
context = canvas.getContext('2d');
context.drawImage(tmpImage, 0, 0);
cb(canvas.toDataURL());
}
}

监听页面请求,获取图片URL base64重设大小,拿到分类结果后判断是否包含要过滤的分类,并发送结果给请求的页面。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// 过滤分类的关键词
const keywords = ['snake', 'cobra'];
const IMAGE_SIZE = 224;
chrome.runtime.onMessage.addListener(function(request, sender, sendResponseA) {
console.log('onMessage', request.src);
var hitBlack = false;
function sendResponse(status, classes){
chrome.tabs.sendMessage(sender.tab.id, {
action: "ret",
src: request.src,
status: status,
classes: classes
}, function(response) {
});
}
getDataUrl(request.src, function(res){
console.log(res);
let img = document.createElement('img');
img.src = res;
img.width = IMAGE_SIZE;
img.height = IMAGE_SIZE;
img.onerror = function(){
sendResponse(false);
}
img.onload = function(){
(async () => {
try{
var classes = await predict(img);
for (var index = 0; index < classes.length; index++) {
const classe = classes[index];
for (var ck = 0; ck < keywords.length; ck++) {
var keyword = keywords[ck];
if(classe.className.indexOf(keyword) > -1){
hitBlack = true;
break;
}
}
if(hitBlack){
break;
}
}
console.log('classes', classes);
sendResponse(hitBlack, classes);
}catch(e){
sendResponse(false);
}
})();
};
});
});

content_script

内容页js需要先把图片隐藏透明度0,页面所有图并监听发现新的图片,然后都发送请求给background.js

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
const eleCache = {};
function findImages(){
// console.log('findImages');
var images = document.getElementsByTagName('img');
for (var index = 0; index < images.length; index++) {
var image = images[index];
var classify = image.getAttribute('classify');
if(classify != null) continue;
if(image.getAttribute('fetching') != null) continue;
if(image.width > 50){
image.style = "opacity: 0;";
eleCache[image.src] = eleCache[image.src] || [];
eleCache[image.src].push(image);
image.setAttribute('fetching', 1);
try{
chrome.extension.sendMessage({
src: image.src
}, function(hitBlack) {
});
}catch(e){}

}
}
}

// findImages();
var timeStart = Date.now();
var longRun = null;
var timer = setInterval(function(){
var timeLeft = Date.now() - timeStart;
if(timeLeft > 5000){
// longRun = setInterval(function(){
// findImages();
// }, 3000)
return clearInterval(timer);
}
findImages();
}, 50);


var MutationObserver = window.MutationObserver || window.WebKitMutationObserver || window.MozMutationObserver
window.onload = function(){
findImages();
var target = document.body;
var observer = new MutationObserver(function(mutations) {
findImages();
});
var config = { attributes: true, childList: true, characterData: true, subtree: true}
observer.observe(target, config);
}

监听分类结果并显示不需要屏蔽的图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
chrome.runtime.onMessage.addListener(function(request, sender, sendResponseA) {
if(request.action){
// console.log('request', request);
var src = request.src;
var classes = request.classes;
var els = eleCache[src];
if(els && !request.status){
els.forEach(function(el){
el.style = "opacity: 1;";
});
console.log('image', src, 'notinblacklist')
}else{
console.log('image', src, 'blocked')
}
if(els) els.forEach(function(el){
el.setAttribute('classify', 1);
// classes
var alts = [];
if(classes) classes.forEach(function(a){
alts.push(a.className)
})
el.setAttribute('cats', JSON.stringify(classes));
el.setAttribute('alt', alts.join("\n"));
});
}
})

以上一个简单基于tfjs的前端图片过滤系统就实现了。

此外我们还可以增加一个UI来设置要过滤的分类。或者增加其他的一些模型进来完善各种特殊的过滤需求。

下载image-filter.zip