@@ -21,8 +21,8 @@ RolloutStorage::RolloutStorage(int64_t num_steps,
2121 : device(device), num_steps(num_steps), step(0 )
2222{
2323 std::vector<int64_t > observations_shape{num_steps + 1 , num_processes};
24- observations_shape.insert (observations_shape.end (), obs_shape. begin (),
25- obs_shape.end ());
24+ observations_shape.insert (observations_shape.end (),
25+ obs_shape.begin (), obs_shape. end ());
2626 observations = torch::zeros (observations_shape, torch::TensorOptions (device));
2727 hidden_states = torch::zeros ({num_steps + 1 , num_processes,
2828 hidden_state_size},
@@ -48,6 +48,59 @@ RolloutStorage::RolloutStorage(int64_t num_steps,
4848 masks = torch::ones ({num_steps + 1 , num_processes, 1 }, torch::TensorOptions (device));
4949}
5050
51+ RolloutStorage::RolloutStorage (std::vector<RolloutStorage *> individual_storages,
52+ torch::Device device)
53+ : device(device), step(0 )
54+ {
55+ std::vector<torch::Tensor> observations_vec;
56+ std::transform (individual_storages.begin (), individual_storages.end (),
57+ std::back_inserter (observations_vec),
58+ [](RolloutStorage *storage) { return storage->get_observations (); });
59+ observations = torch::cat (observations_vec, 1 );
60+
61+ std::vector<torch::Tensor> hidden_states_vec;
62+ std::transform (individual_storages.begin (), individual_storages.end (),
63+ std::back_inserter (hidden_states_vec),
64+ [](RolloutStorage *storage) { return storage->get_hidden_states (); });
65+ hidden_states = torch::cat (hidden_states_vec, 1 );
66+
67+ std::vector<torch::Tensor> rewards_vec;
68+ std::transform (individual_storages.begin (), individual_storages.end (),
69+ std::back_inserter (rewards_vec),
70+ [](RolloutStorage *storage) { return storage->get_rewards (); });
71+ rewards = torch::cat (rewards_vec, 1 );
72+
73+ std::vector<torch::Tensor> value_predictions_vec;
74+ std::transform (individual_storages.begin (), individual_storages.end (),
75+ std::back_inserter (value_predictions_vec),
76+ [](RolloutStorage *storage) { return storage->get_value_predictions (); });
77+ value_predictions = torch::cat (value_predictions_vec, 1 );
78+
79+ std::vector<torch::Tensor> returns_vec;
80+ std::transform (individual_storages.begin (), individual_storages.end (),
81+ std::back_inserter (returns_vec),
82+ [](RolloutStorage *storage) { return storage->get_returns (); });
83+ returns = torch::cat (returns_vec, 1 );
84+
85+ std::vector<torch::Tensor> action_log_probs_vec;
86+ std::transform (individual_storages.begin (), individual_storages.end (),
87+ std::back_inserter (action_log_probs_vec),
88+ [](RolloutStorage *storage) { return storage->get_action_log_probs (); });
89+ action_log_probs = torch::cat (action_log_probs_vec, 1 );
90+
91+ std::vector<torch::Tensor> actions_vec;
92+ std::transform (individual_storages.begin (), individual_storages.end (),
93+ std::back_inserter (actions_vec),
94+ [](RolloutStorage *storage) { return storage->get_actions (); });
95+ actions = torch::cat (actions_vec, 1 );
96+
97+ std::vector<torch::Tensor> masks_vec;
98+ std::transform (individual_storages.begin (), individual_storages.end (),
99+ std::back_inserter (masks_vec),
100+ [](RolloutStorage *storage) { return storage->get_masks (); });
101+ masks = torch::cat (masks_vec, 1 );
102+ }
103+
51104void RolloutStorage::after_update ()
52105{
53106 observations[0 ].copy_ (observations[-1 ]);
@@ -381,7 +434,7 @@ TEST_CASE("RolloutStorage")
381434 }
382435 }
383436
384- SUBCASE (" after_update() copies last observation, hidden state and mask to "
437+ SUBCASE (" after_update() copies last observation, moves hidden state and mask to "
385438 " the 0th timestep" )
386439 {
387440 RolloutStorage storage (3 , 2 , {3 }, ActionSpace{" Discrete" , {3 }}, 2 , torch::kCPU );
@@ -445,5 +498,75 @@ TEST_CASE("RolloutStorage")
445498 auto generator = storage.recurrent_generator (torch::rand ({3 , 5 , 1 }), 5 );
446499 generator->next ();
447500 }
501+
502+ SUBCASE (" Can combine multiple storages into one" )
503+ {
504+ std::vector<RolloutStorage> storages;
505+ for (int i = 0 ; i < 5 ; ++i)
506+ {
507+ storages.push_back ({3 , 1 , {4 }, ActionSpace{" Discrete" , {3 }}, 5 , torch::kCPU });
508+
509+ std::vector<float > value_preds{1 };
510+ std::vector<float > rewards{1 };
511+ std::vector<float > masks{1 };
512+ storages[i].insert (torch::zeros ({1 , 4 }),
513+ torch::zeros ({1 , 5 }),
514+ torch::zeros ({1 , 1 }),
515+ torch::zeros ({1 , 1 }),
516+ torch::from_blob (value_preds.data (), {1 , 1 }),
517+ torch::from_blob (rewards.data (), {1 , 1 }),
518+ torch::from_blob (masks.data (), {1 , 1 }));
519+ value_preds = {2 };
520+ rewards = {2 };
521+ masks = {0 };
522+ storages[i].insert (torch::zeros ({1 , 4 }),
523+ torch::zeros ({1 , 5 }),
524+ torch::zeros ({1 , 1 }),
525+ torch::zeros ({1 , 1 }),
526+ torch::from_blob (value_preds.data (), {1 , 1 }),
527+ torch::from_blob (rewards.data (), {1 , 1 }),
528+ torch::from_blob (masks.data (), {1 , 1 }));
529+ value_preds = {3 };
530+ rewards = {3 };
531+ masks = {1 };
532+ storages[i].insert (torch::zeros ({1 , 4 }),
533+ torch::zeros ({1 , 5 }),
534+ torch::zeros ({1 , 1 }),
535+ torch::zeros ({1 , 1 }),
536+ torch::from_blob (value_preds.data (), {1 , 1 }),
537+ torch::from_blob (rewards.data (), {1 , 1 }),
538+ torch::from_blob (masks.data (), {1 , 1 }));
539+ }
540+
541+ std::vector<RolloutStorage *> pointers;
542+ std::transform (storages.begin (), storages.end (), std::back_inserter (pointers),
543+ [](RolloutStorage &storage) { return &storage; });
544+
545+ RolloutStorage combined_storage (pointers, torch::kCPU );
546+
547+ CHECK (combined_storage.get_observations ().size (0 ) == 4 );
548+ CHECK (combined_storage.get_observations ().size (1 ) == 5 );
549+ CHECK (combined_storage.get_hidden_states ().size (0 ) == 4 );
550+ CHECK (combined_storage.get_hidden_states ().size (1 ) == 5 );
551+ CHECK (combined_storage.get_hidden_states ().size (2 ) == 5 );
552+ CHECK (combined_storage.get_rewards ().size (0 ) == 3 );
553+ CHECK (combined_storage.get_rewards ().size (1 ) == 5 );
554+ CHECK (combined_storage.get_rewards ().size (2 ) == 1 );
555+ CHECK (combined_storage.get_value_predictions ().size (0 ) == 4 );
556+ CHECK (combined_storage.get_value_predictions ().size (1 ) == 5 );
557+ CHECK (combined_storage.get_value_predictions ().size (2 ) == 1 );
558+ CHECK (combined_storage.get_returns ().size (0 ) == 4 );
559+ CHECK (combined_storage.get_returns ().size (1 ) == 5 );
560+ CHECK (combined_storage.get_returns ().size (2 ) == 1 );
561+ CHECK (combined_storage.get_action_log_probs ().size (0 ) == 3 );
562+ CHECK (combined_storage.get_action_log_probs ().size (1 ) == 5 );
563+ CHECK (combined_storage.get_action_log_probs ().size (2 ) == 1 );
564+ CHECK (combined_storage.get_actions ().size (0 ) == 3 );
565+ CHECK (combined_storage.get_actions ().size (1 ) == 5 );
566+ CHECK (combined_storage.get_actions ().size (2 ) == 1 );
567+ CHECK (combined_storage.get_masks ().size (0 ) == 4 );
568+ CHECK (combined_storage.get_masks ().size (1 ) == 5 );
569+ CHECK (combined_storage.get_masks ().size (2 ) == 1 );
570+ }
448571}
449572}
0 commit comments