Skip to content

Commit

Permalink
feat: aws sagemaker compatible image (#147)
Browse files Browse the repository at this point in the history
The only difference is that now it pushes to
registry.internal.huggingface.tech/api-inference/community/text-generation-inference/sagemaker:...
instead of
registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sagemaker-...

---------

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
  • Loading branch information
OlivierDehaene and philschmid authored Mar 29, 2023
1 parent c9bdaa8 commit d503e8f
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 3 deletions.
46 changes: 45 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,48 @@ jobs:
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max

build-and-push-sagemaker-image:
needs:
- build-and-push-image
runs-on: ubuntu-latest
steps:
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0
with:
install: true
- name: Checkout repository
uses: actions/checkout@v3
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4
- name: Login to internal Container Registry
uses: docker/login-action@v2.1.0
with:
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4.3.0
with:
flavor: |
latest=auto
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference/sagemaker
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
- name: Build and push Docker image
uses: docker/build-push-action@v2
with:
context: .
file: Dockerfile
push: ${{ github.event_name != 'pull_request' }}
platforms: 'linux/amd64'
target: sagemaker
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
13 changes: 12 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ COPY router router
COPY launcher launcher
RUN cargo build --release

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as base

ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
Expand Down Expand Up @@ -76,5 +76,16 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher

# AWS Sagemaker compatbile image
FROM base as sagemaker

COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]

# Original image
FROM base

ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
10 changes: 9 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,19 @@ pub async fn run(
// Create router
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
// Base routes
.route("/", post(compat_generate))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/", get(health))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
.route("/health", get(health))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route
.route("/ping", get(health))
// Prometheus metrics route
.route("/metrics", get(metrics))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
Expand Down
20 changes: 20 additions & 0 deletions sagemaker-entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

if [[ -z "${HF_MODEL_ID}" ]]; then
echo "HF_MODEL_ID must be set"
exit 1
fi

if [[ -n "${HF_MODEL_REVISION}" ]]; then
export REVISION="${HF_MODEL_REVISION}"
fi

if [[ -n "${SM_NUM_GPUS}" ]]; then
export NUM_SHARD="${SM_NUM_GPUS}"
fi

if [[ -n "${HF_MODEL_QUANTIZE}" ]]; then
export QUANTIZE="${HF_MODEL_QUANTIZE}"
fi

text-generation-launcher --port 8080

0 comments on commit d503e8f

Please sign in to comment.