Skip to content

tf.keras #38

Open
Open
@xuchengggg

Description

def compile(self, learning_rate, momentum):
    """Gets the model ready for training. Adds losses, regularization, and
    metrics. Then calls the Keras compile() function.
    """
    # Optimizer object
    optimizer = keras.optimizers.SGD(
        lr=learning_rate, momentum=momentum,
        clipnorm=self.config.GRADIENT_CLIP_NORM)
    # Add Losses
    # First, clear previously set losses to avoid duplication
    self.keras_model._losses = []
    self.keras_model._per_input_losses = {}
    loss_names = [
        "rpn_class_loss",  "rpn_bbox_loss",
        "mrcnn_class_loss", "mrcnn_bbox_loss", "mrcnn_mask_loss"]
    for name in loss_names:
        layer = self.keras_model.get_layer(name)
        if layer.output in self.keras_model.losses:
            continue
        loss = (
            tf.reduce_mean(layer.output, keepdims=True)
            * self.config.LOSS_WEIGHTS.get(name, 1.))
        self.keras_model.add_loss(loss)

    # Add L2 Regularization
    # Skip gamma and beta weights of batch normalization layers.
    reg_losses = [
        keras.regularizers.l2(self.config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32)
        for w in self.keras_model.trainable_weights
        if 'gamma' not in w.name and 'beta' not in w.name]
    self.keras_model.add_loss(tf.add_n(reg_losses))

    # Compile
    self.keras_model.compile(
        optimizer=optimizer,
        loss=[None] * len(self.keras_model.outputs))

    # Add metrics for losses
    for name in loss_names:
        if name in self.keras_model.metrics_names:
            continue
        layer = self.keras_model.get_layer(name)
        self.keras_model.metrics_names.append(name)
        loss = (
            tf.reduce_mean(layer.output, keepdims=True)
            * self.config.LOSS_WEIGHTS.get(name, 1.))
        self.keras_model.metrics_tensors.append(loss)

这个是mask rcnn的代码,这部分要怎么修改才能在tensorflow 2.0, tf.keras下运行呢

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions