1
1
from __future__ import annotations
2
2
3
+ import os
3
4
from typing import Any , Callable , Optional , Union
4
5
5
6
from pandas import DataFrame
6
7
7
8
from graphdatascience import QueryRunner , ServerVersion
8
9
from graphdatascience .arrow_client .arrow_authentication import ArrowAuthentication
9
10
from graphdatascience .arrow_client .arrow_info import ArrowInfo
11
+ from graphdatascience .arrow_client .authenticated_arrow_client import AuthenticatedArrowClient
12
+ from graphdatascience .arrow_client .v2 .write_back_client import WriteBackClient
10
13
from graphdatascience .call_builder import IndirectCallBuilder
11
14
from graphdatascience .endpoints import (
12
15
AlphaRemoteEndpoints ,
15
18
)
16
19
from graphdatascience .error .uncallable_namespace import UncallableNamespace
17
20
from graphdatascience .graph .graph_remote_proc_runner import GraphRemoteProcRunner
21
+ from graphdatascience .procedure_surface .api .wcc_endpoints import WccEndpoints
22
+ from graphdatascience .procedure_surface .arrow .wcc_arrow_endpoints import WccArrowEndpoints
18
23
from graphdatascience .query_runner .arrow_query_runner import ArrowQueryRunner
19
24
from graphdatascience .query_runner .gds_arrow_client import GdsArrowClient
20
25
from graphdatascience .query_runner .neo4j_query_runner import Neo4jQueryRunner
24
29
from graphdatascience .utils .util_remote_proc_runner import UtilRemoteProcRunner
25
30
26
31
27
- class AuraGraphDataScience (DirectEndpoints , UncallableNamespace ):
28
- """
29
- Primary API class for interacting with Neo4j database + Graph Data Science Session.
30
- Always bind this object to a variable called `gds`.
31
- """
32
+ class AuraGraphDataScienceFactory :
33
+ """Factory class for creating AuraGraphDataScience instances with all required components."""
32
34
33
- @classmethod
34
- def create (
35
- cls ,
35
+ def __init__ (
36
+ self ,
36
37
session_bolt_connection_info : DbmsConnectionInfo ,
37
38
arrow_authentication : Optional [ArrowAuthentication ],
38
39
db_endpoint : Optional [Union [Neo4jQueryRunner , DbmsConnectionInfo ]],
@@ -41,74 +42,114 @@ def create(
41
42
arrow_tls_root_certs : Optional [bytes ] = None ,
42
43
bookmarks : Optional [Any ] = None ,
43
44
show_progress : bool = True ,
44
- ) -> AuraGraphDataScience :
45
- session_bolt_query_runner = Neo4jQueryRunner .create_for_session (
46
- endpoint = session_bolt_connection_info .uri ,
47
- auth = session_bolt_connection_info .get_auth (),
48
- show_progress = show_progress ,
45
+ ):
46
+ self .session_bolt_connection_info = session_bolt_connection_info
47
+ self .arrow_authentication = arrow_authentication
48
+ self .db_endpoint = db_endpoint
49
+ self .delete_fn = delete_fn
50
+ self .arrow_disable_server_verification = arrow_disable_server_verification
51
+ self .arrow_tls_root_certs = arrow_tls_root_certs
52
+ self .bookmarks = bookmarks
53
+ self .show_progress = show_progress
54
+
55
+ def create (self ) -> AuraGraphDataScience :
56
+ """Create and configure an AuraGraphDataScience instance."""
57
+ session_bolt_query_runner = self ._create_session_bolt_query_runner ()
58
+ arrow_info = ArrowInfo .create (session_bolt_query_runner )
59
+ session_arrow_query_runner = self ._create_session_arrow_query_runner (session_bolt_query_runner , arrow_info )
60
+ session_arrow_client = self ._create_session_arrow_client (arrow_info , session_bolt_query_runner )
61
+ gds_version = session_bolt_query_runner .server_version ()
62
+
63
+ session_query_runner : QueryRunner
64
+
65
+ if self .db_endpoint is not None :
66
+ db_bolt_query_runner = self ._create_db_bolt_query_runner ()
67
+ session_query_runner = SessionQueryRunner .create (
68
+ session_arrow_query_runner , db_bolt_query_runner , session_arrow_client , self .show_progress
69
+ )
70
+ wcc_endpoints = self ._create_wcc_endpoints (arrow_info , session_bolt_query_runner , db_bolt_query_runner )
71
+ else :
72
+ session_query_runner = StandaloneSessionQueryRunner (session_arrow_query_runner )
73
+ wcc_endpoints = self ._create_wcc_endpoints (arrow_info , session_bolt_query_runner , None )
74
+
75
+ return AuraGraphDataScience (
76
+ query_runner = session_query_runner ,
77
+ wcc_endpoints = wcc_endpoints ,
78
+ delete_fn = self .delete_fn ,
79
+ gds_version = gds_version ,
49
80
)
50
81
51
- arrow_info = ArrowInfo .create (session_bolt_query_runner )
52
- session_arrow_query_runner = ArrowQueryRunner .create (
82
+ def _create_session_bolt_query_runner (self ) -> Neo4jQueryRunner :
83
+ return Neo4jQueryRunner .create_for_session (
84
+ endpoint = self .session_bolt_connection_info .uri ,
85
+ auth = self .session_bolt_connection_info .get_auth (),
86
+ show_progress = self .show_progress ,
87
+ )
88
+
89
+ def _create_session_arrow_query_runner (
90
+ self , session_bolt_query_runner : Neo4jQueryRunner , arrow_info : ArrowInfo
91
+ ) -> ArrowQueryRunner :
92
+ return ArrowQueryRunner .create (
53
93
fallback_query_runner = session_bolt_query_runner ,
54
94
arrow_info = arrow_info ,
55
- arrow_authentication = arrow_authentication ,
95
+ arrow_authentication = self . arrow_authentication ,
56
96
encrypted = session_bolt_query_runner .encrypted (),
57
- disable_server_verification = arrow_disable_server_verification ,
58
- tls_root_certs = arrow_tls_root_certs ,
97
+ disable_server_verification = self . arrow_disable_server_verification ,
98
+ tls_root_certs = self . arrow_tls_root_certs ,
59
99
)
60
100
61
- # TODO: merge with the gds_arrow_client created inside ArrowQueryRunner
62
- session_arrow_client = GdsArrowClient .create (
101
+ def _create_session_arrow_client (
102
+ self , arrow_info : ArrowInfo , session_bolt_query_runner : Neo4jQueryRunner
103
+ ) -> GdsArrowClient :
104
+ return GdsArrowClient .create (
63
105
arrow_info ,
64
- arrow_authentication ,
106
+ self . arrow_authentication ,
65
107
session_bolt_query_runner .encrypted (),
66
- arrow_disable_server_verification ,
67
- arrow_tls_root_certs ,
108
+ self . arrow_disable_server_verification ,
109
+ self . arrow_tls_root_certs ,
68
110
)
69
111
70
- gds_version = session_bolt_query_runner .server_version ()
71
-
72
- if db_endpoint is not None :
73
- if isinstance (db_endpoint , Neo4jQueryRunner ):
74
- db_bolt_query_runner = db_endpoint
75
- else :
76
- db_bolt_query_runner = Neo4jQueryRunner .create_for_db (
77
- db_endpoint .uri ,
78
- db_endpoint .get_auth (),
79
- aura_ds = True ,
80
- show_progress = False ,
81
- database = db_endpoint .database ,
82
- )
83
- db_bolt_query_runner .set_bookmarks (bookmarks )
84
-
85
- session_query_runner = SessionQueryRunner .create (
86
- session_arrow_query_runner , db_bolt_query_runner , session_arrow_client , show_progress
87
- )
88
- return cls (
89
- query_runner = session_query_runner ,
90
- delete_fn = delete_fn ,
91
- gds_version = gds_version ,
112
+ def _create_db_bolt_query_runner (self ) -> Neo4jQueryRunner :
113
+ if isinstance (self .db_endpoint , Neo4jQueryRunner ):
114
+ db_bolt_query_runner = self .db_endpoint
115
+ elif isinstance (self .db_endpoint , DbmsConnectionInfo ):
116
+ db_bolt_query_runner = Neo4jQueryRunner .create_for_db (
117
+ self .db_endpoint .uri ,
118
+ self .db_endpoint .get_auth (),
119
+ aura_ds = True ,
120
+ show_progress = False ,
121
+ database = self .db_endpoint .database ,
92
122
)
93
123
else :
94
- standalone_query_runner = StandaloneSessionQueryRunner (session_arrow_query_runner )
95
- return cls (
96
- query_runner = standalone_query_runner ,
97
- delete_fn = delete_fn ,
98
- gds_version = gds_version ,
124
+ raise ValueError ("db_endpoint must be a Neo4jQueryRunner or a DbmsConnectionInfo" )
125
+
126
+ db_bolt_query_runner .set_bookmarks (self .bookmarks )
127
+ return db_bolt_query_runner
128
+
129
+ def _create_wcc_endpoints (
130
+ self , arrow_info : ArrowInfo , session_bolt_query_runner : Neo4jQueryRunner , db_query_runner : Optional [QueryRunner ]
131
+ ) -> Optional [WccEndpoints ]:
132
+ wcc_endpoints : Optional [WccEndpoints ] = None
133
+ if os .environ .get ("ENABLE_EXPLICIT_ENDPOINTS" ) is not None :
134
+ arrow_client = AuthenticatedArrowClient .create (
135
+ arrow_info ,
136
+ self .arrow_authentication ,
137
+ session_bolt_query_runner .encrypted (),
138
+ self .arrow_disable_server_verification ,
139
+ self .arrow_tls_root_certs ,
99
140
)
100
141
101
- def __init__ (
102
- self ,
103
- query_runner : QueryRunner ,
104
- delete_fn : Callable [[], bool ],
105
- gds_version : ServerVersion ,
106
- ):
107
- self ._query_runner = query_runner
108
- self ._delete_fn = delete_fn
109
- self ._server_version = gds_version
142
+ write_back_client = WriteBackClient (arrow_client , db_query_runner ) if db_query_runner is not None else None
110
143
111
- super ().__init__ (self ._query_runner , namespace = "gds" , server_version = self ._server_version )
144
+ wcc_endpoints = WccArrowEndpoints (arrow_client , write_back_client )
145
+ return wcc_endpoints
146
+
147
+
148
+ class AuraGraphDataScience (DirectEndpoints , UncallableNamespace ):
149
+ """
150
+ Primary API class for interacting with Neo4j database + Graph Data Science Session.
151
+ Always bind this object to a variable called `gds`.
152
+ """
112
153
113
154
def run_cypher (
114
155
self ,
@@ -133,6 +174,20 @@ def run_cypher(
133
174
"""
134
175
return self ._query_runner .run_cypher (query , params , database , False )
135
176
177
+ def __init__ (
178
+ self ,
179
+ query_runner : QueryRunner ,
180
+ delete_fn : Callable [[], bool ],
181
+ gds_version : ServerVersion ,
182
+ wcc_endpoints : Optional [WccEndpoints ] = None ,
183
+ ):
184
+ self ._query_runner = query_runner
185
+ self ._delete_fn = delete_fn
186
+ self ._server_version = gds_version
187
+ self ._wcc_endpoints = wcc_endpoints
188
+
189
+ super ().__init__ (self ._query_runner , namespace = "gds" , server_version = self ._server_version )
190
+
136
191
@property
137
192
def graph (self ) -> GraphRemoteProcRunner :
138
193
return GraphRemoteProcRunner (self ._query_runner , f"{ self ._namespace } .graph" , self ._server_version )
@@ -149,6 +204,13 @@ def alpha(self) -> AlphaRemoteEndpoints:
149
204
def beta (self ) -> BetaEndpoints :
150
205
return BetaEndpoints (self ._query_runner , "gds.beta" , self ._server_version )
151
206
207
+ @property
208
+ def wcc (self ) -> Union [WccEndpoints , IndirectCallBuilder ]:
209
+ if self ._wcc_endpoints is None :
210
+ return IndirectCallBuilder (self ._query_runner , f"gds.{ self ._namespace } .wcc" , self ._server_version )
211
+
212
+ return self ._wcc_endpoints
213
+
152
214
def __getattr__ (self , attr : str ) -> IndirectCallBuilder :
153
215
return IndirectCallBuilder (self ._query_runner , f"gds.{ attr } " , self ._server_version )
154
216
0 commit comments