-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy path1_split_afhq.py
77 lines (52 loc) · 1.41 KB
/
1_split_afhq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import random
ori_train = '../afhq/train_ori.txt'
ori_val = '../afhq/val_ori.txt'
save_train = '../afhq/train.txt'
save_val = '../afhq/val.txt'
give_to_val = 1500
train_0 = []
train_1 = []
train_2 = []
val = []
with open(ori_train, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line.endswith(' 0'):
train_0.append(line)
elif line.endswith(' 1'):
train_1.append(line)
elif line.endswith(' 2'):
train_2.append(line)
random.shuffle(train_0)
random.shuffle(train_1)
random.shuffle(train_2)
with open(ori_val, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
val.append(line)
content_train = ''
content_val = ''
for record in val:
content_val += record + '\n'
for i, record in enumerate(train_0):
if i < give_to_val:
content_val += record + '\n'
else:
content_train += record + '\n'
for i, record in enumerate(train_1):
if i < give_to_val:
content_val += record + '\n'
else:
content_train += record + '\n'
for i, record in enumerate(train_2):
if i < give_to_val:
content_val += record + '\n'
else:
content_train += record + '\n'
with open(save_train, 'w', encoding='utf-8') as f:
f.write(content_train)
f.close()
with open(save_val, 'w', encoding='utf-8') as f:
f.write(content_val)
f.close()
a = 1