@@ -10,6 +10,10 @@ type vector32 struct {
1010 creator * Creator32
1111 size int
1212
13+ // Used to detect overlap.
14+ bufferID * int
15+ start int
16+
1317 // May be nil for lazy evaluations.
1418 buffer cuda.Buffer
1519}
@@ -22,6 +26,13 @@ func (v *vector32) Len() int {
2226 return v .size
2327}
2428
29+ func (v * vector32 ) Overlaps (v1 anyvec.Vector ) bool {
30+ v1Vec := v1 .(* vector32 )
31+ return v1Vec .bufferID == v .bufferID &&
32+ v .start < v1Vec .start + v1Vec .Len () &&
33+ v1Vec .start < v .start + v .Len ()
34+ }
35+
2536func (v * vector32 ) Data () anyvec.NumericList {
2637 res := make ([]float32 , v .Len ())
2738 v .runSync (func () error {
@@ -74,54 +85,20 @@ func (v *vector32) Slice(start, end int) anyvec.Vector {
7485 if start < 0 || start > end || end > v .Len () {
7586 panic ("index out of range" )
7687 }
77- res := & vector32 {creator : v .creator , size : end - start }
78- v .run (func () (err error ) {
79- if v .buffer == nil {
80- return nil
81- }
82- subSlice := cuda .Slice (v .buffer , uintptr (start )* 4 , uintptr (end )* 4 )
83- res .buffer , err = cuda .AllocBuffer (v .creator .Handle .allocator ,
84- uintptr (res .size )* 4 )
85- if err != nil {
86- return err
87- }
88- return cuda .CopyBuffer (res .buffer , subSlice )
89- })
90- return res
91- }
92-
93- func (v * vector32 ) SetSlice (start int , other anyvec.Vector ) {
94- v1 := other .(* vector32 )
95- if start <= - v1 .Len () || start >= v .Len () {
96- return
88+ res := & vector32 {
89+ creator : v .creator ,
90+ size : end - start ,
91+ bufferID : v .bufferID ,
92+ start : v .start + start ,
9793 }
98-
99- v .run (func () error {
100- if v .buffer == nil && v1 .buffer == nil {
101- return nil
102- }
103- dstStart := start
104- srcStart := 0
105- if start < 0 {
106- dstStart = 0
107- srcStart = - start
108- }
109- copyCount := v1 .Len () - srcStart
110- if v .Len ()- dstStart < copyCount {
111- copyCount = v .Len () - dstStart
112- }
94+ v .run (func () (err error ) {
11395 if err := v .lazyInit (true ); err != nil {
11496 return err
11597 }
116- dst := cuda .Slice (v .buffer , uintptr (dstStart )* 4 ,
117- uintptr (dstStart + copyCount )* 4 )
118- if v1 .buffer == nil {
119- return cuda .ClearBuffer (dst )
120- }
121- srcSlice := cuda .Slice (v1 .buffer , uintptr (srcStart )* 4 ,
122- uintptr (srcStart + copyCount )* 4 )
123- return cuda .CopyBuffer (dst , srcSlice )
98+ res .buffer = cuda .Slice (v .buffer , uintptr (start )* 4 , uintptr (end )* 4 )
99+ return nil
124100 })
101+ return res
125102}
126103
127104func (v * vector32 ) Scale (s anyvec.Numeric ) {
@@ -202,8 +179,8 @@ func (v *vector32) Gemm(transA, transB bool, m, n, k int,
202179 betaFloat := beta .(float32 )
203180 a32 := a .(* vector32 )
204181 b32 := b .(* vector32 )
205- if a32 == v || b32 == v {
206- panic ("vectors cannot be equal " )
182+ if v . Overlaps ( a32 ) || v . Overlaps ( b32 ) {
183+ panic ("invalid overlap " )
207184 }
208185 v .run (func () error {
209186 if err := lazyInitAll (true , v , a32 , b32 ); err != nil {
@@ -229,8 +206,8 @@ func (v *vector32) Gemv(trans bool, m, n int, alpha anyvec.Numeric, a anyvec.Vec
229206 betaFloat := beta .(float32 )
230207 x32 := x .(* vector32 )
231208 a32 := a .(* vector32 )
232- if x32 == v || a32 == v {
233- panic ("vectors cannot be equal " )
209+ if v . Overlaps ( x32 ) || v . Overlaps ( a32 ) {
210+ panic ("invalid overlap " )
234211 }
235212 v .run (func () error {
236213 if err := lazyInitAll (true , v , x32 , a32 ); err != nil {
@@ -292,8 +269,8 @@ func (v *vector32) lazyInit(clear bool) error {
292269}
293270
294271func (v * vector32 ) assertCompat (v1 * vector32 , readOnly bool ) {
295- if ! readOnly && v == v1 {
296- panic ("vectors cannot be equal " )
272+ if ! readOnly && v . Overlaps ( v1 ) {
273+ panic ("invalid overlap " )
297274 } else if v .Len () != v1 .Len () {
298275 panic ("length mismatch" )
299276 }
0 commit comments