Skip to content

Commit

Permalink
Fix FloatListTensorizer (facebookresearch#450)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#450

1 make FloatListTensorizer work with comma/space as separater
2 remove duplicated FloatVectorTensorizer

Reviewed By: gardenia22

Differential Revision: D14791878

fbshipit-source-id: 4396a71c8eb2d2494774fbbd23c94e947dc072c7
  • Loading branch information
seayoung1112 authored and facebook-github-bot committed Apr 5, 2019
1 parent f501bf9 commit 6abb019
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
7 changes: 2 additions & 5 deletions pytext/data/tensorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import json
import re
from typing import List, Optional, Tuple, Type

import torch
Expand Down Expand Up @@ -226,10 +227,6 @@ def tensorize(self, batch):
return (pad_and_tensorize(characters), pad_and_tensorize(lengths))


class FloatVectorTensorizer(Tensorizer):
"""TODO: support for dense features."""


class LabelTensorizer(Tensorizer):
"""Numberize labels."""

Expand Down Expand Up @@ -329,7 +326,7 @@ def __init__(self, column: str):
self.column = column

def numberize(self, row):
res = json.loads(row[self.column].replace(" ", ","))
res = json.loads(re.sub(r",? +", ",", row[self.column]))
if type(res) is not list:
raise ValueError(f"{res} is not a valid float list")
return [float(n) for n in res]
Expand Down
15 changes: 15 additions & 0 deletions pytext/data/test/tensorizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytext.data.tensorizers import (
ByteTensorizer,
CharacterTokenTensorizer,
FloatListTensorizer,
LabelTensorizer,
TokenTensorizer,
initialize_tensorizers,
Expand Down Expand Up @@ -138,3 +139,17 @@ def test_create_label_tensors(self):
self.assertEqual(1, tensor)
with self.assertRaises(Exception):
tensor = next(tensors)

def test_create_float_list_tensor(self):
tensorizer = FloatListTensorizer(column="dense")
rows = [
{"dense": "[0.1,0.2]"}, # comma
{"dense": "[0.1, 0.2]"}, # comma with single space
{"dense": "[0.1, 0.2]"}, # comma with multiple spaces
{"dense": "[0.1 0.2]"}, # space
{"dense": "[0.1 0.2]"}, # multiple spaces
]

tensors = (tensorizer.numberize(row) for row in rows)
for tensor in tensors:
self.assertEqual([0.1, 0.2], tensor)

0 comments on commit 6abb019

Please sign in to comment.