Skip to content

[MPS] MPSNDArray error: product of dimension sizes > 2**31 #84039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
junukwon7 opened this issue Aug 25, 2022 · 38 comments
Closed

[MPS] MPSNDArray error: product of dimension sizes > 2**31 #84039

junukwon7 opened this issue Aug 25, 2022 · 38 comments
Assignees
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@junukwon7
Copy link

junukwon7 commented Aug 25, 2022

🐛 Describe the bug

Full error message (no traceback):

AppleInternal/Library/BuildRoots/20d6c351-ee94-11ec-bcaf-7247572f23b4/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:705: failed assertion '[MPSNDArray initWithDevice:descriptor:] Error: product of dimension sizes > 2**31 '

How to reproduce

  1. Install stable-diffusion using instructions for macOS
  2. run python scripts/txt2img.py --prompt "a horse" --plms --n_samples 1 --n_rows 1 --n_iter 1: runs well.
  3. But if you add --W 1024 --H 1024 flag, which means width and height respectively, then it'll return the error.
  • default width and height is 512, so no flag means 512x512

I'm finding a way to reproduce it without installing the whole procedure, so I'll update the procedure soon.

Edit: repro by @Birch-san

from torch import einsum, ones
import argparse

parser = argparse.ArgumentParser(description='mpsndarray test')
parser.add_argument('--n_samples', type=int, default=2)
args = parser.parse_args()
n_samples = args.n_samples

einsum('b i d, b j d -> b i j', ones(16 * n_samples, 4096, 40, device='mps'), ones(16 * n_samples, 4096, 40, device='mps')).shape

print(n_samples, 'passed')

It fails when n_samples is 2 or over 7. Which looks pretty weird.

About vram?

As you would all expect, the error seems to be something about VRAM. However, there remains question.

  1. The error seems to be the size exceeding INT_MAX(2**31)
    The error doesn't occur at --W 512 --H 512 or lower resolution.
  2. The error is a software issue
    Unlike errors like CUDA out of memory, this error isn't about the real memory limit.
    If the error was due to lack of VRAM, the code above (--W 1024 --H 1024) should run on M1 Max 64GB since --W 512 --H 512 runs well on my M1 8G macbook. Also, the limit 2**31 is a fixed number, which would not change from the current memory usage.

So, my expectation is that something is being computed in 32-bit, which shouldn't be.

This might not be torch's problem - maybe(surely) metal.

However, all helps will be accepted gracefully.

Thanks.

Versions

PyTorch version: 1.13.0.dev20220824
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.5.1 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: version 3.24.1
Libc version: N/A

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:14) [Clang 12.0.1 ] (64-bit runtime)
Python platform: macOS-12.5.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.2
[pip3] pytorch-lightning==1.7.2
[pip3] torch==1.13.0.dev20220824
[pip3] torch-fidelity==0.3.0
[pip3] torchaudio==0.13.0.dev20220824
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.14.0.dev20220824
[conda] numpy 1.23.2 py38h579d673_0 conda-forge
[conda] pytorch 1.13.0.dev20220824 py3.8_0 pytorch-nightly
[conda] pytorch-lightning 1.7.2 pypi_0 pypi
[conda] torch-fidelity 0.3.0 pypi_0 pypi
[conda] torchaudio 0.13.0.dev20220824 py38_cpu pytorch-nightly
[conda] torchmetrics 0.9.3 pypi_0 pypi
[conda] torchvision 0.14.0.dev20220824 py38_cpu pytorch-nightly

cc @kulinseth @albanD

@DenisVieriu97 DenisVieriu97 added the module: mps Related to Apple Metal Performance Shaders framework label Aug 25, 2022
@Karric
Copy link

Karric commented Aug 28, 2022

I had the same issue, but with M1 mac + tensorflow (for mac), no pytorch but I am using numpy. Not sure if that's helpful.

@junukwon7
Copy link
Author

I had the same issue, but with M1 mac + tensorflow (for mac), no pytorch but I am using numpy. Not sure if that's helpful.

@Karric Thanks for your reply. It seems like the error is something about metal then. Could you provide more information about the situation with m1 mac + tensorflow?

Thanks.

