Skip to content

Commit c40d96b

Browse files
committed
Merge remote-tracking branch 'upstream/main' into unity
2 parents 474c06b + 5d4c01e commit c40d96b

File tree

13 files changed

+314
-159
lines changed

13 files changed

+314
-159
lines changed

python/tvm/driver/tvmc/compiler.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tvm import autotvm, auto_scheduler
3232
from tvm import relay
3333
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
34-
from tvm.ir.instrument import PassInstrument, PassTimingInstrument
34+
from tvm.ir.instrument import PassInstrument, PassTimingInstrument, PassPrintingInstrument
3535
from tvm.ir.memory_pools import WorkspaceMemoryPools
3636
from tvm.target import Target
3737
from tvm.relay.backend import Executor, Runtime
@@ -162,6 +162,18 @@ def add_compile_parser(subparsers, _, json_params):
162162
action="store_true",
163163
help="print compilation time per pass",
164164
)
165+
parser.add_argument(
166+
"--print-ir-before",
167+
help="print IR before each named pass of a comma-separated list of pass names."
168+
"e.g. '--print-ir-before [tir.SplitHostDevice,tir.ConvertSSA]' ",
169+
default="",
170+
)
171+
parser.add_argument(
172+
"--print-ir-after",
173+
help="print IR after each named pass of a comma-separated list of pass names."
174+
"e.g. '--print-ir-after [tir.SplitHostDevice,tir.ConvertSSA]' ",
175+
default="",
176+
)
165177
for one_entry in json_params:
166178
parser.set_defaults(**one_entry)
167179

@@ -220,6 +232,8 @@ def drive_compile(args):
220232
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
221233
),
222234
print_pass_times=args.print_pass_times,
235+
print_ir_before=args.print_ir_before,
236+
print_ir_after=args.print_ir_after,
223237
**transform_args,
224238
)
225239

