Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ docker run --rm -e LANGEXTRACT_API_KEY="your-api-key" langextract python your_sc

## API Key Setup for Cloud Models

When using LangExtract with cloud-hosted models (like Gemini), you'll need to
When using LangExtract with cloud-hosted models (like Gemini or OpenAI), you'll need to
set up an API key. On-device models don't require an API key. For developers
using local LLMs, LangExtract offers built-in support for Ollama and can be
extended to other third-party APIs by updating the inference endpoints.
Expand All @@ -201,6 +201,7 @@ Get API keys from:

* [AI Studio](https://aistudio.google.com/app/apikey) for Gemini models
* [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/sdks/overview) for enterprise use
* [OpenAI Platform](https://platform.openai.com/api-keys) for OpenAI models

### Setting up API key in your environment

Expand Down Expand Up @@ -250,6 +251,27 @@ result = lx.extract(
)
```

## Using OpenAI Models

LangExtract also supports OpenAI models. Example OpenAI configuration:

```python
from langextract.inference import OpenAILanguageModel

result = lx.extract(
text_or_documents=input_text,
prompt_description=prompt,
examples=examples,
language_model_type=OpenAILanguageModel,
model_id="gpt-4o",
api_key=os.environ.get('OPENAI_API_KEY'),
fence_output=True,
use_schema_constraints=False
)
```

Note: OpenAI models require `fence_output=True` and `use_schema_constraints=False` because LangExtract doesn't implement schema constraints for OpenAI yet.

## More Examples

Additional examples of LangExtract in action:
Expand Down
170 changes: 169 additions & 1 deletion langextract/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Any

from google import genai
import openai
import requests
from typing_extensions import override
import yaml
Expand Down Expand Up @@ -383,7 +384,174 @@ def infer(
yield [result]

def parse_output(self, output: str) -> Any:
"""Parses Gemini output as JSON or YAML."""
"""Parses Gemini output as JSON or YAML.

Note: This expects raw JSON/YAML without code fences.
Code fence extraction is handled by resolver.py.
"""
try:
if self.format_type == data.FormatType.JSON:
return json.loads(output)
else:
return yaml.safe_load(output)
except Exception as e:
raise ValueError(
f'Failed to parse output as {self.format_type.name}: {str(e)}'
) from e


@dataclasses.dataclass(init=False)
class OpenAILanguageModel(BaseLanguageModel):
"""Language model inference using OpenAI's API with structured output."""

model_id: str = 'gpt-4o-mini'
api_key: str | None = None
organization: str | None = None
format_type: data.FormatType = data.FormatType.JSON
temperature: float = 0.0
max_workers: int = 10
_client: openai.OpenAI | None = dataclasses.field(
default=None, repr=False, compare=False
)
_extra_kwargs: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, compare=False
)

def __init__(
self,
model_id: str = 'gpt-4o-mini',
api_key: str | None = None,
organization: str | None = None,
format_type: data.FormatType = data.FormatType.JSON,
temperature: float = 0.0,
max_workers: int = 10,
**kwargs,
) -> None:
"""Initialize the OpenAI language model.

Args:
model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o').
api_key: API key for OpenAI service.
organization: Optional OpenAI organization ID.
format_type: Output format (JSON or YAML).
temperature: Sampling temperature.
max_workers: Maximum number of parallel API calls.
**kwargs: Ignored extra parameters so callers can pass a superset of
arguments shared across back-ends without raising ``TypeError``.
"""
self.model_id = model_id
self.api_key = api_key
self.organization = organization
self.format_type = format_type
self.temperature = temperature
self.max_workers = max_workers
self._extra_kwargs = kwargs or {}

if not self.api_key:
raise ValueError('API key not provided.')

# Initialize the OpenAI client
self._client = openai.OpenAI(
api_key=self.api_key, organization=self.organization
)

super().__init__(
constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE)
)

def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput:
"""Process a single prompt and return a ScoredOutput."""
try:
# Prepare the system message for structured output
system_message = ''
if self.format_type == data.FormatType.JSON:
system_message = (
'You are a helpful assistant that responds in JSON format.'
)
elif self.format_type == data.FormatType.YAML:
system_message = (
'You are a helpful assistant that responds in YAML format.'
)

# Create the chat completion using the v1.x client API
response = self._client.chat.completions.create(
model=self.model_id,
messages=[
{'role': 'system', 'content': system_message},
{'role': 'user', 'content': prompt},
],
temperature=config.get('temperature', self.temperature),
max_tokens=config.get('max_output_tokens'),
top_p=config.get('top_p'),
n=1,
)

# Extract the response text using the v1.x response format
output_text = response.choices[0].message.content

return ScoredOutput(score=1.0, output=output_text)

except Exception as e:
raise InferenceOutputError(f'OpenAI API error: {str(e)}') from e

def infer(
self, batch_prompts: Sequence[str], **kwargs
) -> Iterator[Sequence[ScoredOutput]]:
"""Runs inference on a list of prompts via OpenAI's API.

Args:
batch_prompts: A list of string prompts.
**kwargs: Additional generation params (temperature, top_p, etc.)

Yields:
Lists of ScoredOutputs.
"""
config = {
'temperature': kwargs.get('temperature', self.temperature),
}
if 'max_output_tokens' in kwargs:
config['max_output_tokens'] = kwargs['max_output_tokens']
if 'top_p' in kwargs:
config['top_p'] = kwargs['top_p']

# Use parallel processing for batches larger than 1
if len(batch_prompts) > 1 and self.max_workers > 1:
with concurrent.futures.ThreadPoolExecutor(
max_workers=min(self.max_workers, len(batch_prompts))
) as executor:
future_to_index = {
executor.submit(
self._process_single_prompt, prompt, config.copy()
): i
for i, prompt in enumerate(batch_prompts)
}

results: list[ScoredOutput | None] = [None] * len(batch_prompts)
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
try:
results[index] = future.result()
except Exception as e:
raise InferenceOutputError(
f'Parallel inference error: {str(e)}'
) from e

for result in results:
if result is None:
raise InferenceOutputError('Failed to process one or more prompts')
yield [result]
else:
# Sequential processing for single prompt or worker
for prompt in batch_prompts:
result = self._process_single_prompt(prompt, config.copy())
yield [result]

def parse_output(self, output: str) -> Any:
"""Parses OpenAI output as JSON or YAML.

Note: This expects raw JSON/YAML without code fences.
Code fence extraction is handled by resolver.py.
"""
try:
if self.format_type == data.FormatType.JSON:
return json.loads(output)
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "langextract"
version = "1.0.2"
version = "1.0.3"
description = "LangExtract: A library for extracting structured data from language models"
readme = "README.md"
requires-python = ">=3.10"
Expand All @@ -35,10 +35,11 @@ dependencies = [
"ml-collections>=0.1.0",
"more-itertools>=8.0.0",
"numpy>=1.20.0",
"openai>=0.27.0",
"openai>=1.50.0",
"pandas>=1.3.0",
"pydantic>=1.8.0",
"python-dotenv>=0.19.0",
"PyYAML>=6.0",
"requests>=2.25.0",
"tqdm>=4.64.0",
"typing-extensions>=4.0.0"
Expand All @@ -55,9 +56,8 @@ dev = [
"pyink~=24.3.0",
"isort>=5.13.0",
"pylint>=3.0.0",
"pytest>=7.4.0",
"pytype>=2024.10.11",
"tox>=4.0.0",
"tox>=4.0.0"
]
test = [
"pytest>=7.4.0",
Expand Down
116 changes: 116 additions & 0 deletions tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from absl.testing import absltest

from langextract import data
from langextract import inference


Expand Down Expand Up @@ -93,5 +94,120 @@ def test_ollama_infer(self, mock_ollama_query):
self.assertEqual(results, expected_results)


class TestOpenAILanguageModel(absltest.TestCase):

@mock.patch("openai.OpenAI")
def test_openai_infer(self, mock_openai_class):
# Mock the OpenAI client and chat completion response
mock_client = mock.Mock()
mock_openai_class.return_value = mock_client

# Mock response structure for v1.x API
mock_response = mock.Mock()
mock_response.choices = [
mock.Mock(message=mock.Mock(content='{"name": "John", "age": 30}'))
]
mock_client.chat.completions.create.return_value = mock_response

# Create model instance
model = inference.OpenAILanguageModel(
model_id="gpt-4o-mini", api_key="test-api-key", temperature=0.5
)

# Test inference
batch_prompts = ["Extract name and age from: John is 30 years old"]
results = list(model.infer(batch_prompts))

# Verify API was called correctly
mock_client.chat.completions.create.assert_called_once_with(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": (
"You are a helpful assistant that responds in JSON format."
),
},
{
"role": "user",
"content": "Extract name and age from: John is 30 years old",
},
],
temperature=0.5,
max_tokens=None,
top_p=None,
n=1,
)

# Check results
expected_results = [[
inference.ScoredOutput(score=1.0, output='{"name": "John", "age": 30}')
]]
self.assertEqual(results, expected_results)

def test_openai_parse_output_json(self):
model = inference.OpenAILanguageModel(
api_key="test-key", format_type=data.FormatType.JSON
)

# Test valid JSON parsing
output = '{"key": "value", "number": 42}'
parsed = model.parse_output(output)
self.assertEqual(parsed, {"key": "value", "number": 42})

# Test invalid JSON
with self.assertRaises(ValueError) as context:
model.parse_output("invalid json")
self.assertIn("Failed to parse output as JSON", str(context.exception))

def test_openai_parse_output_yaml(self):
model = inference.OpenAILanguageModel(
api_key="test-key", format_type=data.FormatType.YAML
)

# Test valid YAML parsing
output = "key: value\nnumber: 42"
parsed = model.parse_output(output)
self.assertEqual(parsed, {"key": "value", "number": 42})

# Test invalid YAML
with self.assertRaises(ValueError) as context:
model.parse_output("invalid: yaml: bad")
self.assertIn("Failed to parse output as YAML", str(context.exception))

def test_openai_no_api_key_raises_error(self):
with self.assertRaises(ValueError) as context:
inference.OpenAILanguageModel(api_key=None)
self.assertEqual(str(context.exception), "API key not provided.")

@mock.patch("openai.OpenAI")
def test_openai_temperature_zero(self, mock_openai_class):
# Test that temperature=0.0 is properly passed through
mock_client = mock.Mock()
mock_openai_class.return_value = mock_client

mock_response = mock.Mock()
mock_response.choices = [
mock.Mock(message=mock.Mock(content='{"result": "test"}'))
]
mock_client.chat.completions.create.return_value = mock_response

model = inference.OpenAILanguageModel(
api_key="test-key", temperature=0.0 # Testing zero temperature
)

list(model.infer(["test prompt"]))

# Verify temperature=0.0 was passed to the API
mock_client.chat.completions.create.assert_called_with(
model="gpt-4o-mini",
messages=mock.ANY,
temperature=0.0,
max_tokens=None,
top_p=None,
n=1,
)


if __name__ == "__main__":
absltest.main()
Loading