@dagitses dagitses added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 29, 2022
@i3oc9i
Copy link

i3oc9i commented Sep 5, 2022

I have the same issue on my Mac Studio Ultra with 128G :(

@pcuenca
Copy link

pcuenca commented Sep 6, 2022

We are having the same problem in 🤗 diffusers, but I fail to see where the offending dimensions are being generated or used. The error arises when attempting to run the unet module of the model on input latents with shape (4, 4, 64, 64). This is what torchinfo.summary has to say about the output shapes of that model's layers:

===================================================================================================================
Layer (type:depth-idx)                                            Output Shape              Param #
===================================================================================================================
UNet2DConditionModel                                              [4, 4, 64, 64]            510,807,680
├─Conv2d: 1-4                                                     [4, 320, 64, 64]          (recursive)
├─Timesteps: 1-2                                                  [4, 320]                  --
├─TimestepEmbedding: 1-3                                          [4, 1280]                 --
│    └─Linear: 2-1                                                [4, 1280]                 410,880
│    └─SiLU: 2-2                                                  [4, 1280]                 --
│    └─Linear: 2-3                                                [4, 1280]                 1,639,680
├─Conv2d: 1-4                                                     [4, 320, 64, 64]          (recursive)
├─ModuleList: 1                                                   --                        --
│    └─CrossAttnDownBlock2D: 2-4                                  [4, 320, 32, 32]          --
│    └─CrossAttnDownBlock2D: 2-5                                  [4, 640, 16, 16]          --
│    └─CrossAttnDownBlock2D: 2-6                                  [4, 1280, 8, 8]           --
│    └─DownBlock2D: 2-7                                           [4, 1280, 8, 8]           --
├─UNetMidBlock2DCrossAttn: 1-5                                    [4, 1280, 8, 8]           --
│    └─ModuleList: 2                                              --                        --
│    │    └─ResnetBlock2D: 3-1                                    [4, 1280, 8, 8]           31,138,560
│    └─ModuleList: 2                                              --                        --
│    │    └─SpatialTransformer: 3-2                               [4, 1280, 8, 8]           34,760,960
│    └─ModuleList: 2                                              --                        --
│    │    └─ResnetBlock2D: 3-3                                    [4, 1280, 8, 8]           31,138,560
├─ModuleList: 1                                                   --                        --
│    └─UpBlock2D: 2-8                                             [4, 1280, 16, 16]         --
│    └─CrossAttnUpBlock2D: 2-9                                    [4, 1280, 32, 32]         --
│    └─CrossAttnUpBlock2D: 2-10                                   [4, 640, 64, 64]          --
│    └─CrossAttnUpBlock2D: 2-11                                   [4, 320, 64, 64]          --
├─GroupNorm: 1-6                                                  [4, 320, 64, 64]          640
├─SiLU: 1-7                                                       [4, 320, 64, 64]          --
├─Conv2d: 1-8                                                     [4, 4, 64, 64]            11,524
===================================================================================================================

It looks to me that those sizes shouldn't trigger any internal Metal limitation about 32-bit sizes being exceeded. Is there any hint as what might be causing this? Thank you!

@pcuenca
Copy link

pcuenca commented Sep 6, 2022

In fact, as @FahimF reports, the crash does not occur for larger batches such as those that would result in latent shapes of (6, 4, 64, 64) or (8, 4, 64, 64). Is there any way to debug where the problem might be triggering?

@Birch-san
Copy link

Birch-san commented Sep 7, 2022

I have a repro (from stable-diffusion's attention forward-pass).

https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L180

from torch import einsum, ones
# crashes with "product of dimension sizes > 2**31"
# this is equivalent to invoking stable-diffusion with --n_samples 2
einsum('b i d, b j d -> b i j', ones(32, 4096, 40, device='mps'), ones(32, 4096, 40, device='mps')).shape

# doesn't crash, even though it's bigger
# this is equivalent to invoking stable-diffusion with --n_samples 3
einsum('b i d, b j d -> b i j', ones(48, 4096, 40, device='mps'), ones(48, 4096, 40, device='mps')).shape

Perhaps related to this?

#80808 (comment)

Depending on the size/number of dimensions, different algorithm might get selected leading to small differences.

Perhaps at the larger size of 48: we hit tensor size which persuades einsum() to use a different algorithm, which doesn't crash?

I'm not sure whether this is necessarily the same line of stable-diffusion code on which @junukwon's is crashing (i.e. due to image dimensions). but if we're lucky it's the same mechanism.

@junukwon7
Copy link
Author

junukwon7 commented Sep 7, 2022

@Birch-san

Thanks for your repro!
I've tried running a bash script for that, and found out that 16*i fails when i == 2 or i >= 8.

from torch import einsum, ones
import argparse

parser = argparse.ArgumentParser(description='mpsndarray test')
parser.add_argument('--n_samples', type=int, default=2)
args = parser.parse_args()
n_samples = args.n_samples

einsum('b i d, b j d -> b i j', ones(16 * n_samples, 4096, 40, device='mps'), ones(16 * n_samples, 4096, 40, device='mps')).shape

print(n_samples, 'passed')

made this as a python file and ran

for i in {1..80}; do python3 mpsndarray.py --n_samples ${i}; done

The result was:

1 pass
2 ERR
3 ~ 7 pass
8 ~ 35 ERR
36 ~ RuntimeError: Invalid buffer size: 36.00 GB

M1 Max MBP with 64GB RAM

@pcuenca
Copy link

pcuenca commented Sep 9, 2022

Regarding the reproducibility code, @patil-suraj replaced einsum with matmul in the diffusers codebase and the problem still occurs in exactly the same way.

In fact, following up on @Birch-san's example, this behaves exactly like they describe:

# Crashes
t1 = torch.rand((32, 4096, 40))
t2 = torch.rand((32, 4096, 40))
torch.matmul(t1.to("mps"), t2.to("mps").transpose(1, 2)).shape

# Doesn't crash, even though it's bigger
t1 = torch.rand((48, 4096, 40))
t2 = torch.rand((48, 4096, 40))
torch.matmul(t1.to("mps"), t2.to("mps").transpose(1, 2)).shape

In addition, this is something different but very weird:

t1 = torch.rand((48, 4096, 40))
t2 = torch.rand((48, 4096, 40))

x_mps = einsum('b i d, b j d -> b i j', t1.to('mps'), t2.to('mps'))
x_cpu = einsum('b i d, b j d -> b i j', t1, t2)
print((x_mps.to("cpu") - x_cpu).abs().max())
# tensor(9.4567) !?

As you can see, the output seems to be wrong. However, if we do other operations first, the einsum is now correct:

t1 = torch.rand((48, 4096, 40))
t2 = torch.rand((48, 4096, 40))

x_mm_mps = torch.matmul(t1.to("mps"), t2.to("mps").transpose(1, 2))
x_mm_cpu = torch.matmul(t1, t2.transpose(1, 2))
print((x_mm_mps.to("cpu") - x_mm_cpu).abs().max())
# tensor(0.)

x_mps = einsum('b i d, b j d -> b i j', t1.to('mps'), t2.to('mps'))
x_cpu = einsum('b i d, b j d -> b i j', t1, t2)
print((x_mps.to("cpu") - x_cpu).abs().max())
# tensor(0.)

Is there any guidance or hint as to what could be going on here?

TL:DR:

  • torch.matmul crashes for some sizes when using mps.
  • torch.einsum produces inconsistent results "sometimes". Doing some previous operations with similar memory requirements seems to prevent the issue, maybe?

@DenisVieriu97
Copy link
Collaborator

Thank you for the report @junukwon7 and for the repro code @pcuenca, @Birch-san. This is an issue on the MPS side - we are working on a fix.

@malfet
Copy link
Contributor

malfet commented Sep 20, 2022

I have an even simpler 1-line reproducer:

% python -c "import torch;print(torch.ones(32, 4096, 4096, device='mps').shape)"
/AppleInternal/Library/BuildRoots/5381bdfb-27e8-11ed-bdc1-96898e02b808/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:705: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: product of dimension sizes > 2**31'

@malfet
Copy link
Contributor

malfet commented Sep 20, 2022

And I can reproduce matmul crash using following pure-MPS script:

import MetalPerformanceShadersGraph

let graph = MPSGraph()
let x = graph.constant(1, shape: [32, 4096, 40], dataType: .float32)
let y = graph.constant(1, shape: [32, 40, 4096], dataType: .float32)
let z = graph.matrixMultiplication(primary: x, secondary: y, name: nil)
let device = MTLCreateSystemDefaultDevice()!
let buf = device.makeBuffer(length: 16384)!
let td = MPSGraphTensorData(buf, shape: [64, 64], dataType: .int32)
let cmdBuf = MPSCommandBuffer(from: device.makeCommandQueue()!)
graph.encode(to: cmdBuf, feeds: [:], targetOperations: nil, resultsDictionary: [z:td], executionDescriptor: nil)
cmdBuf.commit()

@kulinseth
Copy link
Collaborator

And I can reproduce matmul crash using following pure-MPS script:

import MetalPerformanceShadersGraph

let graph = MPSGraph()
let x = graph.constant(1, shape: [32, 4096, 40], dataType: .float32)
let y = graph.constant(1, shape: [32, 40, 4096], dataType: .float32)
let z = graph.matrixMultiplication(primary: x, secondary: y, name: nil)
let device = MTLCreateSystemDefaultDevice()!
let buf = device.makeBuffer(length: 16384)!
let td = MPSGraphTensorData(buf, shape: [64, 64], dataType: .int32)
let cmdBuf = MPSCommandBuffer(from: device.makeCommandQueue()!)
graph.encode(to: cmdBuf, feeds: [:], targetOperations: nil, resultsDictionary: [z:td], executionDescriptor: nil)
cmdBuf.commit()

Thanks @malfet , as Denis mentioned above, this is a MPS side issue. So I expect this would be reproduced by a simple test-case outside of Torch as well. There were couple of issues one was the Heap logic which allocated size for MPSNDArray (basic struct for allocating Tensor in MPS) was incorrectly error'ing out with dimension product size when compared with the buffer size allocated. And there was restriction on how much the buffer size can be allocated. Both the conditions are updated and fixed. These will be available in upcoming updates. Having said this we do have a current limit of the NDArray which can be created which is 2^32. So for instance this size will work : (63, 4096, 4096) 32-bit values will work but not 64.

@junukwon7
Copy link
Author

@kulinseth Thanks for your response.
May I ask why the devs limited the size to 32-bit values?

@malfet
Copy link
Contributor

malfet commented Sep 27, 2022

@junukwon7 I don't know the exact details, but I assume using 32-bit indexes results in faster kernels, as one can perform twice as much 32-bit operations per one SIMD instruction compared to 64-bit ones.

Consider AVX2 instruction set as an example, _mm256_add_epi32 and _mm256_add_epi64 have the same latency, but former can do 8 32-bit additions, while later does 4 64-bit ones. That is, more threads are needed to accomplish the same computation on CPU using 64-bit indices than 32-bit ones.

And as majority of tensor has less than 4 billion elements, it makes perfect sense to use 32-bit indices while performing operations on them.

@Vargol
Copy link

Vargol commented Oct 12, 2022

Hi.

Is the fix for this in the nightlies as I'm still seeing this error under some strange circumstances.

e.g. in the following code note the first tensor is actually larger than the second which fails
due to the > 2**31 error

import torch

mps_device = torch.device("mps")

ix = torch.zeros(1, 256, 960, 1024, device=mps_device)
iy = torch.nn.functional.interpolate(ix, scale_factor=2.0, mode="nearest")

print(str(iy.device))
print(str(iy.shape))

ix = torch.zeros(1, 256, 512, 1024, device=mps_device)
iy = torch.nn.functional.interpolate(ix, scale_factor=2.0, mode="nearest")

print(str(iy.device))
print(str(iy.shape))

gives

% python test.py
mps:0
torch.Size([1, 256, 1920, 2048])
/AppleInternal/Library/BuildRoots/a0876c02-1788-11ed-b9c4-96898e02b808/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:705: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: product of dimension sizes > 2**31'
zsh: abort      python test.py

using this version of pytorch

 % pip list | grep torch
torch              1.14.0.dev20221011
torchaudio         0.13.0.dev20221010
torchvision        0.15.0.dev20221010

It seems to be that exact number of values that fails

if ix = torch.zeros(1, 256, 512, 1025, device=mps_device) torch.nn.functional.interpolate works
if ix = torch.zeros(2, 256, 256, 1024, device=mps_device) torch.nn.functional.interpolate fails.

@kulinseth
Copy link
Collaborator

Hi.

Is the fix for this in the nightlies as I'm still seeing this error under some strange circumstances.

e.g. in the following code note the first tensor is actually larger than the second which fails due to the > 2**31 error

import torch

mps_device = torch.device("mps")

ix = torch.zeros(1, 256, 960, 1024, device=mps_device)
iy = torch.nn.functional.interpolate(ix, scale_factor=2.0, mode="nearest")

print(str(iy.device))
print(str(iy.shape))

ix = torch.zeros(1, 256, 512, 1024, device=mps_device)
iy = torch.nn.functional.interpolate(ix, scale_factor=2.0, mode="nearest")

print(str(iy.device))
print(str(iy.shape))

gives

% python test.py
mps:0
torch.Size([1, 256, 1920, 2048])
/AppleInternal/Library/BuildRoots/a0876c02-1788-11ed-b9c4-96898e02b808/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:705: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: product of dimension sizes > 2**31'
zsh: abort      python test.py

using this version of pytorch

 % pip list | grep torch
torch              1.14.0.dev20221011
torchaudio         0.13.0.dev20221010
torchvision        0.15.0.dev20221010

It seems to be that exact number of values that fails

if ix = torch.zeros(1, 256, 512, 1025, device=mps_device) torch.nn.functional.interpolate works if ix = torch.zeros(2, 256, 256, 1024, device=mps_device) torch.nn.functional.interpolate fails.

@Vargol , the fix is not in PyTorch Nightlies. This required fix is needed on OS side MPS library. Can you please try with Latest Ventura Beta build? Locally it seems to work and outputs:

mps:0
torch.Size([1, 256, 1920, 2048])
mps:0
torch.Size([1, 256, 1024, 2048])

@Vargol
Copy link

Vargol commented Oct 13, 2022

No plans on installing a Beta OS, especially one where some to the tools I use are not fully working on yet :-)

@kulinseth
Copy link
Collaborator

No plans on installing a Beta OS, especially one where some to the tools I use are not fully working on yet :-)

