1- import torch
2- from collections import defaultdict
31import re
2+ from collections import defaultdict
3+
4+ import torch
45
5- from fastNLP .core .dataset import DataSet
6- from fastNLP .core .vocabulary import Vocabulary
76from fastNLP .core .batch import Batch
7+ from fastNLP .core .dataset import DataSet
88from fastNLP .core .sampler import SequentialSampler
9+ from fastNLP .core .vocabulary import Vocabulary
910
1011
11- class Processor :
12+ class Processor ( object ) :
1213 def __init__ (self , field_name , new_added_field_name ):
1314 self .field_name = field_name
1415 if new_added_field_name is None :
@@ -17,7 +18,7 @@ def __init__(self, field_name, new_added_field_name):
1718 self .new_added_field_name = new_added_field_name
1819
1920 def process (self , * args , ** kwargs ):
20- pass
21+ raise NotImplementedError
2122
2223 def __call__ (self , * args , ** kwargs ):
2324 return self .process (* args , ** kwargs )
@@ -132,27 +133,29 @@ def process(self, dataset):
132133
133134
134135class IndexerProcessor (Processor ):
135- def __init__ (self , vocab , field_name , new_added_field_name , delete_old_field = False ):
136+ def __init__ (self , vocab , field_name , new_added_field_name , delete_old_field = False , is_input = True ):
136137
137138 assert isinstance (vocab , Vocabulary ), "Only Vocabulary class is allowed, not {}." .format (type (vocab ))
138139
139140 super (IndexerProcessor , self ).__init__ (field_name , new_added_field_name )
140141 self .vocab = vocab
141142 self .delete_old_field = delete_old_field
143+ self .is_input = is_input
142144
143145 def set_vocab (self , vocab ):
144146 assert isinstance (vocab , Vocabulary ), "Only Vocabulary class is allowed, not {}." .format (type (vocab ))
145147
146148 self .vocab = vocab
147149
148150 def process (self , dataset ):
149- assert isinstance (dataset , DataSet ), "Only Dataset class is allowed, not {}." .format (type (dataset ))
151+ assert isinstance (dataset , DataSet ), "Only DataSet class is allowed, not {}." .format (type (dataset ))
150152 for ins in dataset :
151153 tokens = ins [self .field_name ]
152154 index = [self .vocab .to_index (token ) for token in tokens ]
153155 ins [self .new_added_field_name ] = index
154156
155- dataset ._set_need_tensor (** {self .new_added_field_name : True })
157+ if self .is_input :
158+ dataset .set_input (self .new_added_field_name )
156159
157160 if self .delete_old_field :
158161 dataset .delete_field (self .field_name )
@@ -161,6 +164,9 @@ def process(self, dataset):
161164
162165
163166class VocabProcessor (Processor ):
167+ """Build vocabulary with a field in the data set.
168+
169+ """
164170 def __init__ (self , field_name ):
165171 super (VocabProcessor , self ).__init__ (field_name , None )
166172 self .vocab = Vocabulary ()
@@ -178,17 +184,20 @@ def get_vocab(self):
178184
179185
180186class SeqLenProcessor (Processor ):
181- def __init__ (self , field_name , new_added_field_name = 'seq_lens' ):
187+ def __init__ (self , field_name , new_added_field_name = 'seq_lens' , is_input = True ):
182188 super (SeqLenProcessor , self ).__init__ (field_name , new_added_field_name )
189+ self .is_input = is_input
183190
184191 def process (self , dataset ):
185192 assert isinstance (dataset , DataSet ), "Only Dataset class is allowed, not {}." .format (type (dataset ))
186193 for ins in dataset :
187194 length = len (ins [self .field_name ])
188195 ins [self .new_added_field_name ] = length
189- dataset ._set_need_tensor (** {self .new_added_field_name : True })
196+ if self .is_input :
197+ dataset .set_input (self .new_added_field_name )
190198 return dataset
191199
200+
192201class ModelProcessor (Processor ):
193202 def __init__ (self , model , seq_len_field_name = 'seq_lens' , batch_size = 32 ):
194203 """
@@ -238,6 +247,7 @@ def set_model_device(self, device):
238247 device = torch .device (device )
239248 self .model .to (device )
240249
250+
241251class Index2WordProcessor (Processor ):
242252 def __init__ (self , vocab , field_name , new_added_field_name ):
243253 super (Index2WordProcessor , self ).__init__ (field_name , new_added_field_name )
@@ -251,26 +261,28 @@ def process(self, dataset):
251261
252262
253263class SetTensorProcessor (Processor ):
264+ # TODO: remove it. It is strange.
254265 def __init__ (self , field_dict , default = False ):
255266 super (SetTensorProcessor , self ).__init__ (None , None )
256267 self .field_dict = field_dict
257268 self .default = default
258269
259270 def process (self , dataset ):
260- set_dict = {name : self .default for name in dataset .get_fields ().keys ()}
271+ set_dict = {name : self .default for name in dataset .get_all_fields ().keys ()}
261272 set_dict .update (self .field_dict )
262273 dataset ._set_need_tensor (** set_dict )
263274 return dataset
264275
265276
266277class SetIsTargetProcessor (Processor ):
278+ # TODO; remove it.
267279 def __init__ (self , field_dict , default = False ):
268280 super (SetIsTargetProcessor , self ).__init__ (None , None )
269281 self .field_dict = field_dict
270282 self .default = default
271283
272284 def process (self , dataset ):
273- set_dict = {name : self .default for name in dataset .get_fields ().keys ()}
285+ set_dict = {name : self .default for name in dataset .get_all_fields ().keys ()}
274286 set_dict .update (self .field_dict )
275287 dataset .set_target (** set_dict )
276288 return dataset
0 commit comments