forked from votrubac/pytext
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils.py
77 lines (60 loc) · 2.51 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
from enum import Enum
from typing import List, NamedTuple, Optional
from pytext.utils.path import PYTEXT_HOME
TEST_DATA_DIR = os.environ.get(
"PYTEXT_TEST_DATA", os.path.join(PYTEXT_HOME, "tests/data")
)
TEST_CONFIG_DIR = os.environ.get(
"PYTEXT_TEST_CONFIG", os.path.join(PYTEXT_HOME, "demo/configs")
)
def test_file(filename):
return os.path.join(TEST_DATA_DIR, filename)
class TestFileName(Enum):
def __str__(self):
return str(self.value)
TRAIN_DENSE_FEATURES_TINY_TSV = "train_dense_features_tiny.tsv"
TEST_PERSONALIZATION_OPPOSITE_INPUTS_TSV = (
"test_personalization_opposite_inputs.tsv"
)
TEST_PERSONALIZATION_SAME_INPUTS_TSV = "test_personalization_same_inputs.tsv"
TEST_PERSONALIZATION_SINGLE_USER_TSV = "test_personalization_single_user.tsv"
class TestFileMetadata(NamedTuple):
filename: str
field_names: Optional[List[str]] = None
dense_col_name: Optional[str] = None
dense_feat_dim: Optional[int] = None
uid_col_name: Optional[str] = None
TEST_FILE_NAME_TO_METADATA = {
TestFileName.TRAIN_DENSE_FEATURES_TINY_TSV: TestFileMetadata(
filename=test_file(str(TestFileName.TRAIN_DENSE_FEATURES_TINY_TSV)),
field_names=["label", "slots", "text", "dense_features"],
dense_col_name="dense_features",
dense_feat_dim=10,
),
TestFileName.TEST_PERSONALIZATION_OPPOSITE_INPUTS_TSV: TestFileMetadata(
filename=test_file(str(TestFileName.TEST_PERSONALIZATION_OPPOSITE_INPUTS_TSV)),
field_names=["label", "text", "dense_features", "uid"],
dense_col_name="dense_features",
dense_feat_dim=10,
uid_col_name="uid",
),
TestFileName.TEST_PERSONALIZATION_SAME_INPUTS_TSV: TestFileMetadata(
filename=test_file(str(TestFileName.TEST_PERSONALIZATION_SAME_INPUTS_TSV)),
field_names=["label", "text", "dense_features", "uid"],
dense_col_name="dense_features",
dense_feat_dim=10,
uid_col_name="uid",
),
TestFileName.TEST_PERSONALIZATION_SINGLE_USER_TSV: TestFileMetadata(
filename=test_file(str(TestFileName.TEST_PERSONALIZATION_SINGLE_USER_TSV)),
field_names=["label", "text", "dense_features", "uid"],
dense_col_name="dense_features",
dense_feat_dim=10,
uid_col_name="uid",
),
}
def get_test_file_metadata(test_file_id: TestFileName) -> TestFileMetadata:
return TEST_FILE_NAME_TO_METADATA[test_file_id]