Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes MCTS and adapts Tester to it #181

Merged
merged 20 commits into from
Jan 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tester bug
  • Loading branch information
AlexHartmann00 committed Jan 7, 2022
commit d85d1a318f9091e63e7a21fb553cb09b100dae46
11 changes: 9 additions & 2 deletions app/src/main/java/Bamboo/controller/MCTS/MCTS.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class MCTS implements Agent
{
private Color colour;
private Node root;
private boolean testing = false;
public Mutable<Integer> iterations = new Mutable<>(200);
public Mutable<Float> c = new Mutable<>(1f);
public Mutable<Heuristic> heuristic = new Mutable<>(new Uniform());
Expand Down Expand Up @@ -49,8 +50,12 @@ public Vector getNextMove(Game game)
root = lastMove;
else
root = new Node(game.getGrid(), game.getCurrentPlayer().getColor(), null, null);

for(int i = 0; i < iterations.get(); i++)
UCB.C = c.get();
int mutableValue = 0;
if(testing)
mutableValue = Math.round((float)(Number)iterations.get());
int iter = testing ? mutableValue : iterations.get();
for(int i = 0; i < iter; i++)
{
Node n = root.select();
if(n != null)
Expand All @@ -77,11 +82,13 @@ public Mutable<Integer> getDepth() {

@Override
public Mutable<Integer> getIterations() {
testing = true;
return this.iterations;
}

@Override
public Mutable<Float> getC() {
testing = true;
return this.c;
}

Expand Down
1 change: 0 additions & 1 deletion app/src/main/java/Bamboo/model/GameWithoutGUI.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ private void takeTurn(){
addMoveToLogString(move);
remainingTiles.remove(move);
this.grid.setTile(move,currentPlayer.getColor());
//System.out.println("Agent " + currentPlayer.getName() + " placed color " + currentPlayer.getColor() + " at " + move.toString());
toggleTurn();
}

Expand Down
65 changes: 47 additions & 18 deletions app/src/test/java/Bamboo/TestingAPI/Tester.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import Bamboo.controller.*;
import Bamboo.controller.heuristics.Heuristic;
import Bamboo.model.GameWithoutGUI;
import org.checkerframework.checker.units.qual.A;

import java.awt.Color;
import java.awt.desktop.SystemSleepEvent;
import java.io.IOException;
import java.util.ArrayList;

Expand All @@ -19,6 +21,8 @@ public class Tester {
private int replications;
private String fileName = "experiment.csv";
private ArrayList<Iterator> variables = new ArrayList<>();
private ArrayList<TesterAgent> variableTargets = new ArrayList<>();
private ArrayList<Variable> variableLabels = new ArrayList<>();
private float redPercentage = 0.5f;
private int gamesPlayed = 0;
private int localGameCounter = 0;
Expand All @@ -30,6 +34,7 @@ public class Tester {
private int total;
private boolean hasRun = false;
private Plan plan;
private float[][] table;
private ArrayList<String> colnames = new ArrayList<>();

public Tester(AgentType agent, int size) throws IOException {
Expand All @@ -53,30 +58,18 @@ public Tester() throws IOException{
public void setStartingColor(Color color){startingColor = color;}
public void resetStartingColor(){startingColor = Color.WHITE;}

public float[][] run(){
public float[][] run() throws IOException {
if(!hasRun)colnames.add("WinRate");
if(TRACK_TIME && !hasRun)colnames.add("ms");
makePlan();
int cols = plan.getCols() + 1;
System.out.println("Cols: " + cols + ", plan: " + plan);
float[] results = new float[plan.getRows()*this.replications];
if(TRACK_TIME)cols += 1;
float[][] table = new float[plan.getRows()][cols];
table = new float[plan.getRows()][cols];
writeHeader();
for(int i = 0; i < plan.getRows(); i++){
int[] indices = plan.getRowIndices(i);
for(int var = 0; var < plan.getCols(); var++){
if(!this.variables.get(var).isEmpty()){
Iterator currentVar = variables.get(var);
//TODO: Get numeric if numeric, other values otherwise
if(currentVar.isNumeric())
currentVar.getReference().set(currentVar.getValues()[indices[var]]);
else
currentVar.getReference().set(currentVar.getNon_numeric_values()[indices[var]]);
}
table[i][var] = indices[var];
}
results[i] = getWinPercentage();
results[i] = getWinPercentage(i);
table[i][plan.getCols()] = results[i];
if(TRACK_TIME)table[i][cols-1] = elapsed;
writeRow(table[i]);
Expand All @@ -87,8 +80,22 @@ public float[][] run(){
}

//Gets winner from one game
private Agent getWinner(){
private Agent getWinner(int planRow) throws IOException {
this.gamesPlayed ++;
agent1 = AgentFactory.makeAgent(player1,Color.RED);
agent2 = AgentFactory.makeAgent(player2,Color.BLUE);
refreshReferences();
int[] indices = plan.getRowIndices(planRow);
for(int var = 0; var < plan.getCols(); var++){
if(!this.variables.get(var).isEmpty()){
Iterator currentVar = variables.get(var);
if(currentVar.isNumeric())
currentVar.getReference().set(currentVar.getValues()[indices[var]]);
else
currentVar.getReference().set(currentVar.getNon_numeric_values()[indices[var]]);
}
table[planRow][var] = indices[var];
}
Settings settings = new Settings(agent1, agent2, ((Number)boardSize.get()).intValue());
GameWithoutGUI game;
if(startingColor == Color.WHITE)
Expand All @@ -105,11 +112,11 @@ private Agent getWinner(){
return winner;
}

private float getWinPercentage(){
private float getWinPercentage(int planRow) throws IOException {
int sum = 0;
long before = System.nanoTime();
for(int i = 0; i < replications; i++){
if(getWinner() == agent1)
if(getWinner(planRow) == agent1)
sum++;
}
long after = System.nanoTime();
Expand Down Expand Up @@ -198,20 +205,25 @@ public void addVariable(Variable v, float value){
Mutable ref = VariableFactory.getValueFromVariable(v,this.getAgent1(),this);
Iterator variable = new Iterator<>(ref,value);
colnames.add(v.toString());
variableTargets.add(TesterAgent.AGENT_1);
variableLabels.add(v);
pushVariable(variable);
}

public void addVariable(Variable v, float[] value){
Mutable ref = VariableFactory.getValueFromVariable(v,this.getAgent1(),this);
Iterator variable = new Iterator<>(ref,value);
colnames.add(v.toString());
variableTargets.add(TesterAgent.AGENT_1);
variableLabels.add(v);
pushVariable(variable);
}

public void addVariable(Variable v, float min, float max, float step){
Mutable ref = VariableFactory.getValueFromVariable(v,this.getAgent1(),this);
Iterator variable = new Iterator<>(ref,min,max,step);
colnames.add(v.toString());
variableTargets.add(TesterAgent.AGENT_1);
pushVariable(variable);
}

Expand All @@ -220,6 +232,8 @@ public void addVariable(TesterAgent agent,Variable v, float value){
Mutable ref = VariableFactory.getValueFromVariable(v,a,this);
Iterator variable = new Iterator<>(ref,value);
colnames.add(agent.toString() + "_" + v.toString());
variableTargets.add(agent);
variableLabels.add(v);
pushVariable(variable);
}

Expand All @@ -228,6 +242,8 @@ public void addVariable(TesterAgent agent, Variable v, float[] value){
Mutable ref = VariableFactory.getValueFromVariable(v,a,this);
Iterator variable = new Iterator<>(ref,value);
colnames.add(agent.toString() + "_" + v.toString());
variableTargets.add(agent);
variableLabels.add(v);
pushVariable(variable);
}

Expand All @@ -236,6 +252,8 @@ public void addVariable(TesterAgent agent ,Variable v, float min, float max, flo
Mutable ref = VariableFactory.getValueFromVariable(v,a,this);
Iterator variable = new Iterator<>(ref,min,max,step);
colnames.add(agent.toString() + "_" + v.toString());
variableTargets.add(agent);
variableLabels.add(v);
pushVariable(variable);
}

Expand All @@ -244,9 +262,20 @@ public void addVariable(TesterAgent agent, Variable v, Heuristics[] values){
Mutable ref = VariableFactory.getValueFromVariable(v,a,this);
Iterator variable = new Iterator<>(ref,values);
colnames.add(agent.toString() + "_" + v.toString());
variableTargets.add(agent);
variableLabels.add(v);
pushVariable(variable);
}

private void refreshReferences(){
for(int i = 0; i < variables.size(); i++){
//System.out.println("Refreshing: " + variableTargets.get(i) + ", Variable " + variableLabels.get(i) + ", correct agent: " + (TesterAgentFactory.getAgentReference(this,variableTargets.get(i)) == agent1));
//System.out.println("iterations ref: true::" + agent1.getIterations() + ", got::" + VariableFactory.getValueFromVariable(variableLabels.get(i),TesterAgentFactory.getAgentReference(this,variableTargets.get(i)),this) + ", equal::" + (VariableFactory.getValueFromVariable(variableLabels.get(i),TesterAgentFactory.getAgentReference(this,variableTargets.get(i)),this)==agent1.getIterations()));
variables.get(i).setReference(VariableFactory.getValueFromVariable(variableLabels.get(i),TesterAgentFactory.getAgentReference(this,variableTargets.get(i)),this));
//System.out.println("Iterations ref: " + (variables.get(i).getReference()==agent1.getIterations()));
}
}

public void addMetric(Metrics m){
switch(m){
case ELAPSED_TIME -> this.TRACK_TIME = true;
Expand Down
8 changes: 8 additions & 0 deletions app/src/test/java/Bamboo/TestingAPI/TestingAPITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ public class TestingAPITest {
tester.run();
}

@Disabled
@Test void newMCTSTest() throws IOException{
Tester tester = new Tester(AgentType.MCTS,3);
tester.addVariable(TesterAgent.AGENT_1,Variable.ITERATIONS,1,200,100);
tester.addVariable(TesterAgent.AGENT_1,Variable.C,0.1f,1f,0.5f);
tester.setReplications(2);
tester.run();
}

@Disabled
@Test void neuralNetTest() throws IOException{
Expand Down