Skip to content

Commit fe6db77

Browse files
authored
[Refactor] Make @Tensorclass work properly with pyright (#1042)
1 parent 59a0ce5 commit fe6db77

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

tensordict/tensorclass.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,7 @@
2424
from dataclasses import dataclass
2525
from pathlib import Path
2626
from textwrap import indent
27-
from typing import (
28-
Any,
29-
Callable,
30-
get_type_hints,
31-
List,
32-
overload,
33-
Sequence,
34-
Type,
35-
TypeVar,
36-
)
27+
from typing import Any, Callable, get_type_hints, List, Sequence, Type, TypeVar
3728

3829
import numpy as np
3930
import orjson as json
@@ -371,20 +362,8 @@ def __call__(self, cls):
371362
return clz
372363

373364

374-
@overload
375-
def tensorclass(autocast: bool = False, frozen: bool = False) -> _tensorclass_dec: ...
376-
377-
378-
@overload
379-
def tensorclass(cls: T) -> T: ...
380-
381-
382-
@overload
383-
def tensorclass(cls: T) -> T: ...
384-
385-
386365
@dataclass_transform()
387-
def tensorclass(*args, **kwargs):
366+
def tensorclass(cls=None, /, *, autocast: bool = False, frozen: bool = False):
388367
"""A decorator to create :obj:`tensorclass` classes.
389368
390369
``tensorclass`` classes are specialized :func:`dataclasses.dataclass` instances that
@@ -465,7 +444,17 @@ def tensorclass(*args, **kwargs):
465444
466445
467446
"""
468-
return _tensorclass_dec(*args, **kwargs)
447+
448+
def wrap(cls):
449+
return _tensorclass_dec(autocast, frozen)(cls)
450+
451+
# See if we're being called as @tensorclass or @tensorclass().
452+
if cls is None:
453+
# We're called with parens.
454+
return wrap
455+
456+
# We're called as @tensorclass without parens.
457+
return wrap(cls)
469458

470459

471460
@dataclass_transform()

0 commit comments

Comments
 (0)