Skip to content

Commit f4a7d7a

Browse files
committed
simplify symbol remapping
In Keras 2.13, almost all symbols moved from `keras.*` to `keras.src.*`. The first approach registered aliases for all the necessary S3 methods, and checked for both `keras.src.*` and `keras.src.*` in various parts in the sources. The new, simpler approach involves registereing a class filter with reticualte, so any S3 class name starting with `keras.src.*` is automatically renamed to `keras.*`.
1 parent b2c3a32 commit f4a7d7a

File tree

11 files changed

+18
-71
lines changed

11 files changed

+18
-71
lines changed

NAMESPACE

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,26 @@
33
S3method("$",py_R6ClassGenerator)
44
S3method(as.data.frame,keras_training_history)
55
S3method(compile,keras.engine.training.Model)
6-
S3method(compile,keras.src.engine.training.Model)
76
S3method(evaluate,keras.engine.training.Model)
8-
S3method(evaluate,keras.src.engine.training.Model)
97
S3method(export_savedmodel,keras.engine.training.Model)
10-
S3method(export_savedmodel,keras.src.engine.training.Model)
118
S3method(fit,keras.engine.training.Model)
12-
S3method(fit,keras.src.engine.training.Model)
139
S3method(format,keras.engine.training.Model)
14-
S3method(format,keras.src.engine.training.Model)
1510
S3method(pillar::type_sum,py_R6ClassGenerator)
1611
S3method(plot,keras.engine.training.Model)
17-
S3method(plot,keras.src.engine.training.Model)
1812
S3method(plot,keras_training_history)
1913
S3method(predict,keras.engine.training.Model)
20-
S3method(predict,keras.src.engine.training.Model)
2114
S3method(print,keras.engine.training.Model)
22-
S3method(print,keras.src.engine.training.Model)
2315
S3method(print,keras_training_history)
2416
S3method(print,kerastools.model.RModel)
2517
S3method(print,py_R6ClassGenerator)
2618
S3method(py_str,keras.engine.training.Model)
27-
S3method(py_str,keras.src.engine.training.Model)
2819
S3method(py_to_r,keras.utils.generic_utils.SharedObjectConfig)
2920
S3method(py_to_r_wrapper,keras.engine.base_layer.Layer)
3021
S3method(py_to_r_wrapper,keras.engine.training.Model)
31-
S3method(py_to_r_wrapper,keras.src.engine.base_layer.Layer)
32-
S3method(py_to_r_wrapper,keras.src.engine.training.Model)
3322
S3method(py_to_r_wrapper,kerastools.model.RModel)
3423
S3method(r_to_py,R6ClassGenerator)
3524
S3method(r_to_py,keras_layer_wrapper)
3625
S3method(summary,keras.engine.training.Model)
37-
S3method(summary,keras.src.engine.training.Model)
3826
S3method(summary,kerastools.model.RModel)
3927
export("%<-%")
4028
export("%<-active%")

R/layers-core.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,4 +562,4 @@ compose_layer.keras.models.Sequential <- function(object, layer, ...) {
562562
}
563563

564564
compose_layer.keras.engine.sequential.Sequential <- compose_layer.keras.models.Sequential
565-
compose_layer.keras.src.engine.sequential.Sequential <- compose_layer.keras.models.Sequential
565+
# compose_layer.keras.src.engine.sequential.Sequential <- compose_layer.keras.models.Sequential

R/layers-embedding.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ layer_embedding <-
6565
input_dim = as.integer,
6666
output_dim = as.integer,
6767
input_length = as_nullable_integer,
68-
batch_size = as_nullable_integer,
68+
batch_size = as_nullable_integer
6969
), ignore = "object")
7070
create_layer(keras$layers$Embedding, object, args)
7171
}

R/layers-normalization.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
#' })
5858
#' ```
5959
#'
60+
#' @param object Layer or model object
61+
#'
6062
#' @param axis Integer, the axis that should be normalized (typically the features
6163
#' axis). For instance, after a `Conv2D` layer with
6264
#' `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.

