-
Notifications
You must be signed in to change notification settings - Fork 484
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add phone number and zip code custom types
- Loading branch information
Showing
10 changed files
with
147 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Type constraints | ||
|
||
We can ask completions to be restricted to valid python types: | ||
|
||
```python | ||
from outlines import models, generate | ||
|
||
model = models.transformers("mistralai/Mistral-7B-v0.1") | ||
generator = generate.format(model, int) | ||
answer = generator("When I was 6 my sister was half my age. Now I’m 70 how old is my sister?") | ||
print(answer) | ||
# 67 | ||
``` | ||
|
||
The following types are currently available: | ||
|
||
- int | ||
- float | ||
- bool | ||
- datetime.date | ||
- datetime.time | ||
- datetime.datetime | ||
- We also provide [custom types](types.md) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,49 @@ | ||
# Type constraints | ||
# Custom types | ||
|
||
We can ask completions to be restricted to valid python types: | ||
Outlines provides custom Pydantic types so you can focus on your use case rather than on writing regular expressions: | ||
|
||
- Using `outlines.types.ZipCode` will generate valid US Zip(+4) codes. | ||
- Using `outlines.types.PhoneNumber` will generate valid US phone numbers. | ||
|
||
You can use these types in Pydantic schemas for JSON-structured generation: | ||
|
||
```python | ||
from pydantic import BaseModel | ||
|
||
from outlines import models, generate, types | ||
|
||
|
||
class Client(BaseModel): | ||
name: str | ||
phone_number: types.PhoneNumber | ||
zip_code: types.ZipCode | ||
|
||
|
||
model = models.transformers("mistralai/Mistral-7B-v0.1") | ||
generator = generate.json(model, Client) | ||
result = generator( | ||
"Create a client profile with the fields name, phone_number and zip_code" | ||
) | ||
print(result) | ||
# name='Tommy' phone_number='129-896-5501' zip_code='50766' | ||
``` | ||
|
||
Or simply with `outlines.generate.format`: | ||
|
||
```python | ||
from outlines import models, generate | ||
from pydantic import BaseModel | ||
|
||
from outlines import models, generate, types | ||
|
||
|
||
model = models.transformers("mistralai/Mistral-7B-v0.1") | ||
generator = generate.format(model, int) | ||
answer = generator("When I was 6 my sister was half my age. Now I’m 70 how old is my sister?") | ||
print(answer) | ||
# 67 | ||
generator = generate.format(model, types.PhoneNumber) | ||
result = generator( | ||
"Return a US Phone number: " | ||
) | ||
print(result) | ||
# 334-253-2630 | ||
``` | ||
|
||
The following types are currently available: | ||
|
||
- int | ||
- float | ||
- bool | ||
- datetime.date | ||
- datetime.time | ||
- datetime.datetime | ||
We plan on adding many more custom types. If you have found yourself writing regular expressions to generate fields of a given type, or if you could benefit from more specific types don't hesite to [submit a PR](https://github.com/outlines-dev/outlines/pulls) or [open an issue](https://github.com/outlines-dev/outlines/issues/new/choose). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .phone_numbers import PhoneNumber | ||
from .zip_codes import ZipCode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
"""Phone number types. | ||
We currently only support US phone numbers. We can however imagine having custom types | ||
for each country, for instance leveraging the `phonenumbers` library. | ||
""" | ||
from pydantic import WithJsonSchema | ||
from typing_extensions import Annotated | ||
|
||
US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}" | ||
|
||
|
||
PhoneNumber = Annotated[ | ||
str, | ||
WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Zip code types. | ||
We currently only support US Zip Codes. | ||
""" | ||
from pydantic import WithJsonSchema | ||
from typing_extensions import Annotated | ||
|
||
# This matches Zip and Zip+4 codes | ||
US_ZIP_CODE = r"\d{5}(?:-\d{4})?" | ||
|
||
|
||
ZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import re | ||
|
||
import pytest | ||
from pydantic import BaseModel | ||
|
||
from outlines import types | ||
from outlines.fsm.types import python_types_to_regex | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"custom_type,test_string,should_match", | ||
[ | ||
(types.PhoneNumber, "12", False), | ||
(types.PhoneNumber, "(123) 123-1234", True), | ||
(types.PhoneNumber, "123-123-1234", True), | ||
(types.ZipCode, "12", False), | ||
(types.ZipCode, "12345", True), | ||
(types.ZipCode, "12345-1234", True), | ||
], | ||
) | ||
def test_phone_number(custom_type, test_string, should_match): | ||
class Model(BaseModel): | ||
attr: custom_type | ||
|
||
schema = Model.model_json_schema() | ||
assert schema["properties"]["attr"]["type"] == "string" | ||
regex_str = schema["properties"]["attr"]["pattern"] | ||
does_match = re.match(regex_str, test_string) is not None | ||
assert does_match is should_match | ||
|
||
regex_str, format_fn = python_types_to_regex(custom_type) | ||
assert isinstance(format_fn(1), str) | ||
does_match = re.match(regex_str, test_string) is not None | ||
assert does_match is should_match |