|
8 | 8 | from torch_tensorrt import EngineCapability, Device
|
9 | 9 | from torch_tensorrt.fx.utils import LowerPrecision
|
10 | 10 |
|
11 |
| -from torch_tensorrt.dynamo.backend._settings import CompilationSettings |
12 | 11 | from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
|
13 | 12 | from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
|
14 | 13 | from torch_tensorrt.dynamo.backend._defaults import (
|
@@ -62,6 +61,10 @@ def compile(
|
62 | 61 |
|
63 | 62 | inputs = prepare_inputs(inputs, prepare_device(device))
|
64 | 63 |
|
| 64 | + if not isinstance(enabled_precisions, collections.abc.Collection): |
| 65 | + enabled_precisions = [enabled_precisions] |
| 66 | + |
| 67 | + # Parse user-specified enabled precisions |
65 | 68 | if (
|
66 | 69 | torch.float16 in enabled_precisions
|
67 | 70 | or torch_tensorrt.dtype.half in enabled_precisions
|
@@ -123,19 +126,12 @@ def create_backend(
|
123 | 126 | Returns:
|
124 | 127 | Backend for torch.compile
|
125 | 128 | """
|
126 |
| - if debug: |
127 |
| - logger.setLevel(logging.DEBUG) |
128 |
| - |
129 |
| - settings = CompilationSettings( |
| 129 | + return partial( |
| 130 | + torch_tensorrt_backend, |
130 | 131 | debug=debug,
|
131 | 132 | precision=precision,
|
132 | 133 | workspace_size=workspace_size,
|
133 | 134 | min_block_size=min_block_size,
|
134 | 135 | torch_executed_ops=torch_executed_ops,
|
135 | 136 | pass_through_build_failures=pass_through_build_failures,
|
136 | 137 | )
|
137 |
| - |
138 |
| - return partial( |
139 |
| - torch_tensorrt_backend, |
140 |
| - settings=settings, |
141 |
| - ) |
0 commit comments