Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,5 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import quantize_and_dequantize
4 changes: 1 addition & 3 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,5 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import quantize_and_dequantize
118 changes: 94 additions & 24 deletions keras/src/quantizers/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend import KerasTensor
from keras.src.backend import any_symbolic_tensors
from keras.src.backend.common.backend_utils import canonicalize_axis
from keras.src.backend.common.backend_utils import standardize_axis_for_numpy
from keras.src.ops.operation import Operation

"""Int8-related classes and methods"""

Expand Down Expand Up @@ -130,17 +133,20 @@ def get_config(self):

def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
"""Adjusts and nudges the quantization range for better accuracy."""
# Use higher precision for the computation.
compute_dtype = backend.result_type(min_range.dtype, "float32")
min_range = ops.cast(min_range, compute_dtype)
max_range = ops.cast(max_range, compute_dtype)

quant_max = ops.cast(ops.subtract(ops.power(2, num_bits), 1.0), "float32")

quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32")
quant_max = (1 << num_bits) - 1
quant_min = 0 if not narrow_range else 1
diff_range = ops.subtract(max_range, min_range)

# Calculate the scale and ensure it's positive
scale = ops.divide(
ops.subtract(max_range, min_range), ops.subtract(quant_max, quant_min)
)
scale = ops.divide(diff_range, quant_max - quant_min)

inv_scale = ops.reciprocal(scale)
# Re-calculate the inverse to avoid loss of precision
inv_scale = ops.divide(quant_max - quant_min, diff_range)

# Calculate the zero point from the min range
zero_point_from_min = quant_min - ops.divide(min_range, scale)
Expand All @@ -158,17 +164,37 @@ def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
return nudged_min, nudged_max, scale, inv_scale


@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel")
class FakeQuantWithMinMaxVars(Operation):
def __init__(self, num_bits=8, narrow_range=False, axis=None):
super().__init__()
self.num_bits = num_bits
self.narrow_range = narrow_range
self.axis = axis

def call(self, inputs, min_vals, max_vals):
return fake_quant_with_min_max_vars(
inputs,
min_vals,
max_vals,
num_bits=self.num_bits,
narrow_range=self.narrow_range,
axis=self.axis,
)

def compute_output_spec(self, inputs, min_vals, max_vals):
return KerasTensor(inputs.shape, dtype=inputs.dtype)


@keras_export("keras.quantizers.fake_quant_with_min_max_vars")
def fake_quant_with_min_max_vars(
inputs,
min_vals,
max_vals,
num_bits,
num_bits=8,
narrow_range=False,
axis=None,
):
"""
Perform per-tensor or per-channel fake quantization.
"""Perform per-tensor or per-channel fake quantization.

`[min_vals, max_vals]` define the clamping range for the `inputs`.

