Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into unity
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 11, 2024
2 parents 474c06b + 5d4c01e commit c40d96b
Show file tree
Hide file tree
Showing 13 changed files with 314 additions and 159 deletions.
31 changes: 28 additions & 3 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tvm import autotvm, auto_scheduler
from tvm import relay
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
from tvm.ir.instrument import PassInstrument, PassTimingInstrument
from tvm.ir.instrument import PassInstrument, PassTimingInstrument, PassPrintingInstrument
from tvm.ir.memory_pools import WorkspaceMemoryPools
from tvm.target import Target
from tvm.relay.backend import Executor, Runtime
Expand Down Expand Up @@ -162,6 +162,18 @@ def add_compile_parser(subparsers, _, json_params):
action="store_true",
help="print compilation time per pass",
)
parser.add_argument(
"--print-ir-before",
help="print IR before each named pass of a comma-separated list of pass names."
"e.g. '--print-ir-before [tir.SplitHostDevice,tir.ConvertSSA]' ",
default="",
)
parser.add_argument(
"--print-ir-after",
help="print IR after each named pass of a comma-separated list of pass names."
"e.g. '--print-ir-after [tir.SplitHostDevice,tir.ConvertSSA]' ",
default="",
)
for one_entry in json_params:
parser.set_defaults(**one_entry)

Expand Down Expand Up @@ -220,6 +232,8 @@ def drive_compile(args):
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
),
print_pass_times=args.print_pass_times,
print_ir_before=args.print_ir_before,
print_ir_after=args.print_ir_after,
**transform_args,
)

