Skip to content

Commit

Permalink
♻️ Refactor types to detect invalid extra arguments (ray-project#25541)
Browse files Browse the repository at this point in the history
Currently, each function decorated with `@ray.remote` is marked with type annotations as a `RemoteFunction` class (only used for type annotations, autocompletion, inline errors, etc). The current class takes several *type parameters*. And then it uses those parameters in the extended `func.remote()` method.

But with the current type annotations, it marks any of the unused type parameters as `None`. This means that calling the `.remote()` method would check the first (actual) arguments and the rest are marked as `None`, but that means that for type annotations it considers "correct" to pass extra `None` arguments, while actually, that would not be valid. So, this doesn't show an error, but it should:

<img width="371" alt="Screenshot 2022-06-07 at 05 38 48" src="https://user-images.githubusercontent.com/1326112/172360355-9b344220-7824-4b5c-87da-038f5b53fe04.png">

...those 2 extra `None` values should be marked as invalid.

After this PR, those invalid extra arguments would be marked as invalid:

<img width="588" alt="Screenshot 2022-06-07 at 05 42 10" src="https://user-images.githubusercontent.com/1326112/172360956-424b40d4-8197-4663-8298-617a1df37658.png">

And:

<img width="687" alt="Screenshot 2022-06-07 at 05 42 50" src="https://user-images.githubusercontent.com/1326112/172361140-eb93c675-f5d6-4e0c-b9b2-83c4801bb450.png">

## More context

I also tried the new `TypeVarTuple`, it might simplify these type annotations in the future, but it's not currently supported by mypy yet, it's a very recent addition to the language (and `typing_extensions`) so it's probably too early to adopt it.
  • Loading branch information
tiangolo authored Jun 7, 2022
1 parent 3876fcd commit 3257994
Showing 1 changed file with 77 additions and 40 deletions.
117 changes: 77 additions & 40 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,27 +115,43 @@
R = TypeVar("R")


class RemoteFunction(Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]):
def __init__(
self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]
) -> None:
class RemoteFunctionNoArgs(Generic[R]):
def __init__(self, function: Callable[[], R]) -> None:
pass

@overload
def remote(self) -> "ObjectRef[R]":
def remote(
self,
) -> "ObjectRef[R]":
...

@overload
def remote(self, __arg0: "Union[T0, ObjectRef[T0]]") -> "ObjectRef[R]":

class RemoteFunction0(Generic[R, T0]):
def __init__(self, function: Callable[[T0], R]) -> None:
pass

def remote(
self,
__arg0: "Union[T0, ObjectRef[T0]]",
) -> "ObjectRef[R]":
...

@overload

class RemoteFunction1(Generic[R, T0, T1]):
def __init__(self, function: Callable[[T0, T1], R]) -> None:
pass

def remote(
self, __arg0: "Union[T0, ObjectRef[T0]]", __arg1: "Union[T1, ObjectRef[T1]]"
self,
__arg0: "Union[T0, ObjectRef[T0]]",
__arg1: "Union[T1, ObjectRef[T1]]",
) -> "ObjectRef[R]":
...

@overload

class RemoteFunction2(Generic[R, T0, T1, T2]):
def __init__(self, function: Callable[[T0, T1, T2], R]) -> None:
pass

