@@ -183,12 +183,8 @@ def clamp_bounding_box(
183
183
return convert_format_bounding_box (xyxy_boxes , BoundingBoxFormat .XYXY , format )
184
184
185
185
186
- def _split_alpha (image : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
187
- return image [..., :- 1 , :, :], image [..., - 1 :, :, :]
188
-
189
-
190
186
def _strip_alpha (image : torch .Tensor ) -> torch .Tensor :
191
- image , alpha = _split_alpha (image )
187
+ image , alpha = torch . tensor_split (image , indices = ( - 1 ,), dim = - 3 )
192
188
if not torch .all (alpha == _FT ._max_value (alpha .dtype )):
193
189
raise RuntimeError (
194
190
"Stripping the alpha channel if it contains values other than the max value is not supported."
@@ -237,7 +233,7 @@ def convert_color_space_image_tensor(
237
233
elif old_color_space == ColorSpace .GRAY_ALPHA and new_color_space == ColorSpace .RGB :
238
234
return _gray_to_rgb (_strip_alpha (image ))
239
235
elif old_color_space == ColorSpace .GRAY_ALPHA and new_color_space == ColorSpace .RGB_ALPHA :
240
- image , alpha = _split_alpha (image )
236
+ image , alpha = torch . tensor_split (image , indices = ( - 1 ,), dim = - 3 )
241
237
return _add_alpha (_gray_to_rgb (image ), alpha )
242
238
elif old_color_space == ColorSpace .RGB and new_color_space == ColorSpace .GRAY :
243
239
return _rgb_to_gray (image )
@@ -248,7 +244,7 @@ def convert_color_space_image_tensor(
248
244
elif old_color_space == ColorSpace .RGB_ALPHA and new_color_space == ColorSpace .GRAY :
249
245
return _rgb_to_gray (_strip_alpha (image ))
250
246
elif old_color_space == ColorSpace .RGB_ALPHA and new_color_space == ColorSpace .GRAY_ALPHA :
251
- image , alpha = _split_alpha (image )
247
+ image , alpha = torch . tensor_split (image , indices = ( - 1 ,), dim = - 3 )
252
248
return _add_alpha (_rgb_to_gray (image ), alpha )
253
249
elif old_color_space == ColorSpace .RGB_ALPHA and new_color_space == ColorSpace .RGB :
254
250
return _strip_alpha (image )
0 commit comments