33import jax
44import numpy as np
55
6+ from keras .src .backend .common import global_state
7+ from keras .src .random import seed_generator
68from keras .src .utils import jax_utils
9+ from keras .src .utils import rng_utils
710
811
912def list_devices (device_type = None ):
@@ -185,6 +188,52 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name):
185188 return global_batch_array
186189
187190
191+ def initialize_rng ():
192+ """Initializes the global random number generator across processes.
193+
194+ This is required for consistent initialization in multi-host settings.
195+ """
196+ global_seed = rng_utils .get_random_seed ()
197+ # Only set a random seed if not already set
198+ # via keras.config.set_random_seed()
199+ if global_seed is None :
200+ # Generate a random seed on each CPU host and psum them to get a single
201+ # consistent seed across all processes.
202+ cpu_devices = jax .devices ("cpu" )
203+ num_local_cpu_devices = jax .local_device_count ("cpu" )
204+ # Seed must be in range [0, 2^32 - 1], so to ensure proper range and
205+ # avoid signed integer overflow, we use uint32.
206+ local_seed = jax .numpy .asarray (
207+ [seed_generator .make_default_seed ()] * num_local_cpu_devices ,
208+ dtype = jax .numpy .uint32 ,
209+ )
210+ # Sum across processes and pull out the first item.
211+ global_seed = jax .pmap (
212+ lambda x : jax .lax .psum (x , "all" ),
213+ axis_name = "all" ,
214+ devices = cpu_devices ,
215+ )(local_seed ).item (0 )
216+ # Set the global seed.
217+ rng_utils .set_random_seed (global_seed )
218+
219+ # Check if the global seed generator is set and ensure it has an initialized
220+ # seed. Otherwise, reset the seed to the global seed.
221+ global_seed_generator = global_state .get_global_attribute (
222+ "global_seed_generator"
223+ )
224+ if global_seed_generator is not None :
225+ seed = global_seed_generator .get_config ()["seed" ]
226+ if seed is None :
227+ global_state .set_global_attribute (
228+ "global_seed_generator" ,
229+ seed_generator .SeedGenerator (
230+ seed = global_seed ,
231+ name = global_seed_generator .name ,
232+ backend = global_seed_generator .backend ,
233+ ),
234+ )
235+
236+
188237def initialize (job_addresses , num_processes , process_id ):
189238 if job_addresses and "," in job_addresses :
190239 # When user provide all the job addresses, we will split and get the
@@ -208,6 +257,9 @@ def initialize(job_addresses, num_processes, process_id):
208257 process_id = process_id ,
209258 )
210259
260+ # Ensure the random number generator is initialized across processes.
261+ initialize_rng ()
262+
211263
212264def num_processes ():
213265 """Return the number of processes for the current distribution setting."""
0 commit comments