|
10 | 10 | from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes |
11 | 11 | from cryptography.hazmat.primitives.ciphers.aead import AESGCM |
12 | 12 |
|
| 13 | +from . import utils |
| 14 | + |
13 | 15 | class TransformDirection(Enum): |
14 | 16 | Encrypt = 0 |
15 | 17 | Decrypt = 1 |
@@ -110,6 +112,9 @@ def __init__(self, master_keys: SrtpMasterKeys): |
110 | 112 | """ |
111 | 113 | MS-SRTP context |
112 | 114 | """ |
| 115 | + self.roc = 0 |
| 116 | + self.seq = 0 |
| 117 | + |
113 | 118 | self.master_keys = master_keys |
114 | 119 | self.session_keys = SrtpContext._derive_session_keys( |
115 | 120 | self.master_keys.master_key, self.master_keys.master_salt |
@@ -149,19 +154,13 @@ def _derive_single_key(master_key, master_salt, key_index: int = 0, max_bytes: i |
149 | 154 | import binascii |
150 | 155 | '''SRTP key derivation, https://tools.ietf.org/html/rfc3711#section-4.3''' |
151 | 156 |
|
152 | | - def bytes_to_int(b): |
153 | | - return int.from_bytes(b, byteorder='big') |
154 | | - |
155 | | - def int_to_bytes(i, n_bytes): |
156 | | - return i.to_bytes(n_bytes, byteorder='big') |
157 | | - |
158 | 157 | assert len(master_key) == 128 // 8 |
159 | 158 | assert len(master_salt) == 112 // 8 |
160 | | - salt = bytes_to_int(master_salt) |
| 159 | + salt = utils.bytes_to_int(master_salt) |
161 | 160 |
|
162 | 161 | DIV = lambda x, y: 0 if y == 0 else x // y |
163 | 162 | prng = lambda iv: SrtpContext._crypt_ctr_oneshot( |
164 | | - master_key, int_to_bytes(iv, 16), b'\x00' * 16, max_bytes=max_bytes |
| 163 | + master_key, utils.int_to_bytes(iv, 16), b'\x00' * 16, max_bytes=max_bytes |
165 | 164 | ) |
166 | 165 | r = DIV(pkt_i, key_derivation_rate) # pkt_i is always 48 bits |
167 | 166 | derive_key_from_label = lambda label: prng( |
@@ -189,19 +188,26 @@ def _decrypt(ctx: AESGCM, nonce: bytes, data: bytes, aad: bytes) -> bytes: |
189 | 188 | def _encrypt(ctx: AESGCM, nonce: bytes, data: bytes, aad: bytes) -> bytes: |
190 | 189 | return ctx.encrypt(nonce, data, aad) |
191 | 190 |
|
192 | | - def _get_transformed_nonce(self, transform_direction: TransformDirection) -> bytes: |
193 | | - # Skip first 2 bytes of Nonce key |
194 | | - nonce = bytearray(self.session_keys.salt_key[2:]) |
195 | | - # TODO: Implement transform logic |
196 | | - # FIXME: Just tranforming the Nonce to a known value for |
197 | | - # our single test packet |
198 | | - nonce[-1] += 1 |
199 | | - return nonce |
200 | | - |
201 | | - def decrypt(self, data: bytes, aad: bytes) -> bytes: |
202 | | - nonce = self._get_transformed_nonce(TransformDirection.Decrypt) |
203 | | - return SrtpContext._decrypt(self.decryptor_ctx, nonce, data, aad) |
| 191 | + @staticmethod |
| 192 | + def packet_index(roc, seq): |
| 193 | + return seq + (roc << 16) |
| 194 | + |
| 195 | + @staticmethod |
| 196 | + def _calc_iv(salt, ssrc, pkt_i): |
| 197 | + salt = utils.bytes_to_int(salt) |
| 198 | + iv = ((ssrc << (48)) + pkt_i) ^ salt |
| 199 | + return utils.int_to_bytes(iv, 12) |
204 | 200 |
|
205 | | - def encrypt(self, data: bytes, aad: bytes) -> RtpPacket: |
206 | | - nonce = self._get_transformed_nonce(TransformDirection.Encrypt) |
207 | | - return SrtpContext._encrypt(self.decryptor_ctx, nonce, data, aad) |
| 201 | + def decrypt_packet(self, rtp_packet: bytes) -> RtpPacket: |
| 202 | + rtp_header = rtp_packet[:12] |
| 203 | + parsed = RtpPacket.parse(rtp_packet) |
| 204 | + |
| 205 | + if parsed.sequence_number < self.seq: |
| 206 | + self.roc += 1 |
| 207 | + self.seq = parsed.sequence_number |
| 208 | + pkt_i = SrtpContext.packet_index(self.roc, self.seq) |
| 209 | + iv = SrtpContext._calc_iv(self.session_keys.salt_key[2:], parsed.ssrc, pkt_i) |
| 210 | + |
| 211 | + decrypted_payload = SrtpContext._decrypt(self.decryptor_ctx, iv, parsed.payload, rtp_header) |
| 212 | + parsed.payload = decrypted_payload |
| 213 | + return parsed |
0 commit comments