Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ All notable changes to this project will be documented in this file.

### Kaun

- Honor text dataset encodings via incremental Uutf decoding (#122, @Satarupa22-SD).
- Preserve empty sequential modules when unflattening so indices stay aligned for checkpoint round-tripping (@tmattio)
- Prevent `Training.fit`/`evaluate` from consuming entire datasets eagerly and fail fast when a dataset yields no batches, avoiding hangs and division-by-zero crashes (@tmattio)
- Allow metric history to tolerate metrics that appear or disappear between epochs so dynamic metric sets no longer raise during training (@tmattio)
Expand Down
169 changes: 105 additions & 64 deletions kaun/lib/kaun/dataset/dataset.ml
Original file line number Diff line number Diff line change
Expand Up @@ -185,72 +185,112 @@ let from_tensors (x, y) =
}

(* Text Data Sources *)
let from_text_file ?(encoding = `UTF8) ?(chunk_size = 65536) path =
match encoding with
| `UTF8 | `ASCII | `LATIN1 ->
let enc_name =
match encoding with
| `UTF8 -> "utf-8"
| `ASCII -> "us-ascii"
| `LATIN1 -> "iso-8859-1"
in
let uenc_opt = Uutf.encoding_of_string enc_name in
let make_decoder () = Uutf.decoder ?encoding:uenc_opt `Manual in
let handle_ref = ref None in
let file_size = ref 0 in
let offset = ref 0 in
let closed = ref false in
let buf = Buffer.create 512 in
let lines_queue = Queue.create () in
let decoder = ref (make_decoder ()) in

let open_handle () =
let handle = create_mmap path in
file_size := handle.size;
handle_ref := Some handle;
handle
in
let ensure_handle () =
match !handle_ref with Some h -> h | None -> open_handle ()
in
let close_handle () =
match !handle_ref with
| None -> ()
| Some h ->
(* Closing twice raises EBADF; swallow it because reset can
reopen. *)
(try close_mmap h with
| Unix.Unix_error (Unix.EBADF, _, _) -> ()
| exn -> raise exn);
handle_ref := None
in
ignore (open_handle ());

let from_text_file ?encoding ?(chunk_size = 65536) path =
let _ = encoding in
(* TODO: Handle different encodings *)
let handle = create_mmap path in
let offset = ref 0 in
let buffer = ref "" in
let buffer_pos = ref 0 in
let closed = ref false in

let rec next_line () =
if !closed then None
else
(* Look for newline in buffer *)
try
let nl_pos = String.index_from !buffer !buffer_pos '\n' in
let line = String.sub !buffer !buffer_pos (nl_pos - !buffer_pos) in
buffer_pos := nl_pos + 1;
Some line
with Not_found ->
(* Need more data *)
if !offset >= handle.size then
(* End of file - return remaining buffer if any *)
if !buffer_pos < String.length !buffer then (
let line =
String.sub !buffer !buffer_pos
(String.length !buffer - !buffer_pos)
in
buffer := "";
buffer_pos := 0;
Some line)
else (
close_mmap handle;
closed := true;
None)
else
(* Read next chunk *)
let chunk =
read_mmap_chunk handle ~offset:!offset ~length:chunk_size
in
offset := !offset + String.length chunk;

(* Append to remaining buffer *)
if !buffer_pos < String.length !buffer then
buffer :=
String.sub !buffer !buffer_pos
(String.length !buffer - !buffer_pos)
^ chunk
else buffer := chunk;
buffer_pos := 0;
next_line ()
in
let push_line_from_buf () =
let line = Buffer.contents buf in
Buffer.clear buf;
Queue.add line lines_queue
in

let reset () =
offset := 0;
buffer := "";
buffer_pos := 0;
closed := false
in
let rec fill_queue () =
if Queue.is_empty lines_queue && not !closed then
match Uutf.decode !decoder with
| `Uchar u ->
if Uchar.to_int u = 0x000A then push_line_from_buf ()
else Uutf.Buffer.add_utf_8 buf u;
if Queue.is_empty lines_queue then fill_queue ()
| `Malformed _ ->
Uutf.Buffer.add_utf_8 buf Uutf.u_rep;
fill_queue ()
| `Await ->
if !offset >= !file_size then (
Uutf.Manual.src !decoder (Bytes.create 0) 0 0;
fill_queue ())
else
let handle = ensure_handle () in
let chunk =
read_mmap_chunk handle ~offset:!offset ~length:chunk_size
in
offset := !offset + String.length chunk;
if chunk = "" then (
Uutf.Manual.src !decoder (Bytes.create 0) 0 0;
fill_queue ())
else
let bytes = Bytes.of_string chunk in
Uutf.Manual.src !decoder bytes 0 (Bytes.length bytes);
fill_queue ()
| `End ->
if Buffer.length buf > 0 then push_line_from_buf ();
close_handle ();
closed := true
in

{
next = next_line;
cardinality = (fun () -> Unknown);
reset = Some reset;
spec = (fun () -> Scalar "string");
}
let rec next_line () =
if not (Queue.is_empty lines_queue) then Some (Queue.take lines_queue)
else if !closed then None
else (
fill_queue ();
if not (Queue.is_empty lines_queue) then Some (Queue.take lines_queue)
else if !closed then None
else next_line ())
in

let reset () =
Buffer.clear buf;
Queue.clear lines_queue;
offset := 0;
closed := false;
decoder := make_decoder ();
close_handle ();
ignore (open_handle ())
in

{
next = next_line;
cardinality = (fun () -> Unknown);
reset = Some reset;
spec = (fun () -> Scalar "string");
}

let from_text_files ?(encoding = `UTF8) ?(chunk_size = 65536) paths =
let current_file = ref 0 in
Expand All @@ -262,7 +302,8 @@ let from_text_files ?(encoding = `UTF8) ?(chunk_size = 65536) paths =
if !current_file >= List.length paths then None
else
let path = List.nth paths !current_file in
current_dataset := Some (from_text_file ~encoding ~chunk_size path);
let ds = from_text_file ~encoding ~chunk_size path in
current_dataset := Some ds;
incr current_file;
next ()
| Some ds -> (
Expand Down
10 changes: 8 additions & 2 deletions kaun/lib/kaun/dataset/dataset.mli
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,21 @@ val from_file : (string -> 'a) -> string -> 'a t
(** {2 Text Data Sources} *)

val from_text_file :
?encoding:[ `UTF8 | `ASCII ] -> ?chunk_size:int -> string -> string t
?encoding:[ `UTF8 | `ASCII | `LATIN1 ] ->
?chunk_size:int ->
string ->
string t
(** [from_text_file ?encoding ?chunk_size path] creates a memory-mapped text
dataset yielding lines as strings.
- [encoding]: Text encoding (default: UTF8)
- [chunk_size]: Size of chunks to read at once (default: 64KB) The file is
memory-mapped and read lazily in chunks. *)

val from_text_files :
?encoding:[ `UTF8 | `ASCII ] -> ?chunk_size:int -> string list -> string t
?encoding:[ `UTF8 | `ASCII | `LATIN1 ] ->
?chunk_size:int ->
string list ->
string t
(** [from_text_files paths] creates a dataset from multiple text files. Files
are processed sequentially without loading all into memory. *)

Expand Down
2 changes: 1 addition & 1 deletion kaun/lib/kaun/dune
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
(library
(name kaun)
(public_name kaun)
(libraries rune unix str nx nx.core nx.io yojson domainslib))
(libraries rune unix str nx nx.core nx.io yojson domainslib uutf))
45 changes: 45 additions & 0 deletions kaun/test/test_dataset.ml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,23 @@ let test_from_text_file () =
[ "line1"; "line2"; "line3" ]
collected)

(* Test for utf8 *)
let test_from_text_file_utf8 () =
let content = "hello \xF0\x9F\x98\x8A\nsecond\n" in
with_temp_file content (fun path ->
let ds = from_text_file ~encoding:`UTF8 path in
let lines = collect_dataset ds in
Alcotest.(check (list string))
"utf8 emoji preserved" [ "hello 😊"; "second" ] lines)

(* Test for Latin1 *)
let test_from_text_file_latin1 () =
let content = "caf\xE9\nna\xEFve\n" in
with_temp_file content (fun path ->
let ds = from_text_file ~encoding:`LATIN1 path in
let lines = collect_dataset ds in
Alcotest.(check (list string)) "latin1 decoded" [ "café"; "naïve" ] lines)

let test_from_text_file_large_lines () =
let line = String.make 1000 'x' in
let content = line ^ "\n" ^ line ^ "\n" in
Expand All @@ -80,6 +97,29 @@ let test_from_text_file_large_lines () =
(fun l -> Alcotest.(check int) "line length" 1000 (String.length l))
collected)

