We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 240de6d commit 0fe4872Copy full SHA for 0fe4872
dsl/src/main/scala/com/thoughtworks/deeplearning/dsl/package.scala
@@ -58,10 +58,12 @@ package object dsl {
58
Batch.Aux[InputData, InputDelta],
59
Batch.Aux[OutputData, OutputDelta]],
60
outputDataIsOutputDelta: OutputData <:< OutputDelta
61
- ): Unit = {
+ ): OutputData = {
62
val outputBatch = toLiteral.forward(Literal[InputData](inputData)).open()
63
try {
64
- outputBatch.backward(outputDataIsOutputDelta(outputBatch.value))
+ val loss = outputBatch.value
65
+ outputBatch.backward(outputDataIsOutputDelta(loss))
66
+ loss
67
} finally {
68
outputBatch.close()
69
}
0 commit comments