Skip to content

Commit

Permalink
Set sharded to false for tp 1
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Jun 28, 2023
1 parent bfb5466 commit 4829afd
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)

import torch
import logging

ARCHITECTURE_2_BATCH_CLS = {
"RWForCausalLM": FlashCausalLMBatch,
Expand Down Expand Up @@ -57,6 +56,7 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
"""

super().__init__(device)
self.properties = properties
self.batch_cls = None
self._init_model(kwargs, model_id_or_path)
self.batch_id_counter = 0
Expand All @@ -66,9 +66,10 @@ def _init_model(self, kwargs, model_id_or_path):
self.config = AutoConfig.from_pretrained(model_id_or_path,
**kwargs)
self.batch_cls = get_batch_cls_from_architecture(self.config.architectures[0])
sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1
self.model = get_model(model_id_or_path,
revision=None,
sharded=True,
sharded=sharded,
quantize=None,
trust_remote_code=kwargs.get("trust_remote_code"))

Expand Down

0 comments on commit 4829afd

Please sign in to comment.