Expand All @@ -183,27 +209,68 @@ def fake_quant_with_min_max_vars(
`max_vals` to be trained.

Args:
inputs: Input tensor of float dtype.
inputs: Input Keras tensor of float dtype.
min_vals: A global minimum scalar or a per-channel minimum tensor.
max_vals: A global maximum scalar or a per-channel maximum tensor.
num_bits: Quantization bit width (e.g., `8` for int8).
narrow_range: Whether to use narrow quantization range.
num_bits: Quantization bit width (e.g., `8` for int8). Defaults to `8`.
narrow_range: Whether to use narrow quantization range. Defaults to
`False`.
axis: Axis along which to perform per-channel quantization. If `None`,
per-tensor quantization is performed. Defaults to `None`.


Returns:
Fake-quantized tensor
Tensor: A Keras tensor with fake quantization applied.
"""
if any_symbolic_tensors((inputs,)):
return FakeQuantWithMinMaxVars().symbolic_call(
inputs, min_vals, max_vals
)

inputs = ops.convert_to_tensor(inputs)
min_vals = ops.convert_to_tensor(min_vals)
max_vals = ops.convert_to_tensor(max_vals)
num_bits = int(num_bits)

if axis is not None:
axis = canonicalize_axis(axis, inputs.ndim)

# Shortcut for TensorFlow backend by using `tf.quantization.fake_quant_*`
# apis. This is necessary to be recognizable for the TFLite converter.
if backend.backend() == "tensorflow":
import tensorflow as tf

# `tf.quantization.fake_quant_*` only supports float32.
dtype = backend.standardize_dtype(inputs.dtype)
if axis is None:
outputs = tf.quantization.fake_quant_with_min_max_vars(
ops.cast(inputs, "float32"),
ops.cast(ops.reshape(min_vals, ()), "float32"),
ops.cast(ops.reshape(max_vals, ()), "float32"),
num_bits=num_bits,
narrow_range=narrow_range,
)
return ops.cast(outputs, dtype=dtype)
else:
# `tf.quantization.fake_quant_with_min_max_vars_per_channel` only
# supports the last channel for the per-channel quantization. We
# use `ops.swapaxes` for the pre- and post-processing.
last_axis = inputs.ndim - 1
inputs = ops.swapaxes(inputs, axis, last_axis)
outputs = tf.quantization.fake_quant_with_min_max_vars_per_channel(
ops.cast(inputs, "float32"),
ops.cast(min_vals, "float32"),
ops.cast(max_vals, "float32"),
num_bits=num_bits,
narrow_range=narrow_range,
)
outputs = ops.cast(outputs, dtype=dtype)
return ops.swapaxes(outputs, last_axis, axis)

@ops.custom_gradient
def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
dtype = backend.standardize_dtype(x.dtype)

# Calculate quantization parameters for all channels at once
nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge(
min_val, max_val, num_bits, narrow_range
Expand All @@ -212,7 +279,9 @@ def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
quant_zero = ops.floor(
ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
)
x_clamped = ops.clip(x, nudged_min, nudged_max)
x_clamped = ops.clip(
x, ops.cast(nudged_min, x.dtype), ops.cast(nudged_max, x.dtype)
)
x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
result = ops.multiply(
ops.floor(
Expand All @@ -225,33 +294,34 @@ def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
),
scale,
)
result = ops.cast(result, dtype=dtype)

# Create gradient mask for all channels
masks = ops.cast(
(x >= nudged_min) & (x <= nudged_max),
dtype="float32",
masks = ops.logical_and(
ops.greater_equal(x, nudged_min), ops.less_equal(x, nudged_max)
)

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args

# Gradient for x
dx = ops.multiply(upstream, masks)
dx = ops.where(masks, upstream, 0.0)
axes = [i for i in range(len(dx.shape)) if i != axis]

# Gradient for min_val
# When x is clipped to min, the gradient flows to min_val
min_mask = ops.cast(x <= nudged_min, dtype="float32")
grad_min = ops.multiply(upstream, min_mask)
min_mask = ops.less_equal(x, nudged_min)
grad_min = ops.where(min_mask, upstream, 0.0)
if axis is not None:
grad_min = ops.sum(grad_min, axis=axes)
else:
grad_min = ops.sum(grad_min)

# Gradient for max_val
# When x is clipped to max, the gradient flows to max_val
max_mask = ops.cast(x >= nudged_max, dtype="float32")
grad_max = ops.multiply(upstream, max_mask)
max_mask = ops.greater_equal(x, nudged_max)
grad_max = ops.where(max_mask, upstream, 0.0)
if axis is not None:
grad_max = ops.sum(grad_max, axis=axes)
else:
Expand Down
91 changes: 68 additions & 23 deletions keras/src/quantizers/quantizers_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys

import pytest
from absl.testing import parameterized

from keras.src import backend
Expand Down Expand Up @@ -104,6 +107,17 @@ def test_quantize_and_dequantize(self):
# A loose assertion due to an expected quantization error
self.assertAllClose(qdq_values, values, atol=5e-1)

@parameterized.named_parameters(
("per_tensor", None),
("per_channel", -1),
)
def test_fake_quant_with_min_max_vars_symbolic(self, axis):
x = backend.KerasTensor((2, 3, 4))
y = quantizers.fake_quant_with_min_max_vars(x, -3.0, 3.0, axis=axis)

self.assertIsInstance(y, backend.KerasTensor)
self.assertEqual(y.shape, (2, 3, 4))

@parameterized.named_parameters(
[
{
Expand Down Expand Up @@ -334,7 +348,11 @@ def test_quantize_and_dequantize(self):
},
]
)
def test_op(
@pytest.mark.skipif(
backend.backend() not in ("tensorflow", "jax", "torch"),
reason=f"{backend.backend()} doesn't support `custom_gradient`.",
)
def test_fake_quant_with_min_max_vars(
self,
input_mins,
input_maxs,
Expand Down Expand Up @@ -401,6 +419,8 @@ def test_op(
initial_gradients = ops.transpose(
ops.array(initial_gradients_list, dtype="float32")
)

# Test gradients.
if backend.backend() == "tensorflow":
import tensorflow as tf

Expand All @@ -420,58 +440,60 @@ def test_op(
)
return initial_gradients * tape.gradient(result, inputs)

gradients = test_op(
inputs, input_mins, input_maxs, num_bits, narrow_range, axis
)
# test gradients
self.assertAllClose(gradients, expected_backprops_wrt_input)

if backend.backend() == "torch":
import torch

def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range):
def test_op(
inputs, input_mins, input_maxs, num_bits, narrow_range, axis
):
# Create tensor and enable gradient tracking
inputs = torch.tensor(
inputs, dtype=torch.float32, requires_grad=True
)

# Apply the quantization operation
result = quantizers.fake_quant_with_min_max_vars(
inputs, input_mins, input_maxs, num_bits, narrow_range
inputs, input_mins, input_maxs, num_bits, narrow_range, axis
)

# Compute gradients
result.backward(torch.ones_like(result))

return initial_gradients * inputs.grad

gradients = test_op(
inputs, input_min, input_max, num_bits, narrow_range
)
# test gradients
self.assertAllClose(gradients, expected_backprops_wrt_input)

if backend.backend() == "jax":
import jax

def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range):
def test_op(
inputs, input_mins, input_maxs, num_bits, narrow_range, axis
):
# Define the function to compute gradients for
def quantize_fn(x):
return quantizers.fake_quant_with_min_max_vars(
x, input_mins, input_maxs, num_bits, narrow_range
x, input_mins, input_maxs, num_bits, narrow_range, axis
)

_, f_vjp = jax.vjp(quantize_fn, inputs)
# NOTE:python 3.10 input_gradients = f_vjp.args[0].args[0][0] !
input_gradients = f_vjp.args[0].args[0][1]

# NOTE: When python version >= 3.10, the gradients are at
# `f_vjp.args[0].args[0][0]`. Otherwise, they are at
# `f_vjp.args[0].args[0][1]`.
if sys.version_info >= (3, 10):
input_gradients = f_vjp.args[0].args[0][0]
else:
input_gradients = f_vjp.args[0].args[0][1]

return ops.multiply(initial_gradients, input_gradients)

gradients = test_op(
inputs, input_min, input_max, num_bits, narrow_range
)
# test gradients
gradients = test_op(
inputs, input_min, input_max, num_bits, narrow_range, axis
)
if backend.backend() != "jax" or not testing.jax_uses_gpu():
# JAX GPU produces less precise numbers, causing the CI to fail.
# For example, 127.5 / 255.0 results in 0.49999997 instead of 0.5.
self.assertAllClose(gradients, expected_backprops_wrt_input)

# Test outputs.
outputs = quantizers.fake_quant_with_min_max_vars(
inputs,
input_min,
Expand All @@ -481,3 +503,26 @@ def quantize_fn(x):
axis=axis,
)
self.assertAllClose(outputs, expected)

# Test bfloat16 & float16 dtype
outputs = quantizers.fake_quant_with_min_max_vars(
ops.cast(inputs, "bfloat16"),
input_min,
input_max,
num_bits=num_bits,
narrow_range=narrow_range,
axis=axis,
)
self.assertDType(outputs, "bfloat16")
self.assertAllClose(outputs, expected)

outputs = quantizers.fake_quant_with_min_max_vars(
ops.cast(inputs, "float16"),
input_min,
input_max,
num_bits=num_bits,
narrow_range=narrow_range,
axis=axis,
)
self.assertDType(outputs, "float16")
self.assertAllClose(outputs, expected)
Loading