Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 61 additions & 7 deletions csrc/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
@@ -1,14 +1,50 @@
#ifdef TORCH_STABLE_ONLY
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/types.h>

#include "stable_abi_utils.h"
#else
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#endif

#include <assert.h>

// #include <iostream>

// This header is the one-stop shop for all your multi-tensor apply needs.

// Namespace aliases for dual-build support
#ifdef TORCH_STABLE_ONLY
namespace apex_tensor {
using Tensor = torch::stable::Tensor;
using MemoryFormat = apex::stable::MemoryFormat;
namespace device = torch::headeronly;

inline bool is_contiguous_any_format(const Tensor& t) {
return apex::stable::is_contiguous(t, MemoryFormat::Contiguous) ||
apex::stable::is_contiguous(t, MemoryFormat::ChannelsLast) ||
apex::stable::is_contiguous(t, MemoryFormat::ChannelsLast3d);
}
} // namespace apex_tensor
#else
namespace apex_tensor {
using Tensor = at::Tensor;
using MemoryFormat = at::MemoryFormat;
namespace device = at;

inline bool is_contiguous_any_format(const Tensor& t) {
return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast) ||
t.is_contiguous(at::MemoryFormat::ChannelsLast3d);
}
} // namespace apex_tensor
#endif

// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
Expand All @@ -30,21 +66,20 @@ __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop
}

template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists, T callable, ArgTypes... args) {
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const apex_tensor::Tensor& noop_flag,
const std::vector<std::vector<apex_tensor::Tensor>>& tensor_lists, T callable,
ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
TORCH_CHECK(ref_device.type() == apex_tensor::device::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
bool contiguous_memory = apex_tensor::is_contiguous_any_format(tensor_lists[l][t]);
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
Expand All @@ -55,8 +90,22 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor

TensorListMetadata<depth> tl;

#ifdef TORCH_STABLE_ONLY
// Stable ABI: device guard and stream management
auto device = tensor_lists[0][0].device();
int32_t device_index = static_cast<int32_t>(device.index());

// Use stable ABI DeviceGuard for proper device context
torch::stable::accelerator::DeviceGuard device_guard(device_index);

// Get current CUDA stream using stable ABI C API
void* stream_ptr = nullptr;
auto err = aoti_torch_get_current_cuda_stream(device_index, &stream_ptr);
cudaStream_t stream = (err == AOTI_TORCH_SUCCESS) ? reinterpret_cast<cudaStream_t>(stream_ptr) : nullptr;
#else
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
#endif

tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
Expand All @@ -82,7 +131,12 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(chunk_size, noop_flag.data_ptr<int>(), tl,
callable, args...);

#ifdef TORCH_STABLE_ONLY
cudaError_t err = cudaGetLastError();
apex::stable::STD_TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: %s", cudaGetErrorString(err));
#else
AT_CUDA_CHECK(cudaGetLastError());
#endif

// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
Expand Down
258 changes: 258 additions & 0 deletions csrc/stable_abi_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
#pragma once

#ifdef TORCH_STABLE_ONLY

// Stable ABI headers
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/dispatcher.h>
#include <torch/csrc/stable/ivalue.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/types.h>

namespace apex {
namespace stable {

// ============================================================================
// MemoryFormat Contiguity Checking Workaround
// ============================================================================
// The stable ABI's Tensor::is_contiguous() doesn't support MemoryFormat
// parameter. This provides a workaround for checking different memory layouts.

enum class MemoryFormat { Contiguous, ChannelsLast, ChannelsLast3d, Preserve };

// Check if a tensor is contiguous in a specific memory format
inline bool is_contiguous(const torch::stable::Tensor& tensor, MemoryFormat format) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given torch::stable would be under development, wouldn't it sound perhaps legit to wait this gets implemented in the upstream?

using namespace torch::stable;

// For standard contiguous check, use the stable ABI method
if (format == MemoryFormat::Contiguous) {
return tensor.is_contiguous();
}

// For ChannelsLast and ChannelsLast3d, we need custom logic
// Get tensor properties
auto sizes = tensor.sizes();
auto strides = tensor.strides();
int64_t ndim = tensor.dim();

if (format == MemoryFormat::ChannelsLast) {
// NHWC format requires ndim == 4
if (ndim != 4) return false;

// For ChannelsLast (NHWC), strides should follow: C=1, W=C, H=W*W_size, N=H*H_size
// Expected stride order: strides[1] < strides[3] < strides[2] < strides[0]
int64_t N = sizes[0], C = sizes[1], H = sizes[2], W = sizes[3];
int64_t stride_c = strides[1];
int64_t stride_w = strides[3];
int64_t stride_h = strides[2];
int64_t stride_n = strides[0];

// Check if strides match NHWC layout
return (stride_c == 1) && (stride_w == C) && (stride_h == W * C) && (stride_n == H * W * C);
}

if (format == MemoryFormat::ChannelsLast3d) {
// NDHWC format requires ndim == 5
if (ndim != 5) return false;

// For ChannelsLast3d (NDHWC), similar logic for 5D tensors
int64_t N = sizes[0], C = sizes[1], D = sizes[2], H = sizes[3], W = sizes[4];
int64_t stride_c = strides[1];
int64_t stride_w = strides[4];
int64_t stride_h = strides[3];
int64_t stride_d = strides[2];
int64_t stride_n = strides[0];

// Check if strides match NDHWC layout
return (stride_c == 1) && (stride_w == C) && (stride_h == W * C) && (stride_d == H * W * C) &&
(stride_n == D * H * W * C);
}

if (format == MemoryFormat::Preserve) {
// Preserve means "keep current format" - not applicable for checking contiguity
return false;
}

return false;
}

// ============================================================================
// Type Conversion Utilities
// ============================================================================

// Convert stable ScalarType to string for error messages
inline const char* scalar_type_name(torch::headeronly::ScalarType type) {
using namespace torch::headeronly;
switch (type) {
case kByte:
return "Byte";
case kChar:
return "Char";
case kShort:
return "Short";
case kInt:
return "Int";
case kLong:
return "Long";
case kHalf:
return "Half";
case kFloat:
return "Float";
case kDouble:
return "Double";
case kBool:
return "Bool";
case kBFloat16:
return "BFloat16";
case kFloat8_e5m2:
return "Float8_e5m2";
case kFloat8_e4m3fn:
return "Float8_e4m3fn";
default:
return "Unknown";
}
}

// ============================================================================
// Error Checking Macros
// ============================================================================

#define STD_TORCH_CHECK(cond, ...) \
do { \
if (!(cond)) { \
char buffer[1024]; \
snprintf(buffer, sizeof(buffer), __VA_ARGS__); \
throw std::runtime_error(buffer); \
} \
} while (0)

#define STD_TORCH_CHECK_EQ(a, b, ...) STD_TORCH_CHECK((a) == (b), __VA_ARGS__)
#define STD_TORCH_CHECK_NE(a, b, ...) STD_TORCH_CHECK((a) != (b), __VA_ARGS__)
#define STD_TORCH_CHECK_GT(a, b, ...) STD_TORCH_CHECK((a) > (b), __VA_ARGS__)
#define STD_TORCH_CHECK_GE(a, b, ...) STD_TORCH_CHECK((a) >= (b), __VA_ARGS__)
#define STD_TORCH_CHECK_LT(a, b, ...) STD_TORCH_CHECK((a) < (b), __VA_ARGS__)
#define STD_TORCH_CHECK_LE(a, b, ...) STD_TORCH_CHECK((a) <= (b), __VA_ARGS__)

// ============================================================================
// Boxed Calling Convention Helpers
// ============================================================================

// Helper to extract tensor from IValue stack
inline torch::stable::Tensor tensor_from_stack(torch::stable::StableIValue* stack, int idx) {
return stack[idx].toTensor();
}

// Helper to extract int64 from IValue stack
inline int64_t int64_from_stack(torch::stable::StableIValue* stack, int idx) { return stack[idx].toInt(); }

// Helper to extract double from IValue stack
inline double double_from_stack(torch::stable::StableIValue* stack, int idx) { return stack[idx].toDouble(); }

// Helper to extract bool from IValue stack
inline bool bool_from_stack(torch::stable::StableIValue* stack, int idx) { return stack[idx].toBool(); }

// Helper to extract optional tensor from IValue stack
inline std::optional<torch::stable::Tensor> optional_tensor_from_stack(torch::stable::StableIValue* stack, int idx) {
if (stack[idx].isNone()) {
return std::nullopt;
}
return stack[idx].toTensor();
}

// Helper to extract tensor list from IValue stack
inline std::vector<torch::stable::Tensor> tensor_list_from_stack(torch::stable::StableIValue* stack, int idx) {
auto list = stack[idx].toList();
std::vector<torch::stable::Tensor> result;
result.reserve(list.size());
for (size_t i = 0; i < list.size(); ++i) {
result.push_back(list.get(i).toTensor());
}
return result;
}

// Helper to put tensor to IValue stack
inline void tensor_to_stack(torch::stable::StableIValue* stack, int idx, const torch::stable::Tensor& tensor) {
stack[idx] = torch::stable::StableIValue::from(tensor);
}

// Helper to put tuple to IValue stack
inline void tuple_to_stack(torch::stable::StableIValue* stack, int idx,
const std::vector<torch::stable::Tensor>& tensors) {
std::vector<torch::stable::StableIValue> ivalues;
ivalues.reserve(tensors.size());
for (const auto& t : tensors) {
ivalues.push_back(torch::stable::StableIValue::from(t));
}
stack[idx] = torch::stable::StableIValue::fromTuple(ivalues);
}

// Helper to put list to IValue stack
inline void tensor_list_to_stack(torch::stable::StableIValue* stack, int idx,
const std::vector<torch::stable::Tensor>& tensors) {
std::vector<torch::stable::StableIValue> ivalues;
ivalues.reserve(tensors.size());
for (const auto& t : tensors) {
ivalues.push_back(torch::stable::StableIValue::from(t));
}
stack[idx] = torch::stable::StableIValue::fromList(ivalues);
}

// ============================================================================
// Device and Stream Utilities
// ============================================================================

// Check if tensor is on CUDA
inline bool is_cuda(const torch::stable::Tensor& tensor) { return tensor.device().type() == torch::headeronly::kCUDA; }

// Get CUDA device index
inline int64_t get_device_index(const torch::stable::Tensor& tensor) {
STD_TORCH_CHECK(is_cuda(tensor), "Tensor must be on CUDA device");
return tensor.device().index();
}

// ============================================================================
// Common Tensor Checks
// ============================================================================

inline void check_cuda(const torch::stable::Tensor& tensor, const char* name) {
STD_TORCH_CHECK(is_cuda(tensor), "%s must be a CUDA tensor", name);
}

inline void check_contiguous(const torch::stable::Tensor& tensor, const char* name) {
STD_TORCH_CHECK(tensor.is_contiguous(), "%s must be contiguous", name);
}

inline void check_same_device(const torch::stable::Tensor& t1, const torch::stable::Tensor& t2, const char* name1,
const char* name2) {
STD_TORCH_CHECK(t1.device() == t2.device(), "%s and %s must be on the same device", name1, name2);
}

inline void check_same_dtype(const torch::stable::Tensor& t1, const torch::stable::Tensor& t2, const char* name1,
const char* name2) {
STD_TORCH_CHECK(t1.scalar_type() == t2.scalar_type(), "%s and %s must have the same dtype, got %s and %s", name1,
name2, scalar_type_name(t1.scalar_type()), scalar_type_name(t2.scalar_type()));
}

} // namespace stable
} // namespace apex

#else // !TORCH_STABLE_ONLY

// When not using stable ABI, provide no-op definitions or traditional includes
#include <ATen/ATen.h>
#include <torch/extension.h>

namespace apex {
namespace stable {

// Map to traditional PyTorch MemoryFormat for non-stable builds
using MemoryFormat = at::MemoryFormat;

// Use traditional is_contiguous in non-stable builds
inline bool is_contiguous(const at::Tensor& tensor, at::MemoryFormat format) { return tensor.is_contiguous(format); }

} // namespace stable
} // namespace apex

#endif // TORCH_STABLE_ONLY
Loading