Skip to content

Commit b094f17

Browse files
committed
Support blocks sharded across multiple files
1 parent d0a6a00 commit b094f17

File tree

2 files changed

+46
-42
lines changed

2 files changed

+46
-42
lines changed

src/petals/server/from_pretrained.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
77
88
"""
9-
import itertools
109
import json
1110
import time
1211
from typing import Dict, Optional, Union
@@ -26,6 +25,7 @@
2625
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
2726

2827
logger = get_logger(__name__)
28+
logger.setLevel("DEBUG")
2929

3030

3131
class AutoBlockConfig:
@@ -75,11 +75,6 @@ def load_pretrained_block(
7575
cache_dir=cache_dir,
7676
max_disk_space=max_disk_space,
7777
)
78-
state_dict = {
79-
param_name[len(block_prefix) :]: param
80-
for param_name, param in state_dict.items()
81-
if param_name.startswith(block_prefix)
82-
}
8378

8479
# dummy load, check that keys match
8580
report = block.load_state_dict(state_dict, strict=True)
@@ -107,7 +102,6 @@ def _load_state_dict_from_repo(
107102
use_auth_token: Optional[str] = None,
108103
cache_dir: str,
109104
max_disk_space: Optional[int] = None,
110-
min_backoff: float = 5,
111105
) -> StateDict:
112106
index_file = get_file_from_repo(
113107
model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir
@@ -118,59 +112,66 @@ def _load_state_dict_from_repo(
118112
filenames = {
119113
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
120114
}
121-
if len(filenames) > 1:
122-
raise RuntimeError(
123-
f"Block {block_prefix}* is stored in {filenames}, but Petals can't load blocks divided into multiple files yet"
115+
if not filenames:
116+
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
117+
logger.debug(f"Loading {block_prefix}* from {filenames}")
118+
119+
state_dict = {}
120+
for filename in filenames:
121+
shard_state_dict = _load_state_dict_from_file(
122+
model_name, filename, use_auth_token=use_auth_token, cache_dir=cache_dir, max_disk_space=max_disk_space
124123
)
125-
[filename] = filenames
126-
logger.debug(f"Loading {block_prefix}* from {filename}")
124+
shard_state_dict = {
125+
param_name[len(block_prefix) :]: param
126+
for param_name, param in shard_state_dict.items()
127+
if param_name.startswith(block_prefix)
128+
} # Remove unused parameters from memory
129+
state_dict.update(shard_state_dict)
130+
return state_dict
131+
127132

133+
def _load_state_dict_from_file(
134+
model_name: str,
135+
filename: str,
136+
*,
137+
use_auth_token: Optional[str],
138+
cache_dir: str,
139+
max_disk_space: Optional[int] = None,
140+
delay: float = 30,
141+
) -> StateDict:
128142
# First, try to find the weights locally
129143
try:
130144
with allow_cache_reads(cache_dir):
131-
return _load_state_dict_from_file(
145+
path = get_file_from_repo(
132146
model_name, filename, use_auth_token=use_auth_token, cache_dir=cache_dir, local_files_only=True
133147
)
148+
if path is not None:
149+
return torch.load(path, map_location="cpu")
134150
except Exception:
135151
logger.debug(
136-
f"Failed to load block {block_prefix}* from cache, proceeding to downloading the block", exc_info=True
152+
f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
137153
)
138154

139155
# If not found, ensure that we have enough disk space to download them (maybe remove something)
140-
for attempt_no in itertools.count():
156+
while True:
141157
try:
142158
with allow_cache_writes(cache_dir):
143159
url = hf_hub_url(model_name, filename)
144160
file_size = get_hf_file_metadata(url, token=use_auth_token).size
145-
gib = 1024**3
146-
logger.debug(f"Shard size for {filename}: {file_size / gib:.2f} GiB")
147-
148-
free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
161+
if file_size is not None:
162+
free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
163+
else:
164+
logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
149165

150-
return _load_state_dict_from_file(
166+
path = get_file_from_repo(
151167
model_name, filename, use_auth_token=use_auth_token, cache_dir=cache_dir, local_files_only=False
152168
)
169+
if path is None:
170+
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
171+
return torch.load(path, map_location="cpu")
153172
except Exception as e:
154-
delay = min_backoff * (2**attempt_no)
155-
logger.warning(
156-
f"Failed to load block {block_prefix}* from HF Hub (retry in {delay:.0f} sec)", exc_info=True
157-
)
173+
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
158174
time.sleep(delay)
159175

160176

161-
def _load_state_dict_from_file(
162-
model_name: str, filename: str, *, use_auth_token: Optional[str], cache_dir: str, local_files_only: bool
163-
) -> StateDict:
164-
path = get_file_from_repo(
165-
model_name,
166-
filename=filename,
167-
use_auth_token=use_auth_token,
168-
cache_dir=cache_dir,
169-
local_files_only=local_files_only,
170-
)
171-
if path is None:
172-
raise RuntimeError(f"Failed to load file {filename} from repo {model_name}")
173-
return torch.load(path, map_location="cpu")
174-
175-
176177
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

src/petals/utils/disk_cache.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from hivemind.utils.logging import get_logger
1010

1111
logger = get_logger(__name__)
12+
logger.setLevel("DEBUG")
1213

1314
DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))
1415

@@ -57,13 +58,16 @@ def free_disk_space_for(
5758
available_space = shutil.disk_usage(cache_dir).free - os_quota
5859
if max_disk_space is not None:
5960
available_space = min(available_space, max_disk_space - occupied_space)
61+
62+
gib = 1024**3
63+
logger.debug(f"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB")
6064
if size <= available_space:
6165
return
6266

6367
revisions = [revision for repo in model_repos for revision in repo.revisions]
6468
revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified))
6569

66-
# Remove as few least recently used blocks as possible
70+
# Remove as few least recently used shards as possible
6771
pending_removal = []
6872
freed_space = 0
6973
extra_space_needed = size - available_space
@@ -73,9 +77,8 @@ def free_disk_space_for(
7377
if freed_space >= extra_space_needed:
7478
break
7579

76-
gib = 1024**3
7780
if pending_removal:
78-
logger.info(f"Removing {len(pending_removal)} blocks to free {freed_space / gib:.1f} GiB of disk space")
81+
logger.info(f"Removing {len(pending_removal)} shards to free {freed_space / gib:.1f} GiB of disk space")
7982
delete_strategy = cache_info.delete_revisions(*pending_removal)
8083
delete_strategy.execute()
8184

0 commit comments

Comments
 (0)