Skip to content

Commit 0ed3090

Browse files
authored
Move non-NF4 tensor to device prior to quantization on copy (#737)
1 parent 68e4643 commit 0ed3090

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchao/dtypes/nf4tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def copy_(func, *args, **kwargs):
339339
# Convert Non NF4Tensor into NF4 for copy in
340340
if not isinstance(copy_in, NF4Tensor):
341341
copy_in_nf4 = NF4Tensor.from_tensor(
342-
copy_in, original.block_size, original.scaler_block_size
342+
copy_in.to(original.device), original.block_size, original.scaler_block_size
343343
)
344344
return original.copy_(copy_in_nf4)
345345

0 commit comments

Comments
 (0)