Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
45d53de
Move PennTreebank, WikiText103, WikiText2 to torchtext.legacy
Oct 23, 2019
1f95483
Some initial work.
Oct 25, 2019
2d3ebe2
Merge branch 'master' into legacy_language_modeling
Oct 25, 2019
97af9d0
Re-write three datasets.
Oct 29, 2019
544b069
Merge branch 'master' into legacy_language_modeling
Oct 29, 2019
cc127de
Update tests.
Oct 29, 2019
97cfd05
Move legacy docs for language modeling dataset.
Oct 29, 2019
0ac3e18
Update docs.
Oct 29, 2019
56046fa
Minor debug
Oct 31, 2019
9962732
Update test.
Oct 31, 2019
ad7938e
Minor change in tests.
Oct 31, 2019
3ff1cce
Flake8
Oct 31, 2019
361f688
Merge branch 'master' into legacy_language_modeling
Nov 1, 2019
c2de141
Move imdb to legacy.
Nov 4, 2019
cc1ae4d
Move two funct to data/functional.py.
Nov 5, 2019
f4018cc
Fix <'unk'> compability issue.
Nov 5, 2019
ff329f9
Minor changes.
Nov 5, 2019
eb7a6bc
Merge branch 'legacy_language_modeling' into new_imdb
Nov 5, 2019
821d7e7
Draft
Nov 5, 2019
65c470c
Update unit tests.
Nov 5, 2019
d3d6f1b
Bump up version.
Nov 6, 2019
b78fc88
Update docs.
Nov 6, 2019
aeef855
Minor change.
Nov 6, 2019
96cd268
Merge branch 'master' into legacy_language_modeling
Nov 11, 2019
25336b9
Minor change
Nov 11, 2019
b0c7204
Merge branch 'legacy_language_modeling' into new_imdb
Nov 18, 2019
4819f18
Add flags for train/valid/test/
Nov 18, 2019
48cb0a8
Update docs.
Nov 19, 2019
5a33660
Merge branch 'legacy_language_modeling' into new_imdb
Nov 19, 2019
fb1e13a
Move IMDB to text_classification.py
Nov 19, 2019
290ef2e
Minor docs.
Nov 19, 2019
f49691c
Update docs.
Nov 20, 2019
7d70298
Add returned_dataset flag to determin subset data.
Nov 20, 2019
885c572
Merge branch 'legacy_language_modeling' into new_imdb
Nov 20, 2019
59528f3
Add returned_dataset flag.
Nov 20, 2019
aa85215
docs.
Nov 20, 2019
0588f1d
A small bug.
Nov 20, 2019
f01037d
Remove some printout.
Nov 21, 2019
f2ea3f1
Remove unk token.
Nov 21, 2019
a32712d
Use data_select.
Nov 21, 2019
d217294
Support a string in data_select.
Nov 21, 2019
cb902d4
Use torch.tensor instead of torch.Tensor
Nov 21, 2019
3a05197
remove duplicate code.
Nov 21, 2019
ac99329
Minor change in doc.
Nov 21, 2019
3a342c0
Change the extracted_files.
Nov 21, 2019
149cbc4
Docs.
Nov 21, 2019
6cfe9c9
get_data_path
Nov 21, 2019
33f8480
Merge branch 'legacy_language_modeling' into new_imdb
Nov 21, 2019
8074e98
revision.
Nov 21, 2019
2133a82
docs.
Nov 21, 2019
d262da2
Add test.
Nov 21, 2019
78c5765
Minor fix.
Nov 21, 2019
297d1cc
Remove <unk> token.
Nov 22, 2019
c1fc8e7
Merge branch 'legacy_language_modeling' into new_imdb
Nov 22, 2019
d548bf6
Replace _data with data.
Nov 22, 2019
e77758e
Change create_data_from_iterator to double iter.
Nov 22, 2019
6d49f40
Add select_to_index.
Nov 22, 2019
1f60293
check subset.
Nov 22, 2019
8bb1cb2
Error if dataset is empty.
Nov 22, 2019
15095a4
Merge branch 'legacy_language_modeling' into new_imdb
Nov 22, 2019
9136678
minor revise.
Nov 22, 2019
6a50f2a
filter output is iterable.
Nov 25, 2019
fd76dfd
Merge branch 'legacy_language_modeling' into new_imdb
Nov 25, 2019
a29f4bd
flake8
Nov 25, 2019
26cc92a
Merge branch 'legacy_language_modeling' into new_imdb
Nov 25, 2019
3aaaaef
Remove underline.
Nov 25, 2019
9206e63
Add a claimer in README.rst
Nov 25, 2019
e2ba8bf
revise create_data_from_iterator
Nov 25, 2019
0993540
Remove a printout.
Nov 25, 2019
1cbc096
Merge branch 'legacy_language_modeling' into new_imdb
Nov 25, 2019
81055a0
Remove version num in legacy.
Nov 25, 2019
ab04a3b
Merge branch 'legacy_language_modeling' into new_imdb
Nov 25, 2019
9dc4752
remove read_text_iterator func
Nov 26, 2019
367a340
Update README.
Nov 26, 2019
b54b883
Update the test case after not using read_text_iterator
Nov 26, 2019
1478d13
rename to numericalize_tokens_from_iterator
Nov 26, 2019
675fdc8
Merge branch 'legacy_language_modeling' into new_imdb
Nov 26, 2019
f4bfc45
revision.
Nov 26, 2019
cf7c188
flake8
Nov 26, 2019
03dfc27
minor
Nov 26, 2019
3ea2971
Merge branch 'legacy_language_modeling' into new_imdb
Nov 26, 2019
bcc9452
resolve comflict
Nov 26, 2019
90f4ae3
Add tokenizer to text_classification API
Nov 26, 2019
6a7370c
IMDB calls _setup_datasets
Nov 26, 2019
acc9bb3
Remove include_unk
Nov 26, 2019
b29e87b
materialize tokens at very end.
Nov 26, 2019
8e938d2
change version number.
Nov 26, 2019
67d33fc
change tokenizer to None.
Nov 26, 2019
3aa0ab4
remove tqdm from build_vocab.
Nov 26, 2019
e1c1e3b
add data_select option
Nov 26, 2019
a4d450f
minor
Nov 26, 2019
6143ba6
some changes for pos and neg.
Nov 27, 2019
9c87338
Fix typeerror.
Nov 27, 2019
caf6ef2
Fix a typo.
Nov 27, 2019
f689b46
pass an iterator group to _setup_datasets func.
Nov 27, 2019
03d79dc
Fix doc.
Nov 27, 2019
bf63dcf
Put tqdm back.
Nov 27, 2019
a5aaac3
fix doc
Nov 27, 2019
6b5c659
Fix test.
Nov 27, 2019
57c5414
minor edit.
Nov 30, 2019
ccbcba7
Move new imdb dataset to prototype folder.
Dec 2, 2019
eeac195
Move new imdb dataset to prototype folder (part 2).
Dec 2, 2019
6cd5017
Move legacy imdb back to the main folder.
Dec 2, 2019
cca58bf
add underscore to internal function.
Dec 2, 2019
2acb8a1
add _initiate_datasets func
Dec 2, 2019
9228a59
update docs
Dec 2, 2019
adbda15
add prototype to __init__ file
Dec 2, 2019
7be9df2
flake8
Dec 2, 2019
50707ae
Flake8
Dec 2, 2019
888122c
Remove imdb in legacy
Dec 2, 2019
51a3408
add prototype in examples/text_classification
Dec 2, 2019
40494bb
docs.
Dec 2, 2019
65db243
combine _generate_data_iterators and _generate_imdb_data_iterators fu…
Dec 3, 2019
a42ff14
re-name prototype to experimental.
Dec 3, 2019
09160bd
Move text classfication datasets back.
Dec 3, 2019
55c8a31
Remove text classification datasets from experimental.
Dec 3, 2019
92fd553
Remove prototype from examples/text_classification
Dec 3, 2019
b26a43d
remove experimental marker from text classification dataset.
Dec 3, 2019
367eaf3
split IMDB and text classification into two PRs.
Dec 4, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,19 @@ def test_text_classification(self):
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", "ag_news_csv.tar.gz")
conditional_remove(datafile)

