webgpu_mm.js
var best = 0;
function log(...args) {
const out = document.getElementById('out');
const s = document.createElement('div');
for (let arg of args) {
s.textContent += arg + ' ';
}
out.appendChild(s);
}
function mm_ref(A, B, C, M, N, K) {
for (let m = 0; m < M; ++m) {
for (let n = 0; n < N; ++n) {
let res = 0;
for (let k = 0; k < K; ++k) {
res += A[m * K + k] * B[k * N + n];
}
C[m * N + n] = res;
}
}
}
async function check(device, M, N, K, opt) {
const mm = createMatrixMultiplication(device, M, N, K, opt);
const [A, A_cpu] = randGPU(device, M * K, true);
const [B, B_cpu] = randGPU(device, K * N, true);
const [C, C_cpu] = randGPU(device, M * N, true);
device.getQueue().submit([mm(A, B, C)]);
const gpu_out = await toCPU(device, C, M * N);
mm_ref(A_cpu, B_cpu, C_cpu, M, N, K);
let max_diff = 0;
for (let i = 0; i < M * N; ++i) {
const diff = Math.abs(gpu_out[i] - C_cpu[i]);
console.assert(diff < 0.0001);
if (diff > max_diff) {
max_diff = diff;
}
}
if (max_diff < 0.0001) {
//log("pass! max diff:", max_diff);
return true;
} else {
log("fail! max diff:", max_diff);
console.log(gpu_out, C_cpu);
return false;
}
}
function randGPU(device, numel, return_cpu_ref = false) {
const [gpu, cpu] = device.createBufferMapped({
size: numel * 4, // sizeof float
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
let rand = new Float32Array(numel);
for (let i = 0; i < numel; ++i) {
rand[i] = Math.random() / 511.91;
}
new Float32Array(cpu).set(rand);
gpu.unmap();
if (return_cpu_ref) {
return [gpu, rand];
}
return gpu;
}
async function toCPU(device, gpu_array, numel) {
const buffer = device.createBuffer({
size: numel * 4,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
});
const commandEncoder = device.createCommandEncoder();
commandEncoder.copyBufferToBuffer(gpu_array, 0, buffer, 0, numel * 4);
device.getQueue().submit([commandEncoder.finish()]);
return new Float32Array(await buffer.mapReadAsync());
}
function generateMatrixMultiplicationKernelOpt(M, N, K, opt) {
if (opt.use_matrix) {
return __generateMatrixMultiplicationKernelOpt(M, N, K, opt);
}
if (opt.vec_width > opt.n_unroll) {
log("error generating kernel. check options");
return "";
}
let A_type = 'float4';
if (opt.k_unroll == 1) {
A_type = 'float';
}
let source = ``;
if (opt.swap_threads) {
source += `[numthreads(${opt.y_threads}, ${opt.x_threads}, 1)]`;
} else {
source += `[numthreads(${opt.x_threads}, ${opt.y_threads}, 1)]`;
}
source += `
compute void main(constant ${A_type}[] A : register(u0),
constant float4[] B : register(u1),
device float4[] C : register(u2),
float3 threadID : SV_DispatchThreadID) {`;
if (opt.swap_threads) {
source += `
uint n = uint(threadID.x);
uint m = uint(threadID.y);\n`;
} else {
source += `
uint m = uint(threadID.x);
uint n = uint(threadID.y);\n`;
}
for (let m = 0; m < opt.m_unroll; ++m) {
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
source += `
float4 result_${m}_${n} = float4(0.0, 0.0, 0.0, 0.0);`;
}
}
source += `
for (uint k = 0; k < ${K / opt.k_unroll}; k++) {`;
for (let m = 0; m < opt.m_unroll; ++m) {
for (let k = 0; k < Math.max(opt.k_unroll / opt.vec_width, 1); ++k) {
const idx = `(m * ${opt.m_unroll} + ${m}) * ${K / opt.vec_width} + (k * ${opt.k_unroll / opt.vec_width} + ${k})`;
source += `
${A_type} a_${m}_${k} = A[${idx}];`;
}
}
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
for (let k = 0; k < opt.k_unroll; ++k) {
const idx = `(k * ${opt.k_unroll} + ${k}) * ${N / opt.vec_width} + (n * ${opt.n_unroll / opt.vec_width} + ${n})`;
source += `
float4 b_${n}_${k} = B[${idx}];`;
}
}
if (opt.use_mad) { // have to unroll to use mad
for (let m = 0; m < opt.m_unroll; ++m) {
for (let k = 0; k < Math.max(opt.k_unroll / opt.vec_width, 1); ++k) {
if (opt.k_unroll == 1) {
source += `
float4 a_${m}_${k}_v = float4(a_${m}_${k}, a_${m}_${k}, a_${m}_${k}, a_${m}_${k});`;
} else {
source += `
float4 a_${m}_${k}_x = float4(a_${m}_${k}.x, a_${m}_${k}.x, a_${m}_${k}.x, a_${m}_${k}.x);
float4 a_${m}_${k}_y = float4(a_${m}_${k}.y, a_${m}_${k}.y, a_${m}_${k}.y, a_${m}_${k}.y);
float4 a_${m}_${k}_z = float4(a_${m}_${k}.z, a_${m}_${k}.z, a_${m}_${k}.z, a_${m}_${k}.z);
float4 a_${m}_${k}_w = float4(a_${m}_${k}.w, a_${m}_${k}.w, a_${m}_${k}.w, a_${m}_${k}.w);`
}
}
}
}
if (opt.k_unroll == 1) {
for (let k = 0; k < Math.max(opt.k_unroll / opt.vec_width, 1); ++k) {
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
for (let m = 0; m < opt.m_unroll; ++m) {
if (opt.use_mad) {
source += `
result_${m}_${n} = mad(a_${m}_${k}_v, b_${n}_${k}, result_${m}_${n});`
} else {
source += `
result_${m}_${n} += mul(a_${m}_${k}.v, b_${n}_${k});`
}
}
}
}
} else {
const lets = ['x', 'y', 'z', 'w'];
for (let l of lets) {
for (let k = 0; k < Math.max(opt.k_unroll / opt.vec_width, 1); ++k) {
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
for (let m = 0; m < opt.m_unroll; ++m) {
if (opt.use_mad) {
source += `
result_${m}_${n} = mad(a_${m}_${k}_${l}, b_${n}_${k * opt.vec_width + lets.indexOf(l)}, result_${m}_${n});`;
} else {
source += `
result_${m}_${n} += mul(a_${m}_${k}.${l}, b_${n}_${k * opt.vec_width + lets.indexOf(l)});`;
}
}
}
}
}
}
source += `
}`;
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
for (let m = 0; m < opt.m_unroll; ++m) {
const idx = `(m * ${opt.m_unroll} + ${m}) * ${N / opt.vec_width} + (n * ${opt.n_unroll / opt.vec_width} + ${n})`;
source += `
C[${idx}] = result_${m}_${n};`;
}
}
source += `\n}`;
const dispatch = [M / opt.x_threads / opt.m_unroll, N / opt.y_threads / opt.n_unroll, 1];
if (opt.swap_threads) {
const x = dispatch[0];
const y = dispatch[1];
dispatch[0] = y;
dispatch[1] = x;
}
return [source, dispatch];
}
// use matrices instead of vectors
function __generateMatrixMultiplicationKernelOpt(M, N, K, opt) {
let source = ``;
source += `[numthreads(${opt.x_threads}, ${opt.y_threads}, 1)]`;
source += `
compute void main(constant float4[] A : register(u0),
constant float4[] B : register(u1),
device float4[] C : register(u2),
float3 threadID : SV_DispatchThreadID) {`;
source += `
uint m = uint(threadID.x);
uint n = uint(threadID.y);\n`;
for (let m = 0; m < opt.m_unroll / opt.vec_width; ++m) {
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
source += `
float4x4 result_${m}_${n} = float4x4(${"0.0, ".repeat(15)}0.0);`;
}
}
source += `
for (uint k = 0; k < ${K / opt.k_unroll}; k++) {`;
// load A vecs
for (let m = 0; m < opt.m_unroll; ++m) {
for (let k = 0; k < opt.k_unroll / opt.vec_width; ++k) {
const idx = `(m * ${opt.m_unroll} + ${m}) * ${K / opt.vec_width} + (k * ${opt.k_unroll / opt.vec_width} + ${k})`;
source += `
float4 a_${m}_${k} = A[${idx}];`;
}
}
// make A matrices
for (let m = 0; m < opt.m_unroll / opt.vec_width; ++m) {
for (let k = 0; k < opt.k_unroll / opt.vec_width; ++k) {
source += `
float4x4 a_m_${m}_${k} = float4x4(`;
for (let i = 0; i < opt.vec_width; ++i) {
source += `a_${m * opt.vec_width + i}_${k}`;
if (i != opt.vec_width - 1) {
source += `, `;
}
}
source += `);`;
}
}
// load B vecs
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
for (let k = 0; k < opt.k_unroll; ++k) {
const idx = `(k * ${opt.k_unroll} + ${k}) * ${N / opt.vec_width} + (n * ${opt.n_unroll / opt.vec_width} + ${n})`;
source += `
float4 b_${n}_${k} = B[${idx}];`;
}
}
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
for (let k = 0; k < opt.k_unroll / opt.vec_width; ++k) {
source += `
float4x4 b_m_${n}_${k} = float4x4(`;
for (let i = 0; i < opt.vec_width; ++i) {
source += `b_${n}_${k * opt.vec_width + i}`;
if (i != opt.vec_width - 1) {
source += `, `;
}
}
source += `);`;
}
}
// multiply matrices
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
for (let m = 0; m < opt.m_unroll / opt.vec_width; ++m) {
for (let k = 0; k < opt.k_unroll / opt.vec_width; ++k) {
source += `
result_${m}_${n} += mul(b_m_${n}_${k}, a_m_${m}_${k});`
}
}
}
source += `
}`; // k
// write to C
for (let m = 0; m < opt.m_unroll / opt.vec_width; ++m) {
for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
for (let i = 0; i < opt.vec_width; ++i) {
const idx = `(m * ${opt.m_unroll} + ${m * opt.vec_width + i}) * ${N / opt.vec_width} + (n * ${opt.n_unroll / opt.vec_width} + ${n})`;
source += `
C[${idx}] = result_${m}_${n}[${i}];`
}
}
}
source += `
}`; // main
const dispatch = [M / opt.x_threads / opt.m_unroll, N / opt.y_threads / opt.n_unroll, 1];
return [source, dispatch];
}
function _generateMatrixMultiplicationKernelOpt(M, N, K, opt) {
let source = `
[numthreads(8, 8, 1)]
compute void main(constant float4[] A: register(u0),
constant float4[] B: register(u1),
device float4[] C: register(u2),
float3 threadID : SV_DispatchThreadID) {
uint m = uint(threadID.x);
uint n = uint(threadID.y);
float4x4 result = float4x4(
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0);
for (uint k = 0; k < ${K / 4}; ++k) {
float4 a0 = A[(m * 4 + 0) * ${K / 4} + k];
float4 a1 = A[(m * 4 + 1) * ${K / 4} + k];
float4 a2 = A[(m * 4 + 2) * ${K / 4} + k];
float4 a3 = A[(m * 4 + 3) * ${K / 4} + k];
float4 b0 = B[(k * 4 + 0) * ${N / 4} + n];
float4 b1 = B[(k * 4 + 1) * ${N / 4} + n];
float4 b2 = B[(k * 4 + 2) * ${N / 4} + n];
float4 b3 = B[(k * 4 + 3) * ${N / 4} + n];
float4x4 b = float4x4(b0, b1, b2, b3);
float4x4 a = float4x4(a0, a1, a2, a3);
result += mul(b, a);
}
C[(m * 4 + 0) * ${N / 4} + n] = result[0];
C[(m * 4 + 1) * ${N / 4} + n] = result[1];
C[(m * 4 + 2) * ${N / 4} + n] = result[2];
C[(m * 4 + 3) * ${N / 4} + n] = result[3];
}
`;
const dispatch = [M / 8 / 4, N / 8 / 4, 1];
return [source, dispatch];
}
function createMatrixMultiplication(device, M, N, K, opt) {
// BindGroupLayout
const visibility = GPUShaderStage.COMPUTE;
const type = "storage-buffer";
const bindGroupLayout = device.createBindGroupLayout({
bindings: [
{ binding: 0, visibility: visibility, type: type },
{ binding: 1, visibility: visibility, type: type },
{ binding: 2, visibility: visibility, type: type },
]
});
// PipelineLayout
const pipelineLayout = device.createPipelineLayout({
bindGroupLayouts: [bindGroupLayout],
});
// ComputePipeline
const [source, dispatch] = generateMatrixMultiplicationKernelOpt(M, N, K, opt);
const computePipeline = device.createComputePipeline({
layout: pipelineLayout,
computeStage: {
module: device.createShaderModule({
code: source,
}),
entryPoint: "main"
}
});
// define the mm function
function mm(A, B, C) {
const commandEncoder = device.createCommandEncoder();
const bindGroup = device.createBindGroup({
layout: bindGroupLayout,
bindings: [
{ binding: 0, resource: { buffer: A, size: M * K * 4 } },
{ binding: 1, resource: { buffer: B, size: N * K * 4 } },
{ binding: 2, resource: { buffer: C, size: M * N * 4 } },
]
});
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(computePipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatch(dispatch[0], dispatch[1], dispatch[2]);
passEncoder.endPass();
return commandEncoder.finish();
}
return mm;
}
async function run(opt) {
if (!navigator.gpu) {
log("WebGPU not found.");
return;
}
const adapter = await navigator.gpu.requestAdapter();
const device = await adapter.requestDevice();
const M = 1024;
const N = 1024;
const K = 1024;
let f = await check(device, M / 4, N / 4, K / 4, opt);
//let f = await check(device, M, N, K, opt);
if (!f) {
return;
}
const mm = createMatrixMultiplication(device, M, N, K, opt);
let A = randGPU(device, M * K);
let B = randGPU(device, K * N);
let C = randGPU(device, M * N);
// warmup
device.getQueue().submit([
mm(A, B, C),
mm(C, B, A),
mm(A, C, B),
mm(B, A, C),
mm(A, B, C),
mm(C, B, A),
mm(A, C, B),
mm(B, A, C),
]);
const warmup_res = await toCPU(device, C, M * N);
console.log(warmup_res[0]);
//log("benchmarking...");
A = randGPU(device, M * K);
B = randGPU(device, K * N);
C = randGPU(device, M * N);
const t0 = performance.now();
device.getQueue().submit([
mm(A, B, C),
mm(C, B, A),
mm(A, C, B),
mm(B, A, C),
mm(C, B, A),
mm(A, B, C),
mm(C, B, A),
mm(A, C, B),
mm(B, A, C),
mm(C, B, A),
]);
const result = await toCPU(device, C, M * N);
console.log(result[0]);
const t1 = performance.now();
const flops = M * N * K * 2 * 10;
const gflops = flops / ((t1 - t0) * 1e6);
log("gflops:", gflops, "time:", t1 - t0);
if (gflops > best) {
best = gflops;
let best_elem = document.getElementById('best');
const [source, dispatch] = generateMatrixMultiplicationKernelOpt(M, N, K, opt);
best_elem.textContent = 'best: ' + gflops.toFixed(2) + ' gflops\n' + source + '\n\n'
+ 'dispatch params: ' + dispatch;
}
}
async function try_opts() {
if (!navigator.gpu) {
log("WebGPU not found. Be sure to use Safari and enable WebGPU in Develop > Experimental Features.");
log(" ");
log("If you don't have Safari, you can try TensorFlow's WebGL backend: https://jott.live/html/tf_mm.html");
return;
}
log("Attempting to naively optimize matrix multiplication...\nattempts below:\n");
for (let n of [4, 8, 16]) {
for (let k of [4, 8, 16]) {
for (let m of [2, 4, 8, 16]) {
for (let x of [2, 4, 8, 16]) {
for (let y of [2, 4, 8, 16]) {
for (let use_matrix of [0, 1]) {
for (let use_mad of [0, 1]) {
if (use_matrix && (m < 4 || k < 4)) { continue; }
// pretty much always bad to swap
for (let swap of [0]) {
const opt = {
n_unroll: n,
m_unroll: m,
k_unroll: k,
x_threads : x,
y_threads : y,
matrix: use_matrix,
use_mad: use_mad,
swap_threads : swap,
vec_width: 4
};
log(`n: ${n} k: ${k} m: ${m} tx: ${x} ty: ${y} swap: ${swap} use matrix: ${use_matrix} use mad: ${use_mad}`);
await run(opt);
}
}
}
}
}
}
}
}
}
window.addEventListener('load', try_opts);