Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused device parameter from all rolling batch classes #1538

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,8 @@ def initialize(self, properties: dict):
}
if "output_formatter" in properties:
kwargs["output_formatter"] = properties.get("output_formatter")
self.rolling_batch = DeepSpeedRollingBatch(self.model,
self.properties.device,
properties, **kwargs)
self.rolling_batch = DeepSpeedRollingBatch(self.model, properties,
**kwargs)
else:
self.create_model_pipeline()
self.logger.info(
Expand Down
4 changes: 2 additions & 2 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def initialize(self, properties: dict):
self.hf_configs.rolling_batch.value, self.hf_configs.is_mpi,
self.model_config)
self.rolling_batch = _rolling_batch_cls(
self.hf_configs.model_id_or_path, self.hf_configs.device,
properties, **self.hf_configs.kwargs)
self.hf_configs.model_id_or_path, properties,
**self.hf_configs.kwargs)
self.initialized = True
return
elif is_streaming_enabled(self.hf_configs.enable_streaming):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@

class DeepSpeedRollingBatch(RollingBatch):

def __init__(self, model: InferenceEngine, device, properties, **kwargs):
def __init__(self, model: InferenceEngine, properties, **kwargs):
"""
Initializes the LmiDistRollingBatch.
Initializes the DeepSpeedRollingBatch.

:param model_id_or_path: model id or path
:param device: model loaded device
:param properties: other properties of the model, such as decoder strategy
:param kwargs passed while loading the model
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@

class LmiDistRollingBatch(RollingBatch):

def __init__(self, model_id_or_path, device, properties, **kwargs):
def __init__(self, model_id_or_path, properties, **kwargs):
"""
Initializes the LmiDistRollingBatch.

:param model_id_or_path: model id or path
:param device: model loaded device
:param properties: other properties of the model, such as decoder strategy
:param kwargs passed while loading the model
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ def enable_flash():

class SchedulerRollingBatch(RollingBatch):

def __init__(self, model_id_or_path, device, properties, **kwargs):
def __init__(self, model_id_or_path, properties, **kwargs):
"""
Initializes the rolling batch scheduler.

:param model_id_or_path: model id or path
:param device: model loaded device
:param properties: other properties of the model, such as decoder strategy
:param kwargs passed while loading the model
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class TRTLLMRollingBatch(RollingBatch):

def __init__(self, model_id_or_path, device, properties, **kwargs):
def __init__(self, model_id_or_path, properties, **kwargs):
"""
Initializes the TRTLLMRollingBatch.
:param model_id_or_path: model id or path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
class VLLMRollingBatch(RollingBatch):

# TODO: Make properties is the only parameter, after refactoring all rolling batch handlers
def __init__(self, model_id_or_path, device, properties, **kwargs):
def __init__(self, model_id_or_path, properties, **kwargs):
"""
Initializes the VLLMRollingBatch.
:param properties: other properties of the model, such as decoder strategy
Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def initialize(self, properties: dict):
self.trt_configs = TensorRtLlmProperties(**properties)

self.rolling_batch = TRTLLMRollingBatch(
self.trt_configs.model_id_or_path, None, properties, **properties)
self.trt_configs.model_id_or_path, properties, **properties)
self.initialized = True
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def init_rolling_batch(rolling_batch_type: str, model_id: str,
else:
rolling_batcher_cls = get_rolling_batch_class_from_str(
rolling_batch_type)
return rolling_batcher_cls(model_id, device, properties, **properties)
return rolling_batcher_cls(model_id, properties, **properties)


def print_rank0(content):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@

# ===================== lmi ============================
print("=========== before =========")
rolling_batch = SchedulerRollingBatch(model_id, device, properties)
# rolling_batch = LmiDistRollingBatch(model_id, device, properties)
rolling_batch = SchedulerRollingBatch(model_id, properties)
# rolling_batch = LmiDistRollingBatch(model_id, properties)
rolling_batch.output_formatter = None
print("reach here")

Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/tests/test_lmi_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_models(self):
device = int(os.environ.get("RANK", 0))
properties["device"] = int(os.environ.get("RANK", 0))

rolling_batch = LmiDistRollingBatch(model_id, device, properties)
rolling_batch = LmiDistRollingBatch(model_id, properties)
rolling_batch.output_formatter = None

gen = Generator(rolling_batch=rolling_batch)
Expand Down
Loading