Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = pytest-testmon
version = 1.3.6
version = 1.3.7
license = AGPL
author_email = tibor.arpas@infinit.sk
author = Tibor Arpas, Tomas Matlovic, Daniel Hahler, Martin Racak
Expand Down Expand Up @@ -46,4 +46,5 @@ pytest11 =
testmon = testmon.pytest_testmon
tox =
testmon = testmon.tox_testmon

console_scripts =
testmon-merge-db = testmon:merge_db
18 changes: 18 additions & 0 deletions testmon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

def merge_db():
import argparse
from testmon.db import DB, merge_dbs

parser = argparse.ArgumentParser()
parser.add_argument('dbs', metavar='N', type=str, nargs='+')
parser.add_argument('--output', metavar='N', type=str, nargs='?', default="merged")
parser.add_argument('--environment', metavar='N', type=str, nargs='?', default="default")

args = parser.parse_args()
databases = args.dbs
output_db = args.output
env = args.environment

db_1 = DB(datafile=databases[0], environment=env)
db_2 = DB(datafile=databases[1], environment=env)
merge_dbs(merged_datafile=output_db, db_1=db_1, db_2=db_2)
66 changes: 56 additions & 10 deletions testmon/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sqlite3

from collections import namedtuple
from sqlite3 import Binary
from typing import List, Optional

from testmon.process_code import (
blob_to_checksums,
Expand All @@ -22,7 +24,7 @@ class TestmonDbException(Exception):
pass


def connect(datafile):
def connect(datafile: os.PathLike):
connection = sqlite3.connect(datafile)

connection.execute("PRAGMA synchronous = OFF")
Expand All @@ -31,9 +33,31 @@ def connect(datafile):
connection.row_factory = sqlite3.Row
return connection

def merge_dbs(merged_datafile, db_1: "DB", db_2: "DB") -> "DB":
if db_1.env != db_2.env:
raise

memory_db = DB(":memory:", environment=db_1.env)
file_db = DB(merged_datafile, environment=db_1.env)
with memory_db:
with db_1:
for data in db_1.all_data():
memory_db.insert_node_fingerprints(data["name"], fingerprints=[data], failed=data["failed"],
duration=data["duration"])

with db_2:
for data in db_2.all_data():
memory_db.insert_node_fingerprints(data["name"], fingerprints=[data], failed=data["failed"],
duration=data["duration"])

with file_db:
memory_db.con.backup(file_db.con)

return file_db


class DB:
def __init__(self, datafile, environment="default"):
def __init__(self, datafile: os.PathLike, environment: str = "default"):
new_db = not os.path.exists(datafile)

connection = connect(datafile)
Expand All @@ -45,7 +69,7 @@ def __init__(self, datafile, environment="default"):
if new_db or old_format:
self.init_tables()

def _check_data_version(self, datafile):
def _check_data_version(self, datafile: os.PathLike) -> bool:
stored_data_version = self._fetch_data_version()

if int(stored_data_version) == DATA_VERSION:
Expand All @@ -56,14 +80,14 @@ def _check_data_version(self, datafile):
self.con = connect(datafile)
return True

def __enter__(self):
def __enter__(self) -> "DB":
self.con = self.con.__enter__()
return self

def __exit__(self, *args, **kwargs):
self.con.__exit__(*args, **kwargs)

def update_mtimes(self, new_mtimes):
def update_mtimes(self, new_mtimes: float):
with self.con as con:
con.executemany(
"UPDATE fingerprint SET mtime=?, checksum=? WHERE id = ?", new_mtimes
Expand All @@ -80,7 +104,7 @@ def remove_unused_fingerprints(self):
"""
)

def fetch_or_create_fingerprint(self, filename, mtime, checksum, method_checksums):
def fetch_or_create_fingerprint(self, filename: str, mtime: float, checksum: str, method_checksums: Binary) -> int:
cursor = self.con.cursor()
try:
cursor.execute(
Expand Down Expand Up @@ -113,7 +137,7 @@ def fetch_or_create_fingerprint(self, filename, mtime, checksum, method_checksum
return fingerprint_id

def insert_node_fingerprints(
self, nodeid, fingerprints, failed=False, duration=None
self, nodeid: str, fingerprints: Fingerprints, failed: bool = False, duration: Optional[float] = None
):
with self.con as con:
cursor = con.cursor()
Expand Down Expand Up @@ -151,15 +175,15 @@ def _fetch_data_version(self):

return con.execute("PRAGMA user_version").fetchone()[0]

def _write_attribute(self, attribute, data, environment=None):
def _write_attribute(self, attribute: str, data: dict, environment: Optional[str] = None):
dataid = (environment or self.env) + ":" + attribute
with self.con as con:
con.execute(
"INSERT OR REPLACE INTO metadata VALUES (?, ?)",
[dataid, json.dumps(data)],
)

def _fetch_attribute(self, attribute, default=None, environment=None):
def _fetch_attribute(self, attribute: str, default=None, environment=None):
cursor = self.con.execute(
"SELECT data FROM metadata WHERE dataid=?",
[(environment or self.env) + ":" + attribute],
Expand Down Expand Up @@ -214,7 +238,7 @@ def init_tables(self):

connection.execute(f"PRAGMA user_version = {DATA_VERSION}")

def get_changed_file_data(self, changed_fingerprints):
def get_changed_file_data(self, changed_fingerprints: Fingerprints):
in_clause_questionsmarks = ", ".join("?" * len(changed_fingerprints))
result = []
for row in self.con.execute(
Expand Down Expand Up @@ -300,3 +324,25 @@ def filenames_fingerprints(self):
)

return [dict(row) for row in cursor]

def all_data(self) -> List[dict]:
cursor = self.con.execute(
"""
SELECT
n.name,
n.duration,
n.failed,
f.filename,
f.method_checksums,
f.mtime,
f.checksum
FROM
node n, node_fingerprint nfp, fingerprint f
WHERE
n.id = nfp.node_id AND
nfp.fingerprint_id = f.id AND
environment = ?
""",
(self.env,),
)
return [dict(row) for row in cursor]