Skip to content

Commit 014784c

Browse files
authored
update filtering and decontamination
1 parent 283033d commit 014784c

File tree

1 file changed

+95
-38
lines changed

1 file changed

+95
-38
lines changed

preprocessing/filtering.py

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ALL_FILTERS = ["basic", "basic_per_extension", "stars", "comments", "fertility", "xml", "html", "large_and_small_files"]
2424
THRESHOLDS_FERTILITY = {"python": 2.5, "java": 2.9, "javascript": 2.6}
2525

26-
26+
LANG = "language"
2727
class MultiChoice:
2828
def __init__(self, choices):
2929
self.choices = choices
@@ -63,7 +63,7 @@ def parse_args():
6363
def get_comments_ratio(examples):
6464
"""Get ratio of comments to code in each example. Requires a language argument"""
6565
ratio_list = []
66-
for code, language in zip(examples["content"], examples["lang"]):
66+
for code, language in zip(examples["content"], examples[LANG]):
6767
ratio_list.append(get_nl_ratio(code, language.lower()))
6868
return {"nl_ratio": ratio_list}
6969

@@ -89,6 +89,17 @@ def basic_filters(example):
8989
return False
9090
return True
9191

92+
def add_stats(example):
93+
"""Add extra stats:
94+
- size of text, mean and max line length of file
95+
- % alphanumeric characters
96+
- extracts file extension"""
97+
size = len(example["content"])
98+
line_lengths = [len(line) for line in example["content"].splitlines()]
99+
alpha_frac = np.mean([c.isalnum() for c in example["content"]])
100+
ext = example["path"].split(".")[-1]
101+
return {"size": size, "avg_line_length": np.mean(line_lengths), "max_line_length": max(line_lengths), "alphanum_fraction": alpha_frac, "ext": ext}
102+
92103

