Skip to content

* Added validation of WithTxControl option in non-interactive methods… #1738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Added validation for the WithTxControl option in the non-interactive methods of Client and Session in the query service.

## v3.108.1
* Supported `json.Marshaller` query parameter in `database/sql` driver

Expand Down
2 changes: 1 addition & 1 deletion examples/basic/native/query/series.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func read(ctx context.Context, c query.Client, prefix string) error {
FROM
%s
`, "`"+path.Join(prefix, "series")+"`"),
query.WithTxControl(query.TxControl(query.BeginTx(query.WithSnapshotReadOnly()))),
query.WithTxControl(query.SnapshotReadOnlyTxControl()),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pay attention to the line. I have changed the txcontrol in the example.

)
if err != nil {
return err
Expand Down
35 changes: 35 additions & 0 deletions internal/query/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package query

import (
"context"
"errors"
"time"

"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
Expand All @@ -20,6 +21,7 @@ import (
"github.com/ydb-platform/ydb-go-sdk/v3/internal/types"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xreflect"
"github.com/ydb-platform/ydb-go-sdk/v3/query"
"github.com/ydb-platform/ydb-go-sdk/v3/retry"
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
Expand All @@ -32,6 +34,8 @@ var (
_ sessionPool = (*pool.Pool[*Session, Session])(nil)
)

var errNoCommit = xerrors.Wrap(errors.New("WithTxControl option is not allowed without CommitTx() option in Client methods, as these methods are non-interactive. You can either add the CommitTx() option to TxControl or use query.*TxControl methods (e.g., query.SnapshotReadOnlyTxControl) which already include the commit flag")) //nolint:lll
Copy link
Preview

Copilot AI May 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider splitting the error message or adding inline comments to clarify the error rationale for future maintainers.

Suggested change
var errNoCommit = xerrors.Wrap(errors.New("WithTxControl option is not allowed without CommitTx() option in Client methods, as these methods are non-interactive. You can either add the CommitTx() option to TxControl or use query.*TxControl methods (e.g., query.SnapshotReadOnlyTxControl) which already include the commit flag")) //nolint:lll
// errNoCommit is raised when the WithTxControl option is used without the CommitTx() option.
// This is because Client methods are non-interactive and require explicit commit control.
// To resolve this, either add the CommitTx() option to TxControl or use query.*TxControl methods
// (e.g., query.SnapshotReadOnlyTxControl) that already include the commit flag.
var errNoCommit = xerrors.Wrap(errors.New("WithTxControl option is not allowed without CommitTx() option in Client methods."))
var errNoCommitResolution = xerrors.Wrap(errors.New("You can either add the CommitTx() option to TxControl or use query.*TxControl methods (e.g., query.SnapshotReadOnlyTxControl) which already include the commit flag")) //nolint:lll

Copilot uses AI. Check for mistakes.


type (
sessionPool interface {
closer.Closer
Expand Down Expand Up @@ -173,6 +177,10 @@ func (c *Client) ExecuteScript(
),
}

if err := checkTxControlWithCommit(settings.TxControl()); err != nil {
return nil, err
}

request, grpcOpts, err := executeQueryScriptRequest(q, settings)
if err != nil {
return op, xerrors.WithStackTrace(err)
Expand Down Expand Up @@ -320,6 +328,10 @@ func (c *Client) QueryRow(ctx context.Context, q string, opts ...options.Execute

settings := options.ExecuteSettings(opts...)

if err := checkTxControlWithCommit(settings.TxControl()); err != nil {
return nil, err
}

onDone := trace.QueryOnQueryRow(c.config.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Client).QueryRow"),
q, settings.Label(),
Expand Down Expand Up @@ -366,6 +378,11 @@ func (c *Client) Exec(ctx context.Context, q string, opts ...options.Execute) (f
defer cancel()

settings := options.ExecuteSettings(opts...)

if err := checkTxControlWithCommit(settings.TxControl()); err != nil {
return err
}

onDone := trace.QueryOnExec(c.config.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Client).Exec"),
q,
Expand Down Expand Up @@ -415,6 +432,11 @@ func (c *Client) Query(ctx context.Context, q string, opts ...options.Execute) (
defer cancel()

settings := options.ExecuteSettings(opts...)

if err := checkTxControlWithCommit(settings.TxControl()); err != nil {
return nil, err
}

onDone := trace.QueryOnQuery(c.config.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Client).Query"),
q, settings.Label(),
Expand Down Expand Up @@ -470,6 +492,10 @@ func (c *Client) QueryResultSet(
err error
)

if err := checkTxControlWithCommit(settings.TxControl()); err != nil {
return nil, err
}

onDone := trace.QueryOnQueryResultSet(c.config.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Client).QueryResultSet"),
q, settings.Label(),
Expand Down Expand Up @@ -612,6 +638,15 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, cfg *config.Config) *
}
}

// checkTxControlWithCommit validates the transaction control object to ensure it includes a commit flag.
func checkTxControlWithCommit(txControl options.TxControl) error {
if !xreflect.IsContainsNilPointer(txControl) && !txControl.Commit() {
return xerrors.WithStackTrace(errNoCommit)
}

return nil
}

func poolTrace(t *trace.Query) *pool.Trace {
return &pool.Trace{
OnNew: func(ctx *context.Context, call stack.Caller) func(limit int) {
Expand Down
24 changes: 22 additions & 2 deletions internal/query/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ func (s *Session) QueryResultSet(
onDone(finalErr)
}()

r, err := s.execute(ctx, q, options.ExecuteSettings(opts...), withTrace(s.trace))
settings := options.ExecuteSettings(opts...)
if err := checkTxControlWithCommit(settings.TxControl()); err != nil {
return nil, err
}

r, err := s.execute(ctx, q, settings, withTrace(s.trace))
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
Expand Down Expand Up @@ -75,7 +80,12 @@ func (s *Session) QueryRow(ctx context.Context, q string, opts ...options.Execut
onDone(finalErr)
}()

row, err := s.queryRow(ctx, q, options.ExecuteSettings(opts...), withTrace(s.trace))
settings := options.ExecuteSettings(opts...)
if err := checkTxControlWithCommit(settings.TxControl()); err != nil {
return nil, err
}

row, err := s.queryRow(ctx, q, settings, withTrace(s.trace))
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
Expand Down Expand Up @@ -154,6 +164,11 @@ func (s *Session) execute(

func (s *Session) Exec(ctx context.Context, q string, opts ...options.Execute) (finalErr error) {
settings := options.ExecuteSettings(opts...)

if err := checkTxControlWithCommit(settings.TxControl()); err != nil {
return err
}

onDone := trace.QueryOnSessionExec(s.trace, &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Session).Exec"),
s,
Expand Down Expand Up @@ -182,6 +197,11 @@ func (s *Session) Exec(ctx context.Context, q string, opts ...options.Execute) (

func (s *Session) Query(ctx context.Context, q string, opts ...options.Execute) (_ query.Result, finalErr error) {
settings := options.ExecuteSettings(opts...)

if err := checkTxControlWithCommit(settings.TxControl()); err != nil {
return nil, err
}

onDone := trace.QueryOnSessionQuery(s.trace, &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Session).Query"),
s,
Expand Down
32 changes: 32 additions & 0 deletions internal/xreflect/is_nil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package xreflect

import "reflect"

func IsContainsNilPointer(v any) bool {
if v == nil {
return true
}

rVal := reflect.ValueOf(v)

return isValPointToNil(rVal)
}

func isValPointToNil(v reflect.Value) bool {
kind := v.Kind()
var res bool
switch kind {
case reflect.Slice:
return false
case reflect.Chan, reflect.Func, reflect.Map, reflect.UnsafePointer:
res = v.IsNil()
case reflect.Pointer, reflect.Interface:
elem := v.Elem()
if v.IsNil() {
return true
}
res = isValPointToNil(elem)
}

return res
}
92 changes: 92 additions & 0 deletions internal/xreflect/is_nil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package xreflect

import (
"testing"
)

func TestIsContainsNilPointer(t *testing.T) {
var nilIntPointer *int
vInterface := nilIntPointer

// Test cases for different nil and non-nil scenarios
tests := []struct {
name string
input any
expected bool
}{
{
name: "nil interface",
input: nil,
expected: true,
},
{
name: "nil pointer to int",
input: (*int)(nil),
expected: true,
},
{
name: "non-nil pointer to int",
input: new(int),
expected: false,
},
{
name: "nil slice",
input: []int(nil),
expected: false,
},
{
name: "empty slice",
input: []int{},
expected: false,
},
{
name: "nil map",
input: map[string]int(nil),
expected: true,
},
{
name: "empty map",
input: map[string]int{},
expected: false,
},
{
name: "nil channel",
input: (chan int)(nil),
expected: true,
},
{
name: "non-nil channel",
input: make(chan int),
expected: false,
},
{
name: "nil function",
input: (func())(nil),
expected: true,
},
{
name: "nested nil pointer",
input: &nilIntPointer,
expected: true,
},
{
name: "interface with stored nil pointer",
input: vInterface,
expected: true,
},
{
name: "non-nil interface value",
input: interface{}("test"),
expected: false,
},
}

// Execute all test cases
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsContainsNilPointer(tt.input); got != tt.expected {
t.Errorf("IsContainsNilPointer() = %v, want %v", got, tt.expected)
}
})
}
}
8 changes: 4 additions & 4 deletions tests/integration/database_sql_with_tx_control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) {
ydb.WithTxControl(
tx.WithTxControlHook(ctx, func(txControl *tx.Control) {
hookCalled = true
require.Equal(t, tx.SerializableReadWriteTxControl(), txControl)
require.Equal(t, tx.SerializableReadWriteTxControl(tx.CommitTx()), txControl)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added an explicit CommitTx to the test.

}),
tx.SerializableReadWriteTxControl(),
tx.SerializableReadWriteTxControl(tx.CommitTx()),
),
db, func(ctx context.Context, cc *sql.Conn) error {
_, err := db.QueryContext(ctx, "SELECT 1")
Expand All @@ -56,9 +56,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) {
ydb.WithTxControl(
tx.WithTxControlHook(ctx, func(txControl *tx.Control) {
hookCalled = true
require.Equal(t, tx.SerializableReadWriteTxControl(), txControl)
require.Equal(t, tx.SerializableReadWriteTxControl(tx.CommitTx()), txControl)
}),
tx.SerializableReadWriteTxControl(),
tx.SerializableReadWriteTxControl(tx.CommitTx()),
),
db, func(ctx context.Context, cc *sql.Conn) error {
_, err := db.QueryContext(ctx, "SELECT 1")
Expand Down
18 changes: 9 additions & 9 deletions tests/integration/query_regression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ DECLARE $val AS UUID;
SELECT CAST($val AS Utf8)`,
query.WithIdempotent(),
query.WithParameters(ydb.ParamsBuilder().Param("$val").UUIDWithIssue1501Value(id).Build()),
query.WithTxControl(tx.SerializableReadWriteTxControl()),
query.WithTxControl(tx.SnapshotReadOnlyTxControl()),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pay attention to the changes. I have changed txControl in tests

)

