Skip to content

Commit bfee4c8

Browse files
committed
Address PR comments
1 parent 6ef68d4 commit bfee4c8

File tree

2 files changed

+71
-52
lines changed

2 files changed

+71
-52
lines changed

redisgraph_bulk_loader/bulk_update.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,47 @@
22
import csv
33
import redis
44
import click
5-
from redis import ResponseError
5+
from redisgraph import Graph
66
from timeit import default_timer as timer
77

88

99
def utf8len(s):
1010
return len(s.encode('utf-8'))
1111

1212

13+
# Count number of rows in file.
14+
def count_entities(filename):
15+
entities_count = 0
16+
with open(filename, 'rt') as f:
17+
entities_count = sum(1 for line in f)
18+
return entities_count
19+
20+
1321
class BulkUpdate:
1422
"""Handler class for emitting bulk update commands"""
15-
def __init__(self, graph, max_token_size, separator, no_header, filename, query, variable_name, client):
23+
def __init__(self, graph_name, max_token_size, separator, no_header, filename, query, variable_name, client):
1624
self.separator = separator
1725
self.no_header = no_header
1826
self.query = " ".join(["UNWIND $rows AS", variable_name, query])
1927
self.buffer_size = 0
2028
self.max_token_size = max_token_size * 1024 * 1024 - utf8len(self.query)
21-
self.graph = graph
2229
self.filename = filename
23-
self.client = client
30+
self.graph_name = graph_name
31+
self.graph = Graph(graph_name, client)
2432
self.statistics = {}
2533

26-
# Count number of rows in file.
27-
def count_entities(self):
28-
entities_count = 0
29-
with open(self.filename, 'rt') as f:
30-
entities_count = sum(1 for line in f)
31-
return entities_count
32-
3334
def update_statistics(self, result):
34-
for raw_stat in result[0]:
35-
stat = raw_stat.split(": ")
36-
key = stat[0]
35+
for key, new_val in result.statistics.items():
3736
try:
3837
val = self.statistics[key]
3938
except KeyError:
4039
val = 0
41-
val += float(stat[1].split(" ")[0])
40+
val += new_val
4241
self.statistics[key] = val
4342

4443
def emit_buffer(self, rows):
4544
command = " ".join([rows, self.query])
46-
try:
47-
result = self.client.execute_command("GRAPH.QUERY", self.graph, command)
48-
except ResponseError as e:
49-
raise e
50-
# If we encountered a run-time error, the last response element will be an exception.
51-
if isinstance(result[-1], ResponseError):
52-
raise result[-1]
45+
result = self.graph.query(command)
5346
self.update_statistics(result)
5447

5548
def quote_string(self, cell):
@@ -65,8 +58,14 @@ def quote_string(self, cell):
6558
cell = "".join(["\"", cell, "\""])
6659
return cell
6760

61+
# Raise an exception if the query triggers a compile-time error
62+
def validate_query(self):
63+
command = " ".join(["CYPHER rows=[]", self.query])
64+
# The plan call will raise an error if the query is malformed or invalid.
65+
self.graph.execution_plan(command)
66+
6867
def process_update_csv(self):
69-
entity_count = self.count_entities()
68+
entity_count = count_entities(self.filename)
7069

7170
with open(self.filename, 'rt') as f:
7271
if self.no_header is False:
@@ -75,7 +74,7 @@ def process_update_csv(self):
7574
reader = csv.reader(f, delimiter=self.separator, skipinitialspace=True, quoting=csv.QUOTE_NONE, escapechar='\\')
7675

7776
rows_strs = []
78-
with click.progressbar(reader, length=entity_count, label=self.graph) as reader:
77+
with click.progressbar(reader, length=entity_count, label=self.graph_name) as reader:
7978
for row in reader:
8079
# Prepare the string representation of the current row.
8180
row = ",".join([self.quote_string(cell) for cell in row])
@@ -145,6 +144,7 @@ def bulk_update(graph, host, port, password, unix_socket_path, query, variable_n
145144
pass
146145

147146
updater = BulkUpdate(graph, max_token_size, separator, no_header, csv, query, variable_name, client)
147+
updater.validate_query()
148148
updater.process_update_csv()
149149

150150
end_time = timer()

test/test_bulk_update.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,35 @@ def test02_traversal_updates(self):
9797
["c", "c2"]]
9898
self.assertEqual(query_result.result_set, expected_result)
9999

100-
def test03_custom_delimiter(self):
101-
"""Validate that non-comma delimiters produce the correct results."""
100+
def test03_datatypes(self):
101+
"""Validate that all RedisGraph datatypes are supported by the bulk updater."""
102102
graphname = "tmpgraph2"
103103
# Write temporary files
104+
with open('/tmp/csv.tmp', mode='w') as csv_file:
105+
out = csv.writer(csv_file)
106+
out.writerow([0, 1.5, "true", "string", "[1, 'nested_str']"])
107+
108+
runner = CliRunner()
109+
res = runner.invoke(bulk_update, ['--csv', '/tmp/csv.tmp',
110+
'--query', 'CREATE (a:L) SET a.intval = row[0], a.doubleval = row[1], a.boolval = row[2], a.stringval = row[3], a.arrayval = row[4]',
111+
'--no-header',
112+
graphname], catch_exceptions=False)
113+
114+
self.assertEqual(res.exit_code, 0)
115+
self.assertIn('Nodes created: 1', res.output)
116+
self.assertIn('Properties set: 5', res.output)
117+
118+
tmp_graph = Graph(graphname, self.redis_con)
119+
query_result = tmp_graph.query('MATCH (a) RETURN a.intval, a.doubleval, a.boolval, a.stringval, a.arrayval')
120+
121+
# Validate that the expected results are all present in the graph
122+
expected_result = [[0, 1.5, True, "string", "[1,'nested_str']"]]
123+
self.assertEqual(query_result.result_set, expected_result)
124+
125+
def test04_custom_delimiter(self):
126+
"""Validate that non-comma delimiters produce the correct results."""
127+
graphname = "tmpgraph3"
128+
# Write temporary files
104129
with open('/tmp/csv.tmp', mode='w') as csv_file:
105130
out = csv.writer(csv_file, delimiter='|')
106131
out.writerow(["id", "name"])
@@ -140,7 +165,7 @@ def test03_custom_delimiter(self):
140165
self.assertNotIn('Nodes created', res.output)
141166
self.assertNotIn('Properties set', res.output)
142167

143-
def test04_custom_variable_name(self):
168+
def test05_custom_variable_name(self):
144169
"""Validate that the user can specify the name of the 'row' query variable."""
145170
graphname = "variable_name"
146171
runner = CliRunner()
@@ -178,9 +203,9 @@ def test04_custom_variable_name(self):
178203
['Valerie Abigail Arad', 31, 'female', 'married']]
179204
self.assertEqual(query_result.result_set, expected_result)
180205

181-
def test05_no_header(self):
206+
def test06_no_header(self):
182207
"""Validate that the '--no-header' option works properly."""
183-
graphname = "tmpgraph3"
208+
graphname = "tmpgraph4"
184209
# Write temporary files
185210
with open('/tmp/csv.tmp', mode='w') as csv_file:
186211
out = csv.writer(csv_file)
@@ -208,7 +233,7 @@ def test05_no_header(self):
208233
[5, "b"]]
209234
self.assertEqual(query_result.result_set, expected_result)
210235

211-
def test06_batched_update(self):
236+
def test07_batched_update(self):
212237
"""Validate that updates performed over multiple batches produce the correct results."""
213238
graphname = "batched_update"
214239

@@ -238,9 +263,9 @@ def test06_batched_update(self):
238263
expected_result = [[prop_str]]
239264
self.assertEqual(query_result.result_set, expected_result)
240265

241-
def test07_runtime_error(self):
266+
def test08_runtime_error(self):
242267
"""Validate that run-time errors are captured by the bulk updater."""
243-
graphname = "tmpgraph1"
268+
graphname = "tmpgraph5"
244269

245270
# Write temporary files
246271
with open('/tmp/csv.tmp', mode='w') as csv_file:
@@ -255,9 +280,21 @@ def test07_runtime_error(self):
255280
self.assertNotEqual(res.exit_code, 0)
256281
self.assertIn("Cannot merge node", str(res.exception))
257282

258-
def test07_invalid_inputs(self):
283+
def test09_compile_time_error(self):
284+
"""Validate that malformed queries trigger an early exit from the bulk updater."""
285+
graphname = "tmpgraph5"
286+
runner = CliRunner()
287+
res = runner.invoke(bulk_update, ['--csv', '/tmp/csv.tmp',
288+
'--query', 'CREATE (:L {val: row[0], val2: undefined_identifier})',
289+
'--no-header',
290+
graphname])
291+
292+
self.assertNotEqual(res.exit_code, 0)
293+
self.assertIn("undefined_identifier not defined", str(res.exception))
294+
295+
def test10_invalid_inputs(self):
259296
"""Validate that the bulk updater handles invalid inputs incorrectly."""
260-
graphname = "tmpgraph1"
297+
graphname = "tmpgraph6"
261298

262299
# Attempt to insert a non-existent CSV file.
263300
runner = CliRunner()
@@ -267,21 +304,3 @@ def test07_invalid_inputs(self):
267304

268305
self.assertNotEqual(res.exit_code, 0)
269306
self.assertIn("No such file", str(res.exception))
270-
271-
# Write temporary files
272-
with open('/tmp/csv.tmp', mode='w') as csv_file:
273-
out = csv.writer(csv_file)
274-
out.writerow(["id", "name"])
275-
out.writerow([0, "a"])
276-
out.writerow([5, "b"])
277-
out.writerow([3, "c"])
278-
279-
# Attempt to access a non-existent column.
280-
res = runner.invoke(bulk_update, ['--csv', '/tmp/csv.tmp',
281-
'--query', 'CREATE (:L {val: row[3]})',
282-
graphname])
283-
284-
# self.assertNotEqual(res.exit_code, 0)
285-
# import ipdb
286-
# ipdb.set_trace()
287-
# self.assertIn("No such file", str(res.exception))

0 commit comments

Comments
 (0)