Skip to content

Commit

Permalink
Add both local and dist trainer, and a Data class to make example s…
Browse files Browse the repository at this point in the history
…impler

Signed-off-by: Tao He <sighingnow@gmail.com>
  • Loading branch information
sighingnow committed Nov 14, 2022
1 parent 711f677 commit 10ad969
Show file tree
Hide file tree
Showing 25 changed files with 1,190 additions and 194 deletions.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,13 @@ dynamic_graph_service/dist/
**/_static
**/_templates
**/apis/apis/apis

# ignore logs
**/*.INFO
**/*.WARNING
**/*.ERROR
**/*.FATAL
**/*.log.INFO.*
**/*.log.WARNING.*
**/*.log.ERROR.*
**/*.log.FATAL.*
2 changes: 2 additions & 0 deletions graphlearn/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ examples/tf/*.tar.gz
# core dump
/core

# ignore generated __init__.py
/__init__.py
19 changes: 13 additions & 6 deletions graphlearn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,12 @@ endif()

macro (get_target_import_location var target)
if (TARGET ${target})
get_target_property(${var} ${target} IMPORTED_LOCATION)
if ("${${var}}" STREQUAL "${var}-NOTFOUND")
get_target_property(${var} ${target} IMPORTED_LOCATION_RELEASE)
endif ()
foreach (prop IMPORTED_LOCATION IMPORTED_LOCATION_NOCONFIG IMPORTED_LOCATION_DEBUG IMPORTED_LOCATION_RELEASE)
get_target_property (${var} ${target} ${prop})
if (NOT ("${${var}}" STREQUAL "${var}-NOTFOUND"))
break ()
endif ()
endforeach ()
endif ()
endmacro ()

Expand Down Expand Up @@ -335,6 +337,10 @@ else()
set (PROFILING_FLAG CLOSE)
endif()

if (WITH_VINEYARD)
set (GL_CXX_DIALECT "c++14")
endif()

if (WITH_HIACTOR)
set (GL_CXX_DIALECT "gnu++17")
endif()
Expand Down Expand Up @@ -487,6 +493,7 @@ if (TESTING)
${GL_SERVICE_DIR}/*.cpp)
# fixme: disable thread_unittest now
list (FILTER GL_TEST_FILES EXCLUDE REGEX ".*/thread_unittest\\.cpp$")
list (FILTER GL_TEST_FILES EXCLUDE REGEX ".*/thread_dag_scheduler_unittest\\.cpp$") # unknown and unreproduceable coredump
list (FILTER GL_TEST_FILES EXCLUDE REGEX ".*/waitable_event_unittest\\.cpp$")
list (FILTER GL_TEST_FILES EXCLUDE REGEX ".*/vineyard_storage_unittest\\.cpp$")
add_gl_tests (${GL_TEST_FILES})
Expand All @@ -503,7 +510,7 @@ if (TESTING)
add_gl_tests (${ACTOR_TEST_FILES})
endif ()

if (WITH_VINEYARD)
if (WITH_VINEYARD AND TARGET vineyard_storage_unittest)
target_compile_options(vineyard_storage_unittest PRIVATE "-std=c++14")
endif ()
endif()
Expand Down Expand Up @@ -539,10 +546,10 @@ add_custom_command (TARGET python
COMMAND cp -f ${GL_SETUP_DIR}/gl.__init__.py ${GL_PYTHON_DIR}/__init__.py
COMMAND echo "__version__ = ${VERSION}" >> ${GL_PYTHON_DIR}/__init__.py
COMMAND echo "__git_version__ = '${GIT_BRANCH}-${GIT_VERSION}'" >> ${GL_PYTHON_DIR}/__init__.py
COMMAND OPEN_KNN=${KNN_FLAG} ${GL_PYTHON_BIN} ${GL_SETUP_DIR}/setup.py build_ext --inplace
COMMAND OPEN_KNN=${KNN_FLAG} ${GL_PYTHON_BIN} ${GL_SETUP_DIR}/setup.py bdist_wheel
COMMAND ${CMAKE_COMMAND} -E make_directory "${GL_BUILT_BIN_DIR}/ge_data/data"
COMMAND ${CMAKE_COMMAND} -E make_directory "${GL_BUILT_BIN_DIR}/ge_data/ckpt"
COMMAND rm -f ${GL_PYTHON_DIR}/__init__.py
WORKING_DIRECTORY ${GL_ROOT}
VERBATIM)

Expand Down
50 changes: 50 additions & 0 deletions graphlearn/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import os
import sys

try:
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))

try:
import graphlearn.python.nn.pytorch
from .pytorch.gcn.gcn import GCN as TorchGCN
except Exception:
pass

try:
import graphlearn.python.nn.tf

from .tf.trainer import LocalTrainer, DistTrainer

