Skip to content

Commit

Permalink
Separate compatibility shim from canonical EntryPoints container.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaraco committed Feb 22, 2021
1 parent e3d1b93 commit 9d55a33
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions importlib_metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,23 +165,12 @@ class EntryPoints(tuple):

__slots__ = ()

def __getitem__(self, name) -> Union[EntryPoint, 'EntryPoints']:
def __getitem__(self, name): # -> EntryPoint:
try:
match = next(iter(self.select(name=name)))
return match
return next(iter(self.select(name=name)))
except StopIteration:
if name in self.groups:
return self._group_getitem(name)
raise KeyError(name)

def _group_getitem(self, name):
"""
For backward compatability, supply .__getitem__ for groups.
"""
msg = "GroupedEntryPoints.__getitem__ is deprecated for groups. Use select."
warnings.warn(msg, DeprecationWarning)
return self.select(group=name)

def select(self, **params):
return EntryPoints(ep for ep in self if ep.matches(**params))

Expand All @@ -193,6 +182,23 @@ def names(self):
def groups(self):
return set(ep.group for ep in self)

@classmethod
def _from_text_for(cls, text, dist):
return cls(ep._for(dist) for ep in EntryPoint._from_text(text))


class LegacyGroupedEntryPoints(EntryPoints):
def __getitem__(self, name) -> Union[EntryPoint, 'EntryPoints']:
try:
return super().__getitem__(name)
except KeyError:
if name not in self.groups:
raise

msg = "GroupedEntryPoints.__getitem__ is deprecated for groups. Use select."
warnings.warn(msg, DeprecationWarning)
return self.select(group=name)

def get(self, group, default=None):
"""
For backward compatibility, supply .get
Expand All @@ -202,9 +208,10 @@ def get(self, group, default=None):
is_flake8 or warnings.warn(msg, DeprecationWarning)
return self.select(group=group) or default

@classmethod
def _from_text_for(cls, text, dist):
return cls(ep._for(dist) for ep in EntryPoint._from_text(text))
def select(self, **params):
if not params:
return self
return super().select(**params)


class PackagePath(pathlib.PurePosixPath):
Expand Down Expand Up @@ -704,7 +711,7 @@ def entry_points(**params):
eps = itertools.chain.from_iterable(
dist.entry_points for dist in unique(distributions())
)
return EntryPoints(eps).select(**params)
return LegacyGroupedEntryPoints(eps).select(**params)


def files(distribution_name):
Expand Down

0 comments on commit 9d55a33

Please sign in to comment.