Skip to content

Commit 053abeb

Browse files
committed
Fix petals.client, petals.models
1 parent dedf09e commit 053abeb

File tree

4 files changed

+25
-49
lines changed

4 files changed

+25
-49
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ install_requires =
3535
bitsandbytes==0.38.0.post2
3636
accelerate>=0.16.0,<1.0.0
3737
huggingface-hub>=0.11.1,<1.0.0
38+
tokenizers>=0.13.3
3839
transformers>=4.30.1,<5.0.0
3940
speedtest-cli==2.1.3
4041
hivemind==1.1.8

src/petals/client/lm_head.py

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import dataclasses
22
import platform
3-
from typing import Union
3+
from typing import Optional, Union
44

55
import psutil
66
import torch
77
import torch.nn.functional as F
88
import torch.utils.checkpoint
99
from hivemind import get_logger
1010
from torch import nn
11+
from transformers import PretrainedConfig
1112

1213
logger = get_logger(__name__)
1314

@@ -21,15 +22,16 @@ class LMHeadConfig:
2122

2223

2324
class LMHead(nn.Module):
24-
"""
25-
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
26-
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
27-
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
28-
"""
29-
30-
def __init__(self, config: LMHeadConfig, word_embeddings: nn.Embedding):
25+
def __init__(self, config: PretrainedConfig):
3126
super().__init__()
32-
self.word_embeddings = word_embeddings
27+
28+
if not config.tie_word_embeddings:
29+
self.weight = nn.Parameter(torch.zeros((config.vocab_size, config.hidden_size), requires_grad=False))
30+
else:
31+
self.weight = None # Will be set to get_input_embeddings().weight during loading the model
32+
self.bias = None
33+
self.in_features = config.hidden_size # Similar to nn.Linear attributes
34+
self.out_features = config.vocab_size
3335

3436
self.use_chunked_forward = config.use_chunked_forward
3537
if self.use_chunked_forward == "auto":
@@ -45,35 +47,17 @@ def __init__(self, config: LMHeadConfig, word_embeddings: nn.Embedding):
4547
self.chunked_forward_step = config.chunked_forward_step
4648
self._bf16_warning_shown = False
4749

48-
@property
49-
def in_features(self) -> int:
50-
return self.word_embeddings.num_embeddings
51-
52-
@property
53-
def out_features(self) -> int:
54-
return self.word_embeddings.embedding_dim
55-
56-
@property
57-
def weight(self):
58-
return self.word_embeddings.weight
59-
60-
@property
61-
def bias(self):
62-
return None
63-
6450
def forward(self, hidden_states):
65-
word_embeddings = self.word_embeddings.weight
66-
6751
if (
68-
word_embeddings.dtype in [torch.float16, torch.bfloat16]
69-
and word_embeddings.device.type == "cpu"
52+
self.weight.dtype in [torch.float16, torch.bfloat16]
53+
and self.weight.device.type == "cpu"
7054
and self.use_chunked_forward
7155
):
7256
lm_logits = self.chunked_forward(hidden_states)
7357
else:
7458
# Switch dtype in case word_embeddings are fp16/bf16
75-
hidden_states = hidden_states.to(word_embeddings.dtype)
76-
lm_logits = F.linear(hidden_states, word_embeddings)
59+
hidden_states = hidden_states.to(self.weight.dtype)
60+
lm_logits = F.linear(hidden_states, self.weight)
7761
return lm_logits
7862

7963
def chunked_forward(self, hidden_states):
@@ -83,20 +67,17 @@ def chunked_forward(self, hidden_states):
8367
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
8468

8569
if not self._bf16_warning_shown:
86-
if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
70+
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
8771
logger.warning(
8872
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
8973
"Consider loading the model with torch_dtype='float32'"
9074
)
9175
self._bf16_warning_shown = True
9276

93-
word_embeddings = self.word_embeddings.weight
94-
num_embeddings = self.word_embeddings.num_embeddings
95-
9677
hidden_states = hidden_states.float()
97-
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
78+
output = torch.empty(*hidden_states.shape[:-1], self.out_features)
9879

99-
for i in range(0, num_embeddings, self.chunked_forward_step):
100-
chunk = word_embeddings[i : i + self.chunked_forward_step].float()
80+
for i in range(0, self.out_features, self.chunked_forward_step):
81+
chunk = self.weight[i : i + self.chunked_forward_step].float()
10182
output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
10283
return output

src/petals/models/bloom/model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
9696
_keys_to_ignore_on_load_missing = (
9797
BloomForCausalLM._keys_to_ignore_on_load_missing
9898
+ DistributedBloomModel._keys_to_ignore_on_load_missing
99-
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
99+
+ [r"^lm_head\."] # Missing since they are shared with input embeddings
100100
)
101101
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
102102

@@ -105,16 +105,13 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
105105
def __init__(self, config: DistributedBloomConfig):
106106
BloomPreTrainedModel.__init__(self, config)
107107
self.transformer = DistributedBloomModel(config)
108-
self.lm_head = LMHead(config, self.transformer.word_embeddings)
108+
self.lm_head = LMHead(config)
109109

110110
# Initialize weights and apply final processing
111111
self.post_init()
112112

113113
def get_output_embeddings(self):
114-
return self.lm_head.word_embeddings
115-
116-
def set_output_embeddings(self, new_embeddings: torch.Tensor):
117-
self.lm_head.word_embeddings = new_embeddings
114+
return self.lm_head
118115

119116

120117
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):

src/petals/models/llama/model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,13 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll
115115
def __init__(self, config: DistributedLlamaConfig):
116116
LlamaPreTrainedModel.__init__(self, config)
117117
self.model = DistributedLlamaModel(config)
118-
self.lm_head = LMHead(config, nn.Embedding(config.vocab_size, config.hidden_size))
118+
self.lm_head = LMHead(config)
119119

120120
# Initialize weights and apply final processing
121121
self.post_init()
122122

123123
def get_output_embeddings(self):
124-
return self.lm_head.word_embeddings
125-
126-
def set_output_embeddings(self, new_embeddings):
127-
self.lm_head.word_embeddings = new_embeddings
124+
return self.lm_head
128125

129126
@property
130127
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin

0 commit comments

Comments
 (0)