-
Notifications
You must be signed in to change notification settings - Fork 15.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
langchain_huggingface: Fix multiple GPU usage bug in from_model_id function #23628
base: master
Are you sure you want to change the base?
Changes from 8 commits
7a2b7ee
7e635b4
f98f992
f5ff450
d010fb3
fb774b5
cd90167
0c49e7a
802b2c9
3754828
836ba88
1975792
acd3803
2d2309a
dd31867
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -96,7 +96,17 @@ def from_model_id( | |||||||
"Could not import transformers python package. " | ||||||||
"Please install it with `pip install transformers`." | ||||||||
) | ||||||||
|
||||||||
if device_map is not None: | ||||||||
if device is not None: | ||||||||
logger.warning( | ||||||||
"Both `device` and `device_map` are specified. " | ||||||||
"`device` will override `device_map`. " | ||||||||
"You will most likely encounter unexpected behavior." | ||||||||
"Please remove `device` and keep " | ||||||||
"`device_map`." | ||||||||
) | ||||||||
model_kwargs["device_map"] = device_map | ||||||||
device = None | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reason as the comment above, this line was used as a temporary fix. If the default value of device is changed, this line is no longer needed |
||||||||
_model_kwargs = model_kwargs or {} | ||||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) | ||||||||
|
||||||||
|
@@ -219,7 +229,6 @@ def from_model_id( | |||||||
model=model, | ||||||||
tokenizer=tokenizer, | ||||||||
device=device, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
would it make sense to keep this passthrough incase transformers uses it in future? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar reason as above. There is also a similar parameter validation to check if both |
||||||||
device_map=device_map, | ||||||||
batch_size=batch_size, | ||||||||
model_kwargs=_model_kwargs, | ||||||||
**_pipeline_kwargs, | ||||||||
|
@@ -262,7 +271,6 @@ def _generate( | |||||||
text_generations: List[str] = [] | ||||||||
pipeline_kwargs = kwargs.get("pipeline_kwargs", {}) | ||||||||
skip_prompt = kwargs.get("skip_prompt", False) | ||||||||
kenchanLOL marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
||||||||
for i in range(0, len(prompts), self.batch_size): | ||||||||
batch_prompts = prompts[i : i + self.batch_size] | ||||||||
|
||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we simply escalate it from warning to error, it would always be triggered when we set
device_map
as the default value of device is -1(cpu) in current implementation. I checked the documentation of transformer and realized that the default value of device forpipeline
changed from -1 to None. I tested both value and they have same behavior loading the model into CPU, the only difference is their logging message in terminalDespite the log message for
device=-1
looks better, given that it is a legacy value that could cause conflicts , updating the default to None will align the behavior with the latest practices and prevent unexpected errors in the future. So, I think we should also change the default value of device in line 77 while changing this from warning to error.