|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Any |
| 3 | + |
| 4 | +from DPF.filters.data_filter import DataFilter |
| 5 | +from DPF.modalities import ModalityName |
| 6 | +from DPF.types import ModalityToDataMapping |
| 7 | + |
| 8 | + |
| 9 | +@dataclass |
| 10 | +class ComplexFilterPreprocessedData: |
| 11 | + key: str |
| 12 | + preprocessed_values: dict[int, Any] |
| 13 | + |
| 14 | + |
| 15 | +class ComplexDataFilter(DataFilter): |
| 16 | + |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + datafilters: list[DataFilter], |
| 20 | + workers: int, |
| 21 | + pbar: bool = True, |
| 22 | + _pbar_position: int = 0 |
| 23 | + ): |
| 24 | + super().__init__(pbar, _pbar_position) |
| 25 | + self.datafilters = datafilters |
| 26 | + self.workers = workers |
| 27 | + |
| 28 | + assert len(self.datafilters) > 0 |
| 29 | + assert all( |
| 30 | + i.key_column == self.datafilters[0].key_column for i in self.datafilters |
| 31 | + ) # check all filters have same key col |
| 32 | + |
| 33 | + @property |
| 34 | + def modalities(self) -> list[ModalityName]: |
| 35 | + modals = [] |
| 36 | + for datafilter in self.datafilters: |
| 37 | + modals.extend(datafilter.modalities) |
| 38 | + return list(set(modals)) |
| 39 | + |
| 40 | + @property |
| 41 | + def key_column(self) -> str: |
| 42 | + return self.datafilters[0].key_column |
| 43 | + |
| 44 | + @property |
| 45 | + def metadata_columns(self) -> list[str]: |
| 46 | + meta_cols = [] |
| 47 | + for datafilter in self.datafilters: |
| 48 | + meta_cols.extend(datafilter.metadata_columns) |
| 49 | + return list(set(meta_cols)) |
| 50 | + |
| 51 | + @property |
| 52 | + def result_columns(self) -> list[str]: |
| 53 | + result_cols = [] |
| 54 | + for datafilter in self.datafilters: |
| 55 | + result_cols.extend(datafilter.result_columns) |
| 56 | + return list(set(result_cols)) |
| 57 | + |
| 58 | + @property |
| 59 | + def dataloader_kwargs(self) -> dict[str, Any]: |
| 60 | + return { |
| 61 | + "num_workers": self.workers, |
| 62 | + "batch_size": 1, |
| 63 | + "drop_last": False, |
| 64 | + } |
| 65 | + |
| 66 | + def preprocess_data( |
| 67 | + self, |
| 68 | + modality2data: ModalityToDataMapping, |
| 69 | + metadata: dict[str, Any] |
| 70 | + ) -> Any: |
| 71 | + key = metadata[self.key_column] |
| 72 | + preprocessed_results = {} |
| 73 | + for ind, datafilter in enumerate(self.datafilters): |
| 74 | + preprocessed_results[ind] = datafilter.preprocess_data(modality2data, metadata) |
| 75 | + return ComplexFilterPreprocessedData(key, preprocessed_results) |
| 76 | + |
| 77 | + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: |
| 78 | + results = {} |
| 79 | + preprocessed_data: ComplexFilterPreprocessedData = batch[0] |
| 80 | + for ind, datafilter in enumerate(self.datafilters): |
| 81 | + filter_results_data = datafilter.process_batch([preprocessed_data.preprocessed_values[ind]]) |
| 82 | + for col, value in filter_results_data.items(): |
| 83 | + results[col] = value |
| 84 | + results[self.key_column] = [preprocessed_data.key] |
| 85 | + |
| 86 | + return results |
0 commit comments