Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: modernize code #378

Merged
merged 18 commits into from
Dec 22, 2024
Prev Previous commit
Next Next commit
chore: modernize code
Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de>
  • Loading branch information
jkroepke committed Dec 22, 2024
commit 630d49be30f1371efe3534001002d49b25455130
21 changes: 8 additions & 13 deletions internal/oauth2/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -231,24 +230,20 @@ func TestRefreshReAuth(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

conf, openVPNClient, managementInterface, _, _, httpClient, logger := testutils.SetupMockEnvironment(context.Background(), t, tt.conf, tt.rt)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

t.Cleanup(func() {
if t.Failed() {
t.Log(logger.String())
}
})
conf, openVPNClient, managementInterface, _, _, httpClient, logger := testutils.SetupMockEnvironment(ctx, t, tt.conf, tt.rt)

wg := sync.WaitGroup{}
wg.Add(1)

errCh := make(chan error, 1)

go func() {
defer wg.Done()

err := openVPNClient.Connect(context.Background())
if err != nil && !errors.Is(err, io.EOF) {
assert.NoError(t, err)
}
errCh <- openVPNClient.Connect(ctx)
}()

managementInterfaceConn, err := managementInterface.Accept()
Expand Down Expand Up @@ -406,10 +401,10 @@ func TestRefreshReAuth(t *testing.T) {

testutils.SendMessage(t, managementInterfaceConn, "SUCCESS: %s command succeeded", strings.SplitN(auth, " ", 2)[0])

time.Sleep(time.Millisecond * 50)

openVPNClient.Shutdown()

wg.Wait()
require.NoError(t, <-errCh, logger.String())
})
}
}
13 changes: 6 additions & 7 deletions internal/openvpn/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ import (
"time"

"github.com/jkroepke/openvpn-auth-oauth2/internal/openvpn/connection"
"github.com/jkroepke/openvpn-auth-oauth2/internal/utils"
)

