Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ We're closing 8 user-reported issues or feature requests and are totalling 30 co
- Add NDCG, MAP, and MRR ranking metrics to Kaun metrics (@tmattio)
- Add BLEU, ROUGE, and METEOR metrics to Kaun for pre-tokenized sequences, removing tokenizer dependencies (@tmattio)
- Add SSIM, IoU, and Dice metrics for vision workloads in Kaun (@tmattio)
- Fix(training): reinitialize dataset each epoch to avoid iterator exhaustion (#147, @Shocker444)

### Talon

Expand Down
58 changes: 30 additions & 28 deletions kaun/example/m0-mnist-basics/mnist.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,46 +24,46 @@ let metrics =
let train () =
(* Datasets *)
Printf.printf "Creating datasets...\n%!";
let start = Unix.gettimeofday () in
let train_data = Kaun_datasets.mnist ~train:true ~flatten:false () in
Printf.printf " MNIST train data loaded in %.2fs\n%!"
let make_train_ds () =
let rng = Rune.Rng.key 42 in
let shuffle_start = Unix.gettimeofday () in
let start = Unix.gettimeofday () in
let train_data = Kaun_datasets.mnist ~train:true ~flatten:false () in
Printf.printf " MNIST train data loaded in %.2fs\n%!"
(Unix.gettimeofday () -. start);
let train_data_shuffled =
Kaun.Dataset.shuffle ~rng ~buffer_size:60000 train_data
in
Printf.printf " Shuffle done in %.2fs\n%!"
(Unix.gettimeofday () -. shuffle_start);
let batch_start = Unix.gettimeofday () in

Printf.printf " Batching done in %.2fs\n%!"
(Unix.gettimeofday () -. batch_start);

let shuffle_start = Unix.gettimeofday () in
let rng = Rune.Rng.key 42 in
let train_data_shuffled =
Kaun.Dataset.shuffle ~rng ~buffer_size:60000 train_data
in
Printf.printf " Shuffle done in %.2fs\n%!"
(Unix.gettimeofday () -. shuffle_start);

let batch_start = Unix.gettimeofday () in
let train_ds =
Printf.printf "Train dataset created in %.2fs\n%!"
(Unix.gettimeofday () -. start);
Kaun.Dataset.batch_map 32
(fun batch ->
let images, labels = Array.split batch in
let batched_images = Rune.stack ~axis:0 (Array.to_list images) in
let batched_labels = Rune.stack ~axis:0 (Array.to_list labels) in
(batched_images, batched_labels))
(fun batch ->
let images, labels = Array.split batch in
let batched_images = Rune.stack ~axis:0 (Array.to_list images) in
let batched_labels = Rune.stack ~axis:0 (Array.to_list labels) in
(batched_images, batched_labels))
train_data_shuffled
in
Printf.printf " Batching done in %.2fs\n%!"
(Unix.gettimeofday () -. batch_start);

Printf.printf "Train dataset created in %.2fs\n%!"
let make_test_ds () =
let start = Unix.gettimeofday () in
let test_data = Kaun_datasets.mnist ~train:false ~flatten:false () in
Printf.printf "Test dataset created in %.2fs\n%!"
(Unix.gettimeofday () -. start);

let start = Unix.gettimeofday () in
let test_ds =
Kaun_datasets.mnist ~train:false ~flatten:false ()
|> Kaun.Dataset.batch_map 100 (fun batch ->
Kaun.Dataset.batch_map 100 (fun batch ->
let images, labels = Array.split batch in
let batched_images = Rune.stack ~axis:0 (Array.to_list images) in
let batched_labels = Rune.stack ~axis:0 (Array.to_list labels) in
(batched_images, batched_labels))
test_data
in
Printf.printf "Test dataset created in %.2fs\n%!"
(Unix.gettimeofday () -. start);

(* Initialize model with dummy input to get params *)
Printf.printf "Initializing model...\n%!";
Expand All @@ -83,6 +83,8 @@ let train () =

(* Training *)
Printf.printf "Starting training iteration...\n%!";
let train_ds = make_train_ds () in
let test_ds = make_test_ds () in
Metrics.Collection.reset metrics;
Kaun.Dataset.iter
(fun (x_batch, y_batch) ->
Expand Down
Loading