-
-
Notifications
You must be signed in to change notification settings - Fork 7
/
session.go
201 lines (173 loc) · 4.26 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
package mango
import (
"context"
"reflect"
"sync"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
// Session mongo session
type Session struct {
client *mongo.Client
collection *mongo.Collection
maxPoolSize uint64
db string
uri string
m sync.RWMutex
filter interface{}
limit *int64
project interface{}
skip *int64
sort interface{}
}
// New session
//
// Relevant documentation:
//
// https://docs.mongodb.com/manual/reference/connection-string/
func New(uri string) *Session {
session := &Session{
uri: uri,
}
return session
}
// C Collection alias
func (s *Session) C(collection string) *Collection {
s.m.Lock()
defer s.m.Unlock()
if len(s.db) == 0 {
s.db = "test"
}
d := &Database{database: s.client.Database(s.db)}
return &Collection{collection: d.database.Collection(collection)}
}
// Collection returns collection
func (s *Session) Collection(collection string) *Collection {
s.m.Lock()
defer s.m.Unlock()
if len(s.db) == 0 {
s.db = "test"
}
d := &Database{database: s.client.Database(s.db)}
return &Collection{collection: d.database.Collection(collection)}
}
// SetPoolLimit specifies the max size of a server's connection pool.
func (s *Session) SetPoolLimit(limit uint64) {
s.m.Lock()
defer s.m.Unlock()
s.maxPoolSize = limit
}
// Connect mongo client
func (s *Session) Connect() error {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
opt := options.Client().ApplyURI(s.uri)
opt.SetMaxPoolSize(s.maxPoolSize)
client, err := mongo.NewClient(opt)
if err != nil {
return err
}
err = client.Connect(ctx)
if err != nil {
return err
}
if err != nil {
return err
}
s.client = client
return nil
}
// Ping verifies that the client can connect to the topology.
// If readPreference is nil then will use the client's default read
// preference.
func (s *Session) Ping() error {
return s.client.Ping(context.TODO(), readpref.Primary())
}
// Client return mongo Client
func (s *Session) Client() *mongo.Client {
return s.client
}
// DB returns a value representing the named database.
func (s *Session) DB(db string) *Database {
s.m.Lock()
defer s.m.Unlock()
return &Database{database: s.client.Database(db)}
}
// Limit specifies a limit on the number of results.
// A negative limit implies that only 1 batch should be returned.
func (s *Session) Limit(limit int64) *Session {
s.limit = &limit
return s
}
// Skip specifies the number of documents to skip before returning.
// For server versions < 3.2, this defaults to 0.
func (s *Session) Skip(skip int64) *Session {
s.skip = &skip
return s
}
// Sort specifies the order in which to return documents.
func (s *Session) Sort(sort interface{}) *Session {
s.sort = sort
return s
}
// One returns up to one document that matches the model.
func (s *Session) One(result interface{}) error {
var err error
data, err := s.collection.FindOne(context.TODO(), s.filter).DecodeBytes()
if err != nil {
return err
}
err = bson.Unmarshal(data, result)
return err
}
// All find all
func (s *Session) All(result interface{}) error {
resultv := reflect.ValueOf(result)
if resultv.Kind() != reflect.Ptr {
panic("result argument must be a slice address")
}
slicev := resultv.Elem()
if slicev.Kind() == reflect.Interface {
slicev = slicev.Elem()
}
if slicev.Kind() != reflect.Slice {
panic("result argument must be a slice address")
}
slicev = slicev.Slice(0, slicev.Cap())
elemt := slicev.Type().Elem()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var err error
opt := options.Find()
if s.sort != nil {
opt.SetSort(s.sort)
}
if s.limit != nil {
opt.SetLimit(*s.limit)
}
if s.skip != nil {
opt.SetSkip(*s.skip)
}
cur, err := s.collection.Find(ctx, s.filter, opt)
defer cur.Close(ctx)
if err != nil {
return err
}
if err = cur.Err(); err != nil {
return err
}
i := 0
for cur.Next(ctx) {
elemp := reflect.New(elemt)
if err = bson.Unmarshal(cur.Current, elemp.Interface()); err != nil {
return err
}
slicev = reflect.Append(slicev, elemp.Elem())
i++
}
resultv.Elem().Set(slicev.Slice(0, i))
return nil
}