diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index 7a391b3f60b19..52156b23e682b 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -732,7 +732,9 @@ def fit_transform(self, y): class_mapping = np.empty(len(tmp), dtype=dtype) class_mapping[:] = tmp self.classes_, inverse = np.unique(class_mapping, return_inverse=True) - yt.indices = np.take(inverse, yt.indices) + # ensure yt.indices keeps its current dtype + yt.indices = np.array(inverse[yt.indices], dtype=yt.indices.dtype, + copy=False) if not self.sparse_output: yt = yt.toarray() diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index baf1cfbc8bddd..f48ad29bd29b5 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -226,6 +226,8 @@ def test_sparse_output_multilabel_binarizer(): got = mlb.fit_transform(inp()) assert_equal(issparse(got), sparse_output) if sparse_output: + # verify CSR assumption that indices and indptr have same dtype + assert_equal(got.indices.dtype, got.indptr.dtype) got = got.toarray() assert_array_equal(indicator_mat, got) assert_array_equal([1, 2, 3], mlb.classes_) @@ -236,6 +238,8 @@ def test_sparse_output_multilabel_binarizer(): got = mlb.fit(inp()).transform(inp()) assert_equal(issparse(got), sparse_output) if sparse_output: + # verify CSR assumption that indices and indptr have same dtype + assert_equal(got.indices.dtype, got.indptr.dtype) got = got.toarray() assert_array_equal(indicator_mat, got) assert_array_equal([1, 2, 3], mlb.classes_)