Skip to content

Commit

Permalink
refactor: enhance variable resolution and prompt rendering (promptfoo…
Browse files Browse the repository at this point in the history
  • Loading branch information
mldangelo authored Dec 12, 2024
1 parent aa16a25 commit 3a7d1fd
Show file tree
Hide file tree
Showing 4 changed files with 624 additions and 72 deletions.
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
"Evals",
"exfiltration",
"figcaption",
"footgun",
"globbed",
"Groq",
"helicone",
"Huggingface",
"icontains",
"jailbreaking",
"jamba",
"langfuse",
"leaderboard",
"leetspeak",
"Lightbox",
Expand All @@ -43,6 +46,7 @@
"openai",
"overreliance",
"OWASP",
"portkey",
"Probs",
"promptfoo",
"promptfooconfig",
Expand Down
174 changes: 126 additions & 48 deletions src/evaluatorHelpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,38 +32,112 @@ export async function extractTextFromPDF(pdfPath: string): Promise<string> {
}
}

export function resolveVariables(
variables: Record<string, string | object>,
): Record<string, string | object> {
let resolved = true;
const regex = /\{\{\s*(\w+)\s*\}\}/; // Matches {{variableName}}, {{ variableName }}, etc.

let iterations = 0;
do {
resolved = true;
for (const key of Object.keys(variables)) {
if (typeof variables[key] !== 'string') {
type VariableValue = string | object | number | boolean;
type Variables = Record<string, VariableValue>;

/**
* Helper function to remove trailing newlines from string values
* This prevents issues with JSON prompts
*/
export function trimTrailingNewlines(variables: Variables): Variables {
const trimmedVars: Variables = { ...variables };
for (const [key, value] of Object.entries(trimmedVars)) {
if (typeof value === 'string') {
trimmedVars[key] = value.replace(/\n$/, '');
}
}
return trimmedVars;
}

/**
* Helper function that resolves variables within a single string.
*
* @param value - The string containing variables to resolve
* @param variables - Object containing variable values
* @param regex - Regular expression for matching variables
* @returns The string with all resolvable variables replaced
*/
function resolveString(value: string, variables: Variables, regex: RegExp): string {
let result = value;
let match: RegExpExecArray | null;

// Reset regex for new string
regex.lastIndex = 0;

// Find and replace all variables in the string
while ((match = regex.exec(result)) !== null) {
const [placeholder, varName] = match;

// Skip undefined variables (will be handled by nunjucks later)
if (variables[varName] === undefined) {
continue;
}

// Only replace if the replacement is a string
const replacement = variables[varName];
if (typeof replacement === 'string') {
result = result.replace(placeholder, replacement);
}
}

return result;
}

/**
* Resolves variables within string values of an object, replacing {{varName}} with
* the corresponding value from the variables object.
*
* Example:
* Input: { greeting: "Hello {{name}}!", name: "World" }
* Output: { greeting: "Hello World!", name: "World" }
*
* @param variables - Object containing variable names and their values
* @returns A new object with all variables resolved
*/
export function resolveVariables(variables: Variables): Variables {
const regex = /\{\{\s*(\w+)\s*\}\}/g; // Matches {{variableName}}, {{ variableName }}, etc.
const resolvedVars: Variables = trimTrailingNewlines(variables);
const MAX_ITERATIONS = 5;

// Iterate up to MAX_ITERATIONS times to handle nested variables
for (let iteration = 0; iteration < MAX_ITERATIONS; iteration++) {
let hasChanges = false;

// Process each variable in the object
for (const [key, value] of Object.entries(resolvedVars)) {
// Skip non-string values as they can't contain variable references
if (typeof value !== 'string') {
continue;
}
const value = variables[key] as string;
const match = regex.exec(value);
if (match) {
const [placeholder, varName] = match;
if (variables[varName] === undefined) {
// Do nothing - final nunjucks render will fail if necessary.
// logger.warn(`Variable "${varName}" not found for substitution.`);
} else {
variables[key] = value.replace(placeholder, variables[varName] as string);
resolved = false; // Indicate that we've made a replacement and should check again
}

// Try to resolve any variables in this string
const newValue = resolveString(value, resolvedVars, regex);

// Only update if the value actually changed
if (newValue !== value) {
resolvedVars[key] = newValue;
hasChanges = true;
}
}
iterations++;
} while (!resolved && iterations < 5);

return variables;
// If no changes were made in this iteration, we're done
if (!hasChanges) {
break;
}
}

return resolvedVars;
}

const isStringPrompt = (s: string): boolean => {
// If it starts with { or [ it's likely JSON/YAML
const firstNonWhitespaceChar = s.trim()[0];
if (firstNonWhitespaceChar === '{' || firstNonWhitespaceChar === '[') {
return false;
}
return true;
};

export async function renderPrompt(
prompt: Prompt,
vars: Record<string, string | object>,
Expand Down Expand Up @@ -156,22 +230,16 @@ export async function renderPrompt(
}
}

// Remove any trailing newlines from vars, as this tends to be a footgun for JSON prompts.
for (const key of Object.keys(vars)) {
if (typeof vars[key] === 'string') {
vars[key] = (vars[key] as string).replace(/\n$/, '');
}
}

// Resolve variable mappings
resolveVariables(vars);
const resolvedVars: Variables = resolveVariables(vars);

// Third party integrations
// Handle third party integrations first
if (prompt.raw.startsWith('portkey://')) {
const { getPrompt } = await import('./integrations/portkey');
const portKeyResult = await getPrompt(prompt.raw.slice('portkey://'.length), vars);
const portKeyResult = await getPrompt(prompt.raw.slice('portkey://'.length), resolvedVars);
return JSON.stringify(portKeyResult.messages);
} else if (prompt.raw.startsWith('langfuse://')) {
}
if (prompt.raw.startsWith('langfuse://')) {
const { getPrompt } = await import('./integrations/langfuse');
const langfusePrompt = prompt.raw.slice('langfuse://'.length);

Expand All @@ -183,38 +251,48 @@ export async function renderPrompt(

const langfuseResult = await getPrompt(
helper,
vars,
resolvedVars,
promptType,
version === 'latest' ? undefined : Number(version),
);
return langfuseResult;
} else if (prompt.raw.startsWith('helicone://')) {
}
if (prompt.raw.startsWith('helicone://')) {
const { getPrompt } = await import('./integrations/helicone');
const heliconePrompt = prompt.raw.slice('helicone://'.length);
const [id, version] = heliconePrompt.split(':');
const [majorVersion, minorVersion] = version ? version.split('.') : [undefined, undefined];
const heliconeResult = await getPrompt(
id,
vars,
resolvedVars,
majorVersion === undefined ? undefined : Number(majorVersion),
minorVersion === undefined ? undefined : Number(minorVersion),
);
return heliconeResult;
}

// Render prompt
// If JSON autoescape is disabled, just render with nunjucks
// basic prompt is defined as something that does not contain JSON
// elements like { or [ before rendering
if (getEnvBool('PROMPTFOO_DISABLE_JSON_AUTOESCAPE') || isStringPrompt(basePrompt)) {
return nunjucks.renderString(basePrompt, resolvedVars);
}
try {
if (getEnvBool('PROMPTFOO_DISABLE_JSON_AUTOESCAPE')) {
return nunjucks.renderString(basePrompt, vars);
}

const parsed = JSON.parse(basePrompt);

// Everything should be a JSON at this point. Use yaml parser because it's more forgiving
const parsed = yaml.load(basePrompt) as Record<string, any>;
// The _raw_ prompt is valid JSON. That means that the user likely wants to substitute vars _within_ the JSON itself.
// Recursively walk the JSON structure. If we find a string, render it with nunjucks.
return JSON.stringify(renderVarsInObject(parsed, vars), null, 2);
const rendered = renderVarsInObject<Variables>(parsed, resolvedVars);
if (typeof rendered === 'object' && rendered !== null) {
return JSON.stringify(rendered, null, 2);
}
if (typeof rendered === 'string') {
return rendered;
}
throw new Error(`Unknown rendered type: ${typeof rendered}`);
} catch {
return nunjucks.renderString(basePrompt, vars);
// If YAML/JSON parsing fails, render as basic text
return nunjucks.renderString(basePrompt, resolvedVars);
}
}

Expand Down
34 changes: 19 additions & 15 deletions src/util/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { parse as csvParse } from 'csv-parse/sync';
import { stringify } from 'csv-stringify/sync';
import dedent from 'dedent';
import dotenv from 'dotenv';
import { desc, eq, like, and, sql, not } from 'drizzle-orm';
import { and, desc, eq, like, not, sql } from 'drizzle-orm';
import deepEqual from 'fast-deep-equal';
import * as fs from 'fs';
import { globSync } from 'glob';
Expand All @@ -15,13 +15,13 @@ import { TERMINAL_MAX_WIDTH } from '../constants';
import { getDbSignalPath, getDb } from '../database';
import {
datasetsTable,
evalResultsTable,
evalsTable,
evalsToDatasetsTable,
evalsToPromptsTable,
evalsToTagsTable,
promptsTable,
tagsTable,
evalResultsTable,
} from '../database/tables';
import { getEnvBool } from '../envars';
import { getDirectory, importModule } from '../esm';
Expand All @@ -32,31 +32,32 @@ import Eval, { createEvalId, getSummaryOfLatestEvals } from '../models/eval';
import type EvalResult from '../models/evalResult';
import { generateIdFromPrompt } from '../models/prompt';
import {
type EvalWithMetadata,
isApiProvider,
isProviderOptions,
OutputFileExtension,
ResultFailureReason,
type CompletedPrompt,
type CsvRow,
type EvaluateResult,
type EvaluateTable,
type EvaluateTableOutput,
type EvaluateSummaryV2,
type EvalWithMetadata,
type NunjucksFilterMap,
type OutputFile,
type PromptWithMetadata,
type ResultLightweight,
type ResultsFile,
type TestCase,
type TestCasesWithMetadata,
type TestCasesWithMetadataPrompt,
type UnifiedConfig,
type OutputFile,
type CompletedPrompt,
type CsvRow,
type ResultLightweight,
isApiProvider,
isProviderOptions,
OutputFileExtension,
type EvaluateSummaryV2,
ResultFailureReason,
} from '../types';
import invariant from '../util/invariant';
import { getConfigDirectoryPath } from './config/manage';
import { sha256 } from './createHash';
import { convertTestResultsToTableRow, getHeaderForTable } from './exportToFile';
import { getHeaderForTable } from './exportToFile';
import { convertTestResultsToTableRow } from './exportToFile';
import { isJavascriptFile } from './file';
import { getNunjucksEngine } from './templates';

Expand Down Expand Up @@ -984,7 +985,10 @@ export function resultIsForTestCase(result: EvaluateResult, testCase: TestCase):
return varsMatch(testCase.vars, result.vars) && providersMatch;
}

export function renderVarsInObject<T>(obj: T, vars?: Record<string, string | object>): T {
export function renderVarsInObject<T>(
obj: T,
vars?: Record<string, string | object | number | boolean>,
): T {
// Renders nunjucks template strings with context variables
if (!vars || getEnvBool('PROMPTFOO_DISABLE_TEMPLATING')) {
return obj;
Expand All @@ -993,7 +997,7 @@ export function renderVarsInObject<T>(obj: T, vars?: Record<string, string | obj
return nunjucks.renderString(obj, vars) as unknown as T;
}
if (Array.isArray(obj)) {
return obj.map((item) => renderVarsInObject(item, vars)) as unknown as T;
return obj.map((item) => renderVarsInObject<T>(item, vars)) as unknown as T;
}
if (typeof obj === 'object' && obj !== null) {
const result: Record<string, unknown> = {};
Expand Down
Loading

0 comments on commit 3a7d1fd

Please sign in to comment.