Skip to content

Commit a3693ce

Browse files
committed
capture: improve DontReadFromInput typing
Have `DontReadFromInput` inherit from `TextIO`, ensuring it's fully compatible with `sys.stdin` (which has type `TextIO`).
1 parent 7d4b403 commit a3693ce

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

src/_pytest/capture.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@
66
import sys
77
from io import UnsupportedOperation
88
from tempfile import TemporaryFile
9+
from types import TracebackType
910
from typing import Any
1011
from typing import AnyStr
12+
from typing import BinaryIO
1113
from typing import Generator
1214
from typing import Generic
15+
from typing import Iterable
16+
from typing import Iterator
17+
from typing import List
1318
from typing import NamedTuple
1419
from typing import Optional
1520
from typing import TextIO
1621
from typing import Tuple
22+
from typing import Type
1723
from typing import TYPE_CHECKING
1824
from typing import Union
1925

@@ -185,19 +191,27 @@ def write(self, s: str) -> int:
185191
return self._other.write(s)
186192

187193

188-
class DontReadFromInput:
189-
encoding = None
194+
class DontReadFromInput(TextIO):
195+
@property
196+
def encoding(self) -> str:
197+
return sys.__stdin__.encoding
190198

191-
def read(self, *args):
199+
def read(self, size: int = -1) -> str:
192200
raise OSError(
193201
"pytest: reading from stdin while output is captured! Consider using `-s`."
194202
)
195203

196204
readline = read
197-
readlines = read
198-
__next__ = read
199205

200-
def __iter__(self):
206+
def __next__(self) -> str:
207+
return self.readline()
208+
209+
def readlines(self, hint: Optional[int] = -1) -> List[str]:
210+
raise OSError(
211+
"pytest: reading from stdin while output is captured! Consider using `-s`."
212+
)
213+
214+
def __iter__(self) -> Iterator[str]:
201215
return self
202216

203217
def fileno(self) -> int:
@@ -215,7 +229,7 @@ def close(self) -> None:
215229
def readable(self) -> bool:
216230
return False
217231

218-
def seek(self, offset: int) -> int:
232+
def seek(self, offset: int, whence: int = 0) -> int:
219233
raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)")
220234

221235
def seekable(self) -> bool:
@@ -224,22 +238,34 @@ def seekable(self) -> bool:
224238
def tell(self) -> int:
225239
raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()")
226240

227-
def truncate(self, size: int) -> None:
241+
def truncate(self, size: Optional[int] = None) -> int:
228242
raise UnsupportedOperation("cannont truncate stdin")
229243

230-
def write(self, *args) -> None:
244+
def write(self, data: str) -> int:
231245
raise UnsupportedOperation("cannot write to stdin")
232246

233-
def writelines(self, *args) -> None:
247+
def writelines(self, lines: Iterable[str]) -> None:
234248
raise UnsupportedOperation("Cannot write to stdin")
235249

236250
def writable(self) -> bool:
237251
return False
238252

239-
@property
240-
def buffer(self):
253+
def __enter__(self) -> "DontReadFromInput":
241254
return self
242255

256+
def __exit__(
257+
self,
258+
type: Optional[Type[BaseException]],
259+
value: Optional[BaseException],
260+
traceback: Optional[TracebackType],
261+
) -> None:
262+
pass
263+
264+
@property
265+
def buffer(self) -> BinaryIO:
266+
# The str/bytes doesn't actually matter in this type, so OK to fake.
267+
return self # type: ignore[return-value]
268+
243269

244270
# Capture classes.
245271

testing/test_capture.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ def test_dontreadfrominput() -> None:
890890
from _pytest.capture import DontReadFromInput
891891

892892
f = DontReadFromInput()
893-
assert f.buffer is f
893+
assert f.buffer is f # type: ignore[comparison-overlap]
894894
assert not f.isatty()
895895
pytest.raises(OSError, f.read)
896896
pytest.raises(OSError, f.readlines)
@@ -906,7 +906,10 @@ def test_dontreadfrominput() -> None:
906906
pytest.raises(UnsupportedOperation, f.write, b"")
907907
pytest.raises(UnsupportedOperation, f.writelines, [])
908908
assert not f.writable()
909+
assert isinstance(f.encoding, str)
909910
f.close() # just for completeness
911+
with f:
912+
pass
910913

911914

912915
def test_captureresult() -> None:

0 commit comments

Comments
 (0)