diff --git a/internal/pkg/store/sql/database.go b/internal/pkg/store/sql/database.go index 1ce30deed..74c04a14d 100644 --- a/internal/pkg/store/sql/database.go +++ b/internal/pkg/store/sql/database.go @@ -1,4 +1,4 @@ -// Copyright 2022-2022 EMQ Technologies Co., Ltd. +// Copyright 2022-2024 EMQ Technologies Co., Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,8 +14,31 @@ package sql -import "database/sql" +import ( + "database/sql" + "regexp" +) type Database interface { Apply(f func(db *sql.DB) error) error } + +// isValidTableName checks if the given string is a valid database table name. +func isValidTableName(tableName string) bool { + // Check if the table name is empty + if tableName == "" { + return false + } + + // Regular expression to match valid table names + // ^[a-zA-Z_][a-zA-Z0-9_]*$ + // ^[a-zA-Z_] ensures the name starts with a letter or underscore + // [a-zA-Z0-9_]*$ ensures the rest of the name consists of letters, digits, or underscores + validTableNamePattern := `^[a-zA-Z_][a-zA-Z0-9/_]*$` + + // Compile the regular expression + re := regexp.MustCompile(validTableNamePattern) + + // Check if the table name matches the pattern + return re.MatchString(tableName) +} diff --git a/internal/pkg/store/sql/sqlKv.go b/internal/pkg/store/sql/sqlKv.go index 38c4be5e8..3659f9fbc 100644 --- a/internal/pkg/store/sql/sqlKv.go +++ b/internal/pkg/store/sql/sqlKv.go @@ -1,4 +1,4 @@ -// Copyright 2021-2023 EMQ Technologies Co., Ltd. +// Copyright 2021-2024 EMQ Technologies Co., Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,6 +31,9 @@ type sqlKvStore struct { } func createSqlKvStore(database Database, table string) (*sqlKvStore, error) { + if !isValidTableName(table) { + return nil, fmt.Errorf("invalid table name: %s", table) + } store := &sqlKvStore{ database: database, table: table, @@ -157,12 +160,12 @@ func (kv *sqlKvStore) Delete(key string) error { if nil != err || 0 == len(tmp) { return errorx.NewWithCode(errorx.NOT_FOUND, fmt.Sprintf("%s is not found", key)) } - query = fmt.Sprintf("DELETE FROM '%s' WHERE key='%s';", kv.table, key) + query = fmt.Sprintf("DELETE FROM '%s' WHERE key=?;", kv.table) stmt, err = db.Prepare(query) if err != nil { return err } - _, err = stmt.Exec() + _, err = stmt.Exec(key) return err }) } diff --git a/internal/pkg/store/sql/sqlKv_test.go b/internal/pkg/store/sql/sqlKv_test.go index cc57257ee..d698c9057 100644 --- a/internal/pkg/store/sql/sqlKv_test.go +++ b/internal/pkg/store/sql/sqlKv_test.go @@ -1,4 +1,4 @@ -// Copyright 2021-2022 EMQ Technologies Co., Ltd. +// Copyright 2021-2024 EMQ Technologies Co., Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,9 @@ import ( "path/filepath" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/lf-edge/ekuiper/v2/internal/pkg/store/definition" "github.com/lf-edge/ekuiper/v2/internal/pkg/store/sql/sqlite" "github.com/lf-edge/ekuiper/v2/internal/pkg/store/test/common" @@ -78,6 +81,27 @@ func TestSqlKvGetKeyedState(t *testing.T) { common.TestKvGetKeyedState(ks, t) } +func TestInvalidTableName(t *testing.T) { + absPath, err := filepath.Abs("test") + require.NoError(t, err) + err = deleteIfExists(absPath) + assert.NoError(t, err) + config := definition.Config{ + Type: "sqlite", + Redis: definition.RedisConfig{}, + Sqlite: definition.SqliteConfig{ + Path: absPath, + Name: SDbName, + }, + } + db, _ := sqlite.NewSqliteDatabase(config, "sqliteKV.db") + err = db.Connect() + require.NoError(t, err) + builder := NewStoreBuilder(db.(Database)) + _, err = builder.CreateStore("1_abc") + require.EqualError(t, err, "invalid table name: 1_abc") +} + func deleteIfExists(abs string) error { absPath := path.Join(abs, SDbName) if f, _ := os.Stat(absPath); f != nil { diff --git a/internal/pkg/store/sql/sqlTs.go b/internal/pkg/store/sql/sqlTs.go index e963b25ad..8abbf724d 100644 --- a/internal/pkg/store/sql/sqlTs.go +++ b/internal/pkg/store/sql/sqlTs.go @@ -1,4 +1,4 @@ -// Copyright 2022-2023 EMQ Technologies Co., Ltd. +// Copyright 2022-2024 EMQ Technologies Co., Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,6 +30,9 @@ type ts struct { } func createSqlTs(database Database, table string) (*ts, error) { + if !isValidTableName(table) { + return nil, fmt.Errorf("invalid table name: %s", table) + } store := &ts{ database: database, table: table, diff --git a/internal/pkg/store/sql/sqlTs_test.go b/internal/pkg/store/sql/sqlTs_test.go index 0d8e7addd..a77afccd5 100644 --- a/internal/pkg/store/sql/sqlTs_test.go +++ b/internal/pkg/store/sql/sqlTs_test.go @@ -1,4 +1,4 @@ -// Copyright 2021-2022 EMQ Technologies Co., Ltd. +// Copyright 2021-2024 EMQ Technologies Co., Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,9 @@ import ( "path/filepath" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/lf-edge/ekuiper/v2/internal/pkg/store/definition" "github.com/lf-edge/ekuiper/v2/internal/pkg/store/sql/sqlite" "github.com/lf-edge/ekuiper/v2/internal/pkg/store/test/common" @@ -108,6 +111,26 @@ func setupTSqlKv() (ts2.Tskv, definition.Database, string) { return store, db, absPath } +func TestInvalidTsTableName(t *testing.T) { + absPath, err := filepath.Abs("test") + require.NoError(t, err) + err = deleteIfExists(absPath) + assert.NoError(t, err) + config := definition.Config{ + Type: "sqlite", + Sqlite: definition.SqliteConfig{ + Path: absPath, + Name: TDbName, + }, + } + db, _ := sqlite.NewSqliteDatabase(config, "sqliteKV.db") + err = db.Connect() + require.NoError(t, err) + builder := NewTsBuilder(db.(Database)) + _, err = builder.CreateTs("1_abc") + require.EqualError(t, err, "invalid table name: 1_abc") +} + func cleanTSqlKv(db definition.Database, abs string) { if err := db.Disconnect(); err != nil { panic(err)