@@ -620,12 +620,28 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s
620620    save_bundled_program (exec_prog , method_test_suites , output_name )
621621
622622
623+ def  quantize_model (
624+     exported_program , args , model : torch .nn .Module , example_inputs , compile_spec 
625+ ):
626+     model_int8  =  quantize (
627+         model ,
628+         args .model_name ,
629+         compile_spec ,
630+         example_inputs ,
631+         args .evaluate ,
632+         args .evaluate_config ,
633+     )
634+     # Wrap quantized model back into an exported_program 
635+     exported_program  =  torch .export .export_for_training (
636+         model_int8 , example_inputs , strict = True 
637+     )
638+ 
639+     return  model_int8 , exported_program 
640+ 
641+ 
623642def  to_edge_TOSA_delegate (
624-     exported_program ,
625-     args ,
626-     model : torch .nn .Module ,
643+     exported_program , args , model : torch .nn .Module , example_inputs 
627644):
628-     model_int8  =  None 
629645    # As we can target multiple output encodings, one must 
630646    # be specified. 
631647    compile_spec  =  get_compile_spec (
@@ -634,23 +650,13 @@ def to_edge_TOSA_delegate(
634650        args .system_config ,
635651        args .memory_mode ,
636652    )
653+ 
654+     model_int8  =  None 
637655    if  args .quantize :
638-         model  =  quantize (
639-             model ,
640-             args .model_name ,
641-             compile_spec ,
642-             example_inputs ,
643-             args .evaluate ,
644-             args .evaluate_config ,
656+         model_int8 , exported_program  =  quantize_model (
657+             exported_program , args , model , example_inputs , compile_spec 
645658        )
646-         model_int8  =  model 
647-         # Wrap quantized model back into an exported_program 
648-         exported_program  =  torch .export .export_for_training (
649-             model , example_inputs , strict = True 
650-         )
651- 
652-         if  args .intermediates :
653-             os .makedirs (args .intermediates , exist_ok = True )
659+         model  =  model_int8 
654660
655661    if  is_ethosu (compile_spec ):
656662        partitioner  =  EthosUPartitioner (compile_spec )
@@ -669,6 +675,31 @@ def to_edge_TOSA_delegate(
669675    return  model_int8 , edge 
670676
671677
678+ def  to_edge_no_delegate (exported_program , args , model : torch .nn .Module , example_inputs ):
679+     model_int8  =  None 
680+     if  args .quantize :
681+         # As we can target multiple output encodings, one must 
682+         # be specified. 
683+         compile_spec  =  get_compile_spec (
684+             args .target ,
685+             args .intermediates ,
686+             args .system_config ,
687+             args .memory_mode ,
688+         )
689+         model , exported_program  =  quantize_model (
690+             exported_program , args , model , example_inputs , compile_spec 
691+         )
692+         model_int8  =  model 
693+ 
694+     edge  =  to_edge_transform_and_lower (
695+         exported_program ,
696+         compile_config = EdgeCompileConfig (
697+             _check_ir_validity = False ,
698+         ),
699+     )
700+     return  model_int8 , edge 
701+ 
702+ 
672703if  __name__  ==  "__main__" :  # noqa: C901 
673704    args  =  get_args ()
674705
@@ -686,16 +717,18 @@ def to_edge_TOSA_delegate(
686717    model  =  exported_program .module ()
687718    model_fp32  =  model 
688719
720+     if  args .intermediates :
721+         os .makedirs (args .intermediates , exist_ok = True )
722+ 
689723    # Quantize if required 
690724    model_int8  =  None 
691725    if  args .delegate :
692-         model_int8 , edge  =  to_edge_TOSA_delegate (exported_program , args , model )
726+         model_int8 , edge  =  to_edge_TOSA_delegate (
727+             exported_program , args , model , example_inputs 
728+         )
693729    else :
694-         edge  =  to_edge_transform_and_lower (
695-             exported_program ,
696-             compile_config = EdgeCompileConfig (
697-                 _check_ir_validity = False ,
698-             ),
730+         model_int8 , edge  =  to_edge_no_delegate (
731+             exported_program , args , model , example_inputs 
699732        )
700733
701734    dump_delegation_info (edge , args .intermediates )
0 commit comments