-
Notifications
You must be signed in to change notification settings - Fork 661
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/master' into master
- Loading branch information
Showing
13 changed files
with
416 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# DJL Easy | ||
|
||
## Overview | ||
|
||
This module is a no deep learning knowledge required wrapper over DJL. Instead of worrying about finding a model or how to train, this will provide a simple recommendation for your deep learning application. It is the easiest way to get started with DJL and get a solution for your deep learning problem. | ||
|
||
## List of Applications | ||
|
||
This module contains the following applications: | ||
|
||
- Image Classification - take an image and classify the main subject of the image. | ||
|
||
|
||
## Documentation | ||
|
||
The latest javadocs can be found on the [djl.ai website](https://javadoc.io/doc/ai.djl/easy/latest/index.html). | ||
|
||
You can also build the latest javadocs locally using the following command: | ||
|
||
```sh | ||
# for Linux/macOS: | ||
./gradlew javadoc | ||
|
||
# for Windows: | ||
..\gradlew javadoc | ||
``` | ||
The javadocs output is built in the build/doc/javadoc folder. | ||
|
||
|
||
## Installation | ||
You can pull the module from the central Maven repository by including the following dependency in your `pom.xml` file: | ||
|
||
```xml | ||
<dependency> | ||
<groupId>ai.djl</groupId> | ||
<artifactId>easy</artifactId> | ||
<version>0.9.0-SNAPSHOT</version> | ||
</dependency> | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
dependencies { | ||
api project(":api") | ||
api project(":basicdataset") | ||
testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" | ||
|
||
api project(":testing") | ||
api("org.testng:testng:${testng_version}") { | ||
exclude group: "junit", module: "junit" | ||
} | ||
|
||
// Current engines and model zoos used for inference | ||
runtimeOnly project(":model-zoo") | ||
// runtimeOnly project(":pytorch:pytorch-engine") | ||
// runtimeOnly project(":pytorch:pytorch-model-zoo") | ||
// runtimeOnly "ai.djl.pytorch:pytorch-native-auto:${pytorch_version}" | ||
// runtimeOnly project(":tensorflow:tensorflow-engine") | ||
// runtimeOnly "ai.djl.tensorflow:tensorflow-native-auto:${tensorflow_version}" | ||
// runtimeOnly project(":onnxruntime:onnxruntime-engine") | ||
// runtimeOnly "com.microsoft.onnxruntime:onnxruntime:${onnxruntime_version}" | ||
runtimeOnly project(":mxnet:mxnet-engine") | ||
runtimeOnly project(":mxnet:mxnet-model-zoo") | ||
runtimeOnly "ai.djl.mxnet:mxnet-native-auto:${mxnet_version}" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../gradlew |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
/* | ||
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.easy; | ||
|
||
/** | ||
* Describes the speed/accuracy tradeoff. | ||
* | ||
* <p>In deep learning, it is often possible to improve the accuracy of a model by using a larger | ||
* model. However, this then results in slower latency and worse throughput. So, there is a tradeoff | ||
* between the choices of speed and accuracy. | ||
*/ | ||
public enum Performance { | ||
/** Fast prioritizes speed over accuracy. */ | ||
FAST, | ||
|
||
/** Balanced has a more even tradeoff of speed and accuracy. */ | ||
BALANCED, | ||
|
||
/** Accurate prioritizes accuracy over speed. */ | ||
ACCURATE | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* | ||
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.easy; | ||
|
||
import ai.djl.engine.Engine; | ||
import ai.djl.repository.zoo.ModelZoo; | ||
|
||
/** | ||
* A set of utilities for requiring a {@link ModelZoo}. | ||
* | ||
* <p>Throws an exception if the {@link ModelZoo} is not available. | ||
*/ | ||
public final class RequireZoo { | ||
|
||
private RequireZoo() {} | ||
|
||
/** Requires {@code ai.djl.basicmodelzoo.BasicModelZoo}. */ | ||
public static void basic() { | ||
if (!ModelZoo.hasModelZoo("ai.djl.zoo")) { | ||
throw new IllegalStateException( | ||
"The basic model zoo is required, but not found." | ||
+ "Please install it by following http://docs.djl.ai/model-zoo/index.html#installation"); | ||
} | ||
} | ||
|
||
/** Requires {@code ai.djl.mxnet.zoo.MxModelZoo}. */ | ||
public static void mxnet() { | ||
if (!ModelZoo.hasModelZoo("ai.djl.mxnet")) { | ||
throw new IllegalStateException( | ||
"The MXNet model zoo is required, but not found." | ||
+ "Please install it by following http://docs.djl.ai/mxnet/mxnet-model-zoo/index.html#installation"); | ||
} | ||
if (!Engine.hasEngine("MXNet")) { | ||
throw new IllegalStateException( | ||
"The MXNet engine is required, but not found." | ||
+ "Please install it by following http://docs.djl.ai/mxnet/mxnet-engine/index.html#installation"); | ||
} | ||
} | ||
|
||
/** Requires {@code ai.djl.pytorch.zoo.PtModelZoo}. */ | ||
public static void pytorch() { | ||
if (!ModelZoo.hasModelZoo("ai.djl.pytorch")) { | ||
throw new IllegalStateException( | ||
"The PyTorch model zoo is required, but not found." | ||
+ "Please install it by following http://docs.djl.ai/pytorch/pytorch-model-zoo/index.html#installation"); | ||
} | ||
if (!Engine.hasEngine("PyTorch")) { | ||
throw new IllegalStateException( | ||
"The PyTorch engine is required, but not found." | ||
+ "Please install it by following http://docs.djl.ai/pytorch/pytorch-engine/index.html#installation"); | ||
} | ||
} | ||
|
||
/** Requires {@code ai.djl.tensorflow.zoo.TfModelZoo}. */ | ||
public static void tensorflow() { | ||
if (!ModelZoo.hasModelZoo("ai.djl.tensorflow")) { | ||
throw new IllegalStateException( | ||
"The TensorFlow model zoo is required, but not found." | ||
+ "Please install it by following http://docs.djl.ai/tensorflow/tensorflow-model-zoo/index.html#installation"); | ||
} | ||
if (!Engine.hasEngine("TensorFlow")) { | ||
throw new IllegalStateException( | ||
"The TensorFlow engine is required, but not found." | ||
+ "Please install it by following http://docs.djl.ai/tensorflow/tensorflow-engine/index.html#installation"); | ||
} | ||
} | ||
} |
114 changes: 114 additions & 0 deletions
114
djl-easy/src/main/java/ai/djl/easy/cv/ImageClassification.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.easy.cv; | ||
|
||
import ai.djl.Application.CV; | ||
import ai.djl.MalformedModelException; | ||
import ai.djl.easy.Performance; | ||
import ai.djl.easy.RequireZoo; | ||
import ai.djl.modality.Classifications; | ||
import ai.djl.repository.zoo.Criteria; | ||
import ai.djl.repository.zoo.ModelNotFoundException; | ||
import ai.djl.repository.zoo.ModelZoo; | ||
import ai.djl.repository.zoo.ZooModel; | ||
import java.io.IOException; | ||
|
||
/** ImageClassification takes an image and classifies the main subject of the image. */ | ||
public final class ImageClassification { | ||
private ImageClassification() {} | ||
|
||
/** | ||
* Returns a pretrained and ready to use image classification model from our model zoo. | ||
* | ||
* @param input the input class between {@link ai.djl.modality.cv.Image}, {@link | ||
* java.nio.file.Path}, {@link java.net.URL}, and {@link java.io.InputStream} | ||
* @param classes what {@link Classes} the image is classified into | ||
* @param performance the performance tradeoff (see {@link Performance} | ||
* @param <I> the input type | ||
* @return a pretrained and ready to use model from our model zoo | ||
* @throws MalformedModelException if the model zoo model is broken | ||
* @throws ModelNotFoundException if the model could not be found | ||
* @throws IOException if the model could not be loaded | ||
*/ | ||
public static <I> ZooModel<I, Classifications> pretrained( | ||
Class<I> input, Classes classes, Performance performance) | ||
throws MalformedModelException, ModelNotFoundException, IOException { | ||
Criteria.Builder<I, Classifications> criteria = | ||
Criteria.builder() | ||
.setTypes(input, Classifications.class) | ||
.optApplication(CV.IMAGE_CLASSIFICATION); | ||
|
||
switch (classes) { | ||
case IMAGENET: | ||
RequireZoo.mxnet(); | ||
criteria.optGroupId("ai.djl.mxnet") | ||
.optArtifactId("resnet") | ||
.optFilter("dataset", "imagenet"); | ||
switch (performance) { | ||
case FAST: | ||
criteria.optFilter("layers", "18"); | ||
break; | ||
case BALANCED: | ||
criteria.optFilter("layers", "50"); | ||
break; | ||
case ACCURATE: | ||
criteria.optFilter("layers", "152"); | ||
break; | ||
default: | ||
throw new IllegalArgumentException("Unknown performance"); | ||
} | ||
break; | ||
case DIGITS: | ||
RequireZoo.basic(); | ||
criteria.optGroupId("ai.djl.zoo") | ||
.optArtifactId("mlp") | ||
.optFilter("dataset", "mnist"); | ||
break; | ||
default: | ||
throw new IllegalArgumentException("Unknown classes"); | ||
} | ||
|
||
return ModelZoo.loadModel(criteria.build()); | ||
} | ||
|
||
/* | ||
I am leaving this commented out as an example of what the DJL-Easy train API should look like. | ||
public static <I> ZooModel<I, Classifications> train(Class<I> input, Dataset dataset, Performance performance) { | ||
throw new UnsupportedOperationException("Not yet implemented"); | ||
} | ||
*/ | ||
|
||
/** | ||
* The possible classes to classify the images into. | ||
* | ||
* <p>The classes available depends on the data that the model was trained with. | ||
*/ | ||
public enum Classes { | ||
|
||
/** | ||
* Imagenet is a standard dataset of 1000 diverse classes. | ||
* | ||
* <p>The dataset can be found at {@link ai.djl.basicdataset.ImageNet}. You can <a | ||
* href="https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/mxnet/synset.txt">view | ||
* the list of classes here</a>. | ||
*/ | ||
IMAGENET, | ||
|
||
/** | ||
* Classify images of the digits 0-9. | ||
* | ||
* <p>This contains models trained using the {@link ai.djl.basicdataset.Mnist} dataset. | ||
*/ | ||
DIGITS | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/* | ||
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
|
||
/** | ||
* Contains easy pretrained models and training for Computer Vision({@link ai.djl.Application.CV}) | ||
* tasks. | ||
*/ | ||
package ai.djl.easy.cv; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/* | ||
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
|
||
/** | ||
* Contains a no deep learning knowledge required wrapper over DJL. | ||
* | ||
* <p><a href="https://docs.djl.ai/easy/index.html">See more details</a>. | ||
*/ | ||
package ai.djl.easy; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
<html> | ||
<head> | ||
<meta charset="UTF-8"> | ||
</head> | ||
<body> | ||
<p>This document is the API specification for the Deep Java Library (DJL) easy API.</p> | ||
|
||
<p> | ||
The easy module contains a no deep learning knowledge required wrapper over DJL. | ||
See <a href="https://github.com/awslabs/djl/tree/master/easy">here</a> for more details. | ||
</p> | ||
|
||
</body> | ||
</html> |
Oops, something went wrong.