Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the implementation of quantization-related methods #19954

Merged
merged 8 commits into from
Jul 23, 2024
3 changes: 2 additions & 1 deletion keras/src/backend/torch/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ def _post_track_variable(self, variable):

def _post_untrack_variable(self, variable):
if hasattr(self, "torch_params"):
self.torch_params.pop(variable.path)
if variable.path in self.torch_params:
self.torch_params.pop(variable.path)
34 changes: 5 additions & 29 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,6 @@ def _check_load_own_variables(self, store):

# Quantization-related (int8 and float8) methods

def _quantization_mode_error(self, mode):
return NotImplementedError(
"Invalid quantization mode. Expected one of "
f"{dtype_policies.QUANTIZATION_MODES}. "
f"Received: quantization_mode={mode}"
)

def quantized_build(self, input_shape, mode):
if mode == "int8":
input_dim = input_shape[-1]
Expand Down Expand Up @@ -390,15 +383,7 @@ def _float8_build(self):
self.outputs_grad_amax_history.overwrite_with_gradient = True
self._is_quantized = True

def quantized_call(self, inputs, training=None):
if self.quantization_mode == "int8":
return self._int8_call(inputs)
elif self.quantization_mode == "float8":
return self._float8_call(inputs, training=training)
else:
raise self._quantization_mode_error(self.quantization_mode)

def _int8_call(self, inputs):
def _int8_call(self, inputs, training=None):
@ops.custom_gradient
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
def grad_fn(*args, upstream=None):
Expand Down Expand Up @@ -525,22 +510,17 @@ def grad(*args, upstream=None, variables=None):
return x

def quantize(self, mode, type_check=True):
import gc

# Prevent quantization of the subclasses
if type_check and (type(self) is not Dense):
raise self._quantize_not_implemented_error()
self._check_quantize_args(mode, self.compute_dtype)
raise self._not_implemented_error(self.quantize)

self._tracker.unlock()
if mode == "int8":
# Quantize `self._kernel` to int8 and compute corresponding scale
kernel_value, kernel_scale = quantizers.abs_max_quantize(
self._kernel, axis=0
self._kernel, axis=0, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
self._untrack_variable(self._kernel)
kernel_shape = self._kernel.shape
kernel_shape = tuple(self._kernel.shape)
del self._kernel
# Utilize a lambda expression as an initializer to prevent adding a
# large constant to the computation graph.
Expand All @@ -553,16 +533,12 @@ def quantize(self, mode, type_check=True):
self._float8_build()
else:
raise self._quantization_mode_error(mode)
self._tracker.lock()

# Set new dtype policy
if self.dtype_policy.quantization_mode is None:
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy

# Release memory manually because sometimes the backend doesn't
gc.collect()

def _get_kernel_with_merged_lora(self):
if self.dtype_policy.quantization_mode is not None:
kernel_value = self._kernel
Expand All @@ -576,7 +552,7 @@ def _get_kernel_with_merged_lora(self):
ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
kernel_value, kernel_scale = quantizers.abs_max_quantize(
kernel_value, axis=0
kernel_value, axis=0, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
return kernel_value, kernel_scale
Expand Down
34 changes: 5 additions & 29 deletions keras/src/layers/core/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,6 @@ def _check_load_own_variables(self, store):

# Quantization-related (int8 and float8) methods

def _quantization_mode_error(self, mode):
return NotImplementedError(
"Invalid quantization mode. Expected one of "
f"{dtype_policies.QUANTIZATION_MODES}. "
f"Received: quantization_mode={mode}"
)

def quantized_build(self, input_shape, mode):
if mode == "int8":
shape_data = _analyze_einsum_string(
Expand Down Expand Up @@ -477,15 +470,7 @@ def _float8_build(self):
self.outputs_grad_amax_history.overwrite_with_gradient = True
self._is_quantized = True

def quantized_call(self, inputs, training=None):
if self.quantization_mode == "int8":
return self._int8_call(inputs)
elif self.quantization_mode == "float8":
return self._float8_call(inputs, training=training)
else:
raise self._quantization_mode_error(self.quantization_mode)

def _int8_call(self, inputs):
def _int8_call(self, inputs, training=None):
@ops.custom_gradient
def einsum_with_inputs_gradient(inputs, kernel, kernel_scale):
def grad_fn(*args, upstream=None):
Expand Down Expand Up @@ -641,14 +626,10 @@ def grad(*args, upstream=None, variables=None):
return x

def quantize(self, mode, type_check=True):
import gc

# Prevent quantization of the subclasses
if type_check and (type(self) is not EinsumDense):
raise self._quantize_not_implemented_error()
self._check_quantize_args(mode, self.compute_dtype)
raise self._not_implemented_error(self.quantize)

self._tracker.unlock()
if mode == "int8":
(
self._input_reduced_axes,
Expand All @@ -664,7 +645,7 @@ def quantize(self, mode, type_check=True):
) = _analyze_quantization_info(self.equation, self.input_spec.ndim)
# Quantize `self._kernel` to int8 and compute corresponding scale
kernel_value, kernel_scale = quantizers.abs_max_quantize(
self._kernel, axis=self._kernel_reduced_axes
self._kernel, axis=self._kernel_reduced_axes, to_numpy=True
)
kernel_scale = ops.transpose(
kernel_scale, self._kernel_transpose_axes
Expand All @@ -677,8 +658,7 @@ def quantize(self, mode, type_check=True):
kernel_scale = ops.squeeze(
kernel_scale, axis=self._kernel_squeeze_axes
)
self._untrack_variable(self._kernel)
kernel_shape = self._kernel.shape
kernel_shape = tuple(self._kernel.shape)
del self._kernel
# Utilize a lambda expression as an initializer to prevent adding a
# large constant to the computation graph.
Expand All @@ -691,16 +671,12 @@ def quantize(self, mode, type_check=True):
self._float8_build()
else:
raise self._quantization_mode_error(mode)
self._tracker.lock()

# Set new dtype policy
if self.dtype_policy.quantization_mode is None:
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy

# Release memory manually because sometimes the backend doesn't
gc.collect()

def _get_kernel_with_merged_lora(self):
if self.dtype_policy.quantization_mode is not None:
kernel_value = self._kernel
Expand All @@ -714,7 +690,7 @@ def _get_kernel_with_merged_lora(self):
ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
kernel_value, kernel_scale = quantizers.abs_max_quantize(
kernel_value, axis=self._kernel_reduced_axes
kernel_value, axis=self._kernel_reduced_axes, to_numpy=True
)
kernel_scale = ops.transpose(
kernel_scale, self._kernel_transpose_axes
Expand Down
26 changes: 9 additions & 17 deletions keras/src/layers/core/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,12 @@ def _int8_build(
)
self._is_quantized = True

def quantized_call(self, inputs):
if self.quantization_mode == "int8":
return self._int8_call(inputs)
else:
def quantized_call(self, *args, **kwargs):
if self.quantization_mode != "int8":
raise self._quantization_mode_error(self.quantization_mode)
return super().quantized_call(*args, **kwargs)

def _int8_call(self, inputs):
def _int8_call(self, inputs, training=None):
# We cannot update quantized self._embeddings, so the custom gradient is
# not needed
if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"):
Expand All @@ -345,22 +344,17 @@ def _int8_call(self, inputs):
return outputs

def quantize(self, mode, type_check=True):
import gc

# Prevent quantization of the subclasses
if type_check and (type(self) is not Embedding):
raise self._quantize_not_implemented_error()
self._check_quantize_args(mode, self.compute_dtype)
raise self._not_implemented_error(self.quantize)

self._tracker.unlock()
if mode == "int8":
# Quantize `self._embeddings` to int8 and compute corresponding
# scale
embeddings_value, embeddings_scale = quantizers.abs_max_quantize(
self._embeddings, axis=-1
self._embeddings, axis=-1, to_numpy=True
)
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
self._untrack_variable(self._embeddings)
del self._embeddings
# Utilize a lambda expression as an initializer to prevent adding a
# large constant to the computation graph.
Expand All @@ -370,16 +364,12 @@ def quantize(self, mode, type_check=True):
)
else:
raise self._quantization_mode_error(mode)
self._tracker.lock()

# Set new dtype policy
if self.dtype_policy.quantization_mode is None:
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy

# Release memory manually because sometimes the backend doesn't
gc.collect()

def _get_embeddings_with_merged_lora(self):
if self.dtype_policy.quantization_mode is not None:
embeddings_value = self._embeddings
Expand All @@ -395,7 +385,9 @@ def _get_embeddings_with_merged_lora(self):
ops.matmul(self.lora_embeddings_a, self.lora_embeddings_b),
)
embeddings_value, embeddings_scale = (
quantizers.abs_max_quantize(embeddings_value, axis=-1)
quantizers.abs_max_quantize(
embeddings_value, axis=-1, to_numpy=True
)
)
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
return embeddings_value, embeddings_scale
Expand Down
Loading