diff --git a/README.md b/README.md index f7522fa1e..0043df22d 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # f Nightly Release ```Shell -pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124 +pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124 ``` From source diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index 174038d20..ebf9e1e73 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -17,6 +17,7 @@ _replace_with_custom_fn_if_matches_filter, ) import copy +from torchao.utils import unwrap_tensor_subclass def _int8wo_api(mod, **kwargs): if TORCH_VERSION_AT_LEAST_2_4: @@ -133,15 +134,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): WARMUP = 20 RUNS = 100 + torch._dynamo.reset() m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True) benchmark_model(m_ref, WARMUP, example_inputs) ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs) + torch._dynamo.reset() m = torch.compile(m, mode='max-autotune', fullgraph=True) benchmark_model(m, WARMUP, example_inputs) elapsed_time = benchmark_model(m, RUNS, example_inputs) - + torch._dynamo.reset() m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True) benchmark_model(m_bf16, WARMUP, example_inputs) bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index a5a05f041..91d344ac1 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + """ This is a script to estimate the benefit from converting a `torch.nn.Linear` layer to float8, by estimating the difference in e2e GPU kernel time between: @@ -45,26 +51,10 @@ import torch import torch.utils.benchmark as benchmark -BYTES_PER_EL_FLOAT8 = 1 -BYTES_PER_EL_BF16 = 2 - -# https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity -H100_BF16_PEAK_TOPS = 989e12 -H100_FP8_PEAK_TOPS = 1979e12 - -# 2.4 TB per second, custom to Meta's H100 variant -H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12 - -# based on quick experimental observation with sample large inputs -H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6 - -# based on previous experience looking at pointwise triton kernels with large inputs, -# which would hit about 2.2k GBPS on Meta's H100 variant -H100_PCT_ACHIEVABLE_MEM_BW = 0.92 - -# Source: run a triton kernel with a single element read/write on an H100 and -# measure GPU time from the trace -TRITON_KERNEL_1_ELEMENT_TIME_SEC = 0.002 * 0.001 +from torchao.float8.roofline_utils import ( + get_gemm_time_sympy, + get_float8_mem_sympy, +) def benchmark_fn_in_sec(f, *args, **kwargs): @@ -78,90 +68,6 @@ def benchmark_fn_in_sec(f, *args, **kwargs): return measurement.mean -def get_tensor_memory_traffic_bytes( - dim0, - dim1, - scaling_type: str, - fuse_with_prev=False, - model_torch_compile_limitations=False, -): - # assumes input bf16, output f8 - numel = dim0 * dim1 - - if scaling_type == "dynamic": - # x_bf16 = ... - # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp - # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs - # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 - - if fuse_with_prev: - kernel_1_rw = 0 - else: - # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) - kernel_1_rw = BYTES_PER_EL_BF16 * numel - - # kernel 3: read in bf16, write twice in float8 (row-major and col-major) - kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel - - if model_torch_compile_limitations: - # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) - # has an extra memory read of the input in fp8 - # context: https://github.com/pytorch/pytorch/issues/130015 - tc_adjustment = numel * BYTES_PER_EL_FLOAT8 - else: - tc_adjustment = 0 - - return kernel_1_rw + kernel_3_rw + tc_adjustment - - else: - assert scaling_type == "delayed", "unsupported" - # x_bf16 = ... - # kernel 1: x_bf16 -> max_abs_stage_1_and_to_float8 -> x_float8, tmp - # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs - # kernel 3 (not modeled): scale -> reciprocal -> inv_scale - - if fuse_with_prev: - kernel_1_r = 0 - else: - kernel_1_r = numel * BYTES_PER_EL_BF16 - # write twice: once in row major, once in col-major - kernel_1_w = numel * BYTES_PER_EL_FLOAT8 * 2 - - if model_torch_compile_limitations: - # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) - # has an extra memory read of the input in fp8 - # context: https://github.com/pytorch/pytorch/issues/130015 - tc_adjustment = numel * BYTES_PER_EL_FLOAT8 - - # https://github.com/pytorch/pytorch/issues/128063 - # instead of - # kernel 1: x_bf16 -> max(abs(x)), x_fp8 - # kernel 2: not modeled - # kernel 3: not modeled - # we get - # kernel 1: x_bf16 -> max(abs(x)) - # reads: same as before - # writes: 0 - # ... - # kernel 4: x_bf16, scale -> x_fp8 - # reads: numel * BYTES_PER_EL_BF16 - # writes: 2 * numel * BYTES_PER_EL_FLOAT8 - # Note that assuming worst case, this issue brings the memory - # traffic for delayed scaling to be equal to that of dynamic scaling. - tc_adjustment += ( - # subtract writes from kernel 1 - -1 * 2 * numel * BYTES_PER_EL_FLOAT8 - # add reads for kernel 4 - + numel * BYTES_PER_EL_BF16 - # add writes for kernel 4 - + 2 * numel * BYTES_PER_EL_FLOAT8 - ) - else: - tc_adjustment = 0 - - return kernel_1_r + kernel_1_w + tc_adjustment - - def get_gemm_times_cache(gemm_benchmarks_file: str): cache = {} with open(gemm_benchmarks_file, 'r') as f: @@ -176,114 +82,6 @@ def get_gemm_times_cache(gemm_benchmarks_file: str): return cache -def get_gemm_time_sympy(M, K, N, dtype): - gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N - if dtype is torch.bfloat16: - peak_tops = H100_BF16_PEAK_TOPS - elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - peak_tops = H100_FP8_PEAK_TOPS - gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS - return gemm_time_s - - -def get_float8_mem_sympy( - M, - K, - N, - model_torch_compile_limitations: bool = False, - scaling_type_input: str = "dynamic", - scaling_type_weight: str = "dynamic", - scaling_type_grad_output: str = "dynamic", -): - - assert scaling_type_input in ("dynamic", "delayed"), "unsupported" - assert scaling_type_weight in ("dynamic", "delayed"), "unsupported" - assert scaling_type_grad_output in ("dynamic", "delayed"), "unsupported" - - # there are three gemms in the fwd/bwd of a linear: - # - # input @ weight_t = output - # MxK @ KxN => MxN - # - # grad_output @ weight = grad_input - # MxN @ NxK => MxK - # - # input_t @ grad_output = grad_weight - # KxM @ MxN => KxN - - # - # forward - output - # - fwd_fp8_input_mem = get_tensor_memory_traffic_bytes( - M, K, scaling_type_input, fuse_with_prev=True, - model_torch_compile_limitations=model_torch_compile_limitations) - fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes( - K, N, scaling_type_weight, fuse_with_prev=False, - model_torch_compile_limitations=model_torch_compile_limitations) - fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem - - # - # backward - grad_input - # - gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes( - M, N, scaling_type_grad_output, fuse_with_prev=True, - model_torch_compile_limitations=model_torch_compile_limitations) - # already casted, assuming that we save weight from fw to bw - # TODO: model this if FSDP float8 all-gather is on - # TODO: model this if we don't save weight from fw to bw, and recompute instead - gi_fp8_weight_mem = 0 - - # - # backward - grad_weight - # - # TODO: model this if we don't save fp8 input from fw to bw - gw_fp8_input_t_mem = 0 # already casted - # this should be always 0 - gw_fp8_grad_output_mem = 0 # already casted - - bwd_fp8_total_mem = \ - gi_fp8_grad_output_mem + gi_fp8_weight_mem + \ - gw_fp8_input_t_mem + gw_fp8_grad_output_mem - fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem - fp8_mem_time_s = ( - fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW - ) - - # Adjust final estimate for small kernel launches - # note that we do this adjustment here because we are assuming a minimal - # kernel overhead in the units of seconds, and the per-gemm-input memory - # estimations are in the units of bytes. - num_extra_kernels = 0 - if scaling_type_input == "dynamic": - # second stage of max-abs reduction - num_extra_kernels += 1 - elif scaling_type_input == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 - if scaling_type_weight == "dynamic": - # second stage of max-abs reduction - num_extra_kernels += 1 - elif scaling_type_weight == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 - if scaling_type_grad_output == "dynamic": - # second stage of max-abs reduction - num_extra_kernels += 1 - elif scaling_type_grad_output == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 - - extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC - - return fp8_mem_time_s + extra_kernel_overhead_s - - def run( outfile: str, gemm_time_strategy: str = "benchmarks", diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 344a3a71a..de3ed04e8 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -1,5 +1,5 @@ # pre-train a mini Llama2 on TinyStories with INT8 quantized training -# pip install transformers sentencepiece wandb +# pip install huggingface_hub sentencepiece wandb # # BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile # INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only @@ -9,21 +9,33 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import argparse +from functools import partial from pathlib import Path import numpy as np import torch import wandb +from torch.utils.checkpoint import checkpoint from tqdm import tqdm -from transformers import LlamaConfig, LlamaForCausalLM +from torchao._models.llama.model import ModelArgs, Transformer from torchao.prototype import low_bit_optim from torchao.prototype.quantized_training import int8_weight_only_quantized_training from torchao.quantization.quant_api import quantize_ -def get_loss(model: LlamaForCausalLM, batch: torch.Tensor): - return model(batch, labels=batch).loss +# hack from fairseq +# https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py +def enable_activation_checkpointing(m: torch.nn.Module): + assert not hasattr(m, "_forward") + m._forward = m.forward + m.forward = partial(checkpoint, m.forward) + + +def get_loss(model: Transformer, batch: torch.Tensor): + logits = model(batch)[:, :-1].flatten(0, 1) + labels = batch[:, 1:].flatten() + return torch.nn.functional.cross_entropy(logits, labels) def get_tinystories(): @@ -91,17 +103,19 @@ def get_tinystories(): if args.seed is not None: torch.manual_seed(args.seed) - config = LlamaConfig( - hidden_size=args.d_model, + config = ModelArgs( + block_size=args.seq_len, + n_layer=args.depth, + n_head=args.d_model // args.head_dim, + dim=args.d_model, intermediate_size=args.ffn_size, - num_hidden_layers=args.depth, - num_attention_heads=args.d_model // args.head_dim, - max_position_embeddings=args.seq_len, - use_cache=False, ) - model = LlamaForCausalLM(config).bfloat16().cuda() + model = Transformer(config).bfloat16().cuda() + with torch.device("cuda"): + model.setup_caches(args.batch_size, args.seq_len, training=True) if args.activation_checkpointing: - model.gradient_checkpointing_enable() + for layer in model.layers: + enable_activation_checkpointing(layer) if args.quantize == "int8_weight_only": quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) elif args.quantize is not None: diff --git a/scripts/download.py b/scripts/download.py index 571e03adb..3fc89e712 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - from huggingface_hub import snapshot_download os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) try: - snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token) + snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors") except HTTPError as e: if e.response.status_code == 401: print("You need to pass a valid `--hf_token=...` to download private checkpoints.") diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 5260c7d55..d1e11ab82 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -51,6 +51,23 @@ def test_weights_only(self): else: _ = torch.load(f, weights_only=False) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_to_device(self): + from torchao.quantization import quantize_ + for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]: + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(l) + ql.to("cuda") + + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(l) + ql.to(device="cuda") + + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(l) + ql.cuda() + + if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 387e11e8b..15b85942b 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -36,6 +36,21 @@ def __init__(self, scale, device): def forward(self, x): return self.net(x) +@pytest.mark.parametrize("bit_width", bit_widths) +@pytest.mark.parametrize("group_size", group_sizes) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build") +def test_uintx_quant_on_cpu_then_move_to_cuda(bit_width, group_size): + scale = 512 + fp16_mod_on_cpu = Linear16(scale, "cpu") + quantize_(fp16_mod_on_cpu, uintx_weight_only(bit_width, group_size=group_size)) + test_input_on_cpu = torch.randn(scale*2, dtype=torch.float16, device="cpu") + output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu) + fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda") + test_input_on_cuda = test_input_on_cpu.to("cuda") + output_on_cuda = fp16_mod_on_cuda(test_input_on_cuda) + assert torch.allclose(output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3), "The output of the model on CPU and CUDA should be close" + @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 30aa73548..cf5bd8fa0 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -3,7 +3,7 @@ import pytest import threading import unittest -from typing import Any, List +from typing import Any, List, Optional from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -59,7 +59,7 @@ def init_multi_module(self) -> nn.Module: self.broadcast_module(module) return module - def init_transformer(self, weight_tying: bool) -> nn.Module: + def init_transformer(self, weight_tying: bool, dtype: Optional[torch.dtype] = None) -> nn.Module: torch.manual_seed(42) args = ModelArgs( n_layers=3, @@ -70,6 +70,8 @@ def init_transformer(self, weight_tying: bool) -> nn.Module: vocab_size=32, ) module = Transformer(args).cuda() + if dtype is not None: + module = module.to(dtype=dtype) self.broadcast_module(module) return module @@ -96,6 +98,7 @@ def test_transformer_parity(self): ScalingType.DELAYED, ], "compile_transformer_block": [False, True], + "dtype": [torch.float32, torch.bfloat16], }, self._test_transformer_parity, ) @@ -106,6 +109,7 @@ def _test_transformer_parity( precompute: bool, scaling_type_weight: ScalingType, compile_transformer_block: bool, + dtype: Optional[torch.dtype] = None, ): if not enable_fsdp_float8_all_gather and precompute: return @@ -117,7 +121,7 @@ def _test_transformer_parity( # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to # fp8 for that tied weight, incorrectly using fp8 for the embedding. weight_tying = not enable_fsdp_float8_all_gather - module = self.init_transformer(weight_tying=weight_tying).cuda() + module = self.init_transformer(weight_tying=weight_tying, dtype=dtype) ref_module = copy.deepcopy(module) float8_linear_config1 = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 58a170964..eaaccd7a5 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -150,7 +150,7 @@ def __init__(self, config: ModelArgs) -> None: self.max_batch_size = -1 self.max_seq_length = -1 - def setup_caches(self, max_batch_size, max_seq_length): + def setup_caches(self, max_batch_size, max_seq_length, training: bool = False): if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: return head_dim = self.config.dim // self.config.n_head @@ -163,16 +163,21 @@ def setup_caches(self, max_batch_size, max_seq_length): dtype = self.output.scales.dtype elif hasattr(self.output, "scales_and_zeros"): dtype = self.output.scales_and_zeros.dtype - for b in self.layers: - b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) + if not training: + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] - freqs_cis = self.freqs_cis[input_pos] + if input_pos is not None: + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + else: + mask = None + freqs_cis = self.freqs_cis[:idx.shape[1]] x = self.tok_embeddings(idx) for i, layer in enumerate(self.layers): @@ -194,7 +199,7 @@ def __init__(self, config: ModelArgs) -> None: self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps) - def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + def forward(self, x: Tensor, input_pos: Optional[Tensor], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -224,7 +229,7 @@ def load_hook(self, state_dict, prefix, *args): wv = state_dict.pop(prefix + "wv.weight") state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Optional[Tensor], input_pos: Optional[Tensor] = None) -> Tensor: bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim @@ -244,7 +249,10 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + if mask is not None: + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + else: + y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index ef96f11e7..2f0b11319 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -16,7 +16,6 @@ pack_tinygemm_scales_and_zeros, ) from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import find_multiple from torchao.dtypes.utils import ( _implements, _dispatch__torch_function__, @@ -29,14 +28,18 @@ ) from torch.utils._python_dispatch import is_traceable_wrapper_subclass from dataclasses import dataclass -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + find_multiple, + TorchAOBaseTensor, + TORCH_VERSION_AT_LEAST_2_5, +) aten = torch.ops.aten ############################### # Base Layout Tensor Subclass # ############################### -class AQTLayout(torch.Tensor): +class AQTLayout(TorchAOBaseTensor): """ Base class for the layout tensor for `AffineQuantizedTensor` """ @@ -61,19 +64,6 @@ def __repr__(self): layout_type = self.get_layout_type() return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})" - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - } - return kwargs ############################## # Tensor Subclass Definition # @@ -83,7 +73,7 @@ def _get_to_kwargs(self, *args, **kwargs): def _register_quantized_linear_dispatch(dispatch_condition, impl): _QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl -class AffineQuantizedTensor(torch.Tensor): +class AffineQuantizedTensor(TorchAOBaseTensor): """ Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: quantized_tensor = float_tensor / scale + zero_point @@ -223,7 +213,7 @@ def from_float( input_float = layout_type.pre_process(input_float) scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) - + int_data = layout_type.post_process(int_data) layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type) @@ -273,25 +263,9 @@ def from_float_static( def layout_type(self) -> LayoutType: return self.layout_tensor.layout_type - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - } - return kwargs - def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs.pop("device") - # not supported yet - kwargs.pop("memory_format") return self.__class__( self.layout_tensor.to(device), self.block_size, @@ -446,6 +420,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + if func is aten.t.default: tensor = args[0] new = tensor.__class__( @@ -576,10 +555,10 @@ def from_plain( scale: torch.Tensor, zero_point: torch.Tensor, layout_type: LayoutType - ): - + ): + assert isinstance(layout_type, TensorCoreTiledLayoutType) - + if TORCH_VERSION_AT_LEAST_2_5: int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" @@ -617,6 +596,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + if func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 90516ea19..b386f85ae 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -339,7 +339,7 @@ def copy_(func, *args, **kwargs): # Convert Non NF4Tensor into NF4 for copy in if not isinstance(copy_in, NF4Tensor): copy_in_nf4 = NF4Tensor.from_tensor( - copy_in, original.block_size, original.scaler_block_size + copy_in.to(original.device), original.block_size, original.scaler_block_size ) return original.copy_(copy_in_nf4) diff --git a/torchao/dtypes/uintx/Uintx.py b/torchao/dtypes/uintx/Uintx.py index 9fdaab0f4..12187f8d8 100644 --- a/torchao/dtypes/uintx/Uintx.py +++ b/torchao/dtypes/uintx/Uintx.py @@ -105,6 +105,35 @@ def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1): return cls(shards, int_data.shape, bit_width, pack_dim) + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + if "copy" in kwargs: + return super().to(*args, **kwargs) + kwargs = self._get_to_kwargs(*args, **kwargs) + if "device" in kwargs: + return self.__class__( + list(shard.to(kwargs["device"]) for shard in self.get_shards()), + self.packed_shape, + self.bit_width, + self.pack_dim, + ) + return super().to(*args, **kwargs) + + + implements = UintxTensor.implements diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt index 1c1a779db..5227ff109 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt @@ -1,4 +1,8 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. cmake_minimum_required(VERSION 3.19) project(benchmarks) diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp index d03a3bfca..926d47523 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp index 631bab42d..8e3ec0516 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp index 942855c01..868f01648 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh index 1c38bc39e..08f835836 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh @@ -1,4 +1,10 @@ #!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + # Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index 8f6fe4a5b..fce5abba4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h index 6bd06e0df..4861edbee 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h @@ -1,3 +1,9 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + #pragma once #define TORCHAO_ALWAYS_INLINE __attribute__((always_inline)) diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h index b74714809..b76b146ba 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h index c30949d72..d4d3f391f 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h index 626bff348..19d4fe5bd 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h index efcce4bb2..2fcd8d131 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include @@ -218,7 +222,21 @@ void kernel_impl( if constexpr (has_clamp) { res = clamp(res, clamp_min, clamp_max); } - vst1q_f32(output + m_idx * output_m_stride + n_idx, res); + + // Store result + int remaining = n - n_idx; + float* store_loc = output + m_idx * output_m_stride + n_idx; + if (remaining >= 4) { + vst1q_f32(store_loc, res); + } else if (remaining >= 3) { + vst1_f32(store_loc, vget_low_f32(res)); + *(store_loc + 2) = res[2]; + } else if (remaining >= 2) { + vst1_f32(store_loc, vget_low_f32(res)); + } else { + *(store_loc) = res[0]; + } + } // n_idx activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr); } // m_idx diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h index 37f254c98..4974e909d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include @@ -290,8 +294,34 @@ void kernel_impl( res_0123 = vec_clamp(res_0123, vec_min, vec_max); res_4567 = vec_clamp(res_4567, vec_min, vec_max); } - vst1q_f32(output + m_idx * output_m_stride + n_idx, res_0123); - vst1q_f32(output + m_idx * output_m_stride + n_idx + 4, res_4567); + + // Store result + int remaining = n - n_idx; + float* store_loc = output + m_idx * output_m_stride + n_idx; + if (remaining >= 8) { + vst1q_f32(store_loc, res_0123); + vst1q_f32(store_loc + 4, res_4567); + } else if (remaining >= 7) { + vst1q_f32(store_loc, res_0123); + vst1_f32(store_loc + 4, vget_low_f32(res_4567)); + *(store_loc + 6) = res_4567[2]; + } else if (remaining >= 6) { + vst1q_f32(store_loc, res_0123); + vst1_f32(store_loc + 4, vget_low_f32(res_4567)); + } else if (remaining >= 5) { + vst1q_f32(store_loc, res_0123); + *(store_loc + 4) = res_4567[0]; + } else if (remaining >= 4) { + vst1q_f32(store_loc, res_0123); + } else if (remaining >= 3) { + vst1_f32(store_loc, vget_low_f32(res_0123)); + *(store_loc + 2) = res_0123[2]; + } else if (remaining >= 2) { + vst1_f32(store_loc, vget_low_f32(res_0123)); + } else { + *store_loc = res_0123[0]; + } + } // n_idx activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr); } // m_idx diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h index 7c2ba7d07..a67e2b0d1 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h index 2607a2371..cf3af21b5 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp index 3aed6d019..523fd9360 100644 --- a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h index af4983659..a8214cc44 100644 --- a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h +++ b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp b/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp index ab1f26180..3aa7f4a5d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp @@ -1,15 +1,22 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include +#include int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum( const int8_t* vals, int size) { + assert(size >= 1); + int32_t res = 0; int i = 0; #pragma unroll(4) - for (; i < size; i += 16) { + for (; i + 15 < size; i += 16) { int8x16_t vec_vals = vld1q_s8(vals + i); res += (int)(vaddlvq_s8(vec_vals)); } diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp b/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp index ed7ca01bb..1516f3cef 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp @@ -1,23 +1,37 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include +#include void torchao::kernels::cpu::aarch64::reduction::find_min_and_max( float32_t& min, float32_t& max, const float32_t* vals, int size) { - float32x4_t mins = vdupq_n_f32(0.0); - float32x4_t maxes = vdupq_n_f32(0.0); + assert(size > 0); + + // Needed in case size < 4 so we don't compare to + // uninitialized min/max values + min = vals[0]; + max = min; + int i = 0; - for (; i < size; i += 8) { - float32x4_t v1 = vld1q_f32(vals + i); - float32x4_t v2 = vld1q_f32(vals + i + 4); - mins = vminq_f32(v1, v2); - maxes = vmaxq_f32(v1, v2); + if (i + 3 < size) { + float32x4_t mins = vld1q_f32(vals + i); + float32x4_t maxes = mins; + i += 4; + for (; i + 3 < size; i += 4) { + float32x4_t v = vld1q_f32(vals + i); + mins = vminq_f32(mins, v); + maxes = vmaxq_f32(maxes, v); + } + min = vminvq_f32(mins); + max = vmaxvq_f32(maxes); } - min = vminvq_f32(mins); - max = vmaxvq_f32(maxes); // Remainder while (i < size) { diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h b/torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h index 25110f4f3..f027c8530 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 1b78f25b9..8e281ed79 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -1,4 +1,8 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. cmake_minimum_required(VERSION 3.19) project(tests) @@ -35,6 +39,14 @@ target_link_libraries( dep ) +add_executable(test_reduction test_reduction.cpp) +target_link_libraries( + test_reduction + PRIVATE + GTest::gtest_main + dep +) + add_executable(test_bitpacking test_bitpacking.cpp) target_link_libraries( test_bitpacking @@ -61,6 +73,7 @@ target_link_libraries( include(GoogleTest) gtest_discover_tests(test_quantization) +gtest_discover_tests(test_reduction) gtest_discover_tests(test_bitpacking) gtest_discover_tests(test_linear) gtest_discover_tests(test_valpacking) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 308455206..ce8861ac6 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -1,4 +1,10 @@ #!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests @@ -7,7 +13,8 @@ cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/e cmake --build ${CMAKE_OUT} # Run - ${CMAKE_OUT}/test_quantization - ${CMAKE_OUT}/test_bitpacking - ${CMAKE_OUT}/test_linear - ${CMAKE_OUT}/test_valpacking +${CMAKE_OUT}/test_quantization +${CMAKE_OUT}/test_reduction +${CMAKE_OUT}/test_bitpacking +${CMAKE_OUT}/test_linear +${CMAKE_OUT}/test_valpacking diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index 9e530da8e..28a46f8e0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 4b61c162e..22a2ed0f8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include @@ -10,12 +14,11 @@ float kTol = 0.0001; template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot() { - int m = 7; - int k = 128; - int n = 13; - int group_size = 32; - +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot( + int m, + int k, + int n, + int group_size) { auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -50,7 +53,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot test_case.weight_scales.data(), /*weight_zeros=*/test_case.weight_zeros.data()); - std::vector output(m * k); + std::vector output(m * n); kernel( output.data(), /*output_m_stride=*/n, @@ -72,70 +75,53 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, Standard) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasWeightZeros) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = true; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + true /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasBias) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = true; - constexpr bool has_clamp = false; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/>( + /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasClamp) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = true; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); } template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot() { - int m = 7; - int k = 64; - int n = 13; - int group_size = 16; - +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( + int m, + int k, + int n, + int group_size) { auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -170,7 +156,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot test_case.weight_scales.data(), /*weight_zeros=*/test_case.weight_zeros.data()); - std::vector output(m * k); + std::vector output(m * n); kernel( output.data(), /*output_m_stride=*/n, @@ -192,70 +178,66 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, Standard) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasWeightZeros) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = true; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + true /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasBias) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = true; - constexpr bool has_clamp = false; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasClamp) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = true; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot() { - int m = 7; - int k = 64; - int n = 13; - int group_size = 16; +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, + NLessThan4) { + for (int n = 1; n < 4; n++) { + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); + } +} +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( + int m, + int k, + int n, + int group_size) { auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -290,7 +272,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot test_case.weight_scales.data(), /*weight_zeros=*/test_case.weight_zeros.data()); - std::vector output(m * k); + std::vector output(m * n); kernel( output.data(), /*output_m_stride=*/n, @@ -312,59 +294,56 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, Standard) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasWeightZeros) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = true; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + true /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasBias) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = true; - constexpr bool has_clamp = false; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasClamp) { - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = true; - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp>(); + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, + NLessThan8) { + for (int n = 1; n < 8; n++) { + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); + } } diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp index 6fac44244..74fc5ef52 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_reduction.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_reduction.cpp new file mode 100644 index 000000000..16eb87fbb --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_reduction.cpp @@ -0,0 +1,60 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include + +TEST(test_find_min_and_sum, SizeHasRemainderAfterDivideBy4) { + auto vals = torchao::get_random_vector(19, -1.0, 1.0); + float vmin, vmax; + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, vmax, vals.data(), vals.size()); + + auto expected_vmin = *std::min_element(vals.begin(), vals.end()); + auto expected_vmax = *std::max_element(vals.begin(), vals.end()); + EXPECT_EQ(vmin, expected_vmin); + EXPECT_EQ(vmax, expected_vmax); +} + +TEST(test_find_min_and_sum, SizeSmallerThan4) { + auto vals = torchao::get_random_vector(3, -1.0, 1.0); + float vmin, vmax; + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, vmax, vals.data(), vals.size()); + + auto expected_vmin = *std::min_element(vals.begin(), vals.end()); + auto expected_vmax = *std::max_element(vals.begin(), vals.end()); + EXPECT_EQ(vmin, expected_vmin); + EXPECT_EQ(vmax, expected_vmax); +} + +TEST(test_compute_sum, ExpectedOutput) { + auto vals = torchao::get_random_lowbit_vector(/*size=*/19, /*int8*/ 3); + int sum = torchao::kernels::cpu::aarch64::reduction::compute_sum( + (int8_t*)vals.data(), vals.size()); + int expected_sum = std::accumulate(vals.begin(), vals.end(), 0); + EXPECT_EQ(sum, expected_sum); +} + +TEST(test_compute_sum, SizeHasRemainderAfterDivideBy16) { + auto vals = torchao::get_random_lowbit_vector(/*size=*/17, /*int8*/ 3); + int sum = torchao::kernels::cpu::aarch64::reduction::compute_sum( + (int8_t*)vals.data(), vals.size()); + int expected_sum = std::accumulate(vals.begin(), vals.end(), 0); + EXPECT_EQ(sum, expected_sum); +} + +TEST(test_compute_sum, SizeSmallerThan16) { + auto vals = torchao::get_random_lowbit_vector(/*size=*/3, /*int8*/ 3); + int sum = torchao::kernels::cpu::aarch64::reduction::compute_sum( + (int8_t*)vals.data(), vals.size()); + int expected_sum = std::accumulate(vals.begin(), vals.end(), 0); + EXPECT_EQ(sum, expected_sum); +} diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 278209587..4e5083d9e 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_valpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_valpacking.cpp index 5497a62f7..02be12a67 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_valpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_valpacking.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp b/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp index ace1d1697..8cbf03695 100644 --- a/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h b/torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h index ecfb16ac8..383f71780 100644 --- a/torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt index 72aa539d2..61e5eeae2 100644 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt @@ -1,4 +1,8 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. cmake_minimum_required(VERSION 3.19) project(benchmarks) diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp b/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp index 48df081f5..ad6563eab 100644 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp +++ b/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh index a5451777e..18da0e992 100644 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh +++ b/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh @@ -1,5 +1,9 @@ #!/bin/bash -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. # Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} diff --git a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h b/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h index 177ea5772..6196e69fa 100644 --- a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h +++ b/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h index c006b45ce..73ca5e073 100644 --- a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once diff --git a/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt index a86005b73..73314651c 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt @@ -1,4 +1,8 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. project(examples) diff --git a/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h b/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h index 1bb9500bc..06e3bfc43 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h +++ b/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once #include diff --git a/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh b/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh index 2d1083058..9c244e54c 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh +++ b/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh @@ -1,11 +1,18 @@ #!/bin/bash -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. +export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" +echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ + -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples \ -B ${CMAKE_OUT} \ -DOpenMP_ROOT=$(brew --prefix libomp) diff --git a/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp b/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp index 00783d4b3..ba3e5b29b 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp +++ b/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp b/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp index 5f106fcee..5fb24c683 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp +++ b/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include #include diff --git a/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp b/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp index c6f0da0ac..5408e426b 100644 --- a/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp +++ b/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp @@ -30,7 +30,8 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight( has_weight_zeros, has_bias, has_clamp); - float output[m * n]; + + auto output = std::vector(m * n); for (auto linear_scheduling_policy : {LinearTileSchedulingPolicy::single_mc_parallel_nc, @@ -82,7 +83,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight( linear_tiling_params, linear_scheduling_policy, activation_data_buffer.get(), - output, + output.data(), m, n, k, diff --git a/torchao/experimental/kernels/cpu/macro.h b/torchao/experimental/kernels/cpu/macro.h index 441c8b8ce..62c73f1f3 100644 --- a/torchao/experimental/kernels/cpu/macro.h +++ b/torchao/experimental/kernels/cpu/macro.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once diff --git a/torchao/experimental/kernels/cpu/memory.h b/torchao/experimental/kernels/cpu/memory.h index 55f0e6bf0..cf3220f0e 100644 --- a/torchao/experimental/kernels/cpu/memory.h +++ b/torchao/experimental/kernels/cpu/memory.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once @@ -9,12 +13,9 @@ namespace torchao { -using aligned_byte_ptr = - std::unique_ptr; +using aligned_byte_ptr = std::unique_ptr; -aligned_byte_ptr make_aligned_byte_ptr( - size_t alignment, - size_t size) { +aligned_byte_ptr make_aligned_byte_ptr(size_t alignment, size_t size) { // Adjust size to next multiple of alignment >= size size_t adjusted_size = ((size + alignment - 1) / alignment) * alignment; diff --git a/torchao/experimental/kernels/cpu/parallel.h b/torchao/experimental/kernels/cpu/parallel.h index b61223c76..b1fe4dea6 100644 --- a/torchao/experimental/kernels/cpu/parallel.h +++ b/torchao/experimental/kernels/cpu/parallel.h @@ -1,4 +1,8 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #pragma once diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 868d4f52a..54613e5b0 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -209,7 +209,7 @@ def pad_tensor_for_matmul( Args: tensor: The tensor to pad. - both: Whether to pad both dimensions or just the second dimension. + dims: Dimensions to pad. Returns: torch.Tensor: The padded tensor. diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 3f10784f7..81859de4b 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -59,7 +59,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: return # inf-norm is equivalent to max(abs(w)) - max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float32) # Partial amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f441009c4..ccf83d7ce 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -151,7 +151,7 @@ def from_float( Create an nn.Linear with fp8 compute from another nn.Linear Args: - mod (torch.nn.Linear): nn.Linear to convert + module (torch.nn.Linear): nn.Linear to convert quant_config (QuantConfig): Configuration for the weight and activation casting """ forward_config = ScaledMMConfig( diff --git a/torchao/float8/roofline_utils.py b/torchao/float8/roofline_utils.py new file mode 100644 index 000000000..490435fbf --- /dev/null +++ b/torchao/float8/roofline_utils.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +BYTES_PER_EL_FLOAT8 = 1 +BYTES_PER_EL_BF16 = 2 + +# https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity +H100_BF16_PEAK_TOPS = 989e12 +H100_FP8_PEAK_TOPS = 1979e12 + +# 2.4 TB per second, custom to Meta's H100 variant +H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12 + +# based on quick experimental observation with sample large inputs +H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6 + +# based on previous experience looking at pointwise triton kernels with large inputs, +# which would hit about 2.2k GBPS on Meta's H100 variant +H100_PCT_ACHIEVABLE_MEM_BW = 0.92 + +# Source: run a triton kernel with a single element read/write on an H100 and +# measure GPU time from the trace +TRITON_KERNEL_1_ELEMENT_TIME_SEC = 0.002 * 0.001 + + +def get_tensor_memory_traffic_bytes( + dim0, + dim1, + scaling_type: str, + fuse_with_prev=False, + model_torch_compile_limitations=False, +): + # assumes input bf16, output f8 + numel = dim0 * dim1 + + if scaling_type == "dynamic": + # x_bf16 = ... + # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp + # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs + # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 + + if fuse_with_prev: + kernel_1_rw = 0 + else: + # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + + # kernel 3: read in bf16, write twice in float8 (row-major and col-major) + kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel + + if model_torch_compile_limitations: + # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) + # has an extra memory read of the input in fp8 + # context: https://github.com/pytorch/pytorch/issues/130015 + tc_adjustment = numel * BYTES_PER_EL_FLOAT8 + else: + tc_adjustment = 0 + + return kernel_1_rw + kernel_3_rw + tc_adjustment + + else: + assert scaling_type == "delayed", "unsupported" + # x_bf16 = ... + # kernel 1: x_bf16 -> max_abs_stage_1_and_to_float8 -> x_float8, tmp + # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs + # kernel 3 (not modeled): scale -> reciprocal -> inv_scale + + if fuse_with_prev: + kernel_1_r = 0 + else: + kernel_1_r = numel * BYTES_PER_EL_BF16 + # write twice: once in row major, once in col-major + kernel_1_w = numel * BYTES_PER_EL_FLOAT8 * 2 + + if model_torch_compile_limitations: + # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) + # has an extra memory read of the input in fp8 + # context: https://github.com/pytorch/pytorch/issues/130015 + tc_adjustment = numel * BYTES_PER_EL_FLOAT8 + + # https://github.com/pytorch/pytorch/issues/128063 + # instead of + # kernel 1: x_bf16 -> max(abs(x)), x_fp8 + # kernel 2: not modeled + # kernel 3: not modeled + # we get + # kernel 1: x_bf16 -> max(abs(x)) + # reads: same as before + # writes: 0 + # ... + # kernel 4: x_bf16, scale -> x_fp8 + # reads: numel * BYTES_PER_EL_BF16 + # writes: 2 * numel * BYTES_PER_EL_FLOAT8 + # Note that assuming worst case, this issue brings the memory + # traffic for delayed scaling to be equal to that of dynamic scaling. + tc_adjustment += ( + # subtract writes from kernel 1 + -1 * 2 * numel * BYTES_PER_EL_FLOAT8 + # add reads for kernel 4 + + numel * BYTES_PER_EL_BF16 + # add writes for kernel 4 + + 2 * numel * BYTES_PER_EL_FLOAT8 + ) + else: + tc_adjustment = 0 + + return kernel_1_r + kernel_1_w + tc_adjustment + + +def get_gemm_time_sympy(M, K, N, dtype): + gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N + if dtype is torch.bfloat16: + peak_tops = H100_BF16_PEAK_TOPS + elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + peak_tops = H100_FP8_PEAK_TOPS + gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS + return gemm_time_s + + +def get_float8_mem_sympy( + M, + K, + N, + model_torch_compile_limitations: bool = False, + scaling_type_input: str = "dynamic", + scaling_type_weight: str = "dynamic", + scaling_type_grad_output: str = "dynamic", +): + + assert scaling_type_input in ("dynamic", "delayed"), "unsupported" + assert scaling_type_weight in ("dynamic", "delayed"), "unsupported" + assert scaling_type_grad_output in ("dynamic", "delayed"), "unsupported" + + # there are three gemms in the fwd/bwd of a linear: + # + # input @ weight_t = output + # MxK @ KxN => MxN + # + # grad_output @ weight = grad_input + # MxN @ NxK => MxK + # + # input_t @ grad_output = grad_weight + # KxM @ MxN => KxN + + # + # forward - output + # + fwd_fp8_input_mem = get_tensor_memory_traffic_bytes( + M, K, scaling_type_input, fuse_with_prev=True, + model_torch_compile_limitations=model_torch_compile_limitations) + fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes( + K, N, scaling_type_weight, fuse_with_prev=False, + model_torch_compile_limitations=model_torch_compile_limitations) + fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem + + # + # backward - grad_input + # + gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes( + M, N, scaling_type_grad_output, fuse_with_prev=True, + model_torch_compile_limitations=model_torch_compile_limitations) + # already casted, assuming that we save weight from fw to bw + # TODO: model this if FSDP float8 all-gather is on + # TODO: model this if we don't save weight from fw to bw, and recompute instead + gi_fp8_weight_mem = 0 + + # + # backward - grad_weight + # + # TODO: model this if we don't save fp8 input from fw to bw + gw_fp8_input_t_mem = 0 # already casted + # this should be always 0 + gw_fp8_grad_output_mem = 0 # already casted + + bwd_fp8_total_mem = \ + gi_fp8_grad_output_mem + gi_fp8_weight_mem + \ + gw_fp8_input_t_mem + gw_fp8_grad_output_mem + fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem + fp8_mem_time_s = ( + fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW + ) + + # Adjust final estimate for small kernel launches + # note that we do this adjustment here because we are assuming a minimal + # kernel overhead in the units of seconds, and the per-gemm-input memory + # estimations are in the units of bytes. + num_extra_kernels = 0 + if scaling_type_input == "dynamic": + # second stage of max-abs reduction + num_extra_kernels += 1 + elif scaling_type_input == "delayed": + # second stage of max-abs reduction + num_extra_kernels += 1 + # reciprocal of scale + num_extra_kernels += 1 + if scaling_type_weight == "dynamic": + # second stage of max-abs reduction + num_extra_kernels += 1 + elif scaling_type_weight == "delayed": + # second stage of max-abs reduction + num_extra_kernels += 1 + # reciprocal of scale + num_extra_kernels += 1 + if scaling_type_grad_output == "dynamic": + # second stage of max-abs reduction + num_extra_kernels += 1 + elif scaling_type_grad_output == "delayed": + # second stage of max-abs reduction + num_extra_kernels += 1 + # reciprocal of scale + num_extra_kernels += 1 + + extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC + + return fp8_mem_time_s + extra_kernel_overhead_s diff --git a/torchao/ops.py b/torchao/ops.py index f9949af2b..5bb827163 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -115,7 +115,7 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens Args: packed_w: torch.tensor: 4D tensor with shape `(N / 8) x (K / (inner_k_tiles * 16)) x 32 x inner_k_tiles / 2`, dtype is torch.int32 scales_and_zeros: torch.tensor: 3D tensor with shape `numQGroups x N x 2`, dtype is torch.bfloat16 where numQGroups is K / qGroupSize - qGroupSize: int + group_size: int inner_k_tiles: int Returns: diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index d3faa5d4c..c8ecaa60d 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -6,7 +6,10 @@ ) from typing import Callable from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TorchAOBaseTensor, + TORCH_VERSION_AT_LEAST_2_5, +) __all__ = [ "LinearActivationQuantizedTensor", @@ -15,7 +18,7 @@ aten = torch.ops.aten -class LinearActivationQuantizedTensor(torch.Tensor): +class LinearActivationQuantizedTensor(TorchAOBaseTensor): """ Applies activation quantization for linear operator """ @@ -74,20 +77,6 @@ def _apply_fn_to_data(self, fn): self.input_quant_func, ) - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - } - return kwargs - def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) return self.__class__( diff --git a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py index 6ebe458a4..6ec933435 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py @@ -12,7 +12,7 @@ def intN_weight_only(group_size=32, n=8, symmetric=False): ''' Apply int N-bit weight only quantization to a linear layer. Args: - `groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32] + `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32] `n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2] Usage: from torchao.quantization import quantize_ diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 1ac97de3c..bd4656f6c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -33,7 +33,7 @@ class MappingType(Enum): """How floating point number is mapped to integer number - symmetric mapping means floating point range is symetrically mapped to integer range + symmetric mapping means floating point range is symmetrically mapped to integer range let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) @@ -167,7 +167,7 @@ def quantize_affine( output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) @@ -287,11 +287,11 @@ def dequantize_affine( e.g. when size is the same as the input tensor dimension, we are using per tensor quantization scale (Tensor): quantization parameter for affine quantization zero_point (Tensor): quantization parameter for affine quantization - dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor quant_min (Optional[int]): minimum quantized value for input Tensor quant_max (Optional[int]): maximum quantized value for input Tensor output_dtype (torch.dtype): dtype for output Tensor, default is fp32 - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) @@ -413,7 +413,7 @@ def fake_quantize_affine( quant_dtype (torch.dtype): desired quantized dtype for determining and validating quant_min and quant_max values. quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) @@ -549,7 +549,7 @@ def choose_qparams_affine( If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) diff --git a/torchao/utils.py b/torchao/utils.py index 9e0d3fb02..329d4790f 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -20,6 +20,7 @@ "_register_custom_op", "get_model_size_in_bytes", "unwrap_tensor_subclass", + "TorchAOBaseTensor", "TORCH_VERSION_AT_LEAST_2_2", "TORCH_VERSION_AT_LEAST_2_3", "TORCH_VERSION_AT_LEAST_2_4", @@ -284,6 +285,30 @@ def unwrap_tensor_subclass(model, filter_fn=None): unwrap_tensor_subclass(child) return model +class TorchAOBaseTensor(torch.Tensor): + """A util tensor subclass that provides commonly used functions + """ + def _get_to_kwargs(self, *args, **kwargs): + # `torch._C._nn._parse_to` can't handle `layout` argument + for arg in args: + if isinstance(arg, torch.layout): + args.remove(arg) + if "layout" in kwargs: + kwargs.pop("layout") + # ignoring `non_blocking` and `memory_format` args since these are not + # very useful for most of the tensor subclasses + # if in the future there are use cases that need these, we'd recommend + # to override `_get_to_kwargs` and return these args + device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + kwargs = { + "device": device, + "dtype": dtype, + } + return kwargs + + def parse_version(version_string): # Extract just the X.Y.Z part from the version string