Skip to content

Commit 6b131c6

Browse files
committed
Add configuration for testcontainers and gds sessions
1 parent 713cc5c commit 6b131c6

File tree

6 files changed

+277
-0
lines changed

6 files changed

+277
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from typing import List, Optional, Any
2+
3+
from pandas import DataFrame
4+
5+
from ..api.estimation_result import EstimationResult
6+
from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
7+
from ...arrow_client.v2.job_client import JobClient
8+
from ...arrow_client.v2.mutation_client import MutationClient
9+
from ...arrow_client.v2.write_back_client import WriteBackClient
10+
from ...graph.graph_object import Graph
11+
from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult
12+
from ..utils.config_converter import ConfigConverter
13+
14+
WCC_ENDPOINT = "v2/community.wcc"
15+
16+
17+
class WccArrowEndpoints(WccEndpoints):
18+
def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[WriteBackClient] = None):
19+
self._arrow_client = arrow_client
20+
self._write_back_client = write_back_client
21+
22+
def mutate(
23+
self,
24+
G: Graph,
25+
mutate_property: str,
26+
threshold: Optional[float] = None,
27+
relationship_types: Optional[List[str]] = None,
28+
node_labels: Optional[List[str]] = None,
29+
sudo: Optional[bool] = None,
30+
log_progress: Optional[bool] = None,
31+
username: Optional[str] = None,
32+
concurrency: Optional[int] = None,
33+
job_id: Optional[str] = None,
34+
seed_property: Optional[str] = None,
35+
consecutive_ids: Optional[bool] = None,
36+
relationship_weight_property: Optional[str] = None,
37+
) -> WccMutateResult:
38+
config = ConfigConverter.convert_to_gds_config(
39+
graph_name=G.name(),
40+
concurrency=concurrency,
41+
consecutive_ids=consecutive_ids,
42+
job_id=job_id,
43+
log_progress=log_progress,
44+
node_labels=node_labels,
45+
relationship_types=relationship_types,
46+
relationship_weight_property=relationship_weight_property,
47+
seed_property=seed_property,
48+
sudo=sudo,
49+
threshold=threshold,
50+
)
51+
52+
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
53+
54+
mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property)
55+
computation_result = JobClient.get_summary(self._arrow_client, job_id)
56+
57+
computation_result["nodePropertiesWritten"] = mutate_result.node_properties_written
58+
computation_result["mutateMillis"] = 0
59+
60+
return WccMutateResult(**computation_result)
61+
62+
def stats(
63+
self,
64+
G: Graph,
65+
threshold: Optional[float] = None,
66+
relationship_types: Optional[List[str]] = None,
67+
node_labels: Optional[List[str]] = None,
68+
sudo: Optional[bool] = None,
69+
log_progress: Optional[bool] = None,
70+
username: Optional[str] = None,
71+
concurrency: Optional[int] = None,
72+
job_id: Optional[str] = None,
73+
seed_property: Optional[str] = None,
74+
consecutive_ids: Optional[bool] = None,
75+
relationship_weight_property: Optional[str] = None,
76+
) -> WccStatsResult:
77+
config = ConfigConverter.convert_to_gds_config(
78+
graph_name=G.name(),
79+
concurrency=concurrency,
80+
consecutive_ids=consecutive_ids,
81+
job_id=job_id,
82+
log_progress=log_progress,
83+
node_labels=node_labels,
84+
relationship_types=relationship_types,
85+
relationship_weight_property=relationship_weight_property,
86+
seed_property=seed_property,
87+
sudo=sudo,
88+
threshold=threshold,
89+
)
90+
91+
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
92+
computation_result = JobClient.get_summary(self._arrow_client, job_id)
93+
94+
return WccStatsResult(**computation_result)
95+
96+
def stream(
97+
self,
98+
G: Graph,
99+
min_component_size: Optional[int] = None,
100+
threshold: Optional[float] = None,
101+
relationship_types: Optional[List[str]] = None,
102+
node_labels: Optional[List[str]] = None,
103+
sudo: Optional[bool] = None,
104+
log_progress: Optional[bool] = None,
105+
username: Optional[str] = None,
106+
concurrency: Optional[int] = None,
107+
job_id: Optional[str] = None,
108+
seed_property: Optional[str] = None,
109+
consecutive_ids: Optional[bool] = None,
110+
relationship_weight_property: Optional[str] = None,
111+
) -> DataFrame:
112+
config = ConfigConverter.convert_to_gds_config(
113+
graph_name=G.name(),
114+
concurrency=concurrency,
115+
consecutive_ids=consecutive_ids,
116+
job_id=job_id,
117+
log_progress=log_progress,
118+
min_component_size=min_component_size,
119+
node_labels=node_labels,
120+
relationship_types=relationship_types,
121+
relationship_weight_property=relationship_weight_property,
122+
seed_property=seed_property,
123+
sudo=sudo,
124+
threshold=threshold,
125+
)
126+
127+
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
128+
return JobClient.stream_results(self._arrow_client, G.name(), job_id)
129+
130+
def write(
131+
self,
132+
G: Graph,
133+
write_property: str,
134+
min_component_size: Optional[int] = None,
135+
threshold: Optional[float] = None,
136+
relationship_types: Optional[List[str]] = None,
137+
node_labels: Optional[List[str]] = None,
138+
sudo: Optional[bool] = None,
139+
log_progress: Optional[bool] = None,
140+
username: Optional[str] = None,
141+
concurrency: Optional[int] = None,
142+
job_id: Optional[str] = None,
143+
seed_property: Optional[str] = None,
144+
consecutive_ids: Optional[bool] = None,
145+
relationship_weight_property: Optional[str] = None,
146+
write_concurrency: Optional[int] = None,
147+
) -> WccWriteResult:
148+
config = ConfigConverter.convert_to_gds_config(
149+
graph_name=G.name(),
150+
concurrency=concurrency,
151+
consecutive_ids=consecutive_ids,
152+
job_id=job_id,
153+
log_progress=log_progress,
154+
min_component_size=min_component_size,
155+
node_labels=node_labels,
156+
relationship_types=relationship_types,
157+
relationship_weight_property=relationship_weight_property,
158+
seed_property=seed_property,
159+
sudo=sudo,
160+
threshold=threshold,
161+
)
162+
163+
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
164+
computation_result = JobClient.get_summary(self._arrow_client, job_id)
165+
166+
if self._write_back_client is None:
167+
raise Exception("Write back client is not initialized")
168+
169+
write_millis = self._write_back_client.write(
170+
G.name(), job_id, write_concurrency if write_concurrency is not None else concurrency
171+
)
172+
173+
computation_result["writeMillis"] = write_millis
174+
175+
return WccWriteResult(**computation_result)
176+
177+
def estimate(self, graph_name: Optional[str] = None,
178+
projection_config: Optional[dict[str, Any]] = None) -> EstimationResult:
179+
pass
180+
181+