93104
def basic_filters_per_extension(example, ext_to_filter):
94105
"""Filter files based on line length and % alphanumeric characters.
@@ -97,7 +108,7 @@ def basic_filters_per_extension(example, ext_to_filter):
97108
# extension `None` is an empty string in the csv
98109
try:
99110
(include, line_max, line_mean, alphanum_frac, alphabetic_frac) = ext_to_filter[(language_format_from_dataset(
100-
example["lang"]), example["ext"] if example["ext"] is not None else ""
111+
example[LANG]), example["ext"] if example["ext"] is not None else ""
101112
)]
102113
except KeyError as e:
103114
# Some extensions are not in the csv. This happens for dockerfiles.
@@ -187,7 +198,7 @@ def char_token_ratio(examples, tokenizer):
187198
def filter_tokenizer(examples):
188199
"""Filter files based on char to token ratio"""
189200
values = []
190-
for ratio, lang in zip(examples["fertility_ratio"], examples["lang"]):
201+
for ratio, lang in zip(examples["fertility_ratio"], examples[LANG]):
191202
if ratio < THRESHOLDS_FERTILITY[lang.lower()]:
192203
values.append(False)
193204
else:
@@ -202,7 +213,7 @@ def filter_xml(example):
202213

203214
def filter_html(example):
204215
"""Filter HTML files based on displayed text VS code ratio"""
205-
assert example["lang"] == "HTML", "Filter is only for html examples"
216+
assert example[LANG] == "HTML", "Filter is only for html examples"
206217
html = example["content"]
207218
try:
208219
soup = BeautifulSoup(html, features="html.parser")
@@ -226,6 +237,8 @@ def filter_large_and_small_files(example):
226237
def get_size_text(example):
227238
return {"size": len(example["content"])}
228239

240+
def get_ext(example):
241+
return {"ext": example["path"].split(".")[-1]}
229242

230243
LICENSE_COLUMNS = ['max_stars_repo_licenses', 'max_issues_repo_licenses', 'max_forks_repo_licenses']
231244
def fix_license_cols(example):
@@ -234,6 +247,7 @@ def fix_license_cols(example):
234247
return example
235248

236249

250+
237251
if __name__ == "__main__":
238252
args = parse_args()
239253
print(f"Selected filters: {args.filters}")
@@ -258,20 +272,29 @@ def fix_license_cols(example):
258272
# Load dataset
259273
t_start = time.time()
260274
logger.info(f" ===== Loading {args.dataset_name} and subset {args.subset}=====")
275+
# assert out_path/data doesn't exists
276+
import os
277+
if os.path.exists(f"{args.out_path}/data"):
278+
raise ValueError(f"Output path already exists: {args.out_path}/data delete if before filtering")
279+
261280
dataset = load_dataset(
262-
args.dataset_name, split=args.split, data_dir=args.subset, use_auth_token=True, num_proc=args.num_workers
281+
args.dataset_name, split=args.split, use_auth_token=True, num_proc=rgs.num_workers
263282
)
264283
logger.info(f"Dataset loaded in {time.time() - t_start:.2f} seconds")
265284
logger.info(f"Dataset: {dataset}")
266285
if "size" not in dataset.column_names:
267-
logger.info("Add text size column")
268-
dataset = dataset.map(get_size_text)
286+
logger.info("Add text size column, ext and line stats")
287+
dataset = dataset.map(add_stats, num_proc=args.num_workers)
269288
if args.fix_license_columns:
270289
dataset = dataset.map(fix_license_cols, num_proc=args.num_workers)
271290
logger.info(
272-
f"Dataset size before any filtering: {len(dataset)} examples, {sum(dataset['size']) / 1e9:.2f} GB"
291+
f"Dataset size before any filtering: {len(dataset)} examples, {sum(dataset['size']) / 1e9:.2f} GB and columns: {dataset.column_names}"
292+
)
293+
# filter non permissive data
294+
dataset = dataset.filter(lambda x: x["license_type"] != "non_permissive")
295+
logger.info(
296+
f"Dataset size after non permissive filtering: {len(dataset)} examples, {sum(dataset['size']) / 1e9:.2f} GB"
273297
)
274-
275298
# Run pre-processing if needed
276299
if "stars" in filters:
277300
logger.info(f"===== Processing dataset to add proper stars column=====")
@@ -335,6 +358,8 @@ def fix_license_cols(example):
335358
elif filter == "basic_per_extension":
336359
assert args.per_extension_filter_csv is not None
337360
language = language_format_from_data_dir(args.subset.split("/")[-1]) if args.subset is not None else None
361+
language = "python"
362+
logger.info("selected language: ", language)
338363
logger.info(
339364
f"===== Language: {language}. Basic filtering with line_max, avg_line, alphanum_frac and alphabetic_frac given by : {args.per_extension_filter_csv} ====="
340365
)
@@ -536,6 +561,65 @@ def fix_license_cols(example):
536561
)
537562
dataset = ds
538563

564+
565+
# Run decontamination
566+
if args.run_decontamination:
567+
logger.info(
568+
f"===== Running decontamination ====="
569+
)
570+
import sys
571+
import os
572+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))
573+
from decontamination.benchmark_data import FILTER_OUT
574+
575+
FILTER_OUT.pop('apps_docstrings', None)
576+
FILTER_OUT.pop('gsm8k_questions', None)
577+
logger.info(f"FILTER OUT Benchmarks: {FILTER_OUT.keys()}")
578+
def decontaminate(samples, filter_out=FILTER_OUT):
579+
"""
580+
filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be
581+
filtered-out.
582+
Return a list where each element is True if the corresponding file should be included in the dataset.
583+
Otherwise, the element is False.
584+
"""
585+
output = []
586+
587+
for content in samples["content"]:
588+
content = content.lower()
589+
matched = False
590+
for benchmark, substrings in filter_out.items():
591+
for substring in substrings:
592+
if substring.lower() in content:
593+
matched = True
594+
break
595+
if matched:
596+
break
597+
# we keep files that are not matched
598+
output.append(not matched)
599+
600+
return output
601+
602+
old_size = len(dataset)
603+
old_size_gb = sum(dataset["size"])
604+
dataset = dataset.filter(decontaminate, batched=True, batch_size=10_000, num_proc=64)
605+
filtered_size_gb = sum(dataset["size"])
606+
logger.info(
607+
f"Removed {old_size - len(dataset)} files from {old_size} (i.e {(old_size - len(dataset)) * 100 / old_size}%)"
608+
)
609+
logger.info(
610+
f"Dataset size after decontamination: {len(dataset)} examples, {filtered_size_gb / 1e9:.2f} GB"
611+
)
612+
613+
if args.add_metadata:
614+
from add_content_with_meta import content_with_meta
615+
616+
logger.info("===== Adding content with metadata =====")
617+
dataset = dataset.map(
618+
content_with_meta,
619+
remove_columns=["content"],
620+
num_proc=args.num_workers,
621+
)
622+
539623
# Save dataset
540624
logger.info(
541625
f"Final dataset has {len(dataset)} samples and {sum(dataset['size']) / 1e9:.2f} GB of code"
@@ -548,7 +632,7 @@ def fix_license_cols(example):
548632
dataset.push_to_hub(args.remote_repo)
549633
else:
550634
print(
551-
f"Saving the dataset in manual shards in a clone of {args.hub_username + args.remote_repo}"
635+
f"Saving the dataset in manual shards in a clone of {args.hub_username}/{args.remote_repo}"
552636
)
553637
try:
554638
save_manual_shards(
@@ -557,30 +641,3 @@ def fix_license_cols(example):
557641
logger.info(f"Dataset successfully saved at {args.out_path}/{args.subset} in {time.time() - t_start:.2f} seconds")
558642
except FileExistsError:
559643
logger.warning(f"Output dir already exists at {args.out_path}/{args.subset}. Will not save filtered data")
560-
561-
# Run decontamination
562-
if args.run_decontamination:
563-
import sys
564-
import os
565-
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))
566-
from decontamination.find_substrings import SubstringFilterer
567-
568-
output_dir_decontaminated = f"{args.out_path}_decontaminate/{args.subset}"
569-
570-
filterer = SubstringFilterer(
571-
output_dir=output_dir_decontaminated,
572-
cached_decontamination_dir=None, # no previous cached run
573-
split_languages=False,
574-
cache_retrieval_key="",
575-
data_dir=output_dir_decontaminated
576-
)
577-
578-
filtered = filterer.run(dataset, args.num_workers, args.batch_size)
579-
580-
filtered_size_gb = sum(filtered["size"])
581-
logger.info(
582-
f"Removed {len(dataset) - len(filtered)} / {len(dataset)} files"
583-
)
584-
logger.info(
585-
f"Dataset size after decontamination: {len(filtered)} examples, {filtered_size_gb / 1e9:.2f} GB"
586-
)

0 commit comments

Comments
 (0)