Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy: Add music_assistant.common #1428

Merged
merged 12 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions music_assistant/client/music.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from __future__ import annotations

import urllib.parse
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

from music_assistant.common.models.enums import ImageType, MediaType
from music_assistant.common.models.errors import UnsupportedFeaturedException
from music_assistant.common.models.media_items import (
Album,
AlbumTrack,
Expand All @@ -19,7 +20,6 @@
Radio,
SearchResults,
Track,
media_from_dict,
)
from music_assistant.common.models.provider import SyncTask
from music_assistant.common.models.queue_item import QueueItem
Expand All @@ -28,6 +28,21 @@
from .client import MusicAssistantClient


def media_from_dict(media_item: dict[str, Any]) -> MediaItemType:
marcelveldt marked this conversation as resolved.
Show resolved Hide resolved
"""Return MediaItem from dict."""
if media_item["media_type"] == "artist":
return Artist.from_dict(media_item)
if media_item["media_type"] == "album":
return Album.from_dict(media_item)
if media_item["media_type"] == "track":
return Track.from_dict(media_item)
if media_item["media_type"] == "playlist":
return Playlist.from_dict(media_item)
if media_item["media_type"] == "radio":
return Radio.from_dict(media_item)
raise UnsupportedFeaturedException(f"Unknown media_type: {media_item['media_type']}")


class Music:
"""Music(library) related endpoints/data for Music Assistant."""

Expand Down
2 changes: 1 addition & 1 deletion music_assistant/common/helpers/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def now_timestamp() -> float:
return now().timestamp()


def future_timestamp(**kwargs) -> float:
def future_timestamp(**kwargs: float) -> float:
Jc2k marked this conversation as resolved.
Show resolved Hide resolved
"""Return current timestamp + timedelta."""
return (now() + datetime.timedelta(**kwargs)).timestamp()

Expand Down
2 changes: 1 addition & 1 deletion music_assistant/common/helpers/global_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# global cache - we use this on a few places (as limited as possible)
# where we have no other options
_global_cache_lock = asyncio.Lock()
_global_cache = {}
_global_cache: dict[str, Any] = {}


def get_global_cache_value(key: str, default: Any = None) -> Any:
Expand Down
12 changes: 6 additions & 6 deletions music_assistant/common/helpers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import base64
from _collections_abc import dict_keys, dict_values
from types import MethodType
from typing import Any
from typing import Any, TypeVar

import aiofiles
import orjson
from mashumaro.mixins.orjson import DataClassORJSONMixin

JSON_ENCODE_EXCEPTIONS = (TypeError, ValueError)
JSON_DECODE_EXCEPTIONS = (orjson.JSONDecodeError,)
Expand Down Expand Up @@ -59,12 +60,11 @@ def json_dumps(data: Any, indent: bool = False) -> str:

json_loads = orjson.loads

TargetT = TypeVar("TargetT", bound=DataClassORJSONMixin)

async def load_json_file(path: str, target_class: type | None = None) -> dict:

async def load_json_file(path: str, target_class: type[TargetT]) -> TargetT:
marcelveldt marked this conversation as resolved.
Show resolved Hide resolved
"""Load JSON from file."""
async with aiofiles.open(path, "r") as _file:
content = await _file.read()
if target_class:
# support for a mashumaro model
return target_class.from_json(content)
return json_loads(content)
return target_class.from_json(content)
2 changes: 1 addition & 1 deletion music_assistant/common/helpers/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
base62_length22_id_pattern = re.compile(r"^[a-zA-Z0-9]{22}$")


def valid_base62_length22(item_id) -> bool:
def valid_base62_length22(item_id: str) -> bool:
"""Validate Spotify style ID."""
return bool(base62_length22_id_pattern.match(item_id))

Expand Down
42 changes: 18 additions & 24 deletions music_assistant/common/helpers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import re
import socket
from collections.abc import Callable
from collections.abc import Set as AbstractSet
from typing import Any, TypeVar
from urllib.parse import urlparse
from uuid import UUID

# pylint: disable=invalid-name
T = TypeVar("T")
_UNDEF: dict = {}
marcelveldt marked this conversation as resolved.
Show resolved Hide resolved
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
CALLBACK_TYPE = Callable[[], None]
# pylint: enable=invalid-name

Expand Down Expand Up @@ -50,7 +49,7 @@ def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float |
return default


def try_parse_bool(possible_bool: Any) -> str:
def try_parse_bool(possible_bool: Any) -> bool:
"""Try to parse a bool."""
if isinstance(possible_bool, bool):
return possible_bool
Expand Down Expand Up @@ -79,7 +78,7 @@ def create_sort_name(input_str: str) -> str:
return input_str.strip()


def parse_title_and_version(title: str, track_version: str | None = None):
def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
"""Try to parse clean track title and version from the title."""
version = ""
for splitter in [" (", " [", " - ", " (", " [", "-"]:
Expand Down Expand Up @@ -135,7 +134,7 @@ def clean_title(title: str) -> str:
return title.strip()


def get_version_substitute(version_str: str):
def get_version_substitute(version_str: str) -> str:
"""Transform provider version str to universal version type."""
version_str = version_str.lower()
# substitute edit and edition with version
Expand Down Expand Up @@ -169,7 +168,7 @@ def strip_url(line: str) -> str:
).rstrip()


def strip_dotcom(line: str):
def strip_dotcom(line: str) -> str:
"""Strip scheme-less netloc from line."""
return dot_com_pattern.sub("", line)

Expand Down Expand Up @@ -227,17 +226,17 @@ def clean_stream_title(line: str) -> str:
return line


async def get_ip():
async def get_ip() -> str:
"""Get primary IP-address for this host."""

def _get_ip():
def _get_ip() -> str:
"""Get primary IP-address for this host."""
# pylint: disable=broad-except,no-member
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(("10.255.255.255", 1))
_ip = sock.getsockname()[0]
_ip = str(sock.getsockname()[0])
except Exception:
_ip = "127.0.0.1"
finally:
Expand Down Expand Up @@ -273,7 +272,7 @@ async def select_free_port(range_start: int, range_end: int) -> int:
async def get_ip_from_host(dns_name: str) -> str | None:
"""Resolve (first) IP-address for given dns name."""

def _resolve():
def _resolve() -> str | None:
try:
return socket.gethostbyname(dns_name)
except Exception: # pylint: disable=broad-except
Expand All @@ -283,7 +282,7 @@ def _resolve():
return await asyncio.to_thread(_resolve)


async def get_ip_pton(ip_string: str | None = None):
async def get_ip_pton(ip_string: str | None = None) -> bytes:
"""Return socket pton for local ip."""
if ip_string is None:
ip_string = await get_ip()
Expand All @@ -294,7 +293,7 @@ async def get_ip_pton(ip_string: str | None = None):
return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)


def get_folder_size(folderpath):
def get_folder_size(folderpath: str) -> float:
"""Return folder size in gb."""
total_size = 0
# pylint: disable=unused-variable
Expand All @@ -306,7 +305,9 @@ def get_folder_size(folderpath):
return total_size / float(1 << 30)


def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False):
def merge_dict(
base_dict: dict[Any, Any], new_dict: dict[Any, Any], allow_overwite: bool = False
) -> dict[Any, Any]:
"""Merge dict without overwriting existing values."""
final_dict = base_dict.copy()
for key, value in new_dict.items():
Expand All @@ -321,12 +322,12 @@ def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False):
return final_dict


def merge_tuples(base: tuple, new: tuple) -> tuple:
def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
"""Merge 2 tuples."""
return tuple(x for x in base if x not in new) + tuple(new)


def merge_lists(base: list, new: list) -> list:
def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
"""Merge 2 lists."""
return [x for x in base if x not in new] + list(new)

Expand All @@ -335,7 +336,7 @@ def get_changed_keys(
dict1: dict[str, Any],
dict2: dict[str, Any],
ignore_keys: list[str] | None = None,
) -> set[str]:
) -> AbstractSet[str]:
"""Compare 2 dicts and return set of changed keys."""
return get_changed_values(dict1, dict2, ignore_keys).keys()

Expand Down Expand Up @@ -369,7 +370,7 @@ def get_changed_values(
return changed_values


def empty_queue(q: asyncio.Queue) -> None:
def empty_queue(q: asyncio.Queue[T]) -> None:
"""Empty an asyncio Queue."""
for _ in range(q.qsize()):
try:
Expand All @@ -386,10 +387,3 @@ def is_valid_uuid(uuid_to_test: str) -> bool:
except ValueError:
return False
return str(uuid_obj) == uuid_to_test


class classproperty(property): # noqa: N801
"""Implement class property for python3.11+."""

def __get__(self, cls, owner): # noqa: D105
return classmethod(self.fget).__get__(None, owner)()
2 changes: 1 addition & 1 deletion music_assistant/common/models/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ServerInfoMessage(DataClassORJSONMixin):
)


def parse_message(raw: dict) -> MessageType:
def parse_message(raw: dict[Any, Any]) -> MessageType:
"""Parse Message from raw dict object."""
if "event" in raw:
return EventMessage.from_dict(raw)
Expand Down
Loading
Loading