Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for generating templates using Ollama provider #1180

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion langtest/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __init__(
raise ImportError(Errors.E097())

except Exception as e_msg:
raise Errors.E095(e=e_msg)
raise Exception(Errors.E095(e=e_msg))

if show_templates:
[print(template) for template in self.__templates]
Expand Down Expand Up @@ -610,6 +610,7 @@ def __generate_templates(
from langtest.augmentation.utils import (
generate_templates_azoi, # azoi means Azure OpenAI
generate_templates_openai,
generate_templates_ollama,
)

params = model_config.copy() if model_config else {}
Expand All @@ -620,5 +621,8 @@ def __generate_templates(
elif model_config and model_config.get("provider") == "azure":
return generate_templates_azoi(template, num_extra_templates, params)

elif model_config and model_config.get("provider") == "ollama":
return generate_templates_ollama(template, num_extra_templates, params)

else:
return generate_templates_openai(template, num_extra_templates)
53 changes: 53 additions & 0 deletions langtest/augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def generate_templates_azoi(

if "provider" in model_config:
del model_config["provider"]

if "model" in model_config:
del model_config["model"]

client = openai.AzureOpenAI(**model_config)
Expand Down Expand Up @@ -151,6 +153,8 @@ def generate_templates_openai(

if "provider" in model_config:
del model_config["provider"]

if "model" in model_config:
del model_config["model"]

client = openai.OpenAI(**model_config)
Expand Down Expand Up @@ -180,3 +184,52 @@ def generate_templates_openai(
generated_response.remove_invalid_templates(template)

return generated_response.templates[:num_extra_templates]


def generate_templates_ollama(
template: str, num_extra_templates: int, model_config: OpenAIConfig = OpenAIConfig()
):
"""Generate new templates based on the provided template using OpenAI API."""
import ollama

# model_name
model_name = model_config.get("model", "llama3.1")
try:

if "provider" in model_config:
del model_config["provider"]

if "model" in model_config:
del model_config["model"]

client = ollama.Client()

prompt = (
f"Based on the provided template, create {num_extra_templates} new and unique templates that are "
"variations on this theme. Present these as a list, with each template as a quoted string. The list should "
"contain only the templates, without any additional text or explanation. Ensure that the structure of "
"these variables remains consistent in each generated template. Note: don't add any extra variables and ignore typo errors.\n\n"
"Template:\n"
f"{template}\n"
)
response = client.chat(
model=model_name,
messages=[
{
"role": "system",
"content": f"Action: Generate up to {num_extra_templates} templates and ensure that the structure of the variables within the templates remains unchanged and don't add any extra variables.",
},
{"role": "user", "content": prompt},
],
format=Templates.model_json_schema(),
)

generated_response = Templates.model_validate_json(response.message.content)
generated_response.remove_invalid_templates(template)
return generated_response.templates[:num_extra_templates]
except ollama.ResponseError as e:
if any("model" in arg for arg in e.args):
raise ValueError(
f"Model not found: {e}, please pull model using `ollama pull {model_name}`"
)
raise ValueError(f"Error in response: {e}")