Skip to content

Commit b9bf288

Browse files
authored
Merge pull request #32 from jedymatt/remove-class-registry
Remove class registry
2 parents 90b5923 + e5b99e8 commit b9bf288

File tree

5 files changed

+87
-151
lines changed

5 files changed

+87
-151
lines changed

src/sqlalchemyseed/class_registry.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

src/sqlalchemyseed/errors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,13 @@ class UnsupportedClassError(Exception):
4141
class NotInModuleError(Exception):
4242
"""Raised when a value is not found in module"""
4343
pass
44+
45+
46+
class InvalidModelPath(Exception):
47+
"""Raised when an invalid model path is invoked"""
48+
pass
49+
50+
51+
class UnsupportedClassError(Exception):
52+
"""Raised when an unsupported class is invoked"""
53+
pass

src/sqlalchemyseed/seeder.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from sqlalchemy.orm.relationships import RelationshipProperty
3232
from sqlalchemy.sql import schema
3333

34-
from . import class_registry, validator, errors, util
34+
from . import validator, errors, util
3535

3636

3737
class AbstractSeeder(abc.ABC):
@@ -146,7 +146,6 @@ class Seeder(AbstractSeeder):
146146

147147
def __init__(self, session: sqlalchemy.orm.Session = None, ref_prefix="!"):
148148
self.session = session
149-
self._class_registry = class_registry.ClassRegistry()
150149
self._instances = []
151150
self.ref_prefix = ref_prefix
152151

@@ -156,15 +155,14 @@ def instances(self):
156155

157156
def get_model_class(self, entity, parent: Entity):
158157
if self.__model_key in entity:
159-
return self._class_registry.register_class(entity[self.__model_key])
158+
return util.get_model_class(entity[self.__model_key])
160159
# parent is not None
161160
return parent.referenced_class
162161

163162
def seed(self, entities, add_to_session=True):
164163
validator.validate(entities=entities, ref_prefix=self.ref_prefix)
165164

166165
self._instances.clear()
167-
self._class_registry.clear()
168166

169167
self._pre_seed(entities)
170168

@@ -231,7 +229,6 @@ class HybridSeeder(AbstractSeeder):
231229

232230
def __init__(self, session: sqlalchemy.orm.Session, ref_prefix: str = '!'):
233231
self.session = session
234-
self._class_registry = class_registry.ClassRegistry()
235232
self._instances = []
236233
self.ref_prefix = ref_prefix
237234

@@ -245,7 +242,7 @@ def get_model_class(self, entity, parent: Entity):
245242

246243
if self.__model_key in entity:
247244
class_path = entity[self.__model_key]
248-
return self._class_registry.register_class(class_path)
245+
return util.get_model_class(class_path)
249246

250247
# parent is not None
251248
return parent.referenced_class
@@ -255,7 +252,6 @@ def seed(self, entities):
255252
entities=entities, ref_prefix=self.ref_prefix)
256253

257254
self._instances.clear()
258-
self._class_registry.clear()
259255

260256
self._pre_seed(entities)
261257

src/sqlalchemyseed/util.py

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,55 @@
33
"""
44

55

6+
from functools import lru_cache
7+
import importlib
8+
from typing import Iterable
9+
610
from sqlalchemy import inspect
11+
from sqlalchemyseed import errors
712

813

914
def iter_ref_kwargs(kwargs: dict, ref_prefix: str):
10-
"""Iterate kwargs with name prefix or references"""
15+
"""
16+
Iterate kwargs with name prefix or references
17+
"""
1118
for attr_name, value in kwargs.items():
1219
if attr_name.startswith(ref_prefix):
1320
# removed prefix
1421
yield attr_name[len(ref_prefix):], value
1522

1623

24+
def iter_kwargs_with_prefix(kwargs: dict, prefix: str):
25+
"""
26+
Iterate kwargs(dict) that has the specified prefix.
27+
"""
28+
for key, value in kwargs.items():
29+
if str(key).startswith(prefix):
30+
yield key, value
31+
32+
33+
def iterate_json(json: dict, key_prefix: str):
34+
"""
35+
Iterate through json that has matching key prefix
36+
"""
37+
for key, value in json.items():
38+
has_prefix = str(key).startswith(key_prefix)
39+
40+
if has_prefix:
41+
# removed prefix
42+
yield key[len(key_prefix):], value
43+
44+
45+
def iterate_json_no_prefix(json: dict, key_prefix: str):
46+
"""
47+
Iterate through json that has no matching key prefix
48+
"""
49+
for key, value in json.items():
50+
has_prefix = str(key).startswith(key_prefix)
51+
if not has_prefix:
52+
yield key, value
53+
54+
1755
def iter_non_ref_kwargs(kwargs: dict, ref_prefix: str):
1856
"""Iterate kwargs, skipping item with name prefix or references"""
1957
for attr_name, value in kwargs.items():
@@ -33,22 +71,44 @@ def is_supported_class(class_):
3371
def generate_repr(instance: object) -> str:
3472
"""
3573
Generate repr of object instance
36-
37-
Example:
38-
```
39-
class Person(Base):
40-
...
41-
def __repr__(self):
42-
return generate_repr(self)
43-
```
44-
45-
Output format:
46-
```
47-
"<Person(id='1',name='John Doe')>"
48-
```
4974
"""
5075
class_name = instance.__class__.__name__
5176
insp = inspect(instance)
5277
attributes = {column.key: column.value for column in insp.attrs}
5378
str_attributes = ",".join(f"{k}='{v}'" for k, v in attributes.items())
5479
return f"<{class_name}({str_attributes})>"
80+
81+
82+
def find_item(json: Iterable, keys: list):
83+
"""
84+
Finds item of json from keys
85+
"""
86+
return find_item(json[keys[0]], keys[1:]) if keys else json
87+
88+
89+
# check if class is a sqlalchemy model
90+
def is_model(class_):
91+
"""
92+
Check if class is a sqlalchemy model
93+
"""
94+
insp = inspect(class_, raiseerr=False)
95+
return insp is not None and insp.is_mapper
96+
97+
98+
# get sqlalchemy model class from path
99+
@lru_cache(maxsize=None)
100+
def get_model_class(path: str):
101+
"""
102+
Get sqlalchemy model class from path
103+
"""
104+
try:
105+
module_name, class_name = path.rsplit(".", 1)
106+
module = importlib.import_module(module_name)
107+
except (ImportError, AttributeError) as e:
108+
raise errors.InvalidModelPath(path=path, error=e)
109+
110+
class_ = getattr(module, class_name)
111+
if not is_model(class_):
112+
raise errors.UnsupportedClassError(path=path)
113+
114+
return class_

tests/test_class_registry.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

0 commit comments

Comments
 (0)