Skip to content

Commit 381b50d

Browse files
committed
Improvement:
Add comments Code style
1 parent c017fdf commit 381b50d

File tree

1 file changed

+85
-2
lines changed

1 file changed

+85
-2
lines changed

learning_rate.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""Calculate data for learning rate calculation.
2+
3+
Taking data from sampling_learning_rate.py
4+
Output may be processed by plot_learning_rate.py
5+
"""
16
# @Author: Joey Teng
27
# @Email: joey.teng.dev@gmail.com
38
# @Filename: learning_rate.py
@@ -22,6 +27,19 @@
2227

2328

2429
def split_data_target(dataset):
30+
"""Split the input CSV files into X, y vectors for sklearn implementations.
31+
32+
Args:
33+
dataset (list): List of list of floats.
34+
[
35+
[0...n - 1]: X, feature vector
36+
[-1]: y, label
37+
]
38+
39+
Returns:
40+
tuple: (X, y) for sklearn implementations
41+
42+
"""
2543
try:
2644
return ([[float(element)
2745
for element in row.strip().split(',')[:-1]]
@@ -34,15 +52,59 @@ def split_data_target(dataset):
3452

3553

3654
def generate_training_sets(dataset, percentage, copies):
55+
"""Resample from separated training sets to generate smaller training sets.
56+
57+
No instance will present in one new training set more than once.
58+
Mechanism is to shuffle, then pick the first percentage% instances.
59+
60+
Args:
61+
dataset (list): List of vectors (features + label)
62+
percentage (number that supports __mul__ and __floordiv__):
63+
This decides the size of new training set generated related to the
64+
population.
65+
copies (int): The number of new training sets required.
66+
67+
Returns:
68+
list: list of new training datasets
69+
list of list of vectors
70+
71+
"""
3772
training_sets = []
38-
for i in range(copies):
73+
i = copies
74+
while i > 0:
3975
population = copy.deepcopy(dataset)
4076
random.shuffle(population)
4177
training_sets.append(population[:len(population) * percentage // 100])
78+
i -= 1
79+
4280
return training_sets
4381

4482

4583
def generate_result(datasets, classifier, path):
84+
"""Generate the learning rate accuracies.
85+
86+
Args:
87+
datasets (dict): {
88+
'test set': testing set for the specific dataset
89+
'remainder': instances in the dataset but not testing set
90+
}
91+
classifier (func): a function that will return an instance of
92+
sklearn classifier.
93+
path (str): path of the dataset, for logging only.
94+
95+
Returns:
96+
dict: dict of dict {
97+
percentage: results under respective portion of training data {
98+
'raw' (list): raw accuracy values [
99+
accuracy values of each training set-testing set pairs
100+
]
101+
'average': average of 'raw'
102+
'standard deviation': standard deviation of 'raw'
103+
'range': range of 'raw'
104+
}
105+
}
106+
107+
"""
46108
results = []
47109
for dataset in datasets:
48110
test_set = dataset['test set']
@@ -84,11 +146,15 @@ def generate_result(datasets, classifier, path):
84146

85147

86148
def RandomForestClassifier():
149+
"""Wrap a default Random Forest classifier with fixed parameter."""
87150
return sklearn.ensemble.RandomForestClassifier(n_estimators=64)
88151

89152

90153
def main(path):
91-
"""main"""
154+
"""Start main function here.
155+
156+
Run tasks and dump result files.
157+
"""
92158
print("{} Start".format(path), flush=True)
93159

94160
datasets = json.load(open(path, 'r'))
@@ -103,6 +169,15 @@ def main(path):
103169

104170

105171
def traverse(paths):
172+
"""Travsere to append all files in children folders into the task queue.
173+
174+
Args:
175+
paths (list): Paths of all folders to be detected
176+
177+
Returns:
178+
list: Paths of all files added in the task queue
179+
180+
"""
106181
print("Starting Traverse Through", flush=True)
107182
files = []
108183
while paths:
@@ -119,6 +194,14 @@ def traverse(paths):
119194

120195

121196
def parse_path():
197+
"""Parse the arguments.
198+
199+
No argument is required for calling this function.
200+
201+
Returns:
202+
Namespace: parsed arguments enclosed by an object defined in argparse
203+
204+
"""
122205
parser = argparse.ArgumentParser(
123206
description="Generate Datasets for Detecting Learning Rate")
124207
parser.add_argument('-r', action='store', nargs='+', default=[],

0 commit comments

Comments
 (0)