Description
Would it be possible to convert a pyspiel
game's State
object to a dictionary of array-likes and back again, in an efficient way? If that is currently not supported, would it be possible to add this feature?
At the moment, it seems to me that this is not possible. pyspiel
states are implemented in C++ and bound to Python with pybind11, and it doesn't look like any of the bound methods or properties provide a dict-of-arrays representation of the state.
I'm asking about this because I am looking into adding an environment wrapper class for OpenSpiel to TorchRL. Ideally, the wrapper would be stateless, so the state would need to be provided to the wrapper's step function as part of a TensorDict, which is a dictionary of array-likes.
Some other RL environment libraries support dict-of-arrays representations, like Brax and Jumanji. Just to give an example:
import jumanji
import jax
env = jumanji.make('Snake-v1')
key = jax.random.PRNGKey(0)
state, _ = env.reset(key)
def state_to_dict_of_arrays(state):
res = {}
for key, value in state.items():
if hasattr(value, '_fields'):
res[key] = {}
for field in value._fields:
res[key][field] = jax.numpy.asarray(value)
else:
res[key] = jax.numpy.asarray(value)
return res
state_to_dict_of_arrays(state)
{'body': Array([[False, False, False, False, False, False, False, False, False,
False, False, False],
...
[False, False, False, False, False, False, False, False, False,
False, False, False]], dtype=bool),
'body_state': Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
...
'col': Array([2, 4], dtype=int32)},
'length': Array(1, dtype=int32),
'step_count': Array(0, dtype=int32),
'action_mask': Array([ True, True, True, True], dtype=bool),
'key': Array([2467461003, 428148500], dtype=uint32)}