-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathprepare_code_data.py
44 lines (38 loc) · 1.73 KB
/
prepare_code_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description:
"""
import argparse
import sys
from sklearn.model_selection import train_test_split
sys.path.append("..")
from codeassist import create_dataset
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", type=str, default="download", help="Save dataset directory")
parser.add_argument("--num_repos", type=int, default=3, help="Number of repos to use")
parser.add_argument("--code", default="python", const='python', nargs='?',
choices=['python', 'java', 'cpp'], help="Download code language source code dataset")
args = parser.parse_args()
print(args)
sources = dict()
try:
sources = create_dataset.get_source_code_by_language(code_languages=args.code,
save_dir=args.save_dir,
each_limit_repos=args.num_repos
)
except KeyboardInterrupt:
pass
X = sources[f"{args.code}"]
X_train, X_test = train_test_split(X, test_size=0.2, random_state=1)
X_train, X_val = train_test_split(X_train, test_size=0.25, random_state=1) # 0.25 x 0.8 = 0.2
train_file = f'{args.save_dir}/{args.code}/train.txt'
valid_file = f'{args.save_dir}/{args.code}/valid.txt'
test_file = f'{args.save_dir}/{args.code}/test.txt'
create_dataset.merge_and_save(X_train, train_file)
create_dataset.merge_and_save(X_val, valid_file)
create_dataset.merge_and_save(X_test, test_file)
print(f'Save train file: {train_file}, valid file: {valid_file}, test file: {test_file}')
if __name__ == '__main__':
main()