blog-web/source/_posts/算法/TensorFlow.js初见(2).md
2021-07-07 18:59:42 +08:00

74 lines
2.1 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

---
title: TensorFlow.js初见(2)
date: 2021-04-27 17:13:57
tags:
- TensorFlow
- 机器学习
categories:
- 算法
---
模型训练好之后,就可以使用该模型进行预测了
代码依然是nodejs环境的运行方式
<!-- more -->
```typescript
import * as tf from '@tensorflow/tfjs-node'
import * as fs from 'fs'
(async function(){
// 加载之前训练好的模型
const model = await tf.loadLayersModel('http://localhost:8080/model.json')
// 打印模型的摘要信息
model.summary()
const imgBuffer = fs.readFileSync(`${process.cwd()}/resource/test/可回收物-帆布鞋.jpg`)
// 对图片数据的处理方式和训练的过程一样
const x = img2x(imgBuffer)
// 执行预测
const pred = <tf.Tensor>model.predict(x)
pred.print()
console.log(pred.arraySync()[0])
})()
/**
* 图片数据处理
* @param buffer 图片数据Buffer
* @returns
*/
const img2x = (buffer: Buffer) => {
// tf.tidy 执行后就会清除所有的中间张量并释放它们的GPU内存(相当于优化运行过程, 这一层包装也可以不要)
return tf.tidy(() => {
// 图片格式转换
const imgTs = tf.node.decodeImage(new Uint8Array(buffer))
// 图片尺寸转换
const imgTsResized = tf.image.resizeBilinear(imgTs, [224, 224])
// 将像素值归一化到[-1, 1]
/**
* 图片像素值是[0, 255]
* 先减去 255 / 2, 此时区间是[-127.5, 127.5]
* 再除以 255 / 2, 此时区间是[-1, 1]
* reshape进行模型转换
* 224,224代表尺寸 3代表RGB图片 1代表把图片放在数字1(拓展一维)
*/
return imgTsResized.toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3])
})
}
```
这里在本地使用`http-server`启动了一个HTTP服务方便加载模型
上面代码最后输出的执行结果是
```
[
0.03434975817799568,
0.001036567147821188,
0.9645556211471558,
0.00005802588930237107
]
```
与4种类型相对应
```
["其他垃圾","厨余垃圾","可回收物","有害垃圾"]
```
显然与可回收物的匹配度较高,其他几种的匹配度较低
> 实际预测的结果与模型的设定以及训练的素材数量相关