Skip to content

Commit b4e8805

Browse files
committed
Integrate fixes from open PRs:
aws#104 aws#99
1 parent 82dc729 commit b4e8805

File tree

12 files changed

+162
-114
lines changed

12 files changed

+162
-114
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.2.0.0
1+
1.2.694.0

src/datachannel/streaming.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ type IDataChannel interface {
6262
RegisterOutputStreamHandler(handler OutputStreamDataMessageHandler, isSessionSpecificHandler bool)
6363
DeregisterOutputStreamHandler(handler OutputStreamDataMessageHandler)
6464
IsSessionTypeSet() chan bool
65+
EndSession() error
66+
IsSessionEnded() bool
6567
IsStreamMessageResendTimeout() chan bool
6668
GetSessionType() string
6769
SetSessionType(sessionType string)
@@ -106,6 +108,8 @@ type DataChannel struct {
106108
isSessionTypeSet chan bool
107109
sessionProperties interface{}
108110

111+
isSessionEnded bool
112+
109113
// Used to detect if resending a streaming message reaches timeout
110114
isStreamMessageResendTimeout chan bool
111115

@@ -187,6 +191,7 @@ func (dataChannel *DataChannel) Initialize(log log.T, clientId string, sessionId
187191
dataChannel.wsChannel = &communicator.WebSocketChannel{}
188192
dataChannel.encryptionEnabled = false
189193
dataChannel.isSessionTypeSet = make(chan bool, 1)
194+
dataChannel.isSessionEnded = false
190195
dataChannel.isStreamMessageResendTimeout = make(chan bool, 1)
191196
dataChannel.sessionType = ""
192197
dataChannel.IsAwsCliUpgradeNeeded = isAwsCliUpgradeNeeded
@@ -199,7 +204,7 @@ func (dataChannel *DataChannel) SetWebsocket(log log.T, channelUrl string, chann
199204

200205
// FinalizeHandshake sends the token for service to acknowledge the connection.
201206
func (dataChannel *DataChannel) FinalizeDataChannelHandshake(log log.T, tokenValue string) (err error) {
202-
uuid.SwitchFormat(uuid.CleanHyphen)
207+
uuid.SwitchFormat(uuid.FormatCanonical)
203208
uid := uuid.NewV4().String()
204209

205210
log.Infof("Sending token through data channel %s to acknowledge connection", dataChannel.wsChannel.GetStreamUrl())
@@ -773,7 +778,7 @@ func (dataChannel *DataChannel) HandleAcknowledgeMessage(
773778
}
774779

775780
// handleChannelClosedMessage exits the shell
776-
func (dataChannel DataChannel) HandleChannelClosedMessage(log log.T, stopHandler Stop, sessionId string, outputMessage message.ClientMessage) {
781+
func (dataChannel *DataChannel) HandleChannelClosedMessage(log log.T, stopHandler Stop, sessionId string, outputMessage message.ClientMessage) {
777782
var (
778783
channelClosedMessage message.ChannelClosed
779784
err error
@@ -788,6 +793,8 @@ func (dataChannel DataChannel) HandleChannelClosedMessage(log log.T, stopHandler
788793
} else {
789794
fmt.Fprintf(os.Stdout, "\n\nSessionId: %s : %s\n\n", sessionId, channelClosedMessage.Output)
790795
}
796+
dataChannel.EndSession()
797+
dataChannel.Close(log)
791798

792799
stopHandler()
793800
}
@@ -850,7 +857,7 @@ func (dataChannel *DataChannel) CalculateRetransmissionTimeout(log log.T, stream
850857
func (dataChannel *DataChannel) ProcessKMSEncryptionHandshakeAction(log log.T, actionParams json.RawMessage) (err error) {
851858

852859
if dataChannel.IsAwsCliUpgradeNeeded {
853-
return errors.New("Installed version of CLI does not support Session Manager encryption feature. Please upgrade to the latest version of your CLI (e.g., AWS CLI).")
860+
return errors.New("installed version of CLI does not support Session Manager encryption feature. Please upgrade to the latest version of your CLI (e.g., AWS CLI)")
854861
}
855862
kmsEncRequest := message.KMSEncryptionRequest{}
856863
json.Unmarshal(actionParams, &kmsEncRequest)
@@ -882,7 +889,7 @@ func (dataChannel *DataChannel) ProcessSessionTypeHandshakeAction(actionParams j
882889
dataChannel.sessionProperties = sessTypeReq.Properties
883890
return nil
884891
default:
885-
return errors.New(fmt.Sprintf("Unknown session type %s", sessTypeReq.SessionType))
892+
return fmt.Errorf("Unknown session type %s", sessTypeReq.SessionType)
886893
}
887894
}
888895

@@ -891,6 +898,17 @@ func (dataChannel *DataChannel) IsSessionTypeSet() chan bool {
891898
return dataChannel.isSessionTypeSet
892899
}
893900

901+
// IsSessionEnded check if session has ended
902+
func (dataChannel *DataChannel) IsSessionEnded() bool {
903+
return dataChannel.isSessionEnded
904+
}
905+
906+
// IsSessionEnded check if session has ended
907+
func (dataChannel *DataChannel) EndSession() error {
908+
dataChannel.isSessionEnded = true
909+
return nil
910+
}
911+
894912
// IsStreamMessageResendTimeout checks if resending a streaming message reaches timeout
895913
func (dataChannel *DataChannel) IsStreamMessageResendTimeout() chan bool {
896914
return dataChannel.isStreamMessageResendTimeout

src/message/messageparser.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,31 +167,31 @@ func getUuid(log log.T, byteArray []byte, offset int) (result uuid.UUID, err err
167167
byteArrayLength := len(byteArray)
168168
if offset > byteArrayLength-1 || offset+16-1 > byteArrayLength-1 || offset < 0 {
169169
log.Error("getUuid failed: Offset is invalid.")
170-
return nil, errors.New("Offset is outside the byte array.")
170+
return uuid.Nil.UUID(), errors.New("Offset is outside the byte array.")
171171
}
172172

173173
leastSignificantLong, err := getLong(log, byteArray, offset)
174174
if err != nil {
175175
log.Error("getUuid failed: failed to get uuid LSBs Long value.")
176-
return nil, errors.New("Failed to get uuid LSBs long value.")
176+
return uuid.Nil.UUID(), errors.New("Failed to get uuid LSBs long value.")
177177
}
178178

179179
leastSignificantBytes, err := longToBytes(log, leastSignificantLong)
180180
if err != nil {
181181
log.Error("getUuid failed: failed to get uuid LSBs bytes value.")
182-
return nil, errors.New("Failed to get uuid LSBs bytes value.")
182+
return uuid.Nil.UUID(), errors.New("Failed to get uuid LSBs bytes value.")
183183
}
184184

185185
mostSignificantLong, err := getLong(log, byteArray, offset+8)
186186
if err != nil {
187187
log.Error("getUuid failed: failed to get uuid MSBs Long value.")
188-
return nil, errors.New("Failed to get uuid MSBs long value.")
188+
return uuid.Nil.UUID(), errors.New("Failed to get uuid MSBs long value.")
189189
}
190190

191191
mostSignificantBytes, err := longToBytes(log, mostSignificantLong)
192192
if err != nil {
193193
log.Error("getUuid failed: failed to get uuid MSBs bytes value.")
194-
return nil, errors.New("Failed to get uuid MSBs bytes value.")
194+
return uuid.Nil.UUID(), errors.New("Failed to get uuid MSBs bytes value.")
195195
}
196196

197197
uuidBytes := append(mostSignificantBytes, leastSignificantBytes...)
@@ -414,7 +414,7 @@ func putBytes(log log.T, byteArray []byte, offsetStart int, offsetEnd int, input
414414

415415
// putUuid puts the 128 bit uuid to an array of bytes starting from the offset.
416416
func putUuid(log log.T, byteArray []byte, offset int, input uuid.UUID) (err error) {
417-
if input == nil {
417+
if uuid.IsNil(input) {
418418
log.Error("putUuid failed: input is null.")
419419
return errors.New("putUuid failed: input is null.")
420420
}
@@ -494,7 +494,7 @@ func SerializeClientMessageWithAcknowledgeContent(log log.T, acknowledgeContent
494494
return
495495
}
496496

497-
uuid.SwitchFormat(uuid.CleanHyphen)
497+
uuid.SwitchFormat(uuid.FormatCanonical)
498498
messageId := uuid.NewV4()
499499
clientMessage := ClientMessage{
500500
MessageType: AcknowledgeMessage,

src/sessionmanagerplugin/session/portsession/basicportforwarding.go

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,34 +34,25 @@ import (
3434
// accepts one client connection at a time
3535
type BasicPortForwarding struct {
3636
port IPortSession
37-
stream *net.Conn
38-
listener *net.Listener
37+
stream net.Conn
38+
listener net.Listener
3939
sessionId string
4040
portParameters PortParameters
4141
session session.Session
4242
}
4343

44-
// getNewListener returns a new listener to given address and type like tcp, unix etc.
45-
var getNewListener = func(listenerType string, listenerAddress string) (listener net.Listener, err error) {
46-
return net.Listen(listenerType, listenerAddress)
47-
}
48-
49-
// acceptConnection returns connection to the listener
50-
var acceptConnection = func(log log.T, listener net.Listener) (tcpConn net.Conn, err error) {
51-
return listener.Accept()
52-
}
53-
5444
// IsStreamNotSet checks if stream is not set
5545
func (p *BasicPortForwarding) IsStreamNotSet() (status bool) {
5646
return p.stream == nil
5747
}
5848

5949
// Stop closes the stream
6050
func (p *BasicPortForwarding) Stop() {
51+
p.listener.Close()
6152
if p.stream != nil {
62-
(*p.stream).Close()
53+
p.stream.Close()
6354
}
64-
os.Exit(0)
55+
return
6556
}
6657

6758
// InitializeStreams establishes connection and initializes the stream
@@ -77,7 +68,7 @@ func (p *BasicPortForwarding) InitializeStreams(log log.T, agentVersion string)
7768
func (p *BasicPortForwarding) ReadStream(log log.T) (err error) {
7869
msg := make([]byte, config.StreamDataPayloadSize)
7970
for {
80-
numBytes, err := (*p.stream).Read(msg)
71+
numBytes, err := p.stream.Read(msg)
8172
if err != nil {
8273
log.Debugf("Reading from port %s failed with error: %v. Close this connection, listen and accept new one.",
8374
p.portParameters.PortNumber, err)
@@ -108,7 +99,7 @@ func (p *BasicPortForwarding) ReadStream(log log.T) (err error) {
10899

109100
// WriteStream writes data to stream
110101
func (p *BasicPortForwarding) WriteStream(outputMessage message.ClientMessage) error {
111-
_, err := (*p.stream).Write(outputMessage.Payload)
102+
_, err := p.stream.Write(outputMessage.Payload)
112103
return err
113104
}
114105

@@ -120,41 +111,40 @@ func (p *BasicPortForwarding) startLocalConn(log log.T) (err error) {
120111
localPortNumber = "0"
121112
}
122113

123-
var listener net.Listener
124-
if listener, err = p.startLocalListener(log, localPortNumber); err != nil {
114+
if err = p.startLocalListener(log, localPortNumber); err != nil {
125115
log.Errorf("Unable to open tcp connection to port. %v", err)
126116
return err
127117
}
128118

129-
var tcpConn net.Conn
130-
if tcpConn, err = acceptConnection(log, listener); err != nil {
131-
log.Errorf("Failed to accept connection with error. %v", err)
132-
return err
119+
if p.stream, err = p.listener.Accept(); err != nil {
120+
if p.session.DataChannel.IsSessionEnded() == false {
121+
log.Errorf("Failed to accept connection with error. %v", err)
122+
return err
123+
}
124+
}
125+
if p.session.DataChannel.IsSessionEnded() == false {
126+
log.Infof("Connection accepted for session %s.", p.sessionId)
127+
fmt.Printf("Connection accepted for session %s.\n", p.sessionId)
133128
}
134-
log.Infof("Connection accepted for session %s.", p.sessionId)
135-
fmt.Printf("Connection accepted for session %s.\n", p.sessionId)
136-
137-
p.listener = &listener
138-
p.stream = &tcpConn
139129

140130
return
141131
}
142132

143133
// startLocalListener starts a local listener to given address
144-
func (p *BasicPortForwarding) startLocalListener(log log.T, portNumber string) (listener net.Listener, err error) {
134+
func (p *BasicPortForwarding) startLocalListener(log log.T, portNumber string) (err error) {
145135
var displayMessage string
146136
switch p.portParameters.LocalConnectionType {
147137
case "unix":
148-
if listener, err = getNewListener(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil {
138+
if p.listener, err = net.Listen(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil {
149139
return
150140
}
151141
displayMessage = fmt.Sprintf("Unix socket %s opened for sessionId %s.", p.portParameters.LocalUnixSocket, p.sessionId)
152142
default:
153-
if listener, err = getNewListener("tcp", "localhost:"+portNumber); err != nil {
143+
if p.listener, err = net.Listen("tcp", "localhost:"+portNumber); err != nil {
154144
return
155145
}
156146
// get port number the TCP listener opened
157-
p.portParameters.LocalPortNumber = strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)
147+
p.portParameters.LocalPortNumber = strconv.Itoa(p.listener.Addr().(*net.TCPAddr).Port)
158148
displayMessage = fmt.Sprintf("Port %s opened for sessionId %s.", p.portParameters.LocalPortNumber, p.sessionId)
159149
}
160150

@@ -171,29 +161,31 @@ func (p *BasicPortForwarding) handleControlSignals(log log.T) {
171161
<-c
172162
fmt.Println("Terminate signal received, exiting.")
173163

164+
p.session.DataChannel.EndSession()
174165
if version.DoesAgentSupportTerminateSessionFlag(log, p.session.DataChannel.GetAgentVersion()) {
175166
if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil {
176167
log.Errorf("Failed to send TerminateSession flag: %v", err)
177168
}
178169
fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId)
179-
p.Stop()
180170
} else {
181171
p.session.TerminateSession(log)
182172
}
173+
p.Stop()
183174
}()
184175
}
185176

186177
// reconnect closes existing connection, listens to new connection and accept it
187178
func (p *BasicPortForwarding) reconnect(log log.T) (err error) {
188179
// close existing connection as it is in a state from which data cannot be read
189-
(*p.stream).Close()
180+
p.stream.Close()
190181

191182
// wait for new connection on listener and accept it
192-
var conn net.Conn
193-
if conn, err = acceptConnection(log, *p.listener); err != nil {
194-
return log.Errorf("Failed to accept connection with error. %v", err)
183+
if p.stream, err = p.listener.Accept(); err != nil {
184+
if p.session.DataChannel.IsSessionEnded() == false {
185+
log.Errorf("Failed to accept connection with error. %v", err)
186+
return err
187+
}
195188
}
196-
p.stream = &conn
197189

198190
return
199191
}

0 commit comments

Comments
 (0)