Skip to content
Discussion options

You must be logged in to vote

Thanks for your reply. rollout_model isn't static, so generating offline won't work. I found a solution by passing an nnx.pmap(backend="cpu") decorated function to jax.pure_callback. Before calling it, I have to reshape the input so that its leading axis matches what is set via --xla_force_host_platform_device_count. I see all CPU cores in high usage!

The key takeaway is summarized here: jax-ml/jax#5022 (comment)

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless requir…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@DBraun
Comment options

Answer selected by DBraun
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants