Skip to content

Commit

Permalink
wip: query loader fixed
Browse files Browse the repository at this point in the history
now need to address isolated nodes detection and removal
  • Loading branch information
wey-gu committed Jul 4, 2023
1 parent 7b5b616 commit aaec99d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ with open('example/homogeneous_graph.yaml', 'r') as f:
nebula_loader = NebulaLoader(nebula_config, feature_mapper)
homo_dgl_graph = nebula_loader.load()

# or query based
query = """
MATCH p=()-->() RETURN p
"""
nebula_loader = NebulaLoader(nebula_config, feature_mapper, query=query, query_space="basketballplayer")
homo_dgl_graph = nebula_loader.load()

nx_g = homo_dgl_graph.to_networkx()
nx.draw(nx_g, with_labels=True, pos=nx.spring_layout(nx_g))
```
Expand Down
29 changes: 21 additions & 8 deletions nebula_dgl/nebula_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def validate_feature_mapper(self, feature_mapper: Dict):
for space in m_client.list_spaces()}
self.vertex_tag_schema_dict = {}
self.tag_feature_dict = {}
self.prop_pos_index = {}
self._validate_vertex_tags(m_client, feature_mapper)
self.edge_type_schema_dict = {}
self.edge_feature_dict = {}
Expand All @@ -196,6 +197,12 @@ def _validate_vertex_tags(self, m_client: MetaClient, feature_mapper: Dict):
tag.tag_name.decode(): tag for tag in m_client.list_tags(
self.spaces_dict[space_name])
}
# build self.prop_pos_index
if tag_name not in self.prop_pos_index:
self.prop_pos_index[tag_name] = dict()
tag = self.vertex_tag_schema_dict[space_name][tag_name]
for index, prop in enumerate(tag.schema.columns):
self.prop_pos_index[tag_name][prop.name.decode()] = index

# ensure tag exists
assert tag_name in self.vertex_tag_schema_dict[space_name], \
Expand Down Expand Up @@ -280,8 +287,16 @@ def _validate_edge_types(self, m_client: MetaClient, feature_mapper: Dict):
self.spaces_dict[space_name])
}

# ensure edge exists
# build self.prop_pos_index
edge_name = edge_type.get('name')
if edge_name not in self.prop_pos_index:
self.prop_pos_index[edge_name] = dict()
edge = self.edge_type_schema_dict[space_name][edge_name]
for index, prop in enumerate(edge.schema.columns):
self.prop_pos_index[edge_name][prop.name.decode()] = index

# ensure edge exists

assert edge_name in self.edge_type_schema_dict[space_name], \
'edge {} does not exist'.format(edge_name)
if space_name not in self.edge_feature_dict:
Expand Down Expand Up @@ -376,6 +391,7 @@ def transform_function(prop_values):
for prop in feature_props]
feature_prop_values = []
for index, prop_name in enumerate(feature_prop_names):
#raw_value = prop_values[self.prop_pos_index[tag_or_edge][prop_name]]
raw_value = prop_values[prop_pos_index[prop_name]]
# convert byte value according to type
feature_prop_values.append(
Expand Down Expand Up @@ -607,12 +623,13 @@ def _load_in_query_mode(self) -> DGLHeteroGraph:
vertex_index = 0
transform_function = self.get_feature_transform_function(
tag_features, prop_names)
for vertex_id, prop_values in g['nodes'][tag_name].items():
for vertex_id, prop_map in g['nodes'][tag_name].items():
_vertex_id_dict[vertex_id] = vertex_index
vertex_index += 1
# feature data for vertex(node)
if not tag_features:
continue
prop_values = [prop_map.get(prop_name) for prop_name in prop_names]
feature_values = transform_function(prop_values)
for index, feature_name in enumerate(tag_features):
feature = tag_features[feature_name]
Expand Down Expand Up @@ -647,18 +664,13 @@ def _load_in_query_mode(self) -> DGLHeteroGraph:
props.add(prop['name'])
prop_names = list(props)

graph_storage_client = self.get_storage_client()
resp = graph_storage_client.scan_edge(
space_name=space_name,
edge_name=edge_name,
prop_names=prop_names)
transform_function = self.get_feature_transform_function(
edge_features, prop_names)
start_vertices, end_vertices = [], []
start_vertex_id_dict = vertex_id_dict[space_name][start_vertex_tag]
end_vertex_id_dict = vertex_id_dict[space_name][end_vertex_tag]

for edge_tuple, prop_values in g['edges'][edge_name].items():
for edge_tuple, prop_map in g['edges'][edge_name].items():
start_vertices.append(
start_vertex_id_dict[edge_tuple[0]]
)
Expand All @@ -668,6 +680,7 @@ def _load_in_query_mode(self) -> DGLHeteroGraph:
# feature data for edge
if not edge_features:
continue
prop_values = [prop_map.get(prop_name) for prop_name in prop_names]
feature_values = transform_function(prop_values)
for index, feature_name in enumerate(edge_features):
feature = edge_features[feature_name]
Expand Down

0 comments on commit aaec99d

Please sign in to comment.