forked from huggingface/hub-docs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generalize
inference snippets
(huggingface#106)
* Snippet for `img-cls` task * Chore * modelInputSnippets lexiographical ordering * Better renamings
- Loading branch information
Showing
4 changed files
with
156 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,53 @@ | ||
import type { PipelineType, ModelData } from "../interfaces/Types"; | ||
import { getModelInputSnippet } from "./inputs"; | ||
|
||
export const bodyBasic = (model: ModelData): string => | ||
`-d '{"inputs": ${getModelInputSnippet(model, true)}}'`; | ||
export const snippetBasic = (model: ModelData, accessToken: string): string => | ||
`curl https://api-inference.huggingface.co/models/${model.id} \\ | ||
-X POST \\ | ||
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\ | ||
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" | ||
`; | ||
|
||
export const bodyZeroShotClassification = (model: ModelData): string => | ||
`-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}'`; | ||
export const snippetZeroShotClassification = (model: ModelData, accessToken: string): string => | ||
`curl https://api-inference.huggingface.co/models/${model.id} \\ | ||
-X POST \\ | ||
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\ | ||
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" | ||
`; | ||
|
||
export const curlSnippetBodies: | ||
Partial<Record<PipelineType, (model: ModelData) => string>> = | ||
export const snippetFile = (model: ModelData, accessToken: string): string => | ||
`curl https://api-inference.huggingface.co/models/${model.id} \\ | ||
-X POST \\ | ||
--data-binary '@${getModelInputSnippet(model, true, true)}' \\ | ||
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" | ||
`; | ||
|
||
export const curlSnippets: | ||
Partial<Record<PipelineType, (model: ModelData, accessToken: string) => string>> = | ||
{ | ||
// Same order as in js/src/lib/interfaces/Types.ts | ||
"text-classification": bodyBasic, | ||
"token-classification": bodyBasic, | ||
"table-question-answering": bodyBasic, | ||
"question-answering": bodyBasic, | ||
"zero-shot-classification": bodyZeroShotClassification, | ||
"translation": bodyBasic, | ||
"summarization": bodyBasic, | ||
"conversational": bodyBasic, | ||
"feature-extraction": bodyBasic, | ||
"text-generation": bodyBasic, | ||
"text2text-generation": bodyBasic, | ||
"fill-mask": bodyBasic, | ||
"sentence-similarity": bodyBasic, | ||
"text-classification": snippetBasic, | ||
"token-classification": snippetBasic, | ||
"table-question-answering": snippetBasic, | ||
"question-answering": snippetBasic, | ||
"zero-shot-classification": snippetZeroShotClassification, | ||
"translation": snippetBasic, | ||
"summarization": snippetBasic, | ||
"conversational": snippetBasic, | ||
"feature-extraction": snippetBasic, | ||
"text-generation": snippetBasic, | ||
"text2text-generation": snippetBasic, | ||
"fill-mask": snippetBasic, | ||
"sentence-similarity": snippetBasic, | ||
"image-classification": snippetFile, | ||
}; | ||
|
||
export function getCurlInferenceSnippet(model: ModelData, accessToken: string): string { | ||
const body = model.pipeline_tag && model.pipeline_tag in curlSnippetBodies | ||
? curlSnippetBodies[model.pipeline_tag]?.(model) ?? "" | ||
return model.pipeline_tag && model.pipeline_tag in curlSnippets | ||
? curlSnippets[model.pipeline_tag]?.(model, accessToken) ?? "" | ||
: ""; | ||
|
||
return `curl https://api-inference.huggingface.co/models/${model.id} \\ | ||
-X POST \\ | ||
${body} \\ | ||
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" | ||
`; | ||
} | ||
|
||
export function hasCurlInferenceSnippet(model: ModelData): boolean { | ||
return !!model.pipeline_tag && model.pipeline_tag in curlSnippetBodies; | ||
return !!model.pipeline_tag && model.pipeline_tag in curlSnippets; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,67 @@ | ||
import type { PipelineType, ModelData } from "../interfaces/Types"; | ||
import { getModelInputSnippet } from "./inputs"; | ||
|
||
export const bodyZeroShotClassification = (model: ModelData): string => | ||
`output = query({ | ||
export const snippetZeroShotClassification = (model: ModelData): string => | ||
`def query(payload): | ||
response = requests.post(API_URL, headers=headers, json=payload) | ||
return response.json() | ||
output = query({ | ||
"inputs": ${getModelInputSnippet(model)}, | ||
"parameters": {"candidate_labels": ["refund", "legal", "faq"]}, | ||
})`; | ||
|
||
export const bodyBasic = (model: ModelData): string => | ||
`output = query({ | ||
export const snippetBasic = (model: ModelData): string => | ||
`def query(payload): | ||
response = requests.post(API_URL, headers=headers, json=payload) | ||
return response.json() | ||
output = query({ | ||
"inputs": ${getModelInputSnippet(model)}, | ||
})`; | ||
|
||
export const pythonSnippetBodies: | ||
export const snippetFile = (model: ModelData): string => | ||
`def query(filename): | ||
with open(filename, "rb") as f: | ||
data = f.read() | ||
response = requests.request("POST", API_URL, headers=headers, data=data) | ||
return json.loads(response.content.decode("utf-8")) | ||
output = query(${getModelInputSnippet(model)})`; | ||
|
||
export const pythonSnippets: | ||
Partial<Record<PipelineType, (model: ModelData) => string>> = | ||
{ | ||
// Same order as in js/src/lib/interfaces/Types.ts | ||
"text-classification": bodyBasic, | ||
"token-classification": bodyBasic, | ||
"table-question-answering": bodyBasic, | ||
"question-answering": bodyBasic, | ||
"zero-shot-classification": bodyZeroShotClassification, | ||
"translation": bodyBasic, | ||
"summarization": bodyBasic, | ||
"conversational": bodyBasic, | ||
"feature-extraction": bodyBasic, | ||
"text-generation": bodyBasic, | ||
"text2text-generation": bodyBasic, | ||
"fill-mask": bodyBasic, | ||
"sentence-similarity": bodyBasic, | ||
"text-classification": snippetBasic, | ||
"token-classification": snippetBasic, | ||
"table-question-answering": snippetBasic, | ||
"question-answering": snippetBasic, | ||
"zero-shot-classification": snippetZeroShotClassification, | ||
"translation": snippetBasic, | ||
"summarization": snippetBasic, | ||
"conversational": snippetBasic, | ||
"feature-extraction": snippetBasic, | ||
"text-generation": snippetBasic, | ||
"text2text-generation": snippetBasic, | ||
"fill-mask": snippetBasic, | ||
"sentence-similarity": snippetBasic, | ||
"image-classification": snippetFile, | ||
}; | ||
|
||
export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string { | ||
const body = model.pipeline_tag && model.pipeline_tag in pythonSnippetBodies | ||
? pythonSnippetBodies[model.pipeline_tag]?.(model) ?? "" | ||
const body = model.pipeline_tag && model.pipeline_tag in pythonSnippets | ||
? pythonSnippets[model.pipeline_tag]?.(model) ?? "" | ||
: ""; | ||
|
||
return `import requests | ||
API_URL = "https://api-inference.huggingface.co/models/${model.id}" | ||
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}} | ||
def query(payload): | ||
response = requests.post(API_URL, headers=headers, json=payload) | ||
return response.json() | ||
${body}`; | ||
} | ||
|
||
export function hasPythonInferenceSnippet(model: ModelData): boolean { | ||
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippetBodies; | ||
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets; | ||
} |