forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Propagate device data names (llvm#1157)
* Propagate device data names * Address PR comment * Add example usage * Add test for device data names * Make TorchMlirComputation fields protected * Add lazy backend device data name unit tests * Disable lazy backend tests if LTC is disabled * Add comments
- Loading branch information
1 parent
84d345c
commit 0af5578
Showing
14 changed files
with
242 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
# RUN: %PYTHON %s | FileCheck %s | ||
|
||
|
||
import torch | ||
import torch._lazy | ||
|
||
import torch_mlir.reference_lazy_backend._REFERENCE_LAZY_BACKEND as lazy_backend | ||
|
||
from run_test import run_test | ||
|
||
lazy_backend._initialize() | ||
|
||
device = "lazy" | ||
|
||
|
||
# CHECK: 0 input tensors found | ||
# ----- | ||
# CHECK: PASS - test_no_device_data_name | ||
@run_test | ||
def test_no_device_data_name(): | ||
x = torch.tensor(1).to(device) | ||
y = torch.tensor(2).to(device) | ||
z = x + y | ||
torch._lazy.mark_step() | ||
|
||
|
||
# CHECK: Input tensor: input_x | ||
# CHECK: 1 input tensors found | ||
# ----- | ||
# CHECK: PASS - test_device_data_name | ||
@run_test | ||
def test_device_data_name(): | ||
x = torch.tensor(1).to(device) | ||
y = torch.tensor(2).to(device) | ||
|
||
lazy_backend.set_parameter_name(x, "input_x") | ||
|
||
z = x + y | ||
torch._lazy.mark_step() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
# RUN: true | ||
|
||
|
||
def run_test(*args, XPASS=False, XFAIL=False): | ||
def _run_test(test): | ||
test_name = test.__name__ | ||
try: | ||
test() | ||
print(("X" if XPASS else "") + f"PASS - {test_name}") | ||
except Exception as e: | ||
print(("X" if XFAIL else "") + f"FAIL - {test_name}") | ||
print("Errors: ", e) | ||
print(flush=True) | ||
|
||
if len(args): | ||
_run_test(args[0]) | ||
else: | ||
return _run_test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 31 additions & 0 deletions
31
python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#pragma once | ||
|
||
#include <string> | ||
#include <sstream> | ||
#include <vector> | ||
|
||
|
||
template <typename T> | ||
std::ostream& string_join(std::ostream& out, const std::vector<T>& v, const std::string& delimiter) { | ||
size_t i = 0; | ||
for (const T& e : v) { | ||
if ((i++) > 0) { out << delimiter; } | ||
out << e; | ||
} | ||
return out; | ||
} | ||
|
||
template <typename T> | ||
std::string string_join(const std::vector<T>& v, const std::string& delimiter) { | ||
std::ostringstream joined; | ||
string_join(joined, v, delimiter); | ||
return joined.str(); | ||
} | ||
|
||
|
||
/* | ||
* Returns true if str starts with prefix | ||
*/ | ||
inline bool startswith(const std::string& str, const std::string& prefix) { | ||
return str.rfind(prefix, 0) == 0; | ||
} |
30 changes: 30 additions & 0 deletions
30
python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#pragma once | ||
|
||
#include "torch/csrc/lazy/backend/backend_device.h" | ||
#include "torch/csrc/lazy/core/tensor.h" | ||
|
||
#include "../ops/device_data.h" | ||
|
||
|
||
namespace torch { | ||
namespace lazy { | ||
|
||
inline torch::lazy::DeviceData* device_data_cast( | ||
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt | ||
) { | ||
if (!device) { | ||
device = torch::lazy::GetBackendDevice(tensor); | ||
} | ||
TORCH_CHECK(device); | ||
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device); | ||
if (lazy_tensor) { | ||
torch::lazy::Value param_value = lazy_tensor->GetIrValue(); | ||
if (param_value && param_value->op() == torch::lazy::DeviceData::ClassOpKind()) { | ||
return dynamic_cast<torch::lazy::DeviceData*>(param_value.node.get()); | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
} // namespace lazy | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters