Skip to content

Hide much of Microsoft.ML.Model namespace. #2649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Core/Data/ModelHeader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,8 @@ public static string GetLoaderSigAlt(ref ModelHeader header)
/// This is used to simplify version checking boiler-plate code. It is an optional
/// utility type.
/// </summary>
public readonly struct VersionInfo
[BestFriend]
internal readonly struct VersionInfo
{
public readonly ulong ModelSignature;
public readonly uint VerWrittenCur;
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Core/Data/ModelLoadContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ namespace Microsoft.ML.Model
/// amount of boiler plate code. It can also be used when loading from a single stream,
/// for implementors of ICanSaveInBinaryFormat.
/// </summary>
public sealed partial class ModelLoadContext : IDisposable
[BestFriend]
internal sealed partial class ModelLoadContext : IDisposable
{
/// <summary>
/// When in repository mode, this is the repository we're reading from. It is null when
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/ModelLoading.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.ML.Model
[BestFriend]
internal delegate void SignatureLoadModel(ModelLoadContext ctx);

public sealed partial class ModelLoadContext : IDisposable
internal sealed partial class ModelLoadContext : IDisposable
{
public const string ModelStreamName = "Model.key";
internal const string NameBinary = "Model.bin";
Expand Down
45 changes: 30 additions & 15 deletions src/Microsoft.ML.Core/Data/ModelSaveContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@ public sealed partial class ModelSaveContext : IDisposable
/// When in repository mode, this is the repository we're writing to. It is null when
/// in single-stream mode.
/// </summary>
public readonly RepositoryWriter Repository;
[BestFriend]
internal readonly RepositoryWriter Repository;

/// <summary>
/// When in repository mode, this is the directory we're reading from. Null means the root
/// of the repository. It is always null in single-stream mode.
/// </summary>
public readonly string Directory;
[BestFriend]
internal readonly string Directory;

/// <summary>
/// The main stream writer.
/// </summary>
public readonly BinaryWriter Writer;
[BestFriend]
internal readonly BinaryWriter Writer;

/// <summary>
/// The strings that will be saved in the main stream's string table.
Expand All @@ -49,7 +52,8 @@ public sealed partial class ModelSaveContext : IDisposable
/// <summary>
/// The min file position of the main stream.
/// </summary>
public readonly long FpMin;
[BestFriend]
internal readonly long FpMin;

/// <summary>
/// The wrapped entry.
Expand All @@ -69,7 +73,8 @@ public sealed partial class ModelSaveContext : IDisposable
/// <summary>
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
/// </summary>
public bool InRepository { get { return Repository != null; } }
[BestFriend]
internal bool InRepository => Repository != null;

/// <summary>
/// Create a <see cref="ModelSaveContext"/> supporting saving to a repository, for implementors of <see cref="ICanSaveModel"/>.
Expand Down Expand Up @@ -125,7 +130,8 @@ internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null)
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
}

public void CheckAtModel()
[BestFriend]
internal void CheckAtModel()
{
_ectx.Check(Writer.BaseStream.Position == FpMin + Header.FpModel);
}
Expand All @@ -135,13 +141,15 @@ public void CheckAtModel()
/// <see cref="Done"/> is called.
/// </summary>
/// <param name="ver"></param>
public void SetVersionInfo(VersionInfo ver)
[BestFriend]
internal void SetVersionInfo(VersionInfo ver)
{
ModelHeader.SetVersionInfo(ref Header, ver);
_loaderAssemblyName = ver.LoaderAssemblyName;
}

public void SaveTextStream(string name, Action<TextWriter> action)
[BestFriend]
internal void SaveTextStream(string name, Action<TextWriter> action)
{
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
Expand All @@ -156,7 +164,8 @@ public void SaveTextStream(string name, Action<TextWriter> action)
}
}

