Skip to content

Commit

Permalink
Clean up rp sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
mzio committed Sep 20, 2024
1 parent c187136 commit ec36da9
Show file tree
Hide file tree
Showing 29 changed files with 27 additions and 10 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 3 additions & 1 deletion src/dataloaders/preprocess_rp_contig.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,12 @@ def main():
_data_attr = distill_config['dataset']['dataset_config']['train_data']
_data_attr = '-d='.join(_data_attr).replace('/', '_').replace('.json', '')
_data_attr = _data_attr.replace('[','_').replace(']','')

dataset_config = distill_config.dataset.dataset_config

# fname = f'd={_data_attr}-nts={num_train_samples}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}'
fname = f'd={_data_attr}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}'
fname = join('./src/dataloaders', fname)
fname = join(dataset_config['dataloaders_dir'], 'redpajama_sample_indices', fname)

# Rank samples by effective sequence length
_train_esl = train_esl.mean(0).mean(0).mean(-1) # num_samples
Expand Down
33 changes: 24 additions & 9 deletions src/dataloaders/redpajama_sample_contig.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,11 @@ def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
_data_attr = _data_attr.replace('[','_').replace(']','')

# fname = f'd={_data_attr}-nts={num_train_samples}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}'
fname = f'd={_data_attr}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}'
fname = join(dataset_config['dataloaders_dir'], fname)


try:
fname = f'd={_data_attr}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}'
fname = join(dataset_config['dataloaders_dir'], fname)
fname = join(dataset_config['dataloaders_dir'], 'redpajama_sample_indices', fname)
if dataset_config['filter_window'] > 0:
sorted_idx = np.load(f'{fname}_l{window:03d}.npy')
else:
Expand All @@ -145,14 +143,31 @@ def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
_train_esl = train_esl.mean(0).mean(0).mean(-1) # num_samples
sorted_idx = torch.argsort(_train_esl, dim=-1, descending=True)
# Save indices to generated filename
fname = f'd={_data_attr}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}'
fname = join(dataset_config['dataloaders_dir'], 'redpajama_sample_indices', fname)
np.save(f'{fname}.npy', sorted_idx)
print(f'-> Top {num_train_samples} saved to {fname}!')

# _train_esl = train_esl[..., -128:].mean(0).mean(0).mean(-1) # num_samples
# sorted_idx = torch.argsort(_train_esl, dim=-1, descending=True)
# # Save indices to generated filename
# np.save(f'{fname}_l128.npy', sorted_idx)
# print(f'-> Top {num_train_samples} saved to {fname}!')

# Also sort by computing sequence lengths over last window tokens
for window in [1, 2, 4, 8, 16, 32, 64, 128]:
_train_esl = train_esl[..., -window:].mean(0).mean(0).mean(-1) # num_samples
sorted_idx = torch.argsort(_train_esl, dim=-1, descending=True)
# Save indices to generated filename
try:
_fname = f'{fname}_l{window:03d}.npy'
np.save(_fname, sorted_idx)
print(f'-> Samples saved to {_fname}!')

# Also save top samples
sample_idx = sorted_idx[:num_train_samples].numpy()
_fname = f'{fname}-nts={num_train_samples}_l{window:03d}.npy'
np.save(_fname, sample_idx) # sorted_idx)
print(f'-> Top {num_train_samples} saved to {_fname}!')
except:
sample_idx = sorted_idx[:num_train_samples].numpy()
_fname = f'{fname}-nts={num_train_samples}_l{window:03d}.npy'
np.save(_fname, sample_idx) # sorted_idx)
print(f'-> Top {num_train_samples} saved to {_fname}!')

sample_idx = sorted_idx[:num_train_samples].numpy()
train_set.filtered_samples = [train_set.filtered_samples[ix] for ix in sample_idx]
Expand Down

0 comments on commit ec36da9

Please sign in to comment.