Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix as comment
  • Loading branch information
zhang hui committed Nov 11, 2025
commit fd18344cf15c479a488da9ae951d5cb3c6db56b5
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ static const char * cu_get_error_str(CUresult err) {
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

#if defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_WMMA)
#if defined(GGML_USE_HIP) && defined(RDNA4)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ template<typename dst_t, typename src_t>
return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
return __bfloat162float(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
return __float22bfloat162_rn(x);
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
Expand Down
42 changes: 16 additions & 26 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,17 @@ namespace ggml_cuda_mma {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
#else
Expand Down Expand Up @@ -263,7 +265,6 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
half2 x[ne] = {{0.0f, 0.0f}};

Expand All @@ -276,18 +277,19 @@ namespace ggml_cuda_mma {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
#endif // defined(RDNA4)
#else
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};
Expand Down Expand Up @@ -339,7 +341,6 @@ namespace ggml_cuda_mma {
static constexpr int J = J_;

#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

Expand All @@ -352,18 +353,19 @@ namespace ggml_cuda_mma {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
#endif // defined(RDNA4)
#else
static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
Expand Down Expand Up @@ -435,18 +437,10 @@ namespace ggml_cuda_mma {
xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
// Special tile size to load <16, 8> as <16, 16> for half2 and __hip_bfloat162
if constexpr (I == 16 && J == 8 && (std::is_same<T, half2>::value || std::is_same<T, nv_bfloat162>::value)) {
constexpr int RDNA4_WMMA_MEM_N = 4;
using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) int32_t;
reinterpret_cast<TxN_t&>(t.x[0]) = reinterpret_cast<const TxN_t&>(xs0[t.get_i(0) * stride + t.get_j(0)]);
} else {
constexpr int RDNA4_WMMA_MEM_N = 8;
using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) T;
reinterpret_cast<TxN_t&>(t.x[0]) = reinterpret_cast<const TxN_t&>(xs0[t.get_i(0) * stride + t.get_j(0)]);
}
#endif // defined(RDNA4)
constexpr int nbytes = sizeof(t.x);
// Special case for RDNA3 fp16 and bf16 wmma, the size is 32 bytes.
constexpr int alignment = nbytes > ggml_cuda_get_max_cpy_bytes() ? ggml_cuda_get_max_cpy_bytes() : nbytes;
ggml_cuda_memcpy_1<sizeof(t.x), alignment>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
#else
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
Expand Down Expand Up @@ -734,14 +728,12 @@ namespace ggml_cuda_mma {
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand All @@ -761,14 +753,12 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand Down
35 changes: 8 additions & 27 deletions ggml/src/ggml-cuda/mmf.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "mma.cuh"
#include "common.cuh"
#include "convert.cuh"

using namespace ggml_cuda_mma;

Expand Down Expand Up @@ -150,27 +151,15 @@ static __device__ __forceinline__ void mul_mat_f_impl(
#if !defined(GGML_USE_HIP)
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
#else
if constexpr (std::is_same<T, half2>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp);
} else if constexpr (std::is_same<T, nv_bfloat162>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp);
} else {
static_assert(0, "unsupported type");
}
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T, float2>(tmp);
#endif // !defined(GGML_USE_HIP)
} else {
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
#if !defined(GGML_USE_HIP)
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
#else
if constexpr (std::is_same<T, half2>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp);
} else if constexpr (std::is_same<T, nv_bfloat162>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp);
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T, float2>(tmp);
#endif // !defined(GGML_USE_HIP)
}
}
Expand Down Expand Up @@ -448,13 +437,7 @@ static __device__ __forceinline__ void mul_mat_f_ids_impl(
#if !defined(GGML_USE_HIP)
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
#else
if constexpr (std::is_same<T, half2>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp);
} else if constexpr (std::is_same<T, nv_bfloat162>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp);
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T, float2>(tmp);
#endif // !defined(GGML_USE_HIP)
}

Expand Down Expand Up @@ -651,11 +634,8 @@ void mul_mat_f_cuda(
cudaStream_t stream, const mmf_ids_data * ids_data) {
typedef tile<16, 8, T> tile_A_16;
typedef tile<32, 8, T> tile_A_32;
#if defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, T> tile_B;
#else
typedef tile< 8, 8, T> tile_B;
#endif // defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, T> tile_B_16;
typedef tile< 8, 8, T> tile_B_8;

GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
Expand All @@ -682,7 +662,8 @@ void mul_mat_f_cuda(

constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
Expand Down