Skip to content

Commit 03eb011

Browse files
committed
Generalize config extraction for arrow endpoints
1 parent d8f221c commit 03eb011

File tree

4 files changed

+87
-93
lines changed

4 files changed

+87
-93
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Optional, Any, Dict
2+
3+
from graphdatascience import Graph
4+
5+
6+
class ArrowConfigConverter:
7+
8+
@staticmethod
9+
def build_configuration(G: Graph, **kwargs: Optional[Any]) -> dict[str, Any]:
10+
config: dict[str, Any] = {
11+
"graphName": G.name(),
12+
}
13+
14+
# Process kwargs
15+
processed_kwargs = ArrowConfigConverter._process_dict_values(kwargs)
16+
config.update(processed_kwargs)
17+
18+
return config
19+
20+
@staticmethod
21+
def _convert_to_camel_case(name: str) -> str:
22+
"""Convert a snake_case string to camelCase."""
23+
parts = name.split('_')
24+
return ''.join([word.capitalize() if i > 0 else word.lower() for i, word in enumerate(parts)])
25+
26+
@staticmethod
27+
def _process_dict_values(input_dict: Dict[str, Any]) -> Dict[str, Any]:
28+
"""Process dictionary values, converting keys to camelCase and handling nested dictionaries."""
29+
result = {}
30+
for key, value in input_dict.items():
31+
if value is not None:
32+
camel_key = ArrowConfigConverter._convert_to_camel_case(key)
33+
# Recursively process nested dictionaries
34+
if isinstance(value, dict):
35+
result[camel_key] = ArrowConfigConverter._process_dict_values(value)
36+
else:
37+
result[camel_key] = value
38+
return result

graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py

Lines changed: 49 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pandas import DataFrame
44

