Skip to content

Commit 8894065

Browse files
committed
Improvement in train_test_split function
Improvement in train_test_split function Shuffling has been removed
1 parent bce45f8 commit 8894065

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

learning.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,14 +1049,34 @@ def grade_learner(predict, tests):
10491049
return mean(int(predict(X) == y) for X, y in tests)
10501050

10511051

1052-
def train_test_split(dataset, start, end):
1053-
"""Reserve dataset.examples[start:end] for test; train on the remainder."""
1054-
start = int(start)
1055-
end = int(end)
1056-
examples = dataset.examples
1057-
train = examples[:start] + examples[end:]
1058-
val = examples[start:end]
1059-
return train, val
1052+
def train_test_split(dataset, start = None, end = None, test_split = None):
1053+
"""If you are giving 'start' and 'end' as a parameter,
1054+
then it will return testing set from index 'start' to 'end'
1055+
and rest for training.
1056+
If you give 'test_split' as parameter then it will first shuffle the
1057+
dataset then return test_split * 100% as testing set and rest as
1058+
training set.
1059+
"""
1060+
1061+
if start == None and end != None:
1062+
raise ValueError("'start' parameter is missing")
1063+
1064+
if start != None and end == None:
1065+
raise ValueError("'end' parameter is missing")
1066+
1067+
if test_split == None:
1068+
examples = dataset.examples
1069+
train = examples[:start] + examples[end:]
1070+
val = examples[start:end]
1071+
return train, val
1072+
else:
1073+
examples = dataset.examples
1074+
total_size = len(examples)
1075+
val_size = int(total_size * test_split)
1076+
train_size = total_size - val_size
1077+
train = examples[:train_size]
1078+
val = examples[train_size:total_size]
1079+
return train, val
10601080

10611081

10621082
def cross_validation(learner, size, dataset, k=10, trials=1):

0 commit comments

Comments
 (0)