Skip to content

Commit

Permalink
Bug fixes with HybridBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
deepakkumar1984 committed Apr 18, 2021
1 parent 52a4576 commit 88ee022
Show file tree
Hide file tree
Showing 15 changed files with 100 additions and 47 deletions.
1 change: 1 addition & 0 deletions Onnx.Net
Submodule Onnx.Net added at 08d53f
21 changes: 15 additions & 6 deletions examples/BasicExamples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using MxNet;
using MxNet.Numpy;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

Expand All @@ -12,14 +13,22 @@ class Program
static void Main(string[] args)
{
//Console.WriteLine("Runnin XOR Example......");
XORGate.Run();
//XORGate.Run();
//CrashCourse_NN.Run();
//LogisticRegressionExplained.Run();
var methods = mx.GetAllRegisteredOperators();
var y = np.full(new Shape(3, 3), 0.6);
var x = np.random.power(y, new Shape(3, 3));

var z = np.linalg.cholesky(x);
//var methods = mx.GetAllRegisteredCApiOperators();
//var y = np.full(new Shape(3, 3), 0.6);
//var x = np.random.power(y, new Shape(3, 3));
//var fc = npx.fully_connected(x, y, null, 3);
//var z = np.linalg.cholesky(x);
DateTime start = DateTime.Now;
var x = np.random.uniform(size: new Shape(3000, 1000));
var y = np.random.uniform(size: new Shape(1000, 3000));
//var d = 0.5f * np.sqrt(x) + np.sin(y) * np.log(x) - np.exp(y);
var d = np.dot(x, y);
//var v = d.data.GetValue(0);
//Console.WriteLine(v);
Console.WriteLine("Duration: " + (DateTime.Now - start).TotalMilliseconds / 1000);
}

private static void GenerateFOps()
Expand Down
18 changes: 8 additions & 10 deletions src/MxNet/Gluon/Block/HybridBlock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ public class HybridBlock : Block

