Skip to content

Commit 0fe4872

Browse files
committed
Return loss value when training
1 parent 240de6d commit 0fe4872

File tree

1 file changed

+4
-2
lines changed
  • dsl/src/main/scala/com/thoughtworks/deeplearning/dsl

1 file changed

+4
-2
lines changed

dsl/src/main/scala/com/thoughtworks/deeplearning/dsl/package.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ package object dsl {
5858
Batch.Aux[InputData, InputDelta],
5959
Batch.Aux[OutputData, OutputDelta]],
6060
outputDataIsOutputDelta: OutputData <:< OutputDelta
61-
): Unit = {
61+
): OutputData = {
6262
val outputBatch = toLiteral.forward(Literal[InputData](inputData)).open()
6363
try {
64-
outputBatch.backward(outputDataIsOutputDelta(outputBatch.value))
64+
val loss = outputBatch.value
65+
outputBatch.backward(outputDataIsOutputDelta(loss))
66+
loss
6567
} finally {
6668
outputBatch.close()
6769
}

0 commit comments

Comments
 (0)