Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 60ac169

Browse files
idling11sungchun12
andauthored
[to#811]Fix special characters in PG url and Mysql connection reconnect (#812)
* [to#811]Fix special characters in PG * add TestSpecialCharacterPassword UT for pg * format import * format fix * quick nits * add newline * format changed files --------- Co-authored-by: Sung Won Chung <sungwonchung3@gmail.com>
1 parent 243997a commit 60ac169

File tree

3 files changed

+53
-3
lines changed

3 files changed

+53
-3
lines changed

data_diff/databases/mysql.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, ClassVar, Dict, Type
1+
from typing import Any, ClassVar, Dict, Type, Union
22

33
import attrs
44

@@ -20,6 +20,7 @@
2020
import_helper,
2121
ConnectError,
2222
BaseDialect,
23+
ThreadLocalInterpreter,
2324
)
2425
from data_diff.databases.base import (
2526
MD5_HEXDIGITS,
@@ -148,3 +149,11 @@ def create_connection(self):
148149
elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
149150
raise ConnectError("Database does not exist") from e
150151
raise ConnectError(*e.args) from e
152+
153+
def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]):
154+
"This method runs in a worker thread"
155+
if self._init_error:
156+
raise self._init_error
157+
if not self.thread_local.conn.is_connected():
158+
self.thread_local.conn.ping(reconnect=True, attempts=3, delay=5)
159+
return self._query_conn(self.thread_local.conn, sql_code)

data_diff/databases/postgresql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Any, ClassVar, Dict, List, Type
2-
2+
from urllib.parse import unquote
33
import attrs
44

55
from data_diff.abcs.database_types import (
@@ -168,6 +168,7 @@ def create_connection(self):
168168

169169
pg = import_postgresql()
170170
try:
171+
self._args["password"] = unquote(self._args["password"])
171172
self._conn = pg.connect(
172173
**self._args, keepalives=1, keepalives_idle=5, keepalives_interval=2, keepalives_count=2
173174
)

tests/test_postgresql.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import unittest
22

3+
from urllib.parse import quote
34
from data_diff.queries.api import table, commit
45
from data_diff import TableSegment, HashDiffer
56
from data_diff import databases as db
6-
from tests.common import get_conn, random_table_suffix
7+
from tests.common import get_conn, random_table_suffix, connect
8+
from data_diff import connect_to_table
79

810

911
class TestUUID(unittest.TestCase):
@@ -113,3 +115,41 @@ def test_100_fields(self):
113115
id_ = diff[0][1][0]
114116
result = (id_,) + tuple("1" for x in range(100))
115117
self.assertEqual(diff, [("-", result)])
118+
119+
120+
class TestSpecialCharacterPassword(unittest.TestCase):
121+
def setUp(self) -> None:
122+
self.connection = get_conn(db.PostgreSQL)
123+
124+
table_suffix = random_table_suffix()
125+
126+
self.table_name = f"table{table_suffix}"
127+
self.table = table(self.table_name)
128+
129+
def test_special_char_password(self):
130+
password = "passw!!!@rd"
131+
# Setup user with special character '@' in password
132+
self.connection.query("DROP USER IF EXISTS test;", None)
133+
self.connection.query(f"CREATE USER test WITH PASSWORD '{password}';", None)
134+
135+
password_quoted = quote(password)
136+
db_config = {
137+
"driver": "postgresql",
138+
"host": "localhost",
139+
"port": 5432,
140+
"dbname": "postgres",
141+
"user": "test",
142+
"password": password_quoted,
143+
}
144+
145+
# verify pythonic connection method
146+
connect_to_table(
147+
db_config,
148+
self.table_name,
149+
)
150+
151+
# verify connection method with URL string unquoted after it's verified
152+
db_url = f"postgresql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"
153+
154+
connection_verified = connect(db_url)
155+
assert connection_verified._args.get("password") == password

0 commit comments

Comments
 (0)