diff --git a/example_config/MIMIC-50/cnn.yml b/example_config/MIMIC-50/cnn.yml index fdeff34c..27bdff64 100644 --- a/example_config/MIMIC-50/cnn.yml +++ b/example_config/MIMIC-50/cnn.yml @@ -20,8 +20,8 @@ monitor_metrics: ['P@1','P@3','P@5'] val_metric: P@5 # model -model_name: KimCNN -init_weight: xavier_uniform +model_name: VanillaCNN +init_weight: null network_config: activation: tanh dropout: 0.2 diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index f30e0467..1665bd2b 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -24,6 +24,8 @@ class MultiLabelModel(pl.LightningModule): log_path (str): Path to a directory holding the log files and models. silent (bool, optional): Enable silent mode. Defaults to False. save_k_predictions (int, optional): Save top k predictions on test set. Defaults to 0. + device (torch.device): `device` to use to initialize the network. This parameter is used in + the `reproduce` branch only for the reproducibility of the model. """ def __init__( @@ -38,6 +40,7 @@ def __init__( log_path=None, silent=False, save_k_predictions=0, + device=None, # for reproducibility **kwargs ): super().__init__() diff --git a/libmultilabel/nn/networks/__init__.py b/libmultilabel/nn/networks/__init__.py index 996fc8ed..5752429d 100644 --- a/libmultilabel/nn/networks/__init__.py +++ b/libmultilabel/nn/networks/__init__.py @@ -3,6 +3,7 @@ from .bigru import BiGRU from .caml import CAML from .kim_cnn import KimCNN +from .vanilla_cnn import VanillaCNN from .xml_cnn import XMLCNN diff --git a/libmultilabel/nn/networks/vanilla_cnn.py b/libmultilabel/nn/networks/vanilla_cnn.py new file mode 100644 index 00000000..ddef4153 --- /dev/null +++ b/libmultilabel/nn/networks/vanilla_cnn.py @@ -0,0 +1,43 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_ + +from ..networks.base import BaseModel + + +class VanillaCNN(BaseModel): + def __init__( + self, + embed_vecs, + num_classes, + filter_sizes=None, + num_filter_per_size=500, + dropout=0.2, + activation='tanh' + ): + super(VanillaCNN, self).__init__(embed_vecs, dropout, activation) + if len(filter_sizes) != 1: + raise ValueError(f'VanillaCNN expect 1 filter size. Got filter_sizes={filter_sizes}') + filter_size = filter_sizes[0] + + num_filter_per_size = num_filter_per_size + + self.conv = nn.Conv1d(embed_vecs.shape[1], num_filter_per_size, kernel_size=filter_size) + xavier_uniform_(self.conv.weight) + + self.linear = nn.Linear(num_filter_per_size, num_classes) + xavier_uniform_(self.linear.weight) + + + def forward(self, input): + h = self.embedding(input['text']) # (batch_size, length, embed_dim) + h = self.embed_drop(h) + h = h.transpose(1, 2) # (batch_size, embed_dim, length) + + h = self.conv(h) # (batch_size, num_filter, length) + h = F.max_pool1d(self.activation(h), kernel_size=h.size()[2]) # (batch_size, num_filter, 1) + h = h.squeeze(dim=2) # batch_size, num_filter + + h = self.linear(h) + return {'logits': h} +