Skip to content

Commit ed37a4d

Browse files
committed
Use config converter also for Cypher endpoints
1 parent 03eb011 commit ed37a4d

File tree

5 files changed

+158
-184
lines changed

5 files changed

+158
-184
lines changed

graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Any, List, Optional
1+
from typing import List, Optional
22

33
from pandas import DataFrame
44

5-
from .arrow_config_converter import ArrowConfigConverter
5+
from graphdatascience.procedure_surface.config_converter import ConfigConverter
6+
67
from ...arrow_client.authenticated_arrow_client import AuthenticatedArrowClient
78
from ...arrow_client.v2.job_client import JobClient
89
from ...arrow_client.v2.mutation_client import MutationClient
@@ -34,18 +35,18 @@ def mutate(
3435
consecutive_ids: Optional[bool] = None,
3536
relationship_weight_property: Optional[str] = None,
3637
) -> WccMutateResult:
37-
config = ArrowConfigConverter.build_configuration(
38-
G,
39-
concurrency = concurrency,
40-
consecutive_ids = consecutive_ids,
41-
job_id = job_id,
42-
log_progress = log_progress,
43-
node_labels = node_labels,
44-
relationship_types = relationship_types,
45-
relationship_weight_property = relationship_weight_property,
46-
seed_property = seed_property,
47-
sudo = sudo,
48-
threshold = threshold,
38+
config = ConfigConverter.convert_to_gds_config(
39+
graph_name=G.name(),
40+
concurrency=concurrency,
41+
consecutive_ids=consecutive_ids,
42+
job_id=job_id,
43+
log_progress=log_progress,
44+
node_labels=node_labels,
45+
relationship_types=relationship_types,
46+
relationship_weight_property=relationship_weight_property,
47+
seed_property=seed_property,
48+
sudo=sudo,
49+
threshold=threshold,
4950
)
5051

5152
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
@@ -79,18 +80,18 @@ def stats(
7980
consecutive_ids: Optional[bool] = None,
8081
relationship_weight_property: Optional[str] = None,
8182
) -> WccStatsResult:
82-
config = ArrowConfigConverter.build_configuration(
83-
G,
84-
concurrency = concurrency,
85-
consecutive_ids = consecutive_ids,
86-
job_id = job_id,
87-
log_progress = log_progress,
88-
node_labels = node_labels,
89-
relationship_types = relationship_types,
90-
relationship_weight_property = relationship_weight_property,
91-
seed_property = seed_property,
92-
sudo = sudo,
93-
threshold = threshold,
83+
config = ConfigConverter.convert_to_gds_config(
84+
graph_name=G.name(),
85+
concurrency=concurrency,
86+
consecutive_ids=consecutive_ids,
87+
job_id=job_id,
88+
log_progress=log_progress,
89+
node_labels=node_labels,
90+
relationship_types=relationship_types,
91+
relationship_weight_property=relationship_weight_property,
92+
seed_property=seed_property,
93+
sudo=sudo,
94+
threshold=threshold,
9495
)
9596

9697
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
@@ -121,19 +122,19 @@ def stream(
121122
consecutive_ids: Optional[bool] = None,
122123
relationship_weight_property: Optional[str] = None,
123124
) -> DataFrame:
124-
config = ArrowConfigConverter.build_configuration(
125-
G,
126-
concurrency = concurrency,
127-
consecutive_ids = consecutive_ids,
128-
job_id = job_id,
129-
log_progress = log_progress,
130-
min_component_size = min_component_size,
131-
node_labels = node_labels,
132-
relationship_types = relationship_types,
133-
relationship_weight_property = relationship_weight_property,
134-
seed_property = seed_property,
135-
sudo = sudo,
136-
threshold = threshold,
125+
config = ConfigConverter.convert_to_gds_config(
126+
graph_name=G.name(),
127+
concurrency=concurrency,
128+
consecutive_ids=consecutive_ids,
129+
job_id=job_id,
130+
log_progress=log_progress,
131+
min_component_size=min_component_size,
132+
node_labels=node_labels,
133+
relationship_types=relationship_types,
134+
relationship_weight_property=relationship_weight_property,
135+
seed_property=seed_property,
136+
sudo=sudo,
137+
threshold=threshold,
137138
)
138139

139140
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
@@ -157,20 +158,19 @@ def write(
157158
relationship_weight_property: Optional[str] = None,
158159
write_concurrency: Optional[int] = None,
159160
) -> WccWriteResult:
160-
161-
config = ArrowConfigConverter.build_configuration(
162-
G,
163-
concurrency = concurrency,
164-
consecutive_ids = consecutive_ids,
165-
job_id = job_id,
166-
log_progress = log_progress,
167-
min_component_size = min_component_size,
168-
node_labels = node_labels,
169-
relationship_types = relationship_types,
170-
relationship_weight_property = relationship_weight_property,
171-
seed_property = seed_property,
172-
sudo = sudo,
173-
threshold = threshold,
161+
config = ConfigConverter.convert_to_gds_config(
162+
graph_name=G.name(),
163+
concurrency=concurrency,
164+
consecutive_ids=consecutive_ids,
165+
job_id=job_id,
166+
log_progress=log_progress,
167+
min_component_size=min_component_size,
168+
node_labels=node_labels,
169+
relationship_types=relationship_types,
170+
relationship_weight_property=relationship_weight_property,
171+
seed_property=seed_property,
172+
sudo=sudo,
173+
threshold=threshold,
174174
)
175175

176176
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
@@ -192,4 +192,4 @@ def write(
192192
computation_result["postProcessingMillis"],
193193
computation_result["nodePropertiesWritten"],
194194
computation_result["configuration"],
195-
)
195+
)
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,33 @@
1-
from typing import Optional, Any, Dict
1+
from typing import Any, Dict, Optional
22

3-
from graphdatascience import Graph
4-
5-
6-
class ArrowConfigConverter:
73

4+
class ConfigConverter:
85
@staticmethod
9-
def build_configuration(G: Graph, **kwargs: Optional[Any]) -> dict[str, Any]:
10-
config: dict[str, Any] = {
11-
"graphName": G.name(),
12-
}
6+
def convert_to_gds_config(**kwargs: Optional[Any]) -> dict[str, Any]:
7+
config: dict[str, Any] = {}
138

149
# Process kwargs
15-
processed_kwargs = ArrowConfigConverter._process_dict_values(kwargs)
10+
processed_kwargs = ConfigConverter._process_dict_values(kwargs)
1611
config.update(processed_kwargs)
1712

1813
return config
1914

2015
@staticmethod
2116
def _convert_to_camel_case(name: str) -> str:
2217
"""Convert a snake_case string to camelCase."""
23-
parts = name.split('_')
24-
return ''.join([word.capitalize() if i > 0 else word.lower() for i, word in enumerate(parts)])
18+
parts = name.split("_")
19+
return "".join([word.capitalize() if i > 0 else word.lower() for i, word in enumerate(parts)])
2520

2621
@staticmethod
2722
def _process_dict_values(input_dict: Dict[str, Any]) -> Dict[str, Any]:
2823
"""Process dictionary values, converting keys to camelCase and handling nested dictionaries."""
2924
result = {}
3025
for key, value in input_dict.items():
3126
if value is not None:
32-
camel_key = ArrowConfigConverter._convert_to_camel_case(key)
27+
camel_key = ConfigConverter._convert_to_camel_case(key)
3328
# Recursively process nested dictionaries
3429
if isinstance(value, dict):
35-
result[camel_key] = ArrowConfigConverter._process_dict_values(value)
30+
result[camel_key] = ConfigConverter._process_dict_values(value)
3631
else:
3732
result[camel_key] = value
38-
return result
33+
return result

0 commit comments

Comments
 (0)