Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TMP: feature(pu): add sampled_unizero multitask pipeline #311

Open
wants to merge 138 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
138 commits
Select commit Hold shift + click to select a range
dd2c95c
feature(pu): add UniZero multitask related pipeline
puyuan1996 Jul 5, 2024
8769a5c
polish(pu): polish unizero_multitask config
dyyoungg Jul 8, 2024
c342ce1
fix(pu): fix empty_keys_values in init_infer
dyyoungg Jul 11, 2024
6eb772a
feature(pu): add softmoe head option in unizero_multitask
dyyoungg Jul 11, 2024
71f55b4
fix(pu): fix unizero reset in muzero_collector
dyyoungg Jul 12, 2024
445fd70
polish(pu): polish unizero-multitask config
dyyoungg Jul 14, 2024
4954581
fix(pu): fix replay ratio
dyyoungg Jul 16, 2024
44304bf
feature(pu): add moe option of feedforward in transformer backbone
dyyoungg Jul 16, 2024
d6be21a
feature(pu): add value_priority in unizero_multitask
dyyoungg Jul 17, 2024
fde51cc
polish(pu): polish value_priority in unizero_multitask
dyyoungg Jul 17, 2024
b460d2f
sync code
dyyoungg Jul 18, 2024
5117459
fix(pu): fix moe in feedforward layer of transformer and polish configs
dyyoungg Jul 19, 2024
2495d60
feature(pu): add mistralai moe in transformer feedforward and head of…
dyyoungg Jul 23, 2024
95886bd
polish(pu): polish quantize_state_hash and deepcopy
PaParaZz1 Aug 18, 2024
0e49a30
fix(pu): fix np.array dtype bug in buffer
PaParaZz1 Aug 18, 2024
00147f4
polish(pu): use 0 deepcopy in kv_cache operation in collect/eval phas…
PaParaZz1 Aug 19, 2024
b40c71b
polish(pu): use custom deepcopy for kv_cache
PaParaZz1 Aug 22, 2024
2cc81be
polish(pu): use value_array rather than value_list in compute_target_…
PaParaZz1 Aug 22, 2024
bc5332f
polish(pu): optimize compute_target_policy_non_re
PaParaZz1 Aug 22, 2024
a6c6a8e
polish(pu): optimize kv_caching update()
PaParaZz1 Aug 22, 2024
b5dcdcc
polish(pu): kv_cache_dict no to_cpu
PaParaZz1 Aug 22, 2024
5b0cbd4
polish(pu): optimize custom kv_cache copy
PaParaZz1 Aug 22, 2024
0035829
polish(pu): kv_cache_dict no to_cpu
PaParaZz1 Aug 22, 2024
043727b
feature(pu): add unizero ddp config
PaParaZz1 Aug 23, 2024
d568008
fix(pu): fix unizero ddp
dyyoungg Aug 23, 2024
d349137
sync code
dyyoungg Aug 25, 2024
3a344aa
polish(pu): use de kv_cacheepcopy only in recur_infer load
PaParaZz1 Aug 26, 2024
40053f7
Merge branch 'dev-efficiency' of https://github.com/opendilab/LightZe…
dyyoungg Aug 26, 2024
61a1139
sync code
dyyoungg Aug 26, 2024
0e545c7
polish(pu): polish suz dmc config
dyyoungg Aug 26, 2024
bb38a10
sync code
dyyoungg Aug 26, 2024
b813be7
Merge branch 'dev-efficiency' of https://github.com/opendilab/LightZe…
dyyoungg Aug 27, 2024
715d17e
polish(pu): use share_polol for kv_cache in recurrent_inference and u…
jiayilee65 Aug 27, 2024
39d6bbe
polish(pu): all kv_cache copy use predefined share_pool
jiayilee65 Aug 27, 2024
f18be2a
polish(pu): unuse decoder_net and lpips in ddp config
dyyoungg Aug 28, 2024
1d010d3
sync code
dyyoungg Aug 28, 2024
a2ce5a8
feature(pu): add dmc save_replay_gif option
puyuan1996 Aug 30, 2024
abf8924
sync code
dyyoungg Sep 2, 2024
9de4096
polish(pu): polish sampled muzero ctree
dyyoungg Sep 2, 2024
54416e6
Merge branch 'polish-unizero-cont' of https://github.com/opendilab/Li…
dyyoungg Sep 3, 2024
5804ff2
test(pu): add sac cheetah config
dyyoungg Sep 3, 2024
fabffd2
fix(pu): fix render_image in dmc_env
dyyoungg Sep 3, 2024
7d0d4c7
fix(pu): fix reanalyze in sampled unizero
dyyoungg Sep 3, 2024
e666d12
polish(pu): polish policy projector
dyyoungg Sep 4, 2024
fea98ee
feature(pu): add muzero_segment_collector.py
dyyoungg Sep 5, 2024
0391f4c
polish(pu): use uniform prior in ucb_score of suz mcts
dyyoungg Sep 5, 2024
2a376ec
fix(pu): fix self.action_mask_dict init bug
dyyoungg Sep 5, 2024
0121b63
test(pu): use clamp0.9->1
dyyoungg Sep 10, 2024
ed4773b
polish(pu): polish suz
dyyoungg Sep 12, 2024
d36196e
fix(pu): fix muzero_segment_collector
dyyoungg Sep 12, 2024
51e10f2
fix(pu): uz target-value obs also use aug when use_aug=True
dyyoungg Sep 12, 2024
31543c3
sync code
dyyoungg Sep 13, 2024
8615899
fix(pu): fix last_game_segment bug in muzero_segment_collector.py
dyyoungg Sep 13, 2024
4c969f6
fix(pu): one episode done then return in muzero_segment_collector.py
dyyoungg Sep 13, 2024
cf2fd81
fix(pu): fix muzero_collector
dyyoungg Sep 14, 2024
bff16f7
polish(pu): polish unizero config and polish sample from segments
dyyoungg Sep 16, 2024
f0ff953
fix(pu): fix reanalyze in uz
dyyoungg Sep 17, 2024
91d48c1
polish(pu): add batch config and bash
dyyoungg Sep 17, 2024
1c8b92b
polish(pu): polish uz configs
dyyoungg Sep 20, 2024
1eed401
feature(pu): add unizero buffer_reanalyze variant
dyyoungg Sep 20, 2024
380f693
fix(pu): fix uz reanalyze_buffer
dyyoungg Sep 20, 2024
c43cdd4
polish(pu): polish configs
dyyoungg Sep 22, 2024
05a2ec3
feature(pu): add atari_muzero_segment_config
dyyoungg Sep 23, 2024
639c2e1
Merge branch 'dev-efficieny-plus-tune-uz-atari100k' of https://github…
dyyoungg Sep 23, 2024
81b47d2
fix(pu): fix sampled_unizero reanalyze_policy
dyyoungg Sep 23, 2024
a78fa70
polish(pu):polish configs
dyyoungg Sep 24, 2024
dc2d454
polish(pu):polish suz configs
dyyoungg Sep 24, 2024
f634e1f
polish(pu):polish configs
dyyoungg Sep 24, 2024
c19b203
Merge branch 'dev-efficieny-plus-tune-uz-atari100k' of https://github…
dyyoungg Sep 24, 2024
f536f3c
fix(pu): fix root value in suz buffer
dyyoungg Sep 25, 2024
f3e6d8d
fix(pu): fix suz ctree
dyyoungg Sep 25, 2024
b7243ea
polish(pu): polish uz related configs, segment collector, train_entry
puyuan1996 Sep 26, 2024
dba9ca7
polish(pu): polish unizero world_model
puyuan1996 Sep 26, 2024
d5fff6d
polish(pu): polish reanalyze in buffer
puyuan1996 Sep 26, 2024
29197d2
fix(pu): fix entry import and nparray object bug in buffer
puyuan1996 Sep 26, 2024
eb268ac
polish(pu): polish configs
puyuan1996 Sep 26, 2024
5e75d09
polish(pu): polish configs
dyyoungg Sep 26, 2024
31221d8
Merge branch 'dev-efficieny-plus-tune-uz-atari100k-polish' of https:/…
dyyoungg Sep 27, 2024
32ad2d0
polish(pu): fix collector, polish configs
dyyoungg Sep 27, 2024
4a18e33
fix(pu): fix truncation segment sample in buffer
dyyoungg Sep 28, 2024
cb37b29
fix(pu): fix segment sample for uz in buffer
dyyoungg Sep 28, 2024
39051c5
fix(pu): use origin buffer
dyyoungg Sep 29, 2024
5842e89
fix(pu): fixvaluebugV8
jiayilee65 Sep 29, 2024
361d6c6
sync code
dyyoungg Sep 30, 2024
dafb655
fix(pu): fix target action when calculating bootstrap value in unizero
dyyoungg Oct 2, 2024
f7792c0
fix(pu): fix target-action in sampled_unizero buffer
dyyoungg Oct 3, 2024
858006c
polish(pu): delete wrongly added files
dyyoungg Oct 3, 2024
1e13f9b
polish(pu): polish entry/buffer/ctree, and fix index+1 bug in compute…
puyuan1996 Oct 3, 2024
0027834
polish(pu): polish buffer and config
puyuan1996 Oct 8, 2024
46b7096
polish(pu): rename train_xxx_reanalyze to train_xxx_segment
puyuan1996 Oct 8, 2024
a2e1611
polish(pu): polish world_model
puyuan1996 Oct 8, 2024
5d384bb
polish(pu): polish entry comments
puyuan1996 Oct 8, 2024
b7f0f0f
fix(pu): fix reward shape bug in dmc
puyuan1996 Oct 9, 2024
ea16e48
fix(pu): polish sample_orig_reanalyze_batch and fix sample_orig_data …
puyuan1996 Oct 12, 2024
b4ae014
polish(pu): polish comments in _sample_orig_reanalyze_batch
puyuan1996 Oct 16, 2024
6abef12
Merge remote-tracking branch 'origin/dev-unizero-multitask-v2' into d…
puyuan1996 Oct 16, 2024
cc7cc66
polish(pu): adapt muzero_multitask to segment_collector
puyuan1996 Oct 16, 2024
43e0966
fix(pu): fix unizero task_id
dyyoungg Oct 16, 2024
ad01aab
fix(pu): fix unizero task_id
dyyoungg Oct 16, 2024
144e7d4
Merge branch 'dev-uz-multitask-v3' of https://github.com/opendilab/Li…
dyyoungg Oct 17, 2024
9bb1189
polish(pu): polish uz_mt config
dyyoungg Nov 4, 2024
04730e1
polish(pu): polish uz_mt config
dyyoungg Nov 6, 2024
f1d62e0
polish(pu): polish uz_mt config
dyyoungg Nov 6, 2024
025527c
feature(pu): add uz_mt_ddp config
dyyoungg Nov 7, 2024
122bcf2
feature(pu): add atari100k unizero_multitask ddp config
dyyoungg Nov 9, 2024
558c048
fix(pu): fix unizero_multitask ddp config
dyyoungg Nov 9, 2024
140df70
fix(pu): fix unizero_multitask ddp_v2 log in collector/evaluator/tb_l…
dyyoungg Nov 11, 2024
92fd026
fix(pu): fix uz_mt ddp_v2 learner log
dyyoungg Nov 12, 2024
4799b74
feature(pu): add eval_async use ThreadPool
dyyoungg Nov 13, 2024
aeb997c
fix(pu): fix ddp bug when task_num>gpu_num
dyyoungg Nov 13, 2024
9f4fba9
fix(pu): fix log_buffer_memory_usage in ddp setting
dyyoungg Nov 13, 2024
d30cf00
feature(pu): add allocated_batch_sizes option
dyyoungg Nov 13, 2024
77cb8b9
fix(pu): add timeout in eval_async
dyyoungg Nov 14, 2024
eeaf986
fix(pu): use stop_event to quit eval() when timeout in eval_async
dyyoungg Nov 15, 2024
0fb4263
polish(pu): polish unizero_mt configs
dyyoungg Nov 19, 2024
aaa2793
sync code
dyyoungg Nov 24, 2024
13fbe4c
sync code
dyyoungg Nov 24, 2024
d5842f1
feature(pu): add muzero multitask (and its ddp version) pipeline
puyuan1996 Nov 28, 2024
7723f13
polish(pu): polish configs
puyuan1996 Nov 28, 2024
f1e8d8c
polish(pu): polish config
dyyoungg Nov 29, 2024
62c8a96
polish(pu): polish config
dyyoungg Dec 1, 2024
99c08e2
Merge branch 'dev-mz-multitask-ddp' of https://github.com/opendilab/L…
dyyoungg Dec 1, 2024
1edcba3
fix(pu): fix embed dim in uz_multitask pipeline
dyyoungg Dec 2, 2024
4b195eb
feature(pu): add uz finetune config
dyyoungg Dec 2, 2024
13183e7
feature(pu): add uz eval-tsne config
dyyoungg Dec 2, 2024
1e37cae
fix(pu): add uz eval-tsne config
dyyoungg Dec 2, 2024
29298e6
polish(pu): polish tsne-plot legend
dyyoungg Dec 3, 2024
ffdf4db
polish(pu): polish atari multitask related configs
puyuan1996 Dec 18, 2024
2b0af34
polish(pu): polish unizero/muzero multitask related entry
puyuan1996 Dec 18, 2024
0bd688e
polish(pu): delete unused files
puyuan1996 Dec 18, 2024
69a1842
Merge remote-tracking branch 'origin/main' into dev-mz-multitask-ddp
puyuan1996 Dec 18, 2024
d06ce61
polish(pu): polish policy/model in multitask settings
puyuan1996 Dec 18, 2024
3a88f46
feature(pu): add sampled_unizero multitask pipeline
puyuan1996 Dec 24, 2024
563548b
fix(pu): fix sampled_unizero multitask ddp pipeline
puyuan1996 Dec 24, 2024
d8705e6
fix(pu): fix sampled_unizero multitask ddp pipeline
puyuan1996 Dec 24, 2024
a5b38b6
sync code
puyuan1996 Dec 25, 2024
1c8c4fb
polish(pu): polish suz dmc multitask configs
puyuan1996 Dec 26, 2024
aedf83b
fix(pu): fix self.last_batch_obs_eval
Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feature(pu): add uz finetune config
  • Loading branch information
