Skip to content

Commit

Permalink
Wikidata5M dataset (#7864)
Browse files Browse the repository at this point in the history
implements the Wikidata5m transductive and inductive link prediction
datasets

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Aug 17, 2023
1 parent 2b33bca commit 58e325f
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `Wikidata5M` dataset ([#7864](https://github.com/pyg-team/pytorch_geometric/pull/7864))
- Added TorchScript support inside `BasicGNN` models ([#7865](https://github.com/pyg-team/pytorch_geometric/pull/7865))
- Added a `batch_size` argument to `unbatch` functionalities ([#7851](https://github.com/pyg-team/pytorch_geometric/pull/7851))
- Added a distributed example using `graphlearn-for-pytorch` ([#7402](https://github.com/pyg-team/pytorch_geometric/pull/7402))
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .hydro_net import HydroNet
from .airfrans import AirfRANS
from .jodie import JODIEDataset
from .wikidata import Wikidata5M

from .dbp15k import DBP15K
from .aminer import AMiner
Expand Down Expand Up @@ -174,6 +175,7 @@
'HydroNet',
'AirfRANS',
'JODIEDataset',
'Wikidata5M',
]

hetero_datasets = [
Expand Down
131 changes: 131 additions & 0 deletions torch_geometric/datasets/wikidata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
import os.path as osp
from typing import Callable, Dict, List, Optional

import torch

from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
extract_tar,
)


class Wikidata5M(InMemoryDataset):
r"""The Wikidata-5M dataset from the `"KEPLER: A Unified Model for
Knowledge Embedding and Pre-trained Language Representation"
<https://arxiv.org/pdf/1911.06136.pdf>`_ paper,
containing 4,594,485 entities, 822 relations,
20,614,279 train triples, 5,163 validation triples, and 5,133 test triples.
`Wikidata-5M <https://deepgraphlearning.github.io/project/wikidata5m>`_
is a large-scale knowledge graph dataset with aligned corpus
extracted form Wikidata.
Args:
root (str): Root directory where the dataset should be saved.
setting (str, optional):
If :obj:`"transductive"`, loads the transductive dataset.
If :obj:`"inductive"`, loads the inductive dataset.
(default: :obj:`"transductive"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
"""
def __init__(
self,
root: str,
setting: str = 'transductive',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
):
if setting not in {'transductive', 'inductive'}:
raise ValueError(f"Invalid 'setting' argument (got '{setting}')")

self.setting = setting

self.urls = [
('https://www.dropbox.com/s/7jp4ib8zo3i6m10/'
'wikidata5m_text.txt.gz?dl=1'),
'https://uni-bielefeld.sciebo.de/s/yuBKzBxsEc9j3hy/download',
]
if self.setting == 'inductive':
self.urls.append('https://www.dropbox.com/s/csed3cgal3m7rzo/'
'wikidata5m_inductive.tar.gz?dl=1')
else:
self.urls.append('https://www.dropbox.com/s/6sbhm0rwo4l73jq/'
'wikidata5m_transductive.tar.gz?dl=1')

super().__init__(root, transform, pre_transform)
self.load(self.processed_paths[0])

@property
def raw_file_names(self) -> List[str]:
return [
'wikidata5m_text.txt.gz',
'download',
f'wikidata5m_{self.setting}_train.txt',
f'wikidata5m_{self.setting}_valid.txt',
f'wikidata5m_{self.setting}_test.txt',
]

@property
def processed_file_names(self) -> str:
return f'{self.setting}_data.pt'

def download(self):
for url in self.urls:
download_url(url, self.raw_dir)
path = osp.join(self.raw_dir, f'wikidata5m_{self.setting}.tar.gz')
extract_tar(path, self.raw_dir)
os.remove(path)

def process(self):
import gzip

entity_to_id: Dict[str, int] = {}
with gzip.open(self.raw_paths[0], 'rt') as f:
for i, line in enumerate(f):
values = line.strip().split('\t')
entity_to_id[values[0]] = i

x = torch.load(self.raw_paths[1])

edge_index = []
edge_type = []
split_index = []

rel_to_id: Dict[str, int] = {}
for split, path in enumerate(self.raw_paths[2:]):
with open(path, 'r') as f:
for line in f:
head, rel, tail = line[:-1].split('\t')
edge_index.append([entity_to_id[head], entity_to_id[tail]])
if rel not in rel_to_id:
rel_to_id[rel] = len(rel_to_id)
edge_type.append(rel_to_id[rel])
split_index.append(split)

edge_index = torch.tensor(edge_index).t().contiguous()
edge_type = torch.tensor(edge_type)
split_index = torch.tensor(split_index)

data = Data(
x=x,
edge_index=edge_index,
edge_type=edge_type,
train_mask=split_index == 0,
val_mask=split_index == 1,
test_mask=split_index == 2,
)

if self.pre_transform is not None:
data = self.pre_transform(data)

self.save([data], self.processed_paths[0])

0 comments on commit 58e325f

Please sign in to comment.