Skip to content

Commit

Permalink
feat: buildable realm-backend
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhuofu committed Mar 5, 2025
1 parent 419cca8 commit 2c0b573
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 37 deletions.
8 changes: 5 additions & 3 deletions lib/realm-backend/include/realm-backend/task_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@ template <typename T> class Future {
public:
explicit Future(std::shared_ptr<SharedState<T>> state)
: state_(std::move(state)) {}
explicit Future() = default;
explicit Future(T value) : value_(std::move(value)) {}
void set_event(Realm::Event e) { state_->set_event(e); }
T get() {
value_ = state_->get_value();
return value_;
value_ = std::make_optional(state_->get_value());
return value_.value();
}
void wait() { state_->wait(); }

private:
std::shared_ptr<SharedState<T>> state_;
T value_;
std::optional<T> value_ = std::nullopt;
};

// Specialization of Future for the `void` type, as it does not carry a value.
Expand All @@ -67,6 +68,7 @@ template <> class Future<void> {
: state_(std::move(state)) {}
explicit Future() = default;
void set_event(Realm::Event e) { state_->set_event(e); }
void get() { state_->wait(); }
void wait() { state_->wait(); }

private:
Expand Down
62 changes: 28 additions & 34 deletions lib/realm-backend/src/realm_training_backing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,16 @@ RealmTrainingBacking::RealmTrainingBacking(
// allocators.push_back(create_realm_memory_allocator(p));

// register tasks for realm
for (layer_guid_t const &node :
topological_ordering(this->computation_graph)) {
ComputationGraphOpAttrs attrs =
get_layer_attrs(this->computation_graph, node).attrs;
if (attrs.has<OpTaskInvocation>()) {
OpTaskInvocation op_task_invocation = attrs.get<OpTaskInvocation>();
std::vector<task_id_t> task_ids = get_task_ids(attrs);
for (task_id_t task_id : task_ids) {
TaskSignatureAndImpl task_signature_impl =
this->task_registry.task_mapping.at(task_id);
std::unordered_map<layer_guid_t, LayerAttrs> const &layer_attrs_mapping =
get_layer_attrs_mapping(this->computation_graph);
for (std::pair<layer_guid_t, LayerAttrs> const &layer_attrs :
layer_attrs_mapping) {
ComputationGraphOpAttrs attrs = layer_attrs.second.attrs;
std::vector<task_id_t> task_ids = get_task_ids(attrs);
for (task_id_t task_id : task_ids) {
TaskSignatureAndImpl task_signature_impl = get_task_sig_impl(task_id);
// TODO: multi gpu
register_wrapper_tasks(worker_procs[0], task_id, task_signature_impl);
}
}
}
}
Expand Down Expand Up @@ -99,19 +96,16 @@ RealmTrainingBacking::RealmTrainingBacking(
}

// register tasks for realm
for (layer_guid_t const &node :
topological_ordering(this->computation_graph)) {
ComputationGraphOpAttrs attrs =
get_layer_attrs(this->computation_graph, node).attrs;
if (attrs.has<OpTaskInvocation>()) {
OpTaskInvocation op_task_invocation = attrs.get<OpTaskInvocation>();
std::vector<task_id_t> task_ids = get_task_ids(attrs);
for (task_id_t task_id : task_ids) {
TaskSignatureAndImpl task_signature_impl =
this->task_registry.task_mapping.at(task_id);
std::unordered_map<layer_guid_t, LayerAttrs> const &layer_attrs_mapping =
get_layer_attrs_mapping(this->computation_graph);
for (std::pair<layer_guid_t, LayerAttrs> const &layer_attrs :
layer_attrs_mapping) {
ComputationGraphOpAttrs attrs = layer_attrs.second.attrs;
std::vector<task_id_t> task_ids = get_task_ids(attrs);
for (task_id_t task_id : task_ids) {
TaskSignatureAndImpl task_signature_impl = get_task_sig_impl(task_id);
// TODO: multi gpu
register_wrapper_tasks(worker_procs[0], task_id, task_signature_impl);
}
}
}
}
Expand Down Expand Up @@ -168,7 +162,7 @@ initialize_args_backing(RealmTrainingBacking *backing,
return RealmArgsBacking{runtime_arg_config, per_device_op_states};
}

Future<std::optional<float>>
Future<float>
execute_forward(RealmTrainingBacking &realm_training_backing,
layer_guid_t const &operator_node) {
if (registry_contains_task_for_layer(realm_training_backing.task_registry,
Expand Down Expand Up @@ -199,22 +193,22 @@ execute_forward(RealmTrainingBacking &realm_training_backing,
realm_training_backing.task_registry.task_mapping.at(task_id)
.impl_function;
// TODO: multi gpu launching
Promise<std::optional<float>> promise(realm_training_backing.master_mem);
Future<std::optional<float>> future = promise.get_future();
RealmTaskArgs<std::optional<float>> args{task_id, impl_function, accessor,
std::move(promise)};
Promise<float> promise(realm_training_backing.master_mem);
Future<float> future = promise.get_future();
RealmTaskArgs<float> args{task_id, impl_function, accessor,
std::move(promise)};
Event e = realm_training_backing.worker_procs[0].spawn(
static_cast<Processor::TaskFuncID>(task_id), &args, sizeof(args),
realm_training_backing.worker_events[0]);
realm_training_backing.worker_events[0] = e;
future.set_event(e);
return future;
} else {
return Future<std::optional<float>>(std::nullopt);
return Future<float>(0.0f);
}
}

Future<std::optional<float>>
Future<float>
execute_backward(RealmTrainingBacking &realm_training_backing,
layer_guid_t const &operator_node) {
if (registry_contains_task_for_layer(realm_training_backing.task_registry,
Expand Down Expand Up @@ -245,18 +239,18 @@ execute_backward(RealmTrainingBacking &realm_training_backing,
realm_training_backing.task_registry.task_mapping.at(task_id)
.impl_function;
// TODO: multi gpu launching
Promise<std::optional<float>> promise(realm_training_backing.master_mem);
Future<std::optional<float>> future = promise.get_future();
RealmTaskArgs<std::optional<float>> args{task_id, impl_function, accessor,
std::move(promise)};
Promise<float> promise(realm_training_backing.master_mem);
Future<float> future = promise.get_future();
RealmTaskArgs<float> args{task_id, impl_function, accessor,
std::move(promise)};
Event e = realm_training_backing.worker_procs[0].spawn(
static_cast<Processor::TaskFuncID>(task_id), &args, sizeof(args),
realm_training_backing.worker_events[0]);
realm_training_backing.worker_events[0] = e;
future.set_event(e);
return future;
} else {
return Future<std::optional<float>>(std::nullopt);
return Future<float>(0.0f);
}
}

Expand Down

0 comments on commit 2c0b573

Please sign in to comment.