Skip to content

RuntimeError: ipex_prepack::convolution_prepack() Expected a value of type 'List[int]' for argument 'padding' but instead found type 'str'. #283

Open
@Wallfacer005CN

Description

@Wallfacer005CN

padding='same' is supported by PyTorch but .....

With too many bugs, I decided to give up [intel-extension-for-pytorch].


RuntimeError Traceback (most recent call last)
Cell In[32], line 1
----> 1 models = train_models()

Cell In[31], line 27, in train_models()
25 model = ResNet50()
26 # model = ipex.optimize(model, dtype=torch.float32)
---> 27 trained_models = train_hyperparams(model)
29 models.extend(trained_models)
31 writer.close()

Cell In[31], line 14, in train_hyperparams(model)
12 # optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=1e-07)
13 optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
---> 14 model, optimizer = ipex.optimize(model, optimizer=optimizer)
15 loss_fn = nn.CrossEntropyLoss()
16 model, train_name, seconds = train_loop(train_loader, model, optimizer,loss_fn, loop)

File /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/frontend.py:294, in optimize(model, dtype, optimizer, level, inplace, conv_bn_folding, weights_prepack, replace_dropout_with_identity, optimize_lstm, split_master_weight_for_bf16, fuse_update_step, auto_kernel_selection, sample_input)
290 if dtype == torch.bfloat16:
291 assert core.onednn_has_bf16_support(),
292 "BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq, " +
293 "please set dtype to torch.float or set weights_prepack to False."
--> 294 optimized_model, optimized_optimizer, params_attr = utils._weight_prepack.weight_prepack_with_ipex(
295 optimized_model, optimized_optimizer, params_attr, opt_properties.auto_kernel_selection)
296 # TODO: model list, optimizer list.
297 if optimizer is None:

File /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py:331, in weight_prepack_with_ipex(module, optimizer, params_attr, auto_kernel_selection)
328 setattr(new_m, name, convert_rec(sub_m, optimizer, params_attr, auto_kernel_selection)[0])
329 return new_m, optimizer, params_attr
--> 331 opt_model, opt_optmizer, params_attr = convert_rec(module, optimizer, params_attr, auto_kernel_selection)
332 if opt_optmizer is not None:
333 setattr(opt_optmizer, 'params_attr', params_attr)

File /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py:328, in weight_prepack_with_ipex..convert_rec(m, optimizer, params_attr, auto_kernel_selection)
326 new_m = convert(m, optimizer, params_attr, auto_kernel_selection)
327 for name, sub_m in m.named_children():
--> 328 setattr(new_m, name, convert_rec(sub_m, optimizer, params_attr, auto_kernel_selection)[0])
329 return new_m, optimizer, params_attr

File /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py:328, in weight_prepack_with_ipex..convert_rec(m, optimizer, params_attr, auto_kernel_selection)
326 new_m = convert(m, optimizer, params_attr, auto_kernel_selection)
327 for name, sub_m in m.named_children():
--> 328 setattr(new_m, name, convert_rec(sub_m, optimizer, params_attr, auto_kernel_selection)[0])
329 return new_m, optimizer, params_attr

[... skipping similar frames: weight_prepack_with_ipex.<locals>.convert_rec at line 328 (1 times)]

File /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py:328, in weight_prepack_with_ipex..convert_rec(m, optimizer, params_attr, auto_kernel_selection)
326 new_m = convert(m, optimizer, params_attr, auto_kernel_selection)
327 for name, sub_m in m.named_children():
--> 328 setattr(new_m, name, convert_rec(sub_m, optimizer, params_attr, auto_kernel_selection)[0])
329 return new_m, optimizer, params_attr

File /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py:326, in weight_prepack_with_ipex..convert_rec(m, optimizer, params_attr, auto_kernel_selection)
325 def convert_rec(m, optimizer, params_attr, auto_kernel_selection):
--> 326 new_m = convert(m, optimizer, params_attr, auto_kernel_selection)
327 for name, sub_m in m.named_children():
328 setattr(new_m, name, convert_rec(sub_m, optimizer, params_attr, auto_kernel_selection)[0])

File /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py:293, in weight_prepack_with_ipex..convert(m, optimizer, params_attr, auto_kernel_selection)
291 if weight not in params_attr:
292 params_attr[weight] = {}
--> 293 new_m = IPEX_WEIGHT_PREPACK_MODULEtype(m)
294 params_attr[weight].update({
295 'op': type(m),
296 'ctx': new_m.ctx})
297 if hasattr(new_m, "weight_channels_last"):

File /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py:86, in _IPEXConv2d.init(self, dense_module)
85 def init(self, dense_module):
---> 86 super(_IPEXConv2d, self).init(dense_module)

File /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py:37, in _IPEXConvNd.init(self, dense_module)
35 self.register_parameter('bias', None)
36 # create conv op context
---> 37 self.ctx = torch.ops.ipex_prepack.convolution_prepack(
38 dense_module.weight, self.bias, self.stride, self.padding,
39 self.dilation, self.groups,
40 self.weight_channels_last, self.prepack_input_shape
41 )
43 self.weight = nn.Parameter(self.ctx.get_weight(), requires_grad = dense_module.weight.requires_grad)
45 # pack master_weight or weight_trail if needed

File /usr/local/lib/python3.8/dist-packages/torch/_ops.py:143, in OpOverloadPacket.call(self, *args, **kwargs)
138 def call(self, *args, **kwargs):
139 # overloading call to ensure torch.ops.foo.bar()
140 # is still callable from JIT
141 # We save the function ptr as the op attribute on
142 # OpOverloadPacket to access it here.
--> 143 return self._op(*args, **kwargs or {})

RuntimeError: ipex_prepack::convolution_prepack() Expected a value of type 'List[int]' for argument 'padding' but instead found type 'str'.
Position: 3
Value: 'same'
Declaration: ipex_prepack::convolution_prepack(Tensor W, Tensor? B, int[] stride, int[] padding, int[] dilation, int groups, bool input_is_channels_last, int[] input_sizes) -> (torch.torch.classes.ipex_prepack.ConvolutionOpContext)
Cast error details: Unable to cast Python instance to C++ type (compile in debug mode for details)

Metadata

Metadata

Assignees

No one assigned

    Labels

    BugSomething isn't workingCPUCPU specific issuesCrashExecution crashes

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions