Skip to content

Commit ae2f8cb

Browse files
committed
Add feature flag to enabled explicit APIs
1 parent ed37a4d commit ae2f8cb

File tree

8 files changed

+158
-79
lines changed

8 files changed

+158
-79
lines changed

graphdatascience/graph_data_science.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
import warnings
45
from types import TracebackType
56
from typing import Any, Optional, Type, Union
@@ -10,12 +11,13 @@
1011

1112
from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
1213
from graphdatascience.arrow_client.arrow_info import ArrowInfo
13-
from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints
14+
from graphdatascience.procedure_surface.cypher.wcc_cypher_endpoints import WccCypherEndpoints
1415

1516
from .call_builder import IndirectCallBuilder
1617
from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
1718
from .error.uncallable_namespace import UncallableNamespace
1819
from .graph.graph_proc_runner import GraphProcRunner
20+
from .procedure_surface.api.wcc_endpoints import WccEndpoints
1921
from .query_runner.arrow_query_runner import ArrowQueryRunner
2022
from .query_runner.neo4j_query_runner import Neo4jQueryRunner
2123
from .query_runner.query_runner import QueryRunner
@@ -118,15 +120,20 @@ def __init__(
118120
self._query_runner.set_show_progress(show_progress)
119121
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
120122

121-
self._wcc_endpoints = WccCypherEndpoints(self._query_runner)
123+
self._wcc_endpoints: Optional[WccEndpoints] = None
124+
if os.environ.get("ENABLE_EXPLICIT_ENDPOINTS") is not None:
125+
self._wcc_endpoints = WccCypherEndpoints(self._query_runner)
122126

123127
@property
124128
def graph(self) -> GraphProcRunner:
125129
return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
126130

127-
# @property
128-
# def wcc(self) -> WccEndpoints:
129-
# return self._wcc_endpoints
131+
@property
132+
def wcc(self) -> Union[WccEndpoints, IndirectCallBuilder]:
133+
if self._wcc_endpoints is None:
134+
return IndirectCallBuilder(self._query_runner, f"gds.{self._namespace}.wcc", self._server_version)
135+
136+
return self._wcc_endpoints
130137

131138
@property
132139
def util(self) -> UtilProcRunner:

graphdatascience/procedure_surface/api/wcc_endpoints.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ class WccMutateResult:
253253
node_properties_written: int
254254
configuration: dict[str, Any]
255255

256+
def __getitem__(self, item: str) -> Any:
257+
return getattr(self, item)
258+
256259

257260
@dataclass(frozen=True, repr=True)
258261
class WccStatsResult:
@@ -263,6 +266,9 @@ class WccStatsResult:
263266
post_processing_millis: int
264267
configuration: dict[str, Any]
265268

269+
def __getitem__(self, item: str) -> Any:
270+
return getattr(self, item)
271+
266272

267273
@dataclass(frozen=True, repr=True)
268274
class WccWriteResult:
@@ -274,3 +280,6 @@ class WccWriteResult:
274280
post_processing_millis: int
275281
node_properties_written: int
276282
configuration: dict[str, Any]
283+
284+
def __getitem__(self, item: str) -> Any:
285+
return getattr(self, item)

graphdatascience/session/aura_graph_data_science.py

Lines changed: 122 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3+
import os
34
from typing import Any, Callable, Optional, Union
45

56
from pandas import DataFrame
67

78
from graphdatascience import QueryRunner, ServerVersion
89
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
910
from graphdatascience.arrow_client.arrow_info import ArrowInfo
11+
from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient
12+
from graphdatascience.arrow_client.v2.write_back_client import WriteBackClient
1013
from graphdatascience.call_builder import IndirectCallBuilder
1114
from graphdatascience.endpoints import (
1215
AlphaRemoteEndpoints,
@@ -15,6 +18,8 @@
1518
)
1619
from graphdatascience.error.uncallable_namespace import UncallableNamespace
1720
from graphdatascience.graph.graph_remote_proc_runner import GraphRemoteProcRunner
21+
from graphdatascience.procedure_surface.api.wcc_endpoints import WccEndpoints
22+
from graphdatascience.procedure_surface.arrow.wcc_arrow_endpoints import WccArrowEndpoints
1823
from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner
1924
from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient
2025
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
@@ -24,15 +29,11 @@
2429
from graphdatascience.utils.util_remote_proc_runner import UtilRemoteProcRunner
2530

2631

27-
class AuraGraphDataScience(DirectEndpoints, UncallableNamespace):
28-
"""
29-
Primary API class for interacting with Neo4j database + Graph Data Science Session.
30-
Always bind this object to a variable called `gds`.
31-
"""
32+
class AuraGraphDataScienceFactory:
33+
"""Factory class for creating AuraGraphDataScience instances with all required components."""
3234

33-
@classmethod
34-
def create(
35-
cls,
35+
def __init__(
36+
self,
3637
session_bolt_connection_info: DbmsConnectionInfo,
3738
arrow_authentication: Optional[ArrowAuthentication],
3839
db_endpoint: Optional[Union[Neo4jQueryRunner, DbmsConnectionInfo]],
@@ -41,74 +42,114 @@ def create(
4142
arrow_tls_root_certs: Optional[bytes] = None,
4243
bookmarks: Optional[Any] = None,
4344
show_progress: bool = True,
44-
) -> AuraGraphDataScience:
45-
session_bolt_query_runner = Neo4jQueryRunner.create_for_session(
46-
endpoint=session_bolt_connection_info.uri,
47-
auth=session_bolt_connection_info.get_auth(),
48-
show_progress=show_progress,
45+
):
46+
self.session_bolt_connection_info = session_bolt_connection_info
47+
self.arrow_authentication = arrow_authentication
48+
self.db_endpoint = db_endpoint
49+
self.delete_fn = delete_fn
50+
self.arrow_disable_server_verification = arrow_disable_server_verification
51+
self.arrow_tls_root_certs = arrow_tls_root_certs
52+
self.bookmarks = bookmarks
53+
self.show_progress = show_progress
54+
55+
def create(self) -> AuraGraphDataScience:
56+
"""Create and configure an AuraGraphDataScience instance."""
57+
session_bolt_query_runner = self._create_session_bolt_query_runner()
58+
arrow_info = ArrowInfo.create(session_bolt_query_runner)
59+
session_arrow_query_runner = self._create_session_arrow_query_runner(session_bolt_query_runner, arrow_info)
60+
session_arrow_client = self._create_session_arrow_client(arrow_info, session_bolt_query_runner)
61+
gds_version = session_bolt_query_runner.server_version()
62+
63+
session_query_runner: QueryRunner
64+
65+
if self.db_endpoint is not None:
66+
db_bolt_query_runner = self._create_db_bolt_query_runner()
67+
session_query_runner = SessionQueryRunner.create(
68+
session_arrow_query_runner, db_bolt_query_runner, session_arrow_client, self.show_progress
69+
)
70+
wcc_endpoints = self._create_wcc_endpoints(arrow_info, session_bolt_query_runner, db_bolt_query_runner)
71+
else:
72+
session_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner)
73+
wcc_endpoints = self._create_wcc_endpoints(arrow_info, session_bolt_query_runner, None)
74+
75+
return AuraGraphDataScience(
76+
query_runner=session_query_runner,
77+
wcc_endpoints=wcc_endpoints,
78+
delete_fn=self.delete_fn,
79+
gds_version=gds_version,
4980
)
5081

51-
arrow_info = ArrowInfo.create(session_bolt_query_runner)
52-
session_arrow_query_runner = ArrowQueryRunner.create(
82+
def _create_session_bolt_query_runner(self) -> Neo4jQueryRunner:
83+
return Neo4jQueryRunner.create_for_session(
84+
endpoint=self.session_bolt_connection_info.uri,
85+
auth=self.session_bolt_connection_info.get_auth(),
86+
show_progress=self.show_progress,
87+
)
88+
89+
def _create_session_arrow_query_runner(
90+
self, session_bolt_query_runner: Neo4jQueryRunner, arrow_info: ArrowInfo
91+
) -> ArrowQueryRunner:
92+
return ArrowQueryRunner.create(
5393
fallback_query_runner=session_bolt_query_runner,
5494
arrow_info=arrow_info,
55-
arrow_authentication=arrow_authentication,
95+
arrow_authentication=self.arrow_authentication,
5696
encrypted=session_bolt_query_runner.encrypted(),
57-
disable_server_verification=arrow_disable_server_verification,
58-
tls_root_certs=arrow_tls_root_certs,
97+
disable_server_verification=self.arrow_disable_server_verification,
98+
tls_root_certs=self.arrow_tls_root_certs,
5999
)
60100

61-
# TODO: merge with the gds_arrow_client created inside ArrowQueryRunner
62-
session_arrow_client = GdsArrowClient.create(
101+
def _create_session_arrow_client(
102+
self, arrow_info: ArrowInfo, session_bolt_query_runner: Neo4jQueryRunner
103+
) -> GdsArrowClient:
104+
return GdsArrowClient.create(
63105
arrow_info,
64-
arrow_authentication,
106+
self.arrow_authentication,
65107
session_bolt_query_runner.encrypted(),
66-
arrow_disable_server_verification,
67-
arrow_tls_root_certs,
108+
self.arrow_disable_server_verification,
109+
self.arrow_tls_root_certs,
68110
)
69111

70-
gds_version = session_bolt_query_runner.server_version()
71-
72-
if db_endpoint is not None:
73-
if isinstance(db_endpoint, Neo4jQueryRunner):
74-
db_bolt_query_runner = db_endpoint
75-
else:
76-
db_bolt_query_runner = Neo4jQueryRunner.create_for_db(
77-
db_endpoint.uri,
78-
db_endpoint.get_auth(),
79-
aura_ds=True,
80-
show_progress=False,
81-
database=db_endpoint.database,
82-
)
83-
db_bolt_query_runner.set_bookmarks(bookmarks)
84-
85-
session_query_runner = SessionQueryRunner.create(
86-
session_arrow_query_runner, db_bolt_query_runner, session_arrow_client, show_progress
87-
)
88-
return cls(
89-
query_runner=session_query_runner,
90-
delete_fn=delete_fn,
91-
gds_version=gds_version,
112+
def _create_db_bolt_query_runner(self) -> Neo4jQueryRunner:
113+
if isinstance(self.db_endpoint, Neo4jQueryRunner):
114+
db_bolt_query_runner = self.db_endpoint
115+
elif isinstance(self.db_endpoint, DbmsConnectionInfo):
116+
db_bolt_query_runner = Neo4jQueryRunner.create_for_db(
117+
self.db_endpoint.uri,
118+
self.db_endpoint.get_auth(),
119+
aura_ds=True,
120+
show_progress=False,
121+
database=self.db_endpoint.database,
92122
)
93123
else:
94-
standalone_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner)
95-
return cls(
96-
query_runner=standalone_query_runner,
97-
delete_fn=delete_fn,
98-
gds_version=gds_version,
124+
raise ValueError("db_endpoint must be a Neo4jQueryRunner or a DbmsConnectionInfo")
125+
126+
db_bolt_query_runner.set_bookmarks(self.bookmarks)
127+
return db_bolt_query_runner
128+
129+
def _create_wcc_endpoints(
130+
self, arrow_info: ArrowInfo, session_bolt_query_runner: Neo4jQueryRunner, db_query_runner: Optional[QueryRunner]
131+
) -> Optional[WccEndpoints]:
132+
wcc_endpoints: Optional[WccEndpoints] = None
133+
if os.environ.get("ENABLE_EXPLICIT_ENDPOINTS") is not None:
134+
arrow_client = AuthenticatedArrowClient.create(
135+
arrow_info,
136+
self.arrow_authentication,
137+
session_bolt_query_runner.encrypted(),
138+
self.arrow_disable_server_verification,
139+
self.arrow_tls_root_certs,
99140
)
100141

101-
def __init__(
102-
self,
103-
query_runner: QueryRunner,
104-
delete_fn: Callable[[], bool],
105-
gds_version: ServerVersion,
106-
):
107-
self._query_runner = query_runner
108-
self._delete_fn = delete_fn
109-
self._server_version = gds_version
142+
write_back_client = WriteBackClient(arrow_client, db_query_runner) if db_query_runner is not None else None
110143

111-
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
144+
wcc_endpoints = WccArrowEndpoints(arrow_client, write_back_client)
145+
return wcc_endpoints
146+
147+
148+
class AuraGraphDataScience(DirectEndpoints, UncallableNamespace):
149+
"""
150+
Primary API class for interacting with Neo4j database + Graph Data Science Session.
151+
Always bind this object to a variable called `gds`.
152+
"""
112153

113154
def run_cypher(
114155
self,
@@ -133,6 +174,20 @@ def run_cypher(
133174
"""
134175
return self._query_runner.run_cypher(query, params, database, False)
135176

177+
def __init__(
178+
self,
179+
query_runner: QueryRunner,
180+
delete_fn: Callable[[], bool],
181+
gds_version: ServerVersion,
182+
wcc_endpoints: Optional[WccEndpoints] = None,
183+
):
184+
self._query_runner = query_runner
185+
self._delete_fn = delete_fn
186+
self._server_version = gds_version
187+
self._wcc_endpoints = wcc_endpoints
188+
189+
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
190+
136191
@property
137192
def graph(self) -> GraphRemoteProcRunner:
138193
return GraphRemoteProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
@@ -149,6 +204,13 @@ def alpha(self) -> AlphaRemoteEndpoints:
149204
def beta(self) -> BetaEndpoints:
150205
return BetaEndpoints(self._query_runner, "gds.beta", self._server_version)
151206

207+
@property
208+
def wcc(self) -> Union[WccEndpoints, IndirectCallBuilder]:
209+
if self._wcc_endpoints is None:
210+
return IndirectCallBuilder(self._query_runner, f"gds.{self._namespace}.wcc", self._server_version)
211+
212+
return self._wcc_endpoints
213+
152214
def __getattr__(self, attr: str) -> IndirectCallBuilder:
153215
return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version)
154216

graphdatascience/session/dedicated_sessions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from graphdatascience.session.aura_api import AuraApi
1212
from graphdatascience.session.aura_api_responses import SessionDetails
1313
from graphdatascience.session.aura_api_token_authentication import AuraApiTokenAuthentication
14-
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
14+
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience, AuraGraphDataScienceFactory
1515
from graphdatascience.session.cloud_location import CloudLocation
1616
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1717
from graphdatascience.session.session_info import SessionInfo
@@ -210,9 +210,9 @@ def _construct_client(
210210
arrow_authentication: ArrowAuthentication,
211211
db_runner: Optional[Neo4jQueryRunner],
212212
) -> AuraGraphDataScience:
213-
return AuraGraphDataScience.create(
213+
return AuraGraphDataScienceFactory(
214214
session_bolt_connection_info=session_bolt_connection_info,
215215
arrow_authentication=arrow_authentication,
216216
db_endpoint=db_runner,
217217
delete_fn=lambda: self._aura_api.delete_session(session_id=session_id),
218-
)
218+
).create()

graphdatascience/tests/integration/conftest.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from graphdatascience.graph_data_science import GraphDataScience
1111
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
1212
from graphdatascience.server_version.server_version import ServerVersion
13-
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
13+
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience, AuraGraphDataScienceFactory
1414
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1515

1616
URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
@@ -92,12 +92,13 @@ def gds_without_arrow() -> Generator[GraphDataScience, None, None]:
9292

9393
@pytest.fixture(scope="package", autouse=False)
9494
def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Generator[AuraGraphDataScience, None, None]:
95-
_gds = AuraGraphDataScience.create(
95+
_gds = AuraGraphDataScienceFactory(
9696
session_bolt_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]),
9797
arrow_authentication=UsernamePasswordAuthentication(AUTH[0], AUTH[1]),
9898
db_endpoint=DbmsConnectionInfo(AURA_DB_URI, AURA_DB_AUTH[0], AURA_DB_AUTH[1]),
9999
delete_fn=lambda: True,
100-
)
100+
).create()
101+
101102
_gds.set_database(DB)
102103

103104
yield _gds
@@ -107,12 +108,12 @@ def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Generator[AuraGraphD
107108

108109
@pytest.fixture(scope="package", autouse=False)
109110
def standalone_aura_gds() -> Generator[AuraGraphDataScience, None, None]:
110-
_gds = AuraGraphDataScience.create(
111+
_gds = AuraGraphDataScienceFactory(
111112
session_bolt_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]),
112113
arrow_authentication=UsernamePasswordAuthentication(AUTH[0], AUTH[1]),
113114
db_endpoint=None,
114115
delete_fn=lambda: True,
115-
)
116+
).create()
116117

117118
yield _gds
118119

graphdatascience/tests/integration/test_remote_graph_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_remote_projection_and_writeback_custom_database_name(gds_with_cloud_set
6262
assert projection_result["nodeCount"] == 2
6363
assert projection_result["relationshipCount"] == 1
6464

65-
write_result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc")
65+
write_result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") # type: ignore
6666

6767
assert write_result["nodePropertiesWritten"] == 2
6868
count_wcc_nodes_query = "MATCH (n WHERE n.wcc IS NOT NULL) RETURN count(*) AS c"
@@ -234,6 +234,6 @@ def test_empty_graph_write_back(
234234

235235
assert G.node_count() == 0
236236

237-
result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc")
237+
result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") # type: ignore
238238

239239
assert result["nodePropertiesWritten"] == 0

0 commit comments

Comments
 (0)