Skip to content

Commit 2a5f8ba

Browse files
committed
test with smaller k
1 parent 5571750 commit 2a5f8ba

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

Test_gen/GRU_atari_k3.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import os
2+
os.environ["CUDA_VISIBLE_DEVICES"]="-1" #just use CPU
3+
from sandbox.rocky.tf.algos.trpo import TRPO
4+
from sandbox.rocky.tf.algos.vpg import VPG
5+
from Algo.trpo_transfer import TRPO_t
6+
from Algo.vpg_transfer import VPG_t
7+
from Algo.npo_transfer import NPO_t
8+
from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
9+
from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer, FiniteDifferenceHvp
10+
from sandbox.rocky.tf.optimizers.penalty_lbfgs_optimizer import PenaltyLbfgsOptimizer
11+
12+
from Env.atari import AtariEnv
13+
from rllab.misc.instrument import stub, run_experiment_lite
14+
from Policy_gen.qmdp_policy import QMDPPolicy
15+
from sandbox.rocky.tf.policies.categorical_gru_policy import CategoricalGRUPolicy
16+
17+
import lasagne.nonlinearities as NL
18+
from sandbox.rocky.tf.envs.base import TfEnv
19+
from rllab.misc import logger
20+
import os.path as osp
21+
import tensorflow as tf
22+
from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler
23+
import joblib
24+
import dill
25+
26+
import sys
27+
game_name = sys.argv[1]
28+
mask_num = int(sys.argv[2])
29+
30+
31+
env = TfEnv(AtariEnv(mask_num,game_name))
32+
qmdp_param = {}
33+
qmdp_param['K'] = 3
34+
qmdp_param['obs_len'] = env.spec.observation_space.flat_dim
35+
qmdp_param['num_action'] = env.spec.action_space.flat_dim
36+
qmdp_param['num_state'] = 32 #env.spec.observation_space.flat_dim
37+
qmdp_param['info_len'] = qmdp_param['num_state']
38+
39+
# log_dir = "./Data/FixMapStartState"
40+
log_dir = "./Data/qmdp_"+game_name+'_'+str(mask_num)
41+
42+
tabular_log_file = osp.join(log_dir, "progress.csv")
43+
text_log_file = osp.join(log_dir, "debug.log")
44+
params_log_file = osp.join(log_dir, "params.json")
45+
pkl_file = osp.join(log_dir, "params.pkl")
46+
47+
logger.add_text_output(text_log_file)
48+
logger.add_tabular_output(tabular_log_file)
49+
prev_snapshot_dir = logger.get_snapshot_dir()
50+
prev_mode = logger.get_snapshot_mode()
51+
logger.set_snapshot_dir(log_dir)
52+
logger.set_snapshot_mode("gaplast")
53+
logger.set_snapshot_gap(1000)
54+
logger.set_log_tabular_only(False)
55+
logger.push_prefix("[%s] " % (game_name+'_'+str(mask_num)))
56+
57+
from Algo import parallel_sampler
58+
parallel_sampler.initialize(n_parallel=1)
59+
parallel_sampler.set_seed(0)
60+
61+
policy = QMDPPolicy(
62+
env_spec=env.spec,
63+
name="QMDP",
64+
qmdp_param=qmdp_param,
65+
)
66+
67+
68+
baseline = LinearFeatureBaseline(env_spec=env.spec)
69+
70+
with tf.Session() as sess:
71+
72+
# writer = tf.summary.FileWriter(logdir=log_dir,)
73+
74+
algo = VPG_t(
75+
env=env,
76+
policy=policy,
77+
baseline=baseline,
78+
batch_size=2048,#2*env._wrapped_env.params['traj_limit'],
79+
max_path_length=200,
80+
n_itr=10000,
81+
discount=0.95,
82+
step_size=0.01,
83+
record_rewards=True,
84+
transfer=False,
85+
)
86+
87+
algo.train(sess)
88+
# tf.summary.merge_all()
89+
# print(sess.graph)
90+
# writer.add_graph(sess.graph)
91+
# writer.close()

0 commit comments

Comments
 (0)