-
-
Notifications
You must be signed in to change notification settings - Fork 8
/
__init__.py
executable file
·479 lines (411 loc) · 17.4 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
from __future__ import annotations
import re
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property, partial
from typing import TYPE_CHECKING
import panflute as pf
from panflute.tools import convert_text
from .helper import cancel_emph, cite_to_id_mode, cite_to_ref, merge_emph, parse_markdown_as_inline, to_emph
from .util import setup_logging
if TYPE_CHECKING:
from typing import Union
from panflute.elements import Doc, Element
THM_DEF = list[Union[str, dict[str, str], dict[str, list[str]]]]
__version__: str = "2.0.0"
PARENT_COUNTERS: set[str] = {
"part",
"chapter",
"section",
"subsection",
"subsubsection",
"paragraph",
"subparagraph",
}
STYLES: tuple[str, ...] = ("plain", "definition", "remark")
METADATA_KEY: str = "amsthm"
REF_REGEX = re.compile(r"^\\(ref|eqref)\{(.*)\}$")
LATEX_LIKE: set[str] = {"latex", "beamer"}
PLAIN_OR_DEF: set[str] = {"plain", "definition"}
COUNTER_DEPTH_DEFAULT: int = 0
logger = setup_logging()
def parse_info(info: str | None) -> list[Element]:
"""Convert theorem info to panflute AST inline elements."""
return [pf.Str(r"(")] + parse_markdown_as_inline(info) + [pf.Str(r")")] if info else []
@dataclass
class NewTheorem:
style: str
env_name: str
text: str = ""
parent_counter: str | None = None
shared_counter: str | None = None
numbered: bool = True
"""A LaTeX amsthm new theorem.
:param parent_counter: for LaTeX output, controlling the number of numbers in a theorem.
Should be used with counter_depth to match LaTeX and non-LaTeX output.
"""
def __post_init__(self) -> None:
if self.env_name.endswith("*"):
self.env_name = self.env_name[:-1]
self.numbered = False
if not self.text:
logger.debug("Defaulting text to %s", self.env_name)
self.text = self.env_name
if (parent_counter := self.parent_counter) is not None and parent_counter not in PARENT_COUNTERS:
logger.warning("Unsupported parent_counter %s, ignoring.", parent_counter)
if self.numbered and parent_counter is not None and self.shared_counter is not None:
logger.warning("Dropping shared_counter as parent_counter is defined.")
self.shared_counter = None
@property
def latex(self) -> str:
res = [r"\newtheorem"]
if not self.numbered:
res.append(f"*{{{self.env_name}}}{{{self.text}}}")
elif self.shared_counter is None:
if self.parent_counter is None:
res.append(f"{{{self.env_name}}}{{{self.text}}}")
else:
res.append(f"{{{self.env_name}}}{{{self.text}}}[{self.parent_counter}]")
else:
res.append(f"{{{self.env_name}}}[{self.shared_counter}]{{{self.text}}}")
return "".join(res)
@property
def class_name(self) -> str:
"""Name in pandoc div classes.
It cannot have space.
"""
return self.env_name.replace(" ", "_")
@property
def counter_name(self) -> str:
return self.env_name if self.shared_counter is None else self.shared_counter
def to_panflute_theorem_header(
self,
options: DocOptions,
id: str | None,
info: str | None,
) -> list[pf.Element]:
"""Return a theorem header as panflute AST.
This mutates `options.theorem_counters`, `options.identifiers` in-place.
"""
TextType: type[Element]
text = self.text
# text and number separated by Space
NumberType: type[Element]
theorem_number: str | None
if self.numbered:
counter_name = self.counter_name
options.theorem_counters[counter_name] += 1
theorem_counter = options.theorem_counters[counter_name]
theorem_number = ".".join([str(i) for i in options.header_counters] + [str(theorem_counter)])
if id:
options.identifiers[id] = theorem_number
else:
theorem_number = None
# no additional styling here
info_list = parse_info(info)
# append TextType of ".", Space
# cases: PLAIN_OR_DEF, theorem_number, info_list
if self.style in PLAIN_OR_DEF:
TextType = pf.Strong
NumberType = pf.Strong
else:
TextType = pf.Emph
NumberType = pf.Str
# We are normalizing the Emph/Strong boundary manually by having 6 cases
if theorem_number is None:
if info_list:
res = [TextType(pf.Str(text)), pf.Space] + info_list + [TextType(pf.Str(".")), pf.Space]
else:
res = [TextType(pf.Str(f"{text}.")), pf.Space]
else:
if TextType is NumberType:
if info_list:
res = (
[TextType(pf.Str(f"{text} {theorem_number}")), pf.Space]
+ info_list
+ [TextType(pf.Str(".")), pf.Space]
)
else:
res = [TextType(pf.Str(f"{text} {theorem_number}.")), pf.Space]
else:
if info_list:
res = (
[TextType(pf.Str(text)), pf.Space, pf.Str(theorem_number), pf.Space]
+ info_list
+ [TextType(pf.Str(".")), pf.Space]
)
else:
res = [
TextType(pf.Str(text)),
pf.Space,
pf.Str(theorem_number),
TextType(pf.Str(".")),
pf.Space,
]
return res
@dataclass
class Proof(NewTheorem):
style: str = "proof"
env_name: str = "proof"
text: str = "proof"
parent_counter: str | None = None
shared_counter: str | None = None
numbered: bool = False
def to_panflute_theorem_header(
self,
options: DocOptions,
id: str | None,
info: str | None,
) -> list[pf.Element]:
"""Return a theorem header as panflute AST."""
if info is None:
return [pf.Emph(pf.Str("Proof.")), pf.Space]
else:
# put it into a Para then walk
ast = parse_markdown_as_inline(info)
info_list = pf.Para(*ast)
info_list.walk(to_emph)
info_list.walk(cancel_emph)
info_list.walk(merge_emph)
return list(info_list.content) + [pf.Emph(pf.Str(".")), pf.Space]
@dataclass
class DocOptions:
"""Document options.
:param: counter_depth: can be n=0-6 inclusive.
n means n+1 numbers shown in non-LaTeX outputs.
e.g. n=1 means x.y, where x is the heading 1 counter, y is the theorem counter.
Should be used with parent_counter to match LaTeX and non-LaTeX output.
"""
theorems: dict[str, NewTheorem] = field(default_factory=dict)
counter_depth: int = COUNTER_DEPTH_DEFAULT
counter_ignore_headings: set[str] = field(default_factory=set)
def __post_init__(self) -> None:
try:
self.counter_depth = int(self.counter_depth)
except ValueError:
logger.warning("counter_depth must be int, default to 1.")
self.counter_depth = COUNTER_DEPTH_DEFAULT
# initial count is zero
# should be += 1 before using
self.header_counters: list[int] = [0] * self.counter_depth
self.reset_theorem_counters()
# from identifiers to numbers
self.identifiers: dict[str, str] = {}
def reset_theorem_counters(self) -> None:
self.theorem_counters: dict[str, int] = defaultdict(int)
@cached_property
def theorems_set(self) -> set[str]:
return set(self.theorems)
@classmethod
def from_doc(
cls,
doc: Doc,
) -> DocOptions:
options: dict[
str,
dict[str, str | dict[str, str] | THM_DEF],
] = doc.get_metadata(METADATA_KEY, {})
name_to_text: dict[str, str] = options.get("name_to_text", {}) # type: ignore[assignment, arg-type]
parent_counter: str = options.get("parent_counter", None) # type: ignore[assignment, arg-type]
theorems: dict[str, NewTheorem] = {}
for style in STYLES:
option: THM_DEF = options.get(style, []) # type: ignore[assignment]
for opt in option:
if isinstance(opt, dict):
for key, value in opt.items():
# key
theorem = NewTheorem(style, key, text=name_to_text.get(key, ""), parent_counter=parent_counter)
theorems[theorem.class_name] = theorem
# value(s)
if isinstance(value, list):
for v in value:
theorem = NewTheorem(style, v, text=name_to_text.get(v, ""), shared_counter=key)
theorems[theorem.class_name] = theorem
else:
v = value
theorem = NewTheorem(style, v, text=name_to_text.get(v, ""), shared_counter=key)
theorems[theorem.class_name] = theorem
else:
key = opt
theorem = NewTheorem(style, key, text=name_to_text.get(key, ""), parent_counter=parent_counter)
theorems[theorem.class_name] = theorem
# proof is predefined in amsthm
theorems["proof"] = Proof()
return cls(
theorems,
counter_depth=options.get("counter_depth", COUNTER_DEPTH_DEFAULT), # type: ignore[arg-type] # will be verified at __post_init__
counter_ignore_headings=set(options.get("counter_ignore_headings", set())),
)
@property
def latex(self) -> str:
cur_style: str = ""
res: list[str] = []
for theorem in self.theorems.values():
# proof is predefined in amsthm
if not isinstance(theorem, Proof):
if theorem.style != cur_style:
cur_style = theorem.style
res.append(f"\\theoremstyle{{{cur_style}}}")
res.append(theorem.latex)
return "\n".join(res)
@property
def to_panflute(self) -> pf.RawBlock:
return pf.RawBlock(self.latex, format="latex")
def prepare(doc: Doc) -> None:
doc._amsthm = options = DocOptions.from_doc(doc)
if doc.format in LATEX_LIKE:
doc.content.insert(0, options.to_panflute)
def amsthm(elem: Element, doc: Doc) -> None:
"""General amsthm transformation working for all document types.
Essentially we replicate LaTeX amsthm behavior in this filter.
"""
options: DocOptions = doc._amsthm
if isinstance(elem, pf.Header):
if elem.level <= options.counter_depth:
header_string = None
if (counter_ignore_headings := options.counter_ignore_headings) and (
header_string := pf.stringify(elem)
) in counter_ignore_headings:
logger.debug("Ignoring header %s in header_counters as it is in counter_ignore_headings", header_string)
else:
# Header.level is 1-indexed, while list is 0-indexed
options.header_counters[elem.level - 1] += 1
# reset deeper levels
for i in range(elem.level, options.counter_depth):
options.header_counters[i] = 0
logger.debug(
"Header encounter: %s, current counter: %s", header_string or elem, options.header_counters
)
options.reset_theorem_counters()
elif isinstance(elem, pf.Div):
environments: set[str] = options.theorems_set.intersection(elem.classes)
if environments:
if len(environments) != 1:
logger.warning("Multiple environments found: %s", environments)
return None
environment = environments.pop()
theorem = options.theorems[environment]
info = elem.attributes.get("info", None)
id = elem.identifier
res = theorem.to_panflute_theorem_header(options, id, info)
# theorem body
if theorem.style == "plain":
elem.walk(to_emph)
elem.walk(cancel_emph)
elem.walk(merge_emph)
try:
# insert in the beginning of the first block element
for r in reversed(res):
elem.content[0].content.insert(0, r)
except AttributeError:
# if fail, insert a Para before content
elem.content.insert(0, pf.Para(*res))
r = pf.RawInline("<span style='float: right'>◻</span>", format="html")
try:
# insert in the end of the last block element
if theorem.style == "proof":
elem.content[-1].content.append(r)
except AttributeError:
# if fail, append a Para
elem.content.append(pf.Para(r))
def resolve_ref(elem: Element, doc: Doc) -> pf.Str | None:
"""Resolve references to theorem numbers.
Consider this as post-process ref for general output formats.
"""
options: DocOptions = doc._amsthm
# from [@...] to number
if isinstance(elem, pf.Cite):
if (temp := cite_to_id_mode(elem)) is not None and (id := temp[0]) in options.identifiers:
mode = temp[1]
# @[...]
if mode == "NormalCitation":
return pf.Str(f"({options.identifiers[id]})")
# @...
elif mode == "AuthorInText":
return pf.Str(options.identifiers[id])
else:
logger.warning("Unknown citation mode %s from Cite: %s. Ignoring...", mode, elem)
return None
# from \ref{...} to number
elif isinstance(elem, pf.RawInline) and elem.format == "tex":
text = elem.text
if matches := REF_REGEX.findall(text):
if len(matches) != 1:
logger.warning("Ignoring ref matching in %s: %s", text, matches)
return None
ref_type, id = matches[0]
if id in options.identifiers:
if ref_type == "eqref":
return pf.Str(f"({options.identifiers[id]})")
else:
return pf.Str(options.identifiers[id])
return None
def collect_ref_id(elem: Element, doc: Doc) -> None:
"""Only collect all amsthm environment id.
This should be used before the `amsthm_latex` filter.
This is done in 2 passes as the id may be cited/referenced earlier than definition.
Consider this as pre-process of ref for LaTeX output.
`options.identifiers` modified in-place.
"""
# check if it is a Div, and the class is an amsthm environment
options: DocOptions = doc._amsthm
environments: set[str]
if isinstance(elem, pf.Div) and (environments := options.theorems_set.intersection(elem.classes)):
if len(environments) != 1:
logger.warning("Multiple environments found: %s", environments)
return None
if id := elem.identifier:
# in LaTeX output, we only need to keep a reference of the id
# the numbering (value of this dict) is handled by LaTeX
options.identifiers[id] = ""
return None
def amsthm_latex(elem: Element, doc: Doc) -> pf.RawBlock | None:
"""Transform amsthm defintion to LaTeX package specifications."""
# check if it is a Div, and the class is an amsthm environment
options: DocOptions = doc._amsthm
if isinstance(elem, pf.Div):
environments: set[str] = options.theorems_set.intersection(elem.classes)
if environments:
if len(environments) != 1:
logger.warning("Multiple environments found: %s", environments)
return None
environment = environments.pop()
theorem = options.theorems[environment]
div_content = pf.convert_text(elem, input_format="panflute", output_format="latex")
info = elem.attributes.get("info", None)
id = elem.identifier
res = [f"\\begin{{{theorem.env_name}}}"]
if info:
# wrap in Para for walk
ast = pf.Para(*parse_markdown_as_inline(info))
ast.walk(partial(cite_to_ref, check_id=options.identifiers))
ast = convert_text(ast, input_format="panflute", output_format="latex").strip()
res += [f"[{ast}]"]
if id:
res.append(f"\\label{{{id}}}")
res.append(f"\n{div_content}\n\\end{{{theorem.env_name}}}")
return pf.RawBlock("".join(res), format="latex")
# check if pf.Cite is done inside cite_to_ref
else:
return cite_to_ref(elem, doc, options.identifiers)
return None
def action1(elem: Element, doc: Doc) -> pf.RawBlock | None:
if doc.format in LATEX_LIKE:
collect_ref_id(elem, doc)
else:
amsthm(elem, doc)
return None
def action2(elem: Element, doc: Doc) -> pf.Str | pf.RawInline | None:
if doc.format in LATEX_LIKE:
return amsthm_latex(elem, doc)
else:
return resolve_ref(elem, doc)
def finalize(doc: Doc) -> None:
del doc._amsthm
def main(doc: Doc | None = None) -> None:
return pf.run_filters(
(action1, action2),
prepare=prepare,
finalize=finalize,
doc=doc,
)