Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8818,6 +8818,14 @@ from typestable`,
{"6.141592653589793"},
},
},
{
Query: "select exp(i) from mytable;",
Expected: []sql.Row{
{math.Exp(1)},
{math.Exp(2)},
{math.Exp(3)},
},
},
}

var KeylessQueries = []QueryTest{
Expand Down
66 changes: 66 additions & 0 deletions sql/expression/function/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,3 +867,69 @@ func (p *Pi) Children() []sql.Expression {
func (p *Pi) WithChildren(children ...sql.Expression) (sql.Expression, error) {
return sql.NillaryWithChildren(p, children...)
}

type Exp struct {
*UnaryFunc
}

func NewExp(arg sql.Expression) sql.Expression {
return &Exp{NewUnaryFunc(arg, "EXP", types.Float64)}
}

var _ sql.FunctionExpression = (*Exp)(nil)
var _ sql.CollationCoercible = (*Exp)(nil)

// Description implements sql.FunctionExpression
func (e *Exp) Description() string {
return "returns e raised to the power of the argument given."
}

// Type implements the Expression interface.
func (e *Exp) Type() sql.Type {
return types.Float64
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (e *Exp) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
}

// Eval implements the Expression interface.
func (e *Exp) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
if e.Child == nil {
return nil, nil
}

val, err := e.Child.Eval(ctx, row)
if err != nil {
return nil, err
}

if val == nil {
return nil, err
}

v, _, err := types.Float64.Convert(val)
if err != nil {
// TODO: truncate
ctx.Warn(1292, "Truncated incorrect DOUBLE value: '%v'", val)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nil check maybe

v = 0.0
}

vv := v.(float64)
res := math.Exp(vv)

if math.IsNaN(res) || math.IsInf(res, 0) {
return nil, nil
}

return res, nil
}

// WithChildren implements the Expression interface.
func (e *Exp) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1)
}
return NewExp(children[0]), nil
}
100 changes: 100 additions & 0 deletions sql/expression/function/math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,103 @@ func TestPi(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, -1.0, res)
}

func TestExp(t *testing.T) {
tests := []struct {
name string
arg sql.Expression
exp interface{}
err bool
skip bool
}{
{
name: "null argument",
arg: nil,
exp: nil,
},
{
name: "zero",
arg: expression.NewLiteral(int64(0), types.Int64),
exp: math.Exp(0),
},
{
name: "one",
arg: expression.NewLiteral(int64(1), types.Int64),
exp: math.Exp(1),
},
{
name: "ten",
arg: expression.NewLiteral(int64(10), types.Int64),
exp: math.Exp(10),
},
{
name: "negative",
arg: expression.NewLiteral(int64(-1), types.Int64),
exp: math.Exp(-1),
},
{
name: "float64 1.1",
arg: expression.NewLiteral(1.1, types.Float64),
exp: math.Exp(1.1),
},
{
name: "decimal 1.1",
arg: expression.NewLiteral(decimal.NewFromFloat(1.1), types.DecimalType_{}),
exp: math.Exp(1.1),
},
{
name: "float64 -12.34",
arg: expression.NewLiteral(-12.34, types.Float64),
exp: math.Exp(-12.34),
},
{
name: "decimal is -12.34",
arg: expression.NewLiteral(decimal.NewFromFloat(-12.34), types.DecimalType_{}),
exp: math.Exp(-12.34),
},
{
name: "invalid string is 0",
arg: expression.NewLiteral("notanumber", types.Text),
exp: math.Exp(0),
},
{
name: "empty string",
arg: expression.NewLiteral("", types.Text),
exp: math.Exp(0),
},
{
name: "numerical string",
arg: expression.NewLiteral("10", types.Text),
exp: math.Exp(10),
},
{
// we don't do truncation yet
// https://github.com/dolthub/dolt/issues/7302
name: "scientific string is truncated",
arg: expression.NewLiteral("1e1", types.Text),
exp: "",
err: false,
skip: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skip {
t.Skip()
}

ctx := sql.NewEmptyContext()
f := NewExp(tt.arg)

res, err := f.Eval(ctx, nil)
if tt.err {
require.Error(t, err)
return
}

require.NoError(t, err)
require.Equal(t, tt.exp, res)
})
}
}
1 change: 1 addition & 0 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ var BuiltIns = []sql.Function{
sql.Function1{Name: "dayofweek", Fn: NewDayOfWeek},
sql.Function1{Name: "dayofyear", Fn: NewDayOfYear},
sql.Function1{Name: "degrees", Fn: NewDegrees},
sql.Function1{Name: "exp", Fn: NewExp},
sql.Function2{Name: "extract", Fn: NewExtract},
sql.Function2{Name: "find_in_set", Fn: NewFindInSet},
sql.Function1{Name: "first", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewFirst(e) }},
Expand Down