Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Realm backend #1592

Draft
wants to merge 69 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
6adb290
temporary weight adjust index
reyna-abhyankar Aug 25, 2024
61697c2
Loss function
reyna-abhyankar Aug 27, 2024
b56c046
Add cuda test for loss function
reyna-abhyankar Aug 27, 2024
f75a3d4
Format
reyna-abhyankar Aug 27, 2024
f74711f
Refactor and build optimizer kernels, op
reyna-abhyankar Aug 27, 2024
40c6252
Finish optimizer local backing
reyna-abhyankar Aug 27, 2024
ad9b9ea
Format
reyna-abhyankar Aug 27, 2024
1ddfade
E2E update test
reyna-abhyankar Aug 27, 2024
dde9496
Format
reyna-abhyankar Aug 27, 2024
59635d8
Small fixes
reyna-abhyankar Sep 11, 2024
103ef07
Format
reyna-abhyankar Sep 11, 2024
f48f9ff
Fix test and small issues
reyna-abhyankar Sep 18, 2024
189c9c8
Format
reyna-abhyankar Sep 18, 2024
d93f464
Merge branch 'repo-refactor' into local-e2e-training
reyna-abhyankar Oct 1, 2024
b5647c8
Pass tests after merge
reyna-abhyankar Oct 1, 2024
f5ff91e
Fix input/weight differentiation
reyna-abhyankar Oct 1, 2024
7470e71
Fix signature to use unified rep
reyna-abhyankar Oct 1, 2024
deece1b
Fix model training instance abstraction
reyna-abhyankar Oct 1, 2024
1d3cc94
Change subcase test name
reyna-abhyankar Oct 1, 2024
3cf5d08
Quick fixes
reyna-abhyankar Oct 16, 2024
79ef4c9
Refactor training backing and instance
reyna-abhyankar Oct 22, 2024
a73b1c3
Expose op folders publicly
reyna-abhyankar Nov 13, 2024
c6fed29
Add tensor type, operate over reduced tensor
reyna-abhyankar Nov 13, 2024
0cdfb1a
Fixes
reyna-abhyankar Jan 7, 2025
9d252b3
Remove tensor lower
reyna-abhyankar Jan 15, 2025
895c117
Add tensor and task lowering scheme
reyna-abhyankar Jan 17, 2025
66d61eb
feat: add realm-backend subdir
chenzhuofu Jan 21, 2025
8d0cfec
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Jan 21, 2025
411017d
Build local exec
reyna-abhyankar Jan 22, 2025
759abdd
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Jan 22, 2025
bcd1408
chore: duplicate some files from local-execution
chenzhuofu Jan 22, 2025
5e11568
Merge branch 'master' of github.com:flexflow/flexflow-train into real…
chenzhuofu Jan 28, 2025
1c55cf7
Merge branch 'master' of github.com:flexflow/flexflow-train into real…
chenzhuofu Jan 28, 2025
b9144ad
chore: update legion
chenzhuofu Jan 30, 2025
66647a2
feat: add legion related code
chenzhuofu Jan 30, 2025
0128abb
Disaggregate local backend
reyna-abhyankar Feb 1, 2025
277f8c2
Update task binding interface and cost estimator
reyna-abhyankar Feb 1, 2025
377c6aa
Merge master into local execution
reyna-abhyankar Feb 4, 2025
6f689a4
feat: add Future wrapper for func result
chenzhuofu Feb 5, 2025
fe2bc21
feat: add realm-backend draft impl
chenzhuofu Feb 5, 2025
8efaec7
Build
reyna-abhyankar Feb 6, 2025
1dc1398
Format
reyna-abhyankar Feb 6, 2025
17ad5c8
Split task spec files
reyna-abhyankar Feb 6, 2025
639c2c1
Delete outdated sim environment file
reyna-abhyankar Feb 6, 2025
c408ebb
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Feb 8, 2025
a697044
Finish API
reyna-abhyankar Feb 13, 2025
187a8d5
Add tests for allocated and unallocated
reyna-abhyankar Feb 13, 2025
a0f8113
Fix nonnegative
reyna-abhyankar Feb 13, 2025
b1eab94
Format
reyna-abhyankar Feb 13, 2025
b532c50
Pass allocated-unallocated tests
reyna-abhyankar Feb 13, 2025
f28e5c2
Update task registry tests
reyna-abhyankar Feb 13, 2025
7887183
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Feb 16, 2025
9c16d76
feat: intial implementation of realm-backend
chenzhuofu Feb 19, 2025
89752fa
Move local tensor backing to dtgen
reyna-abhyankar Feb 22, 2025
aef8ad5
Remove lowered tensor source
reyna-abhyankar Feb 22, 2025
f0a4285
Loss and update tests
reyna-abhyankar Feb 24, 2025
9047edc
Merge master
reyna-abhyankar Feb 24, 2025
350babf
Passing tests after merge issues
reyna-abhyankar Feb 24, 2025
aef7c6e
Pass gpu tests
reyna-abhyankar Feb 25, 2025
6c84fb3
chore: fix typo
chenzhuofu Feb 26, 2025
d6aa7ad
chore: update realm allocator impl
chenzhuofu Feb 27, 2025
419cca8
chore: eliminate std::optional<float>
chenzhuofu Mar 3, 2025
2c0b573
feat: buildable realm-backend
chenzhuofu Mar 5, 2025
ebe06cf
Merge commit 'aef8ad58196f7b7f724fc7f0a1a65af24ee12acd' of github.com…
chenzhuofu Mar 5, 2025
062825e
chore: Move realm tensor backing to dtgen
chenzhuofu Mar 5, 2025
d82fa2a
Merge commit '350babf3584c3d99e76e4dc0f72a658aa0222afc' of github.com…
chenzhuofu Mar 5, 2025
7c53bb3
chore: minor
chenzhuofu Mar 5, 2025
403ec78
Merge commit 'aef7c6e3c3087f15b4c90792148f170da84f6f7c' of github.com…
chenzhuofu Mar 5, 2025
bf57d1d
chore: remove deprecated file
chenzhuofu Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Pass allocated-unallocated tests
  • Loading branch information
