16
16
)
17
17
from torchvision .transforms .functional_tensor import _parse_pad_padding
18
18
19
- from ._meta import (
20
- convert_format_bounding_box ,
21
- get_dimensions_image_tensor ,
22
- get_spatial_size_image_pil ,
23
- get_spatial_size_image_tensor ,
24
- )
19
+ from ._meta import convert_format_bounding_box , get_spatial_size_image_pil
25
20
26
21
horizontal_flip_image_tensor = _FT .hflip
27
22
horizontal_flip_image_pil = _FP .hflip
@@ -120,9 +115,9 @@ def resize_image_tensor(
120
115
max_size : Optional [int ] = None ,
121
116
antialias : bool = False ,
122
117
) -> torch .Tensor :
123
- num_channels , old_height , old_width = get_dimensions_image_tensor (image )
118
+ shape = image .shape
119
+ num_channels , old_height , old_width = shape [- 3 :]
124
120
new_height , new_width = _compute_resized_output_size ((old_height , old_width ), size = size , max_size = max_size )
125
- extra_dims = image .shape [:- 3 ]
126
121
127
122
if image .numel () > 0 :
128
123
image = image .reshape (- 1 , num_channels , old_height , old_width )
@@ -134,7 +129,7 @@ def resize_image_tensor(
134
129
antialias = antialias ,
135
130
)
136
131
137
- return image .reshape (extra_dims + (num_channels , new_height , new_width ))
132
+ return image .reshape (shape [: - 3 ] + (num_channels , new_height , new_width ))
138
133
139
134
140
135
@torch .jit .unused
@@ -270,8 +265,8 @@ def affine_image_tensor(
270
265
if image .numel () == 0 :
271
266
return image
272
267
273
- num_channels , height , width = image .shape [ - 3 :]
274
- extra_dims = image . shape [: - 3 ]
268
+ shape = image .shape
269
+ num_channels , height , width = shape [- 3 : ]
275
270
image = image .reshape (- 1 , num_channels , height , width )
276
271
277
272
angle , translate , shear , center = _affine_parse_args (angle , translate , scale , shear , interpolation , center )
@@ -285,7 +280,7 @@ def affine_image_tensor(
285
280
matrix = _get_inverse_affine_matrix (center_f , angle , translate_f , scale , shear )
286
281
287
282
output = _FT .affine (image , matrix , interpolation = interpolation .value , fill = fill )
288
- return output .reshape (extra_dims + ( num_channels , height , width ) )
283
+ return output .reshape (shape )
289
284
290
285
291
286
@torch .jit .unused
@@ -511,8 +506,8 @@ def rotate_image_tensor(
511
506
fill : features .FillTypeJIT = None ,
512
507
center : Optional [List [float ]] = None ,
513
508
) -> torch .Tensor :
514
- num_channels , height , width = image .shape [ - 3 :]
515
- extra_dims = image . shape [: - 3 ]
509
+ shape = image .shape
510
+ num_channels , height , width = shape [- 3 : ]
516
511
517
512
center_f = [0.0 , 0.0 ]
518
513
if center is not None :
@@ -538,7 +533,7 @@ def rotate_image_tensor(
538
533
else :
539
534
new_width , new_height = _FT ._compute_affine_output_size (matrix , width , height ) if expand else (width , height )
540
535
541
- return image .reshape (extra_dims + (num_channels , new_height , new_width ))
536
+ return image .reshape (shape [: - 3 ] + (num_channels , new_height , new_width ))
542
537
543
538
544
539
@torch .jit .unused
@@ -675,8 +670,8 @@ def _pad_with_scalar_fill(
675
670
fill : Union [int , float , None ],
676
671
padding_mode : str = "constant" ,
677
672
) -> torch .Tensor :
678
- num_channels , height , width = image .shape [ - 3 :]
679
- extra_dims = image . shape [: - 3 ]
673
+ shape = image .shape
674
+ num_channels , height , width = shape [- 3 : ]
680
675
681
676
if image .numel () > 0 :
682
677
image = _FT .pad (
@@ -688,7 +683,7 @@ def _pad_with_scalar_fill(
688
683
new_height = height + top + bottom
689
684
new_width = width + left + right
690
685
691
- return image .reshape (extra_dims + (num_channels , new_height , new_width ))
686
+ return image .reshape (shape [: - 3 ] + (num_channels , new_height , new_width ))
692
687
693
688
694
689
# TODO: This should be removed once pytorch pad supports non-scalar padding values
@@ -1130,7 +1125,8 @@ def elastic(
1130
1125
1131
1126
def _center_crop_parse_output_size (output_size : List [int ]) -> List [int ]:
1132
1127
if isinstance (output_size , numbers .Number ):
1133
- return [int (output_size ), int (output_size )]
1128
+ s = int (output_size )
1129
+ return [s , s ]
1134
1130
elif isinstance (output_size , (tuple , list )) and len (output_size ) == 1 :
1135
1131
return [output_size [0 ], output_size [0 ]]
1136
1132
else :
@@ -1156,18 +1152,21 @@ def _center_crop_compute_crop_anchor(
1156
1152
1157
1153
def center_crop_image_tensor (image : torch .Tensor , output_size : List [int ]) -> torch .Tensor :
1158
1154
crop_height , crop_width = _center_crop_parse_output_size (output_size )
1159
- image_height , image_width = get_spatial_size_image_tensor (image )
1155
+ shape = image .shape
1156
+ if image .numel () == 0 :
1157
+ return image .reshape (shape [:- 2 ] + (crop_height , crop_width ))
1158
+ image_height , image_width = shape [- 2 :]
1160
1159
1161
1160
if crop_height > image_height or crop_width > image_width :
1162
1161
padding_ltrb = _center_crop_compute_padding (crop_height , crop_width , image_height , image_width )
1163
- image = pad_image_tensor (image , padding_ltrb , fill = 0 )
1162
+ image = _FT . torch_pad (image , _FT . _parse_pad_padding ( padding_ltrb ), value = 0. 0 )
1164
1163
1165
- image_height , image_width = get_spatial_size_image_tensor ( image )
1164
+ image_height , image_width = image . shape [ - 2 :]
1166
1165
if crop_width == image_width and crop_height == image_height :
1167
1166
return image
1168
1167
1169
1168
crop_top , crop_left = _center_crop_compute_crop_anchor (crop_height , crop_width , image_height , image_width )
1170
- return crop_image_tensor ( image , crop_top , crop_left , crop_height , crop_width )
1169
+ return image [... , crop_top : ( crop_top + crop_height ), crop_left : ( crop_left + crop_width )]
1171
1170
1172
1171
1173
1172
@torch .jit .unused
@@ -1332,7 +1331,7 @@ def five_crop_image_tensor(
1332
1331
image : torch .Tensor , size : List [int ]
1333
1332
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
1334
1333
crop_height , crop_width = _parse_five_crop_size (size )
1335
- image_height , image_width = get_spatial_size_image_tensor ( image )
1334
+ image_height , image_width = image . shape [ - 2 :]
1336
1335
1337
1336
if crop_width > image_width or crop_height > image_height :
1338
1337
msg = "Requested crop size {} is bigger than input size {}"
0 commit comments