Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft modification to redirect logs to test output #4710

Merged
merged 10 commits into from
Feb 1, 2020
29 changes: 23 additions & 6 deletions src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ private sealed class ConsoleWriter
private readonly ConsoleEnvironment _parent;
private readonly TextWriter _out;
private readonly TextWriter _err;
private readonly TextWriter _test;

private readonly bool _colorOut;
private readonly bool _colorErr;
Expand All @@ -35,7 +36,7 @@ private sealed class ConsoleWriter
private const int _maxDots = 50;
private int _dots;

public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter)
public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter, TextWriter testWriter = null)
{
Contracts.AssertValue(parent);
Contracts.AssertValue(outWriter);
Expand All @@ -44,6 +45,7 @@ public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter
_parent = parent;
_out = outWriter;
_err = errWriter;
_test = testWriter;

_colorOut = outWriter == Console.Out;
_colorErr = outWriter == Console.Error;
Expand Down Expand Up @@ -86,10 +88,19 @@ public void PrintMessage(IMessageSource sender, ChannelMessage msg)
string prefix = WriteAndReturnLinePrefix(msg.Sensitivity, wr);
var commChannel = sender as PipeBase<ChannelMessage>;
if (commChannel?.Verbose == true)
{
WriteHeader(wr, commChannel);
if (_test != null)
WriteHeader(_test, commChannel);
}
if (msg.Kind == ChannelMessageKind.Warning)
{
wr.Write("Warning: ");
_test?.Write("Warning: ");
}
_parent.PrintMessageNormalized(wr, msg.Message, true, prefix);
if (_test != null)
_parent.PrintMessageNormalized(_test, msg.Message, true);
if (toColor)
Console.ResetColor();
}
Expand Down Expand Up @@ -340,6 +351,9 @@ protected override void Dispose(bool disposing)
private volatile ConsoleWriter _consoleWriter;
private readonly MessageSensitivity _sensitivityFlags;

// This object is used to write to the test log along with the console if the host process is a test environment
private TextWriter _testWriter;

/// <summary>
/// Create an ML.NET <see cref="IHostEnvironment"/> for local execution, with console feedback.
/// </summary>
Expand All @@ -348,10 +362,11 @@ protected override void Dispose(bool disposing)
/// <param name="sensitivity">Allowed message sensitivity.</param>
/// <param name="outWriter">Text writer to print normal messages to.</param>
/// <param name="errWriter">Text writer to print error messages to.</param>
/// <param name="testWriter">Optional TextWriter to write messages if the host is a test environment.</param>
public ConsoleEnvironment(int? seed = null, bool verbose = false,
MessageSensitivity sensitivity = MessageSensitivity.All,
TextWriter outWriter = null, TextWriter errWriter = null)
: this(RandomUtils.Create(seed), verbose, sensitivity, outWriter, errWriter)
TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
: this(RandomUtils.Create(seed), verbose, sensitivity, outWriter, errWriter, testWriter)
{
}

