diff --git a/crates/core_arch/missing-x86.md b/crates/core_arch/missing-x86.md index e9f68eb9e6..3a82f9761f 100644 --- a/crates/core_arch/missing-x86.md +++ b/crates/core_arch/missing-x86.md @@ -1,41 +1,4 @@ -
["AMX-BF16"]

- - * [ ] [`__tile_dpbf16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbf16ps) -

- - -
["AMX-COMPLEX"]

- - * [ ] [`__tile_cmmimfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_cmmimfp16ps) - * [ ] [`__tile_cmmrlfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_cmmrlfp16ps) -

- - -
["AMX-FP16"]

- - * [ ] [`__tile_dpfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpfp16ps) -

- - -
["AMX-INT8"]

- - * [ ] [`__tile_dpbssd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbssd) - * [ ] [`__tile_dpbsud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbsud) - * [ ] [`__tile_dpbusd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbusd) - * [ ] [`__tile_dpbuud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbuud) -

- - -
["AMX-TILE"]

- - * [ ] [`__tile_loadd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_loadd) - * [ ] [`__tile_stored`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_stored) - * [ ] [`__tile_stream_loadd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_stream_loadd) - * [ ] [`__tile_zero`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_zero) -

- -
["AVX512_FP16"]

* [ ] [`_mm256_set1_pch`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_set1_pch) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index b3b3e86750..7d693ec1a5 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -1,14 +1,34 @@ +use crate::core_arch::x86_64::{__tile1024i, Tile}; use crate::core_arch::{simd::*, x86::*}; #[cfg(test)] use stdarch_test::assert_instr; -/// Load tile configuration from a 64-byte memory location specified by mem_addr. +/// Load tile configuration from a 64-byte memory location specified by `mem_addr`. /// The tile configuration format is specified below, and includes the tile type pallette, /// the number of bytes per row, and the number of rows. If the specified pallette_id is zero, /// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed. /// Any invalid configurations will result in #GP fault. /// +/// ```intel +/// // format of memory payload. each field is a byte. +/// 0: palette +/// 1: start_row +/// 2-15: reserved, must be zero +/// 16-17: tile0.colsb +/// 18-19: tile1.colsb +/// 20-21: tile2.colsb +/// ... +/// 30-31: tile7.colsb +/// 32-47: reserved, must be zero +/// 48: tile0.rows +/// 49: tile1.rows +/// 50: tile2.rows +/// ... +/// 55: tile7.rows +/// 56-63: reserved, must be zero +/// ``` +/// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875) #[inline] #[target_feature(enable = "amx-tile")] @@ -18,8 +38,8 @@ pub unsafe fn _tile_loadconfig(mem_addr: *const u8) { ldtilecfg(mem_addr); } -/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr. -/// The tile configuration format is specified below, and includes the tile type pallette, +/// Stores the current tile configuration to a 64-byte memory location specified by `mem_addr`. +/// The tile configuration format is as specified in [`_tile_loadconfig`], and includes the tile type pallette, /// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879) @@ -31,7 +51,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { sttilecfg(mem_addr); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. +/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration previously configured via [`_tile_loadconfig`]. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877) #[inline] @@ -41,7 +61,19 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloadd64(DST as i8, base, stride); + tileloadd64(DST as i8, base, stride as u64); +} + +/// Load tile rows from memory specified by base address and stride into destination tile dst. The shape +/// of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_loadd&ig_expand=6877) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tileloadd))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_loadd(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloadd64_internal((*dst).rows, (*dst).cols, base, stride as u64); } /// Release the tile configuration to return to the init state, which releases all storage it currently holds. @@ -55,7 +87,7 @@ pub unsafe fn _tile_release() { tilerelease(); } -/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. +/// Store the tile specified by src to memory specified by base address and stride using the tile configuration previously configured via [`_tile_loadconfig`]. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881) #[inline] @@ -65,11 +97,23 @@ pub unsafe fn _tile_release() { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tilestored64(DST as i8, base, stride); + tilestored64(DST as i8, base, stride as u64); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration -/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will +/// Store the tile specified by src to memory specified by base address and stride. The shape of the tile +/// is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stored&ig_expand=6881) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tilestored))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_stored(base: *mut u8, stride: usize, src: __tile1024i) { + tilestored64_internal(src.rows, src.cols, base, stride as u64, src.tile); +} + +/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration +/// previously configured via [`_tile_loadconfig`]. This intrinsic provides a hint to the implementation that the data will /// likely not be reused in the near future and the data caching can be optimized accordingly. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883) @@ -80,10 +124,24 @@ pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddt164(DST as i8, base, stride); + tileloaddt164(DST as i8, base, stride as u64); +} + +/// Load tile rows from memory specified by base address and stride into destination tile dst. The shape +/// of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// This intrinsic provides a hint to the implementation that the data will likely not be reused in the +/// near future and the data caching can be optimized accordingly. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stream_loadd&ig_expand=6883) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tileloaddt1))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_stream_loadd(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloaddt164_internal((*dst).rows, (*dst).cols, base, stride as u64); } -/// Zero the tile specified by tdest. +/// Zero the tile specified by `tdest`. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) #[inline] @@ -96,6 +154,18 @@ pub unsafe fn _tile_zero() { tilezero(DST as i8); } +/// Zero the tile specified by `dst`. The shape of the tile is specified in the struct of [`__tile1024i`]. +/// The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_zero&ig_expand=6885) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tilezero))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_zero(dst: *mut __tile1024i) { + (*dst).tile = tilezero_internal((*dst).rows, (*dst).cols); +} + /// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b, /// accumulating the intermediate single-precision (32-bit) floating-point elements /// with elements in dst, and store the 32-bit result back to tile dst. @@ -113,6 +183,20 @@ pub unsafe fn _tile_dpbf16ps() { tdpbf16ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, +/// accumulating the intermediate single-precision (32-bit) floating-point elements +/// with elements in dst, and store the 32-bit result back to tile dst. The shape of the tile +/// is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbf16ps&ig_expand=6864) +#[inline] +#[target_feature(enable = "amx-bf16")] +#[cfg_attr(test, assert_instr(tdpbf16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbf16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbf16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -131,6 +215,21 @@ pub unsafe fn _tile_dpbssd() { tdpbssd(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding +/// signed 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbssd&ig_expand=6866) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbssd))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbssd(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbssd_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -149,6 +248,21 @@ pub unsafe fn _tile_dpbsud() { tdpbsud(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding +/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbsud&ig_expand=6868) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbsud))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbsud(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbsud_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -167,6 +281,21 @@ pub unsafe fn _tile_dpbusd() { tdpbusd(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding +/// signed 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbusd&ig_expand=6870) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbusd))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbusd(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbusd_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -185,6 +314,21 @@ pub unsafe fn _tile_dpbuud() { tdpbuud(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding +/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbuud&ig_expand=6872) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbuud))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbuud(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbuud_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, /// accumulating the intermediate single-precision (32-bit) floating-point elements /// with elements in dst, and store the 32-bit result back to tile dst. @@ -202,6 +346,20 @@ pub unsafe fn _tile_dpfp16ps() { tdpfp16ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, +/// accumulating the intermediate single-precision (32-bit) floating-point elements +/// with elements in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpfp16ps&ig_expand=6874) +#[inline] +#[target_feature(enable = "amx-fp16")] +#[cfg_attr(test, assert_instr(tdpfp16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpfp16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. /// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), @@ -223,6 +381,24 @@ pub unsafe fn _tile_cmmimfp16ps() { tcmmimfp16ps(DST as i8, A as i8, B as i8); } +/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. +/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. +/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), +/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). +/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of +/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added, +/// and then accumulated into the corresponding row and column of dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmimfp16ps&ig_expand=6860) +#[inline] +#[target_feature(enable = "amx-complex")] +#[cfg_attr(test, assert_instr(tcmmimfp16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cmmimfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tcmmimfp16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. /// Calculates the real part of the result. For each possible combination of (row of a, column of b), @@ -244,6 +420,24 @@ pub unsafe fn _tile_cmmrlfp16ps() { tcmmrlfp16ps(DST as i8, A as i8, B as i8); } +/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. +/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. +/// Calculates the real part of the result. For each possible combination of (row of a, column of b), +/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). +/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of +/// the a element is multiplied with the imaginary part of the corresponding b elements. +/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmrlfp16ps&ig_expand=6862) +#[inline] +#[target_feature(enable = "amx-complex")] +#[cfg_attr(test, assert_instr(tcmmrlfp16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cmmrlfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tcmmrlfp16ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2) /// floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -263,6 +457,19 @@ pub unsafe fn _tile_dpbf8ps() { tdpbf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2) +/// floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8 /// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -282,6 +489,19 @@ pub unsafe fn _tile_dpbhf8ps() { tdpbhf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8 +/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbhf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbhf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbhf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8 /// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -301,6 +521,19 @@ pub unsafe fn _tile_dphbf8ps() { tdphbf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8 +/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphbf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dphbf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdphbf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3) /// floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -320,8 +553,21 @@ pub unsafe fn _tile_dphf8ps() { tdphf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3) +/// floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dphf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdphf8ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// using the tile configuration previously configured via [`_tile_loadconfig`]. /// Additionally, this intrinsic indicates the source memory location is likely to become /// read-shared by multiple processors, i.e., read in the future by at least one other processor /// before it is written, assuming it is ever written in the future. @@ -335,11 +581,24 @@ pub unsafe fn _tile_dphf8ps() { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddrs64(DST as i8, base, stride); + tileloaddrs64(DST as i8, base, stride as u64); +} + +/// Load tile rows from memory specified by base address and stride into destination tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// Additionally, this intrinsic indicates the source memory location is likely to become +/// read-shared by multiple processors, i.e., read in the future by at least one other processor +/// before it is written, assuming it is ever written in the future. +#[inline] +#[target_feature(enable = "amx-movrs")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tileloaddrs))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloaddrs64_internal((*dst).rows, (*dst).cols, base, stride as u64); } /// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// using the tile configuration previously configured via [`_tile_loadconfig`]. /// Provides a hint to the implementation that the data would be reused but does not need /// to be resident in the nearest cache levels. /// Additionally, this intrinsic indicates the source memory location is likely to become @@ -355,7 +614,22 @@ pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stream_loaddrs(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddrst164(DST as i8, base, stride); + tileloaddrst164(DST as i8, base, stride as u64); +} + +/// Load tile rows from memory specified by base address and stride into destination tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// Provides a hint to the implementation that the data would be reused but does not need +/// to be resident in the nearest cache levels. +/// Additionally, this intrinsic indicates the source memory location is likely to become +/// read-shared by multiple processors, i.e., read in the future by at least one other processor +/// before it is written, assuming it is ever written in the future. +#[inline] +#[target_feature(enable = "amx-movrs")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tileloaddrst1))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_stream_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloaddrst164_internal((*dst).rows, (*dst).cols, base, stride as u64); } /// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) @@ -383,6 +657,25 @@ pub unsafe fn _tile_mmultf32ps() { tmmultf32ps(DST as i8, A as i8, B as i8); } +/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) +/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the +/// results into a packed single precision tile. +/// For each possible combination of (row of a, column of b), it performs +/// - convert to TF32 +/// - multiply the corresponding elements of a and b +/// - accumulate the results into the corresponding row and column of dst using round-to-nearest-even +/// rounding mode. +/// Output FP32 denormals are always flushed to zero, input single precision denormals are always +/// handled and *not* treated as zero. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-tf32")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tmmultf32ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_mmultf32ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tmmultf32ps_internal(a.rows, b.cols, a.cols, (*dst).tile, a.tile, b.tile); +} + /// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer /// elements to packed single-precision (32-bit) floating-point elements. #[inline] @@ -414,6 +707,17 @@ pub unsafe fn _tile_cvtrowd2psi() -> __m512 { tcvtrowd2psi(TILE as i8, ROW as u32).as_m512() } +/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer +/// elements to packed single-precision (32-bit) floating-point elements. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowd2ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowd2ps(src: __tile1024i, row: u32) -> __m512 { + tcvtrowd2ps_internal(src.rows, src.cols, src.tile, row).as_m512() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. @@ -447,6 +751,18 @@ pub unsafe fn _tile_cvtrowps2phhi() -> __m512h tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phh))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2phh(src: __tile1024i, row: u32) -> __m512h { + tcvtrowps2phh_internal(src.rows, src.cols, src.tile, row).as_m512h() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. @@ -480,6 +796,18 @@ pub unsafe fn _tile_cvtrowps2phli() -> __m512h tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phl))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2phl(src: __tile1024i, row: u32) -> __m512h { + tcvtrowps2phl_internal(src.rows, src.cols, src.tile, row).as_m512h() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. @@ -513,6 +841,18 @@ pub unsafe fn _tile_cvtrowps2bf16hi() -> __m512 tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16h))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2bf16h(src: __tile1024i, row: u32) -> __m512bh { + tcvtrowps2bf16h_internal(src.rows, src.cols, src.tile, row).as_m512bh() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. @@ -546,6 +886,18 @@ pub unsafe fn _tile_cvtrowps2bf16li() -> __m512 tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16l))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2bf16l(src: __tile1024i, row: u32) -> __m512bh { + tcvtrowps2bf16l_internal(src.rows, src.cols, src.tile, row).as_m512bh() +} + /// Moves one row of tile data into a zmm vector register #[inline] #[rustc_legacy_const_generics(0)] @@ -575,83 +927,170 @@ pub unsafe fn _tile_movrowi() -> __m512i { tilemovrowi(TILE as i8, ROW as u32).as_m512i() } +/// Moves one row of tile data into a zmm vector register +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tilemovrow))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_movrow(src: __tile1024i, row: u32) -> __m512i { + tilemovrow_internal(src.rows, src.cols, src.tile, row).as_m512i() +} + #[allow(improper_ctypes)] -unsafe extern "C" { +unsafe extern "unadjusted" { #[link_name = "llvm.x86.ldtilecfg"] fn ldtilecfg(mem_addr: *const u8); #[link_name = "llvm.x86.sttilecfg"] fn sttilecfg(mem_addr: *mut u8); + #[link_name = "llvm.x86.tileloadd64"] - fn tileloadd64(dst: i8, base: *const u8, stride: usize); + fn tileloadd64(dst: i8, base: *const u8, stride: u64); + #[link_name = "llvm.x86.tileloadd64.internal"] + fn tileloadd64_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tileloaddt164"] - fn tileloaddt164(dst: i8, base: *const u8, stride: usize); + fn tileloaddt164(dst: i8, base: *const u8, stride: u64); + #[link_name = "llvm.x86.tileloaddt164.internal"] + fn tileloaddt164_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tilerelease"] fn tilerelease(); + #[link_name = "llvm.x86.tilestored64"] - fn tilestored64(dst: i8, base: *mut u8, stride: usize); + fn tilestored64(dst: i8, base: *mut u8, stride: u64); + #[link_name = "llvm.x86.tilestored64.internal"] + fn tilestored64_internal(rows: u16, cols: u16, base: *mut u8, stride: u64, src: Tile); + #[link_name = "llvm.x86.tilezero"] fn tilezero(dst: i8); + #[link_name = "llvm.x86.tilezero.internal"] + fn tilezero_internal(rows: u16, cols: u16) -> Tile; + #[link_name = "llvm.x86.tdpbf16ps"] fn tdpbf16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbf16ps.internal"] + fn tdpbf16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbuud"] fn tdpbuud(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbuud.internal"] + fn tdpbuud_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbusd"] fn tdpbusd(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbusd.internal"] + fn tdpbusd_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbsud"] fn tdpbsud(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbsud.internal"] + fn tdpbsud_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbssd"] fn tdpbssd(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbssd.internal"] + fn tdpbssd_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpfp16ps"] fn tdpfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpfp16ps.internal"] + fn tdpfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tcmmimfp16ps"] fn tcmmimfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tcmmimfp16ps.internal"] + fn tcmmimfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tcmmrlfp16ps"] fn tcmmrlfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tcmmrlfp16ps.internal"] + fn tcmmrlfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbf8ps"] fn tdpbf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbf8ps.internal"] + fn tdpbf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbhf8ps"] fn tdpbhf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbhf8ps.internal"] + fn tdpbhf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdphbf8ps"] fn tdphbf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdphbf8ps.internal"] + fn tdphbf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdphf8ps"] fn tdphf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdphf8ps.internal"] + fn tdphf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tileloaddrs64"] - fn tileloaddrs64(dst: i8, base: *const u8, stride: usize); + fn tileloaddrs64(dst: i8, base: *const u8, stride: u64); + #[link_name = "llvm.x86.tileloaddrs64.internal"] + fn tileloaddrs64_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tileloaddrst164"] - fn tileloaddrst164(dst: i8, base: *const u8, stride: usize); + fn tileloaddrst164(dst: i8, base: *const u8, stride: u64); + #[link_name = "llvm.x86.tileloaddrst164.internal"] + fn tileloaddrst164_internal(rows: u16, cols: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tmmultf32ps"] fn tmmultf32ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tmmultf32ps.internal"] + fn tmmultf32ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tcvtrowd2ps"] fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16; #[link_name = "llvm.x86.tcvtrowd2psi"] fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16; + #[link_name = "llvm.x86.tcvtrowd2ps.internal"] + fn tcvtrowd2ps_internal(rows: u16, cols: u16, src: Tile, row: u32) -> f32x16; + #[link_name = "llvm.x86.tcvtrowps2phh"] fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32; #[link_name = "llvm.x86.tcvtrowps2phhi"] fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phh.internal"] + fn tcvtrowps2phh_internal(rows: u16, cols: u16, src: Tile, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phl"] fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32; #[link_name = "llvm.x86.tcvtrowps2phli"] fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phl.internal"] + fn tcvtrowps2phl_internal(rows: u16, cols: u16, src: Tile, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16h"] fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32; #[link_name = "llvm.x86.tcvtrowps2bf16hi"] fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16h.internal"] + fn tcvtrowps2bf16h_internal(rows: u16, cols: u16, src: Tile, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16l"] fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32; #[link_name = "llvm.x86.tcvtrowps2bf16li"] fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16l.internal"] + fn tcvtrowps2bf16l_internal(rows: u16, cols: u16, src: Tile, row: u32) -> u16x32; + #[link_name = "llvm.x86.tilemovrow"] fn tilemovrow(tile: i8, row: u32) -> i32x16; #[link_name = "llvm.x86.tilemovrowi"] fn tilemovrowi(tile: i8, row: u32) -> i32x16; + #[link_name = "llvm.x86.tilemovrow.internal"] + fn tilemovrow_internal(rows: u16, cols: u16, src: Tile, row: u32) -> i32x16; } #[cfg(test)] mod tests { use crate::core_arch::x86::_mm_cvtness_sbh; use crate::core_arch::x86_64::*; - use core::{array, mem::transmute}; + use core::array; + use core::mem::MaybeUninit; use stdarch_test::simd_test; #[cfg(target_os = "linux")] use syscalls::{Sysno, syscall}; @@ -704,19 +1143,23 @@ mod tests { #[cfg(target_os = "linux")] #[target_feature(enable = "amx-tile")] #[inline] - unsafe fn _init_amx() { + fn _init_amx() { let mut ret: usize; let mut xfeatures: usize = 0; - ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize) - .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed"); + ret = unsafe { + syscall!(Sysno::arch_prctl, 0x1022, &raw mut xfeatures) + .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed") + }; if ret != 0 { panic!("Failed to get XFEATURES"); } else { match 0b11 & (xfeatures >> 17) { 0 => panic!("AMX is not available"), 1 => { - ret = syscall!(Sysno::arch_prctl, 0x1023, 18) - .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed"); + ret = unsafe { + syscall!(Sysno::arch_prctl, 0x1023, 18) + .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed") + }; if ret != 0 { panic!("Failed to enable AMX"); } @@ -727,6 +1170,18 @@ mod tests { } } + impl __tile1024i { + #[inline] + #[target_feature(enable = "amx-tile")] + fn zeroed(rows: u16, cols: u16) -> Self { + Self { + rows, + cols, + tile: unsafe { super::tilezero_internal(rows, cols) }, + } + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_loadconfig() { unsafe { @@ -759,12 +1214,26 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } } + #[simd_test(enable = "amx-tile")] + fn test__tile_zero() { + unsafe { + _init_amx(); + + let tile = __tile1024i::zeroed(16, 64); + + let mut out = [[1_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[0; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_stored() { unsafe { @@ -776,12 +1245,26 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } } + #[simd_test(enable = "amx-tile")] + fn test__tile_stored() { + unsafe { + _init_amx(); + + let tile = __tile1024i::zeroed(16, 64); + + let mut out = [[1_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[0; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_loadd() { unsafe { @@ -793,14 +1276,30 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_loadd::<0>(&mat as *const i8 as *const u8, 64); + _tile_loadd::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } } + #[simd_test(enable = "amx-tile")] + fn test__tile_loadd() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_loadd(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_stream_loadd() { unsafe { @@ -812,14 +1311,30 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64); + _tile_stream_loadd::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } } + #[simd_test(enable = "amx-tile")] + fn test__tile_stream_loadd() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_stream_loadd(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_release() { unsafe { @@ -827,14 +1342,15 @@ mod tests { } } - #[simd_test(enable = "amx-bf16,avx512f")] + const BF16_1: u16 = 0x3f80; + const BF16_2: u16 = 0x4000; + + #[simd_test(enable = "amx-bf16")] fn test_tile_dpbf16ps() { unsafe { _init_amx(); - let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); - let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); - let ones: [u8; 1024] = transmute([bf16_1; 512]); - let twos: [u8; 1024] = transmute([bf16_2; 512]); + let ones = [BF16_1; 512]; + let twos = [BF16_2; 512]; let mut res = [[0f32; 16]; 16]; let mut config = __tilecfg::default(); config.palette = 1; @@ -844,15 +1360,36 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbf16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } } + #[simd_test(enable = "amx-bf16,avx512f")] + fn test__tile_dpbf16ps() { + unsafe { + _init_amx(); + let ones = [BF16_1; 512]; + let twos = [BF16_2; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpbf16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[64f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbssd() { unsafe { @@ -868,15 +1405,36 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); - _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbssd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbssd() { + unsafe { + _init_amx(); + let ones = [-1_i8; 1024]; + let twos = [-2_i8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpbssd(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbsud() { unsafe { @@ -892,15 +1450,36 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbsud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbsud() { + unsafe { + _init_amx(); + let ones = [-1_i8; 1024]; + let twos = [2_u8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbsud(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[-128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbusd() { unsafe { @@ -916,15 +1495,36 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbusd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbusd() { + unsafe { + _init_amx(); + let ones = [1_u8; 1024]; + let twos = [-2_i8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpbusd(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[-128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbuud() { unsafe { @@ -940,15 +1540,36 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbuud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbuud() { + unsafe { + _init_amx(); + let ones = [1_u8; 1024]; + let twos = [2_u8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbuud(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp16")] fn test_tile_dpfp16ps() { unsafe { @@ -964,15 +1585,36 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } } + #[simd_test(enable = "amx-fp16")] + fn test__tile_dpfp16ps() { + unsafe { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpfp16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[64f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-complex")] fn test_tile_cmmimfp16ps() { unsafe { @@ -988,15 +1630,36 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_cmmimfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } } + #[simd_test(enable = "amx-complex")] + fn test__tile_cmmimfp16ps() { + unsafe { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_cmmimfp16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[64f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-complex")] fn test_tile_cmmrlfp16ps() { unsafe { @@ -1012,15 +1675,36 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_cmmrlfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[0f32; 16]; 16]); } } + #[simd_test(enable = "amx-complex")] + fn test__tile_cmmrlfp16ps() { + unsafe { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_cmmrlfp16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[0f32; 16]; 16]); + } + } + const BF8_ONE: u8 = 0x3c; const BF8_TWO: u8 = 0x40; const HF8_ONE: u8 = 0x38; @@ -1041,8 +1725,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1050,6 +1734,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dpbf8ps() { + unsafe { + _init_amx(); + let ones = [BF8_ONE; 1024]; + let twos = [BF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp8")] fn test_tile_dpbhf8ps() { unsafe { @@ -1065,8 +1770,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbhf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1074,6 +1779,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dpbhf8ps() { + unsafe { + _init_amx(); + let ones = [BF8_ONE; 1024]; + let twos = [HF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbhf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp8")] fn test_tile_dphbf8ps() { unsafe { @@ -1089,8 +1815,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dphbf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1098,6 +1824,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dphbf8ps() { + unsafe { + _init_amx(); + let ones = [HF8_ONE; 1024]; + let twos = [BF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dphbf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp8")] fn test_tile_dphf8ps() { unsafe { @@ -1113,8 +1860,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dphf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1122,6 +1869,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dphf8ps() { + unsafe { + _init_amx(); + let ones = [HF8_ONE; 1024]; + let twos = [HF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dphf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-movrs")] fn test_tile_loaddrs() { unsafe { @@ -1133,14 +1901,30 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64); + _tile_loaddrs::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } } + #[simd_test(enable = "amx-movrs")] + fn test__tile_loaddrs() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_loaddrs(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-movrs")] fn test_tile_stream_loaddrs() { unsafe { @@ -1152,14 +1936,30 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64); + _tile_stream_loaddrs::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } } + #[simd_test(enable = "amx-movrs")] + fn test__tile_stream_loaddrs() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_stream_loaddrs(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_movrow() { unsafe { @@ -1223,6 +2023,22 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_movrow() { + unsafe { + _init_amx(); + let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_movrow(tile, i); + assert_eq!(*row.as_u8x64().as_array(), [i as _; _]); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowd2ps() { unsafe { @@ -1262,6 +2078,22 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowd2ps() { + unsafe { + _init_amx(); + let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowd2ps(tile, i); + assert_eq!(*row.as_f32x16().as_array(), [i as _; _]); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2phh() { unsafe { @@ -1306,6 +2138,25 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2phh() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2phh(tile, i); + assert_eq!( + *row.as_f16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2phl() { unsafe { @@ -1350,6 +2201,25 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2phl() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2phl(tile, i); + assert_eq!( + *row.as_f16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2bf16h() { unsafe { @@ -1402,6 +2272,29 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2bf16h() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2bf16h(tile, i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + 0 + } else { + _mm_cvtness_sbh(i as _).to_bits() + }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2bf16l() { unsafe { @@ -1454,6 +2347,29 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2bf16l() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2bf16l(tile, i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + _mm_cvtness_sbh(i as _).to_bits() + } else { + 0 + }) + ); + } + } + } + #[simd_test(enable = "amx-tf32")] fn test_tile_mmultf32ps() { unsafe { @@ -1480,4 +2396,26 @@ mod tests { assert_eq!(res, expected); } } + + #[simd_test(enable = "amx-tf32")] + fn test__tile_mmultf32ps() { + unsafe { + _init_amx(); + let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _]; + let mut res = [[0.0; 16]; 16]; + + let mut tile_a = __tile1024i::zeroed(16, 64); + let mut tile_b = __tile1024i::zeroed(16, 64); + let mut tile_c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut tile_a, a.as_ptr().cast(), 64); + __tile_loadd(&mut tile_b, b.as_ptr().cast(), 64); + __tile_mmultf32ps(&mut tile_c, tile_a, tile_b); + __tile_stored(res.as_mut_ptr().cast(), 64, tile_c); + + let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32)); + assert_eq!(res, expected); + } + } } diff --git a/crates/core_arch/src/x86_64/mod.rs b/crates/core_arch/src/x86_64/mod.rs index 46384176e0..ffc2daaefa 100644 --- a/crates/core_arch/src/x86_64/mod.rs +++ b/crates/core_arch/src/x86_64/mod.rs @@ -3,6 +3,20 @@ #[macro_use] mod macros; +// Any 1024-byte vector should work +type Tile = crate::core_arch::simd::Simd; + +/// A tile register, used by AMX instructions. +// TODO: add more docs +#[derive(Copy, Clone, Debug)] +#[allow(non_camel_case_types)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub struct __tile1024i { + pub rows: u16, + pub cols: u16, + tile: Tile, +} + mod fxsr; #[stable(feature = "simd_x86", since = "1.27.0")] pub use self::fxsr::*; diff --git a/crates/stdarch-test/src/lib.rs b/crates/stdarch-test/src/lib.rs index ecaf95f617..c468ebd12b 100644 --- a/crates/stdarch-test/src/lib.rs +++ b/crates/stdarch-test/src/lib.rs @@ -172,6 +172,10 @@ pub fn assert(shim_addr: usize, fnname: &str, expected: &str) { // vst1q_p64_x4_nop : #instructions = 33 >= 22 (limit) "nop" if fnname.contains("vst1q_p64") => 34, + // AMX intrinsics generate a lot of move instructions to load/store the tile registers + // due to Rust ABI + _ if fnname.contains("___tile") => 165, + // Original limit was 20 instructions, but ARM DSP Intrinsics // are exactly 20 instructions long. So, bump the limit to 22 // instead of adding here a long list of exceptions. diff --git a/crates/stdarch-verify/src/lib.rs b/crates/stdarch-verify/src/lib.rs index f7304ab326..5412ab466a 100644 --- a/crates/stdarch-verify/src/lib.rs +++ b/crates/stdarch-verify/src/lib.rs @@ -202,6 +202,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream { "_MM_MANTISSA_NORM_ENUM" => quote! { &MM_MANTISSA_NORM_ENUM }, "_MM_MANTISSA_SIGN_ENUM" => quote! { &MM_MANTISSA_SIGN_ENUM }, "_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM }, + "__tile1024i" => quote! { &TILE1024I }, "bool" => quote! { &BOOL }, "bf16" => quote! { &BF16 }, "f16" => quote! { &F16 }, diff --git a/crates/stdarch-verify/tests/x86-intel.rs b/crates/stdarch-verify/tests/x86-intel.rs index 024a873de1..aad19ca55a 100644 --- a/crates/stdarch-verify/tests/x86-intel.rs +++ b/crates/stdarch-verify/tests/x86-intel.rs @@ -62,6 +62,7 @@ static MM_CMPINT_ENUM: Type = Type::MM_CMPINT_ENUM; static MM_MANTISSA_NORM_ENUM: Type = Type::MM_MANTISSA_NORM_ENUM; static MM_MANTISSA_SIGN_ENUM: Type = Type::MM_MANTISSA_SIGN_ENUM; static MM_PERM_ENUM: Type = Type::MM_PERM_ENUM; +static TILE1024I: Type = Type::TILE1024I; static TUPLE: Type = Type::Tuple; static CPUID: Type = Type::CpuidResult; @@ -102,6 +103,7 @@ enum Type { CpuidResult, Never, Ordering, + TILE1024I, } stdarch_verify::x86_functions!(static FUNCTIONS); @@ -774,6 +776,7 @@ fn equate( (&Type::MMASK32, "__mmask32") => {} (&Type::MMASK16, "__mmask16") => {} (&Type::MMASK8, "__mmask8") => {} + (&Type::TILE1024I, "__tile1024i") => {} (&Type::MutPtr(_type), "void*") | (&Type::ConstPtr(_type), "void const*") => { let pointed_type = pointed_type(intrinsic)?; @@ -812,6 +815,7 @@ fn equate( (&Type::MutPtr(&Type::M512BH), "__m512bh*") => {} (&Type::MutPtr(&Type::M512I), "__m512i*") => {} (&Type::MutPtr(&Type::M512D), "__m512d*") => {} + (&Type::MutPtr(&Type::TILE1024I), "__tile1024i*") => {} (&Type::ConstPtr(&Type::PrimFloat(16)), "_Float16 const*") => {} (&Type::ConstPtr(&Type::PrimFloat(32)), "float const*") => {}