Skip to content

Commit

Permalink
feat: Add Cloudflare AI Provider (promptfoo#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdr authored May 30, 2024
1 parent 9e89fad commit dad977d
Show file tree
Hide file tree
Showing 14 changed files with 868 additions and 8 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,17 @@ Here are some of the available scripts:
- `db:generate`: Generate new db migrations (and create the db if it doesn't already exist). Note that after generating a new migration, you'll have to `npm i` to copy the migrations into `dist/`.
- `db:migrate`: Run existing db migrations (and create the db if it doesn't already exist)

To run the CLI during development you can run a command like: `npm run local -- eval --config $(readlink -f ./examples/cloudflare-ai/chat_config.yaml)`, where any parts of the command after `--` are passed through to our CLI entrypoint. Since the Next dev server isn't supported in this mode, see the instructions above for running the web server.

# [» View full documentation «](https://promptfoo.dev/docs/intro)

[providers-docs]: https://promptfoo.dev/docs/providers


### Adding a New Provider

1. Create an implementation in `src/providers/SOME_PROVIDER_FILE`
2. Update `loadApiProvider` in `src/providers.ts` to load your provider via string
3. Add test cases in `test/providers.test.ts`
1. Test the actual provider implementation
2. Test loading the provider via a `loadApiProvider` test
5 changes: 5 additions & 0 deletions examples/cloudflare-ai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Cloudflare Workers AI

See the [full docs for Workers AI](https://developers.cloudflare.com/workers-ai/) and the [Cloudflare REST docs](https://developers.cloudflare.com/api/operations/workers-ai-post-run-model) for further reference on endpoints.

The basic example in this repo shows the difference in outputs between the "chat" and "completion" versions of the provider.
25 changes: 25 additions & 0 deletions examples/cloudflare-ai/chat_advanced_configuration.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
prompts:
- Tell me a really funny joke about {{topic}}. The joke should contain the word {{topic}}

providers:
- id: cloudflare-ai:chat:@cf/meta/llama-3-8b-instruct
config:
accountId: YOUR_ACCOUNT_ID_HERE
# ===============
# It is not recommended to keep your API key on the config file since it is a secret value.
# Use the CLOUDFLARE_API_KEY environment variable or set the apiKeyEnvar value
# in the config
# apiKey: YOUR_API_KEY_HERE
# apiKeyEnvar: SOME_ENV_HAR_CONTAINING_THE_API_KEY
# ===============
# Additional model parameters that are passed through in the HTTP request body to the Cloudflare REST API call
# to run the model
max_tokens: 800
seed: 1

tests:
- vars:
topic: birds
assert:
- type: icontains
value: "{{topic}}"
18 changes: 18 additions & 0 deletions examples/cloudflare-ai/chat_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
prompts:
- What is the capital of the {{country}}?

providers:
- cloudflare-ai:completion:@cf/meta/llama-3-8b-instruct
- cloudflare-ai:chat:@cf/meta/llama-3-8b-instruct

tests:
- vars:
country: United States
assert:
- type: icontains
value: Washington, D.C.
- vars:
country: England
assert:
- type: icontains
value: London
35 changes: 35 additions & 0 deletions examples/cloudflare-ai/embedding_configuration.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
prompts:
- This prompt will be evaluated for embedding similarity

defaultTest:
options:
provider:
embedding:
id: cloudflare-ai:embedding:@cf/baai/bge-base-en-v1.5
config:
# accountId: YOUR_ACCOUNT_ID_HERE



providers:
- id: cloudflare-ai:chat:@cf/meta/llama-3-8b-instruct
config:
# accountId: YOUR_ACCOUNT_ID_HERE

# ===============
# It is not recommended to keep your API key on the config file since it is a secret value.
# Use the CLOUDFLARE_API_KEY environment variable or set the apiKeyEnvar value
# in the config
# apiKey: YOUR_API_KEY_HERE
# apiKeyEnvar: SOME_ENV_HAR_CONTAINING_THE_API_KEY
# ===============
# Additional model parameters that are passed through in the HTTP request body to the Cloudflare REST API call
# to run the model
# max_tokens: 800
# seed: 1

tests:
- assert:
- type: similar
value: embedding similarity
threshold: 0.6
10 changes: 4 additions & 6 deletions jest.config.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
/** @type {import('ts-jest/dist/types').InitialOptionsTsJest} */
import type { Config } from 'jest';
import type { TsJestTransformerOptions } from 'ts-jest';

const tsJestConfig: TsJestTransformerOptions & Record<string, unknown> = { useESM: true };

const config: Config = {
transform: {
'\\.[jt]sx?$': 'ts-jest',
},
globals: {
'ts-jest': {
useESM: true,
},
'^.+\\.m?[tj]sx?$': ['ts-jest', tsJestConfig],
},
/*
moduleNameMapper: {
Expand Down
40 changes: 40 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
"promptfoo": "dist/src/main.js"
},
"scripts": {
"bin": "dist/src/main.js",
"tsc": "tsc",
"local": "ts-node --esm --files src/main.ts",
"local": "ts-node --cwdMode --transpileOnly src/main.ts",
"local:web": "cd src/web/nextui && npm run dev",
"install:nextui": "cd src/web/nextui && npm install",
"build:clean": "rm -rf dist",
Expand Down Expand Up @@ -71,6 +72,7 @@
"@types/semver": "^7.5.0",
"@types/uuid": "^9.0.2",
"babel-jest": "^29.5.0",
"cloudflare": "^3.2.0",
"drizzle-kit": "^0.20.13",
"jest": "^29.5.0",
"jest-watch-typeahead": "^2.2.2",
Expand Down
53 changes: 53 additions & 0 deletions site/docs/providers/cloudflare-ai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Cloudflare Workers AI

This provider supports the [models](https://developers.cloudflare.com/workers-ai/models/) provided by Cloudflare Workers AI, a serverless edge embedding and inference runtime.

## Required Configuration

Calling the Workers AI requires the user to supply a Cloudflare account ID and API key with sufficient permissions to invoke the Workers AI REST endpoints.

```bash
export CLOUDFLARE_ACCOUNT_ID=YOUR_ACCOUNT_ID_HERE
export CLOUDFLARE_API_KEY=YOUR_API_KEY_HERE
```

The Cloudflare account ID is not secret and therefore it is safe to put it in your `promptfoo` configuration file. The Cloudflare API key is secret, so while you can provide it in the config, this is **HIGHLY NOT RECOMMENDED** as it might lead to abuse. See below for an example safe configuration:


```yaml
prompts:
- Tell me a really funny joke about {{topic}}. The joke should contain the word {{topic}}

providers:
- id: cloudflare-ai:chat:@cf/meta/llama-3-8b-instruct
config:
accountId: YOUR_ACCOUNT_ID_HERE
# It is not recommended to keep your API key on the config file since it is a secret value.
# Use the CLOUDFLARE_API_KEY environment variable or set the apiKeyEnvar value
# in the config
# apiKey: YOUR_API_KEY_HERE
# apiKeyEnvar: SOME_ENV_HAR_CONTAINING_THE_API_KEY

tests:
- vars:
topic: birds
assert:
- type: icontains
value: "{{topic}}"
```
In addition to `apiKeyEnvar` allowed environment variable redirection for the `CLOUDFLARE_API_KEY` value, the `accountIdEnvar` can be used to similarly redirect to a value for the `CLOUDFLARE_ACCOUNT_ID`.

## Available Models and Model Parameters

Cloudflare is constantly adding new models to its inventory. See their [official list of models](https://developers.cloudflare.com/workers-ai/models/) for a list of supported models. Different models support different parameters, which is supported by supplying those parameters as additional keys of the config object in the `promptfoo` config file.

For an example of how advanced embedding configuration should be supplied, see `examples/cloudflare-ai/embedding_configuration.yaml`

For an example of how advanced completion/chat configuration should be supplied, see `examples/cloudflare-ai/chat_advanced_configuration.yaml`

Different models support different parameters. While this provider strives to be relatively flexible, it is possible that not all possible parameters have been added. If you need support for a parameter that is not yet supported, please open up a PR to add support.

## Future Improvements

- [ ] Allow for the pass through of all generic configuration parameters for Cloudflare REST API
6 changes: 6 additions & 0 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,12 @@ async function main() {
const rowsLeft = summary.table.body.length - 25;
logger.info(`... ${rowsLeft} more row${rowsLeft === 1 ? '' : 's'} not shown ...\n`);
}
} else if (summary.stats.failures !== 0) {
logger.debug(
`At least one evaluation failure occurred. This might be caused by the underlying call to the provider, or a test failure. Context: \n${JSON.stringify(
summary.results,
)}`,
);
}

const { outputPath } = config;
Expand Down
27 changes: 27 additions & 0 deletions src/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import {
import { AwsBedrockCompletionProvider, AwsBedrockEmbeddingProvider } from './providers/bedrock';
import { PythonProvider } from './providers/pythonCompletion';
import { CohereChatCompletionProvider } from './providers/cohere';
import * as CloudflareAiProviders from './providers/cloudflare-ai';
import { BAMChatProvider, BAMEmbeddingProvider } from './providers/bam';
import { PortkeyChatCompletionProvider } from './providers/portkey';
import { HttpProvider } from './providers/http';
Expand Down Expand Up @@ -289,6 +290,32 @@ export async function loadApiProvider(
`Invalid BAM provider: ${providerPath}. Use one of the following providers: bam:chat:<model name>`,
);
}
} else if (providerPath.startsWith('cloudflare-ai:')) {
// Load Cloudflare AI
const splits = providerPath.split(':');
const modelType = splits[1];
const deploymentName = splits[2];

if (modelType === 'chat') {
ret = new CloudflareAiProviders.CloudflareAiChatCompletionProvider(
deploymentName,
providerOptions,
);
} else if (modelType === 'embedding' || modelType === 'embeddings') {
ret = new CloudflareAiProviders.CloudflareAiEmbeddingProvider(
deploymentName,
providerOptions,
);
} else if (modelType === 'completion') {
ret = new CloudflareAiProviders.CloudflareAiCompletionProvider(
deploymentName,
providerOptions,
);
} else {
throw new Error(
`Unknown Cloudflare AI model type: ${modelType}. Use one of the following providers: cloudflare-ai:chat:<model name>, cloudflare-ai:completion:<model name>, cloudflare-ai:embedding:`,
);
}
} else if (providerPath.startsWith('webhook:')) {
const webhookUrl = providerPath.substring('webhook:'.length);
ret = new WebhookProvider(webhookUrl, providerOptions);
Expand Down
Loading

0 comments on commit dad977d

Please sign in to comment.