diff --git a/src/arm/64/refmvs.S b/src/arm/64/refmvs.S index 10584a580..0638cf456 100644 --- a/src/arm/64/refmvs.S +++ b/src/arm/64/refmvs.S @@ -386,20 +386,20 @@ function load_tmvs_neon, export=1 movrel x1, div_mult_tbl 10: // nloop - ldr w16, [x29, x10, lsl #2] // ref2cur = rf->mfmv_ref2cur[n] - cmp w16, #-32 // instead of INT_MIN, we can use smaller constants - b.lt 9f // if (ref2cur == INT_MIN) continue + ldrsb w16, [x29, x10] // ref2cur = rf->mfmv_ref2cur[n] + cmp w16, #-32 + b.eq 9f // if (ref2cur == INVALID_REF2CUR) continue add x17, x10, #(RMVSF_MFMV_REF - RMVSF_MFMV_REF2CUR) // n - (&rf->mfmv_ref - &rf->mfmv_ref2cur) mov x20, #4 ldrb w17, [x29, x17] // ref = rf->mfmv_ref[n] ldr x13, [x29, #(RMVSF_RP_REF - RMVSF_MFMV_REF2CUR)] - mov w28, #28 // 7 * sizeof(int) + sub x21, x10, x10, lsl #3 // -(n * 7) smaddl x20, row_start8, wstride5, x20 // row_start8 * stride * sizeof(refmvs_temporal_block) + 4 mov w12, row_start8 // y = row_start8 - add x21, x29, #(RMVSF_MFMV_REF2REF - RMVSF_MFMV_REF2CUR - 4) // &rf->mfmv_ref2ref - 1 + add x28, x29, #(RMVSF_MFMV_REF2REF - RMVSF_MFMV_REF2CUR - 1) // &rf->mfmv_ref2ref - 1 ldr x13, [x13, x17, lsl #3] // rf->rp_ref[ref] - smaddl x28, w28, w10, x21 // rf->mfmv_ref2ref[n] - 1 + sub x28, x28, x21 // rf->mfmv_ref2ref[n] - 1 sub w17, w17, #4 // ref_sign = ref - 4 add x13, x13, x20 // r = &rf->rp_ref[ref][row_start8 * stride].ref dup v0.2s, w17 // ref_sign @@ -418,7 +418,7 @@ function load_tmvs_neon, export=1 ldrb w22, [x23, x11] // b_ref = rb->ref cbz w22, 6f // if (!b_ref) continue - ldr w24, [x28, x22, lsl #2] // ref2ref = rf->mfmv_ref2ref[n][b_ref - 1] + ldrb w24, [x28, x22] // ref2ref = rf->mfmv_ref2ref[n][b_ref - 1] cbz w24, 6f // if (!ref2ref) continue ldrh w20, [x1, x24, lsl #1] // div_mult[ref2ref] diff --git a/src/arm/asm-offsets.h b/src/arm/asm-offsets.h index 538de71b3..595ac6900 100644 --- a/src/arm/asm-offsets.h +++ b/src/arm/asm-offsets.h @@ -47,12 +47,12 @@ #define RMVSF_IH8 20 #define RMVSF_MFMV_REF 53 #define RMVSF_MFMV_REF2CUR 56 -#define RMVSF_MFMV_REF2REF 68 -#define RMVSF_N_MFMVS 152 -#define RMVSF_RP_REF 168 -#define RMVSF_RP_PROJ 176 -#define RMVSF_RP_STRIDE 184 -#define RMVSF_N_TILE_THREADS 200 +#define RMVSF_MFMV_REF2REF 59 +#define RMVSF_N_MFMVS 80 +#define RMVSF_RP_REF 96 +#define RMVSF_RP_PROJ 104 +#define RMVSF_RP_STRIDE 112 +#define RMVSF_N_TILE_THREADS 128 #endif #endif /* ARM_ASM_OFFSETS_H */ diff --git a/src/refmvs.c b/src/refmvs.c index 6e8309b36..874a2a9c6 100644 --- a/src/refmvs.c +++ b/src/refmvs.c @@ -710,7 +710,7 @@ static void load_tmvs_c(const refmvs_frame *const rf, int tile_row_idx, rp_proj = &rf->rp_proj[16 * stride * tile_row_idx]; for (int n = 0; n < rf->n_mfmvs; n++) { const int ref2cur = rf->mfmv_ref2cur[n]; - if (ref2cur == INT_MIN) continue; + if (ref2cur == INVALID_REF2CUR) continue; const int ref = rf->mfmv_ref[n]; const int ref_sign = ref - 4; @@ -835,7 +835,7 @@ int dav1d_refmvs_init_frame(refmvs_frame *const rf, rf->n_blocks = n_blocks; } - const unsigned poc = frm_hdr->frame_offset; + const int poc = frm_hdr->frame_offset; for (int i = 0; i < 7; i++) { const int poc_diff = get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[i], poc); @@ -874,15 +874,15 @@ int dav1d_refmvs_init_frame(refmvs_frame *const rf, rf->mfmv_ref[rf->n_mfmvs++] = 1; // last2 for (int n = 0; n < rf->n_mfmvs; n++) { - const unsigned rpoc = ref_poc[rf->mfmv_ref[n]]; + const int rpoc = ref_poc[rf->mfmv_ref[n]]; const int diff1 = get_poc_diff(seq_hdr->order_hint_n_bits, rpoc, frm_hdr->frame_offset); if (abs(diff1) > 31) { - rf->mfmv_ref2cur[n] = INT_MIN; + rf->mfmv_ref2cur[n] = INVALID_REF2CUR; } else { rf->mfmv_ref2cur[n] = rf->mfmv_ref[n] < 4 ? -diff1 : diff1; for (int m = 0; m < 7; m++) { - const unsigned rrpoc = ref_ref_poc[rf->mfmv_ref[n]][m]; + const int rrpoc = ref_ref_poc[rf->mfmv_ref[n]][m]; const int diff2 = get_poc_diff(seq_hdr->order_hint_n_bits, rpoc, rrpoc); // unsigned comparison also catches the < 0 case diff --git a/src/refmvs.h b/src/refmvs.h index a5bcdad7e..23d0f2273 100644 --- a/src/refmvs.h +++ b/src/refmvs.h @@ -38,10 +38,11 @@ #include "src/tables.h" #define INVALID_MV 0x80008000 +#define INVALID_REF2CUR (-32) PACKED(typedef struct refmvs_temporal_block { mv mv; - int8_t ref; + uint8_t ref; }) refmvs_temporal_block; CHECK_SIZE(refmvs_temporal_block, 5); @@ -72,8 +73,8 @@ typedef struct refmvs_frame { uint8_t sign_bias[7], mfmv_sign[7]; int8_t pocdiff[7]; uint8_t mfmv_ref[3]; - int mfmv_ref2cur[3]; - int mfmv_ref2ref[3][7]; + int8_t mfmv_ref2cur[3]; + uint8_t mfmv_ref2ref[3][7]; int n_mfmvs; int n_blocks; diff --git a/src/refmvs.rs b/src/refmvs.rs index 188bdcfab..c4d9057c3 100644 --- a/src/refmvs.rs +++ b/src/refmvs.rs @@ -21,11 +21,13 @@ use crate::intra_edge::EdgeFlags; use crate::levels::{BlockSize, Mv, UnalignedMv}; use crate::wrap_fn_ptr::wrap_fn_ptr; +const INVALID_REF2CUR: i8 = -32; + #[derive(Clone, Copy, Default, PartialEq, Eq)] #[repr(C, packed)] pub struct RefMvsTemporalBlock { pub mv: UnalignedMv, - pub r#ref: i8, + pub r#ref: u8, } const _: () = assert!(mem::size_of::() == 5); @@ -107,8 +109,8 @@ pub(crate) struct AsmRefMvsFrame<'a> { pub mfmv_sign: [u8; 7], pub pocdiff: [i8; 7], pub mfmv_ref: [u8; 3], - pub mfmv_ref2cur: [i32; 3], - pub mfmv_ref2ref: [[i32; 7]; 3], + pub mfmv_ref2cur: [i8; 3], + pub mfmv_ref2ref: [[u8; 7]; 3], pub n_mfmvs: i32, pub n_blocks: i32, pub rp: *mut RefMvsTemporalBlock, @@ -134,8 +136,8 @@ pub struct RefMvsFrame { pub mfmv_sign: [u8; 7], pub pocdiff: [i8; 7], pub mfmv_ref: [u8; 3], - pub mfmv_ref2cur: [i32; 3], - pub mfmv_ref2ref: [[i32; 7]; 3], + pub mfmv_ref2cur: [i8; 3], + pub mfmv_ref2ref: [[u8; 7]; 3], pub n_mfmvs: i32, pub n_blocks: u32, // TODO: The C code uses a single buffer to store `rp_proj` and `r` to minimize @@ -1451,7 +1453,7 @@ fn load_tmvs_rust( } for n in 0..rf.n_mfmvs { let ref2cur = rf.mfmv_ref2cur[n as usize]; - if ref2cur == i32::MIN { + if ref2cur == INVALID_REF2CUR { continue; } let r#ref = rf.mfmv_ref[n as usize]; @@ -1474,7 +1476,7 @@ fn load_tmvs_rust( x += 1; continue; } - let offset = mv_projection(rb.mv.into_aligned(), ref2cur, ref2ref); + let offset = mv_projection(rb.mv.into_aligned(), ref2cur.into(), ref2ref.into()); let mut pos_x = x + apply_sign((offset.x as i32).abs() >> 6, offset.x as i32 ^ ref_sign); let pos_y = @@ -1490,7 +1492,7 @@ fn load_tmvs_rust( rp_proj_offset + (pos as isize + pos_x as isize) as usize, ) = RefMvsTemporalBlock { mv: rb.mv, - r#ref: ref2ref as i8, + r#ref: ref2ref, }; } x += 1; @@ -1578,7 +1580,7 @@ fn save_tmvs_rust( { Some(RefMvsTemporalBlock { mv: mv.into_unaligned(), - r#ref, + r#ref: r#ref as u8, }) } else { None @@ -1696,14 +1698,18 @@ pub(crate) fn rav1d_refmvs_init_frame( frm_hdr.frame_offset as i32, ); if diff1.abs() > 31 { - rf.mfmv_ref2cur[n] = i32::MIN; + rf.mfmv_ref2cur[n] = INVALID_REF2CUR; } else { - rf.mfmv_ref2cur[n] = if rf.mfmv_ref[n] < 4 { -diff1 } else { diff1 }; + rf.mfmv_ref2cur[n] = if rf.mfmv_ref[n] < 4 { + -diff1 as i8 + } else { + diff1 as i8 + }; for m in 0..7 { let rrpoc = ref_ref_poc[rf.mfmv_ref[n] as usize][m]; let diff2 = get_poc_diff(seq_hdr.order_hint_n_bits, rpoc as i32, rrpoc as i32); // unsigned comparison also catches the < 0 case - rf.mfmv_ref2ref[n][m] = if diff2 as u32 > 31 { 0 } else { diff2 }; + rf.mfmv_ref2ref[n][m] = if diff2 as u32 > 31 { 0 } else { diff2 as u8 }; } } } diff --git a/src/x86/refmvs.asm b/src/x86/refmvs.asm index 085c9b3de..84d70cf39 100644 --- a/src/x86/refmvs.asm +++ b/src/x86/refmvs.asm @@ -104,8 +104,8 @@ struc rf .mfmv_sign: resb 7 .pocdiff: resb 7 .mfmv_ref: resb 3 - .mfmv_ref2cur: resd 3 - .mfmv_ref2ref: resd 3*7 + .mfmv_ref2cur: resb 3 + .mfmv_ref2ref: resb 3*7 .n_mfmvs: resd 1 .n_blocks: resd 1 .rp: resq 1 @@ -432,7 +432,7 @@ cglobal load_tmvs, 6, 15, 4, -0x50, rf, tridx, xstart, xend, ystart, yend, \ mov [rsp+0x38], yendd mov [rsp+0x20], xstartid xor nd, nd - xor n7d, n7d + lea n7q, [rfq+rf.mfmv_ref2ref-1] imul r9, strideq ; ystart * stride mov [rsp+0x48], rfq mov [rsp+0x18], stride5q @@ -443,8 +443,8 @@ cglobal load_tmvs, 6, 15, 4, -0x50, rf, tridx, xstart, xend, ystart, yend, \ DEFINE_ARGS y, off, xstart, xend, ystart, rf, n7, refsign, \ ref, rp_ref, xendi, xstarti, _, _, n mov rfq, [rsp+0x48] - mov refd, [rfq+rf.mfmv_ref2cur+nq*4] - cmp refd, 0x80000000 + movsx refd, byte [rfq+rf.mfmv_ref2cur+nq] + cmp refd, -32 ; INVALID_REF2CUR je .next_n mov [rsp+0x40], refd mov offq, [rsp+0x00] ; ystart * stride * 5 @@ -473,12 +473,10 @@ cglobal load_tmvs, 6, 15, 4, -0x50, rf, tridx, xstart, xend, ystart, yend, \ .xloop: lea rbd, [xq*5] add rbq, srcq - movsx refd, byte [rbq+4] + movzx refd, byte [rbq+4] test refd, refd jz .next_x_bad_ref - mov rfq, [rsp+0x48] - lea ref2refd, [(rf.mfmv_ref2ref/4)+n7q+refq-1] - mov ref2refd, [rfq+ref2refq*4] ; rf->mfmv_ref2ref[n][b_ref-1] + movzx ref2refd, byte [n7q+refq] ; rf->mfmv_ref2ref[n][b_ref-1] test ref2refd, ref2refd jz .next_x_bad_ref lea fracq, [mv_proj] @@ -554,7 +552,7 @@ cglobal load_tmvs, 6, 15, 4, -0x50, rf, tridx, xstart, xend, ystart, yend, \ mov nd, [rsp+0x14] mov ystartd, [rsp+0x24] .next_n: - add n7d, 7 + add n7q, 7 inc nd cmp nd, [rsp+0x0c] ; n_mfmvs jne .nloop diff --git a/tests/checkasm/refmvs.c b/tests/checkasm/refmvs.c index 4d082cb3f..7205420e8 100644 --- a/tests/checkasm/refmvs.c +++ b/tests/checkasm/refmvs.c @@ -49,7 +49,7 @@ static inline int get_min_mv_val(const int idx) { else return (idx - 36) * 10000; } -static inline void gen_tmv(refmvs_temporal_block *const rb, const int *ref2ref) { +static inline void gen_tmv(refmvs_temporal_block *const rb, const uint8_t *const ref2ref) { rb->ref = rnd() % 7; if (!rb->ref) return; static const int x_prob[] = {