forked from ku-dmlab/arc_trajectory_generator
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
67 lines (52 loc) · 2.09 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import gymnasium as gym
import hydra
from omegaconf import DictConfig
from omegaconf import OmegaConf
from arcle.loaders import Loader
from loader import SizeConstrainedLoader
import wandb
from ppo.ppo import learn
class TestLoader(Loader):
def __init__(self, size_x, size_y, **kwargs):
self.size_x = size_x
self.size_y = size_y
self.rng = np.random.default_rng(12345)
super().__init__(**kwargs)
def get_path(self, **kwargs):
return ['']
def pick(self, **kwargs):
return self.parse()[0]
def parse(self, **kwargs):
ti= np.zeros((self.size_x,self.size_y), dtype=np.uint8)
to = np.zeros((self.size_x,self.size_y), dtype=np.uint8)
ei = np.zeros((self.size_x,self.size_y), dtype=np.uint8)
eo = np.zeros((self.size_x,self.size_y), dtype=np.uint8)
ti[0:self.size_x, 0:self.size_y] = self.rng.integers(0,10, size=[self.size_x,self.size_y])
to[0:self.size_x, 0:self.size_y] = self.rng.integers(0,10, size=[self.size_x,self.size_y])
ei[0:self.size_x, 0:self.size_y] = self.rng.integers(0,10, size=[self.size_x,self.size_y])
eo[0:self.size_x, 0:self.size_y] = self.rng.integers(0,10, size=[self.size_x,self.size_y])
return [([ti],[to],[ei],[eo], {'desc': "just for test"})]
@hydra.main(config_path="ppo", config_name="ppo_config_random")
def main(cfg: DictConfig) -> None:
wandb.init(
project="arc_traj_gen",
config=OmegaConf.to_container(cfg)
)
if cfg.env.use_arc:
env = gym.make(
'ARCLE/O2ARCv2Env-v0',
data_loader = SizeConstrainedLoader(cfg.env.grid_x),
max_trial = 3,
max_grid_size=(cfg.env.grid_x, cfg.env.grid_y),
colors=cfg.env.num_colors)
else:
env = gym.make(
'ARCLE/O2ARCv2Env-v0',
data_loader = TestLoader(cfg.env.grid_x, cfg.env.grid_y),
max_trial = 3,
max_grid_size=(cfg.env.grid_x, cfg.env.grid_y),
colors=cfg.env.num_colors)
learn(cfg, env)
if __name__ == "__main__":
main()