Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
  • Loading branch information
samhita-alla committed Feb 20, 2023
0 parents commit 0cef693
Show file tree
Hide file tree
Showing 12 changed files with 1,125 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.python-version
__pycache__/
.DS_Store/
.vscode/
*.tgz
43 changes: 43 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#####################
# BANANA DOCKERFILE #
#####################

# Must use cuda version 11+
FROM nvcr.io/nvidia/cuda:11.4.2-cudnn8-devel-ubuntu20.04
ENV DEBIAN_FRONTEND=noninteractive
ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
RUN apt-get update && apt-get install -y --no-install-recommends python3-dev ca-certificates g++ python3-numpy gcc make git python3-setuptools python3-wheel python3-pip aria2 && aria2c -q -d /tmp -o cmake-3.21.0-linux-x86_64.tar.gz https://github.com/Kitware/CMake/releases/download/v3.21.0/cmake-3.21.0-linux-x86_64.tar.gz && tar -zxf /tmp/cmake-3.21.0-linux-x86_64.tar.gz --strip=1 -C /usr
WORKDIR /

# Install git & wget
RUN apt-get update && apt-get install -y git wget gfortran libsm6 libblas-dev liblapack-dev ffmpeg python3-pip && \
wget http://ftp.de.debian.org/debian/pool/main/y/youtube-dl/youtube-dl_2021.12.17-1_all.deb && \
apt-get install -y ./youtube-dl_2021.12.17-1_all.deb

# Install python packages
RUN pip3 install --upgrade pip
ADD services/banana/banana_requirements.txt requirements.txt
RUN pip3 install -r requirements.txt

# Clone model
RUN git clone https://github.com/samhita-alla/GeoEstimation.git

ADD services/banana/server.py GeoEstimation/
ADD services/banana/app.py GeoEstimation/
ADD app/post_processing.py GeoEstimation/
ADD app/pre_processing.py GeoEstimation/
ADD app/capture_video_frames.py GeoEstimation/

WORKDIR /GeoEstimation

RUN mkdir -p resources/s2_cells && \
wget -nc https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_5000.csv -O resources/s2_cells/cells_50_5000.csv && \
wget -nc https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_2000.csv -O resources/s2_cells/cells_50_2000.csv && \
wget -nc https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_1000.csv -O resources/s2_cells/cells_50_1000.csv

# Download model
RUN wget https://huggingface.co/Samhita/geolocator/resolve/main/geolocator.onnx

EXPOSE 8000

CMD python3 -u server.py
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Model Development and Inference with Flyte and Banana
31 changes: 31 additions & 0 deletions banana/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from transformers import AutoModelForSequenceClassification


# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
global model

device = 0 if torch.cuda.is_available() else -1

model = AutoModelForSequenceClassification.from_pretrained(
model_dir, num_labels=5
)


# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(model_inputs: dict) -> dict:
global model

# Parse out your arguments
prompt = model_inputs.get("prompt", None)
if prompt == None:
return {"message": "No prompt provided"}

# Run the model
result = model(prompt)

# Return the results as a dictionary
return result
1 change: 1 addition & 0 deletions banana/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sanic
44 changes: 44 additions & 0 deletions banana/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Do not edit if deploying to Banana Serverless
# This file is boilerplate for the http server, and follows a strict interface.

# Instead, edit the init() and inference() functions in app.py

from sanic import Sanic, response
import subprocess
import app as user_src

# We do the model load-to-GPU step on server startup
# so the model object is available globally for reuse
user_src.init()

# Create the http server app
server = Sanic("flyte_banana_app")


# Healthchecks verify that the environment is correct on Banana Serverless
@server.route("/healthcheck", methods=["GET"])
def healthcheck(request):
# dependency free way to check if GPU is visible
gpu = False
out = subprocess.run("nvidia-smi", shell=True)
if out.returncode == 0: # success state on shell command
gpu = True

return response.json({"state": "healthy", "gpu": gpu})


# Inference POST handler at '/' is called for every http call from Banana
@server.route("/", methods=["POST"])
def inference(request):
try:
model_inputs = response.json.loads(request.json)
except:
model_inputs = request.json

output = user_src.inference(model_inputs)

return response.json(output)


if __name__ == "__main__":
server.run(host="0.0.0.0", port=8000, workers=1)
27 changes: 27 additions & 0 deletions flyte/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime

WORKDIR /root
ENV VENV /opt/venv
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8
ENV PYTHONPATH /root

RUN apt-get update && apt-get install -y build-essential
RUN pip3 install awscli

ENV VENV /opt/venv

# Virtual environment
RUN python3 -m venv ${VENV}
ENV PATH="${VENV}/bin:$PATH"

# Install Python dependencies
COPY ./requirements.txt /root
RUN pip install -r /root/requirements.txt

COPY workflows /root/workflows

# This tag is supplied by the build script and will be used to determine the version
# when registering tasks, workflows, and launch plans
ARG tag
ENV FLYTE_INTERNAL_IMAGE $tag
6 changes: 6 additions & 0 deletions flyte/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
datasets
transformers
evaluate
torch
scikit-learn
flytekit
Loading

0 comments on commit 0cef693

Please sign in to comment.