Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions csrc/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>

#include <climits>

// #include <iostream>

// This header is the one-stop shop for all your multi-tensor apply needs.
Expand All @@ -22,6 +24,17 @@ struct TensorListMetadata {
int start_tensor_this_launch;
};

inline bool tensor_lists_require_64bit_indexing(const std::vector<std::vector<at::Tensor>>& tensor_lists) {
for (const auto& tensor_list : tensor_lists) {
for (const auto& tensor : tensor_list) {
if (tensor.numel() > INT_MAX) {
return true;
}
}
}
return false;
}

template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop_flag, T tl, U callable,
ArgTypes... args) {
Expand Down
112 changes: 71 additions & 41 deletions csrc/multi_tensor_l2norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

template <typename x_t>
template <typename x_t, typename index_t>
struct L2NormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,
float* output, float* output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
index_t tensor_loc = static_cast<index_t>(tl.block_to_tensor[blockIdx.x]);
index_t chunk_idx = static_cast<index_t>(tl.block_to_chunk[blockIdx.x]);
index_t n = tl.sizes[tensor_loc];

x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;
Expand All @@ -54,20 +54,20 @@ struct L2NormFunctor {

// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {
for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
load_store(r_x, x, 0, static_cast<int>(i_start));
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]);
vals[ii] += next * next;
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
index_t i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]);
vals[ii] += next * next;
Expand All @@ -90,18 +90,18 @@ struct L2NormFunctor {
}
};

template <typename x_t>
template <typename x_t, typename index_t>
struct UnscaleL2NormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,
const float* inv_scale, float* output, float* output_per_tensor,
bool per_tensor, int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
index_t tensor_loc = static_cast<index_t>(tl.block_to_tensor[blockIdx.x]);
index_t chunk_idx = static_cast<index_t>(tl.block_to_chunk[blockIdx.x]);
index_t n = tl.sizes[tensor_loc];

x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;
Expand All @@ -119,20 +119,20 @@ struct UnscaleL2NormFunctor {

// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {
for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
load_store(r_x, x, 0, static_cast<int>(i_start));
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]) * (*inv_scale);
vals[ii] += next * next;
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
index_t i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]) * (*inv_scale);
vals[ii] += next * next;
Expand All @@ -156,18 +156,18 @@ struct UnscaleL2NormFunctor {
};

// Probably better to template, but since we are not likely to support other norm
template <typename x_t>
template <typename x_t, typename index_t>
struct MaxNormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,
float* output, float* output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
index_t tensor_loc = static_cast<index_t>(tl.block_to_tensor[blockIdx.x]);
index_t chunk_idx = static_cast<index_t>(tl.block_to_chunk[blockIdx.x]);
index_t n = tl.sizes[tensor_loc];

x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;
Expand All @@ -185,20 +185,20 @@ struct MaxNormFunctor {

// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {
for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
load_store(r_x, x, 0, static_cast<int>(i_start));
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
index_t i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
Expand Down Expand Up @@ -291,6 +291,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(int chunk_size, at::
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
bool requires_64bit_indexing = tensor_lists_require_64bit_indexing(tensor_lists);

auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
Expand All @@ -314,9 +315,16 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(int chunk_size, at::

DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0>(),
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
per_tensor, max_chunks_per_tensor);)
if (requires_64bit_indexing) {
multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0, int64_t>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);
} else {
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0, int32_t>(),
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
per_tensor, max_chunks_per_tensor);
})

AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
Expand All @@ -339,6 +347,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(int chunk_si
at::Tensor inv_scale,
at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
bool requires_64bit_indexing = tensor_lists_require_64bit_indexing(tensor_lists);

auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
Expand All @@ -362,10 +371,17 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(int chunk_si

DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<scalar_t_0>(),
inv_scale.data_ptr<float>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);)
if (requires_64bit_indexing) {
multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
UnscaleL2NormFunctor<scalar_t_0, int64_t>(), inv_scale.data_ptr<float>(),
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
per_tensor, max_chunks_per_tensor);
} else {
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
UnscaleL2NormFunctor<scalar_t_0, int32_t>(), inv_scale.data_ptr<float>(),
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
per_tensor, max_chunks_per_tensor);
})

AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
Expand All @@ -390,6 +406,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(int chunk_si
void multi_tensor_norm_out_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor out, const float alpha, const float beta, const int norm_type) {
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
bool requires_64bit_indexing = tensor_lists_require_64bit_indexing(tensor_lists);
TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors");
// we don't need global thus uses empty here
auto output = at::empty({320}, float_options);
Expand All @@ -410,16 +427,29 @@ void multi_tensor_norm_out_cuda(int chunk_size, at::Tensor noop_flag, std::vecto
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);

if (norm_type == 0) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
MaxNormFunctor<scalar_t_0>(), output.data_ptr<float>(),
output_per_tensor.data_ptr<float>(), true, max_chunks_per_tensor);)
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
if (requires_64bit_indexing) {
multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
MaxNormFunctor<scalar_t_0, int64_t>(), output.data_ptr<float>(),
output_per_tensor.data_ptr<float>(), true, max_chunks_per_tensor);
} else {
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, MaxNormFunctor<scalar_t_0, int32_t>(),
output.data_ptr<float>(), output_per_tensor.data_ptr<float>(), true,
max_chunks_per_tensor);
})
} else {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0>(),
output.data_ptr<float>(), output_per_tensor.data_ptr<float>(), true,
max_chunks_per_tensor);)
if (requires_64bit_indexing) {
multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0, int64_t>(), output.data_ptr<float>(),
output_per_tensor.data_ptr<float>(), true, max_chunks_per_tensor);
} else {
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0, int32_t>(),
output.data_ptr<float>(), output_per_tensor.data_ptr<float>(), true,
max_chunks_per_tensor);
})
}
AT_CUDA_CHECK(cudaGetLastError());

Expand Down
32 changes: 20 additions & 12 deletions csrc/multi_tensor_l2norm_kernel_mp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

template <typename x_t>
template <typename x_t, typename index_t>
struct L2NormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,
float* output, float* output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
if (*noop_gmem) {
return;
}

int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
index_t tensor_loc = static_cast<index_t>(tl.block_to_tensor[blockIdx.x]);
index_t chunk_idx = static_cast<index_t>(tl.block_to_chunk[blockIdx.x]);
index_t n = tl.sizes[tensor_loc];

x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;
Expand All @@ -54,20 +54,20 @@ struct L2NormFunctor {

// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {
for (index_t i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
load_store(r_x, x, 0, static_cast<int>(i_start));
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]);
vals[ii] += next * next;
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
index_t i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]);
vals[ii] += next * next;
Expand Down Expand Up @@ -122,6 +122,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(int chunk_size, a
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
bool requires_64bit_indexing = tensor_lists_require_64bit_indexing(tensor_lists);

auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
Expand All @@ -145,9 +146,16 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(int chunk_size, a

DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_mp_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0>(),
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
per_tensor, max_chunks_per_tensor);)
if (requires_64bit_indexing) {
multi_tensor_apply<1>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0, int64_t>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);
} else {
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0, int32_t>(),
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
per_tensor, max_chunks_per_tensor);
})

AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
Expand Down
Loading
Loading