From 87c340efcf309a2b6f5087897e3aafcb59ba4573 Mon Sep 17 00:00:00 2001 From: Daniel Bank Date: Fri, 13 Sep 2019 17:32:52 -0700 Subject: [PATCH] Image Classification API --- .gitignore | 2 ++ Cargo.toml | 7 ++--- README.md | 5 ++++ src/main.rs | 73 ++++++++++++++++++++++++++++++++++++++--------------- 4 files changed, 63 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index 154e80c..4c6331f 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ Cargo.lock /target #**/*.rs.bk + +.DS_Store \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index a586680..28dbeab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ gotham = "0.4.0" hyper = "0.12" image = "0.22.2" mime = "0.3" -tract-core = "0.4.0" -tract-tensorflow = "0.4.0" -url = "2.1.0" \ No newline at end of file +tract-core = "0.4.2" +tract-tensorflow = "0.4.2" +url = "2.1.0" +regex = "1.3.1" \ No newline at end of file diff --git a/README.md b/README.md index df5dd8c..ed855f4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,7 @@ # offline-ml + Offline ML with Rust + +``` +curl -i -X POST -F "image=@/Users/DanielBank/Desktop/grace_hopper.jpg" http://127.0.0.1:7878/ +``` diff --git a/src/main.rs b/src/main.rs index aecfb02..82bd4f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,3 @@ -//! An example of decoding requests from an HTML form element - extern crate futures; extern crate gotham; extern crate hyper; @@ -8,36 +6,70 @@ extern crate url; use futures::{future, Future, Stream}; use hyper::{Body, StatusCode}; -use url::form_urlencoded; use gotham::handler::{HandlerFuture, IntoHandlerError}; use gotham::helpers::http::response::create_response; use gotham::router::builder::{build_simple_router, DefineSingleRoute, DrawRoutes}; use gotham::router::Router; use gotham::state::{FromState, State}; +use regex::bytes::Regex; -const HELLO_WORLD: &str = "Hello World!"; - -fn say_hello(state: State) -> (State, &'static str) { - (state, HELLO_WORLD) -} +use tract_core::ndarray; +use tract_core::prelude::*; -/// Extracts the elements of the POST request and responds with the form keys and values -fn form_handler(mut state: State) -> Box { +/// Extracts the image from a POST request and responds with a prediction tuple (probability, class) +fn prediction_handler(mut state: State) -> Box { let f = Body::take_from(&mut state) .concat2() .then(|full_body| match full_body { Ok(valid_body) => { + // load the model + let mut model = tract_tensorflow::tensorflow() + .model_for_path("mobilenet_v2_1.4_224_frozen.pb") + .unwrap(); + + // specify input type and shape + model + .set_input_fact( + 0, + TensorFact::dt_shape(f32::datum_type(), tvec!(1, 224, 224, 3)), + ) + .unwrap(); + + // optimize the model and get an execution plan + let model = model.into_optimized().unwrap(); + let plan = SimplePlan::new(&model).unwrap(); + + // extract the image from the body as input let body_content = valid_body.into_bytes(); - // Perform decoding on request body - let form_data = form_urlencoded::parse(&body_content).into_owned(); - // Add form keys and values to response body - let mut res_body = String::new(); - for (key, value) in form_data { - let res_body_line = format!("{}: {}\n", key, value); - res_body.push_str(&res_body_line); - } - let res = create_response(&state, StatusCode::OK, mime::TEXT_PLAIN, res_body); + let re = Regex::new(r"\r\n\r\n").unwrap(); + let contents: Vec<_> = re.split(body_content.as_ref()).collect(); + let image = image::load_from_memory(contents[1]).unwrap().to_rgb(); + let resized = image::imageops::resize(&image, 224, 224, ::image::FilterType::Triangle); + let image: Tensor = ndarray::Array4::from_shape_fn((1, 224, 224, 3), |(_, y, x, c)| { + resized[(x as _, y as _)][c] as f32 / 255.0 + }) + .into(); + + // run the plan on the input + let result = plan.run(tvec!(image)).unwrap(); + + // find and display the max value with its index + let best = result[0] + .to_array_view::() + .unwrap() + .iter() + .cloned() + .zip(1..) + .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + + // respond with the prediction tuple + let res = create_response( + &state, + StatusCode::OK, + mime::TEXT_PLAIN, + format!("{:?}", best.unwrap()), + ); future::ok((state, res)) } Err(e) => future::err((state, e.into_handler_error())), @@ -49,8 +81,7 @@ fn form_handler(mut state: State) -> Box { /// Create a `Router` fn router() -> Router { build_simple_router(|route| { - route.get("/").to(say_hello); - route.post("/").to(form_handler); + route.post("/").to(prediction_handler); }) }