Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Realm backend #1592

Draft
wants to merge 69 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
6adb290
temporary weight adjust index
reyna-abhyankar Aug 25, 2024
61697c2
Loss function
reyna-abhyankar Aug 27, 2024
b56c046
Add cuda test for loss function
reyna-abhyankar Aug 27, 2024
f75a3d4
Format
reyna-abhyankar Aug 27, 2024
f74711f
Refactor and build optimizer kernels, op
reyna-abhyankar Aug 27, 2024
40c6252
Finish optimizer local backing
reyna-abhyankar Aug 27, 2024
ad9b9ea
Format
reyna-abhyankar Aug 27, 2024
1ddfade
E2E update test
reyna-abhyankar Aug 27, 2024
dde9496
Format
reyna-abhyankar Aug 27, 2024
59635d8
Small fixes
reyna-abhyankar Sep 11, 2024
103ef07
Format
reyna-abhyankar Sep 11, 2024
f48f9ff
Fix test and small issues
reyna-abhyankar Sep 18, 2024
189c9c8
Format
reyna-abhyankar Sep 18, 2024
d93f464
Merge branch 'repo-refactor' into local-e2e-training
reyna-abhyankar Oct 1, 2024
b5647c8
Pass tests after merge
reyna-abhyankar Oct 1, 2024
f5ff91e
Fix input/weight differentiation
reyna-abhyankar Oct 1, 2024
7470e71
Fix signature to use unified rep
reyna-abhyankar Oct 1, 2024
deece1b
Fix model training instance abstraction
reyna-abhyankar Oct 1, 2024
1d3cc94
Change subcase test name
reyna-abhyankar Oct 1, 2024
3cf5d08
Quick fixes
reyna-abhyankar Oct 16, 2024
79ef4c9
Refactor training backing and instance
reyna-abhyankar Oct 22, 2024
a73b1c3
Expose op folders publicly
reyna-abhyankar Nov 13, 2024
c6fed29
Add tensor type, operate over reduced tensor
reyna-abhyankar Nov 13, 2024
0cdfb1a
Fixes
reyna-abhyankar Jan 7, 2025
9d252b3
Remove tensor lower
reyna-abhyankar Jan 15, 2025
895c117
Add tensor and task lowering scheme
reyna-abhyankar Jan 17, 2025
66d61eb
feat: add realm-backend subdir
chenzhuofu Jan 21, 2025
8d0cfec
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Jan 21, 2025
411017d
Build local exec
reyna-abhyankar Jan 22, 2025
759abdd
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Jan 22, 2025
bcd1408
chore: duplicate some files from local-execution
chenzhuofu Jan 22, 2025
5e11568
Merge branch 'master' of github.com:flexflow/flexflow-train into real…
chenzhuofu Jan 28, 2025
1c55cf7
Merge branch 'master' of github.com:flexflow/flexflow-train into real…
chenzhuofu Jan 28, 2025
b9144ad
chore: update legion
chenzhuofu Jan 30, 2025
66647a2
feat: add legion related code
chenzhuofu Jan 30, 2025
0128abb
Disaggregate local backend
reyna-abhyankar Feb 1, 2025
277f8c2
Update task binding interface and cost estimator
reyna-abhyankar Feb 1, 2025
377c6aa
Merge master into local execution
reyna-abhyankar Feb 4, 2025
6f689a4
feat: add Future wrapper for func result
chenzhuofu Feb 5, 2025
fe2bc21
feat: add realm-backend draft impl
chenzhuofu Feb 5, 2025
8efaec7
Build
reyna-abhyankar Feb 6, 2025
1dc1398
Format
reyna-abhyankar Feb 6, 2025
17ad5c8
Split task spec files
reyna-abhyankar Feb 6, 2025
639c2c1
Delete outdated sim environment file
reyna-abhyankar Feb 6, 2025
c408ebb
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Feb 8, 2025
a697044
Finish API
reyna-abhyankar Feb 13, 2025
187a8d5
Add tests for allocated and unallocated
reyna-abhyankar Feb 13, 2025
a0f8113
Fix nonnegative
reyna-abhyankar Feb 13, 2025
b1eab94
Format
reyna-abhyankar Feb 13, 2025
b532c50
Pass allocated-unallocated tests
reyna-abhyankar Feb 13, 2025
f28e5c2
Update task registry tests
reyna-abhyankar Feb 13, 2025
7887183
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Feb 16, 2025
9c16d76
feat: intial implementation of realm-backend
chenzhuofu Feb 19, 2025
89752fa
Move local tensor backing to dtgen
reyna-abhyankar Feb 22, 2025
aef8ad5
Remove lowered tensor source
reyna-abhyankar Feb 22, 2025
f0a4285
Loss and update tests
reyna-abhyankar Feb 24, 2025
9047edc
Merge master
reyna-abhyankar Feb 24, 2025
350babf
Passing tests after merge issues
reyna-abhyankar Feb 24, 2025
aef7c6e
Pass gpu tests
reyna-abhyankar Feb 25, 2025
6c84fb3
chore: fix typo
chenzhuofu Feb 26, 2025
d6aa7ad
chore: update realm allocator impl
chenzhuofu Feb 27, 2025
419cca8
chore: eliminate std::optional<float>
chenzhuofu Mar 3, 2025
2c0b573
feat: buildable realm-backend
chenzhuofu Mar 5, 2025
ebe06cf
Merge commit 'aef8ad58196f7b7f724fc7f0a1a65af24ee12acd' of github.com…
chenzhuofu Mar 5, 2025
062825e
chore: Move realm tensor backing to dtgen
chenzhuofu Mar 5, 2025
d82fa2a
Merge commit '350babf3584c3d99e76e4dc0f72a658aa0222afc' of github.com…
chenzhuofu Mar 5, 2025
7c53bb3
chore: minor
chenzhuofu Mar 5, 2025
403ec78
Merge commit 'aef7c6e3c3087f15b4c90792148f170da84f6f7c' of github.com…
chenzhuofu Mar 5, 2025
bf57d1d
chore: remove deprecated file
chenzhuofu Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Build
  • Loading branch information
reyna-abhyankar committed Feb 6, 2025
commit 8efaec7f2590bc4b8613c9f742910119d67df71a
5 changes: 1 addition & 4 deletions lib/kernels/include/kernels/array_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct ArrayShape {
explicit ArrayShape(nonnegative_int *dims, nonnegative_int num_dims);
explicit ArrayShape(TensorShape const &shape);
explicit ArrayShape(std::vector<nonnegative_int> const &);
explicit ArrayShape(LegionTensorDims const &);
explicit ArrayShape(LegionOrdered<nonnegative_int> const &);

/**
* @brief Alias of ArrayShape::num_elements for compatibility with
Expand Down Expand Up @@ -53,9 +53,6 @@ struct ArrayShape {
ArrayShape sub_shape(std::optional<legion_dim_t> start,
std::optional<legion_dim_t> end) const;

bool operator==(ArrayShape const &) const;
bool operator!=(ArrayShape const &) const;

public:
LegionOrdered<nonnegative_int> dims;

Expand Down
2 changes: 2 additions & 0 deletions lib/kernels/include/kernels/legion_dim.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value);

legion_dim_t legion_dim_from_ff_dim(ff_dim_t, nonnegative_int num_dimensions);

ff_dim_t ff_dim_from_legion_dim(legion_dim_t, nonnegative_int num_dimensions);

template <typename T>
using LegionOrdered = DimOrdered<legion_dim_t, T>;

Expand Down
1 change: 0 additions & 1 deletion lib/kernels/src/allocation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ void Allocator::deallocate(void *ptr) {

GenericTensorAccessorW
Allocator::allocate_tensor(TensorShape const &tensor_shape) {
return {tensor_shape.data_type, ArrayShape{tensor_shape}, ptr};
void *ptr =
this->allocate(get_size_in_bytes(tensor_shape).unwrap_nonnegative());
return {tensor_shape.data_type, ArrayShape{tensor_shape}, ptr};
Expand Down
38 changes: 15 additions & 23 deletions lib/kernels/src/array_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ArrayShape::ArrayShape(TensorShape const &shape)
ArrayShape::ArrayShape(std::vector<nonnegative_int> const &input_dims)
: dims(input_dims) {}

ArrayShape::ArrayShape(LegionTensorDims const &legion_tensor_dims)
ArrayShape::ArrayShape(LegionOrdered<nonnegative_int> const &legion_tensor_dims)
: dims(legion_tensor_dims) {}

nonnegative_int ArrayShape::get_volume() const {
Expand Down Expand Up @@ -58,23 +58,23 @@ nonnegative_int ArrayShape::at(ff_dim_t idx) const {

ArrayShape ArrayShape::sub_shape(std::optional<ff_dim_t> start,
std::optional<ff_dim_t> end) const {
std::optional<legion_dim_t> legion_start =
return ArrayShape{legion_ordered_from_ff_ordered(slice(ff_ordered_from_legion_ordered(this->dims), start, end))};
}

ArrayShape ArrayShape::sub_shape(std::optional<legion_dim_t> start,
std::optional<legion_dim_t> end) const {
std::optional<ff_dim_t> legion_start =
transform(start, [&](auto const &start_unwrapped) {
return legion_dim_from_ff_dim(start_unwrapped, num_dims());
return ff_dim_from_legion_dim(start_unwrapped, num_dims());
});

std::optional<legion_dim_t> legion_end =
std::optional<ff_dim_t> legion_end =
transform(end, [&](auto const &end_unwrapped) {
return legion_dim_from_ff_dim(end_unwrapped, num_dims());
return ff_dim_from_legion_dim(end_unwrapped, num_dims());
});
return this->sub_shape(legion_start, legion_end);
}

ArrayShape ArrayShape::sub_shape(std::optional<legion_dim_t> start,
std::optional<legion_dim_t> end) const {
return ArrayShape{slice(this->dims, start, end)};
}

bool ArrayShape::operator==(ArrayShape const &other) const {
return this->tie() == other.tie();
}
Expand All @@ -83,11 +83,11 @@ bool ArrayShape::operator!=(ArrayShape const &other) const {
return this->tie() != other.tie();
}

ArrayShape ArrayShape::sub_shape(
std::optional<std::variant<ff_dim_t, legion_dim_t>> start,
std::optional<std::variant<ff_dim_t, legion_dim_t>> end) const {
NOT_IMPLEMENTED();
}
// ArrayShape ArrayShape::sub_shape(
// std::optional<std::variant<ff_dim_t, legion_dim_t>> start,
// std::optional<std::variant<ff_dim_t, legion_dim_t>> end) const {
// NOT_IMPLEMENTED();
// }

std::optional<nonnegative_int> ArrayShape::at_maybe(legion_dim_t index) const {
if (index.value < dims.size()) {
Expand All @@ -114,14 +114,6 @@ TensorShape get_tensor_shape(ArrayShape const &shape, DataType dtype) {
dtype};
}

bool ArrayShape::operator==(ArrayShape const &other) const {
return this->dims == other.dims;
}

bool ArrayShape::operator!=(ArrayShape const &other) const {
return this->dims != other.dims;
}

std::string format_as(ArrayShape const &x) {
std::ostringstream oss;
oss << "<ArrayShape";
Expand Down
3 changes: 2 additions & 1 deletion lib/kernels/src/cuda/ops/concat_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "device.h"
#include "kernels/concat_kernels.h"
#include "kernels/legion_dim.h"
#include "utils/nonnegative_int/nonnegative_int.h"
#include <cassert>

namespace FlexFlow {
Expand All @@ -27,7 +28,7 @@ void calc_blk_size(size_t &num_blocks,
ArrayShape const &shape,
ff_dim_t axis) {
legion_dim_t axis_legion_dim = legion_dim_from_ff_dim(axis, shape.num_dims());
blk_size = shape.sub_shape(legion_dim_t{0}, axis_legion_dim).num_elements().unwrap_nonnegative();
blk_size = shape.sub_shape(legion_dim_t{nonnegative_int{0}}, axis_legion_dim).num_elements().unwrap_nonnegative();
num_blocks = shape.sub_shape(axis, std::nullopt).num_elements().unwrap_nonnegative();
}

Expand Down
6 changes: 6 additions & 0 deletions lib/kernels/src/legion_dim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,10 @@ legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim,
ff_dim.value.unwrap_nonnegative() - 1}};
}

ff_dim_t legion_dim_from_ff_dim(legion_dim_t legion_dim,
nonnegative_int num_dimensions) {
return ff_dim_t{nonnegative_int{num_dimensions.unwrap_nonnegative() -
legion_dim.value.unwrap_nonnegative() - 1}};
}

} // namespace FlexFlow
28 changes: 1 addition & 27 deletions lib/local-execution/src/local-execution/ops/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,8 @@ enum Slots {
OUTPUT, // tensor
ATTRS,
PROFILING,
PER_DEVICE_STATE,
};

OpTaskInvocation init(TransposeAttrs const &attrs) {
OpTaskBinding binding;
binding.bind_arg(ATTRS, attrs);
return {task_id_t::TRANSPOSE_INIT_TASK_ID, binding};
}

static DeviceSpecificDeviceStates
init_task_impl(TaskArgumentAccessor const &acc) {
auto const &attrs = acc.get_argument<TransposeAttrs>(ATTRS);
std::vector<ff_dim_t> perm = inner_to_outer_idxs(attrs.perm);
TransposePerDeviceState per_device_state = init_kernel(perm.size(), perm);

return DeviceSpecificDeviceStates{
DeviceSpecific<TransposePerDeviceState>::create(per_device_state)};
}

OpTaskInvocation forward(TransposeAttrs const &attrs) {
OpTaskBinding binding;
Expand Down Expand Up @@ -95,9 +79,6 @@ OpTaskInvocation backward(TransposeAttrs const &attrs) {
return {task_id_t::TRANSPOSE_BWD_TASK_ID, binding};
}

TaskImplFunction get_transpose_init_task_impl() {
return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}};
}

TaskImplFunction get_transpose_fwd_task_impl() {
return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}};
Expand All @@ -107,13 +88,6 @@ TaskImplFunction get_transpose_bwd_task_impl() {
return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}};
}

OpTaskSignature get_transpose_init_signature() {
OpTaskSignature init(OpTaskType::INIT);

init.add_arg_slot<TransposeAttrs>(ATTRS);
init.add_return_value<TransposePerDeviceState>();
return init;
}

OpTaskSignature get_transpose_fwd_signature() {
OpTaskSignature fwd(OpTaskType::FWD);
Expand All @@ -131,7 +105,7 @@ OpTaskSignature get_transpose_bwd_signature() {
}

std::vector<task_id_t> get_task_ids(TransposeAttrs const &) {
return {task_id_t::TRANSPOSE_INIT_TASK_ID, task_id_t::TRANSPOSE_FWD_TASK_ID, task_id_t::TRANSPOSE_BWD_TASK_ID};
return {task_id_t::TRANSPOSE_FWD_TASK_ID, task_id_t::TRANSPOSE_BWD_TASK_ID};
}

} // namespace FlexFlow
2 changes: 1 addition & 1 deletion lib/local-execution/src/local_cost_estimator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace FlexFlow {
LocalCostEstimator::LocalCostEstimator(RuntimeArgConfig const &config)
: runtime_arg_config(config) {}

static ComputationGraph const &
static ComputationGraph
create_computation_graph_for_local_cost_estimation(
PCGOperatorAttrs const &op,
std::vector<ParallelTensorShape> const &inputs,
Expand Down
45 changes: 23 additions & 22 deletions lib/local-execution/src/loss_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "kernels/loss_function_kernels.h"
#include "local-execution/loss_functions.h"
#include "local-execution/profiling.h"
#include "utils/nonnegative_int/nonnegative_int.h"

namespace FlexFlow {

Expand Down Expand Up @@ -54,52 +55,52 @@ static void backward_task_impl(TaskArgumentAccessor const &acc) {
auto logit_grad = acc.get_tensor_grad<Permissions::RW>(LOGIT_GRAD);
auto logit = acc.get_tensor<Permissions::RO>(LOGIT);
auto label = acc.get_loss_tensor<Permissions::RO>(LABEL);
int batch_size = logit.shape.at(legion_dim_t{1});
int batch_size = logit.shape.at(legion_dim_t{nonnegative_int{1}}).unwrap_nonnegative();
// assuming logit shape is [batch dim, num classes]

LossFunction loss_type = get_loss_function(attrs);
float scale_factor = 1.0f / batch_size;
if (loss_type == LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE) {
assert(logit.shape.get_volume() == label.shape.get_volume());
scale_factor = 2.0f / logit.shape.get_volume();
scale_factor = 2.0f / logit.shape.get_volume().unwrap_nonnegative();
}

if (loss_type == LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY) {
// label shape is [batch dim, 1]
auto scce_attrs = attrs.get<SparseCategoricalCrossEntropyLossAttrs>();
size_t ndim = logit.shape.num_dims();
int num_classes = logit.shape.at(legion_dim_t{0});
size_t ndim = logit.shape.num_dims().unwrap_nonnegative();
int num_classes = logit.shape.at(legion_dim_t{nonnegative_int{0}}).unwrap_nonnegative();
assert(logit_grad.shape == logit.shape);
int k = 1;
if (scce_attrs.replace_labels) {
k = logit.shape.at(legion_dim_t(ndim - 1)) /
k = logit.shape.at(legion_dim_t(nonnegative_int{ndim - 1})).unwrap_nonnegative() /
label.shape.at(legion_dim_t(
ndim - 1)); // TODO FIXME something seems wrong here, isn't the
nonnegative_int{ndim - 1})).unwrap_nonnegative(); // TODO FIXME something seems wrong here, isn't the
// numerator guaranteed to be 1? <--- this is not the
// case because of the potential parallel dim
}
assert(label.shape.sub_shape(legion_dim_t(1), std::nullopt) ==
logit.shape.sub_shape(legion_dim_t(1), std::nullopt));
assert(k * label.shape.at(legion_dim_t(ndim - 1)) ==
logit.shape.at(legion_dim_t(ndim - 1)));
assert(label.shape.at(legion_dim_t(0)) == 1);
assert(label.shape.sub_shape(legion_dim_t(nonnegative_int{1}), std::nullopt) ==
logit.shape.sub_shape(legion_dim_t(nonnegative_int{1}), std::nullopt));
assert(k * label.shape.at(legion_dim_t(nonnegative_int{ndim - 1})).unwrap_nonnegative() ==
logit.shape.at(legion_dim_t(nonnegative_int{ndim - 1})).unwrap_nonnegative());
assert(label.shape.at(legion_dim_t(nonnegative_int{0})).unwrap_nonnegative() == 1);

profile(sparse_categorical_crossentropy_loss_backward_kernel,
profiling,
"[SparseCategoricalCrossEntropyLoss] backward_time = %.2lfms\n",
get_float_ptr(logit_grad),
get_float_ptr(logit),
reinterpret_cast<int const *>(get_float_ptr(label)),
get_volume(logit.shape),
get_volume(logit_grad.shape),
get_volume(logit.shape).unwrap_nonnegative(),
get_volume(logit_grad.shape).unwrap_nonnegative(),
batch_size,
num_classes,
k,
scale_factor);
} else {
assert(logit.shape == label.shape);
assert(logit_grad.shape == logit.shape);
int num_channels = logit.shape.at(legion_dim_t{0});
int num_channels = logit.shape.at(legion_dim_t{nonnegative_int{0}}).unwrap_nonnegative();
switch (loss_type) {
case LossFunction::CATEGORICAL_CROSSENTROPY: {
profile(categorical_crossentropy_loss_backward_kernel,
Expand All @@ -108,8 +109,8 @@ static void backward_task_impl(TaskArgumentAccessor const &acc) {
get_float_ptr(logit_grad),
get_float_ptr(logit),
get_float_ptr(label),
get_volume(logit.shape),
get_volume(logit_grad.shape),
get_volume(logit.shape).unwrap_nonnegative(),
get_volume(logit_grad.shape).unwrap_nonnegative(),
scale_factor);
break;
}
Expand All @@ -120,8 +121,8 @@ static void backward_task_impl(TaskArgumentAccessor const &acc) {
get_float_ptr(logit_grad),
get_float_ptr(logit),
get_float_ptr(label),
get_volume(logit.shape),
get_volume(logit_grad.shape),
get_volume(logit.shape).unwrap_nonnegative(),
get_volume(logit_grad.shape).unwrap_nonnegative(),
scale_factor);
break;
}
Expand All @@ -131,15 +132,15 @@ static void backward_task_impl(TaskArgumentAccessor const &acc) {
"[IdentityLoss] backward_time = %.2lfms\n",
get_float_ptr(logit_grad),
get_float_ptr(logit),
get_volume(logit.shape),
get_volume(logit_grad.shape),
get_volume(logit.shape).unwrap_nonnegative(),
get_volume(logit_grad.shape).unwrap_nonnegative(),
scale_factor);
break;
}
default:
throw mk_runtime_error(
throw mk_runtime_error(fmt::format(
"Unsupported loss function {}. Please report this as an issue.",
loss_type);
loss_type));
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions lib/local-execution/src/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ static void sgd_update_task_impl(TaskArgumentAccessor const &acc) {
auto profiling = acc.get_argument<ProfilingSettings>(PROFILING);

assert(weight.shape == weight_grad.shape);
size_t size = weight_grad.shape.get_volume();
int size = weight_grad.shape.get_volume().unwrap_nonnegative();

assert(weight_grad.shape.get_volume() & weight.shape.get_volume() == 0);
size_t num_replicas =
weight_grad.shape.get_volume() / weight.shape.get_volume();
assert(weight_grad.shape.get_volume().unwrap_nonnegative() & weight.shape.get_volume().unwrap_nonnegative() == 0);
int num_replicas =
weight_grad.shape.get_volume().unwrap_nonnegative() / weight.shape.get_volume().unwrap_nonnegative();

float *sgd_v_ptr;
if (attrs.momentum > 0.0f) {
Expand Down Expand Up @@ -153,11 +153,11 @@ static void adam_update_task_impl(TaskArgumentAccessor const &acc) {
auto profiling = acc.get_argument<ProfilingSettings>(PROFILING);

assert(weight.shape == weight_grad.shape);
size_t size = weight_grad.shape.get_volume();
int size = weight_grad.shape.get_volume().unwrap_nonnegative();

assert(weight_grad.shape.get_volume() % weight.shape.get_volume() == 0);
size_t num_replicas =
weight_grad.shape.get_volume() / weight.shape.get_volume();
assert(weight_grad.shape.get_volume().unwrap_nonnegative() % weight.shape.get_volume().unwrap_nonnegative() == 0);
int num_replicas =
weight_grad.shape.get_volume().unwrap_nonnegative() / weight.shape.get_volume().unwrap_nonnegative();

if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) {
auto handle = acc.get_argument<PerDeviceFFHandle>(HANDLE);
Expand Down
6 changes: 3 additions & 3 deletions lib/local-execution/src/task_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ void register_tasks_for_layer(TaskRegistry &task_registry,
task_registry.backward_task_ids[op_id] = task_id;
break;
default:
throw mk_runtime_error("Invalid OpTaskType, got {}",
task_signature_impl.task_signature.type);
throw mk_runtime_error(fmt::format("Invalid OpTaskType, got {}",
task_signature_impl.task_signature.type));
}
task_registry.task_mapping.insert({task_id, task_signature_impl});
}
Expand All @@ -58,7 +58,7 @@ bool registry_contains_task_for_layer(TaskRegistry const &task_registry,
task_ids = task_registry.backward_task_ids;
break;
default:
throw mk_runtime_error("Invalid OpTaskType, got {}", op_task_type);
throw mk_runtime_error(fmt::format("Invalid OpTaskType, got {}", op_task_type));
}

return task_ids.at(op).has_value();
Expand Down