1
1
from dataclasses import dataclass
2
- from typing import Iterable , Tuple , Optional , Sequence
2
+ from typing import Iterable , Tuple , Optional , Sequence , List , cast
3
3
4
4
import torch
5
5
6
6
7
7
@dataclass
8
8
class Path :
9
- from_token : torch . Tensor # [max token parts]
10
- path_node : torch . Tensor # [path length]
11
- to_token : torch . Tensor # [max token parts]
9
+ from_token : List [ int ] # [max token parts]
10
+ path_node : List [ int ] # [path length]
11
+ to_token : List [ int ] # [max token parts]
12
12
13
13
14
14
@dataclass
15
15
class LabeledPathContext :
16
- label : torch . Tensor # [max label parts]
16
+ label : List [ int ] # [max label parts]
17
17
path_contexts : Sequence [Path ]
18
18
19
19
20
+ def transpose (list_of_lists : List [List [int ]]) -> List [List [int ]]:
21
+ return [cast (List [int ], it ) for it in zip (* list_of_lists )]
22
+
23
+
20
24
class BatchedLabeledPathContext :
21
25
def __init__ (self , all_samples : Sequence [Optional [LabeledPathContext ]]):
22
26
samples = [s for s in all_samples if s is not None ]
23
27
24
28
# [max label parts; batch size]
25
- self .labels = torch .cat ( [s .label . unsqueeze ( 1 ) for s in samples ], dim = 1 )
29
+ self .labels = torch .tensor ( transpose ( [s .label for s in samples ]), dtype = torch . long )
26
30
# [batch size]
27
31
self .contexts_per_label = torch .tensor ([len (s .path_contexts ) for s in samples ])
28
32
29
33
# [max token parts; n contexts]
30
- self .from_token = torch .cat ([path .from_token .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
34
+ self .from_token = torch .tensor (
35
+ transpose ([path .from_token for s in samples for path in s .path_contexts ]), dtype = torch .long
36
+ )
31
37
# [path length; n contexts]
32
- self .path_nodes = torch .cat ([path .path_node .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
38
+ self .path_nodes = torch .tensor (
39
+ transpose ([path .path_node for s in samples for path in s .path_contexts ]), dtype = torch .long
40
+ )
33
41
# [max token parts; n contexts]
34
- self .to_token = torch .cat ([path .to_token .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
42
+ self .to_token = torch .tensor (
43
+ transpose ([path .to_token for s in samples for path in s .path_contexts ]), dtype = torch .long
44
+ )
35
45
36
46
def __len__ (self ) -> int :
37
47
return len (self .contexts_per_label )
@@ -53,8 +63,8 @@ def move_to_device(self, device: torch.device):
53
63
54
64
@dataclass
55
65
class TypedPath (Path ):
56
- from_type : torch . Tensor # [max type parts]
57
- to_type : torch . Tensor # [max type parts]
66
+ from_type : List [ int ] # [max type parts]
67
+ to_type : List [ int ] # [max type parts]
58
68
59
69
60
70
@dataclass
@@ -67,6 +77,10 @@ def __init__(self, all_samples: Sequence[Optional[LabeledTypedPathContext]]):
67
77
super ().__init__ (all_samples )
68
78
samples = [s for s in all_samples if s is not None ]
69
79
# [max type parts; n contexts]
70
- self .from_type = torch .cat ([path .from_type .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
80
+ self .from_type = torch .tensor (
81
+ transpose ([path .from_type for s in samples for path in s .path_contexts ]), dtype = torch .long
82
+ )
71
83
# [max type parts; n contexts]
72
- self .to_type = torch .cat ([path .to_type .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
84
+ self .to_type = torch .tensor (
85
+ transpose ([path .to_type for s in samples for path in s .path_contexts ]), dtype = torch .long
86
+ )
0 commit comments