diff --git a/pytext/data/tensorizers.py b/pytext/data/tensorizers.py index ce3e10ab2..e07168e8b 100644 --- a/pytext/data/tensorizers.py +++ b/pytext/data/tensorizers.py @@ -2,6 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import json +import re from typing import List, Optional, Tuple, Type import torch @@ -226,10 +227,6 @@ def tensorize(self, batch): return (pad_and_tensorize(characters), pad_and_tensorize(lengths)) -class FloatVectorTensorizer(Tensorizer): - """TODO: support for dense features.""" - - class LabelTensorizer(Tensorizer): """Numberize labels.""" @@ -329,7 +326,7 @@ def __init__(self, column: str): self.column = column def numberize(self, row): - res = json.loads(row[self.column].replace(" ", ",")) + res = json.loads(re.sub(r",? +", ",", row[self.column])) if type(res) is not list: raise ValueError(f"{res} is not a valid float list") return [float(n) for n in res] diff --git a/pytext/data/test/tensorizers_test.py b/pytext/data/test/tensorizers_test.py index c738e65b7..a56ae0c05 100644 --- a/pytext/data/test/tensorizers_test.py +++ b/pytext/data/test/tensorizers_test.py @@ -9,6 +9,7 @@ from pytext.data.tensorizers import ( ByteTensorizer, CharacterTokenTensorizer, + FloatListTensorizer, LabelTensorizer, TokenTensorizer, initialize_tensorizers, @@ -138,3 +139,17 @@ def test_create_label_tensors(self): self.assertEqual(1, tensor) with self.assertRaises(Exception): tensor = next(tensors) + + def test_create_float_list_tensor(self): + tensorizer = FloatListTensorizer(column="dense") + rows = [ + {"dense": "[0.1,0.2]"}, # comma + {"dense": "[0.1, 0.2]"}, # comma with single space + {"dense": "[0.1, 0.2]"}, # comma with multiple spaces + {"dense": "[0.1 0.2]"}, # space + {"dense": "[0.1 0.2]"}, # multiple spaces + ] + + tensors = (tensorizer.numberize(row) for row in rows) + for tensor in tensors: + self.assertEqual([0.1, 0.2], tensor)