Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions nx-datasets/example/california.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ let () =
let labels_1d = Nx.reshape [| n_samples |] labels in
let labels_f32 = astype_f32 labels_1d in

let longitude = Nx.slice [ Nx.R (0, n_samples); Nx.R (0, 1) ] features in
let latitude = Nx.slice [ Nx.R (0, n_samples); Nx.R (1, 2) ] features in
(* slice produces shape [n;1] — reshape to 1-D [n] so Hugin's scatter receives
a vector not a 2-D column *)
let longitude_col = Nx.slice [ Nx.R (0, n_samples); Nx.R (0, 1) ] features in
let latitude_col = Nx.slice [ Nx.R (0, n_samples); Nx.R (1, 2) ] features in

let longitude = Nx.reshape [| n_samples |] longitude_col in
let latitude = Nx.reshape [| n_samples |] latitude_col in

let longitude_f32 = astype_f32 longitude in
let latitude_f32 = astype_f32 latitude in

Expand Down
34 changes: 22 additions & 12 deletions nx-datasets/lib/datasets/airline_passengers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@ let data_path = dataset_dir ^ data_filename
let url =
"https://raw.githubusercontent.com/jbrownlee/Datasets/master/airline-passengers.csv"

(* Logging source for this loader *)
let src =
Logs.Src.create "nx.datasets.airline_passengers"
~doc:"Airline passengers loader"

module Log = (val Logs.src_log src : Logs.LOG)

let ensure_dataset () = ensure_file url data_path

let load () =
ensure_dataset ();
Printf.printf "Loading Airline Passengers dataset...\n%!";
Log.info (fun m -> m "Loading Airline Passengers dataset...");

let header, data_rows_iter =
try
Expand Down Expand Up @@ -58,9 +65,11 @@ let load () =
let row_list = Csv.Row.to_list row in
(* Convert Row.t to string list *)
if List.length row_list <> List.length header then
Printf.eprintf "Warning: Row %d has %d columns, expected %d\n%!"
(List.length acc + 1)
(List.length row_list) (List.length header);
Log.warn (fun m ->
m "Row %d has %d columns, expected %d (header: %s)"
(List.length acc + 1)
(List.length row_list) (List.length header)
(String.concat ", " header));

(* Check length before accessing *)
if List.length row_list > passenger_col_index then
Expand All @@ -73,12 +82,13 @@ let load () =
let passenger_int = parse_int_cell ~context passenger_str in
passenger_int :: acc
else (
Printf.eprintf
"Warning: Row %d is shorter than expected (%d < %d), skipping \
passenger value.\n\
%!"
(List.length acc + 1)
(List.length row_list) (passenger_col_index + 1);
Log.warn (fun m ->
m
"Row %d is shorter than expected (%d < %d), skipping \
passenger value. Missing column: %s"
(List.length acc + 1)
(List.length row_list) (passenger_col_index + 1)
passenger_col_name);
-1 :: acc (* Placeholder for missing data *)))
~init:[] data_rows_iter
with
Expand All @@ -101,7 +111,7 @@ let load () =
let num_samples = List.length data_rows_rev in
if num_samples = 0 then
failwith "No data rows loaded from airline-passengers.csv";
Printf.printf "Found %d samples.\n%!" num_samples;
Log.info (fun m -> m "Found %d samples." num_samples);

(* Create Bigarray and populate (data is reversed from fold_left) *)
let passengers = Array1.create int32 c_layout num_samples in
Expand All @@ -110,5 +120,5 @@ let load () =
passengers.{num_samples - 1 - i} <- Int32.of_int passenger_val)
data_rows_rev;

Printf.printf "Airline Passengers loading complete.\n%!";
Log.info (fun m -> m "Airline Passengers loading complete.");
passengers
13 changes: 10 additions & 3 deletions nx-datasets/lib/datasets/breast_cancer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
open Bigarray
open Dataset_utils

(* Logging source for this loader *)
let src =
Logs.Src.create "nx.datasets.breast_cancer"
~doc:"Breast Cancer dataset loader"

module Log = (val Logs.src_log src : Logs.LOG)

let dataset_name = "breast-cancer-wisconsin"
let dataset_dir = get_cache_dir dataset_name
let data_filename = "wdbc.data"
Expand All @@ -24,7 +31,7 @@ let encode_label label row =

let load () =
ensure_dataset ();
Printf.printf "Loading Breast Cancer Wisconsin dataset...\n%!";
Log.info (fun m -> m "Loading Breast Cancer Wisconsin dataset...");

let data_rows =
try
Expand All @@ -51,7 +58,7 @@ let load () =
let expected_cols = 32 in
let num_features = 30 in

