summaryrefslogtreecommitdiff
path: root/sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch
diff options
context:
space:
mode:
Diffstat (limited to 'sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch')
-rw-r--r--sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch65
1 files changed, 65 insertions, 0 deletions
diff --git a/sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch b/sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch
new file mode 100644
index 000000000000..72ab792b2278
--- /dev/null
+++ b/sci-libs/caffe2/files/caffe2-2.4.0-exclude-aotriton.patch
@@ -0,0 +1,65 @@
+Disables aotriton download when both USE_FLASH_ATTENTION and USE_MEM_EFF_ATTENTION cmake flags are OFF
+Backports upstream PR to 2.3.0: https://github.com/pytorch/pytorch/pull/130197
+--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
++++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+@@ -24,7 +24,7 @@
+ #include <c10/core/SymInt.h>
+ #include <c10/util/string_view.h>
+
+-#if USE_ROCM
++#if defined(USE_ROCM) && (defined(USE_MEM_EFF_ATTENTION) || defined(USE_FLASH_ATTENTION))
+ #include <aotriton/flash.h>
+ #endif
+
+@@ -207,7 +207,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
+ // Check that the gpu is capable of running flash attention
+ using sm80 = SMVersion<8, 0>;
+ using sm90 = SMVersion<9, 0>;
+-#if USE_ROCM
++#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION)
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
+ if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
+ auto dprops = at::cuda::getCurrentDeviceProperties();
+@@ -238,7 +238,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
+ // Mem Efficient attention supports hardware in the range [sm_50, sm_90]
+ using sm50 = SMVersion<5, 0>;
+ using sm90 = SMVersion<9, 0>;
+-#if USE_ROCM
++#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION)
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
+ if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
+ auto dprops = at::cuda::getCurrentDeviceProperties();
+@@ -623,7 +623,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
+ array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
+ constexpr auto less_than_sm80_mem_efficient_dtypes =
+ array_of<at::ScalarType>(at::kHalf, at::kFloat);
+-#ifdef USE_ROCM
++#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION)
+ constexpr auto aotriton_mem_efficient_dtypes =
+ array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
+ #endif
+@@ -668,7 +668,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
+ }
+ }
+
+-#ifdef USE_ROCM
++#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION)
+ return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
+ #else
+ auto dprop = at::cuda::getCurrentDeviceProperties();
+--- a/cmake/Dependencies.cmake
++++ b/cmake/Dependencies.cmake
+@@ -1095,10 +1095,12 @@ if(USE_ROCM)
+ message(STATUS "Disabling Kernel Assert for ROCm")
+ endif()
+
+- include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
+ if(USE_CUDA)
+ caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
+ endif()
++ if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
++ include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
++ endif()
+ else()
+ caffe2_update_option(USE_ROCM OFF)
+ endif()