graphdatascience/tests/integrationV2/__init__.py

Whitespace-only changes.

graphdatascience/tests/integrationV2/procedure_surface/__init__.py

Whitespace-only changes.

graphdatascience/tests/integrationV2/procedure_surface/arrow/__init__.py

Whitespace-only changes.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import os
2+
import tempfile
3+
from typing import Generator
4+
5+
import pytest
6+
from testcontainers.core.container import DockerContainer
7+
from testcontainers.core.waiting_utils import wait_for_logs
8+
from testcontainers.neo4j import Neo4jContainer
9+
10+
from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
11+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
12+
from graphdatascience.arrow_client.v2.write_back_client import WriteBackClient
13+
from graphdatascience.query_runner.arrow_info import ArrowInfo
14+
15+
16+
@pytest.fixture(scope="session")
17+
def password_file():
18+
"""Create a temporary file and return its path."""
19+
temp_dir = tempfile.mkdtemp()
20+
temp_file_path = os.path.join(temp_dir, "password")
21+
22+
with open(temp_file_path, "w") as f:
23+
f.write("password")
24+
25+
yield temp_dir
26+
27+
# Clean up the file and directory
28+
os.unlink(temp_file_path)
29+
os.rmdir(temp_dir)
30+
31+
32+
@pytest.fixture(scope="session")
33+
def neo4j_database_container() -> Generator[Neo4jContainer, None, None]:
34+
neo4j_image = os.getenv("NEO4J_DATABASE_IMAGE", "neo4j:5.11-enterprise")
35+
36+
neo4j_container = (
37+
Neo4jContainer(
38+
image=neo4j_image,
39+
)
40+
.with_env("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes")
41+
.with_env("NEO4J_dbms_security_procedures_unrestricted", "gds.*")
42+
.with_env("NEO4J_dbms_security_procedures_allowlist", "gds.*")
43+
)
44+
45+
with neo4j_container as neo4j_db:
46+
wait_for_logs(neo4j_db)
47+
yield neo4j_db
48+
49+
50+
@pytest.fixture(scope="session")
51+
def session_container(password_file: str) -> Generator[DockerContainer, None, None]:
52+
session_image = os.getenv("GDS_SESSION_IMAGE")
53+
54+
if session_image is None:
55+
raise ValueError("GDS_SESSION_IMAGE environment variable is not set")
56+
57+
session_container = (
58+
DockerContainer(
59+
image=session_image,
60+
)
61+
.with_env("ALLOW_LIST", "DEFAULT")
62+
.with_env("DNS_NAME", "gds-session")
63+
.with_env("PAGE_CACHE_SIZE", "100M")
64+
.with_exposed_ports(8491)
65+
.with_network_aliases(["gds-session"])
66+
.with_volume_mapping(password_file, "/passwords")
67+
)
68+
69+
with session_container as session_container:
70+
wait_for_logs(session_container, "Running GDS tasks: 0")
71+
yield session_container
72+
stdout, stderr = session_container.get_logs()
73+
print(stdout)
74+
75+
76+
@pytest.fixture
77+
def arrow_client(session_container: DockerContainer) -> AuthenticatedArrowClient:
78+
"""Create an authenticated Arrow client connected to the session container."""
79+
host = session_container.get_container_host_ip()
80+
port = session_container.get_exposed_port(8491)
81+
82+
return AuthenticatedArrowClient.create(
83+
arrow_info=ArrowInfo(f"{host}:{port}", True, True, ["v1", "v2"]),
84+
auth=UsernamePasswordAuthentication("neo4j", "password"),
85+
encrypted=False,
86+
)
87+
88+
89+
@pytest.fixture
90+
def write_back_client(neo4j_session_container: DockerContainer) -> WriteBackClient:
91+
"""Create a write-back client for the session container."""
92+
host = neo4j_session_container.get_container_host_ip()
93+
port = neo4j_session_container.get_exposed_port(8491)
94+
95+
return WriteBackClient(host=host, port=port, username="neo4j", password="test_password")

requirements/dev/dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ enum-tools[sphinx] == 0.12.0
99
types-requests
1010
types-tqdm
1111
types-python-dateutil
12+
testcontainers >= 4.0

0 commit comments

Comments
 (0)