Skip to content

Commit

Permalink
fix random start state for aa and fc
Browse files Browse the repository at this point in the history
  • Loading branch information
StephAO committed Oct 25, 2022
1 parent 63d4b81 commit 054939d
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions src/overcooked_ai_py/mdp/overcooked_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,13 @@ def start_state_fn():
def get_fully_random_start_state_fn(self, mlam):
def start_state_fn(random_pos=False, random_dir=False, max_random_objs=0):
if random_pos:
valid_positions = self.get_valid_joint_player_positions()
if self.layout_name == 'asymmetric_advantages':
valid_positions = itertools.product([(7, 1), (5, 2), (6, 2), (7, 2), (5, 3), (6, 3), (7, 3)],
[(1, 1), (1, 2), (1, 2), (3, 2), (1, 3), (2, 3), (3, 3)])
elif self.layout_name == 'forced_coordination':
valid_positions = itertools.product([(1, 1), (1, 2), (1, 3)], [(3, 1), (3, 2), (3, 3)])
else:
valid_positions = self.get_valid_joint_player_positions()
start_pos = valid_positions[np.random.choice(len(valid_positions))]
else:
start_pos = self.start_player_positions
Expand Down Expand Up @@ -1090,7 +1096,7 @@ def start_state_fn(random_pos=False, random_dir=False, max_random_objs=0):

# Randomize counter items
# For each counter space, add up to max_objects objects on free counters
free_counters = self.find_free_counters_valid_for_both_players(start_state, mlam)
free_counters = self.get_empty_counter_locations(start_state)
max_num_objs = min(max_random_objs, len(free_counters))
if max_num_objs == 0:
return start_state
Expand All @@ -1115,7 +1121,13 @@ def start_state_fn(p_idx=0, curr_subtask='unknown', random_pos=False, random_dir
n_random_objs = num_random_objects if num_random_objects is not None else max_random_objs
t_idx = (p_idx + 1) % 2
if random_pos:
valid_positions = self.get_valid_joint_player_positions()
if self.layout_name == 'asymmetric_advantages':
valid_positions = itertools.product([(7, 1), (5, 2), (6, 2), (7, 2), (5, 3), (6, 3), (7, 3)],
[(1, 1), (1, 2), (1, 2), (3, 2), (1, 3), (2, 3), (3, 3)])
elif self.layout_name == 'forced_coordination':
valid_positions = itertools.product([(1, 1), (1, 2), (1, 3)], [(3, 1), (3, 2), (3, 3)])
else:
valid_positions = self.get_valid_joint_player_positions()
start_pos = valid_positions[np.random.choice(len(valid_positions))]
else:
start_pos = self.start_player_positions
Expand All @@ -1127,7 +1139,7 @@ def start_state_fn(p_idx=0, curr_subtask='unknown', random_pos=False, random_dir
player = start_state.players[p_idx]
if curr_subtask in ['get_onion_from_dispenser', 'get_plate_from_dish_rack',
'get_onion_from_counter', 'get_plate_from_counter', 'get_soup_from_counter']:
# The respective items must exist on a counter somewhere
# The respective items must exist on a reachable counter somewhere
if 'from_counter' in curr_subtask:
obj_name = None
if curr_subtask == 'get_onion_from_counter':
Expand Down Expand Up @@ -1195,14 +1207,16 @@ def start_state_fn(p_idx=0, curr_subtask='unknown', random_pos=False, random_dir
if 'closer' in curr_subtask:
# Make sure there is a counter free to put the object on
n_random_objs -= 1
free_counters = self.find_free_counters_valid_for_both_players(start_state, mlam)
free_counters = self.get_empty_counter_locations(start_state)
n_random_objs = min(n_random_objs, len(free_counters))

if n_random_objs == 0:
return start_state

if num_random_objects is not None:
num_objs = n_random_objs
else:
num_objs = np.random.randint(min(n_random_objs, len(free_counters)))

if num_objs == 0:
return start_state
num_objs = np.random.randint(n_random_objs)

counter_indices = np.random.choice(len(free_counters), size=num_objs, replace=False)
for counter_idx in counter_indices:
Expand Down

0 comments on commit 054939d

Please sign in to comment.