Skip to content

Commit 1ef0656

Browse files
committed
Support loading LLaMA and BLOOM blocks from existing repos
1 parent 675bacb commit 1ef0656

File tree

10 files changed

+258
-76
lines changed

10 files changed

+258
-76
lines changed

src/petals/cli/run_server.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@ def main():
8989
parser.add_argument('--alloc_timeout', type=float, default=60,
9090
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
9191
'before rejecting the request')
92-
parser.add_argument('--revision', type=str, default='main',
93-
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
94-
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
9592

9693
parser.add_argument('--throughput',
9794
type=lambda value: value if value in ['auto', 'eval'] else float(value),

src/petals/llama/__init__.py

Whitespace-only changes.

src/petals/llama/block.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Bloom intermediate layer
3+
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
4+
See commit history for authorship.
5+
"""
6+
import os
7+
from typing import Optional, Tuple
8+
9+
import torch.nn.quantized.dynamic.modules.linear
10+
import transformers
11+
from packaging import version
12+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
13+
14+
# if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
15+
# assert (
16+
# version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
17+
# ), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0"
18+
19+
20+
class WrappedLlamaBlock(LlamaDecoderLayer):
21+
def forward(
22+
self,
23+
hidden_states: torch.Tensor,
24+
*args,
25+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
26+
**kwargs
27+
):
28+
return super().forward(hidden_states, *args, past_key_value=layer_past, **kwargs)

src/petals/llama/modeling_utils.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
PyTorch BLOOM model that implements several memory-efficient modes.
3+
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
4+
See commit history for authorship.
5+
"""
6+
7+
import platform
8+
9+
import psutil
10+
import torch
11+
import torch.nn.functional as F
12+
import torch.utils.checkpoint
13+
from hivemind import get_logger
14+
from torch import nn
15+
from transformers import BloomConfig
16+
17+
logger = get_logger(__name__)
18+
19+
20+
class LMHead(nn.Module):
21+
"""
22+
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
23+
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
24+
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
25+
"""
26+
27+
def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
28+
super().__init__()
29+
self.word_embeddings = word_embeddings
30+
31+
self.use_chunked_forward = config.use_chunked_forward
32+
if self.use_chunked_forward == "auto":
33+
if platform.machine() == "x86_64":
34+
# Import of cpufeature may crash on non-x86_64 machines
35+
from cpufeature import CPUFeature
36+
37+
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
38+
# Otherwise, it's ~8x slower.
39+
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
40+
else:
41+
self.use_chunked_forward = True
42+
self.chunked_forward_step = config.chunked_forward_step
43+
self._bf16_warning_shown = False
44+
45+
@property
46+
def in_features(self) -> int:
47+
return self.word_embeddings.num_embeddings
48+
49+
@property
50+
def out_features(self) -> int:
51+
return self.word_embeddings.embedding_dim
52+
53+
@property
54+
def weight(self):
55+
return self.word_embeddings.weight
56+
57+
@property
58+
def bias(self):
59+
return None
60+
61+
def forward(self, hidden_states):
62+
word_embeddings = self.word_embeddings.weight
63+
64+
if (
65+
word_embeddings.dtype in [torch.float16, torch.bfloat16]
66+
and word_embeddings.device.type == "cpu"
67+
and self.use_chunked_forward
68+
):
69+
lm_logits = self.chunked_forward(hidden_states)
70+
else:
71+
# Switch dtype in case word_embeddings are fp16/bf16
72+
hidden_states = hidden_states.to(word_embeddings.dtype)
73+
lm_logits = F.linear(hidden_states, word_embeddings)
74+
return lm_logits
75+
76+
def chunked_forward(self, hidden_states):
77+
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
78+
chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
79+
"""
80+
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
81+
82+
if not self._bf16_warning_shown:
83+
if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
84+
logger.warning(
85+
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
86+
"Consider loading the model with torch_dtype='float32'"
87+
)
88+
self._bf16_warning_shown = True
89+
90+
word_embeddings = self.word_embeddings.weight
91+
num_embeddings = self.word_embeddings.num_embeddings
92+
93+
hidden_states = hidden_states.float()
94+
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
95+
96+
for i in range(0, num_embeddings, self.chunked_forward_step):
97+
chunk = word_embeddings[i : i + self.chunked_forward_step].float()
98+
output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
99+
return output

src/petals/server/backend.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
"""Code for serving bloom blocks via hivemind-server"""
21
from __future__ import annotations
32

43
from collections import Counter
@@ -12,8 +11,7 @@
1211
from hivemind.utils import get_logger
1312
from tensor_parallel import TensorParallel
1413
from tensor_parallel.tensor_parallel import PerDeviceTensors
15-
from transformers import BloomConfig
16-
from transformers.models.bloom.modeling_bloom import BloomAttention
14+
from transformers import PretrainedConfig
1715

1816
from petals.data_structures import InferenceMetadata
1917
from petals.server.memory_cache import Handle, MemoryCache
@@ -24,17 +22,17 @@
2422

2523

2624
class TransformerBackend(ModuleBackend):
27-
"""A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
25+
"""A wrapper for a transformer block that can process requests for forward, backward and inference"""
2826

29-
def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
27+
def __init__(self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
3028
super().__init__(*args, **kwargs)
3129
assert isinstance(self.module, TensorParallel)
3230
self.config = config
3331
self.memory_cache = memory_cache
3432
for name, param in self.module.named_parameters():
35-
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
33+
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
3634
for name, buf in self.module.named_buffers():
37-
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
35+
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
3836

3937
max_batch_size = self.forward_pool.max_batch_size
4038
device = self.module.devices[self.module.output_device_index]
@@ -53,9 +51,10 @@ def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backen
5351
self.shard_num_heads = []
5452
for shard in self.module.module_shards:
5553
for submodule in shard.modules():
56-
if isinstance(submodule, BloomAttention):
54+
if isinstance(submodule, config.attn_class):
5755
self.shard_num_heads.append(submodule.num_heads)
58-
assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head
56+
assert len(self.shard_num_heads) == len(self.module.devices)
57+
assert sum(self.shard_num_heads) == config.n_head
5958

6059
self.inference_schema = (
6160
(

src/petals/server/block_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33
import torch
44
from accelerate import init_empty_weights
5-
from transformers import BloomConfig
5+
from transformers import PretrainedConfig
66

7-
from petals.bloom.block import WrappedBloomBlock
87

9-
10-
def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
11-
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
8+
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
9+
"""If dtype is "auto", resolves it using the config. Returns `dtype` intact otherwise."""
1210

1311
if dtype == "auto" or dtype is None:
1412
dtype = config.torch_dtype
@@ -18,7 +16,7 @@ def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) ->
1816

1917

2018
def get_block_size(
21-
config: BloomConfig,
19+
config: PretrainedConfig,
2220
location: str,
2321
*,
2422
dtype: Optional[Union[str, torch.dtype]] = None,
@@ -31,7 +29,7 @@ def get_block_size(
3129
), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
3230

3331
with init_empty_weights(include_buffers=True):
34-
block = WrappedBloomBlock(config)
32+
block = config.block_class(config)
3533
n_params = sum(param.numel() for param in block.parameters())
3634

3735
if location == "memory" and load_in_8bit:

0 commit comments

Comments
 (0)