Skip to content

Commit

Permalink
Generalize inference snippets (huggingface#106)
Browse files Browse the repository at this point in the history
* Snippet for `img-cls` task

* Chore

* modelInputSnippets lexiographical ordering

* Better renamings
  • Loading branch information
mishig25 authored Apr 19, 2022
1 parent c539e60 commit 9cf841d
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 88 deletions.
19 changes: 15 additions & 4 deletions js/src/lib/inferenceSnippets/inputs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ const inputsSentenceSimilarity = () =>
const inputsFeatureExtraction = () =>
`"Today is a sunny day and I'll get some ice cream."`;

const inputsImageClassification = () => `"cats.jpg"`;

const modelInputSnippets: {
[key in PipelineType]?: (model: ModelData) => string;
} = {
"conversational": inputsConversational,
"feature-extraction": inputsFeatureExtraction,
"fill-mask": inputsFillMask,
"image-classification": inputsImageClassification,
"question-answering": inputsQuestionAnswering,
"sentence-similarity": inputsSentenceSimilarity,
"summarization": inputsSummarization,
Expand All @@ -84,13 +87,21 @@ const modelInputSnippets: {
};

// Use noWrap to put the whole snippet on a single line (removing new lines and tabulations)
export function getModelInputSnippet(model: ModelData, noWrap = false): string {
// Use noQuotes to strip quotes from start & end (example: "abc" -> abc)
export function getModelInputSnippet(model: ModelData, noWrap = false, noQuotes = false): string {
if (model.pipeline_tag) {
const inputs = modelInputSnippets[model.pipeline_tag];
if (inputs) {
return noWrap
? inputs(model).replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ")
: inputs(model);
let result = inputs(model);
if (noWrap) {
result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ");
}
if (noQuotes) {
const REGEX_QUOTES = /^"(.+)"$/s;
const match = result.match(REGEX_QUOTES);
result = match ? match[1] : result;
}
return result;
}
}
return "No input example has been defined for this model task.";
Expand Down
66 changes: 38 additions & 28 deletions js/src/lib/inferenceSnippets/serveCurl.ts
Original file line number Diff line number Diff line change
@@ -1,43 +1,53 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const bodyBasic = (model: ModelData): string =>
`-d '{"inputs": ${getModelInputSnippet(model, true)}}'`;
export const snippetBasic = (model: ModelData, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const bodyZeroShotClassification = (model: ModelData): string =>
`-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}'`;
export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const curlSnippetBodies:
Partial<Record<PipelineType, (model: ModelData) => string>> =
export const snippetFile = (model: ModelData, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const curlSnippets:
Partial<Record<PipelineType, (model: ModelData, accessToken: string) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": bodyBasic,
"token-classification": bodyBasic,
"table-question-answering": bodyBasic,
"question-answering": bodyBasic,
"zero-shot-classification": bodyZeroShotClassification,
"translation": bodyBasic,
"summarization": bodyBasic,
"conversational": bodyBasic,
"feature-extraction": bodyBasic,
"text-generation": bodyBasic,
"text2text-generation": bodyBasic,
"fill-mask": bodyBasic,
"sentence-similarity": bodyBasic,
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"conversational": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"image-classification": snippetFile,
};

export function getCurlInferenceSnippet(model: ModelData, accessToken: string): string {
const body = model.pipeline_tag && model.pipeline_tag in curlSnippetBodies
? curlSnippetBodies[model.pipeline_tag]?.(model) ?? ""
return model.pipeline_tag && model.pipeline_tag in curlSnippets
? curlSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";

return `curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
${body} \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;
}

export function hasCurlInferenceSnippet(model: ModelData): boolean {
return !!model.pipeline_tag && model.pipeline_tag in curlSnippetBodies;
return !!model.pipeline_tag && model.pipeline_tag in curlSnippets;
}
95 changes: 64 additions & 31 deletions js/src/lib/inferenceSnippets/serveJs.ts
Original file line number Diff line number Diff line change
@@ -1,37 +1,26 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const bodyBasic = (model: ModelData): string =>
`{"inputs": ${getModelInputSnippet(model)}}`;

export const bodyZeroShotClassification = (model: ModelData): string =>
`{"inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}`;
export const snippetBasic = (model: ModelData, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
method: "POST",
body: JSON.stringify(data),
}
);
const result = await response.json();
return result;
}
export const jsSnippetBodies:
Partial<Record<PipelineType, (model: ModelData) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": bodyBasic,
"token-classification": bodyBasic,
"table-question-answering": bodyBasic,
"question-answering": bodyBasic,
"zero-shot-classification": bodyZeroShotClassification,
"translation": bodyBasic,
"summarization": bodyBasic,
"conversational": bodyBasic,
"feature-extraction": bodyBasic,
"text-generation": bodyBasic,
"text2text-generation": bodyBasic,
"fill-mask": bodyBasic,
"sentence-similarity": bodyBasic,
};
query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
console.log(JSON.stringify(response));
});`;

export function getJsInferenceSnippet(model: ModelData, accessToken: string): string {
const body = model.pipeline_tag && model.pipeline_tag in jsSnippetBodies
? jsSnippetBodies[model.pipeline_tag]?.(model) ?? ""
: "";

return `async function query(data) {
export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
Expand All @@ -44,11 +33,55 @@ export function getJsInferenceSnippet(model: ModelData, accessToken: string): st
return result;
}
query(${body}).then((response) => {
query({"inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => {
console.log(JSON.stringify(response));
});`;

export const snippetFile = (model: ModelData, accessToken: string): string =>
`async function query(filename) {
const data = fs.readFileSync(filename);
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
method: "POST",
body: data,
}
);
const result = await response.json();
return result;
}
query(${getModelInputSnippet(model)}).then((response) => {
console.log(JSON.stringify(response));
});`;

export const jsSnippets:
Partial<Record<PipelineType, (model: ModelData, accessToken: string) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"conversational": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"image-classification": snippetFile,
};

export function getJsInferenceSnippet(model: ModelData, accessToken: string): string {
return model.pipeline_tag && model.pipeline_tag in jsSnippets
? jsSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";
}

export function hasJsInferenceSnippet(model: ModelData): boolean {
return !!model.pipeline_tag && model.pipeline_tag in jsSnippetBodies;
return !!model.pipeline_tag && model.pipeline_tag in jsSnippets;
}
64 changes: 39 additions & 25 deletions js/src/lib/inferenceSnippets/servePython.ts
Original file line number Diff line number Diff line change
@@ -1,53 +1,67 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const bodyZeroShotClassification = (model: ModelData): string =>
`output = query({
export const snippetZeroShotClassification = (model: ModelData): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
})`;

export const bodyBasic = (model: ModelData): string =>
`output = query({
export const snippetBasic = (model: ModelData): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": ${getModelInputSnippet(model)},
})`;

export const pythonSnippetBodies:
export const snippetFile = (model: ModelData): string =>
`def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.request("POST", API_URL, headers=headers, data=data)
return json.loads(response.content.decode("utf-8"))
output = query(${getModelInputSnippet(model)})`;

export const pythonSnippets:
Partial<Record<PipelineType, (model: ModelData) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": bodyBasic,
"token-classification": bodyBasic,
"table-question-answering": bodyBasic,
"question-answering": bodyBasic,
"zero-shot-classification": bodyZeroShotClassification,
"translation": bodyBasic,
"summarization": bodyBasic,
"conversational": bodyBasic,
"feature-extraction": bodyBasic,
"text-generation": bodyBasic,
"text2text-generation": bodyBasic,
"fill-mask": bodyBasic,
"sentence-similarity": bodyBasic,
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"conversational": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"image-classification": snippetFile,
};

export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string {
const body = model.pipeline_tag && model.pipeline_tag in pythonSnippetBodies
? pythonSnippetBodies[model.pipeline_tag]?.(model) ?? ""
const body = model.pipeline_tag && model.pipeline_tag in pythonSnippets
? pythonSnippets[model.pipeline_tag]?.(model) ?? ""
: "";

return `import requests
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
${body}`;
}

export function hasPythonInferenceSnippet(model: ModelData): boolean {
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippetBodies;
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets;
}

0 comments on commit 9cf841d

Please sign in to comment.