Skip to content

Commit e09f33c

Browse files
authored
Android check pte exists
Differential Revision: D74850727 Pull Request resolved: #10931
1 parent 1244672 commit e09f33c

File tree

3 files changed

+36
-25
lines changed

3 files changed

+36
-25
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,9 @@ public void testModuleLoadForwardExplicit() throws IOException{
9696
assertTrue(results[0].isTensor());
9797
}
9898

99-
@Test
99+
@Test(expected = RuntimeException.class)
100100
public void testModuleLoadNonExistantFile() throws IOException{
101101
Module module = Module.load(getTestFilePath(MISSING_FILE_NAME));
102-
103-
EValue[] results = module.forward();
104-
assertEquals(null, results);
105-
}
106-
107-
@Test
108-
public void testModuleLoadMethodNonExistantFile() throws IOException{
109-
Module module = Module.load(getTestFilePath(MISSING_FILE_NAME));
110-
111-
int loadMethod = module.loadMethod(FORWARD_METHOD);
112-
assertEquals(loadMethod, ACCESS_FAILED);
113102
}
114103

115104
@Test
@@ -146,11 +135,11 @@ public void testForwardOnDestroyedModule() throws IOException{
146135
assertEquals(loadMethod, OK);
147136

148137
module.destroy();
149-
138+
150139
EValue[] results = module.forward();
151140
assertEquals(0, results.length);
152141
}
153-
142+
154143
@Test
155144
public void testForwardFromMultipleThreads() throws InterruptedException, IOException {
156145
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
@@ -169,7 +158,7 @@ public void run() {
169158
assertTrue(results[0].isTensor());
170159
completed.incrementAndGet();
171160
} catch (InterruptedException e) {
172-
161+
173162
}
174163
}
175164
};

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import android.util.Log;
1212
import com.facebook.soloader.nativeloader.NativeLoader;
1313
import com.facebook.soloader.nativeloader.SystemDelegate;
14+
import java.io.File;
1415
import java.util.concurrent.locks.Lock;
1516
import java.util.concurrent.locks.ReentrantLock;
1617
import org.pytorch.executorch.annotations.Experimental;
@@ -52,6 +53,10 @@ public static Module load(final String modelPath, int loadMode) {
5253
if (!NativeLoader.isInitialized()) {
5354
NativeLoader.init(new SystemDelegate());
5455
}
56+
File modelFile = new File(modelPath);
57+
if (!modelFile.canRead() || !modelFile.isFile()) {
58+
throw new RuntimeException("Cannot load model path " + modelPath);
59+
}
5560
return new Module(new NativePeer(modelPath, loadMode));
5661
}
5762

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import com.facebook.jni.annotations.DoNotStrip;
1313
import com.facebook.soloader.nativeloader.NativeLoader;
1414
import com.facebook.soloader.nativeloader.SystemDelegate;
15+
import java.io.File;
1516
import org.pytorch.executorch.annotations.Experimental;
1617

1718
/**
@@ -41,33 +42,49 @@ public class LlmModule {
4142
private static native HybridData initHybrid(
4243
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath);
4344

45+
/**
46+
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
47+
* data path.
48+
*/
49+
public LlmModule(
50+
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) {
51+
File modelFile = new File(modulePath);
52+
if (!modelFile.canRead() || !modelFile.isFile()) {
53+
throw new RuntimeException("Cannot load model path " + modulePath);
54+
}
55+
File tokenizerFile = new File(tokenizerPath);
56+
if (!tokenizerFile.canRead() || !tokenizerFile.isFile()) {
57+
throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath);
58+
}
59+
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataPath);
60+
}
61+
4462
/** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */
4563
public LlmModule(String modulePath, String tokenizerPath, float temperature) {
46-
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null);
64+
this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null);
4765
}
4866

4967
/**
5068
* Constructs a LLM Module for a model with given model path, tokenizer, temperature and data
5169
* path.
5270
*/
5371
public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) {
54-
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath);
72+
this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath);
5573
}
5674

5775
/** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
5876
public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
59-
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null);
77+
this(modelType, modulePath, tokenizerPath, temperature, null);
6078
}
6179

6280
/** Constructs a LLM Module for a model with the given LlmModuleConfig */
6381
public LlmModule(LlmModuleConfig config) {
64-
mHybridData =
65-
initHybrid(
66-
config.getModelType(),
67-
config.getModulePath(),
68-
config.getTokenizerPath(),
69-
config.getTemperature(),
70-
config.getDataPath());
82+
this(
83+
config.getModelType(),
84+
config.getModulePath(),
85+
config.getTokenizerPath(),
86+
config.getTemperature(),
87+
config.getDataPath());
7188
}
7289

7390
public void resetNative() {

0 commit comments

Comments
 (0)