Skip to content

Commit

Permalink
Refactor case type check
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt committed Mar 20, 2024
1 parent 28a177f commit 7f97843
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions magma/sum_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,16 @@ def activated_cases(self):
def num_tags(self):
return len(self._tag_map)

def _check_case_type(self, T):
if T.undirected_t not in self._tag_map:
raise TypeError(f"Unexpected case type {T}")

def activate_case(self, T):
if self._active_case is not None:
raise TypeError("Cannot have more than one active case")
if any(T is x for x in self._activated_cases):
raise TypeError(f"Cannot call case({T}) twice")
if isinstance(T, Kind):
if T.undirected_t not in self._tag_map:
raise TypeError(f"Unexpected case type {T}")
else:
if not isinstance(T, Type):
raise TypeError(f"Unexpected case type {T}")
self._check_case_type(T)

self._active_case = T
self._activated_cases.append(T)
Expand Down Expand Up @@ -250,3 +249,7 @@ def __new__(mcs, name, bases, namespace):
class Enum2(Sum, metaclass=Enum2Meta):
def _get_tag(self, driver):
return self._tag_map[driver]

def _check_case_type(self, value):
if not isinstance(value, Type):
raise TypeError(f"Unexpected case value {value}")

0 comments on commit 7f97843

Please sign in to comment.