Expand All @@ -364,14 +379,16 @@ public ConsoleEnvironment(int? seed = null, bool verbose = false,
/// <param name="sensitivity">Allowed message sensitivity.</param>
/// <param name="outWriter">Text writer to print normal messages to.</param>
/// <param name="errWriter">Text writer to print error messages to.</param>
/// <param name="testWriter">Optional TextWriter to write messages if the host is a test environment.</param>
private ConsoleEnvironment(Random rand, bool verbose = false,
MessageSensitivity sensitivity = MessageSensitivity.All,
TextWriter outWriter = null, TextWriter errWriter = null)
TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
: base(rand, verbose, nameof(ConsoleEnvironment))
{
Contracts.CheckValueOrNull(outWriter);
Contracts.CheckValueOrNull(errWriter);
_consoleWriter = new ConsoleWriter(this, outWriter ?? Console.Out, errWriter ?? Console.Error);
_testWriter = testWriter;
_consoleWriter = new ConsoleWriter(this, outWriter ?? Console.Out, errWriter ?? Console.Error, testWriter);
_sensitivityFlags = sensitivity;
AddListener<ChannelMessage>(PrintMessage);
}
Expand Down Expand Up @@ -444,7 +461,7 @@ public OutputRedirector(ConsoleEnvironment env, TextWriter newOutWriter, TextWri
Contracts.AssertValue(newOutWriter);
Contracts.AssertValue(newErrWriter);
_root = env.Root;
_newConsoleWriter = new ConsoleWriter(_root, newOutWriter, newErrWriter);
_newConsoleWriter = new ConsoleWriter(_root, newOutWriter, newErrWriter, _root._testWriter);
_oldConsoleWriter = Interlocked.Exchange(ref _root._consoleWriter, _newConsoleWriter);
Contracts.AssertValue(_oldConsoleWriter);
}
Expand Down
30 changes: 30 additions & 0 deletions src/Microsoft.ML.Data/LoggingEventArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML.Runtime;

namespace Microsoft.ML
{
Expand All @@ -20,9 +21,38 @@ public LoggingEventArgs(string message)
Message = message;
}

/// <summary>
/// Initializes a new instane of <see cref="LoggingEventArgs"/> class that includes the kind and source of the message
/// </summary>
/// <param name="message"> The message being logged </param>
/// <param name="kind"> The type of message <see cref="ChannelMessageKind"/> </param>
/// <param name="source"> The source of the message </param>
public LoggingEventArgs(string message, ChannelMessageKind kind, string source)
{
RawMessage = message;
Kind = kind;
Source = source;
Message = $"[Source={Source}, Kind={Kind}] {RawMessage}";
}

/// <summary>
/// Gets the source component of the event
/// </summary>
public string Source { get; }

/// <summary>
/// Gets the type of message
/// </summary>
public ChannelMessageKind Kind { get; }

/// <summary>
/// Gets the message being logged.
/// </summary>
public string Message { get; }

/// <summary>
/// Gets the original message that doesn't include the source and kind
/// </summary>
public string RawMessage { get; }
}
}
4 changes: 1 addition & 3 deletions src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
if (log == null)
return;

var msg = $"[Source={source.FullName}, Kind={message.Kind}] {message.Message}";

log(this, new LoggingEventArgs(msg));
log(this, new LoggingEventArgs(message.Message, message.Kind, source.FullName));
}

string IExceptionContext.ContextDescription => _env.ContextDescription;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,8 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParame
if (stateGCHandle.IsAllocated)
stateGCHandle.Free();
}

ch.Info($"Bias: {bias}, Weights: [{String.Join(",", weights.DenseValues())}]");
return CreatePredictor(weights, bias);
}

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Vision/DnnRetrainTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ internal DnnRetrainTransformer(IHostEnvironment env, Session session, string[] o

_env = env;
_session = session;
_modelLocation = modelLocation;
_modelLocation = Path.IsPathRooted(modelLocation) ? modelLocation : Path.Combine(Directory.GetCurrentDirectory(), modelLocation);
_isTemporarySavedModel = isTemporarySavedModel;
_addBatchDimensionInput = addBatchDimensionInput;
_inputs = inputColumnNames;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ maml.exe CV tr=SymSGD{nt=1} threads=- norm=No dout=%Output% data=%Data% seed=1
Not adding a normalizer.
Data fully loaded into memory.
Initial learning rate is tuned to 100.000000
Bias: -468.3528, Weights: [4.515409,75.74901,22.2914,-10.50209,-28.58107,44.81024,23.8734,13.20304,2.448269]
Not training a calibrator because it is not needed.
Not adding a normalizer.
Data fully loaded into memory.
Initial learning rate is tuned to 100.000000
Bias: -484.2862, Weights: [-12.78704,140.4291,121.9383,37.5274,-129.8139,70.9061,-89.37057,81.64314,-32.32779]
Not training a calibrator because it is not needed.
Warning: The predictor produced non-finite prediction values on 8 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable.
TEST POSITIVE RATIO: 0.3785 (134.0/(134.0+220.0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ maml.exe TrainTest test=%Data% tr=SymSGD{nt=1} norm=No dout=%Output% data=%Data%
Not adding a normalizer.
Data fully loaded into memory.
Initial learning rate is tuned to 100.000000
Bias: -448.1, Weights: [-0.3852913,49.29393,-3.424153,16.76877,-25.15009,23.68305,-6.658058,13.76585,4.843107]
Not training a calibrator because it is not needed.
Warning: The predictor produced non-finite prediction values on 16 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable.
TEST POSITIVE RATIO: 0.3499 (239.0/(239.0+444.0))
Expand Down
12 changes: 11 additions & 1 deletion test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ protected BaseTestBaseline(ITestOutputHelper output) : base(output)
private string _baselineBuildStringDir;

// The writer to write to test log files.
protected TestLogger TestLogger;
protected StreamWriter LogWriter;
private protected ConsoleEnvironment _env;
protected IHostEnvironment Env => _env;
Expand All @@ -97,12 +98,21 @@ protected override void Initialize()

string logPath = Path.Combine(logDir, FullTestName + LogSuffix);
LogWriter = OpenWriter(logPath);
_env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter)

TestLogger = new TestLogger(Output);
_env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter, testWriter: TestLogger)
.AddStandardComponents();
ML = new MLContext(42);
ML.Log += LogTestOutput;
ML.AddStandardComponents();
}

private void LogTestOutput(object sender, LoggingEventArgs e)
{
if (e.Kind >= MessageKindToLog)
Output.WriteLine(e.Message);
}

// This method is used by subclass to dispose of disposable objects
// such as LocalEnvironment.
// It is called as a first step in test clean up.
Expand Down
12 changes: 12 additions & 0 deletions test/Microsoft.ML.TestFramework/BaseTestClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
using System.Reflection;
using System.Threading;
using Microsoft.ML.Internal.Internallearn.Test;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.TestFrameworkCommon.Attributes;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.TestFramework
Expand All @@ -18,6 +21,8 @@ public class BaseTestClass : IDisposable
public string TestName { get; set; }
public string FullTestName { get; set; }

public ChannelMessageKind MessageKindToLog;

static BaseTestClass()
{
AppDomain.CurrentDomain.UnhandledException += (sender, e) =>
Expand Down Expand Up @@ -54,6 +59,13 @@ public BaseTestClass(ITestOutputHelper output)
FullTestName = test.TestCase.TestMethod.TestClass.Class.Name + "." + test.TestCase.TestMethod.Method.Name;
TestName = test.TestCase.TestMethod.Method.Name;

MessageKindToLog = ChannelMessageKind.Error;
var attributes = test.TestCase.TestMethod.Method.GetCustomAttributes(typeof(LogMessageKind));
foreach (var attrib in attributes)
{
MessageKindToLog = attrib.GetNamedArgument<ChannelMessageKind>("MessageKind");
}

// write to the console when a test starts and stops so we can identify any test hangs/deadlocks in CI
Console.WriteLine($"Starting test: {FullTestName}");
Initialize();
Expand Down
51 changes: 51 additions & 0 deletions test/Microsoft.ML.TestFramework/TestLogger.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.IO;
using System.Text;
using Xunit.Abstractions;

namespace Microsoft.ML.TestFramework
{
public sealed class TestLogger : TextWriter
{
private Encoding _encoding;
private ITestOutputHelper _testOutput;

public override Encoding Encoding => _encoding;

public TestLogger(ITestOutputHelper testOutput)
{
_testOutput = testOutput;
_encoding = new UnicodeEncoding();
}

public override void Write(char value)
{
_testOutput.WriteLine($"{value}");
}

public override void Write(string value)
{
if (value.EndsWith("\r\n"))
value = value.Substring(0, value.Length - 2);
_testOutput.WriteLine(value);
}

public override void Write(string format, params object[] args)
{
if (format.EndsWith("\r\n"))
format = format.Substring(0, format.Length - 2);

_testOutput.WriteLine(format, args);
}

public override void Write(char[] buffer, int index, int count)
{
var span = buffer.AsSpan(index, count);
if ((span.Length >= 2) && (span[count - 2] == '\r') && (span[count - 1] == '\n'))
span = span.Slice(0, count - 2);
_testOutput.WriteLine(span.ToString());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.TestFrameworkCommon.Attributes
{
public sealed class LogMessageKind : Attribute
{
public ChannelMessageKind MessageKind { get; }
public LogMessageKind(ChannelMessageKind messageKind)
{
MessageKind = messageKind;
}
}
}