nsfw模型在python、nodejs、java等平台的AI鉴图应用
Published in:2024-06-05 |
Words: 3.4k | Reading time: 19min | reading:

简介

Detecting Not-Suitable-For-Work (NSFW) content is a high demand task in computer vision. While there are many types of NSFW content, here we focus on the pornographic images and videos.

The Yahoo Open-NSFW model originally developed with the Caffe framework has been a favourite choice, but the work is now discontinued and Caffe is also becoming less popular. Please see the description on the Yahoo project page for the context, definitions, and model training details.

This Open-NSFW 2 project provides a Keras implementation of the Yahoo model, with references to its previous third-party TensorFlow 1 implementation. Note that Keras 3 is compatible with TensorFlow, JAX, and PyTorch. However, currently this model is only guaranteed to work with TensorFlow and JAX.

A simple API is provided for making predictions on images and videos.

相关技术

  • TensorFlow
  • fastapi
  • loguru
  • pillow
  • requests
  • nodejs
  • express
  • tfjs-node
  • body-parser
  • multiparty
  • nsfwjs

python版本

usage(模型使用)

load image

1
2
3
4
5
6
7
8
def load_image(image_path):
img = Image.open(image_path)
img = img.resize((_IMAGE_SIZE, _IMAGE_SIZE))
img.load()
data = np.asarray(img, dtype="float32")
data = standardize(data)
data = data.astype(np.float16, copy=False)
return data

predict image

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def predict(image_path):
with tf.compat.v1.Session() as sess:
# type
graph = tf.compat.v1.get_default_graph()
# 获取默认的 TensorFlow 图,加载一个预先保存的模型
tf.compat.v1.saved_model.loader.load(sess, [tf.compat.v1.saved_model.tag_constants.SERVING], _MODEL_DIR)
'''
inputs:模型的输入张量,用于接收图像数据。
probabilities_op:经过 softmax 函数后的输出张量,包含每个类别的概率。
class_index_op:通过 ArgMax 操作得到的最大概率对应的类别索引。
'''
inputs = graph.get_tensor_by_name("input_tensor:0")
probabilities_op = graph.get_tensor_by_name('softmax_tensor:0')
class_index_op = graph.get_tensor_by_name('ArgMax:0')

image_data = load_image(image_path)
probabilities, class_index = sess.run([probabilities_op, class_index_op],
feed_dict={inputs: [image_data] * _BATCH_SIZE})

'''
* `probabilities_dict`:将概率数组转换为一个字典,其中键是类别的标签(通过 `_LABEL_MAP` 映射),值是对应的概率。
* `pre_label`:获取最大概率对应的类别标签。
* `result`:将预测结果组织成一个字典,包含预测的类别标签和每个类别的概率。
'''
probabilities_dict = {_LABEL_MAP.get(i): l for i, l in enumerate(probabilities[0])}
pre_label = _LABEL_MAP.get(class_index[0])
result = {"class": pre_label, "probability": probabilities_dict}
return result

上述代码主要用于预测给定图像路径对应的图像的分类标签及其概率。

js版本

导入lib

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// 导入experss模块
const express=require("express");
let fs = require('fs');
// 创建服务器对象
let app = express();
// 导入body-parser插件
const bodyparser = require("body-parser");;
// 配置body-parser模块
app.use(bodyparser.urlencoded({extended:false}));
app.use(bodyparser.json());
// 导入系统模块path
const path = require("path");
const afs = require('fs-extra');

let multiparty = require('multiparty');
let imgJS = require("image-js");
const nsfw = require('nsfwjs');
const tf = require('@tensorflow/tfjs-node');
// const tf = require('@tensorflow/tfjs');
// require('@tensorflow/tfjs-node');
const safeContent = ['Drawing', 'Neutral']; // 设置图片内容安全的类型
// const baseUrl = path.join(__dirname, './model/')

