-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathclassifier_movinet.js
96 lines (81 loc) · 2.17 KB
/
classifier_movinet.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
const path = require('path')
const fsSync = require('fs')
const _ = require('lodash')
let tf, getPort, StaticServer
let PUREJS = false
if (process.env.RECOGNIZE_PUREJS === 'true') {
tf = require('@tensorflow/tfjs')
getPort = require('get-port')
StaticServer = require('static-server')
PUREJS = true
} else {
try {
if (process.env.RECOGNIZE_GPU === 'true') {
tf = require('@tensorflow/tfjs-node-gpu')
} else {
tf = require('@tensorflow/tfjs-node')
}
} catch (e) {
console.error(e)
console.error('Trying js-only mode')
tf = require('@tensorflow/tfjs')
PUREJS = true
}
}
const Movinet = require('./movinet/MovinetModel.js')
const { downloadAll } = require('./model-manager.js')
if (process.argv.length < 3) throw new Error('Incorrect arguments: node classify.js ...<IMAGE_FILES> | node classify.js -')
/**
*
*/
async function main() {
const modelPath = path.resolve(__dirname, '..', 'models', 'movinet-a3')
const modelFileName = 'model.json'
let modelUrl
if (PUREJS) {
// See https://github.com/tensorflow/tfjs/issues/4927
const port = await getPort()
const server = new StaticServer({
rootPath: modelPath,
port,
})
await new Promise(resolve => server.start(resolve))
modelUrl = `http://localhost:${port}/${modelFileName}`
} else {
modelUrl = `${modelPath}/`
}
// Download models on first run
if (!fsSync.existsSync(modelPath)) {
await downloadAll()
}
const model = await Movinet.create(modelUrl)
const getStdin = (await import('get-stdin')).default
const paths = process.argv[2] === '-'
? (await getStdin()).split('\n')
: process.argv.slice(2)
for (const path of paths) {
try {
const results = await model.inference(path, {
topK: 6,
})
const threshold = 0.85
const labels = results
.filter(result => {
console.error(result)
return result.probability >= threshold
})
.map(result => result.className)
console.log(JSON.stringify(_.uniq(labels)))
} catch (e) {
console.error(e)
console.log('[]')
}
}
}
tf.setBackend(process.env.RECOGNIZE_PUREJS === 'true' ? 'cpu' : 'tensorflow')
.then(() => main())
.then(() => process.exit(0))
.catch(e => {
console.error(e)
process.exit(1)
})