Skip to content

Commit

Permalink
Merge pull request #1 from danielbank/prediction-api
Browse files Browse the repository at this point in the history
Image Classification API
  • Loading branch information
danielbank authored Sep 14, 2019
2 parents da02ec0 + 87c340e commit ef6013d
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ Cargo.lock

/target
#**/*.rs.bk

.DS_Store
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ gotham = "0.4.0"
hyper = "0.12"
image = "0.22.2"
mime = "0.3"
tract-core = "0.4.0"
tract-tensorflow = "0.4.0"
url = "2.1.0"
tract-core = "0.4.2"
tract-tensorflow = "0.4.2"
url = "2.1.0"
regex = "1.3.1"
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
# offline-ml

Offline ML with Rust

```
curl -i -X POST -F "image=@/Users/DanielBank/Desktop/grace_hopper.jpg" http://127.0.0.1:7878/
```
73 changes: 52 additions & 21 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
//! An example of decoding requests from an HTML form element
extern crate futures;
extern crate gotham;
extern crate hyper;
Expand All @@ -8,36 +6,70 @@ extern crate url;

use futures::{future, Future, Stream};
use hyper::{Body, StatusCode};
use url::form_urlencoded;

use gotham::handler::{HandlerFuture, IntoHandlerError};
use gotham::helpers::http::response::create_response;
use gotham::router::builder::{build_simple_router, DefineSingleRoute, DrawRoutes};
use gotham::router::Router;
use gotham::state::{FromState, State};
use regex::bytes::Regex;

const HELLO_WORLD: &str = "Hello World!";

fn say_hello(state: State) -> (State, &'static str) {
(state, HELLO_WORLD)
}
use tract_core::ndarray;
use tract_core::prelude::*;

/// Extracts the elements of the POST request and responds with the form keys and values
fn form_handler(mut state: State) -> Box<HandlerFuture> {
/// Extracts the image from a POST request and responds with a prediction tuple (probability, class)
fn prediction_handler(mut state: State) -> Box<HandlerFuture> {
let f = Body::take_from(&mut state)
.concat2()
.then(|full_body| match full_body {
Ok(valid_body) => {
// load the model
let mut model = tract_tensorflow::tensorflow()
.model_for_path("mobilenet_v2_1.4_224_frozen.pb")
.unwrap();

// specify input type and shape
model
.set_input_fact(
0,
TensorFact::dt_shape(f32::datum_type(), tvec!(1, 224, 224, 3)),
)
.unwrap();

// optimize the model and get an execution plan
let model = model.into_optimized().unwrap();
let plan = SimplePlan::new(&model).unwrap();

// extract the image from the body as input
let body_content = valid_body.into_bytes();
// Perform decoding on request body
let form_data = form_urlencoded::parse(&body_content).into_owned();
// Add form keys and values to response body
let mut res_body = String::new();
for (key, value) in form_data {
let res_body_line = format!("{}: {}\n", key, value);
res_body.push_str(&res_body_line);
}
let res = create_response(&state, StatusCode::OK, mime::TEXT_PLAIN, res_body);
let re = Regex::new(r"\r\n\r\n").unwrap();
let contents: Vec<_> = re.split(body_content.as_ref()).collect();
let image = image::load_from_memory(contents[1]).unwrap().to_rgb();
let resized = image::imageops::resize(&image, 224, 224, ::image::FilterType::Triangle);
let image: Tensor = ndarray::Array4::from_shape_fn((1, 224, 224, 3), |(_, y, x, c)| {
resized[(x as _, y as _)][c] as f32 / 255.0
})
.into();

// run the plan on the input
let result = plan.run(tvec!(image)).unwrap();

// find and display the max value with its index
let best = result[0]
.to_array_view::<f32>()
.unwrap()
.iter()
.cloned()
.zip(1..)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

// respond with the prediction tuple
let res = create_response(
&state,
StatusCode::OK,
mime::TEXT_PLAIN,
format!("{:?}", best.unwrap()),
);
future::ok((state, res))
}
Err(e) => future::err((state, e.into_handler_error())),
Expand All @@ -49,8 +81,7 @@ fn form_handler(mut state: State) -> Box<HandlerFuture> {
/// Create a `Router`
fn router() -> Router {
build_simple_router(|route| {
route.get("/").to(say_hello);
route.post("/").to(form_handler);
route.post("/").to(prediction_handler);
})
}

Expand Down

0 comments on commit ef6013d

Please sign in to comment.