diff --git a/.gitignore b/.gitignore index 8b07c26aa608..7125c18decb9 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,7 @@ R-package/inst/* *.tar.gz *.tgz R-package/man/*.Rd +R-package/R/mxnet_generated.R # data *.rec diff --git a/Makefile b/Makefile index c2720acd4d33..e3f4cd0b42e1 100644 --- a/Makefile +++ b/Makefile @@ -300,8 +300,10 @@ rpkg: echo "import(methods)" >> R-package/NAMESPACE R CMD INSTALL R-package Rscript -e "require(mxnet); mxnet:::mxnet.export(\"R-package\")" + rm -rf R-package/NAMESPACE Rscript -e "require(roxygen2); roxygen2::roxygenise(\"R-package\")" R CMD build --no-build-vignettes R-package + rm -rf mxnet_current_r.tar.gz mv mxnet_*.tar.gz mxnet_current_r.tar.gz scalapkg: diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index 8d8e7532110f..0d19050f1506 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -15,7 +15,7 @@ BugReports: https://github.com/dmlc/mxnet/issues Imports: methods, Rcpp (>= 0.12.1), - DiagrammeR (>= 0.8.1), + DiagrammeR (>= 0.9.0), data.table, jsonlite, magrittr, diff --git a/R-package/R/viz.graph.R b/R-package/R/viz.graph.R index 695e768a2406..0587ce432fcc 100644 --- a/R-package/R/viz.graph.R +++ b/R-package/R/viz.graph.R @@ -9,10 +9,10 @@ #' @importFrom data.table := #' @importFrom data.table setkey #' @importFrom jsonlite fromJSON -#' @importFrom DiagrammeR create_nodes +#' @importFrom DiagrammeR create_node_df #' @importFrom DiagrammeR create_graph -#' @importFrom DiagrammeR create_edges -#' @importFrom DiagrammeR combine_edges +#' @importFrom DiagrammeR create_edge_df +#' @importFrom DiagrammeR combine_edfs #' @importFrom DiagrammeR render_graph #' #' @param model a \code{string} representing the path to a file containing the \code{JSon} of a model dump or the actual model dump. @@ -106,8 +106,8 @@ graph.viz <- function(model, graph.title = "Computation graph", graph.title.font mx.model.nodes[,id] %>% unique %>% setdiff(nodes.to.keep) %>% sort nodes <- - create_nodes( - nodes = mx.model.nodes[id %in% nodes.to.keep, id], + create_node_df( + n = length(mx.model.nodes[id %in% nodes.to.keep, id]), label = mx.model.nodes[id %in% nodes.to.keep, label], type = "lower", style = "filled", @@ -118,6 +118,8 @@ graph.viz <- function(model, graph.title = "Computation graph", graph.title.font width = "1.3", height = "0.8034" ) + + nodes$id <- mx.model.nodes[id %in% nodes.to.keep, id] mx.model.nodes[,has.connection:= sapply(inputs, function(x) length(x) > 0)] @@ -132,24 +134,24 @@ graph.viz <- function(model, graph.title = "Computation graph", graph.title.font origin <- nodes.to.insert[i, inputs][[1]][,1] %>% setdiff(nodes.to.remove) %>% unique destination <- rep(current.id, length(origin)) - edges.temp <- create_edges(from = origin, - to = destination, + edges.temp <- create_edge_df(from = as.character(origin), + to = as.character(destination), relationship = "leading_to") if (is.null(edges)) edges <- edges.temp else - edges <- combine_edges(edges.temp, edges) + edges <- combine_edfs(edges.temp, edges) } graph <- create_graph( nodes_df = nodes, - edges_df = edges, - directed = TRUE, + edges_df = edges#, +# directed = TRUE#, # node_attrs = c("fontname = Helvetica"), - graph_attrs = paste0("label = \"", graph.title, "\"") %>% c(paste0("fontname = ", graph.title.font.name)) %>% c(paste0("fontsize = ", graph.title.font.size)) %>% c("labelloc = t"), +# graph_attrs = paste0("label = \"", graph.title, "\"") %>% c(paste0("fontname = ", graph.title.font.name)) %>% c(paste0("fontsize = ", graph.title.font.size)) %>% c("labelloc = t"), # node_attrs = "fontname = Helvetica", - edge_attrs = c("color = gray20", "arrowsize = 0.8", "arrowhead = vee") + # edge_attrs = c("color = gray20", "arrowsize = 0.8", "arrowhead = vee") ) return(render_graph(graph, width = graph.width.px, height = graph.height.px))