-
Notifications
You must be signed in to change notification settings - Fork 0
Datalake [2/3?]: Dataset and to_HF() #271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
57990e6
26814e9
6cc7cb5
da0bc39
da79c1a
559b901
2c4b7b4
ddd8944
85ca358
2a67de5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from .datalake import Datalake | ||
| from .types import Datum | ||
| from .dataset import Dataset | ||
| from .datum import Datum | ||
|
|
||
| __all__ = ["Datalake", "Datum"] | ||
| __all__ = ["Datalake", "Datum", "Dataset"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| import pathlib | ||
| from typing import Any | ||
|
|
||
| from datasets import Image, List, Sequence, Value | ||
|
|
||
| contracts_to_hf_type = { | ||
| "image": Image(), | ||
| "classification": {"label": Value("string"), "confidence": Value("float")}, | ||
| "bbox": {"bbox": List(Sequence(Value("float"), length=4))}, | ||
| } | ||
|
|
||
|
|
||
| def validate_contract(data: Any, contract: str): | ||
| if contract == "default": | ||
| pass | ||
| elif contract == "image": | ||
| if not isinstance(data, (pathlib.Path, pathlib.PosixPath)): | ||
| raise ValueError(f"Data must be a path to an image, got {type(data)}") | ||
| # TODO: check if this is actually an image | ||
| elif contract == "classification": | ||
| if not isinstance(data, dict): | ||
| raise ValueError(f"Data must be a dictionary, got {type(data)}") | ||
| if "label" not in data: | ||
| raise ValueError("Data must contain a 'label' key") | ||
| if "confidence" not in data: | ||
| raise ValueError("Data must contain a 'confidence' key") | ||
| if not isinstance(data["confidence"], float): | ||
| raise ValueError(f"Confidence must be a float, got {type(data['confidence'])}") | ||
| if data["confidence"] < 0 or data["confidence"] > 1: | ||
| raise ValueError("Confidence must be between 0 and 1") | ||
| elif contract == "bbox": | ||
| if not isinstance(data, dict): | ||
| raise ValueError(f"Data must be a dictionary, got {type(data)}") | ||
| if "bbox" not in data: | ||
| raise ValueError("Data must contain a 'bbox' key") | ||
| if not isinstance(data["bbox"], list): | ||
| raise ValueError(f"Bbox must be a list, got {type(data['bbox'])}") | ||
|
|
||
| for entry in data["bbox"]: | ||
| if not isinstance(entry, list): | ||
| raise ValueError(f"Bbox must be a list of lists, got {type(entry)}") | ||
| if len(entry) != 4: | ||
| raise ValueError("Bbox must be a list of lists of 4 elements") | ||
| if not all(isinstance(x, float) for x in entry): | ||
| raise ValueError("Bbox must be a list of lists of floats") | ||
| # Validate coordinates are non-negative (x1, y1, x2, y2 format) | ||
| if entry[0] < 0 or entry[1] < 0 or entry[2] < 0 or entry[3] < 0: | ||
| raise ValueError("Bbox coordinates must be non-negative") | ||
| # Validate that x2 >= x1 and y2 >= y1 | ||
| if entry[2] < entry[0] or entry[3] < entry[1]: | ||
| raise ValueError("Bbox must have x2 >= x1 and y2 >= y1") | ||
| elif contract == "regression": | ||
| pass | ||
| elif contract == "segmentation": | ||
| pass | ||
| else: | ||
| raise ValueError(f"Unsupported contract: {contract}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,9 @@ | |
| from mindtrace.core import Mindtrace | ||
| from mindtrace.database import MongoMindtraceODMBackend | ||
| from mindtrace.database.core.exceptions import DocumentNotFoundError | ||
| from mindtrace.datalake.types import Datum | ||
| from mindtrace.datalake.contracts import validate_contract | ||
| from mindtrace.datalake.dataset import Dataset | ||
| from mindtrace.datalake.datum import Datum | ||
| from mindtrace.registry import Registry | ||
| from mindtrace.registry.backends.local_registry_backend import LocalRegistryBackend | ||
|
|
||
|
|
@@ -50,14 +52,42 @@ def __init__(self, mongo_db_uri: str, mongo_db_name: str) -> None: | |
| db_uri=self.mongo_db_uri, | ||
| ) | ||
| self.registries: Dict[str, Registry] = {} | ||
| self.dataset_database: MongoMindtraceODMBackend[Dataset] = MongoMindtraceODMBackend[Dataset]( | ||
| model_cls=Dataset, | ||
| db_name=self.mongo_db_name, | ||
| db_uri=self.mongo_db_uri, | ||
| ) | ||
|
|
||
| async def initialize(self): | ||
| """ | ||
| Initialize the datalake by setting up database connections. | ||
|
|
||
| This method initializes both the datum database and dataset database backends, | ||
| establishing connections to MongoDB and preparing them for use. | ||
|
|
||
| Raises: | ||
| Exception: If database initialization fails | ||
| """ | ||
| await self.datum_database.initialize() | ||
| await self.dataset_database.initialize() | ||
|
|
||
| @classmethod | ||
| async def create(cls, mongo_db_uri: str, mongo_db_name: str) -> "Datalake": | ||
| """ | ||
| Create a Datalake instance from a configuration dictionary. | ||
| Create and initialize a Datalake instance. | ||
|
|
||
| This is a convenience class method that creates a Datalake instance | ||
| and initializes it in a single call. | ||
|
|
||
| Args: | ||
| mongo_db_uri: MongoDB connection URI | ||
| mongo_db_name: Name of the MongoDB database to use | ||
|
|
||
| Returns: | ||
| An initialized Datalake instance ready for use | ||
|
|
||
| Raises: | ||
| Exception: If database initialization fails | ||
| """ | ||
| datalake = cls(mongo_db_uri=mongo_db_uri, mongo_db_name=mongo_db_name) | ||
| await datalake.initialize() | ||
|
|
@@ -67,21 +97,47 @@ async def add_datum( | |
| self, | ||
| data: Any, | ||
| metadata: Dict[str, Any], | ||
| contract: Optional[str] = None, | ||
| registry_uri: Optional[str] = None, | ||
| derived_from: Optional[PydanticObjectId] = None, | ||
| project_id: Optional[str] = None, | ||
| line_id: Optional[str] = None, | ||
| ) -> Datum: | ||
| """ | ||
| Add a datum to the datalake asynchronously. | ||
|
|
||
| This method validates the data according to the specified contract, | ||
| stores it either in the database or in a registry backend, and returns | ||
| the created datum with an assigned ID. | ||
|
|
||
| Args: | ||
| data: The data to store | ||
| metadata: Metadata associated with the datum | ||
| registry_uri: Optional registry URI for external storage | ||
| derived_from: Optional ID of the parent datum | ||
| data: The data to store. Format depends on the contract: | ||
| - "image": Must be a pathlib.Path or pathlib.PosixPath to an image file | ||
| - "classification": Must be a dict with "label" (str) and "confidence" (float 0-1) | ||
| - "bbox": Must be a dict with "bbox" key containing a list of lists of 4 floats | ||
| - "default": Any data type | ||
| metadata: Metadata dictionary associated with the datum | ||
| contract: Optional contract type specifying the data format. If None, defaults to "default". | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "type" might be better naming for this usage?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me "type" has a precise meaning in Python which these "contract"s are very closely related to but not exactly, and so it's useful to use a different word?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh ok, I thought contracts were type primitives(classification etc), thus suggested as such. In that case, does it make sense coupling some stuff under single contract or in cases where we produce multiple outputs from a model, should we commit multiple datums or aggregate them within a single contract and data field? like if a visual inspection model produces 3 classification and 1 severity field, what should we do in general? |
||
| Supported contracts: "default", "image", "classification", "bbox", "regression", "segmentation" | ||
| registry_uri: Optional registry URI for external storage. If provided, data is stored | ||
| in the registry backend instead of the database. | ||
| derived_from: Optional ID of the parent datum this datum was derived from | ||
| project_id: Name of the project this datum belongs to | ||
| line_id: Name of the line this datum belongs to | ||
|
|
||
| Returns: | ||
| The created datum with assigned ID | ||
| The created Datum instance with assigned ID | ||
|
|
||
| Raises: | ||
| ValueError: If data doesn't match the contract requirements | ||
| Exception: If database or registry operations fail | ||
| """ | ||
|
|
||
| if contract is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think on datum we can have a required project_id/line_id key to help discriminate between unrelated tasks so we dont accidently get it while querying or potentially to speed up queries. Datastets would have them too im assuming
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, keep meaning to do this, done. |
||
| contract = "default" | ||
|
|
||
| validate_contract(data, contract) | ||
|
|
||
| if registry_uri: | ||
| # Store in registry | ||
| uuid = str(uuid4()) | ||
|
|
@@ -94,6 +150,9 @@ async def add_datum( | |
| registry_key=uuid, | ||
| derived_from=derived_from, | ||
| metadata=metadata, | ||
| contract=contract, | ||
| project_id=project_id or "default_project", | ||
| line_id=line_id or "default_line", | ||
| ) | ||
| else: | ||
| # Store in database | ||
|
|
@@ -103,6 +162,9 @@ async def add_datum( | |
| registry_key=None, | ||
| derived_from=derived_from, | ||
| metadata=metadata, | ||
| contract=contract, | ||
| project_id=project_id or "default_project", | ||
| line_id=line_id or "default_line", | ||
| ) | ||
| inserted_datum = await self.datum_database.insert(datum) | ||
| return inserted_datum | ||
|
|
@@ -208,6 +270,79 @@ async def get_indirectly_derived_data(self, datum_id: PydanticObjectId) -> list[ | |
|
|
||
| return result | ||
|
|
||
| async def add_dataset(self, dataset: Dataset) -> Dataset: | ||
| """ | ||
| Add a dataset to the datalake asynchronously. | ||
|
|
||
| Args: | ||
| dataset: The Dataset instance to store | ||
|
|
||
| Returns: | ||
| The created dataset with assigned ID | ||
|
|
||
| Raises: | ||
| Exception: If database insert fails | ||
| """ | ||
| inserted_dataset = await self.dataset_database.insert(dataset) | ||
| return inserted_dataset | ||
|
|
||
| async def get_dataset(self, dataset_id: PydanticObjectId | None) -> Dataset: | ||
| """ | ||
| Retrieve a dataset by its ID. | ||
|
|
||
| Args: | ||
| dataset_id: The unique identifier of the dataset to retrieve | ||
|
|
||
| Returns: | ||
| The dataset if found | ||
|
|
||
| Raises: | ||
| DocumentNotFoundError: If the dataset is not found | ||
| """ | ||
| if dataset_id is None: | ||
| raise DocumentNotFoundError("Dataset ID is None") | ||
| dataset = await self.dataset_database.get(dataset_id) | ||
| return dataset | ||
|
|
||
| async def get_datasets(self, dataset_ids: list[PydanticObjectId]) -> list[Dataset]: | ||
| """ | ||
| Retrieve multiple datasets by their IDs. | ||
|
|
||
| Args: | ||
| dataset_ids: List of unique identifiers of the datasets to retrieve | ||
|
|
||
| Returns: | ||
| List of datasets | ||
|
|
||
| Raises: | ||
| Exception: If database queries fail | ||
| """ | ||
| return await asyncio.gather(*[self.get_dataset(dataset_id) for dataset_id in dataset_ids]) | ||
|
|
||
| async def find_datasets(self, filter: dict[str, Any] | None = None) -> list[Dataset]: | ||
| """ | ||
| Find datasets matching the given filter. | ||
|
|
||
| This method searches for datasets using a MongoDB-style filter dictionary. | ||
| If no filter is provided, returns all datasets in the database. | ||
|
|
||
| Args: | ||
| filter: MongoDB-style filter dictionary. Examples: | ||
| - {"name": "my_dataset"} - find datasets with specific name | ||
| - {"metadata.project": "test_project"} - find datasets by metadata | ||
| - None - returns all datasets | ||
|
|
||
| Returns: | ||
| List of Dataset instances matching the filter | ||
|
|
||
| Raises: | ||
| Exception: If database query fails | ||
| """ | ||
| if filter is None: | ||
| filter = {} | ||
| datasets = await self.dataset_database.find(filter) | ||
| return list(datasets) | ||
|
|
||
| @overload | ||
| async def query_data( | ||
| self, query: list[dict[str, Any]] | dict[str, Any], datums_wanted: int | None = None, transpose: bool = False | ||
|
|
@@ -339,12 +474,12 @@ async def query_data( | |
| by using MongoDB's native aggregation capabilities instead of multiple round trips. | ||
|
|
||
| Args: | ||
| query: Same syntax as query_data - list of queries or single query | ||
| query: Same syntax as query_data_legacy - list of queries or single query | ||
| datums_wanted: Maximum number of results to return | ||
| transpose: Whether to return dict of lists (True) or list of dicts (False) | ||
|
|
||
| Returns: | ||
| Same format as query_data - list of dictionaries or dictionary of lists | ||
| Same format as query_data_legacy - list of dictionaries or dictionary of lists | ||
|
|
||
| Note: | ||
| This optimized version handles common cases but may fall back to the original | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I'm not super clear on how the Database module works on the back end. Having a single DB with two different docs seems like what we want but not what's currently doable with the MT Database?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @mazen-elabd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@canelbirlik Sorry just seen this comment,, The initialize is deprecated now. Still works, but deprecated. Will make sure to look into this and support it.