Skip to content

Commit

Permalink
Simplify the implementation of quantization-related methods (#19954)
Browse files Browse the repository at this point in the history
* Simplify the implementation of `quantize`

* Address comments

* Add compatiblity test

* Remove time.sleep and CompatibilityTest

* Save memory on the device by using numpy for `quantize`

* Increase test coverage
  • Loading branch information
james77777778 authored Jul 23, 2024
1 parent da72acd commit 3ac43b1
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 106 deletions.
3 changes: 2 additions & 1 deletion keras/src/backend/torch/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ def _post_track_variable(self, variable):

def _post_untrack_variable(self, variable):
if hasattr(self, "torch_params"):
self.torch_params.pop(variable.path)
if variable.path in self.torch_params:
self.torch_params.pop(variable.path)
34 changes: 5 additions & 29 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,6 @@ def _check_load_own_variables(self, store):

# Quantization-related (int8 and float8) methods

def _quantization_mode_error(self, mode):
return NotImplementedError(
"Invalid quantization mode. Expected one of "
f"{dtype_policies.QUANTIZATION_MODES}. "
f"Received: quantization_mode={mode}"
)

def quantized_build(self, input_shape, mode):
if mode == "int8":
input_dim = input_shape[-1]
Expand Down Expand Up @@ -390,15 +383,7 @@ def _float8_build(self):
self.outputs_grad_amax_history.overwrite_with_gradient = True
self._is_quantized = True

def quantized_call(self, inputs, training=None):
if self.quantization_mode == "int8":
return self._int8_call(inputs)
elif self.quantization_mode == "float8":
return self._float8_call(inputs, training=training)
else:
raise self._quantization_mode_error(self.quantization_mode)

def _int8_call(self, inputs):
def _int8_call(self, inputs, training=None):
@ops.custom_gradient
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
def grad_fn(*args, upstream=None):
Expand Down Expand Up @@ -525,22 +510,17 @@ def grad(*args, upstream=None, variables=None):
return x

def quantize(self, mode, type_check=True):
import gc

# Prevent quantization of the subclasses
if type_check and (type(self) is not Dense):
raise self._quantize_not_implemented_error()
self._check_quantize_args(mode, self.compute_dtype)
raise self._not_implemented_error(self.quantize)

self._tracker.unlock()
if mode == "int8":
# Quantize `self._kernel` to int8 and compute corresponding scale
kernel_value, kernel_scale = quantizers.abs_max_quantize(
self._kernel, axis=0
self._kernel, axis=0, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
self._untrack_variable(self._kernel)
kernel_shape = self._kernel.shape
kernel_shape = tuple(self._kernel.shape)
del self._kernel
# Utilize a lambda expression as an initializer to prevent adding a
# large constant to the computation graph.
Expand All @@ -553,16 +533,12 @@ def quantize(self, mode, type_check=True):
self._float8_build()
else:
raise self._quantization_mode_error(mode)
self._tracker.lock()

# Set new dtype policy
if self.dtype_policy.quantization_mode is None:
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy

# Release memory manually because sometimes the backend doesn't
gc.collect()

def _get_kernel_with_merged_lora(self):
if self.dtype_policy.quantization_mode is not None:
kernel_value = self._kernel
Expand All @@ -576,7 +552,7 @@ def _get_kernel_with_merged_lora(self):
ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
kernel_value, kernel_scale = quantizers.abs_max_quantize(
kernel_value, axis=0
kernel_value, axis=0, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
return kernel_value, kernel_scale
Expand Down
34 changes: 5 additions & 29 deletions keras/src/layers/core/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,6 @@ def _check_load_own_variables(self, store):

# Quantization-related (int8 and float8) methods

def _quantization_mode_error(self, mode):
return NotImplementedError(
"Invalid quantization mode. Expected one of "
f"{dtype_policies.QUANTIZATION_MODES}. "
f"Received: quantization_mode={mode}"
)

def quantized_build(self, input_shape, mode):
if mode == "int8":
shape_data = _analyze_einsum_string(
Expand Down Expand Up @@ -477,15 +470,7 @@ def _float8_build(self):
self.outputs_grad_amax_history.overwrite_with_gradient = True
self._is_quantized = True

def quantized_call(self, inputs, training=None):
if self.quantization_mode == "int8":
return self._int8_call(inputs)
elif self.quantization_mode == "float8":
return self._float8_call(inputs, training=training)
else:
raise self._quantization_mode_error(self.quantization_mode)

def _int8_call(self, inputs):
def _int8_call(self, inputs, training=None):
@ops.custom_gradient
def einsum_with_inputs_gradient(inputs, kernel, kernel_scale):
def grad_fn(*args, upstream=None):
Expand Down Expand Up @@ -641,14 +626,10 @@ def grad(*args, upstream=None, variables=None):
return x

def quantize(self, mode, type_check=True):
import gc

# Prevent quantization of the subclasses
if type_check and (type(self) is not EinsumDense):
raise self._quantize_not_implemented_error()
self._check_quantize_args(mode, self.compute_dtype)
raise self._not_implemented_error(self.quantize)

self._tracker.unlock()
if mode == "int8":
(
self._input_reduced_axes,
Expand All @@ -664,7 +645,7 @@ def quantize(self, mode, type_check=True):
) = _analyze_quantization_info(self.equation, self.input_spec.ndim)
# Quantize `self._kernel` to int8 and compute corresponding scale
kernel_value, kernel_scale = quantizers.abs_max_quantize(
self._kernel, axis=self._kernel_reduced_axes
self._kernel, axis=self._kernel_reduced_axes, to_numpy=True
)
kernel_scale = ops.transpose(
kernel_scale, self._kernel_transpose_axes
Expand All @@ -677,8 +658,7 @@ def quantize(self, mode, type_check=True):
kernel_scale = ops.squeeze(
kernel_scale, axis=self._kernel_squeeze_axes
)
self._untrack_variable(self._kernel)
kernel_shape = self._kernel.shape
kernel_shape = tuple(self._kernel.shape)
del self._kernel
# Utilize a lambda expression as an initializer to prevent adding a
# large constant to the computation graph.
Expand All @@ -691,16 +671,12 @@ def quantize(self, mode, type_check=True):
self._float8_build()
else:
raise self._quantization_mode_error(mode)
self._tracker.lock()

# Set new dtype policy
if self.dtype_policy.quantization_mode is None:
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy

# Release memory manually because sometimes the backend doesn't
gc.collect()

def _get_kernel_with_merged_lora(self):
if self.dtype_policy.quantization_mode is not None:
kernel_value = self._kernel
Expand Down Expand Up @@ -732,7 +708,7 @@ def _argsort(seq):
ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
kernel_value, kernel_scale = quantizers.abs_max_quantize(
kernel_value, axis=self._kernel_reduced_axes
kernel_value, axis=self._kernel_reduced_axes, to_numpy=True
)
kernel_scale = ops.transpose(
kernel_scale, self._kernel_transpose_axes
Expand Down
26 changes: 9 additions & 17 deletions keras/src/layers/core/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,12 @@ def _int8_build(
)
self._is_quantized = True

def quantized_call(self, inputs):
if self.quantization_mode == "int8":
return self._int8_call(inputs)
else:
def quantized_call(self, *args, **kwargs):
if self.quantization_mode != "int8":
raise self._quantization_mode_error(self.quantization_mode)
return super().quantized_call(*args, **kwargs)

def _int8_call(self, inputs):
def _int8_call(self, inputs, training=None):
# We cannot update quantized self._embeddings, so the custom gradient is
# not needed
if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"):
Expand All @@ -345,22 +344,17 @@ def _int8_call(self, inputs):
return outputs

def quantize(self, mode, type_check=True):
import gc

# Prevent quantization of the subclasses
if type_check and (type(self) is not Embedding):
raise self._quantize_not_implemented_error()
self._check_quantize_args(mode, self.compute_dtype)
raise self._not_implemented_error(self.quantize)

self._tracker.unlock()
if mode == "int8":
# Quantize `self._embeddings` to int8 and compute corresponding
# scale
embeddings_value, embeddings_scale = quantizers.abs_max_quantize(
self._embeddings, axis=-1
self._embeddings, axis=-1, to_numpy=True
)
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
self._untrack_variable(self._embeddings)
del self._embeddings
# Utilize a lambda expression as an initializer to prevent adding a
# large constant to the computation graph.
Expand All @@ -370,16 +364,12 @@ def quantize(self, mode, type_check=True):
)
else:
raise self._quantization_mode_error(mode)
self._tracker.lock()

# Set new dtype policy
if self.dtype_policy.quantization_mode is None:
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy

# Release memory manually because sometimes the backend doesn't
gc.collect()

def _get_embeddings_with_merged_lora(self):
if self.dtype_policy.quantization_mode is not None:
embeddings_value = self._embeddings
Expand All @@ -395,7 +385,9 @@ def _get_embeddings_with_merged_lora(self):
ops.matmul(self.lora_embeddings_a, self.lora_embeddings_b),
)
embeddings_value, embeddings_scale = (
quantizers.abs_max_quantize(embeddings_value, axis=-1)
quantizers.abs_max_quantize(
embeddings_value, axis=-1, to_numpy=True
)
)
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
return embeddings_value, embeddings_scale
Expand Down
Loading

0 comments on commit 3ac43b1

Please sign in to comment.