diff --git a/csrc/multi_tensor_apply.cuh b/csrc/multi_tensor_apply.cuh index 4a4795caf..cb95507c6 100644 --- a/csrc/multi_tensor_apply.cuh +++ b/csrc/multi_tensor_apply.cuh @@ -5,6 +5,8 @@ #include #include +#include + // #include // This header is the one-stop shop for all your multi-tensor apply needs. @@ -22,6 +24,17 @@ struct TensorListMetadata { int start_tensor_this_launch; }; +inline bool tensor_lists_require_64bit_indexing(const std::vector>& tensor_lists) { + for (const auto& tensor_list : tensor_lists) { + for (const auto& tensor : tensor_list) { + if (tensor.numel() > INT_MAX) { + return true; + } + } + } + return false; +} + template __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop_flag, T tl, U callable, ArgTypes... args) { diff --git a/csrc/multi_tensor_l2norm_kernel.cu b/csrc/multi_tensor_l2norm_kernel.cu index 5b3c477d6..a1c39cead 100644 --- a/csrc/multi_tensor_l2norm_kernel.cu +++ b/csrc/multi_tensor_l2norm_kernel.cu @@ -25,18 +25,18 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } -template +template struct L2NormFunctor { - __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, float* output, float* output_per_tensor, bool per_tensor, int max_chunks_per_tensor) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t tensor_loc = static_cast(tl.block_to_tensor[blockIdx.x]); + index_t chunk_idx = static_cast(tl.block_to_chunk[blockIdx.x]); + index_t n = tl.sizes[tensor_loc]; x_t* x = (x_t*)tl.addresses[0][tensor_loc]; x += chunk_idx * chunk_size; @@ -54,9 +54,9 @@ struct L2NormFunctor { // to make things simple, we put aligned case in a different code path if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { - for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { + for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0, i_start); + load_store(r_x, x, 0, static_cast(i_start)); #pragma unroll for (int ii = 0; ii < ILP; ii++) { float next = static_cast(r_x[ii]); @@ -64,10 +64,10 @@ struct L2NormFunctor { } } } else { - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; + index_t i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { float next = static_cast(x[i]); vals[ii] += next * next; @@ -90,18 +90,18 @@ struct L2NormFunctor { } }; -template +template struct UnscaleL2NormFunctor { - __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, const float* inv_scale, float* output, float* output_per_tensor, bool per_tensor, int max_chunks_per_tensor) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t tensor_loc = static_cast(tl.block_to_tensor[blockIdx.x]); + index_t chunk_idx = static_cast(tl.block_to_chunk[blockIdx.x]); + index_t n = tl.sizes[tensor_loc]; x_t* x = (x_t*)tl.addresses[0][tensor_loc]; x += chunk_idx * chunk_size; @@ -119,9 +119,9 @@ struct UnscaleL2NormFunctor { // to make things simple, we put aligned case in a different code path if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { - for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { + for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0, i_start); + load_store(r_x, x, 0, static_cast(i_start)); #pragma unroll for (int ii = 0; ii < ILP; ii++) { float next = static_cast(r_x[ii]) * (*inv_scale); @@ -129,10 +129,10 @@ struct UnscaleL2NormFunctor { } } } else { - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; + index_t i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { float next = static_cast(x[i]) * (*inv_scale); vals[ii] += next * next; @@ -156,18 +156,18 @@ struct UnscaleL2NormFunctor { }; // Probably better to template, but since we are not likely to support other norm -template +template struct MaxNormFunctor { - __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, float* output, float* output_per_tensor, bool per_tensor, int max_chunks_per_tensor) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t tensor_loc = static_cast(tl.block_to_tensor[blockIdx.x]); + index_t chunk_idx = static_cast(tl.block_to_chunk[blockIdx.x]); + index_t n = tl.sizes[tensor_loc]; x_t* x = (x_t*)tl.addresses[0][tensor_loc]; x += chunk_idx * chunk_size; @@ -185,9 +185,9 @@ struct MaxNormFunctor { // to make things simple, we put aligned case in a different code path if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { - for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { + for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0, i_start); + load_store(r_x, x, 0, static_cast(i_start)); #pragma unroll for (int ii = 0; ii < ILP; ii++) { float next = static_cast(r_x[ii]); @@ -195,10 +195,10 @@ struct MaxNormFunctor { } } } else { - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; + index_t i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { float next = static_cast(x[i]); vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); @@ -291,6 +291,7 @@ std::tuple multi_tensor_l2norm_cuda(int chunk_size, at:: std::vector> tensor_lists, at::optional per_tensor_python) { bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; + bool requires_64bit_indexing = tensor_lists_require_64bit_indexing(tensor_lists); auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); auto output = at::zeros({320}, float_options); @@ -314,9 +315,16 @@ std::tuple multi_tensor_l2norm_cuda(int chunk_size, at:: DISPATCH_FLOAT_HALF_AND_BFLOAT( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", - multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), - output.data_ptr(), per_tensor ? output_per_tensor.data_ptr() : nullptr, - per_tensor, max_chunks_per_tensor);) + if (requires_64bit_indexing) { + multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, per_tensor, + max_chunks_per_tensor); + } else { + multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), + output.data_ptr(), per_tensor ? output_per_tensor.data_ptr() : nullptr, + per_tensor, max_chunks_per_tensor); + }) AT_CUDA_CHECK(cudaGetLastError()); // AT_CUDA_CHECK(cudaDeviceSynchronize()); @@ -339,6 +347,7 @@ std::tuple multi_tensor_unscale_l2norm_cuda(int chunk_si at::Tensor inv_scale, at::optional per_tensor_python) { bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; + bool requires_64bit_indexing = tensor_lists_require_64bit_indexing(tensor_lists); auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); auto output = at::zeros({320}, float_options); @@ -362,10 +371,17 @@ std::tuple multi_tensor_unscale_l2norm_cuda(int chunk_si DISPATCH_FLOAT_HALF_AND_BFLOAT( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda", - multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor(), - inv_scale.data_ptr(), output.data_ptr(), - per_tensor ? output_per_tensor.data_ptr() : nullptr, per_tensor, - max_chunks_per_tensor);) + if (requires_64bit_indexing) { + multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + UnscaleL2NormFunctor(), inv_scale.data_ptr(), + output.data_ptr(), per_tensor ? output_per_tensor.data_ptr() : nullptr, + per_tensor, max_chunks_per_tensor); + } else { + multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + UnscaleL2NormFunctor(), inv_scale.data_ptr(), + output.data_ptr(), per_tensor ? output_per_tensor.data_ptr() : nullptr, + per_tensor, max_chunks_per_tensor); + }) AT_CUDA_CHECK(cudaGetLastError()); // AT_CUDA_CHECK(cudaDeviceSynchronize()); @@ -390,6 +406,7 @@ std::tuple multi_tensor_unscale_l2norm_cuda(int chunk_si void multi_tensor_norm_out_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor out, const float alpha, const float beta, const int norm_type) { auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + bool requires_64bit_indexing = tensor_lists_require_64bit_indexing(tensor_lists); TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors"); // we don't need global thus uses empty here auto output = at::empty({320}, float_options); @@ -410,16 +427,29 @@ void multi_tensor_norm_out_cuda(int chunk_size, at::Tensor noop_flag, std::vecto output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); if (norm_type == 0) { - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", - multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - MaxNormFunctor(), output.data_ptr(), - output_per_tensor.data_ptr(), true, max_chunks_per_tensor);) + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", + if (requires_64bit_indexing) { + multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + MaxNormFunctor(), output.data_ptr(), + output_per_tensor.data_ptr(), true, max_chunks_per_tensor); + } else { + multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, MaxNormFunctor(), + output.data_ptr(), output_per_tensor.data_ptr(), true, + max_chunks_per_tensor); + }) } else { DISPATCH_FLOAT_HALF_AND_BFLOAT( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", - multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), - output.data_ptr(), output_per_tensor.data_ptr(), true, - max_chunks_per_tensor);) + if (requires_64bit_indexing) { + multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.data_ptr(), + output_per_tensor.data_ptr(), true, max_chunks_per_tensor); + } else { + multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), + output.data_ptr(), output_per_tensor.data_ptr(), true, + max_chunks_per_tensor); + }) } AT_CUDA_CHECK(cudaGetLastError()); diff --git a/csrc/multi_tensor_l2norm_kernel_mp.cu b/csrc/multi_tensor_l2norm_kernel_mp.cu index 586839ac6..f829c51a1 100644 --- a/csrc/multi_tensor_l2norm_kernel_mp.cu +++ b/csrc/multi_tensor_l2norm_kernel_mp.cu @@ -25,18 +25,18 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } -template +template struct L2NormFunctor { - __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, float* output, float* output_per_tensor, bool per_tensor, int max_chunks_per_tensor) { if (*noop_gmem) { return; } - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t tensor_loc = static_cast(tl.block_to_tensor[blockIdx.x]); + index_t chunk_idx = static_cast(tl.block_to_chunk[blockIdx.x]); + index_t n = tl.sizes[tensor_loc]; x_t* x = (x_t*)tl.addresses[0][tensor_loc]; x += chunk_idx * chunk_size; @@ -54,9 +54,9 @@ struct L2NormFunctor { // to make things simple, we put aligned case in a different code path if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { - for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { + for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0, i_start); + load_store(r_x, x, 0, static_cast(i_start)); #pragma unroll for (int ii = 0; ii < ILP; ii++) { float next = static_cast(r_x[ii]); @@ -64,10 +64,10 @@ struct L2NormFunctor { } } } else { - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; + index_t i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { float next = static_cast(x[i]); vals[ii] += next * next; @@ -122,6 +122,7 @@ std::tuple multi_tensor_l2norm_mp_cuda(int chunk_size, a std::vector> tensor_lists, at::optional per_tensor_python) { bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; + bool requires_64bit_indexing = tensor_lists_require_64bit_indexing(tensor_lists); auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); auto output = at::zeros({320}, float_options); @@ -145,9 +146,16 @@ std::tuple multi_tensor_l2norm_mp_cuda(int chunk_size, a DISPATCH_FLOAT_HALF_AND_BFLOAT( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_mp_cuda", - multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), - output.data_ptr(), per_tensor ? output_per_tensor.data_ptr() : nullptr, - per_tensor, max_chunks_per_tensor);) + if (requires_64bit_indexing) { + multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, per_tensor, + max_chunks_per_tensor); + } else { + multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), + output.data_ptr(), per_tensor ? output_per_tensor.data_ptr() : nullptr, + per_tensor, max_chunks_per_tensor); + }) AT_CUDA_CHECK(cudaGetLastError()); // AT_CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/csrc/multi_tensor_l2norm_scale_kernel.cu b/csrc/multi_tensor_l2norm_scale_kernel.cu index 17ee1e30d..ce620c61b 100644 --- a/csrc/multi_tensor_l2norm_scale_kernel.cu +++ b/csrc/multi_tensor_l2norm_scale_kernel.cu @@ -25,18 +25,18 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } -template +template struct L2NormScaleFunctor { - __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<2>& tl, + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<2>& tl, float* output, float* output_per_tensor, float scale, bool per_tensor, int max_chunks_per_tensor) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t tensor_loc = static_cast(tl.block_to_tensor[blockIdx.x]); + index_t chunk_idx = static_cast(tl.block_to_chunk[blockIdx.x]); + index_t n = tl.sizes[tensor_loc]; in_t* in = (in_t*)tl.addresses[0][tensor_loc]; in += chunk_idx * chunk_size; @@ -59,9 +59,9 @@ struct L2NormScaleFunctor { // to make things simple, we put aligned case in a different code path if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) { - for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { + for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_in, in, 0, i_start); + load_store(r_in, in, 0, static_cast(i_start)); #pragma unroll for (int ii = 0; ii < ILP; ii++) { float next = static_cast(r_in[ii]); @@ -69,14 +69,14 @@ struct L2NormScaleFunctor { vals[ii] += next * next; // finite = finite && isfinite(r_in[ii]); } - load_store(out, r_out, i_start, 0); + load_store(out, r_out, static_cast(i_start), 0); } } else { - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_in[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; + index_t i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { r_in[ii] = in[i]; float next = static_cast(in[i]); @@ -110,18 +110,18 @@ struct L2NormScaleFunctor { } }; // Probably better to template, but since we are not likely to support other norm -template +template struct MaxNormFunctor { - __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl, float* output, float* output_per_tensor, bool per_tensor, int max_chunks_per_tensor) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t tensor_loc = static_cast(tl.block_to_tensor[blockIdx.x]); + index_t chunk_idx = static_cast(tl.block_to_chunk[blockIdx.x]); + index_t n = tl.sizes[tensor_loc]; x_t* x = (x_t*)tl.addresses[0][tensor_loc]; x += chunk_idx * chunk_size; @@ -139,9 +139,9 @@ struct MaxNormFunctor { // to make things simple, we put aligned case in a different code path if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { - for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { + for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0, i_start); + load_store(r_x, x, 0, static_cast(i_start)); #pragma unroll for (int ii = 0; ii < ILP; ii++) { float next = static_cast(r_x[ii]); @@ -149,10 +149,10 @@ struct MaxNormFunctor { } } } else { - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; + index_t i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { float next = static_cast(x[i]); vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); @@ -204,6 +204,7 @@ std::tuple multi_tensor_l2norm_scale_cuda(int chunk_size std::vector> tensor_lists, float scale, at::optional per_tensor_python) { bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; + bool requires_64bit_indexing = tensor_lists_require_64bit_indexing(tensor_lists); auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); auto output = at::zeros({320}, float_options); @@ -229,10 +230,17 @@ std::tuple multi_tensor_l2norm_scale_cuda(int chunk_size tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_scale_cuda", DISPATCH_FLOAT_AND_HALF( tensor_lists[1][0].scalar_type(), 1, "multi_tensor_l2norm_scale_cuda", - multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - L2NormScaleFunctor(), output.data_ptr(), - per_tensor ? output_per_tensor.data_ptr() : nullptr, scale, per_tensor, - max_chunks_per_tensor);)) + if (requires_64bit_indexing) { + multi_tensor_apply<2>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + L2NormScaleFunctor(), output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, scale, per_tensor, + max_chunks_per_tensor); + } else { + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormScaleFunctor(), output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, scale, per_tensor, + max_chunks_per_tensor); + })) AT_CUDA_CHECK(cudaGetLastError()); // AT_CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/tests/L0/run_optimizers/test_large_tensor_l2norm.py b/tests/L0/run_optimizers/test_large_tensor_l2norm.py new file mode 100644 index 000000000..02eb7f8da --- /dev/null +++ b/tests/L0/run_optimizers/test_large_tensor_l2norm.py @@ -0,0 +1,80 @@ +import unittest + +import torch + +try: + import amp_C + from apex.multi_tensor_apply import multi_tensor_applier +except ImportError: + HAS_APEX = False +else: + HAS_APEX = True + +from torch.testing._internal.common_device_type import largeTensorTest + +INT32_MAX = 2_147_483_647 +LARGE_NUMEL = INT32_MAX + 1 + + +@unittest.skipIf(not HAS_APEX, "`apex` is not found.") +class LargeTensorL2NormTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.noop_flag = torch.zeros([1], dtype=torch.int32, device="cuda") + + def _make_large_tensor(self, dtype=torch.float16): + tensor = torch.zeros(LARGE_NUMEL, dtype=dtype, device="cuda") + tensor[0] = 3 + tensor[-1] = 4 + return tensor + + @largeTensorTest("5GB", "cuda") + def test_multi_tensor_l2norm_large_tensor(self): + tensor = self._make_large_tensor(torch.float16) + + expected = torch.norm(tensor, 2.0).float().unsqueeze(0) + actual, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, + self.noop_flag, + [[tensor]], + False, + ) + + torch.testing.assert_close(actual, expected) + + @largeTensorTest("5GB", "cuda") + def test_multi_tensor_l2norm_mp_large_tensor(self): + tensor = self._make_large_tensor(torch.float16) + + expected = torch.norm(tensor, 2.0).float().unsqueeze(0) + actual, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm_mp, + self.noop_flag, + [[tensor]], + False, + ) + + torch.testing.assert_close(actual, expected) + + @largeTensorTest("9GB", "cuda") + def test_multi_tensor_l2norm_scale_large_tensor(self): + tensor = self._make_large_tensor(torch.float16) + scaled = torch.empty_like(tensor) + scale = 0.5 + + expected = torch.norm(tensor, 2.0).float().unsqueeze(0) + actual, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm_scale, + self.noop_flag, + [[tensor], [scaled]], + scale, + False, + ) + + torch.testing.assert_close(actual, expected) + self.assertEqual(scaled[0].item(), 1.5) + self.assertEqual(scaled[-1].item(), 2.0) + + +if __name__ == "__main__": + unittest.main()