public HybridBlock() : base()
{
this._v2 = true;
this._cached_graph = (null, null);
this._v2 = false;
this._cached_graph = null;
this._cached_op = null;
this._out_format = null;
this._in_format = null;
Expand Down Expand Up @@ -97,7 +97,6 @@ public HybridBlock(Dictionary<string, Block> blocks, bool loadkeys = false)
{
if (_cached_graph == null)
{
var inputs = new SymbolList();
var (flatten_args, _in_format) = Flatten(args, "input");
this._in_format = _in_format.ToList();
var flatten_inputs = new List<NDArrayOrSymbolList>();
Expand Down Expand Up @@ -140,12 +139,11 @@ public HybridBlock(Dictionary<string, Block> blocks, bool loadkeys = false)
var @params = new SymbolDict();
foreach (var item in _reg_params) @params[item.Key] = item.Value.Var();

foreach (var input in grouped_inputs)
outputs.Add(HybridForward((input, new NDArrayOrSymbol(@params.Values))));
outputs.Add(HybridForward((grouped_inputs, @params.Values.NDArrayOrSymbol)));
}

var (@out, _out_format) = Flatten(outputs, "output");
_cached_graph = (inputs, _Symbol.Group(@out.ToSymbols()));
_cached_graph = (symbol_inputs, _Symbol.Group(@out.ToSymbols()));
}

return _cached_graph.Value;
Expand Down Expand Up @@ -194,7 +192,7 @@ public HybridBlock(Dictionary<string, Block> blocks, bool loadkeys = false)
using(var ag = Autograd.Pause())
{
DeferredCompute.Context();
@out = base.Call(args);
@out = Call(args);
}

var (flatten_out, out_format) = Flatten(@out, "output");
Expand Down Expand Up @@ -646,7 +644,7 @@ private void InterAttrs(string infer_fn, string attr, NDArrayOrSymbolList argume

var collectedValues = CollectParams().Values();
for (var i = 0; i < collectedValues.Length; i++)
collectedValues[i]._shape = sdict[collectedValues[i]._var_name];
collectedValues[i]._shape = sdict[collectedValues[i].Name];
}
else if (infer_fn == "infer_type")
{
Expand All @@ -666,11 +664,11 @@ private void InterAttrs(string infer_fn, string attr, NDArrayOrSymbolList argume

var collectedValues = CollectParams().Values();
for (var i = 0; i < collectedValues.Length; i++)
collectedValues[i].DataType = sdict[collectedValues[i]._var_name];
collectedValues[i].DataType = sdict[collectedValues[i].Name];
}
}

public void InferShape(NDArrayOrSymbolList args)
public virtual void InferShape(NDArrayOrSymbolList args)
{
if (!this._v2)
{
Expand Down
2 changes: 1 addition & 1 deletion src/MxNet/Gluon/Losses/Loss.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public Loss(float? weight = null, int? batch_axis = null, string prefix = "", Pa

public override NDArrayOrSymbolList HybridForward(NDArrayOrSymbolList args)
{
return HybridForward(args[0], args.Length > 2 ? args[1] : null, args.Length > 3 ? args[2] : null);
return HybridForward(args[0], args.Length > 1 ? args[1] : null, args.Length > 2 ? args[2] : null);
}

public virtual NDArrayOrSymbol HybridForward(NDArrayOrSymbol pred, NDArrayOrSymbol label,
Expand Down
4 changes: 2 additions & 2 deletions src/MxNet/Gluon/Parameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ internal void FinishDeferredInit()
{
if (data == null)
{
data = nd.Zeros(Shape, dtype: DataType, ctx: ctx[0]).ToSType(Stype);
data = np.zeros(Shape, dtype: DataType, ctx: ctx[0]).ToSType(Stype);

if (init == null)
{
Expand Down Expand Up @@ -535,7 +535,7 @@ public _Symbol Var()
}
_var = _Symbol.Var(_var_name, shape: Shape, dtype: DataType, lr_mult: Lr_Mult, wd_mult: Wd_Mult, init: Init,
stype: Stype);

return _var;
}

Expand Down
2 changes: 1 addition & 1 deletion src/MxNet/Initializers/Uniform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public Uniform(float scale = 0.07f)

public override void InitWeight(string name, ref ndarray arr)
{
arr = nd.Random.Uniform(-Scale, Scale, arr.shape);
arr = np.random.uniform(-Scale, Scale, arr.shape);
}
}
}
2 changes: 1 addition & 1 deletion src/MxNet/MxNet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<LangVersion>latest</LangVersion>
<Platforms>AnyCPU;x64</Platforms>
<Version>2.0.0.1</Version>
<Authors>@deepakkumar1984, @horker</Authors>
<Authors>@deepakkumar1984</Authors>
<Product />
<Description>C# Binding for the Apache MxNet library. NDArray, Symbolic and Gluon Supported

Expand Down
3 changes: 3 additions & 0 deletions src/MxNet/NDArray/NumpyExtensions/_api_internals.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ public override bool TryInvokeMember(InvokeMemberBinder binder, object[] args, o
if (k == "ctx" && value == null)
value = Context.CurrentContext;

if (k.Contains("weight") || k.Contains("bias"))
if (value == null)
continue;
if (value == null)
value = "None";

Expand Down
4 changes: 2 additions & 2 deletions src/MxNet/NDArray/NumpyExtensions/npx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ public static ndarray dropout(ndarray data, float p= 0.5f, string mode= "trainin

public static ndarray embedding(ndarray data, ndarray weight, int input_dim, int output_dim, DType dtype= null, bool sparse_grad= false)
{
return _api_internal.dropout(data: data, weight: weight, input_dim: input_dim, output_dim: output_dim, dtype: dtype, sparse_grad: sparse_grad);
return _api_internal.embedding(data: data, weight: weight, input_dim: input_dim, output_dim: output_dim, dtype: dtype, sparse_grad: sparse_grad);
}

public static ndarray fully_connected(ndarray x, ndarray weight, ndarray bias, int num_hidden, bool no_bias= true, bool flatten= true)
{
return _api_internal.fully_connected(data: x, weight: weight, bias: bias, num_hidden: num_hidden, no_bias: no_bias, flatten: flatten);
return _api_internal.fully_connected(x: x, weight: weight, bias: bias, num_hidden: num_hidden, no_bias: no_bias, flatten: flatten);
}

public static ndarray layer_norm(ndarray data, ndarray gamma, ndarray beta, int axis= -1, float eps= 9.99999975e-06f, bool output_mean_var= false)
Expand Down
3 changes: 3 additions & 0 deletions src/MxNet/NDArray/Shape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ public override int GetHashCode()

public override string ToString()
{
if (Dimension == 1)
return $"({Data[0]},)";

return $"({string.Join(",", Enumerable.Range(0, Dimension).Select(i => Data[i].ToString()))})";
}

Expand Down
49 changes: 42 additions & 7 deletions src/MxNet/NDArrayOrSymbol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public static implicit operator _Symbol(NDArrayOrSymbol x)

public static implicit operator NDArrayOrSymbol(NDArrayOrSymbolList x)
{
return x[0];
return new NDArrayOrSymbol(x);
}

public static implicit operator NDArrayOrSymbol(float x)
Expand Down Expand Up @@ -476,7 +476,9 @@ public NDArrayOrSymbolList(NDArrayOrSymbol x)

public NDArrayOrSymbolList((NDArrayOrSymbol, NDArrayOrSymbol) args)
{
data = new List<NDArrayOrSymbol> { args.Item1, args.Item2 };
data = new List<NDArrayOrSymbol>();
data.AddRange(args.Item1.List);
data.AddRange(args.Item2.List);
}

public NDArrayOrSymbolList((NDArrayOrSymbolList, NDArrayOrSymbol) args)
Expand All @@ -488,7 +490,40 @@ public NDArrayOrSymbolList((NDArrayOrSymbolList, NDArrayOrSymbol) args)
public NDArrayOrSymbolList((NDArrayOrSymbol, NDArrayOrSymbolList) args)
{
data = new List<NDArrayOrSymbol> { args.Item1 };
data.Add(new NDArrayOrSymbol(args.Item2));
foreach (var item in args.Item2)
{
data.Add(item);
}
}

public NDArrayOrSymbolList((ndarray, ndarray) args)
{
data.Add(args.Item1);
data.Add(args.Item2);
}

public NDArrayOrSymbolList((_Symbol, _Symbol) args)
{
data.Add(args.Item1);
data.Add(args.Item2);
}

public NDArrayOrSymbolList((ndarray, NDArrayList) args)
{
data.Add(args.Item1);
foreach (var item in args.Item2)
{
data.Add(item);
}
}

public NDArrayOrSymbolList((_Symbol, SymbolList) args)
{
data.Add(args.Item1);
foreach (var item in args.Item2)
{
data.Add(item);
}
}

public NDArrayOrSymbolList((NDArrayOrSymbol, NDArrayOrSymbol, NDArrayOrSymbol) args)
Expand Down Expand Up @@ -580,28 +615,28 @@ public SymbolList Symbols
public void Deconstruct(out NDArrayOrSymbol x0, out NDArrayOrSymbol x1)
{
x0 = this[0];
x1 = this.Length > 2 ? this[1] : null;
x1 = this.Length > 1 ? this[1] : null;
}

public void Deconstruct(out NDArrayOrSymbol x0, out NDArrayOrSymbol x1, out NDArrayOrSymbol x2)
{
x0 = this[0];
x1 = this.Length > 2 ? this[1] : null;
x1 = this.Length > 1 ? this[1] : null;
x2 = this.Length > 2 ? this[2] : null;
}

public void Deconstruct(out NDArrayOrSymbol x0, out NDArrayOrSymbol x1, out NDArrayOrSymbol x2, out NDArrayOrSymbol x3)
{
x0 = this[0];
x1 = this.Length > 2 ? this[1] : null;
x1 = this.Length > 1 ? this[1] : null;
x2 = this.Length > 2 ? this[2] : null;
x3 = this.Length > 3 ? this[3] : null;
}

public void Deconstruct(out NDArrayOrSymbol x0, out NDArrayOrSymbol x1, out NDArrayOrSymbol x2, out NDArrayOrSymbol x3, out NDArrayOrSymbol x4)
{
x0 = this[0];
x1 = this.Length > 2 ? this[1] : null;
x1 = this.Length > 1 ? this[1] : null;
x2 = this.Length > 2 ? this[2] : null;
x3 = this.Length > 3 ? this[3] : null;
x4 = this.Length > 4 ? this[4] : null;
Expand Down
2 changes: 1 addition & 1 deletion src/MxNet/Sym/Numpy/_Symbol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ public static _Symbol Var(string name, Dictionary<string, string> attr = null, S
attr.Add("__wd_mult__", wd_mult.Value.ToString());

if (dtype != null)
attr.Add("__dtype__", dtype.Name);
attr.Add("__dtype__", dtype);

if (init != null)
{
Expand Down
20 changes: 10 additions & 10 deletions src/MxNet/Sym/NumpyExtensions/_api_internals.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,21 @@ public override bool TryInvokeMember(InvokeMemberBinder binder, object[] args, o
if (k == "ctx" && value == null)
value = Context.CurrentContext;

if (k.Contains("weight") || k.Contains("bias"))
if (value == null)
continue;

if (value == null)
value = "None";

var argType = value.GetType();
if (argType.Name == "ndarray")
if (argType.Name == "_Symbol")
{
op.SetInput(k, (ndarray)value);
op.Set((_Symbol)value);
}
else if (argType.Name == "NDArrayList")
else if (argType.Name == "SymbolList")
{
op.SetInput((NDArrayList)value);
op.SetInput((SymbolList)value);
}
else
{
Expand All @@ -73,13 +77,9 @@ public override bool TryInvokeMember(InvokeMemberBinder binder, object[] args, o
}

if (multiple)
{
NDArrayList list = new NDArrayList();
op.Invoke(list);
result = list;
}
result = op.CreateNpSymbol().ToList();
else
result = op.Invoke();
result = op.CreateNpSymbol();

return true;
}
Expand Down
10 changes: 5 additions & 5 deletions src/MxNet/Sym/NumpyExtensions/npx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ public static _Symbol dropout(_Symbol data, float p = 0.5f, string mode = "train

public static _Symbol embedding(_Symbol data, _Symbol weight, int input_dim, int output_dim, DType dtype = null, bool sparse_grad = false)
{
return _api_internal.dropout(data: data, weight: weight, input_dim: input_dim, output_dim: output_dim, dtype: dtype, sparse_grad: sparse_grad);
return _api_internal.embedding(data: data, weight: weight, input_dim: input_dim, output_dim: output_dim, dtype: dtype, sparse_grad: sparse_grad);
}

public static _Symbol fully_connected(_Symbol x, _Symbol weight, _Symbol bias, int num_hidden, bool no_bias = true, bool flatten = true)
{
return _api_internal.dropout(data: x, weight: weight, bias: bias, num_hidden: num_hidden, no_bias: no_bias, flatten: flatten);
return _api_internal.fully_connected(data: x, weight: weight, bias: bias, num_hidden: num_hidden, no_bias: no_bias, flatten: flatten);
}

public static _Symbol layer_norm(_Symbol data, _Symbol gamma, _Symbol beta, int axis = -1, float eps = 9.99999975e-06f, bool output_mean_var = false)
{
return _api_internal.dropout(data: data, gamma: gamma, beta: beta, axis: axis, eps: eps, output_mean_var: output_mean_var);
return _api_internal.layer_norm(data: data, gamma: gamma, beta: beta, axis: axis, eps: eps, output_mean_var: output_mean_var);
}

public static _Symbol pooling(_Symbol data, int[] kernel, int[] stride = null, int[] pad = null, string pool_type = "max",
Expand Down Expand Up @@ -114,7 +114,7 @@ public static _Symbol roi_pooling(_Symbol data, _Symbol rois, int[] pooled_size,

public static _Symbol smooth_l1(_Symbol data, float scalar)
{
return _api_internal.roi_pooling(data: data, scalar: scalar);
return _api_internal.smooth_l1(data: data, scalar: scalar);
}

public static _Symbol sigmoid(_Symbol data)
Expand Down Expand Up @@ -149,7 +149,7 @@ public static _Symbol one_hot(_Symbol data, long depth, double on_value = 1.0, d

public static _Symbol pick(_Symbol data, _Symbol index, int axis = -1, string mode = "clip", bool keepdims = false)
{
return _api_internal.one_hot(data: data, index: index, axis: axis, mode: keepdims, dtype: keepdims);
return _api_internal.pick(data: data, index: index, axis: axis, mode: keepdims, dtype: keepdims);
}

public static _Symbol reshape_like(_Symbol lhs, _Symbol rhs, int? lhs_begin = null, int? lhs_end = null, int? rhs_begin = null, int? rhs_end = null)
Expand Down
6 changes: 5 additions & 1 deletion src/MxNet/Sym/Operator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,11 @@ public Operator SetInput(string name, Symbol symbol)

public Operator SetInput(SymbolList symbols)
{
foreach (var item in symbols) _InputSymbols.Add(item.GetHandle());
foreach (var item in symbols)
{
if (item != null)
_InputSymbols.Add(item.GetHandle());
}

return this;
}
Expand Down

0 comments on commit 88ee022

Please sign in to comment.