Skip to content

Commit

Permalink
init state in cpu first and move it to neuron afterwards
Browse files Browse the repository at this point in the history
  • Loading branch information
penxujun committed Mar 14, 2024
1 parent 1a0b9a3 commit 8ac9323
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,16 +573,44 @@ def _init_state(prng_key: Tensor, prebuilt_model_state: NestedTensor):
learner=learner_params,
)

def _init_state_cpu(prng_key: Tensor, prebuilt_model_state: NestedTensor):
prng_key, init_key = jax.random.split(prng_key)

cpu_device = jax.devices("cpu")[0]
with jax.default_device(cpu_device):
logging.info("prebuilt_model_state: %s", utils.shapes(prebuilt_model_state))
model_params = self.model.initialize_parameters_recursively(
init_key,
prebuilt=prebuilt_model_state,
)

return prng_key, model_params

def _move_state_to_neuron(prng_key: Tensor, model_params):
model_params = jax.device_put(model_params)
self.vlog(
1, "tree_structure(model_params)=%s", jax.tree_util.tree_structure(model_params)
)
learner_params = self.learner.init(self._opt_params(model_params))
return TrainerState(
prng_key=prng_key,
model=model_params,
learner=learner_params,
)

logging.info("prebuilt_model_state_partition_spec: %s", prebuilt_model_state_partition_spec)
logging.info("trainer_state_partition_specs: %s", self._trainer_state_partition_specs)
init_computation = pjit(
_init_state,
#_init_state,
_move_state_to_neuron,
in_shardings=(None, prebuilt_model_state_partition_spec),
out_shardings=self._trainer_state_partition_specs,
)
self._step_log("Initializing trainer state.")
with self.mesh():
self._trainer_state = init_computation(prng_key, prebuilt_model_state)
#self._trainer_state = init_computation(prng_key, prebuilt_model_state)
prng_key, model_params = _init_state_cpu(prng_key, prebuilt_model_state)
self._trainer_state = init_computation(prng_key, model_params)

def _log_trainer_state_stats(self):
total_num_params = count_model_params(self._trainer_state.model)
Expand Down

0 comments on commit 8ac9323

Please sign in to comment.