Printf.printf "Found %d samples.\n%!" num_samples;
Log.info (fun m -> m "Found %d samples." num_samples);

let features = Array2.create float64 c_layout num_samples num_features in
let labels = Array1.create int c_layout num_samples in
Expand All @@ -75,5 +82,5 @@ let load () =
done)
data_rows;

Printf.printf "Breast Cancer Wisconsin loading complete.\n%!";
Log.info (fun m -> m "Breast Cancer Wisconsin loading complete.");
(features, labels)
45 changes: 28 additions & 17 deletions nx-datasets/lib/datasets/california_housing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
open Bigarray
open Dataset_utils

(* Logging source for this loader *)
let src =
Logs.Src.create "nx.datasets.california_housing"
~doc:"California Housing loader"

module Log = (val Logs.src_log src : Logs.LOG)

let dataset_name = "california-housing"
let dataset_dir = get_cache_dir dataset_name
let data_filename = "housing.csv"
Expand All @@ -28,7 +35,7 @@ let calculate_mean_non_nan column_data =

let load () =
ensure_dataset ();
Printf.printf "Loading California Housing dataset...\n%!";
Log.info (fun m -> m "Loading California Housing dataset...");

let header, all_data_rows =
try
Expand Down Expand Up @@ -85,8 +92,9 @@ let load () =
List.find_index (( = ) "total_bedrooms") header
in

Printf.printf "Found %d samples. Loading %d features + target '%s'.\n%!"
num_samples num_features target_name;
Log.info (fun m ->
m "Found %d samples. Loading %d features + target '%s'." num_samples
num_features target_name);

let parsed_features_temp = Array.make_matrix num_samples num_features nan in
let parsed_labels_temp = Array.make num_samples nan in
Expand All @@ -95,8 +103,10 @@ let load () =
List.iteri
(fun i row ->
if List.length row <> List.length header then
Printf.eprintf "Warning: Row %d has %d columns, expected %d\n%!" (i + 1)
(List.length row) (List.length header);
Log.warn (fun m ->
m "Row %d has %d columns, expected %d (header: %s)" (i + 1)
(List.length row) (List.length header)
(String.concat ", " header));

List.iteri
(fun j feature_idx ->
Expand All @@ -106,29 +116,30 @@ let load () =
parsed_features_temp.(i).(j) <- v_float;
if Some feature_idx = total_bedrooms_index_opt then
total_bedrooms_col_temp := v_float :: !total_bedrooms_col_temp)
else (
Printf.eprintf
"Warning: Row %d missing feature column %d ('%s'). Setting NaN.\n\
%!"
(i + 1) feature_idx (List.nth feature_names j);
else
let feature_name = List.nth feature_names j in
Log.warn (fun m ->
m "Row %d missing feature column %d ('%s'). Setting NaN."
(i + 1) feature_idx feature_name);
parsed_features_temp.(i).(j) <- nan;
if Some feature_idx = total_bedrooms_index_opt then
total_bedrooms_col_temp := nan :: !total_bedrooms_col_temp))
total_bedrooms_col_temp := nan :: !total_bedrooms_col_temp)
feature_indices;

if List.length row > target_index then
let label_str = List.nth row target_index in
parsed_labels_temp.(i) <- parse_float_or_nan label_str
else (
Printf.eprintf
"Warning: Row %d missing target column %d ('%s'). Setting NaN.\n%!"
(i + 1) target_index target_name;
Log.warn (fun m ->
m "Row %d missing target column %d ('%s'). Setting NaN." (i + 1)
target_index target_name);
parsed_labels_temp.(i) <- nan))
data_rows_str;

let total_bedrooms_mean = calculate_mean_non_nan !total_bedrooms_col_temp in
Printf.printf "Calculated mean for 'total_bedrooms' (for imputation): %f\n%!"
total_bedrooms_mean;
Log.info (fun m ->
m "Calculated mean for 'total_bedrooms' (for imputation): %f"
total_bedrooms_mean);
let total_bedrooms_feature_index =
match List.find_index (( = ) "total_bedrooms") feature_names with
| Some idx -> idx
Expand Down Expand Up @@ -163,5 +174,5 @@ let load () =
else labels.{i} <- label_v
done;

Printf.printf "California Housing loading complete.\n%!";
Log.info (fun m -> m "California Housing loading complete.");
(features, labels)
13 changes: 9 additions & 4 deletions nx-datasets/lib/datasets/cifar10.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@ let url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
let tar_path = base_dir ^ Filename.basename url
let test_batch_rel_path = archive_dir_name ^ "/test_batch"

(* Logging source for this loader *)
let src = Logs.Src.create "nx.datasets.cifar10" ~doc:"CIFAR10 dataset loader"

