Skip to content

Commit 2713f72

Browse files
committed
basic vector primitives
1 parent bc1eb93 commit 2713f72

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

creator32.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package cudavec
2+
3+
import "github.com/unixpickle/anyvec"
4+
5+
// A Creator32 is an anyvec.Creator for vectors using
6+
// float32 numerics and []float32 slice types.
7+
type Creator32 struct {
8+
Handle *Handle
9+
}
10+
11+
// MakeNumeric creates a float32.
12+
func (c *Creator32) MakeNumeric(x float64) anyvec.Numeric {
13+
return float32(x)
14+
}
15+
16+
// MakeNumericList creates a []float32.
17+
func (c *Creator32) MakeNumericList(x []float64) anyvec.NumericList {
18+
res := make([]float32, len(x))
19+
for i, k := range x {
20+
res[i] = float32(k)
21+
}
22+
return res
23+
}
24+
25+
// MakeVector creates a zero'd out anyvec.Vector.
26+
func (c *Creator32) MakeVector(size int) anyvec.Vector {
27+
panic("nyi")
28+
}
29+
30+
// MakeVectorData creates an anyvec.Vector with the
31+
// specified contents.
32+
func (c *Creator32) MakeVectorData(dObj anyvec.NumericList) anyvec.Vector {
33+
panic("nyi")
34+
}
35+
36+
// Concat concatenates vectors.
37+
func (c *Creator32) Concat(v ...anyvec.Vector) anyvec.Vector {
38+
panic("nyi")
39+
}
40+
41+
// MakeMapper creates a mapper.
42+
func (c *Creator32) MakeMapper(inSize int, table []int) anyvec.Mapper {
43+
panic("nyi")
44+
}

vector32.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package cudavec
2+
3+
import (
4+
"github.com/unixpickle/anyvec"
5+
"github.com/unixpickle/cuda"
6+
)
7+
8+
type vector32 struct {
9+
creator *Creator32
10+
size int
11+
12+
// May be nil for lazy evaluations.
13+
buffer cuda.Buffer
14+
}
15+
16+
func (v *vector32) Creator() anyvec.Creator {
17+
return v.creator
18+
}
19+
20+
func (v *vector32) Len() int {
21+
return v.size
22+
}
23+
24+
func (v *vector32) Data() anyvec.NumericList {
25+
res := make([]float32, v.Len())
26+
v.runSync(func() error {
27+
if v.buffer != nil {
28+
return cuda.ReadBuffer(res, v.buffer)
29+
}
30+
return nil
31+
})
32+
return res
33+
}
34+
35+
func (v *vector32) SetData(d anyvec.NumericList) {
36+
slice := d.([]float32)
37+
if len(slice) > v.Len() {
38+
panic("index out of bounds")
39+
}
40+
v.runSync(func() error {
41+
if err := v.lazyInit(len(slice) < v.Len()); err != nil {
42+
return err
43+
}
44+
return cuda.WriteBuffer(v.buffer, slice)
45+
})
46+
}
47+
48+
func (v *vector32) Scale(s anyvec.Numeric) {
49+
v.run(func() error {
50+
if v.buffer == nil {
51+
return nil
52+
}
53+
return v.creator.Handle.blas.Sscal(v.Len(), s.(float32), v.buffer, 1)
54+
})
55+
}
56+
57+
func (v *vector32) run(f func() error) <-chan error {
58+
return v.creator.Handle.context.Run(func() error {
59+
if err := f(); err != nil {
60+
panic(err)
61+
}
62+
return nil
63+
})
64+
}
65+
66+
func (v *vector32) runSync(f func() error) {
67+
<-v.run(f)
68+
}
69+
70+
func (v *vector32) lazyInit(clear bool) error {
71+
if v.buffer != nil {
72+
return nil
73+
}
74+
var err error
75+
v.buffer, err = cuda.AllocBuffer(v.creator.Handle.allocator, uintptr(v.Len())*4)
76+
if err != nil {
77+
return err
78+
}
79+
if clear {
80+
return cuda.ClearBuffer(v.buffer)
81+
}
82+
return nil
83+
}

0 commit comments

Comments
 (0)