From 128697c89a318db0dd8dec8389d3d43508145c52 Mon Sep 17 00:00:00 2001 From: Sindhu Somasundaram <56774226+sindhuvahinis@users.noreply.github.com> Date: Thu, 6 Jul 2023 13:58:44 -0700 Subject: [PATCH] Assign random seed for lmi dist (#912) * Assign random seed * review comments * format Java * Add comments * format Python * minor refactor --------- Co-authored-by: Frank Liu --- .../python/setup/djl_python/huggingface.py | 9 ++++++- .../rolling_batch/lmi_dist_rolling_batch.py | 2 +- .../ai/djl/python/engine/RollingBatch.java | 26 +++++++++++++++++-- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index e6d597a4a..4935cd79c 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -206,7 +206,7 @@ def inference(self, inputs): parameters = [] batch = inputs.get_batches() first = True - for item in batch: + for i, item in enumerate(batch): input_map = decode(item, content_type) _inputs = input_map.pop("inputs", input_map) if isinstance(_inputs, list): @@ -224,6 +224,13 @@ def inference(self, inputs): "In order to enable dynamic batching, all input batches must have the same parameters" ) + seed_key = 'seed' if inputs.is_batch() else f'batch_{i}.seed' + if item.contains_key(seed_key): + seed = parameters[i].get("seed") + if not seed: + # set server provided seed if seed is not part of request + parameters[i]["seed"] = item.get_as_string(key=seed_key) + outputs = Output() if self.rolling_batch_type: diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index cbb1cd7df..4018b2a06 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -125,7 +125,7 @@ def preprocess_requests(self, requests, **kwargs): top_p=param.get("top_p", 1.0), typical_p=param.get("typical_p", 1.0), do_sample=param.get("do_sample", False), - ) + seed=int(param.get("seed", 0))) stop_parameters = StoppingCriteriaParameters( stop_sequences=param.get("stop_sequences", []), max_new_tokens=param.get("max_new_tokens", 30)) diff --git a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java index 525f5048c..eed735d53 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java +++ b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java @@ -19,6 +19,7 @@ import ai.djl.translate.TranslateException; import ai.djl.util.JsonUtils; import ai.djl.util.PairList; +import ai.djl.util.RandomUtils; import com.google.gson.JsonObject; @@ -80,6 +81,11 @@ public void run() { batch.setProperties(req.input.getProperties()); } batch.add(prefix, req.getRequest()); + String seed = req.getSeed(); + if (seed != null) { + String seedPrefix = "batch_" + i + ".seed"; + batch.add(seedPrefix, req.seed); + } } batch.addProperty("batch_size", String.valueOf(size)); @@ -123,7 +129,7 @@ public Output addInput(Input input, int timeout) throws TranslateException { throw new TranslateException("Time out in: " + timeout); } } - Request req = new Request(input); + Request req = new Request(input, String.valueOf(RandomUtils.nextInt())); list.add(req); canRead.signal(); return req.output; @@ -147,12 +153,14 @@ private static final class Request { Output output; String nextToken; boolean last; + String seed; - Request(Input input) { + Request(Input input, String seed) { this.input = input; data = new ChunkedBytesSupplier(); output = new Output(); output.add(data); + this.seed = seed; } BytesSupplier getRequest() { @@ -162,6 +170,20 @@ BytesSupplier getRequest() { return input.getData(); } + /** + * Seed is required for LMI Dist for sampling for all processes in the MPI to generate the + * same token. NextTokenChooserParameters is constructed during first forward and preserved + * for all forward calls of the request. + * + * @return seed, only for first forward + */ + String getSeed() { + if (nextToken != null) { + return null; + } + return seed; + } + void addResponse(String json) { JsonObject element = JsonUtils.GSON.fromJson(json, JsonObject.class); last = element.get("last").getAsBoolean();