diff --git a/pkg/os/shell/.shell_windows.go.swp b/pkg/os/shell/.shell_windows.go.swp new file mode 100644 index 0000000000..95b99bdbfb Binary files /dev/null and b/pkg/os/shell/.shell_windows.go.swp differ diff --git a/pkg/os/shell/shell_windows.go b/pkg/os/shell/shell_windows.go index 9f4c8d36c1..e047d1d9a5 100644 --- a/pkg/os/shell/shell_windows.go +++ b/pkg/os/shell/shell_windows.go @@ -55,6 +55,30 @@ func getNameAndItsPpid(pid uint32) (exefile string, parentid uint32, err error) return name, pe.ParentProcessID, nil } +func analyzeShell(shell string, shellpid uint32) (string, error) { + switch { + case strings.Contains(strings.ToLower(shell), "powershell"): + return "powershell", nil + case strings.Contains(strings.ToLower(shell), "pwsh"): + return "powershell", nil + case strings.Contains(strings.ToLower(shell), "cmd"): + return "cmd", nil + default: + shell, _, err := getNameAndItsPpid(shellpid) + if err != nil { + return "cmd", err // defaulting to cmd + } + switch { + case strings.Contains(strings.ToLower(shell), "powershell"): + return "powershell", nil + case strings.Contains(strings.ToLower(shell), "cmd"): + return "cmd", nil + default: + return "cmd", nil // this could be either powershell or cmd, defaulting to cmd + } + } +} + func detect() (string, error) { shell := os.Getenv("SHELL") @@ -67,32 +91,22 @@ func detect() (string, error) { if err != nil { return "cmd", err // defaulting to cmd } - switch { - case strings.Contains(strings.ToLower(shell), "powershell"): - return "powershell", nil - case strings.Contains(strings.ToLower(shell), "pwsh"): - return "powershell", nil - case strings.Contains(strings.ToLower(shell), "cmd"): - return "cmd", nil - default: - shell, _, err := getNameAndItsPpid(shellppid) - if err != nil { - return "cmd", err // defaulting to cmd - } - switch { - case strings.Contains(strings.ToLower(shell), "powershell"): - return "powershell", nil - case strings.Contains(strings.ToLower(shell), "cmd"): - return "cmd", nil - default: - return "cmd", nil // this could be either powershell or cmd, defaulting to cmd - } - } + return analyzeShell(shell, shellppid) } if os.Getenv("__fish_bin_dir") != "" { return "fish", nil } - return filepath.Base(shell), nil + baseShell := filepath.Base(shell) + + if baseShell != "powershell" && baseShell != "cmd" { + pid := os.Getpid() + if pid < 0 || pid > math.MaxUint32 { + return "", fmt.Errorf("integer overflow for pid: %v", pid) + } + return analyzeShell(shell, uint32(pid)) + } + + return baseShell, nil }