module Log = (val Logs.src_log src : Logs.LOG)

let ensure_dataset () =
ensure_extracted_archive ~url ~archive_path:tar_path ~extract_dir:base_dir
~check_file:test_batch_rel_path

let read_cifar_batch filename =
Printf.printf "Reading batch file: %s\n%!" filename;
Log.debug (fun m -> m "Reading batch file: %s" filename);
let ic = open_in_bin filename in
let s =
try really_input_string ic (in_channel_length ic)
Expand All @@ -32,7 +37,7 @@ let read_cifar_batch filename =
(Printf.sprintf "File %s has unexpected size %d" filename num_bytes);

let num_images = num_bytes / bytes_per_image in
Printf.printf "Found %d images in %s.\n%!" num_images filename;
Log.debug (fun m -> m "Found %d images in %s." num_images filename);

let images =
Genarray.create int8_unsigned c_layout [| num_images; 32; 32; 3 |]
Expand Down Expand Up @@ -64,7 +69,7 @@ let read_cifar_batch filename =

let load () =
ensure_dataset ();
Printf.printf "Loading CIFAR-10 dataset...\n%!";
Log.info (fun m -> m "Loading CIFAR-10 dataset...");

let train_batches_files =
List.init 5 (fun i -> dataset_dir ^ Printf.sprintf "data_batch_%d" (i + 1))
Expand Down Expand Up @@ -107,5 +112,5 @@ let load () =
let test_batch_file = dataset_dir ^ "test_batch" in
let test_images, test_labels = read_cifar_batch test_batch_file in

Printf.printf "CIFAR-10 loading complete.\n%!";
Log.info (fun m -> m "CIFAR-10 loading complete.");
((train_images, train_labels), (test_images, test_labels))
14 changes: 10 additions & 4 deletions nx-datasets/lib/datasets/diabetes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
open Bigarray
open Dataset_utils

(* Logging source for this loader *)
let src = Logs.Src.create "nx.datasets.diabetes" ~doc:"Diabetes dataset loader"

module Log = (val Logs.src_log src : Logs.LOG)

let dataset_name = "diabetes-sklearn"
let dataset_dir = get_cache_dir dataset_name
let data_filename = "diabetes.tab.txt"
Expand All @@ -11,7 +16,7 @@ let ensure_dataset () = ensure_file url data_path

let load () =
ensure_dataset ();
Printf.printf "Loading Diabetes (sklearn version) dataset...\n%!";
Log.info (fun m -> m "Loading Diabetes (sklearn version) dataset...");

let header, data_rows_iter =
try
Expand Down Expand Up @@ -106,8 +111,9 @@ let load () =
let num_samples = List.length labels_rev in

if num_samples = 0 then failwith "No data rows loaded from diabetes.tab.txt";
Printf.printf "Found %d samples with %d features and target '%s'.\n%!"
num_samples num_features target_col_name;
Log.info (fun m ->
m "Found %d samples with %d features and target '%s'." num_samples
num_features target_col_name);

let features_ba = Array2.create float64 c_layout num_samples num_features in
let labels_ba = Array1.create float64 c_layout num_samples in
Expand All @@ -124,5 +130,5 @@ let load () =
(fun i label_val -> labels_ba.{num_samples - 1 - i} <- label_val)
labels_rev;

Printf.printf "Diabetes loading complete.\n%!";
Log.info (fun m -> m "Diabetes loading complete.");
(features_ba, labels_ba)
10 changes: 7 additions & 3 deletions nx-datasets/lib/datasets/iris.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
open Bigarray
open Dataset_utils

let src = Logs.Src.create "nx.datasets.iris" ~doc:"Iris dataset loader"

module Log = (val Logs.src_log src : Logs.LOG)

let dataset_name = "iris"
let dataset_dir = get_cache_dir dataset_name
let data_filename = "iris.data"
Expand All @@ -24,7 +28,7 @@ let encode_label s =

let load () =
ensure_dataset ();
Printf.printf "Loading Iris dataset...\n%!";
Log.info (fun m -> m "Loading Iris dataset...");

let data_rows =
try
Expand All @@ -50,7 +54,7 @@ let load () =
let num_features = 4 in

if num_samples = 0 then failwith "No data loaded from iris.data";
Printf.printf "Found %d samples.\n%!" num_samples;
Log.info (fun m -> m "Found %d samples" num_samples);

let features = Array2.create float64 c_layout num_samples num_features in
let labels = Array1.create int32 c_layout num_samples in
Expand All @@ -73,5 +77,5 @@ let load () =
labels.{i} <- encode_label label_str)
data_rows;

Printf.printf "Iris loading complete.\n%!";
Log.info (fun m -> m "Iris loading complete");
(features, labels)
Loading