@@ -183,8 +183,12 @@ def clamp_bounding_box(
183
183
return convert_format_bounding_box (xyxy_boxes , BoundingBoxFormat .XYXY , format )
184
184
185
185
186
+ def _split_alpha (image : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
187
+ return torch .tensor_split (image , indices = (- 1 ,), dim = - 3 )
188
+
189
+
186
190
def _strip_alpha (image : torch .Tensor ) -> torch .Tensor :
187
- image , alpha = torch . tensor_split (image , indices = ( - 1 ,), dim = - 3 )
191
+ image , alpha = _split_alpha (image )
188
192
if not torch .all (alpha == _FT ._max_value (alpha .dtype )):
189
193
raise RuntimeError (
190
194
"Stripping the alpha channel if it contains values other than the max value is not supported."
@@ -233,7 +237,7 @@ def convert_color_space_image_tensor(
233
237
elif old_color_space == ColorSpace .GRAY_ALPHA and new_color_space == ColorSpace .RGB :
234
238
return _gray_to_rgb (_strip_alpha (image ))
235
239
elif old_color_space == ColorSpace .GRAY_ALPHA and new_color_space == ColorSpace .RGB_ALPHA :
236
- image , alpha = torch . tensor_split (image , indices = ( - 1 ,), dim = - 3 )
240
+ image , alpha = _split_alpha (image )
237
241
return _add_alpha (_gray_to_rgb (image ), alpha )
238
242
elif old_color_space == ColorSpace .RGB and new_color_space == ColorSpace .GRAY :
239
243
return _rgb_to_gray (image )
@@ -244,7 +248,7 @@ def convert_color_space_image_tensor(
244
248
elif old_color_space == ColorSpace .RGB_ALPHA and new_color_space == ColorSpace .GRAY :
245
249
return _rgb_to_gray (_strip_alpha (image ))
246
250
elif old_color_space == ColorSpace .RGB_ALPHA and new_color_space == ColorSpace .GRAY_ALPHA :
247
- image , alpha = torch . tensor_split (image , indices = ( - 1 ,), dim = - 3 )
251
+ image , alpha = _split_alpha (image )
248
252
return _add_alpha (_rgb_to_gray (image ), alpha )
249
253
elif old_color_space == ColorSpace .RGB_ALPHA and new_color_space == ColorSpace .RGB :
250
254
return _strip_alpha (image )
0 commit comments