Skip to content
5 changes: 2 additions & 3 deletions src/loss/loss_priori.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ Format of function signature and outputs are compatible with the Lux ecosystem.
- `(; y_pred = y_pred)`: Named tuple containing the predicted values `y_pred`.
"""
function loss_priori_lux(model, ps, st, (x, y), device = identity)
y_pred, st_ = model(x, ps, st)
loss = device(sum(abs2, y_pred .- y) / sum(abs2, y))
return loss, st_, (; y_pred = y_pred)
y_pred, st_ = Lux.apply(model, x, ps, st)[1:2]
return sum(abs2, y_pred .- y) / sum(abs2, y), st_, (; y_pred = y_pred)
end
5 changes: 1 addition & 4 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@ function train(model, ps, st, train_dataloader, loss_function;
if tstate === nothing
tstate = Lux.Training.TrainState(model, ps, st, alg)
end
@warn "Training state type: $(typeof(tstate))"
loss::Float32 = 0 #NOP TODO: check compatibiity with precision of data
@info "Lux Training started"
for epoch in 1:nepochs
data = ignore_derivatives() do
dev(train_dataloader())
end
data = train_dataloader()
_, loss, _, tstate = Lux.Training.single_train_step!(
ad_type, loss_function, data, tstate)
if callback !== nothing
Expand Down
Loading