Skip to content

Commit

Permalink
Assign random seed for lmi dist (deepjavalibrary#912)
Browse files Browse the repository at this point in the history
* Assign random seed

* review comments

* format Java

* Add comments

* format Python

* minor refactor

---------

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
2 people authored and KexinFeng committed Aug 16, 2023
1 parent af16d2c commit 128697c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
9 changes: 8 additions & 1 deletion engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def inference(self, inputs):
parameters = []
batch = inputs.get_batches()
first = True
for item in batch:
for i, item in enumerate(batch):
input_map = decode(item, content_type)
_inputs = input_map.pop("inputs", input_map)
if isinstance(_inputs, list):
Expand All @@ -224,6 +224,13 @@ def inference(self, inputs):
"In order to enable dynamic batching, all input batches must have the same parameters"
)

seed_key = 'seed' if inputs.is_batch() else f'batch_{i}.seed'
if item.contains_key(seed_key):
seed = parameters[i].get("seed")
if not seed:
# set server provided seed if seed is not part of request
parameters[i]["seed"] = item.get_as_string(key=seed_key)

outputs = Output()

if self.rolling_batch_type:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def preprocess_requests(self, requests, **kwargs):
top_p=param.get("top_p", 1.0),
typical_p=param.get("typical_p", 1.0),
do_sample=param.get("do_sample", False),
)
seed=int(param.get("seed", 0)))
stop_parameters = StoppingCriteriaParameters(
stop_sequences=param.get("stop_sequences", []),
max_new_tokens=param.get("max_new_tokens", 30))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;
import ai.djl.util.PairList;
import ai.djl.util.RandomUtils;

import com.google.gson.JsonObject;

Expand Down Expand Up @@ -80,6 +81,11 @@ public void run() {
batch.setProperties(req.input.getProperties());
}
batch.add(prefix, req.getRequest());
String seed = req.getSeed();
if (seed != null) {
String seedPrefix = "batch_" + i + ".seed";
batch.add(seedPrefix, req.seed);
}
}
batch.addProperty("batch_size", String.valueOf(size));

Expand Down Expand Up @@ -123,7 +129,7 @@ public Output addInput(Input input, int timeout) throws TranslateException {
throw new TranslateException("Time out in: " + timeout);
}
}
Request req = new Request(input);
Request req = new Request(input, String.valueOf(RandomUtils.nextInt()));
list.add(req);
canRead.signal();
return req.output;
Expand All @@ -147,12 +153,14 @@ private static final class Request {
Output output;
String nextToken;
boolean last;
String seed;

Request(Input input) {
Request(Input input, String seed) {
this.input = input;
data = new ChunkedBytesSupplier();
output = new Output();
output.add(data);
this.seed = seed;
}

BytesSupplier getRequest() {
Expand All @@ -162,6 +170,20 @@ BytesSupplier getRequest() {
return input.getData();
}

/**
* Seed is required for LMI Dist for sampling for all processes in the MPI to generate the
* same token. NextTokenChooserParameters is constructed during first forward and preserved
* for all forward calls of the request.
*
* @return seed, only for first forward
*/
String getSeed() {
if (nextToken != null) {
return null;
}
return seed;
}

void addResponse(String json) {
JsonObject element = JsonUtils.GSON.fromJson(json, JsonObject.class);
last = element.get("last").getAsBoolean();
Expand Down

0 comments on commit 128697c

Please sign in to comment.