11open Kaun
22
3+ module Snapshot = Checkpoint. Snapshot
4+
35type config = {
46 learning_rate : float ;
57 gamma : float ;
@@ -47,6 +49,151 @@ type update_metrics = {
4749 loss : float ;
4850}
4951
52+ let dqn_schema_key = " schema"
53+ let dqn_schema_value = " fehu.dqn/1"
54+
55+ let config_to_snapshot (c : config ) : Snapshot.t =
56+ Snapshot. record
57+ [
58+ (" learning_rate" , Snapshot. float c.learning_rate);
59+ (" gamma" , Snapshot. float c.gamma);
60+ (" epsilon_start" , Snapshot. float c.epsilon_start);
61+ (" epsilon_end" , Snapshot. float c.epsilon_end);
62+ (" epsilon_decay" , Snapshot. float c.epsilon_decay);
63+ (" batch_size" , Snapshot. int c.batch_size);
64+ (" buffer_capacity" , Snapshot. int c.buffer_capacity);
65+ (" target_update_freq" , Snapshot. int c.target_update_freq);
66+ ]
67+
68+ let config_of_snapshot (snapshot : Snapshot.t ) : (config, string) result =
69+ let open Result in
70+ let ( let * ) = Result. bind in
71+ let open Snapshot in
72+ match snapshot with
73+ | Record record ->
74+ let find_float field =
75+ match Snapshot.Record. find_opt field record with
76+ | Some (Scalar (Float value )) -> Ok value
77+ | Some (Scalar (Int value )) -> Ok (float_of_int value)
78+ | Some _ ->
79+ Error (Printf. sprintf " DQN config field %s must be float" field)
80+ | None -> Error (Printf. sprintf " Missing DQN config field %s" field)
81+ in
82+ let find_int field =
83+ match Snapshot.Record. find_opt field record with
84+ | Some (Scalar (Int value )) -> Ok value
85+ | Some (Scalar (Float value )) -> Ok (int_of_float value)
86+ | Some _ ->
87+ Error (Printf. sprintf " DQN config field %s must be int" field)
88+ | None -> Error (Printf. sprintf " Missing DQN config field %s" field)
89+ in
90+ let * learning_rate = find_float " learning_rate" in
91+ let * gamma = find_float " gamma" in
92+ let * epsilon_start = find_float " epsilon_start" in
93+ let * epsilon_end = find_float " epsilon_end" in
94+ let * epsilon_decay = find_float " epsilon_decay" in
95+ let * batch_size = find_int " batch_size" in
96+ let * buffer_capacity = find_int " buffer_capacity" in
97+ let * target_update_freq = find_int " target_update_freq" in
98+ Ok
99+ {
100+ learning_rate;
101+ gamma;
102+ epsilon_start;
103+ epsilon_end;
104+ epsilon_decay;
105+ batch_size;
106+ buffer_capacity;
107+ target_update_freq;
108+ }
109+ | _ -> Error " DQN config must be a record"
110+
111+ let to_snapshot (t : t ) : Snapshot.t =
112+ Snapshot. record
113+ [
114+ (dqn_schema_key, Snapshot. string dqn_schema_value);
115+ (" n_actions" , Snapshot. int t.n_actions);
116+ (" rng" , Snapshot. rng t.rng);
117+ (" config" , config_to_snapshot t.config);
118+ (" optimizer_state" , Optimizer. serialize t.opt_state);
119+ (" q_params" , Snapshot. ptree t.q_params);
120+ (" target_params" , Snapshot. ptree t.target_params);
121+ ]
122+
123+ let of_snapshot ~(q_network : module_ ) ~(optimizer : Optimizer.algorithm )
124+ (snapshot : Snapshot.t ) : (t, string) result =
125+ let open Result in
126+ let open Snapshot in
127+ let ( let * ) = bind in
128+ let error msg = Error (" Dqn.of_snapshot: " ^ msg) in
129+ match snapshot with
130+ | Record record ->
131+ let validate_schema () =
132+ match Snapshot.Record. find_opt dqn_schema_key record with
133+ | None -> Ok ()
134+ | Some (Scalar (String value )) ->
135+ if String. equal value dqn_schema_value then Ok ()
136+ else error (" unsupported schema " ^ value)
137+ | Some _ -> error " invalid schema field"
138+ in
139+ let * () = validate_schema () in
140+ let find field =
141+ match Snapshot.Record. find_opt field record with
142+ | Some value -> Ok value
143+ | None -> error (" missing field " ^ field)
144+ in
145+ let decode_int = function
146+ | Scalar (Int value ) -> Ok value
147+ | Scalar (Float value ) -> Ok (int_of_float value)
148+ | _ -> error " expected int scalar"
149+ in
150+ let decode_rng = function
151+ | Scalar (Int seed ) -> Ok (Rune.Rng. key seed)
152+ | Scalar (Float value ) -> Ok (Rune.Rng. key (int_of_float value))
153+ | _ -> error " expected rng scalar"
154+ in
155+ let * n_actions_node = find " n_actions" in
156+ let * n_actions = decode_int n_actions_node in
157+ let * rng_node = find " rng" in
158+ let * rng = decode_rng rng_node in
159+ let * config_node = find " config" in
160+ let * config = config_of_snapshot config_node in
161+ let * q_params_node = find " q_params" in
162+ let * q_params =
163+ match Snapshot. to_ptree q_params_node with
164+ | Ok params -> Ok params
165+ | Error msg -> error msg
166+ in
167+ let * target_params_node = find " target_params" in
168+ let * target_params =
169+ match Snapshot. to_ptree target_params_node with
170+ | Ok params -> Ok params
171+ | Error msg -> error msg
172+ in
173+ let * opt_state_node = find " optimizer_state" in
174+ let * opt_state =
175+ match Optimizer. restore optimizer opt_state_node with
176+ | Ok state -> Ok state
177+ | Error msg -> error msg
178+ in
179+ let replay_buffer =
180+ Fehu.Buffer.Replay. create ~capacity: config.buffer_capacity
181+ in
182+ Ok
183+ {
184+ q_network;
185+ q_params;
186+ target_network = q_network;
187+ target_params;
188+ optimizer;
189+ opt_state;
190+ replay_buffer;
191+ rng;
192+ n_actions;
193+ config;
194+ }
195+ | _ -> error " expected snapshot record"
196+
50197let create ~q_network ~n_actions ~rng config =
51198 let keys = Rune.Rng. split ~n: 2 rng in
52199
@@ -356,81 +503,19 @@ let learn t ~env ~total_timesteps
356503
357504 t
358505
359- let save t dir =
360- if not (Sys. file_exists dir) then Unix. mkdir dir 0o755 ;
361- let checkpointer = Kaun.Checkpoint.Checkpointer. create () in
362- Kaun.Checkpoint.Checkpointer. save_file checkpointer
363- ~path: (Filename. concat dir " q_params.safetensors" )
364- ~params: t.q_params () ;
365- Kaun.Checkpoint.Checkpointer. save_file checkpointer
366- ~path: (Filename. concat dir " target_params.safetensors" )
367- ~params: t.target_params () ;
368- let rng_seed = Rune.Rng. to_int t.rng in
369- let metadata = `Assoc [
370- (" n_actions" , `Int t.n_actions);
371- (" rng_seed" , `Int rng_seed);
372- (" learning_rate" , `Float t.config.learning_rate);
373- (" gamma" , `Float t.config.gamma);
374- (" epsilon_start" , `Float t.config.epsilon_start);
375- (" epsilon_end" , `Float t.config.epsilon_end);
376- (" epsilon_decay" , `Float t.config.epsilon_decay);
377- (" batch_size" , `Int t.config.batch_size);
378- (" buffer_capacity" , `Int t.config.buffer_capacity);
379- (" target_update_freq" , `Int t.config.target_update_freq);
380- ] in
381- let metadata_path = Filename. concat dir " metadata.json" in
382- let oc = open_out metadata_path in
383- Yojson.Basic. to_channel oc metadata;
384- close_out oc
385-
386- let load dir =
387- let metadata_path = Filename. concat dir " metadata.json" in
388- let metadata = Yojson.Basic. from_file metadata_path in
389- let open Yojson.Basic.Util in
390- let config = {
391- learning_rate = metadata |> member " learning_rate" |> to_float;
392- gamma = metadata |> member " gamma" |> to_float;
393- epsilon_start = metadata |> member " epsilon_start" |> to_float;
394- epsilon_end = metadata |> member " epsilon_end" |> to_float;
395- epsilon_decay = metadata |> member " epsilon_decay" |> to_float;
396- batch_size = metadata |> member " batch_size" |> to_int;
397- buffer_capacity = metadata |> member " buffer_capacity" |> to_int;
398- target_update_freq = metadata |> member " target_update_freq" |> to_int;
399- } in
400- let n_actions = metadata |> member " n_actions" |> to_int in
401- let rng_seed = metadata |> member " rng_seed" |> to_int in
402- let rng = Rune.Rng. key rng_seed in
403- let checkpointer = Kaun.Checkpoint.Checkpointer. create () in
404-
405- (* Example: fixed architecture for 2 input features and n_actions *)
406- let q_network =
407- Kaun.Layer. sequential [
408- Kaun.Layer. linear ~in_features: 2 ~out_features: 8 () ;
409- Kaun.Layer. relu () ;
410- Kaun.Layer. linear ~in_features: 8 ~out_features: n_actions () ;
411- ]
412- in
413-
414- let q_params = Kaun.Checkpoint.Checkpointer. restore_file checkpointer
415- ~path: (Filename. concat dir " q_params.safetensors" )
416- ~dtype: Rune. float32
417- in
418- let target_params = Kaun.Checkpoint.Checkpointer. restore_file checkpointer
419- ~path: (Filename. concat dir " target_params.safetensors" )
420- ~dtype: Rune. float32
421- in
422- let optimizer = Optimizer. adam ~lr: config.learning_rate () in
423- let opt_state = optimizer.init q_params in
424- let replay_buffer = Fehu.Buffer.Replay. create ~capacity: config.buffer_capacity in
425- {
426- q_network;
427- q_params;
428- target_network = q_network;
429- target_params;
430- optimizer;
431- opt_state;
432- replay_buffer;
433- rng;
434- n_actions;
435- config;
436- }
506+ let save_to_file (t : t ) ~path =
507+ match Checkpoint. write_snapshot_file_with ~path ~encode: (fun () -> to_snapshot t) with
508+ | Ok () -> ()
509+ | Error err ->
510+ failwith
511+ (Printf. sprintf " Dqn.save_to_file: %s"
512+ (Checkpoint. error_to_string err))
513+
514+ let load_from_file ~path ~(q_network : module_ ) ~(optimizer : Optimizer.algorithm )
515+ =
516+ match
517+ Checkpoint. load_snapshot_file_with ~path
518+ ~decode: (fun snapshot -> of_snapshot ~q_network ~optimizer snapshot)
519+ with
520+ | Ok agent -> Ok agent
521+ | Error err -> Error (Checkpoint. error_to_string err)
0 commit comments