Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ All notable changes to this project will be documented in this file.

### Nx-datasets

- Fix cache directory resolution to respect `RAVEN_CACHE_ROOT` (or fall back to `XDG_CACHE_HOME`/`HOME`), allowing custom cache locations. (#128, @Arsalaan-Alam)
- Switch CIFAR-10 loader to the binary archive so parsing succeeds again (@tmattio)
- Add a CIFAR-10 example (@tmattio)
- Standardize dataset examples on `Logs` (@tmattio)
Expand Down
24 changes: 14 additions & 10 deletions nx-datasets/lib/dataset_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,19 @@ let mkdir_p path perm =
initial_prefix components);
()

module Xdg = struct
let home =
try Sys.getenv "HOME"
with Not_found -> failwith "HOME environment variable not set."

let cache_base = home ^ "/.cache/ocaml-nx/datasets/"
end

let get_cache_dir dataset_name = Xdg.cache_base ^ dataset_name ^ "/"
let get_cache_base_dir ?(getenv = Sys.getenv_opt) () =
match getenv "RAVEN_CACHE_ROOT" with
| Some dir when dir <> "" -> dir
| _ ->
let xdg = Xdg.create ~env:getenv () in
Filename.concat (Xdg.cache_dir xdg) "raven"

let get_cache_dir ?(getenv = Sys.getenv_opt) dataset_name =
let base = get_cache_base_dir ~getenv () in
let path = List.fold_left Filename.concat base [ "datasets"; dataset_name ] in
let sep = Filename.dir_sep.[0] in
if path <> "" && path.[String.length path - 1] = sep then path
else path ^ Filename.dir_sep

let mkdir_p dir =
try mkdir_p dir 0o755 with Unix.Unix_error (Unix.EEXIST, _, _) -> ()
Expand All @@ -77,7 +81,7 @@ let download_file url dest_path =
h#set_timeout 300;
(* 5 minutes *)
(* Provide a user agent *)
h#set_useragent "ocaml-nx-datasets/1.0.0";
h#set_useragent "raven/1.0.0";

let oc = open_out_bin dest_path in
let result =
Expand Down
25 changes: 22 additions & 3 deletions nx-datasets/lib/dataset_utils.mli
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
(** Utilities for downloading and managing datasets. *)

val get_cache_dir : string -> string
val get_cache_dir : ?getenv:(string -> string option) -> string -> string
(** Return the platform-specific cache directory path for the given dataset.

The default location is "~/.cache/ocaml-nx/datasets/[dataset_name]/".
The cache directory is resolved using the following priority order: 1.
[RAVEN_CACHE_ROOT] environment variable (highest priority; absolute cache
root) 2. [XDG_CACHE_HOME] environment variable (if RAVEN_CACHE_ROOT not set)
3. [$HOME/.cache] (fallback, default behavior)

The resolved path will be "[cache_root]/datasets/[dataset_name]/", where
[cache_root] is either [RAVEN_CACHE_ROOT] or
"[XDG_CACHE_HOME or HOME]/raven", with platform-appropriate directory
separators and a trailing separator.

{2 Parameters}
- dataset_name: the name of the dataset.

{2 Returns}
- the cache directory path, including trailing slash. *)
- the cache directory path, including trailing slash.

@param getenv
optional environment lookup function (defaults to [Sys.getenv_opt]) to
facilitate testing.

{2 Environment Variables}
- [RAVEN_CACHE_ROOT]: Custom cache directory root (overrides all other
settings)
- [XDG_CACHE_HOME]: XDG Base Directory cache location (standard on
Linux/Unix)
- [HOME]: User home directory (used for fallback cache location) *)

val download_file : string -> string -> unit
(** Download a file from a URL to a destination path.
Expand Down
2 changes: 1 addition & 1 deletion nx-datasets/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
(library
(name nx_datasets)
(public_name nx-datasets)
(libraries unix zip curl csv nx bigarray_ext logs))
(libraries unix zip curl csv nx bigarray_ext logs xdg))
3 changes: 3 additions & 0 deletions nx-datasets/lib/nx_datasets.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ let load_airline_passengers () =