R/model-persistence.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,7 @@ model_from_yaml <- function(yaml, custom_objects = NULL) {
342342
#' @export
343343
serialize_model <- function(model, include_optimizer = TRUE) {
344344

345-
if (!inherits(model, c("keras.engine.training.Model",
346-
"keras.src.engine.training.Model")))
345+
if (!inherits(model, "keras.engine.training.Model"))
347346
stop("You must pass a Keras model object to serialize_model")
348347

349348
# write hdf5 file to temp file

R/model.R

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,7 @@ resolve_input_data <- function(x, y = NULL) {
555555
args$x <- as_generator(x)
556556
} else if (inherits(x, "python.builtin.iterator")) {
557557
args$x <- x
558-
} else if (inherits(x, c("keras.src.utils.data_utils.Sequence",
559-
"keras.utils.data_utils.Sequence"))) {
558+
} else if (inherits(x, "keras.utils.data_utils.Sequence")) {
560559
args$x <- x
561560
} else {
562561
if (!is.null(x))
@@ -577,8 +576,7 @@ resolve_validation_data <- function(validation_data) {
577576
args$validation_data <- as_generator(validation_data)
578577
else if (inherits(validation_data, "python.builtin.iterator"))
579578
args$validation_data <- validation_data
580-
else if (inherits(validation_data, c("keras.src.utils.data_utils.Sequence",
581-
"keras.utils.data_utils.Sequence")))
579+
else if (inherits(validation_data, "keras.utils.data_utils.Sequence"))
582580
args$validation_data <- validation_data
583581
else {
584582
args$validation_data <- keras_array(validation_data)
@@ -1336,8 +1334,6 @@ is_main_thread_generator.keras_preprocessing.sequence.TimeseriesGenerator <- fun
13361334
FALSE
13371335
}
13381336

1339-
is_main_thread_generator.keras.src.preprocessing.sequence.TimeseriesGenerator <-
1340-
is_main_thread_generator.keras_preprocessing.sequence.TimeseriesGenerator
13411337

13421338
is_tensorflow_dataset <- function(x) {
13431339
inherits(x, "tensorflow.python.data.ops.dataset_ops.DatasetV2") ||

R/package.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ keras <- NULL
107107
# delay load keras
108108
keras <<- import(implementation_module, delay_load = list(
109109

110-
priority = 10,
110+
priority = 10, # tensorflow priority == 5
111111

112112
environment = "r-tensorflow",
113113

@@ -149,13 +149,15 @@ keras <- NULL
149149
if (identical(module, "tensorflow.keras"))
150150
module <- "tensorflow.python.keras"
151151

152+
# replace "tensorflow.python.keras.*" with "keras.*"
152153
classes <- sub(paste0("^", module), "keras", classes)
153154

155+
# All python symbols moved in v2.13 under .src
156+
classes <- sub("^keras\\.src\\.", "keras.", classes)
154157

155158
# let KerasTensor inherit all the S3 methods of tf.Tensor, but
156159
# KerasTensor methods take precedence.
157-
if(any(c("keras.src.engine.keras_tensor.KerasTensor",
158-
"keras.engine.keras_tensor.KerasTensor") %in% classes))
160+
if(any("keras.engine.keras_tensor.KerasTensor" %in% classes))
159161
classes <- unique(c("keras.engine.keras_tensor.KerasTensor",
160162
"tensorflow.tensor",
161163
classes))

R/zzz.R

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +0,0 @@
1-
2-
## keras internal symbols moved under `src` in TF 2.12
3-
#' @export
4-
compile.keras.src.engine.training.Model <- compile.keras.engine.training.Model
5-
6-
#' @export
7-
evaluate.keras.src.engine.training.Model <- evaluate.keras.engine.training.Model
8-
9-
#' @export
10-
export_savedmodel.keras.src.engine.training.Model <- export_savedmodel.keras.engine.training.Model
11-
12-
#' @export
13-
fit.keras.src.engine.training.Model <- fit.keras.engine.training.Model
14-
15-
#' @export
16-
format.keras.src.engine.training.Model <- format.keras.engine.training.Model
17-
18-
#' @export
19-
plot.keras.src.engine.training.Model <- plot.keras.engine.training.Model
20-
21-
#' @export
22-
predict.keras.src.engine.training.Model <- predict.keras.engine.training.Model
23-
24-
#' @export
25-
print.keras.src.engine.training.Model <- print.keras.engine.training.Model
26-
27-
#' @export
28-
py_str.keras.src.engine.training.Model <- py_str.keras.engine.training.Model
29-
30-
#' @export
31-
py_to_r_wrapper.keras.src.engine.base_layer.Layer <- py_to_r_wrapper.keras.engine.base_layer.Layer
32-
33-
#' @export
34-
py_to_r_wrapper.keras.src.engine.training.Model <- py_to_r_wrapper.keras.engine.training.Model
35-
36-
#' @export
37-
summary.keras.src.engine.training.Model <- summary.keras.engine.training.Model
38-
39-
as_generator.keras.src.utils.data_utils.Sequence <- as_generator.keras_preprocessing.sequence.TimeseriesGenerator

man/layer_batch_normalization.Rd

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-metrics.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,7 @@ test_metric <- function(metric, ...) {
143143
m <- metric(...)
144144

145145
expect_s3_class(m, c("keras.metrics.Metric",
146-
'keras.metrics.base_metric.Metric',
147-
'keras.src.metrics.base_metric.Metric'))
146+
'keras.metrics.base_metric.Metric'))
148147

149148
define_model() %>%
150149
compile(loss = loss,

0 commit comments

Comments
 (0)