dyyoungg committed Dec 2, 2024
commit 4b195eb58071ceddf856125ad2bc35936a4f2725
101 changes: 100 additions & 1 deletion lzero/policy/unizero_multitask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import sys
from collections import defaultdict
from typing import List, Dict, Tuple, Union
from typing import List, Dict, Any, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -1224,3 +1224,102 @@ def recompute_pos_emb_diff_and_clear_cache(self) -> None:
model.world_model.precompute_pos_emb_diff_kv()
model.world_model.clear_caches()
torch.cuda.empty_cache()

def _state_dict_learn(self) -> Dict[str, Any]:
"""
Overview:
Return the state_dict of learn mode, usually including model, target_model and optimizer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
"""
return {
'model': self._learn_model.state_dict(),
'target_model': self._target_model.state_dict(),
'optimizer_world_model': self._optimizer_world_model.state_dict(),
}

# def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
# """
# Overview:
# Load the state_dict variable into policy learn mode.
# Arguments:
# - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
# """
# self._learn_model.load_state_dict(state_dict['model'])
# self._target_model.load_state_dict(state_dict['target_model'])
# self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model'])

# ========== TODO ==========
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load the state_dict variable into policy learn mode, excluding multi-task related parameters.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously.
"""
# 定义需要排除的参数前缀
exclude_prefixes = [
'_orig_mod.world_model.head_policy_multi_task.',
'_orig_mod.world_model.head_value_multi_task.',
'_orig_mod.world_model.head_rewards_multi_task.',
'_orig_mod.world_model.head_observations_multi_task.',
'_orig_mod.world_model.task_emb.'
]

# 定义需要排除的具体参数(如果有特殊情况)
exclude_keys = [
'_orig_mod.world_model.task_emb.weight',
'_orig_mod.world_model.task_emb.bias', # 如果存在则添加
# 添加其他需要排除的具体参数名
]

def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]:
"""
过滤掉需要排除的参数。
"""
filtered = {}
for k, v in state_dict_loader.items():
if any(k.startswith(prefix) for prefix in exclude_prefixes):
print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除
continue
if k in exclude_keys:
print(f"Excluding specific parameter: {k}") # 调试用
continue
filtered[k] = v
return filtered

# 过滤并加载 'model' 部分
if 'model' in state_dict:
model_state_dict = state_dict['model']
filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys)
missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False)
if missing_keys:
print(f"Missing keys when loading _learn_model: {missing_keys}")
if unexpected_keys:
print(f"Unexpected keys when loading _learn_model: {unexpected_keys}")
else:
print("No 'model' key found in the state_dict.")

# 过滤并加载 'target_model' 部分
if 'target_model' in state_dict:
target_model_state_dict = state_dict['target_model']
filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys)
missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False)
if missing_keys:
print(f"Missing keys when loading _target_model: {missing_keys}")
if unexpected_keys:
print(f"Unexpected keys when loading _target_model: {unexpected_keys}")
else:
print("No 'target_model' key found in the state_dict.")

# 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数
if 'optimizer_world_model' in state_dict:
optimizer_state_dict = state_dict['optimizer_world_model']
try:
self._optimizer_world_model.load_state_dict(optimizer_state_dict)
except Exception as e:
print(f"Error loading optimizer state_dict: {e}")
else:
print("No 'optimizer_world_model' key found in the state_dict.")

# 如果需要,还可以加载其他部分,例如 scheduler 等
Loading
Loading