blog-web/source/_posts/算法/TensorFlow.js初见(2).md
2021-07-07 18:59:42 +08:00

2.1 KiB
Raw Blame History

title date tags categories
TensorFlow.js初见(2) 2021-04-27 17:13:57
TensorFlow
机器学习
算法

模型训练好之后,就可以使用该模型进行预测了 代码依然是nodejs环境的运行方式

import * as tf from '@tensorflow/tfjs-node'
import * as fs from 'fs'

(async function(){
  // 加载之前训练好的模型
  const model = await tf.loadLayersModel('http://localhost:8080/model.json')
  // 打印模型的摘要信息
  model.summary()

  const imgBuffer = fs.readFileSync(`${process.cwd()}/resource/test/可回收物-帆布鞋.jpg`)
  // 对图片数据的处理方式和训练的过程一样
  const x = img2x(imgBuffer)
  // 执行预测
  const pred = <tf.Tensor>model.predict(x)
  pred.print()
  console.log(pred.arraySync()[0])
})()

/**
 * 图片数据处理
 * @param buffer 图片数据Buffer
 * @returns 
 */
 const img2x = (buffer: Buffer) => {
  // tf.tidy 执行后就会清除所有的中间张量并释放它们的GPU内存(相当于优化运行过程, 这一层包装也可以不要)
  return tf.tidy(() => {
    // 图片格式转换
    const imgTs = tf.node.decodeImage(new Uint8Array(buffer))
    // 图片尺寸转换
    const imgTsResized = tf.image.resizeBilinear(imgTs, [224, 224])
    // 将像素值归一化到[-1, 1]
    /**
     * 图片像素值是[0, 255]
     * 先减去 255 / 2, 此时区间是[-127.5, 127.5]
     * 再除以 255 / 2, 此时区间是[-1, 1]
     * reshape进行模型转换
     * 224,224代表尺寸 3代表RGB图片 1代表把图片放在数字1(拓展一维)
     */
    return imgTsResized.toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3])
  })
}

这里在本地使用http-server启动了一个HTTP服务方便加载模型 上面代码最后输出的执行结果是

[
  0.03434975817799568,
  0.001036567147821188,
  0.9645556211471558,
  0.00005802588930237107
]

与4种类型相对应

["其他垃圾","厨余垃圾","可回收物","有害垃圾"]

显然与可回收物的匹配度较高,其他几种的匹配度较低

实际预测的结果与模型的设定以及训练的素材数量相关