let test_from_text_file_reset () =
let content = "line1\nline2\n" in
with_temp_file content (fun path ->
let dataset = from_text_file path in
let expected = [ "line1"; "line2" ] in
let first_pass = collect_dataset dataset in
Alcotest.(check (list string)) "first pass" expected first_pass;
reset dataset;
let second_pass = collect_dataset dataset in
Alcotest.(check (list string)) "after reset" expected second_pass)

let test_from_text_file_reset_mid_stream () =
let content = "alpha\nbeta\ngamma\n" in
with_temp_file content (fun path ->
let dataset = from_text_file path in
let first_chunk = collect_n 1 dataset in
Alcotest.(check (list string))
"consumed first element" [ "alpha" ] first_chunk;
reset dataset;
let refreshed = collect_n 2 dataset in
Alcotest.(check (list string))
"after reset first two elements" [ "alpha"; "beta" ] refreshed)

let test_from_text_files () =
let content1 = "file1_line1\nfile1_line2\n" in
let content2 = "file2_line1\nfile2_line2\n" in
Expand Down Expand Up @@ -618,8 +658,13 @@ let () =
( "text_files",
[
test_case "from_text_file" `Quick test_from_text_file;
test_case "from_text_file_utf8" `Quick test_from_text_file_utf8;
test_case "from_text_file_latin1" `Quick test_from_text_file_latin1;
test_case "from_text_file_large_lines" `Quick
test_from_text_file_large_lines;
test_case "from_text_file_reset" `Quick test_from_text_file_reset;
test_case "from_text_file_reset_mid_stream" `Quick
test_from_text_file_reset_mid_stream;
test_case "from_text_files" `Quick test_from_text_files;
test_case "from_jsonl" `Quick test_from_jsonl;
test_case "from_jsonl_custom_field" `Quick
Expand Down
Loading