diff --git a/csrc/multi_tensor_apply.cuh b/csrc/multi_tensor_apply.cuh index 4a4795caf..d431f4c50 100644 --- a/csrc/multi_tensor_apply.cuh +++ b/csrc/multi_tensor_apply.cuh @@ -1,14 +1,50 @@ +#ifdef TORCH_STABLE_ONLY +#include +#include +#include +#include + +#include "stable_abi_utils.h" +#else #include #include #include #include -#include #include +#endif + +#include // #include // This header is the one-stop shop for all your multi-tensor apply needs. +// Namespace aliases for dual-build support +#ifdef TORCH_STABLE_ONLY +namespace apex_tensor { +using Tensor = torch::stable::Tensor; +using MemoryFormat = apex::stable::MemoryFormat; +namespace device = torch::headeronly; + +inline bool is_contiguous_any_format(const Tensor& t) { + return apex::stable::is_contiguous(t, MemoryFormat::Contiguous) || + apex::stable::is_contiguous(t, MemoryFormat::ChannelsLast) || + apex::stable::is_contiguous(t, MemoryFormat::ChannelsLast3d); +} +} // namespace apex_tensor +#else +namespace apex_tensor { +using Tensor = at::Tensor; +using MemoryFormat = at::MemoryFormat; +namespace device = at; + +inline bool is_contiguous_any_format(const Tensor& t) { + return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast) || + t.is_contiguous(at::MemoryFormat::ChannelsLast3d); +} +} // namespace apex_tensor +#endif + // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; @@ -30,21 +66,20 @@ __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop } template -void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor& noop_flag, - const std::vector>& tensor_lists, T callable, ArgTypes... args) { +void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const apex_tensor::Tensor& noop_flag, + const std::vector>& tensor_lists, T callable, + ArgTypes... args) { TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); int len0 = tensor_lists[0].size(); TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); auto ref_device = tensor_lists[0][0].device(); - TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + TORCH_CHECK(ref_device.type() == apex_tensor::device::kCUDA, "expected input to be on cuda"); for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices { TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); for (int t = 0; t < tensor_lists[l].size(); t++) { // TODO: Print which tensor fails. - bool contiguous_memory = tensor_lists[l][t].is_contiguous(); - contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || - tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d)); + bool contiguous_memory = apex_tensor::is_contiguous_any_format(tensor_lists[l][t]); TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor"); TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); @@ -55,8 +90,22 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor TensorListMetadata tl; +#ifdef TORCH_STABLE_ONLY + // Stable ABI: device guard and stream management + auto device = tensor_lists[0][0].device(); + int32_t device_index = static_cast(device.index()); + + // Use stable ABI DeviceGuard for proper device context + torch::stable::accelerator::DeviceGuard device_guard(device_index); + + // Get current CUDA stream using stable ABI C API + void* stream_ptr = nullptr; + auto err = aoti_torch_get_current_cuda_stream(device_index, &stream_ptr); + cudaStream_t stream = (err == AOTI_TORCH_SUCCESS) ? reinterpret_cast(stream_ptr) : nullptr; +#else const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); auto stream = at::cuda::getCurrentCUDAStream(); +#endif tl.start_tensor_this_launch = 0; int loc_block_info = 0; @@ -82,7 +131,12 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor multi_tensor_apply_kernel<<>>(chunk_size, noop_flag.data_ptr(), tl, callable, args...); +#ifdef TORCH_STABLE_ONLY + cudaError_t err = cudaGetLastError(); + apex::stable::STD_TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: %s", cudaGetErrorString(err)); +#else AT_CUDA_CHECK(cudaGetLastError()); +#endif // Reset. The control flow possibilities here make my brain hurt. loc_block_info = 0; diff --git a/csrc/stable_abi_utils.h b/csrc/stable_abi_utils.h new file mode 100644 index 000000000..1b7470967 --- /dev/null +++ b/csrc/stable_abi_utils.h @@ -0,0 +1,258 @@ +#pragma once + +#ifdef TORCH_STABLE_ONLY + +// Stable ABI headers +#include +#include +#include +#include +#include +#include + +namespace apex { +namespace stable { + +// ============================================================================ +// MemoryFormat Contiguity Checking Workaround +// ============================================================================ +// The stable ABI's Tensor::is_contiguous() doesn't support MemoryFormat +// parameter. This provides a workaround for checking different memory layouts. + +enum class MemoryFormat { Contiguous, ChannelsLast, ChannelsLast3d, Preserve }; + +// Check if a tensor is contiguous in a specific memory format +inline bool is_contiguous(const torch::stable::Tensor& tensor, MemoryFormat format) { + using namespace torch::stable; + + // For standard contiguous check, use the stable ABI method + if (format == MemoryFormat::Contiguous) { + return tensor.is_contiguous(); + } + + // For ChannelsLast and ChannelsLast3d, we need custom logic + // Get tensor properties + auto sizes = tensor.sizes(); + auto strides = tensor.strides(); + int64_t ndim = tensor.dim(); + + if (format == MemoryFormat::ChannelsLast) { + // NHWC format requires ndim == 4 + if (ndim != 4) return false; + + // For ChannelsLast (NHWC), strides should follow: C=1, W=C, H=W*W_size, N=H*H_size + // Expected stride order: strides[1] < strides[3] < strides[2] < strides[0] + int64_t N = sizes[0], C = sizes[1], H = sizes[2], W = sizes[3]; + int64_t stride_c = strides[1]; + int64_t stride_w = strides[3]; + int64_t stride_h = strides[2]; + int64_t stride_n = strides[0]; + + // Check if strides match NHWC layout + return (stride_c == 1) && (stride_w == C) && (stride_h == W * C) && (stride_n == H * W * C); + } + + if (format == MemoryFormat::ChannelsLast3d) { + // NDHWC format requires ndim == 5 + if (ndim != 5) return false; + + // For ChannelsLast3d (NDHWC), similar logic for 5D tensors + int64_t N = sizes[0], C = sizes[1], D = sizes[2], H = sizes[3], W = sizes[4]; + int64_t stride_c = strides[1]; + int64_t stride_w = strides[4]; + int64_t stride_h = strides[3]; + int64_t stride_d = strides[2]; + int64_t stride_n = strides[0]; + + // Check if strides match NDHWC layout + return (stride_c == 1) && (stride_w == C) && (stride_h == W * C) && (stride_d == H * W * C) && + (stride_n == D * H * W * C); + } + + if (format == MemoryFormat::Preserve) { + // Preserve means "keep current format" - not applicable for checking contiguity + return false; + } + + return false; +} + +// ============================================================================ +// Type Conversion Utilities +// ============================================================================ + +// Convert stable ScalarType to string for error messages +inline const char* scalar_type_name(torch::headeronly::ScalarType type) { + using namespace torch::headeronly; + switch (type) { + case kByte: + return "Byte"; + case kChar: + return "Char"; + case kShort: + return "Short"; + case kInt: + return "Int"; + case kLong: + return "Long"; + case kHalf: + return "Half"; + case kFloat: + return "Float"; + case kDouble: + return "Double"; + case kBool: + return "Bool"; + case kBFloat16: + return "BFloat16"; + case kFloat8_e5m2: + return "Float8_e5m2"; + case kFloat8_e4m3fn: + return "Float8_e4m3fn"; + default: + return "Unknown"; + } +} + +// ============================================================================ +// Error Checking Macros +// ============================================================================ + +#define STD_TORCH_CHECK(cond, ...) \ + do { \ + if (!(cond)) { \ + char buffer[1024]; \ + snprintf(buffer, sizeof(buffer), __VA_ARGS__); \ + throw std::runtime_error(buffer); \ + } \ + } while (0) + +#define STD_TORCH_CHECK_EQ(a, b, ...) STD_TORCH_CHECK((a) == (b), __VA_ARGS__) +#define STD_TORCH_CHECK_NE(a, b, ...) STD_TORCH_CHECK((a) != (b), __VA_ARGS__) +#define STD_TORCH_CHECK_GT(a, b, ...) STD_TORCH_CHECK((a) > (b), __VA_ARGS__) +#define STD_TORCH_CHECK_GE(a, b, ...) STD_TORCH_CHECK((a) >= (b), __VA_ARGS__) +#define STD_TORCH_CHECK_LT(a, b, ...) STD_TORCH_CHECK((a) < (b), __VA_ARGS__) +#define STD_TORCH_CHECK_LE(a, b, ...) STD_TORCH_CHECK((a) <= (b), __VA_ARGS__) + +// ============================================================================ +// Boxed Calling Convention Helpers +// ============================================================================ + +// Helper to extract tensor from IValue stack +inline torch::stable::Tensor tensor_from_stack(torch::stable::StableIValue* stack, int idx) { + return stack[idx].toTensor(); +} + +// Helper to extract int64 from IValue stack +inline int64_t int64_from_stack(torch::stable::StableIValue* stack, int idx) { return stack[idx].toInt(); } + +// Helper to extract double from IValue stack +inline double double_from_stack(torch::stable::StableIValue* stack, int idx) { return stack[idx].toDouble(); } + +// Helper to extract bool from IValue stack +inline bool bool_from_stack(torch::stable::StableIValue* stack, int idx) { return stack[idx].toBool(); } + +// Helper to extract optional tensor from IValue stack +inline std::optional optional_tensor_from_stack(torch::stable::StableIValue* stack, int idx) { + if (stack[idx].isNone()) { + return std::nullopt; + } + return stack[idx].toTensor(); +} + +// Helper to extract tensor list from IValue stack +inline std::vector tensor_list_from_stack(torch::stable::StableIValue* stack, int idx) { + auto list = stack[idx].toList(); + std::vector result; + result.reserve(list.size()); + for (size_t i = 0; i < list.size(); ++i) { + result.push_back(list.get(i).toTensor()); + } + return result; +} + +// Helper to put tensor to IValue stack +inline void tensor_to_stack(torch::stable::StableIValue* stack, int idx, const torch::stable::Tensor& tensor) { + stack[idx] = torch::stable::StableIValue::from(tensor); +} + +// Helper to put tuple to IValue stack +inline void tuple_to_stack(torch::stable::StableIValue* stack, int idx, + const std::vector& tensors) { + std::vector ivalues; + ivalues.reserve(tensors.size()); + for (const auto& t : tensors) { + ivalues.push_back(torch::stable::StableIValue::from(t)); + } + stack[idx] = torch::stable::StableIValue::fromTuple(ivalues); +} + +// Helper to put list to IValue stack +inline void tensor_list_to_stack(torch::stable::StableIValue* stack, int idx, + const std::vector& tensors) { + std::vector ivalues; + ivalues.reserve(tensors.size()); + for (const auto& t : tensors) { + ivalues.push_back(torch::stable::StableIValue::from(t)); + } + stack[idx] = torch::stable::StableIValue::fromList(ivalues); +} + +// ============================================================================ +// Device and Stream Utilities +// ============================================================================ + +// Check if tensor is on CUDA +inline bool is_cuda(const torch::stable::Tensor& tensor) { return tensor.device().type() == torch::headeronly::kCUDA; } + +// Get CUDA device index +inline int64_t get_device_index(const torch::stable::Tensor& tensor) { + STD_TORCH_CHECK(is_cuda(tensor), "Tensor must be on CUDA device"); + return tensor.device().index(); +} + +// ============================================================================ +// Common Tensor Checks +// ============================================================================ + +inline void check_cuda(const torch::stable::Tensor& tensor, const char* name) { + STD_TORCH_CHECK(is_cuda(tensor), "%s must be a CUDA tensor", name); +} + +inline void check_contiguous(const torch::stable::Tensor& tensor, const char* name) { + STD_TORCH_CHECK(tensor.is_contiguous(), "%s must be contiguous", name); +} + +inline void check_same_device(const torch::stable::Tensor& t1, const torch::stable::Tensor& t2, const char* name1, + const char* name2) { + STD_TORCH_CHECK(t1.device() == t2.device(), "%s and %s must be on the same device", name1, name2); +} + +inline void check_same_dtype(const torch::stable::Tensor& t1, const torch::stable::Tensor& t2, const char* name1, + const char* name2) { + STD_TORCH_CHECK(t1.scalar_type() == t2.scalar_type(), "%s and %s must have the same dtype, got %s and %s", name1, + name2, scalar_type_name(t1.scalar_type()), scalar_type_name(t2.scalar_type())); +} + +} // namespace stable +} // namespace apex + +#else // !TORCH_STABLE_ONLY + +// When not using stable ABI, provide no-op definitions or traditional includes +#include +#include + +namespace apex { +namespace stable { + +// Map to traditional PyTorch MemoryFormat for non-stable builds +using MemoryFormat = at::MemoryFormat; + +// Use traditional is_contiguous in non-stable builds +inline bool is_contiguous(const at::Tensor& tensor, at::MemoryFormat format) { return tensor.is_contiguous(format); } + +} // namespace stable +} // namespace apex + +#endif // TORCH_STABLE_ONLY diff --git a/csrc/type_shim.h b/csrc/type_shim.h index 9812293c5..883b97925 100644 --- a/csrc/type_shim.h +++ b/csrc/type_shim.h @@ -1,5 +1,39 @@ +#ifdef TORCH_STABLE_ONLY +#include +#include + +#include "stable_abi_utils.h" + +// Error macro for stable ABI +#define APEX_ERROR(...) apex::stable::STD_TORCH_CHECK(false, __VA_ARGS__) + +// Namespace and type aliases for stable ABI +namespace apex_internal { +using ScalarType = torch::headeronly::ScalarType; +using Half = torch::headeronly::Half; +using BFloat16 = torch::headeronly::BFloat16; + +inline std::string toString(ScalarType type) { return std::string(apex::stable::scalar_type_name(type)); } +} // namespace apex_internal + +#else // !TORCH_STABLE_ONLY + #include +// Error macro for traditional API +#define APEX_ERROR(...) AT_ERROR(__VA_ARGS__) + +// Namespace and type aliases for traditional API +namespace apex_internal { +using ScalarType = at::ScalarType; +using Half = at::Half; +using BFloat16 = at::BFloat16; + +inline std::string toString(at::ScalarType type) { return std::string(c10::toString(type)); } +} // namespace apex_internal + +#endif // TORCH_STABLE_ONLY + // Forward/backward compatiblity hack around // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 // pending more future-proof guidance from upstream. @@ -13,251 +47,251 @@ // //operator at::ScalarType(){ return payload.; }; // }; -#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_##LEVEL = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPE), "'"); \ } -#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_##LEVEL = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::BFloat16: { \ + using scalar_t_##LEVEL = apex_internal::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPE), "'"); \ } -#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Byte: { \ - using scalar_t_##LEVEL = uint8_t; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_##LEVEL = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Byte: { \ + using scalar_t_##LEVEL = uint8_t; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPE), "'"); \ } -#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Double: { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case apex_internal::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_##LEVEL = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPE), "'"); \ } -#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Double: { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case apex_internal::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_##LEVEL = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::BFloat16: { \ + using scalar_t_##LEVEL = apex_internal::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPE), "'"); \ } -#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Double: { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case apex_internal::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPE), "'"); \ } -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) { \ + case apex_internal::ScalarType::Half: { \ + using scalar_t = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::BFloat16: { \ + using scalar_t = apex_internal::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPE), "'"); \ } -#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch (TYPEIN) { \ - case at::ScalarType::Float: { \ - using scalar_t_in = float; \ - switch (TYPEOUT) { \ - case at::ScalarType::Float: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_in = at::Half; \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t_in = at::BFloat16; \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) { \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_in = float; \ + switch (TYPEOUT) { \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_out = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::BFloat16: { \ + using scalar_t_out = apex_internal::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_in = apex_internal::Half; \ + using scalar_t_out = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::BFloat16: { \ + using scalar_t_in = apex_internal::BFloat16; \ + using scalar_t_out = apex_internal::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPEIN), "'"); \ } -#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch (TYPEIN) { \ - case at::ScalarType::Double: { \ - using scalar_t_in = double; \ - switch (TYPEOUT) { \ - case at::ScalarType::Double: { \ - using scalar_t_out = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Float: { \ - using scalar_t_in = float; \ - switch (TYPEOUT) { \ - case at::ScalarType::Float: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_in = at::Half; \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t_in = at::BFloat16; \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ +#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) { \ + case apex_internal::ScalarType::Double: { \ + using scalar_t_in = double; \ + switch (TYPEOUT) { \ + case apex_internal::ScalarType::Double: { \ + using scalar_t_out = double; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_out = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::BFloat16: { \ + using scalar_t_out = apex_internal::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_in = float; \ + switch (TYPEOUT) { \ + case apex_internal::ScalarType::Float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_out = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::BFloat16: { \ + using scalar_t_out = apex_internal::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case apex_internal::ScalarType::Half: { \ + using scalar_t_in = apex_internal::Half; \ + using scalar_t_out = apex_internal::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case apex_internal::ScalarType::BFloat16: { \ + using scalar_t_in = apex_internal::BFloat16; \ + using scalar_t_out = apex_internal::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + APEX_ERROR(#NAME, " not implemented for '", apex_internal::toString(TYPEIN), "'"); \ } template diff --git a/setup.py b/setup.py index 32b218d0d..80ec4edb7 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,11 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) +# Check for Stable ABI build +USE_STABLE_ABI = os.environ.get("TORCH_STABLE_ONLY", "0") == "1" +if USE_STABLE_ABI: + print("[apex] Building with LibTorch Stable ABI support (TORCH_STABLE_ONLY=1)") + # Allow environment variables to specify build flags for PEP 517 compatibility ENV_TO_FLAG = { "APEX_CPP_EXT": "--cpp_ext", @@ -114,6 +119,76 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int return True +def prepare_stable_abi_sources(sources): + """ + Convert source file paths to stable ABI versions if TORCH_STABLE_ONLY is set. + For dual-build support, replaces .cpp files with _stable.cpp equivalents. + + Example: "csrc/amp_C_frontend.cpp" -> "csrc/amp_C_frontend_stable.cpp" + """ + if not USE_STABLE_ABI: + return sources + + stable_sources = [] + for src in sources: + if src.endswith(".cpp"): + # Replace .cpp with _stable.cpp + stable_src = src[:-4] + "_stable.cpp" + stable_sources.append(stable_src) + else: + # Keep .cu files as-is (they should work with stable ABI via updated headers) + stable_sources.append(src) + return stable_sources + + +def add_stable_abi_compile_args(extra_compile_args): + """ + Add -DTORCH_STABLE_ONLY to compiler flags when building with stable ABI. + """ + if not USE_STABLE_ABI: + return extra_compile_args + + # Make a copy to avoid modifying the original + args = extra_compile_args.copy() if extra_compile_args else {} + + # Add TORCH_STABLE_ONLY define to both cxx and nvcc compilers + if "cxx" not in args: + args["cxx"] = [] + if "nvcc" not in args: + args["nvcc"] = [] + + args["cxx"] = args["cxx"] + ["-DTORCH_STABLE_ONLY"] + args["nvcc"] = args["nvcc"] + ["-DTORCH_STABLE_ONLY"] + + return args + + +def StableCUDAExtension(name, sources, extra_compile_args=None, **kwargs): + """ + Wrapper for CUDAExtension that automatically handles stable ABI source substitution and flags. + """ + stable_sources = prepare_stable_abi_sources(sources) + stable_compile_args = add_stable_abi_compile_args( + extra_compile_args if extra_compile_args else {} + ) + return CUDAExtension( + name=name, sources=stable_sources, extra_compile_args=stable_compile_args, **kwargs + ) + + +def StableCppExtension(name, sources, extra_compile_args=None, **kwargs): + """ + Wrapper for CppExtension that automatically handles stable ABI source substitution and flags. + """ + stable_sources = prepare_stable_abi_sources(sources) + stable_compile_args = add_stable_abi_compile_args( + extra_compile_args if extra_compile_args else {} + ) + return CppExtension( + name=name, sources=stable_sources, extra_compile_args=stable_compile_args, **kwargs + ) + + if not torch.cuda.is_available(): # https://github.com/NVIDIA/apex/issues/486 # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), @@ -168,7 +243,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int if has_flag("--cpp_ext", "APEX_CPP_EXT"): if "--cpp_ext" in sys.argv: sys.argv.remove("--cpp_ext") - ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"])) + ext_modules.append(StableCppExtension("apex_C", ["csrc/flatten_unflatten.cpp"])) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) @@ -178,7 +253,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--distributed_adam") raise_if_cuda_home_none("--distributed_adam") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="distributed_adam_cuda", sources=[ "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp", @@ -197,7 +272,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--distributed_lamb") raise_if_cuda_home_none("--distributed_lamb") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="distributed_lamb_cuda", sources=[ "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp", @@ -218,7 +293,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int check_cuda_torch_binary_vs_bare_metal(CUDA_HOME) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="amp_C", sources=[ "csrc/amp_C_frontend.cpp", @@ -249,7 +324,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="syncbn", sources=["csrc/syncbn.cpp", "csrc/welford.cu"], extra_compile_args={ @@ -260,7 +335,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fused_layer_norm_cuda", sources=["csrc/layer_norm_cuda.cpp", "csrc/layer_norm_cuda_kernel.cu"], extra_compile_args={ @@ -271,7 +346,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="mlp_cuda", sources=["csrc/mlp.cpp", "csrc/mlp_cuda.cu"], extra_compile_args={ @@ -281,7 +356,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fused_dense_cuda", sources=["csrc/fused_dense.cpp", "csrc/fused_dense_cuda.cu"], extra_compile_args={ @@ -292,7 +367,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="scaled_upper_triang_masked_softmax_cuda", sources=[ "csrc/megatron/scaled_upper_triang_masked_softmax.cpp", @@ -313,7 +388,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="generic_scaled_masked_softmax_cuda", sources=[ "csrc/megatron/generic_scaled_masked_softmax.cpp", @@ -334,7 +409,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="scaled_masked_softmax_cuda", sources=[ "csrc/megatron/scaled_masked_softmax.cpp", @@ -355,7 +430,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="scaled_softmax_cuda", sources=[ "csrc/megatron/scaled_softmax.cpp", @@ -376,7 +451,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fused_rotary_positional_embedding", sources=[ "csrc/megatron/fused_rotary_positional_embedding.cpp", @@ -397,7 +472,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fused_weight_gradient_mlp_cuda", include_dirs=[os.path.join(this_dir, "csrc")], sources=[ @@ -454,7 +529,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--bnp") raise_if_cuda_home_none("--bnp") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="bnp", sources=[ "apex/contrib/csrc/groupbn/batch_norm.cu", @@ -484,7 +559,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int xentropy_ver = datetime.today().strftime("%y.%m.%d") print(f"`--xentropy` setting version of {xentropy_ver}") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="xentropy_cuda", sources=[ "apex/contrib/csrc/xentropy/interface.cpp", @@ -503,7 +578,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--focal_loss") raise_if_cuda_home_none("--focal_loss") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="focal_loss_cuda", sources=[ "apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp", @@ -523,7 +598,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--group_norm") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="group_norm_cuda", sources=[ "apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp", @@ -583,7 +658,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--index_mul_2d") raise_if_cuda_home_none("--index_mul_2d") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fused_index_mul_2d", sources=[ "apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp", @@ -602,7 +677,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--deprecated_fused_adam") raise_if_cuda_home_none("--deprecated_fused_adam") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fused_adam_cuda", sources=[ "apex/contrib/csrc/optimizers/fused_adam_cuda.cpp", @@ -621,7 +696,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--deprecated_fused_lamb") raise_if_cuda_home_none("--deprecated_fused_lamb") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fused_lamb_cuda", sources=[ "apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp", @@ -649,7 +724,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--fast_layer_norm") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fast_layer_norm", sources=[ "apex/contrib/csrc/layer_norm/ln_api.cpp", @@ -701,7 +776,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int cc_flag.append("arch=compute_110,code=sm_110") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fmhalib", sources=[ "apex/contrib/csrc/fmha/fmha_api.cpp", @@ -752,7 +827,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ] ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="fast_multihead_attn", sources=[ "apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp", @@ -792,7 +867,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--transducer") raise_if_cuda_home_none("--transducer") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="transducer_joint_cuda", sources=[ "apex/contrib/csrc/transducer/transducer_joint.cpp", @@ -809,7 +884,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ) ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="transducer_loss_cuda", sources=[ "apex/contrib/csrc/transducer/transducer_loss.cpp", @@ -854,7 +929,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--peer_memory") raise_if_cuda_home_none("--peer_memory") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="peer_memory_cuda", sources=[ "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu", @@ -978,7 +1053,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int sys.argv.remove("--gpu_direct_storage") raise_if_cuda_home_none("--gpu_direct_storage") ext_modules.append( - CUDAExtension( + StableCUDAExtension( name="_apex_gpu_direct_storage", sources=[ "apex/contrib/csrc/gpu_direct_storage/gds.cpp",