This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 814
Re-write IMDB dataset in torchtext.experimental.datasets #651
Merged
Merged
Changes from all commits
Commits
Show all changes
119 commits
Select commit
Hold shift + click to select a range
45d53de
Move PennTreebank, WikiText103, WikiText2 to torchtext.legacy
1f95483
Some initial work.
2d3ebe2
Merge branch 'master' into legacy_language_modeling
97af9d0
Re-write three datasets.
544b069
Merge branch 'master' into legacy_language_modeling
cc127de
Update tests.
97cfd05
Move legacy docs for language modeling dataset.
0ac3e18
Update docs.
56046fa
Minor debug
9962732
Update test.
ad7938e
Minor change in tests.
3ff1cce
Flake8
361f688
Merge branch 'master' into legacy_language_modeling
c2de141
Move imdb to legacy.
cc1ae4d
Move two funct to data/functional.py.
f4018cc
Fix <'unk'> compability issue.
ff329f9
Minor changes.
eb7a6bc
Merge branch 'legacy_language_modeling' into new_imdb
821d7e7
Draft
65c470c
Update unit tests.
d3d6f1b
Bump up version.
b78fc88
Update docs.
aeef855
Minor change.
96cd268
Merge branch 'master' into legacy_language_modeling
25336b9
Minor change
b0c7204
Merge branch 'legacy_language_modeling' into new_imdb
4819f18
Add flags for train/valid/test/
48cb0a8
Update docs.
5a33660
Merge branch 'legacy_language_modeling' into new_imdb
fb1e13a
Move IMDB to text_classification.py
290ef2e
Minor docs.
f49691c
Update docs.
7d70298
Add returned_dataset flag to determin subset data.
885c572
Merge branch 'legacy_language_modeling' into new_imdb
59528f3
Add returned_dataset flag.
aa85215
docs.
0588f1d
A small bug.
f01037d
Remove some printout.
f2ea3f1
Remove unk token.
a32712d
Use data_select.
d217294
Support a string in data_select.
cb902d4
Use torch.tensor instead of torch.Tensor
3a05197
remove duplicate code.
ac99329
Minor change in doc.
3a342c0
Change the extracted_files.
149cbc4
Docs.
6cfe9c9
get_data_path
33f8480
Merge branch 'legacy_language_modeling' into new_imdb
8074e98
revision.
2133a82
docs.
d262da2
Add test.
78c5765
Minor fix.
297d1cc
Remove <unk> token.
c1fc8e7
Merge branch 'legacy_language_modeling' into new_imdb
d548bf6
Replace _data with data.
e77758e
Change create_data_from_iterator to double iter.
6d49f40
Add select_to_index.
1f60293
check subset.
8bb1cb2
Error if dataset is empty.
15095a4
Merge branch 'legacy_language_modeling' into new_imdb
9136678
minor revise.
6a50f2a
filter output is iterable.
fd76dfd
Merge branch 'legacy_language_modeling' into new_imdb
a29f4bd
flake8
26cc92a
Merge branch 'legacy_language_modeling' into new_imdb
3aaaaef
Remove underline.
9206e63
Add a claimer in README.rst
e2ba8bf
revise create_data_from_iterator
0993540
Remove a printout.
1cbc096
Merge branch 'legacy_language_modeling' into new_imdb
81055a0
Remove version num in legacy.
ab04a3b
Merge branch 'legacy_language_modeling' into new_imdb
9dc4752
remove read_text_iterator func
367a340
Update README.
b54b883
Update the test case after not using read_text_iterator
1478d13
rename to numericalize_tokens_from_iterator
675fdc8
Merge branch 'legacy_language_modeling' into new_imdb
f4bfc45
revision.
cf7c188
flake8
03dfc27
minor
3ea2971
Merge branch 'legacy_language_modeling' into new_imdb
bcc9452
resolve comflict
90f4ae3
Add tokenizer to text_classification API
6a7370c
IMDB calls _setup_datasets
acc9bb3
Remove include_unk
b29e87b
materialize tokens at very end.
8e938d2
change version number.
67d33fc
change tokenizer to None.
3aa0ab4
remove tqdm from build_vocab.
e1c1e3b
add data_select option
a4d450f
minor
6143ba6
some changes for pos and neg.
9c87338
Fix typeerror.
caf6ef2
Fix a typo.
f689b46
pass an iterator group to _setup_datasets func.
03d79dc
Fix doc.
bf63dcf
Put tqdm back.
a5aaac3
fix doc
6b5c659
Fix test.
57c5414
minor edit.
ccbcba7
Move new imdb dataset to prototype folder.
eeac195
Move new imdb dataset to prototype folder (part 2).
6cd5017
Move legacy imdb back to the main folder.
cca58bf
add underscore to internal function.
2acb8a1
add _initiate_datasets func
9228a59
update docs
adbda15
add prototype to __init__ file
7be9df2
flake8
50707ae
Flake8
888122c
Remove imdb in legacy
51a3408
add prototype in examples/text_classification
40494bb
docs.
65db243
combine _generate_data_iterators and _generate_imdb_data_iterators fu…
a42ff14
re-name prototype to experimental.
09160bd
Move text classfication datasets back.
55c8a31
Remove text classification datasets from experimental.
92fd553
Remove prototype from examples/text_classification
b26a43d
remove experimental marker from text classification dataset.
367eaf3
split IMDB and text classification into two PRs.
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,13 @@ | |
| from . import utils | ||
| from . import vocab | ||
| from . import legacy | ||
| from . import experimental | ||
|
|
||
| __version__ = '0.4.0' | ||
|
|
||
| __all__ = ['data', | ||
| 'datasets', | ||
| 'utils', | ||
| 'vocab', | ||
| 'legacy'] | ||
| 'legacy', | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You probably need to delete this |
||
| 'experimental'] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from . import datasets | ||
|
|
||
| __all__ = ['datasets'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .text_classification import IMDB | ||
|
|
||
| __all__ = ['IMDB'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| import logging | ||
| import torch | ||
| import io | ||
| from torchtext.utils import download_from_url, extract_archive | ||
| from torchtext.data.utils import ngrams_iterator | ||
| from torchtext.data.utils import get_tokenizer | ||
| from torchtext.vocab import build_vocab_from_iterator | ||
| from torchtext.vocab import Vocab | ||
| from torchtext.datasets import TextClassificationDataset | ||
|
|
||
| URLS = { | ||
| 'IMDB': | ||
| 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz' | ||
| } | ||
|
|
||
|
|
||
| def _create_data_from_iterator(vocab, iterator, removed_tokens): | ||
| for cls, tokens in iterator: | ||
| yield cls, iter(map(lambda x: vocab[x], | ||
| filter(lambda x: x not in removed_tokens, tokens))) | ||
|
|
||
|
|
||
| def _imdb_iterator(key, extracted_files, tokenizer, ngrams, yield_cls=False): | ||
| for fname in extracted_files: | ||
| if 'urls' in fname: | ||
| continue | ||
| elif key in fname and ('pos' in fname or 'neg' in fname): | ||
| with io.open(fname, encoding="utf8") as f: | ||
| label = 1 if 'pos' in fname else 0 | ||
| if yield_cls: | ||
| yield label, ngrams_iterator(tokenizer(f.read()), ngrams) | ||
| else: | ||
| yield ngrams_iterator(tokenizer(f.read()), ngrams) | ||
|
|
||
|
|
||
| def _generate_data_iterators(dataset_name, root, ngrams, tokenizer, data_select): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think you need to repeat all of this, but you can use it from the torchtext/datasets/text_classification file |
||
| if not tokenizer: | ||
| tokenizer = get_tokenizer("basic_english") | ||
|
|
||
| if not set(data_select).issubset(set(('train', 'test'))): | ||
| raise TypeError('Given data selection {} is not supported!'.format(data_select)) | ||
|
|
||
| dataset_tar = download_from_url(URLS[dataset_name], root=root) | ||
| extracted_files = extract_archive(dataset_tar) | ||
|
|
||
| iters_group = {} | ||
| if 'train' in data_select: | ||
| iters_group['vocab'] = _imdb_iterator('train', extracted_files, | ||
| tokenizer, ngrams) | ||
| for item in data_select: | ||
| iters_group[item] = _imdb_iterator(item, extracted_files, | ||
| tokenizer, ngrams, yield_cls=True) | ||
| return iters_group | ||
|
|
||
|
|
||
| def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, | ||
| removed_tokens=[], tokenizer=None, | ||
| data_select=('train', 'test')): | ||
|
|
||
| if isinstance(data_select, str): | ||
| data_select = [data_select] | ||
|
|
||
| iters_group = _generate_data_iterators(dataset_name, root, ngrams, | ||
| tokenizer, data_select) | ||
|
|
||
| if vocab is None: | ||
| if 'vocab' not in iters_group.keys(): | ||
| raise TypeError("Must pass a vocab if train is not selected.") | ||
| logging.info('Building Vocab based on train data') | ||
| vocab = build_vocab_from_iterator(iters_group['vocab']) | ||
| else: | ||
| if not isinstance(vocab, Vocab): | ||
| raise TypeError("Passed vocabulary is not of type Vocab") | ||
| logging.info('Vocab has {} entries'.format(len(vocab))) | ||
|
|
||
| data = {} | ||
| for item in iters_group.keys(): | ||
| data[item] = {} | ||
| data[item]['data'] = [] | ||
| data[item]['labels'] = [] | ||
| logging.info('Creating {} data'.format(item)) | ||
| data_iter = _create_data_from_iterator(vocab, iters_group[item], removed_tokens) | ||
| for cls, tokens in data_iter: | ||
| data[item]['data'].append((torch.tensor(cls), | ||
| torch.tensor([token_id for token_id in tokens]))) | ||
| data[item]['labels'].append(cls) | ||
| data[item]['labels'] = set(data[item]['labels']) | ||
|
|
||
| return tuple(TextClassificationDataset(vocab, data[item]['data'], | ||
| data[item]['labels']) for item in data_select) | ||
|
|
||
|
|
||
| def IMDB(*args, **kwargs): | ||
| """ Defines IMDB datasets. | ||
| The labels includes: | ||
| - 0 : Negative | ||
| - 1 : Positive | ||
| Create sentiment analysis dataset: IMDB | ||
| Separately returns the training and test dataset | ||
| Arguments: | ||
| root: Directory where the datasets are saved. Default: ".data" | ||
| ngrams: a contiguous sequence of n items from s string text. | ||
| Default: 1 | ||
| vocab: Vocabulary used for dataset. If None, it will generate a new | ||
| vocabulary based on the train data set. | ||
| removed_tokens: removed tokens from output dataset (Default: []) | ||
| tokenizer: the tokenizer used to preprocess raw text data. | ||
| The default one is basic_english tokenizer in fastText. spacy tokenizer | ||
| is supported as well. A custom tokenizer is callable | ||
| function with input of a string and output of a token list. | ||
| data_select: a string or tuple for the returned datasets | ||
| (Default: ('train', 'test')) | ||
| By default, all the three datasets (train, test, valid) are generated. Users | ||
| could also choose any one or two of them, for example ('train', 'test') or | ||
| just a string 'train'. If 'train' is not in the tuple or string, a vocab | ||
| object should be provided which will be used to process valid and/or test | ||
| data. | ||
| Examples: | ||
| >>> from torchtext.experimental.datasets import IMDB | ||
| >>> from torchtext.data.utils import get_tokenizer | ||
| >>> train, test = IMDB(ngrams=3) | ||
| >>> tokenizer = get_tokenizer("spacy") | ||
| >>> train, test = IMDB(tokenizer=tokenizer) | ||
| >>> train, = IMDB(tokenizer=tokenizer, data_select='train') | ||
| """ | ||
|
|
||
| return _setup_datasets(*(("IMDB",) + args), **kwargs) | ||
|
|
||
|
|
||
| DATASETS = { | ||
| 'IMDB': IMDB | ||
| } | ||
|
|
||
|
|
||
| LABELS = { | ||
| 'IMDB': {0: 'Negative', | ||
| 1: 'Positive'} | ||
| } | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.