Expand Down Expand Up @@ -247,6 +261,8 @@ def compile_model(
mod_name: Optional[str] = "default",
workspace_pools: Optional[WorkspaceMemoryPools] = None,
print_pass_times: bool = False,
print_ir_before: Optional[List[str]] = None,
print_ir_after: Optional[List[str]] = None,
instruments: Optional[Sequence[PassInstrument]] = None,
desired_layout: Optional[str] = None,
desired_layout_ops: Optional[List[str]] = None,
Expand Down Expand Up @@ -295,7 +311,7 @@ def compile_model(
needs to be generated.
disabled_pass: str, optional
Comma-separated list of passes which needs to be disabled
during compilation
during compilation.
pass_context_configs: list[str], optional
List of strings containing a set of configurations to be passed to the
PassContext.
Expand All @@ -310,6 +326,10 @@ def compile_model(
compilation.
print_pass_times: bool
To enable printing a breakdown of compilation times by pass. Disabled by default.
print_ir_before: list[str], optional
To print IR before each named pass of a comma-separated list of passes.
print_ir_after: list[str], optional
To print IR after each named pass of a comma-separated list of passes.
instruments: Optional[Sequence[PassInstrument]]
The list of pass instrument implementations.
desired_layout: str, optional
Expand Down Expand Up @@ -369,6 +389,12 @@ def compile_model(
timing_inst = PassTimingInstrument()
instruments = [timing_inst] if instruments is None else [timing_inst] + instruments

if print_ir_before or print_ir_after:
print_ir_instr = PassPrintingInstrument(
print_before_pass_names=print_ir_before, print_after_pass_names=print_ir_after
)
instruments = [print_ir_instr] if instruments is None else [print_ir_instr] + instruments

with tvm.transform.PassContext(
opt_level=opt_level,
config=config,
Expand Down Expand Up @@ -581,7 +607,6 @@ def dump_operation_offloads(mod: tvm.ir.IRModule, initial_mod: tvm.ir.IRModule,
save_to_file = all([dump_path != "-", dump_path != ""])

if print_to_console or save_to_file:

operations_distribution = analyze_operations_distribution(mod)

def annotate_f(x):
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/ir/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,21 @@ def render():
profiles = timing_inst.render()
"""
return _ffi_instrument_api.RenderTimePassProfiles()


@pass_instrument
class PassPrintingInstrument:
"""A pass instrument to print if before or
print ir after each element of a named pass."""

def __init__(self, print_before_pass_names, print_after_pass_names):
self.print_before_pass_names = print_before_pass_names
self.print_after_pass_names = print_after_pass_names

def run_before_pass(self, mod, pass_info):
if pass_info.name in self.print_before_pass_names:
print(f"Print IR before: {pass_info.name}\n{mod}\n\n")

def run_after_pass(self, mod, pass_info):
if pass_info.name in self.print_after_pass_names:
print(f"Print IR after: {pass_info.name}\n{mod}\n\n")
7 changes: 6 additions & 1 deletion python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,13 @@ def get_tiling_B_transformed(interleave_A, in_dtype):
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_N = 4
tile_K = 16
# In non-quantized cases, A is not interleaved.
elif in_dtype == "float16" and target.features.has_fp16_simd:
# Each load from B' contains 32 elements (i.e. 32 columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
tile_N = 32
tile_K = 4
else:
# In non-quantized cases, A is not interleaved.
# Each load from B' contains 16 elements (i.e. 16 columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
tile_N = 16
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm import te
from tvm.topi import nn
from tvm.autotvm.task.space import AnnotateEntity, ReorderEntity, OtherOptionEntity
from tvm.topi.arm_cpu.arm_utils import get_tiling_B_transformed
from ..utils import get_const_tuple, get_const_int
from ..nn.utils import get_pad_tuple
from .tensor_intrin import (
Expand Down Expand Up @@ -339,7 +340,15 @@ def compute_conv2d_gemm_without_weight_transform(
),
name="C",
)
zero = tvm.tir.const(0)
# Ensure padding on the N axis does not get removed during tir passes
# by adding a dummy reference to the specific padded area of the result
if in_dtype == "float16" and target.features.has_fp16_simd:
zero = (
tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
- tvm.tir.const(1, C.dtype) * C[0, 0, N_padded - 1]
)
else:
zero = tvm.tir.const(0)

# Reshape the result into a convolution output
out_shape = (batches, OH, OW, OC)
Expand Down Expand Up @@ -454,14 +463,14 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
C = out.op.input_tensors[0]
A = C.op.input_tensors[0]
in_type = A.dtype
y_tile_size, _ = get_tiling_B_transformed(False, in_type)

# Computation
b, x, y = C.op.axis
(k,) = C.op.reduce_axis

if in_type in ["int8", "uint8"]:
k_outer, k_inner = s[C].split(k, 16)
y_tile_size = 16
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner)
gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1)
Expand All @@ -470,9 +479,8 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
s[C].parallel(x_outer)
else:
k_outer, k_inner = s[C].split(k, 4)
y_tile_size = 16
x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=y_tile_size)
y_inner_outer, y_inner_inner = s[C].split(y_inner, 4)
y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4)
b_x_outer_fused = s[C].fuse(b, x_outer)
s[C].parallel(b_x_outer_fused)
s[C].reorder(
Expand Down
43 changes: 25 additions & 18 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,43 +468,43 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
* \param i element index
* \return Converted scalar value, or None if conversion failed
*/
static inline std::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
template <typename T>
static inline std::optional<T> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
return std::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<int8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
return std::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<int16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
return std::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<int32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
return std::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<int64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLUInt) {
if (array->dtype.bits == 1) { // bool
return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]);
} else if (array->dtype.bits == 8) {
return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
return std::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<uint16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
return std::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<uint32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
return std::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<uint64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLFloat) {
if (array->dtype.bits == 16) {
return std::optional<long double>(
__extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
reinterpret_cast<uint16_t*>(array->data)[i]));
return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
reinterpret_cast<uint16_t*>(array->data)[i]));
}
if (array->dtype.bits == 32) {
return std::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<float*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
return std::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
return std::optional<T>(reinterpret_cast<double*>(array->data)[i]);
}
} else if (array->dtype.code == kDLBfloat) {
if (array->dtype.bits == 16) {
return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
reinterpret_cast<uint16_t*>(array->data)[i]));
}
}
Expand All @@ -517,8 +517,15 @@ static inline std::optional<long double> TryToScalar(const runtime::NDArray& arr
* \param i element index
* \return Converted scalar value
*/
template <typename T>
static inline T ToScalar(const runtime::NDArray& array, size_t i = 0) {
auto try_value = TryToScalar<T>(array, i);
ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
return try_value.value();
}

static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
auto try_value = TryToScalar(array, i);
auto try_value = TryToScalar<long double>(array, i);
ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
return try_value.value();
}
Expand All @@ -534,7 +541,7 @@ static inline Array<Integer> ToVector(const runtime::NDArray& array) {
size_t len = array.Shape().front();
Array<Integer> out;
for (size_t i = 0; i < len; ++i) {
long double elem_val = ToScalar(array, i);
uint64_t elem_val = ToScalar<uint64_t>(array, i);
out.push_back(Integer(IntImm(DataType::Int(32), static_cast<int64_t>(elem_val))));
}
return out;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ class EliminateIdentityRewrite : public DFPatternRewrite {
if (!IsScalar(GetRef<Expr>(constant))) {
return false;
}
auto value = TryToScalar(constant->data, 0);
auto value = TryToScalar<long double>(constant->data, 0);
if (!value) {
// unsupported dtype
return false;
Expand Down
Loading

0 comments on commit c40d96b

Please sign in to comment.