-
Notifications
You must be signed in to change notification settings - Fork 406
[Feature,Refactor] Chess improvements: fen, pgn, pixels, san #2702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,26 +4,51 @@ | |||||
# LICENSE file in the root directory of this source tree. | ||||||
from __future__ import annotations | ||||||
|
||||||
import importlib.util | ||||||
import io | ||||||
from typing import Dict, Optional | ||||||
|
||||||
import torch | ||||||
from PIL import Image | ||||||
from tensordict import TensorDict, TensorDictBase | ||||||
from torchrl.data import Categorical, Composite, NonTensor, Unbounded | ||||||
|
||||||
from torchrl.envs import EnvBase | ||||||
from torchrl.envs.common import _EnvPostInit | ||||||
|
||||||
from torchrl.envs.utils import _classproperty | ||||||
|
||||||
|
||||||
class ChessEnv(EnvBase): | ||||||
class _HashMeta(_EnvPostInit): | ||||||
def __call__(cls, *args, **kwargs): | ||||||
instance = super().__call__(*args, **kwargs) | ||||||
if kwargs.get("include_hash"): | ||||||
from torchrl.envs import Hash | ||||||
|
||||||
in_keys = [] | ||||||
out_keys = [] | ||||||
if instance.include_san: | ||||||
in_keys.append("san") | ||||||
out_keys.append("san_hash") | ||||||
if instance.include_fen: | ||||||
in_keys.append("fen") | ||||||
out_keys.append("fen_hash") | ||||||
if instance.include_pgn: | ||||||
in_keys.append("pgn") | ||||||
out_keys.append("pgn_hash") | ||||||
return instance.append_transform(Hash(in_keys, out_keys)) | ||||||
return instance | ||||||
|
||||||
|
||||||
class ChessEnv(EnvBase, metaclass=_HashMeta): | ||||||
"""A chess environment that follows the TorchRL API. | ||||||
|
||||||
Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__. | ||||||
|
||||||
Args: | ||||||
stateful (bool): Whether to keep track of the internal state of the board. | ||||||
If False, the state will be stored in the observation and passed back | ||||||
to the environment on each call. Default: ``False``. | ||||||
to the environment on each call. Default: ``True``. | ||||||
|
||||||
.. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape. | ||||||
Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves, | ||||||
|
@@ -90,28 +115,76 @@ class ChessEnv(EnvBase): | |||||
""" | ||||||
|
||||||
_hash_table: Dict[int, str] = {} | ||||||
_PNG_RESTART = """[Event "?"] | ||||||
[Site "?"] | ||||||
[Date "????.??.??"] | ||||||
[Round "?"] | ||||||
[White "?"] | ||||||
[Black "?"] | ||||||
[Result "*"] | ||||||
|
||||||
*""" | ||||||
|
||||||
@_classproperty | ||||||
def lib(cls): | ||||||
try: | ||||||
import chess | ||||||
import chess.pgn | ||||||
except ImportError: | ||||||
raise ImportError( | ||||||
"The `chess` library could not be found. Make sure you installed it through `pip install chess`." | ||||||
) | ||||||
return chess | ||||||
|
||||||
def __init__(self, stateful: bool = False): | ||||||
def __init__( | ||||||
self, | ||||||
*, | ||||||
stateful: bool = True, | ||||||
include_san: bool = False, | ||||||
include_fen: bool = False, | ||||||
include_pgn: bool = False, | ||||||
include_hash: bool = False, | ||||||
pixels: bool = False, | ||||||
): | ||||||
chess = self.lib | ||||||
super().__init__() | ||||||
self.full_observation_spec = Composite( | ||||||
hashing=Unbounded(shape=(), dtype=torch.int64), | ||||||
fen=NonTensor(shape=()), | ||||||
turn=Categorical(n=2, dtype=torch.bool, shape=()), | ||||||
) | ||||||
self.include_san = include_san | ||||||
self.include_fen = include_fen | ||||||
self.include_pgn = include_pgn | ||||||
if include_san: | ||||||
self.full_observation_spec["san"] = NonTensor(shape=(), example_data="Nc6") | ||||||
if include_pgn: | ||||||
self.full_observation_spec["pgn"] = NonTensor( | ||||||
shape=(), example_data=self._PNG_RESTART | ||||||
) | ||||||
if include_fen: | ||||||
self.full_observation_spec["fen"] = NonTensor(shape=(), example_data="any") | ||||||
if not stateful and not (include_pgn or include_fen): | ||||||
raise RuntimeError( | ||||||
"At least one state representation (pgn or fen) must be enabled when stateful " | ||||||
f"is {stateful}." | ||||||
) | ||||||
|
||||||
self.stateful = stateful | ||||||
|
||||||
if not self.stateful: | ||||||
self.full_state_spec = self.full_observation_spec.clone() | ||||||
|
||||||
self.pixels = pixels | ||||||
if pixels: | ||||||
if importlib.util.find_spec("cairosvg") is None: | ||||||
raise ImportError( | ||||||
"Please install cairosvg to use this environment with pixel rendering." | ||||||
) | ||||||
if importlib.util.find_spec("torchvision") is None: | ||||||
raise ImportError( | ||||||
"Please install torchvision to use this environment with pixel rendering." | ||||||
) | ||||||
self.full_observation_spec["pixels"] = Unbounded(shape=()) | ||||||
|
||||||
self.full_action_spec = Composite( | ||||||
action=Categorical(n=-1, shape=(), dtype=torch.int64) | ||||||
) | ||||||
|
@@ -132,41 +205,126 @@ def _is_done(self, board): | |||||
|
||||||
def _reset(self, tensordict=None): | ||||||
fen = None | ||||||
pgn = None | ||||||
if tensordict is not None: | ||||||
fen = self._get_fen(tensordict).data | ||||||
dest = tensordict.empty() | ||||||
if self.include_fen: | ||||||
fen = self._get_fen(tensordict).data | ||||||
dest = tensordict.empty() | ||||||
if self.include_pgn: | ||||||
fen = self._get_pgn(tensordict).data | ||||||
|
fen = self._get_pgn(tensordict).data | |
pgn = self._get_pgn(tensordict).data |
Also _get_pgn
doesn't exist, but I figure you realize that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could have another option that has a more minimal representation as well, if we want to train without vision. For instance, OpenSpiel's chess env has a representation where each square of the board is given a size 20 array of ones and zeros, which is somehow used to represent which piece is in that square and which color it is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I was using this because it's builtin - but a more minimal representation could also be cool!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this function can be removed, in favor of a direct TensorDict.get
call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that one day we might let people choose the fen / pgn key for themselves but it's true that as of now it's not incredibly useful
Uh oh!
There was an error while loading. Please reload this page.