Sure, you can wait till official release and try it out.

@pcuenca
Copy link

pcuenca commented Oct 17, 2022

Update after testing PyTorch 1.13.0 from test on Ventura 13.0 Beta (22A5373b).

  • My previous crash repro, and the ones reported by @malfet, now work as reported by @kulinseth :)
  • However, this Python script crashes:
import torch

t1 = torch.ones((32, 4096, 4096))
t2 = torch.ones((32, 4096, 1))
torch.matmul(t1.to("mps"), t2.to("mps")).shape

Output:

/AppleInternal/Library/BuildRoots/48415f5a-4155-11ed-be84-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:705: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: product of dimension sizes > 2**31'

But this Swift script (just changed the dimensions in @malfet's example) doesn't:

import MetalPerformanceShadersGraph

let graph = MPSGraph()
let x = graph.constant(1, shape: [32, 4096, 4096], dataType: .float32)
let y = graph.constant(1, shape: [32, 4096, 1], dataType: .float32)
let z = graph.matrixMultiplication(primary: x, secondary: y, name: nil)
let device = MTLCreateSystemDefaultDevice()!
let buf = device.makeBuffer(length: 16384)!
let td = MPSGraphTensorData(buf, shape: [64, 64], dataType: .int32)
let cmdBuf = MPSCommandBuffer(from: device.makeCommandQueue()!)
graph.encode(to: cmdBuf, feeds: [:], targetOperations: nil, resultsDictionary: [z:td], executionDescriptor: nil)
cmdBuf.commit()
  • The following diffusers code still crashes:
from diffusers import StableDiffusionPipeline

sdm = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    safety_checker=None,
).to("mps")

