Skip to content

Commit fca8f9f

Browse files
Add extensible prover implementation (#103)
This PR introduces a module `kprovex` that defines an extensible prover based on `APRPRover`. It has the the following submodules: * `kprovex.api`: the plugin API definition. A plugin provides the K definition to the prover, as well as functions for loading and printing proofs of that definition. * `kprovex._default`: semantics-agnostic defaults for loading and printing proofs. * `kprovex._loader`: plugin loader. * `kprovex._kprovex`: prover implementation. Furthermore, it adds a simple plugin implementation for `riscv-semantics`. The advantage of this architecture is that `kprovex` can be upstreamed to `pyk` without breaking it or the `riscv-semantics` prover. (Naturally, imports still need to be adjusted.) --------- Co-authored-by: devops <devops@runtimeverification.com>
1 parent fb931c8 commit fca8f9f

File tree

9 files changed

+454
-3
lines changed

9 files changed

+454
-3
lines changed

package/version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.1.100
1+
0.1.101

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "kriscv"
7-
version = "0.1.100"
7+
version = "0.1.101"
88
description = "K tooling for the RISC-V architecture"
99
readme = "README.md"
1010
requires-python = "~=3.10"
@@ -28,6 +28,9 @@ kriscv-asm = "kriscv.devtools:kriscv_asm"
2828
[project.entry-points.kdist]
2929
riscv-semantics = "kriscv.kdist.plugin"
3030

31+
[project.entry-points.kprovex]
32+
riscv = "kriscv.symtools:KRiscVPlugin"
33+
3134
[dependency-groups]
3235
dev = [
3336
"autoflake",

src/kriscv/kprovex/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._kprovex import KProveX, create_prover

src/kriscv/kprovex/_default.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from pathlib import Path
7+
from typing import Final
8+
9+
from pyk.kast import KInner
10+
from pyk.proof.reachability import APRProof
11+
12+
from .api import Config, Init, Show
13+
14+
15+
def init_from_claims(config: Config, spec_file: Path, claim_id: str) -> APRProof:
16+
from pyk.ktool.claim_loader import ClaimLoader
17+
from pyk.proof.reachability import APRProof
18+
19+
spec_module, claim_label = claim_id.split('.', 1)
20+
include_dirs = config.dist.source_dirs
21+
22+
claims = ClaimLoader(config.kprove).load_claims(
23+
spec_file=spec_file,
24+
spec_module_name=spec_module,
25+
claim_labels=[claim_label],
26+
include_dirs=include_dirs,
27+
)
28+
(claim,) = claims
29+
30+
proof = APRProof.from_claim(
31+
config.kprove.definition,
32+
claim=claim,
33+
logs={},
34+
proof_dir=config.proof_dir,
35+
)
36+
return proof
37+
38+
39+
def show_pretty_term(config: Config, term: KInner) -> str:
40+
from pyk.konvert import kast_to_kore
41+
from pyk.kore.tools import kore_print
42+
43+
kore = kast_to_kore(config.definition, term)
44+
text = kore_print(kore, definition_dir=config.dist.haskell_dir)
45+
return text
46+
47+
48+
# Check signatures
49+
_default_init: Final[Init] = init_from_claims
50+
_default_show: Final[Show] = show_pretty_term

src/kriscv/kprovex/_kprovex.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from functools import cached_property
5+
from pathlib import Path
6+
from typing import TYPE_CHECKING, final
7+
8+
from pyk.proof.reachability import APRProof, APRProver
9+
10+
from .api import Config
11+
12+
if TYPE_CHECKING:
13+
from pyk.proof import ProofStatus
14+
from pyk.proof.show import APRProofNodePrinter
15+
from pyk.utils import BugReport
16+
17+
from .api import Init, Plugin, Show
18+
19+
20+
def create_prover(plugin_id: str, proof_dir: str | Path, *, bug_report: BugReport | None = None) -> KProveX:
21+
from ._loader import PLUGINS
22+
23+
if plugin_id not in PLUGINS:
24+
raise ValueError(f'Unknown plugin: {plugin_id}')
25+
26+
plugin = PLUGINS[plugin_id]
27+
proof_dir = Path(proof_dir)
28+
29+
return KProveX(
30+
plugin=plugin,
31+
proof_dir=proof_dir,
32+
bug_report=bug_report,
33+
)
34+
35+
36+
@final
37+
@dataclass
38+
class KProveX:
39+
plugin: Plugin
40+
proof_dir: Path
41+
bug_report: BugReport | None
42+
43+
def __init__(
44+
self,
45+
plugin: Plugin,
46+
proof_dir: Path,
47+
*,
48+
bug_report: BugReport | None = None,
49+
):
50+
self.plugin = plugin
51+
self.proof_dir = proof_dir
52+
self.bug_report = bug_report
53+
54+
proof_dir.mkdir(parents=True, exist_ok=True)
55+
56+
@cached_property
57+
def config(self) -> Config:
58+
return Config(
59+
dist=self.plugin.dist(),
60+
proof_dir=self.proof_dir,
61+
bug_report=self.bug_report,
62+
)
63+
64+
def init_proof(
65+
self,
66+
spec_file: str | Path,
67+
claim_id: str,
68+
*,
69+
init_id: str | None = None,
70+
exist_ok: bool = False,
71+
) -> str:
72+
spec_file = Path(spec_file)
73+
init = self._load_init(init_id=init_id)
74+
proof = init(config=self.config, spec_file=spec_file, claim_id=claim_id)
75+
if not exist_ok and APRProof.proof_data_exists(proof.id, self.proof_dir):
76+
raise ValueError(f'Proof with id already exists: {proof.id}')
77+
78+
proof.write_proof_data()
79+
return proof.id
80+
81+
def list_proofs(self) -> list[str]:
82+
raise ValueError('TODO')
83+
84+
def list_nodes(self, proof_id: str) -> list[int]:
85+
proof = self._load_proof(proof_id)
86+
return [node.id for node in proof.kcfg.nodes]
87+
88+
def advance_proof(
89+
self,
90+
proof_id: str,
91+
*,
92+
max_depth: int | None = None,
93+
max_iterations: int | None = None,
94+
) -> ProofStatus:
95+
proof = self._load_proof(proof_id)
96+
97+
with self.config.explore(id=proof_id) as kcfg_explore:
98+
prover = APRProver(
99+
kcfg_explore=kcfg_explore,
100+
execute_depth=max_depth,
101+
)
102+
prover.advance_proof(proof, max_iterations=max_iterations)
103+
104+
return proof.status
105+
106+
def show_proof(
107+
self,
108+
proof_id: str,
109+
*,
110+
show_id: str | None = None,
111+
truncate: bool = False,
112+
) -> str:
113+
from pyk.proof.show import APRProofShow
114+
115+
proof = self._load_proof(proof_id)
116+
node_printer = self._proof_node_printer(proof, show_id=show_id, full_printer=False)
117+
proof_show = APRProofShow(self.config.definition, node_printer=node_printer)
118+
lines = proof_show.show(proof)
119+
if truncate:
120+
lines = [_truncate(line, 120) for line in lines]
121+
return '\n'.join(lines)
122+
123+
def view_proof(
124+
self,
125+
proof_id: str,
126+
*,
127+
show_id: str | None = None,
128+
) -> None:
129+
from pyk.proof.tui import APRProofViewer
130+
131+
proof = self._load_proof(proof_id)
132+
node_printer = self._proof_node_printer(proof, show_id=show_id, full_printer=False)
133+
viewer = APRProofViewer(proof, self.config.kprove, node_printer=node_printer)
134+
viewer.run()
135+
136+
def prune_node(self, proof_id: str, node_id: str) -> list[int]:
137+
proof = self._load_proof(proof_id)
138+
res = proof.prune(node_id)
139+
proof.write_proof_data()
140+
return res
141+
142+
def show_node(
143+
self,
144+
proof_id: str,
145+
node_id: str,
146+
*,
147+
show_id: str | None = None,
148+
truncate: bool = False,
149+
) -> str:
150+
proof = self._load_proof(proof_id)
151+
node_printer = self._proof_node_printer(proof, show_id=show_id, full_printer=True)
152+
kcfg = proof.kcfg
153+
node = kcfg.node(node_id)
154+
lines = node_printer.print_node(kcfg, node)
155+
if truncate:
156+
lines = [_truncate(line, 120) for line in lines]
157+
return '\n'.join(lines)
158+
159+
# Private helpers
160+
161+
def _load_proof(self, proof_id: str) -> APRProof:
162+
return APRProof.read_proof_data(proof_dir=self.proof_dir, id=proof_id)
163+
164+
def _load_init(self, *, init_id: str | None) -> Init:
165+
if init_id is None:
166+
from . import _default
167+
168+
return _default.init_from_claims
169+
170+
inits = self.plugin.inits()
171+
if init_id not in inits:
172+
raise ValueError(f'Unknown init function: {init_id}')
173+
174+
return inits[init_id]
175+
176+
def _load_show(self, *, show_id: str | None) -> Show:
177+
if show_id is None:
178+
from . import _default
179+
180+
return _default.show_pretty_term
181+
182+
shows = self.plugin.shows()
183+
if show_id not in shows:
184+
raise ValueError(f'Unknown show function: {show_id}')
185+
186+
return shows[show_id]
187+
188+
def _proof_node_printer(
189+
self,
190+
proof: APRProof,
191+
*,
192+
show_id: str | None = None,
193+
full_printer: bool = False,
194+
) -> APRProofNodePrinter:
195+
from pyk.cterm.show import CTermShow
196+
from pyk.proof.show import APRProofNodePrinter
197+
198+
show = self._load_show(show_id=show_id)
199+
printer = lambda kast: show(self.config, kast)
200+
return APRProofNodePrinter(
201+
proof=proof,
202+
cterm_show=CTermShow(
203+
printer=printer,
204+
minimize=False,
205+
break_cell_collections=False,
206+
),
207+
full_printer=full_printer,
208+
)
209+
210+
211+
def _truncate(line: str, n: int) -> str:
212+
if len(line) <= n:
213+
return line
214+
return line[: n - 3] + '...'

src/kriscv/kprovex/_loader.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
import importlib
4+
import logging
5+
import re
6+
from typing import TYPE_CHECKING
7+
8+
from pyk.utils import FrozenDict
9+
10+
from .api import Plugin
11+
12+
if TYPE_CHECKING:
13+
from importlib.metadata import EntryPoint
14+
from typing import Final
15+
16+
17+
_LOGGER: Final = logging.getLogger(__name__)
18+
19+
20+
def _load_plugins() -> FrozenDict[str, Plugin]:
21+
entry_points = importlib.metadata.entry_points(group='kprovex')
22+
plugins: FrozenDict[str, Plugin] = FrozenDict(
23+
(entry_point.name, plugin) for entry_point in entry_points if (plugin := _load_plugin(entry_point)) is not None
24+
)
25+
return plugins
26+
27+
28+
def _load_plugin(entry_point: EntryPoint) -> Plugin | None:
29+
if not _valid_id(entry_point.name):
30+
_LOGGER.warning(f'Invalid entry point name, skipping: {entry_point.name}')
31+
return None
32+
33+
_LOGGER.info(f'Loading entry point: {entry_point.name}')
34+
try:
35+
module_name, class_name = entry_point.value.split(':')
36+
except ValueError:
37+
_LOGGER.error(f'Invalid entry point value: {entry_point.value}', exc_info=True)
38+
return None
39+
40+
try:
41+
_LOGGER.info(f'Importing module: {module_name}')
42+
module = importlib.import_module(module_name)
43+
except Exception:
44+
_LOGGER.error(f'Module {module_name} cannot be imported', exc_info=True)
45+
return None
46+
47+
try:
48+
_LOGGER.info(f'Loading plugin: {class_name}')
49+
cls = getattr(module, class_name)
50+
except AttributeError:
51+
_LOGGER.error(f'Class {class_name} not found in module {module_name}', exc_info=True)
52+
return None
53+
54+
if not issubclass(cls, Plugin):
55+
_LOGGER.error(f'Class {class_name} is not a Plugin', exc_info=True)
56+
return None
57+
58+
try:
59+
_LOGGER.info(f'Instantiating plugin: {class_name}')
60+
plugin = cls()
61+
except TypeError:
62+
_LOGGER.error(f'Cannot instantiate plugin {class_name}', exc_info=True)
63+
return None
64+
65+
return plugin
66+
67+
68+
_ID_PATTERN = re.compile('[a-z0-9]+(-[a-z0-9]+)*')
69+
70+
71+
def _valid_id(s: str) -> bool:
72+
return _ID_PATTERN.fullmatch(s) is not None
73+
74+
75+
PLUGINS: Final = _load_plugins()

0 commit comments

Comments
 (0)