public void SaveBinaryStream(string name, Action<BinaryWriter> action)
[BestFriend]
internal void SaveBinaryStream(string name, Action<BinaryWriter> action)
{
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
_ectx.CheckNonEmpty(name, nameof(name));
Expand All @@ -175,7 +184,8 @@ public void SaveBinaryStream(string name, Action<BinaryWriter> action)
/// Puts a string into the context pool, and writes the integer code of the string ID
/// to the write stream. If str is null, this writes -1 and doesn't add it to the pool.
/// </summary>
public void SaveStringOrNull(string str)
[BestFriend]
internal void SaveStringOrNull(string str)
{
if (str == null)
Writer.Write(-1);
Expand All @@ -187,13 +197,15 @@ public void SaveStringOrNull(string str)
/// Puts a string into the context pool, and writes the integer code of the string ID
/// to the write stream. Checks that str is not null.
/// </summary>
public void SaveString(string str)
[BestFriend]
internal void SaveString(string str)
{
_ectx.CheckValue(str, nameof(str));
Writer.Write(Strings.Add(str).Id);
}

public void SaveString(ReadOnlyMemory<char> str)
[BestFriend]
internal void SaveString(ReadOnlyMemory<char> str)
{
Writer.Write(Strings.Add(str).Id);
}
Expand All @@ -202,13 +214,15 @@ public void SaveString(ReadOnlyMemory<char> str)
/// Puts a string into the context pool, and writes the integer code of the string ID
/// to the write stream.
/// </summary>
public void SaveNonEmptyString(string str)
[BestFriend]
internal void SaveNonEmptyString(string str)
{
_ectx.CheckParam(!string.IsNullOrEmpty(str), nameof(str));
Writer.Write(Strings.Add(str).Id);
}

public void SaveNonEmptyString(ReadOnlyMemory<Char> str)
[BestFriend]
internal void SaveNonEmptyString(ReadOnlyMemory<char> str)
{
Writer.Write(Strings.Add(str).Id);
}
Expand All @@ -217,7 +231,8 @@ public void SaveNonEmptyString(ReadOnlyMemory<Char> str)
/// Commit the save operation. This completes writing of the main stream. When in repository
/// mode, it disposes <see cref="Writer"/> (but not <see cref="Repository"/>).
/// </summary>
public void Done()
[BestFriend]
internal void Done()
{
_ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!");
ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName);
Expand Down
17 changes: 11 additions & 6 deletions src/Microsoft.ML.Core/Data/ModelSaving.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ namespace Microsoft.ML.Model
public sealed partial class ModelSaveContext : IDisposable
{
/// <summary>
/// Save a sub model to the given sub directory. This requires InRepository to be true.
/// Save a sub model to the given sub directory. This requires <see cref="InRepository"/> to be <see langword="true"/>.
/// </summary>
public void SaveModel<T>(T value, string name)
[BestFriend]
internal void SaveModel<T>(T value, string name)
where T : class
{
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
Expand All @@ -23,7 +24,8 @@ public void SaveModel<T>(T value, string name)
/// <summary>
/// Save the object by calling TrySaveModel then falling back to .net serialization.
/// </summary>
public static void SaveModel<T>(RepositoryWriter rep, T value, string path)
[BestFriend]
internal static void SaveModel<T>(RepositoryWriter rep, T value, string path)
where T : class
{
if (value == null)
Expand Down Expand Up @@ -55,7 +57,8 @@ public static void SaveModel<T>(RepositoryWriter rep, T value, string path)
/// <summary>
/// Save to a single-stream by invoking the given action.
/// </summary>
public static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
[BestFriend]
internal static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
{
Contracts.CheckValue(writer, nameof(writer));
Contracts.CheckValue(fn, nameof(fn));
Expand All @@ -68,9 +71,11 @@ public static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
}

/// <summary>
/// Save to the given sub directory by invoking the given action. This requires InRepository to be true.
/// Save to the given sub directory by invoking the given action. This requires
/// <see cref="InRepository"/> to be <see langword="true"/>.
/// </summary>
public void SaveSubModel(string dir, Action<ModelSaveContext> fn)
[BestFriend]
internal void SaveSubModel(string dir, Action<ModelSaveContext> fn)
{
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
_ectx.CheckNonEmpty(dir, nameof(dir));
Expand Down
11 changes: 7 additions & 4 deletions src/Microsoft.ML.Core/Data/Repository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ internal interface ICanSaveInBinaryFormat
}

/// <summary>
/// Abstraction around a ZipArchive or other hierarchical storage.
/// Abstraction around a <see cref="ZipArchive"/> or other hierarchical storage.
/// </summary>
public abstract class Repository : IDisposable
[BestFriend]
internal abstract class Repository : IDisposable
{
public sealed class Entry : IDisposable
{
Expand Down Expand Up @@ -289,7 +290,8 @@ protected Entry AddEntry(string pathEnt, Stream stream)
}
}

public sealed class RepositoryWriter : Repository
[BestFriend]
internal sealed class RepositoryWriter : Repository
{
private const string DirTrainingInfo = "TrainingInfo";

Expand Down Expand Up @@ -429,7 +431,8 @@ public void Commit()
}
}

public sealed class RepositoryReader : Repository
[BestFriend]
internal sealed class RepositoryReader : Repository
{
private ZipArchive _archive;

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private static VersionInfo GetVersionInfo()
}

// Factory for SignatureLoadModel.
public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public static class TransformerChain
{
Copy link
Contributor

@artidoro artidoro Feb 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it, I suspect that this entire class can be made internal. We should expose these methods in MLContext if we need them, and they are not there yet. Please let me know if I am wrong. #Closed

Copy link
Contributor Author

@TomFinley TomFinley Feb 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's all fine, but, maybe we could internalize things unrelated to the issue and this PR in subsequent PRs? My goal @artidoro is not to internalize the entire assembly in one go, just some specific types and infrastructure I see in the Microsoft.ML.Model namespace, as described in the issue, title of the PR, and so on. Perhaps we could evaluate the worth of the PR along those dimensions? #Resolved

Copy link
Contributor Author

@TomFinley TomFinley Feb 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to be clear, I don't really intend to do this. Unless I literally can't get a signoff any other way. #Resolved

public const string LoaderSignature = "TransformerChain";

Copy link
Contributor

@artidoro artidoro Feb 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be made internal. We usually do that with the LoaderSignature elsewhere. #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No doubt. However I wasn't trying to do that.


In reply to: 258701066 [](ancestors = 258701066)

public static TransformerChain<ITransformer> Create(IHostEnvironment env, ModelLoadContext ctx)
private static TransformerChain<ITransformer> Create(IHostEnvironment env, ModelLoadContext ctx)
=> new TransformerChain<ITransformer>(env, ctx);

/// <summary>
Expand Down
11 changes: 7 additions & 4 deletions src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@ namespace Microsoft.ML.Internal.Internallearn
/// </summary>
public abstract class ModelParametersBase<TOutput> : ICanSaveModel, IPredictorProducing<TOutput>
{
public const string NormalizerWarningFormat =
private const string NormalizerWarningFormat =
"Ignoring integrated normalizer while loading a predictor of type {0}.{1}" +
" Please refer to https://aka.ms/MLNetIssue for assistance with converting legacy models.";

protected readonly IHost Host;
[BestFriend]
private protected readonly IHost Host;

protected ModelParametersBase(IHostEnvironment env, string name)
[BestFriend]
private protected ModelParametersBase(IHostEnvironment env, string name)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(name, nameof(name));
Host = env.Register(name);
}

protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
[BestFriend]
private protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(name, nameof(name));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
// Multi-class evaluator adds four per-instance columns: "Assigned", "Top scores", "Top classes" and "Log-loss".
private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
{
// If the label column is a key without text key values, convert it to I8, just for saving the per-instance
// If the label column is a key without text key values, convert it to double, just for saving the per-instance
// text file, since if there are different key counts the columns cannot be appended.
string labelName = schema.Label.Value.Name;
if (!perInst.Schema.TryGetColumnIndex(labelName, out int labelColIndex))
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ public abstract class CalibratorTransformer<TICalibrator> : RowToRowTransformerB
private TICalibrator _calibrator;
private readonly string _loaderSignature;

internal CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature)
private protected CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer<TICalibrator>)))
{
_loaderSignature = loaderSignature;
_calibrator = calibrator;
}

// Factory method for SignatureLoadModel.
internal CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature)
private protected CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer<TICalibrator>)))
{
Contracts.AssertValue(ctx);
Expand Down Expand Up @@ -195,7 +195,7 @@ private protected override void SaveModel(ModelSaveContext ctx)

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper<TICalibrator>(this, _calibrator, schema);

protected VersionInfo GetVersionInfo()
private protected VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "CALTRANS",
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ public void Add(uint hash, T key)
/// Simple utility class for saving a <see cref="VBuffer{T}"/> of ReadOnlyMemory
/// as a model, both in a binary and more easily human readable form.
/// </summary>
public static class TextModelHelper
[BestFriend]
internal static class TextModelHelper
{
private const string LoaderSignature = "TextSpanBuffer";

Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ protected OneToOneTransformerBase(IHost host, params (string outputColumnName, s
ColumnPairs = columns;
}

protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) : base(host)
[BestFriend]
private protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) : base(host)
{
// *** Binary format ***
// int: number of added columns
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/ValueMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ private static bool CheckModelVersion(ModelLoadContext ctx, VersionInfo versionI
}
}

protected static ValueMappingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
private protected static ValueMappingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public LinearModelParameters(IHostEnvironment env, string name, in VBuffer<float
_weightsDenseLock = new object();
}

protected LinearModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx)
private protected LinearModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx)
: base(env, name, ctx)
{
// *** Binary format ***
Expand Down Expand Up @@ -547,12 +547,14 @@ private protected override void SaveAsIni(TextWriter writer, RoleMappedSchema sc

public abstract class RegressionModelParameters : LinearModelParameters
{
public RegressionModelParameters(IHostEnvironment env, string name, in VBuffer<float> weights, float bias)
[BestFriend]
private protected RegressionModelParameters(IHostEnvironment env, string name, in VBuffer<float> weights, float bias)
: base(env, name, in weights, bias)
{
}

protected RegressionModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx)
[BestFriend]
private protected RegressionModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx)
: base(env, name, ctx)
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ public async Task ContractsCheck()
VerifyCS.Diagnostic(ContractsCheckAnalyzer.SimpleMessageDiagnostic.Rule).WithLocation(basis + 32, 35).WithArguments("Check", "\"Less fine: \" + env.GetType().Name"),
VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(basis + 34, 17).WithArguments("CheckUserArg", "name", "\"p\""),
VerifyCS.Diagnostic(ContractsCheckAnalyzer.DecodeMessageWithLoadContextDiagnostic.Rule).WithLocation(basis + 39, 41).WithArguments("CheckDecode", "\"This message is suspicious\""),
new DiagnosticResult("CS0117", DiagnosticSeverity.Error).WithLocation("Test1.cs", 220, 70).WithMessage("'MessageSensitivity' does not contain a definition for 'UserData'"),
new DiagnosticResult("CS0117", DiagnosticSeverity.Error).WithLocation("Test1.cs", 231, 70).WithMessage("'MessageSensitivity' does not contain a definition for 'Schema'"),
new DiagnosticResult("CS1061", DiagnosticSeverity.Error).WithLocation("Test1.cs", 747, 21).WithMessage("'IHostEnvironment' does not contain a definition for 'IsCancelled' and no accessible extension method 'IsCancelled' accepting a first argument of type 'IHostEnvironment' could be found (are you missing a using directive or an assembly reference?)"),
};

var test = new VerifyCS.Test
Expand Down
Loading