Skip to content

Commit c72efc4

Browse files
committed
Move and rename files
1 parent f5d0805 commit c72efc4

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

training/__init__.py

Whitespace-only changes.

metrics.py renamed to training/evaluation_util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33

44
def compute_rouge_scores(pred_seq, target_seq):
5+
"""
6+
:param pred_seq: Predicted sequence
7+
:param target_seq: Target sequence
8+
:return: a pair (rouge_2, rouge_l) containing the rouge-2 and rouge-l scores given pred_seq
9+
and target_seq
10+
"""
511
rouge = Rouge()
612
pred_seq_str = ' '.join([str(x) for x in pred_seq])
713
target_seq = ' '.join([str(x) for x in target_seq])

train.py renamed to training/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
import argparse
33

4-
from models.lstm_to_lstm import Seq2Seq
5-
from models.lstm_to_lstm_full_training import train_iters
4+
from models.full_model import FullModel
5+
from training.train_model import train_iters
66
from models.lstm_encoder import LSTMEncoder
77
from models.lstm_decoder import LSTMDecoder
88
from data_processing.data_util import prepare_tokens, prepare_data
@@ -33,10 +33,10 @@ def main():
3333
graph_encoder = GATEncoder(hidden_size, hidden_size)
3434
else:
3535
graph_encoder = GCNEncoder(hidden_size, hidden_size)
36-
model = Seq2Seq(encoder=encoder, graph_encoder=graph_encoder, decoder=decoder,
37-
device=opt.device)
36+
model = FullModel(encoder=encoder, graph_encoder=graph_encoder, decoder=decoder,
37+
device=opt.device)
3838
else:
39-
model = Seq2Seq(encoder=encoder, decoder=decoder, device=opt.device)
39+
model = FullModel(encoder=encoder, decoder=decoder, device=opt.device)
4040

4141
train_iters(model, opt.iterations, pairs, print_every=opt.print_every, model_dir=model_dir,
4242
lang=lang, graph=opt.graph)

models/lstm_to_lstm_full_training.py renamed to training/train_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import unicode_literals, print_function, division
22
import random
3-
from data_processing import tensors_from_pair_tokens, plot_loss, tensors_from_pair_tokens_graph
3+
from data_processing.data_util import tensors_from_pair_tokens, plot_loss, \
4+
tensors_from_pair_tokens_graph
45

56
import torch
67
import torch.nn as nn
78
from torch import optim
89
from sklearn.metrics import f1_score
910
import numpy as np
10-
from metrics import compute_rouge_scores
11+
from training.evaluation_util import compute_rouge_scores
1112
import pickle
1213

1314

0 commit comments

Comments
 (0)