-
-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathsplit.v
More file actions
142 lines (131 loc) · 5.22 KB
/
split.v
File metadata and controls
142 lines (131 loc) · 5.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
module vtl
// array_split splits an array into multiple sub-arrays.
// Please refer to the split documentation. The only difference between
// these functions is that array_split allows indices_or_sections to be an
// integer that does not equally divide the axis. For an array of length
// l that should be split into n sections, it returns l % n sub-arrays of
// size l//n + 1 and the rest of size l//n.
@[direct_array_access]
pub fn (t &Tensor[T]) array_split[T](ind int, axis int) ![]&Tensor[T] {
ntotal := t.shape[axis]
neach := ntotal / ind
extras := ntotal % ind
mut sizes := [0]
sizes << []int{len: extras, init: neach + 1}
sizes << []int{len: ind - extras, init: neach}
mut rt := 0
for i in 0 .. sizes.len {
tmp := rt
rt += sizes[i]
sizes[i] += tmp
}
return t.splitter[T](axis, ind, sizes)
}
// array_split_expl splits an array into multiple sub-arrays.
// Please refer to the split documentation. The only difference between
// these functions is that array_split allows indices_or_sections to be an
// integer that does not equally divide the axis. For an array of length
// l that should be split into n sections, it returns l % n sub-arrays of
// size l//n + 1 and the rest of size l//n.
pub fn (t &Tensor[T]) array_split_expl[T](ind []int, axis int) ![]&Tensor[T] {
nsections := ind.len + 1
mut div_points := [0]
div_points << ind
div_points << [t.shape[axis]]
return t.splitter[T](axis, nsections, div_points)
}
// split splits an array into multiple sub-arrays. The array will be divided into
// N equal arrays along axis. If such a split is not possible,
// panic
pub fn (t &Tensor[T]) split[T](ind int, axis int) ![]&Tensor[T] {
n := t.shape[axis]
if n % ind != 0 {
return error('Array split does not result in an equal division')
}
return t.array_split[T](ind, axis)
}
// split_expl splits an array into multiple sub-arrays. The array will be divided into
// The entries of ind indicate where along axis the array is split.
// For example, [2, 3] would, for axis=0, result in:
// ary[:2]
// ary[2:3]
// ary[3:]
pub fn (t &Tensor[T]) split_expl[T](ind []int, axis int) ![]&Tensor[T] {
return t.array_split_expl[T](ind, axis)
}
// hsplit splits an array into multiple sub-arrays horizontally (column-wise).
// Please refer to the split documentation. hsplit is equivalent to
// split with axis=1, the array is always split along the second axis
// regardless of the array dimension.
pub fn (t &Tensor[T]) hsplit[T](ind int) ![]&Tensor[T] {
return match t.rank() {
1 { t.split[T](ind, 0)! }
else { t.split[T](ind, 1)! }
}
}
// hsplit_expl splits an array into multiple sub-arrays horizontally (column-wise)
// Please refer to the split documentation. hsplit is equivalent to
// split with axis=1, the array is always split along the second axis
// regardless of the array dimension.
pub fn (t &Tensor[T]) hsplit_expl[T](ind []int) ![]&Tensor[T] {
return match t.rank() {
1 { t.split_expl[T](ind, 0)! }
else { t.split_expl[T](ind, 1)! }
}
}
// vsplit splits an array into multiple sub-arrays vertically (row-wise).
// Please refer to the split documentation. vsplit is equivalent to
// split with axis=0 (default), the array is always split along the
// first axis regardless of the array dimension.
pub fn (t &Tensor[T]) vsplit[T](ind int) ![]&Tensor[T] {
if t.rank() < 2 {
return error('vsplit only works on tensors of >= 2 dimensions')
}
return t.split[T](ind, 0)
}
// vsplit_expl splits an array into multiple sub-arrays vertically (row-wise).
// Please refer to the split documentation. vsplit is equivalent to
// split with axis=0 (default), the array is always split along the
// first axis regardless of the array dimension.
pub fn (t &Tensor[T]) vsplit_expl[T](ind []int) ![]&Tensor[T] {
if t.rank() < 2 {
return error('vsplit only works on tensors of >= 2 dimensions')
}
return t.split_expl[T](ind, 0)
}
// dsplit splits array into multiple sub-arrays along the 3rd axis (depth).
// Please refer to the split documentation. dsplit is equivalent to
// split with axis=2, the array is always split along the third axis
// provided the array dimension is greater than or equal to 3.
pub fn (t &Tensor[T]) dsplit[T](ind int) ![]&Tensor[T] {
if t.rank() < 3 {
return error('dsplit only works on arrays of 3 or more dimensions')
}
return t.split[T](ind, 2)
}
// dsplit_expl splits array into multiple sub-arrays along the 3rd axis (depth).
// Please refer to the split documentation. dsplit is equivalent to
// split with axis=2, the array is always split along the third axis
// provided the array dimension is greater than or equal to 3.
pub fn (t &Tensor[T]) dsplit_expl[T](ind []int) ![]&Tensor[T] {
if t.rank() < 3 {
return error('dsplit only works on arrays of 3 or more dimensions')
}
return t.split_expl[T](ind, 2)
}
// splitter implements a generic splitting function that contains the underlying functionality
// for all split operations
@[direct_array_access]
fn (t &Tensor[T]) splitter[T](axis int, n int, div_points []int) ![]&Tensor[T] {
if n > 0 && div_points.len <= n {
return error('splitter error, div_points.len <= n')
}
mut subary := []&Tensor[T]{}
sary := t.swapaxes(axis, 0)!
for i in 0 .. n {
st := div_points[i]
en := div_points[i + 1]
subary << sary.slice_hilo([st], [en])!.swapaxes(axis, 0)!
}
return subary
}