@slow
def test_imdb(self):
from torchtext.experimental.datasets import IMDB
# smoke test to ensure wikitext2 works properly
train_dataset, test_dataset = IMDB()
self.assertEqual(len(train_dataset), 25000)
self.assertEqual(len(test_dataset), 25000)

# Delete the dataset after we're done to save disk space on CI
datafile = os.path.join(self.project_root, ".data", "imdb")
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", "aclImdb")
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", "aclImdb_v1.tar.gz")
conditional_remove(datafile)
4 changes: 3 additions & 1 deletion torchtext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from . import utils
from . import vocab
from . import legacy
from . import experimental

__version__ = '0.4.0'

__all__ = ['data',
'datasets',
'utils',
'vocab',
'legacy']
'legacy',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably need to delete this

'experimental']
3 changes: 3 additions & 0 deletions torchtext/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import datasets

__all__ = ['datasets']
3 changes: 3 additions & 0 deletions torchtext/experimental/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .text_classification import IMDB

__all__ = ['IMDB']
142 changes: 142 additions & 0 deletions torchtext/experimental/datasets/text_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import logging
import torch
import io
from torchtext.utils import download_from_url, extract_archive
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.vocab import Vocab
from torchtext.datasets import TextClassificationDataset

URLS = {
'IMDB':
'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
}


