Skip to content

Commit

Permalink
validate sql (#21)
Browse files Browse the repository at this point in the history
validate sql
  • Loading branch information
scottlepp authored Oct 21, 2024
1 parent 50cd057 commit 31cc2a2
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 26 deletions.
106 changes: 102 additions & 4 deletions duck/duckdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ import (
"fmt"
"os"
"os/exec"
"strings"

"github.com/grafana/grafana-plugin-sdk-go/backend/log"
sdk "github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana-plugin-sdk-go/data/framestruct"
"github.com/hairyhenderson/go-which"
"github.com/iancoleman/orderedmap"
"github.com/jeremywohl/flatten"
"github.com/scottlepp/go-duck/duck/data"
)

Expand Down Expand Up @@ -105,6 +107,10 @@ func (d *DuckDB) Query(query string) (string, error) {

// QueryFrame will load a dataframe into a view named RefID, and run the query against that view
func (d *DuckDB) QueryFrames(name string, query string, frames []*sdk.Frame) (string, bool, error) {
err := d.validate(query)
if err != nil {
return "", false, err
}
data := FrameData{
cacheDuration: d.cacheDuration,
cache: &d.cache,
Expand All @@ -123,15 +129,21 @@ func wipe(dirs map[string]string) {
}
}

func (d *DuckDB) QueryFramesInto(name string, query string, frames []*sdk.Frame, f *sdk.Frame) error {
func (d *DuckDB) QueryFramesToFrames(name string, query string, frames []*sdk.Frame) (*sdk.Frame, error) {
err := d.validate(query)
if err != nil {
return nil, err
}

f := &sdk.Frame{}
res, cached, err := d.QueryFrames(name, query, frames)
if err != nil {
return err
return nil, err
}

err = resultsToFrame(name, res, f, frames)
if err != nil {
return err
return nil, err
}
if cached {
for _, frame := range frames {
Expand All @@ -145,7 +157,7 @@ func (d *DuckDB) QueryFramesInto(name string, query string, frames []*sdk.Frame,
frame.Meta.Notices = append(frame.Meta.Notices, notice)
}
}
return nil
return f, nil
}

// Destroy will remove database files created by duckdb
Expand Down Expand Up @@ -294,3 +306,89 @@ func getTempDir() string {
}
return temp
}

const (
TABLE_NAME = "table_name"
ERROR = ".error"
ERROR_MESSAGE = ".error_message"
)

func (d *DuckDB) validate(rawSQL string) error {
rawSQL = strings.Replace(rawSQL, "'", "''", -1)
cmd := fmt.Sprintf("SELECT json_serialize_sql('%s')", rawSQL)
ret, err := d.RunCommands([]string{cmd})
if err != nil {
logger.Error("error validating sql", "error", err.Error(), "sql", rawSQL, "cmd", cmd)
return fmt.Errorf("error validating sql: %s", err.Error())
}

result := []map[string]any{}
err = json.Unmarshal([]byte(ret), &result)
if err != nil {
logger.Error("error converting json sql to ast", "error", err.Error(), "ret", ret)
return fmt.Errorf("error converting json to ast: %s", err.Error())
}

if len(result) == 0 {
logger.Error("no ast returned", "ret", ret)
}

var ast map[string]any
for _, v := range result[0] {
validAst, ok := v.(map[string]any)
if !ok {
logger.Error("invalid sql", "sql", ret)
return fmt.Errorf("invalid sql: %s", ret)
}
ast = validAst
break
}

errMsg := ast["error"]
if errMsg != nil {
errMsgBool, ok := errMsg.(bool)
if !ok {
logger.Error("error in ast", "error", ret)
return fmt.Errorf("error in ast: %v", ret)
}
if errMsgBool {
logger.Error("error in ast", "error", ret)
return fmt.Errorf("error in ast: %v", ret)
}
}

statements := ast["statements"]
if statements == nil {
logger.Error("no statements in ast", "ast", ast)
return fmt.Errorf("no statements in ast: %v", ast)
}

flat, err := flatten.Flatten(ast, "", flatten.DotStyle)
if err != nil {
logger.Error("error flattening ast", "error", err.Error(), "ast", ast)
return fmt.Errorf("error flattening ast: %s", err.Error())
}

for k, v := range flat {
if strings.HasSuffix(k, ERROR) {
v, ok := v.(bool)
if ok && v {
logger.Error("error in sql", "error", k)
return fmt.Errorf("error flattening ast: %s", k)
}
}
if strings.Contains(k, "from_table.function.function_name") {
logger.Error("function not allowed", "function", v)
return fmt.Errorf("function not allowed: %s", v)
}
if strings.HasSuffix(k, "from_table.table_name") {
v, ok := v.(string)
if ok && strings.Contains(v, ".") {
logger.Error("table names with . not allowed", "table", v)
return fmt.Errorf("table names with . not allowed: %s", v)
}
}
}

return nil
}
107 changes: 85 additions & 22 deletions duck/duckdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ func TestCommands(t *testing.T) {
assert.Contains(t, res, `[{"i":1,"j":5}]`)
}

func TestDotCommands(t *testing.T) {
db := NewInMemoryDB()

commands := []string{
".databases",
}
res, err := db.RunCommands(commands)
if err != nil {
t.Fail()
return
}
assert.Contains(t, res, `memory`)
}

func TestCommandsDocker(t *testing.T) {
db := NewInMemoryDB(Opts{Docker: true})

Expand Down Expand Up @@ -74,6 +88,66 @@ func TestQueryFrame(t *testing.T) {
assert.Contains(t, res, `[{"value":"test"}]`)
}

func TestQueryAgg(t *testing.T) {
db := NewInMemoryDB()

var values = []string{"test"}
frame := data.NewFrame("foo", data.NewField("value", nil, values))
frame.RefID = "foo"
frames := []*data.Frame{frame}

res, _, err := db.QueryFrames("foo", "select min(value) as value from foo", frames)
assert.Nil(t, err)

assert.Contains(t, res, `[{"value":"test"}]`)
}

func TestQueryJson(t *testing.T) {
db := NewInMemoryDB()

var values = []string{"test"}
frame := data.NewFrame("foo", data.NewField("value", nil, values))
frame.RefID = "foo"
frames := []*data.Frame{frame}

_, _, err := db.QueryFrames("foo", "SELECT * FROM read_json('todos.json')", frames)
assert.NotNil(t, err)
}

func TestValid(t *testing.T) {
db := NewInMemoryDB()

var values = []string{"test"}
frame := data.NewFrame("foo", data.NewField("value", nil, values))
frame.RefID = "foo"
frames := []*data.Frame{frame}

query := fmt.Sprintf(".databases %s", newline)
_, _, err := db.QueryFrames("foo", query, frames)
assert.NotNil(t, err)
}

func TestQueryFrameNoFileRead(t *testing.T) {
db := NewInMemoryDB()

var values = []string{"test"}
frame := data.NewFrame("foo", data.NewField("value", nil, values))
frame.RefID = "foo"
frames := []*data.Frame{frame}

_, _, err := db.QueryFrames("foo", "SELECT * FROM read_csv('flights.csv')", frames)
assert.NotNil(t, err)

_, _, err = db.QueryFrames("foo", "SELECT * FROM read_json('flights.json')", frames)
assert.NotNil(t, err)

_, _, err = db.QueryFrames("foo", "SELECT * FROM 'test.parquet'", frames)
assert.NotNil(t, err)

_, _, err = db.QueryFrames("foo", "COPY test FROM 'test.parquet'", frames)
assert.NotNil(t, err)
}

func TestQueryFrameCache(t *testing.T) {
opts := Opts{
CacheDuration: 5,
Expand Down Expand Up @@ -153,8 +227,7 @@ func TestQueryFrameIntoFrame(t *testing.T) {

frames := []*data.Frame{frame, frame2}

model := &data.Frame{}
err := db.QueryFramesInto("foo", "select * from foo order by value desc", frames, model)
model, err := db.QueryFramesToFrames("foo", "select * from foo order by value desc", frames)
assert.Nil(t, err)

assert.Equal(t, 2, model.Rows())
Expand All @@ -178,8 +251,7 @@ func TestQueryFrameIntoFrameDocker(t *testing.T) {

frames := []*data.Frame{frame, frame2}

model := &data.Frame{}
err := db.QueryFramesInto("foo", "select * from foo order by value desc", frames, model)
model, err := db.QueryFramesToFrames("foo", "select * from foo order by value desc", frames)
assert.Nil(t, err)

assert.Equal(t, 2, model.Rows())
Expand All @@ -203,8 +275,7 @@ func TestQueryFrameIntoFrameMultipleColumns(t *testing.T) {

frames := []*data.Frame{frame}

model := &data.Frame{}
err := db.QueryFramesInto("B", "select * from A", frames, model)
model, err := db.QueryFramesToFrames("B", "select * from A", frames)
assert.Nil(t, err)

assert.Equal(t, "Z State", model.Fields[0].Name)
Expand All @@ -230,8 +301,7 @@ func TestMultiFrame(t *testing.T) {

frames := []*data.Frame{frame, frame2}

model := &data.Frame{}
err := db.QueryFramesInto("foo", "select * from foo", frames, model)
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
assert.Nil(t, err)

assert.Equal(t, 2, model.Rows())
Expand All @@ -257,8 +327,7 @@ func TestMultiFrame2(t *testing.T) {

frames := []*data.Frame{frame, frame2}

model := &data.Frame{}
err := db.QueryFramesInto("foo", "select * from foo", frames, model)
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
assert.Nil(t, err)

assert.Equal(t, 2, model.Rows())
Expand All @@ -281,8 +350,7 @@ func TestTimestamps(t *testing.T) {

frames := []*data.Frame{frame}

model := &data.Frame{}
err = db.QueryFramesInto("foo", "select * from foo", frames, model)
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
assert.Nil(t, err)

assert.Equal(t, 1, model.Rows())
Expand Down Expand Up @@ -311,8 +379,7 @@ func TestTimeSeries(t *testing.T) {

frames := []*data.Frame{frame}

model := &data.Frame{}
err = db.QueryFramesInto("foo", "select * from foo", frames, model)
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
assert.Nil(t, err)

assert.Equal(t, data.FrameTypeTimeSeriesWide, model.Meta.Type)
Expand Down Expand Up @@ -341,8 +408,7 @@ func TestTimeSeriesWide(t *testing.T) {

frames := []*data.Frame{frame}

model := &data.Frame{}
err = db.QueryFramesInto("foo", "select * from foo", frames, model)
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
assert.Nil(t, err)

assert.Equal(t, data.FrameTypeTimeSeriesWide, model.Meta.Type)
Expand Down Expand Up @@ -377,8 +443,7 @@ func TestLabels(t *testing.T) {

frames := []*data.Frame{frame, frame2}

model := &data.Frame{}
err := db.QueryFramesInto("foo", "select * from foo", frames, model)
model, err := db.QueryFramesToFrames("foo", "select * from foo", frames)
assert.Nil(t, err)

assert.Equal(t, 2, model.Rows())
Expand Down Expand Up @@ -427,8 +492,7 @@ func TestLabelsMultiFrame(t *testing.T) {
frames := []*data.Frame{frame, frame2}

// TODO - ordering is broken!
model := &data.Frame{}
err = db.QueryFramesInto("foo", "select * from foo order by timestamp desc", frames, model)
model, err := db.QueryFramesToFrames("foo", "select * from foo order by timestamp desc", frames)
assert.Nil(t, err)

assert.Equal(t, 4, model.Rows())
Expand Down Expand Up @@ -459,8 +523,7 @@ func TestTimeSeriesAggregate(t *testing.T) {

frames := []*data.Frame{frame}

model := &data.Frame{}
err = db.QueryFramesInto("foo", "select CURRENT_TIMESTAMP, min(time) as t, 1 as j from foo group by category", frames, model)
model, err := db.QueryFramesToFrames("foo", "select min(time) as t, 1 as j from foo group by category", frames)
assert.Nil(t, err)

assert.Equal(t, data.FrameTypeTimeSeriesWide, model.Meta.Type)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ require (
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/hashicorp/go-hclog v1.6.3 // indirect
github.com/jeremywohl/flatten v1.0.1
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/asmfmt v1.3.2 // indirect
github.com/klauspost/compress v1.17.9 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ github.com/iancoleman/orderedmap v0.3.0 h1:5cbR2grmZR/DiVt+VJopEhtVs9YGInGIxAoMJ
github.com/iancoleman/orderedmap v0.3.0/go.mod h1:XuLcCUkdL5owUCQeF2Ue9uuw1EptkJDkXXS7VoV7XGE=
github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY=
github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q=
github.com/jeremywohl/flatten v1.0.1 h1:LrsxmB3hfwJuE+ptGOijix1PIfOoKLJ3Uee/mzbgtrs=
github.com/jeremywohl/flatten v1.0.1/go.mod h1:4AmD/VxjWcI5SRB0n6szE2A6s2fsNHDLO0nAlMHgfLQ=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand Down

0 comments on commit 31cc2a2

Please sign in to comment.