Skip to content

Commit fa00a08

Browse files
fix(kaun): Honor text dataset encodings (#122)
Co-authored-by: Thibaut Mattio <thibaut.mattio@gmail.com>
1 parent 21395da commit fa00a08

File tree

5 files changed

+160
-67
lines changed

5 files changed

+160
-67
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ All notable changes to this project will be documented in this file.
4949

5050
### Kaun
5151

52+
- Honor text dataset encodings via incremental Uutf decoding (#122, @Satarupa22-SD).
5253
- Preserve empty sequential modules when unflattening so indices stay aligned for checkpoint round-tripping (@tmattio)
5354
- Prevent `Training.fit`/`evaluate` from consuming entire datasets eagerly and fail fast when a dataset yields no batches, avoiding hangs and division-by-zero crashes (@tmattio)
5455
- Allow metric history to tolerate metrics that appear or disappear between epochs so dynamic metric sets no longer raise during training (@tmattio)

kaun/lib/kaun/dataset/dataset.ml

Lines changed: 105 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -185,72 +185,112 @@ let from_tensors (x, y) =
185185
}
186186

187187
(* Text Data Sources *)
188+
let from_text_file ?(encoding = `UTF8) ?(chunk_size = 65536) path =
189+
match encoding with
190+
| `UTF8 | `ASCII | `LATIN1 ->
191+
let enc_name =
192+
match encoding with
193+
| `UTF8 -> "utf-8"
194+
| `ASCII -> "us-ascii"
195+
| `LATIN1 -> "iso-8859-1"
196+
in
197+
let uenc_opt = Uutf.encoding_of_string enc_name in
198+
let make_decoder () = Uutf.decoder ?encoding:uenc_opt `Manual in
199+
let handle_ref = ref None in
200+
let file_size = ref 0 in
201+
let offset = ref 0 in
202+
let closed = ref false in
203+
let buf = Buffer.create 512 in
204+
let lines_queue = Queue.create () in
205+
let decoder = ref (make_decoder ()) in
206+
207+
let open_handle () =
208+
let handle = create_mmap path in
209+
file_size := handle.size;
210+
handle_ref := Some handle;
211+
handle
212+
in
213+
let ensure_handle () =
214+
match !handle_ref with Some h -> h | None -> open_handle ()
215+
in
216+
let close_handle () =
217+
match !handle_ref with
218+
| None -> ()
219+
| Some h ->
220+
(* Closing twice raises EBADF; swallow it because reset can
221+
reopen. *)
222+
(try close_mmap h with
223+
| Unix.Unix_error (Unix.EBADF, _, _) -> ()
224+
| exn -> raise exn);
225+
handle_ref := None
226+
in
227+
ignore (open_handle ());
188228

189-
let from_text_file ?encoding ?(chunk_size = 65536) path =
190-
let _ = encoding in
191-
(* TODO: Handle different encodings *)
192-
let handle = create_mmap path in
193-
let offset = ref 0 in
194-
let buffer = ref "" in
195-
let buffer_pos = ref 0 in
196-
let closed = ref false in
197-
198-
let rec next_line () =
199-
if !closed then None
200-
else
201-
(* Look for newline in buffer *)
202-
try
203-
let nl_pos = String.index_from !buffer !buffer_pos '\n' in
204-
let line = String.sub !buffer !buffer_pos (nl_pos - !buffer_pos) in
205-
buffer_pos := nl_pos + 1;
206-
Some line
207-
with Not_found ->
208-
(* Need more data *)
209-
if !offset >= handle.size then
210-
(* End of file - return remaining buffer if any *)
211-
if !buffer_pos < String.length !buffer then (
212-
let line =
213-
String.sub !buffer !buffer_pos
214-
(String.length !buffer - !buffer_pos)
215-
in
216-
buffer := "";
217-
buffer_pos := 0;
218-
Some line)
219-
else (
220-
close_mmap handle;
221-
closed := true;
222-
None)
223-
else
224-
(* Read next chunk *)
225-
let chunk =
226-
read_mmap_chunk handle ~offset:!offset ~length:chunk_size
227-
in
228-
offset := !offset + String.length chunk;
229-
230-
(* Append to remaining buffer *)
231-
if !buffer_pos < String.length !buffer then
232-
buffer :=
233-
String.sub !buffer !buffer_pos
234-
(String.length !buffer - !buffer_pos)
235-
^ chunk
236-
else buffer := chunk;
237-
buffer_pos := 0;
238-
next_line ()
239-
in
229+
let push_line_from_buf () =
230+
let line = Buffer.contents buf in
231+
Buffer.clear buf;
232+
Queue.add line lines_queue
233+
in
240234

241-
let reset () =
242-
offset := 0;
243-
buffer := "";
244-
buffer_pos := 0;
245-
closed := false
246-
in
235+
let rec fill_queue () =
236+
if Queue.is_empty lines_queue && not !closed then
237+
match Uutf.decode !decoder with
238+
| `Uchar u ->
239+
if Uchar.to_int u = 0x000A then push_line_from_buf ()
240+
else Uutf.Buffer.add_utf_8 buf u;
241+
if Queue.is_empty lines_queue then fill_queue ()
242+
| `Malformed _ ->
243+
Uutf.Buffer.add_utf_8 buf Uutf.u_rep;
244+
fill_queue ()
245+
| `Await ->
246+
if !offset >= !file_size then (
247+
Uutf.Manual.src !decoder (Bytes.create 0) 0 0;
248+
fill_queue ())
249+
else
250+
let handle = ensure_handle () in
251+
let chunk =
252+
read_mmap_chunk handle ~offset:!offset ~length:chunk_size
253+
in
254+
offset := !offset + String.length chunk;
255+
if chunk = "" then (
256+
Uutf.Manual.src !decoder (Bytes.create 0) 0 0;
257+
fill_queue ())
258+
else
259+
let bytes = Bytes.of_string chunk in
260+
Uutf.Manual.src !decoder bytes 0 (Bytes.length bytes);
261+
fill_queue ()
262+
| `End ->
263+
if Buffer.length buf > 0 then push_line_from_buf ();
264+
close_handle ();
265+
closed := true
266+
in
247267

248-
{
249-
next = next_line;
250-
cardinality = (fun () -> Unknown);
251-
reset = Some reset;
252-
spec = (fun () -> Scalar "string");
253-
}
268+
let rec next_line () =
269+
if not (Queue.is_empty lines_queue) then Some (Queue.take lines_queue)
270+
else if !closed then None
271+
else (
272+
fill_queue ();
273+
if not (Queue.is_empty lines_queue) then Some (Queue.take lines_queue)
274+
else if !closed then None
275+
else next_line ())
276+
in
277+
278+
let reset () =
279+
Buffer.clear buf;
280+
Queue.clear lines_queue;
281+
offset := 0;
282+
closed := false;
283+
decoder := make_decoder ();
284+
close_handle ();
285+
ignore (open_handle ())
286+
in
287+
288+
{
289+
next = next_line;
290+
cardinality = (fun () -> Unknown);
291+
reset = Some reset;
292+
spec = (fun () -> Scalar "string");
293+
}
254294

255295
let from_text_files ?(encoding = `UTF8) ?(chunk_size = 65536) paths =
256296
let current_file = ref 0 in
@@ -262,7 +302,8 @@ let from_text_files ?(encoding = `UTF8) ?(chunk_size = 65536) paths =
262302
if !current_file >= List.length paths then None
263303
else
264304
let path = List.nth paths !current_file in
265-
current_dataset := Some (from_text_file ~encoding ~chunk_size path);
305+
let ds = from_text_file ~encoding ~chunk_size path in
306+
current_dataset := Some ds;
266307
incr current_file;
267308
next ()
268309
| Some ds -> (

kaun/lib/kaun/dataset/dataset.mli

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,21 @@ val from_file : (string -> 'a) -> string -> 'a t
7070
(** {2 Text Data Sources} *)
7171

7272
val from_text_file :
73-
?encoding:[ `UTF8 | `ASCII ] -> ?chunk_size:int -> string -> string t
73+
?encoding:[ `UTF8 | `ASCII | `LATIN1 ] ->
74+
?chunk_size:int ->
75+
string ->
76+
string t
7477
(** [from_text_file ?encoding ?chunk_size path] creates a memory-mapped text
7578
dataset yielding lines as strings.
7679
- [encoding]: Text encoding (default: UTF8)
7780
- [chunk_size]: Size of chunks to read at once (default: 64KB) The file is
7881
memory-mapped and read lazily in chunks. *)
7982

8083
val from_text_files :
81-
?encoding:[ `UTF8 | `ASCII ] -> ?chunk_size:int -> string list -> string t
84+
?encoding:[ `UTF8 | `ASCII | `LATIN1 ] ->
85+
?chunk_size:int ->
86+
string list ->
87+
string t
8288
(** [from_text_files paths] creates a dataset from multiple text files. Files
8389
are processed sequentially without loading all into memory. *)
8490

kaun/lib/kaun/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
(library
44
(name kaun)
55
(public_name kaun)
6-
(libraries rune unix str nx nx.core nx.io yojson domainslib))
6+
(libraries rune unix str nx nx.core nx.io yojson domainslib uutf))

kaun/test/test_dataset.ml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,23 @@ let test_from_text_file () =
6969
[ "line1"; "line2"; "line3" ]
7070
collected)
7171

72+
(* Test for utf8 *)
73+
let test_from_text_file_utf8 () =
74+
let content = "hello \xF0\x9F\x98\x8A\nsecond\n" in
75+
with_temp_file content (fun path ->
76+
let ds = from_text_file ~encoding:`UTF8 path in
77+
let lines = collect_dataset ds in
78+
Alcotest.(check (list string))
79+
"utf8 emoji preserved" [ "hello 😊"; "second" ] lines)
80+
81+
(* Test for Latin1 *)
82+
let test_from_text_file_latin1 () =
83+
let content = "caf\xE9\nna\xEFve\n" in
84+
with_temp_file content (fun path ->
85+
let ds = from_text_file ~encoding:`LATIN1 path in
86+
let lines = collect_dataset ds in
87+
Alcotest.(check (list string)) "latin1 decoded" [ "café"; "naïve" ] lines)
88+
7289
let test_from_text_file_large_lines () =
7390
let line = String.make 1000 'x' in
7491
let content = line ^ "\n" ^ line ^ "\n" in
@@ -80,6 +97,29 @@ let test_from_text_file_large_lines () =
8097
(fun l -> Alcotest.(check int) "line length" 1000 (String.length l))
8198
collected)
8299

100+
let test_from_text_file_reset () =
101+
let content = "line1\nline2\n" in
102+
with_temp_file content (fun path ->
103+
let dataset = from_text_file path in
104+
let expected = [ "line1"; "line2" ] in
105+
let first_pass = collect_dataset dataset in
106+
Alcotest.(check (list string)) "first pass" expected first_pass;
107+
reset dataset;
108+
let second_pass = collect_dataset dataset in
109+
Alcotest.(check (list string)) "after reset" expected second_pass)
110+
111+
let test_from_text_file_reset_mid_stream () =
112+
let content = "alpha\nbeta\ngamma\n" in
113+
with_temp_file content (fun path ->
114+
let dataset = from_text_file path in
115+
let first_chunk = collect_n 1 dataset in
116+
Alcotest.(check (list string))
117+
"consumed first element" [ "alpha" ] first_chunk;
118+
reset dataset;
119+
let refreshed = collect_n 2 dataset in
120+
Alcotest.(check (list string))
121+
"after reset first two elements" [ "alpha"; "beta" ] refreshed)
122+
83123
let test_from_text_files () =
84124
let content1 = "file1_line1\nfile1_line2\n" in
85125
let content2 = "file2_line1\nfile2_line2\n" in
@@ -618,8 +658,13 @@ let () =
618658
( "text_files",
619659
[
620660
test_case "from_text_file" `Quick test_from_text_file;
661+
test_case "from_text_file_utf8" `Quick test_from_text_file_utf8;
662+
test_case "from_text_file_latin1" `Quick test_from_text_file_latin1;
621663
test_case "from_text_file_large_lines" `Quick
622664
test_from_text_file_large_lines;
665+
test_case "from_text_file_reset" `Quick test_from_text_file_reset;
666+
test_case "from_text_file_reset_mid_stream" `Quick
667+
test_from_text_file_reset_mid_stream;
623668
test_case "from_text_files" `Quick test_from_text_files;
624669
test_case "from_jsonl" `Quick test_from_jsonl;
625670
test_case "from_jsonl_custom_field" `Quick

0 commit comments

Comments
 (0)