diff --git a/src/jmbase/__init__.py b/src/jmbase/__init__.py index 7f16d8eff..abc8cdfb4 100644 --- a/src/jmbase/__init__.py +++ b/src/jmbase/__init__.py @@ -8,7 +8,7 @@ EXIT_SUCCESS, hexbin, dictchanger, listchanger, JM_WALLET_NAME_PREFIX, JM_APP_NAME, IndentedHelpFormatterWithNL, wrapped_urlparse, - bdict_sdict_convert, random_insert) + bdict_sdict_convert, random_insert, dict_factory) from .proof_of_work import get_pow, verify_pow from .twisted_utils import (stop_reactor, is_hs_uri, get_tor_agent, get_nontor_agent, JMHiddenService, diff --git a/src/jmbase/support.py b/src/jmbase/support.py index e97bf2cf9..6c60b3dab 100644 --- a/src/jmbase/support.py +++ b/src/jmbase/support.py @@ -7,6 +7,7 @@ from os import path, environ from functools import wraps from optparse import IndentedHelpFormatter +from sqlite3 import Cursor, Row from typing import List import urllib.parse as urlparse @@ -356,3 +357,7 @@ def get_free_tcp_ports(num_ports: int) -> List[int]: for s in sockets: s.close() return ports + +def dict_factory(cursor: Cursor, row: Row) -> dict: + fields = [column[0] for column in cursor.description] + return {key: value for key, value in zip(fields, row)} diff --git a/src/jmclient/wallet_utils.py b/src/jmclient/wallet_utils.py index 686ab16e1..5dccab543 100644 --- a/src/jmclient/wallet_utils.py +++ b/src/jmclient/wallet_utils.py @@ -20,7 +20,7 @@ from jmclient.wallet_service import WalletService from jmbase.support import (get_password, jmprint, EXIT_FAILURE, EXIT_ARGERROR, utxo_to_utxostr, hextobin, bintohex, - IndentedHelpFormatterWithNL) + IndentedHelpFormatterWithNL, dict_factory) from .cryptoengine import TYPE_P2PKH, TYPE_P2SH_P2WPKH, TYPE_P2WPKH, \ TYPE_SEGWIT_WALLET_FIDELITY_BONDS @@ -815,13 +815,6 @@ def wallet_change_passphrase(walletservice, return True -def dict_factory(cursor, row): - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d - - def wallet_fetch_history(wallet, options): # sort txes in a db because python can be really bad with large lists con = sqlite3.connect(":memory:") diff --git a/src/jmdaemon/orderbookwatch.py b/src/jmdaemon/orderbookwatch.py index 796675945..ba62087e3 100644 --- a/src/jmdaemon/orderbookwatch.py +++ b/src/jmdaemon/orderbookwatch.py @@ -8,17 +8,10 @@ from jmdaemon.protocol import JM_VERSION from jmdaemon import fidelity_bond_sanity_check -from jmbase.support import get_log, joinmarket_alert +from jmbase.support import dict_factory, get_log, joinmarket_alert log = get_log() -def dict_factory(cursor, row): - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d - - class JMTakerError(Exception): pass diff --git a/test/jmbase/test_base_support.py b/test/jmbase/test_base_support.py index 68734e5b7..024187915 100644 --- a/test/jmbase/test_base_support.py +++ b/test/jmbase/test_base_support.py @@ -1,7 +1,9 @@ #! /usr/bin/env python -import pytest import copy -from jmbase import random_insert +import pytest +import sqlite3 + +from jmbase import dict_factory, random_insert def test_color_coded_logging(): # TODO @@ -30,3 +32,14 @@ def test_random_insert(list1, list2): i_x = list1.index(x) i_y = list1.index(y) assert i_y > i_x + +def test_dict_factory(): + con = sqlite3.connect(":memory:") + con.row_factory = dict_factory + db = con.cursor() + db.execute("CREATE TABLE test (one TEXT, two TEXT)") + db.execute("INSERT INTO test VALUES (?, ?)", [ "one", "two" ]) + res = db.execute("SELECT * FROM test") + row = res.fetchone() + assert row["one"] == "one" + assert row["two"] == "two"