Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize inference snippets #106

Merged
merged 4 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions js/src/lib/inferenceSnippets/inputs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ const inputsSentenceSimilarity = () =>
const inputsFeatureExtraction = () =>
`"Today is a sunny day and I'll get some ice cream."`;

const inputsImageClassification = () => `"cats.jpg"`;

const modelInputSnippets: {
[key in PipelineType]?: (model: ModelData) => string;
} = {
Expand All @@ -81,16 +83,24 @@ const modelInputSnippets: {
"token-classification": inputsTokenClassification,
"translation": inputsTranslation,
"zero-shot-classification": inputsZeroShotClassification,
"image-classification": inputsImageClassification,
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
};

// Use noWrap to put the whole snippet on a single line (removing new lines and tabulations)
export function getModelInputSnippet(model: ModelData, noWrap = false): string {
// Use noQuotes to strip quotes from start & end (example: "abc" -> abc)
export function getModelInputSnippet(model: ModelData, noWrap = false, noQuotes = false): string {
const REGEX_QUOTES = /^"(.+)"$/s;
if (model.pipeline_tag) {
const inputs = modelInputSnippets[model.pipeline_tag];
if (inputs) {
return noWrap
? inputs(model).replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ")
: inputs(model);
let result = inputs(model);
if (noWrap) {
result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ");
}
if (noQuotes && result.match(REGEX_QUOTES)) {
result = result.match(REGEX_QUOTES)[1];
}
return result;
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
}
}
return "No input example has been defined for this model task.";
Expand Down
62 changes: 36 additions & 26 deletions js/src/lib/inferenceSnippets/serveCurl.ts
Original file line number Diff line number Diff line change
@@ -1,41 +1,51 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const bodyBasic = (model: ModelData): string =>
`-d '{"inputs": ${getModelInputSnippet(model, true)}}'`;
export const snippetBasic = (model: ModelData, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const bodyZeroShotClassification = (model: ModelData): string =>
`-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}'`;
export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const snippetFile = (model: ModelData, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const curlSnippetBodies:
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
Partial<Record<PipelineType, (model: ModelData) => string>> =
Partial<Record<PipelineType, (model: ModelData, accessToken: string) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": bodyBasic,
"token-classification": bodyBasic,
"table-question-answering": bodyBasic,
"question-answering": bodyBasic,
"zero-shot-classification": bodyZeroShotClassification,
"translation": bodyBasic,
"summarization": bodyBasic,
"conversational": bodyBasic,
"feature-extraction": bodyBasic,
"text-generation": bodyBasic,
"text2text-generation": bodyBasic,
"fill-mask": bodyBasic,
"sentence-similarity": bodyBasic,
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"conversational": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"image-classification": snippetFile,
};

export function getCurlInferenceSnippet(model: ModelData, accessToken: string): string {
const body = model.pipeline_tag && model.pipeline_tag in curlSnippetBodies
? curlSnippetBodies[model.pipeline_tag]?.(model) ?? ""
return model.pipeline_tag && model.pipeline_tag in curlSnippetBodies
? curlSnippetBodies[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";

return `curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
${body} \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;
}

export function hasCurlInferenceSnippet(model: ModelData): boolean {
Expand Down
93 changes: 63 additions & 30 deletions js/src/lib/inferenceSnippets/serveJs.ts
Original file line number Diff line number Diff line change
@@ -1,37 +1,26 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const bodyBasic = (model: ModelData): string =>
`{"inputs": ${getModelInputSnippet(model)}}`;

export const bodyZeroShotClassification = (model: ModelData): string =>
`{"inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}`;
export const snippetBasic = (model: ModelData, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
method: "POST",
body: JSON.stringify(data),
}
);
const result = await response.json();
return result;
}

export const jsSnippetBodies:
Partial<Record<PipelineType, (model: ModelData) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": bodyBasic,
"token-classification": bodyBasic,
"table-question-answering": bodyBasic,
"question-answering": bodyBasic,
"zero-shot-classification": bodyZeroShotClassification,
"translation": bodyBasic,
"summarization": bodyBasic,
"conversational": bodyBasic,
"feature-extraction": bodyBasic,
"text-generation": bodyBasic,
"text2text-generation": bodyBasic,
"fill-mask": bodyBasic,
"sentence-similarity": bodyBasic,
};
query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
console.log(JSON.stringify(response));
});`;

export function getJsInferenceSnippet(model: ModelData, accessToken: string): string {
const body = model.pipeline_tag && model.pipeline_tag in jsSnippetBodies
? jsSnippetBodies[model.pipeline_tag]?.(model) ?? ""
: "";

return `async function query(data) {
export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
Expand All @@ -44,9 +33,53 @@ export function getJsInferenceSnippet(model: ModelData, accessToken: string): st
return result;
}

query(${body}).then((response) => {
query({"inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => {
console.log(JSON.stringify(response));
});`;

export const snippetFile = (model: ModelData, accessToken: string): string =>
`async function query(filename) {
const data = fs.readFileSync(filename);
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
method: "POST",
body: data,
}
);
const result = await response.json();
return result;
}

query(${getModelInputSnippet(model)}).then((response) => {
console.log(JSON.stringify(response));
});`;

export const jsSnippetBodies:
Partial<Record<PipelineType, (model: ModelData, accessToken: string) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"conversational": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"image-classification": snippetFile,
};

export function getJsInferenceSnippet(model: ModelData, accessToken: string): string {
return model.pipeline_tag && model.pipeline_tag in jsSnippetBodies
? jsSnippetBodies[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";
}

export function hasJsInferenceSnippet(model: ModelData): boolean {
Expand Down
56 changes: 35 additions & 21 deletions js/src/lib/inferenceSnippets/servePython.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,52 @@
import type { PipelineType, ModelData } from "../interfaces/Types";
import { getModelInputSnippet } from "./inputs";

export const bodyZeroShotClassification = (model: ModelData): string =>
`output = query({
export const snippetZeroShotClassification = (model: ModelData): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

output = query({
"inputs": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
})`;

export const bodyBasic = (model: ModelData): string =>
`output = query({
export const snippetBasic = (model: ModelData): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

output = query({
"inputs": ${getModelInputSnippet(model)},
})`;

export const snippetFile = (model: ModelData): string =>
`def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.request("POST", API_URL, headers=headers, data=data)
return json.loads(response.content.decode("utf-8"))

output = query(${getModelInputSnippet(model)})`;

export const pythonSnippetBodies:
Partial<Record<PipelineType, (model: ModelData) => string>> =
{
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": bodyBasic,
"token-classification": bodyBasic,
"table-question-answering": bodyBasic,
"question-answering": bodyBasic,
"zero-shot-classification": bodyZeroShotClassification,
"translation": bodyBasic,
"summarization": bodyBasic,
"conversational": bodyBasic,
"feature-extraction": bodyBasic,
"text-generation": bodyBasic,
"text2text-generation": bodyBasic,
"fill-mask": bodyBasic,
"sentence-similarity": bodyBasic,
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
"translation": snippetBasic,
"summarization": snippetBasic,
"conversational": snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"image-classification": snippetFile,
};

export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string {
Expand All @@ -41,10 +59,6 @@ export function getPythonInferenceSnippet(model: ModelData, accessToken: string)
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}

def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

${body}`;
}

Expand Down