def remote(
self,
__arg0: "Union[T0, ObjectRef[T0]]",
Expand All @@ -144,7 +160,11 @@ def remote(
) -> "ObjectRef[R]":
...

@overload

class RemoteFunction3(Generic[R, T0, T1, T2, T3]):
def __init__(self, function: Callable[[T0, T1, T2, T3], R]) -> None:
pass

def remote(
self,
__arg0: "Union[T0, ObjectRef[T0]]",
Expand All @@ -154,7 +174,11 @@ def remote(
) -> "ObjectRef[R]":
...

@overload

class RemoteFunction4(Generic[R, T0, T1, T2, T3, T4]):
def __init__(self, function: Callable[[T0, T1, T2, T3, T4], R]) -> None:
pass

def remote(
self,
__arg0: "Union[T0, ObjectRef[T0]]",
Expand All @@ -165,7 +189,11 @@ def remote(
) -> "ObjectRef[R]":
...

@overload

class RemoteFunction5(Generic[R, T0, T1, T2, T3, T4, T5]):
def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> None:
pass

def remote(
self,
__arg0: "Union[T0, ObjectRef[T0]]",
Expand All @@ -177,7 +205,11 @@ def remote(
) -> "ObjectRef[R]":
...

@overload

class RemoteFunction6(Generic[R, T0, T1, T2, T3, T4, T5, T6]):
def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> None:
pass

def remote(
self,
__arg0: "Union[T0, ObjectRef[T0]]",
Expand All @@ -190,7 +222,11 @@ def remote(
) -> "ObjectRef[R]":
...

@overload

class RemoteFunction7(Generic[R, T0, T1, T2, T3, T4, T5, T6, T7]):
def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> None:
pass

def remote(
self,
__arg0: "Union[T0, ObjectRef[T0]]",
Expand All @@ -204,7 +240,13 @@ def remote(
) -> "ObjectRef[R]":
...

@overload

class RemoteFunction8(Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]):
def __init__(
self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]
) -> None:
pass

def remote(
self,
__arg0: "Union[T0, ObjectRef[T0]]",
Expand All @@ -219,7 +261,13 @@ def remote(
) -> "ObjectRef[R]":
...

@overload

class RemoteFunction9(Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]):
def __init__(
self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]
) -> None:
pass

def remote(
self,
__arg0: "Union[T0, ObjectRef[T0]]",
Expand All @@ -235,9 +283,6 @@ def remote(
) -> "ObjectRef[R]":
...

def remote(self, *args, **kwargs) -> "ObjectRef[R]":
...


# Visible for testing.
def _unhandled_error_handler(e: Exception):
Expand Down Expand Up @@ -2347,79 +2392,71 @@ def _make_remote(function_or_class, options):


@overload
def remote(
function: Callable[[], R]
) -> RemoteFunction[R, None, None, None, None, None, None, None, None, None, None]:
def remote(function: Callable[[], R]) -> RemoteFunctionNoArgs[R]:
...


@overload
def remote(
function: Callable[[T0], R]
) -> RemoteFunction[R, T0, None, None, None, None, None, None, None, None, None]:
def remote(function: Callable[[T0], R]) -> RemoteFunction0[R, T0]:
...


@overload
def remote(
function: Callable[[T0, T1], R]
) -> RemoteFunction[R, T0, T1, None, None, None, None, None, None, None, None]:
def remote(function: Callable[[T0, T1], R]) -> RemoteFunction1[R, T0, T1]:
...


@overload
def remote(
function: Callable[[T0, T1, T2], R]
) -> RemoteFunction[R, T0, T1, T2, None, None, None, None, None, None, None]:
def remote(function: Callable[[T0, T1, T2], R]) -> RemoteFunction2[R, T0, T1, T2]:
...


@overload
def remote(
function: Callable[[T0, T1, T2, T3], R]
) -> RemoteFunction[R, T0, T1, T2, T3, None, None, None, None, None, None]:
) -> RemoteFunction3[R, T0, T1, T2, T3]:
...


@overload
def remote(
function: Callable[[T0, T1, T2, T3, T4], R]
) -> RemoteFunction[R, T0, T1, T2, T3, T4, None, None, None, None, None]:
) -> RemoteFunction4[R, T0, T1, T2, T3, T4]:
...


@overload
def remote(
function: Callable[[T0, T1, T2, T3, T4, T5], R]
) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, None, None, None, None]:
) -> RemoteFunction5[R, T0, T1, T2, T3, T4, T5]:
...


@overload
def remote(
function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]
) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, None, None, None]:
) -> RemoteFunction6[R, T0, T1, T2, T3, T4, T5, T6]:
...


@overload
def remote(
function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]
) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, None, None]:
) -> RemoteFunction7[R, T0, T1, T2, T3, T4, T5, T6, T7]:
...


@overload
def remote(
function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]
) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, None]:
) -> RemoteFunction8[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]:
...


@overload
def remote(
function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]
) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]:
) -> RemoteFunction9[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]:
...


Expand Down

0 comments on commit 3257994

Please sign in to comment.