Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce KVariable data class to encapsulate all the information about a single parameter of the layer #324

Merged
merged 2 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.util.getDType
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Variable

/**
* A class that keeps information about a single parameter of the [Layer].
*
* @property [name] name of the variable
* @property [shape] shape of the variable
* @property [variable] corresponding [Variable] object
* @property [initializerOperand] variable initializer
* @property [regularizer] variable regularizer
*/
public data class KVariable(
val name: String,
val shape: Shape,
val variable: Variable<Float>,
val initializerOperand: Operand<Float>,
val regularizer: Regularizer?
)

internal fun createVariable(
tf: Ops,
kGraph: KGraph,
variableName: String,
isTrainable: Boolean,
shape: Shape,
fanIn: Int,
fanOut: Int,
initializer: Initializer,
regularizer: Regularizer?
): KVariable {
val tfVariable = tf.withName(variableName).variable(shape, getDType())

val initOp = initializer.apply(fanIn, fanOut, tf, tfVariable, variableName)
kGraph.addLayerVariable(tfVariable, isTrainable)
kGraph.addInitializer(variableName, initOp)
if (regularizer != null) kGraph.addVariableRegularizer(tfVariable, regularizer)

return KVariable(
name = variableName,
shape = shape,
variable = tfVariable,
initializerOperand = initOp,
regularizer = regularizer
)
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.TrainableModel
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.extension.convertTensorToMultiDimArray
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Variable

/**
* Base abstract class for all layers.
Expand All @@ -34,12 +31,6 @@ public abstract class Layer(public var name: String) {
/** Model where this layer is used. */
public var parentModel: TrainableModel? = null

/** Returns number of input parameters. */
protected var fanIn: Int = Int.MIN_VALUE

/** Returns number of output parameters. */
protected var fanOut: Int = Int.MIN_VALUE

/** Returns inbound layers. */
public var inboundLayers: MutableList<Layer> = mutableListOf()

Expand Down Expand Up @@ -108,47 +99,24 @@ public abstract class Layer(public var name: String) {
return forward(tf, input[0], isTraining, numberOfLosses)
}

/**
* Adds a new weight tensor to the layer
*
* @param name variable name
* @param variable variable to add
* @return the created variable.
*/
protected fun addWeight(
tf: Ops,
kGraph: KGraph,
name: String,
variable: Variable<Float>,
initializer: Initializer,
regularizer: Regularizer? = null
): Variable<Float> {
// require(fanIn != Int.MIN_VALUE) { "fanIn should be calculated before initialization for variable $name" }
// require(fanOut != Int.MIN_VALUE) { "fanOut should be calculated before initialization for variable $name" }

val initOp = initializer.apply(fanIn, fanOut, tf, variable, name)
kGraph.addLayerVariable(variable, isTrainable)
kGraph.addInitializer(name, initOp)
if (regularizer != null) kGraph.addVariableRegularizer(variable, regularizer)
return variable
}

/** Important part of functional API. It takes [layers] as input and saves them to the [inboundLayers] of the given layer. */
public operator fun invoke(vararg layers: Layer): Layer {
inboundLayers = layers.toMutableList()
return this
}

/** Extract weights values by variable names. */
protected fun extractWeights(variableNames: List<String>): Map<String, Array<*>> {
/** Extract weights values for provided variables. */
protected fun extractWeights(vararg variables: KVariable?): Map<String, Array<*>> {
require(parentModel != null) { "Layer $name is not related to any model!" }

val session = parentModel!!.session
val runner = session.runner()

val variableNames = variables.mapNotNull { it?.name }
for (variableName in variableNames) {
runner.fetch(variableName)
}

val weights = runner.run().map { it.convertTensorToMultiDimArray() }
return variableNames.zip(weights).toMap()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand All @@ -8,14 +8,14 @@ package org.jetbrains.kotlinx.dl.api.core.layer.activation
import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros
import org.jetbrains.kotlinx.dl.api.core.layer.KVariable
import org.jetbrains.kotlinx.dl.api.core.layer.createVariable
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.numElements
import org.jetbrains.kotlinx.dl.api.core.shape.toLongArray
import org.jetbrains.kotlinx.dl.api.core.util.getDType
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Variable

/**
* Parametric Rectified Linear Unit.
Expand All @@ -40,15 +40,16 @@ public class PReLU(
/**
* TODO: support for constraint (alphaConstraint) should be added
*/
private lateinit var alphaShape: Shape
private lateinit var alpha: Variable<Float>
private val alphaVariableName = if (name.isNotEmpty()) name + "_" + "alpha" else "alpha"

private lateinit var alpha: KVariable
private fun alphaVariableName(): String =
if (name.isNotEmpty()) "${name}_alpha" else "alpha"

override var weights: Map<String, Array<*>>
get() = extractWeights(listOf(alphaVariableName))
get() = extractWeights(alpha)
set(value) = assignWeights(value)
override val paramCount: Int
get() = alphaShape.numElements().toInt()
get() = alpha.shape.numElements().toInt()

init {
isTrainable = true
Expand All @@ -61,19 +62,28 @@ public class PReLU(
alphaShapeArray[axis - 1] = 1
}
}
alphaShape = Shape.make(alphaShapeArray[0], *alphaShapeArray.drop(1).toLongArray())

fanIn = inputShape.size(inputShape.numDimensions() - 1).toInt()
fanOut = fanIn
val fanIn = inputShape.size(inputShape.numDimensions() - 1).toInt()
val fanOut = fanIn

alpha = tf.withName(alphaVariableName).variable(alphaShape, getDType())
alpha = addWeight(tf, kGraph, alphaVariableName, alpha, alphaInitializer, alphaRegularizer)
val alphaShape = Shape.make(alphaShapeArray[0], *alphaShapeArray.drop(1).toLongArray())
alpha = createVariable(
tf,
kGraph,
alphaVariableName(),
isTrainable,
alphaShape,
fanIn,
fanOut,
alphaInitializer,
alphaRegularizer
)
}

override fun forward(tf: Ops, input: Operand<Float>): Operand<Float> {
// It's equivalent to: `-alpha * relu(-x) + relu(x)`
val positive = tf.nn.relu(input)
val negative = tf.math.mul(tf.math.neg(alpha), tf.nn.relu(tf.math.neg(input)))
val negative = tf.math.mul(tf.math.neg(alpha.variable), tf.nn.relu(tf.math.neg(input)))
return tf.math.add(positive, negative)
}

Expand Down
Loading