Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8892,6 +8892,18 @@ from typestable`,
{"abc"},
},
},
{
Query: "select count(distinct cast(i as decimal)) from mytable;",
Expected: []sql.Row{
{3},
},
},
{
Query: "select count(distinct null);",
Expected: []sql.Row{
{0},
},
},
}

var KeylessQueries = []QueryTest{
Expand Down
21 changes: 21 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -5167,6 +5167,27 @@ CREATE TABLE tab3 (
},
},
},
{
Name: "count distinct decimals",
SetUpScript: []string{
"create table t (i int, j int)",
"insert into t values (1, 11), (11, 1)",
},
Assertions: []ScriptTestAssertion{
{
Query: "select count(distinct i, j) from t;",
Expected: []sql.Row{
{2},
},
},
{
Query: "select count(distinct cast(i as decimal), cast(j as decimal)) from t;",
Expected: []sql.Row{
{2},
},
},
},
},
}

var SpatialScriptTests = []ScriptTest{
Expand Down
2 changes: 1 addition & 1 deletion sql/expression/function/aggregation/count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestCountDistinctEvalStar(t *testing.T) {
require.NoError(b.Update(ctx, sql.NewRow(1)))
require.NoError(b.Update(ctx, sql.NewRow(nil)))
require.NoError(b.Update(ctx, sql.NewRow(1, 2, 3)))
require.Equal(int64(5), evalBuffer(t, b))
require.Equal(int64(4), evalBuffer(t, b))
}

func TestCountDistinctEvalString(t *testing.T) {
Expand Down
30 changes: 24 additions & 6 deletions sql/expression/function/aggregation/unary_agg_buffers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"fmt"
"reflect"

"github.com/mitchellh/hashstructure"
"github.com/cespare/xxhash/v2"
"github.com/shopspring/decimal"

"github.com/dolthub/go-mysql-server/sql"
Expand Down Expand Up @@ -402,7 +402,7 @@ func (c *countDistinctBuffer) Update(ctx *sql.Context, row sql.Row) error {
if _, ok := c.exprs[0].(*expression.Star); ok {
value = row
} else {
val := make([]interface{}, len(c.exprs))
val := make(sql.Row, len(c.exprs))
for i, expr := range c.exprs {
v, err := expr.Eval(ctx, row)
if err != nil {
Expand All @@ -417,12 +417,30 @@ func (c *countDistinctBuffer) Update(ctx *sql.Context, row sql.Row) error {
value = val
}

hash, err := hashstructure.Hash(value, nil)
if err != nil {
return fmt.Errorf("count distinct unable to hash value: %s", err)
var str string
for _, val := range value.(sql.Row) {
// skip nil values
if val == nil {
return nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test for this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and/or comment confirming that any nil in a row won't contribute to distinct count

}
v, _, err := types.Text.Convert(val)
if err != nil {
return err
}
vv, ok := v.(string)
if !ok {
return fmt.Errorf("count distinct unable to hash value: %s", err)
}
str += vv + ","
}

c.seen[hash] = struct{}{}
hash := xxhash.New()
_, err := hash.WriteString(str)
if err != nil {
return err
}
h := hash.Sum64()
c.seen[h] = struct{}{}

return nil
}
Expand Down