@@ -34,34 +34,25 @@ import (
34
34
// accepts one client connection at a time
35
35
type BasicPortForwarding struct {
36
36
port IPortSession
37
- stream * net.Conn
38
- listener * net.Listener
37
+ stream net.Conn
38
+ listener net.Listener
39
39
sessionId string
40
40
portParameters PortParameters
41
41
session session.Session
42
42
}
43
43
44
- // getNewListener returns a new listener to given address and type like tcp, unix etc.
45
- var getNewListener = func (listenerType string , listenerAddress string ) (listener net.Listener , err error ) {
46
- return net .Listen (listenerType , listenerAddress )
47
- }
48
-
49
- // acceptConnection returns connection to the listener
50
- var acceptConnection = func (log log.T , listener net.Listener ) (tcpConn net.Conn , err error ) {
51
- return listener .Accept ()
52
- }
53
-
54
44
// IsStreamNotSet checks if stream is not set
55
45
func (p * BasicPortForwarding ) IsStreamNotSet () (status bool ) {
56
46
return p .stream == nil
57
47
}
58
48
59
49
// Stop closes the stream
60
50
func (p * BasicPortForwarding ) Stop () {
51
+ p .listener .Close ()
61
52
if p .stream != nil {
62
- ( * p .stream ) .Close ()
53
+ p .stream .Close ()
63
54
}
64
- os . Exit ( 0 )
55
+ return
65
56
}
66
57
67
58
// InitializeStreams establishes connection and initializes the stream
@@ -77,7 +68,7 @@ func (p *BasicPortForwarding) InitializeStreams(log log.T, agentVersion string)
77
68
func (p * BasicPortForwarding ) ReadStream (log log.T ) (err error ) {
78
69
msg := make ([]byte , config .StreamDataPayloadSize )
79
70
for {
80
- numBytes , err := ( * p .stream ) .Read (msg )
71
+ numBytes , err := p .stream .Read (msg )
81
72
if err != nil {
82
73
log .Debugf ("Reading from port %s failed with error: %v. Close this connection, listen and accept new one." ,
83
74
p .portParameters .PortNumber , err )
@@ -108,7 +99,7 @@ func (p *BasicPortForwarding) ReadStream(log log.T) (err error) {
108
99
109
100
// WriteStream writes data to stream
110
101
func (p * BasicPortForwarding ) WriteStream (outputMessage message.ClientMessage ) error {
111
- _ , err := ( * p .stream ) .Write (outputMessage .Payload )
102
+ _ , err := p .stream .Write (outputMessage .Payload )
112
103
return err
113
104
}
114
105
@@ -120,41 +111,40 @@ func (p *BasicPortForwarding) startLocalConn(log log.T) (err error) {
120
111
localPortNumber = "0"
121
112
}
122
113
123
- var listener net.Listener
124
- if listener , err = p .startLocalListener (log , localPortNumber ); err != nil {
114
+ if err = p .startLocalListener (log , localPortNumber ); err != nil {
125
115
log .Errorf ("Unable to open tcp connection to port. %v" , err )
126
116
return err
127
117
}
128
118
129
- var tcpConn net.Conn
130
- if tcpConn , err = acceptConnection (log , listener ); err != nil {
131
- log .Errorf ("Failed to accept connection with error. %v" , err )
132
- return err
119
+ if p .stream , err = p .listener .Accept (); err != nil {
120
+ if p .session .DataChannel .IsSessionEnded () == false {
121
+ log .Errorf ("Failed to accept connection with error. %v" , err )
122
+ return err
123
+ }
124
+ }
125
+ if p .session .DataChannel .IsSessionEnded () == false {
126
+ log .Infof ("Connection accepted for session %s." , p .sessionId )
127
+ fmt .Printf ("Connection accepted for session %s.\n " , p .sessionId )
133
128
}
134
- log .Infof ("Connection accepted for session %s." , p .sessionId )
135
- fmt .Printf ("Connection accepted for session %s.\n " , p .sessionId )
136
-
137
- p .listener = & listener
138
- p .stream = & tcpConn
139
129
140
130
return
141
131
}
142
132
143
133
// startLocalListener starts a local listener to given address
144
- func (p * BasicPortForwarding ) startLocalListener (log log.T , portNumber string ) (listener net. Listener , err error ) {
134
+ func (p * BasicPortForwarding ) startLocalListener (log log.T , portNumber string ) (err error ) {
145
135
var displayMessage string
146
136
switch p .portParameters .LocalConnectionType {
147
137
case "unix" :
148
- if listener , err = getNewListener (p .portParameters .LocalConnectionType , p .portParameters .LocalUnixSocket ); err != nil {
138
+ if p . listener , err = net . Listen (p .portParameters .LocalConnectionType , p .portParameters .LocalUnixSocket ); err != nil {
149
139
return
150
140
}
151
141
displayMessage = fmt .Sprintf ("Unix socket %s opened for sessionId %s." , p .portParameters .LocalUnixSocket , p .sessionId )
152
142
default :
153
- if listener , err = getNewListener ("tcp" , "localhost:" + portNumber ); err != nil {
143
+ if p . listener , err = net . Listen ("tcp" , "localhost:" + portNumber ); err != nil {
154
144
return
155
145
}
156
146
// get port number the TCP listener opened
157
- p .portParameters .LocalPortNumber = strconv .Itoa (listener .Addr ().(* net.TCPAddr ).Port )
147
+ p .portParameters .LocalPortNumber = strconv .Itoa (p . listener .Addr ().(* net.TCPAddr ).Port )
158
148
displayMessage = fmt .Sprintf ("Port %s opened for sessionId %s." , p .portParameters .LocalPortNumber , p .sessionId )
159
149
}
160
150
@@ -171,29 +161,31 @@ func (p *BasicPortForwarding) handleControlSignals(log log.T) {
171
161
<- c
172
162
fmt .Println ("Terminate signal received, exiting." )
173
163
164
+ p .session .DataChannel .EndSession ()
174
165
if version .DoesAgentSupportTerminateSessionFlag (log , p .session .DataChannel .GetAgentVersion ()) {
175
166
if err := p .session .DataChannel .SendFlag (log , message .TerminateSession ); err != nil {
176
167
log .Errorf ("Failed to send TerminateSession flag: %v" , err )
177
168
}
178
169
fmt .Fprintf (os .Stdout , "\n \n Exiting session with sessionId: %s.\n \n " , p .sessionId )
179
- p .Stop ()
180
170
} else {
181
171
p .session .TerminateSession (log )
182
172
}
173
+ p .Stop ()
183
174
}()
184
175
}
185
176
186
177
// reconnect closes existing connection, listens to new connection and accept it
187
178
func (p * BasicPortForwarding ) reconnect (log log.T ) (err error ) {
188
179
// close existing connection as it is in a state from which data cannot be read
189
- ( * p .stream ) .Close ()
180
+ p .stream .Close ()
190
181
191
182
// wait for new connection on listener and accept it
192
- var conn net.Conn
193
- if conn , err = acceptConnection (log , * p .listener ); err != nil {
194
- return log .Errorf ("Failed to accept connection with error. %v" , err )
183
+ if p .stream , err = p .listener .Accept (); err != nil {
184
+ if p .session .DataChannel .IsSessionEnded () == false {
185
+ log .Errorf ("Failed to accept connection with error. %v" , err )
186
+ return err
187
+ }
195
188
}
196
- p .stream = & conn
197
189
198
190
return
199
191
}
0 commit comments