@@ -247,6 +261,8 @@ def compile_model(
247261
mod_name: Optional[str] = "default",
248262
workspace_pools: Optional[WorkspaceMemoryPools] = None,
249263
print_pass_times: bool = False,
264+
print_ir_before: Optional[List[str]] = None,
265+
print_ir_after: Optional[List[str]] = None,
250266
instruments: Optional[Sequence[PassInstrument]] = None,
251267
desired_layout: Optional[str] = None,
252268
desired_layout_ops: Optional[List[str]] = None,
@@ -295,7 +311,7 @@ def compile_model(
295311
needs to be generated.
296312
disabled_pass: str, optional
297313
Comma-separated list of passes which needs to be disabled
298-
during compilation
314+
during compilation.
299315
pass_context_configs: list[str], optional
300316
List of strings containing a set of configurations to be passed to the
301317
PassContext.
@@ -310,6 +326,10 @@ def compile_model(
310326
compilation.
311327
print_pass_times: bool
312328
To enable printing a breakdown of compilation times by pass. Disabled by default.
329+
print_ir_before: list[str], optional
330+
To print IR before each named pass of a comma-separated list of passes.
331+
print_ir_after: list[str], optional
332+
To print IR after each named pass of a comma-separated list of passes.
313333
instruments: Optional[Sequence[PassInstrument]]
314334
The list of pass instrument implementations.
315335
desired_layout: str, optional
@@ -369,6 +389,12 @@ def compile_model(
369389
timing_inst = PassTimingInstrument()
370390
instruments = [timing_inst] if instruments is None else [timing_inst] + instruments
371391

392+
if print_ir_before or print_ir_after:
393+
print_ir_instr = PassPrintingInstrument(
394+
print_before_pass_names=print_ir_before, print_after_pass_names=print_ir_after
395+
)
396+
instruments = [print_ir_instr] if instruments is None else [print_ir_instr] + instruments
397+
372398
with tvm.transform.PassContext(
373399
opt_level=opt_level,
374400
config=config,
@@ -581,7 +607,6 @@ def dump_operation_offloads(mod: tvm.ir.IRModule, initial_mod: tvm.ir.IRModule,
581607
save_to_file = all([dump_path != "-", dump_path != ""])
582608

583609
if print_to_console or save_to_file:
584-
585610
operations_distribution = analyze_operations_distribution(mod)
586611

587612
def annotate_f(x):

python/tvm/ir/instrument.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,21 @@ def render():
255255
profiles = timing_inst.render()
256256
"""
257257
return _ffi_instrument_api.RenderTimePassProfiles()
258+
259+
260+
@pass_instrument
261+
class PassPrintingInstrument:
262+
"""A pass instrument to print if before or
263+
print ir after each element of a named pass."""
264+
265+
def __init__(self, print_before_pass_names, print_after_pass_names):
266+
self.print_before_pass_names = print_before_pass_names
267+
self.print_after_pass_names = print_after_pass_names
268+
269+
def run_before_pass(self, mod, pass_info):
270+
if pass_info.name in self.print_before_pass_names:
271+
print(f"Print IR before: {pass_info.name}\n{mod}\n\n")
272+
273+
def run_after_pass(self, mod, pass_info):
274+
if pass_info.name in self.print_after_pass_names:
275+
print(f"Print IR after: {pass_info.name}\n{mod}\n\n")

python/tvm/topi/arm_cpu/arm_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,13 @@ def get_tiling_B_transformed(interleave_A, in_dtype):
7474
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
7575
tile_N = 4
7676
tile_K = 16
77+
# In non-quantized cases, A is not interleaved.
78+
elif in_dtype == "float16" and target.features.has_fp16_simd:
79+
# Each load from B' contains 32 elements (i.e. 32 columns from B)
80+
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
81+
tile_N = 32
82+
tile_K = 4
7783
else:
78-
# In non-quantized cases, A is not interleaved.
7984
# Each load from B' contains 16 elements (i.e. 16 columns from B)
8085
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
8186
tile_N = 16

python/tvm/topi/arm_cpu/conv2d_gemm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tvm import te
2323
from tvm.topi import nn
2424
from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
25+
from tvm.topi.arm_cpu.arm_utils import get_tiling_B_transformed
2526
from ..utils import get_const_tuple, get_const_int
2627
from ..nn.utils import get_pad_tuple
2728
from .tensor_intrin import (
@@ -339,7 +340,15 @@ def compute_conv2d_gemm_without_weight_transform(
339340
),
340341
name="C",
341342
)
342-
zero = tvm.tir.const(0)
343+
# Ensure padding on the N axis does not get removed during tir passes
344+
# by adding a dummy reference to the specific padded area of the result
345+
if in_dtype == "float16" and target.features.has_fp16_simd:
346+
zero = (
347+
tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
348+
- tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
349+
)
350+
else:
351+
zero = tvm.tir.const(0)
343352

344353
# Reshape the result into a convolution output
345354
out_shape = (batches, OH, OW, OC)
@@ -454,14 +463,14 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
454463
C = out.op.input_tensors[0]
455464
A = C.op.input_tensors[0]
456465
in_type = A.dtype
466+
y_tile_size, _ = get_tiling_B_transformed(False, in_type)
457467

458468
# Computation
459469
b, x, y = C.op.axis
460470
(k,) = C.op.reduce_axis
461471

462472
if in_type in ["int8", "uint8"]:
463473
k_outer, k_inner = s[C].split(k, 16)
464-
y_tile_size = 16
465474
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
466475
s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner)
467476
gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1)
@@ -470,9 +479,8 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
470479
s[C].parallel(x_outer)
471480
else:
472481
k_outer, k_inner = s[C].split(k, 4)
473-
y_tile_size = 16
474482
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
475-
y_inner_outer, y_inner_inner = s[C].split(y_inner, 4)
483+
y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4)
476484
b_x_outer_fused = s[C].fuse(b, x_outer)
477485
s[C].parallel(b_x_outer_fused)
478486
s[C].reorder(

src/relay/transforms/pattern_utils.h

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -468,43 +468,43 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
468468
* \param i element index
469469
* \return Converted scalar value, or None if conversion failed
470470
*/
471-
static inline std::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
471+
template <typename T>
472+
static inline std::optional<T> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
472473
if (array->dtype.code == kDLInt) {
473474
if (array->dtype.bits == 8) {
474-
return std::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
475+
return std::optional<T>(reinterpret_cast<int8_t*>(array->data)[i]);
475476
} else if (array->dtype.bits == 16) {
476-
return std::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
477+
return std::optional<T>(reinterpret_cast<int16_t*>(array->data)[i]);
477478
} else if (array->dtype.bits == 32) {
478-
return std::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
479+
return std::optional<T>(reinterpret_cast<int32_t*>(array->data)[i]);
479480
} else if (array->dtype.bits == 64) {
480-
return std::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
481+
return std::optional<T>(reinterpret_cast<int64_t*>(array->data)[i]);
481482
}
482483
} else if (array->dtype.code == kDLUInt) {
483484
if (array->dtype.bits == 1) { // bool
484-
return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
485+
return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]);
485486
} else if (array->dtype.bits == 8) {
486-
return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
487+
return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]);
487488
} else if (array->dtype.bits == 16) {
488-
return std::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
489+
return std::optional<T>(reinterpret_cast<uint16_t*>(array->data)[i]);
489490
} else if (array->dtype.bits == 32) {
490-
return std::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
491+
return std::optional<T>(reinterpret_cast<uint32_t*>(array->data)[i]);
491492
} else if (array->dtype.bits == 64) {
492-
return std::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
493+
return std::optional<T>(reinterpret_cast<uint64_t*>(array->data)[i]);
493494
}
494495
} else if (array->dtype.code == kDLFloat) {
495496
if (array->dtype.bits == 16) {
496-
return std::optional<long double>(
497-
__extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
498-
reinterpret_cast<uint16_t*>(array->data)[i]));
497+
return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
498+
reinterpret_cast<uint16_t*>(array->data)[i]));
499499
}
500500
if (array->dtype.bits == 32) {
501-
return std::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
501+
return std::optional<T>(reinterpret_cast<float*>(array->data)[i]);
502502
} else if (array->dtype.bits == 64) {
503-
return std::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
503+
return std::optional<T>(reinterpret_cast<double*>(array->data)[i]);
504504
}
505505
} else if (array->dtype.code == kDLBfloat) {
506506
if (array->dtype.bits == 16) {
507-
return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
507+
return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
508508
reinterpret_cast<uint16_t*>(array->data)[i]));
509509
}
510510
}
@@ -517,8 +517,15 @@ static inline std::optional<long double> TryToScalar(const runtime::NDArray& arr
517517
* \param i element index
518518
* \return Converted scalar value
519519
*/
520+
template <typename T>
521+
static inline T ToScalar(const runtime::NDArray& array, size_t i = 0) {
522+
auto try_value = TryToScalar<T>(array, i);
523+
ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
524+
return try_value.value();
525+
}
526+
520527
static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
521-
auto try_value = TryToScalar(array, i);
528+
auto try_value = TryToScalar<long double>(array, i);
522529
ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
523530
return try_value.value();
524531
}
@@ -534,7 +541,7 @@ static inline Array<Integer> ToVector(const runtime::NDArray& array) {
534541
size_t len = array.Shape().front();
535542
Array<Integer> out;
536543
for (size_t i = 0; i < len; ++i) {
537-
long double elem_val = ToScalar(array, i);
544+
uint64_t elem_val = ToScalar<uint64_t>(array, i);
538545
out.push_back(Integer(IntImm(DataType::Int(32), static_cast<int64_t>(elem_val))));
539546
}
540547
return out;

src/relay/transforms/simplify_expr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ class EliminateIdentityRewrite : public DFPatternRewrite {
794794
if (!IsScalar(GetRef<Expr>(constant))) {
795795
return false;
796796
}
797-
auto value = TryToScalar(constant->data, 0);
797+
auto value = TryToScalar<long double>(constant->data, 0);
798798
if (!value) {
799799
// unsupported dtype
800800
return false;

0 commit comments

Comments
 (0)