Skip to content

Commit

Permalink
RegularizerAPI and UnitTest
Browse files Browse the repository at this point in the history
  • Loading branch information
SchoenTannenbaum committed May 20, 2024
1 parent f5ba382 commit 5f9fce5
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 7 deletions.
11 changes: 10 additions & 1 deletion src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,14 @@ public interface IRegularizer
[JsonProperty("config")]
IDictionary<string, object> Config { get; }
Tensor Apply(RegularizerArgs args);
}
}

public interface IRegularizerApi
{
IRegularizer GetRegularizerFromName(string name);
IRegularizer L1 { get; }
IRegularizer L2 { get; }
IRegularizer L1L2 { get; }
}

}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/Regularizers/L1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class L1 : IRegularizer
float _l1;
private readonly Dictionary<string, object> _config;

public string ClassName => "L2";
public string ClassName => "L1";
public virtual IDictionary<string, object> Config => _config;

public L1(float l1 = 0.01f)
Expand Down
44 changes: 39 additions & 5 deletions src/TensorFlowNET.Keras/Regularizers.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,51 @@
namespace Tensorflow.Keras
using Tensorflow.Operations.Regularizers;

namespace Tensorflow.Keras
{
public class Regularizers
public class Regularizers: IRegularizerApi
{
private static Dictionary<string, IRegularizer> _nameActivationMap;

public IRegularizer l1(float l1 = 0.01f)
=> new Tensorflow.Operations.Regularizers.L1(l1);
=> new L1(l1);
public IRegularizer l2(float l2 = 0.01f)
=> new Tensorflow.Operations.Regularizers.L2(l2);
=> new L2(l2);

//From TF source
//# The default value for l1 and l2 are different from the value in l1_l2
//# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2
//# and no l1 penalty.
public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f)
=> new Tensorflow.Operations.Regularizers.L1L2(l1, l2);
=> new L1L2(l1, l2);

static Regularizers()
{
_nameActivationMap = new Dictionary<string, IRegularizer>();
_nameActivationMap["L1"] = new L1();
_nameActivationMap["L1"] = new L2();
_nameActivationMap["L1"] = new L1L2();
}

public IRegularizer L1 => l1();

public IRegularizer L2 => l2();

public IRegularizer L1L2 => l1l2();

public IRegularizer GetRegularizerFromName(string name)
{
if (name == null)
{
throw new Exception($"Regularizer name cannot be null");
}
if (!_nameActivationMap.TryGetValue(name, out var res))
{
throw new Exception($"Regularizer {name} not found");
}
else
{
return res;
}
}
}
}
48 changes: 48 additions & 0 deletions test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json.Linq;
using System.Collections.Generic;
using System.Linq;
using System.Xml.Linq;
using Tensorflow.Keras.Engine;
Expand Down Expand Up @@ -129,6 +130,53 @@ public void TestModelBeforeTF2_5()
}


[TestMethod]
public void BiasRegularizerSaveAndLoad()
{
var savemodel = keras.Sequential(new List<ILayer>()
{
tf.keras.layers.InputLayer((227, 227, 3)),
tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),

tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1L2),
tf.keras.layers.BatchNormalization(),

tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L2),
tf.keras.layers.BatchNormalization(),

tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),

tf.keras.layers.Flatten(),

tf.keras.layers.Dense(1000, activation: "linear"),
tf.keras.layers.Softmax(1)
});

savemodel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });

var num_epochs = 1;
var batch_size = 8;

var trainDataset = new RandomDataSet(new Shape(227, 227, 3), 16);

savemodel.fit(trainDataset.Data, trainDataset.Labels, batch_size, num_epochs);

savemodel.save(@"./bias_regularizer_save_and_load", save_format: "tf");

var loadModel = tf.keras.models.load_model(@"./bias_regularizer_save_and_load");
loadModel.summary();

loadModel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });

var fitDataset = new RandomDataSet(new Shape(227, 227, 3), 16);

loadModel.fit(fitDataset.Data, fitDataset.Labels, batch_size, num_epochs);
}


[TestMethod]
public void CreateConcatenateModelSaveAndLoad()
Expand Down

0 comments on commit 5f9fce5

Please sign in to comment.