-
Notifications
You must be signed in to change notification settings - Fork 185
Description
File "zoo/metadrive/config/metadrive_sampled_efficientzero_config.py", line 103, in
train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step)
File "/root/autodl-tmp/LightZero/lzero/entry/train_muzero.py", line 154, in train_muzero
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
File "/root/autodl-tmp/LightZero/lzero/worker/muzero_evaluator.py", line 287, in eval
policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id)
File "/root/autodl-tmp/LightZero/lzero/policy/sampled_efficientzero.py", line 963, in _forward_eval
network_output = self._eval_model.initial_inference(data)
File "/root/autodl-tmp/LightZero/lzero/model/sampled_efficientzero_model.py", line 270, in initial_inference
latent_state = self._representation(obs)
File "/root/autodl-tmp/LightZero/lzero/model/sampled_efficientzero_model.py", line 330, in _representation
latent_state = self.representation_network(observation)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/autodl-tmp/LightZero/lzero/model/common.py", line 646, in forward
x = self.downsample_net(x)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/autodl-tmp/LightZero/lzero/model/common.py", line 336, in forward
x = self.conv1(x)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 458, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [32, 5, 3, 3], expected input[3, 84, 84, 5] to have 5 channels, but got 84 channels instead
there‘s something wrong when run metadrive_sampled_efficientzero_config.py