Skip to content

Commit 7dd7751

Browse files
updated types and export sizes
1 parent 56d0d39 commit 7dd7751

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

src/types.hpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct AgentID {
7676

7777
struct Goal
7878
{
79-
madrona::math::Vector2 position;
79+
madrona::math::Vector3 position;
8080
};
8181

8282
// WorldReset is a per-world singleton component that causes the current
@@ -185,19 +185,19 @@ struct AgentID {
185185
return SelfObservation{
186186
.speed = 0,
187187
.vehicle_size = {0, 0},
188-
.goal = {.position = {0, 0}},
188+
.goal = {.position = {0, 0, 0}},
189189
.collisionState = 0,
190190
.id = -1};
191191
}
192192
};
193193

194-
const size_t SelfObservationExportSize = 7;
194+
const size_t SelfObservationExportSize = 8;
195195

196196
static_assert(sizeof(SelfObservation) == sizeof(float) * SelfObservationExportSize);
197197

198198
struct MapObservation
199199
{
200-
madrona::math::Vector2 position;
200+
madrona::math::Vector3 position;
201201
Scale scale;
202202
float heading;
203203
float type;
@@ -207,7 +207,7 @@ struct AgentID {
207207
static inline MapObservation zero()
208208
{
209209
return MapObservation{
210-
.position = {0, 0},
210+
.position = {0, 0, 0},
211211
.scale = madrona::math::Diag3x3{0, 0, 0},
212212
.heading = 0,
213213
.type = static_cast<float>(EntityType::None),
@@ -217,14 +217,14 @@ struct AgentID {
217217
}
218218
};
219219

220-
const size_t MapObservationExportSize = 9;
220+
const size_t MapObservationExportSize = 10;
221221

222222
static_assert(sizeof(MapObservation) == sizeof(float) * MapObservationExportSize);
223223

224224
struct PartnerObservation
225225
{
226226
float speed;
227-
madrona::math::Vector2 position;
227+
madrona::math::Vector3 position;
228228
float heading;
229229
VehicleSize vehicle_size;
230230
float type;
@@ -233,7 +233,7 @@ struct AgentID {
233233
static inline PartnerObservation zero() {
234234
return PartnerObservation{
235235
.speed = 0,
236-
.position = {0, 0},
236+
.position = {0, 0, 0},
237237
.heading = 0,
238238
.vehicle_size = {0, 0},
239239
.type = static_cast<float>(EntityType::None),
@@ -255,7 +255,7 @@ struct AgentID {
255255
PartnerObservation obs[consts::kMaxAgentCount - 1];
256256
};
257257

258-
const size_t PartnerObservationExportSize = 8;
258+
const size_t PartnerObservationExportSize = 9;
259259

260260
static_assert(sizeof(PartnerObservations) == sizeof(float) *
261261
(consts::kMaxAgentCount - 1) * PartnerObservationExportSize);
@@ -275,7 +275,7 @@ struct AgentID {
275275
{
276276
float depth;
277277
float encodedType;
278-
madrona::math::Vector2 position;
278+
madrona::math::Vector3 position;
279279
};
280280

281281
// Linear depth values and entity type in a circle around the agent
@@ -286,7 +286,7 @@ struct AgentID {
286286
LidarSample samplesRoadLines[consts::numLidarSamples];
287287
};
288288

289-
const size_t LidarExportSize = 3 * consts::numLidarSamples * 4;
289+
const size_t LidarExportSize = 3 * consts::numLidarSamples * 5;
290290

291291
static_assert(sizeof(Lidar) == sizeof(float) * LidarExportSize);
292292
// Number of steps remaining in the episode. Allows non-recurrent policies
@@ -311,14 +311,14 @@ struct AgentID {
311311

312312
struct Trajectory
313313
{
314-
madrona::math::Vector2 positions[consts::kTrajectoryLength];
314+
madrona::math::Vector3 positions[consts::kTrajectoryLength];
315315
madrona::math::Vector2 velocities[consts::kTrajectoryLength];
316316
float headings[consts::kTrajectoryLength];
317317
float valids[consts::kTrajectoryLength];
318318
Action inverseActions[consts::kTrajectoryLength];
319319
};
320320

321-
const size_t TrajectoryExportSize = 2 * 2 * consts::kTrajectoryLength + 2 * consts::kTrajectoryLength + ActionExportSize * consts::kTrajectoryLength;
321+
const size_t TrajectoryExportSize = 3 * 2 * consts::kTrajectoryLength + 2 * consts::kTrajectoryLength + ActionExportSize * consts::kTrajectoryLength;
322322

323323
static_assert(sizeof(Trajectory) == sizeof(float) * TrajectoryExportSize);
324324

0 commit comments

Comments
 (0)