Skip to content

Commit 38c1ec5

Browse files
committed
more standard casting procedure
1 parent 23492ff commit 38c1ec5

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

vector32.go

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ func (v *vector32) SetData(d anyvec.NumericList) {
4646
})
4747
}
4848

49-
func (v *vector32) Set(v1 anyvec.Vector) {
49+
func (v *vector32) Set(other anyvec.Vector) {
50+
v1 := other.(*vector32)
5051
v.assertCompat(v1, false)
5152
v.run(func() error {
52-
buf1 := v1.(*vector32).buffer
53+
buf1 := v1.buffer
5354
if buf1 == nil {
5455
if v.buffer != nil {
5556
return cuda.ClearBuffer(v.buffer)
@@ -89,16 +90,16 @@ func (v *vector32) Slice(start, end int) anyvec.Vector {
8990
return res
9091
}
9192

92-
func (v *vector32) SetSlice(start int, v1 anyvec.Vector) {
93-
src := v1.(*vector32)
94-
if src.Len() > v.Len()-start {
93+
func (v *vector32) SetSlice(start int, other anyvec.Vector) {
94+
v1 := other.(*vector32)
95+
if v1.Len() > v.Len()-start {
9596
panic("index out of range")
96-
} else if start <= -src.Len() {
97+
} else if start <= -v1.Len() {
9798
return
9899
}
99100

100101
v.run(func() error {
101-
if src.buffer == nil && v.buffer == nil {
102+
if v.buffer == nil && v1.buffer == nil {
102103
return nil
103104
}
104105
dstStart := start
@@ -107,7 +108,7 @@ func (v *vector32) SetSlice(start int, v1 anyvec.Vector) {
107108
dstStart = 0
108109
srcStart = -start
109110
}
110-
copyCount := src.Len() - srcStart
111+
copyCount := v1.Len() - srcStart
111112
if v.Len()-dstStart < copyCount {
112113
copyCount = v.Len() - dstStart
113114
}
@@ -116,10 +117,10 @@ func (v *vector32) SetSlice(start int, v1 anyvec.Vector) {
116117
}
117118
dst := cuda.Slice(v.buffer, uintptr(dstStart)*4,
118119
uintptr(dstStart+copyCount)*4)
119-
if src.buffer == nil {
120+
if v1.buffer == nil {
120121
return cuda.ClearBuffer(dst)
121122
}
122-
srcSlice := cuda.Slice(src.buffer, uintptr(srcStart)*4,
123+
srcSlice := cuda.Slice(v1.buffer, uintptr(srcStart)*4,
123124
uintptr(srcStart+copyCount)*4)
124125
return cuda.CopyBuffer(dst, srcSlice)
125126
})
@@ -139,42 +140,43 @@ func (v *vector32) AddScaler(s anyvec.Numeric) {
139140
panic("nyi")
140141
}
141142

142-
func (v *vector32) Dot(v1 anyvec.Vector) anyvec.Numeric {
143+
func (v *vector32) Dot(other anyvec.Vector) anyvec.Numeric {
144+
v1 := other.(*vector32)
143145
v.assertCompat(v1, true)
144146
var res float32
145147
v.runSync(func() error {
146-
v32 := v1.(*vector32)
147-
if err := lazyInitAll(true, v, v32); err != nil {
148+
if err := lazyInitAll(true, v, v1); err != nil {
148149
return err
149150
}
150-
return v.creator.Handle.blas.Sdot(v.Len(), v.buffer, 1, v32.buffer, 1, &res)
151+
return v.creator.Handle.blas.Sdot(v.Len(), v.buffer, 1, v1.buffer, 1, &res)
151152
})
152153
return res
153154
}
154155

155-
func (v *vector32) Add(v1 anyvec.Vector) {
156-
v.axpy(1, v1)
156+
func (v *vector32) Add(other anyvec.Vector) {
157+
v.axpy(1, other.(*vector32))
157158
}
158159

159-
func (v *vector32) Sub(v1 anyvec.Vector) {
160-
v.axpy(-1, v1)
160+
func (v *vector32) Sub(other anyvec.Vector) {
161+
v.axpy(-1, other.(*vector32))
161162
}
162163

163-
func (v *vector32) Mul(v1 anyvec.Vector) {
164+
func (v *vector32) Mul(other anyvec.Vector) {
165+
v1 := other.(*vector32)
164166
v.assertCompat(v1, false)
165167
panic("nyi")
166168
}
167169

168-
func (v *vector32) Div(v1 anyvec.Vector) {
170+
func (v *vector32) Div(other anyvec.Vector) {
171+
v1 := other.(*vector32)
169172
v.assertCompat(v1, false)
170-
v132 := v1.(*vector32)
171173
v.run(func() error {
172-
if err := lazyInitAll(true, v, v132); err != nil {
174+
if err := lazyInitAll(true, v, v1); err != nil {
173175
return err
174176
}
175177
grid, block := v.kernelSizes()
176178
return v.creator.Handle.kernels32.Launch("kernel", grid, 1, 1,
177-
block, 1, 1, 0, v.buffer, v132.buffer, v.Len())
179+
block, 1, 1, 0, v.buffer, v1.buffer, v.Len())
178180
})
179181
}
180182

@@ -206,25 +208,24 @@ func (v *vector32) Gemm(transA, transB bool, m, n, k int,
206208
})
207209
}
208210

209-
func (v *vector32) axpy(scaler float32, v1 anyvec.Vector) {
211+
func (v *vector32) axpy(scaler float32, v1 *vector32) {
210212
v.assertCompat(v1, false)
211-
v32 := v1.(*vector32)
212213
v.run(func() error {
213-
if v32.buffer == nil {
214+
if v1.buffer == nil {
214215
return nil
215216
} else if v.buffer == nil {
216217
if err := v.lazyInit(false); err != nil {
217218
return err
218219
}
219-
if err := cuda.CopyBuffer(v.buffer, v32.buffer); err != nil {
220+
if err := cuda.CopyBuffer(v.buffer, v1.buffer); err != nil {
220221
return err
221222
}
222223
if scaler == 1 {
223224
return nil
224225
}
225226
return v.creator.Handle.blas.Sscal(v.Len(), scaler, v.buffer, 1)
226227
}
227-
return v.creator.Handle.blas.Saxpy(v.Len(), scaler, v32.buffer, 1,
228+
return v.creator.Handle.blas.Saxpy(v.Len(), scaler, v1.buffer, 1,
228229
v.buffer, 1)
229230
})
230231
}
@@ -252,11 +253,10 @@ func (v *vector32) lazyInit(clear bool) error {
252253
return nil
253254
}
254255

255-
func (v *vector32) assertCompat(v1 anyvec.Vector, readOnly bool) {
256-
v132 := v1.(*vector32)
257-
if readOnly || v == v132 {
256+
func (v *vector32) assertCompat(v1 *vector32, readOnly bool) {
257+
if readOnly || v == v1 {
258258
panic("vectors cannot be equal")
259-
} else if v.Len() != v132.Len() {
259+
} else if v.Len() != v1.Len() {
260260
panic("length mismatch")
261261
}
262262
}

0 commit comments

Comments
 (0)