Skip to content

Commit 6ecba89

Browse files
committed
Use snapshot API to save/load dqn and reinforce
1 parent 9d8adee commit 6ecba89

File tree

5 files changed

+450
-117
lines changed

5 files changed

+450
-117
lines changed

fehu/lib/algorithms/dqn.ml

Lines changed: 163 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
open Kaun
22

3+
module Snapshot = Checkpoint.Snapshot
4+
35
type config = {
46
learning_rate : float;
57
gamma : float;
@@ -47,6 +49,151 @@ type update_metrics = {
4749
loss : float;
4850
}
4951

52+
let dqn_schema_key = "schema"
53+
let dqn_schema_value = "fehu.dqn/1"
54+
55+
let config_to_snapshot (c : config) : Snapshot.t =
56+
Snapshot.record
57+
[
58+
("learning_rate", Snapshot.float c.learning_rate);
59+
("gamma", Snapshot.float c.gamma);
60+
("epsilon_start", Snapshot.float c.epsilon_start);
61+
("epsilon_end", Snapshot.float c.epsilon_end);
62+
("epsilon_decay", Snapshot.float c.epsilon_decay);
63+
("batch_size", Snapshot.int c.batch_size);
64+
("buffer_capacity", Snapshot.int c.buffer_capacity);
65+
("target_update_freq", Snapshot.int c.target_update_freq);
66+
]
67+
68+
let config_of_snapshot (snapshot : Snapshot.t) : (config, string) result =
69+
let open Result in
70+
let ( let* ) = Result.bind in
71+
let open Snapshot in
72+
match snapshot with
73+
| Record record ->
74+
let find_float field =
75+
match Snapshot.Record.find_opt field record with
76+
| Some (Scalar (Float value)) -> Ok value
77+
| Some (Scalar (Int value)) -> Ok (float_of_int value)
78+
| Some _ ->
79+
Error (Printf.sprintf "DQN config field %s must be float" field)
80+
| None -> Error (Printf.sprintf "Missing DQN config field %s" field)
81+
in
82+
let find_int field =
83+
match Snapshot.Record.find_opt field record with
84+
| Some (Scalar (Int value)) -> Ok value
85+
| Some (Scalar (Float value)) -> Ok (int_of_float value)
86+
| Some _ ->
87+
Error (Printf.sprintf "DQN config field %s must be int" field)
88+
| None -> Error (Printf.sprintf "Missing DQN config field %s" field)
89+
in
90+
let* learning_rate = find_float "learning_rate" in
91+
let* gamma = find_float "gamma" in
92+
let* epsilon_start = find_float "epsilon_start" in
93+
let* epsilon_end = find_float "epsilon_end" in
94+
let* epsilon_decay = find_float "epsilon_decay" in
95+
let* batch_size = find_int "batch_size" in
96+
let* buffer_capacity = find_int "buffer_capacity" in
97+
let* target_update_freq = find_int "target_update_freq" in
98+
Ok
99+
{
100+
learning_rate;
101+
gamma;
102+
epsilon_start;
103+
epsilon_end;
104+
epsilon_decay;
105+
batch_size;
106+
buffer_capacity;
107+
target_update_freq;
108+
}
109+
| _ -> Error "DQN config must be a record"
110+
111+
let to_snapshot (t : t) : Snapshot.t =
112+
Snapshot.record
113+
[
114+
(dqn_schema_key, Snapshot.string dqn_schema_value);
115+
("n_actions", Snapshot.int t.n_actions);
116+
("rng", Snapshot.rng t.rng);
117+
("config", config_to_snapshot t.config);
118+
("optimizer_state", Optimizer.serialize t.opt_state);
119+
("q_params", Snapshot.ptree t.q_params);
120+
("target_params", Snapshot.ptree t.target_params);
121+
]
122+
123+
let of_snapshot ~(q_network : module_) ~(optimizer : Optimizer.algorithm)
124+
(snapshot : Snapshot.t) : (t, string) result =
125+
let open Result in
126+
let open Snapshot in
127+
let ( let* ) = bind in
128+
let error msg = Error ("Dqn.of_snapshot: " ^ msg) in
129+
match snapshot with
130+
| Record record ->
131+
let validate_schema () =
132+
match Snapshot.Record.find_opt dqn_schema_key record with
133+
| None -> Ok ()
134+
| Some (Scalar (String value)) ->
135+
if String.equal value dqn_schema_value then Ok ()
136+
else error ("unsupported schema " ^ value)
137+
| Some _ -> error "invalid schema field"
138+
in
139+
let* () = validate_schema () in
140+
let find field =
141+
match Snapshot.Record.find_opt field record with
142+
| Some value -> Ok value
143+
| None -> error ("missing field " ^ field)
144+
in
145+
let decode_int = function
146+
| Scalar (Int value) -> Ok value
147+
| Scalar (Float value) -> Ok (int_of_float value)
148+
| _ -> error "expected int scalar"
149+
in
150+
let decode_rng = function
151+
| Scalar (Int seed) -> Ok (Rune.Rng.key seed)
152+
| Scalar (Float value) -> Ok (Rune.Rng.key (int_of_float value))
153+
| _ -> error "expected rng scalar"
154+
in
155+
let* n_actions_node = find "n_actions" in
156+
let* n_actions = decode_int n_actions_node in
157+
let* rng_node = find "rng" in
158+
let* rng = decode_rng rng_node in
159+
let* config_node = find "config" in
160+
let* config = config_of_snapshot config_node in
161+
let* q_params_node = find "q_params" in
162+
let* q_params =
163+
match Snapshot.to_ptree q_params_node with
164+
| Ok params -> Ok params
165+
| Error msg -> error msg
166+
in
167+
let* target_params_node = find "target_params" in
168+
let* target_params =
169+
match Snapshot.to_ptree target_params_node with
170+
| Ok params -> Ok params
171+
| Error msg -> error msg
172+
in
173+
let* opt_state_node = find "optimizer_state" in
174+
let* opt_state =
175+
match Optimizer.restore optimizer opt_state_node with
176+
| Ok state -> Ok state
177+
| Error msg -> error msg
178+
in
179+
let replay_buffer =
180+
Fehu.Buffer.Replay.create ~capacity:config.buffer_capacity
181+
in
182+
Ok
183+
{
184+
q_network;
185+
q_params;
186+
target_network = q_network;
187+
target_params;
188+
optimizer;
189+
opt_state;
190+
replay_buffer;
191+
rng;
192+
n_actions;
193+
config;
194+
}
195+
| _ -> error "expected snapshot record"
196+
50197
let create ~q_network ~n_actions ~rng config =
51198
let keys = Rune.Rng.split ~n:2 rng in
52199

@@ -356,81 +503,19 @@ let learn t ~env ~total_timesteps
356503
357504
t
358505
359-
let save t dir =
360-
if not (Sys.file_exists dir) then Unix.mkdir dir 0o755;
361-
let checkpointer = Kaun.Checkpoint.Checkpointer.create () in
362-
Kaun.Checkpoint.Checkpointer.save_file checkpointer
363-
~path:(Filename.concat dir "q_params.safetensors")
364-
~params:t.q_params ();
365-
Kaun.Checkpoint.Checkpointer.save_file checkpointer
366-
~path:(Filename.concat dir "target_params.safetensors")
367-
~params:t.target_params ();
368-
let rng_seed = Rune.Rng.to_int t.rng in
369-
let metadata = `Assoc [
370-
("n_actions", `Int t.n_actions);
371-
("rng_seed", `Int rng_seed);
372-
("learning_rate", `Float t.config.learning_rate);
373-
("gamma", `Float t.config.gamma);
374-
("epsilon_start", `Float t.config.epsilon_start);
375-
("epsilon_end", `Float t.config.epsilon_end);
376-
("epsilon_decay", `Float t.config.epsilon_decay);
377-
("batch_size", `Int t.config.batch_size);
378-
("buffer_capacity", `Int t.config.buffer_capacity);
379-
("target_update_freq", `Int t.config.target_update_freq);
380-
] in
381-
let metadata_path = Filename.concat dir "metadata.json" in
382-
let oc = open_out metadata_path in
383-
Yojson.Basic.to_channel oc metadata;
384-
close_out oc
385-
386-
let load dir =
387-
let metadata_path = Filename.concat dir "metadata.json" in
388-
let metadata = Yojson.Basic.from_file metadata_path in
389-
let open Yojson.Basic.Util in
390-
let config = {
391-
learning_rate = metadata |> member "learning_rate" |> to_float;
392-
gamma = metadata |> member "gamma" |> to_float;
393-
epsilon_start = metadata |> member "epsilon_start" |> to_float;
394-
epsilon_end = metadata |> member "epsilon_end" |> to_float;
395-
epsilon_decay = metadata |> member "epsilon_decay" |> to_float;
396-
batch_size = metadata |> member "batch_size" |> to_int;
397-
buffer_capacity = metadata |> member "buffer_capacity" |> to_int;
398-
target_update_freq = metadata |> member "target_update_freq" |> to_int;
399-
} in
400-
let n_actions = metadata |> member "n_actions" |> to_int in
401-
let rng_seed = metadata |> member "rng_seed" |> to_int in
402-
let rng = Rune.Rng.key rng_seed in
403-
let checkpointer = Kaun.Checkpoint.Checkpointer.create () in
404-
405-
(* Example: fixed architecture for 2 input features and n_actions *)
406-
let q_network =
407-
Kaun.Layer.sequential [
408-
Kaun.Layer.linear ~in_features:2 ~out_features:8 ();
409-
Kaun.Layer.relu ();
410-
Kaun.Layer.linear ~in_features:8 ~out_features:n_actions ();
411-
]
412-
in
413-
414-
let q_params = Kaun.Checkpoint.Checkpointer.restore_file checkpointer
415-
~path:(Filename.concat dir "q_params.safetensors")
416-
~dtype:Rune.float32
417-
in
418-
let target_params = Kaun.Checkpoint.Checkpointer.restore_file checkpointer
419-
~path:(Filename.concat dir "target_params.safetensors")
420-
~dtype:Rune.float32
421-
in
422-
let optimizer = Optimizer.adam ~lr:config.learning_rate () in
423-
let opt_state = optimizer.init q_params in
424-
let replay_buffer = Fehu.Buffer.Replay.create ~capacity:config.buffer_capacity in
425-
{
426-
q_network;
427-
q_params;
428-
target_network = q_network;
429-
target_params;
430-
optimizer;
431-
opt_state;
432-
replay_buffer;
433-
rng;
434-
n_actions;
435-
config;
436-
}
506+
let save_to_file (t : t) ~path =
507+
match Checkpoint.write_snapshot_file_with ~path ~encode:(fun () -> to_snapshot t) with
508+
| Ok () -> ()
509+
| Error err ->
510+
failwith
511+
(Printf.sprintf "Dqn.save_to_file: %s"
512+
(Checkpoint.error_to_string err))
513+
514+
let load_from_file ~path ~(q_network : module_) ~(optimizer : Optimizer.algorithm)
515+
=
516+
match
517+
Checkpoint.load_snapshot_file_with ~path
518+
~decode:(fun snapshot -> of_snapshot ~q_network ~optimizer snapshot)
519+
with
520+
| Ok agent -> Ok agent
521+
| Error err -> Error (Checkpoint.error_to_string err)

fehu/lib/algorithms/dqn.mli

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -340,21 +340,22 @@ val learn :
340340
ensures the replay buffer has diverse samples before Q-network updates
341341
start. *)
342342

343-
(** [save agent path] saves the agent state to disk.
344-
345-
Creates the following files:
346-
- q_params.safetensors: Q-network weights
347-
- target_params.safetensors: Target network weights
348-
- opt_state.safetensors: Optimizer state
349-
- metadata.json: Config, n_actions, and RNG seed
350-
Note: The replay buffer is not saved. *)
351-
val save : t -> string -> unit
352-
353-
(** [load path ~q_network ~n_actions] loads an agent from disk.
354-
355-
@param path Directory containing the saved checkpoint
356-
@param q_network Network architecture (must match the saved agent)
357-
@param n_actions Number of actions (must match the saved agent)
358-
@raise Failure if n_actions doesn't match or files are missing
359-
Note: The replay buffer starts empty and optimizer is reinitialized. *)
360-
val load : string -> t
343+
(** [save_to_file agent ~path] writes the agent state to a snapshot file.
344+
345+
The snapshot stores the full configuration, RNG seed, Q/target parameters,
346+
and optimizer state. The replay buffer is intentionally omitted. *)
347+
val save_to_file : t -> path:string -> unit
348+
349+
(** [load_from_file ~path ~q_network ~optimizer] restores an agent from a
350+
snapshot file created by {!save_to_file}.
351+
352+
The caller must supply the network architecture and optimizer algorithm so
353+
the parameters and optimizer state can be reconstructed safely. Returns
354+
either the restored agent or an error message describing what went wrong.
355+
356+
The replay buffer starts empty after loading. *)
357+
val load_from_file :
358+
path:string ->
359+
q_network:Kaun.module_ ->
360+
optimizer:Kaun.Optimizer.algorithm ->
361+
(t, string) result

0 commit comments

Comments
 (0)