Skip to content

Commit 8e83471

Browse files
committed
WIP
1 parent e684959 commit 8e83471

File tree

1 file changed

+52
-7
lines changed

1 file changed

+52
-7
lines changed

graphdatascience/gnn/gnn_nc_runner.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from typing import Any, List
3+
import time
34

45
from ..error.illegal_attr_checker import IllegalAttrChecker
56
from ..error.uncallable_namespace import UncallableNamespace
@@ -21,6 +22,21 @@ def make_graph_sage_config(self, graph_sage_config):
2122
final_sage_config.update(graph_sage_config)
2223
return final_sage_config
2324

25+
def get_logs(self, job_id: str, offset=0) -> "Series[Any]": # noqa: F821
26+
return self._query_runner.run_query(
27+
"RETURN gds.remoteml.getLogs($job_id, $offset)",
28+
params={
29+
"job_id": job_id,
30+
"offset": offset
31+
}).squeeze()
32+
33+
def get_train_result(self, model_name: str) -> "Series[Any]": # noqa: F821
34+
return self._query_runner.run_query(
35+
"RETURN gds.remoteml.getTrainResult($model_name)",
36+
params={
37+
"model_name": model_name
38+
}).squeeze()
39+
2440
def train(
2541
self,
2642
graph_name: str,
@@ -30,7 +46,8 @@ def train(
3046
relationship_types: List[str],
3147
target_node_label: str = None,
3248
node_labels: List[str] = None,
33-
graph_sage_config = None
49+
graph_sage_config = None,
50+
logging_interval: int = 5
3451
) -> "Series[Any]": # noqa: F821
3552
mlConfigMap = {
3653
"featureProperties": feature_properties,
@@ -49,19 +66,36 @@ def train(
4966
mlTrainingConfig = json.dumps(mlConfigMap)
5067

5168
# token and uri will be injected by arrow_query_runner
52-
self._query_runner.run_query(
53-
"CALL gds.upload.graph($config)",
69+
job_id = self._query_runner.run_query(
70+
"CALL gds.upload.graph($config) YIELD jobId",
5471
params={
5572
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
5673
},
57-
)
74+
).jobId[0]
75+
76+
received_logs = 0
77+
training_done = False
78+
while not training_done:
79+
for log in self.get_logs(job_id, offset=received_logs):
80+
print(log)
81+
received_logs += 1
82+
try:
83+
self.get_train_result(model_name)
84+
training_done = True
85+
except Exception:
86+
time.sleep(logging_interval)
87+
88+
return job_id
89+
90+
5891

5992
def predict(
6093
self,
6194
graph_name: str,
6295
model_name: str,
6396
mutateProperty: str,
6497
predictedProbabilityProperty: str = None,
98+
logging_interval = 5
6599
) -> "Series[Any]": # noqa: F821
66100
mlConfigMap = {
67101
"job_type": "predict",
@@ -71,9 +105,20 @@ def predict(
71105
mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty
72106

73107
mlTrainingConfig = json.dumps(mlConfigMap)
74-
self._query_runner.run_query(
75-
"CALL gds.upload.graph($config)",
108+
job_id = self._query_runner.run_query(
109+
"CALL gds.upload.graph($config) YIELD jobId",
76110
params={
77111
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
78112
},
79-
) # type: ignore
113+
).jobId[0]
114+
115+
received_logs = 0
116+
prediction_done = False
117+
while not prediction_done:
118+
for log in self.get_logs(job_id, offset=received_logs):
119+
print(log)
120+
received_logs += 1
121+
if log == "Prediction job completed":
122+
prediction_done = True
123+
if not prediction_done:
124+
time.sleep(logging_interval)

0 commit comments

Comments
 (0)