webgpu_mm.js

var best = 0;

function log(...args) {
  const out = document.getElementById('out');
  const s = document.createElement('div');
  for (let arg of args) {
    s.textContent += arg + ' ';
  }
  out.appendChild(s);
}

function mm_ref(A, B, C, M, N, K) {
  for (let m = 0; m < M; ++m) {
    for (let n = 0; n < N; ++n) {
      let res = 0;
      for (let k = 0; k < K; ++k) {
        res += A[m * K + k] * B[k * N + n];
      }
      C[m * N + n] = res;
    }
  }
}

async function check(device, M, N, K, opt) {
  const mm = createMatrixMultiplication(device, M, N, K, opt);
  const [A, A_cpu] = randGPU(device, M * K, true);
  const [B, B_cpu] = randGPU(device, K * N, true);
  const [C, C_cpu] = randGPU(device, M * N, true);
  device.getQueue().submit([mm(A, B, C)]);
  const gpu_out = await toCPU(device, C, M * N);
  mm_ref(A_cpu, B_cpu, C_cpu, M, N, K);
  let max_diff = 0;
  for (let i = 0; i < M * N; ++i) {
    const diff = Math.abs(gpu_out[i] - C_cpu[i]);
    console.assert(diff < 0.0001);
    if (diff > max_diff) {
      max_diff = diff;
    }
  }
  if (max_diff < 0.0001) {
    //log("pass! max diff:", max_diff);
    return true;
  } else {
    log("fail! max diff:", max_diff);
    console.log(gpu_out, C_cpu);
    return false;
  }
}

function randGPU(device, numel, return_cpu_ref = false) {
  const [gpu, cpu] = device.createBufferMapped({
    size: numel * 4, // sizeof float
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
  });
  let rand = new Float32Array(numel);
  for (let i = 0; i < numel; ++i) {
    rand[i] = Math.random() / 511.91;
  }
  new Float32Array(cpu).set(rand);
  gpu.unmap();
  if (return_cpu_ref) {
    return [gpu, rand];
  }
  return gpu;
}

async function toCPU(device, gpu_array, numel) {
  const buffer = device.createBuffer({
    size: numel * 4,
    usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
  });
  const commandEncoder = device.createCommandEncoder();
  commandEncoder.copyBufferToBuffer(gpu_array, 0, buffer, 0, numel * 4);
  device.getQueue().submit([commandEncoder.finish()]);

  return new Float32Array(await buffer.mapReadAsync());
}

