From e1b1a8cf18f8916fc127d3cdd0fda2b05599101d Mon Sep 17 00:00:00 2001 From: quantumiracle Date: Tue, 15 Nov 2022 13:38:38 -0500 Subject: [PATCH] fix instruct parse --- scripts/test_multi_stage.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/scripts/test_multi_stage.py b/scripts/test_multi_stage.py index 1af9dcf..8939e86 100644 --- a/scripts/test_multi_stage.py +++ b/scripts/test_multi_stage.py @@ -22,10 +22,12 @@ def parse_instruct(instruct, type): valid = True if type == 'rotate': - instruct_list = instruct[4:].split('. ') #remove ' A: ' - instruct = instruct.split('.')[0] # take first of cot + # processed_instruct = instruct[4:] #remove ' A: ' + processed_instruct = instruct[instruct.index('Rotate'):] # start from 'Rotate', for case: '\nRotate ***' + instruct_list = processed_instruct.split('. ') + first_instruct = processed_instruct.split('.')[0] # take first of cot # Rotate blue triangle by 71 degrees. - words = instruct.split(' ') + words = first_instruct.split(' ') words = [i for i in words if i not in unuseful_words] # filter anchor_pos1 = words.index('Rotate') @@ -61,7 +63,9 @@ def parse_instruct(instruct, type): } elif type == 'stack': - instruct_list = instruct[4:].split('. ') #remove ' A: ' + # processed_instruct = instruct[4:] #remove ' A: ' + processed_instruct = instruct[instruct.index('Put'):] # start from 'Put', for case: '\nPut ***' + instruct_list = processed_instruct.split('. ') drag_obj_list = [] base_obj_list = [] drag_color_list = [] @@ -112,7 +116,9 @@ def parse_instruct(instruct, type): } elif type == 'put': - instruct_list = instruct[4:].split('. ') #remove ' A: ' + # processed_instruct = instruct[4:] #remove ' A: ' + processed_instruct = instruct[instruct.index('Put'):] # start from 'Put', for case: '\nPut ***' + instruct_list = processed_instruct.split('. ') drag_obj_list = [] base_obj_list = [] drag_color_list = [] @@ -384,7 +390,7 @@ def rollout(policy, task_type, seed, device, prefix='', num_prompts=100, cots=3, logger.save(f'data/{prefix}') if __name__ == "__main__": - task_type = ['stack', 'rotate', 'put'][1] + task_type = ['stack', 'rotate', 'put'][-1] model_size = ['4M', '200M'][-1] model_ckpt = f'../models/{model_size}.ckpt' device = 'cpu'