Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- PR #881 Raft integration infrastructure

## Improvements
- PR #917 Remove gunrock option from Betweenness Centrality
- PR #764 Updated sssp and bfs with GraphCSR, removed gdf_column, added nullptr weights test for sssp
- PR #765 Remove gdf_column from connected components
- PR #780 Remove gdf_column from cuhornet features
Expand Down
20 changes: 6 additions & 14 deletions cpp/include/algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,6 @@ void overlap_list(experimental::GraphCSRView<VT, ET, WT> const &graph,
VT const *second,
WT *result);

enum class cugraph_bc_implem_t {
CUGRAPH_DEFAULT = 0, ///> Native cugraph implementation
CUGRAPH_GUNROCK ///> Gunrock implementation
};

/**
*
* @brief ForceAtlas2 is a continuous graph layout algorithm
Expand Down Expand Up @@ -270,7 +265,7 @@ void force_atlas2(experimental::GraphCOOView<VT, ET, WT> &graph,
* Betweenness centrality for a vertex is the sum of the fraction of
* all pairs shortest paths that pass through the vertex.
*
* Note that both the native and the gunrock implementations do not support a weighted graph.
* The current implementation does not support a weighted graph.
*
* @throws cugraph::logic_error with a custom message when an error
* occurs.
Expand All @@ -295,19 +290,16 @@ void force_atlas2(experimental::GraphCOOView<VT, ET, WT> &graph,
* @param[in] vertices If specified, host array of vertex ids to estimate betweenness
* centrality, these vertices will serve as sources for the traversal algorihtm to obtain
* shortest path counters.
* @param[in] implem Cugraph currently supports 2 implementations: native and
* gunrock
*
*/
template <typename VT, typename ET, typename WT, typename result_t>
void betweenness_centrality(experimental::GraphCSRView<VT, ET, WT> const &graph,
result_t *result,
bool normalized = true,
bool endpoints = false,
WT const *weight = nullptr,
VT k = 0,
VT const *vertices = nullptr,
cugraph_bc_implem_t implem = cugraph_bc_implem_t::CUGRAPH_DEFAULT);
bool normalized = true,
bool endpoints = false,
WT const *weight = nullptr,
VT k = 0,
VT const *vertices = nullptr);

