Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff #15

Merged
merged 5 commits into from
Jan 27, 2025
Merged

Ruff #15

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ name: CI
on:
push:
branches:
- master
- main
pull_request:
branches:
- master
- main

jobs:

Expand Down
24 changes: 13 additions & 11 deletions itemdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
__version__ = "1.2.0"
version_info = tuple(map(int, __version__.split(".")))

__all__ = ["ItemDB", "AsyncItemDB", "asyncify"]
__all__ = ["AsyncItemDB", "ItemDB", "asyncify"]

json_encode = json.JSONEncoder(ensure_ascii=True).encode
json_decode = json.JSONDecoder().decode
Expand Down Expand Up @@ -168,11 +168,11 @@ def get_indices(self, table_name):
except KeyError:
pass
except TypeError:
raise TypeError(f"Table name must be str, not {table_name}.")
raise TypeError(f"Table name must be str, not {table_name}.") from None

# Check table name
if not isinstance(table_name, str):
raise TypeError(f"Table name must be str, not {table_name}")
raise TypeError(f"Table name must be str, not {table_name}") from None
elif not table_name.isidentifier():
raise ValueError(f"Table name must be an identifier, not '{table_name}'")

Expand Down Expand Up @@ -284,9 +284,8 @@ def _ensure_table_helper2(self, table_name, indices):
index_key = fieldname.lstrip("!")
if fieldname not in found_indices:
if fieldname.startswith("!"):
raise IndexError(
f"Cannot add unique index {fieldname!r} after the table has been created."
)
when = "after the table has been created"
raise IndexError(f"Cannot add unique index {fieldname!r} {when}.")
elif fieldname in {x.lstrip("!") for x in found_indices}:
raise IndexError(f"Given index {fieldname!r} should be unique.")
cur.execute(f"ALTER TABLE {table_name} ADD {index_key};")
Expand Down Expand Up @@ -338,7 +337,9 @@ def rename_table(self, table_name, new_table_name):
"""
self.get_indices(table_name) # Fail with KeyError for invalid table name
if not (isinstance(new_table_name, str) and new_table_name.isidentifier()):
raise TypeError(f"Table name must be a str identifier, not '{table_name}'")
raise TypeError(
f"Table name must be a str identifier, not '{table_name}'"
) from None
cur = self._cur
if cur is None:
raise IOError("Can only use rename_table() within a transaction.")
Expand Down Expand Up @@ -393,7 +394,7 @@ def count(self, table_name, query, *save_args):
return cur.fetchone()[0]
except sqlite3.OperationalError as err:
if "no such column" in str(err).lower():
raise IndexError(str(err))
raise IndexError(str(err)) from None
raise err
finally:
cur.close()
Expand Down Expand Up @@ -458,7 +459,7 @@ def select(self, table_name, query, *save_args):
return [json_decode(x[0]) for x in cur]
except sqlite3.OperationalError as err:
if "no such column" in str(err).lower():
raise IndexError(str(err))
raise IndexError(str(err)) from None
raise err
finally:
cur.close()
Expand Down Expand Up @@ -526,8 +527,9 @@ def put(self, table_name, *items):
elif fieldname.startswith("!"):
raise IndexError(f"Item does not have required field {index_key!r}")

cmd = "INSERT OR REPLACE INTO"
cur.execute(
f"INSERT OR REPLACE INTO {table_name} ({index_keys}) VALUES ({row_plac})",
f"{cmd} {table_name} ({index_keys}) VALUES ({row_plac})",
row_vals,
)

Expand Down Expand Up @@ -588,7 +590,7 @@ def delete(self, table_name, query, *save_args):
cur.execute(f"DELETE FROM {table_name} WHERE {query}", save_args)
except sqlite3.OperationalError as err:
if "no such column" in str(err).lower():
raise IndexError(str(err))
raise IndexError(str(err)) from None
raise err
finally:
cur.close()
Expand Down
21 changes: 13 additions & 8 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gc
import sys
import time
import asyncio
import threading
Expand All @@ -11,6 +12,9 @@

side_effect = [0]

root_loop = asyncio.new_event_loop()
asyncio.set_event_loop(root_loop)


def plain_func(x):
time.sleep(1) # emulate io
Expand All @@ -24,13 +28,12 @@ def plain_func_that_errors(x):

def swait(co):
"""Sync-wait for the given coroutine, and return the result."""
return asyncio.get_event_loop().run_until_complete(co)
return root_loop.run_until_complete(co)


def swait_multiple(cos):
"""Sync-wait for the given coroutines."""
# asyncio.get_event_loop().run_until_complete(asyncio.wait(cos)) # stopped working
asyncio.get_event_loop().run_until_complete(asyncio.gather(*cos))
root_loop.run_until_complete(asyncio.gather(*cos))


def test_asyncify1():
Expand Down Expand Up @@ -70,7 +73,7 @@ def test_asyncif2():

# Run it multiple times
t0 = time.perf_counter()
swait_multiple([func(3) for i in range(5)])
swait_multiple([func(3) for _ in range(5)])
assert side_effect[0] == 50
t1 = time.perf_counter()
assert (t1 - t0) < 2
Expand All @@ -89,10 +92,10 @@ def test_AsyncItemDB_threads():
time.sleep(0.1)
assert threading.active_count() < 20

dbs1 = swait(_test_AsyncItemDB_threads()) # noqa
dbs1 = swait(_test_AsyncItemDB_threads())
assert threading.active_count() > 100

dbs2 = swait(_test_AsyncItemDB_threads()) # noqa
dbs2 = swait(_test_AsyncItemDB_threads())
time.sleep(0.1)
assert threading.active_count() > 200

Expand All @@ -112,7 +115,7 @@ def test_AsyncItemDB_threads():

async def _test_AsyncItemDB_threads():
dbs = []
for i in range(100):
for _ in range(100):
dbs.append(await AsyncItemDB(":memory:"))
return dbs

Expand All @@ -137,7 +140,9 @@ async def _test_AsyncItemDB():
with raises(IOError): # Put needs to be used under a context
await db.put("items", dict(id=1, mt=100))

with raises(Exception): # Normal with not allowed
# Normal with not allowed
Exc = AttributeError if sys.version_info < (3, 11) else TypeError
with raises(Exc):
with db:
pass

Expand Down
15 changes: 7 additions & 8 deletions tests/test_itemdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" The main usage tests.
"""
"""The main usage tests."""

