Skip to content

Commit

Permalink
fix instruct parse
Browse files Browse the repository at this point in the history
  • Loading branch information
quantumiracle committed Nov 15, 2022
1 parent a10fd9c commit e1b1a8c
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions scripts/test_multi_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ def parse_instruct(instruct, type):
valid = True

if type == 'rotate':
instruct_list = instruct[4:].split('. ') #remove ' A: '
instruct = instruct.split('.')[0] # take first of cot
# processed_instruct = instruct[4:] #remove ' A: '
processed_instruct = instruct[instruct.index('Rotate'):] # start from 'Rotate', for case: '\nRotate ***'
instruct_list = processed_instruct.split('. ')
first_instruct = processed_instruct.split('.')[0] # take first of cot
# Rotate blue triangle by 71 degrees.
words = instruct.split(' ')
words = first_instruct.split(' ')
words = [i for i in words if i not in unuseful_words] # filter

anchor_pos1 = words.index('Rotate')
Expand Down Expand Up @@ -61,7 +63,9 @@ def parse_instruct(instruct, type):
}

elif type == 'stack':
instruct_list = instruct[4:].split('. ') #remove ' A: '
# processed_instruct = instruct[4:] #remove ' A: '
processed_instruct = instruct[instruct.index('Put'):] # start from 'Put', for case: '\nPut ***'
instruct_list = processed_instruct.split('. ')
drag_obj_list = []
base_obj_list = []
drag_color_list = []
Expand Down Expand Up @@ -112,7 +116,9 @@ def parse_instruct(instruct, type):
}

elif type == 'put':
instruct_list = instruct[4:].split('. ') #remove ' A: '
# processed_instruct = instruct[4:] #remove ' A: '
processed_instruct = instruct[instruct.index('Put'):] # start from 'Put', for case: '\nPut ***'
instruct_list = processed_instruct.split('. ')
drag_obj_list = []
base_obj_list = []
drag_color_list = []
Expand Down Expand Up @@ -384,7 +390,7 @@ def rollout(policy, task_type, seed, device, prefix='', num_prompts=100, cots=3,
logger.save(f'data/{prefix}')

if __name__ == "__main__":
task_type = ['stack', 'rotate', 'put'][1]
task_type = ['stack', 'rotate', 'put'][-1]
model_size = ['4M', '200M'][-1]
model_ckpt = f'../models/{model_size}.ckpt'
device = 'cpu'
Expand Down

0 comments on commit e1b1a8c

Please sign in to comment.