Skip to content

[inductor cpp] support vectorization for index_expr that depends on tiling itervar or with indirect indexing#114545

Closed
jgong5 wants to merge 18 commits intogh/jgong5/31/basefrom
gh/jgong5/31/head
Closed

[inductor cpp] support vectorization for index_expr that depends on tiling itervar or with indirect indexing#114545
jgong5 wants to merge 18 commits intogh/jgong5/31/basefrom
gh/jgong5/31/head

Conversation

@jgong5
Copy link
Copy Markdown
Collaborator

@jgong5 jgong5 commented Nov 26, 2023

Stack from ghstack (oldest at bottom):

As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:

  1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with Vectorized<int32>::arange for this case.
  2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop.

Below is the generated code for the first case from the test test_concat_inner_vec. Here x1 is the index_expr and depends on the vectorized itervar x1. It has constant stride 1. We vectorized it with arange. We use all_zero to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.
Before:

            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L))
                {
                    auto tmp0 = c10::convert<long>(x1);
                    auto tmp1 = static_cast<long>(0);
                    auto tmp2 = tmp0 >= tmp1;
                    auto tmp3 = static_cast<long>(35);
                    auto tmp4 = tmp0 < tmp3;
                    auto tmp5 = [&]
                    {
                        auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))];
                        return tmp6;
                    }
                    ;
                    auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0);
                    auto tmp8 = tmp0 >= tmp3;
                    auto tmp9 = static_cast<long>(155);
                    auto tmp10 = tmp0 < tmp9;
                    auto tmp11 = [&]
                    {
                        auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))];
                        return tmp12;
                    }
                    ;
...

After:

            #pragma omp for 
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L))
                {
                    auto tmp0 = c10::convert<int>(x1);
                    auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1);
                    auto tmp2 = static_cast<int>(0);
                    auto tmp3 = at::vec::Vectorized<int>(tmp2);
                    auto tmp4 = to_float_mask(tmp1 >= tmp3);
                    auto tmp5 = static_cast<int>(35);
                    auto tmp6 = at::vec::Vectorized<int>(tmp5);
                    auto tmp7 = to_float_mask(tmp1 < tmp6);
                    auto tmp8 = [&]
                    {
                        auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7));
                        return tmp9;
                    }
                    ;
                    auto tmp10 =
                    [&]
                    {
                        if (all_zero(to_float_mask(tmp7)))
                        {
                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                        }
                        else
                        {
                            return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7));
                        }
                    }
                    ()
                    ;
...

Below is the generated code for the second case from the test case test_expr_vec_non_contiguous. Here, the index_expr is 31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)) which depends on the vectorized itervar x2 and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses vector_lane_mask_check to implement the masked version of non-contiguous load.
Before:

            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    {
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
                            auto tmp1 = static_cast<long>(2048);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                return tmp4;
                            }
                            ;
                            auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5);
                        }
                        out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0;
                    }
                }
            }

After:

            #pragma omp for 
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
                {
                    {
                        #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 =
                            [&]
                            {
                                __at_align__ std::array<int, 16> tmpbuf;
                                #pragma GCC unroll 16
                                for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                {
                                    tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
                                }
                                return at::vec::Vectorized<int>::loadu(tmpbuf.data());
                            }
                            ()
                            ;
                            auto tmp1 = static_cast<int>(2048);
                            auto tmp2 = at::vec::Vectorized<int>(tmp1);
                            auto tmp3 = to_float_mask(tmp0 < tmp2);
                            auto tmp4 = [&]
                            {
                                auto tmp5 =
                                [&]
                                {
                                    __at_align__ std::array<float, 16> tmpbuf;
                                    #pragma GCC unroll 16
                                    for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                    {
                                        if (vector_lane_mask_check(tmp3, x1_inner))
                                        {
                                            tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                        }
                                    }
                                    return at::vec::Vectorized<float>::loadu(tmpbuf.data());
                                }
                                ()
                                ;
                                return tmp5;
                            }
                            ;
                            auto tmp6 =
                            [&]
                            {
                                if (all_zero(to_float_mask(tmp3)))
                                {
                                    return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                }
                                else
                                {
                                    return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
                                }
                            }
                            ()
                            ;
                            tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
                        }
                        tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
                    }
                }
            }
        }

cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

