Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 53 additions & 17 deletions src/biotite/structure/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def filter_polymer(array, min_size=2, pol_type="peptide"):
return np.concatenate(list(bool_idx))


def filter_intersection(array, intersect):
def filter_intersection(array, intersect, categories=None):
"""
Filter all atoms of one array that exist also in another array.

Expand All @@ -414,6 +414,10 @@ def filter_intersection(array, intersect):
The array to be filtered.
intersect : AtomArray
Atoms in `array` that also exists in `intersect` are filtered.
categories : iterable of str
If specified, the given annotation categories are checked for equality in both
arrays.
By default, all common annotation categories are checked.

Returns
-------
Expand All @@ -434,23 +438,30 @@ def filter_intersection(array, intersect):
>>> print(array1.chain_id)
['B' 'C' 'D']
"""
filter = np.full(array.array_length(), True, dtype=bool)
intersect_categories = intersect.get_annotation_categories()
# Check atom equality only for categories,
# which exist in both arrays
categories = [
category
for category in array.get_annotation_categories()
if category in intersect_categories
]
for i in range(array.array_length()):
subfilter = np.full(intersect.array_length(), True, dtype=bool)
if categories is None:
intersect_categories = intersect.get_annotation_categories()
# Check atom equality only for categories,
# which exist in both arrays
categories = [
category
for category in array.get_annotation_categories()
if category in intersect_categories
]
else:
for category in categories:
subfilter &= (
intersect.get_annotation(category) == array.get_annotation(category)[i]
)
filter[i] = subfilter.any()
return filter
if category not in array.get_annotation_categories():
raise ValueError(f"Category {category} does not exist in 'array'")
if category not in intersect.get_annotation_categories():
raise ValueError(f"Category {category} does not exist in 'intersect'")

# Implicitly expect that the annotation array dtypes are the same for both
structured_dtype = np.dtype(
[(name, array.get_annotation(name).dtype) for name in categories]
)
array_annotations = _annotations_to_structured(array, structured_dtype)
intersect_annotations = _annotations_to_structured(intersect, structured_dtype)
# Identify the intersection of the two annotation arrays
return np.isin(array_annotations, intersect_annotations)


def filter_first_altloc(atoms, altloc_ids):
Expand Down Expand Up @@ -608,3 +619,28 @@ def filter_highest_occupancy_altloc(atoms, altloc_ids, occupancies):
pass

return altloc_filter


def _annotations_to_structured(atoms, structured_dtype):
"""
Convert atom annotations into a single structured `ndarray`.

Parameters
----------
atoms : AtomArray, shape=(n,)
The annotation arrays are taken from this structure.
structured_dtype : dtype
The dtype of the structured array to be created.
The fields of the dtype determine which annotations are taken from `atoms`.

Returns
-------
structured : ndarray, shape=(n,), dtype=structured_dtype
The structured array.
"""
if structured_dtype.fields is None:
raise TypeError("dtype must be structured")
structured = np.zeros(atoms.array_length(), dtype=structured_dtype)
for field in structured_dtype.fields:
structured[field] = atoms.get_annotation(field)
return structured
Loading