Skip to content

Add Synchronous processing #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafog/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.2"
__version__ = "3.3.0"
3 changes: 1 addition & 2 deletions datafog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .__about__ import __version__
from .config import OperationType
from .main import DataFog, OCRPIIAnnotator, TextPIIAnnotator
from .main import DataFog, TextPIIAnnotator
from .processing.image_processing.donut_processor import DonutProcessor
from .processing.image_processing.image_downloader import ImageDownloader
from .processing.image_processing.pytesseract_processor import PytesseractProcessor
Expand All @@ -13,7 +13,6 @@
"DonutProcessor",
"DataFog",
"ImageService",
"OCRPIIAnnotator",
"OperationType",
"SparkService",
"TextPIIAnnotator",
Expand Down
65 changes: 27 additions & 38 deletions datafog/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
from logging import INFO
from typing import List

from .config import OperationType
Expand All @@ -10,7 +9,7 @@
from .services.text_service import TextService

logger = logging.getLogger("datafog_logger")
logger.setLevel(INFO)
logger.setLevel(logging.INFO)


class DataFog:
Expand All @@ -37,7 +36,7 @@
self.logger.info(f"Operations: {operations}")

async def run_ocr_pipeline(self, image_urls: List[str]):
"""Run the OCR pipeline asynchronously."""
"""Run the OCR pipeline asynchronously on a list of images provided via url."""
try:
extracted_text = await self.image_service.ocr_extract(image_urls)
self.logger.info(f"OCR extraction completed for {len(image_urls)} images.")
Expand All @@ -46,7 +45,7 @@
)

if OperationType.ANNOTATE_PII in self.operations:
annotated_text = await self.text_service.batch_annotate_texts(
annotated_text = await self.text_service.batch_annotate_text_async(
extracted_text
)
self.logger.info(
Expand All @@ -59,55 +58,45 @@
self.logger.error(f"Error in run_ocr_pipeline: {str(e)}")
raise

async def run_text_pipeline(self, texts: List[str]):
"""Run the text pipeline asynchronously."""
async def run_text_pipeline(self, str_list: List[str]):
"""Run the text pipeline asynchronously on a list of input text."""
try:
self.logger.info(f"Starting text pipeline with {len(texts)} texts.")
self.logger.info(f"Starting text pipeline with {len(str_list)} texts.")
if OperationType.ANNOTATE_PII in self.operations:
annotated_text = await self.text_service.batch_annotate_texts(texts)
annotated_text = await self.text_service.batch_annotate_text_async(
str_list
)
self.logger.info(
f"Text annotation completed with {len(annotated_text)} annotations."
)
return annotated_text

self.logger.info("No annotation operation found; returning original texts.")
return texts
return str_list

Check warning on line 75 in datafog/main.py

View check run for this annotation

Codecov / codecov/patch

datafog/main.py#L75

Added line #L75 was not covered by tests
except Exception as e:
self.logger.error(f"Error in run_text_pipeline: {str(e)}")
raise

def _add_attributes(self, attributes: dict):
"""Add multiple attributes."""
for key, value in attributes.items():
pass


class OCRPIIAnnotator:
def __init__(self):
self.image_service = ImageService(use_donut=True, use_tesseract=False)
self.text_annotator = SpacyPIIAnnotator.create()
self.spark_service: SparkService = None

async def run(self, image_urls: List[str], output_path=None):
def run_text_pipeline_sync(self, str_list: List[str]):
"""Run the text pipeline synchronously on a list of input text."""
try:
# Download and process the image to extract text
# downloaded_images = await self.image_service.download_images(image_urls)
# extracted_texts = await self.image_service.ocr_extract(downloaded_images)

# # Annotate the extracted text for PII
# annotated_texts = [self.text_annotator.annotate(text) for text in extracted_texts]

# # Optionally, output the results to a JSON file
# if output_path:
# with open(output_path, "w") as f:
# json.dump(annotated_texts, f)
self.logger.info(f"Starting text pipeline with {len(str_list)} texts.")
if OperationType.ANNOTATE_PII in self.operations:
annotated_text = self.text_service.batch_annotate_text_sync(str_list)
self.logger.info(
f"Text annotation completed with {len(annotated_text)} annotations."
)
return annotated_text

# return annotated_texts
pass
self.logger.info("No annotation operation found; returning original texts.")
return str_list
except Exception as e:
self.logger.error(f"Error in run_text_pipeline: {str(e)}")
raise

Check warning on line 95 in datafog/main.py

View check run for this annotation

Codecov / codecov/patch

datafog/main.py#L91-L95

Added lines #L91 - L95 were not covered by tests

finally:
# Ensure Spark resources are released
# self.spark_processor.spark.stop()
def _add_attributes(self, attributes: dict):
"""Add multiple attributes."""
for key, value in attributes.items():

Check warning on line 99 in datafog/main.py

View check run for this annotation

Codecov / codecov/patch

datafog/main.py#L99

Added line #L99 was not covered by tests
pass


Expand Down
24 changes: 18 additions & 6 deletions datafog/services/text_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,24 @@ class TextService:
def __init__(self):
self.annotator = SpacyPIIAnnotator.create()

async def annotate_text(self, text):
"""Asynchronously annotate a single piece of text."""
def annotate_text_sync(self, text):
"""Synchronously Annotate a text string."""
print(f"Starting on {text.split()[0]}")
res = self.annotator.annotate(text)
print(f"Done processing {text.split()[0]}")
return res

def batch_annotate_text_sync(self, texts: list):
"""Synchronously annotate a list of text input."""
results = [self.annotate_text_sync(text) for text in texts]
return dict(zip(texts, results, strict=True))

async def annotate_text_async(self, text):
"""Asynchronously annotate a text string."""
return await asyncio.to_thread(self.annotator.annotate, text)

async def batch_annotate_texts(self, texts: list):
"""Asynchronously annotate a batch of texts."""
tasks = [self.annotate_text(text) for text in texts]
async def batch_annotate_text_async(self, text: list):
"""Asynchronously annotate a list of text input."""
tasks = [self.annotate_text_async(txt) for txt in text]
results = await asyncio.gather(*tasks)
return dict(zip(texts, results, strict=True))
return dict(zip(text, results, strict=True))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
long_description = f.read()

# Use a single source of truth for the version
__version__ = "3.2.2"
__version__ = "3.3.0"

project_urls = {
"Homepage": "https://datafog.ai",
Expand Down
15 changes: 15 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ def test_textpii_annotator():
# assert "Satya Nadella" in annotated_text[0].get("PER", []), "PII not annotated correctly."


def test_datafog_text_annotation_sync():
"""Test DataFog class for synchronous text annotation."""
text = ["Joe Biden is the President of the United States."]
datafog = DataFog()
annotated_text = datafog.run_text_pipeline_sync(text)

assert annotated_text # Ensure that some results are returned.
assert search_nested_dict(
annotated_text, "Joe Biden"
), "Joe Biden not found in annotated results."
assert search_nested_dict(
annotated_text, "the United States"
), "United States not found in annotated results."


@pytest.mark.asyncio
async def test_datafog_text_annotation():
"""Test DataFog class for text annotation."""
Expand Down
Loading