From 306997740e99b735b4b162b808cfcecea508330d Mon Sep 17 00:00:00 2001 From: Madhur Tandon Date: Wed, 15 May 2024 17:51:57 +0530 Subject: [PATCH] vendor typeguard --- metaflow/_vendor/importlib_metadata.LICENSE | 13 + .../_vendor/importlib_metadata/__init__.py | 1063 ++++++ .../_vendor/importlib_metadata/_adapters.py | 68 + .../importlib_metadata/_collections.py | 30 + .../_vendor/importlib_metadata/_compat.py | 71 + .../_vendor/importlib_metadata/_functools.py | 104 + .../_vendor/importlib_metadata/_itertools.py | 73 + metaflow/_vendor/importlib_metadata/_meta.py | 48 + metaflow/_vendor/importlib_metadata/_text.py | 99 + metaflow/_vendor/importlib_metadata/py.typed | 0 metaflow/_vendor/typeguard.LICENSE | 19 + metaflow/_vendor/typeguard/__init__.py | 48 + metaflow/_vendor/typeguard/_checkers.py | 906 +++++ metaflow/_vendor/typeguard/_config.py | 108 + metaflow/_vendor/typeguard/_decorators.py | 237 ++ metaflow/_vendor/typeguard/_exceptions.py | 42 + metaflow/_vendor/typeguard/_functions.py | 307 ++ metaflow/_vendor/typeguard/_importhook.py | 213 ++ metaflow/_vendor/typeguard/_memo.py | 48 + metaflow/_vendor/typeguard/_pytest_plugin.py | 100 + metaflow/_vendor/typeguard/_suppression.py | 88 + metaflow/_vendor/typeguard/_transformer.py | 1193 +++++++ .../_vendor/typeguard/_union_transformer.py | 54 + metaflow/_vendor/typeguard/_utils.py | 169 + metaflow/_vendor/typeguard/py.typed | 0 metaflow/_vendor/typing_extensions.LICENSE | 279 ++ metaflow/_vendor/typing_extensions.py | 3053 +++++++++++++++++ metaflow/_vendor/v3_7/__init__.py | 1 - metaflow/_vendor/vendor_any.txt | 4 + metaflow/_vendor/vendor_v3_7.txt | 1 - metaflow/_vendor/{v3_7 => }/zipp.LICENSE | 0 metaflow/_vendor/{v3_7 => }/zipp.py | 0 metaflow/cmd/develop/stubs.py | 2 + metaflow/extension_support/__init__.py | 2 + metaflow/vendor.py | 1 - test/core/run_tests.py | 12 + 36 files changed, 8453 insertions(+), 3 deletions(-) create mode 100644 metaflow/_vendor/importlib_metadata.LICENSE create mode 100644 metaflow/_vendor/importlib_metadata/__init__.py create mode 100644 metaflow/_vendor/importlib_metadata/_adapters.py create mode 100644 metaflow/_vendor/importlib_metadata/_collections.py create mode 100644 metaflow/_vendor/importlib_metadata/_compat.py create mode 100644 metaflow/_vendor/importlib_metadata/_functools.py create mode 100644 metaflow/_vendor/importlib_metadata/_itertools.py create mode 100644 metaflow/_vendor/importlib_metadata/_meta.py create mode 100644 metaflow/_vendor/importlib_metadata/_text.py create mode 100644 metaflow/_vendor/importlib_metadata/py.typed create mode 100644 metaflow/_vendor/typeguard.LICENSE create mode 100644 metaflow/_vendor/typeguard/__init__.py create mode 100644 metaflow/_vendor/typeguard/_checkers.py create mode 100644 metaflow/_vendor/typeguard/_config.py create mode 100644 metaflow/_vendor/typeguard/_decorators.py create mode 100644 metaflow/_vendor/typeguard/_exceptions.py create mode 100644 metaflow/_vendor/typeguard/_functions.py create mode 100644 metaflow/_vendor/typeguard/_importhook.py create mode 100644 metaflow/_vendor/typeguard/_memo.py create mode 100644 metaflow/_vendor/typeguard/_pytest_plugin.py create mode 100644 metaflow/_vendor/typeguard/_suppression.py create mode 100644 metaflow/_vendor/typeguard/_transformer.py create mode 100644 metaflow/_vendor/typeguard/_union_transformer.py create mode 100644 metaflow/_vendor/typeguard/_utils.py create mode 100644 metaflow/_vendor/typeguard/py.typed create mode 100644 metaflow/_vendor/typing_extensions.LICENSE create mode 100644 metaflow/_vendor/typing_extensions.py delete mode 100644 metaflow/_vendor/v3_7/__init__.py delete mode 100644 metaflow/_vendor/vendor_v3_7.txt rename metaflow/_vendor/{v3_7 => }/zipp.LICENSE (100%) rename metaflow/_vendor/{v3_7 => }/zipp.py (100%) diff --git a/metaflow/_vendor/importlib_metadata.LICENSE b/metaflow/_vendor/importlib_metadata.LICENSE new file mode 100644 index 00000000000..be7e092b0b0 --- /dev/null +++ b/metaflow/_vendor/importlib_metadata.LICENSE @@ -0,0 +1,13 @@ +Copyright 2017-2019 Jason R. Coombs, Barry Warsaw + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/metaflow/_vendor/importlib_metadata/__init__.py b/metaflow/_vendor/importlib_metadata/__init__.py new file mode 100644 index 00000000000..d6c84fb70e9 --- /dev/null +++ b/metaflow/_vendor/importlib_metadata/__init__.py @@ -0,0 +1,1063 @@ +import os +import re +import abc +import csv +import sys +from metaflow._vendor import zipp +import email +import pathlib +import operator +import textwrap +import warnings +import functools +import itertools +import posixpath +import collections + +from . import _adapters, _meta +from ._collections import FreezableDefaultDict, Pair +from ._compat import ( + NullFinder, + install, + pypy_partial, +) +from ._functools import method_cache, pass_none +from ._itertools import always_iterable, unique_everseen +from ._meta import PackageMetadata, SimplePath + +from contextlib import suppress +from importlib import import_module +from importlib.abc import MetaPathFinder +from itertools import starmap +from typing import List, Mapping, Optional, Union + + +__all__ = [ + 'Distribution', + 'DistributionFinder', + 'PackageMetadata', + 'PackageNotFoundError', + 'distribution', + 'distributions', + 'entry_points', + 'files', + 'metadata', + 'packages_distributions', + 'requires', + 'version', +] + + +class PackageNotFoundError(ModuleNotFoundError): + """The package was not found.""" + + def __str__(self): + return f"No package metadata was found for {self.name}" + + @property + def name(self): + (name,) = self.args + return name + + +class Sectioned: + """ + A simple entry point config parser for performance + + >>> for item in Sectioned.read(Sectioned._sample): + ... print(item) + Pair(name='sec1', value='# comments ignored') + Pair(name='sec1', value='a = 1') + Pair(name='sec1', value='b = 2') + Pair(name='sec2', value='a = 2') + + >>> res = Sectioned.section_pairs(Sectioned._sample) + >>> item = next(res) + >>> item.name + 'sec1' + >>> item.value + Pair(name='a', value='1') + >>> item = next(res) + >>> item.value + Pair(name='b', value='2') + >>> item = next(res) + >>> item.name + 'sec2' + >>> item.value + Pair(name='a', value='2') + >>> list(res) + [] + """ + + _sample = textwrap.dedent( + """ + [sec1] + # comments ignored + a = 1 + b = 2 + + [sec2] + a = 2 + """ + ).lstrip() + + @classmethod + def section_pairs(cls, text): + return ( + section._replace(value=Pair.parse(section.value)) + for section in cls.read(text, filter_=cls.valid) + if section.name is not None + ) + + @staticmethod + def read(text, filter_=None): + lines = filter(filter_, map(str.strip, text.splitlines())) + name = None + for value in lines: + section_match = value.startswith('[') and value.endswith(']') + if section_match: + name = value.strip('[]') + continue + yield Pair(name, value) + + @staticmethod + def valid(line): + return line and not line.startswith('#') + + +class DeprecatedTuple: + """ + Provide subscript item access for backward compatibility. + + >>> recwarn = getfixture('recwarn') + >>> ep = EntryPoint(name='name', value='value', group='group') + >>> ep[:] + ('name', 'value', 'group') + >>> ep[0] + 'name' + >>> len(recwarn) + 1 + """ + + _warn = functools.partial( + warnings.warn, + "EntryPoint tuple interface is deprecated. Access members by name.", + DeprecationWarning, + stacklevel=pypy_partial(2), + ) + + def __getitem__(self, item): + self._warn() + return self._key()[item] + + +class EntryPoint(DeprecatedTuple): + """An entry point as defined by Python packaging conventions. + + See `the packaging docs on entry points + `_ + for more information. + """ + + pattern = re.compile( + r'(?P[\w.]+)\s*' + r'(:\s*(?P[\w.]+))?\s*' + r'(?P\[.*\])?\s*$' + ) + """ + A regular expression describing the syntax for an entry point, + which might look like: + + - module + - package.module + - package.module:attribute + - package.module:object.attribute + - package.module:attr [extra1, extra2] + + Other combinations are possible as well. + + The expression is lenient about whitespace around the ':', + following the attr, and following any extras. + """ + + dist: Optional['Distribution'] = None + + def __init__(self, name, value, group): + vars(self).update(name=name, value=value, group=group) + + def load(self): + """Load the entry point from its definition. If only a module + is indicated by the value, return that module. Otherwise, + return the named object. + """ + match = self.pattern.match(self.value) + module = import_module(match.group('module')) + attrs = filter(None, (match.group('attr') or '').split('.')) + return functools.reduce(getattr, attrs, module) + + @property + def module(self): + match = self.pattern.match(self.value) + return match.group('module') + + @property + def attr(self): + match = self.pattern.match(self.value) + return match.group('attr') + + @property + def extras(self): + match = self.pattern.match(self.value) + return list(re.finditer(r'\w+', match.group('extras') or '')) + + def _for(self, dist): + vars(self).update(dist=dist) + return self + + def __iter__(self): + """ + Supply iter so one may construct dicts of EntryPoints by name. + """ + msg = ( + "Construction of dict of EntryPoints is deprecated in " + "favor of EntryPoints." + ) + warnings.warn(msg, DeprecationWarning) + return iter((self.name, self)) + + def matches(self, **params): + attrs = (getattr(self, param) for param in params) + return all(map(operator.eq, params.values(), attrs)) + + def _key(self): + return self.name, self.value, self.group + + def __lt__(self, other): + return self._key() < other._key() + + def __eq__(self, other): + return self._key() == other._key() + + def __setattr__(self, name, value): + raise AttributeError("EntryPoint objects are immutable.") + + def __repr__(self): + return ( + f'EntryPoint(name={self.name!r}, value={self.value!r}, ' + f'group={self.group!r})' + ) + + def __hash__(self): + return hash(self._key()) + + +class DeprecatedList(list): + """ + Allow an otherwise immutable object to implement mutability + for compatibility. + + >>> recwarn = getfixture('recwarn') + >>> dl = DeprecatedList(range(3)) + >>> dl[0] = 1 + >>> dl.append(3) + >>> del dl[3] + >>> dl.reverse() + >>> dl.sort() + >>> dl.extend([4]) + >>> dl.pop(-1) + 4 + >>> dl.remove(1) + >>> dl += [5] + >>> dl + [6] + [1, 2, 5, 6] + >>> dl + (6,) + [1, 2, 5, 6] + >>> dl.insert(0, 0) + >>> dl + [0, 1, 2, 5] + >>> dl == [0, 1, 2, 5] + True + >>> dl == (0, 1, 2, 5) + True + >>> len(recwarn) + 1 + """ + + _warn = functools.partial( + warnings.warn, + "EntryPoints list interface is deprecated. Cast to list if needed.", + DeprecationWarning, + stacklevel=pypy_partial(2), + ) + + def _wrap_deprecated_method(method_name: str): # type: ignore + def wrapped(self, *args, **kwargs): + self._warn() + return getattr(super(), method_name)(*args, **kwargs) + + return wrapped + + for method_name in [ + '__setitem__', + '__delitem__', + 'append', + 'reverse', + 'extend', + 'pop', + 'remove', + '__iadd__', + 'insert', + 'sort', + ]: + locals()[method_name] = _wrap_deprecated_method(method_name) + + def __add__(self, other): + if not isinstance(other, tuple): + self._warn() + other = tuple(other) + return self.__class__(tuple(self) + other) + + def __eq__(self, other): + if not isinstance(other, tuple): + self._warn() + other = tuple(other) + + return tuple(self).__eq__(other) + + +class EntryPoints(DeprecatedList): + """ + An immutable collection of selectable EntryPoint objects. + """ + + __slots__ = () + + def __getitem__(self, name): # -> EntryPoint: + """ + Get the EntryPoint in self matching name. + """ + if isinstance(name, int): + warnings.warn( + "Accessing entry points by index is deprecated. " + "Cast to tuple if needed.", + DeprecationWarning, + stacklevel=2, + ) + return super().__getitem__(name) + try: + return next(iter(self.select(name=name))) + except StopIteration: + raise KeyError(name) + + def select(self, **params): + """ + Select entry points from self that match the + given parameters (typically group and/or name). + """ + return EntryPoints(ep for ep in self if ep.matches(**params)) + + @property + def names(self): + """ + Return the set of all names of all entry points. + """ + return {ep.name for ep in self} + + @property + def groups(self): + """ + Return the set of all groups of all entry points. + + For coverage while SelectableGroups is present. + >>> EntryPoints().groups + set() + """ + return {ep.group for ep in self} + + @classmethod + def _from_text_for(cls, text, dist): + return cls(ep._for(dist) for ep in cls._from_text(text)) + + @staticmethod + def _from_text(text): + return ( + EntryPoint(name=item.value.name, value=item.value.value, group=item.name) + for item in Sectioned.section_pairs(text or '') + ) + + +class Deprecated: + """ + Compatibility add-in for mapping to indicate that + mapping behavior is deprecated. + + >>> recwarn = getfixture('recwarn') + >>> class DeprecatedDict(Deprecated, dict): pass + >>> dd = DeprecatedDict(foo='bar') + >>> dd.get('baz', None) + >>> dd['foo'] + 'bar' + >>> list(dd) + ['foo'] + >>> list(dd.keys()) + ['foo'] + >>> 'foo' in dd + True + >>> list(dd.values()) + ['bar'] + >>> len(recwarn) + 1 + """ + + _warn = functools.partial( + warnings.warn, + "SelectableGroups dict interface is deprecated. Use select.", + DeprecationWarning, + stacklevel=pypy_partial(2), + ) + + def __getitem__(self, name): + self._warn() + return super().__getitem__(name) + + def get(self, name, default=None): + self._warn() + return super().get(name, default) + + def __iter__(self): + self._warn() + return super().__iter__() + + def __contains__(self, *args): + self._warn() + return super().__contains__(*args) + + def keys(self): + self._warn() + return super().keys() + + def values(self): + self._warn() + return super().values() + + +class SelectableGroups(Deprecated, dict): + """ + A backward- and forward-compatible result from + entry_points that fully implements the dict interface. + """ + + @classmethod + def load(cls, eps): + by_group = operator.attrgetter('group') + ordered = sorted(eps, key=by_group) + grouped = itertools.groupby(ordered, by_group) + return cls((group, EntryPoints(eps)) for group, eps in grouped) + + @property + def _all(self): + """ + Reconstruct a list of all entrypoints from the groups. + """ + groups = super(Deprecated, self).values() + return EntryPoints(itertools.chain.from_iterable(groups)) + + @property + def groups(self): + return self._all.groups + + @property + def names(self): + """ + for coverage: + >>> SelectableGroups().names + set() + """ + return self._all.names + + def select(self, **params): + if not params: + return self + return self._all.select(**params) + + +class PackagePath(pathlib.PurePosixPath): + """A reference to a path in a package""" + + def read_text(self, encoding='utf-8'): + with self.locate().open(encoding=encoding) as stream: + return stream.read() + + def read_binary(self): + with self.locate().open('rb') as stream: + return stream.read() + + def locate(self): + """Return a path-like object for this path""" + return self.dist.locate_file(self) + + +class FileHash: + def __init__(self, spec): + self.mode, _, self.value = spec.partition('=') + + def __repr__(self): + return f'' + + +class Distribution: + """A Python distribution package.""" + + @abc.abstractmethod + def read_text(self, filename): + """Attempt to load metadata file given by the name. + + :param filename: The name of the file in the distribution info. + :return: The text if found, otherwise None. + """ + + @abc.abstractmethod + def locate_file(self, path): + """ + Given a path to a file in this distribution, return a path + to it. + """ + + @classmethod + def from_name(cls, name): + """Return the Distribution for the given package name. + + :param name: The name of the distribution package to search for. + :return: The Distribution instance (or subclass thereof) for the named + package, if found. + :raises PackageNotFoundError: When the named package's distribution + metadata cannot be found. + """ + for resolver in cls._discover_resolvers(): + dists = resolver(DistributionFinder.Context(name=name)) + dist = next(iter(dists), None) + if dist is not None: + return dist + else: + raise PackageNotFoundError(name) + + @classmethod + def discover(cls, **kwargs): + """Return an iterable of Distribution objects for all packages. + + Pass a ``context`` or pass keyword arguments for constructing + a context. + + :context: A ``DistributionFinder.Context`` object. + :return: Iterable of Distribution objects for all packages. + """ + context = kwargs.pop('context', None) + if context and kwargs: + raise ValueError("cannot accept context and kwargs") + context = context or DistributionFinder.Context(**kwargs) + return itertools.chain.from_iterable( + resolver(context) for resolver in cls._discover_resolvers() + ) + + @staticmethod + def at(path): + """Return a Distribution for the indicated metadata path + + :param path: a string or path-like object + :return: a concrete Distribution instance for the path + """ + return PathDistribution(pathlib.Path(path)) + + @staticmethod + def _discover_resolvers(): + """Search the meta_path for resolvers.""" + declared = ( + getattr(finder, 'find_distributions', None) for finder in sys.meta_path + ) + return filter(None, declared) + + @classmethod + def _local(cls, root='.'): + from pep517 import build, meta + + system = build.compat_system(root) + builder = functools.partial( + meta.build, + source_dir=root, + system=system, + ) + return PathDistribution(zipp.Path(meta.build_as_zip(builder))) + + @property + def metadata(self) -> _meta.PackageMetadata: + """Return the parsed metadata for this Distribution. + + The returned object will have keys that name the various bits of + metadata. See PEP 566 for details. + """ + text = ( + self.read_text('METADATA') + or self.read_text('PKG-INFO') + # This last clause is here to support old egg-info files. Its + # effect is to just end up using the PathDistribution's self._path + # (which points to the egg-info file) attribute unchanged. + or self.read_text('') + ) + return _adapters.Message(email.message_from_string(text)) + + @property + def name(self): + """Return the 'Name' metadata for the distribution package.""" + return self.metadata['Name'] + + @property + def _normalized_name(self): + """Return a normalized version of the name.""" + return Prepared.normalize(self.name) + + @property + def version(self): + """Return the 'Version' metadata for the distribution package.""" + return self.metadata['Version'] + + @property + def entry_points(self): + return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self) + + @property + def files(self): + """Files in this distribution. + + :return: List of PackagePath for this distribution or None + + Result is `None` if the metadata file that enumerates files + (i.e. RECORD for dist-info or SOURCES.txt for egg-info) is + missing. + Result may be empty if the metadata exists but is empty. + """ + + def make_file(name, hash=None, size_str=None): + result = PackagePath(name) + result.hash = FileHash(hash) if hash else None + result.size = int(size_str) if size_str else None + result.dist = self + return result + + @pass_none + def make_files(lines): + return list(starmap(make_file, csv.reader(lines))) + + return make_files(self._read_files_distinfo() or self._read_files_egginfo()) + + def _read_files_distinfo(self): + """ + Read the lines of RECORD + """ + text = self.read_text('RECORD') + return text and text.splitlines() + + def _read_files_egginfo(self): + """ + SOURCES.txt might contain literal commas, so wrap each line + in quotes. + """ + text = self.read_text('SOURCES.txt') + return text and map('"{}"'.format, text.splitlines()) + + @property + def requires(self): + """Generated requirements specified for this Distribution""" + reqs = self._read_dist_info_reqs() or self._read_egg_info_reqs() + return reqs and list(reqs) + + def _read_dist_info_reqs(self): + return self.metadata.get_all('Requires-Dist') + + def _read_egg_info_reqs(self): + source = self.read_text('requires.txt') + return source and self._deps_from_requires_text(source) + + @classmethod + def _deps_from_requires_text(cls, source): + return cls._convert_egg_info_reqs_to_simple_reqs(Sectioned.read(source)) + + @staticmethod + def _convert_egg_info_reqs_to_simple_reqs(sections): + """ + Historically, setuptools would solicit and store 'extra' + requirements, including those with environment markers, + in separate sections. More modern tools expect each + dependency to be defined separately, with any relevant + extras and environment markers attached directly to that + requirement. This method converts the former to the + latter. See _test_deps_from_requires_text for an example. + """ + + def make_condition(name): + return name and f'extra == "{name}"' + + def quoted_marker(section): + section = section or '' + extra, sep, markers = section.partition(':') + if extra and markers: + markers = f'({markers})' + conditions = list(filter(None, [markers, make_condition(extra)])) + return '; ' + ' and '.join(conditions) if conditions else '' + + def url_req_space(req): + """ + PEP 508 requires a space between the url_spec and the quoted_marker. + Ref python/importlib_metadata#357. + """ + # '@' is uniquely indicative of a url_req. + return ' ' * ('@' in req) + + for section in sections: + space = url_req_space(section.value) + yield section.value + space + quoted_marker(section.name) + + +class DistributionFinder(MetaPathFinder): + """ + A MetaPathFinder capable of discovering installed distributions. + """ + + class Context: + """ + Keyword arguments presented by the caller to + ``distributions()`` or ``Distribution.discover()`` + to narrow the scope of a search for distributions + in all DistributionFinders. + + Each DistributionFinder may expect any parameters + and should attempt to honor the canonical + parameters defined below when appropriate. + """ + + name = None + """ + Specific name for which a distribution finder should match. + A name of ``None`` matches all distributions. + """ + + def __init__(self, **kwargs): + vars(self).update(kwargs) + + @property + def path(self): + """ + The sequence of directory path that a distribution finder + should search. + + Typically refers to Python installed package paths such as + "site-packages" directories and defaults to ``sys.path``. + """ + return vars(self).get('path', sys.path) + + @abc.abstractmethod + def find_distributions(self, context=Context()): + """ + Find distributions. + + Return an iterable of all Distribution instances capable of + loading the metadata for packages matching the ``context``, + a DistributionFinder.Context instance. + """ + + +class FastPath: + """ + Micro-optimized class for searching a path for + children. + + >>> FastPath('').children() + ['...'] + """ + + @functools.lru_cache() # type: ignore + def __new__(cls, root): + return super().__new__(cls) + + def __init__(self, root): + self.root = str(root) + + def joinpath(self, child): + return pathlib.Path(self.root, child) + + def children(self): + with suppress(Exception): + return os.listdir(self.root or '.') + with suppress(Exception): + return self.zip_children() + return [] + + def zip_children(self): + zip_path = zipp.Path(self.root) + names = zip_path.root.namelist() + self.joinpath = zip_path.joinpath + + return dict.fromkeys(child.split(posixpath.sep, 1)[0] for child in names) + + def search(self, name): + return self.lookup(self.mtime).search(name) + + @property + def mtime(self): + with suppress(OSError): + return os.stat(self.root).st_mtime + self.lookup.cache_clear() + + @method_cache + def lookup(self, mtime): + return Lookup(self) + + +class Lookup: + def __init__(self, path: FastPath): + base = os.path.basename(path.root).lower() + base_is_egg = base.endswith(".egg") + self.infos = FreezableDefaultDict(list) + self.eggs = FreezableDefaultDict(list) + + for child in path.children(): + low = child.lower() + if low.endswith((".dist-info", ".egg-info")): + # rpartition is faster than splitext and suitable for this purpose. + name = low.rpartition(".")[0].partition("-")[0] + normalized = Prepared.normalize(name) + self.infos[normalized].append(path.joinpath(child)) + elif base_is_egg and low == "egg-info": + name = base.rpartition(".")[0].partition("-")[0] + legacy_normalized = Prepared.legacy_normalize(name) + self.eggs[legacy_normalized].append(path.joinpath(child)) + + self.infos.freeze() + self.eggs.freeze() + + def search(self, prepared): + infos = ( + self.infos[prepared.normalized] + if prepared + else itertools.chain.from_iterable(self.infos.values()) + ) + eggs = ( + self.eggs[prepared.legacy_normalized] + if prepared + else itertools.chain.from_iterable(self.eggs.values()) + ) + return itertools.chain(infos, eggs) + + +class Prepared: + """ + A prepared search for metadata on a possibly-named package. + """ + + normalized = None + legacy_normalized = None + + def __init__(self, name): + self.name = name + if name is None: + return + self.normalized = self.normalize(name) + self.legacy_normalized = self.legacy_normalize(name) + + @staticmethod + def normalize(name): + """ + PEP 503 normalization plus dashes as underscores. + """ + return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_') + + @staticmethod + def legacy_normalize(name): + """ + Normalize the package name as found in the convention in + older packaging tools versions and specs. + """ + return name.lower().replace('-', '_') + + def __bool__(self): + return bool(self.name) + + +@install +class MetadataPathFinder(NullFinder, DistributionFinder): + """A degenerate finder for distribution packages on the file system. + + This finder supplies only a find_distributions() method for versions + of Python that do not have a PathFinder find_distributions(). + """ + + def find_distributions(self, context=DistributionFinder.Context()): + """ + Find distributions. + + Return an iterable of all Distribution instances capable of + loading the metadata for packages matching ``context.name`` + (or all names if ``None`` indicated) along the paths in the list + of directories ``context.path``. + """ + found = self._search_paths(context.name, context.path) + return map(PathDistribution, found) + + @classmethod + def _search_paths(cls, name, paths): + """Find metadata directories in paths heuristically.""" + prepared = Prepared(name) + return itertools.chain.from_iterable( + path.search(prepared) for path in map(FastPath, paths) + ) + + def invalidate_caches(cls): + FastPath.__new__.cache_clear() + + +class PathDistribution(Distribution): + def __init__(self, path: SimplePath): + """Construct a distribution. + + :param path: SimplePath indicating the metadata directory. + """ + self._path = path + + def read_text(self, filename): + with suppress( + FileNotFoundError, + IsADirectoryError, + KeyError, + NotADirectoryError, + PermissionError, + ): + return self._path.joinpath(filename).read_text(encoding='utf-8') + + read_text.__doc__ = Distribution.read_text.__doc__ + + def locate_file(self, path): + return self._path.parent / path + + @property + def _normalized_name(self): + """ + Performance optimization: where possible, resolve the + normalized name from the file system path. + """ + stem = os.path.basename(str(self._path)) + return self._name_from_stem(stem) or super()._normalized_name + + def _name_from_stem(self, stem): + name, ext = os.path.splitext(stem) + if ext not in ('.dist-info', '.egg-info'): + return + name, sep, rest = stem.partition('-') + return name + + +def distribution(distribution_name): + """Get the ``Distribution`` instance for the named package. + + :param distribution_name: The name of the distribution package as a string. + :return: A ``Distribution`` instance (or subclass thereof). + """ + return Distribution.from_name(distribution_name) + + +def distributions(**kwargs): + """Get all ``Distribution`` instances in the current environment. + + :return: An iterable of ``Distribution`` instances. + """ + return Distribution.discover(**kwargs) + + +def metadata(distribution_name) -> _meta.PackageMetadata: + """Get the metadata for the named package. + + :param distribution_name: The name of the distribution package to query. + :return: A PackageMetadata containing the parsed metadata. + """ + return Distribution.from_name(distribution_name).metadata + + +def version(distribution_name): + """Get the version string for the named package. + + :param distribution_name: The name of the distribution package to query. + :return: The version string for the package as defined in the package's + "Version" metadata key. + """ + return distribution(distribution_name).version + + +def entry_points(**params) -> Union[EntryPoints, SelectableGroups]: + """Return EntryPoint objects for all installed packages. + + Pass selection parameters (group or name) to filter the + result to entry points matching those properties (see + EntryPoints.select()). + + For compatibility, returns ``SelectableGroups`` object unless + selection parameters are supplied. In the future, this function + will return ``EntryPoints`` instead of ``SelectableGroups`` + even when no selection parameters are supplied. + + For maximum future compatibility, pass selection parameters + or invoke ``.select`` with parameters on the result. + + :return: EntryPoints or SelectableGroups for all installed packages. + """ + norm_name = operator.attrgetter('_normalized_name') + unique = functools.partial(unique_everseen, key=norm_name) + eps = itertools.chain.from_iterable( + dist.entry_points for dist in unique(distributions()) + ) + return SelectableGroups.load(eps).select(**params) + + +def files(distribution_name): + """Return a list of files for the named package. + + :param distribution_name: The name of the distribution package to query. + :return: List of files composing the distribution. + """ + return distribution(distribution_name).files + + +def requires(distribution_name): + """ + Return a list of requirements for the named package. + + :return: An iterator of requirements, suitable for + packaging.requirement.Requirement. + """ + return distribution(distribution_name).requires + + +def packages_distributions() -> Mapping[str, List[str]]: + """ + Return a mapping of top-level packages to their + distributions. + + >>> import collections.abc + >>> pkgs = packages_distributions() + >>> all(isinstance(dist, collections.abc.Sequence) for dist in pkgs.values()) + True + """ + pkg_to_dist = collections.defaultdict(list) + for dist in distributions(): + for pkg in _top_level_declared(dist) or _top_level_inferred(dist): + pkg_to_dist[pkg].append(dist.metadata['Name']) + return dict(pkg_to_dist) + + +def _top_level_declared(dist): + return (dist.read_text('top_level.txt') or '').split() + + +def _top_level_inferred(dist): + return { + f.parts[0] if len(f.parts) > 1 else f.with_suffix('').name + for f in always_iterable(dist.files) + if f.suffix == ".py" + } diff --git a/metaflow/_vendor/importlib_metadata/_adapters.py b/metaflow/_vendor/importlib_metadata/_adapters.py new file mode 100644 index 00000000000..aa460d3eda5 --- /dev/null +++ b/metaflow/_vendor/importlib_metadata/_adapters.py @@ -0,0 +1,68 @@ +import re +import textwrap +import email.message + +from ._text import FoldedCase + + +class Message(email.message.Message): + multiple_use_keys = set( + map( + FoldedCase, + [ + 'Classifier', + 'Obsoletes-Dist', + 'Platform', + 'Project-URL', + 'Provides-Dist', + 'Provides-Extra', + 'Requires-Dist', + 'Requires-External', + 'Supported-Platform', + 'Dynamic', + ], + ) + ) + """ + Keys that may be indicated multiple times per PEP 566. + """ + + def __new__(cls, orig: email.message.Message): + res = super().__new__(cls) + vars(res).update(vars(orig)) + return res + + def __init__(self, *args, **kwargs): + self._headers = self._repair_headers() + + # suppress spurious error from mypy + def __iter__(self): + return super().__iter__() + + def _repair_headers(self): + def redent(value): + "Correct for RFC822 indentation" + if not value or '\n' not in value: + return value + return textwrap.dedent(' ' * 8 + value) + + headers = [(key, redent(value)) for key, value in vars(self)['_headers']] + if self._payload: + headers.append(('Description', self.get_payload())) + return headers + + @property + def json(self): + """ + Convert PackageMetadata to a JSON-compatible format + per PEP 0566. + """ + + def transform(key): + value = self.get_all(key) if key in self.multiple_use_keys else self[key] + if key == 'Keywords': + value = re.split(r'\s+', value) + tk = key.lower().replace('-', '_') + return tk, value + + return dict(map(transform, map(FoldedCase, self))) diff --git a/metaflow/_vendor/importlib_metadata/_collections.py b/metaflow/_vendor/importlib_metadata/_collections.py new file mode 100644 index 00000000000..cf0954e1a30 --- /dev/null +++ b/metaflow/_vendor/importlib_metadata/_collections.py @@ -0,0 +1,30 @@ +import collections + + +# from jaraco.collections 3.3 +class FreezableDefaultDict(collections.defaultdict): + """ + Often it is desirable to prevent the mutation of + a default dict after its initial construction, such + as to prevent mutation during iteration. + + >>> dd = FreezableDefaultDict(list) + >>> dd[0].append('1') + >>> dd.freeze() + >>> dd[1] + [] + >>> len(dd) + 1 + """ + + def __missing__(self, key): + return getattr(self, '_frozen', super().__missing__)(key) + + def freeze(self): + self._frozen = lambda key: self.default_factory() + + +class Pair(collections.namedtuple('Pair', 'name value')): + @classmethod + def parse(cls, text): + return cls(*map(str.strip, text.split("=", 1))) diff --git a/metaflow/_vendor/importlib_metadata/_compat.py b/metaflow/_vendor/importlib_metadata/_compat.py new file mode 100644 index 00000000000..15927dbb753 --- /dev/null +++ b/metaflow/_vendor/importlib_metadata/_compat.py @@ -0,0 +1,71 @@ +import sys +import platform + + +__all__ = ['install', 'NullFinder', 'Protocol'] + + +try: + from typing import Protocol +except ImportError: # pragma: no cover + from metaflow._vendor.typing_extensions import Protocol # type: ignore + + +def install(cls): + """ + Class decorator for installation on sys.meta_path. + + Adds the backport DistributionFinder to sys.meta_path and + attempts to disable the finder functionality of the stdlib + DistributionFinder. + """ + sys.meta_path.append(cls()) + disable_stdlib_finder() + return cls + + +def disable_stdlib_finder(): + """ + Give the backport primacy for discovering path-based distributions + by monkey-patching the stdlib O_O. + + See #91 for more background for rationale on this sketchy + behavior. + """ + + def matches(finder): + return getattr( + finder, '__module__', None + ) == '_frozen_importlib_external' and hasattr(finder, 'find_distributions') + + for finder in filter(matches, sys.meta_path): # pragma: nocover + del finder.find_distributions + + +class NullFinder: + """ + A "Finder" (aka "MetaClassFinder") that never finds any modules, + but may find distributions. + """ + + @staticmethod + def find_spec(*args, **kwargs): + return None + + # In Python 2, the import system requires finders + # to have a find_module() method, but this usage + # is deprecated in Python 3 in favor of find_spec(). + # For the purposes of this finder (i.e. being present + # on sys.meta_path but having no other import + # system functionality), the two methods are identical. + find_module = find_spec + + +def pypy_partial(val): + """ + Adjust for variable stacklevel on partial under PyPy. + + Workaround for #327. + """ + is_pypy = platform.python_implementation() == 'PyPy' + return val + is_pypy diff --git a/metaflow/_vendor/importlib_metadata/_functools.py b/metaflow/_vendor/importlib_metadata/_functools.py new file mode 100644 index 00000000000..71f66bd03cb --- /dev/null +++ b/metaflow/_vendor/importlib_metadata/_functools.py @@ -0,0 +1,104 @@ +import types +import functools + + +# from jaraco.functools 3.3 +def method_cache(method, cache_wrapper=None): + """ + Wrap lru_cache to support storing the cache data in the object instances. + + Abstracts the common paradigm where the method explicitly saves an + underscore-prefixed protected property on first call and returns that + subsequently. + + >>> class MyClass: + ... calls = 0 + ... + ... @method_cache + ... def method(self, value): + ... self.calls += 1 + ... return value + + >>> a = MyClass() + >>> a.method(3) + 3 + >>> for x in range(75): + ... res = a.method(x) + >>> a.calls + 75 + + Note that the apparent behavior will be exactly like that of lru_cache + except that the cache is stored on each instance, so values in one + instance will not flush values from another, and when an instance is + deleted, so are the cached values for that instance. + + >>> b = MyClass() + >>> for x in range(35): + ... res = b.method(x) + >>> b.calls + 35 + >>> a.method(0) + 0 + >>> a.calls + 75 + + Note that if method had been decorated with ``functools.lru_cache()``, + a.calls would have been 76 (due to the cached value of 0 having been + flushed by the 'b' instance). + + Clear the cache with ``.cache_clear()`` + + >>> a.method.cache_clear() + + Same for a method that hasn't yet been called. + + >>> c = MyClass() + >>> c.method.cache_clear() + + Another cache wrapper may be supplied: + + >>> cache = functools.lru_cache(maxsize=2) + >>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache) + >>> a = MyClass() + >>> a.method2() + 3 + + Caution - do not subsequently wrap the method with another decorator, such + as ``@property``, which changes the semantics of the function. + + See also + http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ + for another implementation and additional justification. + """ + cache_wrapper = cache_wrapper or functools.lru_cache() + + def wrapper(self, *args, **kwargs): + # it's the first call, replace the method with a cached, bound method + bound_method = types.MethodType(method, self) + cached_method = cache_wrapper(bound_method) + setattr(self, method.__name__, cached_method) + return cached_method(*args, **kwargs) + + # Support cache clear even before cache has been created. + wrapper.cache_clear = lambda: None + + return wrapper + + +# From jaraco.functools 3.3 +def pass_none(func): + """ + Wrap func so it's not called if its first param is None + + >>> print_text = pass_none(print) + >>> print_text('text') + text + >>> print_text(None) + """ + + @functools.wraps(func) + def wrapper(param, *args, **kwargs): + if param is not None: + return func(param, *args, **kwargs) + + return wrapper diff --git a/metaflow/_vendor/importlib_metadata/_itertools.py b/metaflow/_vendor/importlib_metadata/_itertools.py new file mode 100644 index 00000000000..d4ca9b9140e --- /dev/null +++ b/metaflow/_vendor/importlib_metadata/_itertools.py @@ -0,0 +1,73 @@ +from itertools import filterfalse + + +def unique_everseen(iterable, key=None): + "List unique elements, preserving order. Remember all elements ever seen." + # unique_everseen('AAAABBBCCDAABBB') --> A B C D + # unique_everseen('ABBCcAD', str.lower) --> A B C D + seen = set() + seen_add = seen.add + if key is None: + for element in filterfalse(seen.__contains__, iterable): + seen_add(element) + yield element + else: + for element in iterable: + k = key(element) + if k not in seen: + seen_add(k) + yield element + + +# copied from more_itertools 8.8 +def always_iterable(obj, base_type=(str, bytes)): + """If *obj* is iterable, return an iterator over its items:: + + >>> obj = (1, 2, 3) + >>> list(always_iterable(obj)) + [1, 2, 3] + + If *obj* is not iterable, return a one-item iterable containing *obj*:: + + >>> obj = 1 + >>> list(always_iterable(obj)) + [1] + + If *obj* is ``None``, return an empty iterable: + + >>> obj = None + >>> list(always_iterable(None)) + [] + + By default, binary and text strings are not considered iterable:: + + >>> obj = 'foo' + >>> list(always_iterable(obj)) + ['foo'] + + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + >>> obj = {'a': 1} + >>> list(always_iterable(obj)) # Iterate over the dict's keys + ['a'] + >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit + [{'a': 1}] + + Set *base_type* to ``None`` to avoid any special handling and treat objects + Python considers iterable as iterable: + + >>> obj = 'foo' + >>> list(always_iterable(obj, base_type=None)) + ['f', 'o', 'o'] + """ + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) diff --git a/metaflow/_vendor/importlib_metadata/_meta.py b/metaflow/_vendor/importlib_metadata/_meta.py new file mode 100644 index 00000000000..37ee43e6ef4 --- /dev/null +++ b/metaflow/_vendor/importlib_metadata/_meta.py @@ -0,0 +1,48 @@ +from ._compat import Protocol +from typing import Any, Dict, Iterator, List, TypeVar, Union + + +_T = TypeVar("_T") + + +class PackageMetadata(Protocol): + def __len__(self) -> int: + ... # pragma: no cover + + def __contains__(self, item: str) -> bool: + ... # pragma: no cover + + def __getitem__(self, key: str) -> str: + ... # pragma: no cover + + def __iter__(self) -> Iterator[str]: + ... # pragma: no cover + + def get_all(self, name: str, failobj: _T = ...) -> Union[List[Any], _T]: + """ + Return all values associated with a possibly multi-valued key. + """ + + @property + def json(self) -> Dict[str, Union[str, List[str]]]: + """ + A JSON-compatible form of the metadata. + """ + + +class SimplePath(Protocol): + """ + A minimal subset of pathlib.Path required by PathDistribution. + """ + + def joinpath(self) -> 'SimplePath': + ... # pragma: no cover + + def __truediv__(self) -> 'SimplePath': + ... # pragma: no cover + + def parent(self) -> 'SimplePath': + ... # pragma: no cover + + def read_text(self) -> str: + ... # pragma: no cover diff --git a/metaflow/_vendor/importlib_metadata/_text.py b/metaflow/_vendor/importlib_metadata/_text.py new file mode 100644 index 00000000000..c88cfbb2349 --- /dev/null +++ b/metaflow/_vendor/importlib_metadata/_text.py @@ -0,0 +1,99 @@ +import re + +from ._functools import method_cache + + +# from jaraco.text 3.5 +class FoldedCase(str): + """ + A case insensitive string class; behaves just like str + except compares equal when the only variation is case. + + >>> s = FoldedCase('hello world') + + >>> s == 'Hello World' + True + + >>> 'Hello World' == s + True + + >>> s != 'Hello World' + False + + >>> s.index('O') + 4 + + >>> s.split('O') + ['hell', ' w', 'rld'] + + >>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta'])) + ['alpha', 'Beta', 'GAMMA'] + + Sequence membership is straightforward. + + >>> "Hello World" in [s] + True + >>> s in ["Hello World"] + True + + You may test for set inclusion, but candidate and elements + must both be folded. + + >>> FoldedCase("Hello World") in {s} + True + >>> s in {FoldedCase("Hello World")} + True + + String inclusion works as long as the FoldedCase object + is on the right. + + >>> "hello" in FoldedCase("Hello World") + True + + But not if the FoldedCase object is on the left: + + >>> FoldedCase('hello') in 'Hello World' + False + + In that case, use in_: + + >>> FoldedCase('hello').in_('Hello World') + True + + >>> FoldedCase('hello') > FoldedCase('Hello') + False + """ + + def __lt__(self, other): + return self.lower() < other.lower() + + def __gt__(self, other): + return self.lower() > other.lower() + + def __eq__(self, other): + return self.lower() == other.lower() + + def __ne__(self, other): + return self.lower() != other.lower() + + def __hash__(self): + return hash(self.lower()) + + def __contains__(self, other): + return super().lower().__contains__(other.lower()) + + def in_(self, other): + "Does self appear in other?" + return self in FoldedCase(other) + + # cache lower since it's likely to be called frequently. + @method_cache + def lower(self): + return super().lower() + + def index(self, sub): + return self.lower().index(sub.lower()) + + def split(self, splitter=' ', maxsplit=0): + pattern = re.compile(re.escape(splitter), re.I) + return pattern.split(self, maxsplit) diff --git a/metaflow/_vendor/importlib_metadata/py.typed b/metaflow/_vendor/importlib_metadata/py.typed new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/_vendor/typeguard.LICENSE b/metaflow/_vendor/typeguard.LICENSE new file mode 100644 index 00000000000..07806f8af9d --- /dev/null +++ b/metaflow/_vendor/typeguard.LICENSE @@ -0,0 +1,19 @@ +This is the MIT license: http://www.opensource.org/licenses/mit-license.php + +Copyright (c) Alex Grönholm + +Permission is hereby granted, free of charge, to any person obtaining a copy of this +software and associated documentation files (the "Software"), to deal in the Software +without restriction, including without limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or +substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/metaflow/_vendor/typeguard/__init__.py b/metaflow/_vendor/typeguard/__init__.py new file mode 100644 index 00000000000..6781cad094b --- /dev/null +++ b/metaflow/_vendor/typeguard/__init__.py @@ -0,0 +1,48 @@ +import os +from typing import Any + +from ._checkers import TypeCheckerCallable as TypeCheckerCallable +from ._checkers import TypeCheckLookupCallback as TypeCheckLookupCallback +from ._checkers import check_type_internal as check_type_internal +from ._checkers import checker_lookup_functions as checker_lookup_functions +from ._checkers import load_plugins as load_plugins +from ._config import CollectionCheckStrategy as CollectionCheckStrategy +from ._config import ForwardRefPolicy as ForwardRefPolicy +from ._config import TypeCheckConfiguration as TypeCheckConfiguration +from ._decorators import typechecked as typechecked +from ._decorators import typeguard_ignore as typeguard_ignore +from ._exceptions import InstrumentationWarning as InstrumentationWarning +from ._exceptions import TypeCheckError as TypeCheckError +from ._exceptions import TypeCheckWarning as TypeCheckWarning +from ._exceptions import TypeHintWarning as TypeHintWarning +from ._functions import TypeCheckFailCallback as TypeCheckFailCallback +from ._functions import check_type as check_type +from ._functions import warn_on_error as warn_on_error +from ._importhook import ImportHookManager as ImportHookManager +from ._importhook import TypeguardFinder as TypeguardFinder +from ._importhook import install_import_hook as install_import_hook +from ._memo import TypeCheckMemo as TypeCheckMemo +from ._suppression import suppress_type_checks as suppress_type_checks +from ._utils import Unset as Unset + +# Re-export imports so they look like they live directly in this package +for value in list(locals().values()): + if getattr(value, "__module__", "").startswith(f"{__name__}."): + value.__module__ = __name__ + + +config: TypeCheckConfiguration + + +def __getattr__(name: str) -> Any: + if name == "config": + from ._config import global_config + + return global_config + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# Automatically load checker lookup functions unless explicitly disabled +if "TYPEGUARD_DISABLE_PLUGIN_AUTOLOAD" not in os.environ: + load_plugins() diff --git a/metaflow/_vendor/typeguard/_checkers.py b/metaflow/_vendor/typeguard/_checkers.py new file mode 100644 index 00000000000..0c38917a32e --- /dev/null +++ b/metaflow/_vendor/typeguard/_checkers.py @@ -0,0 +1,906 @@ +from __future__ import annotations + +import collections.abc +import inspect +import sys +import types +import typing +import warnings +from enum import Enum +from inspect import Parameter, isclass, isfunction +from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase +from textwrap import indent +from typing import ( + IO, + AbstractSet, + Any, + BinaryIO, + Callable, + Dict, + ForwardRef, + List, + Mapping, + MutableMapping, + NewType, + Optional, + Sequence, + Set, + TextIO, + Tuple, + Type, + TypeVar, + Union, +) +from unittest.mock import Mock + +try: + from metaflow._vendor import typing_extensions +except ImportError: + typing_extensions = None # type: ignore[assignment] + +from ._config import ForwardRefPolicy +from ._exceptions import TypeCheckError, TypeHintWarning +from ._memo import TypeCheckMemo +from ._utils import evaluate_forwardref, get_stacklevel, get_type_name, qualified_name + +if sys.version_info >= (3, 11): + from typing import ( + Annotated, + TypeAlias, + get_args, + get_origin, + get_type_hints, + is_typeddict, + ) + + SubclassableAny = Any +else: + from metaflow._vendor.typing_extensions import ( + Annotated, + TypeAlias, + get_args, + get_origin, + get_type_hints, + is_typeddict, + ) + from metaflow._vendor.typing_extensions import Any as SubclassableAny + +if sys.version_info >= (3, 10): + from importlib.metadata import entry_points + from typing import ParamSpec +else: + from metaflow._vendor.importlib_metadata import entry_points + from metaflow._vendor.typing_extensions import ParamSpec + +TypeCheckerCallable: TypeAlias = Callable[ + [Any, Any, Tuple[Any, ...], TypeCheckMemo], Any +] +TypeCheckLookupCallback: TypeAlias = Callable[ + [Any, Tuple[Any, ...], Tuple[Any, ...]], Optional[TypeCheckerCallable] +] + +checker_lookup_functions: list[TypeCheckLookupCallback] = [] + + +# Sentinel +_missing = object() + +# Lifted from mypy.sharedparse +BINARY_MAGIC_METHODS = { + "__add__", + "__and__", + "__cmp__", + "__divmod__", + "__div__", + "__eq__", + "__floordiv__", + "__ge__", + "__gt__", + "__iadd__", + "__iand__", + "__idiv__", + "__ifloordiv__", + "__ilshift__", + "__imatmul__", + "__imod__", + "__imul__", + "__ior__", + "__ipow__", + "__irshift__", + "__isub__", + "__itruediv__", + "__ixor__", + "__le__", + "__lshift__", + "__lt__", + "__matmul__", + "__mod__", + "__mul__", + "__ne__", + "__or__", + "__pow__", + "__radd__", + "__rand__", + "__rdiv__", + "__rfloordiv__", + "__rlshift__", + "__rmatmul__", + "__rmod__", + "__rmul__", + "__ror__", + "__rpow__", + "__rrshift__", + "__rshift__", + "__rsub__", + "__rtruediv__", + "__rxor__", + "__sub__", + "__truediv__", + "__xor__", +} + + +def check_callable( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not callable(value): + raise TypeCheckError("is not callable") + + if args: + try: + signature = inspect.signature(value) + except (TypeError, ValueError): + return + + argument_types = args[0] + if isinstance(argument_types, list) and not any( + type(item) is ParamSpec for item in argument_types + ): + # The callable must not have keyword-only arguments without defaults + unfulfilled_kwonlyargs = [ + param.name + for param in signature.parameters.values() + if param.kind == Parameter.KEYWORD_ONLY + and param.default == Parameter.empty + ] + if unfulfilled_kwonlyargs: + raise TypeCheckError( + f"has mandatory keyword-only arguments in its declaration: " + f'{", ".join(unfulfilled_kwonlyargs)}' + ) + + num_mandatory_args = len( + [ + param.name + for param in signature.parameters.values() + if param.kind + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + and param.default is Parameter.empty + ] + ) + has_varargs = any( + param + for param in signature.parameters.values() + if param.kind == Parameter.VAR_POSITIONAL + ) + + if num_mandatory_args > len(argument_types): + raise TypeCheckError( + f"has too many arguments in its declaration; expected " + f"{len(argument_types)} but {num_mandatory_args} argument(s) " + f"declared" + ) + elif not has_varargs and num_mandatory_args < len(argument_types): + raise TypeCheckError( + f"has too few arguments in its declaration; expected " + f"{len(argument_types)} but {num_mandatory_args} argument(s) " + f"declared" + ) + + +def check_mapping( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if origin_type is Dict or origin_type is dict: + if not isinstance(value, dict): + raise TypeCheckError("is not a dict") + if origin_type is MutableMapping or origin_type is collections.abc.MutableMapping: + if not isinstance(value, collections.abc.MutableMapping): + raise TypeCheckError("is not a mutable mapping") + elif not isinstance(value, collections.abc.Mapping): + raise TypeCheckError("is not a mapping") + + if args: + key_type, value_type = args + if key_type is not Any or value_type is not Any: + samples = memo.config.collection_check_strategy.iterate_samples( + value.items() + ) + for k, v in samples: + try: + check_type_internal(k, key_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"key {k!r}") + raise + + try: + check_type_internal(v, value_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"value of key {k!r}") + raise + + +def check_typed_dict( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, dict): + raise TypeCheckError("is not a dict") + + declared_keys = frozenset(origin_type.__annotations__) + if hasattr(origin_type, "__required_keys__"): + required_keys = origin_type.__required_keys__ + else: # py3.8 and lower + required_keys = declared_keys if origin_type.__total__ else frozenset() + + existing_keys = frozenset(value) + extra_keys = existing_keys - declared_keys + if extra_keys: + keys_formatted = ", ".join(f'"{key}"' for key in sorted(extra_keys, key=repr)) + raise TypeCheckError(f"has unexpected extra key(s): {keys_formatted}") + + missing_keys = required_keys - existing_keys + if missing_keys: + keys_formatted = ", ".join(f'"{key}"' for key in sorted(missing_keys, key=repr)) + raise TypeCheckError(f"is missing required key(s): {keys_formatted}") + + for key, argtype in get_type_hints(origin_type).items(): + argvalue = value.get(key, _missing) + if argvalue is not _missing: + try: + check_type_internal(argvalue, argtype, memo) + except TypeCheckError as exc: + exc.append_path_element(f"value of key {key!r}") + raise + + +def check_list( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, list): + raise TypeCheckError("is not a list") + + if args and args != (Any,): + samples = memo.config.collection_check_strategy.iterate_samples(value) + for i, v in enumerate(samples): + try: + check_type_internal(v, args[0], memo) + except TypeCheckError as exc: + exc.append_path_element(f"item {i}") + raise + + +def check_sequence( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, collections.abc.Sequence): + raise TypeCheckError("is not a sequence") + + if args and args != (Any,): + samples = memo.config.collection_check_strategy.iterate_samples(value) + for i, v in enumerate(samples): + try: + check_type_internal(v, args[0], memo) + except TypeCheckError as exc: + exc.append_path_element(f"item {i}") + raise + + +def check_set( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if origin_type is frozenset: + if not isinstance(value, frozenset): + raise TypeCheckError("is not a frozenset") + elif not isinstance(value, AbstractSet): + raise TypeCheckError("is not a set") + + if args and args != (Any,): + samples = memo.config.collection_check_strategy.iterate_samples(value) + for v in samples: + try: + check_type_internal(v, args[0], memo) + except TypeCheckError as exc: + exc.append_path_element(f"[{v}]") + raise + + +def check_tuple( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + # Specialized check for NamedTuples + field_types = getattr(origin_type, "__annotations__", None) + if field_types is None and sys.version_info < (3, 8): + field_types = getattr(origin_type, "_field_types", None) + + if field_types: + if not isinstance(value, origin_type): + raise TypeCheckError( + f"is not a named tuple of type {qualified_name(origin_type)}" + ) + + for name, field_type in field_types.items(): + try: + check_type_internal(getattr(value, name), field_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"attribute {name!r}") + raise + + return + elif not isinstance(value, tuple): + raise TypeCheckError("is not a tuple") + + if args: + # Python 3.6+ + use_ellipsis = args[-1] is Ellipsis + tuple_params = args[: -1 if use_ellipsis else None] + else: + # Unparametrized Tuple or plain tuple + return + + if use_ellipsis: + element_type = tuple_params[0] + samples = memo.config.collection_check_strategy.iterate_samples(value) + for i, element in enumerate(samples): + try: + check_type_internal(element, element_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"item {i}") + raise + elif tuple_params == ((),): + if value != (): + raise TypeCheckError("is not an empty tuple") + else: + if len(value) != len(tuple_params): + raise TypeCheckError( + f"has wrong number of elements (expected {len(tuple_params)}, got " + f"{len(value)} instead)" + ) + + for i, (element, element_type) in enumerate(zip(value, tuple_params)): + try: + check_type_internal(element, element_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"item {i}") + raise + + +def check_union( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + errors: dict[str, TypeCheckError] = {} + for type_ in args: + try: + check_type_internal(value, type_, memo) + return + except TypeCheckError as exc: + errors[get_type_name(type_)] = exc + + formatted_errors = indent( + "\n".join(f"{key}: {error}" for key, error in errors.items()), " " + ) + raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}") + + +def check_uniontype( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + errors: dict[str, TypeCheckError] = {} + for type_ in args: + try: + check_type_internal(value, type_, memo) + return + except TypeCheckError as exc: + errors[get_type_name(type_)] = exc + + formatted_errors = indent( + "\n".join(f"{key}: {error}" for key, error in errors.items()), " " + ) + raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}") + + +def check_class( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isclass(value): + raise TypeCheckError("is not a class") + + # Needed on Python 3.7+ + if not args: + return + + if isinstance(args[0], ForwardRef): + expected_class = evaluate_forwardref(args[0], memo) + else: + expected_class = args[0] + + if expected_class is Any: + return + elif getattr(expected_class, "_is_protocol", False): + check_protocol(value, expected_class, (), memo) + elif isinstance(expected_class, TypeVar): + check_typevar(value, expected_class, (), memo, subclass_check=True) + elif get_origin(expected_class) is Union: + errors: dict[str, TypeCheckError] = {} + for arg in get_args(expected_class): + if arg is Any: + return + + try: + check_class(value, type, (arg,), memo) + return + except TypeCheckError as exc: + errors[get_type_name(arg)] = exc + else: + formatted_errors = indent( + "\n".join(f"{key}: {error}" for key, error in errors.items()), " " + ) + raise TypeCheckError( + f"did not match any element in the union:\n{formatted_errors}" + ) + elif not issubclass(value, expected_class): + raise TypeCheckError(f"is not a subclass of {qualified_name(expected_class)}") + + +def check_newtype( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + check_type_internal(value, origin_type.__supertype__, memo) + + +def check_instance( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, origin_type): + raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}") + + +def check_typevar( + value: Any, + origin_type: TypeVar, + args: tuple[Any, ...], + memo: TypeCheckMemo, + *, + subclass_check: bool = False, +) -> None: + if origin_type.__bound__ is not None: + annotation = ( + Type[origin_type.__bound__] if subclass_check else origin_type.__bound__ + ) + check_type_internal(value, annotation, memo) + elif origin_type.__constraints__: + for constraint in origin_type.__constraints__: + annotation = Type[constraint] if subclass_check else constraint + try: + check_type_internal(value, annotation, memo) + except TypeCheckError: + pass + else: + break + else: + formatted_constraints = ", ".join( + get_type_name(constraint) for constraint in origin_type.__constraints__ + ) + raise TypeCheckError( + f"does not match any of the constraints " f"({formatted_constraints})" + ) + + +if sys.version_info >= (3, 8): + if typing_extensions is None: + + def _is_literal_type(typ: object) -> bool: + return typ is typing.Literal + + else: + + def _is_literal_type(typ: object) -> bool: + return typ is typing.Literal or typ is typing_extensions.Literal + +else: + + def _is_literal_type(typ: object) -> bool: + return typ is typing_extensions.Literal + + +def check_literal( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + def get_literal_args(literal_args: tuple[Any, ...]) -> tuple[Any, ...]: + retval: list[Any] = [] + for arg in literal_args: + if _is_literal_type(get_origin(arg)): + # The first check works on py3.6 and lower, the second one on py3.7+ + retval.extend(get_literal_args(arg.__args__)) + elif arg is None or isinstance(arg, (int, str, bytes, bool, Enum)): + retval.append(arg) + else: + raise TypeError( + f"Illegal literal value: {arg}" + ) # TypeError here is deliberate + + return tuple(retval) + + final_args = tuple(get_literal_args(args)) + try: + index = final_args.index(value) + except ValueError: + pass + else: + if type(final_args[index]) is type(value): + return + + formatted_args = ", ".join(repr(arg) for arg in final_args) + raise TypeCheckError(f"is not any of ({formatted_args})") from None + + +def check_literal_string( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + check_type_internal(value, str, memo) + + +def check_typeguard( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + check_type_internal(value, bool, memo) + + +def check_none( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if value is not None: + raise TypeCheckError("is not None") + + +def check_number( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if origin_type is complex and not isinstance(value, (complex, float, int)): + raise TypeCheckError("is neither complex, float or int") + elif origin_type is float and not isinstance(value, (float, int)): + raise TypeCheckError("is neither float or int") + + +def check_io( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if origin_type is TextIO or (origin_type is IO and args == (str,)): + if not isinstance(value, TextIOBase): + raise TypeCheckError("is not a text based I/O object") + elif origin_type is BinaryIO or (origin_type is IO and args == (bytes,)): + if not isinstance(value, (RawIOBase, BufferedIOBase)): + raise TypeCheckError("is not a binary I/O object") + elif not isinstance(value, IOBase): + raise TypeCheckError("is not an I/O object") + + +def check_protocol( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + # TODO: implement proper compatibility checking and support non-runtime protocols + if getattr(origin_type, "_is_runtime_protocol", False): + if not isinstance(value, origin_type): + raise TypeCheckError( + f"is not compatible with the {origin_type.__qualname__} protocol" + ) + else: + warnings.warn( + f"Typeguard cannot check the {origin_type.__qualname__} protocol because " + f"it is a non-runtime protocol. If you would like to type check this " + f"protocol, please use @typing.runtime_checkable", + stacklevel=get_stacklevel(), + ) + + +def check_byteslike( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, (bytearray, bytes, memoryview)): + raise TypeCheckError("is not bytes-like") + + +def check_self( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if memo.self_type is None: + raise TypeCheckError("cannot be checked against Self outside of a method call") + + if isclass(value): + if not issubclass(value, memo.self_type): + raise TypeCheckError( + f"is not an instance of the self type " + f"({qualified_name(memo.self_type)})" + ) + elif not isinstance(value, memo.self_type): + raise TypeCheckError( + f"is not an instance of the self type ({qualified_name(memo.self_type)})" + ) + + +def check_paramspec( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + pass # No-op for now + + +def check_instanceof( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, origin_type): + raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}") + + +def check_type_internal( + value: Any, + annotation: Any, + memo: TypeCheckMemo, +) -> None: + """ + Check that the given object is compatible with the given type annotation. + + This function should only be used by type checker callables. Applications should use + :func:`~.check_type` instead. + + :param value: the value to check + :param annotation: the type annotation to check against + :param memo: a memo object containing configuration and information necessary for + looking up forward references + """ + + if isinstance(annotation, ForwardRef): + try: + annotation = evaluate_forwardref(annotation, memo) + except NameError: + if memo.config.forward_ref_policy is ForwardRefPolicy.ERROR: + raise + elif memo.config.forward_ref_policy is ForwardRefPolicy.WARN: + warnings.warn( + f"Cannot resolve forward reference {annotation.__forward_arg__!r}", + TypeHintWarning, + stacklevel=get_stacklevel(), + ) + + return + + if annotation is Any or annotation is SubclassableAny or isinstance(value, Mock): + return + + # Skip type checks if value is an instance of a class that inherits from Any + if not isclass(value) and SubclassableAny in type(value).__bases__: + return + + extras: tuple[Any, ...] + origin_type = get_origin(annotation) + if origin_type is Annotated: + annotation, *extras_ = get_args(annotation) + extras = tuple(extras_) + origin_type = get_origin(annotation) + else: + extras = () + + if origin_type is not None: + args = get_args(annotation) + + # Compatibility hack to distinguish between unparametrized and empty tuple + # (tuple[()]), necessary due to https://github.com/python/cpython/issues/91137 + if origin_type in (tuple, Tuple) and annotation is not Tuple and not args: + args = ((),) + else: + origin_type = annotation + args = () + + for lookup_func in checker_lookup_functions: + checker = lookup_func(origin_type, args, extras) + if checker: + checker(value, origin_type, args, memo) + return + + if isclass(origin_type): + if not isinstance(value, origin_type): + raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}") + elif type(origin_type) is str: + warnings.warn( + f"Skipping type check against {origin_type!r}; this looks like a " + f"string-form forward reference imported from another module", + TypeHintWarning, + stacklevel=get_stacklevel(), + ) + + +# Equality checks are applied to these +origin_type_checkers = { + bytes: check_byteslike, + AbstractSet: check_set, + BinaryIO: check_io, + Callable: check_callable, + collections.abc.Callable: check_callable, + complex: check_number, + dict: check_mapping, + Dict: check_mapping, + float: check_number, + frozenset: check_set, + IO: check_io, + list: check_list, + List: check_list, + Mapping: check_mapping, + MutableMapping: check_mapping, + None: check_none, + collections.abc.Mapping: check_mapping, + collections.abc.MutableMapping: check_mapping, + Sequence: check_sequence, + collections.abc.Sequence: check_sequence, + collections.abc.Set: check_set, + set: check_set, + Set: check_set, + TextIO: check_io, + tuple: check_tuple, + Tuple: check_tuple, + type: check_class, + Type: check_class, + Union: check_union, +} +if sys.version_info >= (3, 8): + origin_type_checkers[typing.Literal] = check_literal +if sys.version_info >= (3, 10): + origin_type_checkers[types.UnionType] = check_uniontype + origin_type_checkers[typing.TypeGuard] = check_typeguard +if sys.version_info >= (3, 11): + origin_type_checkers.update( + {typing.LiteralString: check_literal_string, typing.Self: check_self} + ) +if typing_extensions is not None: + # On some Python versions, these may simply be re-exports from typing, + # but exactly which Python versions is subject to change, + # so it's best to err on the safe side + # and update the dictionary on all Python versions + # if typing_extensions is installed + origin_type_checkers[typing_extensions.Literal] = check_literal + origin_type_checkers[typing_extensions.LiteralString] = check_literal_string + origin_type_checkers[typing_extensions.Self] = check_self + origin_type_checkers[typing_extensions.TypeGuard] = check_typeguard + + +def builtin_checker_lookup( + origin_type: Any, args: tuple[Any, ...], extras: tuple[Any, ...] +) -> TypeCheckerCallable | None: + checker = origin_type_checkers.get(origin_type) + if checker is not None: + return checker + elif is_typeddict(origin_type): + return check_typed_dict + elif isclass(origin_type) and issubclass( + origin_type, Tuple # type: ignore[arg-type] + ): + # NamedTuple + return check_tuple + elif getattr(origin_type, "_is_protocol", False): + return check_protocol + elif isinstance(origin_type, ParamSpec): + return check_paramspec + elif isinstance(origin_type, TypeVar): + return check_typevar + elif origin_type.__class__ is NewType: + # typing.NewType on Python 3.10+ + return check_newtype + elif ( + isfunction(origin_type) + and getattr(origin_type, "__module__", None) == "typing" + and getattr(origin_type, "__qualname__", "").startswith("NewType.") + and hasattr(origin_type, "__supertype__") + ): + # typing.NewType on Python 3.9 and below + return check_newtype + + return None + + +checker_lookup_functions.append(builtin_checker_lookup) + + +def load_plugins() -> None: + """ + Load all type checker lookup functions from entry points. + + All entry points from the ``typeguard.checker_lookup`` group are loaded, and the + returned lookup functions are added to :data:`typeguard.checker_lookup_functions`. + + .. note:: This function is called implicitly on import, unless the + ``TYPEGUARD_DISABLE_PLUGIN_AUTOLOAD`` environment variable is present. + """ + + for ep in entry_points(group="typeguard.checker_lookup"): + try: + plugin = ep.load() + except Exception as exc: + warnings.warn( + f"Failed to load plugin {ep.name!r}: " f"{qualified_name(exc)}: {exc}", + stacklevel=2, + ) + continue + + if not callable(plugin): + warnings.warn( + f"Plugin {ep} returned a non-callable object: {plugin!r}", stacklevel=2 + ) + continue + + checker_lookup_functions.insert(0, plugin) diff --git a/metaflow/_vendor/typeguard/_config.py b/metaflow/_vendor/typeguard/_config.py new file mode 100644 index 00000000000..04cecf84b3e --- /dev/null +++ b/metaflow/_vendor/typeguard/_config.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from collections.abc import Collection +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from ._functions import TypeCheckFailCallback + +T = TypeVar("T") + + +class ForwardRefPolicy(Enum): + """ + Defines how unresolved forward references are handled. + + Members: + + * ``ERROR``: propagate the :exc:`NameError` when the forward reference lookup fails + * ``WARN``: emit a :class:`~.TypeHintWarning` if the forward reference lookup fails + * ``IGNORE``: silently skip checks for unresolveable forward references + """ + + ERROR = auto() + WARN = auto() + IGNORE = auto() + + +class CollectionCheckStrategy(Enum): + """ + Specifies how thoroughly the contents of collections are type checked. + + This has an effect on the following built-in checkers: + + * ``AbstractSet`` + * ``Dict`` + * ``List`` + * ``Mapping`` + * ``Set`` + * ``Tuple[, ...]`` (arbitrarily sized tuples) + + Members: + + * ``FIRST_ITEM``: check only the first item + * ``ALL_ITEMS``: check all items + """ + + FIRST_ITEM = auto() + ALL_ITEMS = auto() + + def iterate_samples(self, collection: Collection[T]) -> Collection[T]: + if self is CollectionCheckStrategy.FIRST_ITEM: + if len(collection): + return [next(iter(collection))] + else: + return () + else: + return collection + + +@dataclass +class TypeCheckConfiguration: + """ + You can change Typeguard's behavior with these settings. + + .. attribute:: typecheck_fail_callback + :type: Callable[[TypeCheckError, TypeCheckMemo], Any] + + Callable that is called when type checking fails. + + Default: ``None`` (the :exc:`~.TypeCheckError` is raised directly) + + .. attribute:: forward_ref_policy + :type: ForwardRefPolicy + + Specifies what to do when a forward reference fails to resolve. + + Default: ``WARN`` + + .. attribute:: collection_check_strategy + :type: CollectionCheckStrategy + + Specifies how thoroughly the contents of collections (list, dict, etc.) are + type checked. + + Default: ``FIRST_ITEM`` + + .. attribute:: debug_instrumentation + :type: bool + + If set to ``True``, the code of modules or functions instrumented by typeguard + is printed to ``sys.stderr`` after the instrumentation is done + + Requires Python 3.9 or newer. + + Default: ``False`` + """ + + forward_ref_policy: ForwardRefPolicy = ForwardRefPolicy.WARN + typecheck_fail_callback: TypeCheckFailCallback | None = None + collection_check_strategy: CollectionCheckStrategy = ( + CollectionCheckStrategy.FIRST_ITEM + ) + debug_instrumentation: bool = False + + +global_config = TypeCheckConfiguration() diff --git a/metaflow/_vendor/typeguard/_decorators.py b/metaflow/_vendor/typeguard/_decorators.py new file mode 100644 index 00000000000..53f254f7080 --- /dev/null +++ b/metaflow/_vendor/typeguard/_decorators.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import ast +import inspect +import sys +from collections.abc import Sequence +from functools import partial +from inspect import isclass, isfunction +from types import CodeType, FrameType, FunctionType +from typing import TYPE_CHECKING, Any, Callable, ForwardRef, TypeVar, cast, overload +from warnings import warn + +from ._config import CollectionCheckStrategy, ForwardRefPolicy, global_config +from ._exceptions import InstrumentationWarning +from ._functions import TypeCheckFailCallback +from ._transformer import TypeguardTransformer +from ._utils import Unset, function_name, get_stacklevel, is_method_of, unset + +if TYPE_CHECKING: + from typeshed.stdlib.types import _Cell + + _F = TypeVar("_F") + + def typeguard_ignore(f: _F) -> _F: + """This decorator is a noop during static type-checking.""" + return f + +else: + from typing import no_type_check as typeguard_ignore # noqa: F401 + +T_CallableOrType = TypeVar("T_CallableOrType", bound=Callable[..., Any]) + + +def make_cell(value: object) -> _Cell: + return (lambda: value).__closure__[0] # type: ignore[index] + + +def find_target_function( + new_code: CodeType, target_path: Sequence[str], firstlineno: int +) -> CodeType | None: + target_name = target_path[0] + for const in new_code.co_consts: + if isinstance(const, CodeType): + if const.co_name == target_name: + if const.co_firstlineno == firstlineno: + return const + elif len(target_path) > 1: + target_code = find_target_function( + const, target_path[1:], firstlineno + ) + if target_code: + return target_code + + return None + + +def instrument(f: T_CallableOrType) -> FunctionType | str: + if not getattr(f, "__code__", None): + return "no code associated" + elif not getattr(f, "__module__", None): + return "__module__ attribute is not set" + elif f.__code__.co_filename == "": + return "cannot instrument functions defined in a REPL" + elif hasattr(f, "__wrapped__"): + return ( + "@typechecked only supports instrumenting functions wrapped with " + "@classmethod, @staticmethod or @property" + ) + + target_path = [item for item in f.__qualname__.split(".") if item != ""] + module_source = inspect.getsource(sys.modules[f.__module__]) + module_ast = ast.parse(module_source) + instrumentor = TypeguardTransformer(target_path, f.__code__.co_firstlineno) + instrumentor.visit(module_ast) + + if not instrumentor.target_node or instrumentor.target_lineno is None: + return "instrumentor did not find the target function" + + module_code = compile(module_ast, f.__code__.co_filename, "exec", dont_inherit=True) + new_code = find_target_function( + module_code, target_path, instrumentor.target_lineno + ) + if not new_code: + return "cannot find the target function in the AST" + + if global_config.debug_instrumentation and sys.version_info >= (3, 9): + # Find the matching AST node, then unparse it to source and print to stdout + print( + f"Source code of {f.__qualname__}() after instrumentation:" + "\n----------------------------------------------", + file=sys.stderr, + ) + print(ast.unparse(instrumentor.target_node), file=sys.stderr) + print( + "----------------------------------------------", + file=sys.stderr, + ) + + closure = f.__closure__ + if new_code.co_freevars != f.__code__.co_freevars: + # Create a new closure and find values for the new free variables + frame = cast(FrameType, inspect.currentframe()) + frame = cast(FrameType, frame.f_back) + frame_locals = cast(FrameType, frame.f_back).f_locals + cells: list[_Cell] = [] + for key in new_code.co_freevars: + if key in instrumentor.names_used_in_annotations: + # Find the value and make a new cell from it + value = frame_locals.get(key) or ForwardRef(key) + cells.append(make_cell(value)) + else: + # Reuse the cell from the existing closure + assert f.__closure__ + cells.append(f.__closure__[f.__code__.co_freevars.index(key)]) + + closure = tuple(cells) + + new_function = FunctionType(new_code, f.__globals__, f.__name__, closure=closure) + new_function.__module__ = f.__module__ + new_function.__name__ = f.__name__ + new_function.__qualname__ = f.__qualname__ + new_function.__annotations__ = f.__annotations__ + new_function.__doc__ = f.__doc__ + new_function.__defaults__ = f.__defaults__ + new_function.__kwdefaults__ = f.__kwdefaults__ + return new_function + + +@overload +def typechecked( + *, + forward_ref_policy: ForwardRefPolicy | Unset = unset, + typecheck_fail_callback: TypeCheckFailCallback | Unset = unset, + collection_check_strategy: CollectionCheckStrategy | Unset = unset, + debug_instrumentation: bool | Unset = unset, +) -> Callable[[T_CallableOrType], T_CallableOrType]: + ... + + +@overload +def typechecked(target: T_CallableOrType) -> T_CallableOrType: + ... + + +def typechecked( + target: T_CallableOrType | None = None, + *, + forward_ref_policy: ForwardRefPolicy | Unset = unset, + typecheck_fail_callback: TypeCheckFailCallback | Unset = unset, + collection_check_strategy: CollectionCheckStrategy | Unset = unset, + debug_instrumentation: bool | Unset = unset, +) -> Any: + """ + Instrument the target function to perform run-time type checking. + + This decorator recompiles the target function, injecting code to type check + arguments, return values, yield values (excluding ``yield from``) and assignments to + annotated local variables. + + This can also be used as a class decorator. This will instrument all type annotated + methods, including :func:`@classmethod `, + :func:`@staticmethod `, and :class:`@property ` decorated + methods in the class. + + .. note:: When Python is run in optimized mode (``-O`` or ``-OO``, this decorator + is a no-op). This is a feature meant for selectively introducing type checking + into a code base where the checks aren't meant to be run in production. + + :param target: the function or class to enable type checking for + :param forward_ref_policy: override for + :attr:`.TypeCheckConfiguration.forward_ref_policy` + :param typecheck_fail_callback: override for + :attr:`.TypeCheckConfiguration.typecheck_fail_callback` + :param collection_check_strategy: override for + :attr:`.TypeCheckConfiguration.collection_check_strategy` + :param debug_instrumentation: override for + :attr:`.TypeCheckConfiguration.debug_instrumentation` + + """ + if target is None: + return partial( + typechecked, + forward_ref_policy=forward_ref_policy, + typecheck_fail_callback=typecheck_fail_callback, + collection_check_strategy=collection_check_strategy, + debug_instrumentation=debug_instrumentation, + ) + + if not __debug__: + return target + + if isclass(target): + for key, attr in target.__dict__.items(): + if is_method_of(attr, target): + retval = instrument(attr) + if isfunction(retval): + setattr(target, key, retval) + elif isinstance(attr, (classmethod, staticmethod)): + if is_method_of(attr.__func__, target): + retval = instrument(attr.__func__) + if isfunction(retval): + wrapper = attr.__class__(retval) + setattr(target, key, wrapper) + elif isinstance(attr, property): + kwargs: dict[str, Any] = dict(doc=attr.__doc__) + for name in ("fset", "fget", "fdel"): + property_func = kwargs[name] = getattr(attr, name) + if is_method_of(property_func, target): + retval = instrument(property_func) + if isfunction(retval): + kwargs[name] = retval + + setattr(target, key, attr.__class__(**kwargs)) + + return target + + # Find either the first Python wrapper or the actual function + wrapper_class: type[classmethod[Any, Any, Any]] | type[ + staticmethod[Any, Any] + ] | None = None + if isinstance(target, (classmethod, staticmethod)): + wrapper_class = target.__class__ + target = target.__func__ + + retval = instrument(target) + if isinstance(retval, str): + warn( + f"{retval} -- not typechecking {function_name(target)}", + InstrumentationWarning, + stacklevel=get_stacklevel(), + ) + return target + + if wrapper_class is None: + return retval + else: + return wrapper_class(retval) diff --git a/metaflow/_vendor/typeguard/_exceptions.py b/metaflow/_vendor/typeguard/_exceptions.py new file mode 100644 index 00000000000..625437a6499 --- /dev/null +++ b/metaflow/_vendor/typeguard/_exceptions.py @@ -0,0 +1,42 @@ +from collections import deque +from typing import Deque + + +class TypeHintWarning(UserWarning): + """ + A warning that is emitted when a type hint in string form could not be resolved to + an actual type. + """ + + +class TypeCheckWarning(UserWarning): + """Emitted by typeguard's type checkers when a type mismatch is detected.""" + + def __init__(self, message: str): + super().__init__(message) + + +class InstrumentationWarning(UserWarning): + """Emitted when there's a problem with instrumenting a function for type checks.""" + + def __init__(self, message: str): + super().__init__(message) + + +class TypeCheckError(Exception): + """ + Raised by typeguard's type checkers when a type mismatch is detected. + """ + + def __init__(self, message: str): + super().__init__(message) + self._path: Deque[str] = deque() + + def append_path_element(self, element: str) -> None: + self._path.append(element) + + def __str__(self) -> str: + if self._path: + return " of ".join(self._path) + " " + str(self.args[0]) + else: + return str(self.args[0]) diff --git a/metaflow/_vendor/typeguard/_functions.py b/metaflow/_vendor/typeguard/_functions.py new file mode 100644 index 00000000000..ec89e3a58ed --- /dev/null +++ b/metaflow/_vendor/typeguard/_functions.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +import sys +import warnings +from typing import Any, Callable, NoReturn, TypeVar, overload + +from . import _suppression +from ._checkers import BINARY_MAGIC_METHODS, check_type_internal +from ._config import ( + CollectionCheckStrategy, + ForwardRefPolicy, + TypeCheckConfiguration, +) +from ._exceptions import TypeCheckError, TypeCheckWarning +from ._memo import TypeCheckMemo +from ._utils import get_stacklevel, qualified_name + +if sys.version_info >= (3, 11): + from typing import Literal, Never, TypeAlias +else: + from metaflow._vendor.typing_extensions import Literal, Never, TypeAlias + +T = TypeVar("T") +TypeCheckFailCallback: TypeAlias = Callable[[TypeCheckError, TypeCheckMemo], Any] + + +@overload +def check_type( + value: object, + expected_type: type[T], + *, + forward_ref_policy: ForwardRefPolicy = ..., + typecheck_fail_callback: TypeCheckFailCallback | None = ..., + collection_check_strategy: CollectionCheckStrategy = ..., +) -> T: + ... + + +@overload +def check_type( + value: object, + expected_type: Any, + *, + forward_ref_policy: ForwardRefPolicy = ..., + typecheck_fail_callback: TypeCheckFailCallback | None = ..., + collection_check_strategy: CollectionCheckStrategy = ..., +) -> Any: + ... + + +def check_type( + value: object, + expected_type: Any, + *, + forward_ref_policy: ForwardRefPolicy = TypeCheckConfiguration().forward_ref_policy, + typecheck_fail_callback: (TypeCheckFailCallback | None) = ( + TypeCheckConfiguration().typecheck_fail_callback + ), + collection_check_strategy: CollectionCheckStrategy = ( + TypeCheckConfiguration().collection_check_strategy + ), +) -> Any: + """ + Ensure that ``value`` matches ``expected_type``. + + The types from the :mod:`typing` module do not support :func:`isinstance` or + :func:`issubclass` so a number of type specific checks are required. This function + knows which checker to call for which type. + + This function wraps :func:`~.check_type_internal` in the following ways: + + * Respects type checking suppression (:func:`~.suppress_type_checks`) + * Forms a :class:`~.TypeCheckMemo` from the current stack frame + * Calls the configured type check fail callback if the check fails + + Note that this function is independent of the globally shared configuration in + :data:`typeguard.config`. This means that usage within libraries is safe from being + affected configuration changes made by other libraries or by the integrating + application. Instead, configuration options have the same default values as their + corresponding fields in :class:`TypeCheckConfiguration`. + + :param value: value to be checked against ``expected_type`` + :param expected_type: a class or generic type instance + :param forward_ref_policy: see :attr:`TypeCheckConfiguration.forward_ref_policy` + :param typecheck_fail_callback: + see :attr`TypeCheckConfiguration.typecheck_fail_callback` + :param collection_check_strategy: + see :attr:`TypeCheckConfiguration.collection_check_strategy` + :return: ``value``, unmodified + :raises TypeCheckError: if there is a type mismatch + + """ + config = TypeCheckConfiguration( + forward_ref_policy=forward_ref_policy, + typecheck_fail_callback=typecheck_fail_callback, + collection_check_strategy=collection_check_strategy, + ) + + if _suppression.type_checks_suppressed or expected_type is Any: + return value + + frame = sys._getframe(1) + memo = TypeCheckMemo(frame.f_globals, frame.f_locals, config=config) + try: + check_type_internal(value, expected_type, memo) + except TypeCheckError as exc: + exc.append_path_element(qualified_name(value, add_class_prefix=True)) + if config.typecheck_fail_callback: + config.typecheck_fail_callback(exc, memo) + else: + raise + + return value + + +def check_argument_types( + func_name: str, + arguments: dict[str, tuple[Any, Any]], + memo: TypeCheckMemo, +) -> Literal[True]: + if _suppression.type_checks_suppressed: + return True + + for argname, (value, annotation) in arguments.items(): + if annotation is NoReturn or annotation is Never: + exc = TypeCheckError( + f"{func_name}() was declared never to be called but it was" + ) + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise exc + + try: + check_type_internal(value, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(value, add_class_prefix=True) + exc.append_path_element(f'argument "{argname}" ({qualname})') + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return True + + +def check_return_type( + func_name: str, + retval: T, + annotation: Any, + memo: TypeCheckMemo, +) -> T: + if _suppression.type_checks_suppressed: + return retval + + if annotation is NoReturn or annotation is Never: + exc = TypeCheckError(f"{func_name}() was declared never to return but it did") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise exc + + try: + check_type_internal(retval, annotation, memo) + except TypeCheckError as exc: + # Allow NotImplemented if this is a binary magic method (__eq__() et al) + if retval is NotImplemented and annotation is bool: + # This does (and cannot) not check if it's actually a method + func_name = func_name.rsplit(".", 1)[-1] + if func_name in BINARY_MAGIC_METHODS: + return retval + + qualname = qualified_name(retval, add_class_prefix=True) + exc.append_path_element(f"the return value ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return retval + + +def check_send_type( + func_name: str, + sendval: T, + annotation: Any, + memo: TypeCheckMemo, +) -> T: + if _suppression.type_checks_suppressed: + return sendval + + if annotation is NoReturn or annotation is Never: + exc = TypeCheckError( + f"{func_name}() was declared never to be sent a value to but it was" + ) + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise exc + + try: + check_type_internal(sendval, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(sendval, add_class_prefix=True) + exc.append_path_element(f"the value sent to generator ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return sendval + + +def check_yield_type( + func_name: str, + yieldval: T, + annotation: Any, + memo: TypeCheckMemo, +) -> T: + if _suppression.type_checks_suppressed: + return yieldval + + if annotation is NoReturn or annotation is Never: + exc = TypeCheckError(f"{func_name}() was declared never to yield but it did") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise exc + + try: + check_type_internal(yieldval, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(yieldval, add_class_prefix=True) + exc.append_path_element(f"the yielded value ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return yieldval + + +def check_variable_assignment( + value: object, varname: str, annotation: Any, memo: TypeCheckMemo +) -> Any: + if _suppression.type_checks_suppressed: + return + + try: + check_type_internal(value, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(value, add_class_prefix=True) + exc.append_path_element(f"value assigned to {varname} ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return value + + +def check_multi_variable_assignment( + value: Any, targets: list[dict[str, Any]], memo: TypeCheckMemo +) -> Any: + if _suppression.type_checks_suppressed: + return + + if max(len(target) for target in targets) == 1: + iterated_values = [value] + else: + iterated_values = list(value) + + for expected_types in targets: + value_index = 0 + for ann_index, (varname, expected_type) in enumerate(expected_types.items()): + if varname.startswith("*"): + varname = varname[1:] + keys_left = len(expected_types) - 1 - ann_index + next_value_index = len(iterated_values) - keys_left + obj: object = iterated_values[value_index:next_value_index] + value_index = next_value_index + else: + obj = iterated_values[value_index] + value_index += 1 + + try: + check_type_internal(obj, expected_type, memo) + except TypeCheckError as exc: + qualname = qualified_name(obj, add_class_prefix=True) + exc.append_path_element(f"value assigned to {varname} ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return iterated_values[0] if len(iterated_values) == 1 else iterated_values + + +def warn_on_error(exc: TypeCheckError, memo: TypeCheckMemo) -> None: + """ + Emit a warning on a type mismatch. + + This is intended to be used as an error handler in + :attr:`TypeCheckConfiguration.typecheck_fail_callback`. + + """ + warnings.warn(TypeCheckWarning(str(exc)), stacklevel=get_stacklevel()) diff --git a/metaflow/_vendor/typeguard/_importhook.py b/metaflow/_vendor/typeguard/_importhook.py new file mode 100644 index 00000000000..11342951737 --- /dev/null +++ b/metaflow/_vendor/typeguard/_importhook.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import ast +import sys +import types +from collections.abc import Callable, Iterable +from importlib.abc import MetaPathFinder +from importlib.machinery import ModuleSpec, SourceFileLoader +from importlib.util import cache_from_source, decode_source +from inspect import isclass +from os import PathLike +from types import CodeType, ModuleType, TracebackType +from typing import Sequence, TypeVar +from unittest.mock import patch + +from ._config import global_config +from ._transformer import TypeguardTransformer + +if sys.version_info >= (3, 12): + from collections.abc import Buffer +else: + from metaflow._vendor.typing_extensions import Buffer + +if sys.version_info >= (3, 11): + from typing import ParamSpec +else: + from metaflow._vendor.typing_extensions import ParamSpec + +if sys.version_info >= (3, 10): + from importlib.metadata import PackageNotFoundError, version +else: + from metaflow._vendor.importlib_metadata import PackageNotFoundError, version + +try: + OPTIMIZATION = "typeguard" + "".join(version("typeguard").split(".")[:3]) +except PackageNotFoundError: + OPTIMIZATION = "typeguard" + +P = ParamSpec("P") +T = TypeVar("T") + + +# The name of this function is magical +def _call_with_frames_removed( + f: Callable[P, T], *args: P.args, **kwargs: P.kwargs +) -> T: + return f(*args, **kwargs) + + +def optimized_cache_from_source(path: str, debug_override: bool | None = None) -> str: + return cache_from_source(path, debug_override, optimization=OPTIMIZATION) + + +class TypeguardLoader(SourceFileLoader): + @staticmethod + def source_to_code( + data: Buffer | str | ast.Module | ast.Expression | ast.Interactive, + path: Buffer | str | PathLike[str] = "", + ) -> CodeType: + if isinstance(data, (ast.Module, ast.Expression, ast.Interactive)): + tree = data + else: + if isinstance(data, str): + source = data + else: + source = decode_source(data) + + tree = _call_with_frames_removed( + ast.parse, + source, + path, + "exec", + ) + + tree = TypeguardTransformer().visit(tree) + ast.fix_missing_locations(tree) + + if global_config.debug_instrumentation and sys.version_info >= (3, 9): + print( + f"Source code of {path!r} after instrumentation:\n" + "----------------------------------------------", + file=sys.stderr, + ) + print(ast.unparse(tree), file=sys.stderr) + print("----------------------------------------------", file=sys.stderr) + + return _call_with_frames_removed( + compile, tree, path, "exec", 0, dont_inherit=True + ) + + def exec_module(self, module: ModuleType) -> None: + # Use a custom optimization marker – the import lock should make this monkey + # patch safe + with patch( + "importlib._bootstrap_external.cache_from_source", + optimized_cache_from_source, + ): + super().exec_module(module) + + +class TypeguardFinder(MetaPathFinder): + """ + Wraps another path finder and instruments the module with + :func:`@typechecked ` if :meth:`should_instrument` returns + ``True``. + + Should not be used directly, but rather via :func:`~.install_import_hook`. + + .. versionadded:: 2.6 + """ + + def __init__(self, packages: list[str] | None, original_pathfinder: MetaPathFinder): + self.packages = packages + self._original_pathfinder = original_pathfinder + + def find_spec( + self, + fullname: str, + path: Sequence[str] | None, + target: types.ModuleType | None = None, + ) -> ModuleSpec | None: + if self.should_instrument(fullname): + spec = self._original_pathfinder.find_spec(fullname, path, target) + if spec is not None and isinstance(spec.loader, SourceFileLoader): + spec.loader = TypeguardLoader(spec.loader.name, spec.loader.path) + return spec + + return None + + def should_instrument(self, module_name: str) -> bool: + """ + Determine whether the module with the given name should be instrumented. + + :param module_name: full name of the module that is about to be imported (e.g. + ``xyz.abc``) + + """ + if self.packages is None: + return True + + for package in self.packages: + if module_name == package or module_name.startswith(package + "."): + return True + + return False + + +class ImportHookManager: + """ + A handle that can be used to uninstall the Typeguard import hook. + """ + + def __init__(self, hook: MetaPathFinder): + self.hook = hook + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + self.uninstall() + + def uninstall(self) -> None: + """Uninstall the import hook.""" + try: + sys.meta_path.remove(self.hook) + except ValueError: + pass # already removed + + +def install_import_hook( + packages: Iterable[str] | None = None, + *, + cls: type[TypeguardFinder] = TypeguardFinder, +) -> ImportHookManager: + """ + Install an import hook that instruments functions for automatic type checking. + + This only affects modules loaded **after** this hook has been installed. + + :param packages: an iterable of package names to instrument, or ``None`` to + instrument all packages + :param cls: a custom meta path finder class + :return: a context manager that uninstalls the hook on exit (or when you call + ``.uninstall()``) + + .. versionadded:: 2.6 + + """ + if packages is None: + target_packages: list[str] | None = None + elif isinstance(packages, str): + target_packages = [packages] + else: + target_packages = list(packages) + + for finder in sys.meta_path: + if ( + isclass(finder) + and finder.__name__ == "PathFinder" + and hasattr(finder, "find_spec") + ): + break + else: + raise RuntimeError("Cannot find a PathFinder in sys.meta_path") + + hook = cls(target_packages, finder) + sys.meta_path.insert(0, hook) + return ImportHookManager(hook) diff --git a/metaflow/_vendor/typeguard/_memo.py b/metaflow/_vendor/typeguard/_memo.py new file mode 100644 index 00000000000..988a27122a9 --- /dev/null +++ b/metaflow/_vendor/typeguard/_memo.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Any + +from metaflow._vendor.typeguard._config import TypeCheckConfiguration, global_config + + +class TypeCheckMemo: + """ + Contains information necessary for type checkers to do their work. + + .. attribute:: globals + :type: dict[str, Any] + + Dictionary of global variables to use for resolving forward references. + + .. attribute:: locals + :type: dict[str, Any] + + Dictionary of local variables to use for resolving forward references. + + .. attribute:: self_type + :type: type | None + + When running type checks within an instance method or class method, this is the + class object that the first argument (usually named ``self`` or ``cls``) refers + to. + + .. attribute:: config + :type: TypeCheckConfiguration + + Contains the configuration for a particular set of type checking operations. + """ + + __slots__ = "globals", "locals", "self_type", "config" + + def __init__( + self, + globals: dict[str, Any], + locals: dict[str, Any], + *, + self_type: type | None = None, + config: TypeCheckConfiguration = global_config, + ): + self.globals = globals + self.locals = locals + self.self_type = self_type + self.config = config diff --git a/metaflow/_vendor/typeguard/_pytest_plugin.py b/metaflow/_vendor/typeguard/_pytest_plugin.py new file mode 100644 index 00000000000..2e59c33be94 --- /dev/null +++ b/metaflow/_vendor/typeguard/_pytest_plugin.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import sys +import warnings + +from pytest import Config, Parser + +from metaflow._vendor.typeguard._config import CollectionCheckStrategy, ForwardRefPolicy, global_config +from metaflow._vendor.typeguard._exceptions import InstrumentationWarning +from metaflow._vendor.typeguard._importhook import install_import_hook +from metaflow._vendor.typeguard._utils import qualified_name, resolve_reference + + +def pytest_addoption(parser: Parser) -> None: + group = parser.getgroup("typeguard") + group.addoption( + "--typeguard-packages", + action="store", + help="comma separated name list of packages and modules to instrument for " + "type checking, or :all: to instrument all modules loaded after typeguard", + ) + group.addoption( + "--typeguard-debug-instrumentation", + action="store_true", + help="print all instrumented code to stderr", + ) + group.addoption( + "--typeguard-typecheck-fail-callback", + action="store", + help=( + "a module:varname (e.g. typeguard:warn_on_error) reference to a function " + "that is called (with the exception, and memo object as arguments) to " + "handle a TypeCheckError" + ), + ) + group.addoption( + "--typeguard-forward-ref-policy", + action="store", + choices=list(ForwardRefPolicy.__members__), + help=( + "determines how to deal with unresolveable forward references in type " + "annotations" + ), + ) + group.addoption( + "--typeguard-collection-check-strategy", + action="store", + choices=list(CollectionCheckStrategy.__members__), + help="determines how thoroughly to check collections (list, dict, etc)", + ) + + +def pytest_configure(config: Config) -> None: + packages_option = config.getoption("typeguard_packages") + if packages_option: + if packages_option == ":all:": + packages: list[str] | None = None + else: + packages = [pkg.strip() for pkg in packages_option.split(",")] + already_imported_packages = sorted( + package for package in packages if package in sys.modules + ) + if already_imported_packages: + warnings.warn( + f"typeguard cannot check these packages because they are already " + f"imported: {', '.join(already_imported_packages)}", + InstrumentationWarning, + stacklevel=1, + ) + + install_import_hook(packages=packages) + + debug_option = config.getoption("typeguard_debug_instrumentation") + if debug_option: + global_config.debug_instrumentation = True + + fail_callback_option = config.getoption("typeguard_typecheck_fail_callback") + if fail_callback_option: + callback = resolve_reference(fail_callback_option) + if not callable(callback): + raise TypeError( + f"{fail_callback_option} ({qualified_name(callback.__class__)}) is not " + f"a callable" + ) + + global_config.typecheck_fail_callback = callback + + forward_ref_policy_option = config.getoption("typeguard_forward_ref_policy") + if forward_ref_policy_option: + forward_ref_policy = ForwardRefPolicy.__members__[forward_ref_policy_option] + global_config.forward_ref_policy = forward_ref_policy + + collection_check_strategy_option = config.getoption( + "typeguard_collection_check_strategy" + ) + if collection_check_strategy_option: + collection_check_strategy = CollectionCheckStrategy.__members__[ + collection_check_strategy_option + ] + global_config.collection_check_strategy = collection_check_strategy diff --git a/metaflow/_vendor/typeguard/_suppression.py b/metaflow/_vendor/typeguard/_suppression.py new file mode 100644 index 00000000000..012c858e3ae --- /dev/null +++ b/metaflow/_vendor/typeguard/_suppression.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import sys +from collections.abc import Callable, Generator +from contextlib import contextmanager +from functools import update_wrapper +from threading import Lock +from typing import ContextManager, TypeVar, overload + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from metaflow._vendor.typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + +type_checks_suppressed = 0 +type_checks_suppress_lock = Lock() + + +@overload +def suppress_type_checks(func: Callable[P, T]) -> Callable[P, T]: + ... + + +@overload +def suppress_type_checks() -> ContextManager[None]: + ... + + +def suppress_type_checks( + func: Callable[P, T] | None = None +) -> Callable[P, T] | ContextManager[None]: + """ + Temporarily suppress all type checking. + + This function has two operating modes, based on how it's used: + + #. as a context manager (``with suppress_type_checks(): ...``) + #. as a decorator (``@suppress_type_checks``) + + When used as a context manager, :func:`check_type` and any automatically + instrumented functions skip the actual type checking. These context managers can be + nested. + + When used as a decorator, all type checking is suppressed while the function is + running. + + Type checking will resume once no more context managers are active and no decorated + functions are running. + + Both operating modes are thread-safe. + + """ + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + global type_checks_suppressed + + with type_checks_suppress_lock: + type_checks_suppressed += 1 + + assert func is not None + try: + return func(*args, **kwargs) + finally: + with type_checks_suppress_lock: + type_checks_suppressed -= 1 + + def cm() -> Generator[None, None, None]: + global type_checks_suppressed + + with type_checks_suppress_lock: + type_checks_suppressed += 1 + + try: + yield + finally: + with type_checks_suppress_lock: + type_checks_suppressed -= 1 + + if func is None: + # Context manager mode + return contextmanager(cm)() + else: + # Decorator mode + update_wrapper(wrapper, func) + return wrapper diff --git a/metaflow/_vendor/typeguard/_transformer.py b/metaflow/_vendor/typeguard/_transformer.py new file mode 100644 index 00000000000..32d284e1740 --- /dev/null +++ b/metaflow/_vendor/typeguard/_transformer.py @@ -0,0 +1,1193 @@ +from __future__ import annotations + +import ast +import builtins +import sys +import typing +from ast import ( + AST, + Add, + AnnAssign, + Assign, + AsyncFunctionDef, + Attribute, + AugAssign, + BinOp, + BitAnd, + BitOr, + BitXor, + Call, + ClassDef, + Constant, + Dict, + Div, + Expr, + Expression, + FloorDiv, + FunctionDef, + If, + Import, + ImportFrom, + Index, + List, + Load, + LShift, + MatMult, + Mod, + Module, + Mult, + Name, + NodeTransformer, + NodeVisitor, + Pass, + Pow, + Return, + RShift, + Starred, + Store, + Str, + Sub, + Subscript, + Tuple, + Yield, + YieldFrom, + alias, + copy_location, + expr, + fix_missing_locations, + keyword, + walk, +) +from collections import defaultdict +from collections.abc import Generator, Sequence +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, ClassVar, cast, overload + +if sys.version_info >= (3, 8): + from ast import NamedExpr + +generator_names = ( + "typing.Generator", + "collections.abc.Generator", + "typing.Iterator", + "collections.abc.Iterator", + "typing.Iterable", + "collections.abc.Iterable", + "typing.AsyncIterator", + "collections.abc.AsyncIterator", + "typing.AsyncIterable", + "collections.abc.AsyncIterable", + "typing.AsyncGenerator", + "collections.abc.AsyncGenerator", +) +anytype_names = ( + "typing.Any", + "typing_extensions.Any", +) +literal_names = ( + "typing.Literal", + "typing_extensions.Literal", +) +annotated_names = ( + "typing.Annotated", + "typing_extensions.Annotated", +) +ignore_decorators = ( + "typing.no_type_check", + "typeguard.typeguard_ignore", +) +aug_assign_functions = { + Add: "iadd", + Sub: "isub", + Mult: "imul", + MatMult: "imatmul", + Div: "itruediv", + FloorDiv: "ifloordiv", + Mod: "imod", + Pow: "ipow", + LShift: "ilshift", + RShift: "irshift", + BitAnd: "iand", + BitXor: "ixor", + BitOr: "ior", +} + + +@dataclass +class TransformMemo: + node: Module | ClassDef | FunctionDef | AsyncFunctionDef | None + parent: TransformMemo | None + path: tuple[str, ...] + joined_path: Constant = field(init=False) + return_annotation: expr | None = None + yield_annotation: expr | None = None + send_annotation: expr | None = None + is_async: bool = False + local_names: set[str] = field(init=False, default_factory=set) + imported_names: dict[str, str] = field(init=False, default_factory=dict) + ignored_names: set[str] = field(init=False, default_factory=set) + load_names: defaultdict[str, dict[str, Name]] = field( + init=False, default_factory=lambda: defaultdict(dict) + ) + has_yield_expressions: bool = field(init=False, default=False) + has_return_expressions: bool = field(init=False, default=False) + memo_var_name: Name | None = field(init=False, default=None) + should_instrument: bool = field(init=False, default=True) + variable_annotations: dict[str, expr] = field(init=False, default_factory=dict) + configuration_overrides: dict[str, Any] = field(init=False, default_factory=dict) + code_inject_index: int = field(init=False, default=0) + + def __post_init__(self) -> None: + elements: list[str] = [] + memo = self + while isinstance(memo.node, (ClassDef, FunctionDef, AsyncFunctionDef)): + elements.insert(0, memo.node.name) + if not memo.parent: + break + + memo = memo.parent + if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)): + elements.insert(0, "") + + self.joined_path = Constant(".".join(elements)) + + # Figure out where to insert instrumentation code + if self.node: + for index, child in enumerate(self.node.body): + if isinstance(child, ImportFrom) and child.module == "__future__": + # (module only) __future__ imports must come first + continue + elif isinstance(child, Expr): + if isinstance(child.value, Constant) and isinstance( + child.value.value, str + ): + continue # docstring + elif sys.version_info < (3, 8) and isinstance(child.value, Str): + continue # docstring + + self.code_inject_index = index + break + + def get_unused_name(self, name: str) -> str: + memo: TransformMemo | None = self + while memo is not None: + if name in memo.local_names: + memo = self + name += "_" + else: + memo = memo.parent + + self.local_names.add(name) + return name + + def is_ignored_name(self, expression: expr | Expr | None) -> bool: + top_expression = ( + expression.value if isinstance(expression, Expr) else expression + ) + + if isinstance(top_expression, Attribute) and isinstance( + top_expression.value, Name + ): + name = top_expression.value.id + elif isinstance(top_expression, Name): + name = top_expression.id + else: + return False + + memo: TransformMemo | None = self + while memo is not None: + if name in memo.ignored_names: + return True + + memo = memo.parent + + return False + + def get_memo_name(self) -> Name: + if not self.memo_var_name: + self.memo_var_name = Name(id="memo", ctx=Load()) + + return self.memo_var_name + + def get_import(self, module: str, name: str) -> Name: + if module in self.load_names and name in self.load_names[module]: + return self.load_names[module][name] + + qualified_name = f"{module}.{name}" + if name in self.imported_names and self.imported_names[name] == qualified_name: + return Name(id=name, ctx=Load()) + + alias = self.get_unused_name(name) + node = self.load_names[module][name] = Name(id=alias, ctx=Load()) + self.imported_names[name] = qualified_name + return node + + def insert_imports(self, node: Module | FunctionDef | AsyncFunctionDef) -> None: + """Insert imports needed by injected code.""" + if not self.load_names: + return + + # Insert imports after any "from __future__ ..." imports and any docstring + for modulename, names in self.load_names.items(): + aliases = [ + alias(orig_name, new_name.id if orig_name != new_name.id else None) + for orig_name, new_name in sorted(names.items()) + ] + node.body.insert(self.code_inject_index, ImportFrom(modulename, aliases, 0)) + + def name_matches(self, expression: expr | Expr | None, *names: str) -> bool: + if expression is None: + return False + + path: list[str] = [] + top_expression = ( + expression.value if isinstance(expression, Expr) else expression + ) + + if isinstance(top_expression, Subscript): + top_expression = top_expression.value + elif isinstance(top_expression, Call): + top_expression = top_expression.func + + while isinstance(top_expression, Attribute): + path.insert(0, top_expression.attr) + top_expression = top_expression.value + + if not isinstance(top_expression, Name): + return False + + if top_expression.id in self.imported_names: + translated = self.imported_names[top_expression.id] + elif hasattr(builtins, top_expression.id): + translated = "builtins." + top_expression.id + else: + translated = top_expression.id + + path.insert(0, translated) + joined_path = ".".join(path) + if joined_path in names: + return True + elif self.parent: + return self.parent.name_matches(expression, *names) + else: + return False + + def get_config_keywords(self) -> list[keyword]: + if self.parent and isinstance(self.parent.node, ClassDef): + overrides = self.parent.configuration_overrides.copy() + else: + overrides = {} + + overrides.update(self.configuration_overrides) + return [keyword(key, value) for key, value in overrides.items()] + + +class NameCollector(NodeVisitor): + def __init__(self) -> None: + self.names: set[str] = set() + + def visit_Import(self, node: Import) -> None: + for name in node.names: + self.names.add(name.asname or name.name) + + def visit_ImportFrom(self, node: ImportFrom) -> None: + for name in node.names: + self.names.add(name.asname or name.name) + + def visit_Assign(self, node: Assign) -> None: + for target in node.targets: + if isinstance(target, Name): + self.names.add(target.id) + + def visit_NamedExpr(self, node: NamedExpr) -> Any: + if isinstance(node.target, Name): + self.names.add(node.target.id) + + def visit_FunctionDef(self, node: FunctionDef) -> None: + pass + + def visit_ClassDef(self, node: ClassDef) -> None: + pass + + +class GeneratorDetector(NodeVisitor): + """Detects if a function node is a generator function.""" + + contains_yields: bool = False + in_root_function: bool = False + + def visit_Yield(self, node: Yield) -> Any: + self.contains_yields = True + + def visit_YieldFrom(self, node: YieldFrom) -> Any: + self.contains_yields = True + + def visit_ClassDef(self, node: ClassDef) -> Any: + pass + + def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any: + if not self.in_root_function: + self.in_root_function = True + self.generic_visit(node) + self.in_root_function = False + + def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> Any: + self.visit_FunctionDef(node) + + +class AnnotationTransformer(NodeTransformer): + type_substitutions: ClassVar[dict[str, tuple[str, str]]] = { + "builtins.dict": ("typing", "Dict"), + "builtins.list": ("typing", "List"), + "builtins.tuple": ("typing", "Tuple"), + "builtins.set": ("typing", "Set"), + "builtins.frozenset": ("typing", "FrozenSet"), + } + + def __init__(self, transformer: TypeguardTransformer): + self.transformer = transformer + self._memo = transformer._memo + + def visit(self, node: AST) -> Any: + new_node = super().visit(node) + if isinstance(new_node, Expression) and not hasattr(new_node, "body"): + return None + + # Return None if this new node matches a variation of typing.Any + if isinstance(new_node, expr) and self._memo.name_matches( + new_node, *anytype_names + ): + return None + + return new_node + + def visit_BinOp(self, node: BinOp) -> Any: + self.generic_visit(node) + + if isinstance(node.op, BitOr): + # If either side of the operation resolved to None, return None + if not hasattr(node, "left") or not hasattr(node, "right"): + return None + + if sys.version_info < (3, 10): + union_name = self.transformer._get_import("typing", "Union") + return Subscript( + value=union_name, + slice=Index( + Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() + ), + ctx=Load(), + ) + + return node + + def visit_Attribute(self, node: Attribute) -> Any: + if self._memo.is_ignored_name(node): + return None + + return node + + def visit_Subscript(self, node: Subscript) -> Any: + if self._memo.is_ignored_name(node.value): + return None + + # The subscript of typing(_extensions).Literal can be any arbitrary string, so + # don't try to evaluate it as code + if not self._memo.name_matches(node.value, *literal_names) and node.slice: + if isinstance(node.slice, Index): + # Python 3.7 and 3.8 + slice_value = node.slice.value # type: ignore[attr-defined] + else: + slice_value = node.slice + + if isinstance(slice_value, Tuple): + if self._memo.name_matches(node.value, *annotated_names): + # Only treat the first argument to typing.Annotated as a potential + # forward reference + items = cast( + typing.List[expr], + [self.generic_visit(slice_value.elts[0])] + + slice_value.elts[1:], + ) + else: + items = cast( + typing.List[expr], + [self.generic_visit(item) for item in slice_value.elts], + ) + + # If this is a Union and any of the items is Any, erase the entire + # annotation + if self._memo.name_matches(node.value, "typing.Union") and any( + isinstance(item, expr) + and self._memo.name_matches(item, *anytype_names) + for item in items + ): + return None + + # If all items in the subscript were Any, erase the subscript entirely + if all(item is None for item in items): + return node.value + + for index, item in enumerate(items): + if item is None: + items[index] = self.transformer._get_import("typing", "Any") + + slice_value.elts = items + else: + self.generic_visit(node) + + # If the transformer erased the slice entirely, just return the node + # value without the subscript (unless it's Optional, in which case erase + # the node entirely + if self._memo.name_matches(node.value, "typing.Optional"): + return None + elif sys.version_info >= (3, 9) and not hasattr(node, "slice"): + return node.value + elif sys.version_info < (3, 9) and not hasattr(node.slice, "value"): + return node.value + + return node + + def visit_Name(self, node: Name) -> Any: + if self._memo.is_ignored_name(node): + return None + + if sys.version_info < (3, 9): + for typename, substitute in self.type_substitutions.items(): + if self._memo.name_matches(node, typename): + new_node = self.transformer._get_import(*substitute) + return copy_location(new_node, node) + + return node + + def visit_Call(self, node: Call) -> Any: + # Don't recurse into calls + return node + + def visit_Constant(self, node: Constant) -> Any: + if isinstance(node.value, str): + expression = ast.parse(node.value, mode="eval") + new_node = self.visit(expression) + if new_node: + return copy_location(new_node.body, node) + else: + return None + + return node + + def visit_Str(self, node: Str) -> Any: + # Only used on Python 3.7 + expression = ast.parse(node.s, mode="eval") + new_node = self.visit(expression) + if new_node: + return copy_location(new_node.body, node) + else: + return None + + +class TypeguardTransformer(NodeTransformer): + def __init__( + self, target_path: Sequence[str] | None = None, target_lineno: int | None = None + ) -> None: + self._target_path = tuple(target_path) if target_path else None + self._memo = self._module_memo = TransformMemo(None, None, ()) + self.names_used_in_annotations: set[str] = set() + self.target_node: FunctionDef | AsyncFunctionDef | None = None + self.target_lineno = target_lineno + + @contextmanager + def _use_memo( + self, node: ClassDef | FunctionDef | AsyncFunctionDef + ) -> Generator[None, Any, None]: + new_memo = TransformMemo(node, self._memo, self._memo.path + (node.name,)) + if isinstance(node, (FunctionDef, AsyncFunctionDef)): + new_memo.should_instrument = ( + self._target_path is None or new_memo.path == self._target_path + ) + if new_memo.should_instrument: + # Check if the function is a generator function + detector = GeneratorDetector() + detector.visit(node) + + # Extract yield, send and return types where possible from a subscripted + # annotation like Generator[int, str, bool] + return_annotation = deepcopy(node.returns) + if detector.contains_yields and new_memo.name_matches( + return_annotation, *generator_names + ): + if isinstance(return_annotation, Subscript): + annotation_slice = return_annotation.slice + + # Python < 3.9 + if isinstance(annotation_slice, Index): + annotation_slice = ( + annotation_slice.value # type: ignore[attr-defined] + ) + + if isinstance(annotation_slice, Tuple): + items = annotation_slice.elts + else: + items = [annotation_slice] + + if len(items) > 0: + new_memo.yield_annotation = self._convert_annotation( + items[0] + ) + + if len(items) > 1: + new_memo.send_annotation = self._convert_annotation( + items[1] + ) + + if len(items) > 2: + new_memo.return_annotation = self._convert_annotation( + items[2] + ) + else: + new_memo.return_annotation = self._convert_annotation( + return_annotation + ) + + if isinstance(node, AsyncFunctionDef): + new_memo.is_async = True + + old_memo = self._memo + self._memo = new_memo + yield + self._memo = old_memo + + def _get_import(self, module: str, name: str) -> Name: + memo = self._memo if self._target_path else self._module_memo + return memo.get_import(module, name) + + @overload + def _convert_annotation(self, annotation: None) -> None: + ... + + @overload + def _convert_annotation(self, annotation: expr) -> expr: + ... + + def _convert_annotation(self, annotation: expr | None) -> expr | None: + if annotation is None: + return None + + # Convert PEP 604 unions (x | y) and generic built-in collections where + # necessary, and undo forward references + new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation)) + if isinstance(new_annotation, expr): + new_annotation = ast.copy_location(new_annotation, annotation) + + # Store names used in the annotation + names = {node.id for node in walk(new_annotation) if isinstance(node, Name)} + self.names_used_in_annotations.update(names) + + return new_annotation + + def visit_Name(self, node: Name) -> Name: + self._memo.local_names.add(node.id) + return node + + def visit_Module(self, node: Module) -> Module: + self.generic_visit(node) + self._memo.insert_imports(node) + + fix_missing_locations(node) + return node + + def visit_Import(self, node: Import) -> Import: + for name in node.names: + self._memo.local_names.add(name.asname or name.name) + self._memo.imported_names[name.asname or name.name] = name.name + + return node + + def visit_ImportFrom(self, node: ImportFrom) -> ImportFrom: + for name in node.names: + if name.name != "*": + alias = name.asname or name.name + self._memo.local_names.add(alias) + self._memo.imported_names[alias] = f"{node.module}.{name.name}" + + return node + + def visit_ClassDef(self, node: ClassDef) -> ClassDef | None: + self._memo.local_names.add(node.name) + + # Eliminate top level classes not belonging to the target path + if ( + self._target_path is not None + and not self._memo.path + and node.name != self._target_path[0] + ): + return None + + with self._use_memo(node): + for decorator in node.decorator_list.copy(): + if self._memo.name_matches(decorator, "typeguard.typechecked"): + # Remove the decorator to prevent duplicate instrumentation + node.decorator_list.remove(decorator) + + # Store any configuration overrides + if isinstance(decorator, Call) and decorator.keywords: + self._memo.configuration_overrides.update( + {kw.arg: kw.value for kw in decorator.keywords if kw.arg} + ) + + self.generic_visit(node) + return node + + def visit_FunctionDef( + self, node: FunctionDef | AsyncFunctionDef + ) -> FunctionDef | AsyncFunctionDef | None: + """ + Injects type checks for function arguments, and for a return of None if the + function is annotated to return something else than Any or None, and the body + ends without an explicit "return". + + """ + self._memo.local_names.add(node.name) + + # Eliminate top level functions not belonging to the target path + if ( + self._target_path is not None + and not self._memo.path + and node.name != self._target_path[0] + ): + return None + + # Skip instrumentation if we're instrumenting the whole module and the function + # contains either @no_type_check or @typeguard_ignore + if self._target_path is None: + for decorator in node.decorator_list: + if self._memo.name_matches(decorator, *ignore_decorators): + return node + + with self._use_memo(node): + arg_annotations: dict[str, Any] = {} + if self._target_path is None or self._memo.path == self._target_path: + # Find line number we're supposed to match against + if node.decorator_list: + first_lineno = node.decorator_list[0].lineno + else: + first_lineno = node.lineno + + for decorator in node.decorator_list.copy(): + if self._memo.name_matches(decorator, "typing.overload"): + # Remove overloads entirely + return None + elif self._memo.name_matches(decorator, "typeguard.typechecked"): + # Remove the decorator to prevent duplicate instrumentation + node.decorator_list.remove(decorator) + + # Store any configuration overrides + if isinstance(decorator, Call) and decorator.keywords: + self._memo.configuration_overrides = { + kw.arg: kw.value for kw in decorator.keywords if kw.arg + } + + if self.target_lineno == first_lineno: + assert self.target_node is None + self.target_node = node + if node.decorator_list and sys.version_info >= (3, 8): + self.target_lineno = node.decorator_list[0].lineno + else: + self.target_lineno = node.lineno + + all_args = node.args.args + node.args.kwonlyargs + if sys.version_info >= (3, 8): + all_args.extend(node.args.posonlyargs) + + # Ensure that any type shadowed by the positional or keyword-only + # argument names are ignored in this function + for arg in all_args: + self._memo.ignored_names.add(arg.arg) + + # Ensure that any type shadowed by the variable positional argument name + # (e.g. "args" in *args) is ignored this function + if node.args.vararg: + self._memo.ignored_names.add(node.args.vararg.arg) + + # Ensure that any type shadowed by the variable keywrod argument name + # (e.g. "kwargs" in *kwargs) is ignored this function + if node.args.kwarg: + self._memo.ignored_names.add(node.args.kwarg.arg) + + for arg in all_args: + annotation = self._convert_annotation(deepcopy(arg.annotation)) + if annotation: + arg_annotations[arg.arg] = annotation + + if node.args.vararg: + annotation_ = self._convert_annotation(node.args.vararg.annotation) + if annotation_: + if sys.version_info >= (3, 9): + container = Name("tuple", ctx=Load()) + else: + container = self._get_import("typing", "Tuple") + + subscript_slice: Tuple | Index = Tuple( + [ + annotation_, + Constant(Ellipsis), + ], + ctx=Load(), + ) + if sys.version_info < (3, 9): + subscript_slice = Index(subscript_slice, ctx=Load()) + + arg_annotations[node.args.vararg.arg] = Subscript( + container, subscript_slice, ctx=Load() + ) + + if node.args.kwarg: + annotation_ = self._convert_annotation(node.args.kwarg.annotation) + if annotation_: + if sys.version_info >= (3, 9): + container = Name("dict", ctx=Load()) + else: + container = self._get_import("typing", "Dict") + + subscript_slice = Tuple( + [ + Name("str", ctx=Load()), + annotation_, + ], + ctx=Load(), + ) + if sys.version_info < (3, 9): + subscript_slice = Index(subscript_slice, ctx=Load()) + + arg_annotations[node.args.kwarg.arg] = Subscript( + container, subscript_slice, ctx=Load() + ) + + if arg_annotations: + self._memo.variable_annotations.update(arg_annotations) + + self.generic_visit(node) + + if arg_annotations: + annotations_dict = Dict( + keys=[Constant(key) for key in arg_annotations.keys()], + values=[ + Tuple([Name(key, ctx=Load()), annotation], ctx=Load()) + for key, annotation in arg_annotations.items() + ], + ) + func_name = self._get_import( + "typeguard._functions", "check_argument_types" + ) + args = [ + self._memo.joined_path, + annotations_dict, + self._memo.get_memo_name(), + ] + node.body.insert( + self._memo.code_inject_index, Expr(Call(func_name, args, [])) + ) + + # Add a checked "return None" to the end if there's no explicit return + # Skip if the return annotation is None or Any + if ( + self._memo.return_annotation + and (not self._memo.is_async or not self._memo.has_yield_expressions) + and not isinstance(node.body[-1], Return) + and ( + not isinstance(self._memo.return_annotation, Constant) + or self._memo.return_annotation.value is not None + ) + ): + func_name = self._get_import( + "typeguard._functions", "check_return_type" + ) + return_node = Return( + Call( + func_name, + [ + self._memo.joined_path, + Constant(None), + self._memo.return_annotation, + self._memo.get_memo_name(), + ], + [], + ) + ) + + # Replace a placeholder "pass" at the end + if isinstance(node.body[-1], Pass): + copy_location(return_node, node.body[-1]) + del node.body[-1] + + node.body.append(return_node) + + # Insert code to create the call memo, if it was ever needed for this + # function + if self._memo.memo_var_name: + memo_kwargs: dict[str, Any] = {} + if self._memo.parent and isinstance(self._memo.parent.node, ClassDef): + for decorator in node.decorator_list: + if ( + isinstance(decorator, Name) + and decorator.id == "staticmethod" + ): + break + elif ( + isinstance(decorator, Name) + and decorator.id == "classmethod" + ): + memo_kwargs["self_type"] = Name( + id=node.args.args[0].arg, ctx=Load() + ) + break + else: + if node.args.args: + if node.name == "__new__": + memo_kwargs["self_type"] = Name( + id=node.args.args[0].arg, ctx=Load() + ) + else: + memo_kwargs["self_type"] = Attribute( + Name(id=node.args.args[0].arg, ctx=Load()), + "__class__", + ctx=Load(), + ) + + # Construct the function reference + # Nested functions get special treatment: the function name is added + # to free variables (and the closure of the resulting function) + names: list[str] = [node.name] + memo = self._memo.parent + while memo: + if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)): + # This is a nested function. Use the function name as-is. + del names[:-1] + break + elif not isinstance(memo.node, ClassDef): + break + + names.insert(0, memo.node.name) + memo = memo.parent + + config_keywords = self._memo.get_config_keywords() + if config_keywords: + memo_kwargs["config"] = Call( + self._get_import("dataclasses", "replace"), + [self._get_import("typeguard._config", "global_config")], + config_keywords, + ) + + self._memo.memo_var_name.id = self._memo.get_unused_name("memo") + memo_store_name = Name(id=self._memo.memo_var_name.id, ctx=Store()) + globals_call = Call(Name(id="globals", ctx=Load()), [], []) + locals_call = Call(Name(id="locals", ctx=Load()), [], []) + memo_expr = Call( + self._get_import("typeguard", "TypeCheckMemo"), + [globals_call, locals_call], + [keyword(key, value) for key, value in memo_kwargs.items()], + ) + node.body.insert( + self._memo.code_inject_index, + Assign([memo_store_name], memo_expr), + ) + + self._memo.insert_imports(node) + + # Rmove any placeholder "pass" at the end + if isinstance(node.body[-1], Pass): + del node.body[-1] + + return node + + def visit_AsyncFunctionDef( + self, node: AsyncFunctionDef + ) -> FunctionDef | AsyncFunctionDef | None: + return self.visit_FunctionDef(node) + + def visit_Return(self, node: Return) -> Return: + """This injects type checks into "return" statements.""" + self.generic_visit(node) + if ( + self._memo.return_annotation + and self._memo.should_instrument + and not self._memo.is_ignored_name(self._memo.return_annotation) + ): + func_name = self._get_import("typeguard._functions", "check_return_type") + old_node = node + retval = old_node.value or Constant(None) + node = Return( + Call( + func_name, + [ + self._memo.joined_path, + retval, + self._memo.return_annotation, + self._memo.get_memo_name(), + ], + [], + ) + ) + copy_location(node, old_node) + + return node + + def visit_Yield(self, node: Yield) -> Yield | Call: + """ + This injects type checks into "yield" expressions, checking both the yielded + value and the value sent back to the generator, when appropriate. + + """ + self._memo.has_yield_expressions = True + self.generic_visit(node) + + if ( + self._memo.yield_annotation + and self._memo.should_instrument + and not self._memo.is_ignored_name(self._memo.yield_annotation) + ): + func_name = self._get_import("typeguard._functions", "check_yield_type") + yieldval = node.value or Constant(None) + node.value = Call( + func_name, + [ + self._memo.joined_path, + yieldval, + self._memo.yield_annotation, + self._memo.get_memo_name(), + ], + [], + ) + + if ( + self._memo.send_annotation + and self._memo.should_instrument + and not self._memo.is_ignored_name(self._memo.send_annotation) + ): + func_name = self._get_import("typeguard._functions", "check_send_type") + old_node = node + call_node = Call( + func_name, + [ + self._memo.joined_path, + old_node, + self._memo.send_annotation, + self._memo.get_memo_name(), + ], + [], + ) + copy_location(call_node, old_node) + return call_node + + return node + + def visit_AnnAssign(self, node: AnnAssign) -> Any: + """ + This injects a type check into a local variable annotation-assignment within a + function body. + + """ + self.generic_visit(node) + + if ( + isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) + and node.annotation + and isinstance(node.target, Name) + ): + self._memo.ignored_names.add(node.target.id) + annotation = self._convert_annotation(deepcopy(node.annotation)) + if annotation: + self._memo.variable_annotations[node.target.id] = annotation + if node.value: + func_name = self._get_import( + "typeguard._functions", "check_variable_assignment" + ) + node.value = Call( + func_name, + [ + node.value, + Constant(node.target.id), + annotation, + self._memo.get_memo_name(), + ], + [], + ) + + return node + + def visit_Assign(self, node: Assign) -> Any: + """ + This injects a type check into a local variable assignment within a function + body. The variable must have been annotated earlier in the function body. + + """ + self.generic_visit(node) + + # Only instrument function-local assignments + if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)): + targets: list[dict[Constant, expr | None]] = [] + check_required = False + for target in node.targets: + elts: Sequence[expr] + if isinstance(target, Name): + elts = [target] + elif isinstance(target, Tuple): + elts = target.elts + else: + continue + + annotations_: dict[Constant, expr | None] = {} + for exp in elts: + prefix = "" + if isinstance(exp, Starred): + exp = exp.value + prefix = "*" + + if isinstance(exp, Name): + self._memo.ignored_names.add(exp.id) + name = prefix + exp.id + annotation = self._memo.variable_annotations.get(exp.id) + if annotation: + annotations_[Constant(name)] = annotation + check_required = True + else: + annotations_[Constant(name)] = None + + targets.append(annotations_) + + if check_required: + # Replace missing annotations with typing.Any + for item in targets: + for key, expression in item.items(): + if expression is None: + item[key] = self._get_import("typing", "Any") + + if len(targets) == 1 and len(targets[0]) == 1: + func_name = self._get_import( + "typeguard._functions", "check_variable_assignment" + ) + target_varname = next(iter(targets[0])) + node.value = Call( + func_name, + [ + node.value, + target_varname, + targets[0][target_varname], + self._memo.get_memo_name(), + ], + [], + ) + elif targets: + func_name = self._get_import( + "typeguard._functions", "check_multi_variable_assignment" + ) + targets_arg = List( + [ + Dict(keys=list(target), values=list(target.values())) + for target in targets + ], + ctx=Load(), + ) + node.value = Call( + func_name, + [node.value, targets_arg, self._memo.get_memo_name()], + [], + ) + + return node + + def visit_NamedExpr(self, node: NamedExpr) -> Any: + """This injects a type check into an assignment expression (a := foo()).""" + self.generic_visit(node) + + # Only instrument function-local assignments + if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance( + node.target, Name + ): + self._memo.ignored_names.add(node.target.id) + + # Bail out if no matching annotation is found + annotation = self._memo.variable_annotations.get(node.target.id) + if annotation is None: + return node + + func_name = self._get_import( + "typeguard._functions", "check_variable_assignment" + ) + node.value = Call( + func_name, + [ + node.value, + Constant(node.target.id), + annotation, + self._memo.get_memo_name(), + ], + [], + ) + + return node + + def visit_AugAssign(self, node: AugAssign) -> Any: + """ + This injects a type check into an augmented assignment expression (a += 1). + + """ + self.generic_visit(node) + + # Only instrument function-local assignments + if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance( + node.target, Name + ): + # Bail out if no matching annotation is found + annotation = self._memo.variable_annotations.get(node.target.id) + if annotation is None: + return node + + # Bail out if the operator is not found (newer Python version?) + try: + operator_func_name = aug_assign_functions[node.op.__class__] + except KeyError: + return node + + operator_func = self._get_import("operator", operator_func_name) + operator_call = Call( + operator_func, [Name(node.target.id, ctx=Load()), node.value], [] + ) + check_call = Call( + self._get_import("typeguard._functions", "check_variable_assignment"), + [ + operator_call, + Constant(node.target.id), + annotation, + self._memo.get_memo_name(), + ], + [], + ) + return Assign(targets=[node.target], value=check_call) + + return node + + def visit_If(self, node: If) -> Any: + """ + This blocks names from being collected from a module-level + "if typing.TYPE_CHECKING:" block, so that they won't be type checked. + + """ + self.generic_visit(node) + + # Fix empty node body (caused by removal of classes/functions not on the target + # path) + if not node.body: + node.body.append(Pass()) + + if ( + self._memo is self._module_memo + and isinstance(node.test, Name) + and self._memo.name_matches(node.test, "typing.TYPE_CHECKING") + ): + collector = NameCollector() + collector.visit(node) + self._memo.ignored_names.update(collector.names) + + return node diff --git a/metaflow/_vendor/typeguard/_union_transformer.py b/metaflow/_vendor/typeguard/_union_transformer.py new file mode 100644 index 00000000000..fcd6349d35a --- /dev/null +++ b/metaflow/_vendor/typeguard/_union_transformer.py @@ -0,0 +1,54 @@ +""" +Transforms lazily evaluated PEP 604 unions into typing.Unions, for compatibility with +Python versions older than 3.10. +""" +from __future__ import annotations + +from ast import ( + BinOp, + BitOr, + Index, + Load, + Name, + NodeTransformer, + Subscript, + fix_missing_locations, + parse, +) +from ast import Tuple as ASTTuple +from types import CodeType +from typing import Any, Dict, FrozenSet, List, Set, Tuple, Union + +type_substitutions = { + "dict": Dict, + "list": List, + "tuple": Tuple, + "set": Set, + "frozenset": FrozenSet, + "Union": Union, +} + + +class UnionTransformer(NodeTransformer): + def __init__(self, union_name: Name | None = None): + self.union_name = union_name or Name(id="Union", ctx=Load()) + + def visit_BinOp(self, node: BinOp) -> Any: + self.generic_visit(node) + if isinstance(node.op, BitOr): + return Subscript( + value=self.union_name, + slice=Index( + ASTTuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() + ), + ctx=Load(), + ) + + return node + + +def compile_type_hint(hint: str) -> CodeType: + parsed = parse(hint, "", "eval") + UnionTransformer().visit(parsed) + fix_missing_locations(parsed) + return compile(parsed, "", "eval", flags=0) diff --git a/metaflow/_vendor/typeguard/_utils.py b/metaflow/_vendor/typeguard/_utils.py new file mode 100644 index 00000000000..2470c85757a --- /dev/null +++ b/metaflow/_vendor/typeguard/_utils.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import inspect +import sys +from importlib import import_module +from inspect import currentframe +from types import CodeType, FrameType, FunctionType +from typing import TYPE_CHECKING, Any, Callable, ForwardRef, Union, cast +from weakref import WeakValueDictionary + +if TYPE_CHECKING: + from ._memo import TypeCheckMemo + +if sys.version_info >= (3, 10): + from typing import get_args, get_origin + + def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: + return forwardref._evaluate(memo.globals, memo.locals, frozenset()) + +else: + from metaflow._vendor.typing_extensions import get_args, get_origin + + evaluate_extra_args: tuple[frozenset[Any], ...] = ( + (frozenset(),) if sys.version_info >= (3, 9) else () + ) + + def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: + from ._union_transformer import compile_type_hint, type_substitutions + + if not forwardref.__forward_evaluated__: + forwardref.__forward_code__ = compile_type_hint(forwardref.__forward_arg__) + + try: + return forwardref._evaluate(memo.globals, memo.locals, *evaluate_extra_args) + except NameError: + if sys.version_info < (3, 10): + # Try again, with the type substitutions (list -> List etc.) in place + new_globals = memo.globals.copy() + new_globals.setdefault("Union", Union) + if sys.version_info < (3, 9): + new_globals.update(type_substitutions) + + return forwardref._evaluate( + new_globals, memo.locals or new_globals, *evaluate_extra_args + ) + + raise + + +if sys.version_info >= (3, 8): + from typing import final +else: + from metaflow._vendor.typing_extensions import final + + +_functions_map: WeakValueDictionary[CodeType, FunctionType] = WeakValueDictionary() + + +def get_type_name(type_: Any) -> str: + name: str + for attrname in "__name__", "_name", "__forward_arg__": + candidate = getattr(type_, attrname, None) + if isinstance(candidate, str): + name = candidate + break + else: + origin = get_origin(type_) + candidate = getattr(origin, "_name", None) + if candidate is None: + candidate = type_.__class__.__name__.strip("_") + + if isinstance(candidate, str): + name = candidate + else: + return "(unknown)" + + args = get_args(type_) + if args: + if name == "Literal": + formatted_args = ", ".join(repr(arg) for arg in args) + else: + formatted_args = ", ".join(get_type_name(arg) for arg in args) + + name += f"[{formatted_args}]" + + module = getattr(type_, "__module__", None) + if module and module not in (None, "typing", "typing_extensions", "builtins"): + name = module + "." + name + + return name + + +def qualified_name(obj: Any, *, add_class_prefix: bool = False) -> str: + """ + Return the qualified name (e.g. package.module.Type) for the given object. + + Builtins and types from the :mod:`typing` package get special treatment by having + the module name stripped from the generated name. + + """ + if obj is None: + return "None" + elif inspect.isclass(obj): + prefix = "class " if add_class_prefix else "" + type_ = obj + else: + prefix = "" + type_ = type(obj) + + module = type_.__module__ + qualname = type_.__qualname__ + name = qualname if module in ("typing", "builtins") else f"{module}.{qualname}" + return prefix + name + + +def function_name(func: Callable[..., Any]) -> str: + """ + Return the qualified name of the given function. + + Builtins and types from the :mod:`typing` package get special treatment by having + the module name stripped from the generated name. + + """ + # For partial functions and objects with __call__ defined, __qualname__ does not + # exist + module = getattr(func, "__module__", "") + qualname = (module + ".") if module not in ("builtins", "") else "" + return qualname + getattr(func, "__qualname__", repr(func)) + + +def resolve_reference(reference: str) -> Any: + modulename, varname = reference.partition(":")[::2] + if not modulename or not varname: + raise ValueError(f"{reference!r} is not a module:varname reference") + + obj = import_module(modulename) + for attr in varname.split("."): + obj = getattr(obj, attr) + + return obj + + +def is_method_of(obj: object, cls: type) -> bool: + return ( + inspect.isfunction(obj) + and obj.__module__ == cls.__module__ + and obj.__qualname__.startswith(cls.__qualname__ + ".") + ) + + +def get_stacklevel() -> int: + level = 1 + frame = cast(FrameType, currentframe()).f_back + while frame and frame.f_globals.get("__name__", "").startswith("typeguard."): + level += 1 + frame = frame.f_back + + return level + + +@final +class Unset: + __slots__ = () + + def __repr__(self) -> str: + return "" + + +unset = Unset() diff --git a/metaflow/_vendor/typeguard/py.typed b/metaflow/_vendor/typeguard/py.typed new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/_vendor/typing_extensions.LICENSE b/metaflow/_vendor/typing_extensions.LICENSE new file mode 100644 index 00000000000..f26bcf4d2de --- /dev/null +++ b/metaflow/_vendor/typing_extensions.LICENSE @@ -0,0 +1,279 @@ +A. HISTORY OF THE SOFTWARE +========================== + +Python was created in the early 1990s by Guido van Rossum at Stichting +Mathematisch Centrum (CWI, see https://www.cwi.nl) in the Netherlands +as a successor of a language called ABC. Guido remains Python's +principal author, although it includes many contributions from others. + +In 1995, Guido continued his work on Python at the Corporation for +National Research Initiatives (CNRI, see https://www.cnri.reston.va.us) +in Reston, Virginia where he released several versions of the +software. + +In May 2000, Guido and the Python core development team moved to +BeOpen.com to form the BeOpen PythonLabs team. In October of the same +year, the PythonLabs team moved to Digital Creations, which became +Zope Corporation. In 2001, the Python Software Foundation (PSF, see +https://www.python.org/psf/) was formed, a non-profit organization +created specifically to own Python-related Intellectual Property. +Zope Corporation was a sponsoring member of the PSF. + +All Python releases are Open Source (see https://opensource.org for +the Open Source Definition). Historically, most, but not all, Python +releases have also been GPL-compatible; the table below summarizes +the various releases. + + Release Derived Year Owner GPL- + from compatible? (1) + + 0.9.0 thru 1.2 1991-1995 CWI yes + 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes + 1.6 1.5.2 2000 CNRI no + 2.0 1.6 2000 BeOpen.com no + 1.6.1 1.6 2001 CNRI yes (2) + 2.1 2.0+1.6.1 2001 PSF no + 2.0.1 2.0+1.6.1 2001 PSF yes + 2.1.1 2.1+2.0.1 2001 PSF yes + 2.1.2 2.1.1 2002 PSF yes + 2.1.3 2.1.2 2002 PSF yes + 2.2 and above 2.1.1 2001-now PSF yes + +Footnotes: + +(1) GPL-compatible doesn't mean that we're distributing Python under + the GPL. All Python licenses, unlike the GPL, let you distribute + a modified version without making your changes open source. The + GPL-compatible licenses make it possible to combine Python with + other software that is released under the GPL; the others don't. + +(2) According to Richard Stallman, 1.6.1 is not GPL-compatible, + because its license has a choice of law clause. According to + CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 + is "not incompatible" with the GPL. + +Thanks to the many outside volunteers who have worked under Guido's +direction to make these releases possible. + + +B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +=============================================================== + +Python software and documentation are licensed under the +Python Software Foundation License Version 2. + +Starting with Python 3.8.6, examples, recipes, and other code in +the documentation are dual licensed under the PSF License Version 2 +and the Zero-Clause BSD license. + +Some software incorporated into Python is under different licenses. +The licenses are listed with code falling under that license. + + +PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +-------------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023 Python Software Foundation; +All Rights Reserved" are retained in Python alone or in any derivative version +prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + + +BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +------------------------------------------- + +BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 + +1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +Individual or Organization ("Licensee") accessing and otherwise using +this software in source or binary form and its associated +documentation ("the Software"). + +2. Subject to the terms and conditions of this BeOpen Python License +Agreement, BeOpen hereby grants Licensee a non-exclusive, +royalty-free, world-wide license to reproduce, analyze, test, perform +and/or display publicly, prepare derivative works, distribute, and +otherwise use the Software alone or in any derivative version, +provided, however, that the BeOpen Python License is retained in the +Software, alone or in any derivative version prepared by Licensee. + +3. BeOpen is making the Software available to Licensee on an "AS IS" +basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +5. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +6. This License Agreement shall be governed by and interpreted in all +respects by the law of the State of California, excluding conflict of +law provisions. Nothing in this License Agreement shall be deemed to +create any relationship of agency, partnership, or joint venture +between BeOpen and Licensee. This License Agreement does not grant +permission to use BeOpen trademarks or trade names in a trademark +sense to endorse or promote products or services of Licensee, or any +third party. As an exception, the "BeOpen Python" logos available at +http://www.pythonlabs.com/logos.html may be used according to the +permissions granted on that web page. + +7. By copying, installing or otherwise using the software, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + + +CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +--------------------------------------- + +1. This LICENSE AGREEMENT is between the Corporation for National +Research Initiatives, having an office at 1895 Preston White Drive, +Reston, VA 20191 ("CNRI"), and the Individual or Organization +("Licensee") accessing and otherwise using Python 1.6.1 software in +source or binary form and its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, CNRI +hereby grants Licensee a nonexclusive, royalty-free, world-wide +license to reproduce, analyze, test, perform and/or display publicly, +prepare derivative works, distribute, and otherwise use Python 1.6.1 +alone or in any derivative version, provided, however, that CNRI's +License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +1995-2001 Corporation for National Research Initiatives; All Rights +Reserved" are retained in Python 1.6.1 alone or in any derivative +version prepared by Licensee. Alternately, in lieu of CNRI's License +Agreement, Licensee may substitute the following text (omitting the +quotes): "Python 1.6.1 is made available subject to the terms and +conditions in CNRI's License Agreement. This Agreement together with +Python 1.6.1 may be located on the internet using the following +unique, persistent identifier (known as a handle): 1895.22/1013. This +Agreement may also be obtained from a proxy server on the internet +using the following URL: http://hdl.handle.net/1895.22/1013". + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python 1.6.1 or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python 1.6.1. + +4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. This License Agreement shall be governed by the federal +intellectual property law of the United States, including without +limitation the federal copyright law, and, to the extent such +U.S. federal law does not apply, by the law of the Commonwealth of +Virginia, excluding Virginia's conflict of law provisions. +Notwithstanding the foregoing, with regard to derivative works based +on Python 1.6.1 that incorporate non-separable material that was +previously distributed under the GNU General Public License (GPL), the +law of the Commonwealth of Virginia shall govern this License +Agreement only as to issues arising under or with respect to +Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +License Agreement shall be deemed to create any relationship of +agency, partnership, or joint venture between CNRI and Licensee. This +License Agreement does not grant permission to use CNRI trademarks or +trade name in a trademark sense to endorse or promote products or +services of Licensee, or any third party. + +8. By clicking on the "ACCEPT" button where indicated, or by copying, +installing or otherwise using Python 1.6.1, Licensee agrees to be +bound by the terms and conditions of this License Agreement. + + ACCEPT + + +CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +-------------------------------------------------- + +Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +The Netherlands. All rights reserved. + +Permission to use, copy, modify, and distribute this software and its +documentation for any purpose and without fee is hereby granted, +provided that the above copyright notice appear in all copies and that +both that copyright notice and this permission notice appear in +supporting documentation, and that the name of Stichting Mathematisch +Centrum or CWI not be used in advertising or publicity pertaining to +distribution of the software without specific, written prior +permission. + +STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +ZERO-CLAUSE BSD LICENSE FOR CODE IN THE PYTHON DOCUMENTATION +---------------------------------------------------------------------- + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. diff --git a/metaflow/_vendor/typing_extensions.py b/metaflow/_vendor/typing_extensions.py new file mode 100644 index 00000000000..85f5950eaf4 --- /dev/null +++ b/metaflow/_vendor/typing_extensions.py @@ -0,0 +1,3053 @@ +import abc +import collections +import collections.abc +import functools +import inspect +import operator +import sys +import types as _types +import typing +import warnings + +__all__ = [ + # Super-special typing primitives. + 'Any', + 'ClassVar', + 'Concatenate', + 'Final', + 'LiteralString', + 'ParamSpec', + 'ParamSpecArgs', + 'ParamSpecKwargs', + 'Self', + 'Type', + 'TypeVar', + 'TypeVarTuple', + 'Unpack', + + # ABCs (from collections.abc). + 'Awaitable', + 'AsyncIterator', + 'AsyncIterable', + 'Coroutine', + 'AsyncGenerator', + 'AsyncContextManager', + 'Buffer', + 'ChainMap', + + # Concrete collection types. + 'ContextManager', + 'Counter', + 'Deque', + 'DefaultDict', + 'NamedTuple', + 'OrderedDict', + 'TypedDict', + + # Structural checks, a.k.a. protocols. + 'SupportsAbs', + 'SupportsBytes', + 'SupportsComplex', + 'SupportsFloat', + 'SupportsIndex', + 'SupportsInt', + 'SupportsRound', + + # One-off things. + 'Annotated', + 'assert_never', + 'assert_type', + 'clear_overloads', + 'dataclass_transform', + 'deprecated', + 'get_overloads', + 'final', + 'get_args', + 'get_origin', + 'get_original_bases', + 'get_protocol_members', + 'get_type_hints', + 'IntVar', + 'is_protocol', + 'is_typeddict', + 'Literal', + 'NewType', + 'overload', + 'override', + 'Protocol', + 'reveal_type', + 'runtime', + 'runtime_checkable', + 'Text', + 'TypeAlias', + 'TypeAliasType', + 'TypeGuard', + 'TYPE_CHECKING', + 'Never', + 'NoReturn', + 'Required', + 'NotRequired', + + # Pure aliases, have always been in typing + 'AbstractSet', + 'AnyStr', + 'BinaryIO', + 'Callable', + 'Collection', + 'Container', + 'Dict', + 'ForwardRef', + 'FrozenSet', + 'Generator', + 'Generic', + 'Hashable', + 'IO', + 'ItemsView', + 'Iterable', + 'Iterator', + 'KeysView', + 'List', + 'Mapping', + 'MappingView', + 'Match', + 'MutableMapping', + 'MutableSequence', + 'MutableSet', + 'Optional', + 'Pattern', + 'Reversible', + 'Sequence', + 'Set', + 'Sized', + 'TextIO', + 'Tuple', + 'Union', + 'ValuesView', + 'cast', + 'no_type_check', + 'no_type_check_decorator', +] + +# for backward compatibility +PEP_560 = True +GenericMeta = type + +# The functions below are modified copies of typing internal helpers. +# They are needed by _ProtocolMeta and they provide support for PEP 646. + + +class _Sentinel: + def __repr__(self): + return "" + + +_marker = _Sentinel() + + +def _check_generic(cls, parameters, elen=_marker): + """Check correct count for parameters of a generic cls (internal helper). + This gives a nice error message in case of count mismatch. + """ + if not elen: + raise TypeError(f"{cls} is not a generic class") + if elen is _marker: + if not hasattr(cls, "__parameters__") or not cls.__parameters__: + raise TypeError(f"{cls} is not a generic class") + elen = len(cls.__parameters__) + alen = len(parameters) + if alen != elen: + if hasattr(cls, "__parameters__"): + parameters = [p for p in cls.__parameters__ if not _is_unpack(p)] + num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters) + if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples): + return + raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};" + f" actual {alen}, expected {elen}") + + +if sys.version_info >= (3, 10): + def _should_collect_from_parameters(t): + return isinstance( + t, (typing._GenericAlias, _types.GenericAlias, _types.UnionType) + ) +elif sys.version_info >= (3, 9): + def _should_collect_from_parameters(t): + return isinstance(t, (typing._GenericAlias, _types.GenericAlias)) +else: + def _should_collect_from_parameters(t): + return isinstance(t, typing._GenericAlias) and not t._special + + +def _collect_type_vars(types, typevar_types=None): + """Collect all type variable contained in types in order of + first appearance (lexicographic order). For example:: + + _collect_type_vars((T, List[S, T])) == (T, S) + """ + if typevar_types is None: + typevar_types = typing.TypeVar + tvars = [] + for t in types: + if ( + isinstance(t, typevar_types) and + t not in tvars and + not _is_unpack(t) + ): + tvars.append(t) + if _should_collect_from_parameters(t): + tvars.extend([t for t in t.__parameters__ if t not in tvars]) + return tuple(tvars) + + +NoReturn = typing.NoReturn + +# Some unconstrained type variables. These are used by the container types. +# (These are not for export.) +T = typing.TypeVar('T') # Any type. +KT = typing.TypeVar('KT') # Key type. +VT = typing.TypeVar('VT') # Value type. +T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers. +T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant. + + +if sys.version_info >= (3, 11): + from typing import Any +else: + + class _AnyMeta(type): + def __instancecheck__(self, obj): + if self is Any: + raise TypeError("typing_extensions.Any cannot be used with isinstance()") + return super().__instancecheck__(obj) + + def __repr__(self): + if self is Any: + return "typing_extensions.Any" + return super().__repr__() + + class Any(metaclass=_AnyMeta): + """Special type indicating an unconstrained type. + - Any is compatible with every type. + - Any assumed to have all methods. + - All values assumed to be instances of Any. + Note that all the above statements are true from the point of view of + static type checkers. At runtime, Any should not be used with instance + checks. + """ + def __new__(cls, *args, **kwargs): + if cls is Any: + raise TypeError("Any cannot be instantiated") + return super().__new__(cls, *args, **kwargs) + + +ClassVar = typing.ClassVar + + +class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + +# On older versions of typing there is an internal class named "Final". +# 3.8+ +if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7): + Final = typing.Final +# 3.7 +else: + class _FinalForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + Final = _FinalForm('Final', + doc="""A special typing construct to indicate that a name + cannot be re-assigned or overridden in a subclass. + For example: + + MAX_SIZE: Final = 9000 + MAX_SIZE += 1 # Error reported by type checker + + class Connection: + TIMEOUT: Final[int] = 10 + class FastConnector(Connection): + TIMEOUT = 1 # Error reported by type checker + + There is no runtime checking of these properties.""") + +if sys.version_info >= (3, 11): + final = typing.final +else: + # @final exists in 3.8+, but we backport it for all versions + # before 3.11 to keep support for the __final__ attribute. + # See https://bugs.python.org/issue46342 + def final(f): + """This decorator can be used to indicate to type checkers that + the decorated method cannot be overridden, and decorated class + cannot be subclassed. For example: + + class Base: + @final + def done(self) -> None: + ... + class Sub(Base): + def done(self) -> None: # Error reported by type checker + ... + @final + class Leaf: + ... + class Other(Leaf): # Error reported by type checker + ... + + There is no runtime checking of these properties. The decorator + sets the ``__final__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + """ + try: + f.__final__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return f + + +def IntVar(name): + return typing.TypeVar(name) + + +# A Literal bug was fixed in 3.11.0, 3.10.1 and 3.9.8 +if sys.version_info >= (3, 10, 1): + Literal = typing.Literal +else: + def _flatten_literal_params(parameters): + """An internal helper for Literal creation: flatten Literals among parameters""" + params = [] + for p in parameters: + if isinstance(p, _LiteralGenericAlias): + params.extend(p.__args__) + else: + params.append(p) + return tuple(params) + + def _value_and_type_iter(params): + for p in params: + yield p, type(p) + + class _LiteralGenericAlias(typing._GenericAlias, _root=True): + def __eq__(self, other): + if not isinstance(other, _LiteralGenericAlias): + return NotImplemented + these_args_deduped = set(_value_and_type_iter(self.__args__)) + other_args_deduped = set(_value_and_type_iter(other.__args__)) + return these_args_deduped == other_args_deduped + + def __hash__(self): + return hash(frozenset(_value_and_type_iter(self.__args__))) + + class _LiteralForm(_ExtensionsSpecialForm, _root=True): + def __init__(self, doc: str): + self._name = 'Literal' + self._doc = self.__doc__ = doc + + def __getitem__(self, parameters): + if not isinstance(parameters, tuple): + parameters = (parameters,) + + parameters = _flatten_literal_params(parameters) + + val_type_pairs = list(_value_and_type_iter(parameters)) + try: + deduped_pairs = set(val_type_pairs) + except TypeError: + # unhashable parameters + pass + else: + # similar logic to typing._deduplicate on Python 3.9+ + if len(deduped_pairs) < len(val_type_pairs): + new_parameters = [] + for pair in val_type_pairs: + if pair in deduped_pairs: + new_parameters.append(pair[0]) + deduped_pairs.remove(pair) + assert not deduped_pairs, deduped_pairs + parameters = tuple(new_parameters) + + return _LiteralGenericAlias(self, parameters) + + Literal = _LiteralForm(doc="""\ + A type that can be used to indicate to type checkers + that the corresponding value has a value literally equivalent + to the provided parameter. For example: + + var: Literal[4] = 4 + + The type checker understands that 'var' is literally equal to + the value 4 and no other value. + + Literal[...] cannot be subclassed. There is no runtime + checking verifying that the parameter is actually a value + instead of a type.""") + + +_overload_dummy = typing._overload_dummy + + +if hasattr(typing, "get_overloads"): # 3.11+ + overload = typing.overload + get_overloads = typing.get_overloads + clear_overloads = typing.clear_overloads +else: + # {module: {qualname: {firstlineno: func}}} + _overload_registry = collections.defaultdict( + functools.partial(collections.defaultdict, dict) + ) + + def overload(func): + """Decorator for overloaded functions/methods. + + In a stub file, place two or more stub definitions for the same + function in a row, each decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + + In a non-stub file (i.e. a regular .py file), do the same but + follow it with an implementation. The implementation should *not* + be decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + def utf8(value): + # implementation goes here + + The overloads for a function can be retrieved at runtime using the + get_overloads() function. + """ + # classmethod and staticmethod + f = getattr(func, "__func__", func) + try: + _overload_registry[f.__module__][f.__qualname__][ + f.__code__.co_firstlineno + ] = func + except AttributeError: + # Not a normal function; ignore. + pass + return _overload_dummy + + def get_overloads(func): + """Return all defined overloads for *func* as a sequence.""" + # classmethod and staticmethod + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return [] + mod_dict = _overload_registry[f.__module__] + if f.__qualname__ not in mod_dict: + return [] + return list(mod_dict[f.__qualname__].values()) + + def clear_overloads(): + """Clear all overloads in the registry.""" + _overload_registry.clear() + + +# This is not a real generic class. Don't use outside annotations. +Type = typing.Type + +# Various ABCs mimicking those in collections.abc. +# A few are simply re-exported for completeness. + + +Awaitable = typing.Awaitable +Coroutine = typing.Coroutine +AsyncIterable = typing.AsyncIterable +AsyncIterator = typing.AsyncIterator +Deque = typing.Deque +ContextManager = typing.ContextManager +AsyncContextManager = typing.AsyncContextManager +DefaultDict = typing.DefaultDict + +# 3.7.2+ +if hasattr(typing, 'OrderedDict'): + OrderedDict = typing.OrderedDict +# 3.7.0-3.7.2 +else: + OrderedDict = typing._alias(collections.OrderedDict, (KT, VT)) + +Counter = typing.Counter +ChainMap = typing.ChainMap +AsyncGenerator = typing.AsyncGenerator +Text = typing.Text +TYPE_CHECKING = typing.TYPE_CHECKING + + +_PROTO_ALLOWLIST = { + 'collections.abc': [ + 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', 'Buffer', + ], + 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'], + 'typing_extensions': ['Buffer'], +} + + +_EXCLUDED_ATTRS = { + "__abstractmethods__", "__annotations__", "__weakref__", "_is_protocol", + "_is_runtime_protocol", "__dict__", "__slots__", "__parameters__", + "__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__", + "__subclasshook__", "__orig_class__", "__init__", "__new__", + "__protocol_attrs__", "__callable_proto_members_only__", +} + +if sys.version_info < (3, 8): + _EXCLUDED_ATTRS |= { + "_gorg", "__next_in_mro__", "__extra__", "__tree_hash__", "__args__", + "__origin__" + } + +if sys.version_info >= (3, 9): + _EXCLUDED_ATTRS.add("__class_getitem__") + +if sys.version_info >= (3, 12): + _EXCLUDED_ATTRS.add("__type_params__") + +_EXCLUDED_ATTRS = frozenset(_EXCLUDED_ATTRS) + + +def _get_protocol_attrs(cls): + attrs = set() + for base in cls.__mro__[:-1]: # without object + if base.__name__ in {'Protocol', 'Generic'}: + continue + annotations = getattr(base, '__annotations__', {}) + for attr in (*base.__dict__, *annotations): + if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS): + attrs.add(attr) + return attrs + + +def _maybe_adjust_parameters(cls): + """Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__. + + The contents of this function are very similar + to logic found in typing.Generic.__init_subclass__ + on the CPython main branch. + """ + tvars = [] + if '__orig_bases__' in cls.__dict__: + tvars = _collect_type_vars(cls.__orig_bases__) + # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. + # If found, tvars must be a subset of it. + # If not found, tvars is it. + # Also check for and reject plain Generic, + # and reject multiple Generic[...] and/or Protocol[...]. + gvars = None + for base in cls.__orig_bases__: + if (isinstance(base, typing._GenericAlias) and + base.__origin__ in (typing.Generic, Protocol)): + # for error messages + the_base = base.__origin__.__name__ + if gvars is not None: + raise TypeError( + "Cannot inherit from Generic[...]" + " and/or Protocol[...] multiple types.") + gvars = base.__parameters__ + if gvars is None: + gvars = tvars + else: + tvarset = set(tvars) + gvarset = set(gvars) + if not tvarset <= gvarset: + s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) + s_args = ', '.join(str(g) for g in gvars) + raise TypeError(f"Some type variables ({s_vars}) are" + f" not listed in {the_base}[{s_args}]") + tvars = gvars + cls.__parameters__ = tuple(tvars) + + +def _caller(depth=2): + try: + return sys._getframe(depth).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): # For platforms without _getframe() + return None + + +# The performance of runtime-checkable protocols is significantly improved on Python 3.12, +# so we backport the 3.12 version of Protocol to Python <=3.11 +if sys.version_info >= (3, 12): + Protocol = typing.Protocol +else: + def _allow_reckless_class_checks(depth=3): + """Allow instance and class checks for special stdlib modules. + The abc and functools modules indiscriminately call isinstance() and + issubclass() on the whole MRO of a user class, which may contain protocols. + """ + return _caller(depth) in {'abc', 'functools', None} + + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + + if sys.version_info >= (3, 8): + # Inheriting from typing._ProtocolMeta isn't actually desirable, + # but is necessary to allow typing.Protocol and typing_extensions.Protocol + # to mix without getting TypeErrors about "metaclass conflict" + _typing_Protocol = typing.Protocol + _ProtocolMetaBase = type(_typing_Protocol) + else: + _typing_Protocol = _marker + _ProtocolMetaBase = abc.ABCMeta + + class _ProtocolMeta(_ProtocolMetaBase): + # This metaclass is somewhat unfortunate, + # but is necessary for several reasons... + # + # NOTE: DO NOT call super() in any methods in this class + # That would call the methods on typing._ProtocolMeta on Python 3.8-3.11 + # and those are slow + def __new__(mcls, name, bases, namespace, **kwargs): + if name == "Protocol" and len(bases) < 2: + pass + elif {Protocol, _typing_Protocol} & set(bases): + for base in bases: + if not ( + base in {object, typing.Generic, Protocol, _typing_Protocol} + or base.__name__ in _PROTO_ALLOWLIST.get(base.__module__, []) + or is_protocol(base) + ): + raise TypeError( + f"Protocols can only inherit from other protocols, " + f"got {base!r}" + ) + return abc.ABCMeta.__new__(mcls, name, bases, namespace, **kwargs) + + def __init__(cls, *args, **kwargs): + abc.ABCMeta.__init__(cls, *args, **kwargs) + if getattr(cls, "_is_protocol", False): + cls.__protocol_attrs__ = _get_protocol_attrs(cls) + # PEP 544 prohibits using issubclass() + # with protocols that have non-method members. + cls.__callable_proto_members_only__ = all( + callable(getattr(cls, attr, None)) for attr in cls.__protocol_attrs__ + ) + + def __subclasscheck__(cls, other): + if cls is Protocol: + return type.__subclasscheck__(cls, other) + if ( + getattr(cls, '_is_protocol', False) + and not _allow_reckless_class_checks() + ): + if not isinstance(other, type): + # Same error message as for issubclass(1, int). + raise TypeError('issubclass() arg 1 must be a class') + if ( + not cls.__callable_proto_members_only__ + and cls.__dict__.get("__subclasshook__") is _proto_hook + ): + raise TypeError( + "Protocols with non-method members don't support issubclass()" + ) + if not getattr(cls, '_is_runtime_protocol', False): + raise TypeError( + "Instance and class checks can only be used with " + "@runtime_checkable protocols" + ) + return abc.ABCMeta.__subclasscheck__(cls, other) + + def __instancecheck__(cls, instance): + # We need this method for situations where attributes are + # assigned in __init__. + if cls is Protocol: + return type.__instancecheck__(cls, instance) + if not getattr(cls, "_is_protocol", False): + # i.e., it's a concrete subclass of a protocol + return abc.ABCMeta.__instancecheck__(cls, instance) + + if ( + not getattr(cls, '_is_runtime_protocol', False) and + not _allow_reckless_class_checks() + ): + raise TypeError("Instance and class checks can only be used with" + " @runtime_checkable protocols") + + if abc.ABCMeta.__instancecheck__(cls, instance): + return True + + for attr in cls.__protocol_attrs__: + try: + val = inspect.getattr_static(instance, attr) + except AttributeError: + break + if val is None and callable(getattr(cls, attr, None)): + break + else: + return True + + return False + + def __eq__(cls, other): + # Hack so that typing.Generic.__class_getitem__ + # treats typing_extensions.Protocol + # as equivalent to typing.Protocol on Python 3.8+ + if abc.ABCMeta.__eq__(cls, other) is True: + return True + return ( + cls is Protocol and other is getattr(typing, "Protocol", object()) + ) + + # This has to be defined, or the abc-module cache + # complains about classes with this metaclass being unhashable, + # if we define only __eq__! + def __hash__(cls) -> int: + return type.__hash__(cls) + + @classmethod + def _proto_hook(cls, other): + if not cls.__dict__.get('_is_protocol', False): + return NotImplemented + + for attr in cls.__protocol_attrs__: + for base in other.__mro__: + # Check if the members appears in the class dictionary... + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + + # ...or in annotations, if it is a sub-protocol. + annotations = getattr(base, '__annotations__', {}) + if ( + isinstance(annotations, collections.abc.Mapping) + and attr in annotations + and is_protocol(other) + ): + break + else: + return NotImplemented + return True + + if sys.version_info >= (3, 8): + class Protocol(typing.Generic, metaclass=_ProtocolMeta): + __doc__ = typing.Protocol.__doc__ + __slots__ = () + _is_protocol = True + _is_runtime_protocol = False + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + + # Determine if this is a protocol or a concrete subclass. + if not cls.__dict__.get('_is_protocol', False): + cls._is_protocol = any(b is Protocol for b in cls.__bases__) + + # Set (or override) the protocol subclass hook. + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + # Prohibit instantiation for protocol classes + if cls._is_protocol and cls.__init__ is Protocol.__init__: + cls.__init__ = _no_init + + else: + class Protocol(metaclass=_ProtocolMeta): + # There is quite a lot of overlapping code with typing.Generic. + # Unfortunately it is hard to avoid this on Python <3.8, + # as the typing module on Python 3.7 doesn't let us subclass typing.Generic! + """Base class for protocol classes. Protocol classes are defined as:: + + class Proto(Protocol): + def meth(self) -> int: + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self) -> int: + return 0 + + def func(x: Proto) -> int: + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with + @typing_extensions.runtime_checkable act + as simple-minded runtime-checkable protocols that check + only the presence of given attributes, ignoring their type signatures. + + Protocol classes can be generic, they are defined as:: + + class GenProto(Protocol[T]): + def meth(self) -> T: + ... + """ + __slots__ = () + _is_protocol = True + _is_runtime_protocol = False + + def __new__(cls, *args, **kwds): + if cls is Protocol: + raise TypeError("Type Protocol cannot be instantiated; " + "it can only be used as a base class") + return super().__new__(cls) + + @typing._tp_cache + def __class_getitem__(cls, params): + if not isinstance(params, tuple): + params = (params,) + if not params and cls is not typing.Tuple: + raise TypeError( + f"Parameter list to {cls.__qualname__}[...] cannot be empty") + msg = "Parameters to generic types must be types." + params = tuple(typing._type_check(p, msg) for p in params) + if cls is Protocol: + # Generic can only be subscripted with unique type variables. + if not all(isinstance(p, typing.TypeVar) for p in params): + i = 0 + while isinstance(params[i], typing.TypeVar): + i += 1 + raise TypeError( + "Parameters to Protocol[...] must all be type variables." + f" Parameter {i + 1} is {params[i]}") + if len(set(params)) != len(params): + raise TypeError( + "Parameters to Protocol[...] must all be unique") + else: + # Subscripting a regular Generic subclass. + _check_generic(cls, params, len(cls.__parameters__)) + return typing._GenericAlias(cls, params) + + def __init_subclass__(cls, *args, **kwargs): + if '__orig_bases__' in cls.__dict__: + error = typing.Generic in cls.__orig_bases__ + else: + error = typing.Generic in cls.__bases__ + if error: + raise TypeError("Cannot inherit from plain Generic") + _maybe_adjust_parameters(cls) + + # Determine if this is a protocol or a concrete subclass. + if not cls.__dict__.get('_is_protocol', None): + cls._is_protocol = any(b is Protocol for b in cls.__bases__) + + # Set (or override) the protocol subclass hook. + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + # Prohibit instantiation for protocol classes + if cls._is_protocol and cls.__init__ is Protocol.__init__: + cls.__init__ = _no_init + + +if sys.version_info >= (3, 8): + runtime_checkable = typing.runtime_checkable +else: + def runtime_checkable(cls): + """Mark a protocol class as a runtime protocol, so that it + can be used with isinstance() and issubclass(). Raise TypeError + if applied to a non-protocol class. + + This allows a simple-minded structural check very similar to the + one-offs in collections.abc such as Hashable. + """ + if not ( + (isinstance(cls, _ProtocolMeta) or issubclass(cls, typing.Generic)) + and getattr(cls, "_is_protocol", False) + ): + raise TypeError('@runtime_checkable can be only applied to protocol classes,' + f' got {cls!r}') + cls._is_runtime_protocol = True + return cls + + +# Exists for backwards compatibility. +runtime = runtime_checkable + + +# Our version of runtime-checkable protocols is faster on Python 3.7-3.11 +if sys.version_info >= (3, 12): + SupportsInt = typing.SupportsInt + SupportsFloat = typing.SupportsFloat + SupportsComplex = typing.SupportsComplex + SupportsBytes = typing.SupportsBytes + SupportsIndex = typing.SupportsIndex + SupportsAbs = typing.SupportsAbs + SupportsRound = typing.SupportsRound +else: + @runtime_checkable + class SupportsInt(Protocol): + """An ABC with one abstract method __int__.""" + __slots__ = () + + @abc.abstractmethod + def __int__(self) -> int: + pass + + @runtime_checkable + class SupportsFloat(Protocol): + """An ABC with one abstract method __float__.""" + __slots__ = () + + @abc.abstractmethod + def __float__(self) -> float: + pass + + @runtime_checkable + class SupportsComplex(Protocol): + """An ABC with one abstract method __complex__.""" + __slots__ = () + + @abc.abstractmethod + def __complex__(self) -> complex: + pass + + @runtime_checkable + class SupportsBytes(Protocol): + """An ABC with one abstract method __bytes__.""" + __slots__ = () + + @abc.abstractmethod + def __bytes__(self) -> bytes: + pass + + @runtime_checkable + class SupportsIndex(Protocol): + __slots__ = () + + @abc.abstractmethod + def __index__(self) -> int: + pass + + @runtime_checkable + class SupportsAbs(Protocol[T_co]): + """ + An ABC with one abstract method __abs__ that is covariant in its return type. + """ + __slots__ = () + + @abc.abstractmethod + def __abs__(self) -> T_co: + pass + + @runtime_checkable + class SupportsRound(Protocol[T_co]): + """ + An ABC with one abstract method __round__ that is covariant in its return type. + """ + __slots__ = () + + @abc.abstractmethod + def __round__(self, ndigits: int = 0) -> T_co: + pass + + +if sys.version_info >= (3, 13): + # The standard library TypedDict in Python 3.8 does not store runtime information + # about which (if any) keys are optional. See https://bugs.python.org/issue38834 + # The standard library TypedDict in Python 3.9.0/1 does not honour the "total" + # keyword with old-style TypedDict(). See https://bugs.python.org/issue42059 + # The standard library TypedDict below Python 3.11 does not store runtime + # information about optional and required keys when using Required or NotRequired. + # Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11. + # Aaaand on 3.12 we add __orig_bases__ to TypedDict + # to enable better runtime introspection. + # On 3.13 we deprecate some odd ways of creating TypedDicts. + TypedDict = typing.TypedDict + _TypedDictMeta = typing._TypedDictMeta + is_typeddict = typing.is_typeddict +else: + # 3.10.0 and later + _TAKES_MODULE = "module" in inspect.signature(typing._type_check).parameters + + if sys.version_info >= (3, 8): + _fake_name = "Protocol" + else: + _fake_name = "_Protocol" + + class _TypedDictMeta(type): + def __new__(cls, name, bases, ns, total=True): + """Create new typed dict class object. + + This method is called when TypedDict is subclassed, + or when TypedDict is instantiated. This way + TypedDict supports all three syntax forms described in its docstring. + Subclasses and instances of TypedDict return actual dictionaries. + """ + for base in bases: + if type(base) is not _TypedDictMeta and base is not typing.Generic: + raise TypeError('cannot inherit from both a TypedDict type ' + 'and a non-TypedDict base class') + + if any(issubclass(b, typing.Generic) for b in bases): + generic_base = (typing.Generic,) + else: + generic_base = () + + # typing.py generally doesn't let you inherit from plain Generic, unless + # the name of the class happens to be "Protocol" (or "_Protocol" on 3.7). + tp_dict = type.__new__(_TypedDictMeta, _fake_name, (*generic_base, dict), ns) + tp_dict.__name__ = name + if tp_dict.__qualname__ == _fake_name: + tp_dict.__qualname__ = name + + if not hasattr(tp_dict, '__orig_bases__'): + tp_dict.__orig_bases__ = bases + + annotations = {} + own_annotations = ns.get('__annotations__', {}) + msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" + if _TAKES_MODULE: + own_annotations = { + n: typing._type_check(tp, msg, module=tp_dict.__module__) + for n, tp in own_annotations.items() + } + else: + own_annotations = { + n: typing._type_check(tp, msg) + for n, tp in own_annotations.items() + } + required_keys = set() + optional_keys = set() + + for base in bases: + annotations.update(base.__dict__.get('__annotations__', {})) + required_keys.update(base.__dict__.get('__required_keys__', ())) + optional_keys.update(base.__dict__.get('__optional_keys__', ())) + + annotations.update(own_annotations) + for annotation_key, annotation_type in own_annotations.items(): + annotation_origin = get_origin(annotation_type) + if annotation_origin is Annotated: + annotation_args = get_args(annotation_type) + if annotation_args: + annotation_type = annotation_args[0] + annotation_origin = get_origin(annotation_type) + + if annotation_origin is Required: + required_keys.add(annotation_key) + elif annotation_origin is NotRequired: + optional_keys.add(annotation_key) + elif total: + required_keys.add(annotation_key) + else: + optional_keys.add(annotation_key) + + tp_dict.__annotations__ = annotations + tp_dict.__required_keys__ = frozenset(required_keys) + tp_dict.__optional_keys__ = frozenset(optional_keys) + if not hasattr(tp_dict, '__total__'): + tp_dict.__total__ = total + return tp_dict + + __call__ = dict # static method + + def __subclasscheck__(cls, other): + # Typed dicts are only for static structural subtyping. + raise TypeError('TypedDict does not support instance and class checks') + + __instancecheck__ = __subclasscheck__ + + def TypedDict(__typename, __fields=_marker, *, total=True, **kwargs): + """A simple typed namespace. At runtime it is equivalent to a plain dict. + + TypedDict creates a dictionary type such that a type checker will expect all + instances to have a certain set of keys, where each key is + associated with a value of a consistent type. This expectation + is not checked at runtime. + + Usage:: + + class Point2D(TypedDict): + x: int + y: int + label: str + + a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK + b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check + + assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') + + The type info can be accessed via the Point2D.__annotations__ dict, and + the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets. + TypedDict supports an additional equivalent form:: + + Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str}) + + By default, all keys must be present in a TypedDict. It is possible + to override this by specifying totality:: + + class Point2D(TypedDict, total=False): + x: int + y: int + + This means that a Point2D TypedDict can have any of the keys omitted. A type + checker is only expected to support a literal False or True as the value of + the total argument. True is the default, and makes all items defined in the + class body be required. + + The Required and NotRequired special forms can also be used to mark + individual keys as being required or not required:: + + class Point2D(TypedDict): + x: int # the "x" key must always be present (Required is the default) + y: NotRequired[int] # the "y" key can be omitted + + See PEP 655 for more details on Required and NotRequired. + """ + if __fields is _marker or __fields is None: + if __fields is _marker: + deprecated_thing = "Failing to pass a value for the 'fields' parameter" + else: + deprecated_thing = "Passing `None` as the 'fields' parameter" + + example = f"`{__typename} = TypedDict({__typename!r}, {{}})`" + deprecation_msg = ( + f"{deprecated_thing} is deprecated and will be disallowed in " + "Python 3.15. To create a TypedDict class with 0 fields " + "using the functional syntax, pass an empty dictionary, e.g. " + ) + example + "." + warnings.warn(deprecation_msg, DeprecationWarning, stacklevel=2) + __fields = kwargs + elif kwargs: + raise TypeError("TypedDict takes either a dict or keyword arguments," + " but not both") + if kwargs: + warnings.warn( + "The kwargs-based syntax for TypedDict definitions is deprecated " + "in Python 3.11, will be removed in Python 3.13, and may not be " + "understood by third-party type checkers.", + DeprecationWarning, + stacklevel=2, + ) + + ns = {'__annotations__': dict(__fields)} + module = _caller() + if module is not None: + # Setting correct module is necessary to make typed dict classes pickleable. + ns['__module__'] = module + + td = _TypedDictMeta(__typename, (), ns, total=total) + td.__orig_bases__ = (TypedDict,) + return td + + _TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {}) + TypedDict.__mro_entries__ = lambda bases: (_TypedDict,) + + if hasattr(typing, "_TypedDictMeta"): + _TYPEDDICT_TYPES = (typing._TypedDictMeta, _TypedDictMeta) + else: + _TYPEDDICT_TYPES = (_TypedDictMeta,) + + def is_typeddict(tp): + """Check if an annotation is a TypedDict class + + For example:: + class Film(TypedDict): + title: str + year: int + + is_typeddict(Film) # => True + is_typeddict(Union[list, str]) # => False + """ + # On 3.8, this would otherwise return True + if hasattr(typing, "TypedDict") and tp is typing.TypedDict: + return False + return isinstance(tp, _TYPEDDICT_TYPES) + + +if hasattr(typing, "assert_type"): + assert_type = typing.assert_type + +else: + def assert_type(__val, __typ): + """Assert (to the type checker) that the value is of the given type. + + When the type checker encounters a call to assert_type(), it + emits an error if the value is not of the specified type:: + + def greet(name: str) -> None: + assert_type(name, str) # ok + assert_type(name, int) # type checker error + + At runtime this returns the first argument unchanged and otherwise + does nothing. + """ + return __val + + +if hasattr(typing, "Required"): + get_type_hints = typing.get_type_hints +else: + # replaces _strip_annotations() + def _strip_extras(t): + """Strips Annotated, Required and NotRequired from a given type.""" + if isinstance(t, _AnnotatedAlias): + return _strip_extras(t.__origin__) + if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired): + return _strip_extras(t.__args__[0]) + if isinstance(t, typing._GenericAlias): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return t.copy_with(stripped_args) + if hasattr(_types, "GenericAlias") and isinstance(t, _types.GenericAlias): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return _types.GenericAlias(t.__origin__, stripped_args) + if hasattr(_types, "UnionType") and isinstance(t, _types.UnionType): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return functools.reduce(operator.or_, stripped_args) + + return t + + def get_type_hints(obj, globalns=None, localns=None, include_extras=False): + """Return type hints for an object. + + This is often the same as obj.__annotations__, but it handles + forward references encoded as string literals, adds Optional[t] if a + default value equal to None is set and recursively replaces all + 'Annotated[T, ...]', 'Required[T]' or 'NotRequired[T]' with 'T' + (unless 'include_extras=True'). + + The argument may be a module, class, method, or function. The annotations + are returned as a dictionary. For classes, annotations include also + inherited members. + + TypeError is raised if the argument is not of a type that can contain + annotations, and an empty dictionary is returned if no annotations are + present. + + BEWARE -- the behavior of globalns and localns is counterintuitive + (unless you are familiar with how eval() and exec() work). The + search order is locals first, then globals. + + - If no dict arguments are passed, an attempt is made to use the + globals from obj (or the respective module's globals for classes), + and these are also used as the locals. If the object does not appear + to have globals, an empty dictionary is used. + + - If one dict argument is passed, it is used for both globals and + locals. + + - If two dict arguments are passed, they specify globals and + locals, respectively. + """ + if hasattr(typing, "Annotated"): + hint = typing.get_type_hints( + obj, globalns=globalns, localns=localns, include_extras=True + ) + else: + hint = typing.get_type_hints(obj, globalns=globalns, localns=localns) + if include_extras: + return hint + return {k: _strip_extras(t) for k, t in hint.items()} + + +# Python 3.9+ has PEP 593 (Annotated) +if hasattr(typing, 'Annotated'): + Annotated = typing.Annotated + # Not exported and not a public API, but needed for get_origin() and get_args() + # to work. + _AnnotatedAlias = typing._AnnotatedAlias +# 3.7-3.8 +else: + class _AnnotatedAlias(typing._GenericAlias, _root=True): + """Runtime representation of an annotated type. + + At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't' + with extra annotations. The alias behaves like a normal typing alias, + instantiating is the same as instantiating the underlying type, binding + it to types is also the same. + """ + def __init__(self, origin, metadata): + if isinstance(origin, _AnnotatedAlias): + metadata = origin.__metadata__ + metadata + origin = origin.__origin__ + super().__init__(origin, origin) + self.__metadata__ = metadata + + def copy_with(self, params): + assert len(params) == 1 + new_type = params[0] + return _AnnotatedAlias(new_type, self.__metadata__) + + def __repr__(self): + return (f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " + f"{', '.join(repr(a) for a in self.__metadata__)}]") + + def __reduce__(self): + return operator.getitem, ( + Annotated, (self.__origin__,) + self.__metadata__ + ) + + def __eq__(self, other): + if not isinstance(other, _AnnotatedAlias): + return NotImplemented + if self.__origin__ != other.__origin__: + return False + return self.__metadata__ == other.__metadata__ + + def __hash__(self): + return hash((self.__origin__, self.__metadata__)) + + class Annotated: + """Add context specific metadata to a type. + + Example: Annotated[int, runtime_check.Unsigned] indicates to the + hypothetical runtime_check module that this type is an unsigned int. + Every other consumer of this type can ignore this metadata and treat + this type as int. + + The first argument to Annotated must be a valid type (and will be in + the __origin__ field), the remaining arguments are kept as a tuple in + the __extra__ field. + + Details: + + - It's an error to call `Annotated` with less than two arguments. + - Nested Annotated are flattened:: + + Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] + + - Instantiating an annotated type is equivalent to instantiating the + underlying type:: + + Annotated[C, Ann1](5) == C(5) + + - Annotated can be used as a generic type alias:: + + Optimized = Annotated[T, runtime.Optimize()] + Optimized[int] == Annotated[int, runtime.Optimize()] + + OptimizedList = Annotated[List[T], runtime.Optimize()] + OptimizedList[int] == Annotated[List[int], runtime.Optimize()] + """ + + __slots__ = () + + def __new__(cls, *args, **kwargs): + raise TypeError("Type Annotated cannot be instantiated.") + + @typing._tp_cache + def __class_getitem__(cls, params): + if not isinstance(params, tuple) or len(params) < 2: + raise TypeError("Annotated[...] should be used " + "with at least two arguments (a type and an " + "annotation).") + allowed_special_forms = (ClassVar, Final) + if get_origin(params[0]) in allowed_special_forms: + origin = params[0] + else: + msg = "Annotated[t, ...]: t must be a type." + origin = typing._type_check(params[0], msg) + metadata = tuple(params[1:]) + return _AnnotatedAlias(origin, metadata) + + def __init_subclass__(cls, *args, **kwargs): + raise TypeError( + f"Cannot subclass {cls.__module__}.Annotated" + ) + +# Python 3.8 has get_origin() and get_args() but those implementations aren't +# Annotated-aware, so we can't use those. Python 3.9's versions don't support +# ParamSpecArgs and ParamSpecKwargs, so only Python 3.10's versions will do. +if sys.version_info[:2] >= (3, 10): + get_origin = typing.get_origin + get_args = typing.get_args +# 3.7-3.9 +else: + try: + # 3.9+ + from typing import _BaseGenericAlias + except ImportError: + _BaseGenericAlias = typing._GenericAlias + try: + # 3.9+ + from typing import GenericAlias as _typing_GenericAlias + except ImportError: + _typing_GenericAlias = typing._GenericAlias + + def get_origin(tp): + """Get the unsubscripted version of a type. + + This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar + and Annotated. Return None for unsupported types. Examples:: + + get_origin(Literal[42]) is Literal + get_origin(int) is None + get_origin(ClassVar[int]) is ClassVar + get_origin(Generic) is Generic + get_origin(Generic[T]) is Generic + get_origin(Union[T, int]) is Union + get_origin(List[Tuple[T, T]][int]) == list + get_origin(P.args) is P + """ + if isinstance(tp, _AnnotatedAlias): + return Annotated + if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias, _BaseGenericAlias, + ParamSpecArgs, ParamSpecKwargs)): + return tp.__origin__ + if tp is typing.Generic: + return typing.Generic + return None + + def get_args(tp): + """Get type arguments with all substitutions performed. + + For unions, basic simplifications used by Union constructor are performed. + Examples:: + get_args(Dict[str, int]) == (str, int) + get_args(int) == () + get_args(Union[int, Union[T, int], str][int]) == (int, str) + get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + get_args(Callable[[], T][int]) == ([], int) + """ + if isinstance(tp, _AnnotatedAlias): + return (tp.__origin__,) + tp.__metadata__ + if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias)): + if getattr(tp, "_special", False): + return () + res = tp.__args__ + if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: + res = (list(res[:-1]), res[-1]) + return res + return () + + +# 3.10+ +if hasattr(typing, 'TypeAlias'): + TypeAlias = typing.TypeAlias +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def TypeAlias(self, parameters): + """Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. + + For example:: + + Predicate: TypeAlias = Callable[..., bool] + + It's invalid when used anywhere except as in the example above. + """ + raise TypeError(f"{self} is not subscriptable") +# 3.7-3.8 +else: + TypeAlias = _ExtensionsSpecialForm( + 'TypeAlias', + doc="""Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. + + For example:: + + Predicate: TypeAlias = Callable[..., bool] + + It's invalid when used anywhere except as in the example + above.""" + ) + + +def _set_default(type_param, default): + if isinstance(default, (tuple, list)): + type_param.__default__ = tuple((typing._type_check(d, "Default must be a type") + for d in default)) + elif default != _marker: + type_param.__default__ = typing._type_check(default, "Default must be a type") + else: + type_param.__default__ = None + + +def _set_module(typevarlike): + # for pickling: + def_mod = _caller(depth=3) + if def_mod != 'typing_extensions': + typevarlike.__module__ = def_mod + + +class _DefaultMixin: + """Mixin for TypeVarLike defaults.""" + + __slots__ = () + __init__ = _set_default + + +# Classes using this metaclass must provide a _backported_typevarlike ClassVar +class _TypeVarLikeMeta(type): + def __instancecheck__(cls, __instance: Any) -> bool: + return isinstance(__instance, cls._backported_typevarlike) + + +# Add default and infer_variance parameters from PEP 696 and 695 +class TypeVar(metaclass=_TypeVarLikeMeta): + """Type variable.""" + + _backported_typevarlike = typing.TypeVar + + def __new__(cls, name, *constraints, bound=None, + covariant=False, contravariant=False, + default=_marker, infer_variance=False): + if hasattr(typing, "TypeAliasType"): + # PEP 695 implemented, can pass infer_variance to typing.TypeVar + typevar = typing.TypeVar(name, *constraints, bound=bound, + covariant=covariant, contravariant=contravariant, + infer_variance=infer_variance) + else: + typevar = typing.TypeVar(name, *constraints, bound=bound, + covariant=covariant, contravariant=contravariant) + if infer_variance and (covariant or contravariant): + raise ValueError("Variance cannot be specified with infer_variance.") + typevar.__infer_variance__ = infer_variance + _set_default(typevar, default) + _set_module(typevar) + return typevar + + def __init_subclass__(cls) -> None: + raise TypeError(f"type '{__name__}.TypeVar' is not an acceptable base type") + + +# Python 3.10+ has PEP 612 +if hasattr(typing, 'ParamSpecArgs'): + ParamSpecArgs = typing.ParamSpecArgs + ParamSpecKwargs = typing.ParamSpecKwargs +# 3.7-3.9 +else: + class _Immutable: + """Mixin to indicate that object should not be copied.""" + __slots__ = () + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + class ParamSpecArgs(_Immutable): + """The args for a ParamSpec object. + + Given a ParamSpec object P, P.args is an instance of ParamSpecArgs. + + ParamSpecArgs objects have a reference back to their ParamSpec: + + P.args.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.args" + + def __eq__(self, other): + if not isinstance(other, ParamSpecArgs): + return NotImplemented + return self.__origin__ == other.__origin__ + + class ParamSpecKwargs(_Immutable): + """The kwargs for a ParamSpec object. + + Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs. + + ParamSpecKwargs objects have a reference back to their ParamSpec: + + P.kwargs.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.kwargs" + + def __eq__(self, other): + if not isinstance(other, ParamSpecKwargs): + return NotImplemented + return self.__origin__ == other.__origin__ + +# 3.10+ +if hasattr(typing, 'ParamSpec'): + + # Add default parameter - PEP 696 + class ParamSpec(metaclass=_TypeVarLikeMeta): + """Parameter specification.""" + + _backported_typevarlike = typing.ParamSpec + + def __new__(cls, name, *, bound=None, + covariant=False, contravariant=False, + infer_variance=False, default=_marker): + if hasattr(typing, "TypeAliasType"): + # PEP 695 implemented, can pass infer_variance to typing.TypeVar + paramspec = typing.ParamSpec(name, bound=bound, + covariant=covariant, + contravariant=contravariant, + infer_variance=infer_variance) + else: + paramspec = typing.ParamSpec(name, bound=bound, + covariant=covariant, + contravariant=contravariant) + paramspec.__infer_variance__ = infer_variance + + _set_default(paramspec, default) + _set_module(paramspec) + return paramspec + + def __init_subclass__(cls) -> None: + raise TypeError(f"type '{__name__}.ParamSpec' is not an acceptable base type") + +# 3.7-3.9 +else: + + # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + class ParamSpec(list, _DefaultMixin): + """Parameter specification variable. + + Usage:: + + P = ParamSpec('P') + + Parameter specification variables exist primarily for the benefit of static + type checkers. They are used to forward the parameter types of one + callable to another callable, a pattern commonly found in higher order + functions and decorators. They are only valid when used in ``Concatenate``, + or s the first argument to ``Callable``. In Python 3.10 and higher, + they are also supported in user-defined Generics at runtime. + See class Generic for more information on generic types. An + example for annotating a decorator:: + + T = TypeVar('T') + P = ParamSpec('P') + + def add_logging(f: Callable[P, T]) -> Callable[P, T]: + '''A type-safe decorator to add logging to a function.''' + def inner(*args: P.args, **kwargs: P.kwargs) -> T: + logging.info(f'{f.__name__} was called') + return f(*args, **kwargs) + return inner + + @add_logging + def add_two(x: float, y: float) -> float: + '''Add two numbers together.''' + return x + y + + Parameter specification variables defined with covariant=True or + contravariant=True can be used to declare covariant or contravariant + generic types. These keyword arguments are valid, but their actual semantics + are yet to be decided. See PEP 612 for details. + + Parameter specification variables can be introspected. e.g.: + + P.__name__ == 'T' + P.__bound__ == None + P.__covariant__ == False + P.__contravariant__ == False + + Note that only parameter specification variables defined in global scope can + be pickled. + """ + + # Trick Generic __parameters__. + __class__ = typing.TypeVar + + @property + def args(self): + return ParamSpecArgs(self) + + @property + def kwargs(self): + return ParamSpecKwargs(self) + + def __init__(self, name, *, bound=None, covariant=False, contravariant=False, + infer_variance=False, default=_marker): + super().__init__([self]) + self.__name__ = name + self.__covariant__ = bool(covariant) + self.__contravariant__ = bool(contravariant) + self.__infer_variance__ = bool(infer_variance) + if bound: + self.__bound__ = typing._type_check(bound, 'Bound must be a type.') + else: + self.__bound__ = None + _DefaultMixin.__init__(self, default) + + # for pickling: + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + def __repr__(self): + if self.__infer_variance__: + prefix = '' + elif self.__covariant__: + prefix = '+' + elif self.__contravariant__: + prefix = '-' + else: + prefix = '~' + return prefix + self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + # Hack to get typing._type_check to pass. + def __call__(self, *args, **kwargs): + pass + + +# 3.7-3.9 +if not hasattr(typing, 'Concatenate'): + # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + class _ConcatenateGenericAlias(list): + + # Trick Generic into looking into this for __parameters__. + __class__ = typing._GenericAlias + + # Flag in 3.8. + _special = False + + def __init__(self, origin, args): + super().__init__(args) + self.__origin__ = origin + self.__args__ = args + + def __repr__(self): + _type_repr = typing._type_repr + return (f'{_type_repr(self.__origin__)}' + f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]') + + def __hash__(self): + return hash((self.__origin__, self.__args__)) + + # Hack to get typing._type_check to pass in Generic. + def __call__(self, *args, **kwargs): + pass + + @property + def __parameters__(self): + return tuple( + tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) + ) + + +# 3.7-3.9 +@typing._tp_cache +def _concatenate_getitem(self, parameters): + if parameters == (): + raise TypeError("Cannot take a Concatenate of no types.") + if not isinstance(parameters, tuple): + parameters = (parameters,) + if not isinstance(parameters[-1], ParamSpec): + raise TypeError("The last parameter to Concatenate should be a " + "ParamSpec variable.") + msg = "Concatenate[arg, ...]: each arg must be a type." + parameters = tuple(typing._type_check(p, msg) for p in parameters) + return _ConcatenateGenericAlias(self, parameters) + + +# 3.10+ +if hasattr(typing, 'Concatenate'): + Concatenate = typing.Concatenate + _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa: F811 +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def Concatenate(self, parameters): + """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a + higher order function which adds, removes or transforms parameters of a + callable. + + For example:: + + Callable[Concatenate[int, P], int] + + See PEP 612 for detailed information. + """ + return _concatenate_getitem(self, parameters) +# 3.7-8 +else: + class _ConcatenateForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + return _concatenate_getitem(self, parameters) + + Concatenate = _ConcatenateForm( + 'Concatenate', + doc="""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a + higher order function which adds, removes or transforms parameters of a + callable. + + For example:: + + Callable[Concatenate[int, P], int] + + See PEP 612 for detailed information. + """) + +# 3.10+ +if hasattr(typing, 'TypeGuard'): + TypeGuard = typing.TypeGuard +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def TypeGuard(self, parameters): + """Special typing form used to annotate the return type of a user-defined + type guard function. ``TypeGuard`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeGuard[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeGuard`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the type inside ``TypeGuard``. + + For example:: + + def is_str(val: Union[str, float]): + # "isinstance" type guard + if isinstance(val, str): + # Type of ``val`` is narrowed to ``str`` + ... + else: + # Else, type of ``val`` is narrowed to ``float``. + ... + + Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower + form of ``TypeA`` (it can even be a wider form) and this may lead to + type-unsafe results. The main reason is to allow for things like + narrowing ``List[object]`` to ``List[str]`` even though the latter is not + a subtype of the former, since ``List`` is invariant. The responsibility of + writing type-safe type guards is left to the user. + + ``TypeGuard`` also works with type variables. For more information, see + PEP 647 (User-Defined Type Guards). + """ + item = typing._type_check(parameters, f'{self} accepts only a single type.') + return typing._GenericAlias(self, (item,)) +# 3.7-3.8 +else: + class _TypeGuardForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type') + return typing._GenericAlias(self, (item,)) + + TypeGuard = _TypeGuardForm( + 'TypeGuard', + doc="""Special typing form used to annotate the return type of a user-defined + type guard function. ``TypeGuard`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeGuard[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeGuard`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the type inside ``TypeGuard``. + + For example:: + + def is_str(val: Union[str, float]): + # "isinstance" type guard + if isinstance(val, str): + # Type of ``val`` is narrowed to ``str`` + ... + else: + # Else, type of ``val`` is narrowed to ``float``. + ... + + Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower + form of ``TypeA`` (it can even be a wider form) and this may lead to + type-unsafe results. The main reason is to allow for things like + narrowing ``List[object]`` to ``List[str]`` even though the latter is not + a subtype of the former, since ``List`` is invariant. The responsibility of + writing type-safe type guards is left to the user. + + ``TypeGuard`` also works with type variables. For more information, see + PEP 647 (User-Defined Type Guards). + """) + + +# Vendored from cpython typing._SpecialFrom +class _SpecialForm(typing._Final, _root=True): + __slots__ = ('_name', '__doc__', '_getitem') + + def __init__(self, getitem): + self._getitem = getitem + self._name = getitem.__name__ + self.__doc__ = getitem.__doc__ + + def __getattr__(self, item): + if item in {'__name__', '__qualname__'}: + return self._name + + raise AttributeError(item) + + def __mro_entries__(self, bases): + raise TypeError(f"Cannot subclass {self!r}") + + def __repr__(self): + return f'typing_extensions.{self._name}' + + def __reduce__(self): + return self._name + + def __call__(self, *args, **kwds): + raise TypeError(f"Cannot instantiate {self!r}") + + def __or__(self, other): + return typing.Union[self, other] + + def __ror__(self, other): + return typing.Union[other, self] + + def __instancecheck__(self, obj): + raise TypeError(f"{self} cannot be used with isinstance()") + + def __subclasscheck__(self, cls): + raise TypeError(f"{self} cannot be used with issubclass()") + + @typing._tp_cache + def __getitem__(self, parameters): + return self._getitem(self, parameters) + + +if hasattr(typing, "LiteralString"): + LiteralString = typing.LiteralString +else: + @_SpecialForm + def LiteralString(self, params): + """Represents an arbitrary literal string. + + Example:: + + from metaflow._vendor.typing_extensions import LiteralString + + def query(sql: LiteralString) -> ...: + ... + + query("SELECT * FROM table") # ok + query(f"SELECT * FROM {input()}") # not ok + + See PEP 675 for details. + + """ + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, "Self"): + Self = typing.Self +else: + @_SpecialForm + def Self(self, params): + """Used to spell the type of "self" in classes. + + Example:: + + from typing import Self + + class ReturnsSelf: + def parse(self, data: bytes) -> Self: + ... + return self + + """ + + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, "Never"): + Never = typing.Never +else: + @_SpecialForm + def Never(self, params): + """The bottom type, a type that has no members. + + This can be used to define a function that should never be + called, or a function that never returns:: + + from metaflow._vendor.typing_extensions import Never + + def never_call_me(arg: Never) -> None: + pass + + def int_or_str(arg: int | str) -> None: + never_call_me(arg) # type checker error + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + never_call_me(arg) # ok, arg is of type Never + + """ + + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, 'Required'): + Required = typing.Required + NotRequired = typing.NotRequired +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def Required(self, parameters): + """A special typing construct to mark a key of a total=False TypedDict + as required. For example: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + @_ExtensionsSpecialForm + def NotRequired(self, parameters): + """A special typing construct to mark a key of a TypedDict as + potentially missing. For example: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + +else: + class _RequiredForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + Required = _RequiredForm( + 'Required', + doc="""A special typing construct to mark a key of a total=False TypedDict + as required. For example: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """) + NotRequired = _RequiredForm( + 'NotRequired', + doc="""A special typing construct to mark a key of a TypedDict as + potentially missing. For example: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """) + + +_UNPACK_DOC = """\ +Type unpack operator. + +The type unpack operator takes the child types from some container type, +such as `tuple[int, str]` or a `TypeVarTuple`, and 'pulls them out'. For +example: + + # For some generic class `Foo`: + Foo[Unpack[tuple[int, str]]] # Equivalent to Foo[int, str] + + Ts = TypeVarTuple('Ts') + # Specifies that `Bar` is generic in an arbitrary number of types. + # (Think of `Ts` as a tuple of an arbitrary number of individual + # `TypeVar`s, which the `Unpack` is 'pulling out' directly into the + # `Generic[]`.) + class Bar(Generic[Unpack[Ts]]): ... + Bar[int] # Valid + Bar[int, str] # Also valid + +From Python 3.11, this can also be done using the `*` operator: + + Foo[*tuple[int, str]] + class Bar(Generic[*Ts]): ... + +The operator can also be used along with a `TypedDict` to annotate +`**kwargs` in a function signature. For instance: + + class Movie(TypedDict): + name: str + year: int + + # This function expects two keyword arguments - *name* of type `str` and + # *year* of type `int`. + def foo(**kwargs: Unpack[Movie]): ... + +Note that there is only some runtime checking of this operator. Not +everything the runtime allows may be accepted by static type checkers. + +For more information, see PEP 646 and PEP 692. +""" + + +if sys.version_info >= (3, 12): # PEP 692 changed the repr of Unpack[] + Unpack = typing.Unpack + + def _is_unpack(obj): + return get_origin(obj) is Unpack + +elif sys.version_info[:2] >= (3, 9): + class _UnpackSpecialForm(_ExtensionsSpecialForm, _root=True): + def __init__(self, getitem): + super().__init__(getitem) + self.__doc__ = _UNPACK_DOC + + class _UnpackAlias(typing._GenericAlias, _root=True): + __class__ = typing.TypeVar + + @_UnpackSpecialForm + def Unpack(self, parameters): + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return _UnpackAlias(self, (item,)) + + def _is_unpack(obj): + return isinstance(obj, _UnpackAlias) + +else: + class _UnpackAlias(typing._GenericAlias, _root=True): + __class__ = typing.TypeVar + + class _UnpackForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return _UnpackAlias(self, (item,)) + + Unpack = _UnpackForm('Unpack', doc=_UNPACK_DOC) + + def _is_unpack(obj): + return isinstance(obj, _UnpackAlias) + + +if hasattr(typing, "TypeVarTuple"): # 3.11+ + + # Add default parameter - PEP 696 + class TypeVarTuple(metaclass=_TypeVarLikeMeta): + """Type variable tuple.""" + + _backported_typevarlike = typing.TypeVarTuple + + def __new__(cls, name, *, default=_marker): + tvt = typing.TypeVarTuple(name) + _set_default(tvt, default) + _set_module(tvt) + return tvt + + def __init_subclass__(self, *args, **kwds): + raise TypeError("Cannot subclass special typing classes") + +else: + class TypeVarTuple(_DefaultMixin): + """Type variable tuple. + + Usage:: + + Ts = TypeVarTuple('Ts') + + In the same way that a normal type variable is a stand-in for a single + type such as ``int``, a type variable *tuple* is a stand-in for a *tuple* + type such as ``Tuple[int, str]``. + + Type variable tuples can be used in ``Generic`` declarations. + Consider the following example:: + + class Array(Generic[*Ts]): ... + + The ``Ts`` type variable tuple here behaves like ``tuple[T1, T2]``, + where ``T1`` and ``T2`` are type variables. To use these type variables + as type parameters of ``Array``, we must *unpack* the type variable tuple using + the star operator: ``*Ts``. The signature of ``Array`` then behaves + as if we had simply written ``class Array(Generic[T1, T2]): ...``. + In contrast to ``Generic[T1, T2]``, however, ``Generic[*Shape]`` allows + us to parameterise the class with an *arbitrary* number of type parameters. + + Type variable tuples can be used anywhere a normal ``TypeVar`` can. + This includes class definitions, as shown above, as well as function + signatures and variable annotations:: + + class Array(Generic[*Ts]): + + def __init__(self, shape: Tuple[*Ts]): + self._shape: Tuple[*Ts] = shape + + def get_shape(self) -> Tuple[*Ts]: + return self._shape + + shape = (Height(480), Width(640)) + x: Array[Height, Width] = Array(shape) + y = abs(x) # Inferred type is Array[Height, Width] + z = x + x # ... is Array[Height, Width] + x.get_shape() # ... is tuple[Height, Width] + + """ + + # Trick Generic __parameters__. + __class__ = typing.TypeVar + + def __iter__(self): + yield self.__unpacked__ + + def __init__(self, name, *, default=_marker): + self.__name__ = name + _DefaultMixin.__init__(self, default) + + # for pickling: + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + self.__unpacked__ = Unpack[self] + + def __repr__(self): + return self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + def __init_subclass__(self, *args, **kwds): + if '_root' not in kwds: + raise TypeError("Cannot subclass special typing classes") + + +if hasattr(typing, "reveal_type"): + reveal_type = typing.reveal_type +else: + def reveal_type(__obj: T) -> T: + """Reveal the inferred type of a variable. + + When a static type checker encounters a call to ``reveal_type()``, + it will emit the inferred type of the argument:: + + x: int = 1 + reveal_type(x) + + Running a static type checker (e.g., ``mypy``) on this example + will produce output similar to 'Revealed type is "builtins.int"'. + + At runtime, the function prints the runtime type of the + argument and returns it unchanged. + + """ + print(f"Runtime type is {type(__obj).__name__!r}", file=sys.stderr) + return __obj + + +if hasattr(typing, "assert_never"): + assert_never = typing.assert_never +else: + def assert_never(__arg: Never) -> Never: + """Assert to the type checker that a line of code is unreachable. + + Example:: + + def int_or_str(arg: int | str) -> None: + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + assert_never(arg) + + If a type checker finds that a call to assert_never() is + reachable, it will emit an error. + + At runtime, this throws an exception when called. + + """ + raise AssertionError("Expected code to be unreachable") + + +if sys.version_info >= (3, 12): + # dataclass_transform exists in 3.11 but lacks the frozen_default parameter + dataclass_transform = typing.dataclass_transform +else: + def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + frozen_default: bool = False, + field_specifiers: typing.Tuple[ + typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], + ... + ] = (), + **kwargs: typing.Any, + ) -> typing.Callable[[T], T]: + """Decorator that marks a function, class, or metaclass as providing + dataclass-like behavior. + + Example: + + from metaflow._vendor.typing_extensions import dataclass_transform + + _T = TypeVar("_T") + + # Used on a decorator function + @dataclass_transform() + def create_model(cls: type[_T]) -> type[_T]: + ... + return cls + + @create_model + class CustomerModel: + id: int + name: str + + # Used on a base class + @dataclass_transform() + class ModelBase: ... + + class CustomerModel(ModelBase): + id: int + name: str + + # Used on a metaclass + @dataclass_transform() + class ModelMeta(type): ... + + class ModelBase(metaclass=ModelMeta): ... + + class CustomerModel(ModelBase): + id: int + name: str + + Each of the ``CustomerModel`` classes defined in this example will now + behave similarly to a dataclass created with the ``@dataclasses.dataclass`` + decorator. For example, the type checker will synthesize an ``__init__`` + method. + + The arguments to this decorator can be used to customize this behavior: + - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be + True or False if it is omitted by the caller. + - ``order_default`` indicates whether the ``order`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``kw_only_default`` indicates whether the ``kw_only`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``frozen_default`` indicates whether the ``frozen`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``field_specifiers`` specifies a static list of supported classes + or functions that describe fields, similar to ``dataclasses.field()``. + + At runtime, this decorator records its arguments in the + ``__dataclass_transform__`` attribute on the decorated object. + + See PEP 681 for details. + + """ + def decorator(cls_or_fn): + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "frozen_default": frozen_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + return decorator + + +if hasattr(typing, "override"): + override = typing.override +else: + _F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) + + def override(__arg: _F) -> _F: + """Indicate that a method is intended to override a method in a base class. + + Usage: + + class Base: + def method(self) -> None: ... + pass + + class Child(Base): + @override + def method(self) -> None: + super().method() + + When this decorator is applied to a method, the type checker will + validate that it overrides a method with the same name on a base class. + This helps prevent bugs that may occur when a base class is changed + without an equivalent change to a child class. + + There is no runtime checking of these properties. The decorator + sets the ``__override__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + + See PEP 698 for details. + + """ + try: + __arg.__override__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return __arg + + +if hasattr(typing, "deprecated"): + deprecated = typing.deprecated +else: + _T = typing.TypeVar("_T") + + def deprecated( + __msg: str, + *, + category: typing.Optional[typing.Type[Warning]] = DeprecationWarning, + stacklevel: int = 1, + ) -> typing.Callable[[_T], _T]: + """Indicate that a class, function or overload is deprecated. + + Usage: + + @deprecated("Use B instead") + class A: + pass + + @deprecated("Use g instead") + def f(): + pass + + @overload + @deprecated("int support is deprecated") + def g(x: int) -> int: ... + @overload + def g(x: str) -> int: ... + + When this decorator is applied to an object, the type checker + will generate a diagnostic on usage of the deprecated object. + + The warning specified by ``category`` will be emitted on use + of deprecated objects. For functions, that happens on calls; + for classes, on instantiation. If the ``category`` is ``None``, + no warning is emitted. The ``stacklevel`` determines where the + warning is emitted. If it is ``1`` (the default), the warning + is emitted at the direct caller of the deprecated object; if it + is higher, it is emitted further up the stack. + + The decorator sets the ``__deprecated__`` + attribute on the decorated object to the deprecation message + passed to the decorator. If applied to an overload, the decorator + must be after the ``@overload`` decorator for the attribute to + exist on the overload as returned by ``get_overloads()``. + + See PEP 702 for details. + + """ + def decorator(__arg: _T) -> _T: + if category is None: + __arg.__deprecated__ = __msg + return __arg + elif isinstance(__arg, type): + original_new = __arg.__new__ + has_init = __arg.__init__ is not object.__init__ + + @functools.wraps(original_new) + def __new__(cls, *args, **kwargs): + warnings.warn(__msg, category=category, stacklevel=stacklevel + 1) + if original_new is not object.__new__: + return original_new(cls, *args, **kwargs) + # Mirrors a similar check in object.__new__. + elif not has_init and (args or kwargs): + raise TypeError(f"{cls.__name__}() takes no arguments") + else: + return original_new(cls) + + __arg.__new__ = staticmethod(__new__) + __arg.__deprecated__ = __new__.__deprecated__ = __msg + return __arg + elif callable(__arg): + @functools.wraps(__arg) + def wrapper(*args, **kwargs): + warnings.warn(__msg, category=category, stacklevel=stacklevel + 1) + return __arg(*args, **kwargs) + + __arg.__deprecated__ = wrapper.__deprecated__ = __msg + return wrapper + else: + raise TypeError( + "@deprecated decorator with non-None category must be applied to " + f"a class or callable, not {__arg!r}" + ) + + return decorator + + +# We have to do some monkey patching to deal with the dual nature of +# Unpack/TypeVarTuple: +# - We want Unpack to be a kind of TypeVar so it gets accepted in +# Generic[Unpack[Ts]] +# - We want it to *not* be treated as a TypeVar for the purposes of +# counting generic parameters, so that when we subscript a generic, +# the runtime doesn't try to substitute the Unpack with the subscripted type. +if not hasattr(typing, "TypeVarTuple"): + typing._collect_type_vars = _collect_type_vars + typing._check_generic = _check_generic + + +# Backport typing.NamedTuple as it exists in Python 3.12. +# In 3.11, the ability to define generic `NamedTuple`s was supported. +# This was explicitly disallowed in 3.9-3.10, and only half-worked in <=3.8. +# On 3.12, we added __orig_bases__ to call-based NamedTuples +# On 3.13, we deprecated kwargs-based NamedTuples +if sys.version_info >= (3, 13): + NamedTuple = typing.NamedTuple +else: + def _make_nmtuple(name, types, module, defaults=()): + fields = [n for n, t in types] + annotations = {n: typing._type_check(t, f"field {n} annotation must be a type") + for n, t in types} + nm_tpl = collections.namedtuple(name, fields, + defaults=defaults, module=module) + nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = annotations + # The `_field_types` attribute was removed in 3.9; + # in earlier versions, it is the same as the `__annotations__` attribute + if sys.version_info < (3, 9): + nm_tpl._field_types = annotations + return nm_tpl + + _prohibited_namedtuple_fields = typing._prohibited + _special_namedtuple_fields = frozenset({'__module__', '__name__', '__annotations__'}) + + class _NamedTupleMeta(type): + def __new__(cls, typename, bases, ns): + assert _NamedTuple in bases + for base in bases: + if base is not _NamedTuple and base is not typing.Generic: + raise TypeError( + 'can only inherit from a NamedTuple type and Generic') + bases = tuple(tuple if base is _NamedTuple else base for base in bases) + types = ns.get('__annotations__', {}) + default_names = [] + for field_name in types: + if field_name in ns: + default_names.append(field_name) + elif default_names: + raise TypeError(f"Non-default namedtuple field {field_name} " + f"cannot follow default field" + f"{'s' if len(default_names) > 1 else ''} " + f"{', '.join(default_names)}") + nm_tpl = _make_nmtuple( + typename, types.items(), + defaults=[ns[n] for n in default_names], + module=ns['__module__'] + ) + nm_tpl.__bases__ = bases + if typing.Generic in bases: + if hasattr(typing, '_generic_class_getitem'): # 3.12+ + nm_tpl.__class_getitem__ = classmethod(typing._generic_class_getitem) + else: + class_getitem = typing.Generic.__class_getitem__.__func__ + nm_tpl.__class_getitem__ = classmethod(class_getitem) + # update from user namespace without overriding special namedtuple attributes + for key in ns: + if key in _prohibited_namedtuple_fields: + raise AttributeError("Cannot overwrite NamedTuple attribute " + key) + elif key not in _special_namedtuple_fields and key not in nm_tpl._fields: + setattr(nm_tpl, key, ns[key]) + if typing.Generic in bases: + nm_tpl.__init_subclass__() + return nm_tpl + + def NamedTuple(__typename, __fields=_marker, **kwargs): + """Typed version of namedtuple. + + Usage:: + + class Employee(NamedTuple): + name: str + id: int + + This is equivalent to:: + + Employee = collections.namedtuple('Employee', ['name', 'id']) + + The resulting class has an extra __annotations__ attribute, giving a + dict that maps field names to types. (The field names are also in + the _fields attribute, which is part of the namedtuple API.) + An alternative equivalent functional syntax is also accepted:: + + Employee = NamedTuple('Employee', [('name', str), ('id', int)]) + """ + if __fields is _marker: + if kwargs: + deprecated_thing = "Creating NamedTuple classes using keyword arguments" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "Use the class-based or functional syntax instead." + ) + else: + deprecated_thing = "Failing to pass a value for the 'fields' parameter" + example = f"`{__typename} = NamedTuple({__typename!r}, [])`" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + example + "." + elif __fields is None: + if kwargs: + raise TypeError( + "Cannot pass `None` as the 'fields' parameter " + "and also specify fields using keyword arguments" + ) + else: + deprecated_thing = "Passing `None` as the 'fields' parameter" + example = f"`{__typename} = NamedTuple({__typename!r}, [])`" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + example + "." + elif kwargs: + raise TypeError("Either list of fields or keywords" + " can be provided to NamedTuple, not both") + if __fields is _marker or __fields is None: + warnings.warn( + deprecation_msg.format(name=deprecated_thing, remove="3.15"), + DeprecationWarning, + stacklevel=2, + ) + __fields = kwargs.items() + nt = _make_nmtuple(__typename, __fields, module=_caller()) + nt.__orig_bases__ = (NamedTuple,) + return nt + + _NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {}) + + # On 3.8+, alter the signature so that it matches typing.NamedTuple. + # The signature of typing.NamedTuple on >=3.8 is invalid syntax in Python 3.7, + # so just leave the signature as it is on 3.7. + if sys.version_info >= (3, 8): + NamedTuple.__text_signature__ = '(typename, fields=None, /, **kwargs)' + + def _namedtuple_mro_entries(bases): + assert NamedTuple in bases + return (_NamedTuple,) + + NamedTuple.__mro_entries__ = _namedtuple_mro_entries + + +if hasattr(collections.abc, "Buffer"): + Buffer = collections.abc.Buffer +else: + class Buffer(abc.ABC): + """Base class for classes that implement the buffer protocol. + + The buffer protocol allows Python objects to expose a low-level + memory buffer interface. Before Python 3.12, it is not possible + to implement the buffer protocol in pure Python code, or even + to check whether a class implements the buffer protocol. In + Python 3.12 and higher, the ``__buffer__`` method allows access + to the buffer protocol from Python code, and the + ``collections.abc.Buffer`` ABC allows checking whether a class + implements the buffer protocol. + + To indicate support for the buffer protocol in earlier versions, + inherit from this ABC, either in a stub file or at runtime, + or use ABC registration. This ABC provides no methods, because + there is no Python-accessible methods shared by pre-3.12 buffer + classes. It is useful primarily for static checks. + + """ + + # As a courtesy, register the most common stdlib buffer classes. + Buffer.register(memoryview) + Buffer.register(bytearray) + Buffer.register(bytes) + + +# Backport of types.get_original_bases, available on 3.12+ in CPython +if hasattr(_types, "get_original_bases"): + get_original_bases = _types.get_original_bases +else: + def get_original_bases(__cls): + """Return the class's "original" bases prior to modification by `__mro_entries__`. + + Examples:: + + from typing import TypeVar, Generic + from metaflow._vendor.typing_extensions import NamedTuple, TypedDict + + T = TypeVar("T") + class Foo(Generic[T]): ... + class Bar(Foo[int], float): ... + class Baz(list[str]): ... + Eggs = NamedTuple("Eggs", [("a", int), ("b", str)]) + Spam = TypedDict("Spam", {"a": int, "b": str}) + + assert get_original_bases(Bar) == (Foo[int], float) + assert get_original_bases(Baz) == (list[str],) + assert get_original_bases(Eggs) == (NamedTuple,) + assert get_original_bases(Spam) == (TypedDict,) + assert get_original_bases(int) == (object,) + """ + try: + return __cls.__orig_bases__ + except AttributeError: + try: + return __cls.__bases__ + except AttributeError: + raise TypeError( + f'Expected an instance of type, not {type(__cls).__name__!r}' + ) from None + + +# NewType is a class on Python 3.10+, making it pickleable +# The error message for subclassing instances of NewType was improved on 3.11+ +if sys.version_info >= (3, 11): + NewType = typing.NewType +else: + class NewType: + """NewType creates simple unique types with almost zero + runtime overhead. NewType(name, tp) is considered a subtype of tp + by static type checkers. At runtime, NewType(name, tp) returns + a dummy callable that simply returns its argument. Usage:: + UserId = NewType('UserId', int) + def name_by_id(user_id: UserId) -> str: + ... + UserId('user') # Fails type check + name_by_id(42) # Fails type check + name_by_id(UserId(42)) # OK + num = UserId(5) + 1 # type: int + """ + + def __call__(self, obj): + return obj + + def __init__(self, name, tp): + self.__qualname__ = name + if '.' in name: + name = name.rpartition('.')[-1] + self.__name__ = name + self.__supertype__ = tp + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + def __mro_entries__(self, bases): + # We defined __mro_entries__ to get a better error message + # if a user attempts to subclass a NewType instance. bpo-46170 + supercls_name = self.__name__ + + class Dummy: + def __init_subclass__(cls): + subcls_name = cls.__name__ + raise TypeError( + f"Cannot subclass an instance of NewType. " + f"Perhaps you were looking for: " + f"`{subcls_name} = NewType({subcls_name!r}, {supercls_name})`" + ) + + return (Dummy,) + + def __repr__(self): + return f'{self.__module__}.{self.__qualname__}' + + def __reduce__(self): + return self.__qualname__ + + if sys.version_info >= (3, 10): + # PEP 604 methods + # It doesn't make sense to have these methods on Python <3.10 + + def __or__(self, other): + return typing.Union[self, other] + + def __ror__(self, other): + return typing.Union[other, self] + + +if hasattr(typing, "TypeAliasType"): + TypeAliasType = typing.TypeAliasType +else: + def _is_unionable(obj): + """Corresponds to is_unionable() in unionobject.c in CPython.""" + return obj is None or isinstance(obj, ( + type, + _types.GenericAlias, + _types.UnionType, + TypeAliasType, + )) + + class TypeAliasType: + """Create named, parameterized type aliases. + + This provides a backport of the new `type` statement in Python 3.12: + + type ListOrSet[T] = list[T] | set[T] + + is equivalent to: + + T = TypeVar("T") + ListOrSet = TypeAliasType("ListOrSet", list[T] | set[T], type_params=(T,)) + + The name ListOrSet can then be used as an alias for the type it refers to. + + The type_params argument should contain all the type parameters used + in the value of the type alias. If the alias is not generic, this + argument is omitted. + + Static type checkers should only support type aliases declared using + TypeAliasType that follow these rules: + + - The first argument (the name) must be a string literal. + - The TypeAliasType instance must be immediately assigned to a variable + of the same name. (For example, 'X = TypeAliasType("Y", int)' is invalid, + as is 'X, Y = TypeAliasType("X", int), TypeAliasType("Y", int)'). + + """ + + def __init__(self, name: str, value, *, type_params=()): + if not isinstance(name, str): + raise TypeError("TypeAliasType name must be a string") + self.__value__ = value + self.__type_params__ = type_params + + parameters = [] + for type_param in type_params: + if isinstance(type_param, TypeVarTuple): + parameters.extend(type_param) + else: + parameters.append(type_param) + self.__parameters__ = tuple(parameters) + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + # Setting this attribute closes the TypeAliasType from further modification + self.__name__ = name + + def __setattr__(self, __name: str, __value: object) -> None: + if hasattr(self, "__name__"): + self._raise_attribute_error(__name) + super().__setattr__(__name, __value) + + def __delattr__(self, __name: str) -> Never: + self._raise_attribute_error(__name) + + def _raise_attribute_error(self, name: str) -> Never: + # Match the Python 3.12 error messages exactly + if name == "__name__": + raise AttributeError("readonly attribute") + elif name in {"__value__", "__type_params__", "__parameters__", "__module__"}: + raise AttributeError( + f"attribute '{name}' of 'typing.TypeAliasType' objects " + "is not writable" + ) + else: + raise AttributeError( + f"'typing.TypeAliasType' object has no attribute '{name}'" + ) + + def __repr__(self) -> str: + return self.__name__ + + def __getitem__(self, parameters): + if not isinstance(parameters, tuple): + parameters = (parameters,) + parameters = [ + typing._type_check( + item, f'Subscripting {self.__name__} requires a type.' + ) + for item in parameters + ] + return typing._GenericAlias(self, tuple(parameters)) + + def __reduce__(self): + return self.__name__ + + def __init_subclass__(cls, *args, **kwargs): + raise TypeError( + "type 'typing_extensions.TypeAliasType' is not an acceptable base type" + ) + + # The presence of this method convinces typing._type_check + # that TypeAliasTypes are types. + def __call__(self): + raise TypeError("Type alias is not callable") + + if sys.version_info >= (3, 10): + def __or__(self, right): + # For forward compatibility with 3.12, reject Unions + # that are not accepted by the built-in Union. + if not _is_unionable(right): + return NotImplemented + return typing.Union[self, right] + + def __ror__(self, left): + if not _is_unionable(left): + return NotImplemented + return typing.Union[left, self] + + +if hasattr(typing, "is_protocol"): + is_protocol = typing.is_protocol + get_protocol_members = typing.get_protocol_members +else: + def is_protocol(__tp: type) -> bool: + """Return True if the given type is a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, is_protocol + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> is_protocol(P) + True + >>> is_protocol(int) + False + """ + return ( + isinstance(__tp, type) + and getattr(__tp, '_is_protocol', False) + and __tp != Protocol + ) + + def get_protocol_members(__tp: type) -> typing.FrozenSet[str]: + """Return the set of members defined in a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, get_protocol_members + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> get_protocol_members(P) + frozenset({'a', 'b'}) + + Raise a TypeError for arguments that are not Protocols. + """ + if not is_protocol(__tp): + raise TypeError(f'{__tp!r} is not a Protocol') + if hasattr(__tp, '__protocol_attrs__'): + return frozenset(__tp.__protocol_attrs__) + return frozenset(_get_protocol_attrs(__tp)) + + +# Aliases for items that have always been in typing. +# Explicitly assign these (rather than using `from typing import *` at the top), +# so that we get a CI error if one of these is deleted from typing.py +# in a future version of Python +AbstractSet = typing.AbstractSet +AnyStr = typing.AnyStr +BinaryIO = typing.BinaryIO +Callable = typing.Callable +Collection = typing.Collection +Container = typing.Container +Dict = typing.Dict +ForwardRef = typing.ForwardRef +FrozenSet = typing.FrozenSet +Generator = typing.Generator +Generic = typing.Generic +Hashable = typing.Hashable +IO = typing.IO +ItemsView = typing.ItemsView +Iterable = typing.Iterable +Iterator = typing.Iterator +KeysView = typing.KeysView +List = typing.List +Mapping = typing.Mapping +MappingView = typing.MappingView +Match = typing.Match +MutableMapping = typing.MutableMapping +MutableSequence = typing.MutableSequence +MutableSet = typing.MutableSet +Optional = typing.Optional +Pattern = typing.Pattern +Reversible = typing.Reversible +Sequence = typing.Sequence +Set = typing.Set +Sized = typing.Sized +TextIO = typing.TextIO +Tuple = typing.Tuple +Union = typing.Union +ValuesView = typing.ValuesView +cast = typing.cast +no_type_check = typing.no_type_check +no_type_check_decorator = typing.no_type_check_decorator diff --git a/metaflow/_vendor/v3_7/__init__.py b/metaflow/_vendor/v3_7/__init__.py deleted file mode 100644 index 22ae0c5f40e..00000000000 --- a/metaflow/_vendor/v3_7/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Empty file \ No newline at end of file diff --git a/metaflow/_vendor/vendor_any.txt b/metaflow/_vendor/vendor_any.txt index 9c48802aae6..d6ea255e3e1 100644 --- a/metaflow/_vendor/vendor_any.txt +++ b/metaflow/_vendor/vendor_any.txt @@ -1,2 +1,6 @@ click==7.1.2 packaging==23.0 +importlib_metadata==4.8.3 +typeguard==4.0.1 +typing_extensions==4.7.0 +zipp==3.6.0 diff --git a/metaflow/_vendor/vendor_v3_7.txt b/metaflow/_vendor/vendor_v3_7.txt deleted file mode 100644 index 7e7876b678f..00000000000 --- a/metaflow/_vendor/vendor_v3_7.txt +++ /dev/null @@ -1 +0,0 @@ -zipp==3.6.0 \ No newline at end of file diff --git a/metaflow/_vendor/v3_7/zipp.LICENSE b/metaflow/_vendor/zipp.LICENSE similarity index 100% rename from metaflow/_vendor/v3_7/zipp.LICENSE rename to metaflow/_vendor/zipp.LICENSE diff --git a/metaflow/_vendor/v3_7/zipp.py b/metaflow/_vendor/zipp.py similarity index 100% rename from metaflow/_vendor/v3_7/zipp.py rename to metaflow/_vendor/zipp.py diff --git a/metaflow/cmd/develop/stubs.py b/metaflow/cmd/develop/stubs.py index 9bd409b3092..49c00486c3f 100644 --- a/metaflow/cmd/develop/stubs.py +++ b/metaflow/cmd/develop/stubs.py @@ -23,6 +23,8 @@ def _check_stubs_supported(): if _py_ver >= (3, 4): if _py_ver >= (3, 8): from importlib import metadata + elif _py_ver >= (3, 7): + from metaflow._vendor import importlib_metadata as metadata elif _py_ver >= (3, 6): from metaflow._vendor.v3_6 import importlib_metadata as metadata else: diff --git a/metaflow/extension_support/__init__.py b/metaflow/extension_support/__init__.py index 99b7de39207..ffd50603e60 100644 --- a/metaflow/extension_support/__init__.py +++ b/metaflow/extension_support/__init__.py @@ -262,6 +262,8 @@ def multiload_all(modules, extension_point, dst_globals): if _py_ver >= (3, 8): from importlib import metadata + elif _py_ver >= (3, 7): + from metaflow._vendor import importlib_metadata as metadata elif _py_ver >= (3, 6): from metaflow._vendor.v3_6 import importlib_metadata as metadata else: diff --git a/metaflow/vendor.py b/metaflow/vendor.py index ebabd2e8f70..56b952324cc 100644 --- a/metaflow/vendor.py +++ b/metaflow/vendor.py @@ -13,7 +13,6 @@ "vendor_any.txt", "vendor_v3_5.txt", "vendor_v3_6.txt", - "vendor_v3_7.txt", "pip.LICENSE", } diff --git a/test/core/run_tests.py b/test/core/run_tests.py index dec44bcabe8..51f24af3ae3 100644 --- a/test/core/run_tests.py +++ b/test/core/run_tests.py @@ -11,6 +11,18 @@ from metaflow._vendor import click +skip_api_executor = False + +try: + from metaflow.click_api import ( + MetaflowAPI, + extract_all_params, + click_to_python_types, + ) + from metaflow.cli import start, run +except RuntimeError: + skip_api_executor = True + from metaflow_test import MetaflowTest from metaflow_test.formatter import FlowFormatter