Skip to content

Commit

Permalink
Fen token.
Browse files Browse the repository at this point in the history
  • Loading branch information
akuroiwa committed Nov 20, 2022
1 parent ed414d3 commit d264a8f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
5 changes: 3 additions & 2 deletions chess_classification/chess_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def train_and_eval(self, train_json, eval_json):

def predict_fen(self, fen):
# Make predictions with the model
prediction, raw_outputs = self.model.predict([fen])
fen_token = fen.replace("/", " ")
# fen_token = ' '.join(list(fen))
prediction, raw_outputs = self.model.predict([fen_token])
# print(prediction, raw_outputs)
return prediction, raw_outputs

5 changes: 4 additions & 1 deletion chess_classification/genPgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ async def main(path="train-pgn", loop=1, time=20, fen=None) -> None:
os.makedirs(path, exist_ok=True)

for i in range(loop):
board = chess.Board(fen)
if fen:
board = chess.Board(fen)
else:
board = chess.Board()
game = chess.pgn.Game()
game.headers["Event"]
game.setup(board)
Expand Down
10 changes: 8 additions & 2 deletions chess_classification/importPgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ def readPgn(path):
for file in os.listdir(path):
if file.endswith(".pgn"):
file_path = os.path.join(path, file)
train_df = train_df.append(importPgn(file_path), ignore_index=True)
# train_df = train_df.append(importPgn(file_path), ignore_index=True)
df = pd.DataFrame(importPgn(file_path), columns=["text", "labels"])
train_df = pd.concat([train_df, df], ignore_index=True)
train_df.to_json(os.path.join(path, "fen.json"))

def importPgn(pgn_file):
Expand All @@ -35,7 +37,11 @@ def importPgn(pgn_file):
board = game.board()
for move in game.mainline_moves():
board.push(move)
train_df = train_df.append({"text": board.fen(), "labels": result_label}, ignore_index=True)
fen_token = board.fen().replace("/", " ")
# fen_token = ' '.join(list(board.fen()))
# train_df = train_df.append({"text": fen_token, "labels": result_label}, ignore_index=True)
df = pd.DataFrame([{"text": fen_token, "labels": result_label}], columns=["text", "labels"])
train_df = pd.concat([train_df, df], ignore_index=True)
except:
break
return train_df
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='chess_classification',
version='0.0.3',
version='0.0.4',
url='https://github.com/akuroiwa/chess-classification',
# # PyPI url
# download_url='',
Expand Down

0 comments on commit d264a8f

Please sign in to comment.