Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/examples/plot_KimCNN_quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
model_name=model_name,
network_config=network_config,
classes=classes,
word_dict=word_dict,
embed_vecs=embed_vecs,
learning_rate=learning_rate,
monitor_metrics=["Micro-F1", "Macro-F1", "P@1", "P@3", "P@5"],
Expand All @@ -66,7 +65,7 @@
# * ``model_name`` leads ``init_model`` function to find a network model.
# * ``network_config`` contains the configurations of a network model.
# * ``classes`` is the label set of the data.
# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them.
# * ``embed_vecs`` is the the pre-trained word vectors.
# * ``moniter_metrics`` includes metrics you would like to track.
#
#
Expand Down
1 change: 0 additions & 1 deletion docs/examples/plot_bert_quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
# * ``model_name`` leads ``init_model`` function to find a network model.
# * ``network_config`` contains the configurations of a network model.
# * ``classes`` is the label set of the data.
# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them.
# * ``moniter_metrics`` includes metrics you would like to track.
#
#
Expand Down
21 changes: 13 additions & 8 deletions libmultilabel/nn/attentionxml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import logging
import os
import pickle
from functools import partial
from pathlib import Path
from typing import Generator, Sequence, Optional
Expand Down Expand Up @@ -33,6 +35,7 @@

class PLTTrainer:
CHECKPOINT_NAME = "model_"
WORD_DICT_NAME = "word_dict.pickle"

def __init__(
self,
Expand Down Expand Up @@ -261,7 +264,6 @@ def fit(self, datasets):
model_name="AttentionXML_0",
network_config=self.network_config,
classes=clusters,
word_dict=self.word_dict,
embed_vecs=self.embed_vecs,
init_weight=self.init_weight,
log_path=self.log_path,
Expand Down Expand Up @@ -380,7 +382,6 @@ def fit(self, datasets):

model_1 = PLTModel(
classes=self.classes,
word_dict=self.word_dict,
network=network,
log_path=self.log_path,
learning_rate=self.learning_rate,
Expand Down Expand Up @@ -427,7 +428,11 @@ def test(self, dataset):
save_k_predictions=self.save_k_predictions,
metrics=self.metrics,
)
self.word_dict = model_1.word_dict

word_dict_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), self.WORD_DICT_NAME)
if os.path.exists(word_dict_path):
with open(word_dict_path, "rb") as f:
self.word_dict = pickle.load(f)
classes = model_1.classes

test_x = self.reformat_text(dataset)
Expand Down Expand Up @@ -489,9 +494,11 @@ def reformat_text(self, dataset):
# Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length.
encoded_text = list(
map(
lambda text: torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64)
if text
else torch.tensor([self.word_dict[UNK]], dtype=torch.int64),
lambda text: (
torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64)
if text
else torch.tensor([self.word_dict[UNK]], dtype=torch.int64)
),
[instance["text"][: self.max_seq_length] for instance in dataset],
)
)
Expand Down Expand Up @@ -519,15 +526,13 @@ class PLTModel(Model):
def __init__(
self,
classes,
word_dict,
network,
loss_function="binary_cross_entropy_with_logits",
log_path=None,
**kwargs,
):
super().__init__(
classes=classes,
word_dict=word_dict,
network=network,
loss_function=loss_function,
log_path=log_path,
Expand Down
12 changes: 1 addition & 11 deletions libmultilabel/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,27 +181,17 @@ class Model(MultiLabelModel):

Args:
classes (list): List of class names.
word_dict (dict): A dictionary for mapping tokens to indices.
network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN).
loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits,
cross_entropy). Defaults to 'binary_cross_entropy_with_logits'.
log_path (str): Path to a directory holding the log files and models.
"""

def __init__(
self,
classes,
word_dict,
network,
loss_function="binary_cross_entropy_with_logits",
log_path=None,
**kwargs
):
def __init__(self, classes, network, loss_function="binary_cross_entropy_with_logits", log_path=None, **kwargs):
super().__init__(num_classes=len(classes), log_path=log_path, **kwargs)
self.save_hyperparameters(
ignore=["log_path"]
) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir).
self.word_dict = word_dict
self.classes = classes
self.network = network
self.configure_loss_function(loss_function)
Expand Down
3 changes: 0 additions & 3 deletions libmultilabel/nn/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def init_model(
model_name,
network_config,
classes,
word_dict=None,
embed_vecs=None,
init_weight=None,
log_path=None,
Expand All @@ -61,7 +60,6 @@ def init_model(
model_name (str): Model to be used such as KimCNN.
network_config (dict): Configuration for defining the network.
classes (list): List of class names.
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape
(vocab_size, embed_dim). Defaults to None.
init_weight (str): Weight initialization method from `torch.nn.init`.
Expand Down Expand Up @@ -98,7 +96,6 @@ def init_model(

model = Model(
classes=classes,
word_dict=word_dict,
network=network,
log_path=log_path,
learning_rate=learning_rate,
Expand Down
4 changes: 2 additions & 2 deletions tests/nn/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_name(self):
return "token_to_id"

def get_from_trainer(self, trainer):
return trainer.model.word_dict
return trainer.word_dict

def compare(self, a, b):
return a == b
Expand All @@ -34,7 +34,7 @@ def get_name(self):
return "embed_vecs"

def get_from_trainer(self, trainer):
return trainer.model.embed_vecs
return trainer.embed_vecs

def compare(self, a, b):
return (a == b).all()
Expand Down
61 changes: 35 additions & 26 deletions torch_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import pickle

import numpy as np
from lightning.pytorch.callbacks import ModelCheckpoint
Expand All @@ -25,6 +26,8 @@ class TorchTrainer:
Defaults to True.
"""

