-
Notifications
You must be signed in to change notification settings - Fork 13.9k
Add support for CUMSUM and TRI for CUDA. #17584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
For cumsum we should use https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceScan.html and use this kernel as a fallback |
|
I have a small optimization for the tri kernel (; Benchmark Results1. llama.cpp benchmark (50 runs each)
2. Profiler Statistics rtx 2070 (Nsight)
@@ -1,16 +1,7 @@
#include "tri.cuh"
#include "ggml.h"
-// Triangle type comparison - determines which elements to keep
-__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) {
- switch (type) {
- case GGML_TRI_TYPE_LOWER: return i < r;
- case GGML_TRI_TYPE_LOWER_DIAG: return i <= r;
- case GGML_TRI_TYPE_UPPER: return i > r;
- case GGML_TRI_TYPE_UPPER_DIAG: return i >= r;
- default: return false;
- }
-}
+
template<typename T>
static __global__ void tri_kernel(
@@ -31,10 +22,22 @@ static __global__ void tri_kernel(
const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03);
T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+ // Optimization: Avoid control flow (switch) inside the hot loop.
+ // Map the 4 triangle types to a generic "split point" and "keep direction" logic.
+ // LOWER / UPPER_DIAG: Split at 'r' (i1). LOWER_DIAG / UPPER: Split at 'r + 1'.
+ int add_to_split = 0;
+ if (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) {
+ add_to_split = 1;
+ }
+ int64_t split_point = i1 + add_to_split;
+ bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
+
// Each thread processes elements at stride blockDim.x
for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) {
- dst_row[i0] = tri_compare(i0, i1, ttype)
- ? src_row[i0] : static_cast<T>(0.f);
+ // If prefix_keep is true, keep (i0 < split_point). Else, keep (i0 >= split_point).
+ bool keep = ((i0 < split_point) == prefix_keep);
+ dst_row[i0] = keep ? src_row[i0] : T(0);
}
} |
| const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); | ||
| T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As with the other kernel, preferably calculate strides in units of float in host code and pass those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is generic though, should I still be calculating in units of float even though T itself might be half?
| // Load value and compute prefix sum within warp | ||
| float val = static_cast<float>(src_row[i0]); | ||
| val = warp_prefix_inclusive_sum(val); | ||
| dst_row[i0] = static_cast<T>(val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be much preferable to store the temporary results in registers or shared memory rather than global memory.
| __device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) { | ||
| switch (type) { | ||
| case GGML_TRI_TYPE_LOWER: return i < r; | ||
| case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; | ||
| case GGML_TRI_TYPE_UPPER: return i > r; | ||
| case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; | ||
| default: return false; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be very slow in GPU code. Preferably make this a constexpr function and provide the ggml_tri_type at compile time as a template parameter.
| const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); | ||
| T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preferably calculate the stride in host code.
|
Regarding the implementation proposed by @wsbagnsv1 . If one were to do something like that the in my opinion correct way to do it would be to calculate start and end points for copying and for zeroing and to then simply do 2 loops over those areas. If at all possible a conditional statement inside the loop should be avoided. But that would potentially make the kernel less flexible if other patterns for |
Extracted and adapted kernels by @gabe-l-hart from #16623