Skip to content

Commit b1f8a99

Browse files
committed
README, links, formatting
1 parent 6b70ae2 commit b1f8a99

18 files changed

+127
-95
lines changed

README.md

+80-69
Large diffs are not rendered by default.

chess_transformers/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
1-
__all__ = ["configs", "data", "train", "play", "transformers"]
1+
__all__ = [
2+
"transformers",
3+
"configs",
4+
"data",
5+
"train",
6+
"eval",
7+
"play",
8+
]

chess_transformers/configs/data/LE1222.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
###############################
1212

1313
DATA_FOLDER = os.path.join(
14-
os.environ["CT_DATA_FOLDER"], NAME
14+
os.environ.get("CT_DATA_FOLDER"), NAME
1515
) # folder containing all data files
1616
H5_FILE = NAME + ".h5" # H5 file containing data
1717
MAX_MOVE_SEQUENCE_LENGTH = 10 # expected maximum length of move sequences

chess_transformers/configs/data/LE1222x.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
###############################
1212

1313
DATA_FOLDER = os.path.join(
14-
os.environ["CT_DATA_FOLDER"], NAME
14+
os.environ.get("CT_DATA_FOLDER"), NAME
1515
) # folder containing all data files
1616
H5_FILE = NAME + ".h5" # H5 file containing data
1717
MAX_MOVE_SEQUENCE_LENGTH = 10 # expected maximum length of move sequences
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__all__ = ["LE1222", "LE1222x"]
1+
__all__ = ["LE1222", "LE1222x"]

chess_transformers/configs/models/CT-E-20.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,18 @@
6565
USE_AMP = True # use automatic mixed precision training?
6666
CRITERION = LabelSmoothedCE # training criterion (loss)
6767
OPTIMIZER = torch.optim.Adam # optimizer
68-
LOGS_DIR = os.path.join(os.environ["CT_LOGS_FOLDER"], NAME) # logs folder
68+
LOGS_DIR = (
69+
os.path.join(os.environ.get("CT_LOGS_FOLDER"), NAME)
70+
if os.environ.get("CT_LOGS_FOLDER")
71+
else None
72+
) # logs folder
6973

7074
###############################
7175
######### Checkpoints #########
7276
###############################
7377

7478
CHECKPOINT_FOLDER = os.path.join(
75-
os.environ["CT_CHECKPOINTS_FOLDER"], NAME
79+
os.environ.get("CT_CHECKPOINTS_FOLDER"), NAME
7680
) # folder containing checkpoints
7781
TRAINING_CHECKPOINT = (
7882
NAME + ".pt"
@@ -91,10 +95,10 @@
9195
########## Stockfish ##########
9296
###############################
9397

94-
STOCKFISH_PATH = os.environ["CT_STOCKFISH_PATH"] # path to Stockfish engine
95-
FAIRY_STOCKFISH_PATH = os.environ[
98+
STOCKFISH_PATH = os.environ.get("CT_STOCKFISH_PATH") # path to Stockfish engine
99+
FAIRY_STOCKFISH_PATH = os.environ.get(
96100
"CT_FAIRY_STOCKFISH_PATH"
97-
] # path to Fairy Stockfish engine
101+
) # path to Fairy Stockfish engine
98102
LICHESS_LEVELS = {
99103
1: {"SKILL": -9, "DEPTH": 5, "TIME_CONSTRAINT": 0.050},
100104
2: {"SKILL": -5, "DEPTH": 5, "TIME_CONSTRAINT": 0.100},
@@ -105,6 +109,8 @@
105109
7: {"SKILL": 16, "DEPTH": 13, "TIME_CONSTRAINT": 0.500},
106110
8: {"SKILL": 20, "DEPTH": 22, "TIME_CONSTRAINT": 1.000},
107111
} # from https://github.com/lichess-org/fishnet/blob/dc4be23256e3e5591578f0901f98f5835a138d73/src/api.rs#L224
108-
EVAL_GAMES_FOLDER = os.path.join(
109-
os.environ["CT_EVAL_GAMES_FOLDER"], NAME
112+
EVAL_GAMES_FOLDER = (
113+
os.path.join(os.environ.get("CT_EVAL_GAMES_FOLDER"), NAME)
114+
if os.environ.get("CT_EVAL_GAMES_FOLDER")
115+
else None
110116
) # folder where games against Stockfish are saved in PGN files

chess_transformers/configs/models/CT-ED-45.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,18 @@
6565
USE_AMP = True # use automatic mixed precision training?
6666
CRITERION = LabelSmoothedCE # training criterion (loss)
6767
OPTIMIZER = torch.optim.Adam # optimizer
68-
LOGS_DIR = os.path.join(os.environ["CT_LOGS_FOLDER"], NAME) # logs folder
68+
LOGS_DIR = (
69+
os.path.join(os.environ.get("CT_LOGS_FOLDER"), NAME)
70+
if os.environ.get("CT_LOGS_FOLDER")
71+
else None
72+
) # logs folder
6973

7074
###############################
7175
######### Checkpoints #########
7276
###############################
7377

7478
CHECKPOINT_FOLDER = os.path.join(
75-
os.environ["CT_CHECKPOINTS_FOLDER"], NAME
79+
os.environ.get("CT_CHECKPOINTS_FOLDER"), NAME
7680
) # folder containing checkpoints
7781
TRAINING_CHECKPOINT = (
7882
NAME + ".pt"
@@ -91,10 +95,10 @@
9195
########## Stockfish ##########
9296
###############################
9397

94-
STOCKFISH_PATH = os.environ["CT_STOCKFISH_PATH"] # path to Stockfish engine
95-
FAIRY_STOCKFISH_PATH = os.environ[
98+
STOCKFISH_PATH = os.environ.get("CT_STOCKFISH_PATH") # path to Stockfish engine
99+
FAIRY_STOCKFISH_PATH = os.environ.get(
96100
"CT_FAIRY_STOCKFISH_PATH"
97-
] # path to Fairy Stockfish engine
101+
) # path to Fairy Stockfish engine
98102
LICHESS_LEVELS = {
99103
1: {"SKILL": -9, "DEPTH": 5, "TIME_CONSTRAINT": 0.050},
100104
2: {"SKILL": -5, "DEPTH": 5, "TIME_CONSTRAINT": 0.100},
@@ -105,6 +109,8 @@
105109
7: {"SKILL": 16, "DEPTH": 13, "TIME_CONSTRAINT": 0.500},
106110
8: {"SKILL": 20, "DEPTH": 22, "TIME_CONSTRAINT": 1.000},
107111
} # from https://github.com/lichess-org/fishnet/blob/dc4be23256e3e5591578f0901f98f5835a138d73/src/api.rs#L224
108-
EVAL_GAMES_FOLDER = os.path.join(
109-
os.environ["CT_EVAL_GAMES_FOLDER"], NAME
112+
EVAL_GAMES_FOLDER = (
113+
os.path.join(os.environ.get("CT_EVAL_GAMES_FOLDER"), NAME)
114+
if os.environ.get("CT_EVAL_GAMES_FOLDER")
115+
else None
110116
) # folder where games against Stockfish are saved in PGN files
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__all__ = ["CT-E-19", "CT-ED-45"]
1+
__all__ = ["CT-E-19", "CT-ED-45"]

chess_transformers/data/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
__all__ = ["prep"]
1+
__all__ = ["prep", "utils"]
2+
3+
from chess_transformers.data.prep import prepare_data

chess_transformers/eval/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__all__ = ["metrics"]

chess_transformers/play/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__all__ = ["moves", "matchups", "clocks", "exceptions", "utils"]
1+
__all__ = ["moves", "play", "clocks", "exceptions", "utils"]
22

33
from chess_transformers.play.play import human_v_model, model_v_engine, model_v_model, warm_up

chess_transformers/play/moves.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import os
2-
import json
31
import chess
42
import torch
5-
from datetime import date
63
from IPython.display import clear_output, Markdown, display
74

85
from chess_transformers.play.utils import (

chess_transformers/play/play.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from IPython.utils import io
33
from IPython.display import display, Markdown, clear_output
44

5-
from chess_transformers.play.utils import get_pgn, in_notebook, print_text, print_board
65
from chess_transformers.play.exceptions import OutOfTime
76
from chess_transformers.play.moves import model_move, engine_move, human_move
7+
from chess_transformers.play.utils import get_pgn, in_notebook, print_text, print_board
88

99

1010
def model_v_engine(

chess_transformers/train/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
__all__ = ["utils", "datasets"]
1+
__all__ = ["utils", "datasets", "average_checkpoints", "train"]
2+
3+
from chess_transformers.train.train import train_model

img/ct_e_20.png

2.96 KB
Loading

img/ct_ed_45.png

8.84 KB
Loading

img/logo.png

-3.65 MB
Loading

img/logo_old.png

-370 KB
Binary file not shown.

0 commit comments

Comments
 (0)