1
1
import json
2
2
from typing import Any , List
3
+ import time
3
4
4
5
from ..error .illegal_attr_checker import IllegalAttrChecker
5
6
from ..error .uncallable_namespace import UncallableNamespace
@@ -21,6 +22,21 @@ def make_graph_sage_config(self, graph_sage_config):
21
22
final_sage_config .update (graph_sage_config )
22
23
return final_sage_config
23
24
25
+ def get_logs (self , job_id : str , offset = 0 ) -> "Series[Any]" : # noqa: F821
26
+ return self ._query_runner .run_query (
27
+ "RETURN gds.remoteml.getLogs($job_id, $offset)" ,
28
+ params = {
29
+ "job_id" : job_id ,
30
+ "offset" : offset
31
+ }).squeeze ()
32
+
33
+ def get_train_result (self , model_name : str ) -> "Series[Any]" : # noqa: F821
34
+ return self ._query_runner .run_query (
35
+ "RETURN gds.remoteml.getTrainResult($model_name)" ,
36
+ params = {
37
+ "model_name" : model_name
38
+ }).squeeze ()
39
+
24
40
def train (
25
41
self ,
26
42
graph_name : str ,
@@ -30,7 +46,8 @@ def train(
30
46
relationship_types : List [str ],
31
47
target_node_label : str = None ,
32
48
node_labels : List [str ] = None ,
33
- graph_sage_config = None
49
+ graph_sage_config = None ,
50
+ logging_interval : int = 5
34
51
) -> "Series[Any]" : # noqa: F821
35
52
mlConfigMap = {
36
53
"featureProperties" : feature_properties ,
@@ -49,19 +66,36 @@ def train(
49
66
mlTrainingConfig = json .dumps (mlConfigMap )
50
67
51
68
# token and uri will be injected by arrow_query_runner
52
- self ._query_runner .run_query (
53
- "CALL gds.upload.graph($config)" ,
69
+ job_id = self ._query_runner .run_query (
70
+ "CALL gds.upload.graph($config) YIELD jobId " ,
54
71
params = {
55
72
"config" : {"mlTrainingConfig" : mlTrainingConfig , "graphName" : graph_name , "modelName" : model_name },
56
73
},
57
- )
74
+ ).jobId [0 ]
75
+
76
+ received_logs = 0
77
+ training_done = False
78
+ while not training_done :
79
+ for log in self .get_logs (job_id , offset = received_logs ):
80
+ print (log )
81
+ received_logs += 1
82
+ try :
83
+ self .get_train_result (model_name )
84
+ training_done = True
85
+ except Exception :
86
+ time .sleep (logging_interval )
87
+
88
+ return job_id
89
+
90
+
58
91
59
92
def predict (
60
93
self ,
61
94
graph_name : str ,
62
95
model_name : str ,
63
96
mutateProperty : str ,
64
97
predictedProbabilityProperty : str = None ,
98
+ logging_interval = 5
65
99
) -> "Series[Any]" : # noqa: F821
66
100
mlConfigMap = {
67
101
"job_type" : "predict" ,
@@ -71,9 +105,20 @@ def predict(
71
105
mlConfigMap ["predictedProbabilityProperty" ] = predictedProbabilityProperty
72
106
73
107
mlTrainingConfig = json .dumps (mlConfigMap )
74
- self ._query_runner .run_query (
75
- "CALL gds.upload.graph($config)" ,
108
+ job_id = self ._query_runner .run_query (
109
+ "CALL gds.upload.graph($config) YIELD jobId " ,
76
110
params = {
77
111
"config" : {"mlTrainingConfig" : mlTrainingConfig , "graphName" : graph_name , "modelName" : model_name },
78
112
},
79
- ) # type: ignore
113
+ ).jobId [0 ]
114
+
115
+ received_logs = 0
116
+ prediction_done = False
117
+ while not prediction_done :
118
+ for log in self .get_logs (job_id , offset = received_logs ):
119
+ print (log )
120
+ received_logs += 1
121
+ if log == "Prediction job completed" :
122
+ prediction_done = True
123
+ if not prediction_done :
124
+ time .sleep (logging_interval )
0 commit comments