Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean TicTacToe games up #1311

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
688e839
Fix UltimateTTT::MaxGameLength
odanrc Dec 20, 2024
2c7e362
Fix PhantomTTT::MaxGameLength
odanrc Dec 20, 2024
47ea176
Add an alias for the board in TicTacToe
odanrc Dec 15, 2024
39d2dc7
Create a class for a TTT board
odanrc Dec 20, 2024
f17d1d5
Fix misuses of TTT::kNumCells in UltTTT
odanrc Dec 21, 2024
28a4b9e
Generalize TTT hasLine
odanrc Dec 21, 2024
bed642a
Add a TTT Board::Size function
odanrc Dec 21, 2024
88ddeec
Force caller to inform GridBoard dimensions
odanrc Jan 5, 2025
df7eeb5
Rewrite UltimateTTT::toString to use consts
odanrc Jan 5, 2025
fd5cf45
Add functions to get dimensions of GridBoard
odanrc Jan 5, 2025
8136dfc
Remove uses of kNumCells in PhantomTTT
odanrc Jan 6, 2025
ad2b908
Replace all uses of kNumCells in PhantomTTT
odanrc Jan 6, 2025
72af30c
Replace all uses of kNumCells in UltimateTTT
odanrc Jan 6, 2025
5446308
Make PhantomTTT use GridBoard members
odanrc Jan 25, 2025
817945f
Reorder tensor view in Ultimate TTT
odanrc Jan 28, 2025
4e98694
Deduplicate UltTTT::ObservationTensorShape
odanrc Jan 29, 2025
a8c72e3
Rewrite UltTTT's ToString to avoid dup
odanrc Feb 1, 2025
5f5eadb
Remove mentions to global TTT consts in BoardHasLine
odanrc Feb 1, 2025
8c334f6
Make rows and cols params in TTT
odanrc Feb 1, 2025
372fe4b
Create a GridBoard::ToString in TTT
odanrc Feb 8, 2025
8ff4b26
Check legal actions for a board in TTT
odanrc Feb 9, 2025
3076456
Add TTT::GridBoard::Tile
odanrc Feb 16, 2025
4f82435
PlayerToState replaced by PlayerToComponent
odanrc Feb 16, 2025
d5e3d44
Add ToString to Tile and Component
odanrc Feb 16, 2025
301989a
Reduce use of CellState
odanrc Feb 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 47 additions & 63 deletions open_spiel/games/phantom_ttt/phantom_ttt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ namespace phantom_ttt {
namespace {

using tic_tac_toe::kCellStates;
using tic_tac_toe::kNumCells;
using tic_tac_toe::kNumCols;
using tic_tac_toe::kNumRows;

using tic_tac_toe::CellState;

using tic_tac_toe::PlayerToState;
using tic_tac_toe::StateToString;

// Facts about the game.
const GameType kGameType{
Expand Down Expand Up @@ -93,19 +85,22 @@ ImperfectRecallPTTTGame::ImperfectRecallPTTTGame(const GameParameters& params)
: PhantomTTTGame(params, kImperfectRecallGameType) {}

PhantomTTTState::PhantomTTTState(std::shared_ptr<const Game> game,
ObservationType obs_type)
: State(game), state_(game), obs_type_(obs_type) {
std::fill(begin(x_view_), end(x_view_), CellState::kEmpty);
std::fill(begin(o_view_), end(o_view_), CellState::kEmpty);
ObservationType obs_type, size_t rows,
size_t cols)
: State(game), state_(game, rows, cols), obs_type_(obs_type),
x_view_(rows, cols), o_view_(rows, cols) {
if (obs_type_ == ObservationType::kRevealNumTurns) {
// Reserve 0 for the player and 10 as "I don't know."
bits_per_action_ = kNumCells + 2;
// Longest sequence is 17 moves, e.g. 0011223344556677889
longest_sequence_ = 2 * kNumCells - 1;
// Reserve 1 bit to select the player and another bit as "I don't know."
bits_per_action_ = NumDistinctActions() + 2;
// The longest sequence happens when each player perfectly mimics
// the moves of the other player, e.g. 0011223344556677889. Once
// the last possible action is taken by any of the players, the
// game ends (i.e., this action is not mimicable by the other player).
longest_sequence_ = 2 * NumDistinctActions() - 1;
} else {
SPIEL_CHECK_EQ(obs_type_, ObservationType::kRevealNothing);
bits_per_action_ = kNumCells;
longest_sequence_ = kNumCells;
bits_per_action_ = NumDistinctActions();
longest_sequence_ = NumDistinctActions();
}
}

Expand All @@ -115,14 +110,14 @@ void PhantomTTTState::DoApplyAction(Action move) {
auto& cur_view = cur_player == 0 ? x_view_ : o_view_;

// Two cases: either there is a mark already there, or not.
if (state_.BoardAt(move) == CellState::kEmpty) {
if (state_.BoardAt(move).IsEmpty()) {
// No mark on board, so play this normally.
state_.ApplyAction(move);
}

// Update current player's view, and action sequence.
SPIEL_CHECK_EQ(cur_view[move], CellState::kEmpty);
cur_view[move] = state_.BoardAt(move);
SPIEL_CHECK_EQ(cur_view.At(move).IsEmpty(), true);
cur_view.At(move) = state_.BoardAt(move);
action_sequence_.push_back(std::pair<int, Action>(cur_player, move));

// Note: do not modify player's turn here, it will have been done above
Expand All @@ -131,31 +126,13 @@ void PhantomTTTState::DoApplyAction(Action move) {

std::vector<Action> PhantomTTTState::LegalActions() const {
if (IsTerminal()) return {};
std::vector<Action> moves;
const Player player = CurrentPlayer();
const auto& cur_view = player == 0 ? x_view_ : o_view_;

for (Action move = 0; move < kNumCells; ++move) {
if (cur_view[move] == CellState::kEmpty) {
moves.push_back(move);
}
}

return moves;
const auto& cur_view = CurrentPlayer() == 0 ? x_view_ : o_view_;
return state_.LegalActions(cur_view);
}

std::string PhantomTTTState::ViewToString(Player player) const {
const auto& cur_view = player == 0 ? x_view_ : o_view_;
std::string str;
for (int r = 0; r < kNumRows; ++r) {
for (int c = 0; c < kNumCols; ++c) {
absl::StrAppend(&str, StateToString(cur_view[r * kNumCols + c]));
}
if (r < (kNumRows - 1)) {
absl::StrAppend(&str, "\n");
}
}
return str;
return cur_view.ToString();
}

std::string PhantomTTTState::ActionSequenceToString(Player player) const {
Expand Down Expand Up @@ -197,22 +174,23 @@ void PhantomTTTState::InformationStateTensor(Player player,
SPIEL_CHECK_GE(player, 0);
SPIEL_CHECK_LT(player, num_players_);

// First 27 bits encodes the player's view in the same way as TicTacToe.
// First bits encodes the player's view in the same way as TicTacToe.
// Then the action sequence follows (one-hot encoded, per action).
// Encoded in the same way as InformationStateAsString, so full sequences
// which may contain action value 10 to represent "I don't know."
const auto& player_view = player == 0 ? x_view_ : o_view_;
SPIEL_CHECK_EQ(values.size(), kNumCells * kCellStates +
SPIEL_CHECK_EQ(values.size(), NumDistinctActions() * kCellStates +
longest_sequence_ * bits_per_action_);
std::fill(values.begin(), values.end(), 0.);
for (int cell = 0; cell < kNumCells; ++cell) {
values[kNumCells * static_cast<int>(player_view[cell]) + cell] = 1.0;
for (int cell = 0; cell < player_view.Size(); ++cell) {
values[NumDistinctActions() * TileToState(player_view.At(cell)) + cell] =
1.0;
}

// Now encode the sequence. Each (player, action) pair uses 11 bits:
// - first bit is the player taking the action (0 or 1)
// - next 10 bits is the one-hot encoded action (10 = "I don't know")
int offset = kNumCells * kCellStates;
int offset = NumDistinctActions() * kCellStates;
for (const auto& player_with_action : action_sequence_) {
if (player_with_action.first == player) {
// Always include the observing player's actions.
Expand All @@ -230,7 +208,7 @@ void PhantomTTTState::InformationStateTensor(Player player,
// If the number of turns are revealed, then each of the other player's
// actions will show up as unknowns.
values[offset] = player_with_action.first;
values[offset + 1 + kNumCells] = 1.0; // I don't know.
values[offset + 1 + NumDistinctActions()] = 1.0; // I don't know.
offset += bits_per_action_;
} else {
// Do not reveal anything about the number of actions taken by opponent.
Expand All @@ -256,15 +234,16 @@ void PhantomTTTState::ObservationTensor(Player player,
SPIEL_CHECK_EQ(values.size(), game_->ObservationTensorSize());
std::fill(values.begin(), values.end(), 0.);

// First 27 bits encodes the player's view in the same way as TicTacToe.
// First bits encodes the player's view in the same way as TicTacToe.
const auto& player_view = player == 0 ? x_view_ : o_view_;
for (int cell = 0; cell < kNumCells; ++cell) {
values[kNumCells * static_cast<int>(player_view[cell]) + cell] = 1.0;
for (int cell = 0; cell < player_view.Size(); ++cell) {
values[NumDistinctActions() * TileToState(player_view.At(cell)) + cell] =
1.0;
}

// Then a one-hot to represent total number of turns.
if (obs_type_ == ObservationType::kRevealNumTurns) {
values[kNumCells * kCellStates + action_sequence_.size()] = 1.0;
values[NumDistinctActions() * kCellStates + action_sequence_.size()] = 1.0;
}
}

Expand All @@ -276,15 +255,16 @@ void PhantomTTTState::UndoAction(Player player, Action move) {
Action last_move = action_sequence_.back().second;
SPIEL_CHECK_EQ(last_move, move);

if (state_.BoardAt(move) == PlayerToState(player)) {
if (state_.BoardAt(move).component_ ==
tic_tac_toe::PlayerToComponent(player)) {
// If the board has a mark that is the undoing player, then this was
// a successful move. Undo as normal.
state_.UndoAction(player, move);
}

// Undo the action from that player's view, and pop from the action seq
auto& player_view = player == 0 ? x_view_ : o_view_;
player_view[move] = CellState::kEmpty;
player_view.At(move).Clear();
action_sequence_.pop_back();

history_.pop_back();
Expand All @@ -300,29 +280,33 @@ PhantomTTTGame::PhantomTTTGame(const GameParameters& params, GameType game_type)
std::string obs_type = ParameterValue<std::string>("obstype");
if (obs_type == "reveal-nothing") {
obs_type_ = ObservationType::kRevealNothing;
bits_per_action_ = kNumCells;
longest_sequence_ = kNumCells;
bits_per_action_ = game_->NumDistinctActions();
longest_sequence_ = game_->NumDistinctActions();
} else if (obs_type == "reveal-numturns") {
obs_type_ = ObservationType::kRevealNumTurns;
// Reserve 0 for the player and 10 as "I don't know."
bits_per_action_ = kNumCells + 2;
// Longest sequence is 17 moves, e.g. 0011223344556677889
longest_sequence_ = 2 * kNumCells - 1;
// Reserve 1 bit to select the player and another bit as "I don't know."
bits_per_action_ = game_->NumDistinctActions() + 2;
// The longest sequence happens when each player perfectly mimics
// the moves of the other player, e.g. 0011223344556677889. Once
// the last possible action is taken by any of the players, the
// game ends (i.e., this action is not mimicable by the other player).
longest_sequence_ = 2 * game_->NumDistinctActions() - 1;
} else {
SpielFatalError(absl::StrCat("Unrecognized observation type: ", obs_type));
}
}

std::vector<int> PhantomTTTGame::InformationStateTensorShape() const {
// Enc
return {1, kNumCells * kCellStates + longest_sequence_ * bits_per_action_};
return {1, game_->NumDistinctActions() * kCellStates +
longest_sequence_ * bits_per_action_};
}

std::vector<int> PhantomTTTGame::ObservationTensorShape() const {
if (obs_type_ == ObservationType::kRevealNothing) {
return {kNumCells * kCellStates};
return {game_->NumDistinctActions() * kCellStates};
} else if (obs_type_ == ObservationType::kRevealNumTurns) {
return {kNumCells * kCellStates + longest_sequence_};
return {game_->NumDistinctActions() * kCellStates + longest_sequence_};
} else {
SpielFatalError("Unknown observation type");
}
Expand Down
24 changes: 14 additions & 10 deletions open_spiel/games/phantom_ttt/phantom_ttt.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#ifndef OPEN_SPIEL_GAMES_PHANTOM_TTT_H_
#define OPEN_SPIEL_GAMES_PHANTOM_TTT_H_

#include <array>
#include <map>
#include <memory>
#include <string>
Expand Down Expand Up @@ -53,7 +52,8 @@ enum class ObservationType {
// State of an in-play game.
class PhantomTTTState : public State {
public:
PhantomTTTState(std::shared_ptr<const Game> game, ObservationType obs_type);
PhantomTTTState(std::shared_ptr<const Game> game, ObservationType obs_type,
size_t rows, size_t cols);

// Forward to underlying game state
Player CurrentPlayer() const override { return state_.CurrentPlayer(); }
Expand Down Expand Up @@ -89,8 +89,8 @@ class PhantomTTTState : public State {

// TODO(author2): Use the base class history_ instead.
std::vector<std::pair<int, Action>> action_sequence_;
std::array<tic_tac_toe::CellState, tic_tac_toe::kNumCells> x_view_;
std::array<tic_tac_toe::CellState, tic_tac_toe::kNumCells> o_view_;
tic_tac_toe::GridBoard x_view_;
tic_tac_toe::GridBoard o_view_;
};

// Game object.
Expand All @@ -99,7 +99,8 @@ class PhantomTTTGame : public Game {
PhantomTTTGame(const GameParameters& params, GameType game_type);
std::unique_ptr<State> NewInitialState() const override {
return std::unique_ptr<State>(
new PhantomTTTState(shared_from_this(), obs_type_));
new PhantomTTTState(shared_from_this(), obs_type_, game_->Rows(),
game_->Cols()));
}
int NumDistinctActions() const override {
return game_->NumDistinctActions();
Expand All @@ -118,12 +119,14 @@ class PhantomTTTGame : public Game {
// These will depend on the obstype parameter.
std::vector<int> InformationStateTensorShape() const override;
std::vector<int> ObservationTensorShape() const override;
int MaxGameLength() const override { return tic_tac_toe::kNumCells * 2 - 1; }
int MaxGameLength() const override { return game_->MaxGameLength() * 2 - 1; }

ObservationType obs_type() const { return obs_type_; }

private:
protected:
std::shared_ptr<const tic_tac_toe::TicTacToeGame> game_;

private:
ObservationType obs_type_;
int bits_per_action_;
int longest_sequence_;
Expand All @@ -134,8 +137,8 @@ class PhantomTTTGame : public Game {
class ImperfectRecallPTTTState : public PhantomTTTState {
public:
ImperfectRecallPTTTState(std::shared_ptr<const Game> game,
ObservationType obs_type)
: PhantomTTTState(game, obs_type) {}
ObservationType obs_type, size_t rows, size_t cols)
: PhantomTTTState(game, obs_type, rows, cols) {}
std::string InformationStateString(Player player) const override {
SPIEL_CHECK_GE(player, 0);
SPIEL_CHECK_LT(player, num_players_);
Expand All @@ -151,7 +154,8 @@ class ImperfectRecallPTTTGame : public PhantomTTTGame {
explicit ImperfectRecallPTTTGame(const GameParameters& params);
std::unique_ptr<State> NewInitialState() const override {
return std::unique_ptr<State>(
new ImperfectRecallPTTTState(shared_from_this(), obs_type()));
new ImperfectRecallPTTTState(shared_from_this(), obs_type(),
game_->Rows(), game_->Cols()));
}
};

Expand Down
Loading
Loading