Skip to content

Commit

Permalink
add node weights for metis wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Chendi Qian committed Mar 25, 2021
1 parent 54d8418 commit 247154f
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 23 deletions.
30 changes: 26 additions & 4 deletions csrc/cpu/metis_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> vweights,
int64_t num_parts, bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr);
Expand All @@ -22,22 +23,33 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT(optional_value.value().numel() == col.numel());
}

if (vweights.has_value()) {
CHECK_CPU(vweights.value());
CHECK_INPUT(vweights.value().dim() == 1);
CHECK_INPUT(vweights.value().numel() == rowptr.numel() - 1);
}

int64_t nvtxs = rowptr.numel() - 1;
int64_t ncon = 1;
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
int64_t *adjwgt = NULL;
if (optional_value.has_value())
adjwgt = optional_value.value().data_ptr<int64_t>();

int64_t *vwgt = NULL;
if (vweights.has_value())
vwgt = vweights.value().data_ptr<int64_t>();

int64_t objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part_data = part.data_ptr<int64_t>();

if (recursive) {
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data);
} else {
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data);
}

Expand All @@ -52,6 +64,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
// --partitions64bit
torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> vweights,
int64_t num_parts, bool recursive,
int64_t num_workers) {
#ifdef WITH_MTMETIS
Expand All @@ -63,13 +76,22 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT(optional_value.value().numel() == col.numel());
}

if (vweights.has_value()) {
CHECK_CPU(vweights.value());
CHECK_INPUT(vweights.value().dim() == 1);
CHECK_INPUT(vweights.value().numel() == rowptr.numel() - 1);
}

mtmetis_vtx_type nvtxs = rowptr.numel() - 1;
mtmetis_vtx_type ncon = 1;
mtmetis_adj_type *xadj = (mtmetis_adj_type *)rowptr.data_ptr<int64_t>();
mtmetis_vtx_type *adjncy = (mtmetis_vtx_type *)col.data_ptr<int64_t>();
mtmetis_wgt_type *adjwgt = NULL;
if (optional_value.has_value())
adjwgt = optional_value.value().data_ptr<int64_t>();
mtmetis_wgt_type *vwgt = NULL;
if (vweights.has_value())
vwgt = vweights.value().data_ptr<int64_t>();
mtmetis_pid_type nparts = num_parts;
mtmetis_wgt_type objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
Expand All @@ -79,10 +101,10 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
opts[MTMETIS_OPTION_NTHREADS] = num_workers;

if (recursive) {
MTMETIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
MTMETIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&nparts, NULL, NULL, opts, &objval, part_data);
} else {
MTMETIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
MTMETIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&nparts, NULL, NULL, opts, &objval, part_data);
}

Expand Down
2 changes: 2 additions & 0 deletions csrc/cpu/metis_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> vweights,
int64_t num_parts, bool recursive);

torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> vweights,
int64_t num_parts, bool recursive,
int64_t num_workers);
6 changes: 4 additions & 2 deletions csrc/metis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }

torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> vweights,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
Expand All @@ -21,12 +22,13 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_cpu(rowptr, col, optional_value, num_parts, recursive);
return partition_cpu(rowptr, col, optional_value, vweights, num_parts, recursive);
}
}

torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> vweights,
int64_t num_parts, bool recursive,
int64_t num_workers) {
if (rowptr.device().is_cuda()) {
Expand All @@ -36,7 +38,7 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return mt_partition_cpu(rowptr, col, optional_value, num_parts, recursive,
return mt_partition_cpu(rowptr, col, optional_value, vweights, num_parts, recursive,
num_workers);
}
}
Expand Down
38 changes: 22 additions & 16 deletions test/test_metis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,26 @@ def test_metis(device):
value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
value3 = torch.ones(6 * 6, device=device).view(6, 6)

vwgts = torch.rand(6, device=device)

for value in [value1, value2, value3]:
mat = SparseTensor.from_dense(value)

_, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=True)
assert partptr.numel() == 3
assert perm.numel() == 6

_, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=False)
assert partptr.numel() == 3
assert perm.numel() == 6

_, partptr, perm = mat.partition(num_parts=1, recursive=False,
weighted=True)
assert partptr.numel() == 2
assert perm.numel() == 6
for vwgt in [None, vwgts]:
mat = SparseTensor.from_dense(value)

_, partptr, perm = mat.partition(num_parts=2, recursive=False,
vweights=vwgt,
weighted=True)
assert partptr.numel() == 3
assert perm.numel() == 6

_, partptr, perm = mat.partition(num_parts=2, recursive=False,
vweights=vwgt,
weighted=False)
assert partptr.numel() == 3
assert perm.numel() == 6

_, partptr, perm = mat.partition(num_parts=1, recursive=False,
vweights=vwgt,
weighted=True)
assert partptr.numel() == 2
assert perm.numel() == 6
9 changes: 8 additions & 1 deletion torch_sparse/metis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:

def partition(
src: SparseTensor, num_parts: int, recursive: bool = False,
vweights: torch.tensor = None,
weighted: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:

Expand All @@ -41,7 +42,13 @@ def partition(
else:
value = None

cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
if vweights is not None:
assert vweights.numel() == rowptr.numel() - 1
vweights = vweights.view(-1).detach().cpu()
if vweights.is_floating_point():
vweights = weight2metis(vweights)

cluster = torch.ops.torch_sparse.partition(rowptr, col, value, vweights, num_parts,
recursive)
cluster = cluster.to(src.device())

Expand Down

0 comments on commit 247154f

Please sign in to comment.