Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pia): port forward using server IP instead of gateway ip #2254

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions internal/models/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,19 @@ type Connection struct {
Protocol string `json:"protocol"`
// Hostname is used for IPVanish, IVPN, Privado
// and Windscribe for TLS verification.
// It is used for PIA for port forwarding.
Hostname string `json:"hostname"`
// PubKey is the public key of the VPN server,
// used only for Wireguard.
PubKey string `json:"pubkey"`
// ServerName is used for PIA for port forwarding
ServerName string `json:"server_name,omitempty"`
// PortForward is used for PIA for port forwarding
PortForward bool `json:"port_forward"`
}

func (c *Connection) Equal(other Connection) bool {
return c.IP.Compare(other.IP) == 0 && c.Port == other.Port &&
c.Protocol == other.Protocol && c.Hostname == other.Hostname &&
c.PubKey == other.PubKey && c.ServerName == other.ServerName &&
c.PortForward == other.PortForward
c.PubKey == other.PubKey && c.PortForward == other.PortForward
}

// UpdateEmptyWith updates each field of the connection where the
Expand Down
18 changes: 9 additions & 9 deletions internal/portforward/service/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type Settings struct {
PortForwarder PortForwarder
Filepath string
Interface string // needed for PIA and ProtonVPN, tun0 for example
ServerName string // needed for PIA
ServerHostname string // needed for PIA
CanPortForward bool // needed for PIA
ListeningPort uint16
}
Expand All @@ -23,7 +23,7 @@ func (s Settings) Copy() (copied Settings) {
copied.PortForwarder = s.PortForwarder
copied.Filepath = s.Filepath
copied.Interface = s.Interface
copied.ServerName = s.ServerName
copied.ServerHostname = s.ServerHostname
copied.CanPortForward = s.CanPortForward
copied.ListeningPort = s.ListeningPort
return copied
Expand All @@ -34,16 +34,16 @@ func (s *Settings) OverrideWith(update Settings) {
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
s.ServerHostname = gosettings.OverrideWithComparable(s.ServerHostname, update.ServerHostname)
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
s.ListeningPort = gosettings.OverrideWithComparable(s.ListeningPort, update.ListeningPort)
}

var (
ErrPortForwarderNotSet = errors.New("port forwarder not set")
ErrServerNameNotSet = errors.New("server name not set")
ErrFilepathNotSet = errors.New("file path not set")
ErrInterfaceNotSet = errors.New("interface not set")
ErrPortForwarderNotSet = errors.New("port forwarder not set")
ErrServerHostnameNotSet = errors.New("server hostname not set")
ErrFilepathNotSet = errors.New("file path not set")
ErrInterfaceNotSet = errors.New("interface not set")
)

func (s *Settings) Validate(forStartup bool) (err error) {
Expand All @@ -64,8 +64,8 @@ func (s *Settings) Validate(forStartup bool) (err error) {
return fmt.Errorf("%w", ErrPortForwarderNotSet)
case s.Interface == "":
return fmt.Errorf("%w", ErrInterfaceNotSet)
case s.PortForwarder.Name() == providers.PrivateInternetAccess && s.ServerName == "":
return fmt.Errorf("%w", ErrServerNameNotSet)
case s.PortForwarder.Name() == providers.PrivateInternetAccess && s.ServerHostname == "":
return fmt.Errorf("%w", ErrServerHostnameNotSet)
}
return nil
}
2 changes: 1 addition & 1 deletion internal/portforward/service/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
Logger: s.logger,
Gateway: gateway,
Client: s.client,
ServerName: s.settings.ServerName,
ServerHostname: s.settings.ServerHostname,
CanPortForward: s.settings.CanPortForward,
}
port, err := s.settings.PortForwarder.PortForward(ctx, obj)
Expand Down
12 changes: 6 additions & 6 deletions internal/provider/custom/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ func getOpenVPNConnection(extractor Extractor,
connection.Port = customPort
}

