Skip to content

Commit

Permalink
fix(pu): fix smz compile_args and num_simulations bug in world_model (#…
Browse files Browse the repository at this point in the history
…297)

* fix(pu): fix smz compile_args

* fix(pu): fix num_simulations bug in world_model
  • Loading branch information
puyuan1996 authored Nov 19, 2024
1 parent 60be9e3 commit dfaebaa
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 17 deletions.
2 changes: 1 addition & 1 deletion lzero/config/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__TITLE__ = "LightZero"

#: Version of this project.
__VERSION__ = "0.0.3"
__VERSION__ = "0.1.0"

#: Short description of the project, will be included in ``setup.py``.
__DESCRIPTION__ = 'A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkits.'
Expand Down
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .utils import *
8 changes: 4 additions & 4 deletions lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ namespace tree
}
}

std::vector<std::vector<float> > CNode::get_trajectory()
std::vector<std::vector<float>> CNode::get_trajectory()
{
/*
Overview:
Expand All @@ -629,7 +629,7 @@ namespace tree
best_action = node->best_action;
}

std::vector<std::vector<float> > traj_return;
std::vector<std::vector<float>> traj_return;
for (int i = 0; i < traj.size(); ++i)
{
traj_return.push_back(traj[i].value);
Expand Down Expand Up @@ -676,7 +676,7 @@ namespace tree
this->num_of_sampled_actions = 20;
}

CRoots::CRoots(int root_num, std::vector<std::vector<float> > legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space)
CRoots::CRoots(int root_num, std::vector<std::vector<float>> legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space)
{
/*
Overview:
Expand Down Expand Up @@ -728,7 +728,7 @@ namespace tree

CRoots::~CRoots() {}

void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float>> &noises, const std::vector<float> &rewards, const std::vector<std::vector<float>> &policies, std::vector<int> &to_play_batch)
{
/*
Overview:
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ namespace tree
CNode(float prior, std::vector<CAction> &legal_actions, int action_space_size, int num_of_sampled_actions, bool continuous_action_space);
~CNode();

// 辅助采样函数
std::pair<std::vector<std::vector<float>>, std::vector<float>> sample_actions(
// Auxiliary sampling function
std::pair<std::vector<std::vector<float> >, std::vector<float> > sample_actions(
const std::vector<float>& mu,
const std::vector<float>& sigma,
int num_samples,
Expand Down
2 changes: 2 additions & 0 deletions lzero/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .image_transform import Intensity, RandomCrop, ImageTransforms
from .utils import *
from .common import *
3 changes: 1 addition & 2 deletions lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def __init__(
self.encoder_hook = FeatureAndGradientHook()
self.encoder_hook.setup_hooks(self.representation_network)

self.tokenizer = Tokenizer(with_lpips=True, encoder=self.representation_network,
decoder_network=self.decoder_network)
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network)
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)')
Expand Down
9 changes: 6 additions & 3 deletions lzero/model/unizero_world_models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from einops import rearrange
from torch.nn import functional as F

from lzero.model.unizero_world_models.lpips import LPIPS


class LossWithIntermediateLosses:
def __init__(self, **kwargs):
Expand Down Expand Up @@ -47,7 +45,12 @@ def __init__(self, encoder=None, decoder_network=None, with_lpips: bool = False)
with_lpips (bool, optional): Whether to use LPIPS for perceptual loss. Defaults to False.
"""
super().__init__()
self.lpips = LPIPS().eval() if with_lpips else None
if with_lpips:
from lzero.model.unizero_world_models.lpips import LPIPS
self.lpips = LPIPS().eval()
else:
self.lpips = None

self.encoder = encoder
self.decoder_network = decoder_network

Expand Down
2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
# TODO: check the size of the shared pool
# for self.kv_cache_recurrent_infer
# If needed, recurrent_infer should store the results of the one MCTS search.
self.num_simulations = self.config.num_simulations
self.num_simulations = getattr(self.config, 'num_simulations', 50)
self.shared_pool_size = int(self.num_simulations*self.env_num)
self.shared_pool_recur_infer = [None] * self.shared_pool_size
self.shared_pool_index = 0
Expand Down
1 change: 0 additions & 1 deletion lzero/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)



# modified from https://github.com/karpathy/nanoGPT/blob/master/model.py#L263
def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type):
# start with all of the candidate parameters
Expand Down
1 change: 1 addition & 0 deletions lzero/reward_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rnd_reward_model import *
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ pytest
pooltool-billiards>=0.3.1
line_profiler
xxhash
einops
20 changes: 17 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# limitations under the License.
import os
import re
import sys
from distutils.core import setup

import numpy as np
from Cython.Build import cythonize
from setuptools import find_packages, Extension
from Cython.Build import cythonize # this line should be after 'from setuptools import find_packages'

here = os.path.abspath(os.path.dirname(__file__))

Expand All @@ -32,6 +33,19 @@ def _load_req(file: str):
for item in [_REQ_PATTERN.fullmatch(reqpath) for reqpath in os.listdir()] if item
}

# Set C++11 compile parameters according to the operating system
extra_compile_args = []
extra_link_args = []

if sys.platform == 'win32':
# Use the VS compiler on Windows platform
extra_compile_args = ["/std:c++11"]
extra_link_args = ["/std:c++11"]
else:
# Linux/macOS Platform
extra_compile_args = ["-std=c++11"]
extra_link_args = ["-std=c++11"]


def find_pyx(path=None):
path = path or os.path.join(here, 'lzero')
Expand Down Expand Up @@ -60,8 +74,8 @@ def find_cython_extensions(path=None):
extname, [item],
include_dirs=[np.get_include()],
language="c++",
# extra_compile_args=["/std:c++latest"], # only for Windows
# extra_link_args=["/std:c++latest"], # only for Windows
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
))

return extensions
Expand Down

0 comments on commit dfaebaa

Please sign in to comment.