diff --git a/example.lpi b/example.lpi index 6cfef86..2399ffb 100644 --- a/example.lpi +++ b/example.lpi @@ -1,16 +1,16 @@ - + + - <UseAppBundle Value="False"/> <ResourceType Value="res"/> @@ -24,7 +24,6 @@ </PublishOptions> <RunParams> <FormatVersion Value="2"/> - <Modes Count="0"/> </RunParams> <Units Count="3"> <Unit0> diff --git a/example.lpr b/example.lpr index 8a8ad42..6c38dd6 100644 --- a/example.lpr +++ b/example.lpr @@ -2,8 +2,8 @@ {$mode objfpc}{$H+} -uses {$IFDEF UNIX} {$IFDEF UseCThreads} - cthreads, {$ENDIF} {$ENDIF} +uses {$IFDEF UNIX} + cthreads, {$ENDIF} Classes, SysUtils, WebSocket, @@ -40,8 +40,7 @@ TSocketHandler = class(TThreadedWebsocketHandler) var str: string; begin - WriteLn('Connected to ', NetAddrToStr( - ACommunication.SocketStream.RemoteAddress.sin_addr)); + WriteLn('Connected to ', ACommunication.SocketStream.RemoteAddress.Address); ACommunication.OnRecieveMessage := @MessageRecieved; ACommunication.OnClose := @ConnectionClosed; while ACommunication.Open do @@ -55,8 +54,8 @@ TSocketHandler = class(TThreadedWebsocketHandler) finally Free; end; - WriteLn('Message sent to ', - NetAddrToStr(ACommunication.SocketStream.RemoteAddress.sin_addr), ': ', str); + WriteLn('Message to ', ACommunication.SocketStream.RemoteAddress.Address, + ': ', str); end; socket.Stop(True); end; @@ -66,8 +65,7 @@ TSocketHandler = class(TThreadedWebsocketHandler) Comm: TWebsocketCommunincator; begin Comm := TWebsocketCommunincator(Sender); - WriteLn('Connection to ', NetAddrToStr(Comm.SocketStream.RemoteAddress.sin_addr), - ' closed'); + WriteLn('Connection to ', Comm.SocketStream.RemoteAddress.Address, ' closed'); end; procedure TSocketHandler.MessageRecieved(Sender: TObject); @@ -83,8 +81,7 @@ TSocketHandler = class(TThreadedWebsocketHandler) for m in Messages do if m is TWebsocketStringMessage then begin - WriteLn('Message from ', - NetAddrToStr(Comm.SocketStream.RemoteAddress.sin_addr), + WriteLn('Message from ', Comm.SocketStream.RemoteAddress.Address, ': ', TWebsocketStringMessage(m).Data); end; finally diff --git a/websocket.pas b/websocket.pas index 4337c26..d731f41 100644 --- a/websocket.pas +++ b/websocket.pas @@ -8,6 +8,21 @@ interface Classes, SysUtils, ssockets, fgl, sha1, base64, utilities, Sockets; type + + { EWebsocketError } + + EWebsocketError = class(Exception) + private + FCode: integer; + public + constructor Create(const msg: string; ACode: integer); + property Code: integer read FCode; + end; + + EWebsocketWriteError = class(EWebsocketError); + + EWebsocketReadError = class(EWebsocketError); + { TRequestHeaders } TRequestHeaders = class(specialize TFPGMap<string, string>) @@ -73,11 +88,37 @@ TWebsocketMessageList = class(specialize TFPGList<TWebsocketMessage>); TWebsocketMessageOwnerList = class(specialize TFPGObjectList<TWebsocketMessage>); TLockedWebsocketMessageList = class(specialize TThreadedObject<TWebsocketMessageList>); + TNetAddress = record + Address: string; + Port: integer; + end; + + { TLockedSocketStream } + + TLockedSocketStream = class + private + FLocalAddress: TNetAddress; + FRemoteAddress: TNetAddress; + FStream: TSocketStream; + FLock: TRTLCriticalSection; + function isOpen: boolean; + public + constructor Create(const AStream: TSocketStream); + destructor Destroy; override; + + function Lock: TSocketStream; + procedure Unlock; + procedure CloseStream; + property Open: boolean read isOpen; + property RemoteAddress: TNetAddress read FRemoteAddress; + property LocalAddress: TNetAddress read FLocalAddress; + end; + { TWebsocketMessageStream } TWebsocketMessageStream = class(TStream) private - FDataStream: TSocketStream; + FDataStream: TLockedSocketStream; FMaxFrameSize: int64; FMessageType: TWebsocketMessageType; FBuffer: TBytes; @@ -87,7 +128,7 @@ TWebsocketMessageStream = class(TStream) procedure WriteDataFrame(Finished: boolean = False); public - constructor Create(const ADataStream: TSocketStream; + constructor Create(const ADataStream: TLockedSocketStream; AMessageType: TWebsocketMessageType = wmtString; AMaxFrameLen: int64 = 125; AMaskKey: integer = -1); destructor Destroy; override; @@ -100,17 +141,17 @@ TWebsocketMessageStream = class(TStream) TWebsocketCommunincator = class private - FStream: TSocketStream; + FStream: TLockedSocketStream; FMessages: TLockedWebsocketMessageList; FMaskMessages: boolean; FAssumeMaskedMessages: boolean; FOnRecieveMessage: TNotifyEvent; FOnClose: TNotifyEvent; - FOpen: boolean; FExpectClose: boolean; function GenerateMask: integer; + function GetOpen: boolean; public - constructor Create(AStream: TSocketStream; AMaskMessage: boolean; + constructor Create(AStream: TLockedSocketStream; AMaskMessage: boolean; AssumeMaskedMessages: boolean); destructor Destroy; override; @@ -125,8 +166,8 @@ TWebsocketCommunincator = class property OnRecieveMessage: TNotifyEvent read FOnRecieveMessage write FOnRecieveMessage; property OnClose: TNotifyEvent read FOnClose write FOnClose; - property SocketStream: TSocketStream read FStream; - property Open: boolean read FOpen; + property SocketStream: TLockedSocketStream read FStream; + property Open: boolean read GetOpen; end; { TWebsocketHandler } @@ -331,8 +372,8 @@ TAcceptingThread = class(TPoolableThread) HandlerThreadPool: TLockedHandlerThreadPool; AcceptingThreadPool: TLockedAcceptingThreadPool; -function CreateAcceptingThread(const AHandshakeHandler: TWebsocketHandshakeHandler): -TAcceptingThread; inline; +function CreateAcceptingThread( + const AHandshakeHandler: TWebsocketHandshakeHandler): TAcceptingThread; inline; var pool: TAcceptingThreadPool; begin @@ -362,8 +403,8 @@ function CreateHandlerThread(const ACommunicator: TWebsocketCommunincator; end; end; -function CreateRecieverThread( - const ACommunicator: TWebsocketCommunincator): TWebsocketRecieverThread; inline; +function CreateRecieverThread(const ACommunicator: TWebsocketCommunincator): +TWebsocketRecieverThread; inline; var pool: TRecieverThreadPool; begin @@ -406,6 +447,64 @@ function DoHeaderKeyCompare(const Key1, Key2: string): integer; Result := CompareStr(Key1.ToLower, Key2.ToLower); end; +{ EWebsocketError } + +constructor EWebsocketError.Create(const msg: string; ACode: integer); +begin + inherited Create(msg); + FCode := ACode; +end; + +{ TLockedSocketStream } + +function TLockedSocketStream.isOpen: boolean; +begin + Lock; + try + Result := Assigned(FStream); + finally + Unlock; + end; +end; + +constructor TLockedSocketStream.Create(const AStream: TSocketStream); +begin + FLocalAddress.Address := NetAddrToStr(AStream.LocalAddress.sin_addr); + FLocalAddress.Port := AStream.LocalAddress.sin_port; + FRemoteAddress.Address := NetAddrToStr(AStream.RemoteAddress.sin_addr); + FRemoteAddress.Port := AStream.LocalAddress.sin_port; + FStream := AStream; + InitCriticalSection(FLock); +end; + +destructor TLockedSocketStream.Destroy; +begin + CloseStream; + DoneCriticalsection(FLock); + inherited Destroy; +end; + +function TLockedSocketStream.Lock: TSocketStream; +begin + EnterCriticalsection(FLock); + Result := FStream; +end; + +procedure TLockedSocketStream.Unlock; +begin + LeaveCriticalsection(FLock); +end; + +procedure TLockedSocketStream.CloseStream; +begin + Lock; + try + FreeAndNil(FStream); + finally + Unlock; + end; +end; + procedure TAcceptingThread.DoExecute; begin FHandshakeHandler.PerformHandshake; @@ -438,6 +537,7 @@ procedure TWebsocketRecieverThread.DoExecute; while not Terminated and not FStopped and FCommunicator.Open do begin FCommunicator.RecieveMessage; + Sleep(1000); Yield; end; end; @@ -456,14 +556,18 @@ function TWebsocketCommunincator.GenerateMask: integer; Result := integer(Random(DWord.MaxValue)); end; -constructor TWebsocketCommunincator.Create(AStream: TSocketStream; +function TWebsocketCommunincator.GetOpen: boolean; +begin + Result := FStream.Open; +end; + +constructor TWebsocketCommunincator.Create(AStream: TLockedSocketStream; AMaskMessage: boolean; AssumeMaskedMessages: boolean); begin FStream := AStream; FMaskMessages := AMaskMessage; FAssumeMaskedMessages := AssumeMaskedMessages; FMessages := TLockedWebsocketMessageList.Create(TWebsocketMessageList.Create); - FOpen := True; FExpectClose := False; end; @@ -471,13 +575,14 @@ destructor TWebsocketCommunincator.Destroy; begin // Ending communication => Close stream Close(True); + FStream.Free; FMessages.Free; inherited Destroy; end; procedure TWebsocketCommunincator.Close(ForceClose: boolean); begin - if not FOpen then + if not Open then Exit; if not ForceClose then begin @@ -485,14 +590,68 @@ procedure TWebsocketCommunincator.Close(ForceClose: boolean); FExpectClose := True; Exit; end; - FOpen := False; if Assigned(FOnClose) then FOnClose(Self); - FStream.Free; + FStream.CloseStream; end; procedure TWebsocketCommunincator.RecieveMessage; + procedure ReadData(var buffer; const len: int64); + var + ToRead: longint; + Read: longint; + LeftToRead: int64; + TotalRead: int64; + oldTO: integer; + Stream: TSocketStream; + const + IOTimeoutError = 11; + WaitingTime = 10; + begin + TotalRead := 0; + repeat + // how much we are trying to read at a time + LeftToRead := len - TotalRead; + if LeftToRead > ToRead.MaxValue then + ToRead := ToRead.MaxValue + else + ToRead := LeftToRead; + // Reading + + Stream := FStream.Lock; + try + if not Assigned(Stream) then + begin + raise EWebsocketReadError.Create('Socket already closed', 0); + end; + oldTO := Stream.IOTimeout; + Stream.IOTimeout := 1; + try + Read := Stream.Read(PByte(@buffer)[TotalRead], ToRead); + if Read < 0 then + begin + // on Error + if Stream.LastError <> IOTimeoutError then + raise EWebsocketReadError.Create('error reading from stream', + Stream.LastError); + end + else + begin + // Increase the amount to read + TotalRead += Read; + end; + finally + Stream.IOTimeout := oldTO; + end; + finally + FStream.Unlock; + end; + if (TotalRead < len) and (Read <> ToRead) then // not finished, wait for some data + Sleep(WaitingTime); + until TotalRead >= len; + end; + procedure AddMessageToList(Message: TWebsocketMessage); var lst: TWebsocketMessageList; @@ -512,8 +671,49 @@ procedure TWebsocketCommunincator.RecieveMessage; end; end; + function ProcessSpecialMessages(messageType: TWebsocketMessageType; + var buffer; const buffLen: int64): boolean; + var + str: UTF8String; + begin + Result := True; + case messageType of + wmtClose: + begin + // If we didn't send the original close, return the message + if not FExpectClose then + WriteMessage(wmtClose).Free; + // Close the stream (true to not send a message + Close(True); + end; + wmtPing: + begin + // On ping send pong, with same content + with WriteMessage(wmtPong) do + try + if buffLen > 0 then + Write(PByte(@buffer)[0], buffLen); + finally + Free; + end; + end; + wmtPong: + begin + // lift pong message to message queue, so user can handle it + SetLength(str, buffLen); + if buffLen > 0 then + Move(buffer, str[1], buffLen); + AddMessageToList(TWebsocketPongMessage.Create(str)); + end + else + Result := False; + end; + end; + var Header: TWebsocketFrameHeader; + len64: int64; + len16: word; len: int64; MaskRec: TMaskRec; buffer: TBytes; @@ -522,6 +722,7 @@ procedure TWebsocketCommunincator.RecieveMessage; outputStream: TMemoryStream; messageType: TWebsocketMessageType; str: UTF8String; + w: word; begin Message := nil; outputStream := TMemoryStream.Create; @@ -530,18 +731,25 @@ procedure TWebsocketCommunincator.RecieveMessage; repeat if not Open then Exit; - Header := WordToFrameHeader(FStream.ReadWord); + ReadData(w, 2); + Header := WordToFrameHeader(w); if Header.OPCode <> wmtContinue then messageType := TWebsocketMessageType(Header.OPCode); if Header.PayloadLen < 126 then len := Header.PayloadLen else if Header.PayloadLen = 126 then - len := NToHs(FStream.ReadWord) + begin + ReadData(len16, SizeOf(len16)); + len := NToHs(len16); + end else - len := ntohll(FStream.ReadQWord); + begin + ReadData(len64, SizeOf(len64)); + len := ntohll(len64); + end; if Header.Mask then begin - MaskRec.Key := integer(FStream.ReadDWord); + ReadData(MaskRec.Key, SizeOf(MaskRec.Key)); end else if FAssumeMaskedMessages then begin @@ -552,7 +760,7 @@ procedure TWebsocketCommunincator.RecieveMessage; SetLength(buffer, len); if len > 0 then begin - FStream.ReadBuffer(buffer[0], len); + ReadData(buffer[0], len); if Header.Mask then begin // As this is 64 bit, to be 32 bit compatible we can't use a for loop @@ -565,40 +773,13 @@ procedure TWebsocketCommunincator.RecieveMessage; end; end; // Handling special messages - case messageType of - wmtClose: - begin - // If we didn't send the original close, return the message - if not FExpectClose then - WriteMessage(wmtClose).Free; - // Close the stream (true to not send a message - Close(True); - end; - wmtPing: - begin - // On ping send pong, with same content - with WriteMessage(wmtPong) do - try - if len > 0 then - Write(buffer[0], len); - finally - Free; - end; - end; - wmtPong: - begin - // lift pong message to message queue, so user can handle it - SetLength(str, len); - if len > 0 then - Move(buffer[0], str[1], len); - AddMessageToList(TWebsocketPongMessage.Create(str)); - end; - else - begin - // This is a dataframe, so save data for concatination of fragments - if len > 0 then - outputStream.WriteBuffer(buffer[0], len); - end; + if ProcessSpecialMessages(messageType, PByte(buffer)^, len) then + Continue + else + begin + // This is a dataframe, so save data for concatination of fragments + if len > 0 then + outputStream.WriteBuffer(buffer[0], len); end; until Header.Fin; // Read whole message @@ -622,11 +803,13 @@ procedure TWebsocketCommunincator.RecieveMessage; outputStream.Free; end; except - On e: EReadError do + On e: EWebsocketReadError do begin - // Stream has been closed - // FIXME: Some way to verify that? - Close(True); + if e.Code = 0 then + begin + // Stream has been closed + Close(True); + end; end; end; end; @@ -662,52 +845,67 @@ procedure TWebsocketMessageStream.WriteDataFrame(Finished: boolean); Header: TWebsocketFrameHeader; i: int64; MaskRec: TMaskRec; + Stream: TSocketStream; begin - Header.Fin := Finished; - Header.Mask := (FMaskKey <> -1); - if FFirstWrite then - Header.OPCode := FMessageType - else - Header.OPCode := wmtContinue; - // Compute size - if FCurrentLen < 126 then - Header.PayloadLen := FCurrentLen - else if FCurrentLen <= word.MaxValue then - Header.PayloadLen := 126 - else - Header.PayloadLen := 127; - // Write header - FDataStream.WriteWord(FrameHEaderToWord(Header)); - // Write size if it exceeds 125 - if (FCurrentLen > 125) then - begin - if (FCurrentLen <= word.MaxValue) then - FDataStream.WriteWord(htons(word(FCurrentLen))) - else - FDataStream.WriteQWord(htonll(QWord(FCurrentLen))); - end; - if Header.Mask then - begin - // If we use a mask - MaskRec.Key := FMaskKey; - // First: Transmit mask Key - FDataStream.WriteBuffer(MaskRec.Bytes[0], 4); - // 2. Encode Message - // As this is 64 bit, to be 32 bit compatible we can't use a for loop - i := 0; - while i < FCurrentLen do + Stream := FDataStream.Lock; + try + if not Assigned(Stream) then begin - FBuffer[i] := FBuffer[i] xor MaskRec.Bytes[i mod 4]; - Inc(i); + raise EWebsocketWriteError.Create('Stream already closed', 0); + end; + try + Header.Fin := Finished; + Header.Mask := (FMaskKey <> -1); + if FFirstWrite then + Header.OPCode := FMessageType + else + Header.OPCode := wmtContinue; + // Compute size + if FCurrentLen < 126 then + Header.PayloadLen := FCurrentLen + else if FCurrentLen <= word.MaxValue then + Header.PayloadLen := 126 + else + Header.PayloadLen := 127; + // Write header + Stream.WriteWord(FrameHEaderToWord(Header)); + // Write size if it exceeds 125 + if (FCurrentLen > 125) then + begin + if (FCurrentLen <= word.MaxValue) then + Stream.WriteWord(htons(word(FCurrentLen))) + else + Stream.WriteQWord(htonll(QWord(FCurrentLen))); + end; + if Header.Mask then + begin + // If we use a mask + MaskRec.Key := FMaskKey; + // First: Transmit mask Key + Stream.WriteBuffer(MaskRec.Bytes[0], 4); + // 2. Encode Message + // As this is 64 bit, to be 32 bit compatible we can't use a for loop + i := 0; + while i < FCurrentLen do + begin + FBuffer[i] := FBuffer[i] xor MaskRec.Bytes[i mod 4]; + Inc(i); + end; + end; + // Write Message payload + Stream.WriteBuffer(FBuffer[0], FCurrentLen); + // Reset state for next data + FCurrentLen := 0; + except + on E: EWriteError do + raise EWebsocketWriteError.Create(e.Message, Stream.LastError); end; + finally + FDataStream.Unlock; end; - // Write Message payload - FDataStream.WriteBuffer(FBuffer[0], FCurrentLen); - // Reset state for next data - FCurrentLen := 0; end; -constructor TWebsocketMessageStream.Create(const ADataStream: TSocketStream; +constructor TWebsocketMessageStream.Create(const ADataStream: TLockedSocketStream; AMessageType: TWebsocketMessageType; AMaxFrameLen: int64; AMaskKey: integer); begin FDataStream := ADataStream; @@ -1042,7 +1240,8 @@ procedure TWebsocketHandshakeHandler.PerformHandshake; finally RequestData.Headers.Free; end; - Comm := TWebsocketCommunincator.Create(FStream, False, True); + Comm := TWebsocketCommunincator.Create(TLockedSocketStream.Create(FStream), + False, True); finally // Not needed anymore, we can now die in piece. // All information requier for the rest is now on the stack