Skip to content

Commit

Permalink
Revert "[SLM] Allow modules to define pre-processing of weights" (#16777
Browse files Browse the repository at this point in the history
)

Revert "[SLM] Allow modules to define pre-processing of weights (#16757)"

This reverts commit 1cccc3b.
  • Loading branch information
tqchen authored Mar 25, 2024
1 parent 5a8d928 commit ef46f4e
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 498 deletions.
17 changes: 1 addition & 16 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,22 +591,7 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]:
The computed result.
"""
if not isinstance(expr, rx.DataflowVar):
block_builder = BlockBuilder.current()
if block_builder is None:
# Normalize to make sure we have valid StructInfo, but
# wait until we are actually building the function to
# flatten nested expressions.
#
# TODO(Lunderberg): Make this easier to call. Infering
# struct info for a nested expression should be doable in
# a free function, without requiring an active
# BlockBuilder and an active FunctionFrame.
builder = BlockBuilder()
with builder.function("dummy_scope", params=[]):
expr = builder.normalize(expr)
builder.emit_func_output([])
else:
expr = BlockBuilder.current().emit(expr, name)
expr = BlockBuilder.current().emit(expr, name)
if isinstance(expr.struct_info_, TensorStructInfo):
return Tensor(_expr=expr)
if isinstance(expr.struct_info_, TupleStructInfo):
Expand Down
40 changes: 21 additions & 19 deletions python/tvm/relax/frontend/nn/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
return result

# pylint: enable=protected-access

params = _params()
params = None
effects = _effects()
ext_mods = self.extern_mods
with self:
Expand All @@ -122,6 +121,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
outputs = _emit_effect_init(self.builder, effects)
self.builder.emit_func_output(outputs, params=[])
for method_name, method_spec in zip(spec.method_names, spec.method_specs):
params = _params() # Re-initialize so symbolic shapes not shared across methods
len_args = len(method_spec.arg_specs)
len_effects = {
"packed": 1,
Expand All @@ -135,18 +135,9 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
with self.builder.dataflow():
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)

# TODO(Lunderberg): Make a `ir.transform.ConvertSSA`,
# similar to the existing `tir.transform.ConvertSSA`,
# that converts an entire module to SSA, including TIR
# variable definitions used in either TIR or Relax.
mod = self.builder.get()
mod[method_name] = rx.utils.copy_with_new_vars(mod[method_name])

mod = self.builder.finalize()
assert rx.analysis.well_formed(mod)

mod = rx.transform.CanonicalizeBindings()(mod)
return mod, params, ext_mods


Expand All @@ -170,6 +161,8 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many-
effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]],
):
# pylint: disable=protected-access
# symbolic shape's name mapping to its tir.Var for reuse
str2var_params: typing.Dict[str, tir.Var] = {}

def _unwrap_ret(expr: typing.Any) -> typing.Any:
if isinstance(expr, (core.Tensor, core.Object)):
Expand All @@ -183,26 +176,35 @@ def _unwrap_ret(expr: typing.Any) -> typing.Any:
def _convert_input(arg):
if isinstance(arg, tir.Var):
return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
elif isinstance(arg, (core.Tensor, core.Object)):
if isinstance(arg, (core.Tensor, core.Object)):
return arg._expr # pylint: disable=protected-access
elif isinstance(arg, _spec.Tuple):
if isinstance(arg, _spec.Tuple):
return rx.Var(
arg.name,
struct_info=TupleStructInfo(
[_convert_input(arg_i).struct_info for arg_i in arg.elements]
),
)
elif isinstance(arg, rx.Expr):
return arg
else:
raise TypeError(f"Unsupported input type: {type(arg)}")
raise TypeError(f"Unsupported input type: {type(arg)}")

def _params(mode: str) -> typing.List[rx.Var]:
inputs: typing.List[rx.Var] = []

for name, param in params:
inputs.append(param._expr)
def _get_var(shape_var: tir.Var) -> tir.Var:
name = shape_var.name
if name in str2var_params:
return str2var_params[name]
var = tir.Var(name, "int64")
str2var_params[name] = var
return var

for name, param in params:
# Make sure the a symbolic shape is not re-registered (same as _method_spec_to_inputs)
# e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` for `embed_tokens`
new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in param.shape]
var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr
inputs.append(var)
param._expr = var
if mode == "none":
return []
if mode == "plain":
Expand Down
Loading

0 comments on commit ef46f4e

Please sign in to comment.