diff --git a/sql/mapper.go b/sql/mapper.go index f0d540b..712e41d 100644 --- a/sql/mapper.go +++ b/sql/mapper.go @@ -131,7 +131,20 @@ func (m *mapper) handleCast(actual *expr.Call, args *[]driver.NamedValue, item * if len(*args) == 0 { return fmt.Errorf("missing cast argument %v", item.Alias) } - m.values[i] = (*args)[0].Value + + switch (*args)[0].Value.(type) { + case string: + m.values[i] = (*args)[0].Value.(string) + case time.Time: + m.values[i] = (*args)[0].Value.(time.Time).Format(time.RFC3339) + case *time.Time: + if ts := (*args)[0].Value; ts != nil { + m.values[i] = ts.(*time.Time).Format(time.RFC3339) + } + + default: + m.values[i] = (*args)[0].Value + } *args = (*args)[1:] } case "int": @@ -180,6 +193,38 @@ func (m *mapper) handleCast(actual *expr.Call, args *[]driver.NamedValue, item * } else { return fmt.Errorf("%v unsupported cast argument type: %T", item.Alias, actual.Args[0]) } + case "time", "datetime", "timestamp": + if _, ok := actual.Args[0].(*expr.Placeholder); ok { + if len(*args) == 0 { + return fmt.Errorf("missing cast argument %v", item.Alias) + } + + var tValue *time.Time + switch (*args)[0].Value.(type) { + case time.Time: + t := (*args)[0].Value.(time.Time) + tValue = &t + case *time.Time: + tValue = (*args)[0].Value.(*time.Time) + + case string: + t, err := time.Parse(time.RFC3339, (*args)[0].Value.(string)) + if err != nil { + return fmt.Errorf("%v invalid time: %v %w", item.Alias, (*args)[0].Value, err) + } + tValue = &t + default: + return fmt.Errorf("%v unsupported int argument type: %T", item.Alias, (*args)[0].Value) + } + if tValue == nil { + m.values[i] = nil + } else { + m.values[i] = *tValue + } + *args = (*args)[1:] + } else { + return fmt.Errorf("%v unsupported cast argument type: %T", item.Alias, actual.Args[0]) + } case "float": if _, ok := actual.Args[0].(*expr.Placeholder); ok { diff --git a/sql/rows.go b/sql/rows.go index 76c78a6..038a97d 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -89,8 +89,8 @@ func (r *Rows) Next(dest []driver.Value) error { if r.resourceIndex >= len(r.resources) || r.isFalsePredicate { return io.EOF } - if len(dest) != len(r.mapper.byPos)+len(r.mapper.values) { - return fmt.Errorf("expected %v, but had %v", len(r.mapper.byPos)+len(r.mapper.values), len(dest)) + if len(dest) != len(r.mapper.byPos) { + return fmt.Errorf("expected %v, but had %v", len(r.mapper.byPos), len(dest)) } res := r.resource() has, err := res.Next() diff --git a/sql/statement.go b/sql/statement.go index 5953bf7..297f332 100644 --- a/sql/statement.go +++ b/sql/statement.go @@ -271,6 +271,8 @@ func (s *Statement) autodetectType(ctx context.Context, res resources) (reflect. raw = strings.TrimSpace(raw[idx+4 : len(raw)-1]) } switch raw { + case "time", "datetime", "timestamp": + field = reflect.StructField{Name: item.Alias, Type: reflect.TypeOf(time.Time{})} case "char": field = reflect.StructField{Name: item.Alias, Type: reflect.TypeOf("")} case "int":