Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ map_dir = "resources/drive/binaries/training"
num_maps = 10000
; Determines which step of the trajectory to initialize the agents at upon reset
init_steps = 0
; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only"
; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only", "control_mixed_play"
control_mode = "control_vehicles"
; Options: "created_all_valid", "create_only_controlled"
init_mode = "create_all_valid"
; Sets the maximum number of controllable agents per scene, ONLY used if control_mode is "control_mixed_play"
max_controlled_agents = 32

[train]
seed=42
Expand Down
6 changes: 6 additions & 0 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {
int init_steps = unpack(kwargs, "init_steps");
int goal_behavior = unpack(kwargs, "goal_behavior");
float goal_target_distance = unpack(kwargs, "goal_target_distance");
int max_controlled_agents = unpack(kwargs, "max_controlled_agents");

clock_gettime(CLOCK_REALTIME, &ts);
srand(ts.tv_nsec); // Always use random sampling with replacement
Expand Down Expand Up @@ -104,8 +105,10 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {
env->init_steps = init_steps;
env->goal_behavior = goal_behavior;
env->goal_target_distance = goal_target_distance;
env->max_controlled_agents = max_controlled_agents;
snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id);
env->entities = load_map_binary(map_file, env);
// Count the number of controllable agents in map
set_active_agents(env);

// Skip map if it doesn't contain any controllable agents
Expand Down Expand Up @@ -218,6 +221,7 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
OVERRIDE_FLOAT(goal_target_distance);
OVERRIDE_FLOAT(goal_radius);
OVERRIDE_FLOAT(goal_speed);
OVERRIDE_INT(max_controlled_agents);

#undef OVERRIDE_INT
#undef OVERRIDE_FLOAT
Expand Down Expand Up @@ -264,6 +268,8 @@ static int my_log(PyObject *dict, Log *log) {
assign_to_dict(dict, "dnf_rate", log->dnf_rate);
assign_to_dict(dict, "completion_rate", log->completion_rate);
assign_to_dict(dict, "lane_alignment_rate", log->lane_alignment_rate);
assign_to_dict(dict, "perc_controlled", log->perc_controlled);
assign_to_dict(dict, "perc_other", log->perc_other);
assign_to_dict(dict, "offroad_per_agent", log->offroad_per_agent);
assign_to_dict(dict, "collisions_per_agent", log->collisions_per_agent);
assign_to_dict(dict, "goals_sampled_this_episode", log->goals_sampled_this_episode);
Expand Down
29 changes: 22 additions & 7 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#define CONTROL_AGENTS 1
#define CONTROL_WOSAC 2
#define CONTROL_SDC_ONLY 3
#define CONTROL_MIXED_PLAY 4

// Minimum distance to goal position
#define MIN_DISTANCE_TO_GOAL 2.0f
Expand Down Expand Up @@ -163,6 +164,8 @@ struct Log {
float active_agent_count;
float expert_static_agent_count;
float static_agent_count;
float perc_controlled;
float perc_other;
};

typedef struct Entity Entity;
Expand Down Expand Up @@ -317,7 +320,6 @@ struct Drive {
float reward_goal_post_respawn;
float goal_radius;
float goal_speed;
int max_controlled_agents;
int logs_capacity;
int goal_behavior;
float goal_target_distance;
Expand All @@ -330,6 +332,7 @@ struct Drive {
int *tracks_to_predict_indices;
int init_mode;
int control_mode;
int max_controlled_agents;
};

void add_log(Drive *env) {
Expand Down Expand Up @@ -374,6 +377,9 @@ void add_log(Drive *env) {
env->log.active_agent_count += env->active_agent_count;
env->log.expert_static_agent_count += env->expert_static_agent_count;
env->log.static_agent_count += env->static_agent_count;
int total = env->active_agent_count + env->static_agent_count;
env->log.perc_controlled += (float)env->active_agent_count / (float)total;
env->log.perc_other += (float)env->static_agent_count / (float)total;
env->log.n += 1;
}
}
Expand Down Expand Up @@ -1201,10 +1207,9 @@ void compute_agent_metrics(Drive *env, int agent_idx) {
return;
}

bool should_control_agent(Drive *env, int agent_idx) {

bool should_control_agent(Drive *env, int agent_idx, int control_limit) {
// Check if we have room for more agents or are already at capacity
if (env->active_agent_count >= env->num_agents) {
if (env->active_agent_count >= control_limit) {
return false;
}

Expand Down Expand Up @@ -1267,6 +1272,13 @@ void set_active_agents(Drive *env) {
env->num_agents = MAX_AGENTS;
}

int control_limit;
if (env->control_mode == CONTROL_MIXED_PLAY) {
control_limit = (env->max_controlled_agents < env->num_agents) ? env->max_controlled_agents : env->num_agents;
} else {
control_limit = env->num_agents;
}

// If we have a SDC index (WOMD), initialize it first:
int sdc_index = env->sdc_track_index;

Expand Down Expand Up @@ -1310,17 +1322,17 @@ void set_active_agents(Drive *env) {
// Determine if this agent should be policy-controlled
bool is_controlled = false;

is_controlled = should_control_agent(env, i);
is_controlled = should_control_agent(env, i, control_limit);

if (is_controlled) {
active_agent_indices[env->active_agent_count] = i;
env->active_agent_count++;
env->entities[i].active_agent = 1;
} else if (env->init_mode != INIT_ONLY_CONTROLLABLE_AGENTS) {
static_agent_indices[env->static_agent_count] = i;
env->static_agent_count++;
env->static_agent_count++; // Includes expert replay and static agents
env->entities[i].active_agent = 0;
if (env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->num_agents) {
if (env->entities[i].mark_as_expert == 1 || env->active_agent_count == control_limit) {
expert_static_agent_indices[env->expert_static_agent_count] = i;
env->expert_static_agent_count++;
env->entities[i].mark_as_expert = 1;
Expand All @@ -1341,6 +1353,9 @@ void set_active_agents(Drive *env) {
for (int i = 0; i < env->expert_static_agent_count; i++) {
env->expert_static_agent_indices[i] = expert_static_agent_indices[i];
}
// printf("Total actors: %d, Active agents: %d, Static agents: %d, Expert static agents: %d\n", env->num_actors,
// env->active_agent_count, env->static_agent_count, env->expert_static_agent_count);
// printf("Control mode: %d, max controlled agents: %d\n", env->control_mode, env->max_controlled_agents);

return;
}
Expand Down
16 changes: 9 additions & 7 deletions pufferlib/ocean/drive/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def __init__(
num_agents=512,
action_type="discrete",
dynamics_model="classic",
max_controlled_agents=-1,
buf=None,
seed=1,
init_steps=0,
init_mode="create_all_valid",
control_mode="control_vehicles",
max_controlled_agents=32,
map_dir="resources/drive/binaries/training",
):
# env
Expand All @@ -63,6 +63,7 @@ def __init__(
self.termination_mode = termination_mode
self.resample_frequency = resample_frequency
self.dynamics_model = dynamics_model
self.max_controlled_agents = max_controlled_agents

# Observation space calculation
self.ego_features = {"classic": binding.EGO_FEATURES_CLASSIC, "jerk": binding.EGO_FEATURES_JERK}.get(
Expand Down Expand Up @@ -96,9 +97,11 @@ def __init__(
self.control_mode = 2
elif self.control_mode_str == "control_sdc_only":
self.control_mode = 3
elif self.control_mode_str == "control_mixed_play":
self.control_mode = 4
else:
raise ValueError(
f"control_mode must be one of 'control_vehicles', 'control_wosac', or 'control_agents'. Got: {self.control_mode_str}"
f"control_mode must be one of 'control_vehicles', 'control_wosac', 'control_agents' or 'control_mixed_play'. Got: {self.control_mode_str}"
)
if self.init_mode_str == "create_all_valid":
self.init_mode = 0
Expand Down Expand Up @@ -140,7 +143,6 @@ def __init__(
raise ValueError(
f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). Please reduce num_maps or add more maps to resources/drive/binaries."
)
self.max_controlled_agents = int(max_controlled_agents)

# Iterate through all maps to count total agents that can be initialized for each map
agent_offsets, map_ids, num_envs = binding.shared(
Expand All @@ -150,9 +152,9 @@ def __init__(
init_mode=self.init_mode,
control_mode=self.control_mode,
init_steps=self.init_steps,
max_controlled_agents=self.max_controlled_agents,
goal_behavior=self.goal_behavior,
goal_target_distance=self.goal_target_distance,
max_controlled_agents=self.max_controlled_agents,
)

self.num_agents = agent_offsets[-1]
Expand Down Expand Up @@ -186,14 +188,14 @@ def __init__(
dt=dt,
episode_length=(int(episode_length) if episode_length is not None else None),
termination_mode=(int(self.termination_mode) if self.termination_mode is not None else 0),
max_controlled_agents=self.max_controlled_agents,
map_id=map_ids[i],
max_agents=nxt - cur,
ini_file="pufferlib/config/ocean/drive.ini",
init_steps=init_steps,
init_mode=self.init_mode,
control_mode=self.control_mode,
map_dir=map_dir,
max_controlled_agents=self.max_controlled_agents,
)
env_ids.append(env_id)

Expand All @@ -218,11 +220,11 @@ def resample_maps(self):
init_mode=self.init_mode,
control_mode=self.control_mode,
init_steps=self.init_steps,
max_controlled_agents=self.max_controlled_agents,
goal_behavior=self.goal_behavior,
goal_target_distance=self.goal_target_distance,
goal_speed=self.goal_speed,
map_dir=self.map_dir,
max_controlled_agents=self.max_controlled_agents,
)
self.agent_offsets = agent_offsets
self.map_ids = map_ids
Expand Down Expand Up @@ -253,7 +255,6 @@ def resample_maps(self):
offroad_behavior=self.offroad_behavior,
dt=self.dt,
episode_length=(int(self.episode_length) if self.episode_length is not None else None),
max_controlled_agents=self.max_controlled_agents,
map_id=map_ids[i],
max_agents=nxt - cur,
ini_file="pufferlib/config/ocean/drive.ini",
Expand All @@ -262,6 +263,7 @@ def resample_maps(self):
control_mode=self.control_mode,
map_dir=self.map_dir,
termination_mode=(int(self.termination_mode) if self.termination_mode is not None else 0),
max_controlled_agents=self.max_controlled_agents,
)
env_ids.append(env_id)
self.c_envs = binding.vectorize(*env_ids)
Expand Down
5 changes: 5 additions & 0 deletions pufferlib/ocean/env_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ typedef struct {
int init_steps;
int init_mode;
int control_mode;
int max_controlled_agents;
char map_dir[256];
} env_init_config;

Expand Down Expand Up @@ -104,6 +105,8 @@ static int handler(void *config, const char *section, const char *name, const ch
env_config->control_mode = 2;
} else if (strcmp(value, "\"control_sdc_only\"") == 0 || strcmp(value, "control_sdc_only") == 0) {
env_config->control_mode = 3;
} else if (strcmp(value, "\"control_mixed_play\"") == 0 || strcmp(value, "control_mixed_play") == 0) {
env_config->control_mode = 4;
} else {
printf("Warning: Unknown control_mode value '%s', defaulting to CONTROL_VEHICLES\n", value);
env_config->control_mode = 0; // Default to CONTROL_VEHICLES
Expand All @@ -114,6 +117,8 @@ static int handler(void *config, const char *section, const char *name, const ch
env_config->map_dir[sizeof(env_config->map_dir) - 1] = '\0';
}
// printf("Parsed map_dir: '%s'\n", env_config->map_dir);
} else if (MATCH("env", "max_controlled_agents")) {
env_config->max_controlled_agents = atoi(value);
} else {
return 0; // Unknown section/name, indicate failure to handle
}
Expand Down