Skip to content

Commit 0243367

Browse files
committed
enable support for device map auto
1 parent eb9a4a7 commit 0243367

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ export HF_ACCESS_TOKEN="your_huggingface_api_key"
207207

208208
- `task_id` (int, optional): Problem task id for selecting a problem from a Dataset.
209209

210-
- `use_auto_device_mapping` (bool, optional): Use auto device mapping. Defaults to False.
210+
- `device_map` (str, optional): Device map for the model. Defaults to None.
211211
- `kwargs`(void, optional): Currently supported `kwargs` are `max_length`, `max_new_tokens`, `min_length`, `min_new_tokens`, `early_stopping`, `do_sample`, `num_beams`, `use_cache`, `temperature`, `top_k`, `top_p`, `num_return_sequences`, `pad_token_id`, and `eos_token_id`. Refer to the [HuggingFace Text Generation Documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation) for more information.
212212

213213

@@ -238,7 +238,7 @@ python3 syncode/infer.py
238238
--new_mask_store [True, False]
239239
--parser ["lr", "lalr"]
240240
--task_id [task_id]
241-
--use_auto_device_mapping [True, False]
241+
--device_map [device_map]
242242
```
243243
</details>
244244

syncode/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
HF_ACCESS_TOKEN = os.environ['HF_ACCESS_TOKEN'] if 'HF_ACCESS_TOKEN' in os.environ else None
1212

1313

14-
def load_model(model_name, device, quantize, use_auto_device_mapping = False):
14+
def load_model(model_name, device, quantize, device_map = None):
1515
if model_name == 'test':
1616
model = AutoModelForCausalLM.from_pretrained('bigcode/tiny_starcoder_py').to(device)
1717
elif model_name == 'test-instruct':
1818
model = AutoModelForCausalLM.from_pretrained("rahuldshetty/tiny-starcoder-instruct")
1919
else:
20-
if use_auto_device_mapping:
20+
if device_map is not None:
2121
if (quantize):
22-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = 'auto').eval()
22+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval()
2323
else:
24-
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = 'auto').eval()
24+
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval()
2525
else:
2626
if (quantize):
2727
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)

syncode/infer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
parser: Literal["lr", "lalr"] = "lalr",
5151
seed: Optional[int] = None,
5252
opp: bool = True,
53-
use_auto_device_mapping: bool = False,
53+
device_map: Optional[str] = None,
5454
**kwargs
5555
):
5656
# Check inputs
@@ -86,7 +86,7 @@ def __init__(
8686
self.grammar = Grammar(grammar) if self._is_grammar_mode() else None
8787

8888
# Load model and tokenizer
89-
model = common.load_model(self.model_name, device, quantize, use_auto_device_mapping)
89+
model = common.load_model(self.model_name, device, quantize, device_map)
9090
tokenizer = common.load_tokenizer(self.model_name)
9191

9292
# Initialize grammar decoder if needed
@@ -260,7 +260,7 @@ def main(
260260
parse_output_only: bool = True,
261261
prompt_type: str = 'original',
262262
format_tabs: bool = False,
263-
use_auto_device_mapping: bool = False,
263+
device_map: Optional[str] = None,
264264
**kwargs
265265
):
266266
"""Run Syncode with the specified configuration.
@@ -311,7 +311,7 @@ def main(
311311
seed=seed,
312312
opp=opp,
313313
parse_output_only=parse_output_only,
314-
use_auto_device_mapping=use_auto_device_mapping,
314+
device_map=device_map,
315315
**kwargs
316316
)
317317

0 commit comments

Comments
 (0)