Skip to content

Commit

Permalink
WIP: No PyTorch dep
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Feb 10, 2023
1 parent 089018b commit e2ab670
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 16 deletions.
10 changes: 7 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ if(TORCH_MLIR_ENABLE_STABLEHLO)
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
endif()

option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF)
if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER 0)
set(TORCH_MLIR_ENABLE_LTC 0)
endif()

if(TORCH_MLIR_ENABLE_LTC)
set(ENV{TORCH_MLIR_ENABLE_LTC} 1)
Expand Down Expand Up @@ -109,7 +115,6 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR TORCH_MLIR_OUT_OF_TREE_
# Don't try to compile the python extensions at the moment. We need
# to import lots of dependencies from AddMLIRPython to make this work.
set(MLIR_ENABLE_BINDINGS_PYTHON 1)
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)

set(TORCH-MLIR_BUILT_STANDALONE 1)
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
Expand All @@ -119,7 +124,6 @@ else()
# In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir

option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)

# TODO: Fix this upstream so that global include directories are not needed.
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
Expand Down Expand Up @@ -190,7 +194,7 @@ add_custom_target(check-torch-mlir-all)
add_dependencies(check-torch-mlir-all
check-torch-mlir
check-torch-mlir-dialects
check-torch-mlir-capi
# check-torch-mlir-capi
)

if(MLIR_ENABLE_BINDINGS_PYTHON)
Expand Down
18 changes: 10 additions & 8 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ endif()
declare_mlir_python_sources(TorchMLIRPythonSources)
declare_mlir_python_sources(TorchMLIRPythonExtensions)

declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
__init__.py
compiler_utils.py
dynamo.py
)
if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
__init__.py
compiler_utils.py
dynamo.py
)
endif()

declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

# If true, enable LTC build by default
TORCH_MLIR_ENABLE_LTC_DEFAULT = True
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False))

# Build phase discovery is unreliable. Just tell it what phases to run.
class CustomBuild(_build):
Expand Down Expand Up @@ -90,6 +91,7 @@ def run(self):
f"-DCMAKE_C_VISIBILITY_PRESET=hidden",
f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden",
f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}",
f"-DTORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS={'ON' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'OFF'}",
]

os.makedirs(cmake_build_dir, exist_ok=True)
Expand Down Expand Up @@ -159,9 +161,7 @@ def build_extension(self, ext):
ext_modules=[
CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"),
],
install_requires=[
"numpy",
f"torch=={torch.__version__}".split("+", 1)[0],
],
install_requires=["numpy", ] + [
f"torch=={torch.__version__}".split("+", 1)[0], ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else [],
zip_safe=False,
)
10 changes: 9 additions & 1 deletion test/python/smoketest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# RUN: %PYTHON %s
# RUN: %PYTHON %s | FileCheck %s

import torch_mlir.ir
from torch_mlir.dialects import torch

with torch_mlir.ir.Context() as ctx:
torch.register_dialect(ctx)
with torch_mlir.ir.Location.unknown() as loc:
module = torch_mlir.ir.Module.create(loc)
with torch_mlir.ir.InsertionPoint.at_block_begin(module.body):
n = torch.ConstantNoneOp()
# CHECK: module {
# CHECK: %none = torch.constant.none
# CHECK: }
module.operation.print()

0 comments on commit e2ab670

Please sign in to comment.