diff --git a/v2/connection.go b/v2/connection.go index d8c57949..c1d87644 100755 --- a/v2/connection.go +++ b/v2/connection.go @@ -426,7 +426,7 @@ func (conn *Connection) OpenWithContext(ctx context.Context) error { // start check for context session.StartContext(ctx) defer session.EndContext() - err := session.Connect() + err := session.Connect(ctx) if err != nil { return err } diff --git a/v2/data_set.go b/v2/data_set.go index 4925050c..93887846 100644 --- a/v2/data_set.go +++ b/v2/data_set.go @@ -153,7 +153,7 @@ func (dataSet *DataSet) Scan(dest ...interface{}) error { continue } colInfo := (*dataSet.cols)[srcIndex+processedFields] - if strings.ToUpper(colInfo.Name) != strings.ToUpper(name) { + if !strings.EqualFold(colInfo.Name, name) { continue } err := dataSet.setObjectValue(reflect.ValueOf(dest[destIndex]).Elem().Field(x), srcIndex+processedFields) @@ -167,7 +167,6 @@ func (dataSet *DataSet) Scan(dest ...interface{}) error { srcIndex = srcIndex + processedFields - 1 continue } - } // else err := dataSet.setObjectValue(reflect.ValueOf(dest[destIndex]).Elem(), srcIndex) @@ -193,39 +192,40 @@ func (dataSet *DataSet) Scan(dest ...interface{}) error { // for non-supported type // error means error occur during operation func (dataSet *DataSet) setObjectValue(obj reflect.Value, colIndex int) error { - value := dataSet.currentRow[colIndex] + //value := dataSet.currentRow[colIndex] col := (*dataSet.cols)[colIndex] - if value == nil { - return setNull(obj) - } - if obj.Kind() == reflect.Interface { - obj.Set(reflect.ValueOf(value)) - return nil - } - switch val := value.(type) { - case int64: - return setNumber(obj, float64(val)) - case float64: - return setNumber(obj, val) - case string: - return setString(obj, val) - case time.Time: - return setTime(obj, val) - case []byte: - return setBytes(obj, val) - case bool: - if val { - return setNumber(obj, 1) - } else { - return setNumber(obj, 0) - } - default: - if col.cusType != nil && col.cusType.typ == obj.Type() { - obj.Set(reflect.ValueOf(value)) - return nil - } - return fmt.Errorf("can't assign value: %v to object of type: %v", value, obj.Type().Name()) - } + //if value == nil { + // return setNull(obj) + //} + //if obj.Kind() == reflect.Interface { + // obj.Set(reflect.ValueOf(value)) + // return nil + //} + return setFieldValue(obj, col.cusType, dataSet.currentRow[colIndex]) + //switch val := value.(type) { + //case int64: + // return setNumber(obj, float64(val)) + //case float64: + // return setNumber(obj, val) + //case string: + // return setString(obj, val) + //case time.Time: + // return setTime(obj, val) + //case []byte: + // return setBytes(obj, val) + //case bool: + // if val { + // return setNumber(obj, 1) + // } else { + // return setNumber(obj, 0) + // } + //default: + // if col.cusType != nil && col.cusType.typ == obj.Type() { + // obj.Set(reflect.ValueOf(value)) + // return nil + // } + // return fmt.Errorf("can't assign value: %v to object of type: %v", value, obj.Type().Name()) + //} //err := setFieldValue(obj, col.cusType, dataSet.currentRow[colIndex]) //if err != nil { // return err diff --git a/v2/network/session.go b/v2/network/session.go index d6914f2c..b7fb8bf5 100755 --- a/v2/network/session.go +++ b/v2/network/session.go @@ -39,8 +39,8 @@ type SessionState struct { } type Session struct { - ctx context.Context - oldCtx context.Context + //ctx context.Context + //oldCtx context.Context conn net.Conn sslConn *tls.Conn reader *bufio.Reader @@ -52,7 +52,7 @@ type Session struct { outBuffer bytes.Buffer index int breakIndex int - doneContext chan struct{} + doneContext []chan struct{} TimeZone []byte TTCVersion uint8 HasEOSCapability bool @@ -85,7 +85,7 @@ func NewSessionWithInputBufferForDebug(input []byte) *Session { Tracer: trace.NilTracer(), } return &Session{ - ctx: context.Background(), + //ctx: context.Background(), conn: nil, inBuffer: input, index: 0, @@ -100,7 +100,7 @@ func NewSessionWithInputBufferForDebug(input []byte) *Session { func NewSession(connOption *ConnectionOption) *Session { return &Session{ - ctx: context.Background(), + //ctx: context.Background(), conn: nil, inBuffer: nil, index: 0, @@ -177,14 +177,15 @@ func (session *Session) ResetBuffer() { } func (session *Session) StartContext(ctx context.Context) { - session.oldCtx = session.ctx - session.ctx = ctx - session.doneContext = make(chan struct{}) - go func() { + //session.oldCtx = session.ctx + //session.ctx = ctx + done := make(chan struct{}) + session.doneContext = append(session.doneContext, done) + go func(done chan struct{}) { var err error var tracer = session.Context.ConnOption.Tracer select { - case <-session.doneContext: + case <-done: return case <-ctx.Done(): if session.Connected { @@ -199,14 +200,27 @@ func (session *Session) StartContext(ctx context.Context) { session.Disconnect() } } - }() + }(done) } func (session *Session) EndContext() { - if session.doneContext != nil { - close(session.doneContext) + index := len(session.doneContext) - 1 + if index >= 0 { + done := session.doneContext[index] + if done != nil { + close(done) + done = nil + } + } + if index == 0 { session.doneContext = nil + } else { + session.doneContext = session.doneContext[:index] } - session.ctx = session.oldCtx + //if session.doneContext != nil { + // close(session.doneContext) + // session.doneContext = nil + //} + //session.ctx = session.oldCtx } func (session *Session) initRead() error { @@ -452,7 +466,7 @@ func (session *Session) BreakConnection() error { // check if the client need to use SSL // then send connect packet to the server and // receive either accept, redirect or refuse packet -func (session *Session) Connect() error { +func (session *Session) Connect(ctx context.Context) error { connOption := session.Context.ConnOption session.Disconnect() connOption.Tracer.Print("Connect") @@ -482,9 +496,9 @@ func (session *Session) Connect() error { } addr := host.networkAddr() if len(session.Context.ConnOption.UnixAddress) > 0 { - session.conn, err = dialer.DialContext(session.ctx, "unix", session.Context.ConnOption.UnixAddress) + session.conn, err = dialer.DialContext(ctx, "unix", session.Context.ConnOption.UnixAddress) } else { - session.conn, err = dialer.DialContext(session.ctx, "tcp", addr) + session.conn, err = dialer.DialContext(ctx, "tcp", addr) } if err != nil { @@ -548,7 +562,7 @@ func (session *Session) Connect() error { connOption.AddServer(srv) } host = connOption.GetActiveServer(true) - return session.Connect() + return session.Connect(ctx) } if refusePacket, ok := pck.(*RefusePacket); ok { refusePacket.extractErrCode() @@ -564,7 +578,7 @@ func (session *Session) Connect() error { session.Disconnect() return &refusePacket.Err } - return session.Connect() + return session.Connect(ctx) } return errors.New("connection refused by the server due to unknown reason") } diff --git a/v2/utils.go b/v2/utils.go index 09353bdc..b5e3e319 100644 --- a/v2/utils.go +++ b/v2/utils.go @@ -570,9 +570,11 @@ func setFieldValue(fieldValue reflect.Value, cust *customType, input interface{} } if fieldValue.Kind() == reflect.Ptr && fieldValue.Elem().Kind() == reflect.Interface { fieldValue.Elem().Set(reflect.ValueOf(input)) + return nil } if fieldValue.Kind() == reflect.Interface { fieldValue.Set(reflect.ValueOf(input)) + return nil } //if fieldValue.CanAddr() { // if scan, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {