Skip to content

Commit

Permalink
Fix serialization bug (OpenNMT#1188)
Browse files Browse the repository at this point in the history
* In dataset base, remove __reduce_ex__ and override __getattr__.
``torchtext.Dataset.__getattr__`` is a generator. That doesn't
play well with pickle. Returning a generator (when appropriate)
seems to fix the issue without changing API.
  • Loading branch information
flauted authored and vince62s committed Jan 21, 2019
1 parent 0185202 commit d94e6af
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ class DatasetBase(Dataset):
the same structure as in the fields argument passed to the constructor.
"""

def __getstate__(self):
return self.__dict__

def __setstate__(self, _d):
self.__dict__.update(_d)

def __reduce_ex__(self, proto):
# This is a hack. Something is broken with torch pickle.
return super(DatasetBase, self).__reduce_ex__()

def __init__(self, fields, src_examples_iter, tgt_examples_iter,
filter_pred=None):

Expand Down Expand Up @@ -90,6 +80,15 @@ def __init__(self, fields, src_examples_iter, tgt_examples_iter,

super(DatasetBase, self).__init__(examples, fields, filter_pred)

def __getattr__(self, attr):
# avoid infinite recursion when fields isn't defined
if 'fields' not in vars(self):
raise AttributeError
if attr in self.fields:
return (getattr(x, attr) for x in self.examples)
else:
raise AttributeError

def save(self, path, remove_fields=True):
if remove_fields:
self.fields = []
Expand Down

0 comments on commit d94e6af

Please sign in to comment.