通过接口接收数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
app.post('/checkImg',async (req, res) => {
try {
let form = new multiparty.Form();
// 设置文件存储路径,以当前编辑的文件为相对路径
form.uploadDir = './tempImgs';
form.parse(req, async (err, fields, files) => {
if (!files || !files.file[0]) {
console.log(files)
console.log(fields)
return res.send({
code: -1,
msg: '请上传file图片资源(form-data格式)',
data: {}
})
}
console.log('files.file[0]:', files.file[0]);
// 图片最大尺寸
if (files.file[0].size > 1024 * 1024 * 3) {
return res.send({
code: -2,
msg: '被检测图片最大3M',
data: {}
})
};
// 支持的图片类型
let imgReg = /\S+\.(png|jpeg|jpg)$/g;
let originImgName = files.file[0].originalFilename || files.file[0].path;
if (!imgReg.test(originImgName)) {
return res.send({
code: -3,
msg: '仅仅支持(png、jpeg、jpg)类型图片检测',
data: {}
})
}
let img = await convert(files.file[0]);
let model;
// 加载模型 传入模型路径
model_fp = 'file://' + path.join("./", 'model/model.json');
tf.loadGraphModel(model_fp
).then(function(loadedModel) {
model = loadedModel;
let img = convert(files.file[0]);
nsfw1 = new nsfw.NSFWJS(0, {
size: 224
});
console.log(nsfw1.classify);
let predictions = nsfw1.classify(img);
const {isSafe, imgType} = isSafeContent(predictions);
console.log('是否安全:', predictions, isSafe);
res.send({
code: 0,
msg: isSafe ? '图片合规' : '图片可能存在不合规的风险,请核查',
data: {
isSafe,
imgType,
predictions,
}
})
}).catch(function(error) {
console.error('加载模型时出错:', error);
});

});
} catch (error) {
res.send({
code: -9,
msg: '图片核查失败,请重试',
data: {}
})
}
});


// 监听端口
app.listen(3006,()=>{
console.log("图片鉴黄服务器启动成功!port:3006");
});

图片数据归一化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
let imgTypeoObj = {
Drawing: '艺术性的',
Neutral: '中性的',
Sexy: '性感的',
Porn: '色情的',
Hentai: '变态的',
};

//转换图片格式
const convert = async file => {
const image = await imgJS.Image.load(file.path);
const numChannels = 3;
const numPixels = image.width * image.height;
const values = new Int32Array(numPixels * numChannels);

for (let i = 0; i < numPixels; i++) {
for (let c = 0; c < numChannels; ++c) {
values[i * numChannels + c] = image.data[i * 4 + c];
}
}

return tf.tensor3d(values, [image.height, image.width, numChannels], 'int32');
};

废弃数据清空

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// 递归删除目录中的文件和子目录
function emptyDirSync(dirPath) {
if (!fs.existsSync(dirPath)) {
// 如果目录不存在,则直接返回
return;
}

// 读取目录中的文件和子目录
fs.readdirSync(dirPath).forEach(file => {
const curPath = path.join(dirPath, file);

// 判断是文件还是目录
const stats = fs.lstatSync(curPath);
if (stats.isDirectory()) {
// 如果是目录,递归删除
emptyDirSync(curPath);
} else {
// 如果是文件,直接删除
fs.unlinkSync(curPath);
}
});

// 尝试删除空目录
try {
fs.rmdirSync(dirPath);
} catch (err) {
// 如果目录不为空(例如有其他进程正在写入),则忽略错误
if (err.code !== 'ENOTEMPTY') {
throw err;
}
}
}

内容检测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

const isSafeContent = predictions => {
let safeProbability = 0;
let imgTypeValArr = [];
for (let index = 0; index < predictions.length; index++) {
const item = predictions[index];
const className = item.className;
const probability = item.probability;
if (safeContent.includes(className)) {
safeProbability += probability;
};
}
imgTypeValArr = predictions.sort((a, b) => b.probability - a.probability);
// console.log('imgTypeValArr:', imgTypeValArr);
let myimgType = '';
if (imgTypeValArr.length && imgTypeValArr[0]) {
myimgType = imgTypeoObj[imgTypeValArr[0].className];
}
return {
isSafe: safeProbability > 0.5,
imgType: myimgType
};
};