WORD_DICT_NAME = "word_dict.pickle"

def __init__(
self,
config: dict,
Expand All @@ -44,6 +47,11 @@ def __init__(
self.device = init_device(use_cpu=config.cpu)
self.config = config

# Set dataset meta info
self.embed_vecs = embed_vecs
self.word_dict = word_dict
self.classes = classes

# Load pretrained tokenizer for dataset loader
self.tokenizer = None
tokenize_text = "lm_weight" not in config.network_config
Expand All @@ -69,8 +77,9 @@ def __init__(
# Note that AttentionXML produces two models. checkpoint_path directs to model_1
if config.checkpoint_path is None:
if self.config.embed_file is not None:
logging.info("Load word dictionary ")
word_dict, embed_vecs = data_utils.load_or_build_text_dict(
word_dict_path = os.path.join(self.checkpoint_dir, self.WORD_DICT_NAME)
logging.info(f"Load and cache the word dictionary into {word_dict_path}.")
self.word_dict, self.embed_vecs = data_utils.load_or_build_text_dict(
dataset=self.datasets["train"] + self.datasets["val"],
vocab_file=config.vocab_file,
min_vocab_freq=config.min_vocab_freq,
Expand All @@ -79,9 +88,11 @@ def __init__(
normalize_embed=config.normalize_embed,
embed_cache_dir=config.embed_cache_dir,
)
with open(word_dict_path, "wb") as f:
pickle.dump(self.word_dict, f)

if not classes:
classes = data_utils.load_or_build_label(
if not self.classes:
self.classes = data_utils.load_or_build_label(
self.datasets, self.config.label_file, self.config.include_test_labels
)

Expand All @@ -98,15 +109,12 @@ def __init__(
f"Add {self.config.val_metric} to `monitor_metrics`."
)
self.config.monitor_metrics += [self.config.val_metric]
self.trainer = PLTTrainer(self.config, classes=classes, embed_vecs=embed_vecs, word_dict=word_dict)
self.trainer = PLTTrainer(
self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict
)
return
self._setup_model(
classes=classes,
word_dict=word_dict,
embed_vecs=embed_vecs,
log_path=self.log_path,
checkpoint_path=config.checkpoint_path,
)

self._setup_model(log_path=self.log_path, checkpoint_path=config.checkpoint_path)
self.trainer = init_trainer(
checkpoint_dir=self.checkpoint_dir,
epochs=config.epochs,
Expand All @@ -125,19 +133,13 @@ def __init__(

def _setup_model(
self,
classes: list = None,
word_dict: dict = None,
embed_vecs=None,
log_path: str = None,
checkpoint_path: str = None,
):
"""Setup model from checkpoint if a checkpoint path is passed in or specified in the config.
Otherwise, initialize model from scratch.

Args:
classes(list): List of class names.
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
log_path (str): Path to the log file. The log file contains the validation
results for each epoch and the test results. If the `log_path` is None, no performance
results will be logged.
Expand All @@ -149,11 +151,16 @@ def _setup_model(
if checkpoint_path is not None:
logging.info(f"Loading model from `{checkpoint_path}` with the previously saved hyper-parameter...")
self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path)
word_dict_path = os.path.join(os.path.dirname(checkpoint_path), self.WORD_DICT_NAME)
if os.path.exists(word_dict_path):
with open(word_dict_path, "rb") as f:
self.word_dict = pickle.load(f)
else:
logging.info("Initialize model from scratch.")
if self.config.embed_file is not None:
logging.info("Load word dictionary ")
word_dict, embed_vecs = data_utils.load_or_build_text_dict(
word_dict_path = os.path.join(self.checkpoint_dir, self.WORD_DICT_NAME)
logging.info(f"Load and cache the word dictionary into {word_dict_path}.")
self.word_dict, self.embed_vecs = data_utils.load_or_build_text_dict(
dataset=self.datasets["train"],
vocab_file=self.config.vocab_file,
min_vocab_freq=self.config.min_vocab_freq,
Expand All @@ -162,8 +169,11 @@ def _setup_model(
normalize_embed=self.config.normalize_embed,
embed_cache_dir=self.config.embed_cache_dir,
)
if not classes:
classes = data_utils.load_or_build_label(
with open(word_dict_path, "wb") as f:
pickle.dump(self.word_dict, f)

if not self.classes:
self.classes = data_utils.load_or_build_label(
self.datasets, self.config.label_file, self.config.include_test_labels
)

Expand All @@ -184,9 +194,8 @@ def _setup_model(
self.model = init_model(
model_name=self.config.model_name,
network_config=dict(self.config.network_config),
classes=classes,
word_dict=word_dict,
embed_vecs=embed_vecs,
classes=self.classes,
embed_vecs=self.embed_vecs,
init_weight=self.config.init_weight,
log_path=log_path,
learning_rate=self.config.learning_rate,
Expand Down Expand Up @@ -222,7 +231,7 @@ def _get_dataset_loader(self, split, shuffle=False):
batch_size=self.config.batch_size if split == "train" else self.config.eval_batch_size,
shuffle=shuffle,
data_workers=self.config.data_workers,
word_dict=self.model.word_dict,
word_dict=self.word_dict,
tokenizer=self.tokenizer,
add_special_tokens=self.config.add_special_tokens,
)
Expand Down