-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,9 +5,9 @@ | |
Since get_model() can return either a torch.nn.Module or a rf.Module, both cases must be taken into account. | ||
""" | ||
|
||
|
||
from __future__ import annotations | ||
import torch | ||
from typing import Callable, Optional, Sequence | ||
from typing import Callable, Optional, Dict | ||
import argparse | ||
import os | ||
import numpy | ||
|
@@ -17,9 +17,10 @@ | |
from returnn.tensor import Dim, Tensor, TensorDict | ||
|
||
# noinspection PyProtectedMember | ||
from returnn.torch.frontend.bridge import _RFModuleAsPTModule, _RFTensorAsPTTensor | ||
from returnn.torch.frontend.bridge import _RFModuleAsPTModule | ||
import returnn.frontend as rf | ||
import returnn.util.basic as util | ||
from returnn.torch.data.tensor_utils import tensor_numpy_to_torch_ | ||
import returnn.__main__ as rnn | ||
|
||
|
||
|
@@ -53,6 +54,7 @@ def init(config_filename: str, checkpoint: str, log_verbosity: int, device: str) | |
print("RETURNN frontend module to ONNX conversion.", file=log.v1) | ||
rnn.returnn_greeting() | ||
rnn.init_backend_engine() | ||
config.typed_dict.setdefault("backend", "torch") | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
albertz
Author
Member
|
||
assert util.BackendEngine.is_torch_selected(), "For now only the torch backend is supported." | ||
rnn.init_faulthandler() | ||
|
||
|
@@ -62,94 +64,52 @@ class ForwardModulePT(torch.nn.Module): | |
Wrapper of a PyTorch module that's meant to call forward_step from the config when called. | ||
""" | ||
|
||
def __init__(self, pt_module: torch.nn.Module, forward_step: Callable): | ||
def __init__(self, pt_module: torch.nn.Module, forward_step: Callable, extern_data: TensorDict): | ||
""" | ||
:param pt_module: RF module as obtained from the config. | ||
:param forward_step: forward_step function as obtained from the config. | ||
:param extern_data: | ||
""" | ||
super().__init__() | ||
|
||
self.model = pt_module | ||
self.forward_step_func = forward_step | ||
self.extern_data = extern_data | ||
|
||
def __call__(self, data: _RFTensorAsPTTensor): | ||
def __call__(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
""" | ||
Wrapper to forward_step from the config. | ||
""" | ||
extern_data = TensorDict() | ||
extern_data.update({"data": data.rf_tensor}, auto_convert=True) | ||
extern_data = self.extern_data.copy_template() | ||
extern_data.assign_from_raw_tensor_dict_(data) | ||
self.forward_step_func(model=self.model, extern_data=extern_data) | ||
# debug_raw_tensor_dict = rf.get_run_ctx().outputs.as_raw_tensor_dict() | ||
# return debug_raw_tensor_dict # doesnt work, as there's more than one output in the dict (output:size0, etc) | ||
return rf.get_run_ctx().outputs.data["output"].raw_tensor # works | ||
return rf.get_run_ctx().outputs.as_raw_tensor_dict() | ||
|
||
|
||
class ForwardModuleRF(_RFModuleAsPTModule): | ||
""" | ||
Wrapper of a RETURNN frontend module that's meant to call forward_step from the config when called. | ||
""" | ||
|
||
def __init__(self, rf_module: rf.Module, forward_step: Callable): | ||
def __init__(self, rf_module: rf.Module, forward_step: Callable, extern_data: TensorDict): | ||
""" | ||
:param rf_module: RF module as obtained from the config. | ||
:param forward_step: forward_step function as obtained from the config. | ||
:param extern_data: | ||
""" | ||
super().__init__(rf_module) | ||
|
||
self.forward_step_func = forward_step | ||
self.extern_data = extern_data | ||
|
||
def __call__(self, data: _RFTensorAsPTTensor): | ||
def __call__(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
""" | ||
Wrapper to forward_step from the config. | ||
""" | ||
extern_data = TensorDict() | ||
extern_data.update({"data": data.rf_tensor}, auto_convert=True) | ||
extern_data = self.extern_data.copy_template() | ||
extern_data.assign_from_raw_tensor_dict_(data) | ||
self.forward_step_func(model=self.rf_module, extern_data=extern_data) | ||
# debug_raw_tensor_dict = rf.get_run_ctx().outputs.as_raw_tensor_dict() | ||
# return debug_raw_tensor_dict # doesnt work, as there's more than one output in the dict (output:size0, etc) | ||
return rf.get_run_ctx().outputs.data["output"].raw_tensor # works | ||
|
||
|
||
def fill_batch_time_dims(dims: Sequence[Dim]): | ||
""" | ||
Creates random capacities for batch and time dimensions. | ||
This step is prior to creating a random tensor. | ||
:param dims: Input dimensions extracted from extern_data. This argument is modified in-place. | ||
""" | ||
rnd = numpy.random.RandomState(42) | ||
initial_dims = [] | ||
initial_raw_dims = [] | ||
|
||
# Handle batch dim(s) | ||
# TODO: more refined logic | ||
# TODO: what if some dim.kind is undefined? E.g. in rf-demo all dims would be batch if not by d.name.startswith... | ||
local_batch_dims = [ | ||
d | ||
for d in dims | ||
if not d.is_spatial_dim() | ||
and not d.is_feature_dim() | ||
and not d.name.startswith("time") | ||
and not d.name.startswith("in") | ||
] | ||
for local_batch_dim in local_batch_dims: | ||
raw_tensor = rnd.randint(5, 20, size=initial_raw_dims, dtype="int32") | ||
local_batch_dim.dyn_size_ext = Tensor( | ||
name=local_batch_dim.name, dims=initial_dims, dtype="int32", raw_tensor=raw_tensor | ||
) | ||
initial_dims.append(local_batch_dim) | ||
initial_raw_dims.append(raw_tensor.size) | ||
|
||
# Handle time dim(s) in a similar way to batch dim(s) | ||
# TODO: more refined logic, also same problem as above | ||
local_time_dims = [d for d in dims if d.name.startswith("time")] | ||
for local_time_dim in local_time_dims: | ||
raw_tensor = rnd.randint(5, 20, size=initial_raw_dims, dtype="int32") | ||
local_time_dim.dyn_size_ext = Tensor( | ||
name=local_time_dim.name, dims=initial_dims, dtype="int32", raw_tensor=raw_tensor | ||
) | ||
initial_dims.append(local_time_dim) | ||
initial_raw_dims.append(raw_tensor.size) | ||
return rf.get_run_ctx().outputs.as_raw_tensor_dict() | ||
|
||
|
||
def main(): | ||
|
@@ -168,6 +128,7 @@ def main(): | |
|
||
init(config_filename=args.config, checkpoint=args.checkpoint, log_verbosity=args.verbosity, device=args.device) | ||
rf.init_forward_step_run_ctx() | ||
rf.set_random_seed(42) | ||
|
||
get_model_func = config.typed_value("get_model") | ||
assert get_model_func, "get_model() isn't specified in the config passed as a parameter." | ||
|
@@ -185,44 +146,140 @@ def main(): | |
assert forward_step_func is not None, "forward_step() must be defined in the config." | ||
|
||
extern_data_dict = config.typed_value("extern_data") | ||
extern_data_aux = TensorDict() | ||
extern_data_aux.update(extern_data_dict, auto_convert=True) | ||
dims = extern_data_aux["data"].dims | ||
extern_data = TensorDict() | ||
extern_data.update(extern_data_dict, auto_convert=True) | ||
|
||
fill_batch_time_dims(dims) | ||
for v in extern_data.data.values(): | ||
_reset_tensor(v) | ||
rnd = numpy.random.RandomState(42) | ||
for v in extern_data.data.values(): | ||
_fill_random(v, rnd=rnd) | ||
for v in extern_data.data.values(): | ||
tensor_numpy_to_torch_(v) | ||
extern_data_raw = extern_data.as_raw_tensor_dict() | ||
|
||
dtype = extern_data_aux["data"].dtype | ||
dummy_tensor = rf.random(dims=dims, dtype=dtype, distribution="uniform") | ||
# dummy_tensor = _RFTensorAsPTTensor(dummy_tensor) | ||
dummy_final_tensor = dummy_tensor.raw_tensor.as_subclass(_RFTensorAsPTTensor) | ||
dummy_final_tensor.rf_tensor = dummy_tensor | ||
if is_pt_module: | ||
model.load_state_dict(loaded_checkpoint["model"]) | ||
model.eval() | ||
pt_model_fwd = ForwardModulePT(model, forward_step_func) | ||
# dummy_tensor = dummy_tensor.raw_tensor | ||
pt_model_fwd = ForwardModulePT(model, forward_step_func, extern_data) | ||
else: | ||
pt_model_fwd = ForwardModuleRF(model, forward_step_func) | ||
pt_model_fwd = ForwardModuleRF(model, forward_step_func, extern_data) | ||
pt_model_fwd.load_state_dict(loaded_checkpoint["model"]) | ||
pt_model_fwd.eval() | ||
|
||
# extern_data_raw = extern_data.as_raw_tensor_dict() | ||
# dummy_tensor = extern_data_raw["data"] | ||
model_outputs_dict = config.typed_value("model_outputs") | ||
model_outputs = TensorDict() | ||
model_outputs.update(model_outputs_dict, auto_convert=True) | ||
model_outputs_raw_keys = [] | ||
for k, v in model_outputs.data.items(): | ||
model_outputs_raw_keys.append(k) | ||
for i, dim in enumerate(v.dims): | ||
if dim.is_batch_dim() or dim.is_dynamic(): | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong. |
||
model_outputs_raw_keys.append(f"{k}:size{i}") | ||
|
||
dynamic_axes = {} | ||
for k, v in list(extern_data.data.items()) + list(model_outputs.data.items()): | ||
dynamic_axes[k] = {i: dim.name for i, dim in enumerate(v.dims) if dim.is_dynamic() or dim.is_batch_dim()} | ||
This comment has been minimized.
Sorry, something went wrong.
Icemole
Collaborator
|
||
for i, dim in enumerate(v.dims): | ||
if dim.dyn_size_ext: | ||
dynamic_axes[f"{k}:size{i}"] = { | ||
j: dim_.name | ||
for j, dim_ in enumerate(dim.dyn_size_ext.dims) | ||
if dim_.is_dynamic() or dim_.is_batch_dim() | ||
} | ||
|
||
export_func( | ||
pt_model_fwd, | ||
(dummy_final_tensor,), | ||
(extern_data_raw, {}), | ||
f=args.out_onnx_filename, | ||
verbose=True, | ||
input_names=["data"], # , "data_len"], | ||
output_names=["classes"], | ||
dynamic_axes={ | ||
"data": {0: "batch", 1: "time"}, # TODO: automatically infer dynamic axes | ||
# "data_len": {0: "batch"}, | ||
# "classes": {0: "batch", 1: "time"}, | ||
}, | ||
input_names=list(extern_data_raw.keys()), | ||
output_names=model_outputs_raw_keys, | ||
dynamic_axes=dynamic_axes, | ||
) | ||
|
||
|
||
def _reset_tensor(x: Tensor): | ||
"""reset""" | ||
x.batch = None | ||
x.raw_tensor = None | ||
for dim in x.dims: | ||
dim.batch = None | ||
if dim.dyn_size_ext: | ||
_reset_tensor(dim.dyn_size_ext) | ||
|
||
|
||
def _fill_random( | ||
This comment has been minimized.
Sorry, something went wrong.
Icemole
Collaborator
|
||
x: Tensor, | ||
*, | ||
min_val: int = 0, | ||
max_val: Optional[int] = None, | ||
rnd: numpy.random.RandomState, | ||
dyn_dim_max_sizes: Optional[Dict[Dim, int]] = None, | ||
) -> bool: | ||
"""fill. return whether sth was filled""" | ||
if dyn_dim_max_sizes is None: | ||
dyn_dim_max_sizes = {} | ||
filled = False | ||
while True: | ||
have_unfilled = False | ||
filled_this_round = False | ||
|
||
for dim in x.dims: | ||
if dim.is_batch_dim() and not dim.dyn_size_ext: | ||
dim.dyn_size_ext = Tensor("batch", [], dtype="int32") | ||
if not dim.dyn_size_ext: | ||
continue | ||
if _fill_random( | ||
dim.dyn_size_ext, | ||
min_val=2, | ||
max_val=dyn_dim_max_sizes.get(dim, None), | ||
rnd=rnd, | ||
dyn_dim_max_sizes=dyn_dim_max_sizes, | ||
): | ||
if dim in dyn_dim_max_sizes: | ||
# Make sure at least one of the dyn sizes matches the max size. | ||
i = rnd.randint(0, dim.dyn_size_ext.raw_tensor.size) | ||
dim.dyn_size_ext.raw_tensor.flat[i] = dyn_dim_max_sizes[dim] | ||
filled = True | ||
filled_this_round = True | ||
if dim.dyn_size_ext.raw_tensor is None: | ||
have_unfilled = True | ||
elif not isinstance(dim.dyn_size_ext.raw_tensor, numpy.ndarray): | ||
have_unfilled = True | ||
|
||
if have_unfilled: | ||
assert filled_this_round, f"should have filled something, {x}" | ||
|
||
if not have_unfilled: | ||
break | ||
|
||
if x.raw_tensor is not None: | ||
if not isinstance(x.raw_tensor, numpy.ndarray): | ||
x.raw_tensor = None | ||
|
||
if x.raw_tensor is None: | ||
shape = [d.get_dim_value() for d in x.dims] | ||
if x.dtype.startswith("int"): | ||
if max_val is None: | ||
max_val = rnd.randint(5, 20) | ||
if x.sparse_dim and x.sparse_dim.dimension is not None: | ||
max_val = x.sparse_dim.dimension | ||
x.raw_tensor = rnd.randint(min_val, max_val, size=shape, dtype=x.dtype) | ||
elif x.dtype.startswith("float"): | ||
x.raw_tensor = rnd.normal(0.0, 1.0, size=shape).astype(x.dtype) | ||
elif x.dtype.startswith("complex"): | ||
real = rnd.normal(0.0, 1.0, size=shape) | ||
imag = rnd.normal(0.0, 1.0, size=shape) | ||
x.raw_tensor = (real + 1j * imag).astype(x.dtype) | ||
else: | ||
raise NotImplementedError(f"not implemented for {x} dtype {x.dtype}") | ||
filled = True | ||
|
||
assert isinstance(x.raw_tensor, numpy.ndarray) | ||
|
||
return filled | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Shouldn't this line be above the previous line,
rnn.init_backend_engine()
, as the call toinit_backend_engine()
gets what's in the config (throughBackendEngine.select_engine()
) and sets the backend accordingly?