reyna-abhyankar committed Feb 13, 2025
commit b532c5023861ea8f0391c0aef4dc86e42cda0d22
6 changes: 6 additions & 0 deletions lib/kernels/src/legion_dim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
ff_dim.value.unwrap_nonnegative() - 1}};
}

ff_dim_t legion_dim_from_ff_dim(legion_dim_t legion_dim,

Check warning on line 16 in lib/kernels/src/legion_dim.cc

View check run for this annotation

Codecov / codecov/patch

lib/kernels/src/legion_dim.cc#L16

Added line #L16 was not covered by tests
nonnegative_int num_dimensions) {
return ff_dim_t{nonnegative_int{num_dimensions.unwrap_nonnegative() -
legion_dim.value.unwrap_nonnegative() - 1}};

Check warning on line 19 in lib/kernels/src/legion_dim.cc

View check run for this annotation

Codecov / codecov/patch

lib/kernels/src/legion_dim.cc#L18-L19

Added lines #L18 - L19 were not covered by tests
}

ff_dim_t ff_dim_from_legion_dim(legion_dim_t legion_dim,

Check warning on line 22 in lib/kernels/src/legion_dim.cc

View check run for this annotation

Codecov / codecov/patch

lib/kernels/src/legion_dim.cc#L22

Added line #L22 was not covered by tests
nonnegative_int num_dimensions) {
return ff_dim_t{nonnegative_int{num_dimensions.unwrap_nonnegative() -
legion_dim.value.unwrap_nonnegative() - 1}};

Check warning on line 25 in lib/kernels/src/legion_dim.cc

View check run for this annotation

Codecov / codecov/patch

lib/kernels/src/legion_dim.cc#L24-L25

Added lines #L24 - L25 were not covered by tests
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ struct GradientTensorSource {

gradient_tensor_t new_gradient_tensor();

void reset();

private:
static size_t next_available_gradient_tensor_id;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ struct OptimizerTensorSource {

optimizer_tensor_t new_optimizer_tensor();

void reset();

private:
static size_t next_available_optimizer_tensor_id;
};
Expand Down
4 changes: 4 additions & 0 deletions lib/local-execution/src/gradient_tensor_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@ gradient_tensor_t GradientTensorSource::new_gradient_tensor() {
GradientTensorSource::next_available_gradient_tensor_id++};
}

void GradientTensorSource::reset() {
GradientTensorSource::next_available_gradient_tensor_id = 0;
}

} // namespace FlexFlow
11 changes: 7 additions & 4 deletions lib/local-execution/src/local_tensor_backing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,104 +11,104 @@