(* Include generators inline *)
include Generators

let get_cache_dir ?getenv dataset_name =
Dataset_utils.get_cache_dir ?getenv dataset_name
19 changes: 19 additions & 0 deletions nx-datasets/lib/nx_datasets.mli
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@
generate synthetic datasets for testing and experimentation. Real datasets
are downloaded and cached in the platform-specific cache directory. *)

(** {2 Cache Management}

Helpers for inspecting dataset cache locations. *)

val get_cache_dir : ?getenv:(string -> string option) -> string -> string
(** Resolve the cache directory for the given dataset name.

The lookup order is: 1. [RAVEN_CACHE_ROOT] 2. [XDG_CACHE_HOME] 3.
[$HOME/.cache]

The returned path uses platform-specific separators, ends with a trailing
separator, and is rooted at either [RAVEN_CACHE_ROOT] or
"[XDG_CACHE_HOME or HOME]/raven", with datasets stored under a [datasets]
subdirectory.

@param getenv optional environment getter (defaults to [Sys.getenv_opt]).
@param dataset_name name of the dataset.
@return cache directory path with trailing separator. *)

(** {2 Loading Real Datasets}

Functions to load classic machine learning datasets as Nx tensors. *)
Expand Down
5 changes: 5 additions & 0 deletions nx-datasets/test/dune
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@
(name test_generators)
(package nx-datasets)
(libraries nx nx-datasets alcotest))

(test
(name test_dataset_utils)
(package nx-datasets)
(libraries nx nx-datasets alcotest))
65 changes: 65 additions & 0 deletions nx-datasets/test/test_dataset_utils.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
open Alcotest

let build_expected base dataset =
let path = List.fold_left Filename.concat base [ "datasets"; dataset ] in
let sep = Filename.dir_sep.[0] in
if path <> "" && path.[String.length path - 1] = sep then path
else path ^ Filename.dir_sep

let getenv_of_list env var = List.assoc_opt var env

let test_cache_dir_resolution () =
let temp_dir = Filename.get_temp_dir_name () in
let home_dir = Filename.concat temp_dir "home-dir" in
let custom_cache_dir = Filename.concat temp_dir "nx-cache" in
let xdg_cache_dir = Filename.concat temp_dir "xdg-cache" in
let base_env = [ ("HOME", home_dir); ("USERPROFILE", home_dir) ] in

(* RAVEN_CACHE_ROOT has highest priority *)
let env_with_custom =
("RAVEN_CACHE_ROOT", custom_cache_dir)
:: ("XDG_CACHE_HOME", xdg_cache_dir)
:: base_env
in
let path1 =
Nx_datasets.get_cache_dir ~getenv:(getenv_of_list env_with_custom) "iris"
in
let expected1 = build_expected custom_cache_dir "iris" in
check string "RAVEN_CACHE_ROOT takes priority" expected1 path1;

(* XDG_CACHE_HOME is used when RAVEN_CACHE_ROOT is unset or empty *)
let env_with_xdg =
("RAVEN_CACHE_ROOT", "") :: ("XDG_CACHE_HOME", xdg_cache_dir) :: base_env
in
let path2 =
Nx_datasets.get_cache_dir ~getenv:(getenv_of_list env_with_xdg) "mnist"
in
let expected2 =
build_expected (Filename.concat xdg_cache_dir "raven") "mnist"
in
check string "XDG_CACHE_HOME used when RAVEN_CACHE_ROOT unset" expected2 path2;

(* HOME fallback when neither cache env var is provided *)
let env_with_home_only =
("RAVEN_CACHE_ROOT", "") :: ("XDG_CACHE_HOME", "") :: base_env
in
let path3 =
Nx_datasets.get_cache_dir
~getenv:(getenv_of_list env_with_home_only)
"cifar10"
in
let home_cache =
Filename.concat (Filename.concat home_dir ".cache") "raven"
in
let expected3 = build_expected home_cache "cifar10" in
check string "Falls back to HOME/.cache when no env vars set" expected3 path3

