Skip to content

Implement :ED #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/sqlcmd/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ parse:
if err == nil {
i = min(i, b.rawlen)
empty := isEmptyLine(b.raw, 0, i)
appendLine := b.quote != 0 || b.comment || !empty
appendLine := true
if !b.comment && command != nil && empty {
appendLine = false
}
Expand Down
56 changes: 52 additions & 4 deletions pkg/sqlcmd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ type Command struct {
action func(*Sqlcmd, []string, uint) error
// Name of the command
name string
// whether the command is a system command
isSystem bool
}

// Commands is the set of sqlcmd command implementations
Expand Down Expand Up @@ -89,9 +91,16 @@ func newCommands() Commands {
name: "CONNECT",
},
"EXEC": {
regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(?:[ \t]+(.*$)|$)`),
action: execCommand,
name: "EXEC",
regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(?:[ \t]+(.*$)|$)`),
action: execCommand,
name: "EXEC",
isSystem: true,
},
"EDIT": {
regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`),
action: editCommand,
name: "EDIT",
isSystem: true,
},
}
}
Expand All @@ -103,8 +112,13 @@ func (c Commands) DisableSysCommands(exitOnCall bool) {
if exitOnCall {
f = errorDisabled
}
c["EXEC"].action = f
for _, cmd := range c {
if cmd.isSystem {
cmd.action = f
}
}
}

func (c Commands) matchCommand(line string) (*Command, []string) {
for _, cmd := range c {
matchedCommand := cmd.regex.FindStringSubmatch(line)
Expand Down Expand Up @@ -411,6 +425,40 @@ func execCommand(s *Sqlcmd, args []string, line uint) error {
return nil
}

func editCommand(s *Sqlcmd, args []string, line uint) error {
if args != nil && strings.TrimSpace(args[0]) != "" {
return InvalidCommandError("ED", line)
}
file, err := os.CreateTemp("", "sq*.sql")
if err != nil {
return err
}
fileName := file.Name()
defer os.Remove(fileName)
text := s.batch.String()
if s.batch.State() == "-" {
text = fmt.Sprintf("%s%s", text, SqlcmdEol)
}
_, err = file.WriteString(text)
if err != nil {
return err
}
file.Close()
cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`)
cmd.Stderr = s.GetError()
cmd.Stdout = s.GetOutput()
err = cmd.Run()
if err != nil {
return err
}
wasEcho := s.echoFileLines
s.echoFileLines = true
s.batch.Reset(nil)
_ = s.IncludeFile(fileName, false)
s.echoFileLines = wasEcho
return nil
}

func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
var b *strings.Builder
end := len(arg)
Expand Down
11 changes: 11 additions & 0 deletions pkg/sqlcmd/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,14 @@ func TestDisableSysCommandBlocksExec(t *testing.T) {
assert.Equal(t, 1, s.Exitcode, "ExitCode after error")
}
}

func TestEditCommand(t *testing.T) {
s, buf := setupSqlCmdWithMemoryOutput(t)
defer buf.Close()
s.vars.Set(SQLCMDEDITOR, "echo select 5000> ")
c := []string{"set nocount on", "go", "select 100", ":ed", "go"}
err := runSqlCmd(t, s, c)
if assert.NoError(t, err, ":ed should not raise error") {
assert.Equal(t, "1> select 5000"+SqlcmdEol+"5000"+SqlcmdEol+SqlcmdEol, buf.buf.String(), "Incorrect output from query after :ed command")
}
}
2 changes: 2 additions & 0 deletions pkg/sqlcmd/exec_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ func comSpec() string {
// /bin/sh will be a link to the shell
return `/bin/sh`
}

const defaultEditor = "vi"
2 changes: 2 additions & 0 deletions pkg/sqlcmd/exec_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ func comSpec() string {
// /bin/sh will be a link to the shell
return `/bin/sh`
}

const defaultEditor = "vi"
2 changes: 2 additions & 0 deletions pkg/sqlcmd/exec_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ func comSpec() string {
func comArgs(args string) string {
return `/c ` + args
}

const defaultEditor = "notepad.exe"
18 changes: 13 additions & 5 deletions pkg/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type Sqlcmd struct {
out io.WriteCloser
err io.WriteCloser
batch *Batch
echoFileLines bool
// Exitcode is returned to the operating system when the process exits
Exitcode int
Connect *ConnectSettings
Expand Down Expand Up @@ -310,6 +311,7 @@ func (s *Sqlcmd) IncludeFile(path string, processAll bool) error {
buf := make([]byte, maxLineBuffer)
scanner.Buffer(buf, maxLineBuffer)
curLine := s.batch.read
echoFileLines := s.echoFileLines
s.batch.read = func() (string, error) {
if !scanner.Scan() {
err := scanner.Err()
Expand All @@ -318,14 +320,20 @@ func (s *Sqlcmd) IncludeFile(path string, processAll bool) error {
}
return "", err
}
return scanner.Text(), nil
t := scanner.Text()
if echoFileLines {
_, _ = s.GetOutput().Write([]byte(s.Prompt() + t + SqlcmdEol))
}
return t, nil
}
err = s.Run(false, processAll)
s.batch.read = curLine
if s.batch.State() == "=" {
s.batch.batchline = 1
} else {
s.batch.batchline = b + 1
if !s.echoFileLines {
if s.batch.State() == "=" {
s.batch.batchline = 1
} else {
s.batch.batchline = b + 1
}
}
return err
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/sqlcmd/variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ func (v Variables) Format() string {
return "horizontal"
}

// TextEditor is the query editor application launched by the :ED command
func (v Variables) TextEditor() string {
return v[SQLCMDEDITOR]
}

func mustValue(val string) int64 {
var n int64
_, err := fmt.Sscanf(val, "%d", &n)
Expand All @@ -193,7 +198,7 @@ func mustValue(val string) int64 {
var defaultVariables = Variables{
SQLCMDCOLSEP: " ",
SQLCMDCOLWIDTH: "0",
SQLCMDEDITOR: "edit.com",
SQLCMDEDITOR: defaultEditor,
SQLCMDERRORLEVEL: "0",
SQLCMDHEADERS: "0",
SQLCMDLOGINTIMEOUT: "30",
Expand Down