namespace FlexFlow {

LocalTensorBacking::LocalTensorBacking(

Check warning on line 14 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L14

Added line #L14 was not covered by tests
AllocatedTensors const &allocated_tensors,
UnallocatedTensors const &unallocated_tensors,
Allocator const &allocator)
: tensor_gradient_mapping(allocated_tensors.gradient_mapping),
tensor_optimizer_mapping(allocated_tensors.optimizer_mapping),
allocator(allocator) {

Check warning on line 20 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L17-L20

Added lines #L17 - L20 were not covered by tests

// handle already-allocated tensors
for (std::pair<TensorTypeVariant, GenericTensorAccessorW> const
&tensor_type_backing : allocated_tensors.tensor_type_backings) {

Check warning on line 24 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L23-L24

Added lines #L23 - L24 were not covered by tests
lowered_tensor_t lowered_tensor =
this->insert_tensor(tensor_type_backing.first);
this->tensor_backings.insert({lowered_tensor, tensor_type_backing.second});

Check warning on line 27 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L26-L27

Added lines #L26 - L27 were not covered by tests
}

// allocate new tensors
this->tensor_gradient_mapping.insert(
unallocated_tensors.gradient_mapping.begin(),
unallocated_tensors.gradient_mapping.end());

Check warning on line 33 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L31-L33

Added lines #L31 - L33 were not covered by tests

for (std::pair<tensor_guid_t, std::vector<optimizer_tensor_t>> const

Check warning on line 35 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L35

Added line #L35 was not covered by tests
&unallocated_optimizer_tensors :
unallocated_tensors.optimizer_mapping) {
if (this->tensor_optimizer_mapping.count(
unallocated_optimizer_tensors.first)) {
for (optimizer_tensor_t const &optimizer_tensor :
unallocated_optimizer_tensors.second) {
this->tensor_optimizer_mapping[unallocated_optimizer_tensors.first]
.push_back(optimizer_tensor);

Check warning on line 43 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L37-L43

Added lines #L37 - L43 were not covered by tests
}
} else {
this->tensor_optimizer_mapping.insert({unallocated_optimizer_tensors});

Check warning on line 46 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L46

Added line #L46 was not covered by tests
}
}

Check warning on line 48 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L48

Added line #L48 was not covered by tests

for (std::pair<TensorTypeVariant, TensorShape> const &tensor_type_shape :
unallocated_tensors.tensor_type_shapes) {

Check warning on line 51 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L50-L51

Added lines #L50 - L51 were not covered by tests
lowered_tensor_t lowered_tensor =
this->insert_tensor(tensor_type_shape.first);

Check warning on line 53 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L53

Added line #L53 was not covered by tests
GenericTensorAccessorW tensor_backing =
this->allocator.allocate_tensor(tensor_type_shape.second);
this->tensor_backings.insert({lowered_tensor, tensor_backing});

Check warning on line 56 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L55-L56

Added lines #L55 - L56 were not covered by tests
}
};

Check warning on line 58 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L58

Added line #L58 was not covered by tests

lowered_tensor_t
LocalTensorBacking::insert_tensor(TensorTypeVariant const &tensor_type) {

Check warning on line 61 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L61

Added line #L61 was not covered by tests
lowered_tensor_t lowered_tensor =
this->lowered_tensor_source.new_lowered_tensor();
tensor_type.visit<std::nullopt_t>(overload{
[&](tensor_guid_t const &tensor_guid) {
this->tensor_lowering_mapping.insert({tensor_guid, lowered_tensor});
return std::nullopt;

Check warning on line 67 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L63-L67

Added lines #L63 - L67 were not covered by tests
},
[&](gradient_tensor_t const &gradient_tensor) {
this->gradient_tensor_lowering_mapping.insert(
{gradient_tensor, lowered_tensor});
return std::nullopt;

Check warning on line 72 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L69-L72

Added lines #L69 - L72 were not covered by tests
},
[&](optimizer_tensor_t const &optimizer_tensor) {
this->optimizer_tensor_lowering_mapping.insert(
{optimizer_tensor, lowered_tensor});
return std::nullopt;

Check warning on line 77 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L74-L77

Added lines #L74 - L77 were not covered by tests
},
[&](loss_tensor_t const &loss_tensor) {
this->loss_tensor_lowering_mapping.insert(
{loss_tensor, lowered_tensor});
return std::nullopt;

Check warning on line 82 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L79-L82

Added lines #L79 - L82 were not covered by tests
},
[&](auto const &any_tensor) {
throw mk_runtime_error(
fmt::format("Unhandled tensor type {}", any_tensor));
}});
return lowered_tensor;

Check warning on line 88 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L88

Added line #L88 was not covered by tests
}

GenericTensorAccessorW
LocalTensorBacking::get_tensor(TensorTypeVariant const &tensor_type) const {

Check warning on line 92 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L92

Added line #L92 was not covered by tests
lowered_tensor_t lowered_tensor =
tensor_type.visit<lowered_tensor_t>(overload{
[&](tensor_guid_t const &tensor_guid) {
return this->tensor_lowering_mapping.at(tensor_guid);

Check warning on line 96 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L94-L96

Added lines #L94 - L96 were not covered by tests
},
[&](gradient_tensor_t const &gradient_tensor) {
return this->gradient_tensor_lowering_mapping.at(gradient_tensor);

Check warning on line 99 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L98-L99

Added lines #L98 - L99 were not covered by tests
},
[&](optimizer_tensor_t const &optimizer_tensor) {
return this->optimizer_tensor_lowering_mapping.at(optimizer_tensor);

Check warning on line 102 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L101-L102

Added lines #L101 - L102 were not covered by tests
},
[&](loss_tensor_t const &loss_tensor) {
return this->loss_tensor_lowering_mapping.at(loss_tensor);

Check warning on line 105 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L104-L105

Added lines #L104 - L105 were not covered by tests
},
[&](auto const &any_tensor) {
throw mk_runtime_error(
fmt::format("Unhandled tensor type {}", any_tensor));
}});
return this->tensor_backings.at(lowered_tensor);

Check warning on line 111 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L111

Added line #L111 was not covered by tests
}

UnallocatedTensors generate_unallocated_tensors(
Expand All @@ -135,7 +135,7 @@
gradient_tensor_t gradient_tensor =
gradient_tensor_source.new_gradient_tensor();
tensor_type_shapes.insert(
{TensorTypeVariant{tensor_guid}, tensor_attrs.shape});
{TensorTypeVariant{gradient_tensor}, tensor_attrs.shape});
gradient_mapping.insert({tensor_guid, gradient_tensor});
}
}
Expand Down Expand Up @@ -168,8 +168,7 @@
tensor_attrs_mapping) {
tensor_guid_t tensor_guid = tensor_guid_attrs.first;
TensorAttrs tensor_attrs = tensor_guid_attrs.second;
if (tensor_attrs.create_gradients == CreateGrad::YES &&
!allocated_tensors.optimizer_mapping.count(tensor_guid)) {
if (tensor_attrs.create_gradients == CreateGrad::YES) {
std::vector<optimizer_tensor_t> optimizer_tensors;

int num_optimizer_tensors_to_allocate =
Expand All @@ -178,6 +177,7 @@
num_optimizer_tensors_to_allocate -=
allocated_tensors.optimizer_mapping.at(tensor_guid).size();
}
std::cout << num_optimizer_tensors_to_allocate;

for (int i = 0; i < num_optimizer_tensors_to_allocate; ++i) {
optimizer_tensor_t optimizer_tensor =
Expand All @@ -186,7 +186,10 @@
tensor_type_shapes.insert(
{TensorTypeVariant{optimizer_tensor}, tensor_attrs.shape});
}
optimizer_mapping.insert({tensor_guid, optimizer_tensors});

if (num_optimizer_tensors_to_allocate > 0) {
optimizer_mapping.insert({tensor_guid, optimizer_tensors});
}
}
}

Expand All @@ -194,18 +197,18 @@
tensor_type_shapes, gradient_mapping, optimizer_mapping};
}