function generateMatrixMultiplicationKernelOpt(M, N, K, opt) {
  if (opt.use_matrix) {
    return __generateMatrixMultiplicationKernelOpt(M, N, K, opt);
  }
  if (opt.vec_width > opt.n_unroll) {
    log("error generating kernel. check options");
    return "";
  }
  let A_type = 'float4';
  if (opt.k_unroll == 1) {
    A_type = 'float';
  }
  
  let source = ``;
  if (opt.swap_threads) {
    source += `[numthreads(${opt.y_threads}, ${opt.x_threads}, 1)]`;
  } else {
    source += `[numthreads(${opt.x_threads}, ${opt.y_threads}, 1)]`;
  }

  source += `
compute void main(constant ${A_type}[] A : register(u0),
                  constant float4[] B : register(u1),
                  device float4[] C : register(u2),
                  float3 threadID : SV_DispatchThreadID) {`;
  if (opt.swap_threads) {
    source += `
  uint n = uint(threadID.x);
  uint m = uint(threadID.y);\n`;
  } else {
    source += `
  uint m = uint(threadID.x);
  uint n = uint(threadID.y);\n`;
  }
  for (let m = 0; m < opt.m_unroll; ++m) {
    for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
      source += `
  float4 result_${m}_${n} = float4(0.0, 0.0, 0.0, 0.0);`;
    }
  }
  source += `
  for (uint k = 0; k < ${K / opt.k_unroll}; k++) {`;
  for (let m = 0; m < opt.m_unroll; ++m) {
    for (let k = 0; k < Math.max(opt.k_unroll / opt.vec_width, 1); ++k) {
      const idx = `(m * ${opt.m_unroll} + ${m}) * ${K / opt.vec_width} + (k * ${opt.k_unroll / opt.vec_width} + ${k})`;
      source += `
    ${A_type} a_${m}_${k} = A[${idx}];`;
    }
  }
  for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
    for (let k = 0; k < opt.k_unroll; ++k) {
      const idx = `(k * ${opt.k_unroll} + ${k}) * ${N / opt.vec_width} + (n * ${opt.n_unroll / opt.vec_width} + ${n})`;
      source += `
    float4 b_${n}_${k} = B[${idx}];`;
    }
  }
  if (opt.use_mad) { // have to unroll to use mad
    for (let m = 0; m < opt.m_unroll; ++m) {
      for (let k = 0; k < Math.max(opt.k_unroll / opt.vec_width, 1); ++k) {
        if (opt.k_unroll == 1) {
          source += `
    float4 a_${m}_${k}_v = float4(a_${m}_${k}, a_${m}_${k}, a_${m}_${k}, a_${m}_${k});`;
        } else {
          source += `
    float4 a_${m}_${k}_x = float4(a_${m}_${k}.x, a_${m}_${k}.x, a_${m}_${k}.x, a_${m}_${k}.x);
    float4 a_${m}_${k}_y = float4(a_${m}_${k}.y, a_${m}_${k}.y, a_${m}_${k}.y, a_${m}_${k}.y);
    float4 a_${m}_${k}_z = float4(a_${m}_${k}.z, a_${m}_${k}.z, a_${m}_${k}.z, a_${m}_${k}.z);
    float4 a_${m}_${k}_w = float4(a_${m}_${k}.w, a_${m}_${k}.w, a_${m}_${k}.w, a_${m}_${k}.w);`
        }
      }
    }
  }
  if (opt.k_unroll == 1) {
    for (let k = 0; k < Math.max(opt.k_unroll / opt.vec_width, 1); ++k) {
      for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
        for (let m = 0; m < opt.m_unroll; ++m) {
          if (opt.use_mad) {
            source += `
    result_${m}_${n} = mad(a_${m}_${k}_v, b_${n}_${k}, result_${m}_${n});`
          } else {
            source += `
    result_${m}_${n} += mul(a_${m}_${k}.v, b_${n}_${k});`
          }
        }
      }
    }
  } else {
    const lets = ['x', 'y', 'z', 'w'];
    for (let l of lets) {
      for (let k = 0; k < Math.max(opt.k_unroll / opt.vec_width, 1); ++k) {
        for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
          for (let m = 0; m < opt.m_unroll; ++m) {
            if (opt.use_mad) {
              source += `
    result_${m}_${n} = mad(a_${m}_${k}_${l}, b_${n}_${k * opt.vec_width + lets.indexOf(l)}, result_${m}_${n});`;
            } else {
              source += `
    result_${m}_${n} += mul(a_${m}_${k}.${l}, b_${n}_${k * opt.vec_width + lets.indexOf(l)});`;
            }
          }
        }
      }
    }
  }
  source += `
  }`;
  for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
    for (let m = 0; m < opt.m_unroll; ++m) {
      const idx = `(m * ${opt.m_unroll} + ${m}) * ${N / opt.vec_width} + (n * ${opt.n_unroll / opt.vec_width} + ${n})`;
      source += `
  C[${idx}] = result_${m}_${n};`;
    }
  }
  source += `\n}`;
  const dispatch = [M / opt.x_threads / opt.m_unroll, N / opt.y_threads / opt.n_unroll, 1];
  if (opt.swap_threads) {
    const x = dispatch[0];
    const y = dispatch[1];
    dispatch[0] = y;
    dispatch[1] = x;
  }
  return [source, dispatch];
}

