@@ -58,7 +58,7 @@ def load_text_data(self):
5858 def load_transformer_model (self ):
5959 """Loads the GPT model with pre-trained weights."""
6060 ckpt_path = os .path .join (self .base_dir , 'transformer' , self .gpt_ckpt_dir , 'ckpt.pt' )
61- checkpoint = torch .load (ckpt_path , map_location = self .device )
61+ checkpoint = torch .load (ckpt_path , map_location = self .device , weights_only = False )
6262 gpt_conf = GPTConfig (** checkpoint ['model_args' ])
6363 transformer = HookedGPT (gpt_conf )
6464 state_dict = checkpoint ['model' ]
@@ -100,7 +100,7 @@ def load_autoencoder_model(self):
100100 state_dict = autoencoder_ckpt ['autoencoder' ]
101101 n_features , n_ffwd = state_dict ['encoder.weight' ].shape # H, F
102102 l1_coeff = autoencoder_ckpt ['config' ]['l1_coeff' ]
103- from autoencoder import AutoEncoder
103+ from autoencoder_architecture import AutoEncoder
104104
105105 autoencoder = AutoEncoder (n_ffwd , n_features , lam = l1_coeff ).to (self .device )
106106 autoencoder .load_state_dict (state_dict )
@@ -114,7 +114,7 @@ def get_text_batch(self, num_contexts):
114114 Y = torch .stack ([torch .from_numpy (self .text_data [i + 1 : i + 1 + block_size ].astype (np .int64 )) for i in ix ])
115115 return X .to (device = self .device ), Y .to (device = self .device )
116116
117- def get_autoencoder_data_batch (self , step , batch_size = 8192 ):
117+ def get_autoencoder_data_batch (self , step , batch_size : int ):
118118 """
119119 Retrieves a batch of autoencoder data based on the step and batch size.
120120 It loads the next data partition if the batch exceeds the current partition.
@@ -141,14 +141,12 @@ def get_autoencoder_data_batch(self, step, batch_size=8192):
141141 return batch .to (self .device )
142142
143143 def load_next_autoencoder_partition (self , partition_id ):
144- """
145- Loads the specified partition of the autoencoder data.
146- """
144+ """Loads the specified partition of the autoencoder data."""
147145 file_path = os .path .join (self .autoencoder_data_dir , f'{ partition_id } .pt' )
148- self .autoencoder_data = torch .load (file_path )
146+ self .autoencoder_data = torch .load (file_path , weights_only = False )
149147 return self .autoencoder_data
150148
151- def select_resampling_data (self , size = 819200 ):
149+ def select_resampling_data (self , size : int ):
152150 """
153151 Selects a subset of autoencoder data for resampling, distributed evenly across partitions.
154152 """
0 commit comments