import os
import time
Expand Down Expand Up @@ -362,7 +361,7 @@ def test_multiple_items():
)
assert set(x["id"] for x in db.select_all("items")) == {1, 2, 3, 4}
for x in db.select_all("items"):
x["mt"] == 102
assert x["mt"] == 102

# Lets take it further
with db:
Expand Down Expand Up @@ -463,7 +462,7 @@ def run_fast_transaction2():

def run_read():
db = ItemDB(filename)
for i in range(30):
for _i in range(30):
time.sleep(0.05)
item = db.select_one("items", "id == 3")
read.append(item["value"])
Expand Down Expand Up @@ -495,7 +494,7 @@ def run_slow_transaction():
with db:
time.sleep(0.2)

threads = [threading.Thread(target=run_slow_transaction) for i in range(3)]
threads = [threading.Thread(target=run_slow_transaction) for _ in range(3)]
t0 = time.perf_counter()
for t in threads:
t.start()
Expand Down Expand Up @@ -524,7 +523,7 @@ def test_database_race_conditions():
ItemDB(filename).ensure_table("items", "!id")

def push_a_bunch():
for i in range(n_writes):
for _ in range(n_writes):
id = random.randint(1, 10)
mt = random.randint(1000, 2000)
tracking[id].append(mt)
Expand All @@ -535,7 +534,7 @@ def push_a_bunch():

# Prepare, start, and join threads
t0 = time.perf_counter()
threads = [threading.Thread(target=push_a_bunch) for i in range(n_threads)]
threads = [threading.Thread(target=push_a_bunch) for _ in range(n_threads)]
for t in threads:
t.start()
for t in threads:
Expand All @@ -553,7 +552,7 @@ def push_a_bunch():
id = item["id"]
assert item["mt"] == max(tracking[id])

return items
# return items