let () =
run "Dataset Utils"
[
( "Cache Directory Resolution",
[
test_case "Environment variable precedence" `Quick
test_cache_dir_resolution;
] );
]
1 change: 1 addition & 0 deletions nx-datasets/vendor/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(vendored_dirs *)
21 changes: 21 additions & 0 deletions nx-datasets/vendor/xdg/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
The MIT License

Copyright (c) 2016 Jane Street Group, LLC <opensource@janestreet.com>

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
7 changes: 7 additions & 0 deletions nx-datasets/vendor/xdg/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(library
(name xdg)
(public_name nx-datasets.xdg)
(libraries unix)
(foreign_stubs
(language c)
(names xdg_stubs)))
94 changes: 94 additions & 0 deletions nx-datasets/vendor/xdg/xdg.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
type t =
{ env : string -> string option
; win32 : bool
; home_dir : string
; mutable cache_dir : string
; mutable config_dir : string
; mutable data_dir : string
; mutable state_dir : string
; mutable runtime_dir : string option
}

let ( / ) = Filename.concat

type known_folder =
| InternetCache
| LocalAppData

external get_known_folder_path
: known_folder
-> string option
= "dune_xdg__get_known_folder_path"

let make t env_var unix_default win32_folder =
let default =
if t.win32
then (
match get_known_folder_path win32_folder with
| None -> ""
| Some s -> s)
else unix_default
in
match t.env env_var with
| None -> default
| Some s when Filename.is_relative s -> default
| Some s -> s
;;

let cache_dir t =
let home = t.home_dir in
make t "XDG_CACHE_HOME" (home / ".cache") InternetCache
;;

let config_dir t =
let home = t.home_dir in
make t "XDG_CONFIG_HOME" (home / ".config") LocalAppData
;;

let data_dir t =
let home = t.home_dir in
make t "XDG_DATA_HOME" (home / ".local" / "share") LocalAppData
;;

let state_dir t =
let home = t.home_dir in
make t "XDG_STATE_HOME" (home / ".local" / "state") LocalAppData
;;

let create ?win32 ~env () =
let win32 =
match win32 with
| None -> Sys.win32
| Some s -> s
in
let home_dir =
let var = if win32 then "USERPROFILE" else "HOME" in
match env var with
| None -> ""
| Some s -> s
in
let t =
{ env
; win32
; home_dir
; cache_dir = ""
; config_dir = ""
; data_dir = ""
; state_dir = ""
; runtime_dir = None
}
in
t.cache_dir <- cache_dir t;
t.config_dir <- config_dir t;
t.data_dir <- data_dir t;
t.state_dir <- state_dir t;
t.runtime_dir <- env "XDG_RUNTIME_DIR";
t
;;

let home_dir t = t.home_dir
let config_dir t = t.config_dir
let data_dir t = t.data_dir
let cache_dir t = t.cache_dir
let state_dir t = t.state_dir
let runtime_dir t = t.runtime_dir
26 changes: 26 additions & 0 deletions nx-datasets/vendor/xdg/xdg.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
(** Base directories. Values of type {!t} are created using {!create}. *)
type t

(** The user's home directory. Uses [$USERPROFILE] on Windows, [$HOME]
otherwise. *)
val home_dir : t -> string

(** The directory where the application should read/write config files. *)
val config_dir : t -> string

(** The directory where the application should read/write data files. *)
val data_dir : t -> string

(** The directory where the application should read/write cached files. *)
val cache_dir : t -> string

(** The directory where the application should read/write state files. *)
val state_dir : t -> string

(** The directory where the application should store socket files. *)
val runtime_dir : t -> string option

(** Constructor of type {!t}. [~win32] (default: {!Sys.win32}) determines
whether to use Win32-specific APIs. [~env] is the function to get
environment variables, typically {!Sys.getenv_opt}. *)
val create : ?win32:bool -> env:(string -> string option) -> unit -> t
Loading
Loading