使用库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
{
"main": "app.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1",
"start": "node app.js"
},
"keywords": [],
"author": "",
"license": "ISC",
"dependencies": {
"@tensorflow/tfjs-node": "latest",
"body-parser": "1.20.2",
"express": "4.18.2",
"fs-extra": "11.1.1",
"image-js": "0.35.4",
"multiparty": "4.2.3",
"nsfwjs": "2.4.2"
}
}

程序使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# install nodejs
npm install # yarn install
npm run start # yarn run start

# // 查询源
yarn config get registry

# // 更换国内源
yarn config set registry https://registry.npmmirror.com

# // 恢复官方源
yarn config set registry https://registry.yarnpkg.com

# // 删除注册表
yarn config delete registry
#最新地址 淘宝 NPM 镜像站喊你切换新域名啦!
npm config set registry https://registry.npmmirror.com

npm install -g cnpm --registry=https://registry.npmmirror.com

# 注册模块镜像
npm set registry https://registry.npmmirror.com

// node-gyp 编译依赖的 node 源码镜像
npm set disturl https://npmmirror.com/dist

// 清空缓存
npm cache clean --force

// 安装cnpm
npm install -g cnpm --registry=https://registry.npmmirror.com


# mirror config
sharp_binary_host = https://npmmirror.com/mirrors/sharp
sharp_libvips_binary_host = https://npmmirror.com/mirrors/sharp-libvips
profiler_binary_host_mirror = https://npmmirror.com/mirrors/node-inspector/
fse_binary_host_mirror = https://npmmirror.com/mirrors/fsevents
node_sqlite3_binary_host_mirror = https://npmmirror.com/mirrors
sqlite3_binary_host_mirror = https://npmmirror.com/mirrors
sqlite3_binary_site = https://npmmirror.com/mirrors/sqlite3
sass_binary_site = https://npmmirror.com/mirrors/node-sass
electron_mirror = https://npmmirror.com/mirrors/electron/
puppeteer_download_host = https://npmmirror.com/mirrors
chromedriver_cdnurl = https://npmmirror.com/mirrors/chromedriver
operadriver_cdnurl = https://npmmirror.com/mirrors/operadriver
phantomjs_cdnurl = https://npmmirror.com/mirrors/phantomjs
python_mirror = https://npmmirror.com/mirrors/python
registry = https://registry.npmmirror.com
disturl = https://npmmirror.com/dist
# 关闭防火墙
systemctl stop firewalld.service

java版本

基础软件安装

1
2
nvm install
nvm clean
  • pom.xml引入库
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    <?xml version="1.0" encoding="UTF-8"?>
    <project xmlns="http://maven.apache.org/POM/4.0.0"
    xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
    xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>hserver-nsfw</artifactId>
    <version>1.0-SNAPSHOT</version>


    <parent>
    <artifactId>hserver-parent</artifactId>
    <groupId>cn.hserver</groupId>
    <version>3.4.0</version>
    </parent>

    <dependencies>
    <!-- 核心依赖-->
    <dependency>
    <artifactId>hserver</artifactId>
    <groupId>cn.hserver</groupId>
    </dependency>
    <!-- web框架 -->
    <dependency>
    <artifactId>hserver-plugin-web</artifactId>
    <groupId>cn.hserver</groupId>
    </dependency>


    <dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow-core-api</artifactId>
    <version>0.4.0</version>
    </dependency>

    <dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow-core-platform</artifactId>
    <version>0.4.0</version>
    </dependency>



    <dependency>
    <groupId>cn.hutool</groupId>
    <artifactId>hutool-all</artifactId>
    <version>5.8.16</version>
    </dependency>
    <dependency>
    <groupId>commons-io</groupId>
    <artifactId>commons-io</artifactId>
    <version>2.11.0</version>
    </dependency>
    </dependencies>
    <!-- 打包jar -->
    <build>
    <finalName>${project.artifactId}</finalName>
    <plugins>
    <!-- 配置打包插件(设置主类,并打包成胖包) -->
    <plugin>
    <groupId>org.apache.maven.plugins</groupId>
    <artifactId>maven-assembly-plugin</artifactId>
    <configuration>
    <finalName>${project.artifactId}</finalName>
    <appendAssemblyId>false</appendAssemblyId>
    <descriptorRefs>
    <descriptorRef>jar-with-dependencies</descriptorRef>
    </descriptorRefs>
    <archive>
    <!-- 此处,要改成自己的程序入口(即 main 函数类) -->
    <manifest>
    <mainClass>net.hserver.Main</mainClass>
    </manifest>
    </archive>
    </configuration>
    <executions>
    <execution>
    <id>make-assembly</id>
    <phase>package</phase>
    <goals>
    <goal>single</goal>
    </goals>
    </execution>
    </executions>
    </plugin>
    </plugins>
    </build>

    </project>

