66 - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
77
88"""
9- import itertools
109import json
1110import time
1211from typing import Dict , Optional , Union
2625from petals .utils .disk_cache import DEFAULT_CACHE_DIR , allow_cache_reads , allow_cache_writes , free_disk_space_for
2726
2827logger = get_logger (__name__ )
28+ logger .setLevel ("DEBUG" )
2929
3030
3131class AutoBlockConfig :
@@ -75,11 +75,6 @@ def load_pretrained_block(
7575 cache_dir = cache_dir ,
7676 max_disk_space = max_disk_space ,
7777 )
78- state_dict = {
79- param_name [len (block_prefix ) :]: param
80- for param_name , param in state_dict .items ()
81- if param_name .startswith (block_prefix )
82- }
8378
8479 # dummy load, check that keys match
8580 report = block .load_state_dict (state_dict , strict = True )
@@ -107,7 +102,6 @@ def _load_state_dict_from_repo(
107102 use_auth_token : Optional [str ] = None ,
108103 cache_dir : str ,
109104 max_disk_space : Optional [int ] = None ,
110- min_backoff : float = 5 ,
111105) -> StateDict :
112106 index_file = get_file_from_repo (
113107 model_name , filename = "pytorch_model.bin.index.json" , use_auth_token = use_auth_token , cache_dir = cache_dir
@@ -118,59 +112,66 @@ def _load_state_dict_from_repo(
118112 filenames = {
119113 filename for param_name , filename in index ["weight_map" ].items () if param_name .startswith (block_prefix )
120114 }
121- if len (filenames ) > 1 :
122- raise RuntimeError (
123- f"Block { block_prefix } * is stored in { filenames } , but Petals can't load blocks divided into multiple files yet"
115+ if not filenames :
116+ raise RuntimeError (f"Block { block_prefix } * not found in the index: { index ['weight_map' ]} " )
117+ logger .debug (f"Loading { block_prefix } * from { filenames } " )
118+
119+ state_dict = {}
120+ for filename in filenames :
121+ shard_state_dict = _load_state_dict_from_file (
122+ model_name , filename , use_auth_token = use_auth_token , cache_dir = cache_dir , max_disk_space = max_disk_space
124123 )
125- [filename ] = filenames
126- logger .debug (f"Loading { block_prefix } * from { filename } " )
124+ shard_state_dict = {
125+ param_name [len (block_prefix ) :]: param
126+ for param_name , param in shard_state_dict .items ()
127+ if param_name .startswith (block_prefix )
128+ } # Remove unused parameters from memory
129+ state_dict .update (shard_state_dict )
130+ return state_dict
131+
127132
133+ def _load_state_dict_from_file (
134+ model_name : str ,
135+ filename : str ,
136+ * ,
137+ use_auth_token : Optional [str ],
138+ cache_dir : str ,
139+ max_disk_space : Optional [int ] = None ,
140+ delay : float = 30 ,
141+ ) -> StateDict :
128142 # First, try to find the weights locally
129143 try :
130144 with allow_cache_reads (cache_dir ):
131- return _load_state_dict_from_file (
145+ path = get_file_from_repo (
132146 model_name , filename , use_auth_token = use_auth_token , cache_dir = cache_dir , local_files_only = True
133147 )
148+ if path is not None :
149+ return torch .load (path , map_location = "cpu" )
134150 except Exception :
135151 logger .debug (
136- f"Failed to load block { block_prefix } * from cache, proceeding to downloading the block " , exc_info = True
152+ f"Failed to load block { block_index } from cache. The block will be downloaded again " , exc_info = True
137153 )
138154
139155 # If not found, ensure that we have enough disk space to download them (maybe remove something)
140- for attempt_no in itertools . count () :
156+ while True :
141157 try :
142158 with allow_cache_writes (cache_dir ):
143159 url = hf_hub_url (model_name , filename )
144160 file_size = get_hf_file_metadata (url , token = use_auth_token ).size
145- gib = 1024 ** 3
146- logger . debug ( f"Shard size for { filename } : { file_size / gib :.2f } GiB" )
147-
148- free_disk_space_for ( model_name , file_size , cache_dir = cache_dir , max_disk_space = max_disk_space )
161+ if file_size is not None :
162+ free_disk_space_for ( model_name , file_size , cache_dir = cache_dir , max_disk_space = max_disk_space )
163+ else :
164+ logger . warning ( f"Failed to fetch size of file { filename } from repo { model_name } " )
149165
150- return _load_state_dict_from_file (
166+ path = get_file_from_repo (
151167 model_name , filename , use_auth_token = use_auth_token , cache_dir = cache_dir , local_files_only = False
152168 )
169+ if path is None :
170+ raise RuntimeError (f"File { filename } does not exist in repo { model_name } " )
171+ return torch .load (path , map_location = "cpu" )
153172 except Exception as e :
154- delay = min_backoff * (2 ** attempt_no )
155- logger .warning (
156- f"Failed to load block { block_prefix } * from HF Hub (retry in { delay :.0f} sec)" , exc_info = True
157- )
173+ logger .warning (f"Failed to load file { filename } from HF Hub (retry in { delay :.0f} sec)" , exc_info = True )
158174 time .sleep (delay )
159175
160176
161- def _load_state_dict_from_file (
162- model_name : str , filename : str , * , use_auth_token : Optional [str ], cache_dir : str , local_files_only : bool
163- ) -> StateDict :
164- path = get_file_from_repo (
165- model_name ,
166- filename = filename ,
167- use_auth_token = use_auth_token ,
168- cache_dir = cache_dir ,
169- local_files_only = local_files_only ,
170- )
171- if path is None :
172- raise RuntimeError (f"Failed to load file { filename } from repo { model_name } " )
173- return torch .load (path , map_location = "cpu" )
174-
175-
176177DTYPE_MAP = dict (bfloat16 = torch .bfloat16 , float16 = torch .float16 , float32 = torch .float32 , auto = "auto" )
0 commit comments