Skip to content

Commit

Permalink
Control of sampling_rate for experiment train and test
Browse files Browse the repository at this point in the history
  • Loading branch information
cgpotts committed Mar 17, 2020
1 parent 455fda5 commit fd0c75f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 10 additions & 5 deletions rel_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ def macro_average_results(results):
return avg_result


def evaluate(splits, classifier, test_split='dev', verbose=True):
test_kbts_by_rel, true_labels_by_rel = splits[test_split].build_dataset()
def evaluate(splits, classifier, test_split='dev', sampling_rate=0.1, verbose=True):
test_kbts_by_rel, true_labels_by_rel = splits[test_split].build_dataset(sampling_rate=sampling_rate)
results = {}
if verbose:
print_statistics_header()
Expand All @@ -475,10 +475,11 @@ def train_models(
split_name='train',
model_factory=(lambda: LogisticRegression(
fit_intercept=True, solver='liblinear', random_state=42)),
sampling_rate=0.1,
vectorize=True,
verbose=True):
train_dataset = splits[split_name]
train_o, train_y = train_dataset.build_dataset()
train_o, train_y = train_dataset.build_dataset(sampling_rate=sampling_rate)
train_X, vectorizer = train_dataset.featurize(
train_o, featurizers, vectorize=vectorize)
models = {}
Expand All @@ -493,9 +494,9 @@ def train_models(
'vectorize': vectorize}


def predict(splits, train_result, split_name='dev', vectorize=True):
def predict(splits, train_result, split_name='dev', sampling_rate=0.1, vectorize=True):
assess_dataset = splits[split_name]
assess_o, assess_y = assess_dataset.build_dataset()
assess_o, assess_y = assess_dataset.build_dataset(sampling_rate=sampling_rate)
test_X, _ = assess_dataset.featurize(
assess_o,
featurizers=train_result['featurizers'],
Expand Down Expand Up @@ -531,19 +532,23 @@ def experiment(
test_split='dev',
model_factory=(lambda: LogisticRegression(
fit_intercept=True, solver='liblinear', random_state=42)),
train_sampling_rate=0.1,
test_sampling_rate=0.1,
vectorize=True,
verbose=True):
train_result = train_models(
splits,
featurizers=featurizers,
split_name=train_split,
model_factory=model_factory,
sampling_rate=train_sampling_rate,
vectorize=vectorize,
verbose=verbose)
predictions, test_y = predict(
splits,
train_result,
split_name=test_split,
sampling_rate=test_sampling_rate,
vectorize=vectorize)
evaluate_predictions(
predictions,
Expand Down
2 changes: 2 additions & 0 deletions test/test_rel_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def test_experiment(featurizer, vectorize, corpus, kb):
train_split='tiny_train',
test_split='tiny_dev',
featurizers=[featurizer],
train_sampling_rate=0.2,
test_sampling_rate=0.2,
vectorize=vectorize,
verbose=False)

Expand Down

0 comments on commit fd0c75f

Please sign in to comment.