21
21
22
22
from .._utils import get_output_names
23
23
from .internal_graph import InternalTorchIRGraph , InternalTorchIRNode
24
- from .ops import convert_nodes
24
+ from .ops import TorchFrontend , convert_nodes
25
25
from .quantization_ops import _dequantized_weight
26
26
from .torch_op_registry import _TORCH_OPS_REGISTRY
27
27
from .torchir_passes import (
@@ -194,8 +194,13 @@ class TranscriptionContext:
194
194
context when stepping out.
195
195
"""
196
196
197
- def __init__ (self , name : Optional [str ] = None ) -> None :
197
+ def __init__ (
198
+ self ,
199
+ name : Optional [str ] = None ,
200
+ frontend : TorchFrontend = TorchFrontend .TORCHSCRIPT ,
201
+ ) -> None :
198
202
self .name = name if name else ""
203
+ self .frontend = frontend
199
204
self ._current_graph = [{}]
200
205
self ._torch_graph = None
201
206
self ._quant_context = QuantizationContext (self )
@@ -346,6 +351,7 @@ def __init__(
346
351
self ._prog = Program ()
347
352
348
353
if isinstance (loaded_model , torch .jit .ScriptModule ):
354
+ self .context .frontend = TorchFrontend .TORCHSCRIPT
349
355
self .graph , self .params_dict , self .buffer_dict = InternalTorchIRGraph .from_torchscript (
350
356
torchscript = loaded_model , input_values = self .inputs , cut_at_symbols = cut_at_symbols
351
357
)
@@ -363,6 +369,7 @@ def __init__(
363
369
p (self .graph )
364
370
365
371
elif _HAS_TORCH_EXPORT_API and isinstance (loaded_model , ExportedProgram ):
372
+ self .context .frontend = TorchFrontend .EDGEIR
366
373
self .graph = InternalTorchIRGraph .from_edgeir (edgeir = loaded_model )
367
374
self .params_dict , self .buffer_dict = None , None
368
375
else :
0 commit comments