-
Notifications
You must be signed in to change notification settings - Fork 266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GraphScope on 3/3 part 1: add local/dist trainer, and a Data
class to make the example simpler
#234
Conversation
575a9ba
to
304df75
Compare
304df75
to
e624512
Compare
Data
class to make the example simpler
e624512
to
893effb
Compare
…impler Signed-off-by: Tao He <sighingnow@gmail.com>
893effb
to
10ad969
Compare
graphlearn/examples/tf/ego_data.py
Outdated
self.dataset_train = tfg.Dataset(self.query_train, window=10) | ||
self.train_iterator = self.dataset_train.iterator | ||
self.train_dict = self.dataset_train.get_data_dict() | ||
self.train_embedding = self.model.forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better not to encapsulate the model training into the Data
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is moving self.{train,val,test}_embedding
to outside and still keeping other field in Data
class acceptable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move all model-related data outside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
graphlearn/examples/tf/trainer.py
Outdated
writeGFile.close() | ||
print("Profiling data save to %s success." % save_path) | ||
if self.profiling: | ||
outs = self.run_and_profiling(train_ops, local_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the if
branch, we call run_and_profiling
which contains self.sess.run
, but in the else
branch, the self.sess.run
function is called directly, which doesn't feel very corresponding. Maybe we should wrap the timeline saving separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think move self.profiling
to run_and_profiling()
and use run_and_profiling
directly would look better.
graphlearn/examples/tf/ego_data.py
Outdated
import graphlearn.python.nn.tf as tfg | ||
from graphlearn.python.utils import parse_nbrs_num | ||
|
||
class EgoData: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A rough code of my opinion, just for reference.
Base class:
class EgoSampleLoaderBase: # just for example, maybe we could find a better class name
def __init__(self, graph, nbr_num, sampler, batch_size, mask="train"):
# ..
if mask == 'train':
tfg.conf.training = True
self.sample_query = self.query(graph, mask)
ds = tfg.Dataset(self.query_train, window=10)
self._iterator = ...
def _query(self):
raise NotImplementedError...
def _format(self, ...):
raise NotImplementedError...
@property
def iterator(self):
return self._iterator
def as_list(self):
return self._format()
@property
def src(self):
return self._data_dict['seed']
def hop(self, idx):
return self._data_dict['hop1'] # Just for example
Example inherit class
class EgoRGCNSampleLoader:
def _query(self):
# ...
def _format(self):
# ...
Usage in train.py
graph = g.init()
model = EgoRGCN(...)
train_sample = EgoRGCNSamplLoader(g, nbr_num, "random", 128, 'train')
train_emb = model.forward(train_sample.as_list())
loss = loss_fn(train_emb, train_sample.src.labels)
trainer = Trainer(train_sample.iterator, loss)
trainer.run()
# for test
test_sample_loader = EgoRGCNSamplLoader(g, nbr_num, "random", 128, 'test')
# ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to put all sampling and data preprocessing into a SampleLoader
or NeighborLoader
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
from graphlearn.python.nn.tf.layers.linear_layer import LinearLayer | ||
|
||
|
||
class EgoGCNConv(EgoConv): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like EgoRGCNConv, not EgoGCN. You can use EgoSAGEConv with aggr='gcn' as EgoGCNConv.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. will test that. Thanks!
act_func=tf.nn.relu, | ||
dropout=0.0, | ||
**kwargs): | ||
"""EgoGraph based RGCN. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These args are for RGCN not for GCN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will fix.
Signed-off-by: Tao He <sighingnow@gmail.com>
graphlearn/examples/tf/trainer.py
Outdated
|
||
def _close_session(): | ||
if self.sess is not None: | ||
self.sess.close() | ||
atexit.register(_close_session) | ||
|
||
def train(self, iterator, loss, learning_rate, epochs=10, hooks=[], **kwargs): | ||
def run_and_profiling(self, train_ops, local_step): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run_step
is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Has been renamed to run_step
.
train_data = EgoRGCNDataLoader(g, gl.Mask.TRAIN, FLAGS.sampler, FLAGS.train_batch_size, | ||
node_type='i', nbrs_num=nbrs_num, num_relations=FLAGS.num_relations) | ||
train_embedding = model.forward(train_data.as_list(), nbrs_num) | ||
loss = supervised_loss(train_embedding, train_data['seed'].labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user cannot know 'seed', it is inside 'train_data, so maybe just use API like
seed().labels` ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed back to train/test/val
based on the mask
parameter (keep the previous behaviour) and expose some helpers train_labels
, test_labels
, val_labels
for accessing.
Signed-off-by: Tao He <sighingnow@gmail.com>
Signed-off-by: Tao He <sighingnow@gmail.com>
graphlearn/examples/tf/ego_data.py
Outdated
return self._dataset.get_egograph(key) | ||
|
||
@property | ||
def train_ego(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here should keep only a single ego()
property in which you can call get_egograph
according to self._mask
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we need src_ego
, dst_ego
and neg_dst_ego
three methods(neg_dst_ego
for unsupervised model). train_ego
or test_ego
is just the case of src_ego
when mask is Train or Test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not put train/test/val queries all together for supporting user who only want to run save embedding phase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not put train/test/val queries all together for supporting user who only want to run save embedding phase.
Fixed.
Signed-off-by: Tao He <sighingnow@gmail.com>
graphlearn/examples/tf/trainer.py
Outdated
""" | ||
""" | ||
|
||
def add_initializer(self, iterator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted. Thanks for raising it up.
graphlearn/examples/tf/trainer.py
Outdated
|
||
Args: | ||
cluster_spec: TensorFlow ClusterSpec. | ||
job_name: name of this worker. | ||
task_index: index of this worker. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rm this unused arg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring has been revised in trainer.py.
self.sync_barrier = None | ||
self.global_step = None | ||
self.is_local = None | ||
|
||
def context(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise NotImplementedError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
def _close_session(): | ||
if self.sess is not None: | ||
self.sess.close() | ||
atexit.register(_close_session) | ||
|
||
def train(self, iterator, loss, learning_rate, epochs=10, hooks=[], **kwargs): | ||
def run_step(self, train_ops, local_step): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise NotImplementedError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
print('Start testing ...') | ||
total_test_acc = [] | ||
local_step = 0 | ||
last_local_step = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the LocalTrainer
can also use global_step
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is to make the logs less confusing.
graphlearn/examples/tf/ego_data.py
Outdated
@@ -0,0 +1,118 @@ | |||
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ego_data.py ->ego_data_loader.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, renamed ego_rgcn_data_loader.py
and ego_sage_data_loader.py
as well.
graphlearn/examples/tf/ego_data.py
Outdated
prefix = ('train', 'test', 'val')[self._mask.value - 1] | ||
return self._data_dict[prefix].labels | ||
|
||
def as_list(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to x_list
which means the input node feature(processed) list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
Signed-off-by: Tao He <sighingnow@gmail.com>
def dst_ego(self): | ||
''' Alias for `self.get_egograph('dst')`. | ||
''' | ||
return self.get_egograph('dst') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that not all the queries in sub class contains 'src', 'dst' and 'neg_dst'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They only be called when needed, otherwise we would need user to hard code get_egograph("src"), ...
in their train_(un)supervised.py
.
def src_ego(self): | ||
''' Alias for `self.get_egograph('src')`. | ||
''' | ||
if self._mask is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base class should not provide default implementation for src_ego
, dst_ego
, neg_dst_ego
, because 'src' or 'dst' should only be valid when query use it in derived class.
You can just raise NotImplementedError here.
''' | ||
return self.get_egograph('neg_dst') | ||
|
||
@property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
labels
, x_list
and _format
these interfaces are not so common. It is best to implement them in the required subclasses such as EgoRGCNDataLoader, not in the base class.
… subclasses Signed-off-by: Tao He <sighingnow@gmail.com>
This pull request
Data
class to gcn and rgcn to make the example code simpler, as discussedadd aEgo
based GCN to restore the previous GCN example, and show howEgoXXXData
works.