Skip to content

Commit

Permalink
use dict get to better work with load_in_4bit conditions
Browse files Browse the repository at this point in the history
Co-authored-by: Niklas Muennighoff <n.muennighoff@gmail.com>
  • Loading branch information
kabachuha and Muennighoff authored Jun 3, 2024
1 parent 4ef1a60 commit ea17214
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion gritlm/gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
if self.embed_eos:
assert self.embed_eos in self.tokenizer.vocab, f"EOS token {self.embed_eos} not in vocab"
self.model.eval()
if not("device_map" in kwargs) and not("load_in_4bit" in kwargs and kwargs["load_in_4bit"]) and not("load_in_8bit" in kwargs and kwargs["load_in_8bit"]):
if not("device_map" in kwargs) and not(kwargs.get("load_in_4bit", False)) and not(kwargs.get("load_in_8bit", False)):
self.model.to(self.device)
# Parallelize embedding model
if mode == 'embedding':
Expand Down

0 comments on commit ea17214

Please sign in to comment.