Skip to content

Commit abfb542

Browse files
committed
Random files
1 parent ac5c176 commit abfb542

File tree

3 files changed

+1654
-0
lines changed

3 files changed

+1654
-0
lines changed

python/arithmetic_ops.py

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
from typing import Any
2+
from typing_extensions import assert_type
3+
4+
from torch import randn, Tensor
5+
6+
7+
TENSOR, INT, FLOAT, BOOL = randn(3), 2, 1.5, True
8+
9+
# Test deduced types of arithmetic operations between tensors, ints, floats and bools
10+
# The expected type should always be `Tensor`: `Any` and `bool` below are wrong.
11+
# See https://github.com/pytorch/pytorch/issues/145838
12+
13+
# Unary ops
14+
15+
assert_type(+TENSOR, Tensor)
16+
assert_type(-TENSOR, Tensor)
17+
assert_type(~TENSOR, Tensor)
18+
19+
# Binary ops
20+
21+
assert_type(TENSOR == TENSOR, Tensor)
22+
assert_type(TENSOR != TENSOR, Tensor)
23+
assert_type(TENSOR < TENSOR, Tensor)
24+
assert_type(TENSOR > TENSOR, Tensor)
25+
assert_type(TENSOR <= TENSOR, Tensor)
26+
assert_type(TENSOR >= TENSOR, Tensor)
27+
assert_type(TENSOR + TENSOR, Tensor)
28+
assert_type(TENSOR - TENSOR, Tensor)
29+
assert_type(TENSOR * TENSOR, Tensor)
30+
assert_type(TENSOR // TENSOR, Any)
31+
assert_type(TENSOR / TENSOR, Tensor)
32+
assert_type(TENSOR % TENSOR, Tensor)
33+
assert_type(TENSOR**TENSOR, Any)
34+
assert_type(TENSOR << TENSOR, Tensor)
35+
assert_type(TENSOR >> TENSOR, Tensor)
36+
assert_type(TENSOR & TENSOR, Tensor)
37+
assert_type(TENSOR | TENSOR, Tensor)
38+
assert_type(TENSOR ^ TENSOR, Tensor)
39+
40+
assert_type(TENSOR == BOOL, Tensor)
41+
assert_type(TENSOR != BOOL, Tensor)
42+
assert_type(TENSOR < BOOL, Tensor)
43+
assert_type(TENSOR > BOOL, Tensor)
44+
assert_type(TENSOR <= BOOL, Tensor)
45+
assert_type(TENSOR >= BOOL, Tensor)
46+
assert_type(TENSOR + BOOL, Tensor)
47+
assert_type(TENSOR - BOOL, Tensor)
48+
assert_type(TENSOR * BOOL, Tensor)
49+
assert_type(TENSOR // BOOL, Any)
50+
assert_type(TENSOR / BOOL, Tensor)
51+
assert_type(TENSOR % BOOL, Tensor)
52+
assert_type(TENSOR**BOOL, Any)
53+
assert_type(TENSOR << BOOL, Tensor)
54+
assert_type(TENSOR >> BOOL, Tensor)
55+
assert_type(TENSOR & BOOL, Tensor)
56+
assert_type(TENSOR | BOOL, Tensor)
57+
assert_type(TENSOR ^ BOOL, Tensor)
58+
59+
assert_type(BOOL == TENSOR, bool)
60+
assert_type(BOOL != TENSOR, bool)
61+
assert_type(BOOL < TENSOR, Tensor)
62+
assert_type(BOOL > TENSOR, Tensor)
63+
assert_type(BOOL <= TENSOR, Tensor)
64+
assert_type(BOOL >= TENSOR, Tensor)
65+
assert_type(BOOL + TENSOR, Tensor)
66+
assert_type(BOOL - TENSOR, Any)
67+
assert_type(BOOL * TENSOR, Tensor)
68+
assert_type(BOOL // TENSOR, Any)
69+
assert_type(BOOL / TENSOR, Any)
70+
assert_type(BOOL % TENSOR, Any)
71+
assert_type(BOOL**TENSOR, Any)
72+
assert_type(BOOL << TENSOR, Any)
73+
assert_type(BOOL >> TENSOR, Any)
74+
assert_type(BOOL & TENSOR, Tensor)
75+
assert_type(BOOL | TENSOR, Tensor)
76+
assert_type(BOOL ^ TENSOR, Tensor)
77+
78+
assert_type(TENSOR == INT, Tensor)
79+
assert_type(TENSOR != INT, Tensor)
80+
assert_type(TENSOR < INT, Tensor)
81+
assert_type(TENSOR > INT, Tensor)
82+
assert_type(TENSOR <= INT, Tensor)
83+
assert_type(TENSOR >= INT, Tensor)
84+
assert_type(TENSOR + INT, Tensor)
85+
assert_type(TENSOR - INT, Tensor)
86+
assert_type(TENSOR * INT, Tensor)
87+
assert_type(TENSOR // INT, Any)
88+
assert_type(TENSOR / INT, Tensor)
89+
assert_type(TENSOR % INT, Tensor)
90+
assert_type(TENSOR**INT, Any)
91+
assert_type(TENSOR << INT, Tensor)
92+
assert_type(TENSOR >> INT, Tensor)
93+
assert_type(TENSOR & INT, Tensor)
94+
assert_type(TENSOR | INT, Tensor)
95+
assert_type(TENSOR ^ INT, Tensor)
96+
97+
assert_type(INT == TENSOR, bool)
98+
assert_type(INT != TENSOR, bool)
99+
assert_type(INT < TENSOR, Tensor)
100+
assert_type(INT > TENSOR, Tensor)
101+
assert_type(INT <= TENSOR, Tensor)
102+
assert_type(INT >= TENSOR, Tensor)
103+
assert_type(INT + TENSOR, Tensor)
104+
assert_type(INT - TENSOR, Any)
105+
assert_type(INT * TENSOR, Tensor)
106+
assert_type(INT // TENSOR, Any)
107+
assert_type(INT / TENSOR, Any)
108+
assert_type(INT % TENSOR, Any)
109+
assert_type(INT**TENSOR, Any)
110+
assert_type(INT << TENSOR, Any)
111+
assert_type(INT >> TENSOR, Any)
112+
assert_type(INT & TENSOR, Any) # type: ignore[operator]
113+
assert_type(INT | TENSOR, Any) # type: ignore[operator]
114+
assert_type(INT ^ TENSOR, Any) # type: ignore[operator]
115+
116+
assert_type(TENSOR == FLOAT, Tensor)
117+
assert_type(TENSOR != FLOAT, Tensor)
118+
assert_type(TENSOR < FLOAT, Tensor)
119+
assert_type(TENSOR > FLOAT, Tensor)
120+
assert_type(TENSOR <= FLOAT, Tensor)
121+
assert_type(TENSOR >= FLOAT, Tensor)
122+
assert_type(TENSOR + FLOAT, Tensor)
123+
assert_type(TENSOR - FLOAT, Tensor)
124+
assert_type(TENSOR * FLOAT, Tensor)
125+
assert_type(TENSOR // FLOAT, Any)
126+
assert_type(TENSOR / FLOAT, Tensor)
127+
assert_type(TENSOR % FLOAT, Tensor)
128+
assert_type(TENSOR**FLOAT, Any)
129+
assert_type(TENSOR << FLOAT, Tensor)
130+
assert_type(TENSOR >> FLOAT, Tensor)
131+
assert_type(TENSOR & FLOAT, Tensor)
132+
assert_type(TENSOR | FLOAT, Tensor)
133+
assert_type(TENSOR ^ FLOAT, Tensor)
134+
135+
assert_type(FLOAT == TENSOR, bool)
136+
assert_type(FLOAT != TENSOR, bool)
137+
assert_type(FLOAT < TENSOR, Tensor)
138+
assert_type(FLOAT > TENSOR, Tensor)
139+
assert_type(FLOAT <= TENSOR, Tensor)
140+
assert_type(FLOAT >= TENSOR, Tensor)
141+
assert_type(FLOAT + TENSOR, Tensor)
142+
assert_type(FLOAT - TENSOR, Any)
143+
assert_type(FLOAT * TENSOR, Tensor)
144+
assert_type(FLOAT // TENSOR, Any)
145+
assert_type(FLOAT / TENSOR, Any)
146+
assert_type(FLOAT % TENSOR, Any)
147+
assert_type(FLOAT**TENSOR, Any)
148+
assert_type(FLOAT << TENSOR, Any)
149+
assert_type(FLOAT >> TENSOR, Any)
150+
assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator]
151+
assert_type(FLOAT | TENSOR, Tensor) # type: ignore[operator]
152+
assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator]
153+
154+
155+
class Binary:
156+
def __add__(self, other: int) -> "Binary": # type: ignore[override]
157+
return self
158+
159+
def __and__(self, other: int) -> "Binary": # type: ignore[override]
160+
return self
161+
162+
def __ceil__(self, other: int) -> "Binary": # type: ignore[override]
163+
return self
164+
165+
def __cmp__(self, other: int) -> "Binary": # type: ignore[override]
166+
return self
167+
168+
def __div__(self, other: int) -> "Binary": # type: ignore[override]
169+
return self
170+
171+
def __divmod__(self, other: int) -> "Binary": # type: ignore[override]
172+
return self
173+
174+
def __eq__(self, other: int) -> "Binary": # type: ignore[override]
175+
return self
176+
177+
def __floor__(self, other: int) -> "Binary": # type: ignore[override]
178+
return self
179+
180+
def __floordiv__(self, other: int) -> "Binary": # type: ignore[override]
181+
return self
182+
183+
def __ge__(self, other: int) -> "Binary": # type: ignore[override]
184+
return self
185+
186+
def __gt__(self, other: int) -> "Binary": # type: ignore[override]
187+
return self
188+
189+
def __iadd__(self, other: int) -> "Binary": # type: ignore[override]
190+
return self
191+
192+
def __iand__(self, other: int) -> "Binary": # type: ignore[override]
193+
return self
194+
195+
def __idiv__(self, other: int) -> "Binary": # type: ignore[override]
196+
return self
197+
198+
def __idivmod__(self, other: int) -> "Binary": # type: ignore[override]
199+
return self
200+
201+
def __ifloordiv__(self, other: int) -> "Binary": # type: ignore[override]
202+
return self
203+
204+
def __ilshift__(self, other: int) -> "Binary": # type: ignore[override]
205+
return self
206+
207+
def __imod__(self, other: int) -> "Binary": # type: ignore[override]
208+
return self
209+
210+
def __imul__(self, other: int) -> "Binary": # type: ignore[override]
211+
return self
212+
213+
def __invert__(self, other: int) -> "Binary": # type: ignore[override]
214+
return self
215+
216+
def __ior__(self, other: int) -> "Binary": # type: ignore[override]
217+
return self
218+
219+
def __ipow__(self, other: int) -> "Binary": # type: ignore[override]
220+
return self
221+
222+
def __irshift__(self, other: int) -> "Binary": # type: ignore[override]
223+
return self
224+
225+
def __isub__(self, other: int) -> "Binary": # type: ignore[override]
226+
return self
227+
228+
def __itruediv__(self, other: int) -> "Binary": # type: ignore[override]
229+
return self
230+
231+
def __ixor__(self, other: int) -> "Binary": # type: ignore[override]
232+
return self
233+
234+
def __le__(self, other: int) -> "Binary": # type: ignore[override]
235+
return self
236+
237+
def __lshift__(self, other: int) -> "Binary": # type: ignore[override]
238+
return self
239+
240+
def __lt__(self, other: int) -> "Binary": # type: ignore[override]
241+
return self
242+
243+
def __mod__(self, other: int) -> "Binary": # type: ignore[override]
244+
return self
245+
246+
def __mul__(self, other: int) -> "Binary": # type: ignore[override]
247+
return self
248+
249+
def __ne__(self, other: int) -> "Binary": # type: ignore[override]
250+
return self
251+
252+
def __or__(self, other: int) -> "Binary": # type: ignore[override]
253+
return self
254+
255+
def __pow__(self, other: int) -> "Binary": # type: ignore[override]
256+
return self
257+
258+
def __radd__(self, other: int) -> "Binary": # type: ignore[override]
259+
return self
260+
261+
def __rand__(self, other: int) -> "Binary": # type: ignore[override]
262+
return self
263+
264+
def __rdiv__(self, other: int) -> "Binary": # type: ignore[override]
265+
return self
266+
267+
def __rdivmod__(self, other: int) -> "Binary": # type: ignore[override]
268+
return self
269+
270+
def __rfloordiv__(self, other: int) -> "Binary": # type: ignore[override]
271+
return self
272+
273+
def __rlshift__(self, other: int) -> "Binary": # type: ignore[override]
274+
return self
275+
276+
def __rmod__(self, other: int) -> "Binary": # type: ignore[override]
277+
return self
278+
279+
def __rmul__(self, other: int) -> "Binary": # type: ignore[override]
280+
return self
281+
282+
def __ror__(self, other: int) -> "Binary": # type: ignore[override]
283+
return self
284+
285+
def __round__(self, other: int) -> "Binary": # type: ignore[override]
286+
return self
287+
288+
def __rpow__(self, other: int) -> "Binary": # type: ignore[override]
289+
return self
290+
291+
def __rrshift__(self, other: int) -> "Binary": # type: ignore[override]
292+
return self
293+
294+
def __rshift__(self, other: int) -> "Binary": # type: ignore[override]
295+
return self
296+
297+
def __rsub__(self, other: int) -> "Binary": # type: ignore[override]
298+
return self
299+
300+
def __rtruediv__(self, other: int) -> "Binary": # type: ignore[override]
301+
return self
302+
303+
def __rxor__(self, other: int) -> "Binary": # type: ignore[override]
304+
return self
305+
306+
def __sub__(self, other: int) -> "Binary": # type: ignore[override]
307+
return self
308+
309+
def __truediv__(self, other: int) -> "Binary": # type: ignore[override]
310+
return self
311+
312+
def __xor__(self, other: int) -> "Binary": # type: ignore[override]
313+
return self
314+
315+
316+
assert_type(Binary() + 5, Binary)
317+
assert_type(Binary() & 5, Binary)
318+
assert_type(Binary() / 5, Binary)
319+
assert_type(Binary() == 5, Binary)
320+
assert_type(Binary() // 5, Binary)
321+
assert_type(Binary() >= 5, Binary)
322+
assert_type(Binary() > 5, Binary)
323+
assert_type(Binary() <= 5, Binary)
324+
assert_type(Binary() << 5, Binary)
325+
assert_type(Binary() < 5, Binary)
326+
assert_type(Binary() % 5, Binary)
327+
assert_type(Binary() * 5, Binary)
328+
assert_type(Binary() != 5, Binary)
329+
assert_type(Binary() | 5, Binary)
330+
assert_type(Binary()**5, Binary)
331+
assert_type(Binary() >> 5, Binary)
332+
assert_type(Binary() - 5, Binary)
333+
assert_type(Binary() ^ 5, Binary)
334+
335+
assert_type(5 + Binary(), Binary)
336+
assert_type(5 & Binary(), Binary)
337+
assert_type(5 / Binary(), Binary)
338+
assert_type(5 == Binary(), bool)
339+
assert_type(5 // Binary(), Binary)
340+
assert_type(5 >= Binary(), Binary)
341+
assert_type(5 > Binary(), Binary)
342+
assert_type(5 <= Binary(), Binary)
343+
assert_type(5 << Binary(), Binary)
344+
assert_type(5 < Binary(), Binary)
345+
assert_type(5 % Binary(), Binary)
346+
assert_type(5 * Binary(), Binary)
347+
assert_type(5 != Binary(), bool)
348+
assert_type(5 | Binary(), Binary)
349+
assert_type(5**Binary(), Binary)
350+
assert_type(5 >> Binary(), Binary)
351+
assert_type(5 - Binary(), Binary)
352+
assert_type(5 ^ Binary(), Binary)

python/generic_parent.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Type, Generic, get_args, TypeVar
4+
T = TypeVar('T')
5+
6+
7+
class NameableType(Generic[T]):
8+
@classmethod
9+
def type(cls) -> Type[T]:
10+
return get_args(cls.__orig_bases__[0])[0]
11+
12+
13+
class StrType(NameableType[str]):
14+
pass
15+
16+
17+
class IntType(NameableType[int]):
18+
pass
19+
20+
21+
print(StrType.type(), IntType.type())

0 commit comments

Comments
 (0)