@@ -46,10 +46,11 @@ func (v *vector32) SetData(d anyvec.NumericList) {
4646 })
4747}
4848
49- func (v * vector32 ) Set (v1 anyvec.Vector ) {
49+ func (v * vector32 ) Set (other anyvec.Vector ) {
50+ v1 := other .(* vector32 )
5051 v .assertCompat (v1 , false )
5152 v .run (func () error {
52- buf1 := v1 .( * vector32 ). buffer
53+ buf1 := v1 .buffer
5354 if buf1 == nil {
5455 if v .buffer != nil {
5556 return cuda .ClearBuffer (v .buffer )
@@ -89,16 +90,16 @@ func (v *vector32) Slice(start, end int) anyvec.Vector {
8990 return res
9091}
9192
92- func (v * vector32 ) SetSlice (start int , v1 anyvec.Vector ) {
93- src := v1 .(* vector32 )
94- if src .Len () > v .Len ()- start {
93+ func (v * vector32 ) SetSlice (start int , other anyvec.Vector ) {
94+ v1 := other .(* vector32 )
95+ if v1 .Len () > v .Len ()- start {
9596 panic ("index out of range" )
96- } else if start <= - src .Len () {
97+ } else if start <= - v1 .Len () {
9798 return
9899 }
99100
100101 v .run (func () error {
101- if src .buffer == nil && v .buffer == nil {
102+ if v .buffer == nil && v1 .buffer == nil {
102103 return nil
103104 }
104105 dstStart := start
@@ -107,7 +108,7 @@ func (v *vector32) SetSlice(start int, v1 anyvec.Vector) {
107108 dstStart = 0
108109 srcStart = - start
109110 }
110- copyCount := src .Len () - srcStart
111+ copyCount := v1 .Len () - srcStart
111112 if v .Len ()- dstStart < copyCount {
112113 copyCount = v .Len () - dstStart
113114 }
@@ -116,10 +117,10 @@ func (v *vector32) SetSlice(start int, v1 anyvec.Vector) {
116117 }
117118 dst := cuda .Slice (v .buffer , uintptr (dstStart )* 4 ,
118119 uintptr (dstStart + copyCount )* 4 )
119- if src .buffer == nil {
120+ if v1 .buffer == nil {
120121 return cuda .ClearBuffer (dst )
121122 }
122- srcSlice := cuda .Slice (src .buffer , uintptr (srcStart )* 4 ,
123+ srcSlice := cuda .Slice (v1 .buffer , uintptr (srcStart )* 4 ,
123124 uintptr (srcStart + copyCount )* 4 )
124125 return cuda .CopyBuffer (dst , srcSlice )
125126 })
@@ -139,42 +140,43 @@ func (v *vector32) AddScaler(s anyvec.Numeric) {
139140 panic ("nyi" )
140141}
141142
142- func (v * vector32 ) Dot (v1 anyvec.Vector ) anyvec.Numeric {
143+ func (v * vector32 ) Dot (other anyvec.Vector ) anyvec.Numeric {
144+ v1 := other .(* vector32 )
143145 v .assertCompat (v1 , true )
144146 var res float32
145147 v .runSync (func () error {
146- v32 := v1 .(* vector32 )
147- if err := lazyInitAll (true , v , v32 ); err != nil {
148+ if err := lazyInitAll (true , v , v1 ); err != nil {
148149 return err
149150 }
150- return v .creator .Handle .blas .Sdot (v .Len (), v .buffer , 1 , v32 .buffer , 1 , & res )
151+ return v .creator .Handle .blas .Sdot (v .Len (), v .buffer , 1 , v1 .buffer , 1 , & res )
151152 })
152153 return res
153154}
154155
155- func (v * vector32 ) Add (v1 anyvec.Vector ) {
156- v .axpy (1 , v1 )
156+ func (v * vector32 ) Add (other anyvec.Vector ) {
157+ v .axpy (1 , other .( * vector32 ) )
157158}
158159
159- func (v * vector32 ) Sub (v1 anyvec.Vector ) {
160- v .axpy (- 1 , v1 )
160+ func (v * vector32 ) Sub (other anyvec.Vector ) {
161+ v .axpy (- 1 , other .( * vector32 ) )
161162}
162163
163- func (v * vector32 ) Mul (v1 anyvec.Vector ) {
164+ func (v * vector32 ) Mul (other anyvec.Vector ) {
165+ v1 := other .(* vector32 )
164166 v .assertCompat (v1 , false )
165167 panic ("nyi" )
166168}
167169
168- func (v * vector32 ) Div (v1 anyvec.Vector ) {
170+ func (v * vector32 ) Div (other anyvec.Vector ) {
171+ v1 := other .(* vector32 )
169172 v .assertCompat (v1 , false )
170- v132 := v1 .(* vector32 )
171173 v .run (func () error {
172- if err := lazyInitAll (true , v , v132 ); err != nil {
174+ if err := lazyInitAll (true , v , v1 ); err != nil {
173175 return err
174176 }
175177 grid , block := v .kernelSizes ()
176178 return v .creator .Handle .kernels32 .Launch ("kernel" , grid , 1 , 1 ,
177- block , 1 , 1 , 0 , v .buffer , v132 .buffer , v .Len ())
179+ block , 1 , 1 , 0 , v .buffer , v1 .buffer , v .Len ())
178180 })
179181}
180182
@@ -206,25 +208,24 @@ func (v *vector32) Gemm(transA, transB bool, m, n, k int,
206208 })
207209}
208210
209- func (v * vector32 ) axpy (scaler float32 , v1 anyvec. Vector ) {
211+ func (v * vector32 ) axpy (scaler float32 , v1 * vector32 ) {
210212 v .assertCompat (v1 , false )
211- v32 := v1 .(* vector32 )
212213 v .run (func () error {
213- if v32 .buffer == nil {
214+ if v1 .buffer == nil {
214215 return nil
215216 } else if v .buffer == nil {
216217 if err := v .lazyInit (false ); err != nil {
217218 return err
218219 }
219- if err := cuda .CopyBuffer (v .buffer , v32 .buffer ); err != nil {
220+ if err := cuda .CopyBuffer (v .buffer , v1 .buffer ); err != nil {
220221 return err
221222 }
222223 if scaler == 1 {
223224 return nil
224225 }
225226 return v .creator .Handle .blas .Sscal (v .Len (), scaler , v .buffer , 1 )
226227 }
227- return v .creator .Handle .blas .Saxpy (v .Len (), scaler , v32 .buffer , 1 ,
228+ return v .creator .Handle .blas .Saxpy (v .Len (), scaler , v1 .buffer , 1 ,
228229 v .buffer , 1 )
229230 })
230231}
@@ -252,11 +253,10 @@ func (v *vector32) lazyInit(clear bool) error {
252253 return nil
253254}
254255
255- func (v * vector32 ) assertCompat (v1 anyvec.Vector , readOnly bool ) {
256- v132 := v1 .(* vector32 )
257- if readOnly || v == v132 {
256+ func (v * vector32 ) assertCompat (v1 * vector32 , readOnly bool ) {
257+ if readOnly || v == v1 {
258258 panic ("vectors cannot be equal" )
259- } else if v .Len () != v132 .Len () {
259+ } else if v .Len () != v1 .Len () {
260260 panic ("length mismatch" )
261261 }
262262}
0 commit comments