Skip to content

Commit 35bf7fd

Browse files
committed
enable support for device map auto
1 parent 01aeae1 commit 35bf7fd

File tree

4 files changed

+17
-6
lines changed

4 files changed

+17
-6
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ export HF_ACCESS_TOKEN="your_huggingface_api_key"
207207

208208
- `task_id` (int, optional): Problem task id for selecting a problem from a Dataset.
209209

210+
- `use_auto` (bool, optional): Use auto device mapping. Defaults to False.
210211
- `kwargs`(void, optional): Currently supported `kwargs` are `max_length`, `max_new_tokens`, `min_length`, `min_new_tokens`, `early_stopping`, `do_sample`, `num_beams`, `use_cache`, `temperature`, `top_k`, `top_p`, `num_return_sequences`, `pad_token_id`, and `eos_token_id`. Refer to the [HuggingFace Text Generation Documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation) for more information.
211212

212213

@@ -237,6 +238,7 @@ python3 syncode/infer.py
237238
--new_mask_store [True, False]
238239
--parser ["lr", "lalr"]
239240
--task_id [task_id]
241+
--use_auto [True, False]
240242
```
241243
</details>
242244

syncode/common.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,22 @@
1111
HF_ACCESS_TOKEN = os.environ['HF_ACCESS_TOKEN'] if 'HF_ACCESS_TOKEN' in os.environ else None
1212

1313

14-
def load_model(model_name, device, quantize):
14+
def load_model(model_name, device, quantize, use_auto = False):
1515
if model_name == 'test':
1616
model = AutoModelForCausalLM.from_pretrained('bigcode/tiny_starcoder_py').to(device)
1717
elif model_name == 'test-instruct':
1818
model = AutoModelForCausalLM.from_pretrained("rahuldshetty/tiny-starcoder-instruct")
1919
else:
20-
if (quantize):
21-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
20+
if use_auto:
21+
if (quantize):
22+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = 'auto').eval()
23+
else:
24+
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = 'auto').eval()
2225
else:
23-
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
26+
if (quantize):
27+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
28+
else:
29+
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
2430
return model
2531

2632
def load_tokenizer(model_name):

syncode/infer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
parser: Literal["lr", "lalr"] = "lalr",
5151
seed: Optional[int] = None,
5252
opp: bool = True,
53+
use_auto: bool = False,
5354
**kwargs
5455
):
5556
# Check inputs
@@ -85,7 +86,7 @@ def __init__(
8586
self.grammar = Grammar(grammar) if self._is_grammar_mode() else None
8687

8788
# Load model and tokenizer
88-
model = common.load_model(self.model_name, device, quantize)
89+
model = common.load_model(self.model_name, device, quantize, use_auto)
8990
tokenizer = common.load_tokenizer(self.model_name)
9091

9192
# Initialize grammar decoder if needed
@@ -259,6 +260,7 @@ def main(
259260
parse_output_only: bool = True,
260261
prompt_type: str = 'original',
261262
format_tabs: bool = False,
263+
use_auto: bool = False,
262264
**kwargs
263265
):
264266
"""Run Syncode with the specified configuration.
@@ -309,6 +311,7 @@ def main(
309311
seed=seed,
310312
opp=opp,
311313
parse_output_only=parse_output_only,
314+
use_auto=use_auto,
312315
**kwargs
313316
)
314317

syncode/language_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
self.prompt_template = prompt_template
5151
self.model: PreTrainedModel = model
5252
self.tokenizer = tokenizer
53-
self.device = device
53+
self.device = self.model.device
5454
self.best_of = best_of
5555
self._before_prediction_hook = before_prediction_hook
5656
self.logits_processor = grammar_decoder

0 commit comments

Comments
 (0)