def _create_data_from_iterator(vocab, iterator, removed_tokens):
for cls, tokens in iterator:
yield cls, iter(map(lambda x: vocab[x],
filter(lambda x: x not in removed_tokens, tokens)))


def _imdb_iterator(key, extracted_files, tokenizer, ngrams, yield_cls=False):
for fname in extracted_files:
if 'urls' in fname:
continue
elif key in fname and ('pos' in fname or 'neg' in fname):
with io.open(fname, encoding="utf8") as f:
label = 1 if 'pos' in fname else 0
if yield_cls:
yield label, ngrams_iterator(tokenizer(f.read()), ngrams)
else:
yield ngrams_iterator(tokenizer(f.read()), ngrams)


def _generate_data_iterators(dataset_name, root, ngrams, tokenizer, data_select):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to repeat all of this, but you can use it from the torchtext/datasets/text_classification file

if not tokenizer:
tokenizer = get_tokenizer("basic_english")

if not set(data_select).issubset(set(('train', 'test'))):
raise TypeError('Given data selection {} is not supported!'.format(data_select))

dataset_tar = download_from_url(URLS[dataset_name], root=root)
extracted_files = extract_archive(dataset_tar)

iters_group = {}
if 'train' in data_select:
iters_group['vocab'] = _imdb_iterator('train', extracted_files,
tokenizer, ngrams)
for item in data_select:
iters_group[item] = _imdb_iterator(item, extracted_files,
tokenizer, ngrams, yield_cls=True)
return iters_group


def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None,
removed_tokens=[], tokenizer=None,
data_select=('train', 'test')):

if isinstance(data_select, str):
data_select = [data_select]

iters_group = _generate_data_iterators(dataset_name, root, ngrams,
tokenizer, data_select)

if vocab is None:
if 'vocab' not in iters_group.keys():
raise TypeError("Must pass a vocab if train is not selected.")
logging.info('Building Vocab based on train data')
vocab = build_vocab_from_iterator(iters_group['vocab'])
else:
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")
logging.info('Vocab has {} entries'.format(len(vocab)))

data = {}
for item in iters_group.keys():
data[item] = {}
data[item]['data'] = []
data[item]['labels'] = []
logging.info('Creating {} data'.format(item))
data_iter = _create_data_from_iterator(vocab, iters_group[item], removed_tokens)
for cls, tokens in data_iter:
data[item]['data'].append((torch.tensor(cls),
torch.tensor([token_id for token_id in tokens])))
data[item]['labels'].append(cls)
data[item]['labels'] = set(data[item]['labels'])

return tuple(TextClassificationDataset(vocab, data[item]['data'],
data[item]['labels']) for item in data_select)


def IMDB(*args, **kwargs):
""" Defines IMDB datasets.
The labels includes:
- 0 : Negative
- 1 : Positive
Create sentiment analysis dataset: IMDB
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
ngrams: a contiguous sequence of n items from s string text.
Default: 1
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
removed_tokens: removed tokens from output dataset (Default: [])
tokenizer: the tokenizer used to preprocess raw text data.
The default one is basic_english tokenizer in fastText. spacy tokenizer
is supported as well. A custom tokenizer is callable
function with input of a string and output of a token list.
data_select: a string or tuple for the returned datasets
(Default: ('train', 'test'))
By default, all the three datasets (train, test, valid) are generated. Users
could also choose any one or two of them, for example ('train', 'test') or
just a string 'train'. If 'train' is not in the tuple or string, a vocab
object should be provided which will be used to process valid and/or test
data.
Examples:
>>> from torchtext.experimental.datasets import IMDB
>>> from torchtext.data.utils import get_tokenizer
>>> train, test = IMDB(ngrams=3)
>>> tokenizer = get_tokenizer("spacy")
>>> train, test = IMDB(tokenizer=tokenizer)
>>> train, = IMDB(tokenizer=tokenizer, data_select='train')
"""

return _setup_datasets(*(("IMDB",) + args), **kwargs)


DATASETS = {
'IMDB': IMDB
}


LABELS = {
'IMDB': {0: 'Negative',
1: 'Positive'}
}