Skip to content

Commit 7a116c3

Browse files
committed
Crypto: Add nonce generation from pysrtp
1 parent a81a9ba commit 7a116c3

File tree

4 files changed

+58
-52
lines changed

4 files changed

+58
-52
lines changed

tests/test_crypto.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44

55
def test_decrypt(test_data: dict, crypto_context: srtp_crypto.SrtpContext):
66
rtp_packet_raw = test_data['rtp_connection_probing.bin']
7-
rtp_header, rtp_body = rtp_packet_raw[:12], rtp_packet_raw[12:]
87

9-
plaintext = crypto_context.decrypt(rtp_body, aad=rtp_header)
8+
plaintext = crypto_context.decrypt_packet(rtp_packet_raw)
109
with pytest.raises(Exception):
1110
# Skip 1 byte of "additional data" to ensure invalid data
12-
crypto_context.decrypt(rtp_body, aad=rtp_header[1:])
11+
crypto_context.decrypt(rtp_packet_raw[:-1])
1312

1413
assert plaintext is not None
1514

xcloud/protocol/srtp_crypto.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
1111
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
1212

13+
from . import utils
14+
1315
class TransformDirection(Enum):
1416
Encrypt = 0
1517
Decrypt = 1
@@ -110,6 +112,9 @@ def __init__(self, master_keys: SrtpMasterKeys):
110112
"""
111113
MS-SRTP context
112114
"""
115+
self.roc = 0
116+
self.seq = 0
117+
113118
self.master_keys = master_keys
114119
self.session_keys = SrtpContext._derive_session_keys(
115120
self.master_keys.master_key, self.master_keys.master_salt
@@ -149,19 +154,13 @@ def _derive_single_key(master_key, master_salt, key_index: int = 0, max_bytes: i
149154
import binascii
150155
'''SRTP key derivation, https://tools.ietf.org/html/rfc3711#section-4.3'''
151156

152-
def bytes_to_int(b):
153-
return int.from_bytes(b, byteorder='big')
154-
155-
def int_to_bytes(i, n_bytes):
156-
return i.to_bytes(n_bytes, byteorder='big')
157-
158157
assert len(master_key) == 128 // 8
159158
assert len(master_salt) == 112 // 8
160-
salt = bytes_to_int(master_salt)
159+
salt = utils.bytes_to_int(master_salt)
161160

162161
DIV = lambda x, y: 0 if y == 0 else x // y
163162
prng = lambda iv: SrtpContext._crypt_ctr_oneshot(
164-
master_key, int_to_bytes(iv, 16), b'\x00' * 16, max_bytes=max_bytes
163+
master_key, utils.int_to_bytes(iv, 16), b'\x00' * 16, max_bytes=max_bytes
165164
)
166165
r = DIV(pkt_i, key_derivation_rate) # pkt_i is always 48 bits
167166
derive_key_from_label = lambda label: prng(
@@ -189,19 +188,26 @@ def _decrypt(ctx: AESGCM, nonce: bytes, data: bytes, aad: bytes) -> bytes:
189188
def _encrypt(ctx: AESGCM, nonce: bytes, data: bytes, aad: bytes) -> bytes:
190189
return ctx.encrypt(nonce, data, aad)
191190

192-
def _get_transformed_nonce(self, transform_direction: TransformDirection) -> bytes:
193-
# Skip first 2 bytes of Nonce key
194-
nonce = bytearray(self.session_keys.salt_key[2:])
195-
# TODO: Implement transform logic
196-
# FIXME: Just tranforming the Nonce to a known value for
197-
# our single test packet
198-
nonce[-1] += 1
199-
return nonce
200-
201-
def decrypt(self, data: bytes, aad: bytes) -> bytes:
202-
nonce = self._get_transformed_nonce(TransformDirection.Decrypt)
203-
return SrtpContext._decrypt(self.decryptor_ctx, nonce, data, aad)
191+
@staticmethod
192+
def packet_index(roc, seq):
193+
return seq + (roc << 16)
194+
195+
@staticmethod
196+
def _calc_iv(salt, ssrc, pkt_i):
197+
salt = utils.bytes_to_int(salt)
198+
iv = ((ssrc << (48)) + pkt_i) ^ salt
199+
return utils.int_to_bytes(iv, 12)
204200

205-
def encrypt(self, data: bytes, aad: bytes) -> RtpPacket:
206-
nonce = self._get_transformed_nonce(TransformDirection.Encrypt)
207-
return SrtpContext._encrypt(self.decryptor_ctx, nonce, data, aad)
201+
def decrypt_packet(self, rtp_packet: bytes) -> RtpPacket:
202+
rtp_header = rtp_packet[:12]
203+
parsed = RtpPacket.parse(rtp_packet)
204+
205+
if parsed.sequence_number < self.seq:
206+
self.roc += 1
207+
self.seq = parsed.sequence_number
208+
pkt_i = SrtpContext.packet_index(self.roc, self.seq)
209+
iv = SrtpContext._calc_iv(self.session_keys.salt_key[2:], parsed.ssrc, pkt_i)
210+
211+
decrypted_payload = SrtpContext._decrypt(self.decryptor_ctx, iv, parsed.payload, rtp_header)
212+
parsed.payload = decrypted_payload
213+
return parsed

xcloud/protocol/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def bytes_to_int(b):
2+
return int.from_bytes(b, byteorder='big')
3+
4+
def int_to_bytes(i, n_bytes):
5+
return i.to_bytes(n_bytes, byteorder='big')

xcloud/scripts/pcap_reader.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class XcloudPcapParser:
2323
def __init__(self, srtp_key: Optional[str]):
2424
self.crypto: Optional[srtp_crypto.SrtpContext] = None
2525
if srtp_key:
26-
self.crypto = srtp_crypto.SrtpContext.from_base64(srtp_key)
26+
self.outgoing_crypto = srtp_crypto.SrtpContext.from_base64(srtp_key)
27+
self.incoming_crypto = srtp_crypto.SrtpContext.from_base64(srtp_key)
2728

2829
@property
2930
def PACKET_TYPES(self):
@@ -33,35 +34,30 @@ def PACKET_TYPES(self):
3334
(teredo.TeredoPacket.parse, self.get_info_teredo)
3435
]
3536

36-
def get_info_stun(self, stun: stun.Message) -> None:
37+
def get_info_stun(self, stun: stun.Message, is_client: bool) -> None:
3738
return f'STUN: {stun}'
3839

39-
def brute_force_nonce(self, nonce_orig: bytes) -> Generator:
40-
for byte1 in range(0, 0xFF):
41-
for byte2 in range(0, 0xFF):
42-
nonce_transform = b''.join([nonce_orig[:5], struct.pack('!B', byte1), nonce_orig[6:11], struct.pack('!B', byte2)])
43-
yield nonce_transform
44-
45-
def get_info_rtp(self, rtp: rtp.RtpPacket) -> None:
40+
def get_info_rtp(self, rtp: rtp.RtpPacket, is_client: bool) -> None:
4641
try:
4742
payload_name = packets.PayloadType(rtp.payload_type)
4843
except:
4944
payload_name = '<UNKNOWN>'
5045

51-
info_str = f'RTP: {payload_name.name} {rtp} SSRC={rtp.ssrc}'
52-
if self.crypto:
53-
rtp_packet_serialized = rtp.serialize()
54-
rtp_header, rtp_data = rtp_packet_serialized[:12], rtp_packet_serialized[12:]
55-
nonce_orig = self.crypto.session_keys.nonce_key[2:]
56-
for nonce_transformed in self.brute_force_nonce(nonce_orig):
57-
try:
58-
decrypted = self.crypto._decrypt(self.crypto.decryptor_ctx, nonce_transformed, rtp_data, rtp_header)
59-
info_str += "\n" + hexdump(decrypted, result='return') + "\n"
60-
except Exception:
61-
pass
46+
direction = 'OUT -> ' if is_client else '<- IN '
47+
info_str = f'{direction} RTP: {payload_name.name} {rtp} SSRC={rtp.ssrc}'
48+
if self.incoming_crypto and self.outgoing_crypto:
49+
rtp_packet = rtp.serialize()
50+
try:
51+
if is_client:
52+
rtp_decrypted = self.outgoing_crypto.decrypt_packet(rtp_packet)
53+
else:
54+
rtp_decrypted = self.incoming_crypto.decrypt_packet(rtp_packet)
55+
info_str += "\n" + hexdump(rtp_decrypted.payload, result='return') + "\n"
56+
except Exception:
57+
info_str += "\n DECRYPTION FAILED \n"
6258
return info_str
6359

64-
def get_info_teredo(self, teredo: teredo.TeredoPacket) -> None:
60+
def get_info_teredo(self, teredo: teredo.TeredoPacket, is_client: bool) -> None:
6561
info = f'TEREDO: {teredo}'
6662
if teredo.ipv6.next_header != ipv6.NO_NEXT_HEADER:
6763
data = teredo.ipv6.data
@@ -77,7 +73,8 @@ def get_info_general(self, packet: Any) -> Optional[str]:
7773
for cls, info_func in self.PACKET_TYPES:
7874
try:
7975
instance = cls(data)
80-
info = info_func(instance)
76+
is_client = (packet.dport == 54881)
77+
info = info_func(instance, is_client)
8178
return info
8279
except:
8380
pass
@@ -96,16 +93,15 @@ def packet_filter(self, filepath):
9693
continue
9794

9895
ip = eth.data
99-
subpacket = ip.data
100-
if not isinstance(subpacket, dpkt.udp.UDP):
96+
if not isinstance(ip.data, dpkt.udp.UDP):
10197
continue
10298

103-
yield(subpacket, ts)
99+
yield(ip, ts)
104100

105101

106102
def parse_file(self, pcap_filepath: str) -> None:
107103
for packet, timestamp in self.packet_filter(pcap_filepath):
108-
info = self.get_info_general(packet)
104+
info = self.get_info_general(packet.data)
109105
if info:
110106
print(info)
111107

0 commit comments

Comments
 (0)