diff --git a/js/src/lib/inferenceSnippets/inputs.ts b/js/src/lib/inferenceSnippets/inputs.ts index 33e9b4d69..6efbf8060 100644 --- a/js/src/lib/inferenceSnippets/inputs.ts +++ b/js/src/lib/inferenceSnippets/inputs.ts @@ -65,12 +65,15 @@ const inputsSentenceSimilarity = () => const inputsFeatureExtraction = () => `"Today is a sunny day and I'll get some ice cream."`; +const inputsImageClassification = () => `"cats.jpg"`; + const modelInputSnippets: { [key in PipelineType]?: (model: ModelData) => string; } = { "conversational": inputsConversational, "feature-extraction": inputsFeatureExtraction, "fill-mask": inputsFillMask, + "image-classification": inputsImageClassification, "question-answering": inputsQuestionAnswering, "sentence-similarity": inputsSentenceSimilarity, "summarization": inputsSummarization, @@ -84,13 +87,21 @@ const modelInputSnippets: { }; // Use noWrap to put the whole snippet on a single line (removing new lines and tabulations) -export function getModelInputSnippet(model: ModelData, noWrap = false): string { +// Use noQuotes to strip quotes from start & end (example: "abc" -> abc) +export function getModelInputSnippet(model: ModelData, noWrap = false, noQuotes = false): string { if (model.pipeline_tag) { const inputs = modelInputSnippets[model.pipeline_tag]; if (inputs) { - return noWrap - ? inputs(model).replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ") - : inputs(model); + let result = inputs(model); + if (noWrap) { + result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " "); + } + if (noQuotes) { + const REGEX_QUOTES = /^"(.+)"$/s; + const match = result.match(REGEX_QUOTES); + result = match ? match[1] : result; + } + return result; } } return "No input example has been defined for this model task."; diff --git a/js/src/lib/inferenceSnippets/serveCurl.ts b/js/src/lib/inferenceSnippets/serveCurl.ts index 62ccf14e2..4facc3090 100644 --- a/js/src/lib/inferenceSnippets/serveCurl.ts +++ b/js/src/lib/inferenceSnippets/serveCurl.ts @@ -1,43 +1,53 @@ import type { PipelineType, ModelData } from "../interfaces/Types"; import { getModelInputSnippet } from "./inputs"; -export const bodyBasic = (model: ModelData): string => - `-d '{"inputs": ${getModelInputSnippet(model, true)}}'`; +export const snippetBasic = (model: ModelData, accessToken: string): string => + `curl https://api-inference.huggingface.co/models/${model.id} \\ + -X POST \\ + -d '{"inputs": ${getModelInputSnippet(model, true)}}' \\ + -H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" +`; -export const bodyZeroShotClassification = (model: ModelData): string => - `-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}'`; +export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string => + `curl https://api-inference.huggingface.co/models/${model.id} \\ + -X POST \\ + -d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\ + -H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" +`; -export const curlSnippetBodies: - Partial string>> = +export const snippetFile = (model: ModelData, accessToken: string): string => + `curl https://api-inference.huggingface.co/models/${model.id} \\ + -X POST \\ + --data-binary '@${getModelInputSnippet(model, true, true)}' \\ + -H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" +`; + +export const curlSnippets: + Partial string>> = { // Same order as in js/src/lib/interfaces/Types.ts - "text-classification": bodyBasic, - "token-classification": bodyBasic, - "table-question-answering": bodyBasic, - "question-answering": bodyBasic, - "zero-shot-classification": bodyZeroShotClassification, - "translation": bodyBasic, - "summarization": bodyBasic, - "conversational": bodyBasic, - "feature-extraction": bodyBasic, - "text-generation": bodyBasic, - "text2text-generation": bodyBasic, - "fill-mask": bodyBasic, - "sentence-similarity": bodyBasic, + "text-classification": snippetBasic, + "token-classification": snippetBasic, + "table-question-answering": snippetBasic, + "question-answering": snippetBasic, + "zero-shot-classification": snippetZeroShotClassification, + "translation": snippetBasic, + "summarization": snippetBasic, + "conversational": snippetBasic, + "feature-extraction": snippetBasic, + "text-generation": snippetBasic, + "text2text-generation": snippetBasic, + "fill-mask": snippetBasic, + "sentence-similarity": snippetBasic, + "image-classification": snippetFile, }; export function getCurlInferenceSnippet(model: ModelData, accessToken: string): string { - const body = model.pipeline_tag && model.pipeline_tag in curlSnippetBodies - ? curlSnippetBodies[model.pipeline_tag]?.(model) ?? "" + return model.pipeline_tag && model.pipeline_tag in curlSnippets + ? curlSnippets[model.pipeline_tag]?.(model, accessToken) ?? "" : ""; - - return `curl https://api-inference.huggingface.co/models/${model.id} \\ - -X POST \\ - ${body} \\ - -H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" -`; } export function hasCurlInferenceSnippet(model: ModelData): boolean { - return !!model.pipeline_tag && model.pipeline_tag in curlSnippetBodies; + return !!model.pipeline_tag && model.pipeline_tag in curlSnippets; } diff --git a/js/src/lib/inferenceSnippets/serveJs.ts b/js/src/lib/inferenceSnippets/serveJs.ts index b021ab5c5..c7c80e64f 100644 --- a/js/src/lib/inferenceSnippets/serveJs.ts +++ b/js/src/lib/inferenceSnippets/serveJs.ts @@ -1,37 +1,26 @@ import type { PipelineType, ModelData } from "../interfaces/Types"; import { getModelInputSnippet } from "./inputs"; -export const bodyBasic = (model: ModelData): string => - `{"inputs": ${getModelInputSnippet(model)}}`; - -export const bodyZeroShotClassification = (model: ModelData): string => - `{"inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}`; +export const snippetBasic = (model: ModelData, accessToken: string): string => + `async function query(data) { + const response = await fetch( + "https://api-inference.huggingface.co/models/${model.id}", + { + headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" }, + method: "POST", + body: JSON.stringify(data), + } + ); + const result = await response.json(); + return result; +} -export const jsSnippetBodies: - Partial string>> = -{ - // Same order as in js/src/lib/interfaces/Types.ts - "text-classification": bodyBasic, - "token-classification": bodyBasic, - "table-question-answering": bodyBasic, - "question-answering": bodyBasic, - "zero-shot-classification": bodyZeroShotClassification, - "translation": bodyBasic, - "summarization": bodyBasic, - "conversational": bodyBasic, - "feature-extraction": bodyBasic, - "text-generation": bodyBasic, - "text2text-generation": bodyBasic, - "fill-mask": bodyBasic, - "sentence-similarity": bodyBasic, -}; +query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { + console.log(JSON.stringify(response)); +});`; -export function getJsInferenceSnippet(model: ModelData, accessToken: string): string { - const body = model.pipeline_tag && model.pipeline_tag in jsSnippetBodies - ? jsSnippetBodies[model.pipeline_tag]?.(model) ?? "" - : ""; - - return `async function query(data) { +export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string => + `async function query(data) { const response = await fetch( "https://api-inference.huggingface.co/models/${model.id}", { @@ -44,11 +33,55 @@ export function getJsInferenceSnippet(model: ModelData, accessToken: string): st return result; } -query(${body}).then((response) => { +query({"inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => { + console.log(JSON.stringify(response)); +});`; + +export const snippetFile = (model: ModelData, accessToken: string): string => + `async function query(filename) { + const data = fs.readFileSync(filename); + const response = await fetch( + "https://api-inference.huggingface.co/models/${model.id}", + { + headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" }, + method: "POST", + body: data, + } + ); + const result = await response.json(); + return result; +} + +query(${getModelInputSnippet(model)}).then((response) => { console.log(JSON.stringify(response)); });`; + +export const jsSnippets: + Partial string>> = +{ + // Same order as in js/src/lib/interfaces/Types.ts + "text-classification": snippetBasic, + "token-classification": snippetBasic, + "table-question-answering": snippetBasic, + "question-answering": snippetBasic, + "zero-shot-classification": snippetZeroShotClassification, + "translation": snippetBasic, + "summarization": snippetBasic, + "conversational": snippetBasic, + "feature-extraction": snippetBasic, + "text-generation": snippetBasic, + "text2text-generation": snippetBasic, + "fill-mask": snippetBasic, + "sentence-similarity": snippetBasic, + "image-classification": snippetFile, +}; + +export function getJsInferenceSnippet(model: ModelData, accessToken: string): string { + return model.pipeline_tag && model.pipeline_tag in jsSnippets + ? jsSnippets[model.pipeline_tag]?.(model, accessToken) ?? "" + : ""; } export function hasJsInferenceSnippet(model: ModelData): boolean { - return !!model.pipeline_tag && model.pipeline_tag in jsSnippetBodies; + return !!model.pipeline_tag && model.pipeline_tag in jsSnippets; } diff --git a/js/src/lib/inferenceSnippets/servePython.ts b/js/src/lib/inferenceSnippets/servePython.ts index 73fb8d4b2..2ea9ffa1d 100644 --- a/js/src/lib/inferenceSnippets/servePython.ts +++ b/js/src/lib/inferenceSnippets/servePython.ts @@ -1,39 +1,57 @@ import type { PipelineType, ModelData } from "../interfaces/Types"; import { getModelInputSnippet } from "./inputs"; -export const bodyZeroShotClassification = (model: ModelData): string => - `output = query({ +export const snippetZeroShotClassification = (model: ModelData): string => + `def query(payload): + response = requests.post(API_URL, headers=headers, json=payload) + return response.json() + +output = query({ "inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}, })`; -export const bodyBasic = (model: ModelData): string => - `output = query({ +export const snippetBasic = (model: ModelData): string => + `def query(payload): + response = requests.post(API_URL, headers=headers, json=payload) + return response.json() + +output = query({ "inputs": ${getModelInputSnippet(model)}, })`; -export const pythonSnippetBodies: +export const snippetFile = (model: ModelData): string => + `def query(filename): + with open(filename, "rb") as f: + data = f.read() + response = requests.request("POST", API_URL, headers=headers, data=data) + return json.loads(response.content.decode("utf-8")) + +output = query(${getModelInputSnippet(model)})`; + +export const pythonSnippets: Partial string>> = { // Same order as in js/src/lib/interfaces/Types.ts - "text-classification": bodyBasic, - "token-classification": bodyBasic, - "table-question-answering": bodyBasic, - "question-answering": bodyBasic, - "zero-shot-classification": bodyZeroShotClassification, - "translation": bodyBasic, - "summarization": bodyBasic, - "conversational": bodyBasic, - "feature-extraction": bodyBasic, - "text-generation": bodyBasic, - "text2text-generation": bodyBasic, - "fill-mask": bodyBasic, - "sentence-similarity": bodyBasic, + "text-classification": snippetBasic, + "token-classification": snippetBasic, + "table-question-answering": snippetBasic, + "question-answering": snippetBasic, + "zero-shot-classification": snippetZeroShotClassification, + "translation": snippetBasic, + "summarization": snippetBasic, + "conversational": snippetBasic, + "feature-extraction": snippetBasic, + "text-generation": snippetBasic, + "text2text-generation": snippetBasic, + "fill-mask": snippetBasic, + "sentence-similarity": snippetBasic, + "image-classification": snippetFile, }; export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string { - const body = model.pipeline_tag && model.pipeline_tag in pythonSnippetBodies - ? pythonSnippetBodies[model.pipeline_tag]?.(model) ?? "" + const body = model.pipeline_tag && model.pipeline_tag in pythonSnippets + ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : ""; return `import requests @@ -41,13 +59,9 @@ export function getPythonInferenceSnippet(model: ModelData, accessToken: string) API_URL = "https://api-inference.huggingface.co/models/${model.id}" headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}} -def query(payload): - response = requests.post(API_URL, headers=headers, json=payload) - return response.json() - ${body}`; } export function hasPythonInferenceSnippet(model: ModelData): boolean { - return !!model.pipeline_tag && model.pipeline_tag in pythonSnippetBodies; + return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets; }