Skip to content

Commit 3c6a52f

Browse files
Optional -> A | None
1 parent ac6fba6 commit 3c6a52f

File tree

16 files changed

+66
-73
lines changed

16 files changed

+66
-73
lines changed

scripts/parity/gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
import string
3-
from typing import Callable, Optional, TypeVar
3+
from typing import Callable, TypeVar
44

55
A = TypeVar("A")
66

@@ -37,7 +37,7 @@ def gen_choice(choices: list[A]) -> Callable[[], A]:
3737
return lambda: random.choice(choices)
3838

3939

40-
def gen_opt(gen_x: Callable[[], A]) -> Callable[[], Optional[A]]:
40+
def gen_opt(gen_x: Callable[[], A]) -> Callable[[], A | None]:
4141
return lambda: gen_x() if gen_bool() else None
4242

4343

src/replit_river/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import contextmanager
44
from dataclasses import dataclass
55
from datetime import timedelta
6-
from typing import Any, AsyncGenerator, Generator, Generic, Literal, Optional, Union
6+
from typing import Any, AsyncGenerator, Generator, Generic, Literal, Union
77

88
from opentelemetry import trace
99
from opentelemetry.trace import Span, SpanKind, Status, StatusCode
@@ -100,9 +100,9 @@ async def send_upload(
100100
self,
101101
service_name: str,
102102
procedure_name: str,
103-
init: Optional[InitType],
103+
init: InitType | None,
104104
request: AsyncIterable[RequestType],
105-
init_serializer: Optional[Callable[[InitType], Any]],
105+
init_serializer: Callable[[InitType], Any] | None,
106106
request_serializer: Callable[[RequestType], Any],
107107
response_deserializer: Callable[[Any], ResponseType],
108108
error_deserializer: Callable[[Any], ErrorType],
@@ -151,9 +151,9 @@ async def send_stream(
151151
self,
152152
service_name: str,
153153
procedure_name: str,
154-
init: Optional[InitType],
154+
init: InitType | None,
155155
request: AsyncIterable[RequestType],
156-
init_serializer: Optional[Callable[[InitType], Any]],
156+
init_serializer: Callable[[InitType], Any] | None,
157157
request_serializer: Callable[[RequestType], Any],
158158
response_deserializer: Callable[[Any], ResponseType],
159159
error_deserializer: Callable[[Any], ErrorType],
@@ -186,7 +186,7 @@ class _SpanHandle:
186186
def set_status(
187187
self,
188188
status: Union[Status, StatusCode],
189-
description: Optional[str] = None,
189+
description: str | None = None,
190190
) -> None:
191191
if self.did_set_status:
192192
return

src/replit_river/client_session.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from collections.abc import AsyncIterable
44
from datetime import timedelta
5-
from typing import Any, AsyncGenerator, Callable, Optional, Union
5+
from typing import Any, AsyncGenerator, Callable, Union
66

77
import nanoid # type: ignore
88
from aiochannel import Channel
@@ -102,9 +102,9 @@ async def send_upload(
102102
self,
103103
service_name: str,
104104
procedure_name: str,
105-
init: Optional[InitType],
105+
init: InitType | None,
106106
request: AsyncIterable[RequestType],
107-
init_serializer: Optional[Callable[[InitType], Any]],
107+
init_serializer: Callable[[InitType], Any] | None,
108108
request_serializer: Callable[[RequestType], Any],
109109
response_deserializer: Callable[[Any], ResponseType],
110110
error_deserializer: Callable[[Any], ErrorType],
@@ -241,9 +241,9 @@ async def send_stream(
241241
self,
242242
service_name: str,
243243
procedure_name: str,
244-
init: Optional[InitType],
244+
init: InitType | None,
245245
request: AsyncIterable[RequestType],
246-
init_serializer: Optional[Callable[[InitType], Any]],
246+
init_serializer: Callable[[InitType], Any] | None,
247247
request_serializer: Callable[[RequestType], Any],
248248
response_deserializer: Callable[[Any], ResponseType],
249249
error_deserializer: Callable[[Any], ErrorType],

src/replit_river/client_transport.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
from collections.abc import Awaitable, Callable
4-
from typing import Generic, Optional
4+
from typing import Generic
55

66
import websockets
77
from pydantic import ValidationError
@@ -98,7 +98,7 @@ async def get_or_create_session(self) -> ClientSession:
9898
await existing_session.close()
9999
return await self._create_new_session()
100100

101-
async def _get_existing_session(self) -> Optional[ClientSession]:
101+
async def _get_existing_session(self) -> ClientSession | None:
102102
async with self._session_lock:
103103
if not self._sessions:
104104
return None
@@ -117,7 +117,7 @@ async def _get_existing_session(self) -> Optional[ClientSession]:
117117

118118
async def _establish_new_connection(
119119
self,
120-
old_session: Optional[ClientSession] = None,
120+
old_session: ClientSession | None = None,
121121
) -> tuple[
122122
WebSocketCommonProtocol,
123123
ControlMessageHandshakeRequest[HandshakeMetadataType],
@@ -129,7 +129,7 @@ async def _establish_new_connection(
129129
client_id = self._client_id
130130
logger.info("Attempting to establish new ws connection")
131131

132-
last_error: Optional[Exception] = None
132+
last_error: Exception | None = None
133133
for i in range(max_retry):
134134
if i > 0:
135135
logger.info(f"Retrying build handshake number {i} times")
@@ -221,7 +221,7 @@ async def _send_handshake_request(
221221
transport_id: str,
222222
to_id: str,
223223
session_id: str,
224-
handshake_metadata: Optional[HandshakeMetadataType],
224+
handshake_metadata: HandshakeMetadataType | None,
225225
websocket: WebSocketCommonProtocol,
226226
expected_session_state: ExpectedSessionState,
227227
) -> ControlMessageHandshakeRequest[HandshakeMetadataType]:
@@ -291,7 +291,7 @@ async def _establish_handshake(
291291
session_id: str,
292292
handshake_metadata: HandshakeMetadataType,
293293
websocket: WebSocketCommonProtocol,
294-
old_session: Optional[ClientSession],
294+
old_session: ClientSession | None,
295295
) -> tuple[
296296
ControlMessageHandshakeRequest[HandshakeMetadataType],
297297
ControlMessageHandshakeResponse,

src/replit_river/codegen/client.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Any,
88
Callable,
99
Literal,
10-
Optional,
1110
OrderedDict,
1211
Sequence,
1312
Set,
@@ -76,7 +75,6 @@
7675
from typing import (
7776
Any,
7877
Literal,
79-
Optional,
8078
Mapping,
8179
NotRequired,
8280
Union,
@@ -96,11 +94,11 @@
9694

9795

9896
class RiverConcreteType(BaseModel):
99-
type: Optional[str] = Field(default=None)
97+
type: str | None = Field(default=None)
10098
properties: dict[str, "RiverType"] = Field(default_factory=lambda: dict())
10199
required: Set[str] = Field(default=set())
102-
items: Optional["RiverType"] = Field(default=None)
103-
const: Optional[Union[str, int]] = Field(default=None)
100+
items: "RiverType | None" = Field(default=None)
101+
const: Union[str, int] | None = Field(default=None)
104102
patternProperties: dict[str, "RiverType"] = Field(default_factory=lambda: dict())
105103

106104

@@ -124,14 +122,14 @@ class RiverNotType(BaseModel):
124122

125123

126124
class RiverProcedure(BaseModel):
127-
init: Optional[RiverType] = Field(default=None)
125+
init: RiverType | None = Field(default=None)
128126
input: RiverType
129127
output: RiverType
130-
errors: Optional[RiverType] = Field(default=None)
128+
errors: RiverType | None = Field(default=None)
131129
type: (
132130
Literal["rpc"] | Literal["stream"] | Literal["subscription"] | Literal["upload"]
133131
)
134-
description: Optional[str] = Field(default=None)
132+
description: str | None = Field(default=None)
135133

136134

137135
class RiverService(BaseModel):
@@ -140,7 +138,7 @@ class RiverService(BaseModel):
140138

141139
class RiverSchema(BaseModel):
142140
services: dict[str, RiverService]
143-
handshakeSchema: Optional[RiverConcreteType] = Field(default=None)
141+
handshakeSchema: RiverConcreteType | None = Field(default=None)
144142

145143

146144
RiverSchemaFile = RootModel[RiverSchema]
@@ -640,7 +638,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
640638
"""
641639
)
642640
current_chunks.append(
643-
f" kind: Optional[{render_type_expr(type_name)}]{value}"
641+
f" kind: {render_type_expr(type_name)} | None{value}"
644642
)
645643
else:
646644
value = ""
@@ -663,7 +661,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
663661
reindent(
664662
" ",
665663
f"""\
666-
{name}: NotRequired[Optional[{render_type_expr(type_name)}]]
664+
{name}: NotRequired[{render_type_expr(type_name)}] | None
667665
""",
668666
)
669667
)
@@ -672,7 +670,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
672670
reindent(
673671
" ",
674672
f"""\
675-
{name}: Optional[{render_type_expr(type_name)}] = None
673+
{name}: {render_type_expr(type_name)} | None = None
676674
""",
677675
)
678676
)
@@ -772,7 +770,7 @@ def __init__(self, client: river.Client[Any]):
772770
]
773771
for name, procedure in schema.procedures.items():
774772
module_names = [ModuleName(name)]
775-
init_type: Optional[TypeExpression] = None
773+
init_type: TypeExpression | None = None
776774
if procedure.init:
777775
init_type, module_info, init_chunks, encoder_names = encode_type(
778776
procedure.init,
@@ -852,7 +850,7 @@ def __init__(self, client: river.Client[Any]):
852850
"""
853851

854852
# Init renderer
855-
render_init_method: Optional[str] = None
853+
render_init_method: str | None = None
856854
if init_type and procedure.init is not None:
857855
if input_base_class == "TypedDict":
858856
if is_literal(procedure.init):
@@ -878,7 +876,7 @@ def __init__(self, client: river.Client[Any]):
878876
)
879877

880878
# Input renderer
881-
render_input_method: Optional[str] = None
879+
render_input_method: str | None = None
882880
if input_base_class == "TypedDict":
883881
if is_literal(procedure.input):
884882
render_input_method = "lambda x: x"

src/replit_river/error_schema.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional
1+
from typing import Any
22

33
from pydantic import BaseModel
44

@@ -38,7 +38,7 @@ class RiverServiceException(RiverException):
3838
"""Exception raised by river as a result of a fault in the service running river."""
3939

4040
def __init__(
41-
self, code: str, message: str, service: Optional[str], procedure: Optional[str]
41+
self, code: str, message: str, service: str | None, procedure: str | None
4242
) -> None:
4343
self.code = code
4444
self.message = message
@@ -92,7 +92,7 @@ def stringify_exception(e: BaseException, limit: int = 10) -> str:
9292
# If there are no causes, just fall back to stringifying the exception.
9393
return str(e)
9494
causes: list[str] = []
95-
cause: Optional[BaseException] = e
95+
cause: BaseException | None = e
9696
while cause and limit:
9797
causes.append(str(cause))
9898
cause = cause.__cause__

src/replit_river/message_buffer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import logging
3-
from typing import Optional
43

54
from replit_river.rpc import TransportMessage
65
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
@@ -41,7 +40,7 @@ async def put(self, message: TransportMessage) -> None:
4140
raise MessageBufferClosedError("message buffer is closed")
4241
self.buffer.append(message)
4342

44-
async def peek(self) -> Optional[TransportMessage]:
43+
async def peek(self) -> TransportMessage | None:
4544
"""Peek the first message in the buffer, returns None if the buffer is empty."""
4645
async with self._lock:
4746
if len(self.buffer) == 0:

src/replit_river/rpc.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
Literal,
1313
Mapping,
1414
NoReturn,
15-
Optional,
1615
Sequence,
1716
TypeVar,
1817
Union,
@@ -63,22 +62,22 @@
6362
# Equivalent of https://github.com/replit/river/blob/c1345f1ff6a17a841d4319fad5c153b5bda43827/transport/message.ts#L23-L33
6463
class ExpectedSessionState(BaseModel):
6564
nextExpectedSeq: int
66-
nextSentSeq: Optional[int] = None
65+
nextSentSeq: int | None = None
6766

6867

6968
class ControlMessageHandshakeRequest(BaseModel, Generic[HandshakeMetadataType]):
7069
type: Literal["HANDSHAKE_REQ"] = "HANDSHAKE_REQ"
7170
protocolVersion: str
7271
sessionId: str
7372
expectedSessionState: ExpectedSessionState
74-
metadata: Optional[HandshakeMetadataType] = None
73+
metadata: HandshakeMetadataType | None = None
7574

7675

7776
class HandShakeStatus(BaseModel):
7877
ok: bool
79-
sessionId: Optional[str] = None
80-
reason: Optional[str] = None
81-
code: Optional[str] = None
78+
sessionId: str | None = None
79+
reason: str | None = None
80+
code: str | None = None
8281

8382

8483
class ControlMessageHandshakeResponse(BaseModel):
@@ -98,11 +97,11 @@ class TransportMessage(BaseModel):
9897
to: str
9998
seq: int
10099
ack: int
101-
serviceName: Optional[str] = None
102-
procedureName: Optional[str] = None
100+
serviceName: str | None = None
101+
procedureName: str | None = None
103102
streamId: str
104103
controlFlags: int
105-
tracing: Optional[PropagationContext] = None
104+
tracing: PropagationContext | None = None
106105
payload: Any
107106
model_config = ConfigDict(populate_by_name=True)
108107
# need this because we create TransportMessage objects with destructuring
@@ -131,8 +130,8 @@ class GrpcContext(grpc.aio.ServicerContext, Generic[RequestType, ResponseType]):
131130

132131
def __init__(self, peer: str) -> None:
133132
self._peer = peer
134-
self._abort_code: Optional[grpc.StatusCode] = None
135-
self._abort_details: Optional[str] = None
133+
self._abort_code: grpc.StatusCode | None = None
134+
self._abort_details: str | None = None
136135

137136
async def abort(
138137
self,
@@ -157,10 +156,10 @@ def invocation_metadata(self) -> None:
157156
def peer(self) -> str:
158157
return self._peer
159158

160-
def peer_identities(self) -> Optional[Iterable[bytes]]:
159+
def peer_identities(self) -> Iterable[bytes] | None:
161160
return None
162161

163-
def peer_identity_key(self) -> Optional[str]:
162+
def peer_identity_key(self) -> str | None:
164163
return None
165164

166165
async def read(self) -> RequestType:

src/replit_river/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from typing import Mapping, Optional
3+
from typing import Mapping
44

55
import websockets
66
from websockets.exceptions import ConnectionClosed
@@ -41,7 +41,7 @@ def add_rpc_handlers(
4141

4242
async def _handshake_to_get_session(
4343
self, websocket: WebSocketServerProtocol
44-
) -> Optional[Session]:
44+
) -> Session | None:
4545
"""This is a wrapper to make sentry happy, sentry doesn't recognize the
4646
exception handling outside of a task or asyncio.wait_for. So we need to catch
4747
the errors specifically here.

0 commit comments

Comments
 (0)