diff options
Diffstat (limited to 'sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch')
-rw-r--r-- | sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch b/sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch new file mode 100644 index 000000000000..72ab792b2278 --- /dev/null +++ b/sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch @@ -0,0 +1,65 @@ +Disables aotriton download when both USE_FLASH_ATTENTION and USE_MEM_EFF_ATTENTION cmake flags are OFF +Backports upstream PR to 2.3.0: https://github.com/pytorch/pytorch/pull/130197 +--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp ++++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +@@ -24,7 +24,7 @@ + #include <c10/core/SymInt.h> + #include <c10/util/string_view.h> + +-#if USE_ROCM ++#if defined(USE_ROCM) && (defined(USE_MEM_EFF_ATTENTION) || defined(USE_FLASH_ATTENTION)) + #include <aotriton/flash.h> + #endif + +@@ -207,7 +207,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug + // Check that the gpu is capable of running flash attention + using sm80 = SMVersion<8, 0>; + using sm90 = SMVersion<9, 0>; +-#if USE_ROCM ++#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION) + auto stream = at::cuda::getCurrentCUDAStream().stream(); + if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + auto dprops = at::cuda::getCurrentDeviceProperties(); +@@ -238,7 +238,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) + // Mem Efficient attention supports hardware in the range [sm_50, sm_90] + using sm50 = SMVersion<5, 0>; + using sm90 = SMVersion<9, 0>; +-#if USE_ROCM ++#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION) + auto stream = at::cuda::getCurrentCUDAStream().stream(); + if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + auto dprops = at::cuda::getCurrentDeviceProperties(); +@@ -623,7 +623,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { + array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16); + constexpr auto less_than_sm80_mem_efficient_dtypes = + array_of<at::ScalarType>(at::kHalf, at::kFloat); +-#ifdef USE_ROCM ++#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION) + constexpr auto aotriton_mem_efficient_dtypes = + array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16); + #endif +@@ -668,7 +668,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { + } + } + +-#ifdef USE_ROCM ++#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION) + return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug); + #else + auto dprop = at::cuda::getCurrentDeviceProperties(); +--- a/cmake/Dependencies.cmake ++++ b/cmake/Dependencies.cmake +@@ -1095,10 +1095,12 @@ if(USE_ROCM) + message(STATUS "Disabling Kernel Assert for ROCm") + endif() + +- include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) + if(USE_CUDA) + caffe2_update_option(USE_MEM_EFF_ATTENTION OFF) + endif() ++ if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) ++ include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) ++ endif() + else() + caffe2_update_option(USE_ROCM OFF) + endif() |