From efe9ffda5b2cf61af15bddbdbaccd2d77e7bd14f Mon Sep 17 00:00:00 2001 From: critter-mj Date: Wed, 8 Apr 2020 18:15:10 +0900 Subject: [PATCH] [modify] Game_State.to_numpy output oracle --- data_proc.py | 28 ++++++++++++++++++++++++++-- lib/mjtypes.py | 3 ++- main.py | 9 +++++++-- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/data_proc.py b/data_proc.py index 689972c..cc8082e 100644 --- a/data_proc.py +++ b/data_proc.py @@ -1,4 +1,5 @@ import pathlib +import glob import subprocess from lib.util import * @@ -98,7 +99,7 @@ def dump_child(self, dir_path, tenhou_id, action_type, X, Y): out_dir = pathlib.Path(out_dir_pathstr) if not out_dir.is_dir(): out_dir.mkdir(parents=True) - np.savez(out_dir_pathstr + "/" + action_type + "_" + tenhou_id, X, Y) + np.savez_compressed(out_dir_pathstr + "/" + action_type + "_" + tenhou_id, X, Y) X.clear() Y.clear() @@ -115,5 +116,28 @@ def proc_tenhou_mjailog(tenhou_id): dp = Data_Processor() game_record = read_log_json("tenhou_mjailog/" + tenhou_id[:4] + "/" + tenhou_id[:8] + "/" + tenhou_id + ".json") dp.process_record(game_record) - dp.dump("tenhou_npz", tenhou_id) + dp.dump("tenhou_npz", tenhou_id) + +def proc_batch_tenhou_mjailog(prefix, update): + if len(prefix) < 4: + print("proc_batch_tenhou_mjailog prefix too short") + return + target = "" + if len(prefix) <= 8: + target = "tenhou_mjailog/" + prefix[:4] + "/" + prefix + "*/*.json" + else: + target = "tenhou_mjailog/" + prefix[:4] + "/" + prefix[:8] + "/" + prefix + "*.json" + + file_list = glob.glob(target) + for file_name in file_list: + file_name = file_name.replace('\\', '/') + tenhou_id = file_name.split('/')[-1].split('.')[0] + + if not update: + discard_path = pathlib.Path("tenhou_npz/discard/" + tenhou_id[:4] + "/" + tenhou_id[:8] + "/discard_" + tenhou_id + ".npz") + if discard_path.is_file(): + continue + + print("process:", tenhou_id) + proc_tenhou_mjailog(tenhou_id) \ No newline at end of file diff --git a/lib/mjtypes.py b/lib/mjtypes.py index 933180f..81a9bd9 100644 --- a/lib/mjtypes.py +++ b/lib/mjtypes.py @@ -491,7 +491,8 @@ def to_json(self, view_pid): def to_numpy(self, my_pid): # my_pid is 0,1,2,3 - ps = np.concatenate([self.player_state[(my_pid + i)%4].to_numpy(i == 0) for i in range(4)]) + #ps = np.concatenate([self.player_state[(my_pid + i)%4].to_numpy(i == 0) for i in range(4)]) + ps = np.concatenate([self.player_state[(my_pid + i)%4].to_numpy(True) for i in range(4)]) return ps def get_game_state_start_kyoku(action_json_dict): diff --git a/main.py b/main.py index 2fd3ed7..bd995ed 100644 --- a/main.py +++ b/main.py @@ -304,10 +304,13 @@ def confirm_end_kyoku(): def main(args): if args.dump_feature_tenhou: - if args.tenhou_id == None: + if args.tenhou_id != None: + proc_tenhou_mjailog(args.tenhou_id) + elif args.prefix != None: + proc_batch_tenhou_mjailog(args.prefix, args.update) + else: print("please specify tenhou_id") return - proc_tenhou_mjailog(args.tenhou_id) elif args.dump_feature: if args.out_dir == None: print("please specify out_dir") @@ -335,5 +338,7 @@ def main(args): parser.add_argument('--out_dir') parser.add_argument('--dump_feature_tenhou', action='store_true') parser.add_argument('--tenhou_id') + parser.add_argument('--prefix') + parser.add_argument('--update', action='store_true') args = parser.parse_args() main(args)