Skip to content

Commit

Permalink
First draft of Preprocessor and DB functionalities
Browse files Browse the repository at this point in the history
  • Loading branch information
NikosDelijohn committed Oct 16, 2024
1 parent 9bb96b2 commit 9dbcb78
Showing 1 changed file with 155 additions and 13 deletions.
168 changes: 155 additions & 13 deletions src/testcrush/a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
import random
import csv
import time
import lark
import sqlite3
import io

from testcrush.utils import get_logger, compile_assembly, zip_archive
import testcrush.grammars.transformers as transformers
from testcrush.utils import get_logger, compile_assembly, zip_archive, Singleton
from testcrush import asm, zoix
from typing import Any

log = get_logger()


class CSVCompactionStatistics():
class CSVCompactionStatistics(metaclass=Singleton):
"""Manages I/O operations on the CSV file which logs the statistics of the A0."""
_header = ["asm_source", "removed_codeline", "compiles", "lsim_ok",
"tat", "fsim_ok", "coverage", "verdict"]
Expand All @@ -32,7 +36,144 @@ def __iadd__(self, rowline: dict):
return self


class A0():
class Preprocessor(metaclass=Singleton):
"""Filters out candidate instructions"""

_trace_db = ".trace.db"

def __init__(self, txt_fault_report: pathlib.Path, processor_trace: pathlib.Path, **mappings) -> 'Preprocessor':

factory = transformers.TraceTransformerFactory()
transformer, grammar = factory(mappings.get("processor_name"))

with open(processor_trace) as src:
trace_raw = src.read()

parser = lark.Lark(grammar=grammar, start="start", parser="lalr", transformer=transformer)
self.trace = parser.parse(trace_raw)

factory = transformers.FaultReportTransformerFactory()
transformer, grammar = factory("FaultList")

fault_report = zoix.TxtFaultReport(txt_fault_report)
parser = lark.Lark(grammar=grammar, start="start", parser="lalr", transformer=transformer)
self.fault_list: list[zoix.Fault] = parser.parse(fault_report.extract("FaultList"))

self._create_trace_db()

def _create_trace_db(self):
"""
Transforms the trace of the DUT to a SQLite database of a single table. The header of the CSV is mapped to the
DB column names and then the CSV body is transformed into DB row entries.
"""

# If pre-existent db is found, delete it.
db = pathlib.Path(self._trace_db)
if db.exists():
db.unlink()

con = sqlite3.connect(self._trace_db)
cursor = con.cursor()

header: list[str] = self.trace[0].split(',')
header = list(map(lambda column_name: f"\"{column_name}\"", header))
header = ", ".join(header)

cursor.execute(f"CREATE TABLE trace({header})")

body: list[str] = self.trace[1:]

with io.StringIO('\n'.join(body)) as source:

for row in csv.reader(source):

cursor.execute(f"INSERT INTO trace VALUES ({', '.join(['?'] * len(row))})", row)

con.commit()
con.close()

def query_trace_db(self, select: str, where: dict[str, str],
history: int = 5, allow_multiple: bool = False) -> list[tuple[str, ...]]:
"""
Perform a query with the specified parameters.
Assuming that the DB looks like this:
::
Time || Cycle || PC || Instruction
-----||-------||----------||------------
10ns || 1 || 00000004 || and
20ns || 2 || 00000008 || or <-*
30ns || 3 || 0000000c || xor <-|
40ns || 4 || 00000010 || sll <-|
50ns || 5 || 00000014 || j <-|
60ns || 6 || 0000004c || addi <-*
70ns || 7 || 00000050 || wfi
And you perform a query for the ``select="PC"`` and ``where={"PC": "0000004c", "Time": "60ns"}`` then the search
would result in a window of 1+4 ``PC`` values, indicated by ``<-`` in the snapshot above. The size of the window
defaults to 5 but can be freely selected by the user.
Args:
select (str): The field to select in the query.
where (dict[str, str]): A dictionary specifying conditions to filter the query.
history (int, optional): The number of past queries to include. Defaults to 5.
allow_multiple (bool, optional): Whether to allow multiple results. Defaults to False.
Returns:
list[tuple[str, ...]: A list of query results (tuples of strings) matching the criteria.
"""

