Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pu): fix smz compile_args and num_simulations bug in world_model #297

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lzero/config/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__TITLE__ = "LightZero"

#: Version of this project.
__VERSION__ = "0.0.3"
__VERSION__ = "0.1.0"

#: Short description of the project, will be included in ``setup.py``.
__DESCRIPTION__ = 'A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkits.'
Expand Down
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .utils import *
8 changes: 4 additions & 4 deletions lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ namespace tree
}
}

std::vector<std::vector<float> > CNode::get_trajectory()
std::vector<std::vector<float>> CNode::get_trajectory()
{
/*
Overview:
Expand All @@ -629,7 +629,7 @@ namespace tree
best_action = node->best_action;
}

std::vector<std::vector<float> > traj_return;
std::vector<std::vector<float>> traj_return;
for (int i = 0; i < traj.size(); ++i)
{
traj_return.push_back(traj[i].value);
Expand Down Expand Up @@ -676,7 +676,7 @@ namespace tree
this->num_of_sampled_actions = 20;
}

CRoots::CRoots(int root_num, std::vector<std::vector<float> > legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space)
CRoots::CRoots(int root_num, std::vector<std::vector<float>> legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space)
{
/*
Overview:
Expand Down Expand Up @@ -728,7 +728,7 @@ namespace tree

CRoots::~CRoots() {}

void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float>> &noises, const std::vector<float> &rewards, const std::vector<std::vector<float>> &policies, std::vector<int> &to_play_batch)
{
/*
Overview:
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ namespace tree
CNode(float prior, std::vector<CAction> &legal_actions, int action_space_size, int num_of_sampled_actions, bool continuous_action_space);
~CNode();

// 辅助采样函数
std::pair<std::vector<std::vector<float>>, std::vector<float>> sample_actions(
// Auxiliary sampling function
std::pair<std::vector<std::vector<float> >, std::vector<float> > sample_actions(
const std::vector<float>& mu,
const std::vector<float>& sigma,
int num_samples,
Expand Down
2 changes: 2 additions & 0 deletions lzero/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .image_transform import Intensity, RandomCrop, ImageTransforms
from .utils import *
from .common import *
3 changes: 1 addition & 2 deletions lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def __init__(
self.encoder_hook = FeatureAndGradientHook()
self.encoder_hook.setup_hooks(self.representation_network)

self.tokenizer = Tokenizer(with_lpips=True, encoder=self.representation_network,
decoder_network=self.decoder_network)
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network)
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)')
Expand Down
9 changes: 6 additions & 3 deletions lzero/model/unizero_world_models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from einops import rearrange
from torch.nn import functional as F

from lzero.model.unizero_world_models.lpips import LPIPS


class LossWithIntermediateLosses:
def __init__(self, **kwargs):
Expand Down Expand Up @@ -47,7 +45,12 @@ def __init__(self, encoder=None, decoder_network=None, with_lpips: bool = False)
with_lpips (bool, optional): Whether to use LPIPS for perceptual loss. Defaults to False.
"""
super().__init__()
self.lpips = LPIPS().eval() if with_lpips else None
if with_lpips:
from lzero.model.unizero_world_models.lpips import LPIPS
self.lpips = LPIPS().eval()
else:
self.lpips = None

self.encoder = encoder
self.decoder_network = decoder_network

Expand Down
2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
# TODO: check the size of the shared pool
# for self.kv_cache_recurrent_infer
# If needed, recurrent_infer should store the results of the one MCTS search.
self.num_simulations = self.config.num_simulations
self.num_simulations = getattr(self.config, 'num_simulations', 50)
self.shared_pool_size = int(self.num_simulations*self.env_num)
self.shared_pool_recur_infer = [None] * self.shared_pool_size
self.shared_pool_index = 0
Expand Down
1 change: 0 additions & 1 deletion lzero/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)



# modified from https://github.com/karpathy/nanoGPT/blob/master/model.py#L263
def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type):
# start with all of the candidate parameters
Expand Down
1 change: 1 addition & 0 deletions lzero/reward_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rnd_reward_model import *
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ pytest
pooltool-billiards>=0.3.1
line_profiler
xxhash
einops
20 changes: 17 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# limitations under the License.
import os
import re
import sys
from distutils.core import setup

import numpy as np
from Cython.Build import cythonize
from setuptools import find_packages, Extension
from Cython.Build import cythonize # this line should be after 'from setuptools import find_packages'

here = os.path.abspath(os.path.dirname(__file__))

Expand All @@ -32,6 +33,19 @@ def _load_req(file: str):
for item in [_REQ_PATTERN.fullmatch(reqpath) for reqpath in os.listdir()] if item
}

# Set C++11 compile parameters according to the operating system
extra_compile_args = []
extra_link_args = []

if sys.platform == 'win32':
# Use the VS compiler on Windows platform
extra_compile_args = ["/std:c++11"]
extra_link_args = ["/std:c++11"]
else:
# Linux/macOS Platform
extra_compile_args = ["-std=c++11"]
extra_link_args = ["-std=c++11"]


def find_pyx(path=None):
path = path or os.path.join(here, 'lzero')
Expand Down Expand Up @@ -60,8 +74,8 @@ def find_cython_extensions(path=None):
extname, [item],
include_dirs=[np.get_include()],
language="c++",
# extra_compile_args=["/std:c++latest"], # only for Windows
# extra_link_args=["/std:c++latest"], # only for Windows
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
))

return extensions
Expand Down
Loading