Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions discojs/discojs-core/src/validation/validator.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const simplefaceMock = {
displayInformation: {},
trainingInformation: {
modelID: 'simple_face-model',
batchSize: 1,
batchSize: 4,
dataType: 'image',
IMAGE_H: 200,
IMAGE_W: 200,
Expand All @@ -27,9 +27,10 @@ describe('validator', () => {
const files: string[][] = ['child/', 'adult/']
.map((subdir: string) => fs.readdirSync(dir + subdir)
.map((file: string) => dir + subdir + file))
const labels = files.flatMap((files, index) => Array(files.length).fill(index))

const data: data.Data = (await new node.data.NodeImageLoader(simplefaceMock)
.loadAll(files.flat(), { labels: files.flatMap((files, index) => Array(files.length).fill(index)) })).train
.loadAll(files.flat(), { labels })).train
const validator = new Validator(
simplefaceMock,
new ConsoleLogger(),
Expand All @@ -42,14 +43,16 @@ describe('validator', () => {
console.log('data.size was undefined')
}
assert(
validator.visitedSamples() === data.size,
`expected ${size} visited samples but got ${validator.visitedSamples()}`
validator.visitedSamples === data.size,
`expected ${size} visited samples but got ${validator.visitedSamples}`
)
assert(
validator.accuracy() > 0.3,
`expected accuracy greater than 0.3 but got ${validator.accuracy()}`
validator.accuracy > 0.3,
`expected accuracy greater than 0.3 but got ${validator.accuracy}`
)
console.table(validator.confusionMatrix)
}).timeout(10_000)

// TODO: fix titanic model (nan accuracy)
// it('works for titanic', async () => {
// const data: Data = await new NodeTabularLoader(titanic.task, ',')
Expand Down
44 changes: 35 additions & 9 deletions discojs/discojs-core/src/validation/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { tf, data, Task, Logger, Client, GraphInformant, Memory, ModelSource, Fe
export class Validator {
private readonly graphInformant = new GraphInformant()
private size = 0
private _confusionMatrix: number[][] | undefined

constructor (
public readonly task: Task,
Expand All @@ -26,7 +27,7 @@ export class Validator {
}
}

async assess (data: data.Data): Promise<Array<{groundTruth: number, pred: number, features: Features}>> {
async assess (data: data.Data, useConfusionMatrix?: boolean): Promise<Array<{groundTruth: number, pred: number, features: Features}>> {
const batchSize = this.task.trainingInformation?.batchSize
if (batchSize === undefined) {
throw new TypeError('batch size is undefined')
Expand Down Expand Up @@ -61,16 +62,34 @@ export class Validator {

hits += List(pred).zip(List(ys)).filter(([p, y]) => p === y).size

// TODO: Confusion Matrix stats

const currentAccuracy = hits / this.size
this.graphInformant.updateAccuracy(currentAccuracy)
} else {
throw new TypeError('missing feature/label in dataset')
throw new Error('missing feature/label in dataset')
}
})
this.logger.success(`Obtained validation accuracy of ${this.accuracy()}`)
this.logger.success(`Visited ${this.visitedSamples()} samples`)
this.logger.success(`Obtained validation accuracy of ${this.accuracy}`)
this.logger.success(`Visited ${this.visitedSamples} samples`)

if (useConfusionMatrix) {
try {
this._confusionMatrix = tf.math.confusionMatrix(
[],
[],
0
).arraySync()
} catch (e: any) {
console.error(e instanceof Error ? e.message : e.toString())
throw new Error('Failed to compute the confusion matrix')
}
}

return List(groundTruth).zip(List(predictions)).zip(List(features)).map(([[gt, p], f]) => ({ groundTruth: gt, pred: p, features: f })).toArray()
return List(groundTruth)
.zip(List(predictions), List(features))
.map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f }))
.toArray()
}

async predict (data: data.Data): Promise<number[]> {
Expand All @@ -82,7 +101,10 @@ export class Validator {
const model = await this.getModel()
const predictions: number[] = []

await data.dataset.batch(batchSize).forEachAsync(e => predictions.push(...Array.from((model.predict(e as tf.Tensor, { batchSize: batchSize }) as tf.Tensor).argMax(1).dataSync())))
await data.dataset
.batch(batchSize)
.forEachAsync(e =>
predictions.push(...(model.predict(e as tf.Tensor, { batchSize: batchSize }) as tf.Tensor).argMax(1).arraySync() as number[]))

return predictions
}
Expand All @@ -99,15 +121,19 @@ export class Validator {
throw new Error('cannot identify model')
}

accuracyData (): List<number> {
get accuracyData (): List<number> {
return this.graphInformant.data()
}

accuracy (): number {
get accuracy (): number {
return this.graphInformant.accuracy()
}

visitedSamples (): number {
get visitedSamples (): number {
return this.size
}

get confusionMatrix (): number[][] | undefined {
return this._confusionMatrix
}
}
4 changes: 1 addition & 3 deletions discojs/discojs-web/src/memory/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ export class IndexedDB extends Memory {
throw new TypeError('source incomplete')
}

const version = (source.version === undefined || source.version === 0)
? ''
: source.version
const version = source.version ?? 0

return `indexeddb://${path.join(source.type, source.taskID, source.name)}@${version}`
}
Expand Down
Loading