db = pathlib.Path(self._trace_db)
if not db.exists():
raise FileNotFoundError("Trace DB not found")

columns = where.keys()

query = f"""
SELECT ROWID
FROM trace
WHERE {' AND '.join([f'{x} = ?' for x in columns])}
"""

values = where.values()
with sqlite3.connect(db) as con:

cursor = con.cursor()

cursor.execute(query, tuple(values))
rowids = cursor.fetchall()

if not rowids:
raise ValueError(f"No row found for {', '.join([f'{k}={v}' for k, v in where.items()])}")

if len(rowids) > 1 and not allow_multiple:
raise ValueError(f"Query resulted in multiple ROWIDs for \
{', '.join([f'{k}={v}' for k, v in where.items()])}")

result = list()
for rowid, in rowids:

query_with_history = f"""
SELECT {select} FROM trace
WHERE ROWID <= ?
ORDER BY ROWID DESC
LIMIT ?
"""

cursor.execute(query_with_history, (rowid, history))
result += cursor.fetchall()[::-1]

return result

def prune_candidates(candidates: list[asm.Codeline]):
# TODO:
...


class A0(metaclass=Singleton):

"""Implements the A0 compaction algorithm of https://doi.org/10.1109/TC.2016.2643663"""

def __init__(self, isa: str, a0_asm_sources: list[str], a0_settings: dict[str, Any]) -> "A0":
Expand All @@ -42,6 +183,11 @@ def __init__(self, isa: str, a0_asm_sources: list[str], a0_settings: dict[str, A
chunksize=1)
for asm_file in a0_asm_sources]

# Flatten candidates list
self.all_instructions: list[asm.Codeline] = [(asm_id, codeline) for asm_id, asm in
enumerate(self.assembly_sources) for codeline in asm.get_code()]
self.path_to_id = {v: k for k, v in enumerate(a0_asm_sources)}

self.assembly_compilation_instructions: list[str] = a0_settings.get("assembly_compilation_instructions")
self.fsim_report: zoix.CSVFaultReport = zoix.CSVFaultReport(
fault_summary=pathlib.Path(a0_settings.get("csv_fault_summary")),
Expand Down Expand Up @@ -235,28 +381,24 @@ def _restore(asm_source: int) -> None:
# they will be modified in-place.
zip_archive(f"../backup_{unique_id}", *[asm.get_asm_source() for asm in self.assembly_sources])

# Flatten candidates list
all_instructions: list[asm.Codeline] = [(asm_id, codeline) for asm_id, asm in
enumerate(self.assembly_sources) for codeline in asm.get_code()]

# Randomize order for Step 2
for i in range(times_to_shuffle):
random.shuffle(all_instructions)
for _ in range(times_to_shuffle):
random.shuffle(self.all_instructions)

iteration_stats = dict.fromkeys(CSVCompactionStatistics._header)

iteration_stats["tat"] = initial_tat
iteration_stats["coverage"] = initial_coverage

total_iterations = len(all_instructions)
total_iterations = len(self.all_instructions)

# Step 2: Select instructions in a random order
old_stl_stats = (initial_tat, initial_coverage)
while len(all_instructions) != 0:
while len(self.all_instructions) != 0:

print(f"""
#############
# ITERATION {total_iterations - len(all_instructions) + 1} / {total_iterations}
# ITERATION {total_iterations - len(self.all_instructions) + 1} / {total_iterations}
#############
""")

Expand All @@ -265,7 +407,7 @@ def _restore(asm_source: int) -> None:
stats += iteration_stats
iteration_stats = dict.fromkeys(CSVCompactionStatistics._header)

asm_id, codeline = all_instructions.pop(0)
asm_id, codeline = self.all_instructions.pop(0)
asm_source_file = self.assembly_sources[asm_id].get_asm_source().name

iteration_stats["asm_source"] = asm_source_file
Expand Down

0 comments on commit 9dbcb78

Please sign in to comment.