11import csv
22import gc
33import logging
4+ import os
5+ import re
46import warnings
7+ import zipfile
8+ from urllib .request import urlretrieve
9+ from collections import Counter , OrderedDict
510
611import pandas as pd
712import torch
1116from sklearn .preprocessing import MultiLabelBinarizer
1217from torch .nn .utils .rnn import pad_sequence
1318from torch .utils .data import Dataset
14- from torchtext .vocab import build_vocab_from_iterator , pretrained_aliases , Vocab
1519from tqdm import tqdm
1620
1721transformers .logging .set_verbosity_error ()
1822warnings .simplefilter (action = "ignore" , category = FutureWarning )
1923
2024UNK = "<unk>"
2125PAD = "<pad>"
26+ GLOVE_WORD_EMBEDDING = {
27+ "glove.42B.300d" ,
28+ "glove.840B.300d" ,
29+ "glove.6B.50d" ,
30+ "glove.6B.100d" ,
31+ "glove.6B.200d" ,
32+ "glove.6B.300d" ,
33+ }
2234
2335
2436class TextDataset (Dataset ):
@@ -31,8 +43,7 @@ class TextDataset(Dataset):
3143 add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
3244 tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
3345 the transformer-based pretrained language model. Defaults to None.
34- word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
35- map tokens to indices. Defaults to None.
46+ word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
3647 """
3748
3849 def __init__ (
@@ -55,7 +66,7 @@ def __init__(
5566 self .num_classes = len (self .classes )
5667 self .label_binarizer = MultiLabelBinarizer ().fit ([classes ])
5768
58- if not isinstance (self .word_dict , Vocab ) ^ isinstance (self .tokenizer , transformers .PreTrainedTokenizerBase ):
69+ if not isinstance (self .word_dict , dict ) ^ isinstance (self .tokenizer , transformers .PreTrainedTokenizerBase ):
5970 raise ValueError ("Please specify exactly one of word_dict or tokenizer" )
6071
6172 def __len__ (self ):
@@ -71,7 +82,7 @@ def __getitem__(self, index):
7182 else :
7283 input_ids = self .tokenizer .encode (data ["text" ], add_special_tokens = False )
7384 else :
74- input_ids = [self .word_dict [ word ] for word in data ["text" ]]
85+ input_ids = [self .word_dict . get ( word , self . word_dict [ UNK ]) for word in data ["text" ]]
7586 return {
7687 "text" : torch .LongTensor (input_ids [: self .max_seq_length ]),
7788 "label" : torch .IntTensor (self .label_binarizer .transform ([data ["label" ]])[0 ]),
@@ -128,8 +139,7 @@ def get_dataset_loader(
128139 add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
129140 tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
130141 the transformer-based pretrained language model. Defaults to None.
131- word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
132- map tokens to indices. Defaults to None.
142+ word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
133143
134144 Returns:
135145 torch.utils.data.DataLoader: A pytorch DataLoader.
@@ -154,6 +164,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
154164 Args:
155165 data (Union[str, pandas,.Dataframe]): Training, test, or validation data in file or dataframe.
156166 is_test (bool, optional): Whether the data is for test or not. Defaults to False.
167+ tokenize_text (bool, optional): Whether to tokenize text. Defaults to True.
157168 remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels.
158169 This is effective only when is_test=False. Defaults to False.
159170
@@ -265,35 +276,34 @@ def load_or_build_text_dict(
265276):
266277 """Build or load the vocabulary from the training dataset or the predefined `vocab_file`.
267278 The pretrained embedding can be either from a self-defined `embed_file` or from one of
268- the vectors defined in torchtext.vocab.pretrained_aliases
269- (https://github.com/pytorch/text/blob/main/torchtext/vocab/vectors.py).
279+ the vectors: `glove.6B.50d`, `glove.6B.100d`, `glove.6B.200d`, `glove.6B.300d`, `glove.42B.300d`, or `glove.840B.300d`.
270280
271281 Args:
272282 dataset (list): List of training instances with index, label, and tokenized text.
273283 vocab_file (str, optional): Path to a file holding vocabuaries. Defaults to None.
274284 min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
275- embed_file (str): Path to a file holding pre-trained embeddings.
285+ embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding. Defaults to None .
276286 embed_cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
277287 silent (bool, optional): Enable silent mode. Defaults to False.
278288 normalize_embed (bool, optional): Whether the embeddings of each word is normalized to a unit vector. Defaults to False.
279289
280290 Returns:
281- tuple[torchtext.vocab.Vocab , torch.Tensor]: A vocab object which maps tokens to indices and the pre-trained word vectors of shape (vocab_size, embed_dim).
291+ tuple[dict , torch.Tensor]: A dictionary which maps tokens to indices and the pre-trained word vectors of shape (vocab_size, embed_dim).
282292 """
283293 if vocab_file :
284294 logging .info (f"Load vocab from { vocab_file } " )
285295 with open (vocab_file , "r" ) as fp :
286296 vocab_list = [[vocab .strip () for vocab in fp .readlines ()]]
287297 # Keep PAD index 0 to align `padding_idx` of
288298 # class Embedding in libmultilabel.nn.networks.modules.
289- vocabs = build_vocab_from_iterator (vocab_list , min_freq = 1 , specials = [PAD , UNK ])
299+ word_dict = _build_word_dict (vocab_list , min_vocab_freq = 1 , specials = [PAD , UNK ])
290300 else :
291301 vocab_list = [set (data ["text" ]) for data in dataset ]
292- vocabs = build_vocab_from_iterator (vocab_list , min_freq = min_vocab_freq , specials = [PAD , UNK ])
293- vocabs . set_default_index ( vocabs [ UNK ])
294- logging .info (f"Read { len (vocabs )} vocabularies." )
302+ word_dict = _build_word_dict (vocab_list , min_vocab_freq = min_vocab_freq , specials = [PAD , UNK ])
303+
304+ logging .info (f"Read { len (word_dict )} vocabularies." )
295305
296- embedding_weights = get_embedding_weights_from_file (vocabs , embed_file , silent , embed_cache_dir )
306+ embedding_weights = get_embedding_weights_from_file (word_dict , embed_file , silent , embed_cache_dir )
297307
298308 if normalize_embed :
299309 # To have better precision for calculating the normalization, we convert the original
@@ -306,7 +316,41 @@ def load_or_build_text_dict(
306316 embedding_weights [i ] = vector / float (torch .linalg .norm (vector ) + 1e-6 )
307317 embedding_weights = embedding_weights .float ()
308318
309- return vocabs , embedding_weights
319+ return word_dict , embedding_weights
320+
321+
322+ def _build_word_dict (vocab_list , min_vocab_freq = 1 , specials = None ):
323+ r"""Build word dictionary, modified from `torchtext.vocab.build-vocab-from-iterator`
324+ (https://docs.pytorch.org/text/stable/vocab.html#build-vocab-from-iterator)
325+
326+ Args:
327+ vocab_list: List of words.
328+ min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
329+ specials: Special tokens (e.g., <unk>, <pad>) to add. Defaults to None.
330+
331+ Returns:
332+ dict: A dictionary which maps tokens to indices.
333+ """
334+
335+ counter = Counter ()
336+ for tokens in vocab_list :
337+ counter .update (tokens )
338+
339+ # sort by descending frequency, then lexicographically
340+ sorted_by_freq_tuples = sorted (counter .items (), key = lambda x : (- x [1 ], x [0 ]))
341+ ordered_dict = OrderedDict (sorted_by_freq_tuples )
342+
343+ # add special tokens at the beginning
344+ tokens = specials or []
345+ for token , freq in ordered_dict .items ():
346+ if freq >= min_vocab_freq :
347+ tokens .append (token )
348+
349+ # build token to indices dict
350+ word_dict = dict ()
351+ for idx , token in enumerate (tokens ):
352+ word_dict [token ] = idx
353+ return word_dict
310354
311355
312356def load_or_build_label (datasets , label_file = None , include_test_labels = False ):
@@ -344,70 +388,84 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False):
344388 return classes
345389
346390
347- def get_embedding_weights_from_file (word_dict , embed_file , silent = False , cache = None ):
348- """If the word exists in the embedding file, load the pretrained word embedding.
349- Otherwise, assign a zero vector to that word.
391+ def get_embedding_weights_from_file (word_dict , embed_file , silent = False , cache_dir = None ):
392+ """Obtain the word embeddings from file. If the word exists in the embedding file,
393+ load the pretrained word embedding. Otherwise, assign a zero vector to that word.
394+ If the given `embed_file` is the name of a pretrained GloVe embedding, the function
395+ will first download the corresponding file.
350396
351397 Args:
352- word_dict (torchtext.vocab.Vocab ): A vocab object which maps tokens to indices.
353- embed_file (str): Path to a file holding pre-trained embeddings.
398+ word_dict (dict ): A dictionary for mapping tokens to indices.
399+ embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding .
354400 silent (bool, optional): Enable silent mode. Defaults to False.
355- cache (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
401+ cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
356402
357403 Returns:
358404 torch.Tensor: Embedding weights (vocab_size, embed_size).
359405 """
360- # Load pretrained word embedding
361- load_embedding_from_file = embed_file not in pretrained_aliases
362- if load_embedding_from_file :
363- logging .info (f"Load pretrained embedding from file: { embed_file } ." )
364- with open (embed_file ) as f :
365- word_vectors = f .readlines ()
366- embed_size = len (word_vectors [0 ].split ()) - 1
367- vector_dict = {}
368- for word_vector in tqdm (word_vectors , disable = silent ):
369- word , vector = word_vector .rstrip ().split (" " , 1 )
370- vector = torch .Tensor (list (map (float , vector .split ())))
371- vector_dict [word ] = vector
372- else :
373- logging .info (f"Load pretrained embedding from torchtext." )
374- # Adapted from https://pytorch.org/text/0.9.0/_modules/torchtext/vocab.html#Vocab.load_vectors.
375- if embed_file not in pretrained_aliases :
376- raise ValueError (
377- "Got embed_file {}, but allowed pretrained "
378- "vectors are {}" .format (embed_file , list (pretrained_aliases .keys ()))
379- )
380-
381- # Hotfix: Glove URLs are outdated in Torchtext
382- # (https://github.com/pytorch/text/blob/main/torchtext/vocab/vectors.py#L213-L217)
383- pretrained_cls = pretrained_aliases [embed_file ]
384- if embed_file .startswith ("glove" ):
385- for name , url in pretrained_cls .func .url .items ():
386- file_name = url .split ("/" )[- 1 ]
387- pretrained_cls .func .url [name ] = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{ file_name } "
388-
389- vector_dict = pretrained_cls (cache = cache )
390- embed_size = vector_dict .dim
391406
392- embedding_weights = torch .zeros (len (word_dict ), embed_size )
407+ if embed_file in GLOVE_WORD_EMBEDDING :
408+ embed_file = _download_glove_embedding (embed_file , cache_dir = cache_dir )
409+ elif not os .path .isfile (embed_file ):
410+ raise ValueError (
411+ "Got embed_file {}, but allowed pretrained " "embeddings are {}" .format (embed_file , GLOVE_WORD_EMBEDDING )
412+ )
413+
414+ logging .info (f"Load pretrained embedding from { embed_file } ." )
415+ with open (embed_file ) as f :
416+ word_vectors = f .readlines ()
417+ embed_size = len (word_vectors [0 ].split ()) - 1
393418
394- if load_embedding_from_file :
395- # Add UNK embedding
396- # AttentionXML: np.random.uniform(-1.0, 1.0, embed_size)
397- # CAML: np.random.randn(embed_size)
398- unk_vector = torch .randn (embed_size )
399- embedding_weights [word_dict [UNK ]] = unk_vector
419+ vector_dict = {}
420+ for word_vector in tqdm (word_vectors , disable = silent ):
421+ word , vector = word_vector .rstrip ().split (" " , 1 )
422+ vector = torch .Tensor (list (map (float , vector .split ())))
423+ vector_dict [word ] = vector
424+
425+ embedding_weights = torch .zeros (len (word_dict ), embed_size )
426+ # Add UNK embedding
427+ # AttentionXML: np.random.uniform(-1.0, 1.0, embed_size)
428+ # CAML: np.random.randn(embed_size)
429+ unk_vector = torch .randn (embed_size )
430+ embedding_weights [word_dict [UNK ]] = unk_vector
400431
401432 # Store pretrained word embedding
402433 vec_counts = 0
403- for word in word_dict .get_itos ():
404- # The condition can be used to process the word that does not in the embedding file.
405- # Note that torchtext vector object has already dealt with this,
406- # so we can directly make a query without addtional handling.
407- if (load_embedding_from_file and word in vector_dict ) or not load_embedding_from_file :
434+ for word in word_dict .keys ():
435+ if word in vector_dict :
408436 embedding_weights [word_dict [word ]] = vector_dict [word ]
409437 vec_counts += 1
410438
411- logging .info (f"loaded { vec_counts } /{ len (word_dict )} word embeddings" )
439+ logging .info (f"Loaded { vec_counts } /{ len (word_dict )} word embeddings" )
412440
413441 return embedding_weights
442+
443+
444+ def _download_glove_embedding (embed_name , cache_dir = None ):
445+ """Download pretrained glove embedding from https://huggingface.co/stanfordnlp/glove/tree/main.
446+
447+ Args:
448+ embed_name (str): The name of the pretrained GloVe embedding. Defaults to None.
449+ cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
450+
451+ Returns:
452+ str: Path to the file that contains the cached embeddings.
453+ """
454+ cache_dir = ".vector_cache" if cache_dir is None else cache_dir
455+ cached_embed_file = f"{ cache_dir } /{ embed_name } .txt"
456+ if os .path .isfile (cached_embed_file ):
457+ return cached_embed_file
458+ os .makedirs (cache_dir , exist_ok = True )
459+
460+ remote_embed_file = re .sub (r"6B.*" , "6B" , embed_name ) + ".zip"
461+ url = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{ remote_embed_file } "
462+ logging .info (f"Downloading pretrained embeddings from { url } ." )
463+ try :
464+ zip_file , _ = urlretrieve (url , f"{ cache_dir } /{ remote_embed_file } " )
465+ with zipfile .ZipFile (zip_file , "r" ) as zf :
466+ zf .extractall (cache_dir )
467+ except Exception as e :
468+ os .remove (zip_file )
469+ raise e
470+ logging .info (f"Downloaded pretrained embeddings { embed_name } to { cached_embed_file } ." )
471+ return cached_embed_file
0 commit comments