Skip to content

Commit 447c64a

Browse files
committed
fix return types
only return types which are allowed by database/sql/driver : int64 float64 bool []byte time.Time
1 parent 74a6452 commit 447c64a

File tree

5 files changed

+45
-20
lines changed

5 files changed

+45
-20
lines changed

const.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,8 @@ const (
7474
COM_STMT_FETCH
7575
)
7676

77-
type FieldType byte
78-
7977
const (
80-
FIELD_TYPE_DECIMAL FieldType = iota
78+
FIELD_TYPE_DECIMAL byte = iota
8179
FIELD_TYPE_TINY
8280
FIELD_TYPE_SHORT
8381
FIELD_TYPE_LONG
@@ -96,7 +94,7 @@ const (
9694
FIELD_TYPE_BIT
9795
)
9896
const (
99-
FIELD_TYPE_NEWDECIMAL FieldType = iota + 0xf6
97+
FIELD_TYPE_NEWDECIMAL byte = iota + 0xf6
10098
FIELD_TYPE_ENUM
10199
FIELD_TYPE_SET
102100
FIELD_TYPE_TINY_BLOB

driver_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ func TestFloat(t *testing.T) {
249249
mustExec(t, db, "DROP TABLE IF EXISTS test")
250250

251251
types := [2]string{"FLOAT", "DOUBLE"}
252-
in := float64(42.23)
253-
var out float64
252+
in := float32(42.23)
253+
var out float32
254254
var rows *sql.Rows
255255

256256
for _, v := range types {

packets.go

+18-13
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
434434
pos += n + 1 + 2 + 4
435435

436436
// Field type [byte]
437-
columns[i].fieldType = FieldType(data[pos])
437+
columns[i].fieldType = data[pos]
438438
pos++
439439

440440
// Flags [16 bit uint]
@@ -561,26 +561,26 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
561561
// build NULL-bitmap
562562
if args[i] == nil {
563563
bitMask += 1 << uint(i)
564-
paramTypes[i<<1] = byte(FIELD_TYPE_NULL)
564+
paramTypes[i<<1] = FIELD_TYPE_NULL
565565
continue
566566
}
567567

568568
// cache types and values
569569
switch args[i].(type) {
570570
case int64:
571-
paramTypes[i<<1] = byte(FIELD_TYPE_LONGLONG)
571+
paramTypes[i<<1] = FIELD_TYPE_LONGLONG
572572
paramValues[i] = uint64ToBytes(uint64(args[i].(int64)))
573573
pktLen += 8
574574
continue
575575

576576
case float64:
577-
paramTypes[i<<1] = byte(FIELD_TYPE_DOUBLE)
577+
paramTypes[i<<1] = FIELD_TYPE_DOUBLE
578578
paramValues[i] = uint64ToBytes(math.Float64bits(args[i].(float64)))
579579
pktLen += 8
580580
continue
581581

582582
case bool:
583-
paramTypes[i<<1] = byte(FIELD_TYPE_TINY)
583+
paramTypes[i<<1] = FIELD_TYPE_TINY
584584
pktLen++
585585
if args[i].(bool) {
586586
paramValues[i] = []byte{0x01}
@@ -590,7 +590,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
590590
continue
591591

592592
case []byte:
593-
paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
593+
paramTypes[i<<1] = FIELD_TYPE_STRING
594594
val := args[i].([]byte)
595595
paramValues[i] = append(
596596
lengthEncodedIntegerToBytes(uint64(len(val))),
@@ -600,7 +600,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
600600
continue
601601

602602
case string:
603-
paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
603+
paramTypes[i<<1] = FIELD_TYPE_STRING
604604
val := []byte(args[i].(string))
605605
paramValues[i] = append(
606606
lengthEncodedIntegerToBytes(uint64(len(val))),
@@ -610,7 +610,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
610610
continue
611611

612612
case time.Time:
613-
paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
613+
paramTypes[i<<1] = FIELD_TYPE_STRING
614614
val := []byte(args[i].(time.Time).Format(TIME_FORMAT))
615615
paramValues[i] = append(
616616
lengthEncodedIntegerToBytes(uint64(len(val))),
@@ -718,7 +718,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
718718
// Numeric Typs
719719
case FIELD_TYPE_TINY:
720720
if unsigned {
721-
dest[i] = uint64(data[pos])
721+
dest[i] = int64(data[pos])
722722
} else {
723723
dest[i] = int64(int8(data[pos]))
724724
}
@@ -727,7 +727,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
727727

728728
case FIELD_TYPE_SHORT, FIELD_TYPE_YEAR:
729729
if unsigned {
730-
dest[i] = uint64(binary.LittleEndian.Uint16(data[pos : pos+2]))
730+
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
731731
} else {
732732
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
733733
}
@@ -736,7 +736,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
736736

737737
case FIELD_TYPE_INT24, FIELD_TYPE_LONG:
738738
if unsigned {
739-
dest[i] = uint64(binary.LittleEndian.Uint32(data[pos : pos+4]))
739+
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
740740
} else {
741741
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
742742
}
@@ -745,15 +745,20 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
745745

746746
case FIELD_TYPE_LONGLONG:
747747
if unsigned {
748-
dest[i] = binary.LittleEndian.Uint64(data[pos : pos+8])
748+
val := binary.LittleEndian.Uint64(data[pos : pos+8])
749+
if val > math.MaxInt64 {
750+
dest[i] = uint64ToString(val)
751+
} else {
752+
dest[i] = int64(val)
753+
}
749754
} else {
750755
dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
751756
}
752757
pos += 8
753758
continue
754759

755760
case FIELD_TYPE_FLOAT:
756-
dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
761+
dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
757762
pos += 4
758763
continue
759764

rows.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717

1818
type mysqlField struct {
1919
name string
20-
fieldType FieldType
20+
fieldType byte
2121
flags FieldFlag
2222
}
2323

utils.go

+22
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,28 @@ func uint64ToBytes(n uint64) []byte {
128128
}
129129
}
130130

131+
func uint64ToString(n uint64) []byte {
132+
var a [20]byte
133+
i := 20
134+
135+
// U+0030 = 0
136+
// ...
137+
// U+0039 = 9
138+
139+
var q uint64
140+
for n >= 10 {
141+
i--
142+
q = n / 10
143+
a[i] = uint8(n-q*10) + 0x30
144+
n = q
145+
}
146+
147+
i--
148+
a[i] = uint8(n) + 0x30
149+
150+
return a[i:]
151+
}
152+
131153
func readLengthEnodedString(b []byte) ([]byte, int, error) {
132154
// Get length
133155
num, _, n := readLengthEncodedInteger(b)

0 commit comments

Comments
 (0)