Skip to content

Commit 78b9b01

Browse files
committed
new slice semantics
1 parent b36ed44 commit 78b9b01

File tree

3 files changed

+40
-61
lines changed

3 files changed

+40
-61
lines changed

creator32.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ func (c *Creator32) MakeNumericList(x []float64) anyvec.NumericList {
2828
// MakeVector creates a zero'd out anyvec.Vector.
2929
func (c *Creator32) MakeVector(size int) anyvec.Vector {
3030
return &vector32{
31-
creator: c,
32-
size: size,
31+
bufferID: new(int),
32+
creator: c,
33+
size: size,
3334
}
3435
}
3536

@@ -57,8 +58,9 @@ func (c *Creator32) Concat(v ...anyvec.Vector) anyvec.Vector {
5758
}
5859

5960
res := &vector32{
60-
creator: c,
61-
size: totalLen,
61+
creator: c,
62+
size: totalLen,
63+
bufferID: new(int),
6264
}
6365

6466
c.run(func() error {

vector32.go

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ type vector32 struct {
1010
creator *Creator32
1111
size int
1212

13+
// Used to detect overlap.
14+
bufferID *int
15+
start int
16+
1317
// May be nil for lazy evaluations.
1418
buffer cuda.Buffer
1519
}
@@ -22,6 +26,13 @@ func (v *vector32) Len() int {
2226
return v.size
2327
}
2428

29+
func (v *vector32) Overlaps(v1 anyvec.Vector) bool {
30+
v1Vec := v1.(*vector32)
31+
return v1Vec.bufferID == v.bufferID &&
32+
v.start < v1Vec.start+v1Vec.Len() &&
33+
v1Vec.start < v.start+v.Len()
34+
}
35+
2536
func (v *vector32) Data() anyvec.NumericList {
2637
res := make([]float32, v.Len())
2738
v.runSync(func() error {
@@ -74,54 +85,20 @@ func (v *vector32) Slice(start, end int) anyvec.Vector {
7485
if start < 0 || start > end || end > v.Len() {
7586
panic("index out of range")
7687
}
77-
res := &vector32{creator: v.creator, size: end - start}
78-
v.run(func() (err error) {
79-
if v.buffer == nil {
80-
return nil
81-
}
82-
subSlice := cuda.Slice(v.buffer, uintptr(start)*4, uintptr(end)*4)
83-
res.buffer, err = cuda.AllocBuffer(v.creator.Handle.allocator,
84-
uintptr(res.size)*4)
85-
if err != nil {
86-
return err
87-
}
88-
return cuda.CopyBuffer(res.buffer, subSlice)
89-
})
90-
return res
91-
}
92-
93-
func (v *vector32) SetSlice(start int, other anyvec.Vector) {
94-
v1 := other.(*vector32)
95-
if start <= -v1.Len() || start >= v.Len() {
96-
return
88+
res := &vector32{
89+
creator: v.creator,
90+
size: end - start,
91+
bufferID: v.bufferID,
92+
start: v.start + start,
9793
}
98-
99-
v.run(func() error {
100-
if v.buffer == nil && v1.buffer == nil {
101-
return nil
102-
}
103-
dstStart := start
104-
srcStart := 0
105-
if start < 0 {
106-
dstStart = 0
107-
srcStart = -start
108-
}
109-
copyCount := v1.Len() - srcStart
110-
if v.Len()-dstStart < copyCount {
111-
copyCount = v.Len() - dstStart
112-
}
94+
v.run(func() (err error) {
11395
if err := v.lazyInit(true); err != nil {
11496
return err
11597
}
116-
dst := cuda.Slice(v.buffer, uintptr(dstStart)*4,
117-
uintptr(dstStart+copyCount)*4)
118-
if v1.buffer == nil {
119-
return cuda.ClearBuffer(dst)
120-
}
121-
srcSlice := cuda.Slice(v1.buffer, uintptr(srcStart)*4,
122-
uintptr(srcStart+copyCount)*4)
123-
return cuda.CopyBuffer(dst, srcSlice)
98+
res.buffer = cuda.Slice(v.buffer, uintptr(start)*4, uintptr(end)*4)
99+
return nil
124100
})
101+
return res
125102
}
126103

127104
func (v *vector32) Scale(s anyvec.Numeric) {
@@ -202,8 +179,8 @@ func (v *vector32) Gemm(transA, transB bool, m, n, k int,
202179
betaFloat := beta.(float32)
203180
a32 := a.(*vector32)
204181
b32 := b.(*vector32)
205-
if a32 == v || b32 == v {
206-
panic("vectors cannot be equal")
182+
if v.Overlaps(a32) || v.Overlaps(b32) {
183+
panic("invalid overlap")
207184
}
208185
v.run(func() error {
209186
if err := lazyInitAll(true, v, a32, b32); err != nil {
@@ -229,8 +206,8 @@ func (v *vector32) Gemv(trans bool, m, n int, alpha anyvec.Numeric, a anyvec.Vec
229206
betaFloat := beta.(float32)
230207
x32 := x.(*vector32)
231208
a32 := a.(*vector32)
232-
if x32 == v || a32 == v {
233-
panic("vectors cannot be equal")
209+
if v.Overlaps(x32) || v.Overlaps(a32) {
210+
panic("invalid overlap")
234211
}
235212
v.run(func() error {
236213
if err := lazyInitAll(true, v, x32, a32); err != nil {
@@ -292,8 +269,8 @@ func (v *vector32) lazyInit(clear bool) error {
292269
}
293270

294271
func (v *vector32) assertCompat(v1 *vector32, readOnly bool) {
295-
if !readOnly && v == v1 {
296-
panic("vectors cannot be equal")
272+
if !readOnly && v.Overlaps(v1) {
273+
panic("invalid overlap")
297274
} else if v.Len() != v1.Len() {
298275
panic("length mismatch")
299276
}

vector32_extra.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ func (v *vector32) Sum() anyvec.Numeric {
5252

5353
func (v *vector32) ScaleChunks(other anyvec.Vector) {
5454
v1 := other.(*vector32)
55-
if v == v1 {
56-
panic("inputs overlap")
55+
if v.Overlaps(v1) {
56+
panic("invalid overlap")
5757
} else if v.Len()%v1.Len() != 0 {
5858
panic("scaler count must divide vector size")
5959
}
@@ -70,8 +70,8 @@ func (v *vector32) ScaleChunks(other anyvec.Vector) {
7070

7171
func (v *vector32) AddChunks(other anyvec.Vector) {
7272
v1 := other.(*vector32)
73-
if v == v1 {
74-
panic("inputs overlap")
73+
if v.Overlaps(v1) {
74+
panic("invalid overlap")
7575
} else if v.Len()%v1.Len() != 0 {
7676
panic("scaler count must divide vector size")
7777
}
@@ -154,8 +154,8 @@ func (v *vector32) ScaleRepeated(other anyvec.Vector) {
154154
}
155155

156156
func (v *vector32) repeatedOp(kernel string, v1 *vector32) {
157-
if v == v1 {
158-
panic("inputs overlap")
157+
if v.Overlaps(v1) {
158+
panic("invalid overlap")
159159
} else if v1.Len() == 0 {
160160
panic("repeated vector cannot be empty")
161161
}
@@ -434,8 +434,8 @@ func (v *vector32) BatchedGemm(transA, transB bool, num, m, n, k int, alpha anyv
434434
b32 := b.(*vector32)
435435
alpha32 := alpha.(float32)
436436
beta32 := beta.(float32)
437-
if a32 == v || b32 == v {
438-
panic("vectors cannot be equal")
437+
if v.Overlaps(a32) || v.Overlaps(b32) {
438+
panic("invalid overlap")
439439
}
440440
v.creator.run(func() error {
441441
if err := lazyInitAll(true, a32, b32, v); err != nil {

0 commit comments

Comments
 (0)