Skip to content

Commit 2c4a689

Browse files
committed
initial import
1 parent f022e27 commit 2c4a689

13 files changed

+227
-0
lines changed

Procfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
web: gunicorn app:app

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# PyTorch Flask API
2+
3+
4+
## Requirements
5+
6+
Install them from `requirements.txt`:
7+
8+
pip install -r requirements.txt
9+
10+
11+
## Local Deployment
12+
13+
Run the server:
14+
15+
python app.py
16+
17+
18+
## Heroku Deployment
19+
20+
[![Deploy](https://www.herokucdn.com/deploy/button.svg)](https://heroku.com/deploy?template=https://github.com/avinassh/pytorch-flask-api-heroku)

app.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"name": "PyTorch Image Detection",
3+
"description": "Image Classifier built using PyTorch",
4+
"repository": "https://github.com/avinassh/pytorch-flask-api-heroku",
5+
"keywords": ["python", "flask", "pytorch", "bootstrap"]
6+
}

app.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
3+
from flask import Flask, render_template, request, redirect
4+
5+
from inference import get_prediction
6+
from commons import format_class_name
7+
8+
app = Flask(__name__)
9+
10+
11+
@app.route('/', methods=['GET', 'POST'])
12+
def upload_file():
13+
if request.method == 'POST':
14+
if 'file' not in request.files:
15+
return redirect(request.url)
16+
file = request.files.get('file')
17+
if not file:
18+
return
19+
img_bytes = file.read()
20+
class_id, class_name = get_prediction(image_bytes=img_bytes)
21+
class_name = format_class_name(class_name)
22+
return render_template('result.html', class_id=class_id,
23+
class_name=class_name)
24+
return render_template('index.html')
25+
26+
27+
if __name__ == '__main__':
28+
app.run(debug=True, port=int(os.environ.get('PORT', 5000)))

commons.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import io
2+
3+
4+
from PIL import Image
5+
from torchvision import models
6+
import torchvision.transforms as transforms
7+
8+
9+
def get_model():
10+
model = models.densenet121(pretrained=True)
11+
model.eval()
12+
return model
13+
14+
15+
def transform_image(image_bytes):
16+
my_transforms = transforms.Compose([transforms.Resize(255),
17+
transforms.CenterCrop(224),
18+
transforms.ToTensor(),
19+
transforms.Normalize(
20+
[0.485, 0.456, 0.406],
21+
[0.229, 0.224, 0.225])])
22+
image = Image.open(io.BytesIO(image_bytes))
23+
return my_transforms(image).unsqueeze(0)
24+
25+
26+
# ImageNet classes are often of the form `can_opener` or `Egyptian_cat`
27+
# will use this method to properly format it so that we get
28+
# `Can Opener` or `Egyptian Cat`
29+
def format_class_name(class_name):
30+
class_name = class_name.replace('_', ' ')
31+
class_name = class_name.title()
32+
return class_name

imagenet_class_index.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

inference.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import json
2+
3+
from commons import get_model, transform_image
4+
5+
model = get_model()
6+
imagenet_class_index = json.load(open('imagenet_class_index.json'))
7+
8+
9+
def get_prediction(image_bytes):
10+
try:
11+
tensor = transform_image(image_bytes=image_bytes)
12+
outputs = model.forward(tensor)
13+
except Exception:
14+
return 0, 'error'
15+
_, y_hat = outputs.max(1)
16+
predicted_idx = str(y_hat.item())
17+
return imagenet_class_index[predicted_idx]

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Flask==1.0.3
2+
torchvision==0.3.0
3+
numpy==1.16.4
4+
Pillow==6.0.0
5+
gunicorn==19.9.0

runtime.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python-3.7.1

static/pytorch.png

11.1 KB
Loading

static/style.css

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
html,
2+
body {
3+
height: 100%;
4+
}
5+
6+
body {
7+
display: -ms-flexbox;
8+
display: flex;
9+
-ms-flex-align: center;
10+
align-items: center;
11+
padding-top: 40px;
12+
padding-bottom: 40px;
13+
background-color: #f5f5f5;
14+
}
15+
16+
.form-signin {
17+
width: 100%;
18+
max-width: 330px;
19+
padding: 15px;
20+
margin: auto;
21+
}
22+
23+
.form-signin .form-control {
24+
position: relative;
25+
box-sizing: border-box;
26+
height: auto;
27+
padding: 10px;
28+
font-size: 16px;
29+
}

templates/index.html

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
<!doctype html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="utf-8">
5+
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
6+
<link rel="stylesheet" href="//stackpath.bootstrapcdn.com/bootstrap/4.2.1/css/bootstrap.min.css" integrity="sha384-GJzZqFGwb1QTTN6wy59ffF1BuGJpLSa9DkKMp0DgiMDm4iYMj70gZWKYbI706tWS" crossorigin="anonymous">
7+
<style>
8+
.bd-placeholder-img {
9+
font-size: 1.125rem;
10+
text-anchor: middle;
11+
}
12+
13+
@media (min-width: 768px) {
14+
.bd-placeholder-img-lg {
15+
font-size: 3.5rem;
16+
}
17+
}
18+
</style>
19+
<link rel="stylesheet" href="/static/style.css">
20+
21+
<title>Image Prediction using PyTorch</title>
22+
</head>
23+
<body class="text-center">
24+
<form class="form-signin" method=post enctype=multipart/form-data>
25+
<img class="mb-4" src="/static/pytorch.png" alt="" width="72">
26+
<h1 class="h3 mb-3 font-weight-normal">Upload any image</h1>
27+
<input type="file" name="file" class="form-control-file" id="inputfile">
28+
<br/>
29+
<button class="btn btn-lg btn-primary btn-block" type="submit">Upload</button>
30+
<p class="mt-5 mb-3 text-muted">Built using Pytorch, Flask and Love</p>
31+
</form>
32+
<script src="//code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script>
33+
<script src="//cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.6/umd/popper.min.js" integrity="sha384-wHAiFfRlMFy6i5SRaxvfOCifBUQy1xHdJ/yoi7FRNXMRBu5WHdZYu1hA6ZOblgut" crossorigin="anonymous"></script>
34+
<script src="//stackpath.bootstrapcdn.com/bootstrap/4.2.1/js/bootstrap.min.js" integrity="sha384-B0UglyR+jN6CkvvICOB2joaf5I4l3gm9GU6Hc1og6Ls7i6U/mkkaduKaBhlAXv9k" crossorigin="anonymous"></script>
35+
<script type="text/javascript">
36+
$('#inputfile').bind('change', function() {
37+
let fileSize = this.files[0].size/1024/1024; // this gives in MB
38+
if (fileSize > 1) {
39+
$("#inputfile").val(null);
40+
alert('file is too big. images more than 1MB are not allowed')
41+
return
42+
}
43+
44+
let ext = $('#inputfile').val().split('.').pop().toLowerCase();
45+
if($.inArray(ext, ['jpg','jpeg']) == -1) {
46+
$("#inputfile").val(null);
47+
alert('only jpeg/jpg files are allowed!');
48+
}
49+
});
50+
</script>
51+
</body>
52+
</html>

templates/result.html

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
<!doctype html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="utf-8">
5+
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
6+
<link rel="stylesheet" href="//stackpath.bootstrapcdn.com/bootstrap/4.2.1/css/bootstrap.min.css" integrity="sha384-GJzZqFGwb1QTTN6wy59ffF1BuGJpLSa9DkKMp0DgiMDm4iYMj70gZWKYbI706tWS" crossorigin="anonymous">
7+
<style>
8+
.bd-placeholder-img {
9+
font-size: 1.125rem;
10+
text-anchor: middle;
11+
}
12+
13+
@media (min-width: 768px) {
14+
.bd-placeholder-img-lg {
15+
font-size: 3.5rem;
16+
}
17+
}
18+
</style>
19+
<link rel="stylesheet" href="/static/style.css">
20+
21+
<title>Image Prediction using PyTorch</title>
22+
</head>
23+
<body class="text-center">
24+
<form class="form-signin" method=post enctype=multipart/form-data>
25+
<img class="mb-4" src="/static/pytorch.png" alt="" width="72">
26+
<h1 class="h3 mb-3 font-weight-normal">Prediction</h1>
27+
<h5 class="h5 mb-3 font-weight-normal">Detected Image: {{ class_name }}</h5>
28+
<h5 class="h6 mb-3 font-weight-normal">ImageNet Class ID: {{ class_id }}</h5>
29+
<p class="mt-5 mb-3 text-muted">Built using Pytorch, Flask and Love</p>
30+
</form>
31+
<script src="//code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script>
32+
<script src="//cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.6/umd/popper.min.js" integrity="sha384-wHAiFfRlMFy6i5SRaxvfOCifBUQy1xHdJ/yoi7FRNXMRBu5WHdZYu1hA6ZOblgut" crossorigin="anonymous"></script>
33+
<script src="//stackpath.bootstrapcdn.com/bootstrap/4.2.1/js/bootstrap.min.js" integrity="sha384-B0UglyR+jN6CkvvICOB2joaf5I4l3gm9GU6Hc1og6Ls7i6U/mkkaduKaBhlAXv9k" crossorigin="anonymous"></script>
34+
</body>
35+
</html>

0 commit comments

Comments
 (0)