Skip to content

Commit

Permalink
Torch export-to-ONNX script fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed May 9, 2023
1 parent d231b55 commit b9631be
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 99 deletions.
8 changes: 2 additions & 6 deletions returnn/frontend/run_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,9 @@ def mark_as_output(self, tensor: Union[Tensor, Any], name: str, *, dims: Optiona
assert self.stage == "forward_step"
if not isinstance(tensor, Tensor):
assert isinstance(tensor, _backend.global_backend.RawTensorType)
if dims is None:
# We trust the user that the raw tensor has a well-defined dim order.
# So just create some dummy dims.
dims = [
Dim(None, name=f"{name}-raw-axis-{i}") for i in range(_backend.global_backend.get_ndim_raw(tensor))
]
tensor = rf.convert_to_tensor(tensor, dims=dims)
# In case it was not specified, just accept whatever order we got.
dims = tensor.dims
assert name not in self.outputs.data
if dims is None:
# We try some reasonable defaults, specifically: BTF or BF.
Expand Down
2 changes: 1 addition & 1 deletion returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_new_dim_raw(raw_tensor: torch.Tensor, axis: int, *, name: str) -> Dim:
:param name:
:return: new Dim object
"""
return Dim(raw_tensor.size(axis), name=name)
return Dim(int(raw_tensor.size(axis)), name=name)

@staticmethod
def get_device(x: Tensor[torch.Tensor]) -> Optional[str]:
Expand Down
11 changes: 0 additions & 11 deletions returnn/torch/frontend/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,3 @@ def rf_module(self) -> rf.Module:
def forward(self, *args, **kwargs):
"""forward"""
return self._rf_module(*args, **kwargs)


class _RFTensorAsPTTensor(torch.Tensor):
"""
This class is meant to be instantiated through torch_tensor_object.as_subclass(_RFTensorAsPTTensor),
so it doesn't have __init__().
"""

def __init__(self, rf_tensor: rf.Tensor):
super().__init__()
self.rf_tensor = rf_tensor
219 changes: 138 additions & 81 deletions tools/export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
Since get_model() can return either a torch.nn.Module or a rf.Module, both cases must be taken into account.
"""


from __future__ import annotations
import torch
from typing import Callable, Optional, Sequence
from typing import Callable, Optional, Dict
import argparse
import os
import numpy
Expand All @@ -17,9 +17,10 @@
from returnn.tensor import Dim, Tensor, TensorDict

# noinspection PyProtectedMember
from returnn.torch.frontend.bridge import _RFModuleAsPTModule, _RFTensorAsPTTensor
from returnn.torch.frontend.bridge import _RFModuleAsPTModule
import returnn.frontend as rf
import returnn.util.basic as util
from returnn.torch.data.tensor_utils import tensor_numpy_to_torch_
import returnn.__main__ as rnn


Expand Down Expand Up @@ -53,6 +54,7 @@ def init(config_filename: str, checkpoint: str, log_verbosity: int, device: str)
print("RETURNN frontend module to ONNX conversion.", file=log.v1)
rnn.returnn_greeting()
rnn.init_backend_engine()
config.typed_dict.setdefault("backend", "torch")

This comment has been minimized.

Copy link
@Icemole

Icemole May 10, 2023

Collaborator

Shouldn't this line be above the previous line, rnn.init_backend_engine(), as the call to init_backend_engine() gets what's in the config (through BackendEngine.select_engine()) and sets the backend accordingly?

This comment has been minimized.

Copy link
@albertz

albertz May 10, 2023

Author Member

Ah you are right.

Btw, in any case, all this code here is not really good. We should never import sth from __main__. We should rather directly call corresponding functions. I was too lazy to rewrite all of that now.

This comment has been minimized.

Copy link
@Icemole

Icemole May 10, 2023

Collaborator

I think for writing this I copy-pasted some of the code from another tool, so if it's not correct here maybe it's not correct on other tools either?

This comment has been minimized.

Copy link
@albertz

albertz May 10, 2023

Author Member

It's not that it is not correct, it is just very much not recommended. Some other tools have this only for historical reasons.

assert util.BackendEngine.is_torch_selected(), "For now only the torch backend is supported."
rnn.init_faulthandler()

Expand All @@ -62,94 +64,52 @@ class ForwardModulePT(torch.nn.Module):
Wrapper of a PyTorch module that's meant to call forward_step from the config when called.
"""

def __init__(self, pt_module: torch.nn.Module, forward_step: Callable):
def __init__(self, pt_module: torch.nn.Module, forward_step: Callable, extern_data: TensorDict):
"""
:param pt_module: RF module as obtained from the config.
:param forward_step: forward_step function as obtained from the config.
:param extern_data:
"""
super().__init__()

self.model = pt_module
self.forward_step_func = forward_step
self.extern_data = extern_data

def __call__(self, data: _RFTensorAsPTTensor):
def __call__(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Wrapper to forward_step from the config.
"""
extern_data = TensorDict()
extern_data.update({"data": data.rf_tensor}, auto_convert=True)
extern_data = self.extern_data.copy_template()
extern_data.assign_from_raw_tensor_dict_(data)
self.forward_step_func(model=self.model, extern_data=extern_data)
# debug_raw_tensor_dict = rf.get_run_ctx().outputs.as_raw_tensor_dict()
# return debug_raw_tensor_dict # doesnt work, as there's more than one output in the dict (output:size0, etc)
return rf.get_run_ctx().outputs.data["output"].raw_tensor # works
return rf.get_run_ctx().outputs.as_raw_tensor_dict()


class ForwardModuleRF(_RFModuleAsPTModule):
"""
Wrapper of a RETURNN frontend module that's meant to call forward_step from the config when called.
"""

def __init__(self, rf_module: rf.Module, forward_step: Callable):
def __init__(self, rf_module: rf.Module, forward_step: Callable, extern_data: TensorDict):
"""
:param rf_module: RF module as obtained from the config.
:param forward_step: forward_step function as obtained from the config.
:param extern_data:
"""
super().__init__(rf_module)

self.forward_step_func = forward_step
self.extern_data = extern_data

def __call__(self, data: _RFTensorAsPTTensor):
def __call__(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Wrapper to forward_step from the config.
"""
extern_data = TensorDict()
extern_data.update({"data": data.rf_tensor}, auto_convert=True)
extern_data = self.extern_data.copy_template()
extern_data.assign_from_raw_tensor_dict_(data)
self.forward_step_func(model=self.rf_module, extern_data=extern_data)
# debug_raw_tensor_dict = rf.get_run_ctx().outputs.as_raw_tensor_dict()
# return debug_raw_tensor_dict # doesnt work, as there's more than one output in the dict (output:size0, etc)
return rf.get_run_ctx().outputs.data["output"].raw_tensor # works


def fill_batch_time_dims(dims: Sequence[Dim]):
"""
Creates random capacities for batch and time dimensions.
This step is prior to creating a random tensor.
:param dims: Input dimensions extracted from extern_data. This argument is modified in-place.
"""
rnd = numpy.random.RandomState(42)
initial_dims = []
initial_raw_dims = []

# Handle batch dim(s)
# TODO: more refined logic
# TODO: what if some dim.kind is undefined? E.g. in rf-demo all dims would be batch if not by d.name.startswith...
local_batch_dims = [
d
for d in dims
if not d.is_spatial_dim()
and not d.is_feature_dim()
and not d.name.startswith("time")
and not d.name.startswith("in")
]
for local_batch_dim in local_batch_dims:
raw_tensor = rnd.randint(5, 20, size=initial_raw_dims, dtype="int32")
local_batch_dim.dyn_size_ext = Tensor(
name=local_batch_dim.name, dims=initial_dims, dtype="int32", raw_tensor=raw_tensor
)
initial_dims.append(local_batch_dim)
initial_raw_dims.append(raw_tensor.size)

# Handle time dim(s) in a similar way to batch dim(s)
# TODO: more refined logic, also same problem as above
local_time_dims = [d for d in dims if d.name.startswith("time")]
for local_time_dim in local_time_dims:
raw_tensor = rnd.randint(5, 20, size=initial_raw_dims, dtype="int32")
local_time_dim.dyn_size_ext = Tensor(
name=local_time_dim.name, dims=initial_dims, dtype="int32", raw_tensor=raw_tensor
)
initial_dims.append(local_time_dim)
initial_raw_dims.append(raw_tensor.size)
return rf.get_run_ctx().outputs.as_raw_tensor_dict()


def main():
Expand All @@ -168,6 +128,7 @@ def main():

init(config_filename=args.config, checkpoint=args.checkpoint, log_verbosity=args.verbosity, device=args.device)
rf.init_forward_step_run_ctx()
rf.set_random_seed(42)

get_model_func = config.typed_value("get_model")
assert get_model_func, "get_model() isn't specified in the config passed as a parameter."
Expand All @@ -185,44 +146,140 @@ def main():
assert forward_step_func is not None, "forward_step() must be defined in the config."

extern_data_dict = config.typed_value("extern_data")
extern_data_aux = TensorDict()
extern_data_aux.update(extern_data_dict, auto_convert=True)
dims = extern_data_aux["data"].dims
extern_data = TensorDict()
extern_data.update(extern_data_dict, auto_convert=True)

fill_batch_time_dims(dims)
for v in extern_data.data.values():
_reset_tensor(v)
rnd = numpy.random.RandomState(42)
for v in extern_data.data.values():
_fill_random(v, rnd=rnd)
for v in extern_data.data.values():
tensor_numpy_to_torch_(v)
extern_data_raw = extern_data.as_raw_tensor_dict()

dtype = extern_data_aux["data"].dtype
dummy_tensor = rf.random(dims=dims, dtype=dtype, distribution="uniform")
# dummy_tensor = _RFTensorAsPTTensor(dummy_tensor)
dummy_final_tensor = dummy_tensor.raw_tensor.as_subclass(_RFTensorAsPTTensor)
dummy_final_tensor.rf_tensor = dummy_tensor
if is_pt_module:
model.load_state_dict(loaded_checkpoint["model"])
model.eval()
pt_model_fwd = ForwardModulePT(model, forward_step_func)
# dummy_tensor = dummy_tensor.raw_tensor
pt_model_fwd = ForwardModulePT(model, forward_step_func, extern_data)
else:
pt_model_fwd = ForwardModuleRF(model, forward_step_func)
pt_model_fwd = ForwardModuleRF(model, forward_step_func, extern_data)
pt_model_fwd.load_state_dict(loaded_checkpoint["model"])
pt_model_fwd.eval()

# extern_data_raw = extern_data.as_raw_tensor_dict()
# dummy_tensor = extern_data_raw["data"]
model_outputs_dict = config.typed_value("model_outputs")
model_outputs = TensorDict()
model_outputs.update(model_outputs_dict, auto_convert=True)
model_outputs_raw_keys = []
for k, v in model_outputs.data.items():
model_outputs_raw_keys.append(k)
for i, dim in enumerate(v.dims):
if dim.is_batch_dim() or dim.is_dynamic():

This comment has been minimized.

Copy link
@Icemole

Icemole May 10, 2023

Collaborator

Same as below regarding the usage of is_batch_dim().

This comment has been minimized.

Copy link
@albertz

albertz May 10, 2023

Author Member

Same, see below.

model_outputs_raw_keys.append(f"{k}:size{i}")

dynamic_axes = {}
for k, v in list(extern_data.data.items()) + list(model_outputs.data.items()):
dynamic_axes[k] = {i: dim.name for i, dim in enumerate(v.dims) if dim.is_dynamic() or dim.is_batch_dim()}

This comment has been minimized.

Copy link
@Icemole

Icemole May 10, 2023

Collaborator

Is the usage of is_batch_dim() correct here? I remember you told me in the LSTM implementation that I do not use this function, but do the following to get the batch dimension/s:

batch_dims = [d for d in source.dims if d != spatial_dim and d != in_dim]

Moreover, is_batch_dim() checks the Dim.kind attribute, which as you stated in the github issue is only for aesthetic purposes and shouldn't be checked:

You never ever should access or depend on the kind.

Maybe we should find another way to get the batch dimension/s?

This comment has been minimized.

Copy link
@albertz

albertz May 10, 2023

Author Member

In all code of the backend, is_batch_dim is never correct, because the batch dims are always those dims which are not specified. What is really a batch dim? The meaning here is different: In the backend, it refers to those dims which are not relevant for some op, so where it would apply the same op on all entries. That is the only meaning in the backend, when you read sth like "other dims are treated as batch dims". It does not matter what is_batch_dim() returns.

However, is_batch_dim(), and the global batch_dim object, and the batch dim in ExternData, in BatchInfo, in collate_batch, etc, that has a different meaning: It means the multiple sequences from the dataset which are put together into a mini batch.

So this global batch dim, from the mini batch, this is kind of a singleton. is_batch_dim() is supposed to return True only for this. (It gets a bit more complicated in the TF-net-dict with some more special cases, but those should not matter for the RETURNN frontend.)

You are right, it checks the kind internally, which is also bad, but we assume there is really only this one dimension where this is True. But this is a historic implementation detail which will probably also be cleaned up at some later point.

Further, this global batch dim dim tag is treated a bit different from other dim tags, again for historical reasons. We probably should clean this up. E.g. dyn_size_ext should normally have been defined when it is dynamic, as it normally is, but this is often not defined for the batch dim properly. For the RF, there is some code in place to actually clean this up.

for i, dim in enumerate(v.dims):
if dim.dyn_size_ext:
dynamic_axes[f"{k}:size{i}"] = {
j: dim_.name
for j, dim_ in enumerate(dim.dyn_size_ext.dims)
if dim_.is_dynamic() or dim_.is_batch_dim()
}

export_func(
pt_model_fwd,
(dummy_final_tensor,),
(extern_data_raw, {}),
f=args.out_onnx_filename,
verbose=True,
input_names=["data"], # , "data_len"],
output_names=["classes"],
dynamic_axes={
"data": {0: "batch", 1: "time"}, # TODO: automatically infer dynamic axes
# "data_len": {0: "batch"},
# "classes": {0: "batch", 1: "time"},
},
input_names=list(extern_data_raw.keys()),
output_names=model_outputs_raw_keys,
dynamic_axes=dynamic_axes,
)


def _reset_tensor(x: Tensor):
"""reset"""
x.batch = None
x.raw_tensor = None
for dim in x.dims:
dim.batch = None
if dim.dyn_size_ext:
_reset_tensor(dim.dyn_size_ext)


def _fill_random(

This comment has been minimized.

Copy link
@Icemole

Icemole May 10, 2023

Collaborator

_fill_random() and _reset_tensor() are already available in tests/rf_utils.py. Maybe it would be better to move both to a common place and import them from there: maybe frontend/_utils.py? Since they are quite useful, they could also be needed by a future tool/test.

This comment has been minimized.

Copy link
@albertz

albertz May 10, 2023

Author Member

Yeah, I thought about that. I was not sure though how useful they would be otherwise. Also, they have a few hardcoded numbers in there, which might not be optimal for all cases.

x: Tensor,
*,
min_val: int = 0,
max_val: Optional[int] = None,
rnd: numpy.random.RandomState,
dyn_dim_max_sizes: Optional[Dict[Dim, int]] = None,
) -> bool:
"""fill. return whether sth was filled"""
if dyn_dim_max_sizes is None:
dyn_dim_max_sizes = {}
filled = False
while True:
have_unfilled = False
filled_this_round = False

for dim in x.dims:
if dim.is_batch_dim() and not dim.dyn_size_ext:
dim.dyn_size_ext = Tensor("batch", [], dtype="int32")
if not dim.dyn_size_ext:
continue
if _fill_random(
dim.dyn_size_ext,
min_val=2,
max_val=dyn_dim_max_sizes.get(dim, None),
rnd=rnd,
dyn_dim_max_sizes=dyn_dim_max_sizes,
):
if dim in dyn_dim_max_sizes:
# Make sure at least one of the dyn sizes matches the max size.
i = rnd.randint(0, dim.dyn_size_ext.raw_tensor.size)
dim.dyn_size_ext.raw_tensor.flat[i] = dyn_dim_max_sizes[dim]
filled = True
filled_this_round = True
if dim.dyn_size_ext.raw_tensor is None:
have_unfilled = True
elif not isinstance(dim.dyn_size_ext.raw_tensor, numpy.ndarray):
have_unfilled = True

if have_unfilled:
assert filled_this_round, f"should have filled something, {x}"

if not have_unfilled:
break

if x.raw_tensor is not None:
if not isinstance(x.raw_tensor, numpy.ndarray):
x.raw_tensor = None

if x.raw_tensor is None:
shape = [d.get_dim_value() for d in x.dims]
if x.dtype.startswith("int"):
if max_val is None:
max_val = rnd.randint(5, 20)
if x.sparse_dim and x.sparse_dim.dimension is not None:
max_val = x.sparse_dim.dimension
x.raw_tensor = rnd.randint(min_val, max_val, size=shape, dtype=x.dtype)
elif x.dtype.startswith("float"):
x.raw_tensor = rnd.normal(0.0, 1.0, size=shape).astype(x.dtype)
elif x.dtype.startswith("complex"):
real = rnd.normal(0.0, 1.0, size=shape)
imag = rnd.normal(0.0, 1.0, size=shape)
x.raw_tensor = (real + 1j * imag).astype(x.dtype)
else:
raise NotImplementedError(f"not implemented for {x} dtype {x.dtype}")
filled = True

assert isinstance(x.raw_tensor, numpy.ndarray)

return filled


if __name__ == "__main__":
main()

0 comments on commit b9631be

Please sign in to comment.