Skip to content

Commit 24890d7

Browse files
datumboxpmeier
andauthored
Fix issues with get_image_size() (#6581)
* Fix bug on `get_image_size()` and move it to deprecated. Introduce generic named spatial/channel equivalents. * Update tests and fix mypy issues. * Remove the use of get_image_size from ElasticTransform. * Fix linter * Apply suggestions from code review. * Update torchvision/prototype/transforms/functional/_deprecated.py Co-authored-by: Philip Meier <[email protected]> * Further changes from code review. * Fix linter Co-authored-by: Philip Meier <[email protected]>
1 parent 6b2e0a0 commit 24890d7

File tree

5 files changed

+25
-9
lines changed

5 files changed

+25
-9
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@ def erase_image_tensor():
531531
and name
532532
not in {
533533
"to_image_tensor",
534+
"get_num_channels",
535+
"get_spatial_size",
534536
"get_image_num_channels",
535537
"get_image_size",
536538
}

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
convert_color_space,
1010
get_dimensions,
1111
get_image_num_channels,
12-
get_image_size,
12+
get_num_channels,
13+
get_spatial_size,
1314
) # usort: skip
1415

1516
from ._augment import erase, erase_image_pil, erase_image_tensor
@@ -125,4 +126,4 @@
125126
to_pil_image,
126127
)
127128

128-
from ._deprecated import rgb_to_grayscale, to_grayscale # usort: skip
129+
from ._deprecated import get_image_size, rgb_to_grayscale, to_grayscale, to_tensor # usort: skip

torchvision/prototype/transforms/functional/_deprecated.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, Union
2+
from typing import Any, List, Union
33

44
import PIL.Image
55
import torch
@@ -50,3 +50,11 @@ def to_tensor(inpt: Any) -> torch.Tensor:
5050
"Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`."
5151
)
5252
return _F.to_tensor(inpt)
53+
54+
55+
def get_image_size(inpt: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
56+
warnings.warn(
57+
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
58+
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
59+
)
60+
return _F.get_image_size(inpt)

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,19 @@ def get_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image])
3434
return list(get_chw(image))
3535

3636

37-
def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int:
37+
def get_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int:
3838
num_channels, *_ = get_chw(image)
3939
return num_channels
4040

4141

42-
def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
43-
_, *image_size = get_chw(image)
44-
return image_size
42+
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
43+
# deprecating the old names.
44+
get_image_num_channels = get_num_channels
45+
46+
47+
def get_spatial_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
48+
_, *size = get_chw(image)
49+
return size
4550

4651

4752
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:

torchvision/transforms/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,8 +2162,8 @@ def forward(self, tensor: Tensor) -> Tensor:
21622162
Returns:
21632163
PIL Image or Tensor: Transformed image.
21642164
"""
2165-
size = F.get_image_size(tensor)[::-1]
2166-
displacement = self.get_params(self.alpha, self.sigma, size)
2165+
_, height, width = F.get_dimensions(tensor)
2166+
displacement = self.get_params(self.alpha, self.sigma, [height, width])
21672167
return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
21682168

21692169
def __repr__(self):

0 commit comments

Comments
 (0)