Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ocean/chess/binding.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "chess.h"
#define OBS_SIZE 1082
// Before embedding approach
// #define OBS_SIZE 1082
#define OBS_SIZE 167
#define NUM_ATNS 1
#define ACT_SIZES {97}
#define OBS_TENSOR_T ByteTensor
Expand Down
121 changes: 83 additions & 38 deletions ocean/chess/chess.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ enum {
OBS_SIZE = 998
};
*/

/*
Selfplay branch obs layout before embedding approach
enum {
O_BOARD = 0,
O_SIDE = 768,
Expand All @@ -326,6 +329,30 @@ enum {
O_PASS_VALID = 1081,
OBS_SIZE = 1082
};
*/
enum {
O_BOARD = 0,
O_SIDE = 64,
O_CASTLE = 65,
O_EP = 69,
O_RULE50 = 78,
O_REPETITION = 79,
O_SELF_CHECK = 80,
O_OPP_CHECK = 81,
O_PICK_PHASE = 82,
O_SELECTED_PIECE = 83,
O_VALID_FROM_COUNT = 84,
O_VALID_FROM = 85,
O_VALID_TO_COUNT = 101,
O_VALID_TO = 102,
O_VALID_PROMOS = 134,
O_PASS_VALID = 166,
OBS_SIZE = 167
};

#define CHESS_MAX_VALID_FROM 16
#define CHESS_MAX_VALID_TO 32
#define CHESS_NULL_SQ 64

#define PASS_ACTION 96
#define NUM_ACTIONS 97
Expand Down Expand Up @@ -1552,32 +1579,28 @@ void populate_observations(Chess* env) {
int flip = player * 56;


// our pieces
// Compact ego-centric board: one byte per square. 0 = empty,
// own P..K = 1..6, enemy P..K = 7..12.
for (int pt = PAWN; pt <= KING; pt++) {
Bitboard bb = pieces_cp(pos, player, pt);
int plane = pt - 1; // 0-5
while (bb) {
Square sq = pop_lsb(&bb);
board_planes[plane * 64 + (sq ^ flip)] = 1;
board_planes[sq ^ flip] = (uint8_t)pt;
}
}

// Their pieces (planes 6-11)
for (int pt = PAWN; pt <= KING; pt++) {
Bitboard bb = pieces_cp(pos, them, pt);
int plane = 6 + (pt - 1); // 6-11
while (bb) {
Square sq = pop_lsb(&bb);
board_planes[plane * 64 + (sq ^ flip)] = 1;
board_planes[sq ^ flip] = (uint8_t)(6 + pt);
}
}

ChessColor side_to_move = pos->sideToMove;

uint8_t* side_onehot = player_obs + O_SIDE;
side_onehot[(pos->sideToMove == us) ? 0 : 1] = 1;
player_obs[O_SIDE] = (pos->sideToMove == us) ? 1 : 0;

uint8_t* castle_onehot = player_obs + O_CASTLE;
uint8_t castle_rights = pos->castlingRights;
if (player == 1) {
uint8_t flipped = 0;
Expand All @@ -1587,55 +1610,75 @@ void populate_observations(Chess* env) {
if (castle_rights & WHITE_OOO) flipped |= BLACK_OOO;
castle_rights = flipped;
}
castle_onehot[castle_rights] = 1;
player_obs[O_CASTLE + 0] = (castle_rights & WHITE_OO) ? 1 : 0;
player_obs[O_CASTLE + 1] = (castle_rights & WHITE_OOO) ? 1 : 0;
player_obs[O_CASTLE + 2] = (castle_rights & BLACK_OO) ? 1 : 0;
player_obs[O_CASTLE + 3] = (castle_rights & BLACK_OOO) ? 1 : 0;

uint8_t* ep_onehot = player_obs + O_EP;
if (pos->epSquare < 64) {
int ep_sq = (player == 1) ? (pos->epSquare ^ 56) : pos->epSquare;
ep_onehot[ep_sq] = 1;
int ep_sq = (player == 1) ? (pos->epSquare ^ 56) : pos->epSquare;
player_obs[O_EP + file_of((Square)ep_sq)] = 1;
} else {
ep_onehot[64] = 1;
player_obs[O_EP + 8] = 1;
}

uint8_t* valid_pieces = player_obs + O_VALID_PIECES;
uint8_t* valid_dests = player_obs + O_VALID_DESTS;

uint8_t* valid_from_indices = player_obs + O_VALID_FROM;
uint8_t* valid_to_indices = player_obs + O_VALID_TO;
for (int k = 0; k < CHESS_MAX_VALID_FROM; k++) valid_from_indices[k] = CHESS_NULL_SQ;
for (int k = 0; k < CHESS_MAX_VALID_TO; k++) valid_to_indices[k] = CHESS_NULL_SQ;
int valid_from_count = 0;
int valid_to_count = 0;

int player_idx = (int)us;

if (side_to_move == us) {
if (env->pick_phase[player_idx] == 0) {
if (env->legal_moves.count > 0) {
for (int i = 0; i < env->legal_moves.count; i++) {
Square from = from_sq(env->legal_moves.moves[i].move);
int view_from = (player == 1) ? (from ^ 56) : from;
valid_pieces[view_from] = 1;
Bitboard added = 0;
for (int i = 0; i < env->legal_moves.count; i++) {
Square from = from_sq(env->legal_moves.moves[i].move);
int view_from = (player == 1) ? (from ^ 56) : from;
Bitboard bit = sq_bb((Square)view_from);
if (!(added & bit)) {
added |= bit;
if (valid_from_count < CHESS_MAX_VALID_FROM) {
valid_from_indices[valid_from_count++] = (uint8_t)view_from;
}
if (fill_mask) my_mask[view_from] = 1;
}
}
} else {
if (env->valid_destinations[player_idx].count > 0) {
for (int i = 0; i < env->valid_destinations[player_idx].count; i++) {
Square to = to_sq(env->valid_destinations[player_idx].moves[i].move);
int view_to = (player == 1) ? (to ^ 56) : to;
valid_dests[view_to] = 1;
Bitboard added = 0;
for (int i = 0; i < env->valid_destinations[player_idx].count; i++) {
Square to = to_sq(env->valid_destinations[player_idx].moves[i].move);
int view_to = (player == 1) ? (to ^ 56) : to;
Bitboard bit = sq_bb((Square)view_to);
if (!(added & bit)) {
added |= bit;
if (valid_to_count < CHESS_MAX_VALID_TO) {
valid_to_indices[valid_to_count++] = (uint8_t)view_to;
}
if (fill_mask) my_mask[view_to] = 1;
}
}
}
}
player_obs[O_VALID_FROM_COUNT] = (uint8_t)valid_from_count;
player_obs[O_VALID_TO_COUNT] = (uint8_t)valid_to_count;
player_obs[O_PASS_VALID] = (side_to_move != us) ? 255 : 0;
if (fill_mask && side_to_move != us) {
my_mask[PASS_ACTION] = 1;
}

uint8_t* phase_onehot = player_obs + O_PICK_PHASE;
phase_onehot[env->pick_phase[player_idx]] = 1;

uint8_t* selected_piece_plane = player_obs + O_SELECTED_PIECE;

player_obs[O_PICK_PHASE] = env->pick_phase[player_idx] ? 1 : 0;

uint8_t selected_byte = (uint8_t)CHESS_NULL_SQ;
if (env->pick_phase[player_idx] == 1 && env->selected_square[player_idx] != SQ_NONE) {
int view_selected = (player == 1) ? (env->selected_square[player_idx] ^ 56) : env->selected_square[player_idx];
selected_piece_plane[view_selected] = 1;
int view_selected = (player == 1)
? (env->selected_square[player_idx] ^ 56)
: env->selected_square[player_idx];
selected_byte = (uint8_t)view_selected;
}
player_obs[O_SELECTED_PIECE] = selected_byte;

uint8_t* valid_promos = player_obs + O_VALID_PROMOS;

Expand All @@ -1654,9 +1697,11 @@ void populate_observations(Chess* env) {
player_obs[O_SELF_CHECK] = is_check(pos, us) ? 255 : 0;
player_obs[O_OPP_CHECK] = is_check(pos, them) ? 255 : 0;

player_obs[O_RULE50] = (uint8_t)((pos->rule50 * 255) / 100);
int rule50 = pos->rule50;
if (rule50 > 100) rule50 = 100;
player_obs[O_RULE50] = (uint8_t)((rule50 * 255) / 100);

uint8_t rep_val = 255;
uint8_t rep_val = 0;
if (env->undo_stack_ptr >= 4) {
uint8_t plies = env->undo_stack[env->undo_stack_ptr - 1].pliesFromNull;
if (plies >= 4) {
Expand All @@ -1668,7 +1713,7 @@ void populate_observations(Chess* env) {
}
}
if (repetitions >= 2) {
rep_val = 0;
rep_val = 255;
} else if (repetitions == 1) {
rep_val = 128;
}
Expand Down
Loading
Loading