|
16 | 16 | from torch_tensorrt.dynamo._defaults import (
|
17 | 17 | DEBUG,
|
18 | 18 | DEVICE,
|
| 19 | + DISABLE_TF32, |
| 20 | + DLA_GLOBAL_DRAM_SIZE, |
| 21 | + DLA_LOCAL_DRAM_SIZE, |
| 22 | + DLA_SRAM_SIZE, |
19 | 23 | ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
|
| 24 | + ENGINE_CAPABILITY, |
20 | 25 | MAX_AUX_STREAMS,
|
21 | 26 | MIN_BLOCK_SIZE,
|
| 27 | + NUM_AVG_TIMING_ITERS, |
22 | 28 | OPTIMIZATION_LEVEL,
|
23 | 29 | PASS_THROUGH_BUILD_FAILURES,
|
24 | 30 | PRECISION,
|
| 31 | + REFIT, |
25 | 32 | REQUIRE_FULL_COMPILATION,
|
| 33 | + SPARSE_WEIGHTS, |
26 | 34 | TRUNCATE_LONG_AND_DOUBLE,
|
27 | 35 | USE_FAST_PARTITIONER,
|
28 | 36 | USE_PYTHON_RUNTIME,
|
@@ -51,17 +59,18 @@ def compile(
|
51 | 59 | inputs: Any,
|
52 | 60 | *,
|
53 | 61 | device: Optional[Union[Device, torch.device, str]] = DEVICE,
|
54 |
| - disable_tf32: bool = False, |
55 |
| - sparse_weights: bool = False, |
| 62 | + disable_tf32: bool = DISABLE_TF32, |
| 63 | + sparse_weights: bool = SPARSE_WEIGHTS, |
56 | 64 | enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
|
57 |
| - refit: bool = False, |
| 65 | + engine_capability: EngineCapability = ENGINE_CAPABILITY, |
| 66 | + refit: bool = REFIT, |
58 | 67 | debug: bool = DEBUG,
|
59 | 68 | capability: EngineCapability = EngineCapability.default,
|
60 |
| - num_avg_timing_iters: int = 1, |
| 69 | + num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS, |
61 | 70 | workspace_size: int = WORKSPACE_SIZE,
|
62 |
| - dla_sram_size: int = 1048576, |
63 |
| - dla_local_dram_size: int = 1073741824, |
64 |
| - dla_global_dram_size: int = 536870912, |
| 71 | + dla_sram_size: int = DLA_SRAM_SIZE, |
| 72 | + dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE, |
| 73 | + dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE, |
65 | 74 | calibrator: object = None,
|
66 | 75 | truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
|
67 | 76 | require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
|
@@ -200,6 +209,13 @@ def compile(
|
200 | 209 | "use_fast_partitioner": use_fast_partitioner,
|
201 | 210 | "enable_experimental_decompositions": enable_experimental_decompositions,
|
202 | 211 | "require_full_compilation": require_full_compilation,
|
| 212 | + "disable_tf32": disable_tf32, |
| 213 | + "sparse_weights": sparse_weights, |
| 214 | + "refit": refit, |
| 215 | + "engine_capability": engine_capability, |
| 216 | + "dla_sram_size": dla_sram_size, |
| 217 | + "dla_local_dram_size": dla_local_dram_size, |
| 218 | + "dla_global_dram_size": dla_global_dram_size, |
203 | 219 | }
|
204 | 220 |
|
205 | 221 | settings = CompilationSettings(**compilation_options)
|
|
0 commit comments