prompt = "A painting of a squirrel eating a burger"
num_samples = 2

images = sdm(prompt, num_images_per_prompt=num_samples).images
for i, image in enumerate(images):
    image.save(f"squirrel_{i}.png")
  • The previous diffusers code does not crash if num_samples is increased to 3, but the resulting images are corrupt.

@JacopoMangiavacchi
Copy link

I'm having the same issue with the VToonify model from https://github.com/williamyang1991/VToonify

@pcuenca
Copy link

pcuenca commented Dec 2, 2022

Gently pinging @kulinseth. Do we have confirmation whether this lies on the mps side or in PyTorch itself, given that the MPS Graph version worked fine when I checked?
#84039 (comment)

@kulinseth
Copy link
Collaborator

@pcuenca Sorry for the delay in response. The underlying issue which you linked in your comment comes during creation of the NDArray buffer (in MPS framework) which is shared between PyTorch and MPSGraph.
You are indeed right, its coming from PyTorch. The difference in behavior is coming from how the PyTorch's torch.matmul layer is implemented (the generic implementation, uses Transposes and reshapes while implementing the graph).

I will take a look as how we can improve this layer. I will update here with more details soon.

@pcuenca
Copy link

pcuenca commented Jan 4, 2023

Hi @kulinseth, sorry for the ping :) It isn't urgent, but a broad time-frame estimation would be awesome here. Thanks a lot for your work!

