Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit ea95b96

Browse files
authored
Merge 2cad147 into f57240f
2 parents f57240f + 2cad147 commit ea95b96

File tree

7 files changed

+33
-23
lines changed

7 files changed

+33
-23
lines changed

frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public class ModelConfig {
3232
*/
3333
private List<Integer> deviceIds;
3434
/** this variable is auto calculated based on torchrun nproc-per-node. */
35-
private int parallelLevel = 1;
35+
private int parallelLevel;
3636
/** the model parallel type can be tp, pp, pptp */
3737
private ParallelType parallelType = ParallelType.NONE;
3838
/** torchrun config */
@@ -259,9 +259,8 @@ public int getParallelLevel() {
259259
}
260260

261261
public void setParallelLevel(int parallelLevel) {
262-
if (parallelLevel <= 0) {
263-
logger.warn("Invalid parallelLevel:{}, set as 1", parallelLevel);
264-
this.parallelLevel = 1;
262+
if (parallelLevel < 0) {
263+
logger.warn("Invalid parallelLevel:{}, set as 0", parallelLevel);
265264
return;
266265
}
267266
this.parallelLevel = parallelLevel;

frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public void testInvalidYamlConfig() throws InvalidModelException, IOException {
4343
Assert.assertEquals(modelConfig.getMaxBatchDelay(), 100);
4444
Assert.assertEquals(modelConfig.getResponseTimeout(), 120);
4545
Assert.assertNotEquals(modelConfig.getDeviceType(), ModelConfig.DeviceType.GPU);
46-
Assert.assertEquals(modelConfig.getParallelLevel(), 1);
46+
Assert.assertEquals(modelConfig.getParallelLevel(), 0);
4747
Assert.assertNotEquals(modelConfig.getParallelType(), ModelConfig.ParallelType.PPTP);
4848
Assert.assertNull(modelConfig.getDeviceIds());
4949
}

frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public class Model {
4040
private int maxWorkers;
4141
private int batchSize;
4242
private int maxBatchDelay;
43-
private int parallelLevel = 1;
43+
private int parallelLevel;
4444
private long maxRetryTimeoutInMill = 5 * 60 * 1000;
4545
private long clientTimeoutInMills;
4646
private ModelConfig.ParallelType parallelType = ModelConfig.ParallelType.NONE;
@@ -71,7 +71,7 @@ public Model(ModelArchive modelArchive, int queueSize) {
7171
this.modelArchive = modelArchive;
7272
if (modelArchive != null && modelArchive.getModelConfig() != null) {
7373
continuousBatching = modelArchive.getModelConfig().isContinuousBatching();
74-
if (modelArchive.getModelConfig().getParallelLevel() > 1
74+
if (modelArchive.getModelConfig().getParallelLevel() > 0
7575
&& modelArchive.getModelConfig().getParallelType()
7676
!= ModelConfig.ParallelType.NONE) {
7777
parallelLevel = modelArchive.getModelConfig().getParallelLevel();
@@ -138,7 +138,7 @@ public JsonObject getModelState(boolean isDefaultVersion) {
138138
modelInfo.addProperty(BATCH_SIZE, getBatchSize());
139139
modelInfo.addProperty(MAX_BATCH_DELAY, getMaxBatchDelay());
140140
modelInfo.addProperty(RESPONSE_TIMEOUT, getResponseTimeout());
141-
if (parallelLevel > 1) {
141+
if (parallelLevel > 0) {
142142
modelInfo.addProperty(PARALLEL_LEVEL, parallelLevel);
143143
}
144144

frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ public CompletableFuture<Integer> updateModel(
461461
throw new ModelVersionNotFoundException(
462462
"Model version: " + versionId + " does not exist for model: " + modelName);
463463
}
464-
if (model.getParallelLevel() > 1 && model.getDeviceType() == ModelConfig.DeviceType.GPU) {
464+
if (model.getParallelLevel() > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) {
465465
/**
466466
* Current capacity check for LMI is based on single node. TODO: multiple nodes check
467467
* will be based on --proc-per-node + numCores.

frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,17 @@ private void addThreads(
211211
int gpuId = -1;
212212

213213
if (maxGpu > 0) {
214-
if (model.isHasCfgDeviceIds() || model.getParallelLevel() > 1) {
214+
if (model.isHasCfgDeviceIds() || model.getParallelLevel() > 0) {
215215
gpuId =
216216
model.getGpuCounter()
217217
.getAndAccumulate(
218218
maxGpu,
219219
(prev, maxGpuId) ->
220-
(prev + model.getParallelLevel()) % maxGpuId);
221-
if (model.getParallelLevel() == 1) {
220+
(prev + model.getParallelLevel() > 0
221+
? model.getParallelLevel()
222+
: 1)
223+
% maxGpuId);
224+
if (model.getParallelLevel() == 0) {
222225
gpuId = model.getDeviceIds().get(gpuId);
223226
}
224227
} else {
@@ -235,7 +238,7 @@ private void addThreads(
235238
aggregator = new BatchAggregator(model);
236239
}
237240
int currentPort =
238-
model.getParallelLevel() > 1
241+
model.getParallelLevel() > 0
239242
? configManager.isDebug()
240243
? distributionPort.get()
241244
: distributionPort.getAndAdd(model.getParallelLevel())

frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ public void startWorker(int port, String deviceIds)
115115
modelPath.getAbsolutePath(),
116116
model.getModelArchive().getManifest().getModel().getHandler())));
117117

118-
if (model.getParallelLevel() > 1) {
118+
if (model.getParallelLevel() > 0) {
119119
attachRunner(argl, envp, port, deviceIds);
120-
} else if (model.getParallelLevel() == 1) {
120+
} else if (model.getParallelLevel() == 0) {
121121
argl.add(EnvironmentUtils.getPythonRunTime(model));
122122
}
123123

@@ -153,7 +153,7 @@ public void startWorker(int port, String deviceIds)
153153
argl.add(configManager.getMetricsConfigPath());
154154

155155
try {
156-
latch = new CountDownLatch(model.getParallelLevel());
156+
latch = new CountDownLatch(model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
157157

158158
String[] args = argl.toArray(new String[argl.size()]);
159159
String[] envs = envp.toArray(new String[envp.size()]);

frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ public WorkerThread(
9999
this.listener = listener;
100100
startTime = System.currentTimeMillis();
101101
lifeCycle = new WorkerLifeCycle(configManager, model);
102-
replies = new ArrayBlockingQueue<>(model.getParallelLevel());
102+
replies =
103+
new ArrayBlockingQueue<>(
104+
model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
103105
this.workerThreadTimeMetric =
104106
MetricCache.getInstance().getMetricFrontend("WorkerThreadTime");
105107
this.workerLoadTimeMetric = MetricCache.getInstance().getMetricFrontend("WorkerLoadTime");
@@ -198,10 +200,10 @@ public void run() {
198200
|| ((req.getCommand() == WorkerCommands.PREDICT
199201
|| req.getCommand()
200202
== WorkerCommands.STREAMPREDICT)
201-
&& model.getParallelLevel() > 1
203+
&& model.getParallelLevel() > 0
202204
&& model.getParallelType()
203205
!= ModelConfig.ParallelType.PP)
204-
? model.getParallelLevel()
206+
? model.getParallelLevel() > 0 ? model.getParallelLevel() : 1
205207
: 1;
206208
for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) {
207209
backendChannel.get(i).writeAndFlush(req).sync();
@@ -305,7 +307,10 @@ public void run() {
305307
// WorkerThread is running in thread pool, the thread will be assigned to next
306308
// Runnable once this worker is finished. If currentThread keep holding the reference
307309
// of the thread, currentThread.interrupt() might kill next worker.
308-
for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) {
310+
for (int i = 0;
311+
backendChannel.size() > 0
312+
&& i < (model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
313+
i++) {
309314
backendChannel.get(i).disconnect();
310315
}
311316
currentThread.set(null);
@@ -346,7 +351,7 @@ private void connect() throws WorkerInitializationException, InterruptedExceptio
346351
String modelName = model.getModelName();
347352
String modelVersion = model.getVersion();
348353
setState(WorkerState.WORKER_STARTED, HttpURLConnection.HTTP_OK);
349-
final int parallelLevel = model.getParallelLevel();
354+
final int parallelLevel = model.getParallelLevel() > 0 ? model.getParallelLevel() : 1;
350355
final CountDownLatch latch = new CountDownLatch(parallelLevel);
351356
final int responseBufferSize = configManager.getMaxResponseSize();
352357
try {
@@ -449,7 +454,10 @@ public int getPid() {
449454
public void shutdown() {
450455
running.set(false);
451456
setState(WorkerState.WORKER_SCALED_DOWN, HttpURLConnection.HTTP_OK);
452-
for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) {
457+
for (int i = 0;
458+
backendChannel.size() > 0
459+
&& i < (model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
460+
i++) {
453461
if (backendChannel.get(i) != null) {
454462
backendChannel.get(i).close();
455463
}
@@ -522,7 +530,7 @@ public void retry() {
522530

523531
private String getDeviceIds() {
524532
List<Integer> deviceIds;
525-
if (gpuId == -1 || model.getParallelLevel() == 1) {
533+
if (gpuId == -1 || model.getParallelLevel() == 0) {
526534
return null;
527535
} else if (model.isHasCfgDeviceIds()) {
528536
return model.getDeviceIds().subList(gpuId, gpuId + model.getParallelLevel()).stream()

0 commit comments

Comments
 (0)