Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize inference snippets #106

Merged
merged 4 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions js/src/lib/inferenceSnippets/inputs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ const inputsSentenceSimilarity = () =>
const inputsFeatureExtraction = () =>
`"Today is a sunny day and I'll get some ice cream."`;

const inputsImageClassification = () => `"cats.jpg"`;

const modelInputSnippets: {
[key in PipelineType]?: (model: ModelData) => string;
} = {
"conversational": inputsConversational,
"feature-extraction": inputsFeatureExtraction,
"fill-mask": inputsFillMask,
"image-classification": inputsImageClassification,
"question-answering": inputsQuestionAnswering,
"sentence-similarity": inputsSentenceSimilarity,
"summarization": inputsSummarization,
Expand All @@ -84,13 +87,21 @@ const modelInputSnippets: {
};

// Use noWrap to put the whole snippet on a single line (removing new lines and tabulations)
export function getModelInputSnippet(model: ModelData, noWrap = false): string {
// Use noQuotes to strip quotes from start & end (example: "abc" -> abc)
export function getModelInputSnippet(model: ModelData, noWrap = false, noQuotes = false): string {
if (model.pipeline_tag) {
const inputs = modelInputSnippets[model.pipeline_tag];
if (inputs) {
return noWrap
? inputs(model).replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ")
: inputs(model);
let result = inputs(model);
if (noWrap) {
result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ");
}
if (noQuotes) {
const REGEX_QUOTES = /^"(.+)"$/s;
const match = result.match(REGEX_QUOTES);
result = match ? match[1] : result;
}
return result;
}
}
return "No input example has been defined for this model task.";
Expand Down
66 changes: 38 additions & 28 deletions js/src/lib/inferenceSnippets/serveCurl.ts
Original file line number Diff line number Diff line change
@@ -1,43 +1,53 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const bodyBasic = (model: ModelData): string =>
`-d '{"inputs": ${getModelInputSnippet(model, true)}}'`;
export const snippetBasic = (model: ModelData, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const bodyZeroShotClassification = (model: ModelData): string =>
`-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}'`;
export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const curlSnippetBodies:
Partial<Record<PipelineType, (model: ModelData) => string>> =
export const snippetFile = (model: ModelData, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const curlSnippets:
Partial<Record<PipelineType, (model: ModelData, accessToken: string) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": bodyBasic,
"token-classification": bodyBasic,
"table-question-answering": bodyBasic,
"question-answering": bodyBasic,
"zero-shot-classification": bodyZeroShotClassification,
"translation": bodyBasic,
"summarization": bodyBasic,
"conversational": bodyBasic,
"feature-extraction": bodyBasic,
"text-generation": bodyBasic,
"text2text-generation": bodyBasic,
"fill-mask": bodyBasic,
"sentence-similarity": bodyBasic,
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"conversational": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"image-classification": snippetFile,
};

export function getCurlInferenceSnippet(model: ModelData, accessToken: string): string {
const body = model.pipeline_tag && model.pipeline_tag in curlSnippetBodies
? curlSnippetBodies[model.pipeline_tag]?.(model) ?? ""
return model.pipeline_tag && model.pipeline_tag in curlSnippets
? curlSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";

return `curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
${body} \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;
}

export function hasCurlInferenceSnippet(model: ModelData): boolean {
return !!model.pipeline_tag && model.pipeline_tag in curlSnippetBodies;
return !!model.pipeline_tag && model.pipeline_tag in curlSnippets;
}
95 changes: 64 additions & 31 deletions js/src/lib/inferenceSnippets/serveJs.ts
Original file line number Diff line number Diff line change
@@ -1,37 +1,26 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const bodyBasic = (model: ModelData): string =>
`{"inputs": ${getModelInputSnippet(model)}}`;

export const bodyZeroShotClassification = (model: ModelData): string =>
`{"inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}`;
export const snippetBasic = (model: ModelData, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
method: "POST",
body: JSON.stringify(data),
}
);
const result = await response.json();
return result;
}

export const jsSnippetBodies:
Partial<Record<PipelineType, (model: ModelData) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": bodyBasic,
"token-classification": bodyBasic,
"table-question-answering": bodyBasic,
"question-answering": bodyBasic,
"zero-shot-classification": bodyZeroShotClassification,
"translation": bodyBasic,
"summarization": bodyBasic,
"conversational": bodyBasic,
"feature-extraction": bodyBasic,
"text-generation": bodyBasic,
"text2text-generation": bodyBasic,
"fill-mask": bodyBasic,
"sentence-similarity": bodyBasic,
};
query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
console.log(JSON.stringify(response));
});`;

export function getJsInferenceSnippet(model: ModelData, accessToken: string): string {
const body = model.pipeline_tag && model.pipeline_tag in jsSnippetBodies
? jsSnippetBodies[model.pipeline_tag]?.(model) ?? ""
: "";

return `async function query(data) {
export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
Expand All @@ -44,11 +33,55 @@ export function getJsInferenceSnippet(model: ModelData, accessToken: string): st
return result;
}

query(${body}).then((response) => {
query({"inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => {
console.log(JSON.stringify(response));
});`;

export const snippetFile = (model: ModelData, accessToken: string): string =>
`async function query(filename) {
const data = fs.readFileSync(filename);
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
method: "POST",
body: data,
}
);
const result = await response.json();
return result;
}

query(${getModelInputSnippet(model)}).then((response) => {
console.log(JSON.stringify(response));
});`;

export const jsSnippets:
Partial<Record<PipelineType, (model: ModelData, accessToken: string) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"conversational": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"image-classification": snippetFile,
};

export function getJsInferenceSnippet(model: ModelData, accessToken: string): string {
return model.pipeline_tag && model.pipeline_tag in jsSnippets
? jsSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";
}

export function hasJsInferenceSnippet(model: ModelData): boolean {
return !!model.pipeline_tag && model.pipeline_tag in jsSnippetBodies;
return !!model.pipeline_tag && model.pipeline_tag in jsSnippets;
}
64 changes: 39 additions & 25 deletions js/src/lib/inferenceSnippets/servePython.ts
Original file line number Diff line number Diff line change
@@ -1,53 +1,67 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const bodyZeroShotClassification = (model: ModelData): string =>
`output = query({
export const snippetZeroShotClassification = (model: ModelData): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

output = query({
"inputs": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
})`;

export const bodyBasic = (model: ModelData): string =>
`output = query({
export const snippetBasic = (model: ModelData): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

output = query({
"inputs": ${getModelInputSnippet(model)},
})`;

export const pythonSnippetBodies:
export const snippetFile = (model: ModelData): string =>
`def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.request("POST", API_URL, headers=headers, data=data)
return json.loads(response.content.decode("utf-8"))

output = query(${getModelInputSnippet(model)})`;

export const pythonSnippets:
Partial<Record<PipelineType, (model: ModelData) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": bodyBasic,
"token-classification": bodyBasic,
"table-question-answering": bodyBasic,
"question-answering": bodyBasic,
"zero-shot-classification": bodyZeroShotClassification,
"translation": bodyBasic,
"summarization": bodyBasic,
"conversational": bodyBasic,
"feature-extraction": bodyBasic,
"text-generation": bodyBasic,
"text2text-generation": bodyBasic,
"fill-mask": bodyBasic,
"sentence-similarity": bodyBasic,
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"conversational": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"image-classification": snippetFile,
};

export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string {
const body = model.pipeline_tag && model.pipeline_tag in pythonSnippetBodies
? pythonSnippetBodies[model.pipeline_tag]?.(model) ?? ""
const body = model.pipeline_tag && model.pipeline_tag in pythonSnippets
? pythonSnippets[model.pipeline_tag]?.(model) ?? ""
: "";

return `import requests

API_URL = "https://api-inference.huggingface.co/models/${model.id}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}

def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

${body}`;
}

export function hasPythonInferenceSnippet(model: ModelData): boolean {
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippetBodies;
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets;
}