@@ -71,6 +71,7 @@ def quantize_scope(*args):
7171 'QuantizeWrapperV2' : quantize_wrapper .QuantizeWrapperV2 ,
7272 'QuantizeLayer' : quantize_layer .QuantizeLayer ,
7373 'OutputOnlyConfig' : quantize_config_mod .OutputOnlyConfig ,
74+ 'FixedQuantizeConfig' : quantize_config_mod .FixedQuantizeConfig ,
7475 }
7576 quantization_objects .update (default_8bit_quantize_registry ._types_dict ()) # pylint: disable=protected-access
7677 quantization_objects .update (default_n_bit_quantize_registry ._types_dict ()) # pylint: disable=protected-access
@@ -472,3 +473,169 @@ def _quantize(layer): # pylint: disable=missing-docstring
472473
473474 return keras .models .clone_model (
474475 transformed_model , input_tensors = None , clone_function = _quantize )
476+
477+
478+ def _unwrap_first_input_name (inbound_nodes ):
479+ """Unwrap inbound_nodes three times to get first input name.
480+
481+ Args:
482+ inbound_nodes: A str config that indicates input node. This method assumed
483+ the inbound_nodes looks like `[[['input', 0, 0, {}]]]`.
484+
485+ Returns:
486+ Returns a str name for the first inbound node.
487+ """
488+ current = inbound_nodes
489+
490+ for _ in range (3 ):
491+ if not current :
492+ return None
493+ if not isinstance (current , list ):
494+ return None
495+ current = current [0 ]
496+
497+ if isinstance (current , str ):
498+ return current
499+
500+ return None
501+
502+
503+ def _wrap_fixed_range (
504+ quantize_config , num_bits , init_min , init_max , narrow_range ):
505+ config = quantize_config_mod .FixedQuantizeConfig .from_config (
506+ {'config' : quantize_config ,
507+ 'num_bits' : num_bits ,
508+ 'init_min' : init_min ,
509+ 'init_max' : init_max ,
510+ 'narrow_range' : narrow_range })
511+ return tf .keras .utils .serialize_keras_object (config )
512+
513+
514+ def _is_serialized_node_data (nested ):
515+ # Node data can be of form `[layer_name, node_id, tensor_id]` or
516+ # `[layer_name, node_id, tensor_id, kwargs]`.
517+ if (isinstance (nested , list ) and (len (nested ) in [3 , 4 ]) and
518+ isinstance (nested [0 ], str )):
519+ return True
520+ return False
521+
522+
523+ def _nested_to_flatten_node_data_list (nested ):
524+ """Makes nested node data to flatten node data list."""
525+ if _is_serialized_node_data (nested ):
526+ return [nested ]
527+
528+ if isinstance (nested , list ):
529+ return sum (map (_nested_to_flatten_node_data_list , nested ), [])
530+
531+ if isinstance (nested , dict ):
532+ return sum (map (_nested_to_flatten_node_data_list , nested .values ()), [])
533+
534+ raise ValueError ('{} is not a supported nested node data.' .format (nested ))
535+
536+
537+ def fix_input_output_range (
538+ model ,
539+ num_bits = 8 ,
540+ input_min = 0.0 ,
541+ input_max = 1.0 ,
542+ output_min = 0.0 ,
543+ output_max = 1.0 ,
544+ narrow_range = False ):
545+ """Fix the input and output ranges.
546+
547+ Example:
548+
549+ ```python
550+ model = keras.Sequential([
551+ layers.Dense(10, activation='relu', input_shape=(100,)),
552+ quantize_annotate_layer(layers.Dense(2, activation='sigmoid'))
553+ ])
554+ with quantize.quantize_scope():
555+ model = quantize_annotate_model(model)
556+ model = quantize_apply(model)
557+ model = fix_input_output_range(model, num_bits=4,
558+ input_min=0, input_max=15,
559+ output_min=0, output_max=15,
560+ narrow_range=False)
561+ ```
562+
563+ In certain cases, a desired input/output ranges is known and should not be
564+ altered during training. To set these values, use the arguments as follows:
565+
566+ Args:
567+ model: A `tf.keras` Sequential or Functional model which has been quantized.
568+ num_bits: Number of bits for quantization
569+ input_min: The lower end of quantization interval for the input.
570+ input_max: The upper end of quantization interval for the input.
571+ output_min: The lower end of quantization interval for the output.
572+ output_max: The upper end of quantization interval for the output.
573+ narrow_range: In case of 8 bits, narrow_range nudges the quantized range
574+ to be [-127, 127] instead of [-128, 127]. This ensures symmetric
575+ range has 0 as the centre.
576+
577+ Returns:
578+ Returns a new `tf.keras` model fixed input range set to (input_min,
579+ input_max) and fixed output range set to (output_min, output_max).
580+ """
581+ config = model .get_config ()
582+ fixed_input_quantizer = quantizers .FixedQuantizer (
583+ num_bits = num_bits ,
584+ init_min = input_min ,
585+ init_max = input_max ,
586+ narrow_range = narrow_range )
587+ serialized_fixed_input_quantizer = tf .keras .utils .serialize_keras_object (
588+ fixed_input_quantizer )
589+
590+ if _is_functional_model (model ):
591+ input_layer_list = _nested_to_flatten_node_data_list (config ['input_layers' ])
592+ for layer_config in config ['layers' ]:
593+ input_name = _unwrap_first_input_name (layer_config ['inbound_nodes' ])
594+ if input_name is None :
595+ continue
596+
597+ for input_layer in input_layer_list :
598+ if input_name == input_layer [0 ]:
599+ layer_config ['config' ]['quantizer' ] = serialized_fixed_input_quantizer
600+ break
601+
602+ output_layer_list = _nested_to_flatten_node_data_list (
603+ config ['output_layers' ])
604+ for layer_config in config ['layers' ]:
605+ for output_layer in output_layer_list :
606+ if layer_config ['config' ]['name' ] == output_layer [0 ]:
607+ if 'quantize_config' in layer_config ['config' ]:
608+ layer_config ['config' ]['quantize_config' ] = (
609+ _wrap_fixed_range (
610+ layer_config ['config' ]['quantize_config' ],
611+ num_bits = num_bits ,
612+ init_min = output_min ,
613+ init_max = output_max ,
614+ narrow_range = narrow_range ))
615+ break
616+
617+ model = keras .Model .from_config (config )
618+ else :
619+ if (len (config ['layers' ]) < 1 or
620+ config ['layers' ][1 ]['class_name' ] != 'QuantizeLayer' ):
621+ raise ValueError ('`model` should be already quantized.' )
622+ config ['layers' ][1 ]['config' ][
623+ 'quantizer' ] = serialized_fixed_input_quantizer
624+ if 'quantize_config' in config ['layers' ][- 1 ]['config' ]:
625+ config ['layers' ][- 1 ]['config' ]['quantize_config' ] = (
626+ _wrap_fixed_range (
627+ config ['layers' ][- 1 ]['config' ]['quantize_config' ],
628+ num_bits = num_bits ,
629+ init_min = output_min ,
630+ init_max = output_max ,
631+ narrow_range = narrow_range ))
632+
633+ model = keras .Sequential .from_config (config )
634+
635+ return model
636+
637+
638+ def _is_functional_model (model ):
639+ return (isinstance (model , keras .Model )
640+ and not isinstance (model , keras .Sequential )
641+ and model ._is_graph_network ) # pylint: disable=protected-access
0 commit comments