From 83f10f0b294915116230bbed17926195a67e4bae Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Sun, 14 Jul 2024 18:24:41 +0800 Subject: [PATCH] fix(pu): fix empty_keys_values in init_infer --- lzero/entry/train_unizero.py | 2 +- lzero/model/unizero_world_models/world_model.py | 2 +- zoo/atari/config/atari_unizero_config.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index 969b1f947..b8f9e4484 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -61,7 +61,7 @@ def train_unizero( game_buffer_classes[create_cfg.policy.type]) # Set device based on CUDA availability - cfg.policy.device = cfg.policy.model.world_model.device if torch.cuda.is_available() else 'cpu' + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' logging.info(f'cfg.policy.device: {cfg.policy.device}') # Compile the configuration diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index e4edcef51..ef31d951c 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -430,7 +430,7 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor if current_obs_embeddings is not None: if max(buffer_action) == -1: # First step in an episode - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], max_tokens=self.context_length) # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 844c97e4e..1c549010f 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -54,8 +54,8 @@ max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action context_length=2 * infer_context_length, - # device='cuda', - device='cpu', + device='cuda', + # device='cpu', action_space_size=action_space_size, num_layers=4, num_heads=8,