Skip to content

[PyMySQL] Add missing stubs #14335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions stubs/PyMySQL/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
pymysql.connections.byte2int
pymysql.connections.int2byte
pymysql.connections.pack_int24
pymysql.cursors.Cursor.__del__
# DictCursorMixin changes method types of inherited classes, but doesn't contain much at runtime
pymysql.cursors.DictCursorMixin.__iter__
pymysql.cursors.DictCursorMixin.fetch[a-z]*
pymysql.escape_dict
pymysql.escape_sequence
pymysql.escape_string
pymysql.util
4 changes: 0 additions & 4 deletions stubs/PyMySQL/METADATA.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
version = "1.1.*"
upstream_repository = "https://github.com/PyMySQL/PyMySQL"
partial_stub = true

[tool.stubtest]
ignore_missing_stub = true
67 changes: 57 additions & 10 deletions stubs/PyMySQL/pymysql/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Final
from _typeshed import ReadableBuffer
from collections.abc import Iterable
from typing import Final, SupportsBytes, SupportsIndex

from . import connections, constants, converters, cursors
from .connections import Connection as Connection
from .constants import FIELD_TYPE as FIELD_TYPE
from .converters import escape_dict as escape_dict, escape_sequence as escape_sequence, escape_string as escape_string
from .err import (
DatabaseError as DatabaseError,
DataError as DataError,
Expand All @@ -27,14 +29,19 @@ from .times import (

VERSION: Final[tuple[str | int, ...]]
VERSION_STRING: Final[str]
version_info: tuple[int, int, int, str, int]
__version__: str

def get_client_info() -> str: ...
def install_as_MySQLdb() -> None: ...

threadsafety: int
apilevel: str
paramstyle: str

class DBAPISet(frozenset[int]):
def __ne__(self, other) -> bool: ...
def __eq__(self, other) -> bool: ...
def __ne__(self, other: object) -> bool: ...
def __eq__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...

STRING: DBAPISet
Expand All @@ -46,16 +53,56 @@ TIMESTAMP: DBAPISet
DATETIME: DBAPISet
ROWID: DBAPISet

def Binary(x) -> bytes: ...
def get_client_info() -> str: ...
def Binary(x: Iterable[SupportsIndex] | SupportsIndex | SupportsBytes | ReadableBuffer) -> bytes: ...
def thread_safe() -> bool: ...

__version__: str
version_info: tuple[int, int, int, str, int]
NULL: str

# pymysql/__init__.py says "Connect = connect = Connection = connections.Connection"
Connect = Connection
connect = Connection

def thread_safe() -> bool: ...
def install_as_MySQLdb() -> None: ...
__all__ = [
"BINARY",
"Binary",
"Connect",
"Connection",
"DATE",
"Date",
"Time",
"Timestamp",
"DateFromTicks",
"TimeFromTicks",
"TimestampFromTicks",
"DataError",
"DatabaseError",
"Error",
"FIELD_TYPE",
"IntegrityError",
"InterfaceError",
"InternalError",
"MySQLError",
"NULL",
"NUMBER",
"NotSupportedError",
"DBAPISet",
"OperationalError",
"ProgrammingError",
"ROWID",
"STRING",
"TIME",
"TIMESTAMP",
"Warning",
"apilevel",
"connect",
"connections",
"constants",
"converters",
"cursors",
"get_client_info",
"paramstyle",
"threadsafety",
"version_info",
"install_as_MySQLdb",
"__version__",
]
156 changes: 62 additions & 94 deletions stubs/PyMySQL/pymysql/connections.pyi
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from _typeshed import Incomplete
from _typeshed import FileDescriptorOrPath, Incomplete
from collections.abc import Mapping
from socket import socket as _socket
from ssl import _PasswordType
from typing import Any, AnyStr, Generic, TypeVar, overload
from typing_extensions import Self
from typing import AnyStr, Generic, TypeVar, overload
from typing_extensions import Self, deprecated

from .charset import charset_by_id as charset_by_id, charset_by_name as charset_by_name
from .constants import CLIENT as CLIENT, COMMAND as COMMAND, FIELD_TYPE as FIELD_TYPE, SERVER_STATUS as SERVER_STATUS
from .cursors import Cursor
from .util import byte2int as byte2int, int2byte as int2byte

_C = TypeVar("_C", bound=Cursor)
_C2 = TypeVar("_C2", bound=Cursor)

SSL_ENABLED: bool
DEFAULT_USER: str | None
Expand All @@ -17,68 +19,30 @@ DEFAULT_CHARSET: str
TEXT_TYPES: set[int]
MAX_PACKET_LEN: int

_C = TypeVar("_C", bound=Cursor)
_C2 = TypeVar("_C2", bound=Cursor)

def dump_packet(data): ...
def pack_int24(n): ...
def _lenenc_int(i: int) -> bytes: ...

class MysqlPacket:
connection: Any
def __init__(self, data, encoding): ...
def get_all_data(self): ...
def read(self, size): ...
def read_all(self): ...
def advance(self, length): ...
def rewind(self, position: int = 0): ...
def get_bytes(self, position, length: int = 1): ...
def read_string(self) -> bytes: ...
def read_uint8(self) -> Any: ...
def read_uint16(self) -> Any: ...
def read_uint24(self) -> Any: ...
def read_uint32(self) -> Any: ...
def read_uint64(self) -> Any: ...
def read_length_encoded_integer(self) -> int: ...
def read_length_coded_string(self) -> bytes: ...
def read_struct(self, fmt: str) -> tuple[Any, ...]: ...
def is_ok_packet(self) -> bool: ...
def is_eof_packet(self) -> bool: ...
def is_auth_switch_request(self) -> bool: ...
def is_extra_auth_data(self) -> bool: ...
def is_resultset_packet(self) -> bool: ...
def is_load_local_packet(self) -> bool: ...
def is_error_packet(self) -> bool: ...
def check_error(self): ...
def raise_for_error(self) -> None: ...
def dump(self): ...

class FieldDescriptorPacket(MysqlPacket):
def __init__(self, data, encoding): ...
def description(self): ...
def get_column_length(self): ...

class Connection(Generic[_C]):
ssl: Any
host: Any
port: Any
user: Any
password: Any
db: Any
unix_socket: Any
ssl: Incomplete
host: Incomplete
port: Incomplete
user: Incomplete
password: Incomplete
db: Incomplete
unix_socket: Incomplete
charset: str
collation: str | None
bind_address: Any
use_unicode: Any
client_flag: Any
cursorclass: Any
connect_timeout: Any
messages: Any
encoders: Any
decoders: Any
host_info: Any
sql_mode: Any
init_command: Any
bind_address: Incomplete
use_unicode: Incomplete
client_flag: Incomplete
cursorclass: Incomplete
connect_timeout: Incomplete
messages: Incomplete
encoders: Incomplete
decoders: Incomplete
host_info: Incomplete
sql_mode: Incomplete
init_command: Incomplete
max_allowed_packet: int
server_public_key: bytes
@overload
Expand All @@ -101,7 +65,7 @@ class Connection(Generic[_C]):
cursorclass: None = None, # different between overloads
init_command=None,
connect_timeout: int | None = 10,
ssl: Mapping[Any, Any] | None = None,
ssl: Mapping[Incomplete, Incomplete] | None = None,
ssl_ca=None,
ssl_cert=None,
ssl_disabled=None,
Expand All @@ -118,7 +82,7 @@ class Connection(Generic[_C]):
local_infile: Incomplete | None = False,
max_allowed_packet: int = 16777216,
defer_connect: bool | None = False,
auth_plugin_map: Mapping[Any, Any] | None = None,
auth_plugin_map: Mapping[Incomplete, Incomplete] | None = None,
read_timeout: float | None = None,
write_timeout: float | None = None,
bind_address=None,
Expand Down Expand Up @@ -147,7 +111,7 @@ class Connection(Generic[_C]):
cursorclass: type[_C] = ..., # different between overloads
init_command=None,
connect_timeout: int | None = 10,
ssl: Mapping[Any, Any] | None = None,
ssl: Mapping[Incomplete, Incomplete] | None = None,
ssl_ca=None,
ssl_cert=None,
ssl_disabled=None,
Expand All @@ -163,27 +127,28 @@ class Connection(Generic[_C]):
local_infile: Incomplete | None = False,
max_allowed_packet: int = 16777216,
defer_connect: bool | None = False,
auth_plugin_map: Mapping[Any, Any] | None = None,
auth_plugin_map: Mapping[Incomplete, Incomplete] | None = None,
read_timeout: float | None = None,
write_timeout: float | None = None,
bind_address=None,
binary_prefix: bool | None = False,
program_name=None,
server_public_key: bytes | None = None,
) -> None: ...
socket: Any
rfile: Any
wfile: Any
socket: Incomplete
rfile: Incomplete
wfile: Incomplete
def close(self) -> None: ...
@property
def open(self) -> bool: ...
def __del__(self) -> None: ...
def autocommit(self, value) -> None: ...
def get_autocommit(self) -> bool: ...
def commit(self) -> None: ...
def begin(self) -> None: ...
def rollback(self) -> None: ...
def select_db(self, db) -> None: ...
def escape(self, obj, mapping: Mapping[Any, Any] | None = None): ...
def escape(self, obj, mapping: Mapping[Incomplete, Incomplete] | None = None): ...
def literal(self, obj): ...
def escape_string(self, s: AnyStr) -> AnyStr: ...
@overload
Expand All @@ -195,7 +160,9 @@ class Connection(Generic[_C]):
def affected_rows(self): ...
def kill(self, thread_id): ...
def ping(self, reconnect: bool = True) -> None: ...
def set_charset(self, charset) -> None: ...
@deprecated("Method is deprecated. Use set_character_set() instead.")
def set_charset(self, charset: str) -> None: ...
def set_character_set(self, charset: str, collation: str | None = None) -> None: ...
def connect(self, sock: _socket | None = None) -> None: ...
def write_packet(self, payload) -> None: ...
def _read_packet(self, packet_type=...): ...
Expand All @@ -208,35 +175,36 @@ class Connection(Generic[_C]):
def show_warnings(self): ...
def __enter__(self) -> Self: ...
def __exit__(self, *exc_info: object) -> None: ...
Warning: Any
Error: Any
InterfaceError: Any
DatabaseError: Any
DataError: Any
OperationalError: Any
IntegrityError: Any
InternalError: Any
ProgrammingError: Any
NotSupportedError: Any
Warning: Incomplete
Error: Incomplete
InterfaceError: Incomplete
DatabaseError: Incomplete
DataError: Incomplete
OperationalError: Incomplete
IntegrityError: Incomplete
InternalError: Incomplete
ProgrammingError: Incomplete
NotSupportedError: Incomplete

class MySQLResult:
connection: Any
affected_rows: Any
insert_id: Any
server_status: Any
warning_count: Any
message: Any
field_count: Any
description: Any
rows: Any
has_next: Any
def __init__(self, connection: Connection[Any]) -> None: ...
first_packet: Any
connection: Incomplete
affected_rows: Incomplete
insert_id: Incomplete
server_status: Incomplete
warning_count: Incomplete
message: Incomplete
field_count: Incomplete
description: Incomplete
rows: Incomplete
has_next: Incomplete
def __init__(self, connection: Connection[Incomplete]) -> None: ...
def __del__(self) -> None: ...
first_packet: Incomplete
def read(self) -> None: ...
def init_unbuffered_query(self) -> None: ...

class LoadLocalFile:
filename: Any
connection: Connection[Any]
def __init__(self, filename: Any, connection: Connection[Any]) -> None: ...
filename: FileDescriptorOrPath
connection: Connection[Incomplete]
def __init__(self, filename: FileDescriptorOrPath, connection: Connection[Incomplete]) -> None: ...
def send_data(self) -> None: ...
Loading
Loading