Skip to content

Commit

Permalink
Some refinements to signature, allowing instantiation and fixing lint…
Browse files Browse the repository at this point in the history
… issues.
  • Loading branch information
thomasahle committed Mar 3, 2024
1 parent 82524ba commit 6ba4130
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 150 deletions.
256 changes: 136 additions & 120 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,42 @@
import dsp
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo
from typing import Type, Union, Dict, Tuple
from typing import Type, Union, Dict, Tuple # noqa: UP035
import re

from dspy.signatures.field import InputField, OutputField, new_to_old_field


def signature_to_template(signature):
"""Convert from new to legacy format"""
def signature_to_template(signature) -> dsp.Template:
"""Convert from new to legacy format."""
return dsp.Template(
signature.instructions,
**{name: new_to_old_field(field) for name, field in signature.fields.items()},
)


def _default_instructions(cls):
inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields.keys()])
outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields.keys()])
def _default_instructions(cls) -> str:
inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields])
outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields])
return f"Given the fields {inputs_}, produce the fields {outputs_}."


class SignatureMeta(type(BaseModel)):
def __new__(mcs, name, bases, namespace, **kwargs):
def __call__(cls, *args, **kwargs): # noqa: ANN002
if cls is Signature:
return make_signature(*args, **kwargs)
return super().__call__(*args, **kwargs)

def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804
# Set `str` as the default type for all fields
raw_annotations = namespace.get("__annotations__", {})
for name, field in namespace.items():
for name, _field in namespace.items():
if not name.startswith("__") and name not in raw_annotations:
raw_annotations[name] = str
namespace["__annotations__"] = raw_annotations

# Let Pydantic do its thing
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs)

if cls.__doc__ is None:
cls.__doc__ = _default_instructions(cls)
Expand Down Expand Up @@ -69,17 +74,20 @@ def signature(cls) -> str:
def instructions(cls) -> str:
return getattr(cls, "__doc__", "")

def with_instructions(cls, instructions: str):
def with_instructions(cls, instructions: str) -> Type["Signature"]:
return Signature(cls.fields, instructions)

@property
def fields(cls):
def fields(cls) -> dict[str, FieldInfo]:
# Make sure to give input fields before output fields
return {**cls.input_fields, **cls.output_fields}

def with_updated_fields(cls, name, type_=None, **kwargs):
"""Returns a new Signature type with the field, name, updated
with fields[name].json_schema_extra[key] = value."""
def with_updated_fields(cls, name, type_=None, **kwargs) -> Type["Signature"]:
"""Update the field, name, in a new Signature type.
Returns a new Signature type with the field, name, updated
with fields[name].json_schema_extra[key] = value.
"""
fields_copy = deepcopy(cls.fields)
fields_copy[name].json_schema_extra = {
**fields_copy[name].json_schema_extra,
Expand All @@ -90,27 +98,23 @@ def with_updated_fields(cls, name, type_=None, **kwargs):
return Signature(fields_copy, cls.instructions)

@property
def input_fields(cls):
def input_fields(cls) -> dict[str, FieldInfo]:
return cls._get_fields_with_type("input")

@property
def output_fields(cls):
def output_fields(cls) -> dict[str, FieldInfo]:
return cls._get_fields_with_type("output")

def _get_fields_with_type(cls, field_type):
return {
k: v
for k, v in cls.model_fields.items()
if v.json_schema_extra["__dspy_field_type"] == field_type
}
def _get_fields_with_type(cls, field_type) -> dict[str, FieldInfo]:
return {k: v for k, v in cls.model_fields.items() if v.json_schema_extra["__dspy_field_type"] == field_type}

def prepend(cls, name, field, type_=None):
def prepend(cls, name, field, type_=None) -> Type["Signature"]:
return cls.insert(0, name, field, type_)

def append(cls, name, field, type_=None):
def append(cls, name, field, type_=None) -> Type["Signature"]:
return cls.insert(-1, name, field, type_)

def insert(cls, index: int, name: str, field, type_: Type = None):
def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signature"]:
# It's posisble to set the type as annotation=type in pydantic.Field(...)
# But this may be annoying for users, so we allow them to pass the type
if type_ is None:
Expand All @@ -122,11 +126,7 @@ def insert(cls, index: int, name: str, field, type_: Type = None):
output_fields = list(cls.output_fields.items())

# Choose the list to insert into based on the field type
lst = (
input_fields
if field.json_schema_extra["__dspy_field_type"] == "input"
else output_fields
)
lst = input_fields if field.json_schema_extra["__dspy_field_type"] == "input" else output_fields
# We support negative insert indices
if index < 0:
index += len(lst) + 1
Expand All @@ -137,83 +137,7 @@ def insert(cls, index: int, name: str, field, type_: Type = None):
new_fields = dict(input_fields + output_fields)
return Signature(new_fields, cls.instructions)

def _parse_signature(cls, signature: str) -> Tuple[Type, Field]:
pattern = r"^\s*[\w\s,]+\s*->\s*[\w\s,]+\s*$"
if not re.match(pattern, signature):
raise ValueError(f"Invalid signature format: '{signature}'")

fields = {}
inputs_str, outputs_str = map(str.strip, signature.split("->"))
inputs = [v.strip() for v in inputs_str.split(",") if v.strip()]
outputs = [v.strip() for v in outputs_str.split(",") if v.strip()]
for name in inputs:
fields[name] = (str, InputField())
for name in outputs:
fields[name] = (str, OutputField())

return fields

def __call__(
cls,
signature: Union[str, Dict[str, Tuple[type, FieldInfo]]],
instructions: str = None,
):
"""
Creates a new Signature type with the given fields and instructions.
Note:
Even though we're calling a type, we're not making an instance of the type.
In general we don't allow instances of Signature types to be made. The call
syntax is only for your convenience.
Parameters:
signature: Format: "input1, input2 -> output1, output2"
instructions: Optional prompt for the signature.
"""

if isinstance(signature, str):
fields = cls._parse_signature(signature)
else:
fields = signature

# Validate the fields, this is important because we sometimes forget the
# slightly unintuitive syntax with tuples of (type, Field)
fixed_fields = {}
for name, type_field in fields.items():
assert isinstance(
name, str,
), f"Field names must be strings, not {type(name)}"
if isinstance(type_field, FieldInfo):
type_ = type_field.annotation
field = type_field
else:
assert isinstance(
type_field, tuple,
), f"Field values must be tuples, not {type(type_field)}"
type_, field = type_field
# It might be better to be explicit about the type, but it currently would break
# program of thought and teleprompters, so we just silently default to string.
if type_ is None:
type_ = str
assert isinstance(type_, type) or isinstance(
typing.get_origin(type_), type,
), f"Field types must be types, not {type(type_)}"
assert isinstance(
field, FieldInfo,
), f"Field values must be Field instances, not {type(field)}"
fixed_fields[name] = (type_, field)

# Fixing the fields shouldn't change the order
assert list(fixed_fields.keys()) == list(fields.keys())

# Default prompt when no instructions are provided
if instructions is None:
sig = Signature(signature, "") # Simple way to parse input/output fields
instructions = _default_instructions(sig)

signature = create_model("Signature", __base__=Signature, **fixed_fields)
signature.__doc__ = instructions
return signature

def equals(cls, other):
def equals(cls, other) -> bool:
"""Compare the JSON schema of two Pydantic models."""
if not isinstance(other, type) or not issubclass(other, BaseModel):
return False
Expand All @@ -226,50 +150,142 @@ def equals(cls, other):
return True

def __repr__(cls):
"""
Outputs something on the form:
"""Output a representation of the signature.
Uses the form:
Signature(question, context -> answer
question: str = InputField(desc="..."),
context: List[str] = InputField(desc="..."),
answer: int = OutputField(desc="..."),
)
).
"""
field_reprs = []
for name, field in cls.fields.items():
field_reprs.append(f"{name} = Field({field})")
field_repr = "\n ".join(field_reprs)
return (
f"Signature({cls.signature}\n"
f" instructions={repr(cls.instructions)}\n"
f" {field_repr}\n)"
)
return f"{cls.__name__}({cls.signature}\n instructions={repr(cls.instructions)}\n {field_repr}\n)"


class Signature(BaseModel, metaclass=SignatureMeta):
"""A signature for a predictor.
You typically subclass it, like this:
class MySignature(Signature):
input: str = InputField(desc="...")
output: int = OutputField(desc="...")
You can call Signature("input1, input2 -> output1, output2") to create a new signature type.
You can also include instructions, Signature("input -> output", "This is a test").
But it's generally better to use the make_signature function.
If you are not sure if your input is a string representation, (like "input1, input2 -> output1, output2"),
or a signature, you can use the ensure_signature function.
For compatibility with the legacy dsp format, you can use the signature_to_template function.
"""

pass


def ensure_signature(signature):
def ensure_signature(signature: str | Type[Signature]) -> Signature:
if signature is None:
return None
if isinstance(signature, str):
return Signature(signature)
return signature


def infer_prefix(attribute_name: str) -> str:
"""Infers a prefix from an attribute name."""
def make_signature(
signature: Union[str, Dict[str, Tuple[type, FieldInfo]]],
instructions: str = None,
signature_name: str = "StringSignature",
) -> Type[Signature]:
"""Create a new Signature type with the given fields and instructions.
Note:
Even though we're calling a type, we're not making an instance of the type.
In general, instances of Signature types are not allowed to be made. The call
syntax is provided for convenience.
Args:
signature: The signature format, specified as "input1, input2 -> output1, output2".
instructions: An optional prompt for the signature.
signature_name: An optional name for the new signature type.
"""
fields = _parse_signature(signature) if isinstance(signature, str) else signature

# Validate the fields, this is important because we sometimes forget the
# slightly unintuitive syntax with tuples of (type, Field)
fixed_fields = {}
for name, type_field in fields.items():
if not isinstance(name, str):
raise ValueError(f"Field names must be strings, not {type(name)}")
if isinstance(type_field, FieldInfo):
type_ = type_field.annotation
field = type_field
else:
if not isinstance(type_field, tuple):
raise ValueError(f"Field values must be tuples, not {type(type_field)}")
type_, field = type_field
# It might be better to be explicit about the type, but it currently would break
# program of thought and teleprompters, so we just silently default to string.
if type_ is None:
type_ = str
if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type):
raise ValueError(f"Field types must be types, not {type(type_)}")
if not isinstance(field, FieldInfo):
raise ValueError(f"Field values must be Field instances, not {type(field)}")
fixed_fields[name] = (type_, field)

# Fixing the fields shouldn't change the order
assert list(fixed_fields.keys()) == list(fields.keys()) # noqa: S101

# Default prompt when no instructions are provided
if instructions is None:
sig = Signature(signature, "") # Simple way to parse input/output fields
instructions = _default_instructions(sig)

return create_model(
signature_name,
__base__=Signature,
__doc__=instructions,
**fixed_fields,
)


def _parse_signature(signature: str) -> Tuple[Type, Field]:
pattern = r"^\s*[\w\s,]+\s*->\s*[\w\s,]+\s*$"
if not re.match(pattern, signature):
raise ValueError(f"Invalid signature format: '{signature}'")

fields = {}
inputs_str, outputs_str = map(str.strip, signature.split("->"))
inputs = [v.strip() for v in inputs_str.split(",") if v.strip()]
outputs = [v.strip() for v in outputs_str.split(",") if v.strip()]
for name in inputs:
fields[name] = (str, InputField())
for name in outputs:
fields[name] = (str, OutputField())

return fields


def infer_prefix(attribute_name: str) -> str:
"""Infer a prefix from an attribute name."""
# Convert camelCase to snake_case, but handle sequences of capital letters properly
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", attribute_name)
intermediate_name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1)

# Insert underscores around numbers to ensure spaces in the final output
with_underscores_around_numbers = re.sub(
r"([a-zA-Z])(\d)", r"\1_\2", intermediate_name,
r"([a-zA-Z])(\d)",
r"\1_\2",
intermediate_name,
)
with_underscores_around_numbers = re.sub(
r"(\d)([a-zA-Z])", r"\1_\2", with_underscores_around_numbers,
r"(\d)([a-zA-Z])",
r"\1_\2",
with_underscores_around_numbers,
)

# Convert snake_case to 'Proper Title Case', but ensure acronyms are uppercased
Expand Down
14 changes: 0 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -232,28 +232,14 @@ ignore = [
"ANN003",
# utf-8 encoding skip
"UP009",
# First argument of a method should be named `self`
"N805",
# 1 blank line required between summary line and description
"D205",
# Missing return type annotation for special method `__init__`
"ANN204",
# Avoid using the generic variable name `df` for DataFrames
"PD901",
# Unnecessary assignment to `df` before `return` statement
"RET504",
# commented code
"ERA001",
# Star-arg unpacking after a keyword argument is strongly discouraged
"B026",
# Missing type annotation for function argument `self`
"ANN001",
# Dynamically typed expressions (typing.Any) are disallowed in `wrapper`
"ANN401",
# Unnecessary `elif` after `return` statement
"RET505",
# Within an `except` clause, raise exceptions with `raise
"B904",
# We don't need docstrings for every method
"ANN202",
"D107",
Expand Down
Loading

0 comments on commit 6ba4130

Please sign in to comment.