Skip to content

Commit 6902aea

Browse files
authored
Merge pull request #942 from neo4j/sampling_endpoints
AV2 - graph.sample endpoints
2 parents 4ab5d2f + c24762a commit 6902aea

File tree

8 files changed

+596
-0
lines changed

8 files changed

+596
-0
lines changed

graphdatascience/procedure_surface/api/catalog_endpoints.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from graphdatascience import Graph
1111
from graphdatascience.procedure_surface.api.base_result import BaseResult
12+
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints
1213

1314

1415
class CatalogEndpoints(ABC):
@@ -65,6 +66,11 @@ def filter(
6566
"""
6667
pass
6768

69+
@property
70+
@abstractmethod
71+
def sample(self) -> GraphSamplingEndpoints:
72+
pass
73+
6874

6975
class GraphListResult(BaseResult):
7076
graph_name: str
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import List, Optional
5+
6+
from graphdatascience import Graph
7+
from graphdatascience.procedure_surface.api.base_result import BaseResult
8+
9+
10+
class GraphSamplingEndpoints(ABC):
11+
"""
12+
Abstract base class defining the API for graph sampling algorithms algorithm.
13+
"""
14+
15+
@abstractmethod
16+
def rwr(
17+
self,
18+
G: Graph,
19+
graph_name: str,
20+
start_nodes: Optional[List[int]] = None,
21+
restart_probability: Optional[float] = None,
22+
sampling_ratio: Optional[float] = None,
23+
node_label_stratification: Optional[bool] = None,
24+
relationship_weight_property: Optional[str] = None,
25+
relationship_types: Optional[List[str]] = None,
26+
node_labels: Optional[List[str]] = None,
27+
sudo: Optional[bool] = None,
28+
log_progress: Optional[bool] = None,
29+
username: Optional[str] = None,
30+
concurrency: Optional[int] = None,
31+
job_id: Optional[str] = None,
32+
) -> GraphSamplingResult:
33+
"""
34+
Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog.
35+
36+
This method performs a random walk, beginning from a set of nodes (if provided),
37+
where at each step there is a probability to restart back at the original nodes.
38+
The result is turned into a new graph induced by the random walks and stored in the catalog.
39+
40+
Parameters
41+
----------
42+
G : Graph
43+
The input graph on which the Random Walk with Restart (RWR) will be
44+
performed.
45+
graph_name : str
46+
The name of the new graph in the catalog.
47+
start_nodes : list of int, optional
48+
A list of node IDs to start the random walk from. If not provided, all
49+
nodes are used as potential starting points.
50+
restart_probability : float, optional
51+
The probability of restarting back to the original node at each step.
52+
Should be a value between 0 and 1. If not specified, a default value is used.
53+
sampling_ratio : float, optional
54+
The ratio of nodes to sample during the computation. This value should
55+
be between 0 and 1. If not specified, no sampling is performed.
56+
node_label_stratification : bool, optional
57+
If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph.
58+
relationship_weight_property : str, optional
59+
The name of the property on relationships to use as weights during
60+
the random walk. If not specified, the relationships are treated as
61+
unweighted.
62+
relationship_types : list of str, optional
63+
The relationship types used to select relationships for this algorithm run.
64+
node_labels : list of str, optional
65+
The node labels used to select nodes for this algorithm run.
66+
sudo : bool, optional
67+
Override memory estimation limits. Use with caution as this can lead to
68+
memory issues if the estimation is significantly wrong.
69+
log_progress : bool, optional
70+
If True, logs the progress of the computation.
71+
username : str, optional
72+
The username to attribute the procedure run to
73+
concurrency : int, optional
74+
The number of concurrent threads used for the algorithm execution.
75+
job_id : str, optional
76+
An identifier for the job that can be used for monitoring and cancellation
77+
78+
Returns
79+
-------
80+
GraphSamplingResult
81+
The result of the Random Walk with Restart (RWR), including the sampled
82+
nodes and their scores.
83+
"""
84+
pass
85+
86+
@abstractmethod
87+
def cnarw(
88+
self,
89+
G: Graph,
90+
graph_name: str,
91+
start_nodes: Optional[List[int]] = None,
92+
restart_probability: Optional[float] = None,
93+
sampling_ratio: Optional[float] = None,
94+
node_label_stratification: Optional[bool] = None,
95+
relationship_weight_property: Optional[str] = None,
96+
relationship_types: Optional[List[str]] = None,
97+
node_labels: Optional[List[str]] = None,
98+
sudo: Optional[bool] = None,
99+
log_progress: Optional[bool] = None,
100+
username: Optional[str] = None,
101+
concurrency: Optional[int] = None,
102+
job_id: Optional[str] = None,
103+
) -> GraphSamplingResult:
104+
"""
105+
Computes a set of Random Walks with Restart (RWR) for the given graph and stores the result as a new graph in the catalog.
106+
107+
This method performs a random walk, beginning from a set of nodes (if provided),
108+
where at each step there is a probability to restart back at the original nodes.
109+
The result is turned into a new graph induced by the random walks and stored in the catalog.
110+
111+
Parameters
112+
----------
113+
G : Graph
114+
The input graph on which the Random Walk with Restart (RWR) will be
115+
performed.
116+
graph_name : str
117+
The name of the new graph in the catalog.
118+
start_nodes : list of int, optional
119+
A list of node IDs to start the random walk from. If not provided, all
120+
nodes are used as potential starting points.
121+
restart_probability : float, optional
122+
The probability of restarting back to the original node at each step.
123+
Should be a value between 0 and 1. If not specified, a default value is used.
124+
sampling_ratio : float, optional
125+
The ratio of nodes to sample during the computation. This value should
126+
be between 0 and 1. If not specified, no sampling is performed.
127+
node_label_stratification : bool, optional
128+
If True, the algorithm tries to preserve the label distribution of the original graph in the sampled graph.
129+
relationship_weight_property : str, optional
130+
The name of the property on relationships to use as weights during
131+
the random walk. If not specified, the relationships are treated as
132+
unweighted.
133+
relationship_types : list of str, optional
134+
The relationship types used to select relationships for this algorithm run.
135+
node_labels : list of str, optional
136+
The node labels used to select nodes for this algorithm run.
137+
sudo : bool, optional
138+
Override memory estimation limits. Use with caution as this can lead to
139+
memory issues if the estimation is significantly wrong.
140+
log_progress : bool, optional
141+
If True, logs the progress of the computation.
142+
username : str, optional
143+
The username to attribute the procedure run to
144+
concurrency : int, optional
145+
The number of concurrent threads used for the algorithm execution.
146+
job_id : str, optional
147+
An identifier for the job that can be used for monitoring and cancellation
148+
149+
Returns
150+
-------
151+
GraphSamplingResult
152+
The result of the Random Walk with Restart (RWR), including the sampled
153+
nodes and their scores.
154+
"""
155+
pass
156+
157+
158+
class GraphSamplingResult(BaseResult):
159+
graph_name: str
160+
from_graph_name: str
161+
node_count: int
162+
relationship_count: int
163+
start_node_count: int
164+
project_millis: int

graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
GraphFilterResult,
1414
GraphListResult,
1515
)
16+
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints
17+
from graphdatascience.procedure_surface.arrow.graph_sampling_arrow_endpoints import GraphSamplingArrowEndpoints
1618
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
1719
from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol
1820
from graphdatascience.query_runner.termination_flag import TerminationFlag
@@ -116,6 +118,10 @@ def filter(
116118

117119
return GraphFilterResult(**JobClient.get_summary(self._arrow_client, job_id))
118120

121+
@property
122+
def sample(self) -> GraphSamplingEndpoints:
123+
return GraphSamplingArrowEndpoints(self._arrow_client)
124+
119125
def _arrow_config(self) -> dict[str, Any]:
120126
connection_info = self._arrow_client.advertised_connection_info()
121127

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, List, Optional
4+
5+
from graphdatascience import Graph
6+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
7+
from graphdatascience.arrow_client.v2.job_client import JobClient
8+
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import (
9+
GraphSamplingEndpoints,
10+
GraphSamplingResult,
11+
)
12+
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
13+
14+
15+
class GraphSamplingArrowEndpoints(GraphSamplingEndpoints):
16+
def __init__(self, arrow_client: AuthenticatedArrowClient):
17+
self._arrow_client = arrow_client
18+
19+
def rwr(
20+
self,
21+
G: Graph,
22+
graph_name: str,
23+
start_nodes: Optional[List[int]] = None,
24+
restart_probability: Optional[float] = None,
25+
sampling_ratio: Optional[float] = None,
26+
node_label_stratification: Optional[bool] = None,
27+
relationship_weight_property: Optional[str] = None,
28+
relationship_types: Optional[List[str]] = None,
29+
node_labels: Optional[List[str]] = None,
30+
sudo: Optional[bool] = None,
31+
log_progress: Optional[bool] = None,
32+
username: Optional[str] = None,
33+
concurrency: Optional[Any] = None,
34+
job_id: Optional[Any] = None,
35+
) -> GraphSamplingResult:
36+
config = ConfigConverter.convert_to_gds_config(
37+
from_graph_name=G.name(),
38+
graph_name=graph_name,
39+
start_nodes=start_nodes,
40+
restart_probability=restart_probability,
41+
sampling_ratio=sampling_ratio,
42+
node_label_stratification=node_label_stratification,
43+
relationship_weight_property=relationship_weight_property,
44+
relationship_types=relationship_types,
45+
node_labels=node_labels,
46+
sudo=sudo,
47+
log_progress=log_progress,
48+
username=username,
49+
concurrency=concurrency,
50+
job_id=job_id,
51+
)
52+
53+
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.rwr", config)
54+
55+
return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id))
56+
57+
def cnarw(
58+
self,
59+
G: Graph,
60+
graph_name: str,
61+
start_nodes: Optional[List[int]] = None,
62+
restart_probability: Optional[float] = None,
63+
sampling_ratio: Optional[float] = None,
64+
node_label_stratification: Optional[bool] = None,
65+
relationship_weight_property: Optional[str] = None,
66+
relationship_types: Optional[List[str]] = None,
67+
node_labels: Optional[List[str]] = None,
68+
sudo: Optional[bool] = None,
69+
log_progress: Optional[bool] = None,
70+
username: Optional[str] = None,
71+
concurrency: Optional[Any] = None,
72+
job_id: Optional[Any] = None,
73+
) -> GraphSamplingResult:
74+
config = ConfigConverter.convert_to_gds_config(
75+
from_graph_name=G.name(),
76+
graph_name=graph_name,
77+
start_nodes=start_nodes,
78+
restart_probability=restart_probability,
79+
sampling_ratio=sampling_ratio,
80+
node_label_stratification=node_label_stratification,
81+
relationship_weight_property=relationship_weight_property,
82+
relationship_types=relationship_types,
83+
node_labels=node_labels,
84+
sudo=sudo,
85+
log_progress=log_progress,
86+
username=username,
87+
concurrency=concurrency,
88+
job_id=job_id,
89+
)
90+
91+
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.cnarw", config)
92+
93+
return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id))
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, List, Optional
4+
5+
from ...call_parameters import CallParameters
6+
from ...graph.graph_object import Graph
7+
from ...query_runner.query_runner import QueryRunner
8+
from ..api.graph_sampling_endpoints import GraphSamplingEndpoints, GraphSamplingResult
9+
from ..utils.config_converter import ConfigConverter
10+
11+
12+
class GraphSamplingCypherEndpoints(GraphSamplingEndpoints):
13+
def __init__(self, query_runner: QueryRunner):
14+
self._query_runner = query_runner
15+
16+
def rwr(
17+
self,
18+
G: Graph,
19+
graph_name: str,
20+
start_nodes: Optional[List[int]] = None,
21+
restart_probability: Optional[float] = None,
22+
sampling_ratio: Optional[float] = None,
23+
node_label_stratification: Optional[bool] = None,
24+
relationship_weight_property: Optional[str] = None,
25+
relationship_types: Optional[List[str]] = None,
26+
node_labels: Optional[List[str]] = None,
27+
sudo: Optional[bool] = None,
28+
log_progress: Optional[bool] = None,
29+
username: Optional[str] = None,
30+
concurrency: Optional[Any] = None,
31+
job_id: Optional[Any] = None,
32+
) -> GraphSamplingResult:
33+
config = ConfigConverter.convert_to_gds_config(
34+
start_nodes=start_nodes,
35+
restart_probability=restart_probability,
36+
sampling_ratio=sampling_ratio,
37+
node_label_stratification=node_label_stratification,
38+
relationship_weight_property=relationship_weight_property,
39+
relationship_types=relationship_types,
40+
node_labels=node_labels,
41+
sudo=sudo,
42+
log_progress=log_progress,
43+
username=username,
44+
concurrency=concurrency,
45+
job_id=job_id,
46+
)
47+
48+
params = CallParameters(
49+
graph_name=graph_name,
50+
from_graph_name=G.name(),
51+
config=config,
52+
)
53+
params.ensure_job_id_in_config()
54+
55+
result = self._query_runner.call_procedure(endpoint="gds.graph.sample.rwr", params=params).squeeze()
56+
return GraphSamplingResult(**result.to_dict())
57+
58+
def cnarw(
59+
self,
60+
G: Graph,
61+
graph_name: str,
62+
start_nodes: Optional[List[int]] = None,
63+
restart_probability: Optional[float] = None,
64+
sampling_ratio: Optional[float] = None,
65+
node_label_stratification: Optional[bool] = None,
66+
relationship_weight_property: Optional[str] = None,
67+
relationship_types: Optional[List[str]] = None,
68+
node_labels: Optional[List[str]] = None,
69+
sudo: Optional[bool] = None,
70+
log_progress: Optional[bool] = None,
71+
username: Optional[str] = None,
72+
concurrency: Optional[Any] = None,
73+
job_id: Optional[Any] = None,
74+
) -> GraphSamplingResult:
75+
config = ConfigConverter.convert_to_gds_config(
76+
start_nodes=start_nodes,
77+
restart_probability=restart_probability,
78+
sampling_ratio=sampling_ratio,
79+
node_label_stratification=node_label_stratification,
80+
relationship_weight_property=relationship_weight_property,
81+
relationship_types=relationship_types,
82+
node_labels=node_labels,
83+
sudo=sudo,
84+
log_progress=log_progress,
85+
username=username,
86+
concurrency=concurrency,
87+
job_id=job_id,
88+
)
89+
90+
params = CallParameters(
91+
graph_name=graph_name,
92+
from_graph_name=G.name(),
93+
config=config,
94+
)
95+
params.ensure_job_id_in_config()
96+
97+
result = self._query_runner.call_procedure(endpoint="gds.graph.sample.cnarw", params=params).squeeze()
98+
return GraphSamplingResult(**result.to_dict())

0 commit comments

Comments
 (0)