From 1046ec2d22ca4b72f38319eb186501b7a50412e9 Mon Sep 17 00:00:00 2001 From: bullardji <102577354+bullardji@users.noreply.github.com> Date: Sun, 3 May 2026 14:26:29 -0400 Subject: [PATCH 1/2] New embedding architecture + added #include to pufferlib.cu to fix --profile --- ocean/chess/binding.c | 4 +- ocean/chess/chess.h | 121 +++++++++++------ src/ocean.cu | 301 ++++++++++++++++++++++++++++++++++++++++++ src/pufferlib.cu | 1 + 4 files changed, 388 insertions(+), 39 deletions(-) diff --git a/ocean/chess/binding.c b/ocean/chess/binding.c index 83a004cca2..99f5767189 100644 --- a/ocean/chess/binding.c +++ b/ocean/chess/binding.c @@ -1,5 +1,7 @@ #include "chess.h" -#define OBS_SIZE 1082 +// Before embedding approach +// #define OBS_SIZE 1082 +#define OBS_SIZE 167 #define NUM_ATNS 1 #define ACT_SIZES {97} #define OBS_TENSOR_T ByteTensor diff --git a/ocean/chess/chess.h b/ocean/chess/chess.h index 81d74d000e..9dcefdc9d9 100644 --- a/ocean/chess/chess.h +++ b/ocean/chess/chess.h @@ -309,6 +309,9 @@ enum { OBS_SIZE = 998 }; */ + +/* +Selfplay branch obs layout before embedding approach enum { O_BOARD = 0, O_SIDE = 768, @@ -326,6 +329,30 @@ enum { O_PASS_VALID = 1081, OBS_SIZE = 1082 }; +*/ +enum { + O_BOARD = 0, + O_SIDE = 64, + O_CASTLE = 65, + O_EP = 69, + O_RULE50 = 78, + O_REPETITION = 79, + O_SELF_CHECK = 80, + O_OPP_CHECK = 81, + O_PICK_PHASE = 82, + O_SELECTED_PIECE = 83, + O_VALID_FROM_COUNT = 84, + O_VALID_FROM = 85, + O_VALID_TO_COUNT = 101, + O_VALID_TO = 102, + O_VALID_PROMOS = 134, + O_PASS_VALID = 166, + OBS_SIZE = 167 +}; + +#define CHESS_MAX_VALID_FROM 16 +#define CHESS_MAX_VALID_TO 32 +#define CHESS_NULL_SQ 64 #define PASS_ACTION 96 #define NUM_ACTIONS 97 @@ -1552,32 +1579,28 @@ void populate_observations(Chess* env) { int flip = player * 56; - // our pieces + // Compact ego-centric board: one byte per square. 0 = empty, + // own P..K = 1..6, enemy P..K = 7..12. for (int pt = PAWN; pt <= KING; pt++) { Bitboard bb = pieces_cp(pos, player, pt); - int plane = pt - 1; // 0-5 while (bb) { Square sq = pop_lsb(&bb); - board_planes[plane * 64 + (sq ^ flip)] = 1; + board_planes[sq ^ flip] = (uint8_t)pt; } } - // Their pieces (planes 6-11) for (int pt = PAWN; pt <= KING; pt++) { Bitboard bb = pieces_cp(pos, them, pt); - int plane = 6 + (pt - 1); // 6-11 while (bb) { Square sq = pop_lsb(&bb); - board_planes[plane * 64 + (sq ^ flip)] = 1; + board_planes[sq ^ flip] = (uint8_t)(6 + pt); } } ChessColor side_to_move = pos->sideToMove; - uint8_t* side_onehot = player_obs + O_SIDE; - side_onehot[(pos->sideToMove == us) ? 0 : 1] = 1; + player_obs[O_SIDE] = (pos->sideToMove == us) ? 1 : 0; - uint8_t* castle_onehot = player_obs + O_CASTLE; uint8_t castle_rights = pos->castlingRights; if (player == 1) { uint8_t flipped = 0; @@ -1587,55 +1610,75 @@ void populate_observations(Chess* env) { if (castle_rights & WHITE_OOO) flipped |= BLACK_OOO; castle_rights = flipped; } - castle_onehot[castle_rights] = 1; + player_obs[O_CASTLE + 0] = (castle_rights & WHITE_OO) ? 1 : 0; + player_obs[O_CASTLE + 1] = (castle_rights & WHITE_OOO) ? 1 : 0; + player_obs[O_CASTLE + 2] = (castle_rights & BLACK_OO) ? 1 : 0; + player_obs[O_CASTLE + 3] = (castle_rights & BLACK_OOO) ? 1 : 0; - uint8_t* ep_onehot = player_obs + O_EP; if (pos->epSquare < 64) { - int ep_sq = (player == 1) ? (pos->epSquare ^ 56) : pos->epSquare; - ep_onehot[ep_sq] = 1; + int ep_sq = (player == 1) ? (pos->epSquare ^ 56) : pos->epSquare; + player_obs[O_EP + file_of((Square)ep_sq)] = 1; } else { - ep_onehot[64] = 1; + player_obs[O_EP + 8] = 1; } - uint8_t* valid_pieces = player_obs + O_VALID_PIECES; - uint8_t* valid_dests = player_obs + O_VALID_DESTS; - + uint8_t* valid_from_indices = player_obs + O_VALID_FROM; + uint8_t* valid_to_indices = player_obs + O_VALID_TO; + for (int k = 0; k < CHESS_MAX_VALID_FROM; k++) valid_from_indices[k] = CHESS_NULL_SQ; + for (int k = 0; k < CHESS_MAX_VALID_TO; k++) valid_to_indices[k] = CHESS_NULL_SQ; + int valid_from_count = 0; + int valid_to_count = 0; + int player_idx = (int)us; - + if (side_to_move == us) { if (env->pick_phase[player_idx] == 0) { - if (env->legal_moves.count > 0) { - for (int i = 0; i < env->legal_moves.count; i++) { - Square from = from_sq(env->legal_moves.moves[i].move); - int view_from = (player == 1) ? (from ^ 56) : from; - valid_pieces[view_from] = 1; + Bitboard added = 0; + for (int i = 0; i < env->legal_moves.count; i++) { + Square from = from_sq(env->legal_moves.moves[i].move); + int view_from = (player == 1) ? (from ^ 56) : from; + Bitboard bit = sq_bb((Square)view_from); + if (!(added & bit)) { + added |= bit; + if (valid_from_count < CHESS_MAX_VALID_FROM) { + valid_from_indices[valid_from_count++] = (uint8_t)view_from; + } if (fill_mask) my_mask[view_from] = 1; } } } else { - if (env->valid_destinations[player_idx].count > 0) { - for (int i = 0; i < env->valid_destinations[player_idx].count; i++) { - Square to = to_sq(env->valid_destinations[player_idx].moves[i].move); - int view_to = (player == 1) ? (to ^ 56) : to; - valid_dests[view_to] = 1; + Bitboard added = 0; + for (int i = 0; i < env->valid_destinations[player_idx].count; i++) { + Square to = to_sq(env->valid_destinations[player_idx].moves[i].move); + int view_to = (player == 1) ? (to ^ 56) : to; + Bitboard bit = sq_bb((Square)view_to); + if (!(added & bit)) { + added |= bit; + if (valid_to_count < CHESS_MAX_VALID_TO) { + valid_to_indices[valid_to_count++] = (uint8_t)view_to; + } if (fill_mask) my_mask[view_to] = 1; } } } } + player_obs[O_VALID_FROM_COUNT] = (uint8_t)valid_from_count; + player_obs[O_VALID_TO_COUNT] = (uint8_t)valid_to_count; player_obs[O_PASS_VALID] = (side_to_move != us) ? 255 : 0; if (fill_mask && side_to_move != us) { my_mask[PASS_ACTION] = 1; } - - uint8_t* phase_onehot = player_obs + O_PICK_PHASE; - phase_onehot[env->pick_phase[player_idx]] = 1; - - uint8_t* selected_piece_plane = player_obs + O_SELECTED_PIECE; + + player_obs[O_PICK_PHASE] = env->pick_phase[player_idx] ? 1 : 0; + + uint8_t selected_byte = (uint8_t)CHESS_NULL_SQ; if (env->pick_phase[player_idx] == 1 && env->selected_square[player_idx] != SQ_NONE) { - int view_selected = (player == 1) ? (env->selected_square[player_idx] ^ 56) : env->selected_square[player_idx]; - selected_piece_plane[view_selected] = 1; + int view_selected = (player == 1) + ? (env->selected_square[player_idx] ^ 56) + : env->selected_square[player_idx]; + selected_byte = (uint8_t)view_selected; } + player_obs[O_SELECTED_PIECE] = selected_byte; uint8_t* valid_promos = player_obs + O_VALID_PROMOS; @@ -1654,9 +1697,11 @@ void populate_observations(Chess* env) { player_obs[O_SELF_CHECK] = is_check(pos, us) ? 255 : 0; player_obs[O_OPP_CHECK] = is_check(pos, them) ? 255 : 0; - player_obs[O_RULE50] = (uint8_t)((pos->rule50 * 255) / 100); + int rule50 = pos->rule50; + if (rule50 > 100) rule50 = 100; + player_obs[O_RULE50] = (uint8_t)((rule50 * 255) / 100); - uint8_t rep_val = 255; + uint8_t rep_val = 0; if (env->undo_stack_ptr >= 4) { uint8_t plies = env->undo_stack[env->undo_stack_ptr - 1].pliesFromNull; if (plies >= 4) { @@ -1668,7 +1713,7 @@ void populate_observations(Chess* env) { } } if (repetitions >= 2) { - rep_val = 0; + rep_val = 255; } else if (repetitions == 1) { rep_val = 128; } diff --git a/src/ocean.cu b/src/ocean.cu index baaa9b7be6..3175a3b297 100644 --- a/src/ocean.cu +++ b/src/ocean.cu @@ -570,6 +570,293 @@ static void* nmmo3_encoder_create_weights(void* self) { static void nmmo3_encoder_free_weights(void* weights) { free(weights); } static void nmmo3_encoder_free_activations(void* activations) { free(activations); } +// ---- Chess constants ---- +// Chess uses learnable piece-square and action-context embeddings. The encoder +// reads compact obs bytes, sums the matching rows from board_w/move_context_w +// into hidden space, adds a learned projection of rule metadata, then applies ReLU. + +static constexpr int CH_BOARD = 0; +static constexpr int CH_SIDE = 64; +static constexpr int CH_CASTLE = 65; +static constexpr int CH_EP = 69; +static constexpr int CH_RULE50 = 78; +static constexpr int CH_REPETITION = 79; +static constexpr int CH_SELF_CHECK = 80; +static constexpr int CH_OPP_CHECK = 81; +static constexpr int CH_PICK_PHASE = 82; +static constexpr int CH_SELECTED = 83; +static constexpr int CH_VALID_FROM_COUNT = 84; +static constexpr int CH_VALID_FROM = 85; +static constexpr int CH_VALID_TO_COUNT = 101; +static constexpr int CH_VALID_TO = 102; +static constexpr int CH_VALID_PROMOS = 134; +static constexpr int CH_PASS_VALID = 166; +static constexpr int CH_OBS_SIZE = 167; +static constexpr int CH_BOARD_SQUARES = 64; +static constexpr int CH_PIECE_TYPES = 12; +static constexpr int CH_BOARD_FEATURES = CH_BOARD_SQUARES * CH_PIECE_TYPES; +static constexpr int CH_MAX_VALID_FROM = 16; +static constexpr int CH_MAX_VALID_TO = 32; +static constexpr int CH_NULL_SQ = 64; +static constexpr int CH_MOVE_CONTEXT_SELECTED = 0; +static constexpr int CH_MOVE_CONTEXT_VALID_FROM = 64; +static constexpr int CH_MOVE_CONTEXT_VALID_TO = 128; +static constexpr int CH_MOVE_CONTEXT_PHASE = 192; +static constexpr int CH_MOVE_CONTEXT_FEATURES = 194; +static constexpr int CH_META = 51; // side/castle/ep/clocks/checks/promos/pass + +// ---- Chess kernels ---- + +__global__ void chess_meta_kernel( + precision_t* __restrict__ meta, + const precision_t* __restrict__ obs, + int B, int obs_size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * CH_META) return; + int b = idx / CH_META; + int m = idx % CH_META; + int src = CH_PASS_VALID; + if (m == 0) { + src = CH_SIDE; + } else if (m < 5) { + src = CH_CASTLE + (m - 1); + } else if (m < 14) { + src = CH_EP + (m - 5); + } else if (m == 14) { + src = CH_RULE50; + } else if (m == 15) { + src = CH_REPETITION; + } else if (m == 16) { + src = CH_SELF_CHECK; + } else if (m == 17) { + src = CH_OPP_CHECK; + } else if (m < 50) { + src = CH_VALID_PROMOS + (m - 18); + } + float x = to_float(obs[b * obs_size + src]); + if (x > 1.0f) x *= (1.0f / 255.0f); + meta[idx] = from_float(x); +} + +__global__ void chess_embed_forward_kernel( + precision_t* __restrict__ out, + const precision_t* __restrict__ obs, + const precision_t* __restrict__ board_w, + const precision_t* __restrict__ move_context_w, + const precision_t* __restrict__ bias, + int B, int hidden, int obs_size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * hidden) return; + int b = idx / hidden; + int h = idx % hidden; + + const precision_t* obs_row = obs + b * obs_size; + float acc = to_float(out[idx]) + to_float(bias[h]); + + #pragma unroll 8 + for (int sq = 0; sq < CH_BOARD_SQUARES; sq++) { + int piece = (int)to_float(obs_row[CH_BOARD + sq]); + if (piece > 0) { + int feat = (piece - 1) * CH_BOARD_SQUARES + sq; + acc += to_float(board_w[feat * hidden + h]); + } + } + + int phase = (int)to_float(obs_row[CH_PICK_PHASE]) > 0 ? 1 : 0; + acc += to_float(move_context_w[(CH_MOVE_CONTEXT_PHASE + phase) * hidden + h]); + + int selected = (int)to_float(obs_row[CH_SELECTED]); + if (selected < CH_BOARD_SQUARES) { + acc += to_float(move_context_w[(CH_MOVE_CONTEXT_SELECTED + selected) * hidden + h]); + } + + int from_count = (int)to_float(obs_row[CH_VALID_FROM_COUNT]); + for (int k = 0; k < from_count; k++) { + int sq = (int)to_float(obs_row[CH_VALID_FROM + k]); + acc += to_float(move_context_w[(CH_MOVE_CONTEXT_VALID_FROM + sq) * hidden + h]); + } + + int to_count = (int)to_float(obs_row[CH_VALID_TO_COUNT]); + for (int k = 0; k < to_count; k++) { + int sq = (int)to_float(obs_row[CH_VALID_TO + k]); + acc += to_float(move_context_w[(CH_MOVE_CONTEXT_VALID_TO + sq) * hidden + h]); + } + + out[idx] = from_float(fmaxf(0.0f, acc)); +} + +__global__ void chess_embed_backward_kernel( + float* __restrict__ board_wgrad_f, + float* __restrict__ move_context_wgrad_f, + const precision_t* __restrict__ grad, + const precision_t* __restrict__ obs, + int B, int hidden, int obs_size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * hidden) return; + int b = idx / hidden; + int h = idx % hidden; + float g = to_float(grad[idx]); + if (g == 0.0f) return; + + const precision_t* obs_row = obs + b * obs_size; + + #pragma unroll 8 + for (int sq = 0; sq < CH_BOARD_SQUARES; sq++) { + int piece = (int)to_float(obs_row[CH_BOARD + sq]); + if (piece > 0) { + int feat = (piece - 1) * CH_BOARD_SQUARES + sq; + atomicAdd(&board_wgrad_f[feat * hidden + h], g); + } + } + + int phase = (int)to_float(obs_row[CH_PICK_PHASE]) > 0 ? 1 : 0; + atomicAdd(&move_context_wgrad_f[(CH_MOVE_CONTEXT_PHASE + phase) * hidden + h], g); + + int selected = (int)to_float(obs_row[CH_SELECTED]); + if (selected < CH_BOARD_SQUARES) { + atomicAdd(&move_context_wgrad_f[(CH_MOVE_CONTEXT_SELECTED + selected) * hidden + h], g); + } + + int from_count = (int)to_float(obs_row[CH_VALID_FROM_COUNT]); + for (int k = 0; k < from_count; k++) { + int sq = (int)to_float(obs_row[CH_VALID_FROM + k]); + atomicAdd(&move_context_wgrad_f[(CH_MOVE_CONTEXT_VALID_FROM + sq) * hidden + h], g); + } + + int to_count = (int)to_float(obs_row[CH_VALID_TO_COUNT]); + for (int k = 0; k < to_count; k++) { + int sq = (int)to_float(obs_row[CH_VALID_TO + k]); + atomicAdd(&move_context_wgrad_f[(CH_MOVE_CONTEXT_VALID_TO + sq) * hidden + h], g); + } +} + +// ---- Chess encoder structs ---- + +struct ChessEmbedEncoderWeights { + PrecisionTensor board_w, move_context_w, meta_w, bias; + int obs_size, hidden; +}; + +struct ChessEmbedEncoderActivations { + PrecisionTensor meta, out, saved_obs; + PrecisionTensor board_wgrad, move_context_wgrad, meta_wgrad, bgrad; + FloatTensor board_wgrad_f, move_context_wgrad_f; +}; + +// ---- Chess encoder interface ---- + +static ChessEmbedEncoderWeights* chess_embed_encoder_create(int obs_size, int hidden) { + ChessEmbedEncoderWeights* ew = (ChessEmbedEncoderWeights*)calloc(1, sizeof(ChessEmbedEncoderWeights)); + ew->obs_size = obs_size; + ew->hidden = hidden; + return ew; +} + +static PrecisionTensor chess_embed_encoder_forward(void* w, void* activations, + PrecisionTensor input, cudaStream_t stream) { + ChessEmbedEncoderWeights* ew = (ChessEmbedEncoderWeights*)w; + ChessEmbedEncoderActivations* a = (ChessEmbedEncoderActivations*)activations; + int B = input.shape[0]; + + if (a->saved_obs.data) puf_copy(&a->saved_obs, &input, stream); + chess_meta_kernel<<>>( + a->meta.data, input.data, B, ew->obs_size); + puf_mm(&a->meta, &ew->meta_w, &a->out, stream); + chess_embed_forward_kernel<<hidden), BLOCK_SIZE, 0, stream>>>( + a->out.data, input.data, ew->board_w.data, ew->move_context_w.data, ew->bias.data, + B, ew->hidden, ew->obs_size); + return a->out; +} + +static void chess_embed_encoder_backward(void* w, void* activations, + PrecisionTensor grad, cudaStream_t stream) { + ChessEmbedEncoderWeights* ew = (ChessEmbedEncoderWeights*)w; + ChessEmbedEncoderActivations* a = (ChessEmbedEncoderActivations*)activations; + int B = grad.shape[0]; + int H = ew->hidden; + + n3_relu_backward_kernel<<>>( + grad.data, a->out.data, B * H); + bias_grad_kernel<<>>(a->bgrad.data, grad.data, B, H); + puf_mm_tn(&grad, &a->meta, &a->meta_wgrad, stream); + + int board_n = CH_BOARD_FEATURES * H; + int move_context_n = CH_MOVE_CONTEXT_FEATURES * H; + cudaMemsetAsync(a->board_wgrad_f.data, 0, board_n * sizeof(float), stream); + cudaMemsetAsync(a->move_context_wgrad_f.data, 0, move_context_n * sizeof(float), stream); + chess_embed_backward_kernel<<>>( + a->board_wgrad_f.data, a->move_context_wgrad_f.data, grad.data, a->saved_obs.data, + B, H, ew->obs_size); + n3_float_to_precision_kernel<<>>( + a->board_wgrad.data, a->board_wgrad_f.data, board_n); + n3_float_to_precision_kernel<<>>( + a->move_context_wgrad.data, a->move_context_wgrad_f.data, move_context_n); +} + +static void chess_embed_encoder_init_weights(void* w, uint64_t* seed, cudaStream_t stream) { + ChessEmbedEncoderWeights* ew = (ChessEmbedEncoderWeights*)w; + puf_normal_init(&ew->board_w, 0.02f, (*seed)++, stream); + puf_normal_init(&ew->move_context_w, 0.02f, (*seed)++, stream); + PrecisionTensor wt = {.data = ew->meta_w.data, .shape = {ew->hidden, CH_META}}; + puf_kaiming_init(&wt, 1.0f, (*seed)++, stream); + cudaMemsetAsync(ew->bias.data, 0, numel(ew->bias.shape) * sizeof(precision_t), stream); +} + +static void chess_embed_encoder_reg_params(void* w, Allocator* alloc) { + ChessEmbedEncoderWeights* ew = (ChessEmbedEncoderWeights*)w; + ew->board_w = {.shape = {CH_BOARD_FEATURES, ew->hidden}}; + ew->move_context_w = {.shape = {CH_MOVE_CONTEXT_FEATURES, ew->hidden}}; + ew->meta_w = {.shape = {ew->hidden, CH_META}}; + ew->bias = {.shape = {ew->hidden}}; + alloc_register(alloc, &ew->board_w); + alloc_register(alloc, &ew->move_context_w); + alloc_register(alloc, &ew->meta_w); + alloc_register(alloc, &ew->bias); +} + +static void chess_embed_encoder_reg_train(void* w, void* activations, + Allocator* acts, Allocator* grads, int B_TT) { + ChessEmbedEncoderWeights* ew = (ChessEmbedEncoderWeights*)w; + ChessEmbedEncoderActivations* a = (ChessEmbedEncoderActivations*)activations; + *a = {}; + a->meta = {.shape = {B_TT, CH_META}}; + a->out = {.shape = {B_TT, ew->hidden}}; + a->saved_obs = {.shape = {B_TT, ew->obs_size}}; + alloc_register(acts, &a->meta); + alloc_register(acts, &a->out); + alloc_register(acts, &a->saved_obs); + + a->board_wgrad = {.shape = {CH_BOARD_FEATURES, ew->hidden}}; + a->move_context_wgrad = {.shape = {CH_MOVE_CONTEXT_FEATURES, ew->hidden}}; + a->meta_wgrad = {.shape = {ew->hidden, CH_META}}; + a->bgrad = {.shape = {ew->hidden}}; + a->board_wgrad_f = {.shape = {CH_BOARD_FEATURES, ew->hidden}}; + a->move_context_wgrad_f = {.shape = {CH_MOVE_CONTEXT_FEATURES, ew->hidden}}; + alloc_register(grads, &a->board_wgrad); + alloc_register(grads, &a->move_context_wgrad); + alloc_register(grads, &a->meta_wgrad); + alloc_register(grads, &a->bgrad); + alloc_register(acts, &a->board_wgrad_f); + alloc_register(acts, &a->move_context_wgrad_f); +} + +static void chess_embed_encoder_reg_rollout(void* w, void* activations, + Allocator* alloc, int B) { + ChessEmbedEncoderWeights* ew = (ChessEmbedEncoderWeights*)w; + ChessEmbedEncoderActivations* a = (ChessEmbedEncoderActivations*)activations; + a->meta = {.shape = {B, CH_META}}; + a->out = {.shape = {B, ew->hidden}}; + alloc_register(alloc, &a->meta); + alloc_register(alloc, &a->out); +} + +static void* chess_embed_encoder_create_weights(void* self) { + Encoder* e = (Encoder*)self; + return chess_embed_encoder_create(e->in_dim, e->out_dim); +} +static void chess_embed_encoder_free_weights(void* weights) { free(weights); } +static void chess_embed_encoder_free_activations(void* activations) { free(activations); } + // Override encoder vtable for known ocean environments. No-op for unknown envs. static void create_custom_encoder(const std::string& env_name, Encoder* enc) { if (env_name == "nmmo3") { @@ -586,5 +873,19 @@ static void create_custom_encoder(const std::string& env_name, Encoder* enc) { .in_dim = enc->in_dim, .out_dim = enc->out_dim, .activation_size = sizeof(NMMO3EncoderActivations), }; + } else if (env_name == "chess") { + *enc = Encoder{ + .forward = chess_embed_encoder_forward, + .backward = chess_embed_encoder_backward, + .init_weights = chess_embed_encoder_init_weights, + .reg_params = chess_embed_encoder_reg_params, + .reg_train = chess_embed_encoder_reg_train, + .reg_rollout = chess_embed_encoder_reg_rollout, + .create_weights = chess_embed_encoder_create_weights, + .free_weights = chess_embed_encoder_free_weights, + .free_activations = chess_embed_encoder_free_activations, + .in_dim = enc->in_dim, .out_dim = enc->out_dim, + .activation_size = sizeof(ChessEmbedEncoderActivations), + }; } } diff --git a/src/pufferlib.cu b/src/pufferlib.cu index 91f5975435..89552f40c3 100644 --- a/src/pufferlib.cu +++ b/src/pufferlib.cu @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "models.cu" From 15b4243dcf166476973636c32d59bb22b759e485 Mon Sep 17 00:00:00 2001 From: bullardji <102577354+bullardji@users.noreply.github.com> Date: Mon, 4 May 2026 00:09:21 -0400 Subject: [PATCH 2/2] Update ocean.cu --- src/ocean.cu | 198 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 123 insertions(+), 75 deletions(-) diff --git a/src/ocean.cu b/src/ocean.cu index 3175a3b297..6cf8adf954 100644 --- a/src/ocean.cu +++ b/src/ocean.cu @@ -598,6 +598,7 @@ static constexpr int CH_BOARD_FEATURES = CH_BOARD_SQUARES * CH_PIECE_TYPES; static constexpr int CH_MAX_VALID_FROM = 16; static constexpr int CH_MAX_VALID_TO = 32; static constexpr int CH_NULL_SQ = 64; +static constexpr int CH_MAX_MOVE_CONTEXT_ACTIVE = 2 + CH_MAX_VALID_FROM + CH_MAX_VALID_TO; static constexpr int CH_MOVE_CONTEXT_SELECTED = 0; static constexpr int CH_MOVE_CONTEXT_VALID_FROM = 64; static constexpr int CH_MOVE_CONTEXT_VALID_TO = 128; @@ -605,6 +606,26 @@ static constexpr int CH_MOVE_CONTEXT_PHASE = 192; static constexpr int CH_MOVE_CONTEXT_FEATURES = 194; static constexpr int CH_META = 51; // side/castle/ep/clocks/checks/promos/pass +__constant__ int CH_META_SRC[CH_META] = { + CH_SIDE, + CH_CASTLE + 0, CH_CASTLE + 1, CH_CASTLE + 2, CH_CASTLE + 3, + CH_EP + 0, CH_EP + 1, CH_EP + 2, CH_EP + 3, CH_EP + 4, + CH_EP + 5, CH_EP + 6, CH_EP + 7, CH_EP + 8, + CH_RULE50, + CH_REPETITION, + CH_SELF_CHECK, + CH_OPP_CHECK, + CH_VALID_PROMOS + 0, CH_VALID_PROMOS + 1, CH_VALID_PROMOS + 2, CH_VALID_PROMOS + 3, + CH_VALID_PROMOS + 4, CH_VALID_PROMOS + 5, CH_VALID_PROMOS + 6, CH_VALID_PROMOS + 7, + CH_VALID_PROMOS + 8, CH_VALID_PROMOS + 9, CH_VALID_PROMOS + 10, CH_VALID_PROMOS + 11, + CH_VALID_PROMOS + 12, CH_VALID_PROMOS + 13, CH_VALID_PROMOS + 14, CH_VALID_PROMOS + 15, + CH_VALID_PROMOS + 16, CH_VALID_PROMOS + 17, CH_VALID_PROMOS + 18, CH_VALID_PROMOS + 19, + CH_VALID_PROMOS + 20, CH_VALID_PROMOS + 21, CH_VALID_PROMOS + 22, CH_VALID_PROMOS + 23, + CH_VALID_PROMOS + 24, CH_VALID_PROMOS + 25, CH_VALID_PROMOS + 26, CH_VALID_PROMOS + 27, + CH_VALID_PROMOS + 28, CH_VALID_PROMOS + 29, CH_VALID_PROMOS + 30, CH_VALID_PROMOS + 31, + CH_PASS_VALID +}; + // ---- Chess kernels ---- __global__ void chess_meta_kernel( @@ -615,24 +636,7 @@ __global__ void chess_meta_kernel( if (idx >= B * CH_META) return; int b = idx / CH_META; int m = idx % CH_META; - int src = CH_PASS_VALID; - if (m == 0) { - src = CH_SIDE; - } else if (m < 5) { - src = CH_CASTLE + (m - 1); - } else if (m < 14) { - src = CH_EP + (m - 5); - } else if (m == 14) { - src = CH_RULE50; - } else if (m == 15) { - src = CH_REPETITION; - } else if (m == 16) { - src = CH_SELF_CHECK; - } else if (m == 17) { - src = CH_OPP_CHECK; - } else if (m < 50) { - src = CH_VALID_PROMOS + (m - 18); - } + int src = CH_META_SRC[m]; float x = to_float(obs[b * obs_size + src]); if (x > 1.0f) x *= (1.0f / 255.0f); meta[idx] = from_float(x); @@ -645,44 +649,65 @@ __global__ void chess_embed_forward_kernel( const precision_t* __restrict__ move_context_w, const precision_t* __restrict__ bias, int B, int hidden, int obs_size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= B * hidden) return; - int b = idx / hidden; - int h = idx % hidden; - + int b = blockIdx.y; + int h = blockIdx.x * blockDim.x + threadIdx.x; const precision_t* obs_row = obs + b * obs_size; - float acc = to_float(out[idx]) + to_float(bias[h]); - - #pragma unroll 8 - for (int sq = 0; sq < CH_BOARD_SQUARES; sq++) { - int piece = (int)to_float(obs_row[CH_BOARD + sq]); - if (piece > 0) { - int feat = (piece - 1) * CH_BOARD_SQUARES + sq; - acc += to_float(board_w[feat * hidden + h]); + + __shared__ int board_features[CH_BOARD_SQUARES]; + __shared__ int board_count; + __shared__ int move_context_features[CH_MAX_MOVE_CONTEXT_ACTIVE]; + __shared__ int move_context_count; + + if (threadIdx.x == 0) { + int n = 0; + #pragma unroll 8 + for (int sq = 0; sq < CH_BOARD_SQUARES; sq++) { + int piece = (int)to_float(obs_row[CH_BOARD + sq]); + if (piece > 0) { + board_features[n++] = (piece - 1) * CH_BOARD_SQUARES + sq; + } } - } + board_count = n; - int phase = (int)to_float(obs_row[CH_PICK_PHASE]) > 0 ? 1 : 0; - acc += to_float(move_context_w[(CH_MOVE_CONTEXT_PHASE + phase) * hidden + h]); + n = 0; + int phase = (int)to_float(obs_row[CH_PICK_PHASE]) > 0 ? 1 : 0; + move_context_features[n++] = CH_MOVE_CONTEXT_PHASE + phase; - int selected = (int)to_float(obs_row[CH_SELECTED]); - if (selected < CH_BOARD_SQUARES) { - acc += to_float(move_context_w[(CH_MOVE_CONTEXT_SELECTED + selected) * hidden + h]); + int selected = (int)to_float(obs_row[CH_SELECTED]); + if (selected < CH_BOARD_SQUARES) { + move_context_features[n++] = CH_MOVE_CONTEXT_SELECTED + selected; + } + + int from_count = (int)to_float(obs_row[CH_VALID_FROM_COUNT]); + for (int k = 0; k < from_count; k++) { + int sq = (int)to_float(obs_row[CH_VALID_FROM + k]); + move_context_features[n++] = CH_MOVE_CONTEXT_VALID_FROM + sq; + } + + int to_count = (int)to_float(obs_row[CH_VALID_TO_COUNT]); + for (int k = 0; k < to_count; k++) { + int sq = (int)to_float(obs_row[CH_VALID_TO + k]); + move_context_features[n++] = CH_MOVE_CONTEXT_VALID_TO + sq; + } + + move_context_count = n; } - int from_count = (int)to_float(obs_row[CH_VALID_FROM_COUNT]); - for (int k = 0; k < from_count; k++) { - int sq = (int)to_float(obs_row[CH_VALID_FROM + k]); - acc += to_float(move_context_w[(CH_MOVE_CONTEXT_VALID_FROM + sq) * hidden + h]); + __syncthreads(); + if (h >= hidden) return; + + int out_idx = b * hidden + h; + float acc = to_float(out[out_idx]) + to_float(bias[h]); + + for (int i = 0; i < board_count; i++) { + acc += to_float(board_w[board_features[i] * hidden + h]); } - int to_count = (int)to_float(obs_row[CH_VALID_TO_COUNT]); - for (int k = 0; k < to_count; k++) { - int sq = (int)to_float(obs_row[CH_VALID_TO + k]); - acc += to_float(move_context_w[(CH_MOVE_CONTEXT_VALID_TO + sq) * hidden + h]); + for (int i = 0; i < move_context_count; i++) { + acc += to_float(move_context_w[move_context_features[i] * hidden + h]); } - out[idx] = from_float(fmaxf(0.0f, acc)); + out[out_idx] = from_float(fmaxf(0.0f, acc)); } __global__ void chess_embed_backward_kernel( @@ -691,42 +716,63 @@ __global__ void chess_embed_backward_kernel( const precision_t* __restrict__ grad, const precision_t* __restrict__ obs, int B, int hidden, int obs_size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= B * hidden) return; - int b = idx / hidden; - int h = idx % hidden; - float g = to_float(grad[idx]); - if (g == 0.0f) return; - + int b = blockIdx.y; + int h = blockIdx.x * blockDim.x + threadIdx.x; const precision_t* obs_row = obs + b * obs_size; - #pragma unroll 8 - for (int sq = 0; sq < CH_BOARD_SQUARES; sq++) { - int piece = (int)to_float(obs_row[CH_BOARD + sq]); - if (piece > 0) { - int feat = (piece - 1) * CH_BOARD_SQUARES + sq; - atomicAdd(&board_wgrad_f[feat * hidden + h], g); + __shared__ int board_features[CH_BOARD_SQUARES]; + __shared__ int board_count; + __shared__ int move_context_features[CH_MAX_MOVE_CONTEXT_ACTIVE]; + __shared__ int move_context_count; + + if (threadIdx.x == 0) { + int n = 0; + #pragma unroll 8 + for (int sq = 0; sq < CH_BOARD_SQUARES; sq++) { + int piece = (int)to_float(obs_row[CH_BOARD + sq]); + if (piece > 0) { + board_features[n++] = (piece - 1) * CH_BOARD_SQUARES + sq; + } } - } + board_count = n; + + n = 0; + int phase = (int)to_float(obs_row[CH_PICK_PHASE]) > 0 ? 1 : 0; + move_context_features[n++] = CH_MOVE_CONTEXT_PHASE + phase; - int phase = (int)to_float(obs_row[CH_PICK_PHASE]) > 0 ? 1 : 0; - atomicAdd(&move_context_wgrad_f[(CH_MOVE_CONTEXT_PHASE + phase) * hidden + h], g); + int selected = (int)to_float(obs_row[CH_SELECTED]); + if (selected < CH_BOARD_SQUARES) { + move_context_features[n++] = CH_MOVE_CONTEXT_SELECTED + selected; + } - int selected = (int)to_float(obs_row[CH_SELECTED]); - if (selected < CH_BOARD_SQUARES) { - atomicAdd(&move_context_wgrad_f[(CH_MOVE_CONTEXT_SELECTED + selected) * hidden + h], g); + int from_count = (int)to_float(obs_row[CH_VALID_FROM_COUNT]); + for (int k = 0; k < from_count; k++) { + int sq = (int)to_float(obs_row[CH_VALID_FROM + k]); + move_context_features[n++] = CH_MOVE_CONTEXT_VALID_FROM + sq; + } + + int to_count = (int)to_float(obs_row[CH_VALID_TO_COUNT]); + for (int k = 0; k < to_count; k++) { + int sq = (int)to_float(obs_row[CH_VALID_TO + k]); + move_context_features[n++] = CH_MOVE_CONTEXT_VALID_TO + sq; + } + + move_context_count = n; } - int from_count = (int)to_float(obs_row[CH_VALID_FROM_COUNT]); - for (int k = 0; k < from_count; k++) { - int sq = (int)to_float(obs_row[CH_VALID_FROM + k]); - atomicAdd(&move_context_wgrad_f[(CH_MOVE_CONTEXT_VALID_FROM + sq) * hidden + h], g); + __syncthreads(); + if (h >= hidden) return; + + int grad_idx = b * hidden + h; + float g = to_float(grad[grad_idx]); + if (g == 0.0f) return; + + for (int i = 0; i < board_count; i++) { + atomicAdd(&board_wgrad_f[board_features[i] * hidden + h], g); } - int to_count = (int)to_float(obs_row[CH_VALID_TO_COUNT]); - for (int k = 0; k < to_count; k++) { - int sq = (int)to_float(obs_row[CH_VALID_TO + k]); - atomicAdd(&move_context_wgrad_f[(CH_MOVE_CONTEXT_VALID_TO + sq) * hidden + h], g); + for (int i = 0; i < move_context_count; i++) { + atomicAdd(&move_context_wgrad_f[move_context_features[i] * hidden + h], g); } } @@ -762,7 +808,8 @@ static PrecisionTensor chess_embed_encoder_forward(void* w, void* activations, chess_meta_kernel<<>>( a->meta.data, input.data, B, ew->obs_size); puf_mm(&a->meta, &ew->meta_w, &a->out, stream); - chess_embed_forward_kernel<<hidden), BLOCK_SIZE, 0, stream>>>( + dim3 embed_grid(grid_size(ew->hidden), B); + chess_embed_forward_kernel<<>>( a->out.data, input.data, ew->board_w.data, ew->move_context_w.data, ew->bias.data, B, ew->hidden, ew->obs_size); return a->out; @@ -784,7 +831,8 @@ static void chess_embed_encoder_backward(void* w, void* activations, int move_context_n = CH_MOVE_CONTEXT_FEATURES * H; cudaMemsetAsync(a->board_wgrad_f.data, 0, board_n * sizeof(float), stream); cudaMemsetAsync(a->move_context_wgrad_f.data, 0, move_context_n * sizeof(float), stream); - chess_embed_backward_kernel<<>>( + dim3 embed_grid(grid_size(H), B); + chess_embed_backward_kernel<<>>( a->board_wgrad_f.data, a->move_context_wgrad_f.data, grad.data, a->saved_obs.data, B, H, ew->obs_size); n3_float_to_precision_kernel<<>>(