diff --git a/qa/common/gen_common.py b/qa/common/gen_common.py index 8bd97720b5..d574627dfd 100644 --- a/qa/common/gen_common.py +++ b/qa/common/gen_common.py @@ -24,6 +24,8 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from typing import List + # Common utilities for model generation scripts import numpy as np @@ -154,5 +156,5 @@ def np_to_torch_dtype(np_dtype): elif np_dtype == np.float64: return torch.double elif np_dtype == np_dtype_string: - return None # Not supported in Torch + return List[str] return None