Skip to content

feat: support configuring custom time.Location for datetime encoding and decoding via DSN #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Other supported formats are listed below.
* `true`/`mandatory`/`yes`/`1`/`t` - Data sent between client and server is encrypted.
* `app name` - The application name (default is go-mssqldb)
* `authenticator` - Can be used to specify use of a registered authentication provider. (e.g. ntlm, winsspi (on windows) or krb5 (on linux))
* `timezone` - Sets the time zone used by the driver when parsing time types. For example: `timezone=Asia/Shanghai`. Supports [IANA](https://www.iana.org/time-zones) time zone names.

### Connection parameters for ODBC and ADO style connection strings

Expand Down
15 changes: 8 additions & 7 deletions bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ const (
sqlTimeFormat = "15:04:05.9999999"
)

func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns}
func (c *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
b := Bulk{ctx: context.Background(), cn: c, tablename: table, headerSent: false, columnsName: columns}
b.Debug = false
return &b
}

func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) {
b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns}
func (c *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) {
b := Bulk{ctx: ctx, cn: c, tablename: table, headerSent: false, columnsName: columns}
b.Debug = false
return &b
}
Expand Down Expand Up @@ -207,7 +207,7 @@ func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x",
col.ColName, col.ti.TypeId)
}
err = col.ti.Writer(buf, param.ti, param.buffer)
err = col.ti.Writer(buf, param.ti, param.buffer, b.cn.sess.encoding)
if err != nil {
return nil, fmt.Errorf("bulkcopy: %s", err.Error())
}
Expand Down Expand Up @@ -318,6 +318,7 @@ func (b *Bulk) getMetadata(ctx context.Context) (err error) {
func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) {
res.ti.Size = col.ti.Size
res.ti.TypeId = col.ti.TypeId
loc := getTimezone(b.cn)

switch valuer := val.(type) {
case driver.Valuer:
Expand Down Expand Up @@ -487,7 +488,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
res.ti.Size = len(res.buffer)
case string:
var t time.Time
if t, err = time.ParseInLocation(sqlDateFormat, val, time.UTC); err != nil {
if t, err = time.ParseInLocation(sqlDateFormat, val, loc); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
}
res.buffer = encodeDate(t)
Expand All @@ -511,7 +512,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
}

if col.ti.Size == 4 {
res.buffer = encodeDateTim4(t)
res.buffer = encodeDateTim4(t, loc)
res.ti.Size = len(res.buffer)
} else if col.ti.Size == 8 {
res.buffer = encodeDateTime(t)
Expand Down
26 changes: 26 additions & 0 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,21 @@ const (
MultiSubnetFailover = "multisubnetfailover"
NoTraceID = "notraceid"
GuidConversion = "guid conversion"
Timezone = "timezone"
)

type EncodeParameters struct {
// Properly convert GUIDs, using correct byte endianness
GuidConversion bool
// Timezone is the timezone to use for encoding and decoding datetime values.
Timezone *time.Location
}

func (e EncodeParameters) GetTimezone() *time.Location {
if e.Timezone == nil {
return time.UTC
}
return e.Timezone
}

type Config struct {
Expand Down Expand Up @@ -301,6 +311,9 @@ func Parse(dsn string) (Config, error) {
p := Config{
ProtocolParameters: map[string]interface{}{},
Protocols: []string{},
Encoding: EncodeParameters{
Timezone: time.UTC,
},
}

activityid, uerr := uuid.NewRandom()
Expand All @@ -325,6 +338,15 @@ func Parse(dsn string) (Config, error) {
p.LogFlags = Log(flags)
}

tz, ok := params[Timezone]
if ok {
location, err := time.LoadLocation(tz)
if err != nil {
return p, fmt.Errorf("invalid timezone '%s': %s", tz, err.Error())
}
p.Encoding.Timezone = location
}

p.Database = params[Database]
p.User = params[UserID]
p.Password = params[Password]
Expand Down Expand Up @@ -612,6 +634,10 @@ func (p Config) URL() *url.URL {
q.Add(GuidConversion, strconv.FormatBool(p.Encoding.GuidConversion))
}

if tz := p.Encoding.Timezone; tz != nil && tz != time.UTC {
q.Add(Timezone, tz.String())
}

if len(q) > 0 {
res.RawQuery = q.Encode()
}
Expand Down
5 changes: 5 additions & 0 deletions msdsn/conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func TestInvalidConnectionString(t *testing.T) {
"applicationintent=ReadOnly",
"disableretry=invalid",
"multisubnetfailover=invalid",
"timezone=invalid",

// ODBC mode
"odbc:password={",
Expand All @@ -35,6 +36,7 @@ func TestInvalidConnectionString(t *testing.T) {
"odbc:=", // unexpected =
"odbc: =",
"odbc:password={some} a",
"odbc:timezone=invalid",

// URL mode
"sqlserver://\x00",
Expand Down Expand Up @@ -107,6 +109,7 @@ func TestValidConnectionString(t *testing.T) {
{"", func(p Config) bool { return p.DisableRetry == disableRetryDefault }},
{"MultiSubnetFailover=true;NoTraceID=true", func(p Config) bool { return p.MultiSubnetFailover && p.NoTraceID }},
{"MultiSubnetFailover=false", func(p Config) bool { return !p.MultiSubnetFailover }},
{"timezone=Asia/Shanghai", func(p Config) bool { return p.Encoding.Timezone.String() == "Asia/Shanghai" }},
{"Pwd=placeholder", func(p Config) bool { return p.Password == "placeholder" }},
// those are supported currently, but maybe should not be
{"someparam", func(p Config) bool { return true }},
Expand Down Expand Up @@ -164,6 +167,7 @@ func TestValidConnectionString(t *testing.T) {
{"odbc:server=somehost;user id=someuser;password=somepass; disableretry = 1 ", func(p Config) bool {
return p.Host == "somehost" && p.User == "someuser" && p.Password == "somepass" && p.DisableRetry
}},
{"odbc:timezone={Asia/Shanghai}", func(p Config) bool { return p.Encoding.Timezone.String() == "Asia/Shanghai" }},

// URL mode
{"sqlserver://somehost?connection+timeout=30", func(p Config) bool {
Expand Down Expand Up @@ -196,6 +200,7 @@ func TestValidConnectionString(t *testing.T) {
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1&guid+conversion=true", func(p Config) bool {
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption && p.Encoding.GuidConversion
}},
{"sqlserver://someuser@somehost?timezone=Asia%2FShanghai", func(p Config) bool { return p.Encoding.Timezone.String() == "Asia/Shanghai" }},
}
for _, ts := range connStrings {
p, err := Parse(ts.connStr)
Expand Down
1 change: 1 addition & 0 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.Size = 0
return
}

switch valuer := val.(type) {
// sql.Nullxxx integer types return an int64. We want the original type, to match the SQL type size.
case sql.NullByte:
Expand Down
6 changes: 4 additions & 2 deletions mssql_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
}

func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
loc := getTimezone(s.c)

switch val := val.(type) {
case VarChar:
res.ti.TypeId = typeBigVarChar
Expand Down Expand Up @@ -174,12 +176,12 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
res.ti.Size = len(res.buffer)
case civil.Date:
res.ti.TypeId = typeDateN
res.buffer = encodeDate(val.In(time.UTC))
res.buffer = encodeDate(val.In(loc))
res.ti.Size = len(res.buffer)
case civil.DateTime:
res.ti.TypeId = typeDateTime2N
res.ti.Scale = 7
res.buffer = encodeDateTime2(val.In(time.UTC), int(res.ti.Scale))
res.buffer = encodeDateTime2(val.In(loc), int(res.ti.Scale))
res.ti.Size = len(res.buffer)
case civil.Time:
res.ti.TypeId = typeTimeN
Expand Down
89 changes: 89 additions & 0 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2809,3 +2809,92 @@ func TestAdminConnection(t *testing.T) {
t.Fatalf("Tcp connection not made. Protocol: %s", protocol)
}
}

func TestCustomTimezone(t *testing.T) {

t.Run("without custom timezone", func(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()
_, err := conn.Exec("create table test (ts datetime)")
defer conn.Exec("drop table test")
if err != nil {
t.Fatal("create table failed with error", err)
}

inputTime := time.Date(2025, 5, 26, 15, 30, 0, 0, time.FixedZone("UTC+8", 8*60*60))
_, err = conn.Exec("insert into test (ts) values (@ts)", sql.Named("ts", inputTime.Format("2006-01-02 15:04:05")))
if err != nil {
t.Fatal("insert failed:", err)
}

var resultTime time.Time
err = conn.QueryRow("select ts from test").Scan(&resultTime)
if err != nil {
t.Fatal("QueryRow failed:", err)
}

if inputTime.Truncate(time.Second).Equal(resultTime) {
t.Errorf("Expected result time to differ from input time due to timezone loss,\ninput: %v\nresult: %v", inputTime, resultTime)
} else {
t.Logf("Input time and result time differ as expected:\ninput: %v\nresult: %v", inputTime, resultTime)
}
})

t.Run("with custom timezone", func(t *testing.T) {
t.Setenv("TIME_ZONE", "Asia/Shanghai") // UTC+8 timezone
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()
_, err := conn.Exec("create table test (ts datetime)")
defer conn.Exec("drop table test")
if err != nil {
t.Fatal("create table failed with error", err)
}

inputTime := time.Date(2025, 5, 26, 15, 30, 0, 0, time.FixedZone("UTC+8", 8*60*60))
_, err = conn.Exec("insert into test (ts) values (@ts)", sql.Named("ts", inputTime.Format("2006-01-02 15:04:05")))
if err != nil {
t.Fatal("insert failed:", err)
}

var resultTime time.Time
err = conn.QueryRow("select ts from test").Scan(&resultTime)
if err != nil {
t.Fatal("QueryRow failed:", err)
}

if !inputTime.Truncate(time.Second).Equal(resultTime) {
t.Errorf("Expected result time to match input time with custom timezone,\ninput: %v\nresult: %v", inputTime, resultTime)
}
})

t.Run("datetimeoffset with custom timezone", func(t *testing.T) {
t.Setenv("TIME_ZONE", "Asia/Shanghai")
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()
_, err := conn.Exec("create table test_offset (ts datetimeoffset)")
defer conn.Exec("drop table test_offset")
if err != nil {
t.Fatal(err)
}

inputTime := time.Date(2025, 5, 26, 15, 30, 0, 0, time.FixedZone("UTC+5", 5*60*60))
_, err = conn.Exec("insert into test_offset (ts) values (@ts)", sql.Named("ts", inputTime))
if err != nil {
t.Fatal(err)
}

var resultTime time.Time
err = conn.QueryRow("select ts from test_offset").Scan(&resultTime)
if err != nil {
t.Fatal(err)
}

if !inputTime.Equal(resultTime) {
t.Errorf("expected %v, got %v", inputTime, resultTime)
}
})

}
2 changes: 1 addition & 1 deletion rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
if err != nil {
return
}
err = param.ti.Writer(buf, param.ti, param.buffer)
err = param.ti.Writer(buf, param.ti, param.buffer, encoding)
if err != nil {
return
}
Expand Down
9 changes: 9 additions & 0 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ func GetConnParams() (*msdsn.Config, error) {
Password: os.Getenv("SQLPASSWORD"),
LogFlags: logFlags,
Parameters: make(map[string]string),
Encoding: msdsn.EncodeParameters{
Timezone: time.UTC,
},
}
if c.Instance == "" {
c.Instance = os.Getenv("SQLINSTANCE")
Expand All @@ -332,6 +335,12 @@ func GetConnParams() (*msdsn.Config, error) {
if os.Getenv("COLUMNENCRYPTION") != "" {
c.ColumnEncryption = true
}
if os.Getenv("TIME_ZONE") != "" {
tz, err := time.LoadLocation(os.Getenv("TIME_ZONE"))
if err == nil {
c.Encoding.Timezone = tz
}
}
return c, nil
}
// try loading connection string from file
Expand Down
10 changes: 10 additions & 0 deletions timezone.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package mssql

import "time"

func getTimezone(c *Conn) *time.Location {
if c != nil && c.sess != nil {
return c.sess.encoding.GetTimezone()
}
return time.UTC
}
10 changes: 5 additions & 5 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ func readCekTableEntry(r *tdsBuffer) cekTableEntry {
// http://msdn.microsoft.com/en-us/library/dd357254.aspx
func parseRow(ctx context.Context, r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) error {
for i, column := range columns {
columnContent := column.ti.Reader(&column.ti, r, nil)
columnContent := column.ti.Reader(&column.ti, r, nil, s.encoding)
if columnContent == nil {
row[i] = columnContent
continue
Expand All @@ -792,7 +792,7 @@ func parseRow(ctx context.Context, r *tdsBuffer, s *tdsSession, columns []column
return err
}
// Decrypt
row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, buffer, column.cryptoMeta)
row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, buffer, column.cryptoMeta, s.encoding)
} else {
row[i] = columnContent
}
Expand Down Expand Up @@ -865,14 +865,14 @@ func parseNbcRow(ctx context.Context, r *tdsBuffer, s *tdsSession, columns []col
row[i] = nil
continue
}
columnContent := col.ti.Reader(&col.ti, r, nil)
columnContent := col.ti.Reader(&col.ti, r, nil, s.encoding)
if col.isEncrypted() {
buffer, err := decryptColumn(ctx, col, s, columnContent)
if err != nil {
return err
}
// Decrypt
row[i] = col.cryptoMeta.typeInfo.Reader(&col.cryptoMeta.typeInfo, buffer, col.cryptoMeta)
row[i] = col.cryptoMeta.typeInfo.Reader(&col.cryptoMeta.typeInfo, buffer, col.cryptoMeta, s.encoding)
} else {
row[i] = columnContent
}
Expand Down Expand Up @@ -933,7 +933,7 @@ func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) {
}

ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata, s.encoding)
nv.Value = ti2.Reader(&ti2, r, cryptoMetadata)
nv.Value = ti2.Reader(&ti2, r, cryptoMetadata, s.encoding)

return
}
Expand Down
2 changes: 1 addition & 1 deletion tvp_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
if err != nil {
return nil, fmt.Errorf("failed to make tvp parameter row col: %s", err)
}
columnStr[columnStrIdx].ti.Writer(buf, param.ti, param.buffer)
columnStr[columnStrIdx].ti.Writer(buf, param.ti, param.buffer, encoding)
}
}
buf.WriteByte(_TVP_END_TOKEN)
Expand Down
Loading