forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Propagate parameter names to TorchMlirComputation (llvm#1420)
* Propagate parameter name to MLIR * Add TorchMlirNode Constructor Hook * Make func_op mutable - Purpose of this is to allow modification of func_op by subclass backend * Clean up unnecessary changes * Remove unnecessary attribute case * Address PR comments
- Loading branch information
1 parent
8f608c0
commit fa5a8e2
Showing
12 changed files
with
160 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#include "tensor_utils.h" | ||
|
||
#include "../generated/LazyIr.h" | ||
#include "../mlir_node.h" | ||
|
||
|
||
namespace torch { | ||
namespace lazy { | ||
|
||
bool is_detach_copy(const torch::lazy::Value& value) { | ||
return value->op() == torch::lazy::DetachCopy::ClassOpKind(); | ||
} | ||
|
||
torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { | ||
if (!value) { | ||
return nullptr; | ||
} | ||
torch::lazy::TorchMlirNode* node = dynamic_cast<torch::lazy::TorchMlirNode*>(value.node.get()); | ||
while(node) { | ||
if (node->op() == torch::lazy::DeviceData::ClassOpKind()) { | ||
return dynamic_cast<torch::lazy::DeviceData*>(node); | ||
} | ||
else if (node->op() == torch::lazy::DetachCopy::ClassOpKind()) { | ||
node = node->mlir_node(0); | ||
} | ||
else { | ||
break; | ||
} | ||
} | ||
} | ||
|
||
torch::lazy::DeviceData* device_data_cast( | ||
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device | ||
) { | ||
if (!device) { | ||
device = torch::lazy::GetBackendDevice(tensor); | ||
} | ||
TORCH_CHECK(device); | ||
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device); | ||
if (lazy_tensor) { | ||
return device_data_cast(lazy_tensor->GetIrValue()); | ||
} | ||
return nullptr; | ||
} | ||
|
||
} // namespace lazy | ||
} // namespace torch |
Oops, something went wrong.