enum class cugraph_cc_t {
CUGRAPH_WEAK = 0, ///> Weakly Connected Components
Expand Down
123 changes: 4 additions & 119 deletions cpp/src/centrality/betweenness_centrality.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

#include <utilities/error_utils.h>

#include <gunrock/gunrock.h>

#include "betweenness_centrality.cuh"

namespace cugraph {
Expand Down Expand Up @@ -295,92 +293,6 @@ void betweenness_centrality(experimental::GraphCSRView<VT, ET, WT> const &graph,
}
} // namespace detail

namespace gunrock {

// NOTE: sample_seeds is not really available anymore, as it has been
// replaced by k and vertices parameters, delegating the random
// generation to somewhere else (i.e python's side)
template <typename VT, typename ET, typename WT, typename result_t>
void betweenness_centrality(experimental::GraphCSRView<VT, ET, WT> const &graph,
result_t *result,
bool normalize,
VT const *sample_seeds = nullptr,
VT number_of_sample_seeds = 0)
{
cudaStream_t stream{nullptr};

//
// gunrock currently (as of 2/28/2020) only operates on a graph and results in
// host memory. [That is, the first step in gunrock is to allocate device memory
// and copy the data into device memory, the last step is to allocate host memory
// and copy the results into the host memory]
//
// They are working on fixing this. In the meantime, to get the features into
// cuGraph we will first copy the graph back into local memory and when we are finished
// copy the result back into device memory.
//
std::vector<ET> v_offsets(graph.number_of_vertices + 1);
std::vector<VT> v_indices(graph.number_of_edges);
std::vector<float> v_result(graph.number_of_vertices);
std::vector<float> v_sigmas(graph.number_of_vertices);
std::vector<int> v_labels(graph.number_of_vertices);

// fill them
CUDA_TRY(cudaMemcpy(v_offsets.data(),
graph.offsets,
sizeof(ET) * (graph.number_of_vertices + 1),
cudaMemcpyDeviceToHost));
CUDA_TRY(cudaMemcpy(
v_indices.data(), graph.indices, sizeof(VT) * graph.number_of_edges, cudaMemcpyDeviceToHost));

if (sample_seeds == nullptr) {
bc(graph.number_of_vertices,
graph.number_of_edges,
v_offsets.data(),
v_indices.data(),
-1,
v_result.data(),
v_sigmas.data(),
v_labels.data());
} else {
//
// Gunrock, as currently implemented
// doesn't support this method.
//
CUGRAPH_FAIL("gunrock doesn't currently support sampling seeds");
}

// copy to results
CUDA_TRY(cudaMemcpy(
result, v_result.data(), sizeof(result_t) * graph.number_of_vertices, cudaMemcpyHostToDevice));

// Rescale result (Based on normalize and directed/undirected)
if (normalize) {
if (graph.number_of_vertices > 2) {
float denominator = (graph.number_of_vertices - 1) * (graph.number_of_vertices - 2);

thrust::transform(rmm::exec_policy(stream)->on(stream),
result,
result + graph.number_of_vertices,
result,
[denominator] __device__(float f) { return (f * 2) / denominator; });
}
} else {
//
// gunrock answer needs to be doubled to match networkx
//
if (graph.prop.directed) {
thrust::transform(rmm::exec_policy(stream)->on(stream),
result,
result + graph.number_of_vertices,
result,
[] __device__(float f) { return (f * 2); });
}
}
}

} // namespace gunrock

/**
* @param[out] result array<result_t>(number_of_vertices)
* @param[in] normalize bool True -> Apply normalization
Expand All @@ -396,34 +308,9 @@ void betweenness_centrality(experimental::GraphCSRView<VT, ET, WT> const &graph,
bool endpoints,
WT const *weight,
VT k,
VT const *vertices,
cugraph_bc_implem_t implem)
VT const *vertices)
{
// FIXME: Gunrock call returns float and not result_t hence the implementation
// switch
if ((typeid(result_t) == typeid(double)) && (implem == cugraph_bc_implem_t::CUGRAPH_GUNROCK)) {
implem = cugraph_bc_implem_t::CUGRAPH_DEFAULT;
std::cerr << "[WARN] result_t type is 'double', switching to default "
<< "implementation" << std::endl;
}
//
// NOTE: gunrock implementation doesn't yet support the unused parameters:
// - endpoints
// - weight
// - k
// - vertices
//
// These parameters are present in the API to support future features.
//
if (implem == cugraph_bc_implem_t::CUGRAPH_DEFAULT) {
detail::betweenness_centrality(graph, result, normalize, endpoints, weight, k, vertices);
} else if (implem == cugraph_bc_implem_t::CUGRAPH_GUNROCK) {
gunrock::betweenness_centrality(graph, result, normalize);
} else {
CUGRAPH_FAIL(
"Invalid Betweenness Centrality implementation, please refer to cugraph_bc_implem_t for "
"valid implementations");
}
detail::betweenness_centrality(graph, result, normalize, endpoints, weight, k, vertices);
}

template void betweenness_centrality<int, int, float, float>(
Expand All @@ -433,16 +320,14 @@ template void betweenness_centrality<int, int, float, float>(
bool,
float const *,
int,
int const *,
cugraph_bc_implem_t);
int const *);
template void betweenness_centrality<int, int, double, double>(
experimental::GraphCSRView<int, int, double> const &,
double *,
bool,
bool,
double const *,
int,
int const *,
cugraph_bc_implem_t);
int const *);

} // namespace cugraph
6 changes: 2 additions & 4 deletions cpp/tests/centrality/betweenness_centrality_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ class Tests_BC : public ::testing::TestWithParam<BC_Usecase> {
endpoints,
static_cast<WT *>(nullptr),
configuration.number_of_sources_,
sources_ptr,
cugraph::cugraph_bc_implem_t::CUGRAPH_DEFAULT),
sources_ptr),
cugraph::logic_error);
return;
} else {
Expand All @@ -294,8 +293,7 @@ class Tests_BC : public ::testing::TestWithParam<BC_Usecase> {
endpoints,
static_cast<WT *>(nullptr),
configuration.number_of_sources_,
sources_ptr,
cugraph::cugraph_bc_implem_t::CUGRAPH_DEFAULT);
sources_ptr);
}
cudaDeviceSynchronize();
CUDA_TRY(cudaMemcpy(result.data(),
Expand Down
7 changes: 1 addition & 6 deletions python/cugraph/centrality/betweenness_centrality.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,12 @@ from libcpp cimport bool

cdef extern from "algorithms.hpp" namespace "cugraph":

ctypedef enum cugraph_bc_implem_t:
CUGRAPH_DEFAULT "cugraph::cugraph_bc_implem_t::CUGRAPH_DEFAULT"
CUGRAPH_GUNROCK "cugraph::cugraph_bc_implem_t::CUGRAPH_GUNROCK"

cdef void betweenness_centrality[VT,ET,WT,result_t](
const GraphCSRView[VT,ET,WT] &graph,
result_t *result,
bool normalized,
bool endpoints,
const WT *weight,
VT k,
const VT *vertices,
cugraph_bc_implem_t implem) except +
const VT *vertices) except +

25 changes: 1 addition & 24 deletions python/cugraph/centrality/betweenness_centrality.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# NOTE: result_type=float could ne an intuitive way to indicate the result type
def betweenness_centrality(G, k=None, normalized=True,
weight=None, endpoints=False, implementation=None,
weight=None, endpoints=False,
seed=None, result_dtype=np.float64):
"""
Compute the betweenness centrality for all nodes of the graph G from a
Expand Down Expand Up @@ -63,13 +63,6 @@ def betweenness_centrality(G, k=None, normalized=True,
If true, include the endpoints in the shortest path counts.
(Not Supported)

implementation : string, optional, default=None
if implementation is None or "default", uses native cugraph,
if "gunrock" uses gunrock based bc.
The default version supports normalized, k and seed options.
"gunrock" might be faster when considering all the sources, but
only return float results and consider all the vertices as sources.

seed : optional
if k is specified and k is an integer, use seed to initialize the
random number generator.
Expand All @@ -79,7 +72,6 @@ def betweenness_centrality(G, k=None, normalized=True,

result_dtype : np.float32 or np.float64, optional, default=np.float64
Indicate the data type of the betweenness centrality scores
Using double automatically switch implementation to "default"

Returns
-------
Expand All @@ -103,29 +95,15 @@ def betweenness_centrality(G, k=None, normalized=True,
>>> bc = cugraph.betweenness_centrality(G)
"""

#
# Some features not implemented in gunrock implementation, failing fast,
# but passing parameters through
#
# vertices is intended to be a cuDF series that contains a sampling of
# k vertices out of the graph.
#
# NOTE: cuDF doesn't currently support sampling, but there is a python
# workaround.
#
vertices = None
if implementation is None:
implementation = "default"
if implementation not in ["default", "gunrock"]:
raise ValueError("Only two implementations are supported: 'default' "
"and 'gunrock'")

if k is not None:
if implementation == "gunrock":
raise ValueError("sampling feature of betweenness "
"centrality not currently supported "
"with gunrock implementation, "
"please use None or 'default'")
# In order to compare with pre-set sources,
# k can either be a list or an integer or None
# int: Generate an random sample with k elements
Expand Down Expand Up @@ -171,6 +149,5 @@ def betweenness_centrality(G, k=None, normalized=True,
endpoints,
weight,
k, vertices,
implementation,
result_dtype)
return df
19 changes: 3 additions & 16 deletions python/cugraph/centrality/betweenness_centrality_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# cython: language_level = 3

from cugraph.centrality.betweenness_centrality cimport betweenness_centrality as c_betweenness_centrality
from cugraph.centrality.betweenness_centrality cimport cugraph_bc_implem_t
from cugraph.structure.graph_new cimport *
from cugraph.utilities.unrenumber import unrenumber
from libcpp cimport bool
Expand All @@ -31,23 +30,13 @@ import numpy.ctypeslib as ctypeslib


def betweenness_centrality(input_graph, normalized, endpoints, weight, k,
vertices, implementation, result_dtype):
vertices, result_dtype):
"""
Call betweenness centrality
"""
# NOTE: This is based on the fact that the call to the wrapper already
# checked for the validity of the implementation parameter
cdef cugraph_bc_implem_t bc_implementation = cugraph_bc_implem_t.CUGRAPH_DEFAULT
cdef GraphCSRView[int, int, float] graph_float
cdef GraphCSRView[int, int, double] graph_double

if (implementation == "default"): # Redundant
bc_implementation = cugraph_bc_implem_t.CUGRAPH_DEFAULT
elif (implementation == "gunrock"):
bc_implementation = cugraph_bc_implem_t.CUGRAPH_GUNROCK
else:
raise ValueError()

if not input_graph.adjlist:
input_graph.view_adj_list()

Expand Down Expand Up @@ -94,8 +83,7 @@ def betweenness_centrality(input_graph, normalized, endpoints, weight, k,
<float*> c_betweenness,
normalized, endpoints,
<float*> c_weight, c_k,
<int*> c_vertices,
<cugraph_bc_implem_t> bc_implementation)
<int*> c_vertices)
graph_float.get_vertex_identifiers(<int*>c_identifier)
elif result_dtype == np.float64:
graph_double = GraphCSRView[int, int, double](<int*>c_offsets, <int*>c_indices,
Expand All @@ -107,8 +95,7 @@ def betweenness_centrality(input_graph, normalized, endpoints, weight, k,
<double*> c_betweenness,
normalized, endpoints,
<double*> c_weight, c_k,
<int*> c_vertices,
<cugraph_bc_implem_t> bc_implementation)
<int*> c_vertices)
graph_double.get_vertex_identifiers(<int*>c_identifier)
else:
raise TypeError("result type for betweenness centrality can only be "
Expand Down
Loading