Skip to content

Commit d95fbaf

Browse files
datumboxvfdev-5
andauthored
[prototype] Optimize Center Crop performance (#6880)
* Reducing unnecessary method calls * Optimize pad branch. * Remove unnecessary call to crop_image_tensor * Fix linter Co-authored-by: vfdev <[email protected]>
1 parent 72c5952 commit d95fbaf

File tree

2 files changed

+29
-32
lines changed

2 files changed

+29
-32
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torchvision.prototype import features
33
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
44

5-
from ._meta import _rgb_to_gray, convert_dtype_image_tensor, get_dimensions_image_tensor, get_num_channels_image_tensor
5+
from ._meta import _rgb_to_gray, convert_dtype_image_tensor
66

77

88
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
@@ -45,7 +45,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
4545
if saturation_factor < 0:
4646
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
4747

48-
c = get_num_channels_image_tensor(image)
48+
c = image.shape[-3]
4949
if c not in [1, 3]:
5050
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
5151

@@ -75,7 +75,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
7575
if contrast_factor < 0:
7676
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
7777

78-
c = get_num_channels_image_tensor(image)
78+
c = image.shape[-3]
7979
if c not in [1, 3]:
8080
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
8181
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
@@ -101,7 +101,7 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat
101101

102102

103103
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
104-
num_channels, height, width = get_dimensions_image_tensor(image)
104+
num_channels, height, width = image.shape[-3:]
105105
if num_channels not in (1, 3):
106106
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
107107

@@ -210,8 +210,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
210210
if not (-0.5 <= hue_factor <= 0.5):
211211
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
212212

213-
c = get_num_channels_image_tensor(image)
214-
213+
c = image.shape[-3]
215214
if c not in [1, 3]:
216215
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
217216

@@ -342,8 +341,7 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
342341

343342

344343
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
345-
c = get_num_channels_image_tensor(image)
346-
344+
c = image.shape[-3]
347345
if c not in [1, 3]:
348346
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
349347

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@
1616
)
1717
from torchvision.transforms.functional_tensor import _parse_pad_padding
1818

19-
from ._meta import (
20-
convert_format_bounding_box,
21-
get_dimensions_image_tensor,
22-
get_spatial_size_image_pil,
23-
get_spatial_size_image_tensor,
24-
)
19+
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
2520

2621
horizontal_flip_image_tensor = _FT.hflip
2722
horizontal_flip_image_pil = _FP.hflip
@@ -120,9 +115,9 @@ def resize_image_tensor(
120115
max_size: Optional[int] = None,
121116
antialias: bool = False,
122117
) -> torch.Tensor:
123-
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
118+
shape = image.shape
119+
num_channels, old_height, old_width = shape[-3:]
124120
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
125-
extra_dims = image.shape[:-3]
126121

127122
if image.numel() > 0:
128123
image = image.reshape(-1, num_channels, old_height, old_width)
@@ -134,7 +129,7 @@ def resize_image_tensor(
134129
antialias=antialias,
135130
)
136131

137-
return image.reshape(extra_dims + (num_channels, new_height, new_width))
132+
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
138133

139134

140135
@torch.jit.unused
@@ -270,8 +265,8 @@ def affine_image_tensor(
270265
if image.numel() == 0:
271266
return image
272267

273-
num_channels, height, width = image.shape[-3:]
274-
extra_dims = image.shape[:-3]
268+
shape = image.shape
269+
num_channels, height, width = shape[-3:]
275270
image = image.reshape(-1, num_channels, height, width)
276271

277272
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
@@ -285,7 +280,7 @@ def affine_image_tensor(
285280
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
286281

287282
output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
288-
return output.reshape(extra_dims + (num_channels, height, width))
283+
return output.reshape(shape)
289284

290285

291286
@torch.jit.unused
@@ -511,8 +506,8 @@ def rotate_image_tensor(
511506
fill: features.FillTypeJIT = None,
512507
center: Optional[List[float]] = None,
513508
) -> torch.Tensor:
514-
num_channels, height, width = image.shape[-3:]
515-
extra_dims = image.shape[:-3]
509+
shape = image.shape
510+
num_channels, height, width = shape[-3:]
516511

517512
center_f = [0.0, 0.0]
518513
if center is not None:
@@ -538,7 +533,7 @@ def rotate_image_tensor(
538533
else:
539534
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)
540535

541-
return image.reshape(extra_dims + (num_channels, new_height, new_width))
536+
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
542537

543538

544539
@torch.jit.unused
@@ -675,8 +670,8 @@ def _pad_with_scalar_fill(
675670
fill: Union[int, float, None],
676671
padding_mode: str = "constant",
677672
) -> torch.Tensor:
678-
num_channels, height, width = image.shape[-3:]
679-
extra_dims = image.shape[:-3]
673+
shape = image.shape
674+
num_channels, height, width = shape[-3:]
680675

681676
if image.numel() > 0:
682677
image = _FT.pad(
@@ -688,7 +683,7 @@ def _pad_with_scalar_fill(
688683
new_height = height + top + bottom
689684
new_width = width + left + right
690685

691-
return image.reshape(extra_dims + (num_channels, new_height, new_width))
686+
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
692687

693688

694689
# TODO: This should be removed once pytorch pad supports non-scalar padding values
@@ -1130,7 +1125,8 @@ def elastic(
11301125

11311126
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
11321127
if isinstance(output_size, numbers.Number):
1133-
return [int(output_size), int(output_size)]
1128+
s = int(output_size)
1129+
return [s, s]
11341130
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
11351131
return [output_size[0], output_size[0]]
11361132
else:
@@ -1156,18 +1152,21 @@ def _center_crop_compute_crop_anchor(
11561152

11571153
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
11581154
crop_height, crop_width = _center_crop_parse_output_size(output_size)
1159-
image_height, image_width = get_spatial_size_image_tensor(image)
1155+
shape = image.shape
1156+
if image.numel() == 0:
1157+
return image.reshape(shape[:-2] + (crop_height, crop_width))
1158+
image_height, image_width = shape[-2:]
11601159

11611160
if crop_height > image_height or crop_width > image_width:
11621161
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1163-
image = pad_image_tensor(image, padding_ltrb, fill=0)
1162+
image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0)
11641163

1165-
image_height, image_width = get_spatial_size_image_tensor(image)
1164+
image_height, image_width = image.shape[-2:]
11661165
if crop_width == image_width and crop_height == image_height:
11671166
return image
11681167

11691168
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1170-
return crop_image_tensor(image, crop_top, crop_left, crop_height, crop_width)
1169+
return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
11711170

11721171

11731172
@torch.jit.unused
@@ -1332,7 +1331,7 @@ def five_crop_image_tensor(
13321331
image: torch.Tensor, size: List[int]
13331332
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
13341333
crop_height, crop_width = _parse_five_crop_size(size)
1335-
image_height, image_width = get_spatial_size_image_tensor(image)
1334+
image_height, image_width = image.shape[-2:]
13361335

13371336
if crop_width > image_width or crop_height > image_height:
13381337
msg = "Requested crop size {} is bigger than input size {}"

0 commit comments

Comments
 (0)