Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
490 changes: 490 additions & 0 deletions docs/features/metadata_tracking.md

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ def check_model_availability(task_name):
default=False,
)

parser.add_argument(
"--disable_metadata",
"-dm",
type=ast.literal_eval,
default=False,
help="Disable metadata collection (default: False)",
)

args = parser.parse_args()

start = time.time()
Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ nav:
- Output Record Generator: concepts/processors/output_record_generator/README.md
- Schema Validator: concepts/schema_validator/README.md
- Structured Output: concepts/structured_output/README.md
- Features:
- Metadata Tracking: features/metadata_tracking.md
- Tutorials:
- Agent Simulation: tutorials/agent_simulation_tutorial.md
- Agent Simulation with Tools: tutorials/agent_tool_simulation_tutorial.md
Expand Down
1 change: 1 addition & 0 deletions sygra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def get_model_info(model_name: str) -> dict[str, Any]:

# Main exports
__all__ = [
"__version__",
# Main classes
"Workflow",
"create_graph",
Expand Down
124 changes: 119 additions & 5 deletions sygra/core/base_task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sygra.core.graph.langgraph.graph_builder import LangGraphBuilder
from sygra.core.graph.sygra_state import SygraState
from sygra.logger.logger_config import logger
from sygra.metadata.metadata_collector import get_metadata_collector
from sygra.processors.output_record_generator import BaseOutputGenerator
from sygra.tools.toolkits.data_quality.processor import DataQuality
from sygra.utils import constants, utils
Expand Down Expand Up @@ -60,6 +61,55 @@ def __init__(self, args: Any, graph_config_dict: Optional[dict] = None):
graph_properties=graph_props,
)
self.output_generator: Optional[BaseOutputGenerator] = self._init_output_generator()
self._init_metadata_collector(args)

def _init_metadata_collector(self, args):
"""Initialize metadata collection for this execution."""
collector = get_metadata_collector()

# Check if metadata collection should be disabled
disable_metadata = getattr(args, "disable_metadata", False)
if disable_metadata:
collector.set_enabled(False)
logger.info("Metadata collection disabled via --disable_metadata flag")
return

# Reset collector to clear any data from health checks or previous runs
collector.reset()
collector.set_execution_context(
task_name=self.task_name,
run_name=getattr(args, "run_name", None),
output_dir=getattr(args, "output_dir", None),
batch_size=getattr(args, "batch_size", 50),
checkpoint_interval=getattr(args, "checkpoint_interval", 100),
resumable=getattr(args, "resume", False),
debug=getattr(args, "debug", False),
)

# Set dataset metadata if available
if self.source_config is not None:
source_path = None
# Try to get source path from repo_id (HuggingFace) or file_path (local files)
if self.source_config.repo_id:
source_path = self.source_config.repo_id
elif self.source_config.file_path:
source_path = self.source_config.file_path

# Use captured dataset version and hash (captured before transformations)
dataset_version = getattr(self, "_dataset_version", None)
dataset_hash = getattr(self, "_dataset_hash", None)

collector.set_dataset_metadata(
source_type=(
str(self.source_config.type.value)
if hasattr(self.source_config.type, "value")
else str(self.source_config.type)
),
source_path=source_path,
start_index=getattr(args, "start_index", 0),
dataset_version=dataset_version,
dataset_hash=dataset_hash,
)

@staticmethod
def _configure_resume_behavior(args: Any, config_resumable: bool) -> bool:
Expand Down Expand Up @@ -328,6 +378,9 @@ def _load_source_data(
reader = self._get_data_reader()
full_data = self._read_data(reader)

# Capture dataset metadata from reader (which stores it before conversion)
self._capture_dataset_metadata(full_data, reader)

# Apply transformations to the dataset
full_data = self.apply_transforms(self.source_config, full_data)

Expand All @@ -340,6 +393,43 @@ def _load_source_data(

return full_data

def _capture_dataset_metadata(self, dataset: Any, reader: Any) -> None:
"""Capture dataset version and hash before transformations."""
try:
# First try to get from reader (HuggingFaceHandler stores it before conversion)
if hasattr(reader, "dataset_version") and hasattr(reader, "dataset_hash"):
self._dataset_version = reader.dataset_version
self._dataset_hash = reader.dataset_hash
logger.debug(
f"Captured dataset metadata from reader: version={self._dataset_version}, hash={self._dataset_hash}"
)
return

# Fallback: try to extract from dataset object directly
import datasets

if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)):
self._dataset_version = None
self._dataset_hash = None