@kulinseth
Copy link
Collaborator

kulinseth commented Jan 6, 2023

@pcuenca broadly we are targeting a fix or a workaround for this issue in 2.0 timeline

@pcuenca
Copy link

pcuenca commented Jan 6, 2023

Thanks, @kulinseth! Happy to test on any of the nightlies when it makes it there.

@tux-o-matic
Copy link

@kulinseth, which version of Python is Apple testing the MPS backend with? I'm seeing many MPS backed ML tasks work on Python 3.9 but nothing newer.

@Birch-san
Copy link

Birch-san commented Jan 14, 2023

Can you give some examples? Because I've done stable diffusion inference and TI training just fine on Python 3.10, with mainline master branch, and have done inference just fine with kulinseth's master branch. On Ventura 13.1 public beta 4.

@tux-o-matic
Copy link

@Birch-san, I’m asking cause I’ve seen the same stable diffusion project and different Torch 2.0 nightly run with MPS on Python 3.9 but throw the error from this issue when running with Python 3.10.
And like many, hoping to see a nightly build of Torch on Python 3.11 for arm64.

@pcuenca
Copy link

pcuenca commented Mar 8, 2023

@kulinseth My repro above works fine on macOS Ventura 13.3 beta, thanks! However, inference using diffusers with some model architectures still fails (but it works for others). For example, the following snippet fails with error NDArray dimension length > INT_MAX':

