Skip to content

torch._dynamo.exc.Unsupported: Unsupported: quantized nyi in meta tensors with fake tensor propagation. #8727

Closed
@gpchowdari

Description

@gpchowdari

🐛 Describe the bug

to_edge_transform_and_lower throwing the error when quantized input is passed in export function.

Sample program to reproduce:


import torch

from torch.export import export_for_training
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

from torch.export.dynamic_shapes import Dim
from torch.export import export, ExportedProgram
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(50,10)
        
    
    def forward(self, x,ilens):
        xd = self.fc(x)
        il = ilens
        return xd,il

def xnnpack_quantize_n_convert(model, x,ilens):
  with torch.no_grad():
    scale = 0.1 
    zero_point = 0 
    x_q = torch.quantize_per_tensor(x,scale = scale,zero_point=zero_point,dtype=torch.qint8)
    ilens_q = torch.quantize_per_tensor(ilens.float(),scale = scale,zero_point=zero_point,dtype=torch.qint8)
    nq_si = (x,ilens)
    q_si = (x_q,ilens_q)
    seq_len_dim = Dim("seq_len",min=1,max=1024)
    dynamic_shapes = {
        "x": {0:1,1:seq_len_dim},
        "ilens": {0:1},
    }
    m = export_for_training(model, (x,ilens),dynamic_shapes= dynamic_shapes).module()

    quantizer = XNNPACKQuantizer()
    operator_config = get_symmetric_quantization_config(is_per_channel=False,is_dynamic=True)
    quantizer.set_global(operator_config)
    m = prepare_pt2e(m, quantizer)
    m(*(x,ilens))
    m = convert_pt2e(m)

    exported_program: ExportedProgram  = export(mod=m, args=q_si,dynamic_shapes = dynamic_shapes)
    edge = to_edge_transform_and_lower(
      exported_program,  
      compile_config=EdgeCompileConfig(_check_ir_validity=True),
      partitioner=[XnnpackPartitioner()])
    
    exec_prog = edge.to_executorch()
    print("DONE")

m = MyModule().to('cpu')
m.eval()
x = torch.randn(torch.Size((1,337,80)))
ilens = torch.randn(torch.Size((1,)))
xnnpack_quantize_n_convert(m,x,ilens)


Logs

V0226 15:33:00.389000 2568045 torch/_dynamo/convert_frame.py:1365] skipping: _wrapped_call_impl (reason: in skipfiles, file: python3.10/site-packages/torch/nn/modules/module.py)
V0226 15:33:00.389000 2568045 torch/_dynamo/convert_frame.py:1365] skipping: _call_impl (reason: in skipfiles, file: python3.10/site-packages/torch/nn/modules/module.py)
I0226 15:33:00.392000 2568045 torch/_dynamo/utils.py:1512] [0/0] ChromiumEventLogger initialized with id 56391616-de4c-468d-923a-88bf29db3beb
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0] torchdynamo start compiling forward /temp.py:24, stack (elided 4 frames):
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "/temp.py", line 67, in <module>
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     xnnpack_quantize_n_convert(m,x,ilens)
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "/temp.py", line 45, in xnnpack_quantize_n_convert
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     m = export_for_training(model, (x,ilens),dynamic_shapes= dynamic_shapes).module()
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/export/__init__.py", line 168, in export_for_training
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     return _export_for_training(
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/export/_trace.py", line 1017, in wrapper
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     ep = fn(*args, **kwargs)
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/export/exported_program.py", line 117, in wrapper
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     return fn(*args, **kwargs)
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/export/_trace.py", line 1944, in _export_for_training
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     export_artifact = export_func(  # type: ignore[operator]
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/export/_trace.py", line 1296, in _strict_export_lower_to_aten_ir
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     gm_torch_level = _export_to_torch_ir(
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/export/_trace.py", line 693, in _export_to_torch_ir
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     gm_torch_level, _ = torch._dynamo.export(
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1579, in inner
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     result_traced = opt_f(*args, **kwargs)
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     return self._call_impl(*args, **kwargs)
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     return forward_call(*args, **kwargs)
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/_dynamo/eval_frame.py", line 570, in _fn
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     return fn(*args, **kwargs)
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]   File "python3.10/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0]     return self._call_impl(*args, **kwargs)
V0226 15:33:00.396000 2568045 torch/_dynamo/convert_frame.py:941] [0/0] 
I0226 15:33:00.397000 2568045 torch/_dynamo/symbolic_convert.py:2746] [0/0] Step 1: torchdynamo start tracing forward /temp.py:24
I0226 15:33:00.398000 2568045 torch/fx/experimental/symbolic_shapes.py:3288] [0/0] create_env
V0226 15:33:00.408000 2568045 torch/_dynamo/variables/builder.py:2886] [0/0] wrap_to_fake L['x'] (1, 337, 80) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.DYNAMIC: 0>, <DimDynamic.DYNAMIC: 0>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[StrictMinMaxConstraint(warn_only=False, vr=VR[1, 1]), StrictMinMaxConstraint(warn_only=False, vr=VR[1, 1024]), None], constraint_strides=[None, None, None], view_base_context=None, tensor_source=LocalSource(local_name='x', is_input=True, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0226 15:33:00.459000 2568045 torch/fx/experimental/symbolic_shapes.py:5967] [0/0] _update_var_to_range s0 = VR[2, 1024] (update)
I0226 15:33:00.459000 2568045 torch/fx/experimental/symbolic_shapes.py:4542] [0/0] create_symbol s0 = 337 for L['x'].size()[1] [2, 1024] (_dynamo/variables/builder.py:2894 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0226 15:33:00.460000 2568045 torch/fx/experimental/symbolic_shapes.py:6781] [0/0] runtime_assert True == True [statically known]
V0226 15:33:00.463000 2568045 torch/_dynamo/output_graph.py:2194] [0/0] create_graph_input L_x_ L['x'] FakeTensor(..., size=(1, s0, 80)) at debug_level 0 before=False
V0226 15:33:00.464000 2568045 torch/_dynamo/variables/builder.py:2886] [0/0] wrap_to_fake L['ilens'] (1,) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.DYNAMIC: 0>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[StrictMinMaxConstraint(warn_only=False, vr=VR[1, 1])], constraint_strides=[None], view_base_context=None, tensor_source=LocalSource(local_name='ilens', is_input=True, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0226 15:33:00.465000 2568045 torch/_dynamo/output_graph.py:2194] [0/0] create_graph_input L_ilens_ L['ilens'] FakeTensor(..., size=(1,)) at debug_level 0 before=False
V0226 15:33:00.466000 2568045 torch/_dynamo/symbolic_convert.py:958] [0/0] [__trace_source] TRACE starts_line /temp.py:25 in forward (MyModule.forward)
V0226 15:33:00.466000 2568045 torch/_dynamo/symbolic_convert.py:958] [0/0] [__trace_source]             xd = self.fc(x)
V0226 15:33:00.469000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE LOAD_FAST self []
V0226 15:33:00.469000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE LOAD_ATTR fc [NNModuleVariable()]
V0226 15:33:00.470000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE LOAD_FAST x [NNModuleVariable()]
V0226 15:33:00.470000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [NNModuleVariable(), TensorVariable()]
V0226 15:33:00.478000 2568045 torch/fx/experimental/symbolic_shapes.py:6579] [0/0] eval Ne(s0, 1) == True [statically known]
V0226 15:33:00.480000 2568045 torch/fx/experimental/symbolic_shapes.py:6781] [0/0] runtime_assert True == True [statically known]
V0226 15:33:00.481000 2568045 torch/fx/experimental/symbolic_shapes.py:6579] [0/0] eval Eq(s0, 1) == False [statically known]
V0226 15:33:00.485000 2568045 torch/fx/experimental/symbolic_shapes.py:6579] [0/0] eval Ne(Mod(1, s0), 0) == True [statically known]
V0226 15:33:00.497000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE STORE_FAST xd [TensorVariable()]
V0226 15:33:00.497000 2568045 torch/_dynamo/symbolic_convert.py:958] [0/0] [__trace_source] TRACE starts_line /temp.py:26 in forward (MyModule.forward)
V0226 15:33:00.497000 2568045 torch/_dynamo/symbolic_convert.py:958] [0/0] [__trace_source]             il = ilens
V0226 15:33:00.498000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE LOAD_FAST ilens []
V0226 15:33:00.498000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE STORE_FAST il [TensorVariable()]
V0226 15:33:00.498000 2568045 torch/_dynamo/symbolic_convert.py:958] [0/0] [__trace_source] TRACE starts_line /temp.py:27 in forward (MyModule.forward)
V0226 15:33:00.498000 2568045 torch/_dynamo/symbolic_convert.py:958] [0/0] [__trace_source]             return xd,il
V0226 15:33:00.498000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE LOAD_FAST xd []
V0226 15:33:00.498000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE LOAD_FAST il [TensorVariable()]
V0226 15:33:00.499000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE BUILD_TUPLE 2 [TensorVariable(), TensorVariable()]
V0226 15:33:00.499000 2568045 torch/_dynamo/symbolic_convert.py:981] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TupleVariable(length=2)]
I0226 15:33:00.499000 2568045 torch/_dynamo/symbolic_convert.py:3067] [0/0] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
V0226 15:33:00.499000 2568045 torch/_dynamo/symbolic_convert.py:3071] [0/0] RETURN_VALUE triggered compile
V0226 15:33:00.499000 2568045 torch/_dynamo/output_graph.py:970] [0/0] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /temp.py, line 27 in forward>], graph_break=False)
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code] TRACED GRAPH
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]  python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]     def forward(self, L_x_: "f32[1, s0, 80][80*s0, 80, 1]cpu", L_ilens_: "f32[1][1]cpu"):
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]         l_x_ = L_x_
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]         l_ilens_ = L_ilens_
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]         
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]          # File: /temp.py:25 in forward, code: xd = self.fc(x)
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]         xd: "f32[1, s0, 10][10*s0, 10, 1]cpu" = self.L__self___fc(l_x_);  l_x_ = None
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]         return (xd, l_ilens_)
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code]         
V0226 15:33:00.502000 2568045 torch/_dynamo/output_graph.py:1355] [0/0] [__graph_code] 
I0226 15:33:00.503000 2568045 torch/_dynamo/output_graph.py:1462] [0/0] Step 2: calling compiler function dynamo_normalization_capturing_compiler
I0226 15:33:00.503000 2568045 torch/_dynamo/output_graph.py:1467] [0/0] Step 2: done compiler function dynamo_normalization_capturing_compiler
I0226 15:33:00.509000 2568045 torch/fx/experimental/symbolic_shapes.py:4670] [0/0] produce_guards
V0226 15:33:00.510000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['x'].size()[0] 1 None
V0226 15:33:00.510000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['x'].size()[1] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[1, 1024])
V0226 15:33:00.511000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['x'].size()[2] 80 None
V0226 15:33:00.511000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['x'].stride()[0] 80*s0 None
V0226 15:33:00.512000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['x'].stride()[1] 80 None
V0226 15:33:00.512000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['x'].stride()[2] 1 None
V0226 15:33:00.513000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['x'].storage_offset() 0 None
V0226 15:33:00.513000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['ilens'].size()[0] 1 None
V0226 15:33:00.513000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['ilens'].stride()[0] 1 None
V0226 15:33:00.514000 2568045 torch/fx/experimental/symbolic_shapes.py:4890] [0/0] track_symint L['ilens'].storage_offset() 0 None
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards] Python shape guard function:
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards] def ___make_guard_fn(IsNonOverlappingAndDenseIndicator, cast_symbool_to_symint_guardless, math, torch, ___check_type_id, ___check_obj_id, ___odict_getitem, ___key_to_id, ___dict_version, ___dict_contains, ___tuple_iterator_len, ___normalize_range_iter, ___tuple_iterator_getitem, ___get_torch_function_mode_stack_at, __math_isnan, __numpy_isnan, inf, __load_module, utils_device, device, ___from_numpy, ___as_tensor, inspect):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]     def guard(L):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         _var0 = L['x']
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         _var1 = _var0.size
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         _var2 = _var1()
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (_var2[0] == 1):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (_var2[2] == 80):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         _var3 = _var0.stride
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         _var4 = _var3()
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         _var5 = _var2[1]
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (_var4[0] == 80 * _var5):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (_var4[1] == 80):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (_var4[2] == 1):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (_var0.storage_offset() == 0):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         _var6 = L['ilens']
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (_var6.size()[0] == 1):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (_var6.stride()[0] == 1):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (_var6.storage_offset() == 0):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         if not (2 <= _var5 and _var5 <= 1024):
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]             return False
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]         return True
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards]     return guard
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:1270] [0/0] [__guards] 
V0226 15:33:00.525000 2568045 torch/_dynamo/guards.py:2513] [0/0] [__guards] GUARDS:
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] 
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] TREE_GUARD_MANAGER:
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- RootGuardManager
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:491 in init_ambient_guards
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=1)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | +- TYPE_MATCH: ___check_type_id(L['x'], 110353158665552)                     # _dynamo/variables/builder.py:1744 in wrap_tensor
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | +- DYNAMIC_INDICES: ((L['x']._dynamo_dynamic_indices.issubset(set())) if hasattr(L['x'], '_dynamo_dynamic_indices') else True)  # _dynamo/variables/builder.py:1744 in wrap_tensor
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | +- GuardManager: source=L['self'], accessed_by=FrameLocalsGuardAccessor(key='self', framelocals_idx=0)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | +- ID_MATCH: ___check_obj_id(L['self'], 125454710324432)                   # _dynamo/output_graph.py:793 in register_attr_or_module
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | +- GuardManager: source=L['self'].__dict__, accessed_by=GetGenericDictGuardAccessor
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | | +- GuardManager: source=L['self'].training, accessed_by=DictGetItemGuardAccessor('training')
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | | | +- ID_MATCH: ___check_obj_id(L['self'].training, 110353092011008)          # _dynamo/output_graph.py:793 in register_attr_or_module
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | | +- GuardManager: source=L['self']._modules, accessed_by=DictGetItemGuardAccessor('_modules')
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | | | +- GuardManager: source=L['self'].fc, accessed_by=DictGetItemGuardAccessor('fc')
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | | | | +- ID_MATCH: ___check_obj_id(L['self'].fc, 125454710324384)                # xd = self.fc(x)  # e2e_pytorch/temp.py:25 in forward
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | | | | +- GuardManager: source=L['self'].fc.__dict__, accessed_by=GetGenericDictGuardAccessor
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | | | | | +- GuardManager: source=L['self'].fc.training, accessed_by=DictGetItemGuardAccessor('training')
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | | | | | | +- ID_MATCH: ___check_obj_id(L['self'].fc.training, 110353092011008)       # xd = self.fc(x)  # e2e_pytorch/temp.py:25 in forward
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | +- GuardManager: source=L['ilens'], accessed_by=FrameLocalsGuardAccessor(key='ilens', framelocals_idx=2)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | +- TYPE_MATCH: ___check_type_id(L['ilens'], 110353158665552)                 # _dynamo/variables/builder.py:1744 in wrap_tensor
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] | | +- DYNAMIC_INDICES: ((L['ilens']._dynamo_dynamic_indices.issubset(set())) if hasattr(L['ilens'], '_dynamo_dynamic_indices') else True)  # _dynamo/variables/builder.py:1744 in wrap_tensor
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: L['x'].size()[0] == 1  # (unknown source L['x'].size()[0], please file a bug)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: L['x'].size()[2] == 80  # (unknown source L['x'].size()[2], please file a bug)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: L['x'].stride()[0] == 80*L['x'].size()[1]  # (unknown source L['x'].stride()[0], please file a bug)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: L['x'].stride()[1] == 80  # (unknown source L['x'].stride()[1], please file a bug)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: L['x'].stride()[2] == 1  # (unknown source L['x'].stride()[2], please file a bug)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: L['x'].storage_offset() == 0  # (unknown source L['x'].storage_offset(), please file a bug)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: L['ilens'].size()[0] == 1  # (unknown source L['ilens'].size()[0], please file a bug)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: L['ilens'].stride()[0] == 1  # (unknown source L['ilens'].stride()[0], please file a bug)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: L['ilens'].storage_offset() == 0  # (unknown source L['ilens'].storage_offset(), please file a bug)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] +- LAMBDA_GUARD: 2 <= L['x'].size()[1] <= 1024  # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim)) and (_dynamo/variables/builder.py:2894 in <lambda>)
V0226 15:33:00.526000 2568045 torch/_dynamo/guards.py:2451] [0/0] [__guards] 
I0226 15:33:00.527000 2568045 torch/_dynamo/pgo.py:647] [0/0] put_code_state: no cache key, skipping
I0226 15:33:00.527000 2568045 torch/_dynamo/convert_frame.py:1059] [0/0] run_gc_after_compile: running gc
V0226 15:33:00.530000 2568045 torch/_dynamo/convert_frame.py:1365] skipping: _fn (reason: in skipfiles, file: python3.10/site-packages/torch/_dynamo/eval_frame.py)
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1609] Summary of dimension constraints:
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665] Dynamo captured graph:
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665] 
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665] class GraphModule(torch.nn.Module):
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665]     def forward(self, L_x_: "f32[1, s0, 80]", L_ilens_: "f32[1]"):
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665]         l_x_ = L_x_
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665]         l_ilens_ = L_ilens_
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665]         
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665]          # File: /temp.py:25 in forward, code: xd = self.fc(x)
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665]         xd: "f32[1, s0, 10]" = self.L__self___fc(l_x_);  l_x_ = None
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665]         return (xd, l_ilens_)
I0226 15:33:00.546000 2568045 torch/_dynamo/eval_frame.py:1665]         
V0226 15:33:00.591000 2568045 torch/_dynamo/convert_frame.py:1365] skipping: call_wrapped (reason: in skipfiles, file: python3.10/site-packages/torch/fx/graph_module.py)
V0226 15:33:00.591000 2568045 torch/_dynamo/convert_frame.py:1365] skipping: __call__ (reason: in skipfiles, file: python3.10/site-packages/torch/fx/graph_module.py)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0] torchdynamo start compiling forward <eval_with_key>.29:4, stack (elided 4 frames):
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "/temp.py", line 67, in <module>
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     xnnpack_quantize_n_convert(m,x,ilens)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "/temp.py", line 54, in xnnpack_quantize_n_convert
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     exported_program: ExportedProgram  = export(mod=m, args=q_si,dynamic_shapes = dynamic_shapes)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/export/__init__.py", line 368, in export
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return _export(
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/export/_trace.py", line 1017, in wrapper
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     ep = fn(*args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/export/exported_program.py", line 117, in wrapper
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return fn(*args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/export/_trace.py", line 2079, in _export
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return _export_for_training(
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/export/_trace.py", line 1017, in wrapper
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     ep = fn(*args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/export/exported_program.py", line 117, in wrapper
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return fn(*args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/export/_trace.py", line 1944, in _export_for_training
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     export_artifact = export_func(  # type: ignore[operator]
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/export/_trace.py", line 1296, in _strict_export_lower_to_aten_ir
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     gm_torch_level = _export_to_torch_ir(
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/export/_trace.py", line 693, in _export_to_torch_ir
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     gm_torch_level, _ = torch._dynamo.export(
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1579, in inner
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     result_traced = opt_f(*args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return self._call_impl(*args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return forward_call(*args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/_dynamo/eval_frame.py", line 570, in _fn
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return fn(*args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/fx/graph_module.py", line 824, in call_wrapped
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return self._wrapped_call(self, *args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/fx/graph_module.py", line 387, in __call__
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]   File "python3.10/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0]     return self._call_impl(*args, **kwargs)
V0226 15:33:00.592000 2568045 torch/_dynamo/convert_frame.py:941] [1/0] 
I0226 15:33:00.594000 2568045 torch/_dynamo/symbolic_convert.py:2746] [1/0] Step 1: torchdynamo start tracing forward <eval_with_key>.29:4
I0226 15:33:00.595000 2568045 torch/fx/experimental/symbolic_shapes.py:3288] [1/0] create_env
V0226 15:33:00.596000 2568045 torch/_dynamo/variables/builder.py:2886] [1/0] wrap_to_fake L['x'] (1, 337, 80) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.DYNAMIC: 0>, <DimDynamic.DYNAMIC: 0>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[StrictMinMaxConstraint(warn_only=False, vr=VR[1, 1]), StrictMinMaxConstraint(warn_only=False, vr=VR[1, 1024]), None], constraint_strides=[None, None, None], view_base_context=None, tensor_source=LocalSource(local_name='x', is_input=True, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
W0226 15:33:00.597000 2568045 torch/_dynamo/utils.py:2512] [1/0] Unsupported: quantized nyi in meta tensors with fake tensor propagation.
I0226 15:33:00.597000 2568045 torch/_dynamo/convert_frame.py:1059] [1/0] run_gc_after_compile: running gc
Traceback (most recent call last):
  File "python3.10/site-packages/torch/_dynamo/utils.py", line 2507, in wrap_fake_exception
    return fn()
  File "python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2894, in <lambda>
    lambda: tx.fake_mode.from_tensor(
  File "python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2604, in from_tensor
    return self.fake_tensor_converter.from_real_tensor(
  File "python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 350, in from_real_tensor
    raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
torch._subclasses.fake_tensor.UnsupportedFakeTensorException: quantized nyi in meta tensors

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/temp.py", line 67, in <module>
    xnnpack_quantize_n_convert(m,x,ilens)
  File "/temp.py", line 54, in xnnpack_quantize_n_convert
    exported_program: ExportedProgram  = export(mod=m, args=q_si,dynamic_shapes = dynamic_shapes)
  File "python3.10/site-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "python3.10/site-packages/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "python3.10/site-packages/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
  File "python3.10/site-packages/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
  File "python3.10/site-packages/torch/export/_trace.py", line 2079, in _export
    return _export_for_training(
  File "python3.10/site-packages/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "python3.10/site-packages/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
  File "python3.10/site-packages/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
  File "python3.10/site-packages/torch/export/_trace.py", line 1944, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "python3.10/site-packages/torch/export/_trace.py", line 1296, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "python3.10/site-packages/torch/export/_trace.py", line 693, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1579, in inner
    result_traced = opt_f(*args, **kwargs)
  File "python3.10/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "python3.10/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
    return forward_call(*args, **kwargs)
  File "python3.10/site-packages/torch/_dynamo/eval_frame.py", line 570, in _fn
    return fn(*args, **kwargs)
  File "python3.10/site-packages/torch/fx/graph_module.py", line 824, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
    raise e
  File "python3.10/site-packages/torch/fx/graph_module.py", line 387, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "python3.10/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "python3.10/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
    return forward_call(*args, **kwargs)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1400, in __call__
    return self._torchdynamo_orig_callable(
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 565, in __call__
    return _compile(
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 997, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 726, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 760, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1404, in transform_code_object
    transformations(instructions, code_options)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 236, in _fn
    return fn(*args, **kwargs)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 660, in transform
    tracer = InstructionTranslator(
  File "python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in __init__
    self.symbolic_locals = variables.LazyVariableTracker.realize_all(
  File "python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 139, in realize_all
    result = {k: cls.realize_all(v, cache) for k, v in list(value.items())}
  File "python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 139, in <dictcomp>
    result = {k: cls.realize_all(v, cache) for k, v in list(value.items())}
  File "python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 125, in realize_all
    result = cls.realize_all(value.realize(), cache)
  File "python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 67, in realize
    self._cache.realize()
  File "python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 33, in realize
    self.vt = VariableTracker.build(tx, self.value, source)
  File "python3.10/site-packages/torch/_dynamo/variables/base.py", line 456, in build
    return builder.VariableBuilder(tx, source)(value)
  File "python3.10/site-packages/torch/_dynamo/variables/builder.py", line 384, in __call__
    vt = self._wrap(value)
  File "python3.10/site-packages/torch/_dynamo/variables/builder.py", line 548, in _wrap
    return type_dispatch(self, value)
  File "python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1708, in wrap_tensor
    example_value = wrap_to_fake_tensor_and_record(
  File "python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2893, in wrap_to_fake_tensor_and_record
    fake_e = wrap_fake_exception(
  File "python3.10/site-packages/torch/_dynamo/utils.py", line 2513, in wrap_fake_exception
    unimplemented(msg, from_exc=e)
  File "python3.10/site-packages/torch/_dynamo/exc.py", line 379, in unimplemented
    raise Unsupported(msg, case_name=case_name) from from_exc
torch._dynamo.exc.Unsupported: Unsupported: quantized nyi in meta tensors with fake tensor propagation.

I0226 15:33:04.084000 2568109 torch/_dynamo/eval_frame.py:392] TorchDynamo attempted to trace the following frames: [
I0226 15:33:04.084000 2568109 torch/_dynamo/eval_frame.py:392] 
I0226 15:33:04.084000 2568109 torch/_dynamo/eval_frame.py:392] ]
I0226 15:33:04.094000 2568109 torch/_dynamo/utils.py:746] TorchDynamo compilation metrics:
I0226 15:33:04.094000 2568109 torch/_dynamo/utils.py:746] Function    Runtimes (s)
I0226 15:33:04.094000 2568109 torch/_dynamo/utils.py:746] ----------  --------------
V0226 15:33:04.094000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.095000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats defer_runtime_assert: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0226 15:33:04.095000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats evaluate_expr: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0226 15:33:04.095000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.096000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0226 15:33:04.096000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _find: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.097000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0226 15:33:04.097000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0226 15:33:04.097000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats simplify: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.098000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.098000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats replace: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.098000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.099000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats get_implications: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.099000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats get_axioms: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.099000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _maybe_evaluate_static_worker: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.100000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats safe_expand: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0226 15:33:04.100000 2568109 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats uninteresting_files: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
I0226 15:33:04.705000 2568045 torch/_dynamo/eval_frame.py:392] TorchDynamo attempted to trace the following frames: [
I0226 15:33:04.705000 2568045 torch/_dynamo/eval_frame.py:392]   * forward /temp.py:24
I0226 15:33:04.705000 2568045 torch/_dynamo/eval_frame.py:392]   * forward <eval_with_key>.29:4
I0226 15:33:04.705000 2568045 torch/_dynamo/eval_frame.py:392] ]
I0226 15:33:04.884000 2568045 torch/_dynamo/utils.py:746] TorchDynamo compilation metrics:
I0226 15:33:04.884000 2568045 torch/_dynamo/utils.py:746] Function                          Runtimes (s)
I0226 15:33:04.884000 2568045 torch/_dynamo/utils.py:746] ------------------------------  --------------
I0226 15:33:04.884000 2568045 torch/_dynamo/utils.py:746] _compile.compile_inner                  0.1343
I0226 15:33:04.884000 2568045 torch/_dynamo/utils.py:746] OutputGraph.call_user_compiler          0.0006
I0226 15:33:04.884000 2568045 torch/_dynamo/utils.py:746] gc                                      0.0012
V0226 15:33:04.884000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.884000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats defer_runtime_assert: CacheInfo(hits=20, misses=2, maxsize=256, currsize=2)
V0226 15:33:04.884000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats evaluate_expr: CacheInfo(hits=98, misses=8, maxsize=256, currsize=8)
V0226 15:33:04.884000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.885000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0226 15:33:04.885000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _find: CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)
V0226 15:33:04.885000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0226 15:33:04.885000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0226 15:33:04.885000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats simplify: CacheInfo(hits=0, misses=4, maxsize=None, currsize=4)
V0226 15:33:04.885000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.886000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats replace: CacheInfo(hits=1266, misses=19, maxsize=None, currsize=19)
V0226 15:33:04.886000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=1, misses=4, maxsize=None, currsize=4)
V0226 15:33:04.886000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats get_implications: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.886000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats get_axioms: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0226 15:33:04.886000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats _maybe_evaluate_static_worker: CacheInfo(hits=0, misses=3, maxsize=None, currsize=3)
V0226 15:33:04.886000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats safe_expand: CacheInfo(hits=2, misses=5, maxsize=256, currsize=5)
V0226 15:33:04.887000 2568045 torch/fx/experimental/symbolic_shapes.py:164] lru_cache_stats uninteresting_files: CacheInfo(hits=62, misses=1, maxsize=None, currsize=1)


Versions

PyTorch version: 2.7.0.dev20250131+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.31.4
Libc version: glibc-2.35

Python version: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.8.0-45-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True


Versions of relevant libraries:
[pip3] ai-edge-torch==0.3.0
[pip3] executorch==0.6.0a0+b5344c1
[pip3] numpy==2.1.3
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] optree==0.14.0
[pip3] torch==2.7.0.dev20250131+cpu
[pip3] torch_xla2==0.0.1.dev202412041639
[pip3] torchao==0.8.0+git11333ba2
[pip3] torchaudio==2.6.0.dev20250131+cpu
[pip3] torchsr==1.0.4
[pip3] torchvision==0.22.0.dev20250131+cpu
[pip3] triton==3.2.0

cc @digantdesai @mcr229 @cbilgin @mergennachin @byjlw

Metadata

Metadata

Assignees

Labels

module: user experienceIssues related to reducing friction for usersmodule: xnnpackIssues related to xnnpack delegation and the code under backends/xnnpack/need-user-inputThe issue needs more information from the reporter before moving forwardtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

Projects

Status

Done

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions