|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import enum |
| 4 | +import struct |
| 5 | +from typing import Final, Iterable, NamedTuple, Sequence |
| 6 | + |
| 7 | + |
| 8 | +def generate(): |
| 9 | + simple_oid = OID((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)) |
| 10 | + ITEMS: list[LineItem] = [ |
| 11 | + Comment("This file is GENERATED! DO NOT MODIFY!"), |
| 12 | + Comment("Instead, modify the content of make-dicts.py"), |
| 13 | + Line(), |
| 14 | + Comment("Random values"), |
| 15 | + Entry("int32_1729", encode_value(1729)), |
| 16 | + Entry("int64_1729", struct.pack("<q", 1729)), |
| 17 | + Entry("prefixed_string", make_string("prefixed-string")), |
| 18 | + Entry("empty_obj", wrap_obj(b"")), |
| 19 | + Entry("c_string", make_cstring("null-terminated-string")), |
| 20 | + Line(), |
| 21 | + Comment("Elements"), |
| 22 | + Entry("null_elem", element("N", None)), |
| 23 | + Entry("undef_elem", element("U", Undefined)), |
| 24 | + Entry("string_elem", element("S", "string-value")), |
| 25 | + Entry("empty_bin_elem", element("Bg", Binary(0, b""))), |
| 26 | + Entry("empty_regex_elem", element("Rx0", Regex("", ""))), |
| 27 | + Entry("simple_regex_elem", element("Rx1", Regex("foo", "ig"))), |
| 28 | + Entry("encrypted_bin_elem", element("Be", Binary(6, b"meow"))), |
| 29 | + Entry("empty_obj_elem", element("Obj0", Doc())), |
| 30 | + Entry( |
| 31 | + "code_w_s_elem", |
| 32 | + element("Clz", CodeWithScope("void 0;", Doc([Elem("foo", "bar")]))), |
| 33 | + ), |
| 34 | + Entry("code_elem", element("Js", Code("() => 0;"))), |
| 35 | + Entry("symbol_elem", element("Sym", Symbol("symbol"))), |
| 36 | + Entry("oid_elem", element("OID", simple_oid)), |
| 37 | + Entry("dbpointer_elem", element("dbp", DBPointer(String("db"), simple_oid))), |
| 38 | + Line(), |
| 39 | + Comment("Embedded nul"), |
| 40 | + Comment("This string contains an embedded null, which is abnormal but valid"), |
| 41 | + Entry("string_with_null", element("S0", "string\0value")), |
| 42 | + Comment("This regex has an embedded null, which is invalid"), |
| 43 | + Entry("bad_regex_elem", element("RxB", Regex("f\0oo", "ig"))), |
| 44 | + Comment("This element's key contains an embedded null, which is invalid"), |
| 45 | + Entry("bad_key_elem", element("foo\0bar", "string")), |
| 46 | + Line(), |
| 47 | + Comment("Objects"), |
| 48 | + Entry("obj_with_string", wrap_obj(element("single-elem", "foo"))), |
| 49 | + Entry("obj_with_null", wrap_obj(element("null", None))), |
| 50 | + Entry("obj_missing_term", wrap_obj(b"")[:-1]), |
| 51 | + ] |
| 52 | + |
| 53 | + for it in ITEMS: |
| 54 | + emit(it) |
| 55 | + |
| 56 | + |
| 57 | +BytesIter = bytes | Iterable["BytesIter"] |
| 58 | + |
| 59 | + |
| 60 | +def flatten(b: BytesIter) -> bytes: |
| 61 | + if isinstance(b, bytes): |
| 62 | + return b |
| 63 | + else: |
| 64 | + return b"".join(map(flatten, b)) |
| 65 | + |
| 66 | + |
| 67 | +def len_prefix(b: BytesIter) -> bytes: |
| 68 | + """Prepend an i32le byte-length prefix to a set of bytes""" |
| 69 | + b = flatten(b) |
| 70 | + length = len(b) |
| 71 | + return encode_value(length) + b |
| 72 | + |
| 73 | + |
| 74 | +def make_cstring(s: str) -> bytes: |
| 75 | + """Encode a UTF-8 string and append a null terminator""" |
| 76 | + return s.encode("utf-8") + b"\0" |
| 77 | + |
| 78 | + |
| 79 | +def make_string(s: str) -> bytes: |
| 80 | + """Create a length-prefixed string byte sequence""" |
| 81 | + return len_prefix(make_cstring(s)) |
| 82 | + |
| 83 | + |
| 84 | +def wrap_obj(items: BytesIter) -> bytes: |
| 85 | + """Wrap a sequence of bytes as if a BSON object (adds a header and trailing nul)""" |
| 86 | + bs = flatten(items) |
| 87 | + header = len(bs) + 5 |
| 88 | + return encode_value(header) + bs + b"\0" |
| 89 | + |
| 90 | + |
| 91 | +class UndefinedType: |
| 92 | + def __bytes__(self) -> bytes: |
| 93 | + return b"" |
| 94 | + |
| 95 | + |
| 96 | +class Binary(NamedTuple): |
| 97 | + tag: int |
| 98 | + data: bytes |
| 99 | + |
| 100 | + def __bytes__(self) -> bytes: |
| 101 | + return encode_value(len(self.data)) + bytes([self.tag]) + self.data |
| 102 | + |
| 103 | + |
| 104 | +class OID(NamedTuple): |
| 105 | + octets: tuple[int, int, int, int, int, int, int, int, int, int, int, int] |
| 106 | + |
| 107 | + def __bytes__(self) -> bytes: |
| 108 | + return bytes(self.octets) |
| 109 | + |
| 110 | + |
| 111 | +class DBPointer(NamedTuple): |
| 112 | + db: String |
| 113 | + oid: OID |
| 114 | + |
| 115 | + def __bytes__(self) -> bytes: |
| 116 | + return self.db.__bytes__() + self.oid.__bytes__() |
| 117 | + |
| 118 | + |
| 119 | +class Regex(NamedTuple): |
| 120 | + rx: str |
| 121 | + opts: str |
| 122 | + |
| 123 | + def __bytes__(self) -> bytes: |
| 124 | + return make_cstring(self.rx) + make_cstring(self.opts) |
| 125 | + |
| 126 | + |
| 127 | +class Elem(NamedTuple): |
| 128 | + key: str |
| 129 | + val: ValueType |
| 130 | + |
| 131 | + def __bytes__(self) -> bytes: |
| 132 | + return element(self.key, self.val) |
| 133 | + |
| 134 | + |
| 135 | +class Doc(NamedTuple): |
| 136 | + items: Sequence[Elem] = () |
| 137 | + |
| 138 | + def __bytes__(self) -> bytes: |
| 139 | + return wrap_obj((e.__bytes__() for e in self.items)) |
| 140 | + |
| 141 | + |
| 142 | +class String(NamedTuple): |
| 143 | + s: str |
| 144 | + |
| 145 | + def __bytes__(self) -> bytes: |
| 146 | + return make_string(self.s) |
| 147 | + |
| 148 | + |
| 149 | +class Symbol(String): |
| 150 | + pass |
| 151 | + |
| 152 | + |
| 153 | +class Code(String): |
| 154 | + pass |
| 155 | + |
| 156 | + |
| 157 | +class CodeWithScope(NamedTuple): |
| 158 | + code: str |
| 159 | + scope: Doc |
| 160 | + |
| 161 | + def __bytes__(self) -> bytes: |
| 162 | + string = make_string(self.code) |
| 163 | + doc = self.scope.__bytes__() |
| 164 | + length = len(string) + len(doc) + 4 |
| 165 | + return encode_value(length) + string + doc |
| 166 | + |
| 167 | + |
| 168 | +Undefined: Final = UndefinedType() |
| 169 | +ValueType = ( |
| 170 | + int |
| 171 | + | str |
| 172 | + | float |
| 173 | + | None |
| 174 | + | UndefinedType |
| 175 | + | Binary |
| 176 | + | Regex |
| 177 | + | Doc |
| 178 | + | CodeWithScope |
| 179 | + | String |
| 180 | + | Symbol |
| 181 | + | Code |
| 182 | + | OID |
| 183 | + | DBPointer |
| 184 | +) |
| 185 | + |
| 186 | + |
| 187 | +def encode_value(val: ValueType) -> bytes: |
| 188 | + match val: |
| 189 | + case int(n): |
| 190 | + return struct.pack("<i", n) |
| 191 | + case str(s): |
| 192 | + return make_string(s) |
| 193 | + case float(f): |
| 194 | + return struct.pack("<d", f) |
| 195 | + case ( |
| 196 | + Doc() |
| 197 | + | Binary() |
| 198 | + | Regex() |
| 199 | + | UndefinedType() |
| 200 | + | CodeWithScope() |
| 201 | + | Code() |
| 202 | + | String() |
| 203 | + | Symbol() |
| 204 | + | DBPointer() |
| 205 | + | OID() |
| 206 | + ) as d: |
| 207 | + return d.__bytes__() |
| 208 | + case None: |
| 209 | + return b"" |
| 210 | + |
| 211 | + |
| 212 | +class Tag(enum.Enum): |
| 213 | + EOD = 0 |
| 214 | + Double = 1 |
| 215 | + UTF8 = 2 |
| 216 | + Doc = 3 |
| 217 | + Binary = 5 |
| 218 | + Undefined = 6 |
| 219 | + OID = 7 |
| 220 | + Null = 10 |
| 221 | + Regex = 11 |
| 222 | + DBPointer = 12 |
| 223 | + Code = 13 |
| 224 | + Symbol = 14 |
| 225 | + CodeWithScope = 15 |
| 226 | + Int32 = 16 |
| 227 | + Int64 = 18 |
| 228 | + |
| 229 | + |
| 230 | +def element(key: str, value: ValueType, *, type: None | Tag = None) -> bytes: |
| 231 | + if type is not None: |
| 232 | + return flatten([bytes([type.value]), make_cstring(key), encode_value(value)]) |
| 233 | + |
| 234 | + match value: |
| 235 | + case int(): |
| 236 | + type = Tag.Int32 |
| 237 | + |
| 238 | + case float(): |
| 239 | + type = Tag.Double |
| 240 | + case None: |
| 241 | + type = Tag.Null |
| 242 | + case UndefinedType(): |
| 243 | + type = Tag.Undefined |
| 244 | + case Binary(): |
| 245 | + type = Tag.Binary |
| 246 | + case Regex(): |
| 247 | + type = Tag.Regex |
| 248 | + case Doc(): |
| 249 | + type = Tag.Doc |
| 250 | + case CodeWithScope(): |
| 251 | + type = Tag.CodeWithScope |
| 252 | + case Code(): |
| 253 | + type = Tag.Code |
| 254 | + case Symbol(): |
| 255 | + type = Tag.Symbol |
| 256 | + case str() | String(): # Must appear after Code()/Symbol() |
| 257 | + type = Tag.UTF8 |
| 258 | + case OID(): |
| 259 | + type = Tag.OID |
| 260 | + case DBPointer(): |
| 261 | + type = Tag.DBPointer |
| 262 | + return element(key, value, type=type) |
| 263 | + |
| 264 | + |
| 265 | +class Entry(NamedTuple): |
| 266 | + key: str |
| 267 | + "The key for the entry. Only for human readability" |
| 268 | + value: bytes |
| 269 | + "The arbitrary bytes that make up the entry" |
| 270 | + |
| 271 | + |
| 272 | +class Comment(NamedTuple): |
| 273 | + txt: str |
| 274 | + |
| 275 | + |
| 276 | +class Line(NamedTuple): |
| 277 | + txt: str = "" |
| 278 | + |
| 279 | + |
| 280 | +LineItem = Entry | Comment | Line |
| 281 | + |
| 282 | + |
| 283 | +def escape(b: bytes) -> Iterable[str]: |
| 284 | + s = b.decode("ascii", "backslashreplace") |
| 285 | + for u8 in b: |
| 286 | + s = chr(u8) # 0 <= u8 and u8 <= 255 |
| 287 | + if s.isascii() and s.isprintable(): |
| 288 | + yield s |
| 289 | + continue |
| 290 | + # Byte is not valid ASCII, or is not a printable char |
| 291 | + yield f"\\x{u8:0>2x}" |
| 292 | + |
| 293 | + |
| 294 | +def emit(item: LineItem): |
| 295 | + match item: |
| 296 | + case Line(t): |
| 297 | + print(t) |
| 298 | + case Comment(txt): |
| 299 | + print(f"# {txt}") |
| 300 | + case Entry(key, val): |
| 301 | + s = "".join(escape(val)) |
| 302 | + s = s.replace('"', r"\x22") |
| 303 | + print(f'{key}="{s}"') |
| 304 | + |
| 305 | + |
| 306 | +if __name__ == "__main__": |
| 307 | + generate() |
0 commit comments