-
Notifications
You must be signed in to change notification settings - Fork 531
/
Copy pathop_fill_test.cpp
159 lines (125 loc) · 5.04 KB
/
op_fill_test.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/kernels/test/supported_features.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <gtest/gtest.h>
using namespace ::testing;
using executorch::aten::Scalar;
using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using torch::executor::testing::TensorFactory;
class OpFillTest : public OperatorTest {
protected:
Tensor&
op_fill_scalar_out(const Tensor& self, const Scalar& other, Tensor& out) {
return torch::executor::aten::fill_outf(context_, self, other, out);
}
Tensor&
op_fill_tensor_out(const Tensor& self, const Tensor& other, Tensor& out) {
return torch::executor::aten::fill_outf(context_, self, other, out);
}
template <ScalarType DTYPE>
void test_fill_scalar_out(std::vector<int32_t>&& sizes) {
TensorFactory<DTYPE> tf;
// Before: `out` consists of 0s.
Tensor self = tf.zeros(sizes);
Tensor out = tf.zeros(sizes);
// After: `out` consists of 1s.
Scalar other = 1;
if (DTYPE == ScalarType::Bool) {
other = false;
}
op_fill_scalar_out(self, other, out);
Tensor exp_out = tf.full(sizes, 1);
if (DTYPE == ScalarType::Bool) {
exp_out = tf.full(sizes, false);
}
// Check `out` matches expected output.
EXPECT_TENSOR_EQ(out, exp_out);
}
template <ScalarType DTYPE>
void test_fill_tensor_out(std::vector<int32_t>&& sizes) {
TensorFactory<DTYPE> tf;
// Before: `out` consists of 0s.
Tensor self = tf.zeros(sizes);
Tensor out = tf.zeros(sizes);
// After: `out` consists of 1s.
Tensor other = tf.ones({});
op_fill_tensor_out(self, other, out);
Tensor exp_out = tf.full(sizes, 1);
// Check `out` matches expected output.
EXPECT_TENSOR_EQ(out, exp_out);
}
};
// A macro for defining tests for both scalar and tensor variants of
// `fill_out`. Here the `self` and `out` tensors will be created according
// to the sizes provided, while the scalar/tensor will be a singleton.
#define TEST_FILL_OUT(FN, DTYPE) \
FN<ScalarType::DTYPE>({}); \
FN<ScalarType::DTYPE>({1}); \
FN<ScalarType::DTYPE>({1, 1, 1}); \
FN<ScalarType::DTYPE>({2, 0, 4}); \
FN<ScalarType::DTYPE>({2, 3, 4});
// Create input support tests for scalar variant.
#define GENERATE_SCALAR_INPUT_SUPPORT_TEST(_, DTYPE) \
TEST_F(OpFillTest, DTYPE##ScalarInputSupport) { \
TEST_FILL_OUT(test_fill_scalar_out, DTYPE); \
}
ET_FORALL_REALHBBF16_TYPES(GENERATE_SCALAR_INPUT_SUPPORT_TEST)
// Create input support tests for tensor variant.
#define GENERATE_TENSOR_INPUT_SUPPORT_TEST(_, DTYPE) \
TEST_F(OpFillTest, DTYPE##TensorInputSupport) { \
TEST_FILL_OUT(test_fill_tensor_out, DTYPE); \
}
ET_FORALL_REALHBBF16_TYPES(GENERATE_TENSOR_INPUT_SUPPORT_TEST)
TEST_F(OpFillTest, MismatchedOtherPropertiesDies) {
TensorFactory<ScalarType::Int> tf;
// `self` and `out` have different shapes but same dtype.
Tensor self = tf.zeros({1});
Tensor out = tf.zeros({1});
// Create `other` tensors with incompatible shapes (`dim()` >=1) and/or
// elements (`numel()` > 1).
Tensor other1 = tf.zeros({1});
EXPECT_EQ(other1.dim(), 1);
EXPECT_EQ(other1.numel(), 1);
Tensor other2 = tf.zeros({2});
EXPECT_EQ(other2.dim(), 1);
EXPECT_EQ(other2.numel(), 2);
Tensor other3 = tf.zeros({3, 3});
EXPECT_EQ(other3.dim(), 2);
EXPECT_EQ(other3.numel(), 9);
// Assert `other` tensors with incompatible properties fails.
ET_EXPECT_KERNEL_FAILURE(context_, op_fill_tensor_out(self, other1, out));
ET_EXPECT_KERNEL_FAILURE(context_, op_fill_tensor_out(self, other2, out));
ET_EXPECT_KERNEL_FAILURE(context_, op_fill_tensor_out(self, other3, out));
}
TEST_F(OpFillTest, MismatchedOutputShapesDies) {
// Skip ATen test since it supports `self` and `out` having different shapes.
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
GTEST_SKIP() << "ATen kernel can handle mismatched output shape";
}
TensorFactory<ScalarType::Int> tf;
// `self` and `out` have different shapes but same dtype.
Tensor self = tf.zeros({1});
Tensor out = tf.zeros({2, 2});
// Assert `out` can't be filled due to incompatible shapes.
ET_EXPECT_KERNEL_FAILURE(context_, op_fill_scalar_out(self, 0, out));
}
TEST_F(OpFillTest, MismatchedOutputDtypeDies) {
TensorFactory<ScalarType::Byte> tf_byte;
TensorFactory<ScalarType::Float> tf_float;
// `self` and `out` have different dtypes but same shape.
Tensor self = tf_byte.zeros({2, 2});
Tensor out = tf_float.ones({2, 2});
// Assert `out` can't be filled due to incompatible dtype.
ET_EXPECT_KERNEL_FAILURE(context_, op_fill_scalar_out(self, 0.0, out));
}