|
| 1 | +""" |
| 2 | +attribute module containing helper functions for instrumented attribute. |
| 3 | +""" |
| 4 | + |
| 5 | +from functools import lru_cache |
| 6 | +from inspect import isclass |
| 7 | + |
| 8 | +from sqlalchemy.orm import ColumnProperty, RelationshipProperty |
| 9 | +from sqlalchemy.orm.attributes import InstrumentedAttribute, get_attribute, set_attribute |
| 10 | + |
| 11 | + |
| 12 | +def instrumented_attribute(class_or_instance, key: str): |
| 13 | + """ |
| 14 | + Returns instrumented attribute from the class or instance. |
| 15 | + """ |
| 16 | + |
| 17 | + if isclass(class_or_instance): |
| 18 | + return getattr(class_or_instance, key) |
| 19 | + |
| 20 | + return getattr(class_or_instance.__class__, key) |
| 21 | + |
| 22 | + |
| 23 | +def attr_is_relationship(instrumented_attr: InstrumentedAttribute): |
| 24 | + """ |
| 25 | + Check if instrumented attribute property is a RelationshipProperty |
| 26 | + """ |
| 27 | + return isinstance(instrumented_attr.property, RelationshipProperty) |
| 28 | + |
| 29 | + |
| 30 | +def attr_is_column(instrumented_attr: InstrumentedAttribute): |
| 31 | + """ |
| 32 | + Check if instrumented attribute property is a ColumnProperty |
| 33 | + """ |
| 34 | + return isinstance(instrumented_attr.property, ColumnProperty) |
| 35 | + |
| 36 | + |
| 37 | +def set_instance_attribute(instance, key, value): |
| 38 | + """ |
| 39 | + Set attribute value of instance |
| 40 | + """ |
| 41 | + |
| 42 | + instr_attr: InstrumentedAttribute = getattr(instance.__class__, key) |
| 43 | + |
| 44 | + if attr_is_relationship(instr_attr) and instr_attr.property.uselist: |
| 45 | + get_attribute(instance, key).append(value) |
| 46 | + else: |
| 47 | + set_attribute(instance, key, value) |
| 48 | + |
| 49 | +@lru_cache() |
| 50 | +def foreign_key_column(instrumented_attr: InstrumentedAttribute): |
| 51 | + """ |
| 52 | + Returns the table name of the first foreignkey. |
| 53 | + """ |
| 54 | + return next(iter(instrumented_attr.foreign_keys)).column |
| 55 | + |
| 56 | +@lru_cache() |
| 57 | +def referenced_class(instrumented_attr: InstrumentedAttribute): |
| 58 | + """ |
| 59 | + Returns class that the attribute is referenced to. |
| 60 | + """ |
| 61 | + |
| 62 | + if attr_is_relationship(instrumented_attr): |
| 63 | + return instrumented_attr.mapper.class_ |
| 64 | + |
| 65 | + table_name = foreign_key_column(instrumented_attr).table.name |
| 66 | + |
| 67 | + return next(filter( |
| 68 | + lambda mapper: mapper.class_.__tablename__ == table_name, |
| 69 | + instrumented_attr.parent.registry.mappers |
| 70 | + )).class_ |
0 commit comments