Skip to content

Commit

Permalink
jni-java wrapper for pytorchScript api (pytorch#25084)
Browse files Browse the repository at this point in the history
Summary:
TLDR; initial commit of android java-jni wrapper of pytorchscript c++ api

The main idea is to provide java interface for android developers to use pytorchscript modules.
java API tries to repeat semantic of c++ and python pytorchscript API

org.pytorch.Module (wrapper of torch::jit::script::Module)
 - static Module load(String path)
 - IValue forward(IValue... inputs)
 - IValue runMethod(String methodName, IValue... inputs)

org.pytorch.Tensor (semantic of at::Tensor)
 - newFloatTensor(long[] dims, float[] data)
 - newFloatTensor(long[] dims, FloatBuffer data)

 - newIntTensor(long[] dims, int[] data)
 - newIntTensor(long[] dims, IntBuffer data)

 - newByteTensor(long[] dims, byte[] data)
 - newByteTensor(long[] dims, ByteBuffer data)

org.pytorch.IValue (semantic of at::IValue)
 - static factory methods to create pytorchscript supported types

Examples of usage api could be found in PytorchInstrumentedTests.java:

Module module = Module.load(path);
IValue input = IValue.tensor(Tensor.newByteTensor(new long[]{1}, Tensor.allocateByteBuffer(1)));
IValue output = module.forward(input);
Tensor outputTensor = output.getTensor();

ThreadSafety:
Api is not thread safe, all synchronization must be done on caller side.

Mutability:
org.pytorch.Tensor buffer is DirectBuffer with native byte order, can be created with static factory methods specifing DirectBuffer.
At the moment org.pytorch.Tensor does not hold at::Tensor on jni side, it has: long[] dimensions, type, DirectByteBuffer blobData

Input tensors are mutable (can be modified and used for the next inference),
Uses values from buffer on the momment of Module#forward or Module#runMethod calls.
Buffers of input tensors is used directly by input at::Tensor

Output is copied from output at::Tensor and is immutable.

Dependencies:
Jni level is implemented with usage of fbjni library, that was developed in Facebook,
and was already used and opensourced in several opensource projects,
added to the repo as submodule from personal account to be able to switch submodule
when fbjni will be opensourced separately.

ghstack-source-id: b39c848359a70d717f2830a15265e4aa122279c0
Pull Request resolved: pytorch#25084
Pull Request resolved: pytorch#25105

Reviewed By: dreiss

Differential Revision: D16988107

Pulled By: IvanKobzarev

fbshipit-source-id: 41ca7c9869f8370b8504c2ef8a96047cc16516d4
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Aug 23, 2019
1 parent 3a59a9b commit d62bca9
Show file tree
Hide file tree
Showing 17 changed files with 1,715 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,7 @@
path = third_party/tbb
url = https://github.com/01org/tbb
branch = tbb_2018
[submodule "android/libs/fbjni"]
ignore = dirty
path = android/libs/fbjni
url = https://github.com/IvanKobzarev/fbjni.git
15 changes: 15 additions & 0 deletions android/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
local.properties
**/*.iml
.gradle
gradlew*
.idea/*
.externalNativeBuild
build
pytorch_android/src/main/cpp/libtorch_include/x86/**
pytorch_android/src/main/cpp/libtorch_include/x86_64/**
pytorch_android/src/main/cpp/libtorch_include/armeabi-v7a/**
pytorch_android/src/main/cpp/libtorch_include/arm64-v8a/**
pytorch_android/src/main/jniLibs/x86/**
pytorch_android/src/main/jniLibs/x86_64/**
pytorch_android/src/main/jniLibs/armeabi-v7a/**
pytorch_android/src/main/jniLibs/arm64-v8a/**
36 changes: 36 additions & 0 deletions android/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
buildscript {
ext {
minSdkVersion = 21
targetSdkVersion = 28
compileSdkVersion = 28
buildToolsVersion = '28.0.3'

coreVersion = "1.2.0"
extJUnitVersion = "1.1.1"
runnerVersion = "1.2.0"
rulesVersion = "1.2.0"
junitVersion = "4.12"
}

repositories {
google()
mavenLocal()
mavenCentral()
jcenter()
}

dependencies {
classpath 'com.android.tools.build:gradle:3.3.2'
}
}

allprojects {
repositories {
google()
jcenter()
}
}

ext.deps = [
jsr305: 'com.google.code.findbugs:jsr305:3.0.1',
]
1 change: 1 addition & 0 deletions android/libs/fbjni
Submodule fbjni added at dc9169
48 changes: 48 additions & 0 deletions android/pytorch_android/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
cmake_minimum_required(VERSION 3.4.1)
project(pytorch CXX)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_VERBOSE_MAKEFILE ON)

set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
set(libtorch_include_DIR ${pytorch_android_DIR}/libtorch_include/${ANDROID_ABI})

set(libtorch_SO ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libtorch.so)
set(libc10_SO ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libc10.so)

message(STATUS "libtorch dir:${libtorch_DIR}")

add_library(libtorch SHARED IMPORTED)
set_property(TARGET libtorch PROPERTY IMPORTED_LOCATION ${libtorch_SO})

add_library(libc10 SHARED IMPORTED ${libc10_SO})
set_property(TARGET libc10 PROPERTY IMPORTED_LOCATION ${libc10_SO})

file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/*.cpp
)

add_library(pytorch SHARED
${pytorch_android_SOURCES}
)

target_compile_options(pytorch PRIVATE
-fexceptions
)

target_include_directories(pytorch PUBLIC
${libtorch_include_DIR}
)

set(BUILD_DIR ${CMAKE_SOURCE_DIR}/build)
file(MAKE_DIRECTORY ${BUILD_DIR})

set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/)
set(fbjni_BUILD_DIR ${BUILD_DIR}/fbjni/${ANDROID_ABI})

add_subdirectory(${fbjni_DIR} ${fbjni_BUILD_DIR})

target_link_libraries(pytorch
fbjni
libtorch
libc10
)
55 changes: 55 additions & 0 deletions android/pytorch_android/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
apply plugin: 'com.android.library'

android {
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion

defaultConfig {
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 1
versionName "1.0"

testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
ndk {
abiFilters "armeabi-v7a", "arm64-v8a", "x86", "x86_64"
}
}
buildTypes {
debug {
minifyEnabled false
}
release {
minifyEnabled false
}
}
sourceSets {
main {
jniLibs.srcDirs = ['src/main/jniLibs']
}
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}

useLibrary 'android.test.runner'
useLibrary 'android.test.base'
useLibrary 'android.test.mock'
}

dependencies {
implementation project(':fbjni')

implementation 'com.android.support:appcompat-v7:28.0.0'

testImplementation 'junit:junit:' + rootProject.junitVersion
testImplementation 'androidx.test:core:' + rootProject.coreVersion

androidTestImplementation 'junit:junit:' + rootProject.junitVersion
androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion
androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion
androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion
androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion
}
103 changes: 103 additions & 0 deletions android/pytorch_android/generate_test_torchscripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
from typing import List, Tuple, Dict

OUTPUT_DIR = "src/androidTest/assets/"

def scriptAndSave(module, fileName):
print('-'*80)
script_module = torch.jit.script(module)
print(script_module.graph)
outputFileName = OUTPUT_DIR + fileName
script_module.save(outputFileName)
print("Saved to " + outputFileName)
print('='*80)

class Test(torch.jit.ScriptModule):
def __init__(self):
super(Test, self).__init__()

@torch.jit.script_method
def forward(self, input):
return None

@torch.jit.script_method
def eqBool(self, input):
# type: (bool) -> bool
return input

@torch.jit.script_method
def eqInt(self, input):
# type: (int) -> int
return input

@torch.jit.script_method
def eqFloat(self, input):
# type: (float) -> float
return input

@torch.jit.script_method
def eqTensor(self, input):
# type: (Tensor) -> Tensor
return input

@torch.jit.script_method
def eqDictStrKeyIntValue(self, input):
# type: (Dict[str, int]) -> Dict[str, int]
return input

@torch.jit.script_method
def eqDictIntKeyIntValue(self, input):
# type: (Dict[int, int]) -> Dict[int, int]
return input

@torch.jit.script_method
def eqDictFloatKeyIntValue(self, input):
# type: (Dict[float, int]) -> Dict[float, int]
return input

@torch.jit.script_method
def listIntSumReturnTuple(self, input):
# type: (List[int]) -> Tuple[List[int], int]
sum = 0
for x in input:
sum += x
return (input, sum)

@torch.jit.script_method
def listBoolConjunction(self, input):
# type: (List[bool]) -> bool
res = True
for x in input:
res = res and x
return res

@torch.jit.script_method
def listBoolDisjunction(self, input):
# type: (List[bool]) -> bool
res = False
for x in input:
res = res or x
return res

@torch.jit.script_method
def tupleIntSumReturnTuple(self, input):
# type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]
sum = 0
for x in input:
sum += x
return (input, sum)

@torch.jit.script_method
def optionalIntIsNone(self, input):
# type: (Optional[int]) -> bool
return input is None

@torch.jit.script_method
def intEq0None(self, input):
# type: (int) -> Optional[int]
if input == 0:
return None
return input


scriptAndSave(Test(), "test.pt")
Binary file not shown.
Loading

0 comments on commit d62bca9

Please sign in to comment.