// use matrices instead of vectors
function __generateMatrixMultiplicationKernelOpt(M, N, K, opt) {
  let source = ``;
  source += `[numthreads(${opt.x_threads}, ${opt.y_threads}, 1)]`;
  source += `
compute void main(constant float4[] A : register(u0),
                  constant float4[] B : register(u1),
                  device float4[] C : register(u2),
                  float3 threadID : SV_DispatchThreadID) {`;
    source += `
  uint m = uint(threadID.x);
  uint n = uint(threadID.y);\n`;
  for (let m = 0; m < opt.m_unroll / opt.vec_width; ++m) {
    for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {

      source += `
  float4x4 result_${m}_${n} = float4x4(${"0.0, ".repeat(15)}0.0);`;
    }
  }

  source += `
  for (uint k = 0; k < ${K / opt.k_unroll}; k++) {`;

  // load A vecs
  for (let m = 0; m < opt.m_unroll; ++m) {
    for (let k = 0; k < opt.k_unroll / opt.vec_width; ++k) {
      const idx = `(m * ${opt.m_unroll} + ${m}) * ${K / opt.vec_width} + (k * ${opt.k_unroll / opt.vec_width} + ${k})`;
      source += `
    float4 a_${m}_${k} = A[${idx}];`;
    }
  }
  // make A matrices
  for (let m = 0; m < opt.m_unroll / opt.vec_width; ++m) {
    for (let k = 0; k < opt.k_unroll / opt.vec_width; ++k) {
      source += `
    float4x4 a_m_${m}_${k} = float4x4(`;
      for (let i = 0; i < opt.vec_width; ++i) {
        source += `a_${m * opt.vec_width + i}_${k}`;
        if (i != opt.vec_width - 1) {
          source += `, `;
        }
      }
      source += `);`;
    }
  }
  // load B vecs
  for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
    for (let k = 0; k < opt.k_unroll; ++k) {
      const idx = `(k * ${opt.k_unroll} + ${k}) * ${N / opt.vec_width} + (n * ${opt.n_unroll / opt.vec_width} + ${n})`;
      source += `
    float4 b_${n}_${k} = B[${idx}];`;
    }
  }
  for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
    for (let k = 0; k < opt.k_unroll / opt.vec_width; ++k) {
      source += `
    float4x4 b_m_${n}_${k} = float4x4(`;
      for (let i = 0; i < opt.vec_width; ++i) {
        source += `b_${n}_${k * opt.vec_width + i}`;
        if (i != opt.vec_width - 1) {
          source += `, `;
        }
      }
      source += `);`;
    }
  }

  // multiply matrices
  for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
    for (let m = 0; m < opt.m_unroll / opt.vec_width; ++m) {
      for (let k = 0; k < opt.k_unroll / opt.vec_width; ++k) {
        source += `
    result_${m}_${n} += mul(b_m_${n}_${k}, a_m_${m}_${k});`
      }
    }
  }

  source += `
  }`; // k

  // write to C
  for (let m = 0; m < opt.m_unroll / opt.vec_width; ++m) {
    for (let n = 0; n < opt.n_unroll / opt.vec_width; ++n) {
      for (let i = 0; i < opt.vec_width; ++i) {
        const idx = `(m * ${opt.m_unroll} + ${m * opt.vec_width + i}) * ${N / opt.vec_width} + (n * ${opt.n_unroll / opt.vec_width} + ${n})`;
        source += `
  C[${idx}] = result_${m}_${n}[${i}];`
      }
    }
  }
  

  source += `
}`; // main
  const dispatch = [M / opt.x_threads / opt.m_unroll, N / opt.y_threads / opt.n_unroll, 1];
  return [source, dispatch];
}

function _generateMatrixMultiplicationKernelOpt(M, N, K, opt) {
  let source = `
[numthreads(8, 8, 1)]
compute void main(constant float4[] A: register(u0),
                  constant float4[] B: register(u1),
                  device float4[] C: register(u2),
                  float3 threadID : SV_DispatchThreadID) {
  uint m = uint(threadID.x);
  uint n = uint(threadID.y);
  float4x4 result = float4x4(
    0.0, 0.0, 0.0, 0.0,
    0.0, 0.0, 0.0, 0.0,
    0.0, 0.0, 0.0, 0.0,
    0.0, 0.0, 0.0, 0.0);
  for (uint k = 0; k < ${K / 4}; ++k) {
    float4 a0 = A[(m * 4 + 0) * ${K / 4} + k];
    float4 a1 = A[(m * 4 + 1) * ${K / 4} + k];
    float4 a2 = A[(m * 4 + 2) * ${K / 4} + k];
    float4 a3 = A[(m * 4 + 3) * ${K / 4} + k];

    float4 b0 = B[(k * 4 + 0) * ${N / 4} + n];
    float4 b1 = B[(k * 4 + 1) * ${N / 4} + n];
    float4 b2 = B[(k * 4 + 2) * ${N / 4} + n];
    float4 b3 = B[(k * 4 + 3) * ${N / 4} + n];

    float4x4 b = float4x4(b0, b1, b2, b3);
    float4x4 a = float4x4(a0, a1, a2, a3);

    result += mul(b, a);
  }
  C[(m * 4 + 0) * ${N / 4} + n] = result[0];
  C[(m * 4 + 1) * ${N / 4} + n] = result[1];
  C[(m * 4 + 2) * ${N / 4} + n] = result[2];
  C[(m * 4 + 3) * ${N / 4} + n] = result[3];
}
`;
  const dispatch = [M / 8 / 4, N / 8 / 4, 1];
  return [source, dispatch];
}

