From 58e325fcfd52ea13ae5c676e9b892b369eac551e Mon Sep 17 00:00:00 2001 From: Moritz Blum <31183934+moritzblum@users.noreply.github.com> Date: Thu, 17 Aug 2023 21:02:21 +0900 Subject: [PATCH] `Wikidata5M` dataset (#7864) implements the Wikidata5m transductive and inductive link prediction datasets --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s --- CHANGELOG.md | 1 + torch_geometric/datasets/__init__.py | 2 + torch_geometric/datasets/wikidata.py | 131 +++++++++++++++++++++++++++ 3 files changed, 134 insertions(+) create mode 100644 torch_geometric/datasets/wikidata.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 49fe53f4f98f..456be23a1e3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added the `Wikidata5M` dataset ([#7864](https://github.com/pyg-team/pytorch_geometric/pull/7864)) - Added TorchScript support inside `BasicGNN` models ([#7865](https://github.com/pyg-team/pytorch_geometric/pull/7865)) - Added a `batch_size` argument to `unbatch` functionalities ([#7851](https://github.com/pyg-team/pytorch_geometric/pull/7851)) - Added a distributed example using `graphlearn-for-pytorch` ([#7402](https://github.com/pyg-team/pytorch_geometric/pull/7402)) diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 399feed97b82..e7287d5633e3 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -71,6 +71,7 @@ from .hydro_net import HydroNet from .airfrans import AirfRANS from .jodie import JODIEDataset +from .wikidata import Wikidata5M from .dbp15k import DBP15K from .aminer import AMiner @@ -174,6 +175,7 @@ 'HydroNet', 'AirfRANS', 'JODIEDataset', + 'Wikidata5M', ] hetero_datasets = [ diff --git a/torch_geometric/datasets/wikidata.py b/torch_geometric/datasets/wikidata.py new file mode 100644 index 000000000000..c06b82e7cb46 --- /dev/null +++ b/torch_geometric/datasets/wikidata.py @@ -0,0 +1,131 @@ +import os +import os.path as osp +from typing import Callable, Dict, List, Optional + +import torch + +from torch_geometric.data import ( + Data, + InMemoryDataset, + download_url, + extract_tar, +) + + +class Wikidata5M(InMemoryDataset): + r"""The Wikidata-5M dataset from the `"KEPLER: A Unified Model for + Knowledge Embedding and Pre-trained Language Representation" + `_ paper, + containing 4,594,485 entities, 822 relations, + 20,614,279 train triples, 5,163 validation triples, and 5,133 test triples. + + `Wikidata-5M `_ + is a large-scale knowledge graph dataset with aligned corpus + extracted form Wikidata. + + Args: + root (str): Root directory where the dataset should be saved. + setting (str, optional): + If :obj:`"transductive"`, loads the transductive dataset. + If :obj:`"inductive"`, loads the inductive dataset. + (default: :obj:`"transductive"`) + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + """ + def __init__( + self, + root: str, + setting: str = 'transductive', + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + ): + if setting not in {'transductive', 'inductive'}: + raise ValueError(f"Invalid 'setting' argument (got '{setting}')") + + self.setting = setting + + self.urls = [ + ('https://www.dropbox.com/s/7jp4ib8zo3i6m10/' + 'wikidata5m_text.txt.gz?dl=1'), + 'https://uni-bielefeld.sciebo.de/s/yuBKzBxsEc9j3hy/download', + ] + if self.setting == 'inductive': + self.urls.append('https://www.dropbox.com/s/csed3cgal3m7rzo/' + 'wikidata5m_inductive.tar.gz?dl=1') + else: + self.urls.append('https://www.dropbox.com/s/6sbhm0rwo4l73jq/' + 'wikidata5m_transductive.tar.gz?dl=1') + + super().__init__(root, transform, pre_transform) + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return [ + 'wikidata5m_text.txt.gz', + 'download', + f'wikidata5m_{self.setting}_train.txt', + f'wikidata5m_{self.setting}_valid.txt', + f'wikidata5m_{self.setting}_test.txt', + ] + + @property + def processed_file_names(self) -> str: + return f'{self.setting}_data.pt' + + def download(self): + for url in self.urls: + download_url(url, self.raw_dir) + path = osp.join(self.raw_dir, f'wikidata5m_{self.setting}.tar.gz') + extract_tar(path, self.raw_dir) + os.remove(path) + + def process(self): + import gzip + + entity_to_id: Dict[str, int] = {} + with gzip.open(self.raw_paths[0], 'rt') as f: + for i, line in enumerate(f): + values = line.strip().split('\t') + entity_to_id[values[0]] = i + + x = torch.load(self.raw_paths[1]) + + edge_index = [] + edge_type = [] + split_index = [] + + rel_to_id: Dict[str, int] = {} + for split, path in enumerate(self.raw_paths[2:]): + with open(path, 'r') as f: + for line in f: + head, rel, tail = line[:-1].split('\t') + edge_index.append([entity_to_id[head], entity_to_id[tail]]) + if rel not in rel_to_id: + rel_to_id[rel] = len(rel_to_id) + edge_type.append(rel_to_id[rel]) + split_index.append(split) + + edge_index = torch.tensor(edge_index).t().contiguous() + edge_type = torch.tensor(edge_type) + split_index = torch.tensor(split_index) + + data = Data( + x=x, + edge_index=edge_index, + edge_type=edge_type, + train_mask=split_index == 0, + val_mask=split_index == 1, + test_mask=split_index == 2, + ) + + if self.pre_transform is not None: + data = self.pre_transform(data) + + self.save([data], self.processed_paths[0])