…iling itervar or with indirect indexing

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Nov 26, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114545

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 81cb361 with merge base f6dfbff (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Nov 26, 2023
…iling itervar or with indirect indexing

ghstack-source-id: e30d45b
Pull Request resolved: #114545
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@jgong5 jgong5 added the topic: not user facing topic category label Nov 27, 2023
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Nov 27, 2023
…iling itervar or with indirect indexing

ghstack-source-id: 1feb391
Pull Request resolved: #114545
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Dec 2, 2023
…iling itervar or with indirect indexing

ghstack-source-id: a8213a1
Pull Request resolved: #114545
Comment thread torch/_inductor/codegen/cpp.py
Comment thread torch/_inductor/codegen/cpp.py
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Dec 4, 2023
…iling itervar or with indirect indexing

ghstack-source-id: c5212cc
Pull Request resolved: #114545
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Dec 19, 2023
…iling itervar or with indirect indexing

ghstack-source-id: f5007f7
Pull Request resolved: #114545
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Dec 24, 2023
…iling itervar or with indirect indexing

ghstack-source-id: 40a678b
Pull Request resolved: #114545
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
Jiong Gong added 2 commits December 25, 2023 10:23
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Dec 25, 2023
…iling itervar or with indirect indexing

ghstack-source-id: 5c9cbe0
Pull Request resolved: #114545
…epends on tiling itervar or with indirect indexing"

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't node dtype propagation be done in common.py for all CSEVariables that have an FX node, similar to how we push forward the bounds from the FX nodes into variables?

Comment thread torch/_inductor/codegen/cpp_prefix.h Outdated
Comment thread torch/_inductor/codegen/cpp_prefix.h
Comment thread torch/_inductor/codegen/cpp.py
Comment thread torch/_inductor/codegen/cpp.py Outdated
Comment thread torch/_inductor/codegen/cpp.py Outdated
Comment thread torch/_inductor/codegen/cpp.py Outdated
Comment thread torch/_inductor/codegen/cpp.py
…epends on tiling itervar or with indirect indexing"


As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:
1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with `Vectorized<int32>::arange` for this case.
2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop.

Below is the generated code for the first case from the test `test_concat_inner_vec`. Here `x1` is the index_expr and depends on the vectorized itervar `x1`. It has constant stride 1. We vectorized it with arange. We use `all_zero` to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.
Before:
```c++
            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L))
                {
                    auto tmp0 = c10::convert<long>(x1);
                    auto tmp1 = static_cast<long>(0);
                    auto tmp2 = tmp0 >= tmp1;
                    auto tmp3 = static_cast<long>(35);
                    auto tmp4 = tmp0 < tmp3;
                    auto tmp5 = [&]
                    {
                        auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))];
                        return tmp6;
                    }
                    ;
                    auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0);
                    auto tmp8 = tmp0 >= tmp3;
                    auto tmp9 = static_cast<long>(155);
                    auto tmp10 = tmp0 < tmp9;
                    auto tmp11 = [&]
                    {
                        auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))];
                        return tmp12;
                    }
                    ;
...
```
After:
```c++
            #pragma omp for 
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L))
                {
                    auto tmp0 = c10::convert<int>(x1);
                    auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1);
                    auto tmp2 = static_cast<int>(0);
                    auto tmp3 = at::vec::Vectorized<int>(tmp2);
                    auto tmp4 = to_float_mask(tmp1 >= tmp3);
                    auto tmp5 = static_cast<int>(35);
                    auto tmp6 = at::vec::Vectorized<int>(tmp5);
                    auto tmp7 = to_float_mask(tmp1 < tmp6);
                    auto tmp8 = [&]
                    {
                        auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7));
                        return tmp9;
                    }
                    ;
                    auto tmp10 =
                    [&]
                    {
                        if (all_zero(to_float_mask(tmp7)))
                        {
                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                        }
                        else
                        {
                            return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7));
                        }
                    }
                    ()
                    ;
...
```

Below is the generated code for the second case from the test case `test_expr_vec_non_contiguous`. Here, the index_expr is `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` which depends on the vectorized itervar `x2` and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses `vector_lane_mask_check` to implement the masked version of non-contiguous load.
Before:
```c++
            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    {
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
                            auto tmp1 = static_cast<long>(2048);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                return tmp4;
                            }
                            ;
                            auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5);
                        }
                        out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0;
                    }
                }
            }
```
After:
```c++
            #pragma omp for 
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
                {
                    {
                        #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 =
                            [&]
                            {
                                __at_align__ std::array<int, 16> tmpbuf;
                                #pragma GCC unroll 16
                                for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                {
                                    tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
                                }
                                return at::vec::Vectorized<int>::loadu(tmpbuf.data());
                            }
                            ()
                            ;
                            auto tmp1 = static_cast<int>(2048);
                            auto tmp2 = at::vec::Vectorized<int>(tmp1);
                            auto tmp3 = to_float_mask(tmp0 < tmp2);
                            auto tmp4 = [&]
                            {
                                auto tmp5 =
                                [&]
                                {
                                    __at_align__ std::array<float, 16> tmpbuf;
                                    #pragma GCC unroll 16
                                    for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                    {
                                        if (vector_lane_mask_check(tmp3, x1_inner))
                                        {
                                            tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                        }
                                    }
                                    return at::vec::Vectorized<float>::loadu(tmpbuf.data());
                                }
                                ()
                                ;
                                return tmp5;
                            }
                            ;
                            auto tmp6 =
                            [&]
                            {
                                if (all_zero(to_float_mask(tmp3)))
                                {
                                    return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                }
                                else
                                {
                                    return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
                                }
                            }
                            ()
                            ;
                            tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
                        }
                        tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
                    }
                }
            }
        }
```

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@jgong5 jgong5 requested a review from lezcano December 25, 2023 14:31
@jgong5
Copy link
Copy Markdown
Collaborator Author

jgong5 commented Dec 25, 2023

Couldn't node dtype propagation be done in common.py for all CSEVariables that have an FX node, similar to how we push forward the bounds from the FX nodes into variables?

The node-level dtype propagation is already in common.py (see

class DataTypePropagation:
)

Comment thread torch/_inductor/codegen/cpp.py Outdated
Comment on lines +1745 to +1751
def get_result_size(dtype: torch.dtype) -> str:
result_size = f"{self.tiling_factor}"
assert dtype.itemsize <= 4
size_multiplier = 4 // dtype.itemsize
if size_multiplier > 1:
result_size += f" * {size_multiplier}"
return result_size
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_result_size(dtype: torch.dtype) -> str:
result_size = f"{self.tiling_factor}"
assert dtype.itemsize <= 4
size_multiplier = 4 // dtype.itemsize
if size_multiplier > 1:
result_size += f" * {size_multiplier}"
return result_size
def get_result_size(dtype: torch.dtype) -> int:
assert dtype.itemsize <= 4
return self.tiling_factor * (4 // dtype.itemsize)

Comment thread torch/_inductor/codegen/cpp_prefix.h Outdated
@lezcano
Copy link
Copy Markdown
Collaborator

lezcano commented Dec 25, 2023

I see that the dtype on CSEVars is a preexisting issue. What I meant is that that logic could potentially go in

def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args, **kwargs):
# TritonTemplateKernel has no current_node
buf_bounds = ValueRanges.unknown()
if hasattr(V.interpreter, "current_node"):
fx_node = V.interpreter.current_node
assert isinstance(self.node_to_bounds, dict)
buf_bounds = self.node_to_bounds.get(
fx_node, ValueRanges.unknown()
)
csevar = self.cse.generate(
self.compute,
getattr(parent_handler, name)(*args, **kwargs), # type: ignore[has-type]
bounds=buf_bounds,
)
csevar.update_on_args(name, args, kwargs)
return csevar

so that it is generic for every CSEVariable.

Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, but looks good

Comment thread torch/_inductor/codegen/cpp.py
…epends on tiling itervar or with indirect indexing"


As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:
1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with `Vectorized<int32>::arange` for this case.
2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop.

Below is the generated code for the first case from the test `test_concat_inner_vec`. Here `x1` is the index_expr and depends on the vectorized itervar `x1`. It has constant stride 1. We vectorized it with arange. We use `all_zero` to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.
Before:
```c++
            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L))
                {
                    auto tmp0 = c10::convert<long>(x1);
                    auto tmp1 = static_cast<long>(0);
                    auto tmp2 = tmp0 >= tmp1;
                    auto tmp3 = static_cast<long>(35);
                    auto tmp4 = tmp0 < tmp3;
                    auto tmp5 = [&]
                    {
                        auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))];
                        return tmp6;
                    }
                    ;
                    auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0);
                    auto tmp8 = tmp0 >= tmp3;
                    auto tmp9 = static_cast<long>(155);
                    auto tmp10 = tmp0 < tmp9;
                    auto tmp11 = [&]
                    {
                        auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))];
                        return tmp12;
                    }
                    ;
...
```
After:
```c++
            #pragma omp for 
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L))
                {
                    auto tmp0 = c10::convert<int>(x1);
                    auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1);
                    auto tmp2 = static_cast<int>(0);
                    auto tmp3 = at::vec::Vectorized<int>(tmp2);
                    auto tmp4 = to_float_mask(tmp1 >= tmp3);
                    auto tmp5 = static_cast<int>(35);
                    auto tmp6 = at::vec::Vectorized<int>(tmp5);
                    auto tmp7 = to_float_mask(tmp1 < tmp6);
                    auto tmp8 = [&]
                    {
                        auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7));
                        return tmp9;
                    }
                    ;
                    auto tmp10 =
                    [&]
                    {
                        if (all_zero(to_float_mask(tmp7)))
                        {
                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                        }
                        else
                        {
                            return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7));
                        }
                    }
                    ()
                    ;
...
```

Below is the generated code for the second case from the test case `test_expr_vec_non_contiguous`. Here, the index_expr is `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` which depends on the vectorized itervar `x2` and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses `vector_lane_mask_check` to implement the masked version of non-contiguous load.
Before:
```c++
            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    {
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
                            auto tmp1 = static_cast<long>(2048);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                return tmp4;
                            }
                            ;
                            auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5);
                        }
                        out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0;
                    }
                }
            }
```
After:
```c++
            #pragma omp for 
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
                {
                    {
                        #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 =
                            [&]
                            {
                                __at_align__ std::array<int, 16> tmpbuf;
                                #pragma GCC unroll 16
                                for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                {
                                    tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
                                }
                                return at::vec::Vectorized<int>::loadu(tmpbuf.data());
                            }
                            ()
                            ;
                            auto tmp1 = static_cast<int>(2048);
                            auto tmp2 = at::vec::Vectorized<int>(tmp1);
                            auto tmp3 = to_float_mask(tmp0 < tmp2);
                            auto tmp4 = [&]
                            {
                                auto tmp5 =
                                [&]
                                {
                                    __at_align__ std::array<float, 16> tmpbuf;
                                    #pragma GCC unroll 16
                                    for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                    {
                                        if (vector_lane_mask_check(tmp3, x1_inner))
                                        {
                                            tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                        }
                                    }
                                    return at::vec::Vectorized<float>::loadu(tmpbuf.data());
                                }
                                ()
                                ;
                                return tmp5;
                            }
                            ;
                            auto tmp6 =
                            [&]
                            {
                                if (all_zero(to_float_mask(tmp3)))
                                {
                                    return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                }
                                else
                                {
                                    return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
                                }
                            }
                            ()
                            ;
                            tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
                        }
                        tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
                    }
                }
            }
        }
```

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@jgong5
Copy link
Copy Markdown
Collaborator Author

jgong5 commented Dec 26, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 26, 2023
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Dec 26, 2023
…range (#116387)

For the test `test_expr_vec_non_contiguous`. The index_expr `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` is invariant under the vector range of `x2`.
Before change
```c++
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
                {
                    {
                        #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 =
                            [&]
                            {
                                __at_align__ std::array<int, 16> tmpbuf;
                                #pragma GCC unroll 16
                                for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                {
                                    tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
                                }
                                return at::vec::Vectorized<int>::loadu(tmpbuf.data());
                            }
                            ()
                            ;
                            auto tmp1 = static_cast<int>(2048);
                            auto tmp2 = at::vec::Vectorized<int>(tmp1);
                            auto tmp3 = to_float_mask(tmp0 < tmp2);
                            auto tmp4 = [&]
                            {
                                auto tmp5 =
                                [&]
                                {
                                    __at_align__ std::array<float, 16> tmpbuf;
                                    #pragma GCC unroll 16
                                    for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                    {
                                        if (vector_lane_mask_check(tmp3, x1_inner))
                                        {
                                            tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                        }
                                    }
                                    return at::vec::Vectorized<float>::loadu(tmpbuf.data());
                                }
                                ()
                                ;
                                return tmp5;
                            }
                            ;
                            auto tmp6 =
                            [&]
                            {
                                if (all_zero(to_float_mask(tmp3)))
                                {
                                    return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                }
                                else
                                {
                                    return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
                                }
                            }
                            ()
                            ;
                            tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
                        }
                        tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
                    }
                }
            }
        }
```
After change
```c++
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
                {
                    {
                        #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
                        float tmp_acc0 = -std::numeric_limits<float>::infinity();
                        at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
                            auto tmp1 = static_cast<int>(2048);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 =
                                [&]
                                {
                                    __at_align__ std::array<float, 16> tmpbuf;
                                    #pragma GCC unroll 16
                                    for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                                    {
                                        if (tmp2 != 0)
                                        {
                                            tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
                                        }
                                    }
                                    return at::vec::Vectorized<float>::loadu(tmpbuf.data());
                                }
                                ()
                                ;
                                return tmp4;
                            }
                            ;
                            auto tmp5 =
                            [&]
                            {
                                if (all_zero(to_float_mask(tmp2)))
                                {
                                    return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                }
                                else
                                {
                                    return decltype(tmp3())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp3(), to_float_mask(tmp2));
                                }
                            }
                            ()
                            ;
                            tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp5);
                        }
                        tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
                    }
                }
            }
        }
```

Pull Request resolved: #116387
Approved by: https://github.com/EikanWang, https://github.com/lezcano
ghstack dependencies: #114545
@facebook-github-bot facebook-github-bot deleted the gh/jgong5/31/head branch December 29, 2023 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants