diff --git a/examples/ConsoleTest/Program.cs b/examples/ConsoleTest/Program.cs index 1f148a48..8b5fe0d3 100644 --- a/examples/ConsoleTest/Program.cs +++ b/examples/ConsoleTest/Program.cs @@ -13,10 +13,7 @@ internal class Program { private static void Main(string[] args) { - var im_fname = Utils.Download("https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/dog.jpg", "dog.jpg"); - - var (x_img, img) = Yolo.LoadTest(im_fname, @short: 416); - Img.ImShow(img); + var arrays = NDArray.LoadNpz(@"C:\Users\deepa\Downloads\imdb.npz"); Console.ReadLine(); } } diff --git a/src/MxNet.Keras/Backend/MxNetBackend.cs b/src/MxNet.Keras/Backend/MxNetBackend.cs index 6c2e5b87..829fe90b 100644 --- a/src/MxNet.Keras/Backend/MxNetBackend.cs +++ b/src/MxNet.Keras/Backend/MxNetBackend.cs @@ -719,6 +719,12 @@ public static KerasSymbol GreaterEqual(KerasSymbol x, KerasSymbol y) return new KerasSymbol(sym.BroadcastGreaterEqual(x.Symbol, y.Symbol)); } + public static KerasSymbol GreaterEqual(KerasSymbol x, float y) + { + var y_sym = sym.Full(y, x.Shape, dtype: x.DType); + return new KerasSymbol(sym.BroadcastGreaterEqual(x.Symbol, y_sym)); + } + public static KerasSymbol Less(KerasSymbol x, KerasSymbol y) { return new KerasSymbol(sym.BroadcastLesser(x.Symbol, y.Symbol)); diff --git a/src/MxNet.Keras/Callbacks/BaseLogger.cs b/src/MxNet.Keras/Callbacks/BaseLogger.cs index 7d0814a2..73202916 100644 --- a/src/MxNet.Keras/Callbacks/BaseLogger.cs +++ b/src/MxNet.Keras/Callbacks/BaseLogger.cs @@ -1,29 +1,76 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace MxNet.Keras.Callbacks { public class BaseLogger : Callback { + public int seen; + + public string[] stateful_metrics; + + public Dictionary totals; + public BaseLogger(string[] stateful_metrics = null) { - throw new NotImplementedException(); + this.stateful_metrics = stateful_metrics != null ? stateful_metrics : new string[0]; } public override void OnEpochBegin(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + seen = 0; + totals = new Dictionary(); } public override void OnBatchEnd(int batch, Dictionary logs = null) { - throw new NotImplementedException(); + if (logs == null) + logs = new Dictionary(); + + var batch_size = logs.ContainsKey("size") ? (int)logs["size"] : 0; + this.seen += batch_size; + foreach (var log in logs) + { + var k = log.Key; + var v = log.Value; + if (this.stateful_metrics.Contains(k)) + { + this.totals[k] = v; + } + else if (this.totals.ContainsKey(k)) + { + this.totals[k] += v * batch_size; + } + else + { + this.totals[k] = v * batch_size; + } + } } public override void OnEpochEnd(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + if (logs != null) + { + string[] metrics = (string[])this.@params["metrics"]; + foreach (var k in metrics) + { + if (this.totals.ContainsKey(k)) + { + // Make value available to next callbacks. + if (this.stateful_metrics.Contains(k)) + { + logs[k] = this.totals[k]; + } + else + { + logs[k] = this.totals[k] / this.seen; + } + } + } + } } } } diff --git a/src/MxNet.Keras/Callbacks/CSVLogger.cs b/src/MxNet.Keras/Callbacks/CSVLogger.cs index 0154fc73..5004be85 100644 --- a/src/MxNet.Keras/Callbacks/CSVLogger.cs +++ b/src/MxNet.Keras/Callbacks/CSVLogger.cs @@ -1,19 +1,66 @@ -using System; +using CsvHelper; +using System; using System.Collections.Generic; +using System.IO; using System.Text; namespace MxNet.Keras.Callbacks { public class CSVLogger : Callback { + public Dictionary _open_args; + + public bool append; + + public bool append_header; + + public FileStream csv_file; + + public string file_flags; + + public string filename; + + public object keys; + + public string sep; + + public CsvWriter writer; + public CSVLogger(string filename, string separator= ",", bool append= false) { - throw new NotImplementedException(); + this.sep = separator; + this.filename = filename; + this.append = append; + this.writer = null; + this.keys = null; + this.append_header = true; + this.file_flags = ""; + this._open_args = new Dictionary { + { + "newline", + "\n" + } + }; } public override void OnTrainBegin(Dictionary logs = null) { - throw new NotImplementedException(); + object mode; + if (this.append) + { + if (File.Exists(this.filename)) + { + append_header = File.ReadAllLines(filename).Length > 0; + } + + mode = "a"; + } + else + { + mode = "w"; + } + + this.csv_file = File.OpenWrite(filename); } public override void OnEpochEnd(int epoch, Dictionary logs = null) @@ -23,7 +70,8 @@ public override void OnEpochEnd(int epoch, Dictionary logs = null public override void OnTrainEnd(Dictionary logs = null) { - throw new NotImplementedException(); + csv_file.Close(); + writer = null; } } } diff --git a/src/MxNet.Keras/Callbacks/Callback.cs b/src/MxNet.Keras/Callbacks/Callback.cs index ff90db18..48c8af08 100644 --- a/src/MxNet.Keras/Callbacks/Callback.cs +++ b/src/MxNet.Keras/Callbacks/Callback.cs @@ -7,19 +7,26 @@ namespace MxNet.Keras.Callbacks { public abstract class Callback { + public Model model; + + public Dictionary @params; + + public NDArrayList validation_data; + public Callback() { - throw new NotImplementedException(); + this.validation_data = null; + this.model = null; } public virtual void SetParams(Dictionary @params) { - throw new NotImplementedException(); + this.@params = @params; } public virtual void SetModel(Model model) { - throw new NotImplementedException(); + this.model = model; } public virtual void OnEpochBegin(int epoch, Dictionary logs = null) diff --git a/src/MxNet.Keras/Callbacks/CallbackList.cs b/src/MxNet.Keras/Callbacks/CallbackList.cs index 0e7a4a25..a9eb5b5d 100644 --- a/src/MxNet.Keras/Callbacks/CallbackList.cs +++ b/src/MxNet.Keras/Callbacks/CallbackList.cs @@ -1,55 +1,139 @@ using MxNet.Keras.Engine; using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace MxNet.Keras.Callbacks { public class CallbackList { + public double _delta_t_batch; + + public List _delta_ts_batch_begin; + + public List _delta_ts_batch_end; + + public DateTime _t_enter_batch; + + public List callbacks; + + public int queue_length; + public CallbackList(Callback[] callbacks = null, int queue_length = 10) { - throw new NotImplementedException(); + if (callbacks != null) + this.callbacks = callbacks.ToList(); + else + this.callbacks = new List(); + + this.queue_length = queue_length; } public void Append(Callback callback) { - throw new NotImplementedException(); + this.callbacks.Add(callback); } public void SetParams(Dictionary @params) { - throw new NotImplementedException(); + foreach (var callback in this.callbacks) + { + callback.SetParams(@params); + } } public void SetModel(Model model) { - throw new NotImplementedException(); + foreach (var callback in this.callbacks) + { + callback.SetModel(model); + } } public void OnEpochBegin(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + if (logs == null) + logs = new Dictionary(); + + foreach (var callback in this.callbacks) + { + callback.OnEpochBegin(epoch, logs); + } + + _delta_t_batch = 0; + } + + public void OnEpochEnd(int epoch, Dictionary logs = null) + { + if (logs == null) + logs = new Dictionary(); + + foreach (var callback in this.callbacks) + { + callback.OnEpochEnd(epoch, logs); + } } public void OnBatchBegin(int batch, Dictionary logs = null) { - throw new NotImplementedException(); + if (logs == null) + logs = new Dictionary(); + + var t_before_callbacks = DateTime.Now; + foreach (var callback in this.callbacks) + { + callback.OnBatchBegin(batch, logs); + } + this._delta_ts_batch_begin.Add((DateTime.Now - t_before_callbacks).TotalMilliseconds); + var delta_t_median = _delta_ts_batch_begin.Average(); + if (this._delta_t_batch > 0.0 && delta_t_median > 0.95 * this._delta_t_batch && delta_t_median > 0.1) + { + Logger.Warning($"Method on_batch_begin() is slow compared to the batch update ({delta_t_median}). Check your callbacks."); + } + + this._t_enter_batch = DateTime.Now; } public void OnBatchEnd(int batch, Dictionary logs = null) { - throw new NotImplementedException(); + if (logs == null) + logs = new Dictionary(); + + this._delta_t_batch = (DateTime.Now - this._t_enter_batch).TotalMilliseconds; + var t_before_callbacks = DateTime.Now; + foreach (var callback in this.callbacks) + { + callback.OnBatchEnd(batch, logs); + } + this._delta_ts_batch_end.Add((DateTime.Now - t_before_callbacks).TotalMilliseconds); + var delta_t_median = _delta_ts_batch_end.Average(); + if (this._delta_t_batch > 0.0 && (delta_t_median > 0.95 * this._delta_t_batch && delta_t_median > 0.1)) + { + Logger.Warning($"In your callbacks, method `on_batch_end()` is slow compared to a model step ({delta_t_median} vs {_delta_t_batch}). Check your callbacks."); + } } public void OnTrainBegin(Dictionary logs = null) { - throw new NotImplementedException(); + if (logs == null) + logs = new Dictionary(); + + foreach (var callback in this.callbacks) + { + callback.OnTrainBegin(logs); + } } public void OnTrainEnd(Dictionary logs = null) { - throw new NotImplementedException(); + if (logs == null) + logs = new Dictionary(); + + foreach (var callback in this.callbacks) + { + callback.OnTrainEnd(logs); + } } } } diff --git a/src/MxNet.Keras/Callbacks/EarlyStopping.cs b/src/MxNet.Keras/Callbacks/EarlyStopping.cs index cd6dc2cd..3567b9b8 100644 --- a/src/MxNet.Keras/Callbacks/EarlyStopping.cs +++ b/src/MxNet.Keras/Callbacks/EarlyStopping.cs @@ -6,30 +6,152 @@ namespace MxNet.Keras.Callbacks { public class EarlyStopping : Callback { + public float? baseline; + + public float best; + + public NDArrayList best_weights; + + public float min_delta; + + public string monitor; + + public Func monitor_op; + + public int patience; + + public bool restore_best_weights; + + public int stopped_epoch; + + public int verbose; + + public int wait; + public EarlyStopping(string monitor= "val_loss", float min_delta= 0, int patience= 0, int verbose= 0, string mode= "auto", float? baseline= null, bool restore_best_weights= false) { - throw new NotImplementedException(); + this.monitor = monitor; + this.baseline = baseline; + this.patience = patience; + this.verbose = verbose; + this.min_delta = min_delta; + this.wait = 0; + this.stopped_epoch = 0; + this.restore_best_weights = restore_best_weights; + this.best_weights = null; + if (!new List { + "auto", + "min", + "max" + }.Contains(mode)) + { + Logger.Warning($"EarlyStopping mode {mode} is unknown, fallback to auto mode."); + mode = "auto"; + } + + if (mode == "min") + { + this.monitor_op = Lesser; + } + else if (mode == "max") + { + this.monitor_op = Greater; + } + else if (this.monitor.Contains("acc")) + { + this.monitor_op = Greater; + } + else + { + this.monitor_op = Lesser; + } + + if (this.monitor_op == Greater) + { + this.min_delta *= 1; + } + else + { + this.min_delta *= -1; + } } + private bool Lesser(float l, float r) => l < r; + + private bool Greater(float l, float r) => l > r; + public override void OnTrainBegin(Dictionary logs = null) { - throw new NotImplementedException(); + this.wait = 0; + this.stopped_epoch = 0; + if (this.baseline != null) + { + this.best = this.baseline.Value; + } + else + { + this.best = this.monitor_op == Lesser ? float.PositiveInfinity : float.NegativeInfinity; + } } public override void OnEpochEnd(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + var current = this.GetMonitorValue(logs); + if (current == null) + { + return; + } + + if (this.monitor_op(current.Value - this.min_delta, this.best)) + { + this.best = current.Value; + this.wait = 0; + if (this.restore_best_weights) + { + this.best_weights = this.model.GetWeights(); + } + } + else + { + this.wait += 1; + if (this.wait >= this.patience) + { + this.stopped_epoch = epoch; + this.model.stop_training = true; + if (this.restore_best_weights) + { + if (this.verbose > 0) + { + Console.WriteLine("Restoring model weights from the end of the best epoch"); + } + + this.model.SetWeights(this.best_weights); + } + } + } } public override void OnTrainEnd(Dictionary logs = null) { - throw new NotImplementedException(); + if (this.stopped_epoch > 0 && this.verbose > 0) + { + Console.WriteLine($"Epoch {stopped_epoch + 1}: early stopping"); + } } - public float GetMonitorValue(Dictionary logs) + public float? GetMonitorValue(Dictionary logs) { - throw new NotImplementedException(); + float? monitor_value = null; + if (logs.ContainsKey(monitor)) + monitor_value = logs[this.monitor]; + + if (monitor_value == null) + { + Logger.Warning($"Early stopping conditioned on metric `{monitor}` which is not available. Available metrics are: {string.Join(",", logs.Keys)}"); + } + + return monitor_value; } } } diff --git a/src/MxNet.Keras/Callbacks/History.cs b/src/MxNet.Keras/Callbacks/History.cs index dae724d8..1c91121c 100644 --- a/src/MxNet.Keras/Callbacks/History.cs +++ b/src/MxNet.Keras/Callbacks/History.cs @@ -6,14 +6,29 @@ namespace MxNet.Keras.Callbacks { public class History : Callback { + public List epoch; + + public Dictionary> history; + public override void OnTrainBegin(Dictionary logs = null) { - throw new NotImplementedException(); + epoch = new List(); + history = new Dictionary>(); } public override void OnEpochEnd(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + logs = logs != null ? logs : new Dictionary(); + this.epoch.Add(epoch); + foreach (var log in logs) + { + var k = log.Key; + var v = log.Value; + if(!history.ContainsKey(k)) + this.history.Add(k, new List()); + + history[k].Add(v); + } } } } diff --git a/src/MxNet.Keras/Callbacks/LearningRateScheduler.cs b/src/MxNet.Keras/Callbacks/LearningRateScheduler.cs index 5738c36b..aeaf00d1 100644 --- a/src/MxNet.Keras/Callbacks/LearningRateScheduler.cs +++ b/src/MxNet.Keras/Callbacks/LearningRateScheduler.cs @@ -1,24 +1,47 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Callbacks { public class LearningRateScheduler : Callback { + public Func schedule; + + public int verbose; + public LearningRateScheduler(Func schedule, int verbose = 0) { - throw new NotImplementedException(); + this.schedule = schedule; + this.verbose = verbose; } public override void OnEpochBegin(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + var lr = this.model.optimizer.LearningRate; + try + { + lr = this.schedule(epoch, lr); + } + catch (Exception) + { + // old API for backward compatibility + lr = this.schedule(epoch, 0); + } + + this.model.optimizer.LearningRate = lr; + + if (this.verbose > 0) + { + Console.WriteLine($"\nEpoch {epoch + 1}: LearningRateScheduler setting learning rate to {lr}."); + } } public override void OnEpochEnd(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + logs = logs != null ? logs : new Dictionary(); + logs["lr"] = this.model.optimizer.LearningRate; } } } diff --git a/src/MxNet.Keras/Callbacks/ModelCheckpoint.cs b/src/MxNet.Keras/Callbacks/ModelCheckpoint.cs index 84d5c9ee..b93d62dc 100644 --- a/src/MxNet.Keras/Callbacks/ModelCheckpoint.cs +++ b/src/MxNet.Keras/Callbacks/ModelCheckpoint.cs @@ -6,15 +6,125 @@ namespace MxNet.Keras.Callbacks { public class ModelCheckpoint : Callback { + public float best; + + public int epochs_since_last_save; + + public string filepath; + + public string monitor; + + public Func monitor_op; + + public int period; + + public bool save_best_only; + + public bool save_weights_only; + + public int verbose; + public ModelCheckpoint(string filepath, string monitor= "val_loss", int verbose= 0, bool save_best_only= false, bool save_weights_only= false, string mode= "auto", int period= 1) { - throw new NotImplementedException(); + this.monitor = monitor; + this.verbose = verbose; + this.filepath = filepath; + this.save_best_only = save_best_only; + this.save_weights_only = save_weights_only; + this.period = period; + this.epochs_since_last_save = 0; + if (!new List { + "auto", + "min", + "max" + }.Contains(mode)) + { + Logger.Warning($"ModelCheckpoint mode {mode} is unknown, fallback to auto mode."); + mode = "auto"; + } + + if (mode == "min") + { + this.monitor_op = Lesser; + this.best = float.PositiveInfinity; + } + else if (mode == "max") + { + this.monitor_op = Greater; + this.best = float.NegativeInfinity; + } + else if (this.monitor.Contains("acc") || this.monitor.StartsWith("fmeasure")) + { + this.monitor_op = Greater; + this.best = float.NegativeInfinity; + } + else + { + this.monitor_op = Lesser; + this.best = float.NegativeInfinity; + } } + private bool Lesser(float l, float r) => l < r; + + private bool Greater(float l, float r) => l > r; public override void OnEpochEnd(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + logs = logs != null ? logs : new Dictionary(); + this.epochs_since_last_save += 1; + if (this.epochs_since_last_save >= this.period) + { + this.epochs_since_last_save = 0; + var filepath = string.Format(this.filepath, epoch); + if (this.save_best_only) + { + float? current = null; + if (logs.ContainsKey(monitor)) + current = logs[this.monitor]; + + if (current == null) + { + Logger.Warning($"Unable to calculate the metric for determining the best model. Can save best model only with {monitor} available, skipping."); + } + else if (this.monitor_op(current.Value, this.best)) + { + if (this.verbose > 0) + { + Console.WriteLine($"\nEpoch {epoch + 1}: {monitor} improved from {best} to {current}, saving model to {filepath}"); + } + + this.best = current.Value; + if (this.save_weights_only) + { + this.model.SaveWeights(filepath, overwrite: true); + } + else + { + this.model.Save(filepath, overwrite: true); + } + } + else if (this.verbose > 0) + { + Console.WriteLine($"\nEpoch {epoch + 1}: {monitor} did not improve from {best}"); + } + } + else + { + if (this.verbose > 0) + { + Console.WriteLine($"\nEpoch {epoch + 1}: saving model to {filepath}"); + } + if (this.save_weights_only) + { + this.model.SaveWeights(filepath, overwrite: true); + } + else + { + this.model.Save(filepath, overwrite: true); + } + } + } } } } diff --git a/src/MxNet.Keras/Callbacks/ProgbarLogger.cs b/src/MxNet.Keras/Callbacks/ProgbarLogger.cs index 4fc38cdc..015f9aee 100644 --- a/src/MxNet.Keras/Callbacks/ProgbarLogger.cs +++ b/src/MxNet.Keras/Callbacks/ProgbarLogger.cs @@ -1,4 +1,5 @@ -using System; +using MxNet.Keras.Utils; +using System; using System.Collections.Generic; using System.Text; @@ -6,34 +7,124 @@ namespace MxNet.Keras.Callbacks { public class ProgbarLogger : Callback { + public int epochs; + + public Dictionary log_values; + + public Progbar progbar; + + public int seen; + + public string[] stateful_metrics; + + public int target; + + public bool use_steps; + + public int verbose; + public ProgbarLogger(string count_mode = "samples", string[] stateful_metrics = null) { - throw new NotImplementedException(); + if (count_mode == "samples") + { + this.use_steps = false; + } + else if (count_mode == "steps") + { + this.use_steps = true; + } + else + { + throw new Exception("Unknown `count_mode`: " + count_mode.ToString()); + } + + if (stateful_metrics != null) + { + this.stateful_metrics = stateful_metrics; + } + else + { + this.stateful_metrics = new string[0]; + } } public override void OnTrainBegin(Dictionary logs = null) { - throw new NotImplementedException(); + this.verbose = (int)this.@params["verbose"]; + this.epochs = (int)this.@params["epochs"]; } public override void OnEpochBegin(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + if (this.verbose > 0) + { + Console.WriteLine(String.Format("Epoch %d/%d", epoch + 1, this.epochs)); + if (this.use_steps) + { + target = (int)this.@params["steps"]; + } + else { + target = (int)this.@params["samples"]; + } + this.progbar = new Progbar(target: this.target, verbose: this.verbose, stateful_metrics: this.stateful_metrics); + } + + this.seen = 0; } public override void OnBatchBegin(int batch, Dictionary logs = null) { - throw new NotImplementedException(); + if (this.seen < this.target) + { + this.log_values = new Dictionary(); + } } public override void OnBatchEnd(int batch, Dictionary logs = null) { - throw new NotImplementedException(); - } + logs = logs != null ? logs : new Dictionary(); + + var batch_size = logs.ContainsKey("size") ? (int)logs["size"] : 0; + if (this.use_steps) + { + this.seen += 1; + } + else + { + this.seen += batch_size; + } + + string[] metrics = (string[])this.@params["metrics"]; + foreach (var k in metrics) + { + if (logs.ContainsKey(k)) + { + this.log_values.Add(k, logs[k]); + } + } + + // Skip progbar update for the last batch; + // will be handled by on_epoch_end. + if (this.verbose > 0 && this.seen < this.target) { + this.progbar.Update(this.seen, this.log_values); +} +} public override void OnEpochEnd(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + logs = logs != null ? logs : new Dictionary(); + string[] metrics = (string[])this.@params["metrics"]; + foreach (var k in metrics) + { + if (logs.ContainsKey(k)) + { + this.log_values.Add(k, logs[k]); + } + } + if (this.verbose > 0) + { + this.progbar.Update(this.seen, this.log_values); + } } } } diff --git a/src/MxNet.Keras/Callbacks/ReduceLROnPlateau.cs b/src/MxNet.Keras/Callbacks/ReduceLROnPlateau.cs index befe41dc..a8edcf39 100644 --- a/src/MxNet.Keras/Callbacks/ReduceLROnPlateau.cs +++ b/src/MxNet.Keras/Callbacks/ReduceLROnPlateau.cs @@ -6,29 +6,134 @@ namespace MxNet.Keras.Callbacks { public class ReduceLROnPlateau : Callback { + public float best; + + public int cooldown; + + public int cooldown_counter; + + public float factor; + + public float min_delta; + + public float min_lr; + + public string mode; + + public string monitor; + + public Func monitor_op; + + public int patience; + + public int verbose; + + public int wait; + public ReduceLROnPlateau(string monitor= "val_loss", float factor= 0.1f, int patience= 10, int verbose= 0, string mode= "auto", float min_delta= 1e-4f, int cooldown= 0, float min_lr= 0) { - throw new NotImplementedException(); + this.monitor = monitor; + if (factor >= 1.0) + { + throw new Exception("ReduceLROnPlateau does not support a factor >= 1.0."); + } + + this.factor = factor; + this.min_lr = min_lr; + this.min_delta = min_delta; + this.patience = patience; + this.verbose = verbose; + this.cooldown = cooldown; + this.cooldown_counter = 0; + this.wait = 0; + this.best = 0; + this.mode = mode; + this.monitor_op = null; + this.Reset(); } internal void Reset() { - throw new NotImplementedException(); + if (!new List { + "auto", + "min", + "max" + }.Contains(this.mode)) + { + Logger.Warning($"Learning Rate Plateau Reducing mode {mode} is unknown, fallback to auto mode."); + this.mode = "auto"; + } + if (this.mode == "min" || this.mode == "auto" && !this.monitor.Contains("acc")) + { + this.monitor_op = (a, b) => (a < b - this.min_delta); + this.best = float.PositiveInfinity; + } + else + { + this.monitor_op = (a, b) => (a > b + this.min_delta); + this.best = float.NegativeInfinity; + } + + this.cooldown_counter = 0; + this.wait = 0; } public override void OnTrainBegin(Dictionary logs = null) { - throw new NotImplementedException(); + Reset(); } public override void OnEpochEnd(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + logs = logs != null ? logs : new Dictionary(); + logs["lr"] = this.model.optimizer.LearningRate; + float? current = null; + if (logs.ContainsKey(monitor)) + current = logs[monitor]; + + if (current == null) + { + Logger.Warning($"Reduce LR on plateau conditioned on metric `{monitor}` which is not available. Available metrics are: {string.Join(",", logs.Keys)}"); + } + else + { + if (this.InCoolDown()) + { + this.cooldown_counter -= 1; + this.wait = 0; + } + if (this.monitor_op(current.Value, this.best)) + { + this.best = current.Value; + this.wait = 0; + } + else if (!this.InCoolDown()) + { + this.wait += 1; + if (this.wait >= this.patience) + { + var old_lr = this.model.optimizer.LearningRate; + if (old_lr > this.min_lr) + { + var new_lr = old_lr * this.factor; + new_lr = Math.Max(new_lr, this.min_lr); + this.model.optimizer.LearningRate = new_lr; + if (this.verbose > 0) + { + Console.WriteLine($"\nEpoch {epoch + 1}: ReduceLROnPlateau reducing learning rate to {new_lr}."); + } + + this.cooldown_counter = this.cooldown; + this.wait = 0; + } + } + } + } } - public bool IsCoolDown() + public bool InCoolDown() { - throw new NotImplementedException(); + return this.cooldown_counter > 0; } } } diff --git a/src/MxNet.Keras/Callbacks/RemoteMonitor.cs b/src/MxNet.Keras/Callbacks/RemoteMonitor.cs index 95fe02c8..1d47852b 100644 --- a/src/MxNet.Keras/Callbacks/RemoteMonitor.cs +++ b/src/MxNet.Keras/Callbacks/RemoteMonitor.cs @@ -1,19 +1,68 @@ using System; using System.Collections.Generic; +using System.Dynamic; +using System.Net.Http; using System.Text; +using Flurl.Http; +using Flurl.Http.Content; namespace MxNet.Keras.Callbacks { public class RemoteMonitor : Callback { - public RemoteMonitor(string root= "http://localhost:9000", string path= "/publish/epoch/end/", string field= "data", Dictionary headers= null, bool send_as_json= false) + public string field; + + public Dictionary headers; + + public string path; + + public string root; + + public bool send_as_json; + + public RemoteMonitor(string root= "http://localhost:9000", string path= "/publish/epoch/end/", string field= "data", Dictionary headers= null, bool send_as_json= true) { - throw new NotImplementedException(); + this.root = root; + this.path = path; + this.field = field; + this.headers = headers == null ? new Dictionary() : headers; + this.send_as_json = send_as_json; } public override void OnEpochEnd(int epoch, Dictionary logs = null) { - throw new NotImplementedException(); + logs = logs != null ? logs : new Dictionary(); + var send = new Dictionary(); + + send["epoch"] = epoch; + foreach (var log in logs) + { + var k = log.Key; + var v = log.Value; + send[k] = v; + } + + try + { + var req = new FlurlRequest(new Flurl.Url(this.root + this.path)); + foreach (var item in headers) + { + req.Headers.Add(item.Key, item.Value); + } + + if (this.send_as_json) + { + req.PostJsonAsync(send).Wait(); + } + else + { + throw new NotSupportedException("Only json format supported"); + } + } + catch + { + Logger.Warning("Warning: could not reach RemoteMonitor root server at " + this.root.ToString()); + } } } } diff --git a/src/MxNet.Keras/Callbacks/TerminateOnNaN.cs b/src/MxNet.Keras/Callbacks/TerminateOnNaN.cs index 583116df..371245cc 100644 --- a/src/MxNet.Keras/Callbacks/TerminateOnNaN.cs +++ b/src/MxNet.Keras/Callbacks/TerminateOnNaN.cs @@ -8,7 +8,18 @@ public class TerminateOnNaN : Callback { public override void OnBatchEnd(int batch, Dictionary logs = null) { - throw new NotImplementedException(); + if (logs == null) + logs = new Dictionary(); + + if (logs.ContainsKey("loss")) + { + float loss = logs["loss"]; + if (float.IsNaN(loss) || float.IsInfinity(loss)) + { + Console.WriteLine($"Batch {batch}: Invalid loss, terminating training"); + this.model.stop_training = true; + } + } } } } diff --git a/src/MxNet.Keras/Constraints/Constraint.cs b/src/MxNet.Keras/Constraints/Constraint.cs index 62acf67c..ddcc6c09 100644 --- a/src/MxNet.Keras/Constraints/Constraint.cs +++ b/src/MxNet.Keras/Constraints/Constraint.cs @@ -8,6 +8,9 @@ public abstract class Constraint { public abstract KerasSymbol Call(KerasSymbol w); - public abstract ConfigDict GetConfig(); + public virtual ConfigDict GetConfig() + { + return new ConfigDict(); + } } } diff --git a/src/MxNet.Keras/Constraints/MaxNorm.cs b/src/MxNet.Keras/Constraints/MaxNorm.cs index dcf949f1..11ff00c2 100644 --- a/src/MxNet.Keras/Constraints/MaxNorm.cs +++ b/src/MxNet.Keras/Constraints/MaxNorm.cs @@ -1,19 +1,36 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Constraints { public class MaxNorm : Constraint { + private int axis; + + private float max_value; + + public MaxNorm(int max_value = 2, int axis = 0) + { + this.max_value = max_value; + this.axis = axis; + } + public override KerasSymbol Call(KerasSymbol w) { - throw new NotImplementedException(); + var norms = K.Sqrt(K.Sum(K.Square(w), axis: this.axis, keepdims: true)); + var desired = K.Clip(norms, 0, this.max_value); + w *= desired / (K.Epsilon() + norms); + return w; } public override ConfigDict GetConfig() { - throw new NotImplementedException(); + ConfigDict dict = new ConfigDict(); + dict.Add("max_value", max_value); + dict.Add("axis", axis); + return dict; } } } diff --git a/src/MxNet.Keras/Constraints/MinMaxNorm.cs b/src/MxNet.Keras/Constraints/MinMaxNorm.cs index dae9742d..9164fe99 100644 --- a/src/MxNet.Keras/Constraints/MinMaxNorm.cs +++ b/src/MxNet.Keras/Constraints/MinMaxNorm.cs @@ -1,24 +1,53 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Constraints { public class MinMaxNorm : Constraint { - public MinMaxNorm(float min_value= 0, float max_value = 1, float rate = 1, int axis = 0) + public int axis; + + public float max_value; + + public float min_value; + + public float rate; + + public MinMaxNorm(float min_value = 0, float max_value = 1, float rate = 1, int axis = 0) { - throw new NotImplementedException(); + this.min_value = min_value; + this.max_value = max_value; + this.rate = rate; + this.axis = axis; } public override KerasSymbol Call(KerasSymbol w) { - throw new NotImplementedException(); + var norms = K.Sqrt(K.Sum(K.Square(w), axis: this.axis, keepdims: true)); + var desired = this.rate * K.Clip(norms, this.min_value, this.max_value) + (1 - this.rate) * norms; + w *= desired / (K.Epsilon() + norms); + return w; } public override ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict { + { + "min_value", + this.min_value}, + { + "max_value", + this.max_value}, + { + "rate", + this.rate}, + { + "axis", + this.axis + } + }; } } } diff --git a/src/MxNet.Keras/Constraints/NonNeg.cs b/src/MxNet.Keras/Constraints/NonNeg.cs index 15e10dd6..18fc244a 100644 --- a/src/MxNet.Keras/Constraints/NonNeg.cs +++ b/src/MxNet.Keras/Constraints/NonNeg.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Constraints { @@ -8,12 +9,8 @@ public class NonNeg : Constraint { public override KerasSymbol Call(KerasSymbol w) { - throw new NotImplementedException(); - } - - public override ConfigDict GetConfig() - { - throw new NotImplementedException(); + w *= K.Cast(K.GreaterEqual(w, 0), K.FloatX()); + return w; } } } diff --git a/src/MxNet.Keras/Constraints/UnitNorm.cs b/src/MxNet.Keras/Constraints/UnitNorm.cs index aa0a07fb..1b57eda8 100644 --- a/src/MxNet.Keras/Constraints/UnitNorm.cs +++ b/src/MxNet.Keras/Constraints/UnitNorm.cs @@ -1,24 +1,29 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Constraints { public class UnitNorm : Constraint { + private int axis; + public UnitNorm(int axis = 0) { - throw new NotImplementedException(); + this.axis = axis; } public override KerasSymbol Call(KerasSymbol w) { - throw new NotImplementedException(); + return w / (K.Epsilon() + K.Sqrt(K.Sum(K.Square(w), axis: this.axis, keepdims: true))); } public override ConfigDict GetConfig() { - throw new NotImplementedException(); + ConfigDict dict = new ConfigDict(); + dict.Add("axis", axis); + return dict; } } } diff --git a/src/MxNet.Keras/Datasets/BostonHousing.cs b/src/MxNet.Keras/Datasets/BostonHousing.cs index c374b25c..0b6f0fa9 100644 --- a/src/MxNet.Keras/Datasets/BostonHousing.cs +++ b/src/MxNet.Keras/Datasets/BostonHousing.cs @@ -1,5 +1,8 @@ -using System; +using MxNet.Keras.Utils; +using NumpyDotNet; +using System; using System.Collections.Generic; +using System.Diagnostics; using System.Text; namespace MxNet.Keras.Datasets @@ -8,7 +11,26 @@ public class BostonHousing { public static ((NDArray, NDArray), (NDArray, NDArray)) LoadData(string path= "boston_housing.npz", float test_split= 0.2f, int seed= 113) { - throw new NotImplementedException(); + Debug.Assert(0 <= test_split && test_split < 1); + path = DataUtils.GetFile(path, origin: "https://s3.amazonaws.com/keras-datasets/boston_housing.npz", file_hash: "f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"); + + var arrays = NDArray.LoadNpz(path); + var x = arrays[0]; + var y = arrays[1]; + + mx.Seed(seed); + NDArray indices = np.arange(x.Shape[0]); + indices = nd.Shuffle(indices.AsType(DType.Int32)); + x = x[indices]; + y = y[indices]; + int n = x.Shape[0]; + int test_n = Convert.ToInt32(test_split * n); + + var x_train = x[$":{test_n}"]; + var y_train = y[$":{test_n}"]; + var x_test = x[$"{test_n}:"]; + var y_test = y[$"{test_n}:"]; + return ((x_train, y_train), (x_test, y_test)); } } } diff --git a/src/MxNet.Keras/Datasets/IMDB.cs b/src/MxNet.Keras/Datasets/IMDB.cs index 46de3bf1..f702fa22 100644 --- a/src/MxNet.Keras/Datasets/IMDB.cs +++ b/src/MxNet.Keras/Datasets/IMDB.cs @@ -1,4 +1,6 @@ -using System; +using MxNet.Keras.Utils; +using NumpyDotNet; +using System; using System.Collections.Generic; using System.Text; @@ -8,6 +10,86 @@ public class IMDB { public static ((NDArray, NDArray), (NDArray, NDArray)) LoadData(string path= "imdb.npz", int? num_words= null, int skip_top= 0, int? maxlen= null, int seed= 113, int start_char= 1, int oov_char= 2, int index_from= 3) { + path = DataUtils.GetFile(path, origin: "https://s3.amazonaws.com/text-datasets/imdb.npz", file_hash: "599dadb1135973df5b59232a0e9a887c"); + var arrays = NDArray.LoadNpz(path); + var x_train = arrays[0]; + var labels_train = arrays[1]; + var x_test = arrays[2]; + var labels_test = arrays[4]; + + mx.Seed(seed); + NDArray indices = nd.Arange(0, x_train.Shape[0]); + indices = nd.Shuffle(indices.AsType(DType.Int32)); + + x_train = x_train[indices]; + labels_train = labels_train[indices]; + + indices = nd.Arange(0, x_test.Shape[0]); + indices = nd.Shuffle(indices.AsType(DType.Int32)); + x_test = x_test[indices]; + labels_test = labels_test[indices]; + ndarray xs = nd.Concat(new List { + x_train, + x_test + }); + + ndarray labels = nd.Concat(new List { + labels_train, + labels_test + }); + + //if (start_char != 0) + //{ + // xs = (from x in xs + // select (new List { + // start_char + // } + (from w in x + // select (w + index_from)).ToList())).ToList(); + //} + //else if (index_from) + //{ + // xs = (from x in xs + // select (from w in x + // select (w + index_from)).ToList()).ToList(); + //} + //if (maxlen) + //{ + // var _tup_1 = _remove_long_seq(maxlen, xs, labels); + // xs = _tup_1.Item1; + // labels = _tup_1.Item2; + // if (!xs) + // { + // throw new ValueError("After filtering for sequences shorter than maxlen=" + maxlen.ToString() + ", no sequence was kept. Increase maxlen."); + // } + //} + //if (!num_words) + //{ + // num_words = max((from x in xs + // select max(x)).ToList()); + //} + //// by convention, use 2 as OOV word + //// reserve 'index_from' (=3 by default) characters: + //// 0 (padding), 1 (start), 2 (OOV) + //if (oov_char != null) + //{ + // xs = (from x in xs + // select (from w in x + // select skip_top <= w < num_words ? w : oov_char).ToList()).ToList(); + //} + //else + //{ + // xs = (from x in xs + // select (from w in x + // where skip_top <= w < num_words + // select w).ToList()).ToList(); + //} + //var idx = x_train.Count; + //x_train = np.array(xs[::idx]); + //var y_train = np.array(labels[::idx]); + //x_test = np.array(xs[idx]); + //var y_test = np.array(labels[idx]); + //return Tuple.Create((x_train, y_train), (x_test, y_test)); + throw new NotImplementedException(); } } diff --git a/src/MxNet.Keras/Engine/Model.cs b/src/MxNet.Keras/Engine/Model.cs index 4993f307..e9ce5721 100644 --- a/src/MxNet.Keras/Engine/Model.cs +++ b/src/MxNet.Keras/Engine/Model.cs @@ -4,14 +4,20 @@ using System.Collections.Generic; using System.Text; using MxNet.Keras.Utils; +using MxNet.Keras.Optimizers; + namespace MxNet.Keras.Engine { - public class Model + public class Model : Network { internal NDArrayDict _args; internal NDArrayDict _auxs; + internal bool stop_training; + + internal Optimizer optimizer; + public void Compile(Optimizer optimizer, string loss= null, string[] metrics= null, float[] loss_weights= null, string sample_weight_mode= null) { throw new NotImplementedException(); diff --git a/src/MxNet.Keras/Engine/Network.cs b/src/MxNet.Keras/Engine/Network.cs index bd6fc69c..3c717b9c 100644 --- a/src/MxNet.Keras/Engine/Network.cs +++ b/src/MxNet.Keras/Engine/Network.cs @@ -103,12 +103,12 @@ public void ResetStates() throw new NotImplementedException(); } - public Symbol[] GetWeights() + public NDArrayList GetWeights() { throw new NotImplementedException(); } - public void GetWeights(Symbol[] weights) + public void SetWeights(NDArrayList weights) { throw new NotImplementedException(); } diff --git a/src/MxNet.Keras/History.cs b/src/MxNet.Keras/History.cs deleted file mode 100644 index 68d08ea5..00000000 --- a/src/MxNet.Keras/History.cs +++ /dev/null @@ -1,10 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace MxNet.Keras -{ - public class History : Dictionary - { - } -} diff --git a/src/MxNet.Keras/Initializers/Constant.cs b/src/MxNet.Keras/Initializers/Constant.cs index 5fd1790d..ce2c6b67 100644 --- a/src/MxNet.Keras/Initializers/Constant.cs +++ b/src/MxNet.Keras/Initializers/Constant.cs @@ -1,24 +1,32 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Initializers { public class Constant : Initializer { + private readonly float value; + public Constant(float value = 0) { - throw new NotImplementedException(); + this.value = value; } - public override KerasSymbol Call(Shape shap, DType dtype = null) + public override KerasSymbol Call(Shape shape, DType dtype = null) { - throw new NotImplementedException(); + return K.Constant(0, shape: shape, dtype: dtype); } public override ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict { + { + "value", + this.value + } + }; } } } diff --git a/src/MxNet.Keras/Initializers/GlorotNormal.cs b/src/MxNet.Keras/Initializers/GlorotNormal.cs new file mode 100644 index 00000000..ef0ed070 --- /dev/null +++ b/src/MxNet.Keras/Initializers/GlorotNormal.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace MxNet.Keras.Initializers +{ + public class GlorotNormal : VarianceScaling + { + public GlorotNormal(int? seed) : base(scale: 1, mode: "fan_avg", distribution: "normal", seed: seed) + { + + } + } +} diff --git a/src/MxNet.Keras/Initializers/GlorotUniform.cs b/src/MxNet.Keras/Initializers/GlorotUniform.cs new file mode 100644 index 00000000..cc17e9e7 --- /dev/null +++ b/src/MxNet.Keras/Initializers/GlorotUniform.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace MxNet.Keras.Initializers +{ + public class GlorotUniform : VarianceScaling + { + public GlorotUniform(int? seed) : base(scale: 1, mode: "fan_avg", distribution: "uniform", seed: seed) + { + + } + } +} diff --git a/src/MxNet.Keras/Initializers/HeNormal.cs b/src/MxNet.Keras/Initializers/HeNormal.cs new file mode 100644 index 00000000..8bea03c9 --- /dev/null +++ b/src/MxNet.Keras/Initializers/HeNormal.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace MxNet.Keras.Initializers +{ + public class HeNormal : VarianceScaling + { + public HeNormal(int? seed) : base(scale: 2, mode: "fan_in", distribution: "normal", seed: seed) + { + + } + } +} diff --git a/src/MxNet.Keras/Initializers/HeUniform.cs b/src/MxNet.Keras/Initializers/HeUniform.cs new file mode 100644 index 00000000..2eede1f9 --- /dev/null +++ b/src/MxNet.Keras/Initializers/HeUniform.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace MxNet.Keras.Initializers +{ + public class HeUniform : VarianceScaling + { + public HeUniform(int? seed) : base(scale: 2, mode: "fan_in", distribution: "uniform", seed: seed) + { + + } + } +} diff --git a/src/MxNet.Keras/Initializers/Identity.cs b/src/MxNet.Keras/Initializers/Identity.cs index 794f1870..43a37dc4 100644 --- a/src/MxNet.Keras/Initializers/Identity.cs +++ b/src/MxNet.Keras/Initializers/Identity.cs @@ -1,24 +1,66 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Initializers { public class Identity : Initializer { + private readonly float gain; + public Identity(float gain = 1) { - throw new NotImplementedException(); + this.gain = gain; } - public override KerasSymbol Call(Shape shap, DType dtype = null) + public override KerasSymbol Call(Shape shape, DType dtype = null) { - throw new NotImplementedException(); + if (shape.Dimension != 2) + { + throw new Exception("Identity matrix initializer can only be used for 2D matrices."); + } + + if (shape.Data.Max() % shape.Data.Min() != 0) + { + throw new Exception("Long side should be multiple of short side."); + } + + if (shape[0] == shape[1]) + { + return this.gain * sym.Eye(shape[0]); + } + else if (shape[0] > shape[1]) + { + List list = new List(); + for(int i = 0; i< shape[0] / shape[1]; i++) + { + list.Add(sym.Eye(shape[1])); + } + + return this.gain * sym.Concat(list, dim: 0); + } + else + { + List list = new List(); + for (int i = 0; i < shape[1] / shape[0]; i++) + { + list.Add(sym.Eye(shape[1])); + } + + return this.gain * sym.Concat(list, dim: 0); + } } public override ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict { + { + "gain", + this.gain + } + }; } } } diff --git a/src/MxNet.Keras/Initializers/Initializer.cs b/src/MxNet.Keras/Initializers/Initializer.cs index af681872..2bde8c23 100644 --- a/src/MxNet.Keras/Initializers/Initializer.cs +++ b/src/MxNet.Keras/Initializers/Initializer.cs @@ -8,6 +8,9 @@ public abstract class Initializer { public abstract KerasSymbol Call(Shape shap, DType dtype = null); - public abstract ConfigDict GetConfig(); + public virtual ConfigDict GetConfig() + { + return new ConfigDict(); + } } } diff --git a/src/MxNet.Keras/Initializers/LecunNormal.cs b/src/MxNet.Keras/Initializers/LecunNormal.cs new file mode 100644 index 00000000..8ae7e3c3 --- /dev/null +++ b/src/MxNet.Keras/Initializers/LecunNormal.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace MxNet.Keras.Initializers +{ + public class LecunNormal : VarianceScaling + { + public LecunNormal(int? seed) : base(scale: 1, mode: "fan_in", distribution: "normal", seed: seed) + { + + } + } +} diff --git a/src/MxNet.Keras/Initializers/LecunUniform.cs b/src/MxNet.Keras/Initializers/LecunUniform.cs new file mode 100644 index 00000000..b47ed62d --- /dev/null +++ b/src/MxNet.Keras/Initializers/LecunUniform.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace MxNet.Keras.Initializers +{ + public class LecunUniform : VarianceScaling + { + public LecunUniform(int? seed) : base(scale: 1, mode: "fan_in", distribution: "uniform", seed: seed) + { + + } + } +} diff --git a/src/MxNet.Keras/Initializers/Ones.cs b/src/MxNet.Keras/Initializers/Ones.cs index c9bdf47b..d2a02b92 100644 --- a/src/MxNet.Keras/Initializers/Ones.cs +++ b/src/MxNet.Keras/Initializers/Ones.cs @@ -1,19 +1,15 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Initializers { public class Ones : Initializer { - public override KerasSymbol Call(Shape shap, DType dtype = null) + public override KerasSymbol Call(Shape shape, DType dtype = null) { - throw new NotImplementedException(); - } - - public override ConfigDict GetConfig() - { - throw new NotImplementedException(); + return K.Constant(1, shape: shape, dtype: dtype); } } } diff --git a/src/MxNet.Keras/Initializers/Orthogonal.cs b/src/MxNet.Keras/Initializers/Orthogonal.cs deleted file mode 100644 index 3b9be165..00000000 --- a/src/MxNet.Keras/Initializers/Orthogonal.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace MxNet.Keras.Initializers -{ - public class Orthogonal : Initializer - { - public Orthogonal(float gain = 1, int? seed = null) - { - throw new NotImplementedException(); - } - - public override KerasSymbol Call(Shape shap, DType dtype = null) - { - throw new NotImplementedException(); - } - - public override ConfigDict GetConfig() - { - throw new NotImplementedException(); - } - } -} diff --git a/src/MxNet.Keras/Initializers/RandomNormal.cs b/src/MxNet.Keras/Initializers/RandomNormal.cs index 8c847ec4..4e174907 100644 --- a/src/MxNet.Keras/Initializers/RandomNormal.cs +++ b/src/MxNet.Keras/Initializers/RandomNormal.cs @@ -1,24 +1,44 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Initializers { public class RandomNormal : Initializer { + private float mean; + + private int? seed; + + private float stddev; + public RandomNormal(float mean = 0, float stddev = 0.05f, int? seed = null) { - throw new NotImplementedException(); + this.mean = mean; + this.stddev = stddev; + this.seed = seed; } - public override KerasSymbol Call(Shape shap, DType dtype = null) + public override KerasSymbol Call(Shape shape, DType dtype = null) { - throw new NotImplementedException(); + return K.RandomNormal(shape, this.mean, this.stddev, dtype: dtype, seed: this.seed); } public override ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict { + { + "mean", + this.mean}, + { + "stddev", + this.stddev}, + { + "seed", + this.seed + } + }; } } } diff --git a/src/MxNet.Keras/Initializers/RandomUniform.cs b/src/MxNet.Keras/Initializers/RandomUniform.cs index c30a1c84..f10c7ce2 100644 --- a/src/MxNet.Keras/Initializers/RandomUniform.cs +++ b/src/MxNet.Keras/Initializers/RandomUniform.cs @@ -1,24 +1,44 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Initializers { public class RandomUniform : Initializer { + private float maxval; + + private float minval; + + private int? seed; + public RandomUniform(float minval= -0.05f, float maxval= 0.05f, int? seed= null) { - throw new NotImplementedException(); + this.minval = minval; + this.maxval = maxval; + this.seed = seed; } - public override KerasSymbol Call(Shape shap, DType dtype = null) + public override KerasSymbol Call(Shape shape, DType dtype = null) { - throw new NotImplementedException(); + return K.RandomUniform(shape, this.minval, this.maxval, dtype: dtype, seed: this.seed); } public override ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict { + { + "minval", + this.minval}, + { + "maxval", + this.maxval}, + { + "seed", + this.seed + } + }; } } } diff --git a/src/MxNet.Keras/Initializers/TruncatedNormal.cs b/src/MxNet.Keras/Initializers/TruncatedNormal.cs index f733c1d4..60e74d7b 100644 --- a/src/MxNet.Keras/Initializers/TruncatedNormal.cs +++ b/src/MxNet.Keras/Initializers/TruncatedNormal.cs @@ -1,24 +1,44 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Initializers { public class TruncatedNormal : Initializer { + private float mean; + + private int? seed; + + private float stddev; + public TruncatedNormal(float mean = 0, float stddev = 0.05f, int? seed = null) { - throw new NotImplementedException(); + this.mean = mean; + this.stddev = stddev; + this.seed = seed; } - public override KerasSymbol Call(Shape shap, DType dtype = null) + public override KerasSymbol Call(Shape shape, DType dtype = null) { - throw new NotImplementedException(); + return K.TruncatedNormal(shape, this.mean, this.stddev, dtype: dtype, seed: this.seed); } public override ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict { + { + "mean", + this.mean}, + { + "stddev", + this.stddev}, + { + "seed", + this.seed + } + }; } } } diff --git a/src/MxNet.Keras/Initializers/VarianceScaling.cs b/src/MxNet.Keras/Initializers/VarianceScaling.cs index bd708e95..769f5cae 100644 --- a/src/MxNet.Keras/Initializers/VarianceScaling.cs +++ b/src/MxNet.Keras/Initializers/VarianceScaling.cs @@ -1,24 +1,132 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Initializers { public class VarianceScaling : Initializer { + private string distribution; + + private string mode; + + private float scale; + + private int? seed; + public VarianceScaling(float scale= 1, string mode= "fan_in", string distribution= "normal", int? seed= null) { - throw new NotImplementedException(); + if (scale <= 0) + { + throw new Exception("`scale` must be a positive float. Got:" + scale); + } + + mode = mode.ToLower(); + if (!(new string[]{ "fan_in", "fan_out", "fan_avg"}).Contains(mode)) { + throw new Exception("Invalid `mode` argument: expected on of {\"fan_in\", \"fan_out\", \"fan_avg\"} but got " + mode); + } + + distribution = distribution.ToLower(); + if (!(new string[]{ "normal", "uniform"}).Contains(distribution)) { + throw new Exception("Invalid `distribution` argument: expected one of {\"normal\", \"uniform\"} but got " + distribution); + } + + this.scale = scale; + this.mode = mode; + this.distribution = distribution; + this.seed = seed; + } + + public static (int, int) _compute_fans(Shape shape, string data_format = "channels_last") + { + int receptive_field_size; + int fan_out; + int fan_in; + if (shape.Dimension == 2) + { + fan_in = shape[0]; + fan_out = shape[1]; + } + + else if ((new int[]{3, 4, 5}).Contains(shape.Dimension)) + { + // Assuming convolution kernels (1D, 2D or 3D). + // TH kernel shape: (depth, input_depth, ...) + // TF kernel shape: (..., input_depth, depth) + if (data_format == "channels_first") + { + receptive_field_size = shape.Data.Skip(2).ToList().Aggregate((a, b) => a * b); + fan_in = shape[1] * receptive_field_size; + fan_out = shape[0] * receptive_field_size; + } + else if (data_format == "channels_last") + { + receptive_field_size = shape.Data.Take(shape.Data.Length - 2).ToList().Aggregate((a, b) => a * b); + fan_in = shape[shape.Data.Length - 2] * receptive_field_size; + fan_out = shape[shape.Data.Length - 1] * receptive_field_size; + } + else + { + throw new Exception("Invalid data_format: " + data_format); + } + } + else + { + // No specific assumptions. + fan_in = Convert.ToInt32(Math.Sqrt(shape.Size)); + fan_out = Convert.ToInt32(Math.Sqrt(shape.Size)); + } + return (fan_in, fan_out); } - public override KerasSymbol Call(Shape shap, DType dtype = null) + public override KerasSymbol Call(Shape shape, DType dtype = null) { - throw new NotImplementedException(); + var (fan_in, fan_out) = _compute_fans(shape); + var scale = this.scale; + if (this.mode == "fan_in") + { + scale /= Math.Max(1, fan_in); + } + else if (this.mode == "fan_out") + { + scale /= Math.Max(1, fan_out); + } + else + { + scale /= Math.Max(1, (fan_in + fan_out) / 2); + } + if (this.distribution == "normal") + { + // 0.879... = scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + float stddev = (float)Math.Sqrt(scale) / 0.8796256610342398f; + return K.TruncatedNormal(shape, 0, stddev, dtype: dtype, seed: this.seed); + } + else + { + float limit = (float)Math.Sqrt(3.0 * scale); + return K.RandomUniform(shape, -limit, limit, dtype: dtype, seed: this.seed); + } } public override ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict{ + { + "scale", + this.scale}, + { + "mode", + this.mode}, + { + "distribution", + this.distribution}, + { + "seed", + this.seed + } + }; } } } diff --git a/src/MxNet.Keras/Initializers/Zeros.cs b/src/MxNet.Keras/Initializers/Zeros.cs index 0daaa017..17655a1b 100644 --- a/src/MxNet.Keras/Initializers/Zeros.cs +++ b/src/MxNet.Keras/Initializers/Zeros.cs @@ -1,19 +1,15 @@ using System; using System.Collections.Generic; using System.Text; +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Initializers { public class Zeros : Initializer { - public override KerasSymbol Call(Shape shap, DType dtype = null) + public override KerasSymbol Call(Shape shape, DType dtype = null) { - throw new NotImplementedException(); - } - - public override ConfigDict GetConfig() - { - throw new NotImplementedException(); + return K.Constant(0, shape: shape, dtype: dtype); } } } diff --git a/src/MxNet.Keras/MxNet.Keras.csproj b/src/MxNet.Keras/MxNet.Keras.csproj index 783308c6..21fe0658 100644 --- a/src/MxNet.Keras/MxNet.Keras.csproj +++ b/src/MxNet.Keras/MxNet.Keras.csproj @@ -5,6 +5,7 @@ + diff --git a/src/MxNet.Keras/Optimizers/Adadelta.cs b/src/MxNet.Keras/Optimizers/Adadelta.cs index a761552b..d8b443ea 100644 --- a/src/MxNet.Keras/Optimizers/Adadelta.cs +++ b/src/MxNet.Keras/Optimizers/Adadelta.cs @@ -6,17 +6,35 @@ namespace MxNet.Keras.Optimizers { public class Adadelta : MxNet.Optimizers.AdaDelta, IOptimizer { - public Adadelta(float lr = 1, float rho = 0.95F, float epsilon = 1E-08F, float decay = 0, float? clipnorm = null) : base(rho, decay, epsilon) + public Adadelta(float lr = 1, float rho = 0.95F, float epsilon = 1E-08F, float decay = 0, float? clipnorm = null) : base(lr, rho, decay, epsilon) { - throw new NotImplementedException(); + Lr = lr; + Decay = decay; + ClipGradient = clipnorm; } - public KerasSymbol Lr { get; set; } - public KerasSymbol Decay { get; set; } + public float Lr { get; set; } + public float Decay { get; set; } public ConfigDict GetConfig() { - throw new NotImplementedException(); + var config = new ConfigDict { + { + "lr", + this.LearningRate}, + { + "rho", + this.Rho}, + { + "decay", + this.Decay}, + { + "epsilon", + this.Epsilon + } + }; + + return config; } } } diff --git a/src/MxNet.Keras/Optimizers/Adagrad.cs b/src/MxNet.Keras/Optimizers/Adagrad.cs index 4095ca70..ba89bf56 100644 --- a/src/MxNet.Keras/Optimizers/Adagrad.cs +++ b/src/MxNet.Keras/Optimizers/Adagrad.cs @@ -6,17 +6,33 @@ namespace MxNet.Keras.Optimizers { public class Adagrad : MxNet.Optimizers.AdaGrad, IOptimizer { - public Adagrad(float lr = 0.01f, float eps = 1e-08F, float? clipnorm = null) : base(lr, epsilon: eps) + public Adagrad(float lr = 0.01f, float eps = 1e-08F, float decay = 0, float? clipnorm = null) : base(lr, epsilon: eps) { - throw new NotImplementedException(); + Lr = lr; + Decay = decay; + ClipGradient = clipnorm; } - public KerasSymbol Lr { get; set; } - public KerasSymbol Decay { get; set; } + public float Lr { get; set; } + public float Decay { get; set; } public ConfigDict GetConfig() { - throw new NotImplementedException(); + var config = new ConfigDict { + { + "lr", + this.Lr + }, + { + "decay", + this.Decay}, + { + "epsilon", + this.Epsilon + } + }; + + return config; } } } diff --git a/src/MxNet.Keras/Optimizers/Adam.cs b/src/MxNet.Keras/Optimizers/Adam.cs index b74d13b6..20df147e 100644 --- a/src/MxNet.Keras/Optimizers/Adam.cs +++ b/src/MxNet.Keras/Optimizers/Adam.cs @@ -8,15 +8,34 @@ public class Adam : MxNet.Optimizers.Adam, IOptimizer { public Adam(float lr = 0.001F, float beta1 = 0.9F, float beta2 = 0.999F, float epsilon = 1E-08F, float decay = 0, float? clipnorm = null) : base(lr, beta1, beta2, epsilon) { - throw new NotImplementedException(); + Lr = lr; + Decay = decay; + ClipGradient = clipnorm; } - public KerasSymbol Lr { get; set; } - public KerasSymbol Decay { get; set; } + public float Lr { get; set; } + public float Decay { get; set; } public ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict { + { + "lr", + this.Lr}, + { + "beta_1", + this.Beta1}, + { + "beta_2", + this.Beta2}, + { + "decay", + this.Decay}, + { + "epsilon", + this.Epsilon + } + }; } } } diff --git a/src/MxNet.Keras/Optimizers/Adamax.cs b/src/MxNet.Keras/Optimizers/Adamax.cs index e34c0036..5dd31f11 100644 --- a/src/MxNet.Keras/Optimizers/Adamax.cs +++ b/src/MxNet.Keras/Optimizers/Adamax.cs @@ -8,15 +8,31 @@ public class Adamax : MxNet.Optimizers.Adamax, IOptimizer { public Adamax(float lr = 0, float beta1 = 0.9F, float beta2 = 0.999F, float decay = 0, float? clipnorm = null) : base(lr, beta1, beta2) { - throw new NotImplementedException(); + Lr = lr; + Decay = decay; + ClipGradient = clipnorm; } - public KerasSymbol Lr { get; set; } - public KerasSymbol Decay { get; set; } + public float Lr { get; set; } + public float Decay { get; set; } public ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict{ + { + "lr", + this.Lr}, + { + "beta_1", + this.Beta1}, + { + "beta_2", + this.Beta2}, + { + "decay", + this.Decay + } + }; } } } diff --git a/src/MxNet.Keras/Optimizers/Nadam.cs b/src/MxNet.Keras/Optimizers/Nadam.cs index 74a4510f..1e890bf1 100644 --- a/src/MxNet.Keras/Optimizers/Nadam.cs +++ b/src/MxNet.Keras/Optimizers/Nadam.cs @@ -8,15 +8,34 @@ public class Nadam : MxNet.Optimizers.Nadam, IOptimizer { public Nadam(float lr = 0.001F, float beta1 = 0.9F, float beta2 = 0.999F, float epsilon = 1E-08F, float decay = 0, float? clipnorm = null, float schedule_decay = 0.004F) : base(lr, beta1, beta2, epsilon, schedule_decay) { - throw new NotImplementedException(); + Lr = lr; + Decay = decay; + ClipGradient = clipnorm; } - public KerasSymbol Lr { get; set; } - public KerasSymbol Decay { get; set; } + public float Lr { get; set; } + public float Decay { get; set; } public ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict{ + { + "lr", + this.Lr}, + { + "beta_1", + this.Beta1}, + { + "beta_2", + this.Beta2}, + { + "epsilon", + this.Epsilon}, + { + "schedule_decay", + this.ScheduleDecay + } + }; } } } diff --git a/src/MxNet.Keras/Optimizers/Optimizer.cs b/src/MxNet.Keras/Optimizers/Optimizer.cs index c60edd3a..62c91668 100644 --- a/src/MxNet.Keras/Optimizers/Optimizer.cs +++ b/src/MxNet.Keras/Optimizers/Optimizer.cs @@ -6,9 +6,9 @@ namespace MxNet.Keras.Optimizers { public interface IOptimizer { - KerasSymbol Lr { get; set; } + float Lr { get; set; } - KerasSymbol Decay { get; set; } + float Decay { get; set; } ConfigDict GetConfig(); } diff --git a/src/MxNet.Keras/Optimizers/RMSprop.cs b/src/MxNet.Keras/Optimizers/RMSprop.cs index 3a358aef..291a625f 100644 --- a/src/MxNet.Keras/Optimizers/RMSprop.cs +++ b/src/MxNet.Keras/Optimizers/RMSprop.cs @@ -6,12 +6,38 @@ namespace MxNet.Keras.Optimizers { public class RMSprop : MxNet.Optimizers.RMSProp, IOptimizer { - public KerasSymbol Lr { get; set; } - public KerasSymbol Decay { get; set; } + public RMSprop(float lr = 0.001f, + float rho = 0.9f, + float epsilon = 1E-08f, + float decay = 0, + float? clipnorm = null) + : base(learning_rate: lr, gamma1: rho, epsilon: epsilon, clip_weights: clipnorm) + { + Lr = lr; + Decay = decay; + ClipGradient = clipnorm; + } + + public float Lr { get; set; } + public float Decay { get; set; } public ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict{ + { + "lr", + this.Lr}, + { + "rho", + this.Gamma1}, + { + "decay", + this.Decay}, + { + "epsilon", + this.Epsilon + } + }; } } } diff --git a/src/MxNet.Keras/Optimizers/SGD.cs b/src/MxNet.Keras/Optimizers/SGD.cs index a1e322ef..07348b9e 100644 --- a/src/MxNet.Keras/Optimizers/SGD.cs +++ b/src/MxNet.Keras/Optimizers/SGD.cs @@ -6,12 +6,16 @@ namespace MxNet.Keras.Optimizers { public class SGD : MxNet.Optimizers.SGD, IOptimizer { - public KerasSymbol Lr { get; set; } - public KerasSymbol Decay { get; set; } + public float Lr { get; set; } + public float Decay { get; set; } + + private int aggregate_num = 1; public SGD(float lr = 0.01f, float momentum = 0, float decay = 0, bool nesterov = false, float? clipnorm = null) : base(lr, momentum) { - throw new NotImplementedException(); + Lr = lr; + Decay = decay; + ClipGradient = clipnorm; } public override float GetLr(int index) @@ -26,7 +30,20 @@ public override float[] GetLrs(int[] indices) public ConfigDict GetConfig() { - throw new NotImplementedException(); + var config = new ConfigDict { + { + "lr", + this.Lr}, + { + "momentum", + this.momentum}, + { + "decay", + this.Decay + } + }; + + return config; } } } diff --git a/src/MxNet.Keras/Regularizers/L1L2.cs b/src/MxNet.Keras/Regularizers/L1L2.cs index 08473c7f..7df34dbd 100644 --- a/src/MxNet.Keras/Regularizers/L1L2.cs +++ b/src/MxNet.Keras/Regularizers/L1L2.cs @@ -1,24 +1,48 @@ using System; using System.Collections.Generic; using System.Text; - +using K = MxNet.Keras.MxNetBackend; namespace MxNet.Keras.Regularizers { public class L1L2 : Regularizer { + public float l1; + + public float l2; + public L1L2(float l1 = 0, float l2 = 0) { - throw new NotImplementedException(); + this.l1 = l1; + this.l2 = l2; } public override KerasSymbol Call(KerasSymbol x) { - throw new NotImplementedException(); + KerasSymbol regularization = null; + if (this.l1 > 0) + { + regularization += K.Sum(this.l1 * K.Abs(x), null); + } + if (this.l2 > 0) + { + regularization += K.Sum(this.l2 * K.Square(x), null); + } + + return regularization; } public ConfigDict GetConfig() { - throw new NotImplementedException(); + return new ConfigDict { + { + "l1", + this.l1 + }, + { + "l2", + this.l2 + } + }; } } } diff --git a/src/MxNet.Keras/Regularizers/Regularizer.cs b/src/MxNet.Keras/Regularizers/Regularizer.cs index 2552a7b2..bdc3e8a7 100644 --- a/src/MxNet.Keras/Regularizers/Regularizer.cs +++ b/src/MxNet.Keras/Regularizers/Regularizer.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; + namespace MxNet.Keras.Regularizers { public class Regularizer @@ -11,9 +12,9 @@ public virtual KerasSymbol Call(KerasSymbol x) return x; } - public static Regularizer FromConfig(Type cls, ConfigDict config) + public static Regularizer FromConfig(ConfigDict config) { - throw new NotImplementedException(); + return Deserialize(config); } public static Regularizer L1(float l = 0.01f) @@ -31,19 +32,45 @@ public static Regularizer L1_L2(float l1 = 0.01f, float l2 = 0.01f) return new L1L2(l1: l1, l2: l2); } - public static string Serialize(Regularizer regularizer) + public static ConfigDict Serialize(Regularizer regularizer) { - throw new NotImplementedException(); + return Utils.GenericUtils.SerializeKerasObject(regularizer); } - public static Regularizer Deserialize(ConfigDict config, object[] custom_objects = null) + public static Regularizer Deserialize(ConfigDict config, CustomObjects custom_objects = null) { - throw new NotImplementedException(); + return (Regularizer)Utils.GenericUtils.DeserializeKerasObject(config, custom_objects: custom_objects, printable_module_name: "regularizer"); } public static Regularizer Get(object identifier) { - throw new NotImplementedException(); + if (identifier == null) + { + return null; + } + + if (identifier is ConfigDict) + { + return Deserialize((ConfigDict)identifier); + } + + else if (identifier is string) + { + ConfigDict config = new ConfigDict { + { + "class_name", + identifier.ToString()}, + { + "config", + new Dictionary() + } + }; + return Deserialize(config); + } + else + { + throw new Exception("Could not interpret regularizer identifier: " + identifier.ToString()); + } } } } diff --git a/src/MxNet.Keras/Utils/GenericUtils.cs b/src/MxNet.Keras/Utils/GenericUtils.cs index fe673714..26fe4dff 100644 --- a/src/MxNet.Keras/Utils/GenericUtils.cs +++ b/src/MxNet.Keras/Utils/GenericUtils.cs @@ -17,12 +17,12 @@ public static CustomObjectScope GetCustomObjects() throw new NotImplementedException(); } - public static ConfigData SerializeKerasObject(object instance) + public static ConfigDict SerializeKerasObject(object instance) { throw new NotImplementedException(); } - public static object DeserializeKerasObject(object identifier, string module_objects = "", CustomObjectScope custom_objects = null, string printable_module_name = "object") + public static object DeserializeKerasObject(object identifier, string module_objects = "", CustomObjects custom_objects = null, string printable_module_name = "object") { throw new NotImplementedException(); } diff --git a/src/MxNet.Keras/Utils/Progbar.cs b/src/MxNet.Keras/Utils/Progbar.cs index 7d53ab7c..120926cb 100644 --- a/src/MxNet.Keras/Utils/Progbar.cs +++ b/src/MxNet.Keras/Utils/Progbar.cs @@ -11,7 +11,7 @@ public Progbar(int target, int width= 30, int verbose= 1, float interval= 0.05f, throw new NotImplementedException(); } - public void Update(int current, Dictionary values = null) + public void Update(int current, Dictionary values = null) { throw new NotImplementedException(); } diff --git a/src/MxNet/MxNet.csproj b/src/MxNet/MxNet.csproj index 4f70e9d0..cfd52fe3 100644 --- a/src/MxNet/MxNet.csproj +++ b/src/MxNet/MxNet.csproj @@ -56,10 +56,10 @@ MXNet is more than a deep learning project. It is a collection of blue prints an - + - + diff --git a/src/MxNet/NDArray/NDArray.cs b/src/MxNet/NDArray/NDArray.cs index 4e7bb049..396b741f 100644 --- a/src/MxNet/NDArray/NDArray.cs +++ b/src/MxNet/NDArray/NDArray.cs @@ -23,6 +23,8 @@ limitations under the License. using mx_uint = System.UInt32; using mx_float = System.Single; using size_t = System.UInt64; +using System.IO.Compression; +using System.IO; // ReSharper disable once CheckNamespace namespace MxNet @@ -346,6 +348,138 @@ public static NDArray LoadCV2Mat(OpenCvSharp.Mat img, Context context = null) return ret; } + public static NDArrayList LoadNpz(string file) + { + NDArrayList result = new NDArrayList(); + using (ZipArchive zip = ZipFile.OpenRead(file)) + { + foreach (ZipArchiveEntry entry in zip.Entries) + { + Stream fs = entry.Open(); + BinaryReader reader = new BinaryReader(fs); + var magic = reader.ReadChars(6); + var maj = reader.ReadByte(); + var min = reader.ReadByte(); + int headerLength = reader.ReadUInt16(); + string header = new string(reader.ReadChars(headerLength)).Trim(); + string mark = "'descr': '"; + int s = header.IndexOf(mark) + mark.Length; + int e = header.IndexOf("'", s + 1); + string type = header.Substring(s, e - s); + + DType dtype = GetNpyType(type); + mark = "'fortran_order': "; + s = header.IndexOf(mark) + mark.Length; + e = header.IndexOf(",", s + 1); + bool fortran = bool.Parse(header.Substring(s, e - s)); + + if (fortran) + throw new Exception(); + + mark = "'shape': ("; + s = header.IndexOf(mark) + mark.Length; + e = header.IndexOf(")", s + 1); + var shapeSplit = header.Substring(s, e - s).Split(','); + List shapeInt = new List(); + foreach (var element in shapeSplit) + { + if (!string.IsNullOrWhiteSpace(element)) + { + shapeInt.Add(Convert.ToInt32(element)); + } + } + + Shape shape = new Shape(shapeInt); + if (dtype == DType.Int32) + { + List data = new List(); + for (int i = 0; i < shape.Size; i++) + { + data.Add(reader.ReadInt32()); + } + + var x = nd.Array(data.ToArray()).AsType(dtype).Reshape(shape); + result.Add(x); + } + else if (dtype == DType.Int8) + { + List data = new List(); + for (int i = 0; i < shape.Size; i++) + { + data.Add(reader.ReadSByte()); + } + + var x = nd.Array(data.ToArray()).AsType(dtype).Reshape(shape); + result.Add(x); + } + else if (dtype == DType.Int64) + { + List data = new List(); + for (int i = 0; i < shape.Size; i++) + { + data.Add(reader.ReadSByte()); + } + + var x = nd.Array(data.ToArray()).AsType(dtype).Reshape(shape); + result.Add(x); + } + else if (dtype == DType.Float32) + { + List data = new List(); + for (int i = 0; i < shape.Size; i++) + { + data.Add(reader.ReadSByte()); + } + + var x = nd.Array(data.ToArray()).AsType(dtype).Reshape(shape); + result.Add(x); + } + else if (dtype == DType.Float64) + { + List data = new List(); + for (int i = 0; i < shape.Size; i++) + { + data.Add(reader.ReadSByte()); + } + + var x = nd.Array(data.Select(i => (float)i).ToArray()).AsType(dtype).Reshape(shape).AsType(dtype); + result.Add(x); + } + else if (dtype == DType.Uint8) + { + var data = reader.ReadBytes(shape.Size); + + var x = nd.Array(data.Select(i => (float)i).ToArray()).Reshape(shape).AsType(dtype); + result.Add(x); + } + } + } + + return result; + } + + private static DType GetNpyType(string dtype) + { + string typeCode = dtype.Substring(1); + + if (typeCode == "i1") + return DType.Int8; + if (typeCode == "u1") + return DType.Uint8; + if (typeCode == "i2") + return DType.Int32; + if (typeCode == "i4") + return DType.Int32; + if (typeCode == "i8") + return DType.Int64; + if (typeCode == "f4") + return DType.Float32; + if (typeCode == "f8") + return DType.Float64; + + throw new NotSupportedException(); + } + public static NDArray NewFromSharedMem(int shared_pid, int shared_id, Shape shape, DType dtype) { NativeMethods.MXNDArrayCreateFromSharedMemEx(shared_pid, shared_id, shape.Data, shape.Dimension, @@ -577,7 +711,7 @@ public NDArray this[string slice] } } - public NDArray this[NDArray slice] => nd.SliceLike(this, slice); + public NDArray this[NDArray indices] => nd.Take(this, indices); public NDArray Detach() { diff --git a/src/MxNet/Optimizers/AdaDelta.cs b/src/MxNet/Optimizers/AdaDelta.cs index 74ea9c58..ebfb712d 100644 --- a/src/MxNet/Optimizers/AdaDelta.cs +++ b/src/MxNet/Optimizers/AdaDelta.cs @@ -17,8 +17,9 @@ namespace MxNet.Optimizers { public class AdaDelta : Optimizer { - public AdaDelta(float rho = 0.95f, float decayRate = 0, float epsilon = 1e-07f) + public AdaDelta(float lr = 1, float rho = 0.95f, float decayRate = 0, float epsilon = 1e-07f) { + LearningRate = lr; Rho = rho; Epsilon = epsilon; } diff --git a/src/MxNet/Optimizers/SGD.cs b/src/MxNet/Optimizers/SGD.cs index 56174ad5..ed2d73ce 100644 --- a/src/MxNet/Optimizers/SGD.cs +++ b/src/MxNet/Optimizers/SGD.cs @@ -20,8 +20,8 @@ namespace MxNet.Optimizers { public class SGD : Optimizer { - private readonly bool lazy_update; - private readonly float momentum; + public readonly bool lazy_update; + public readonly float momentum; public SGD(float learning_rate= 0.1f, float momentum = 0, bool lazy_update = true, bool multi_precision = false) : base(learning_rate: learning_rate, multi_precision: multi_precision) diff --git a/src/MxNet/Sym/Ops.cs b/src/MxNet/Sym/Ops.cs index 6d8cdf5e..92dd04e0 100644 --- a/src/MxNet/Sym/Ops.cs +++ b/src/MxNet/Sym/Ops.cs @@ -7889,7 +7889,7 @@ public static Symbol Zeros(Shape shape = null, Context ctx = null, DType dtype = /// Context of output, in format [cpu|gpu|cpu_pinned](n).Only used for imperative calls. /// Target data type. /// returns new symbol - public static Symbol Eye(Tuple N, int M = 0, int k = 0, Context ctx = null, DType dtype = null, + public static Symbol Eye(int N, int M = 0, int k = 0, Context ctx = null, DType dtype = null, string symbol_name = "") { if (dtype == null) dtype = DType.Float32;