from diffusers import StableDiffusionPipeline

sdm = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    safety_checker=None,
).to("mps")

prompt = "A painting of a squirrel eating a burger"
num_samples = 2

images = sdm(prompt, num_images_per_prompt=num_samples, num_inference_steps=20).images
for i, image in enumerate(images):
    image.save(f"squirrel_{i}.png")

As a workaround, enabling attention slicing in the snippet above works fine (sdm.enable_attention_slicing()). Using model stabilityai/stable-diffusion-2-1-base (different attention heads) works fine too.

I'll try to isolate a more specific reproduction.

@kulinseth
Copy link
Collaborator

Thanks @pcuenca , that will help . Meanwhile we will look at the example you provided . There are indeed integer limits on dimension lengths for the underlying buffers currently .

@qqaatw
Copy link
Collaborator

qqaatw commented Mar 26, 2023

@kulinseth My repro above works fine on macOS Ventura 13.3 beta, thanks! However, inference using diffusers with some model architectures still fails (but it works for others). For example, the following snippet fails with error NDArray dimension length > INT_MAX':

from diffusers import StableDiffusionPipeline

sdm = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    safety_checker=None,
).to("mps")

prompt = "A painting of a squirrel eating a burger"
num_samples = 2

images = sdm(prompt, num_images_per_prompt=num_samples, num_inference_steps=20).images
for i, image in enumerate(images):
    image.save(f"squirrel_{i}.png")

As a workaround, enabling attention slicing in the snippet above works fine (sdm.enable_attention_slicing()). Using model stabilityai/stable-diffusion-2-1-base (different attention heads) works fine too.

I'll try to isolate a more specific reproduction.

With the repro, the problem can be narrowed down to:

device = "mps"
for bsz in range(16):
    if bsz in [4]:  # only bsz == 4 fails.
        continue
    query = torch.randn((bsz, 8, 4096, 40), device=device)
    key = torch.randn((bsz, 8, 4096, 40), device=device)
    value = torch.randn((bsz, 8, 4096, 40), device=device)
    hidden_states = torch.nn.functional.scaled_dot_product_attention(
        query, key, value, dropout_p=0.0, is_causal=False
    )
    # or simply
    hidden_states = torch.matmul(query, key.transpose(-2, -1))
    # or
    hidden_states = torch.bmm(query.view(-1, 4096, 40),  key.transpose(-2, -1).view(-1, 40, 4096))

Since it only fails with bsz==4 and the failing point is when runMPSGraph is executed in the bmm mps impl., I assume that the graph gathering (Placeholder creation) has no issue and suspect that there is an optimized (or not optimized) path for matrix multiplication which is triggered by some condition related to the matrix sizes in the mps library, and that condition has a bug that allows the resulting matrix larger than 2 ** 31 dims getting into the path.

@malfet
Copy link
Contributor

malfet commented Jan 9, 2024

I can not reproduce any of the failures reported about while running on MacOS Sonoma, which makes we hopeful some some of the issues were addressed. cc: @kulinseth

@kulinseth
Copy link
Collaborator

I can not reproduce any of the failures reported about while running on MacOS Sonoma, which makes we hopeful some some of the issues were addressed. cc: @kulinseth

That’s great @malfet , thanks for confirming . Can we add the above test to our test_mps when we enable the Sonoma runners ?

@malfet
Copy link
Contributor

malfet commented Jan 9, 2024

I can not reproduce any of the failures reported about while running on MacOS Sonoma, which makes we hopeful some some of the issues were addressed. cc: @kulinseth

That’s great @malfet , thanks for confirming . Can we add the above test to our test_mps when we enable the Sonoma runners ?

@kulinseth I've already did accidentally while working on 64-bit index select (see #116942 ) and now working on PR that raises exception if one tries to allocate 4Gb+ tensor on Ventura and change the skip to xfail )

@kulinseth
Copy link
Collaborator

I can not reproduce any of the failures reported about while running on MacOS Sonoma, which makes we hopeful some some of the issues were addressed. cc: @kulinseth

That’s great @malfet , thanks for confirming . Can we add the above test to our test_mps when we enable the Sonoma runners ?

@kulinseth I've already did accidentally while working on 64-bit index select (see #116942 ) and now working on PR that raises exception if one tries to allocate 4Gb+ tensor on Ventura and change the skip to xfail )

awesome , thanks . The PR for 64-bit looks good to me , we can close this issue with that

@bghira
Copy link

bghira commented Mar 25, 2024

#116942 is now closed

@rusnov
Copy link

rusnov commented Dec 17, 2024

Is this issue still relevant? I am getting the following error with M4 Sequoia:
/AppleInternal/Library/BuildRoots/.../Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:850: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: total bytes of NDArray > 2**32'

The minimalistic code to produce this error:

import torch
device = torch.device("mps")
mask_bool = torch.triu(torch.ones(1024, 1024, device=device), diagonal=1).bool()
attn_scores = torch.rand(48, 25, 1024, 1024, device=device)
attn_scores.masked_fill_(mask_bool, 0)

Created Issue: #143477

@jhavukainen
Copy link
Collaborator

This initial issue here has been addressed first in commit 92f282c further expanding the range in commit afa313e. Note that it will require user to be on MacOS15, otherwise user only gets the error and recommendation to update. This is due to tiling the operation requiring an API that was only released on MacOS 15.0.

@rusnov the masked_fill_ issue you noted seems separate, although I'm pretty sure the root cause and solution will be similar. I'll continue that thread in the issue you filed. Thanks for bringing this up!

I'll close this issue as addressed. Please reopen if you feel this is inaccurate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests