-
Notifications
You must be signed in to change notification settings - Fork 236
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: intial implementation of realm-backend
- Loading branch information
1 parent
7887183
commit 9c16d76
Showing
15 changed files
with
1,042 additions
and
468 deletions.
There are no files selected for viewing
30 changes: 30 additions & 0 deletions
30
lib/realm-backend/include/realm-backend/allocated_tensors.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#ifndef _FLEXFLOW_LOCAL_EXECUTION_ALLOCATED_TENSORS_H | ||
#define _FLEXFLOW_LOCAL_EXECUTION_ALLOCATED_TENSORS_H | ||
|
||
#include "realm-backend/allocated_tensors.dtg.h" | ||
#include "pcg/computation_graph.h" | ||
|
||
namespace FlexFlow { | ||
|
||
bool are_allocated_forward_tensors_valid( | ||
AllocatedTensors const &, | ||
std::unordered_map<tensor_guid_t, TensorAttrs> const &); | ||
bool are_allocated_gradient_tensors_valid( | ||
AllocatedTensors const &, | ||
std::unordered_map<tensor_guid_t, TensorAttrs> const &); | ||
bool are_allocated_optimizer_tensors_valid( | ||
AllocatedTensors const &, | ||
std::unordered_map<tensor_guid_t, TensorAttrs> const &); | ||
|
||
bool are_allocated_tensors_valid( | ||
AllocatedTensors const &, | ||
std::unordered_map<tensor_guid_t, TensorAttrs> const &); | ||
|
||
bool is_allocated_tensor_backing_valid( | ||
TensorTypeVariant const &, | ||
std::unordered_map<TensorTypeVariant, GenericTensorAccessorW> const &, | ||
ArrayShape const &); | ||
|
||
} // namespace FlexFlow | ||
|
||
#endif |
32 changes: 32 additions & 0 deletions
32
lib/realm-backend/include/realm-backend/allocated_tensors.struct.toml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
namespace = "FlexFlow" | ||
name = "AllocatedTensors" | ||
features = [ | ||
"eq", | ||
"fmt", | ||
"hash", | ||
] | ||
|
||
includes = [ | ||
"task-spec/tensor_type_t.dtg.h", | ||
"kernels/accessor.h", | ||
"realm-backend/realm_allocator.h" | ||
] | ||
|
||
src_includes = [ | ||
"utils/hash/unordered_map.h", | ||
"utils/fmt/unordered_map.h", | ||
"utils/hash/vector.h", | ||
"utils/fmt/vector.h" | ||
] | ||
|
||
[[fields]] | ||
name = "tensor_type_backings" | ||
type = "std::unordered_map<::FlexFlow::TensorTypeVariant, std::pair<::FlexFlow::RealmRegion,::FlexFlow::TensorShape>>" | ||
|
||
[[fields]] | ||
name = "gradient_mapping" | ||
type = "std::unordered_map<::FlexFlow::tensor_guid_t, ::FlexFlow::gradient_tensor_t>" | ||
|
||
[[fields]] | ||
name = "optimizer_mapping" | ||
type = "std::unordered_map<::FlexFlow::tensor_guid_t, std::vector<::FlexFlow::optimizer_tensor_t>>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 17 additions & 17 deletions
34
lib/realm-backend/include/realm-backend/realm_args_backing.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,38 @@ | ||
#ifndef _FLEXFLOW_REALM_BACKEND_REALM_ARGS_BACKING_H | ||
#define _FLEXFLOW_REALM_BACKEND_REALM_ARGS_BACKING_H | ||
|
||
#include "local-execution/op_task_invocation.h" | ||
#include "local-execution/per_device_op_state.h" | ||
#include "local-execution/runtime_arg_config.h" | ||
#include "local-execution/task_invocation.dtg.h" | ||
#include "pcg/computation_graph.h" | ||
#include "pcg/layer_guid_t.dtg.h" | ||
#include "realm-backend/realm_task_argument_accessor.h" | ||
#include "realm-backend/task_result.h" | ||
#include "task-spec/op_task_invocation.h" | ||
#include "task-spec/per_device_op_state.h" | ||
#include "task-spec/runtime_arg_config.h" | ||
#include "task-spec/task_invocation.dtg.h" | ||
|
||
namespace FlexFlow { | ||
|
||
struct RealmArgsBacking { | ||
RealmArgsBacking(RuntimeArgConfig const &); | ||
|
||
public: | ||
void add_per_device_op_state(layer_guid_t const &, | ||
Future<DeviceSpecificDeviceStates> &&); | ||
|
||
ArgSlotsBacking construct_arg_slots_backing(TaskBinding const &) const; | ||
|
||
ConcreteArgSpec lower_to_concrete_arg_spec(RuntimeArgRefSpec const &) const; | ||
ConcreteArgSpec lower_to_concrete_arg_spec(OpArgRefSpec const &, | ||
ComputationGraph const &, | ||
layer_guid_t const &) const; | ||
RealmArgsBacking(RuntimeArgConfig const &, | ||
std::unordered_map<layer_guid_t, DeviceSpecificDeviceStates> const &); | ||
|
||
public: | ||
// arguments | ||
RuntimeArgConfig runtime_arg_config; | ||
std::unordered_map<layer_guid_t, DeviceSpecificDeviceStates> | ||
per_device_op_states; | ||
RuntimeArgConfig runtime_arg_config; | ||
}; | ||
|
||
RealmArgsBacking | ||
make_args_backing_with_empty_device_states(RuntimeArgConfig const &); | ||
|
||
std::optional<DeviceSpecificDeviceStates> | ||
get_per_device_op_state_if_exists(RealmArgsBacking const &, | ||
layer_guid_t const &); | ||
|
||
ArgSlotsBacking construct_arg_slots_backing(TaskBinding const &, | ||
RuntimeArgConfig const &); | ||
|
||
} // namespace FlexFlow | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.