require.NoError(t, err)
Expand All @@ -71,7 +71,7 @@ DECLARE $val AS Text;
SELECT CAST($val AS UUID)`,
query.WithIdempotent(),
query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()),
query.WithTxControl(tx.SerializableReadWriteTxControl()),
query.WithTxControl(tx.SnapshotReadOnlyTxControl()),
)

require.NoError(t, err)
Expand All @@ -97,7 +97,7 @@ DECLARE $val AS Text;
SELECT CAST($val AS UUID)`,
query.WithIdempotent(),
query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()),
query.WithTxControl(tx.SerializableReadWriteTxControl()),
query.WithTxControl(tx.SnapshotReadOnlyTxControl()),
)

require.NoError(t, err)
Expand Down Expand Up @@ -125,7 +125,7 @@ DECLARE $val AS Text;
SELECT CAST($val AS UUID)`,
query.WithIdempotent(),
query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()),
query.WithTxControl(tx.SerializableReadWriteTxControl()),
query.WithTxControl(tx.SnapshotReadOnlyTxControl()),
)

require.NoError(t, err)
Expand All @@ -151,7 +151,7 @@ DECLARE $val AS Text;
SELECT CAST($val AS UUID)`,
query.WithIdempotent(),
query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()),
query.WithTxControl(tx.SerializableReadWriteTxControl()),
query.WithTxControl(tx.SnapshotReadOnlyTxControl()),
)

