-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
- Loading branch information
0 parents
commit 0cef693
Showing
12 changed files
with
1,125 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
.python-version | ||
__pycache__/ | ||
.DS_Store/ | ||
.vscode/ | ||
*.tgz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
##################### | ||
# BANANA DOCKERFILE # | ||
##################### | ||
|
||
# Must use cuda version 11+ | ||
FROM nvcr.io/nvidia/cuda:11.4.2-cudnn8-devel-ubuntu20.04 | ||
ENV DEBIAN_FRONTEND=noninteractive | ||
ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} | ||
RUN apt-get update && apt-get install -y --no-install-recommends python3-dev ca-certificates g++ python3-numpy gcc make git python3-setuptools python3-wheel python3-pip aria2 && aria2c -q -d /tmp -o cmake-3.21.0-linux-x86_64.tar.gz https://github.com/Kitware/CMake/releases/download/v3.21.0/cmake-3.21.0-linux-x86_64.tar.gz && tar -zxf /tmp/cmake-3.21.0-linux-x86_64.tar.gz --strip=1 -C /usr | ||
WORKDIR / | ||
|
||
# Install git & wget | ||
RUN apt-get update && apt-get install -y git wget gfortran libsm6 libblas-dev liblapack-dev ffmpeg python3-pip && \ | ||
wget http://ftp.de.debian.org/debian/pool/main/y/youtube-dl/youtube-dl_2021.12.17-1_all.deb && \ | ||
apt-get install -y ./youtube-dl_2021.12.17-1_all.deb | ||
|
||
# Install python packages | ||
RUN pip3 install --upgrade pip | ||
ADD services/banana/banana_requirements.txt requirements.txt | ||
RUN pip3 install -r requirements.txt | ||
|
||
# Clone model | ||
RUN git clone https://github.com/samhita-alla/GeoEstimation.git | ||
|
||
ADD services/banana/server.py GeoEstimation/ | ||
ADD services/banana/app.py GeoEstimation/ | ||
ADD app/post_processing.py GeoEstimation/ | ||
ADD app/pre_processing.py GeoEstimation/ | ||
ADD app/capture_video_frames.py GeoEstimation/ | ||
|
||
WORKDIR /GeoEstimation | ||
|
||
RUN mkdir -p resources/s2_cells && \ | ||
wget -nc https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_5000.csv -O resources/s2_cells/cells_50_5000.csv && \ | ||
wget -nc https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_2000.csv -O resources/s2_cells/cells_50_2000.csv && \ | ||
wget -nc https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_1000.csv -O resources/s2_cells/cells_50_1000.csv | ||
|
||
# Download model | ||
RUN wget https://huggingface.co/Samhita/geolocator/resolve/main/geolocator.onnx | ||
|
||
EXPOSE 8000 | ||
|
||
CMD python3 -u server.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Model Development and Inference with Flyte and Banana |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
from transformers import AutoModelForSequenceClassification | ||
|
||
|
||
# Init is ran on server startup | ||
# Load your model to GPU as a global variable here using the variable name "model" | ||
def init(): | ||
global model | ||
|
||
device = 0 if torch.cuda.is_available() else -1 | ||
|
||
model = AutoModelForSequenceClassification.from_pretrained( | ||
model_dir, num_labels=5 | ||
) | ||
|
||
|
||
# Inference is ran for every server call | ||
# Reference your preloaded global model variable here. | ||
def inference(model_inputs: dict) -> dict: | ||
global model | ||
|
||
# Parse out your arguments | ||
prompt = model_inputs.get("prompt", None) | ||
if prompt == None: | ||
return {"message": "No prompt provided"} | ||
|
||
# Run the model | ||
result = model(prompt) | ||
|
||
# Return the results as a dictionary | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
sanic |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Do not edit if deploying to Banana Serverless | ||
# This file is boilerplate for the http server, and follows a strict interface. | ||
|
||
# Instead, edit the init() and inference() functions in app.py | ||
|
||
from sanic import Sanic, response | ||
import subprocess | ||
import app as user_src | ||
|
||
# We do the model load-to-GPU step on server startup | ||
# so the model object is available globally for reuse | ||
user_src.init() | ||
|
||
# Create the http server app | ||
server = Sanic("flyte_banana_app") | ||
|
||
|
||
# Healthchecks verify that the environment is correct on Banana Serverless | ||
@server.route("/healthcheck", methods=["GET"]) | ||
def healthcheck(request): | ||
# dependency free way to check if GPU is visible | ||
gpu = False | ||
out = subprocess.run("nvidia-smi", shell=True) | ||
if out.returncode == 0: # success state on shell command | ||
gpu = True | ||
|
||
return response.json({"state": "healthy", "gpu": gpu}) | ||
|
||
|
||
# Inference POST handler at '/' is called for every http call from Banana | ||
@server.route("/", methods=["POST"]) | ||
def inference(request): | ||
try: | ||
model_inputs = response.json.loads(request.json) | ||
except: | ||
model_inputs = request.json | ||
|
||
output = user_src.inference(model_inputs) | ||
|
||
return response.json(output) | ||
|
||
|
||
if __name__ == "__main__": | ||
server.run(host="0.0.0.0", port=8000, workers=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime | ||
|
||
WORKDIR /root | ||
ENV VENV /opt/venv | ||
ENV LANG C.UTF-8 | ||
ENV LC_ALL C.UTF-8 | ||
ENV PYTHONPATH /root | ||
|
||
RUN apt-get update && apt-get install -y build-essential | ||
RUN pip3 install awscli | ||
|
||
ENV VENV /opt/venv | ||
|
||
# Virtual environment | ||
RUN python3 -m venv ${VENV} | ||
ENV PATH="${VENV}/bin:$PATH" | ||
|
||
# Install Python dependencies | ||
COPY ./requirements.txt /root | ||
RUN pip install -r /root/requirements.txt | ||
|
||
COPY workflows /root/workflows | ||
|
||
# This tag is supplied by the build script and will be used to determine the version | ||
# when registering tasks, workflows, and launch plans | ||
ARG tag | ||
ENV FLYTE_INTERNAL_IMAGE $tag |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
datasets | ||
transformers | ||
evaluate | ||
torch | ||
scikit-learn | ||
flytekit |
Oops, something went wrong.