Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support to use brevitas quantized weights for stateless_llama #179

Merged
merged 2 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions python/shark_turbine/transforms/quantization/mm_group_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,22 @@ def match(self, op: Operation):
m=m,
n=n,
k=k,
element_type=self.builder.get_tensor_element_type(op.operands[0].type),
element_type=self.builder.get_tensor_element_type(
op.operands[0].type
),
)


# TODO (ian): Make more generalizable using RenameParametersPass. Currently hardcoded for brevitas quantization
GROUP_MATMUL_TEMPLATE = r"""
module {{
util.global private @{param_name}.quant {{noinline}} : tensor<{k}x{n_div}xi8>
util.global private @{param_name}.quant.scale {{noinline}} : tensor<{k}x{group0}x{element_type}>
util.global private @{param_name}.quant.zero_point {{noinline}} : tensor<{k}x{group0}x{element_type}>
util.global private @{param_name} {{noinline}} = #stream.parameter.named<"model"::"{param_name}"> : tensor<{k}x{n_div}xi8>
util.global private @{param_name}.quant.scale {{noinline}} = #stream.parameter.named<"model"::"{param_name}_scale"> : tensor<{k}x{group0}x{element_type}>
util.global private @{param_name}.quant.zero_point {{noinline}} = #stream.parameter.named<"model"::"{param_name}_zp"> : tensor<{k}x{group0}x{element_type}>
qedawkins marked this conversation as resolved.
Show resolved Hide resolved

func.func private @compute_mm_group_quant(%a : tensor<{m}x{n}x{element_type}>) -> tensor<{m}x{k}x{element_type}> {{
%c0 = arith.constant 0 : index
%weight_raw = util.global.load @{param_name}.quant : tensor<{k}x{n_div}xi8>
%weight_raw = util.global.load @{param_name} : tensor<{k}x{n_div}xi8>
%m = tensor.dim %a, %c0 : tensor<{m}x{n}x{element_type}>
%k = tensor.dim %weight_raw, %c0 : tensor<{k}x{n_div}xi8>
%scale = util.global.load @{param_name}.quant.scale : tensor<{k}x{group0}x{element_type}>
Expand Down Expand Up @@ -131,7 +134,9 @@ def __init__(self, root_op: Operation, *, group_size: int = 128):

def run(self):
globals = self.globals
mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder))
mms = match_children(
self.funcs, TransposedMMMatcher(globals, self.builder)
)

for mr in mms:
if mr.k is None or mr.n is None:
Expand All @@ -145,27 +150,32 @@ def run(self):

def rewrite(self, mr: TransposedMMResult):
none_to_q = lambda x: "?" if x is None else x
inline_module_asm = GROUP_MATMUL_TEMPLATE.format(
param_name=mr.param_name,
lowp_type="i4",
m=none_to_q(mr.m),
n=none_to_q(mr.n),
k=none_to_q(mr.k),
n_div=mr.n // 2,
group0=mr.n // self.group_size,
group1=self.group_size,
element_type=mr.element_type,
)
# TODO (ian): make generalizable and not specific for brevitas
if "lm_head.weight" not in mr.param_name:
inline_module_asm = GROUP_MATMUL_TEMPLATE.format(
# TODO (ian): Fix skipping the "_params." portion of the name to match safetensor format with RenameParametersPass
param_name=mr.param_name[8:],
qedawkins marked this conversation as resolved.
Show resolved Hide resolved
lowp_type="i4",
m=none_to_q(mr.m),
n=none_to_q(mr.n),
k=none_to_q(mr.k),
n_div=mr.n // 2,
group0=mr.n // self.group_size,
group1=self.group_size,
element_type=mr.element_type,
)

inline_module = Operation.parse(inline_module_asm, context=self.context)
actual_callee_name = self.merge_module(inline_module).translate_symbol(
"compute_mm_group_quant"
)
with InsertionPoint(mr.op), mr.op.location:
results = self.builder.call_native(
actual_callee_name, [mr.op.result.type], mr.op.operands[0]
inline_module = Operation.parse(
inline_module_asm, context=self.context
)
self.replace_op(mr.op, *results)
actual_callee_name = self.merge_module(
inline_module
).translate_symbol("compute_mm_group_quant")
with InsertionPoint(mr.op), mr.op.location:
results = self.builder.call_native(
actual_callee_name, [mr.op.result.type], mr.op.operands[0]
)
self.replace_op(mr.op, *results)


if __name__ == "__main__":
Expand Down
50 changes: 32 additions & 18 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):


def export_transformer_model(
hf_model_name, hf_auth_token, compile_to, external_weights=None, external_weight_file=None, quantization=None,
hf_model_name,
hf_auth_token,
compile_to,
external_weights=None,
external_weight_file=None,
quantization=None,
):
state_schema = pytree.treespec_loads(json_schema)

Expand All @@ -83,6 +88,7 @@ def export_transformer_model(
use_fast=False,
use_auth_token=hf_auth_token,
)

# TODO: generate these values instead of magic numbers
HEADS = 32
HIDDEN_DIM = 128
Expand All @@ -97,12 +103,14 @@ def export_transformer_model(
if external_weights == "safetensors":
mod_params = dict(mod.named_parameters())
for name in mod_params:
mapper["params."+name]=name
mapper["params." + name] = name
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)

elif external_weights=="gguf":
tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS)
elif external_weights == "gguf":
tensor_mapper = remap_gguf.TensorNameMap(
remap_gguf.MODEL_ARCH.LLAMA, HEADS
)
mapper = tensor_mapper.mapping

class StateUpdateModule(CompiledModule):
Expand All @@ -115,7 +123,9 @@ class StateUpdateModule(CompiledModule):
global_state = export_global(abstractify(global_pkv), mutable=True)
global_seq_step = export_global(AbstractIndex, mutable=True)

def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
def run_initialize(
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
):
init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ]
token, *state = self.initialize(x, constraints=init_const)
self.global_seq_step = IREE.tensor_dim(
Expand All @@ -135,9 +145,12 @@ def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)):
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
)
forw_const = [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + [
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) for x in state_arg[1:]
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1))
for x in state_arg[1:]
]
token, *state_update = self.forward(x, *state_arg, constraints=forw_const)
token, *state_update = self.forward(
x, *state_arg, constraints=forw_const
)
for i in range(HEADS * 2):
update = IREE.tensor_reshape(
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
Expand Down Expand Up @@ -171,32 +184,32 @@ def forward(token0: torch.Tensor, *state0_flat):
state0 = pytree.tree_unflatten(state0_flat, state_schema)
result = mod.forward(token0, past_key_values=state0)
state1_flat, _ = pytree.tree_flatten(result.past_key_values)
state1_flat = [torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat]
state1_flat = [
torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat
]
token1 = torch.argmax(result.logits[:, -1, :], dim=1)
token1 = token1[None, :]
return token1, *state1_flat

import_to = "IMPORT" if compile_to == "torch" else "INPUT"
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = StateUpdateModule(context=Context(), import_to=import_to)

# TODO: Integrate with external parameters to actually be able to run
# TODO: Make more generalizable to be able to quantize with all compile_to options
if quantization == "int4" and compile_to == "torch":
if quantization == "int4" and not compile_to == "linalg":
from shark_turbine.transforms.quantization import mm_group_quant

mm_group_quant.MMGroupQuantRewriterPass(
CompiledModule.get_mlir_module(inst).operation
).run()

module_str = str(CompiledModule.get_mlir_module(inst))

safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
if compile_to != "vmfb":
return module_str, tokenizer
else:
flags = [
"--iree-input-type=tm_tensor",
"--iree-input-type=torch",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--mlir-print-debuginfo",
"--mlir-print-op-on-diagnostic=false",
Expand Down Expand Up @@ -227,13 +240,9 @@ def forward(token0: torch.Tensor, *state0_flat):

def run_vmfb_comparison(args):
config = ireert.Config("local-task")
print(args.external_weight_file)

if args.external_weight_file:
from pathlib import Path

index = ireert.ParameterIndex()

index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
Expand All @@ -244,12 +253,17 @@ def run_vmfb_comparison(args):
mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb")
else:
sys.exit("no vmfb_path provided, required for run_vmfb")
vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)]

vm_modules = [
mod,
ireert.create_hal_module(config.vm_instance, config.device),
]
if args.external_weight_file:
param_module = ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
)
vm_modules.insert(0, param_module)

ctx = ireert.SystemContext(
vm_modules=vm_modules,
config=config,
Expand Down
Loading