forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
jni-java wrapper for pytorchScript api (pytorch#25084)
Summary: TLDR; initial commit of android java-jni wrapper of pytorchscript c++ api The main idea is to provide java interface for android developers to use pytorchscript modules. java API tries to repeat semantic of c++ and python pytorchscript API org.pytorch.Module (wrapper of torch::jit::script::Module) - static Module load(String path) - IValue forward(IValue... inputs) - IValue runMethod(String methodName, IValue... inputs) org.pytorch.Tensor (semantic of at::Tensor) - newFloatTensor(long[] dims, float[] data) - newFloatTensor(long[] dims, FloatBuffer data) - newIntTensor(long[] dims, int[] data) - newIntTensor(long[] dims, IntBuffer data) - newByteTensor(long[] dims, byte[] data) - newByteTensor(long[] dims, ByteBuffer data) org.pytorch.IValue (semantic of at::IValue) - static factory methods to create pytorchscript supported types Examples of usage api could be found in PytorchInstrumentedTests.java: Module module = Module.load(path); IValue input = IValue.tensor(Tensor.newByteTensor(new long[]{1}, Tensor.allocateByteBuffer(1))); IValue output = module.forward(input); Tensor outputTensor = output.getTensor(); ThreadSafety: Api is not thread safe, all synchronization must be done on caller side. Mutability: org.pytorch.Tensor buffer is DirectBuffer with native byte order, can be created with static factory methods specifing DirectBuffer. At the moment org.pytorch.Tensor does not hold at::Tensor on jni side, it has: long[] dimensions, type, DirectByteBuffer blobData Input tensors are mutable (can be modified and used for the next inference), Uses values from buffer on the momment of Module#forward or Module#runMethod calls. Buffers of input tensors is used directly by input at::Tensor Output is copied from output at::Tensor and is immutable. Dependencies: Jni level is implemented with usage of fbjni library, that was developed in Facebook, and was already used and opensourced in several opensource projects, added to the repo as submodule from personal account to be able to switch submodule when fbjni will be opensourced separately. ghstack-source-id: b39c848359a70d717f2830a15265e4aa122279c0 Pull Request resolved: pytorch#25084 Pull Request resolved: pytorch#25105 Reviewed By: dreiss Differential Revision: D16988107 Pulled By: IvanKobzarev fbshipit-source-id: 41ca7c9869f8370b8504c2ef8a96047cc16516d4
- Loading branch information
1 parent
3a59a9b
commit d62bca9
Showing
17 changed files
with
1,715 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
local.properties | ||
**/*.iml | ||
.gradle | ||
gradlew* | ||
.idea/* | ||
.externalNativeBuild | ||
build | ||
pytorch_android/src/main/cpp/libtorch_include/x86/** | ||
pytorch_android/src/main/cpp/libtorch_include/x86_64/** | ||
pytorch_android/src/main/cpp/libtorch_include/armeabi-v7a/** | ||
pytorch_android/src/main/cpp/libtorch_include/arm64-v8a/** | ||
pytorch_android/src/main/jniLibs/x86/** | ||
pytorch_android/src/main/jniLibs/x86_64/** | ||
pytorch_android/src/main/jniLibs/armeabi-v7a/** | ||
pytorch_android/src/main/jniLibs/arm64-v8a/** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
buildscript { | ||
ext { | ||
minSdkVersion = 21 | ||
targetSdkVersion = 28 | ||
compileSdkVersion = 28 | ||
buildToolsVersion = '28.0.3' | ||
|
||
coreVersion = "1.2.0" | ||
extJUnitVersion = "1.1.1" | ||
runnerVersion = "1.2.0" | ||
rulesVersion = "1.2.0" | ||
junitVersion = "4.12" | ||
} | ||
|
||
repositories { | ||
google() | ||
mavenLocal() | ||
mavenCentral() | ||
jcenter() | ||
} | ||
|
||
dependencies { | ||
classpath 'com.android.tools.build:gradle:3.3.2' | ||
} | ||
} | ||
|
||
allprojects { | ||
repositories { | ||
google() | ||
jcenter() | ||
} | ||
} | ||
|
||
ext.deps = [ | ||
jsr305: 'com.google.code.findbugs:jsr305:3.0.1', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
cmake_minimum_required(VERSION 3.4.1) | ||
project(pytorch CXX) | ||
set(CMAKE_CXX_STANDARD 11) | ||
set(CMAKE_VERBOSE_MAKEFILE ON) | ||
|
||
set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) | ||
set(libtorch_include_DIR ${pytorch_android_DIR}/libtorch_include/${ANDROID_ABI}) | ||
|
||
set(libtorch_SO ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libtorch.so) | ||
set(libc10_SO ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libc10.so) | ||
|
||
message(STATUS "libtorch dir:${libtorch_DIR}") | ||
|
||
add_library(libtorch SHARED IMPORTED) | ||
set_property(TARGET libtorch PROPERTY IMPORTED_LOCATION ${libtorch_SO}) | ||
|
||
add_library(libc10 SHARED IMPORTED ${libc10_SO}) | ||
set_property(TARGET libc10 PROPERTY IMPORTED_LOCATION ${libc10_SO}) | ||
|
||
file(GLOB pytorch_android_SOURCES | ||
${pytorch_android_DIR}/*.cpp | ||
) | ||
|
||
add_library(pytorch SHARED | ||
${pytorch_android_SOURCES} | ||
) | ||
|
||
target_compile_options(pytorch PRIVATE | ||
-fexceptions | ||
) | ||
|
||
target_include_directories(pytorch PUBLIC | ||
${libtorch_include_DIR} | ||
) | ||
|
||
set(BUILD_DIR ${CMAKE_SOURCE_DIR}/build) | ||
file(MAKE_DIRECTORY ${BUILD_DIR}) | ||
|
||
set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/) | ||
set(fbjni_BUILD_DIR ${BUILD_DIR}/fbjni/${ANDROID_ABI}) | ||
|
||
add_subdirectory(${fbjni_DIR} ${fbjni_BUILD_DIR}) | ||
|
||
target_link_libraries(pytorch | ||
fbjni | ||
libtorch | ||
libc10 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
apply plugin: 'com.android.library' | ||
|
||
android { | ||
compileSdkVersion rootProject.compileSdkVersion | ||
buildToolsVersion rootProject.buildToolsVersion | ||
|
||
defaultConfig { | ||
minSdkVersion rootProject.minSdkVersion | ||
targetSdkVersion rootProject.targetSdkVersion | ||
versionCode 1 | ||
versionName "1.0" | ||
|
||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" | ||
ndk { | ||
abiFilters "armeabi-v7a", "arm64-v8a", "x86", "x86_64" | ||
} | ||
} | ||
buildTypes { | ||
debug { | ||
minifyEnabled false | ||
} | ||
release { | ||
minifyEnabled false | ||
} | ||
} | ||
sourceSets { | ||
main { | ||
jniLibs.srcDirs = ['src/main/jniLibs'] | ||
} | ||
} | ||
externalNativeBuild { | ||
cmake { | ||
path "CMakeLists.txt" | ||
} | ||
} | ||
|
||
useLibrary 'android.test.runner' | ||
useLibrary 'android.test.base' | ||
useLibrary 'android.test.mock' | ||
} | ||
|
||
dependencies { | ||
implementation project(':fbjni') | ||
|
||
implementation 'com.android.support:appcompat-v7:28.0.0' | ||
|
||
testImplementation 'junit:junit:' + rootProject.junitVersion | ||
testImplementation 'androidx.test:core:' + rootProject.coreVersion | ||
|
||
androidTestImplementation 'junit:junit:' + rootProject.junitVersion | ||
androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion | ||
androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion | ||
androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion | ||
androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import torch | ||
from typing import List, Tuple, Dict | ||
|
||
OUTPUT_DIR = "src/androidTest/assets/" | ||
|
||
def scriptAndSave(module, fileName): | ||
print('-'*80) | ||
script_module = torch.jit.script(module) | ||
print(script_module.graph) | ||
outputFileName = OUTPUT_DIR + fileName | ||
script_module.save(outputFileName) | ||
print("Saved to " + outputFileName) | ||
print('='*80) | ||
|
||
class Test(torch.jit.ScriptModule): | ||
def __init__(self): | ||
super(Test, self).__init__() | ||
|
||
@torch.jit.script_method | ||
def forward(self, input): | ||
return None | ||
|
||
@torch.jit.script_method | ||
def eqBool(self, input): | ||
# type: (bool) -> bool | ||
return input | ||
|
||
@torch.jit.script_method | ||
def eqInt(self, input): | ||
# type: (int) -> int | ||
return input | ||
|
||
@torch.jit.script_method | ||
def eqFloat(self, input): | ||
# type: (float) -> float | ||
return input | ||
|
||
@torch.jit.script_method | ||
def eqTensor(self, input): | ||
# type: (Tensor) -> Tensor | ||
return input | ||
|
||
@torch.jit.script_method | ||
def eqDictStrKeyIntValue(self, input): | ||
# type: (Dict[str, int]) -> Dict[str, int] | ||
return input | ||
|
||
@torch.jit.script_method | ||
def eqDictIntKeyIntValue(self, input): | ||
# type: (Dict[int, int]) -> Dict[int, int] | ||
return input | ||
|
||
@torch.jit.script_method | ||
def eqDictFloatKeyIntValue(self, input): | ||
# type: (Dict[float, int]) -> Dict[float, int] | ||
return input | ||
|
||
@torch.jit.script_method | ||
def listIntSumReturnTuple(self, input): | ||
# type: (List[int]) -> Tuple[List[int], int] | ||
sum = 0 | ||
for x in input: | ||
sum += x | ||
return (input, sum) | ||
|
||
@torch.jit.script_method | ||
def listBoolConjunction(self, input): | ||
# type: (List[bool]) -> bool | ||
res = True | ||
for x in input: | ||
res = res and x | ||
return res | ||
|
||
@torch.jit.script_method | ||
def listBoolDisjunction(self, input): | ||
# type: (List[bool]) -> bool | ||
res = False | ||
for x in input: | ||
res = res or x | ||
return res | ||
|
||
@torch.jit.script_method | ||
def tupleIntSumReturnTuple(self, input): | ||
# type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] | ||
sum = 0 | ||
for x in input: | ||
sum += x | ||
return (input, sum) | ||
|
||
@torch.jit.script_method | ||
def optionalIntIsNone(self, input): | ||
# type: (Optional[int]) -> bool | ||
return input is None | ||
|
||
@torch.jit.script_method | ||
def intEq0None(self, input): | ||
# type: (int) -> Optional[int] | ||
if input == 0: | ||
return None | ||
return input | ||
|
||
|
||
scriptAndSave(Test(), "test.pt") |
Binary file not shown.
Oops, something went wrong.