-
Notifications
You must be signed in to change notification settings - Fork 512
/
Copy pathaten_autograd_ops.h
37 lines (32 loc) · 1.51 KB
/
aten_autograd_ops.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#pragma once
#include <torch/script.h>
namespace torch_xla {
// Returns true if dilation is non-trivial (not 1) in at least one dimension.
bool IsNonTrivialDilation(at::IntArrayRef dilation);
namespace aten_autograd_ops {
struct MaxPool2dAutogradFunction
: public torch::autograd::Function<MaxPool2dAutogradFunction> {
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode);
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output);
};
struct MaxPool3dAutogradFunction
: public torch::autograd::Function<MaxPool3dAutogradFunction> {
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode);
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output);
};
} // namespace aten_autograd_ops
} // namespace torch_xla