Skip to content

implement startup script #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions cmd/sqlcmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,28 +283,43 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
}
}
}
once := false
if args.InitialQuery != "" {
s.Query = args.InitialQuery
} else if args.Query != "" {
once = true
s.Query = args.Query
}

// connect using no overrides
err = s.ConnectDb(nil, line == nil)
if err != nil {
return 1, err
}

iactive := args.InputFile == nil && args.Query == ""
if iactive || s.Query != "" {
err = s.Run(once, false)
} else {
for f := range args.InputFile {
if err = s.IncludeFile(args.InputFile[f], true); err != nil {
s.WriteError(s.GetError(), err)
s.Exitcode = 1
break
script := vars.StartupScriptFile()
if !args.DisableCmdAndWarn && len(script) > 0 {
f, fileErr := os.Open(script)
if fileErr != nil {
s.WriteError(s.GetError(), sqlcmd.InvalidVariableValue(sqlcmd.SQLCMDINI, script))
} else {
_ = f.Close()
// IncludeFile won't return an error for a SQL error, but ExitCode will be non-zero if -b was passed on the commandline
err = s.IncludeFile(script, true)
}
}

if err == nil && s.Exitcode == 0 {
once := false
if args.InitialQuery != "" {
s.Query = args.InitialQuery
} else if args.Query != "" {
once = true
s.Query = args.Query
}
iactive := args.InputFile == nil && args.Query == ""
if iactive || s.Query != "" {
err = s.Run(once, false)
} else {
for f := range args.InputFile {
if err = s.IncludeFile(args.InputFile[f], true); err != nil {
s.WriteError(s.GetError(), err)
s.Exitcode = 1
break
}
}
}
}
Expand Down
24 changes: 24 additions & 0 deletions cmd/sqlcmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,30 @@ func TestConditionsForPasswordPrompt(t *testing.T) {
}
}

func TestStartupScript(t *testing.T) {
o, err := os.CreateTemp("", "sqlcmdmain")
assert.NoError(t, err, "os.CreateTemp")
defer os.Remove(o.Name())
defer o.Close()
args = newArguments()
args.OutputFile = o.Name()
args.Query = "set nocount on"
if canTestAzureAuth() {
args.UseAad = true
}
vars := sqlcmd.InitializeVariables(true)
setVars(vars, &args)
vars.Set(sqlcmd.SQLCMDINI, "testdata/select100.sql")
vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0")
exitCode, err := run(vars, &args)
assert.NoError(t, err, "run")
assert.Equal(t, 0, exitCode, "exitCode")
bytes, err := os.ReadFile(o.Name())
if assert.NoError(t, err, "os.ReadFile") {
assert.Equal(t, "100"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run")
}
}

// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
func canTestAzureAuth() bool {
server := os.Getenv(sqlcmd.SQLCMDSERVER)
Expand Down
9 changes: 9 additions & 0 deletions pkg/sqlcmd/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package sqlcmd
import (
"errors"
"fmt"
"strings"
)

// ErrorPrefix is the prefix for all sqlcmd-generated errors
Expand Down Expand Up @@ -56,6 +57,14 @@ func UndefinedVariable(variable string) *VariableError {
}
}

// InvalidVariableValue indicates the variable was set to an invalid value
func InvalidVariableValue(variable string, value string) *VariableError {
return &VariableError{
Variable: variable,
MessageFormat: "The environment variable: '%s' has invalid value: '" + strings.ReplaceAll(value, `%`, `%%`) + "'.",
}
}

// CommandError indicates syntax errors for specific sqlcmd commands
type CommandError struct {
Command string
Expand Down
5 changes: 5 additions & 0 deletions pkg/sqlcmd/variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ func (v Variables) Format() string {
return "horizontal"
}

// StartupScriptFile is the path to the file that contains the startup script
func (v Variables) StartupScriptFile() string {
return v[SQLCMDINI]
}

// TextEditor is the query editor application launched by the :ED command
func (v Variables) TextEditor() string {
return v[SQLCMDEDITOR]
Expand Down