function createMatrixMultiplication(device, M, N, K, opt) {

  // BindGroupLayout

  const visibility = GPUShaderStage.COMPUTE;
  const type = "storage-buffer";

  const bindGroupLayout = device.createBindGroupLayout({
    bindings: [
      { binding: 0, visibility: visibility, type: type },
      { binding: 1, visibility: visibility, type: type },
      { binding: 2, visibility: visibility, type: type },
    ]
  });

  // PipelineLayout

  const pipelineLayout = device.createPipelineLayout({
    bindGroupLayouts: [bindGroupLayout],
  });

  // ComputePipeline

  const [source, dispatch] = generateMatrixMultiplicationKernelOpt(M, N, K, opt);

  const computePipeline = device.createComputePipeline({
    layout: pipelineLayout,
    computeStage: {
      module: device.createShaderModule({
        code: source,
      }),
      entryPoint: "main"
    }
  });

  // define the mm function

  function mm(A, B, C) {
    const commandEncoder = device.createCommandEncoder();
    const bindGroup = device.createBindGroup({
      layout: bindGroupLayout,
      bindings: [
        { binding: 0, resource: { buffer: A, size: M * K * 4 } },
        { binding: 1, resource: { buffer: B, size: N * K * 4 } },
        { binding: 2, resource: { buffer: C, size: M * N * 4 } },
      ]
    });

    const passEncoder = commandEncoder.beginComputePass();
    passEncoder.setPipeline(computePipeline);
    passEncoder.setBindGroup(0, bindGroup);
    passEncoder.dispatch(dispatch[0], dispatch[1], dispatch[2]);
    passEncoder.endPass();
    return commandEncoder.finish();
  }

  return mm;
}

async function run(opt) {
  if (!navigator.gpu) {
    log("WebGPU not found.");
    return;
  }
  const adapter = await navigator.gpu.requestAdapter();
  const device = await adapter.requestDevice();

  const M = 1024;
  const N = 1024;
  const K = 1024;

  let f = await check(device, M / 4, N / 4, K / 4, opt);
  //let f = await check(device, M, N, K, opt);
  if (!f) {
    return;
  }

  const mm = createMatrixMultiplication(device, M, N, K, opt);

  let A = randGPU(device, M * K);
  let B = randGPU(device, K * N);
  let C = randGPU(device, M * N);

  // warmup
  device.getQueue().submit([
    mm(A, B, C),
    mm(C, B, A),
    mm(A, C, B),
    mm(B, A, C),
    mm(A, B, C),
    mm(C, B, A),
    mm(A, C, B),
    mm(B, A, C),
  ]);
  const warmup_res = await toCPU(device, C, M * N);
  console.log(warmup_res[0]);

  //log("benchmarking...");
  A = randGPU(device, M * K);
  B = randGPU(device, K * N);
  C = randGPU(device, M * N);
  const t0 = performance.now();
  device.getQueue().submit([
    mm(A, B, C),
    mm(C, B, A),
    mm(A, C, B),
    mm(B, A, C),
    mm(C, B, A),
    mm(A, B, C),
    mm(C, B, A),
    mm(A, C, B),
    mm(B, A, C),
    mm(C, B, A),
  ]);

  const result = await toCPU(device, C, M * N);
  console.log(result[0]);

  const t1 = performance.now();
  const flops = M * N * K * 2 * 10;
  const gflops = flops / ((t1 - t0) * 1e6);
  log("gflops:", gflops, "time:", t1 - t0);
  if (gflops > best) {
    best = gflops;
    let best_elem = document.getElementById('best');
    const [source, dispatch] = generateMatrixMultiplicationKernelOpt(M, N, K, opt);
    best_elem.textContent = 'best: ' + gflops.toFixed(2) + ' gflops\n' + source + '\n\n'
     + 'dispatch params: ' + dispatch;
  }
}

async function try_opts() {
  if (!navigator.gpu) {
    log("WebGPU not found. Be sure to use Safari and enable WebGPU in Develop > Experimental Features.");
    log(" ");
    log("If you don't have Safari, you can try TensorFlow's WebGL backend: https://jott.live/html/tf_mm.html");
    return;
  }
  log("Attempting to naively optimize matrix multiplication...\nattempts below:\n");
  for (let n of [4, 8, 16]) {
    for (let k of [4, 8, 16]) {
      for (let m of [2, 4, 8, 16]) {
        for (let x of [2, 4, 8, 16]) {
          for (let y of [2, 4, 8, 16]) {
            for (let use_matrix of [0, 1]) {
              for (let use_mad of [0, 1]) {
              if (use_matrix && (m < 4 || k < 4)) { continue; }
              // pretty much always bad to swap
              for (let swap of [0]) {
                const opt = { 
                  n_unroll: n,
                  m_unroll: m,
                  k_unroll: k,
                  x_threads : x,
                  y_threads : y,
                  matrix: use_matrix,
                  use_mad: use_mad,
                  swap_threads : swap,
                  vec_width: 4
                };
                log(`n: ${n} k: ${k} m: ${m} tx: ${x} ty: ${y} swap: ${swap} use matrix: ${use_matrix} use mad: ${use_mad}`);
                await run(opt);
              }
              }
            }
          }
        }
      }
    }
  }
}
window.addEventListener('load', try_opts);