-
Notifications
You must be signed in to change notification settings - Fork 24k
[MPS] MPSNDArray error: product of dimension sizes > 2**31 #84039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I had the same issue, but with M1 mac + tensorflow (for mac), no pytorch but I am using numpy. Not sure if that's helpful. |
@Karric Thanks for your reply. It seems like the error is something about metal then. Could you provide more information about the situation with m1 mac + tensorflow? Thanks. |
I have the same issue on my Mac Studio Ultra with 128G :( |
We are having the same problem in 🤗 diffusers, but I fail to see where the offending dimensions are being generated or used. The error arises when attempting to run the
It looks to me that those sizes shouldn't trigger any internal Metal limitation about 32-bit sizes being exceeded. Is there any hint as what might be causing this? Thank you! |
In fact, as @FahimF reports, the crash does not occur for larger batches such as those that would result in latent shapes of |
I have a repro (from stable-diffusion's attention forward-pass). from torch import einsum, ones
# crashes with "product of dimension sizes > 2**31"
# this is equivalent to invoking stable-diffusion with --n_samples 2
einsum('b i d, b j d -> b i j', ones(32, 4096, 40, device='mps'), ones(32, 4096, 40, device='mps')).shape
# doesn't crash, even though it's bigger
# this is equivalent to invoking stable-diffusion with --n_samples 3
einsum('b i d, b j d -> b i j', ones(48, 4096, 40, device='mps'), ones(48, 4096, 40, device='mps')).shape Perhaps related to this?
Perhaps at the larger size of 48: we hit tensor size which persuades I'm not sure whether this is necessarily the same line of stable-diffusion code on which @junukwon's is crashing (i.e. due to image dimensions). but if we're lucky it's the same mechanism. |
Thanks for your repro! from torch import einsum, ones
import argparse
parser = argparse.ArgumentParser(description='mpsndarray test')
parser.add_argument('--n_samples', type=int, default=2)
args = parser.parse_args()
n_samples = args.n_samples
einsum('b i d, b j d -> b i j', ones(16 * n_samples, 4096, 40, device='mps'), ones(16 * n_samples, 4096, 40, device='mps')).shape
print(n_samples, 'passed') made this as a python file and ran for i in {1..80}; do python3 mpsndarray.py --n_samples ${i}; done The result was:
M1 Max MBP with 64GB RAM |
Regarding the reproducibility code, @patil-suraj replaced In fact, following up on @Birch-san's example, this behaves exactly like they describe: # Crashes
t1 = torch.rand((32, 4096, 40))
t2 = torch.rand((32, 4096, 40))
torch.matmul(t1.to("mps"), t2.to("mps").transpose(1, 2)).shape
# Doesn't crash, even though it's bigger
t1 = torch.rand((48, 4096, 40))
t2 = torch.rand((48, 4096, 40))
torch.matmul(t1.to("mps"), t2.to("mps").transpose(1, 2)).shape In addition, this is something different but very weird: t1 = torch.rand((48, 4096, 40))
t2 = torch.rand((48, 4096, 40))
x_mps = einsum('b i d, b j d -> b i j', t1.to('mps'), t2.to('mps'))
x_cpu = einsum('b i d, b j d -> b i j', t1, t2)
print((x_mps.to("cpu") - x_cpu).abs().max())
# tensor(9.4567) !? As you can see, the output seems to be wrong. However, if we do other operations first, the t1 = torch.rand((48, 4096, 40))
t2 = torch.rand((48, 4096, 40))
x_mm_mps = torch.matmul(t1.to("mps"), t2.to("mps").transpose(1, 2))
x_mm_cpu = torch.matmul(t1, t2.transpose(1, 2))
print((x_mm_mps.to("cpu") - x_mm_cpu).abs().max())
# tensor(0.)
x_mps = einsum('b i d, b j d -> b i j', t1.to('mps'), t2.to('mps'))
x_cpu = einsum('b i d, b j d -> b i j', t1, t2)
print((x_mps.to("cpu") - x_cpu).abs().max())
# tensor(0.) Is there any guidance or hint as to what could be going on here? TL:DR:
|
Thank you for the report @junukwon7 and for the repro code @pcuenca, @Birch-san. This is an issue on the MPS side - we are working on a fix. |
I have an even simpler 1-line reproducer:
|
And I can reproduce matmul crash using following pure-MPS script: import MetalPerformanceShadersGraph
let graph = MPSGraph()
let x = graph.constant(1, shape: [32, 4096, 40], dataType: .float32)
let y = graph.constant(1, shape: [32, 40, 4096], dataType: .float32)
let z = graph.matrixMultiplication(primary: x, secondary: y, name: nil)
let device = MTLCreateSystemDefaultDevice()!
let buf = device.makeBuffer(length: 16384)!
let td = MPSGraphTensorData(buf, shape: [64, 64], dataType: .int32)
let cmdBuf = MPSCommandBuffer(from: device.makeCommandQueue()!)
graph.encode(to: cmdBuf, feeds: [:], targetOperations: nil, resultsDictionary: [z:td], executionDescriptor: nil)
cmdBuf.commit() |
Thanks @malfet , as Denis mentioned above, this is a MPS side issue. So I expect this would be reproduced by a simple test-case outside of Torch as well. There were couple of issues one was the Heap logic which allocated size for MPSNDArray (basic struct for allocating Tensor in MPS) was incorrectly error'ing out with dimension product size when compared with the buffer size allocated. And there was restriction on how much the buffer size can be allocated. Both the conditions are updated and fixed. These will be available in upcoming updates. Having said this we do have a current limit of the NDArray which can be created which is 2^32. So for instance this size will work : (63, 4096, 4096) 32-bit values will work but not 64. |
@kulinseth Thanks for your response. |
@junukwon7 I don't know the exact details, but I assume using 32-bit indexes results in faster kernels, as one can perform twice as much 32-bit operations per one SIMD instruction compared to 64-bit ones. Consider AVX2 instruction set as an example, And as majority of tensor has less than 4 billion elements, it makes perfect sense to use 32-bit indices while performing operations on them. |
Hi. Is the fix for this in the nightlies as I'm still seeing this error under some strange circumstances. e.g. in the following code note the first tensor is actually larger than the second which fails
gives
using this version of pytorch
It seems to be that exact number of values that fails if ix = torch.zeros(1, 256, 512, 1025, device=mps_device) torch.nn.functional.interpolate works |
@Vargol , the fix is not in PyTorch Nightlies. This required fix is needed on OS side MPS library. Can you please try with Latest Ventura Beta build? Locally it seems to work and outputs:
|
No plans on installing a Beta OS, especially one where some to the tools I use are not fully working on yet :-) |
Sure, you can wait till official release and try it out. |
Update after testing PyTorch 1.13.0 from
import torch
t1 = torch.ones((32, 4096, 4096))
t2 = torch.ones((32, 4096, 1))
torch.matmul(t1.to("mps"), t2.to("mps")).shape Output:
But this Swift script (just changed the dimensions in @malfet's example) doesn't: import MetalPerformanceShadersGraph
let graph = MPSGraph()
let x = graph.constant(1, shape: [32, 4096, 4096], dataType: .float32)
let y = graph.constant(1, shape: [32, 4096, 1], dataType: .float32)
let z = graph.matrixMultiplication(primary: x, secondary: y, name: nil)
let device = MTLCreateSystemDefaultDevice()!
let buf = device.makeBuffer(length: 16384)!
let td = MPSGraphTensorData(buf, shape: [64, 64], dataType: .int32)
let cmdBuf = MPSCommandBuffer(from: device.makeCommandQueue()!)
graph.encode(to: cmdBuf, feeds: [:], targetOperations: nil, resultsDictionary: [z:td], executionDescriptor: nil)
cmdBuf.commit()
from diffusers import StableDiffusionPipeline
sdm = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
).to("mps")
prompt = "A painting of a squirrel eating a burger"
num_samples = 2
images = sdm(prompt, num_images_per_prompt=num_samples).images
for i, image in enumerate(images):
image.save(f"squirrel_{i}.png")
|
I'm having the same issue with the VToonify model from https://github.com/williamyang1991/VToonify |
Gently pinging @kulinseth. Do we have confirmation whether this lies on the |
@pcuenca Sorry for the delay in response. The underlying issue which you linked in your comment comes during creation of the NDArray buffer (in MPS framework) which is shared between PyTorch and MPSGraph. I will take a look as how we can improve this layer. I will update here with more details soon. |
Hi @kulinseth, sorry for the ping :) It isn't urgent, but a broad time-frame estimation would be awesome here. Thanks a lot for your work! |
@pcuenca broadly we are targeting a fix or a workaround for this issue in 2.0 timeline |
Thanks, @kulinseth! Happy to test on any of the nightlies when it makes it there. |
@kulinseth, which version of Python is Apple testing the MPS backend with? I'm seeing many MPS backed ML tasks work on Python 3.9 but nothing newer. |
Can you give some examples? Because I've done stable diffusion inference and TI training just fine on Python 3.10, with mainline master branch, and have done inference just fine with kulinseth's master branch. On Ventura 13.1 public beta 4. |
@Birch-san, I’m asking cause I’ve seen the same stable diffusion project and different Torch 2.0 nightly run with MPS on Python 3.9 but throw the error from this issue when running with Python 3.10. |
@kulinseth My repro above works fine on macOS Ventura 13.3 beta, thanks! However, inference using from diffusers import StableDiffusionPipeline
sdm = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
safety_checker=None,
).to("mps")
prompt = "A painting of a squirrel eating a burger"
num_samples = 2
images = sdm(prompt, num_images_per_prompt=num_samples, num_inference_steps=20).images
for i, image in enumerate(images):
image.save(f"squirrel_{i}.png") As a workaround, enabling attention slicing in the snippet above works fine ( I'll try to isolate a more specific reproduction. |
Thanks @pcuenca , that will help . Meanwhile we will look at the example you provided . There are indeed integer limits on dimension lengths for the underlying buffers currently . |
With the repro, the problem can be narrowed down to: device = "mps"
for bsz in range(16):
if bsz in [4]: # only bsz == 4 fails.
continue
query = torch.randn((bsz, 8, 4096, 40), device=device)
key = torch.randn((bsz, 8, 4096, 40), device=device)
value = torch.randn((bsz, 8, 4096, 40), device=device)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
# or simply
hidden_states = torch.matmul(query, key.transpose(-2, -1))
# or
hidden_states = torch.bmm(query.view(-1, 4096, 40), key.transpose(-2, -1).view(-1, 40, 4096)) Since it only fails with bsz==4 and the failing point is when |
I can not reproduce any of the failures reported about while running on MacOS Sonoma, which makes we hopeful some some of the issues were addressed. cc: @kulinseth |
That’s great @malfet , thanks for confirming . Can we add the above test to our test_mps when we enable the Sonoma runners ? |
@kulinseth I've already did accidentally while working on 64-bit index select (see #116942 ) and now working on PR that raises exception if one tries to allocate 4Gb+ tensor on Ventura and change the skip to xfail ) |
awesome , thanks . The PR for 64-bit looks good to me , we can close this issue with that |
#116942 is now closed |
Is this issue still relevant? I am getting the following error with M4 Sequoia: The minimalistic code to produce this error: import torch
device = torch.device("mps")
mask_bool = torch.triu(torch.ones(1024, 1024, device=device), diagonal=1).bool()
attn_scores = torch.rand(48, 25, 1024, 1024, device=device)
attn_scores.masked_fill_(mask_bool, 0) Created Issue: #143477 |
This initial issue here has been addressed first in commit 92f282c further expanding the range in commit afa313e. Note that it will require user to be on MacOS15, otherwise user only gets the error and recommendation to update. This is due to tiling the operation requiring an API that was only released on MacOS 15.0. @rusnov the masked_fill_ issue you noted seems separate, although I'm pretty sure the root cause and solution will be similar. I'll continue that thread in the issue you filed. Thanks for bringing this up! I'll close this issue as addressed. Please reopen if you feel this is inaccurate. |
🐛 Describe the bug
Full error message (no traceback):
How to reproduce
python scripts/txt2img.py --prompt "a horse" --plms --n_samples 1 --n_rows 1 --n_iter 1
: runs well.--W 1024 --H 1024
flag, which means width and height respectively, then it'll return the error.I'm finding a way to reproduce it without installing the whole procedure, so I'll update the procedure soon.
Edit: repro by @Birch-san
It fails when n_samples is 2 or over 7. Which looks pretty weird.
About vram?
As you would all expect, the error seems to be something about VRAM. However, there remains question.
INT_MAX(2**31)
The error doesn't occur at
--W 512 --H 512
or lower resolution.Unlike errors like
CUDA out of memory
, this error isn't about the real memory limit.If the error was due to lack of VRAM, the code above (
--W 1024 --H 1024
) should run on M1 Max 64GB since--W 512 --H 512
runs well on my M1 8G macbook. Also, the limit 2**31 is a fixed number, which would not change from the current memory usage.So, my expectation is that something is being computed in 32-bit, which shouldn't be.
This might not be torch's problem - maybe(surely) metal.
However, all helps will be accepted gracefully.
Thanks.
Versions
PyTorch version: 1.13.0.dev20220824
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 12.5.1 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: version 3.24.1
Libc version: N/A
Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:14) [Clang 12.0.1 ] (64-bit runtime)
Python platform: macOS-12.5.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.2
[pip3] pytorch-lightning==1.7.2
[pip3] torch==1.13.0.dev20220824
[pip3] torch-fidelity==0.3.0
[pip3] torchaudio==0.13.0.dev20220824
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.14.0.dev20220824
[conda] numpy 1.23.2 py38h579d673_0 conda-forge
[conda] pytorch 1.13.0.dev20220824 py3.8_0 pytorch-nightly
[conda] pytorch-lightning 1.7.2 pypi_0 pypi
[conda] torch-fidelity 0.3.0 pypi_0 pypi
[conda] torchaudio 0.13.0.dev20220824 py38_cpu pytorch-nightly
[conda] torchmetrics 0.9.3 pypi_0 pypi
[conda] torchvision 0.14.0.dev20220824 py38_cpu pytorch-nightly
cc @kulinseth @albanD
The text was updated successfully, but these errors were encountered: