Skip to content

Commit 1f568b9

Browse files
committed
Add implementation for Wcc.estimate for Arrow endpoints
1 parent 7a09284 commit 1f568b9

File tree

8 files changed

+64
-38
lines changed

8 files changed

+64
-38
lines changed

graphdatascience/procedure_surface/api/wcc_endpoints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def write(
246246
@abstractmethod
247247
def estimate(
248248
self,
249-
graph_name: Optional[str] = None,
249+
G: Optional[Graph] = None,
250250
projection_config: Optional[dict[str, Any]] = None,
251251
) -> EstimationResult:
252252
"""
@@ -259,8 +259,8 @@ def estimate(
259259
260260
Parameters
261261
----------
262-
graph_name : Optional[str], optional
263-
Name of the graph to be used in the estimation
262+
G : Optional[Graph], optional
263+
The graph to be used in the estimation
264264
projection_config : Optional[dict[str, Any]], optional
265265
Configuration dictionary for the projection.
266266

graphdatascience/procedure_surface/arrow/arrow_wcc_endpoints.py renamed to graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import json
12
from typing import Any, List, Optional
23

34
from pandas import DataFrame
45

56
from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
7+
from ...arrow_client.v2.data_mapper_utils import deserialize_single
68
from ...arrow_client.v2.job_client import JobClient
79
from ...arrow_client.v2.mutation_client import MutationClient
810
from ...arrow_client.v2.write_back_client import WriteBackClient
@@ -175,6 +177,15 @@ def write(
175177
return WccWriteResult(**computation_result)
176178

177179
def estimate(
178-
self, graph_name: Optional[str] = None, projection_config: Optional[dict[str, Any]] = None
180+
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
179181
) -> EstimationResult:
180-
pass
182+
if G is not None:
183+
payload = {"graphName": G.name()}
184+
elif projection_config is not None:
185+
payload = projection_config
186+
else:
187+
raise ValueError("Either graph_name or projection_config must be provided.")
188+
189+
res = self._arrow_client.do_action_with_retry("v2/community.wcc.estimate", json.dumps(payload).encode("utf-8"))
190+
191+
return EstimationResult(**deserialize_single(res))

graphdatascience/procedure_surface/cypher/wcc_cypher_endpoints.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import OrderedDict
12
from typing import Any, List, Optional, Union
23

34
from pandas import DataFrame
@@ -177,18 +178,20 @@ def write(
177178
return WccWriteResult(**result.to_dict())
178179

179180
def estimate(
180-
self, graph_name: Optional[str] = None, projection_config: Optional[dict[str, Any]] = None
181+
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
181182
) -> EstimationResult:
182-
config: Union[str, dict[str, Any]] = {}
183+
config: Union[dict[str, Any]] = OrderedDict()
183184

184-
if graph_name is not None:
185-
config = graph_name
185+
if G is not None:
186+
config["graphNameOrConfiguration"] = G.name()
186187
elif projection_config is not None:
187-
config = projection_config
188+
config["graphNameOrConfiguration"] = projection_config
188189
else:
189190
raise ValueError("Either graph_name or projection_config must be provided.")
190191

191-
params = CallParameters(config=config)
192+
config["algoConfig"] = {}
193+
194+
params = CallParameters(**config)
192195

193196
result = self._query_runner.call_procedure(endpoint="gds.wcc.stats.estimate", params=params).squeeze()
194197

graphdatascience/tests/integrationV2/procedure_surface/arrow/conftest.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66
from testcontainers.core.container import DockerContainer
77
from testcontainers.core.waiting_utils import wait_for_logs
8-
from testcontainers.neo4j import Neo4jContainer
98

109
from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
1110
from graphdatascience.arrow_client.arrow_info import ArrowInfo
@@ -28,27 +27,12 @@ def password_file() -> Generator[str, None, None]:
2827
os.rmdir(temp_dir)
2928

3029

31-
@pytest.fixture(scope="session")
32-
def neo4j_database_container() -> Generator[Neo4jContainer, None, None]:
33-
neo4j_image = os.getenv("NEO4J_DATABASE_IMAGE", "neo4j:5.11-enterprise")
34-
35-
neo4j_container = (
36-
Neo4jContainer(
37-
image=neo4j_image,
38-
)
39-
.with_env("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes")
40-
.with_env("NEO4J_dbms_security_procedures_unrestricted", "gds.*")
41-
.with_env("NEO4J_dbms_security_procedures_allowlist", "gds.*")
42-
)
43-
44-
with neo4j_container as neo4j_db:
45-
wait_for_logs(neo4j_db)
46-
yield neo4j_db
47-
48-
4930
@pytest.fixture(scope="session")
5031
def session_container(password_file: str) -> Generator[DockerContainer, None, None]:
51-
session_image = os.getenv("GDS_SESSION_IMAGE")
32+
session_image = os.getenv(
33+
"GDS_SESSION_IMAGE",
34+
"europe-west1-docker.pkg.dev/aura-docker-images/aura/gds-session:97ac47f7928c0533b3099539f0f5b3058a52c203",
35+
)
5236

5337
if session_image is None:
5438
raise ValueError("GDS_SESSION_IMAGE environment variable is not set")

graphdatascience/tests/integrationV2/procedure_surface/arrow/test_wcc_arrow_endpoints.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from graphdatascience import Graph
77
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
8-
from graphdatascience.procedure_surface.arrow.arrow_wcc_endpoints import WccArrowEndpoints
8+
from graphdatascience.procedure_surface.arrow.wcc_arrow_endpoints import WccArrowEndpoints
99

1010

1111
class MockGraph(Graph):
@@ -71,3 +71,17 @@ def test_wcc_mutate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> No
7171
assert result.post_processing_millis >= 0
7272
assert result.mutate_millis >= 0
7373
assert result.node_properties_written == 3
74+
75+
76+
def test_wcc_estimate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> None:
77+
result = wcc_endpoints.estimate(sample_graph)
78+
79+
assert result.node_count == 3
80+
assert result.relationship_count == 1
81+
assert "Bytes" in result.required_memory
82+
# assert result.tree_view > 0
83+
# assert result.map_view > 0
84+
assert result.bytes_min > 0
85+
assert result.bytes_max > 0
86+
assert result.heap_percentage_min > 0
87+
assert result.heap_percentage_max > 0

graphdatascience/tests/integrationV2/procedure_surface/cypher/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
@pytest.fixture(scope="session")
14-
def neo4j_database_container() -> Generator[Neo4jContainer, None, None]:
14+
def gds_plugin_container() -> Generator[Neo4jContainer, None, None]:
1515
neo4j_image = os.getenv("NEO4J_DATABASE_IMAGE", "neo4j:enterprise")
1616

1717
neo4j_container = (
@@ -28,8 +28,8 @@ def neo4j_database_container() -> Generator[Neo4jContainer, None, None]:
2828

2929

3030
@pytest.fixture
31-
def query_runner(neo4j_database_container: DockerContainer) -> Generator[QueryRunner, None, None]:
31+
def query_runner(gds_plugin_container: DockerContainer) -> Generator[QueryRunner, None, None]:
3232
yield Neo4jQueryRunner.create_for_db(
33-
f"bolt://localhost:{neo4j_database_container.get_exposed_port(7687)}",
33+
f"bolt://localhost:{gds_plugin_container.get_exposed_port(7687)}",
3434
("neo4j", "password"),
3535
)

graphdatascience/tests/integrationV2/procedure_surface/cypher/test_wcc_cypher_endpoints.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from graphdatascience import Graph, QueryRunner
6-
from graphdatascience.procedure_surface.arrow.arrow_wcc_endpoints import WccArrowEndpoints
6+
from graphdatascience.procedure_surface.arrow.wcc_arrow_endpoints import WccArrowEndpoints
77
from graphdatascience.procedure_surface.cypher.wcc_cypher_endpoints import WccCypherEndpoints
88

99

@@ -73,3 +73,17 @@ def test_wcc_mutate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> No
7373
assert result.post_processing_millis >= 0
7474
assert result.mutate_millis >= 0
7575
assert result.node_properties_written == 3
76+
77+
78+
def test_wcc_estimate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> None:
79+
result = wcc_endpoints.estimate(sample_graph)
80+
81+
assert result.node_count == 3
82+
assert result.relationship_count == 1
83+
assert "Bytes" in result.required_memory
84+
# assert result.tree_view > 0
85+
# assert result.map_view > 0
86+
assert result.bytes_min > 0
87+
assert result.bytes_max > 0
88+
assert result.heap_percentage_min > 0
89+
assert result.heap_percentage_max > 0

graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def test_estimate_with_graph_name(graph: Graph) -> None:
337337

338338
query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.stats.estimate": pd.DataFrame([result])})
339339

340-
estimate = WccCypherEndpoints(query_runner).estimate(graph_name=graph.name())
340+
estimate = WccCypherEndpoints(query_runner).estimate(G=graph)
341341

342342
assert estimate.node_count == 100
343343
assert estimate.relationship_count == 200
@@ -348,7 +348,7 @@ def test_estimate_with_graph_name(graph: Graph) -> None:
348348
assert len(query_runner.queries) == 1
349349
assert "gds.wcc.stats.estimate" in query_runner.queries[0]
350350
params = query_runner.params[0]
351-
assert params["config"] == "test_graph"
351+
assert params["graphNameOrConfiguration"] == "test_graph"
352352

353353

354354
def test_estimate_with_projection_config() -> None:

0 commit comments

Comments
 (0)