# Try to get version info
if hasattr(dataset, "info") and dataset.info:
version_obj = getattr(dataset.info, "version", None)
if version_obj:
self._dataset_version = str(version_obj)

# Try to get fingerprint/hash
if hasattr(dataset, "_fingerprint"):
self._dataset_hash = dataset._fingerprint
elif hasattr(dataset, "n_shards"):
self._dataset_hash = f"iterable_{dataset.n_shards}_shards"

logger.debug(
f"Captured dataset metadata from dataset: version={self._dataset_version}, hash={self._dataset_hash}"
)
except Exception as e:
logger.debug(f"Could not capture dataset metadata: {e}")

def _generate_empty_dataset(self) -> list[dict]:
"""Generate empty dataset with specified number of records"""
num_records = self.args.num_records
Expand Down Expand Up @@ -533,11 +623,14 @@ def execute(self):
logger.info("Graph compiled successfully")
logger.info("\n" + compiled_graph.get_graph().draw_ascii())

ts_suffix = (
""
if not self.args.output_with_ts
else "_" + str(datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
)
# Create timestamp for output file
run_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ts_suffix = "" if not self.args.output_with_ts else "_" + run_timestamp

# Update metadata collector with the run timestamp so metadata filename matches output filename
if self.args.output_with_ts:
collector = get_metadata_collector()
collector.execution_context.run_timestamp = run_timestamp

num_records_total = self.args.num_records
if isinstance(self.dataset, list):
Expand Down Expand Up @@ -657,6 +750,27 @@ def execute(self):
if dataset_processor.resume_manager:
dataset_processor.resume_manager.force_save_state(is_final=True)

self._save_metadata(dataset_processor)

def _save_metadata(self, dataset_processor=None):
"""Finalize and save execution metadata."""
try:
from sygra.metadata.metadata_collector import get_metadata_collector

collector = get_metadata_collector()

# Update dataset metadata with actual processed count
if dataset_processor:
collector.dataset_metadata.num_records_processed = (
dataset_processor.num_records_processed + dataset_processor.failed_records
)

collector.finalize_execution()
metadata_path = collector.save_metadata()
logger.info(f"Run metadata saved to: {metadata_path}")
except Exception as e:
logger.warning(f"Failed to save metadata: {e}")


class DefaultTaskExecutor(BaseTaskExecutor):
"""
Expand Down
9 changes: 9 additions & 0 deletions sygra/core/dataset/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sygra.core.resumable_execution import ResumableExecutionManager
from sygra.data_mapper.mapper import DataMapper
from sygra.logger.logger_config import logger
from sygra.metadata.metadata_collector import get_metadata_collector
from sygra.utils import constants, graph_utils, multimodal_processor, utils
from sygra.validators.schema_validator_base import SchemaValidator

Expand Down Expand Up @@ -200,6 +201,10 @@ async def _add_graph_result(self, output: dict[str, Any], record: dict[str, Any]
# Check if execution had an error - don't mark as processed if it did
if self.is_error_code_in_output(output):
self.failed_records += 1
# Record failed record in metadata collector
collector = get_metadata_collector()
collector.record_processed_record(success=False)

# For resumable execution, remove from in-process but don't mark as processed
if self.resumable and self.resume_manager:
logger.warning(
Expand Down Expand Up @@ -239,6 +244,10 @@ async def _add_graph_result(self, output: dict[str, Any], record: dict[str, Any]
self.graph_results.append(output)
self.num_records_processed += 1

# Record successful record in metadata collector
collector = get_metadata_collector()
collector.record_processed_record(success=True)

# all the code below should refer total_records_with_error(not self.num_records_processed)
total_records_with_error = self.num_records_processed + self.failed_records

Expand Down
32 changes: 25 additions & 7 deletions sygra/core/dataset/huggingface_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,7 @@
from datasets import config as ds_config
from datasets import load_dataset
from datasets.utils.metadata import MetadataConfigs # type: ignore[import-untyped]
from huggingface_hub import (
CommitOperationAdd,
DatasetCard,
DatasetCardData,
HfApi,
HfFileSystem,
)
from huggingface_hub import CommitOperationAdd, DatasetCard, DatasetCardData, HfApi, HfFileSystem

from sygra.core.dataset.data_handler_base import DataHandler
from sygra.core.dataset.dataset_config import DataSourceConfig, OutputConfig
Expand Down Expand Up @@ -245,6 +239,28 @@ def _read_shard(self, path: str) -> list[dict[str, Any]]:
df = pd.read_parquet(io.BytesIO(f.read()))
return cast(list[dict[str, Any]], df.to_dict(orient="records"))

def _store_dataset_metadata(self, dataset: Dataset) -> None:
"""Store dataset metadata as instance variables for later retrieval."""
try:
self.dataset_version = None
self.dataset_hash = None

# Extract version from dataset info
if hasattr(dataset, "info") and dataset.info:
version_obj = getattr(dataset.info, "version", None)
if version_obj:
self.dataset_version = str(version_obj)

# Extract fingerprint/hash
if hasattr(dataset, "_fingerprint"):
self.dataset_hash = dataset._fingerprint

logger.debug(
f"Stored dataset metadata: version={self.dataset_version}, hash={self.dataset_hash}"
)
except Exception as e:
logger.debug(f"Could not store dataset metadata: {e}")

def _load_dataset_by_split(self, split) -> Union[Dataset, IterableDataset]:
"""Load dataset for a specific split."""
if not self.source_config:
Expand Down Expand Up @@ -279,6 +295,8 @@ def _read_dataset(self) -> Union[list[dict[str, Any]], Iterator[dict[str, Any]]]
return cast(Iterator[dict[str, Any]], ds)
else:
ds_concrete = cast(Dataset, ds)
# Store dataset metadata before converting to list (which loses metadata)
self._store_dataset_metadata(ds_concrete)
return cast(list[dict[str, Any]], ds_concrete.to_pandas().to_dict(orient="records"))

except Exception as e:
Expand Down
5 changes: 2 additions & 3 deletions sygra/core/graph/backend_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,12 @@ def create_multi_llm_runnable(self, llm_dict: dict, post_process):
pass

@abstractmethod
def create_weighted_sampler_runnable(self, weighted_sampler_function, attr_config):
def create_weighted_sampler_runnable(self, exec_wrapper):
"""
Abstract method to create weighted sampler runnable.
Args:
weighted_sampler_function: Weighted sampler function
attr_config: attributes from the weighted sampler node
exec_wrapper: Async function wrapper to execute
Returns:
Any: backend specific runnable object like Runnable for backend=Langgraph
Expand Down
9 changes: 4 additions & 5 deletions sygra/core/graph/langgraph/langgraph_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,17 @@ def create_multi_llm_runnable(self, llm_dict: dict, post_process):
runnable_inputs = {k: v.to_backend() for k, v in llm_dict.items()}
return RunnableParallel(**runnable_inputs) | RunnableLambda(post_process)

def create_weighted_sampler_runnable(self, weighted_sampler_function, attr_config):
def create_weighted_sampler_runnable(self, exec_wrapper):
"""
Abstract method to create weighted sampler runnable.
Create weighted sampler runnable.

Args:
weighted_sampler_function: Weighted sampler function
attr_config: attributes from the weighted sampler node
exec_wrapper: Async function wrapper to execute

Returns:
Any: backend specific runnable object like Runnable for backend=Langgraph
"""
return RunnableLambda(partial(weighted_sampler_function, attr_config))
return RunnableLambda(exec_wrapper)

def create_connector_runnable(self):
"""
Expand Down
Loading