@@ -57,7 +57,7 @@ type Config struct {
57
57
SSLKey string `json:"ssl-key"`
58
58
VerifyCert bool `json:"verify-cert"`
59
59
ServerId uint64 `json:"server-id"`
60
- BinLogFile string `json:"binlog-file"`
60
+ BinlogFile string `json:"binlog-file"`
61
61
Timeout time.Duration
62
62
}
63
63
@@ -89,6 +89,8 @@ type Conn struct {
89
89
writeBuf * bytes.Buffer
90
90
StausFlags * StatusFlags
91
91
Listener * net.Listener
92
+ packetHeader * PacketHeader
93
+ scanPos uint64
92
94
}
93
95
94
96
func newBinlogConn (config * Config ) Conn {
@@ -168,40 +170,34 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
168
170
}
169
171
170
172
// Listen for auth response.
171
- r , err := c .listen ()
172
- switch r .(type ) {
173
- case * OKPacket : // Login successful.
174
- // Reset sequence to 0 now that connection phase is over
175
- // and command phase has started.
176
- c .sequenceId = 0
177
-
178
- // Tell server to stream binlog.
179
- err = c .startBinLogStream ()
180
- if err != nil {
181
- return nil , err
182
- }
173
+ _ , err = c .readPacket ()
174
+ if err != nil {
175
+ return nil , err
176
+ }
183
177
184
- err = c .listenForBinlog ()
185
- if err != nil {
186
- return nil , err
187
- }
188
- case * ErrorPacket : // Bad login.
189
- ep := r .(* ErrorPacket )
190
- em := fmt .Sprintf ("Error %d: %s" , ep .ErrorCode , ep .ErrorMessage )
191
- err = errors .New (em )
178
+ // Auth was successful.
179
+ c .sequenceId = 0
180
+
181
+ // Register as a slave
182
+ err = c .registerAsSlave ()
183
+ if err != nil {
184
+ return nil , err
185
+ }
186
+
187
+ _ , err = c .readPacket ()
188
+ if err != nil {
192
189
return nil , err
193
- case * AuthMoreDataPacket :
194
- panic (errors .New ("Unexpected AuthMoreDataPacket" ))
195
190
}
196
191
197
192
return c , err
198
193
}
199
194
200
- func (c * Conn ) listen () (interface {}, error ) {
195
+ func (c * Conn ) readPacket () (interface {}, error ) {
201
196
ph , err := c .getPacketHeader ()
202
197
if err != nil {
203
198
return nil , err
204
199
}
200
+
205
201
c .sequenceId ++
206
202
207
203
var res interface {}
@@ -222,7 +218,6 @@ func (c *Conn) listen() (interface{}, error) {
222
218
return nil , c .Flush ()
223
219
}
224
220
}
225
-
226
221
case StatusEOF :
227
222
fallthrough
228
223
case StatusOK :
@@ -236,12 +231,11 @@ func (c *Conn) listen() (interface{}, error) {
236
231
return nil , err
237
232
}
238
233
239
- err = errors .New (
240
- fmt .Sprintf (
241
- "Error %d: %s" ,
242
- res .(* ErrorPacket ).ErrorCode ,
243
- res .(* ErrorPacket ).ErrorMessage ,
244
- ))
234
+ err = fmt .Errorf (
235
+ "Error %d: %s" ,
236
+ res .(* ErrorPacket ).ErrorCode ,
237
+ res .(* ErrorPacket ).ErrorMessage ,
238
+ )
245
239
246
240
return res , err
247
241
}
@@ -260,18 +254,27 @@ type PacketHeader struct {
260
254
Status uint64
261
255
}
262
256
263
- func (c * Conn ) getPacketHeader () (PacketHeader , error ) {
257
+ func (c * Conn ) getPacketHeader () (* PacketHeader , error ) {
264
258
ph := PacketHeader {}
265
259
ph .Length = c .getInt (TypeFixedInt , 3 )
260
+
261
+ if ph .Length == 0 {
262
+ err := errors .New ("EOF" )
263
+ return nil , err
264
+ }
265
+
266
266
ph .SequenceID = c .getInt (TypeFixedInt , 1 )
267
267
ph .Status = c .getInt (TypeFixedInt , 1 )
268
268
269
269
err := c .scanner .Err ()
270
270
if err != nil {
271
- return ph , err
271
+ return & ph , err
272
272
}
273
273
274
- return ph , nil
274
+ c .packetHeader = & ph
275
+ c .scanPos = 0
276
+
277
+ return & ph , nil
275
278
}
276
279
277
280
func init () {
@@ -289,33 +292,12 @@ func (c *Conn) readBytes(l uint64) *bytes.Buffer {
289
292
} else {
290
293
panic (err ) // @TODO Handle this gracefully.
291
294
}
292
-
293
- return nil
294
295
}
295
296
296
297
b = append (b , c .scanner .Bytes ()... )
297
298
}
298
299
299
- return bytes .NewBuffer (b )
300
- }
301
-
302
- func (c * Conn ) getBytesUntilEOF () * bytes.Buffer {
303
- l := uint64 (1 )
304
- s := c .readBytes (l )
305
- b := s .Bytes ()
306
-
307
- for true {
308
- if uint64 (s .Len ()) != l || s .Bytes ()[0 ] == NullByte {
309
- break
310
- }
311
-
312
- s := c .readBytes (uint64 (l ))
313
- if s == EOF || s == nil {
314
- return bytes .NewBuffer (b )
315
- }
316
-
317
- b = append (b , s .Bytes ()... )
318
- }
300
+ c .scanPos += uint64 (len (b ))
319
301
320
302
return bytes .NewBuffer (b )
321
303
}
@@ -325,7 +307,7 @@ func (c *Conn) getBytesUntilNull() *bytes.Buffer {
325
307
s := c .readBytes (l )
326
308
b := s .Bytes ()
327
309
328
- for true {
310
+ for {
329
311
if uint64 (s .Len ()) != l || s .Bytes ()[0 ] == NullByte {
330
312
break
331
313
}
@@ -376,8 +358,15 @@ func (c *Conn) getString(t int, l uint64) string {
376
358
}
377
359
378
360
func (c * Conn ) decRestOfPacketString () string {
379
- b := c .getBytesUntilEOF ()
380
- return string (b .Bytes ())
361
+ b := c .getRemainingBytes ()
362
+ return b .String ()
363
+ }
364
+
365
+ func (c * Conn ) getRemainingBytes () * bytes.Buffer {
366
+ l := (c .packetHeader .Length - 1 ) - c .scanPos
367
+ b := c .readBytes (l )
368
+
369
+ return b
381
370
}
382
371
383
372
func (c * Conn ) decNullTerminatedString () string {
@@ -651,7 +640,7 @@ type StatusFlags struct {
651
640
}
652
641
653
642
type OKPacket struct {
654
- PacketHeader
643
+ * PacketHeader
655
644
Header uint64
656
645
AffectedRows uint64
657
646
LastInsertID uint64
@@ -661,7 +650,7 @@ type OKPacket struct {
661
650
SessionStateInfo string
662
651
}
663
652
664
- func (c * Conn ) decodeOKPacket (ph PacketHeader ) (* OKPacket , error ) {
653
+ func (c * Conn ) decodeOKPacket (ph * PacketHeader ) (* OKPacket , error ) {
665
654
op := OKPacket {}
666
655
op .PacketHeader = ph
667
656
op .Header = ph .Status
@@ -684,14 +673,14 @@ func (c *Conn) decodeOKPacket(ph PacketHeader) (*OKPacket, error) {
684
673
}
685
674
686
675
type ErrorPacket struct {
687
- PacketHeader
676
+ * PacketHeader
688
677
ErrorCode uint64
689
678
ErrorMessage string
690
679
SQLStateMarker string
691
680
SQLState string
692
681
}
693
682
694
- func (c * Conn ) decodeErrorPacket (ph PacketHeader ) (* ErrorPacket , error ) {
683
+ func (c * Conn ) decodeErrorPacket (ph * PacketHeader ) (* ErrorPacket , error ) {
695
684
ep := ErrorPacket {}
696
685
ep .PacketHeader = ph
697
686
ep .ErrorCode = c .getInt (TypeFixedInt , 2 )
0 commit comments