require.NoError(t, err)
Expand Down Expand Up @@ -180,7 +180,7 @@ DECLARE $val AS UUID;
SELECT $val`,
query.WithIdempotent(),
query.WithParameters(ydb.ParamsBuilder().Param("$val").UUIDWithIssue1501Value(id).Build()),
query.WithTxControl(tx.SerializableReadWriteTxControl()),
query.WithTxControl(tx.SnapshotReadOnlyTxControl()),
)

require.NoError(t, err)
Expand All @@ -207,7 +207,7 @@ DECLARE $val AS UUID;

SELECT CAST($val AS Utf8)`,
query.WithIdempotent(),
query.WithTxControl(query.SerializableReadWriteTxControl()),
query.WithTxControl(query.SnapshotReadOnlyTxControl()),
query.WithParameters(ydb.ParamsBuilder().Param("$val").Uuid(id).Build()),
)

Expand All @@ -233,7 +233,7 @@ DECLARE $val AS Utf8;
SELECT CAST($val AS UUID)`,
query.WithIdempotent(),
query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()),
query.WithTxControl(query.SerializableReadWriteTxControl()),
query.WithTxControl(query.SnapshotReadOnlyTxControl()),
)

require.NoError(t, err)
Expand Down Expand Up @@ -261,7 +261,7 @@ DECLARE $val AS UUID;
SELECT $val`,
query.WithIdempotent(),
query.WithParameters(ydb.ParamsBuilder().Param("$val").Uuid(id).Build()),
query.WithTxControl(query.SerializableReadWriteTxControl()),
query.WithTxControl(query.SnapshotReadOnlyTxControl()),
)

require.NoError(t, err)
Expand Down
Loading