实体类定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

package net.hserver.bean;

public class NsfwRes {
private float drawings;
private float hentai;
private float neutral;
private float porn;
private float sexy;


public NsfwRes() {
}

public NsfwRes(float[] all) {
this.drawings = all[0];
this.hentai = all[1];
this.neutral = all[2];
this.porn = all[3];
this.sexy = all[4];
}

@Override
public String toString() {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("Drawings-绘画:").append(String.format("%.2f", this.drawings * 100)).append("\n");
stringBuilder.append("Hentai-18禁:").append(String.format("%.2f", this.hentai * 100)).append("\n");
stringBuilder.append("Neutral-中性:").append(String.format("%.2f", this.neutral * 100)).append("\n");
stringBuilder.append("Porn-色情:").append(String.format("%.2f", this.porn * 100)).append("\n");
stringBuilder.append("Sexy-性感:").append(String.format("%.2f", this.sexy * 100));
return stringBuilder.toString();
}


public float getDrawings() {
return drawings;
}

public void setDrawings(float drawings) {
this.drawings = drawings;
}

public float getHentai() {
return hentai;
}

public void setHentai(float hentai) {
this.hentai = hentai;
}

public float getNeutral() {
return neutral;
}

public void setNeutral(float neutral) {
this.neutral = neutral;
}

public float getPorn() {
return porn;
}

public void setPorn(float porn) {
this.porn = porn;
}

public float getSexy() {
return sexy;
}

public void setSexy(float sexy) {
this.sexy = sexy;
}
}

Controller定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
package net.hserver.controller;

import cn.hserver.core.ioc.annotation.Autowired;
import cn.hserver.core.server.util.JsonResult;
import cn.hserver.plugin.web.annotation.Controller;
import cn.hserver.plugin.web.annotation.POST;
import cn.hserver.plugin.web.context.PartFile;
import cn.hserver.plugin.web.interfaces.HttpRequest;
import net.hserver.bean.NsfwRes;
import net.hserver.service.Nsfw;

@Controller
public class TestController {

@POST("/check")
public JsonResult check(HttpRequest request) {
PartFile file = request.queryFile("file");
Nsfw instance = Nsfw.getInstance();
NsfwRes score = instance.getScore(file.getData());
file.deleteTempCacheFile();
return JsonResult.ok().put("score", score);
}

}


输入图片处理与模型调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package net.hserver.service;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.types.TFloat32;
import java.awt.*;
import java.awt.image.BufferedImage;

public class ImagePreprocessingHelper {


public static TFloat32 preprocessImage(BufferedImage image) {
final int H = 224;
final int W = 224;
BufferedImage bufferedImage = resizeImage(image, H, W);
org.tensorflow.ndarray.Shape shape = Shape.of(1, H, W, 3);
return TFloat32.tensorOf(shape, DataBuffers.of(imageToFloatArray(bufferedImage, H, W)));
}


private static BufferedImage resizeImage(BufferedImage originalImage, int targetWidth, int targetHeight) {
BufferedImage resizedImg = new BufferedImage(targetWidth, targetHeight, BufferedImage.TRANSLUCENT);
Graphics2D g2 = resizedImg.createGraphics();
g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION,
RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g2.drawImage(originalImage, 0, 0, targetWidth, targetHeight, null);
g2.dispose();
originalImage.flush();
return resizedImg;
}

private static float[] imageToFloatArray(BufferedImage image, int H, int W) {
if (image == null) {
throw new IllegalArgumentException("Input image is null");
}
float[] floatArray = new float[H * W * 3]; // 3 channels: R, G, B
int index = 0;
for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
int rgb = image.getRGB(x, y);
// 提取RGB通道的值
float r = (rgb >> 16) & 0xFF;
float g = (rgb >> 8) & 0xFF;
float b = rgb & 0xFF;
// 将RGB通道的值归一化到[0, 1]范围内
floatArray[index++] = r / 255.0f;
floatArray[index++] = g / 255.0f;
floatArray[index++] = b / 255.0f;
}
}
image.flush();
return floatArray;
}
}

nsfw模型输出检测结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package net.hserver.service;

import cn.hserver.HServerApplication;
import net.hserver.bean.NsfwRes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.*;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat32;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.util.HashMap;
import java.util.Map;


public class Nsfw {
private static final Logger log = LoggerFactory.getLogger(Nsfw.class);

private static final Nsfw nsfw = new Nsfw();
private SavedModelBundle model;


public NsfwRes getScore(String path) {
return getScore(new File(path));

}

public NsfwRes getScore(File file) {
try {
BufferedImage image = ImageIO.read(file);
return getScore(image);
} catch (Exception e) {
return null;
}
}

public NsfwRes getScore(byte[] file) {
try {
ByteArrayInputStream bis = new ByteArrayInputStream(file);
BufferedImage image = ImageIO.read(bis);
bis.close();
return getScore(image);
} catch (Exception e) {
return null;
}
}

public NsfwRes getScore(BufferedImage file) {
try {
try (TFloat32 tFloat32 = ImagePreprocessingHelper.preprocessImage(file)) {
Map<String, Tensor> data = new HashMap<>();
data.put("input", tFloat32);
Map<String, Tensor> call = model.call(data);
float[] embeddingArray = new float[5];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
try (TFloat32 prediction = (TFloat32) call.get("prediction")) {
prediction.read(floatBuffer);
return new NsfwRes(embeddingArray);
}
}
} catch (Exception e) {
log.error(e.getMessage(), e);
return null;
}
}

public static Nsfw getInstance() {
if (nsfw.model == null) {
//注意配置自己模型文件位置,可以自己放Resource为zip,启动时解压到某个路径在加载。
//开源地址:https://github.com/gantman/nsfw_model
String linux_path = "./mobilenet_v2_140_224/";
String win_path ="C:\\Users\\rces\\mobilenet_v2_140_224\\";
nsfw.model = SavedModelBundle.load(linux_path);
}
return nsfw;
}
}

测试结果

1
2
3
2023-04-21 15:54:54.312 DEBUG --- [server_business@1] c.h.p.web.handlers.DispatcherHandler     [ 390] [c0a80107-lgq9ak0z-9143-1] : 地址:/check 方法:POST 耗时:1294/ms 来源:127.0.0.1
2023-04-21 16:01:43.481 INFO --- [server_business@2] net.hserver.service.NsfwServiceImpl [ 39] [c0a80107-lgq9jc0r-9143-2] : time-consuming : 915 ms
2023-04-21 16:01:43.483 DEBUG --- [server_business@2] c.h.p.web.handlers.DispatcherHandler [ 390] [c0a80107-lgq9jc0r-9143-2] : 地址:/check 方法:POST 耗时:928/ms 来源:127.0.0.1
  • result json
    1
    {'msg': '操作成功', 'score': {'drawings': 0.86819315, 'hentai': 0.13177933, 'neutral': 1.712211e-06, 'porn': 1.654509e-05, 'sexy': 9.231081e-06}, 'code': 200}
Prev:
LCD(Low-Code Development)Overview 低代码开发概述
Next:
在Python中使用loguru库捕获异常出现bug的修正方法