forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[R] Add example of image classification
- Loading branch information
Showing
3 changed files
with
140 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,20 @@ | ||
MXNet R-Package | ||
=============== | ||
This is an on-going effort to support mxnet in R, stay tuned. | ||
The MXNet R packages brings flexible and efficient GPU computing and deep learning to R. | ||
|
||
Bleeding edge Installation | ||
- It enables you to write seamless tensor/matrix computation with multiple GPUs in R. | ||
- It also enables you construct and customize the state-of-art deep learning models in R, | ||
and apply them to tasks such as image classification and data science challenges. | ||
|
||
Installation | ||
------------ | ||
- First build ```../lib/libmxnet.so``` by following [Build Instruction](../doc/build.md) | ||
- Type ```R CMD INSTALL R-package``` in the root folder. | ||
|
||
Examples | ||
-------- | ||
- [Classify Real World Image with MXNet R Package](vignettes/classifyRealImageWithPretrainedModel.Rmd) | ||
|
||
Contributor Guide for R | ||
----------------------- | ||
Checkout [Contributor Guideline](../doc/contribute.md#r-package) |
128 changes: 128 additions & 0 deletions
128
R-package/vignettes/classifyRealImageWithPretrainedModel.Rmd
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
Classify Real-World Images with Pre-trained Model | ||
================================================= | ||
MXNet is a flexible and efficient deep learning framework. One of the cool thing that a deep learning | ||
algorithm can do is to classify real world images. | ||
|
||
In this example we will show how to use a pretrained Inception-BatchNorm Network to predict the class of | ||
real world image. The network architecture is decribed in [1]. | ||
|
||
The pre-trained Inception-BatchNorm network is able to be downloaded from [this link](http://webdocs.cs.ualberta.ca/~bx3/data/Inception.zip) | ||
This model gives the recent state-of-art prediction accuracy on image net dataset. | ||
|
||
|
||
Pacakge Loading | ||
--------------- | ||
To get started, we load the mxnet package by require mxnet. | ||
```{r} | ||
require(mxnet) | ||
``` | ||
|
||
In this example, we also need the imager package to load and preprocess the images in R. | ||
|
||
```{r} | ||
require(imager) | ||
``` | ||
|
||
Load the Pretrained Model | ||
------------------------- | ||
Make sure you unzip the pre-trained model in current folder. And we can use the model | ||
loading function to load the model into R. | ||
|
||
```{r} | ||
model = mx.model.load("Inception/Inception_BN", iteration=39) | ||
``` | ||
|
||
We also need to load in the mean image, which is used for preprocessing using ```mx.nd.load```. | ||
|
||
```{r} | ||
mean.img = as.array(mx.nd.load("Inception/mean_224.nd")[["mean_img"]]) | ||
``` | ||
|
||
Load and Preprocess the Image | ||
----------------------------- | ||
Now we are ready to classify a real image. In this example, we simply take the parrots image | ||
from imager package. But you can always change it to other images. | ||
|
||
Load and plot the image: | ||
|
||
```{r} | ||
im <- load.image(system.file("extdata/parrots.png", package="imager")) | ||
plot(im) | ||
``` | ||
|
||
Before feeding the image to the deep net, we need to do some preprocessing | ||
to make the image fit the input requirement of deepnet. The preprocessing | ||
include cropping, and substraction of the mean. | ||
Because mxnet is deeply integerated with R, we can do all the processing in R function. | ||
|
||
The preprocessing function: | ||
|
||
```{r} | ||
preproc.image <-function(im, mean.image) { | ||
# crop the image | ||
shape <- dim(im) | ||
short.edge <- min(shape[1:2]) | ||
yy <- floor((shape[1] - short.edge) / 2) + 1 | ||
yend <- yy + short.edge - 1 | ||
xx <- floor((shape[2] - short.edge) / 2) + 1 | ||
xend <- xx + short.edge - 1 | ||
croped <- im[yy:yend, xx:xend,,] | ||
# resize to 224 x 224, needed by input of the model. | ||
resized <- resize(croped, 224, 224) | ||
# convert to array (x, y, channel) | ||
arr <- as.array(resized) | ||
dim(arr) = c(224, 224, 3) | ||
# Change to the format of mxnet (channel, height, width) | ||
sample <- aperm(arr, c(3, 2, 1)) | ||
# substract the mean | ||
normed <- sample - mean.img | ||
# Reshape to format needed by mxnet | ||
dim(normed) <- c(1, 3, 224, 224) | ||
return(normed) | ||
} | ||
``` | ||
|
||
We use the defined preprocessing function to get the normalized image. | ||
|
||
```{r} | ||
normed <- preproc.image(im, mean.img) | ||
``` | ||
|
||
Classify the Image | ||
------------------ | ||
Now we are ready to classify the image! We can use the predict function | ||
to get the probability over classes. | ||
|
||
```{r} | ||
prob <- predict(model, X=normed) | ||
dim(prob) | ||
``` | ||
|
||
As you can see ```prob``` is a 1 times 1000 array, which gives the probability | ||
over the 1000 image classes of the input. | ||
|
||
We can use the ```max.col``` to get the class index. | ||
```{r} | ||
max.idx <- max.col(prob) | ||
max.idx | ||
``` | ||
|
||
The index do not make too much sense. So let us see what it really corresponds to. | ||
We can read the names of the classes from the following file. | ||
|
||
```{r} | ||
synsets <- readLines("Inception/synset.txt") | ||
``` | ||
|
||
And let us see what it really is | ||
|
||
```{r} | ||
print(paste0("Predicted Top-class: ", synsets[[max.idx]])) | ||
``` | ||
|
||
Actually I do not know what does the word mean when I saw it. | ||
So I searched on the web to check it out.. and hmm it does get the right answer :) | ||
|
||
Reference | ||
--------- | ||
[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." arXiv preprint arXiv:1502.03167 (2015). |