5+
from .arrow_config_converter import ArrowConfigConverter
56
from ...arrow_client.authenticated_arrow_client import AuthenticatedArrowClient
67
from ...arrow_client.v2.job_client import JobClient
78
from ...arrow_client.v2.mutation_client import MutationClient
@@ -33,19 +34,18 @@ def mutate(
3334
consecutive_ids: Optional[bool] = None,
3435
relationship_weight_property: Optional[str] = None,
3536
) -> WccMutateResult:
36-
config = self._build_configuration(
37+
config = ArrowConfigConverter.build_configuration(
3738
G,
38-
concurrency,
39-
consecutive_ids,
40-
job_id,
41-
log_progress,
42-
None,
43-
node_labels,
44-
relationship_types,
45-
relationship_weight_property,
46-
seed_property,
47-
sudo,
48-
threshold,
39+
concurrency = concurrency,
40+
consecutive_ids = consecutive_ids,
41+
job_id = job_id,
42+
log_progress = log_progress,
43+
node_labels = node_labels,
44+
relationship_types = relationship_types,
45+
relationship_weight_property = relationship_weight_property,
46+
seed_property = seed_property,
47+
sudo = sudo,
48+
threshold = threshold,
4949
)
5050

5151
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
@@ -79,19 +79,18 @@ def stats(
7979
consecutive_ids: Optional[bool] = None,
8080
relationship_weight_property: Optional[str] = None,
8181
) -> WccStatsResult:
82-
config = self._build_configuration(
82+
config = ArrowConfigConverter.build_configuration(
8383
G,
84-
concurrency,
85-
consecutive_ids,
86-
job_id,
87-
log_progress,
88-
None,
89-
node_labels,
90-
relationship_types,
91-
relationship_weight_property,
92-
seed_property,
93-
sudo,
94-
threshold,
84+
concurrency = concurrency,
85+
consecutive_ids = consecutive_ids,
86+
job_id = job_id,
87+
log_progress = log_progress,
88+
node_labels = node_labels,
89+
relationship_types = relationship_types,
90+
relationship_weight_property = relationship_weight_property,
91+
seed_property = seed_property,
92+
sudo = sudo,
93+
threshold = threshold,
9594
)
9695

9796
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
@@ -122,19 +121,19 @@ def stream(
122121
consecutive_ids: Optional[bool] = None,
123122
relationship_weight_property: Optional[str] = None,
124123
) -> DataFrame:
125-
config = self._build_configuration(
124+
config = ArrowConfigConverter.build_configuration(
126125
G,
127-
concurrency,
128-
consecutive_ids,
129-
job_id,
130-
log_progress,
131-
min_component_size,
132-
node_labels,
133-
relationship_types,
134-
relationship_weight_property,
135-
seed_property,
136-
sudo,
137-
threshold,
126+
concurrency = concurrency,
127+
consecutive_ids = consecutive_ids,
128+
job_id = job_id,
129+
log_progress = log_progress,
130+
min_component_size = min_component_size,
131+
node_labels = node_labels,
132+
relationship_types = relationship_types,
133+
relationship_weight_property = relationship_weight_property,
134+
seed_property = seed_property,
135+
sudo = sudo,
136+
threshold = threshold,
138137
)
139138

140139
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
@@ -158,19 +157,20 @@ def write(
158157
relationship_weight_property: Optional[str] = None,
159158
write_concurrency: Optional[int] = None,
160159
) -> WccWriteResult:
161-
config = self._build_configuration(
160+
161+
config = ArrowConfigConverter.build_configuration(
162162
G,
163-
concurrency,
164-
consecutive_ids,
165-
job_id,
166-
log_progress,
167-
min_component_size,
168-
node_labels,
169-
relationship_types,
170-
relationship_weight_property,
171-
seed_property,
172-
sudo,
173-
threshold,
163+
concurrency = concurrency,
164+
consecutive_ids = consecutive_ids,
165+
job_id = job_id,
166+
log_progress = log_progress,
167+
min_component_size = min_component_size,
168+
node_labels = node_labels,
169+
relationship_types = relationship_types,
170+
relationship_weight_property = relationship_weight_property,
171+
seed_property = seed_property,
172+
sudo = sudo,
173+
threshold = threshold,
174174
)
175175

176176
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
@@ -192,48 +192,4 @@ def write(
192192
computation_result["postProcessingMillis"],
193193
computation_result["nodePropertiesWritten"],
194194
computation_result["configuration"],
195-
)
196-
197-
@staticmethod
198-
def _build_configuration(
199-
G: Graph,
200-
concurrency: Optional[int],
201-
consecutive_ids: Optional[bool],
202-
job_id: Optional[str],
203-
log_progress: Optional[bool],
204-
min_component_size: Optional[int],
205-
node_labels: Optional[List[str]],
206-
relationship_types: Optional[List[str]],
207-
relationship_weight_property: Optional[str],
208-
seed_property: Optional[str],
209-
sudo: Optional[bool],
210-
threshold: Optional[float],
211-
) -> dict[str, Any]:
212-
config: dict[str, Any] = {
213-
"graphName": G.name(),
214-
}
215-
216-
if min_component_size is not None:
217-
config["minComponentSize"] = min_component_size
218-
if threshold is not None:
219-
config["threshold"] = threshold
220-
if relationship_types is not None:
221-
config["relationshipTypes"] = relationship_types
222-
if node_labels is not None:
223-
config["nodeLabels"] = node_labels
224-
if sudo is not None:
225-
config["sudo"] = sudo
226-
if log_progress is not None:
227-
config["logProgress"] = log_progress
228-
if concurrency is not None:
229-
config["concurrency"] = concurrency
230-
if job_id is not None:
231-
config["jobId"] = job_id
232-
if seed_property is not None:
233-
config["seedProperty"] = seed_property
234-
if consecutive_ids is not None:
235-
config["consecutiveIds"] = consecutive_ids
236-
if relationship_weight_property is not None:
237-
config["relationshipWeightProperty"] = relationship_weight_property
238-
239-
return config
195+
)

graphdatascience/tests/unit/procedure_surface/__init__.py

Whitespace-only changes.

graphdatascience/tests/unit/procedure_surface/arrow/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)