Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ This is an on-going development so many improvements are still being made. Comme

## Environments
- Python: 3.10+
- CUDA: 11.8, 12.1 (if training neural networks by GPU)
- Pytorch: 2.0.1+
- CUDA: 11.8, 12.1, 12.6 (if training neural networks by GPU)
- Pytorch: 2.3.0+

If you have a different version of CUDA, follow the installation instructions for PyTorch LTS at their [website](https://pytorch.org/).

Expand Down
2 changes: 1 addition & 1 deletion docs/cli/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ If a model was trained before by this package, the training procedure can start

To use your own word embeddings or vocabulary set, specify the following parameters:

- **embed_file**: choose one of the pretrained embeddings defined in `torchtext <https://pytorch.org/text/0.9.0/vocab.html#torchtext.vocab.Vocab.load_vectors>`_ or specify the path to your word embeddings with each line containing a word followed by its vectors. Example:
- **embed_file**: choose one of the pretrained embeddings: `glove.6B.50d`, `glove.6B.100d`, `glove.6B.200d`, `glove.6B.300d`, `glove.42B.300d`, `glove.840B.300d`, or specify the path to your word embeddings with each line containing a word followed by its vectors. Example:

.. code-block::

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/plot_KimCNN_quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# To run KimCNN, LibMultiLabel tokenizes documents and uses an embedding vector for each word.
# Thus, ``tokenize_text=True`` is set.
#
# We choose ``glove.6B.300d`` from torchtext as embedding vectors.
# We choose ``glove.6B.300d`` as embedding vectors.

datasets = load_datasets("data/rcv1/train.txt", "data/rcv1/test.txt", tokenize_text=True)
classes = load_or_build_label(datasets)
Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/attentionxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def reformat_text(self, dataset):
# Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length.
encoded_text = list(
map(
lambda text: torch.tensor([self.word_dict[word] for word in text], dtype=torch.int64)
lambda text: torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64)
if text
else torch.tensor([self.word_dict[UNK]], dtype=torch.int64),
[instance["text"][: self.max_seq_length] for instance in dataset],
Expand Down
192 changes: 125 additions & 67 deletions libmultilabel/nn/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import csv
import gc
import logging
import os
import re
import warnings
import zipfile
from urllib.request import urlretrieve
from collections import Counter, OrderedDict

import pandas as pd
import torch
Expand All @@ -11,14 +16,21 @@
from sklearn.preprocessing import MultiLabelBinarizer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torchtext.vocab import build_vocab_from_iterator, pretrained_aliases, Vocab
from tqdm import tqdm

transformers.logging.set_verbosity_error()
warnings.simplefilter(action="ignore", category=FutureWarning)

UNK = "<unk>"
PAD = "<pad>"
GLOVE_WORD_EMBEDDING = {
"glove.42B.300d",
"glove.840B.300d",
"glove.6B.50d",
"glove.6B.100d",
"glove.6B.200d",
"glove.6B.300d",
}


class TextDataset(Dataset):
Expand All @@ -31,8 +43,7 @@ class TextDataset(Dataset):
add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
the transformer-based pretrained language model. Defaults to None.
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
map tokens to indices. Defaults to None.
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
"""

def __init__(
Expand All @@ -55,7 +66,7 @@ def __init__(
self.num_classes = len(self.classes)
self.label_binarizer = MultiLabelBinarizer().fit([classes])

if not isinstance(self.word_dict, Vocab) ^ isinstance(self.tokenizer, transformers.PreTrainedTokenizerBase):
if not isinstance(self.word_dict, dict) ^ isinstance(self.tokenizer, transformers.PreTrainedTokenizerBase):
raise ValueError("Please specify exactly one of word_dict or tokenizer")

def __len__(self):
Expand All @@ -71,7 +82,7 @@ def __getitem__(self, index):
else:
input_ids = self.tokenizer.encode(data["text"], add_special_tokens=False)
else:
input_ids = [self.word_dict[word] for word in data["text"]]
input_ids = [self.word_dict.get(word, self.word_dict[UNK]) for word in data["text"]]
return {
"text": torch.LongTensor(input_ids[: self.max_seq_length]),
"label": torch.IntTensor(self.label_binarizer.transform([data["label"]])[0]),
Expand Down Expand Up @@ -128,8 +139,7 @@ def get_dataset_loader(
add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
the transformer-based pretrained language model. Defaults to None.
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
map tokens to indices. Defaults to None.
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.

Returns:
torch.utils.data.DataLoader: A pytorch DataLoader.
Expand All @@ -154,6 +164,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
Args:
data (Union[str, pandas,.Dataframe]): Training, test, or validation data in file or dataframe.
is_test (bool, optional): Whether the data is for test or not. Defaults to False.
tokenize_text (bool, optional): Whether to tokenize text. Defaults to True.
remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels.
This is effective only when is_test=False. Defaults to False.

Expand Down Expand Up @@ -265,35 +276,34 @@ def load_or_build_text_dict(
):
"""Build or load the vocabulary from the training dataset or the predefined `vocab_file`.
The pretrained embedding can be either from a self-defined `embed_file` or from one of
the vectors defined in torchtext.vocab.pretrained_aliases
(https://github.com/pytorch/text/blob/main/torchtext/vocab/vectors.py).
the vectors: `glove.6B.50d`, `glove.6B.100d`, `glove.6B.200d`, `glove.6B.300d`, `glove.42B.300d`, or `glove.840B.300d`.

Args:
dataset (list): List of training instances with index, label, and tokenized text.
vocab_file (str, optional): Path to a file holding vocabuaries. Defaults to None.
min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
embed_file (str): Path to a file holding pre-trained embeddings.
embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding. Defaults to None.
embed_cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
silent (bool, optional): Enable silent mode. Defaults to False.
normalize_embed (bool, optional): Whether the embeddings of each word is normalized to a unit vector. Defaults to False.

Returns:
tuple[torchtext.vocab.Vocab, torch.Tensor]: A vocab object which maps tokens to indices and the pre-trained word vectors of shape (vocab_size, embed_dim).
tuple[dict, torch.Tensor]: A dictionary which maps tokens to indices and the pre-trained word vectors of shape (vocab_size, embed_dim).
"""
if vocab_file:
logging.info(f"Load vocab from {vocab_file}")
with open(vocab_file, "r") as fp:
vocab_list = [[vocab.strip() for vocab in fp.readlines()]]
# Keep PAD index 0 to align `padding_idx` of
# class Embedding in libmultilabel.nn.networks.modules.
vocabs = build_vocab_from_iterator(vocab_list, min_freq=1, specials=[PAD, UNK])
word_dict = _build_word_dict(vocab_list, min_vocab_freq=1, specials=[PAD, UNK])
else:
vocab_list = [set(data["text"]) for data in dataset]
vocabs = build_vocab_from_iterator(vocab_list, min_freq=min_vocab_freq, specials=[PAD, UNK])
vocabs.set_default_index(vocabs[UNK])
logging.info(f"Read {len(vocabs)} vocabularies.")
word_dict = _build_word_dict(vocab_list, min_vocab_freq=min_vocab_freq, specials=[PAD, UNK])

logging.info(f"Read {len(word_dict)} vocabularies.")

embedding_weights = get_embedding_weights_from_file(vocabs, embed_file, silent, embed_cache_dir)
embedding_weights = get_embedding_weights_from_file(word_dict, embed_file, silent, embed_cache_dir)

if normalize_embed:
# To have better precision for calculating the normalization, we convert the original
Expand All @@ -306,7 +316,41 @@ def load_or_build_text_dict(
embedding_weights[i] = vector / float(torch.linalg.norm(vector) + 1e-6)
embedding_weights = embedding_weights.float()

return vocabs, embedding_weights
return word_dict, embedding_weights


def _build_word_dict(vocab_list, min_vocab_freq=1, specials=None):
r"""Build word dictionary, modified from `torchtext.vocab.build-vocab-from-iterator`
(https://docs.pytorch.org/text/stable/vocab.html#build-vocab-from-iterator)

Args:
vocab_list: List of words.
min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
specials: Special tokens (e.g., <unk>, <pad>) to add. Defaults to None.

Returns:
dict: A dictionary which maps tokens to indices.
"""

counter = Counter()
for tokens in vocab_list:
counter.update(tokens)

# sort by descending frequency, then lexicographically
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
ordered_dict = OrderedDict(sorted_by_freq_tuples)

# add special tokens at the beginning
tokens = specials or []
for token, freq in ordered_dict.items():
if freq >= min_vocab_freq:
tokens.append(token)

# build token to indices dict
word_dict = dict()
for idx, token in enumerate(tokens):
word_dict[token] = idx
return word_dict


def load_or_build_label(datasets, label_file=None, include_test_labels=False):
Expand Down Expand Up @@ -344,70 +388,84 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False):
return classes


def get_embedding_weights_from_file(word_dict, embed_file, silent=False, cache=None):
"""If the word exists in the embedding file, load the pretrained word embedding.
Otherwise, assign a zero vector to that word.
def get_embedding_weights_from_file(word_dict, embed_file, silent=False, cache_dir=None):
"""Obtain the word embeddings from file. If the word exists in the embedding file,
load the pretrained word embedding. Otherwise, assign a zero vector to that word.
If the given `embed_file` is the name of a pretrained GloVe embedding, the function
will first download the corresponding file.

Args:
word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
embed_file (str): Path to a file holding pre-trained embeddings.
word_dict (dict): A dictionary for mapping tokens to indices.
embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding.
silent (bool, optional): Enable silent mode. Defaults to False.
cache (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.

Returns:
torch.Tensor: Embedding weights (vocab_size, embed_size).
"""
# Load pretrained word embedding
load_embedding_from_file = embed_file not in pretrained_aliases
if load_embedding_from_file:
logging.info(f"Load pretrained embedding from file: {embed_file}.")
with open(embed_file) as f:
word_vectors = f.readlines()
embed_size = len(word_vectors[0].split()) - 1
vector_dict = {}
for word_vector in tqdm(word_vectors, disable=silent):
word, vector = word_vector.rstrip().split(" ", 1)
vector = torch.Tensor(list(map(float, vector.split())))
vector_dict[word] = vector
else:
logging.info(f"Load pretrained embedding from torchtext.")
# Adapted from https://pytorch.org/text/0.9.0/_modules/torchtext/vocab.html#Vocab.load_vectors.
if embed_file not in pretrained_aliases:
raise ValueError(
"Got embed_file {}, but allowed pretrained "
"vectors are {}".format(embed_file, list(pretrained_aliases.keys()))
)

# Hotfix: Glove URLs are outdated in Torchtext
# (https://github.com/pytorch/text/blob/main/torchtext/vocab/vectors.py#L213-L217)
pretrained_cls = pretrained_aliases[embed_file]
if embed_file.startswith("glove"):
for name, url in pretrained_cls.func.url.items():
file_name = url.split("/")[-1]
pretrained_cls.func.url[name] = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{file_name}"

vector_dict = pretrained_cls(cache=cache)
embed_size = vector_dict.dim

embedding_weights = torch.zeros(len(word_dict), embed_size)
if embed_file in GLOVE_WORD_EMBEDDING:
embed_file = _download_glove_embedding(embed_file, cache_dir=cache_dir)
elif not os.path.isfile(embed_file):
raise ValueError(
"Got embed_file {}, but allowed pretrained " "embeddings are {}".format(embed_file, GLOVE_WORD_EMBEDDING)
)

logging.info(f"Load pretrained embedding from {embed_file}.")
with open(embed_file) as f:
word_vectors = f.readlines()
embed_size = len(word_vectors[0].split()) - 1

if load_embedding_from_file:
# Add UNK embedding
# AttentionXML: np.random.uniform(-1.0, 1.0, embed_size)
# CAML: np.random.randn(embed_size)
unk_vector = torch.randn(embed_size)
embedding_weights[word_dict[UNK]] = unk_vector
vector_dict = {}
for word_vector in tqdm(word_vectors, disable=silent):
word, vector = word_vector.rstrip().split(" ", 1)
vector = torch.Tensor(list(map(float, vector.split())))
vector_dict[word] = vector

embedding_weights = torch.zeros(len(word_dict), embed_size)
# Add UNK embedding
# AttentionXML: np.random.uniform(-1.0, 1.0, embed_size)
# CAML: np.random.randn(embed_size)
unk_vector = torch.randn(embed_size)
embedding_weights[word_dict[UNK]] = unk_vector

# Store pretrained word embedding
vec_counts = 0
for word in word_dict.get_itos():
# The condition can be used to process the word that does not in the embedding file.
# Note that torchtext vector object has already dealt with this,
# so we can directly make a query without addtional handling.
if (load_embedding_from_file and word in vector_dict) or not load_embedding_from_file:
for word in word_dict.keys():
if word in vector_dict:
embedding_weights[word_dict[word]] = vector_dict[word]
vec_counts += 1

logging.info(f"loaded {vec_counts}/{len(word_dict)} word embeddings")
logging.info(f"Loaded {vec_counts}/{len(word_dict)} word embeddings")

return embedding_weights


def _download_glove_embedding(embed_name, cache_dir=None):
"""Download pretrained glove embedding from https://huggingface.co/stanfordnlp/glove/tree/main.

Args:
embed_name (str): The name of the pretrained GloVe embedding. Defaults to None.
cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.

Returns:
str: Path to the file that contains the cached embeddings.
"""
cache_dir = ".vector_cache" if cache_dir is None else cache_dir
cached_embed_file = f"{cache_dir}/{embed_name}.txt"
if os.path.isfile(cached_embed_file):
return cached_embed_file
os.makedirs(cache_dir, exist_ok=True)

remote_embed_file = re.sub(r"6B.*", "6B", embed_name) + ".zip"
url = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{remote_embed_file}"
logging.info(f"Downloading pretrained embeddings from {url}.")
try:
zip_file, _ = urlretrieve(url, f"{cache_dir}/{remote_embed_file}")
with zipfile.ZipFile(zip_file, "r") as zf:
zf.extractall(cache_dir)
except Exception as e:
os.remove(zip_file)
raise e
logging.info(f"Downloaded pretrained embeddings {embed_name} to {cached_embed_file}.")
return cached_embed_file
2 changes: 1 addition & 1 deletion libmultilabel/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class Model(MultiLabelModel):

Args:
classes (list): List of class names.
word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
word_dict (dict): A dictionary for mapping tokens to indices.
network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN).
loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits,
cross_entropy). Defaults to 'binary_cross_entropy_with_logits'.
Expand Down
3 changes: 1 addition & 2 deletions libmultilabel/nn/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def init_model(
model_name (str): Model to be used such as KimCNN.
network_config (dict): Configuration for defining the network.
classes (list): List of class names.
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
map tokens to indices. Defaults to None.
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape
(vocab_size, embed_dim). Defaults to None.
init_weight (str): Weight initialization method from `torch.nn.init`.
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def add_all_arguments(parser):
# pretrained vocab / embeddings
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
parser.add_argument(
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings (default: %(default)s)"
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)"
)
parser.add_argument("--label_file", type=str, help="Path to a file holding all labels (default: %(default)s)")

Expand Down
6 changes: 2 additions & 4 deletions requirements_nn.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
nltk
lightning
# https://github.com/pytorch/text/releases
torch<=2.3
torch
torchmetrics==0.10.3
torchtext
# https://github.com/huggingface/transformers/issues/38464
transformers<=4.51.3
transformers
4 changes: 2 additions & 2 deletions search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def train_libmultilabel_tune(config, datasets, classes, word_dict):
Args:
config (dict): Config of the experiment.
datasets (dict): A dictionary of datasets.
classes(list): List of class names.
word_dict(torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
classes (list): List of class names.
word_dict (dict): A dictionary for mapping tokens to indices.
"""

# ray convert AttributeDict to dict
Expand Down
Loading