TensorSlotsBacking construct_tensor_slots_backing(

Check warning on line 200 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L200

Added line #L200 was not covered by tests
LocalTensorBacking const &local_tensor_backing,
TaskBinding const &binding) {
TensorSlotsBacking mapping;

Check warning on line 203 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L203

Added line #L203 was not covered by tests

for (std::pair<SlotTensorTypeId, TensorTypeVariant> const &tensor_binding :
binding.get_tensor_bindings()) {
mapping.insert({tensor_binding.first,
local_tensor_backing.get_tensor(tensor_binding.second)});

Check warning on line 208 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L205-L208

Added lines #L205 - L208 were not covered by tests
}

return mapping;
}

Check warning on line 212 in lib/local-execution/src/local_tensor_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_tensor_backing.cc#L211-L212

Added lines #L211 - L212 were not covered by tests

} // namespace FlexFlow
3 changes: 1 addition & 2 deletions lib/local-execution/src/local_training_backing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,227 +19,226 @@
AllocatedTensors const &allocated_tensors,
ComputationGraph const &computation_graph,
RuntimeArgConfig const &runtime_arg_config)
: computation_graph(computation_graph),
task_registry(construct_task_registry(computation_graph)),
local_tensor_backing(allocated_tensors,
generate_unallocated_tensors(

Check warning on line 25 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L22-L25

Added lines #L22 - L25 were not covered by tests
allocated_tensors,
get_all_tensor_attrs(this->computation_graph),
this->gradient_tensor_source),

Check warning on line 28 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L27-L28

Added lines #L27 - L28 were not covered by tests
allocator),
local_args_backing(initialize_args_backing(this->task_registry,
this->computation_graph,

Check warning on line 31 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L30-L31

Added lines #L30 - L31 were not covered by tests
runtime_arg_config,
this->local_tensor_backing)){};

Check warning on line 33 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L33

Added line #L33 was not covered by tests

LocalTrainingBacking::LocalTrainingBacking(

Check warning on line 35 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L35

Added line #L35 was not covered by tests
Allocator const &allocator,
AllocatedTensors const &allocated_tensors,
ComputationGraph const &computation_graph,
RuntimeArgConfig const &runtime_arg_config,
OptimizerAttrs const &optimizer_attrs)
: computation_graph(computation_graph),
task_registry(construct_task_registry(computation_graph)),
local_tensor_backing(allocated_tensors,
generate_unallocated_tensors_with_optimizer(

Check warning on line 44 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L40-L44

Added lines #L40 - L44 were not covered by tests
allocated_tensors,
get_all_tensor_attrs(this->computation_graph),
this->gradient_tensor_source,
this->optimizer_tensor_source,

Check warning on line 48 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L46-L48

Added lines #L46 - L48 were not covered by tests
optimizer_attrs),
allocator),
local_args_backing(initialize_args_backing(this->task_registry,
this->computation_graph,

Check warning on line 52 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L51-L52

Added lines #L51 - L52 were not covered by tests
runtime_arg_config,
this->local_tensor_backing)){};

Check warning on line 54 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L54

Added line #L54 was not covered by tests

LocalArgsBacking
initialize_args_backing(TaskRegistry const &task_registry,

Check warning on line 57 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L57

Added line #L57 was not covered by tests
ComputationGraph const &cg,
RuntimeArgConfig const &runtime_arg_config,
LocalTensorBacking const &local_tensor_backing) {
std::unordered_map<layer_guid_t, DeviceSpecificDeviceStates>
per_device_op_states;
for (layer_guid_t const &node : topological_ordering(cg)) {
if (registry_contains_task_for_layer(
task_registry, node, OpTaskType::INIT)) {
ComputationGraphOpAttrs attrs = get_layer_attrs(cg, node).attrs;

Check warning on line 66 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L62-L66

Added lines #L62 - L66 were not covered by tests

TaskInvocation invocation =
lower_to_task_invocation(init(attrs),

Check warning on line 69 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L69

Added line #L69 was not covered by tests
node,
get_incoming_inputs(cg, node),
get_incoming_input_shapes(cg, node),
get_outgoing_tensors(cg, node),
get_incoming_weights(cg, node),
local_tensor_backing.tensor_gradient_mapping,
std::nullopt);

Check warning on line 76 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L71-L76

Added lines #L71 - L76 were not covered by tests
TaskArgumentAccessor accessor = get_task_arg_accessor(
local_tensor_backing,
make_args_backing_with_empty_device_states(runtime_arg_config),
invocation);

Check warning on line 80 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L79-L80

Added lines #L79 - L80 were not covered by tests
TaskSignatureAndImpl task_sig_impl =
task_registry.task_mapping.at(invocation.task_id);
auto fn = task_sig_impl.impl_function.get<InitOpTaskImplFunction>()
.function_ptr;
DeviceSpecificDeviceStates device_state = fn(accessor);
per_device_op_states.insert({node, device_state});
}

Check warning on line 87 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L82-L87

Added lines #L82 - L87 were not covered by tests
}

return LocalArgsBacking{runtime_arg_config, per_device_op_states};

Check warning on line 90 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L90

Added line #L90 was not covered by tests
}

std::optional<float> call_task_impl(TaskRegistry const &task_registry,

Check warning on line 93 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L93

Added line #L93 was not covered by tests
task_id_t const &task_id,
TaskArgumentAccessor const &acc) {
TaskSignatureAndImpl task_sig_impl = task_registry.task_mapping.at(task_id);

Check warning on line 96 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L96

Added line #L96 was not covered by tests
auto fn =
task_sig_impl.impl_function.get<FwdBwdOpTaskImplFunction>().function_ptr;

Check warning on line 98 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L98

Added line #L98 was not covered by tests
return fn(acc);
}

std::optional<float>
execute_forward(LocalTrainingBacking const &local_training_backing,

Check warning on line 103 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L103

Added line #L103 was not covered by tests
layer_guid_t const &operator_node,
Allocator &allocator) {
layer_guid_t const &operator_node) {
if (registry_contains_task_for_layer(local_training_backing.task_registry,

Check warning on line 105 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L105

Added line #L105 was not covered by tests
operator_node,
OpTaskType::FWD)) {

Check warning on line 107 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L107

Added line #L107 was not covered by tests
ComputationGraphOpAttrs attrs =
get_layer_attrs(local_training_backing.computation_graph, operator_node)
.attrs;

Check warning on line 110 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L109-L110

Added lines #L109 - L110 were not covered by tests

std::optional<DeviceSpecificDeviceStates> device_state =
get_per_device_op_state_if_exists(
local_training_backing.local_args_backing, operator_node);

Check warning on line 114 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L114

Added line #L114 was not covered by tests
TaskInvocation invocation = lower_to_task_invocation(
forward(attrs),

Check warning on line 116 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L116

Added line #L116 was not covered by tests
operator_node,
get_incoming_inputs(local_training_backing.computation_graph,

Check warning on line 118 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L118

Added line #L118 was not covered by tests
operator_node),
get_incoming_input_shapes(local_training_backing.computation_graph,

Check warning on line 120 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L120

Added line #L120 was not covered by tests
operator_node),
get_outgoing_tensors(local_training_backing.computation_graph,

Check warning on line 122 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L122

Added line #L122 was not covered by tests
operator_node),
get_incoming_weights(local_training_backing.computation_graph,

Check warning on line 124 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L124

Added line #L124 was not covered by tests
operator_node),
local_training_backing.local_tensor_backing.tensor_gradient_mapping,
device_state);

Check warning on line 127 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L126-L127

Added lines #L126 - L127 were not covered by tests
TaskArgumentAccessor accessor =
get_task_arg_accessor(local_training_backing.local_tensor_backing,
local_training_backing.local_args_backing,
invocation);
return call_task_impl(
local_training_backing.task_registry, invocation.task_id, accessor);
} else {
return std::nullopt;

Check warning on line 135 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L129-L135

Added lines #L129 - L135 were not covered by tests
}
}

void compute_loss(LocalTrainingBacking const &local_training_backing,

Check warning on line 139 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L139

Added line #L139 was not covered by tests
LossAttrs const &loss_attrs,
tensor_guid_t const &logit_tensor,
loss_tensor_t const &label_tensor) {
TaskInvocation loss_invocation = backward(
loss_attrs,
logit_tensor,
local_training_backing.local_tensor_backing.tensor_gradient_mapping.at(

Check warning on line 146 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L146

Added line #L146 was not covered by tests
logit_tensor),
label_tensor);

Check warning on line 148 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L148

Added line #L148 was not covered by tests
// TODO: https://github.com/flexflow/flexflow-train/issues/1442
// assert(is_invocation_valid(get_loss_bwd_signature(), loss_invocation));
TaskArgumentAccessor loss_accessor =
get_task_arg_accessor(local_training_backing.local_tensor_backing,
local_training_backing.local_args_backing,
loss_invocation);
TaskImplFunction loss_impl_fn = get_loss_bwd_task_impl();
loss_impl_fn.get<GenericTaskImplFunction>().function_ptr(loss_accessor);

Check warning on line 156 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L152-L156

Added lines #L152 - L156 were not covered by tests
}

std::optional<float>
execute_backward(LocalTrainingBacking const &local_training_backing,

Check warning on line 160 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L160

Added line #L160 was not covered by tests
layer_guid_t const &operator_node) {
if (registry_contains_task_for_layer(local_training_backing.task_registry,

Check warning on line 162 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L162

Added line #L162 was not covered by tests
operator_node,
OpTaskType::BWD)) {

Check warning on line 164 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L164

Added line #L164 was not covered by tests
ComputationGraphOpAttrs attrs =
get_layer_attrs(local_training_backing.computation_graph, operator_node)
.attrs;

Check warning on line 167 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L166-L167

Added lines #L166 - L167 were not covered by tests

std::optional<DeviceSpecificDeviceStates> device_state =
get_per_device_op_state_if_exists(
local_training_backing.local_args_backing, operator_node);

Check warning on line 171 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L171

Added line #L171 was not covered by tests
TaskInvocation invocation = lower_to_task_invocation(
backward(attrs),

Check warning on line 173 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L173

Added line #L173 was not covered by tests
operator_node,
get_incoming_inputs(local_training_backing.computation_graph,

Check warning on line 175 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L175

Added line #L175 was not covered by tests
operator_node),
get_incoming_input_shapes(local_training_backing.computation_graph,

Check warning on line 177 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L177

Added line #L177 was not covered by tests
operator_node),
get_outgoing_tensors(local_training_backing.computation_graph,

Check warning on line 179 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L179

Added line #L179 was not covered by tests
operator_node),
get_incoming_weights(local_training_backing.computation_graph,

Check warning on line 181 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L181

Added line #L181 was not covered by tests
operator_node),
local_training_backing.local_tensor_backing.tensor_gradient_mapping,
device_state);

Check warning on line 184 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L183-L184

Added lines #L183 - L184 were not covered by tests
TaskArgumentAccessor accessor =
get_task_arg_accessor(local_training_backing.local_tensor_backing,
local_training_backing.local_args_backing,
invocation);
return call_task_impl(
local_training_backing.task_registry, invocation.task_id, accessor);
} else {
return std::nullopt;

Check warning on line 192 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L186-L192

Added lines #L186 - L192 were not covered by tests
}
}