// handlePassword enters the password on the OpenVPN management interface connection.
func (c *Client) handlePassword() error {
func (c *Client) handlePassword(ctx context.Context) error {
buf := make([]byte, 15)

err := c.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
Expand All @@ -33,15 +32,15 @@ func (c *Client) handlePassword() error {
return fmt.Errorf("set read deadline: %w", err)
}

c.logger.Debug(utils.StringConcat("password probe: ", string(buf)))
c.logger.LogAttrs(ctx, slog.LevelDebug, "password probe: "+string(buf))

switch {
case string(buf) == "ENTER PASSWORD:":
if c.conf.OpenVpn.Password == "" {
return errors.New("management password required")
}

if err = c.sendPassword(); err != nil {
if err = c.sendPassword(ctx); err != nil {
return err
}
case c.conf.OpenVpn.Password != "":
Expand All @@ -55,8 +54,8 @@ func (c *Client) handlePassword() error {
}

// sendPassword enters the password on the OpenVPN management interface connection.
func (c *Client) sendPassword() error {
if err := c.rawCommand(c.conf.OpenVpn.Password.String()); err != nil {
func (c *Client) sendPassword(ctx context.Context) error {
if err := c.rawCommand(ctx, c.conf.OpenVpn.Password.String()); err != nil {
return fmt.Errorf("error from password command: %w", err)
}

Expand Down Expand Up @@ -197,7 +196,7 @@ func (c *Client) handleCommands(ctx context.Context, errCh chan<- error) {
return
}

if err := c.rawCommand(command); err != nil {
if err := c.rawCommand(ctx, command); err != nil {
errCh <- err

return
Expand Down
10 changes: 5 additions & 5 deletions internal/openvpn/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
c.oauth2 = client
}

func (c *Client) Connect(ctx context.Context) error {

Check failure on line 48 in internal/openvpn/main.go

View workflow job for this annotation

GitHub Actions / lint

calculated cyclomatic complexity for function Connect is 12, max is 10 (cyclop)
var err error

ctx, cancel := context.WithCancel(ctx)
Expand All @@ -63,7 +63,7 @@
c.scanner.Split(bufio.ScanLines)
c.scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)

if err = c.handlePassword(); err != nil {
if err = c.handlePassword(ctx); err != nil {
return fmt.Errorf("openvpn management error: %w", err)
}

Expand Down Expand Up @@ -226,9 +226,9 @@
}

// rawCommand passes command to a given connection (adds logging and EOL character).
func (c *Client) rawCommand(cmd string) error {
if c.logger.Enabled(context.Background(), slog.LevelDebug) {
c.logger.Debug(cmd)
func (c *Client) rawCommand(ctx context.Context, cmd string) error {
if c.logger.Enabled(ctx, slog.LevelDebug) {
c.logger.LogAttrs(ctx, slog.LevelDebug, "send command", slog.String("command", cmd))
}

c.commandsBuffer.Reset()
Expand Down Expand Up @@ -272,7 +272,7 @@
}
}

if c.scanner.Err() != nil {
if c.closed.Load() == 0 && c.scanner.Err() != nil {
return fmt.Errorf("scanner error: %w", c.scanner.Err())
}

Expand Down
6 changes: 2 additions & 4 deletions internal/openvpn/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,7 @@ func TestClientInvalidPassword(t *testing.T) {

err = openVPNClient.Connect(context.Background())

require.Error(t, err)
assert.Equal(t, "unable to connect to openvpn management interface: invalid password", err.Error())
require.EqualError(t, err, "openvpn management error: unable to connect to openvpn management interface: invalid password")
}

func TestClientInvalidVersion(t *testing.T) {
Expand Down Expand Up @@ -510,8 +509,7 @@ func TestClientInvalidVersion(t *testing.T) {

err = <-errCh

require.Error(t, err)
assert.Equal(t, tt.err, err.Error())
require.EqualError(t, err, tt.err, tt.err)
})
}
}
Expand Down
10 changes: 5 additions & 5 deletions internal/openvpn/passthrough.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (c *Client) handlePassThroughClient(ctx context.Context, conn net.Conn) {
panic(fmt.Errorf("%w %s", ErrUnknownProtocol, c.conf.OpenVpn.Passthrough.Address.Scheme))
}

logger.Info("pass-through: accepted connection")
logger.LogAttrs(ctx, slog.LevelInfo, "pass-through: accepted connection")

scanner := bufio.NewScanner(conn)
scanner.Split(bufio.ScanLines)
Expand All @@ -154,10 +154,10 @@ func (c *Client) handlePassThroughClient(ctx context.Context, conn net.Conn) {
c.passThroughConnected.CompareAndSwap(0, 1)

if err = c.handlePassThroughClientCommands(ctx, conn, logger, scanner); err != nil {
logger.Warn(err.Error())
logger.LogAttrs(ctx, slog.LevelWarn, err.Error())
}

logger.Info("pass-through: closed connection")
logger.LogAttrs(ctx, slog.LevelInfo, "pass-through: closed connection")
}

func (c *Client) handlePassThroughClientCommands(ctx context.Context, conn net.Conn, logger *slog.Logger, scanner *bufio.Scanner) error {
Expand All @@ -178,7 +178,7 @@ func (c *Client) handlePassThroughClientCommands(ctx context.Context, conn net.C
switch {
case strings.HasPrefix(line, "client-deny"), strings.HasPrefix(line, "client-auth"):
c.writeToPassThroughClient("ERROR: command not allowed")
logger.Warn("pass-through: client send client-deny or client-auth message, ignoring...")
logger.LogAttrs(ctx, slog.LevelWarn, "pass-through: client send client-deny or client-auth message, ignoring...")

continue
case line == "hold":
Expand All @@ -193,7 +193,7 @@ func (c *Client) handlePassThroughClientCommands(ctx context.Context, conn net.C

resp, err = c.SendCommand(line, true)
if err != nil {
logger.Warn(fmt.Errorf("pass-through: error from command '%s': %w", line, err).Error())
logger.LogAttrs(ctx, slog.LevelWarn, fmt.Errorf("pass-through: error from command '%s': %w", line, err).Error())
} else {
c.writeToPassThroughClient(strings.TrimSpace(resp))
}
Expand Down
Loading