Skip to content

Commit

Permalink
[FX] Remove unused code from FX decoder (openvinotoolkit#23954)
Browse files Browse the repository at this point in the history
### Details:
 - *Remove unused code from FX decoder*

### Tickets:
 - *ticket-id*
  • Loading branch information
mvafin authored Apr 11, 2024
1 parent 992874f commit 2eb09a4
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 125 deletions.
137 changes: 16 additions & 121 deletions src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

class TorchFXPythonDecoder (Decoder):

def __init__(self, pt_module, fx_gm, nodes=None, mark_node_callback=None, input_shapes=[], input_types=[]):
def __init__(self, pt_module, fx_gm=None, nodes=None, mark_node_callback=None, input_shapes=[], input_types=[]):
Decoder.__init__(self)
self.mark_node_callback = mark_node_callback
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = []
self.pt_module = pt_module
self.fx_gm = fx_gm
self.fx_gm = fx_gm if fx_gm is not None else pt_module
self.input_types = [OVAny(pt_to_ov_type_map[str(t)])
for t in input_types]
self.input_shapes = input_shapes
Expand All @@ -46,14 +46,16 @@ def __init__(self, pt_module, fx_gm, nodes=None, mark_node_callback=None, input_
self._input_signature.append(value.name)
if hasattr(value, "meta") and ('tensor_meta' in value.meta.keys()) and value.meta['tensor_meta']:
found_shapes.append(value.meta['tensor_meta'].shape)
found_types.append(OVAny(pt_to_ov_type_map[str(value.meta['tensor_meta'].dtype)]))
found_types.append(
OVAny(pt_to_ov_type_map[str(value.meta['tensor_meta'].dtype)]))
else:
found_shapes.append(None)
found_types.append(None)
elif self._nodes[i].op == 'output':
# Instead of putting output index, refer to its target
uargs = self.unpack_containers(self._nodes[i].args)
self._outputs = [(arg[0], self._nodes.index(arg[1])) for arg in uargs if arg[1] is not None]
self._outputs = [(arg[0], self._nodes.index(arg[1]))
for arg in uargs if arg[1] is not None]

if not input_shapes or len(input_shapes) == 0:
self.input_shapes = found_shapes
Expand Down Expand Up @@ -270,8 +272,9 @@ def get_subgraphs(self):
return list(self.pt_module.blocks())

def get_subgraph_decoder(self, index):
decoder = TorchFXPythonDecoder(self.get_subgraphs(
)[index], self.fx_gm, mark_node_callback=self.mark_node_callback)
decoder = TorchFXPythonDecoder(self.get_subgraphs()[index],
self.fx_gm,
mark_node_callback=self.mark_node_callback)
self.m_decoders.append(decoder)
return decoder

Expand All @@ -284,8 +287,7 @@ def get_op_type(self):
return 'UNKNOWN_TYPE_' + str(self.pt_module.op)

def get_schema(self):
return ''
return self.pt_module.schema()
return 'NONE'

def outputs(self):
return [o[1] for o in self._outputs]
Expand Down Expand Up @@ -318,115 +320,15 @@ def mark_node(self, node):
return node

def as_constant(self):

if self.pt_module.op == 'get_attr':
# Extract Constant from FX module field
ret = fetch_attr(self.fx_gm, self.pt_module.target)
ov_const = torch_tensor_to_ov_const(ret, shared_memory=True)
return ov_const.outputs()

if not self.get_op_type() == 'prim::Constant':
return None
pt_value = self._raw_output(0)

pt_type_class = pt_value.type().__class__
if pt_type_class is torch.TensorType:
return self.as_constant_tensor(pt_value)
if pt_type_class is torch.ListType:
return self.as_constant_list(pt_value)
if str(pt_value.type()) in ['torch.int32', 'int']:
return make_constant(OVType.i32, Shape([]), [pt_value.toIValue()]).outputs()
if str(pt_value.type()) in ['torch.float', 'torch.FloatType', 'float']:
return make_constant(OVType.f32, Shape([]), [pt_value.toIValue()]).outputs()
if str(pt_value.type()) in ['torch.bool', 'bool']:
return make_constant(OVType.boolean, Shape([]), [pt_value.toIValue()]).outputs()

return None
assert self.pt_module.op == 'get_attr', "Only get_attr is supported"
# Extract Constant from FX module field
ret = fetch_attr(self.fx_gm, self.pt_module.target)
ov_const = torch_tensor_to_ov_const(ret, shared_memory=True)
return ov_const.outputs()

def as_string(self):
if not self.get_op_type() == 'prim::Constant':
return None
pt_value = self._raw_output(0)

if str(pt_value.type()) in ['torch.StringType', 'str']:
return pt_value.toIValue()
return None

def as_constant_tensor(self, pt_value):
ivalue = pt_value.toIValue()
if pt_value.isCompleteTensor():
try:
ivalue = ivalue.to(
memory_format=torch.contiguous_format).detach().cpu()
except:
logger.warning("Tensor couldn't detach")
if str(pt_value.type().dtype()) in pt_to_py_type_map:
# Constant interpretation doesn't respect new-full type of PT
# It recognizes only tensors, and give lists as 1D tensors, and scalars as Tensor scalars
# So only tensor-type constants are supported
ovshape = PartialShape(pt_value.type().sizes())
ovtype = pt_to_ov_type_map[str(pt_value.type().dtype())]

# TODO: try-except here is a temporary WA for issues with data_ptr that we currently cannot predict; provide better solution
try:
# this is only possible with adding a new ctor for Constant Python binding
# TODO Check strides and pass them somehow
values = ivalue.data_ptr()
ov_const = make_constant(
ovtype, ovshape.get_shape(), values)
except:
# old variant that makes a slow data copying
logger.warning("Constant wasn't able to convert from data_ptr.")
values = ivalue.flatten().tolist()
ov_const = make_constant(
ovtype, ovshape.get_shape(), values)
return ov_const.outputs()
else:
# Incomplete tensor can be scalar
if isinstance(ivalue, float):
return make_constant(OVType.f32, Shape([]), [ivalue]).outputs()
if isinstance(ivalue, int):
return make_constant(OVType.i64, Shape([]), [ivalue]).outputs()
if isinstance(ivalue, bool):
return make_constant(OVType.boolean, Shape([]), [ivalue]).outputs()

# TODO: verify that it correctly reads incomplete consts
if str(ivalue.type()) in pt_to_ov_type_map:
try:
ovshape = PartialShape(ivalue.size())
ovtype = pt_to_ov_type_map[str(ivalue.type())]
ov_const = make_constant(
ovtype, ovshape.get_shape(), ivalue.data_ptr())
except:
# old variant that makes a slow data copying
logger.warning("Constant wasn't able to convert from data_ptr.")
nvalues = ivalue.numpy(force=True)
ovtype = np_to_ov_type_map[str(nvalues.dtype)]
ovshape = PartialShape(nvalues.shape)
ov_const = make_constant(
ovtype, ovshape.get_shape(), nvalues.flatten().tolist())
return ov_const.outputs()
return None

def as_constant_list(self, pt_value):
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively rewrite them in that part where constant attributes are queried
pt_element_type = str(pt_value.type().getElementType())
ivalue = pt_value.toIValue()
is_known_type = pt_element_type in pt_to_ov_type_map

# WA to broken ov.Type
# Detect integer list and process it with a dedicated method
# TODO: Fix ov.Type and remove this WA
# if pt_to_py_type_map[pt_element_type] == 'int':
# self.as_constant_list_of_ints(ovshape = PartialShape([len(ivalue)]), ivalue)
# End of WA to broken ov.Type

if is_known_type:
ovtype = pt_to_ov_type_map[pt_element_type]
ovshape = PartialShape([len(ivalue)])
ov_const = make_constant(ovtype, ovshape.get_shape(), ivalue)
return ov_const.outputs()

def input_is_none(self, index):
if index >= len(self._inputs) or (isinstance(self._inputs[index], tuple) and self._inputs[index][0] is None):
return True
Expand All @@ -438,11 +340,4 @@ def debug(self):
self.pt_module.print()

def may_produce_alias(self, in_index: int, out_index: int) -> bool:
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::matmul"]:
# AliasDB::may_contain_alias sometimes return True for tensors produced by convnd, we have to workaround that
return False
try:
return self.alias_db.may_contain_alias(self._raw_input(in_index), self._raw_output(out_index))
except:
# Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node
return False
return False
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, options
input_types.append(input_data.type())
input_shapes.append(input_data.size())

decoder = TorchFXPythonDecoder(gm, gm, input_shapes=input_shapes, input_types=input_types)
decoder = TorchFXPythonDecoder(gm, input_shapes=input_shapes, input_types=input_types)

im = fe.load(decoder)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
from common.constants import test_device, test_precision
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder

from openvino.frontend import FrontEndManager
from openvino.runtime import Core, Type, PartialShape
Expand Down
2 changes: 1 addition & 1 deletion tests/model_hub_tests/pytorch/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def convert_model_impl(self, model_obj):
gm = graph.module()
print(gm.code)

decoder = TorchFXPythonDecoder(gm, gm)
decoder = TorchFXPythonDecoder(gm)
decoder._input_signature = list(self.example.keys())
ov_model = convert_model(decoder, verbose=True)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def extract_module_extensions(args):
if version.parse(torch.__version__) >= version.parse("2.2"):
model = model.run_decompositions()
gm = model.module()
decoder = TorchFXPythonDecoder(gm, gm)
decoder = TorchFXPythonDecoder(gm)
else:
decoder = TorchScriptPythonDecoder(
model,
Expand Down

0 comments on commit 2eb09a4

Please sign in to comment.