void execute_update(LocalTrainingBacking const &local_training_backing,

Check warning on line 196 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L196

Added line #L196 was not covered by tests
layer_guid_t const &node,
OptimizerAttrs const &optimizer_attrs) {
LayerAttrs layer_attrs =
get_layer_attrs(local_training_backing.computation_graph, node);
if (layer_attrs.attrs.has<WeightAttrs>()) {

Check warning on line 201 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L200-L201

Added lines #L200 - L201 were not covered by tests
// get tensors
tensor_guid_t weight_tensor = get_only(
get_outgoing_tensors(local_training_backing.computation_graph, node));

Check warning on line 204 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L203-L204

Added lines #L203 - L204 were not covered by tests

gradient_tensor_t weight_grad_tensor =
local_training_backing.local_tensor_backing.tensor_gradient_mapping.at(
weight_tensor);

Check warning on line 208 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L207-L208

Added lines #L207 - L208 were not covered by tests
std::vector<optimizer_tensor_t> optimizer_buffer_tensors =
local_training_backing.local_tensor_backing.tensor_optimizer_mapping.at(
weight_tensor);

Check warning on line 211 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L210-L211

Added lines #L210 - L211 were not covered by tests

// get invocation
TaskInvocation invocation = get_update_invocation(optimizer_attrs,
weight_tensor,
weight_grad_tensor,
optimizer_buffer_tensors);

Check warning on line 217 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L217

Added line #L217 was not covered by tests

// TODO: https://github.com/flexflow/flexflow-train/issues/1442
// assert(is_invocation_valid(get_update_signature(attrs), invocation));

// execute update
TaskArgumentAccessor accessor =
get_task_arg_accessor(local_training_backing.local_tensor_backing,
local_training_backing.local_args_backing,
invocation);
TaskImplFunction update_impl_fn = get_update_task_impl(optimizer_attrs);
update_impl_fn.get<GenericTaskImplFunction>().function_ptr(accessor);
}

Check warning on line 229 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L224-L229

Added lines #L224 - L229 were not covered by tests
}

TaskArgumentAccessor
get_task_arg_accessor(LocalTensorBacking const &local_tensor_backing,

Check warning on line 233 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L233

Added line #L233 was not covered by tests
LocalArgsBacking const &local_args_backing,
TaskInvocation const &invocation) {
TensorSlotsBacking tensor_slots_backing =
construct_tensor_slots_backing(local_tensor_backing, invocation.binding);

Check warning on line 237 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L237

Added line #L237 was not covered by tests
ArgSlotsBacking arg_slots_backing = construct_arg_slots_backing(
invocation.binding, local_args_backing.runtime_arg_config);

Check warning on line 239 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L239

Added line #L239 was not covered by tests
return TaskArgumentAccessor::create<LocalTaskArgumentAccessor>(
local_tensor_backing.allocator, tensor_slots_backing, arg_slots_backing);

Check warning on line 241 in lib/local-execution/src/local_training_backing.cc

View check run for this annotation

Codecov / codecov/patch

lib/local-execution/src/local_training_backing.cc#L241

Added line #L241 was not covered by tests
}

} // namespace FlexFlow
4 changes: 4 additions & 0 deletions lib/local-execution/src/optimizer_tensor_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@ optimizer_tensor_t OptimizerTensorSource::new_optimizer_tensor() {
OptimizerTensorSource::next_available_optimizer_tensor_id++};
}

void OptimizerTensorSource::reset() {
OptimizerTensorSource::next_available_optimizer_tensor_id = 0;
}

} // namespace FlexFlow
128 changes: 46 additions & 82 deletions lib/local-execution/test/src/test_unallocated_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,26 @@ using namespace ::FlexFlow;
TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("UnallocatedTensors") {
MockTensorGuidSource tensor_guid_source;
GradientTensorSource gradient_tensor_source;
OptimizerTensorSource optimizer_tensor_source;

gradient_tensor_source.reset();
optimizer_tensor_source.reset();

Allocator allocator = create_local_cpu_memory_allocator();

tensor_guid_t mock_tensor_1 = tensor_guid_source.new_mock_tensor_guid();
tensor_guid_t mock_tensor_2 = tensor_guid_source.new_mock_tensor_guid();
tensor_guid_t mock_tensor_3_with_grad =
tensor_guid_source.new_mock_tensor_guid();

gradient_tensor_t grad_tensor =
gradient_tensor_source.new_gradient_tensor();
optimizer_tensor_t optimizer_tensor_1 =
optimizer_tensor_source.new_optimizer_tensor();
optimizer_tensor_t optimizer_tensor_2 =
optimizer_tensor_source.new_optimizer_tensor();

TensorAttrs tensor_attrs_1_no_grad = TensorAttrs{
TensorShape{TensorDims{FFOrdered<nonnegative_int>{16_n, 10_n}},
DataType::FLOAT},
Expand Down Expand Up @@ -61,13 +72,10 @@ TEST_SUITE(FF_TEST_SUITE) {
SUBCASE("Without optimizer") {
SUBCASE("AllocatedTensors is empty") {
AllocatedTensors empty = AllocatedTensors{{}, {}, {}};
GradientTensorSource gradient_tensor_source;
gradient_tensor_source.reset();
UnallocatedTensors result = generate_unallocated_tensors(
empty, tensor_attrs_mapping, gradient_tensor_source);

GradientTensorSource mock_gradient_tensor_source;
gradient_tensor_t grad_tensor =
mock_gradient_tensor_source.new_gradient_tensor();
std::unordered_map<TensorTypeVariant, TensorShape>
correct_tensor_type_shapes = {
{TensorTypeVariant{mock_tensor_1},
Expand All @@ -93,15 +101,12 @@ TEST_SUITE(FF_TEST_SUITE) {
},
{},
{}};
GradientTensorSource gradient_tensor_source;

gradient_tensor_source.reset();
UnallocatedTensors result =
generate_unallocated_tensors(allocated_forward_tensors,
tensor_attrs_mapping,
gradient_tensor_source);

GradientTensorSource mock_gradient_tensor_source;
gradient_tensor_t grad_tensor =
mock_gradient_tensor_source.new_gradient_tensor();
std::unordered_map<TensorTypeVariant, TensorShape>
correct_tensor_type_shapes = {
{TensorTypeVariant{mock_tensor_2},
Expand All @@ -127,15 +132,13 @@ TEST_SUITE(FF_TEST_SUITE) {
},
{},
{}};
GradientTensorSource gradient_tensor_source;

gradient_tensor_source.reset();
UnallocatedTensors result =
generate_unallocated_tensors(allocated_forward_tensors,
tensor_attrs_mapping,
gradient_tensor_source);

GradientTensorSource mock_gradient_tensor_source;
gradient_tensor_t grad_tensor =
mock_gradient_tensor_source.new_gradient_tensor();
std::unordered_map<TensorTypeVariant, TensorShape>
correct_tensor_type_shapes = {
{TensorTypeVariant{grad_tensor},
Expand All @@ -149,9 +152,7 @@ TEST_SUITE(FF_TEST_SUITE) {
}

SUBCASE("AllocatedTensors contains only gradient tensor") {
GradientTensorSource gradient_tensor_source;
gradient_tensor_t grad_tensor =
gradient_tensor_source.new_gradient_tensor();

AllocatedTensors allocated_forward_tensors = AllocatedTensors{
{
{TensorTypeVariant{grad_tensor}, tensor_backing_3},
Expand All @@ -178,9 +179,7 @@ TEST_SUITE(FF_TEST_SUITE) {
}

SUBCASE("AllocatedTensors contains mixture") {
GradientTensorSource gradient_tensor_source;
gradient_tensor_t grad_tensor =
gradient_tensor_source.new_gradient_tensor();

AllocatedTensors allocated_forward_tensors = AllocatedTensors{
{
{TensorTypeVariant{mock_tensor_1}, tensor_backing_1},
Expand All @@ -206,9 +205,7 @@ TEST_SUITE(FF_TEST_SUITE) {
}

SUBCASE("Fully AllocatedTensors") {
GradientTensorSource gradient_tensor_source;
gradient_tensor_t grad_tensor =
gradient_tensor_source.new_gradient_tensor();

AllocatedTensors allocated_forward_tensors = AllocatedTensors{
{
{TensorTypeVariant{mock_tensor_1}, tensor_backing_1},
Expand All @@ -235,8 +232,8 @@ TEST_SUITE(FF_TEST_SUITE) {
OptimizerAttrs attrs =
OptimizerAttrs{SGDOptimizerAttrs{0.0, momentum, false, 0.0}};
AllocatedTensors empty = AllocatedTensors{{}, {}, {}};
GradientTensorSource gradient_tensor_source;
OptimizerTensorSource optimizer_tensour_source;

gradient_tensor_source.reset();
UnallocatedTensors result =
generate_unallocated_tensors_with_optimizer(
empty,
Expand All @@ -245,9 +242,9 @@ TEST_SUITE(FF_TEST_SUITE) {
optimizer_tensor_source,
attrs);

GradientTensorSource mock_gradient_tensor_source;
gradient_tensor_source.reset();
UnallocatedTensors correct = generate_unallocated_tensors(
empty, tensor_attrs_mapping, mock_gradient_tensor_source);
empty, tensor_attrs_mapping, gradient_tensor_source);
CHECK(result == correct);
}
SUBCASE("with momentum") {
Expand All @@ -257,8 +254,9 @@ TEST_SUITE(FF_TEST_SUITE) {

SUBCASE("unallocated") {
AllocatedTensors empty = AllocatedTensors{{}, {}, {}};
GradientTensorSource gradient_tensor_source;
OptimizerTensorSource optimizer_tensour_source;

gradient_tensor_source.reset();
optimizer_tensor_source.reset();
UnallocatedTensors result =
generate_unallocated_tensors_with_optimizer(
empty,
Expand All @@ -267,13 +265,6 @@ TEST_SUITE(FF_TEST_SUITE) {
optimizer_tensor_source,
attrs);

GradientTensorSource mock_gradient_tensor_source;
gradient_tensor_t grad_tensor =
mock_gradient_tensor_source.new_gradient_tensor();
OptimizerTensorSource mock_optimizer_tensour_source;
optimizer_tensor_t optimizer_tensor =
mock_optimizer_tensour_source.new_optimizer_tensor();

std::unordered_map<TensorTypeVariant, TensorShape>
correct_tensor_type_shapes = {
{TensorTypeVariant{mock_tensor_1},
Expand All @@ -284,26 +275,25 @@ TEST_SUITE(FF_TEST_SUITE) {
tensor_attrs_3_with_grad.shape},
{TensorTypeVariant{grad_tensor},
tensor_attrs_3_with_grad.shape},
{TensorTypeVariant{optimizer_tensor},
{TensorTypeVariant{optimizer_tensor_1},
tensor_attrs_3_with_grad.shape},
};
UnallocatedTensors correct = UnallocatedTensors{
correct_tensor_type_shapes,
{{mock_tensor_3_with_grad, grad_tensor}},
{{mock_tensor_3_with_grad, {optimizer_tensor}}}};
{{mock_tensor_3_with_grad, {optimizer_tensor_1}}}};

CHECK(result == correct);
}

SUBCASE("allocated") {
OptimizerTensorSource optimizer_tensour_source;
optimizer_tensor_t optimizer_tensor =
optimizer_tensour_source.new_optimizer_tensor();

AllocatedTensors allocated_optimizer_tensor = AllocatedTensors{
{{TensorTypeVariant{optimizer_tensor}, tensor_backing_3}},
{{TensorTypeVariant{optimizer_tensor_1}, tensor_backing_3}},
{},
{{mock_tensor_3_with_grad, {optimizer_tensor}}}};
GradientTensorSource gradient_tensor_source;
{{mock_tensor_3_with_grad, {optimizer_tensor_1}}}};

gradient_tensor_source.reset();
UnallocatedTensors result =
generate_unallocated_tensors_with_optimizer(
allocated_optimizer_tensor,
Expand All @@ -312,10 +302,6 @@ TEST_SUITE(FF_TEST_SUITE) {
optimizer_tensor_source,
attrs);

GradientTensorSource mock_gradient_tensor_source;
gradient_tensor_t grad_tensor =
mock_gradient_tensor_source.new_gradient_tensor();

std::unordered_map<TensorTypeVariant, TensorShape>
correct_tensor_type_shapes = {
{TensorTypeVariant{mock_tensor_1},
Expand Down Expand Up @@ -348,8 +334,9 @@ TEST_SUITE(FF_TEST_SUITE) {
/*epsilon=*/1e-8}};
SUBCASE("Empty") {
AllocatedTensors empty = AllocatedTensors{{}, {}, {}};
GradientTensorSource gradient_tensor_source;
OptimizerTensorSource optimizer_tensour_source;

gradient_tensor_source.reset();
optimizer_tensor_source.reset();
UnallocatedTensors result =
generate_unallocated_tensors_with_optimizer(
empty,
Expand All @@ -358,15 +345,6 @@ TEST_SUITE(FF_TEST_SUITE) {
optimizer_tensor_source,
attrs);

GradientTensorSource mock_gradient_tensor_source;
gradient_tensor_t grad_tensor =
mock_gradient_tensor_source.new_gradient_tensor();
OptimizerTensorSource mock_optimizer_tensour_source;
optimizer_tensor_t optimizer_tensor_1 =
mock_optimizer_tensour_source.new_optimizer_tensor();
optimizer_tensor_t optimizer_tensor_2 =
mock_optimizer_tensour_source.new_optimizer_tensor();

std::unordered_map<TensorTypeVariant, TensorShape>
correct_tensor_type_shapes = {
{TensorTypeVariant{mock_tensor_1},
Expand All @@ -391,14 +369,16 @@ TEST_SUITE(FF_TEST_SUITE) {
CHECK(result == correct);
}
SUBCASE("Partially allocated") {
OptimizerTensorSource optimizer_tensour_source;
optimizer_tensor_t optimizer_tensor_1 =
optimizer_tensour_source.new_optimizer_tensor();
gradient_tensor_source.reset();
optimizer_tensor_source.reset();
optimizer_tensor_t optimizer_tensor_pre_allocated =
optimizer_tensor_source.new_optimizer_tensor();
AllocatedTensors allocated_optimizer_tensor = AllocatedTensors{
{{TensorTypeVariant{optimizer_tensor_1}, tensor_backing_3}},
{{TensorTypeVariant{optimizer_tensor_pre_allocated},
tensor_backing_3}},
{},
{{mock_tensor_3_with_grad, {optimizer_tensor_1}}}};
GradientTensorSource gradient_tensor_source;
{{mock_tensor_3_with_grad, {optimizer_tensor_pre_allocated}}}};

UnallocatedTensors result =
generate_unallocated_tensors_with_optimizer(
allocated_optimizer_tensor,
Expand All @@ -407,14 +387,6 @@ TEST_SUITE(FF_TEST_SUITE) {
optimizer_tensor_source,
attrs);

GradientTensorSource mock_gradient_tensor_source;
gradient_tensor_t grad_tensor =
mock_gradient_tensor_source.new_gradient_tensor();
OptimizerTensorSource mock_optimizer_tensour_source;
optimizer_tensor_source.new_optimizer_tensor();
optimizer_tensor_t optimizer_tensor_2 =
optimizer_tensour_source.new_optimizer_tensor();

std::unordered_map<TensorTypeVariant, TensorShape>
correct_tensor_type_shapes = {
{TensorTypeVariant{mock_tensor_1},
Expand All @@ -437,18 +409,14 @@ TEST_SUITE(FF_TEST_SUITE) {
}

SUBCASE("Fully allocated") {
OptimizerTensorSource optimizer_tensour_source;
optimizer_tensor_t optimizer_tensor_1 =
optimizer_tensour_source.new_optimizer_tensor();
optimizer_tensor_t optimizer_tensor_2 =
optimizer_tensour_source.new_optimizer_tensor();
AllocatedTensors allocated_optimizer_tensor = AllocatedTensors{
{{TensorTypeVariant{optimizer_tensor_1}, tensor_backing_3},
{TensorTypeVariant{optimizer_tensor_2}, tensor_backing_3}},
{},
{{mock_tensor_3_with_grad,
{optimizer_tensor_1, optimizer_tensor_2}}}};
GradientTensorSource gradient_tensor_source;

gradient_tensor_source.reset();
UnallocatedTensors result =
generate_unallocated_tensors_with_optimizer(
allocated_optimizer_tensor,
Expand All @@ -457,10 +425,6 @@ TEST_SUITE(FF_TEST_SUITE) {
optimizer_tensor_source,
attrs);

GradientTensorSource mock_gradient_tensor_source;
gradient_tensor_t grad_tensor =
mock_gradient_tensor_source.new_gradient_tensor();

std::unordered_map<TensorTypeVariant, TensorShape>
correct_tensor_type_shapes = {
{TensorTypeVariant{mock_tensor_1},
Expand Down
Loading
Loading