Skip to content

Fix RLCSE for new metrics format #391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 95 additions & 39 deletions src/jit-rl-cse/MLCSE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
// (or SPMI server mode where we can just send requests to a long-running instance?)
public class MLCSE
{
public static string spmiCollection = @"D:\spmi\mch\d7bbeb5a-aa7d-43ec-b29e-6f24dd3bca9c.windows.x64\aspnet.run.windows.x64.checked.mch";
public static string spmiCollection = @"d:\spmi\mch\b8a05f18-503e-47e4-9193-931c50b151d1.windows.x64\aspnet.run.windows.x64.checked.mch";
public static string checkedCoreRoot = @"c:\repos\runtime0\artifacts\tests\coreclr\Windows.x64.Checked\Tests\Core_Root";
public static string dumpDir = @"d:\bugs\cse-metrics";

Expand All @@ -60,7 +60,7 @@ public class MLCSE
public static Dictionary<Method, State> Best = new Dictionary<Method, State>();

// Number of methods to consider
public static int numMethods = 100;
public static int numMethods = 20;

public static void Main(string[] args)
{
Expand All @@ -83,13 +83,17 @@ public static void Main(string[] args)

// methodsToUse = ["61266"];

// methodsToUse = new List<Method>() { "31866", "35481", "31554" };

// SaveDump(methodsToUse.First());

if (methodsToUse.Count() == 0)
{
methodsToUse = GetMethodSample();
}



// Optionally build a data set that describes the features of the CSE candidates
// (useful for normalizing things)
bool doGatherFeatures = false;
Expand All @@ -111,10 +115,7 @@ public static void Main(string[] args)

if (forgetMCMC)
{
// not quite working... sigh
Q.Clear();
V.Clear();
CollectionData.BuildMethodList(spmiCollection, checkedCoreRoot);
Forget();
}
}

Expand Down Expand Up @@ -146,6 +147,19 @@ static void ComputeBaseline(Method m)
V[baselineState] = data;
}

static void Forget()
{
Best.Clear();
foreach (State s in V.Keys)
{
StateData sd = V[s]; sd.bestPerfScore = sd.basePerfScore; sd.averagePerfScore = 0; sd.numVisits = 0; V[s] = sd;
}
foreach (StateAndAction sa in Q.Keys)
{
StateAndActionData sad = Q[sa]; sad.count = 0; sad.perfScore = 0; Q[sa] = sad;
}
}

// Get or compute the baseline state for a method. This is the terminal state
// reached by the default jit CSE policy.
static State BaselineState(Method m)
Expand Down Expand Up @@ -240,6 +254,8 @@ static IEnumerable<Method> GetMethodSample()
return (numCse > 0) && (minCandidatesToExplore <= numCand) && (numCand <= maxCandidatesToExplore);
}).Select(s => s.method);

Console.WriteLine($"{methods.Count()} methods with between {minCandidatesToExplore} and {maxCandidatesToExplore} cses, {(randomSample ? " randomly" : "")} choosing {maxMethodsToExplore}.");

// optionally randomly shuffle
if (randomSample)
{
Expand All @@ -261,14 +277,14 @@ static void EvaluateGreedyPolicy(string parameters, int runNumber = 0)

// Filter output to just per-method metrics lines.
//
var metricLines = greedyContents.Split(Environment.NewLine).Where(l => l.StartsWith(@"Total bytes of code", StringComparison.Ordinal));
var metricLines = greedyContents.Split(Environment.NewLine).Where(l => l.StartsWith(@"; Total bytes of code", StringComparison.Ordinal));

// Parse each of these. Ignore methods with 0 cse candidates.
//
var methodsAndScores = metricLines.Where(l => MetricsParser.GetNumCand(l) > 0).Select(l => { return (MetricsParser.GetMethodIndex(l), MetricsParser.GetPerfScore(l)); });
// var methodsAndScoresAndBaselines = methodsAndScores.Select(x => { return (x.Item1, x.Item2, V[BaselineState(x.Item1)].basePerfScore); });

uint count = (uint) methodsAndScores.Count();
uint count = (uint)methodsAndScores.Count();
double logSum = 0;
uint nBetter = 0;
uint nWorse = 0;
Expand All @@ -281,7 +297,7 @@ static void EvaluateGreedyPolicy(string parameters, int runNumber = 0)
Method bestMethod = "-1";
uint nRatio = 0;

foreach(var methodAndScore in methodsAndScores)
foreach (var methodAndScore in methodsAndScores)
{
Method method = methodAndScore.Item1;
double score = methodAndScore.Item2;
Expand Down Expand Up @@ -322,6 +338,8 @@ static void EvaluateGreedyPolicy(string parameters, int runNumber = 0)
Console.WriteLine($"Best: {bestMethod.spmiIndex,6} @ {best,7:F4}");
Console.WriteLine($"Worst: {worstMethod.spmiIndex,6} @ {worst,7:F4}");
Console.WriteLine();


//Console.WriteLine(metricLines.Where(l => MetricsParser.GetMethodIndex(l) == bestMethod.spmiIndex).First());
//Console.WriteLine(metricLines.Where(l => MetricsParser.GetMethodIndex(l) == worstMethod.spmiIndex).First());

Expand All @@ -336,8 +354,8 @@ static void PolicyGradient(IEnumerable<Method> methods)
{
// number of times we cycle through the methods
int nRounds = 10_000;
// how many trials per method each cycle
int nIter = 10;
// how many trials per method each cycle (minibatch)
int nIter = 25;
// how often to show results
bool showEvery = true;
uint showEveryInterval = 1;
Expand All @@ -358,17 +376,17 @@ static void PolicyGradient(IEnumerable<Method> methods)
// random salt
int salt = 6;
// learning rate
double alpha = 0.16;
double alpha = 0.02;
// just show tabular results
bool showTabular = true;
// how often to recap baseline/best/greedy
int summaryInterval = 50;
int summaryInterval = 25;
// show greedy policy in summary intervals
bool showGreedy = true;
// save QV dot files each summary interval?
bool saveQVdot = false;

// Initial parameter set. Must be non-empty. Jit will fill in 0 for any missing params.
// Initial parameter set. Must be non-empty. Jit will fill in 0 for any missing params.
string parameters = "0";
string prevParameters = parameters;
int nSameParams = 0;
Expand Down Expand Up @@ -458,7 +476,7 @@ static void PolicyGradient(IEnumerable<Method> methods)
string[] batchNewParams = new string[nIter];
string[] batchRuns = new string[nIter];

Parallel.For (0, nIter, i =>
Parallel.For(0, nIter, i =>
{
{
using StringWriter sw = new StringWriter();
Expand All @@ -468,7 +486,7 @@ static void PolicyGradient(IEnumerable<Method> methods)
//
int iterSalt = salt * nIter * nRounds + r * nIter + i;

List<string> policyOptions = new List<string>() {$"JitRLCSE={parameters}", $"JitRLCSEAlpha={alpha}", $"JitRandomCSE={iterSalt}"};
List<string> policyOptions = new List<string>() { $"JitRLCSE={parameters}", $"JitRLCSEAlpha={alpha}", $"JitRandomCSE={iterSalt}" };

if (showPolicyEvaluations)
{
Expand Down Expand Up @@ -520,7 +538,7 @@ static void PolicyGradient(IEnumerable<Method> methods)

for (int s = 0; s < subScores.Count() - 1; s++)
{
rewards.Add((subScores[s] - policyScore) / baselineScore);
rewards.Add((subScores[s] - subScores[s + 1]) / baselineScore);
}

string rewardString = String.Join(",", rewards);
Expand All @@ -531,7 +549,7 @@ static void PolicyGradient(IEnumerable<Method> methods)
sw.Write($" rewards: {String.Join(",", rewards.Select(x => $"{x,7:F4}")),-30}");
}

List<string> updateOptions = new List<string>() {$"JitRLCSE={parameters}", $"JitRLCSEAlpha={alpha}", $"JitRandomCSE={iterSalt}", $"JitReplayCSE={policySequence}", $"JitReplayCSEReward={rewardString}"};
List<string> updateOptions = new List<string>() { $"JitRLCSE={parameters}", $"JitRLCSEAlpha={alpha}", $"JitRandomCSE={iterSalt}", $"JitReplayCSE={policySequence}", $"JitReplayCSEReward={rewardString}" };

if (showPolicyUpdates)
{
Expand Down Expand Up @@ -564,23 +582,56 @@ static void PolicyGradient(IEnumerable<Method> methods)

// Optionally save dumps for certain sequences
// We do this as separate run to not mess up metrics parsing...
if (method.spmiIndex == "61266" && (updateSequence == "1,2,0" || updateSequence == "1,0"))
// Todo: parameterize this
if (i == 0 && method.spmiIndex == "6276")
{
string cleanSequence = updateSequence.Replace(',', '_');
string dumpFile = Path.Combine(dumpDir, $"dump-{method.spmiIndex}-{cleanSequence}.d");
if (!File.Exists(dumpFile))
{
updateOptions.Add($"JitDump=*");
updateOptions.Add($"JitStdOutFile={dumpFile}");
string dumpRun = SPMI.Run(method.spmiIndex, updateOptions);
List<string> dumpOptions = new List<string>(updateOptions);
dumpOptions.Add($"JitDump=*");
dumpOptions.Add($"JitStdOutFile={dumpFile}");
string dumpRun = SPMI.Run(method.spmiIndex, dumpOptions);
sw.WriteLine($" ---> saved dump to {dumpFile}");
}

string dasmFile = Path.Combine(dumpDir, $"dump-{method.spmiIndex}-{cleanSequence}.dasm");
if (!File.Exists(dasmFile))
{
List<string> dasmOptions = new List<string>(updateOptions);
updateOptions.Add($"JitDisasm=*");
updateOptions.Add($"JitStdOutFile={dasmFile}");
string dasmRun = SPMI.Run(method.spmiIndex, updateOptions);
sw.WriteLine($" ---> saved dasm to {dasmFile}");
}

string baseSequence = "baseline";
string baseDumpFile = Path.Combine(dumpDir, $"dump-{method.spmiIndex}-{baseSequence}.d");
if (!File.Exists(baseDumpFile))
{
List<string> dumpOptions = new List<string>();
dumpOptions.Add($"JitDump=*");
dumpOptions.Add($"JitStdOutFile={baseDumpFile}");
string dumpRun = SPMI.Run(method.spmiIndex, dumpOptions);
sw.WriteLine($" ---> saved baseline dump to {baseDumpFile}");
}

string baseDasmFile = Path.Combine(dumpDir, $"dump-{method.spmiIndex}-{baseSequence}.dasm");
if (!File.Exists(baseDasmFile))
{
List<string> dasmOptions = new List<string>();
dasmOptions.Add($"JitDisasm=*");
dasmOptions.Add($"JitStdOutFile={baseDasmFile}");
string dasmRun = SPMI.Run(method.spmiIndex, dasmOptions);
sw.WriteLine($" ---> saved baseline dasm to {baseDasmFile}");
}
}

batchDetails[i] = sw.ToString();
}
});


// Post-process the batch
//
Expand Down Expand Up @@ -625,7 +676,7 @@ static void PolicyGradient(IEnumerable<Method> methods)

int numValid = validPerfScores.Count();

if (averageParams != null)
if (averageParams != null)
{
if (numValid > 1)
{
Expand Down Expand Up @@ -718,7 +769,7 @@ static void PolicyGradient(IEnumerable<Method> methods)
{
Console.Write($" params: {String.Join(",", MetricsParser.ToDoubles(parameters).Select(x => $"{x,7:F4}"))}");
}

Console.WriteLine();

// If parameters stay same for 50 iterations, stop.
Expand Down Expand Up @@ -837,7 +888,7 @@ static void DumpPolicyGradientStatus(IEnumerable<Method> methods, bool showPolic
{
// Todo: record these as they may be unique...
//
List<string> greedyOptions = new List<string>{$"JitRLCSE={parameters}", $"JitRLCSEGreedy=1"};
List<string> greedyOptions = new List<string> { $"JitRLCSE={parameters}", $"JitRLCSEGreedy=1" };

if (showPolicyUpdates)
{
Expand Down Expand Up @@ -932,7 +983,7 @@ static void MCMC(IEnumerable<Method> methods)
// Show the Markov Chain
bool showMC = false;
// Draw the Markov Chain (tree)
bool showMCDot = true;
bool showMCDot = false;

// Enable random MCMC mode
bool doRandomTrials = true;
Expand Down Expand Up @@ -993,7 +1044,7 @@ static void MCMC(IEnumerable<Method> methods)
//for (int i = 0; i < maxCase; i++)
Parallel.For(0, maxCase, i =>
{
List<string> policyOptions = new List<string>() {$"JitCSEHash=0"};
List<string> policyOptions = new List<string>() { $"JitCSEHash=0" };

if (doRandom && (i != 0))
{
Expand Down Expand Up @@ -1131,7 +1182,7 @@ static void MCMC(IEnumerable<Method> methods)

// Get the "current" value of a state for a method.
//
static double GetValue(Dictionary<State, StateData>V, Method method, State state)
static double GetValue(Dictionary<State, StateData> V, Method method, State state)
{
if (!V.ContainsKey(state))
{
Expand Down Expand Up @@ -1185,18 +1236,23 @@ static bool QVUpdate(Dictionary<StateAndAction, StateAndActionData> Q, Dictionar
state = nextState;
}

// Update V -- we can do this here for terminal states as they have no children.
// Create or update V[state] -- we can do this here for terminal states as they have no children.
//
if (!V.ContainsKey(state))
StateData? sd = null;
if (!V.TryGetValue(state, out sd))
{
StateData d = new StateData() { bestPerfScore = perfScore, averagePerfScore = perfScore, numVisits = 1 };
V[state] = d;
sd = new StateData();
V.Add(state, sd);
}
else

if (sd.numVisits == 0)
{
V[state].numVisits++;
sd.bestPerfScore = perfScore;
sd.averagePerfScore = perfScore;
}

sd.numVisits++;

// See if this is a new best state.
//
State best = BestState(method);
Expand Down Expand Up @@ -1358,12 +1414,12 @@ static void QVDumpDot(Method method, TextWriter? tw = null)
}
static void GatherFeatures(IEnumerable<Method> methods)
{
Console.WriteLine($"Gathering CSE features..." );
Console.WriteLine($"Gathering CSE features...");
Stopwatch s = new Stopwatch();
s.Start();
foreach (var method in methods)
{
List<string> policyOptions = new List<string>{$"JitCSEHash=0"};
List<string> policyOptions = new List<string> { $"JitCSEHash=0" };
policyOptions.Add($"JitRLCSE=0");
policyOptions.Add($"JitReplayCSE=1");
policyOptions.Add($"JitReplayCSEReward=1");
Expand Down Expand Up @@ -1416,7 +1472,7 @@ public static IEnumerable<Method> BuildMethodList(string spmiCollection, string

// Filter output to just per-method metrics line.
//
var metricLines = File.ReadLines(cseIndexFile).Where(l => l.StartsWith(@"Total bytes of code", StringComparison.Ordinal));
var metricLines = File.ReadLines(cseIndexFile).Where(l => l.StartsWith(@"; Total bytes of code", StringComparison.Ordinal));

// Parse each of these. Ignore methods with 0 cse candidates.
//
Expand Down Expand Up @@ -1895,7 +1951,7 @@ public static byte[] GetColor(double val, double min, double max)
byte r = (byte)(data[index, 0] * 255);
byte g = (byte)(data[index, 1] * 255);
byte b = (byte)(data[index, 2] * 255);
return new byte[] {r, g, b};
return new byte[] { r, g, b };
}

private static readonly double[,] data =
Expand Down