Skip to content

Commit

Permalink
wipe_db() implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Vayras authored and cdecker committed Jan 29, 2024
1 parent af4b244 commit 464dc02
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
21 changes: 19 additions & 2 deletions contrib/pyln-testing/pyln/testing/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from typing import Dict, List, Optional, Union


class Sqlite3Db(object):
class BaseDb(object):
def wipe_db(self):
raise NotImplementedError("wipe_db method must be implemented by the subclass")


class Sqlite3Db(BaseDb):
def __init__(self, path: str) -> None:
self.path = path
self.provider = None
Expand All @@ -32,6 +37,8 @@ def query(self, query: str) -> Union[List[Dict[str, Union[int, bytes]]], List[Di

db.row_factory = sqlite3.Row
c = db.cursor()
# Don't get upset by concurrent writes; wait for up to 5 seconds!
c.execute("PRAGMA busy_timeout = 5000")
c.execute(query)
rows = c.fetchall()

Expand All @@ -55,8 +62,12 @@ def execute(self, query: str) -> None:
def stop(self):
pass

def wipe_db(self):
if os.path.exists(self.path):
os.remove(self.path)


class PostgresDb(object):
class PostgresDb(BaseDb):
def __init__(self, dbname, port):
self.dbname = dbname
self.port = port
Expand Down Expand Up @@ -102,6 +113,12 @@ def stop(self):
cur.execute("DROP DATABASE {};".format(self.dbname))
cur.close()

def wipe_db(self):
cur = self.conn.cursor()
cur.execute(f"DROP DATABASE IF EXISTS {self.dbname};")
cur.execute(f"CREATE DATABASE {self.dbname};")
cur.close()


class SqliteDbProvider(object):
def __init__(self, directory: str) -> None:
Expand Down
16 changes: 14 additions & 2 deletions tests/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import subprocess
import time

class BaseDb(object):
def wipe_db(self):
raise NotImplementedError("wipe_db method must be implemented by the subclass")

class Sqlite3Db(object):
class Sqlite3Db(BaseDb):
def __init__(self, path):
self.path = path
self.provider = None
Expand Down Expand Up @@ -50,8 +53,11 @@ def execute(self, query):
c.close()
db.close()

def wipe_db(self):
if os.path.exists(self.path):
os.remove(self.path)

class PostgresDb(object):
class PostgresDb(BaseDb):
def __init__(self, dbname, port):
self.dbname = dbname
self.port = port
Expand Down Expand Up @@ -89,6 +95,12 @@ def execute(self, query):
cur.execute(query)


def wipe_db(self):
cur = self.conn.cursor()
cur.execute(f"DROP DATABASE IF EXISTS {self.dbname};")
cur.execute(f"CREATE DATABASE {self.dbname};")
cur.close()

class SqliteDbProvider(object):
def __init__(self, directory):
self.directory = directory
Expand Down

0 comments on commit 464dc02

Please sign in to comment.