-
I want to have a JAX or NNX jitted function that consumes and returns GPU-sharded tensors. However, inside the function, I also want to perform some CPU operations on the tensors, and I want to see high CPU usage among the cores according to my Task Manager. I have a PPO agent implemented in NNX. The RL environment is also an NNX Module. For a given action, the environment produces an observation of audio. Let's suppose the audio is produced highly sequentially, like by an RNN. Because it's so sequential and we don't even need gradients at this step, it may actually be optimal to produce the audio on the CPU despite the GPU-to-CPU-to-GPU communication overhead. This audio production step is the rollout subroutine in the train_step function. Once the audio is back on the GPU, the GPU-based agent policy/critic are updated. The entire train_step function is decorated with both nnx.jit and nnx.shard_map. I think the code below is a minimum replication of the concept above. It's adapted from 04_data_parallel_with_jit.py. The loss function isn't meant to do go down since it's a bogus model. The issue is that I don't see high usage among many of my CPU cores. Only 2 or so are full throttle and another 2 are medium usage. A larger batch size of 128 doesn't change this. My hope was JAX would be good at doing SIMD CPU operations for the RNN. How can I use more of my CPU in this example? Note that ideally we would also nnx.jit(device="cpu") decorate the # Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
# jax.config.update("jax_platform_name", "cpu")
from jax import lax, numpy as jnp
import numpy as np
import optax
from flax import nnx
from jax.experimental import mesh_utils
# create a mesh + shardings
num_devices = jax.local_device_count()
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((num_devices,)), ('data',)
)
rep_spec = jax.sharding.PartitionSpec()
dp_spec = jax.sharding.PartitionSpec('data')
model_sharding = jax.NamedSharding(mesh, rep_spec)
data_sharding = jax.NamedSharding(mesh, dp_spec)
rollout_model = nnx.RNN(nnx.GRUCell(in_features=1, hidden_features=1, rngs=nnx.Rngs(0)))
rollout_model.eval()
train_model = nnx.Linear(in_features=1, out_features=1, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(train_model, optax.adamw(1e-2))
# replicate state
state1 = jax.device_put(nnx.state((rollout_model, optimizer)), model_sharding)
state2 = jax.device_put(nnx.state((train_model, optimizer)), model_sharding)
nnx.update((rollout_model, optimizer), state1)
nnx.update((train_model, optimizer), state2)
# visualize model sharding
print('model sharding')
jax.debug.visualize_array_sharding(train_model.kernel.value)
batch_size = 64 # pick based on number of cpu cores
def move_model_to_cpu(model):
cpus = jax.devices("cpu")
graph_def, state = nnx.split(model)
graph_def, state = jax.device_put((graph_def, state), cpus[0])
model_cpu = nnx.merge(graph_def, state)
return model_cpu
@nnx.jit
@nnx.shard_map(
mesh=mesh,
in_specs=(rep_spec, rep_spec, dp_spec, dp_spec),
out_specs=rep_spec
)
def train_step(model: nnx.RNN, optimizer: nnx.Optimizer, x, y):
def loss_fn(model: nnx.RNN):
# some arbitrary gpu math:
x_sq = x * x
# the rest should be on CPU:
cpus = jax.devices("cpu")
rollout_model_cpu = move_model_to_cpu(rollout_model)
x_cpu = jax.device_put(x_sq, cpus[0])
# this step should be slow and intense on many CPU cores:
y_pred_cpu = rollout_model_cpu(x_cpu) # [B, T, 1]
assert y_pred_cpu.ndim == 3 and y_pred_cpu.shape[-1] == 1
# back to GPU:
gpus = jax.devices("gpu")
y_pred = jax.device_put(y_pred_cpu, gpus[0])
y_pred = lax.stop_gradient(y_pred) # don't want to backprop through cpu operations.
# this forward/backprop step should be fast since it's non-sequential
y_pred = model(y_pred)
return jnp.mean((y - y_pred) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
loss, grads = lax.pmean((loss, grads), 'data')
optimizer.update(grads)
return loss
def dataset(steps, batch_size):
for _ in range(steps):
x = np.random.uniform(-2, 2, size=(batch_size, 200_000, 1))
y = x + 0.1 + np.random.normal(0, 0.1, size=x.shape)
yield x, y
for step, (x, y) in enumerate(dataset(1000, batch_size)):
# shard data
x, y = jax.device_put((x, y), data_sharding)
# train
loss = train_step(train_model, optimizer, x, y)
if step == 0:
print('data sharding')
jax.debug.visualize_array_sharding(jnp.squeeze(x, axis=-1))
if step % 1 == 0:
print(f'step={step}, loss={loss}') output:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
@DBraun If I understand the code correctly, you would like to generate the training data using the |
Beta Was this translation helpful? Give feedback.
Thanks for your reply.
rollout_model
isn't static, so generating offline won't work. I found a solution by passing annnx.pmap(backend="cpu")
decorated function tojax.pure_callback
. Before calling it, I have to reshape the input so that its leading axis matches what is set via--xla_force_host_platform_device_count
. I see all CPU cores in high usage!The key takeaway is summarized here: jax-ml/jax#5022 (comment)