def test_threaded_access():
Expand Down
3 changes: 1 addition & 2 deletions tests/test_management.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" Tests related to db management and tables.
"""
"""Tests related to db management and tables."""

import os
import tempfile
Expand Down
17 changes: 9 additions & 8 deletions tests/test_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


async def run_in_thread(func, *args, **kwargs):
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
future = loop.create_future()

def thread_func():
Expand All @@ -40,7 +40,7 @@ async def work_async1(self):

async def work_async2(self):
# Using a thread pool
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
return await loop.run_in_executor(executor, self.work)


Expand All @@ -65,7 +65,7 @@ def check_speed_of_async():
# work_async1 0.0002502583791895948
# work_async2 0.00015008997599039617

loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
for i in range(3):
method_name = f"work_async{i}"
t = loop.run_until_complete(do_some_work(method_name))
Expand Down Expand Up @@ -99,7 +99,8 @@ async def work():

def check_speed_of_async_itemdb():
co = _check_speed_of_async_itemdb()
asyncio.get_event_loop().run_until_complete(co)
loop = asyncio.new_event_loop()
loop.run_until_complete(co)


async def _check_speed_of_async_itemdb():
Expand All @@ -113,20 +114,20 @@ async def _check_speed_of_async_itemdb():
time.sleep(0.1)
t0 = time.perf_counter()

for i in range(n):
for _i in range(n):
await do_work_using_asyncify()

t1 = time.perf_counter()
print(f"with asyncify: {(t1 - t0)/n:0.9f} s")
print(f"with asyncify: {(t1 - t0) / n:0.9f} s")

time.sleep(0.1)
t0 = time.perf_counter()

for i in range(n):
for _i in range(n):
await do_work_using_asyncitemdb()

t1 = time.perf_counter()
print(f"with AsyncItemDB: {(t1 - t0)/n:0.9f} s")
print(f"with AsyncItemDB: {(t1 - t0) / n:0.9f} s")


if __name__ == "__main__":
Expand Down
17 changes: 11 additions & 6 deletions tests/test_sqlite_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
tempdir = tempfile.gettempdir()
n_threads = 40
n_writes = 10
RETURN_ITEMS = False


def create_db(dbname):
Expand Down Expand Up @@ -77,7 +78,7 @@ def xxtest_sqlite_nonlocking():
t0 = time.perf_counter()
threads = [
threading.Thread(target=write_to_db, args=(filename, False))
for i in range(n_threads)
for _ in range(n_threads)
]
for t in threads:
t.start()
Expand All @@ -89,7 +90,7 @@ def xxtest_sqlite_nonlocking():
con = sqlite3.connect(filename)
items = [item for item in con.execute("SELECT id, mt FROM items")]

print(f"{t1-t0} s, for {n_threads*n_writes} writes, saving {len(items)} items")
print(f"{t1 - t0} s, for {n_threads * n_writes} writes, saving {len(items)} items")

assert 10 <= len(items) <= 120
fails = 0
Expand All @@ -99,7 +100,9 @@ def xxtest_sqlite_nonlocking():
fails += 1
mts[item[0]] = item[1]
assert fails > 0
return items

if RETURN_ITEMS:
return items


def test_sqlite_locking():
Expand All @@ -113,7 +116,7 @@ def test_sqlite_locking():
# Prepare, start, and join threads
t0 = time.perf_counter()
threads = [
threading.Thread(target=write_to_db, args=(filename,)) for i in range(n_threads)
threading.Thread(target=write_to_db, args=(filename,)) for _ in range(n_threads)
]
for t in threads:
t.start()
Expand All @@ -125,7 +128,7 @@ def test_sqlite_locking():
con = sqlite3.connect(filename)
items = [item for item in con.execute("SELECT id, mt FROM items")]

print(f"{t1-t0} s, for {n_threads*n_writes} writes, saving {len(items)} items")
print(f"{t1 - t0} s, for {n_threads * n_writes} writes, saving {len(items)} items")

assert 10 <= len(items) <= 100
# assert len(items) == n_threads * n_writes
Expand All @@ -134,9 +137,11 @@ def test_sqlite_locking():
assert item[1] >= mts.get(item[0], -9999)
mts[item[0]] = item[1]

return items
if RETURN_ITEMS:
return items


if __name__ == "__main__":
RETURN_ITEMS = True
items = xxtest_sqlite_nonlocking()
items = test_sqlite_locking()
Loading