# backwards compatibility
LocalTFTrainer = LocalTrainer
DistTFTrainer = DistTrainer

from .tf.bipartite_sage.bipartite_sage import BipartiteGraphSAGE
from .tf.bipartite_sage.hetero_edge_inducer import HeteroEdgeInducer
from .tf.ego_bipartite_sage.ego_bipartite_sage import EgoBipartiteGraphSAGE
from .tf.ego_gat.ego_gat import EgoGAT
from .tf.ego_rgcn.ego_rgcn import EgoRGCN
from .tf.ego_sage.ego_sage import EgoGraphSAGE
from .tf.sage.edge_inducer import EdgeInducer
from .tf.seal.edge_cn_inducer import EdgeCNInducer
from .tf.ultra_gcn.ultra_gcn import UltraGCN
except:
pass

finally:
sys.path.pop(sys.path.index(os.path.join(os.path.dirname(__file__), '..', '..')))
14 changes: 14 additions & 0 deletions graphlearn/examples/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2021-2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
14 changes: 14 additions & 0 deletions graphlearn/examples/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2021-2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
17 changes: 17 additions & 0 deletions graphlearn/examples/tf/bipartite_sage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2021-2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from .bipartite_sage import BipartiteGraphSAGE
from .hetero_edge_inducer import HeteroEdgeInducer
16 changes: 16 additions & 0 deletions graphlearn/examples/tf/ego_bipartite_sage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2021-2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from .ego_bipartite_sage import EgoBipartiteGraphSAGE
99 changes: 99 additions & 0 deletions graphlearn/examples/tf/ego_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import datetime
import json
import os
import sys

import numpy as np
try:
# https://www.tensorflow.org/guide/migrate
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf

import graphlearn as gl
import graphlearn.python.nn.tf as tfg
from graphlearn.python.utils import parse_nbrs_num

class EgoData:
def __init__(self, graph, model, nbrs_num=None, sampler='random',
train_batch_size=128, test_batch_size=128, val_batch_size=128):
self.graph = graph
self.model = model
self.nbrs_num = parse_nbrs_num(nbrs_num)
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.val_batch_size = val_batch_size
self.sampler = sampler

# train
tfg.conf.training = True
self.query_train = self.query(self.graph, gl.Mask.TRAIN)
self.dataset_train = tfg.Dataset(self.query_train, window=10)
self.train_iterator = self.dataset_train.iterator
self.train_dict = self.dataset_train.get_data_dict()
self.train_embedding = self.model.forward(
self.reformat_node_feature(
self.train_dict,
self.query_train.list_alias(),
tfg.FeatureHandler('feature_handler', self.query_train.get_node("train").decoder.feature_spec),
),
self.nbrs_num
)

# test
tfg.conf.training = False
self.query_test = self.query(self.graph, gl.Mask.TEST)
self.dataset_test = tfg.Dataset(self.query_test, window=10)
self.test_iterator = self.dataset_test.iterator
self.test_dict = self.dataset_test.get_data_dict()
self.test_embedding = self.model.forward(
self.reformat_node_feature(
self.test_dict,
self.query_test.list_alias(),
tfg.FeatureHandler('feature_handler', self.query_test.get_node("test").decoder.feature_spec),
),
self.nbrs_num
)

# val
tfg.conf.training = False
self.query_val = self.query(self.graph, gl.Mask.VAL)
self.dataset_val = tfg.Dataset(self.query_val, window=10)
self.val_iterator = self.dataset_val.iterator
self.val_dict = self.dataset_val.get_data_dict()
self.val_embedding = self.model.forward(
self.reformat_node_feature(
self.val_dict,
self.query_val.list_alias(),
tfg.FeatureHandler('feature_handler', self.query_val.get_node("val").decoder.feature_spec),
),
self.nbrs_num
)

def query(self, graph, mask=gl.Mask.TRAIN):
"""
"""

def reformat_node_feature(self, data_dict, alias_list, feature_handler):
"""
"""
16 changes: 16 additions & 0 deletions graphlearn/examples/tf/ego_gat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2021-2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from .ego_gat import EgoGAT
19 changes: 19 additions & 0 deletions graphlearn/examples/tf/ego_gcn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# [GCN](https://arxiv.org/abs/1609.02907)
## Introduction
Here we implement fix-sized neighbor sampling based GCN.

## How to run
### Node classification
Here we use cora as an example,

1. Prepare data
```shell script
cd ../../data/
python cora.py
```

2. Train
```shell script
cd ../tf/ego_gcn/
python train_supervised.py
```
17 changes: 17 additions & 0 deletions graphlearn/examples/tf/ego_gcn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2021-2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from .ego_gcn import EgoGCN
from .ego_gcn_data import EgoGCNData
Loading

0 comments on commit 10ad969

Please sign in to comment.