[inductor cpp] support vectorization for index_expr that depends on tiling itervar or with indirect indexing#114545
[inductor cpp] support vectorization for index_expr that depends on tiling itervar or with indirect indexing#114545jgong5 wants to merge 18 commits intogh/jgong5/31/basefrom
Conversation
…iling itervar or with indirect indexing [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114545
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 81cb361 with merge base f6dfbff ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…epends on tiling itervar or with indirect indexing" cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
lezcano
left a comment
There was a problem hiding this comment.
Couldn't node dtype propagation be done in common.py for all CSEVariables that have an FX node, similar to how we push forward the bounds from the FX nodes into variables?
…epends on tiling itervar or with indirect indexing"
As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:
1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with `Vectorized<int32>::arange` for this case.
2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop.
Below is the generated code for the first case from the test `test_concat_inner_vec`. Here `x1` is the index_expr and depends on the vectorized itervar `x1`. It has constant stride 1. We vectorized it with arange. We use `all_zero` to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(x1);
auto tmp1 = static_cast<long>(0);
auto tmp2 = tmp0 >= tmp1;
auto tmp3 = static_cast<long>(35);
auto tmp4 = tmp0 < tmp3;
auto tmp5 = [&]
{
auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))];
return tmp6;
}
;
auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0);
auto tmp8 = tmp0 >= tmp3;
auto tmp9 = static_cast<long>(155);
auto tmp10 = tmp0 < tmp9;
auto tmp11 = [&]
{
auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))];
return tmp12;
}
;
...
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L))
{
auto tmp0 = c10::convert<int>(x1);
auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1);
auto tmp2 = static_cast<int>(0);
auto tmp3 = at::vec::Vectorized<int>(tmp2);
auto tmp4 = to_float_mask(tmp1 >= tmp3);
auto tmp5 = static_cast<int>(35);
auto tmp6 = at::vec::Vectorized<int>(tmp5);
auto tmp7 = to_float_mask(tmp1 < tmp6);
auto tmp8 = [&]
{
auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7));
return tmp9;
}
;
auto tmp10 =
[&]
{
if (all_zero(to_float_mask(tmp7)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7));
}
}
()
;
...
```
Below is the generated code for the second case from the test case `test_expr_vec_non_contiguous`. Here, the index_expr is `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` which depends on the vectorized itervar `x2` and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses `vector_lane_mask_check` to implement the masked version of non-contiguous load.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
auto tmp1 = static_cast<long>(2048);
auto tmp2 = tmp0 < tmp1;
auto tmp3 = [&]
{
auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
return tmp4;
}
;
auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5);
}
out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0;
}
}
}
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
{
{
#pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 =
[&]
{
__at_align__ std::array<int, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
}
return at::vec::Vectorized<int>::loadu(tmpbuf.data());
}
()
;
auto tmp1 = static_cast<int>(2048);
auto tmp2 = at::vec::Vectorized<int>(tmp1);
auto tmp3 = to_float_mask(tmp0 < tmp2);
auto tmp4 = [&]
{
auto tmp5 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
if (vector_lane_mask_check(tmp3, x1_inner))
{
tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
return tmp5;
}
;
auto tmp6 =
[&]
{
if (all_zero(to_float_mask(tmp3)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
}
}
()
;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
}
tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
}
}
}
}
```
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler
[ghstack-poisoned]
The node-level dtype propagation is already in common.py (see pytorch/torch/_inductor/codegen/common.py Line 148 in c7e9c15 |
| def get_result_size(dtype: torch.dtype) -> str: | ||
| result_size = f"{self.tiling_factor}" | ||
| assert dtype.itemsize <= 4 | ||
| size_multiplier = 4 // dtype.itemsize | ||
| if size_multiplier > 1: | ||
| result_size += f" * {size_multiplier}" | ||
| return result_size |
There was a problem hiding this comment.
| def get_result_size(dtype: torch.dtype) -> str: | |
| result_size = f"{self.tiling_factor}" | |
| assert dtype.itemsize <= 4 | |
| size_multiplier = 4 // dtype.itemsize | |
| if size_multiplier > 1: | |
| result_size += f" * {size_multiplier}" | |
| return result_size | |
| def get_result_size(dtype: torch.dtype) -> int: | |
| assert dtype.itemsize <= 4 | |
| return self.tiling_factor * (4 // dtype.itemsize) |
|
I see that the dtype on CSEVars is a preexisting issue. What I meant is that that logic could potentially go in pytorch/torch/_inductor/codegen/common.py Lines 1047 to 1064 in 8ce8f9b so that it is generic for every CSEVariable. |
lezcano
left a comment
There was a problem hiding this comment.
A few comments, but looks good
…epends on tiling itervar or with indirect indexing"
As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:
1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with `Vectorized<int32>::arange` for this case.
2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop.
Below is the generated code for the first case from the test `test_concat_inner_vec`. Here `x1` is the index_expr and depends on the vectorized itervar `x1`. It has constant stride 1. We vectorized it with arange. We use `all_zero` to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(x1);
auto tmp1 = static_cast<long>(0);
auto tmp2 = tmp0 >= tmp1;
auto tmp3 = static_cast<long>(35);
auto tmp4 = tmp0 < tmp3;
auto tmp5 = [&]
{
auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))];
return tmp6;
}
;
auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0);
auto tmp8 = tmp0 >= tmp3;
auto tmp9 = static_cast<long>(155);
auto tmp10 = tmp0 < tmp9;
auto tmp11 = [&]
{
auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))];
return tmp12;
}
;
...
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L))
{
auto tmp0 = c10::convert<int>(x1);
auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1);
auto tmp2 = static_cast<int>(0);
auto tmp3 = at::vec::Vectorized<int>(tmp2);
auto tmp4 = to_float_mask(tmp1 >= tmp3);
auto tmp5 = static_cast<int>(35);
auto tmp6 = at::vec::Vectorized<int>(tmp5);
auto tmp7 = to_float_mask(tmp1 < tmp6);
auto tmp8 = [&]
{
auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7));
return tmp9;
}
;
auto tmp10 =
[&]
{
if (all_zero(to_float_mask(tmp7)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7));
}
}
()
;
...
```
Below is the generated code for the second case from the test case `test_expr_vec_non_contiguous`. Here, the index_expr is `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` which depends on the vectorized itervar `x2` and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses `vector_lane_mask_check` to implement the masked version of non-contiguous load.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
auto tmp1 = static_cast<long>(2048);
auto tmp2 = tmp0 < tmp1;
auto tmp3 = [&]
{
auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
return tmp4;
}
;
auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5);
}
out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0;
}
}
}
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
{
{
#pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 =
[&]
{
__at_align__ std::array<int, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
}
return at::vec::Vectorized<int>::loadu(tmpbuf.data());
}
()
;
auto tmp1 = static_cast<int>(2048);
auto tmp2 = at::vec::Vectorized<int>(tmp1);
auto tmp3 = to_float_mask(tmp0 < tmp2);
auto tmp4 = [&]
{
auto tmp5 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
if (vector_lane_mask_check(tmp3, x1_inner))
{
tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
return tmp5;
}
;
auto tmp6 =
[&]
{
if (all_zero(to_float_mask(tmp3)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
}
}
()
;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
}
tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
}
}
}
}
```
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…range (#116387) For the test `test_expr_vec_non_contiguous`. The index_expr `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` is invariant under the vector range of `x2`. Before change ```c++ #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { { #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())}) float tmp_acc0 = -std::numeric_limits<float>::infinity(); at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()); for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L)) { auto tmp0 = [&] { __at_align__ std::array<int, 16> tmpbuf; #pragma GCC unroll 16 for (long x1_inner = 0; x1_inner < 16; x1_inner++) { tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L))); } return at::vec::Vectorized<int>::loadu(tmpbuf.data()); } () ; auto tmp1 = static_cast<int>(2048); auto tmp2 = at::vec::Vectorized<int>(tmp1); auto tmp3 = to_float_mask(tmp0 < tmp2); auto tmp4 = [&] { auto tmp5 = [&] { __at_align__ std::array<float, 16> tmpbuf; #pragma GCC unroll 16 for (long x1_inner = 0; x1_inner < 16; x1_inner++) { if (vector_lane_mask_check(tmp3, x1_inner)) { tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))]; } } return at::vec::Vectorized<float>::loadu(tmpbuf.data()); } () ; return tmp5; } ; auto tmp6 = [&] { if (all_zero(to_float_mask(tmp3))) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3)); } } () ; tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6); } tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0))); } } } } ``` After change ```c++ #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { { #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())}) float tmp_acc0 = -std::numeric_limits<float>::infinity(); at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()); for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))); auto tmp1 = static_cast<int>(2048); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = [&] { __at_align__ std::array<float, 16> tmpbuf; #pragma GCC unroll 16 for (long x1_inner = 0; x1_inner < 16; x1_inner++) { if (tmp2 != 0) { tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))]; } } return at::vec::Vectorized<float>::loadu(tmpbuf.data()); } () ; return tmp4; } ; auto tmp5 = [&] { if (all_zero(to_float_mask(tmp2))) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp3())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp3(), to_float_mask(tmp2)); } } () ; tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp5); } tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0))); } } } } ``` Pull Request resolved: #116387 Approved by: https://github.com/EikanWang, https://github.com/lezcano ghstack dependencies: #114545
Stack from ghstack (oldest at bottom):
As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:
Vectorized<int32>::arangefor this case.Below is the generated code for the first case from the test
test_concat_inner_vec. Herex1is the index_expr and depends on the vectorized itervarx1. It has constant stride 1. We vectorized it with arange. We useall_zeroto implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.Before:
After:
Below is the generated code for the second case from the test case
test_expr_vec_non_contiguous. Here, the index_expr is31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))which depends on the vectorized itervarx2and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code usesvector_lane_mask_checkto implement the masked version of non-contiguous load.Before:
After:
cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler