Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit bd69f20

Browse files
author
Omegastick
committed
Implement combining RolloutStorages
1 parent 3782452 commit bd69f20

File tree

2 files changed

+128
-3
lines changed

2 files changed

+128
-3
lines changed

include/cpprl/storage.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class RolloutStorage
2727
int64_t hidden_state_size,
2828
torch::Device device);
2929

30+
RolloutStorage(std::vector<RolloutStorage *> individual_storages, torch::Device device);
31+
3032
void after_update();
3133
void compute_returns(torch::Tensor next_value,
3234
bool use_gae,

src/storage.cpp

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ RolloutStorage::RolloutStorage(int64_t num_steps,
2121
: device(device), num_steps(num_steps), step(0)
2222
{
2323
std::vector<int64_t> observations_shape{num_steps + 1, num_processes};
24-
observations_shape.insert(observations_shape.end(), obs_shape.begin(),
25-
obs_shape.end());
24+
observations_shape.insert(observations_shape.end(),
25+
obs_shape.begin(), obs_shape.end());
2626
observations = torch::zeros(observations_shape, torch::TensorOptions(device));
2727
hidden_states = torch::zeros({num_steps + 1, num_processes,
2828
hidden_state_size},
@@ -48,6 +48,59 @@ RolloutStorage::RolloutStorage(int64_t num_steps,
4848
masks = torch::ones({num_steps + 1, num_processes, 1}, torch::TensorOptions(device));
4949
}
5050

51+
RolloutStorage::RolloutStorage(std::vector<RolloutStorage *> individual_storages,
52+
torch::Device device)
53+
: device(device), step(0)
54+
{
55+
std::vector<torch::Tensor> observations_vec;
56+
std::transform(individual_storages.begin(), individual_storages.end(),
57+
std::back_inserter(observations_vec),
58+
[](RolloutStorage *storage) { return storage->get_observations(); });
59+
observations = torch::cat(observations_vec, 1);
60+
61+
std::vector<torch::Tensor> hidden_states_vec;
62+
std::transform(individual_storages.begin(), individual_storages.end(),
63+
std::back_inserter(hidden_states_vec),
64+
[](RolloutStorage *storage) { return storage->get_hidden_states(); });
65+
hidden_states = torch::cat(hidden_states_vec, 1);
66+
67+
std::vector<torch::Tensor> rewards_vec;
68+
std::transform(individual_storages.begin(), individual_storages.end(),
69+
std::back_inserter(rewards_vec),
70+
[](RolloutStorage *storage) { return storage->get_rewards(); });
71+
rewards = torch::cat(rewards_vec, 1);
72+
73+
std::vector<torch::Tensor> value_predictions_vec;
74+
std::transform(individual_storages.begin(), individual_storages.end(),
75+
std::back_inserter(value_predictions_vec),
76+
[](RolloutStorage *storage) { return storage->get_value_predictions(); });
77+
value_predictions = torch::cat(value_predictions_vec, 1);
78+
79+
std::vector<torch::Tensor> returns_vec;
80+
std::transform(individual_storages.begin(), individual_storages.end(),
81+
std::back_inserter(returns_vec),
82+
[](RolloutStorage *storage) { return storage->get_returns(); });
83+
returns = torch::cat(returns_vec, 1);
84+
85+
std::vector<torch::Tensor> action_log_probs_vec;
86+
std::transform(individual_storages.begin(), individual_storages.end(),
87+
std::back_inserter(action_log_probs_vec),
88+
[](RolloutStorage *storage) { return storage->get_action_log_probs(); });
89+
action_log_probs = torch::cat(action_log_probs_vec, 1);
90+
91+
std::vector<torch::Tensor> actions_vec;
92+
std::transform(individual_storages.begin(), individual_storages.end(),
93+
std::back_inserter(actions_vec),
94+
[](RolloutStorage *storage) { return storage->get_actions(); });
95+
actions = torch::cat(actions_vec, 1);
96+
97+
std::vector<torch::Tensor> masks_vec;
98+
std::transform(individual_storages.begin(), individual_storages.end(),
99+
std::back_inserter(masks_vec),
100+
[](RolloutStorage *storage) { return storage->get_masks(); });
101+
masks = torch::cat(masks_vec, 1);
102+
}
103+
51104
void RolloutStorage::after_update()
52105
{
53106
observations[0].copy_(observations[-1]);
@@ -381,7 +434,7 @@ TEST_CASE("RolloutStorage")
381434
}
382435
}
383436

384-
SUBCASE("after_update() copies last observation, hidden state and mask to "
437+
SUBCASE("after_update() copies last observation, moves hidden state and mask to "
385438
"the 0th timestep")
386439
{
387440
RolloutStorage storage(3, 2, {3}, ActionSpace{"Discrete", {3}}, 2, torch::kCPU);
@@ -445,5 +498,75 @@ TEST_CASE("RolloutStorage")
445498
auto generator = storage.recurrent_generator(torch::rand({3, 5, 1}), 5);
446499
generator->next();
447500
}
501+
502+
SUBCASE("Can combine multiple storages into one")
503+
{
504+
std::vector<RolloutStorage> storages;
505+
for (int i = 0; i < 5; ++i)
506+
{
507+
storages.push_back({3, 1, {4}, ActionSpace{"Discrete", {3}}, 5, torch::kCPU});
508+
509+
std::vector<float> value_preds{1};
510+
std::vector<float> rewards{1};
511+
std::vector<float> masks{1};
512+
storages[i].insert(torch::zeros({1, 4}),
513+
torch::zeros({1, 5}),
514+
torch::zeros({1, 1}),
515+
torch::zeros({1, 1}),
516+
torch::from_blob(value_preds.data(), {1, 1}),
517+
torch::from_blob(rewards.data(), {1, 1}),
518+
torch::from_blob(masks.data(), {1, 1}));
519+
value_preds = {2};
520+
rewards = {2};
521+
masks = {0};
522+
storages[i].insert(torch::zeros({1, 4}),
523+
torch::zeros({1, 5}),
524+
torch::zeros({1, 1}),
525+
torch::zeros({1, 1}),
526+
torch::from_blob(value_preds.data(), {1, 1}),
527+
torch::from_blob(rewards.data(), {1, 1}),
528+
torch::from_blob(masks.data(), {1, 1}));
529+
value_preds = {3};
530+
rewards = {3};
531+
masks = {1};
532+
storages[i].insert(torch::zeros({1, 4}),
533+
torch::zeros({1, 5}),
534+
torch::zeros({1, 1}),
535+
torch::zeros({1, 1}),
536+
torch::from_blob(value_preds.data(), {1, 1}),
537+
torch::from_blob(rewards.data(), {1, 1}),
538+
torch::from_blob(masks.data(), {1, 1}));
539+
}
540+
541+
std::vector<RolloutStorage *> pointers;
542+
std::transform(storages.begin(), storages.end(), std::back_inserter(pointers),
543+
[](RolloutStorage &storage) { return &storage; });
544+
545+
RolloutStorage combined_storage(pointers, torch::kCPU);
546+
547+
CHECK(combined_storage.get_observations().size(0) == 4);
548+
CHECK(combined_storage.get_observations().size(1) == 5);
549+
CHECK(combined_storage.get_hidden_states().size(0) == 4);
550+
CHECK(combined_storage.get_hidden_states().size(1) == 5);
551+
CHECK(combined_storage.get_hidden_states().size(2) == 5);
552+
CHECK(combined_storage.get_rewards().size(0) == 3);
553+
CHECK(combined_storage.get_rewards().size(1) == 5);
554+
CHECK(combined_storage.get_rewards().size(2) == 1);
555+
CHECK(combined_storage.get_value_predictions().size(0) == 4);
556+
CHECK(combined_storage.get_value_predictions().size(1) == 5);
557+
CHECK(combined_storage.get_value_predictions().size(2) == 1);
558+
CHECK(combined_storage.get_returns().size(0) == 4);
559+
CHECK(combined_storage.get_returns().size(1) == 5);
560+
CHECK(combined_storage.get_returns().size(2) == 1);
561+
CHECK(combined_storage.get_action_log_probs().size(0) == 3);
562+
CHECK(combined_storage.get_action_log_probs().size(1) == 5);
563+
CHECK(combined_storage.get_action_log_probs().size(2) == 1);
564+
CHECK(combined_storage.get_actions().size(0) == 3);
565+
CHECK(combined_storage.get_actions().size(1) == 5);
566+
CHECK(combined_storage.get_actions().size(2) == 1);
567+
CHECK(combined_storage.get_masks().size(0) == 4);
568+
CHECK(combined_storage.get_masks().size(1) == 5);
569+
CHECK(combined_storage.get_masks().size(2) == 1);
570+
}
448571
}
449572
}

0 commit comments

Comments
 (0)