if len(selection.Names) > 0 {
// Set the server name for PIA port forwarding code used
if len(selection.Hostnames) > 0 {
// Set the server hostname for PIA port forwarding code used
// together with the custom provider.
connection.ServerName = selection.Names[0]
connection.Hostname = selection.Hostnames[0]
connection.PortForward = true
}

Expand All @@ -59,10 +59,10 @@ func getWireguardConnection(selection settings.ServerSelection) (
Protocol: constants.UDP,
PubKey: selection.Wireguard.PublicKey,
}
if len(selection.Names) > 0 {
// Set the server name for PIA port forwarding code used
if len(selection.Hostnames) > 0 {
// Set the server hostname for PIA port forwarding code used
// together with the custom provider.
connection.ServerName = selection.Names[0]
connection.Hostname = selection.Hostnames[0]
connection.PortForward = true
}
return connection
Expand Down
50 changes: 0 additions & 50 deletions internal/provider/privateinternetaccess/httpclient.go

This file was deleted.

51 changes: 0 additions & 51 deletions internal/provider/privateinternetaccess/httpclient_test.go

This file was deleted.

48 changes: 15 additions & 33 deletions internal/provider/privateinternetaccess/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"strconv"
Expand All @@ -27,14 +26,11 @@ var (
// PortForward obtains a VPN server side port forwarded from PIA.
func (p *Provider) PortForward(ctx context.Context,
objects utils.PortForwardObjects) (port uint16, err error) {
switch {
case objects.ServerName == "":
panic("server name cannot be empty")
case !objects.Gateway.IsValid():
panic("gateway is not set")
if objects.ServerHostname == "" {
panic("server hostname cannot be empty")
}

serverName := objects.ServerName
serverName := objects.ServerHostname

logger := objects.Logger

Expand All @@ -43,11 +39,6 @@ func (p *Provider) PortForward(ctx context.Context,
return 0, nil
}

privateIPClient, err := newHTTPClient(serverName)
if err != nil {
return 0, fmt.Errorf("creating custom HTTP client: %w", err)
}

data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil {
return 0, fmt.Errorf("reading saved port forwarded data: %w", err)
Expand All @@ -66,8 +57,7 @@ func (p *Provider) PortForward(ctx context.Context,
}

if !dataFound || expired {
client := objects.Client
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
data, err = refreshPIAPortForwardData(ctx, objects.Client, objects.ServerHostname,
p.portForwardPath, p.authFilePath)
if err != nil {
return 0, fmt.Errorf("refreshing port forward data: %w", err)
Expand All @@ -77,7 +67,7 @@ func (p *Provider) PortForward(ctx context.Context,
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))

// First time binding
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
if err := bindPort(ctx, objects.Client, objects.ServerHostname, data); err != nil {
return 0, fmt.Errorf("binding port: %w", err)
}

Expand All @@ -90,16 +80,8 @@ var (

func (p *Provider) KeepPortForward(ctx context.Context,
objects utils.PortForwardObjects) (err error) {
switch {
case objects.ServerName == "":
panic("server name cannot be empty")
case !objects.Gateway.IsValid():
panic("gateway is not set")
}

privateIPClient, err := newHTTPClient(objects.ServerName)
if err != nil {
return fmt.Errorf("creating custom HTTP client: %w", err)
if objects.ServerHostname == "" {
panic("server hostname cannot be empty")
}

data, err := readPIAPortForwardData(p.portForwardPath)
Expand All @@ -124,7 +106,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
}
return ctx.Err()
case <-keepAliveTimer.C:
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
err = bindPort(ctx, objects.Client, objects.ServerHostname, data)
if err != nil {
return fmt.Errorf("binding port: %w", err)
}
Expand All @@ -136,14 +118,14 @@ func (p *Provider) KeepPortForward(ctx context.Context,
}
}

func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway netip.Addr, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
serverHostname, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, client, authFilePath)
if err != nil {
return data, fmt.Errorf("fetching token: %w", err)
}

data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, client, serverHostname, data.Token)
if err != nil {
return data, fmt.Errorf("fetching port forwarding data: %w", err)
}
Expand Down Expand Up @@ -319,15 +301,15 @@ func getOpenvpnCredentials(authFilePath string) (
return username, password, nil
}

func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) (
func fetchPortForwardData(ctx context.Context, client *http.Client, serverHostname, token string) (
port uint16, signature string, expiration time.Time, err error) {
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}

queryParams := make(url.Values)
queryParams.Add("token", token)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Host: net.JoinHostPort(serverHostname, "19999"),
Path: "/getSignature",
RawQuery: queryParams.Encode(),
}
Expand Down Expand Up @@ -373,7 +355,7 @@ var (
ErrBadResponse = errors.New("bad response received")
)

func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) {
func bindPort(ctx context.Context, client *http.Client, serverHostname string, data piaPortForwardData) (err error) {
payload, err := packPayload(data.Port, data.Token, data.Expiration)
if err != nil {
return fmt.Errorf("serializing payload: %w", err)
Expand All @@ -384,7 +366,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
queryParams.Add("signature", data.Signature)
bindPortURL := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Host: net.JoinHostPort(serverHostname, "19999"),
Path: "/bindPort",
RawQuery: queryParams.Encode(),
}
Expand Down
1 change: 0 additions & 1 deletion internal/provider/utils/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func GetConnection(provider string,
Port: port,
Protocol: protocol,
Hostname: hostname,
ServerName: server.ServerName,
PortForward: server.PortForward,
PubKey: server.WgPubKey, // Wireguard
}
Expand Down
7 changes: 3 additions & 4 deletions internal/provider/utils/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ import (
type PortForwardObjects struct {
// Logger is a logger, used by both Private Internet Access and ProtonVPN.
Logger Logger
// Gateway is the VPN gateway IP address, used by Private Internet Access
// and ProtonVPN.
// Gateway is the VPN gateway IP address, used by ProtonVPN.
Gateway netip.Addr
// Client is used to query the VPN gateway for Private Internet Access.
Client *http.Client
// ServerName is used by Private Internet Access for port forwarding.
ServerName string
// ServerHostname is used by Private Internet Access for port forwarding.
ServerHostname string
// CanPortForward is used by Private Internet Access for port forwarding.
CanPortForward bool
}
Expand Down
Loading
Loading