-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras-like API Advanced Activations, dropout and noise layers #2222
Changes from 1 commit
24bc07b
cb68227
8866a38
0c92054
ac28817
ecf21de
03e1ed5
866735d
c64eff5
50296a2
3054fec
c90ab99
d263320
9538c76
3ebaa40
8e83192
ed6f307
218cd41
07fce1d
daac88b
437f478
8d2ecb0
081649f
411465f
b133247
f9f3b81
082a310
6ce745c
8bfc875
4392d45
5a75157
32ff46e
9c596f2
f8beee3
6fabd80
1ada8b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright 2016 The BigDL Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.intel.analytics.bigdl.nn.keras | ||
|
||
import com.intel.analytics.bigdl.nn.abstractnn._ | ||
import com.intel.analytics.bigdl.tensor.Tensor | ||
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric | ||
import com.intel.analytics.bigdl.utils.Shape | ||
|
||
import scala.reflect.ClassTag | ||
|
||
class SReLU[T: ClassTag](SharedAxes: Array[Int] = null, | ||
var inputShape: Shape = null | ||
)(implicit ev: TensorNumeric[T]) | ||
extends KerasLayer[Tensor[T], Tensor[T], T](KerasLayer.addBatch(inputShape)) { | ||
|
||
override def doBuild(inputShape: Shape): AbstractModule[Tensor[T], Tensor[T], T] = { | ||
val shape = inputShape.toSingle().toArray | ||
if (SharedAxes == null) { | ||
val layer = com.intel.analytics.bigdl.nn.SReLU(shape.slice(1, shape.length)) | ||
layer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]] | ||
} else | ||
{ | ||
val layer = com.intel.analytics.bigdl.nn.SReLU(shape.slice(1, shape.length), SharedAxes) | ||
layer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]] | ||
} | ||
} | ||
} | ||
|
||
|
||
object SReLU { | ||
|
||
def apply[@specialized(Float, Double) T: ClassTag]( | ||
SharedAxes: Array[Int] = null, | ||
inputShape: Shape = null | ||
)(implicit ev: TensorNumeric[T]) : SReLU[T] = { | ||
new SReLU[T]( | ||
SharedAxes, | ||
inputShape) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* Copyright 2016 The BigDL Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.intel.analytics.bigdl.keras.nn | ||
|
||
import com.intel.analytics.bigdl.keras.KerasBaseSpec | ||
import com.intel.analytics.bigdl.nn.keras.{Sequential => KSequential} | ||
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule | ||
import com.intel.analytics.bigdl.nn.keras.SReLU | ||
import com.intel.analytics.bigdl.tensor.Tensor | ||
import com.intel.analytics.bigdl.utils.Shape | ||
|
||
class SReLUSpec extends KerasBaseSpec{ | ||
|
||
"SReLU" should "be the same as Keras" in { | ||
val kerasCode = | ||
""" | ||
|input_tensor = Input(shape=[2, 3]) | ||
|input = np.random.uniform(-1, 1, [1, 2, 3]) | ||
|# input = np.array([[[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove comments.... |
||
|output_tensor = SReLU(a_left_init='one', t_right_init='one')(input_tensor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove a_left_init and t_right_init |
||
|# output_tensor = SReLU()(input_tensor) | ||
|model = Model(input=input_tensor, output=output_tensor) | ||
""".stripMargin | ||
val seq = KSequential[Float]() | ||
val srelu = SReLU[Float](null, inputShape = Shape(2, 3)) | ||
seq.add(srelu) | ||
checkOutputAndGrad(seq.asInstanceOf[AbstractModule[Tensor[Float], Tensor[Float], Float]], | ||
kerasCode) | ||
} | ||
|
||
"SReLU 3D" should "be the same as Keras" in { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both ut are 3d input. This is OK. Change name of ut to SReLU with shared axes |
||
val kerasCode = | ||
""" | ||
|input_tensor = Input(shape=[3, 24]) | ||
|input = np.random.random([2, 3, 24]) | ||
|output_tensor = SReLU(shared_axes=[1, 2])(input_tensor) | ||
|model = Model(input=input_tensor, output=output_tensor) | ||
""".stripMargin | ||
val seq = KSequential[Float]() | ||
val srelu = SReLU[Float](Array(1, 2), inputShape = Shape(3, 24)) | ||
seq.add(srelu) | ||
checkOutputAndGrad(seq.asInstanceOf[AbstractModule[Tensor[Float], Tensor[Float], Float]], | ||
kerasCode) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These if else seems can be integrated?