快速数论变换NTT
时间: 2025-08-08 09:11:54 浏览: 4
<think>快速数论变换(NTT)是离散傅里叶变换(DFT)在有限域(模运算)上的实现,常用于高效计算多项式乘法(卷积)。与FFT相比,NTT在整数环上运算,避免了浮点数精度问题,特别适用于模数意义下的多项式运算(如密码学、组合数学)。
NTT要求模数必须满足存在原根(即模数是形如$p = c \cdot 2^k + 1$的质数),常见模数如998244353(原根为3)、1004535809(原根为3)等。
### NTT原理
设质数$p$,其原根为$g$,则模$p$意义下的$n$次单位根为$g^{(p-1)/n}$,其中$n$是2的幂次且$n \mid (p-1)$。
NTT正变换:
$$ X_k = \sum_{j=0}^{n-1} x_j \cdot (g^{(p-1)/n})^{kj} \mod p $$
逆变换(INTT):
$$ x_j = n^{-1} \sum_{k=0}^{n-1} X_k \cdot (g^{-(p-1)/n})^{kj} \mod p $$
### C++实现NTT
以下代码实现了NTT及其逆变换,并包含多项式乘法示例。
```cpp
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const long long MOD = 998244353; // 常用NTT模数
const long long ROOT = 3; // 模数的原根
const long long ROOT_INV = 332748118; // 原根的逆元
// 快速幂取模
long long mod_pow(long long base, long long exp, long long mod = MOD) {
long long result = 1;
while (exp) {
if (exp & 1)
result = (result * base) % mod;
base = (base * base) % mod;
exp >>= 1;
}
return result;
}
// 求逆元
inline long long mod_inv(long long a, long long mod = MOD) {
return mod_pow(a, mod - 2, mod);
}
// 位逆序置换
void bit_reverse(vector<long long>& a) {
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j)
swap(a[i], a[j]);
}
}
// NTT核心函数(invert=false为正变换,invert=true为逆变换)
void ntt(vector<long long>& a, bool invert) {
int n = a.size();
bit_reverse(a);
for (int len = 2; len <= n; len <<= 1) {
long long wlen = invert ? ROOT_INV : ROOT;
// 计算单位根:wlen = g^{(p-1)/len} mod p
for (int i = len; i < (MOD-1)>>1; i <<= 1)
wlen = (wlen * wlen) % MOD;
for (int i = 0; i < n; i += len) {
long long w = 1;
for (int j = 0; j < len / 2; j++) {
long long u = a[i+j];
long long v = (a[i+j+len/2] * w) % MOD;
a[i+j] = (u + v) % MOD;
a[i+j+len/2] = (u - v + MOD) % MOD;
w = (w * wlen) % MOD;
}
}
}
if (invert) {
long long inv_n = mod_inv(n, MOD);
for (long long& x : a)
x = (x * inv_n) % MOD;
}
}
// 多项式乘法(卷积)
vector<long long> poly_multiply(vector<long long> a, vector<long long> b) {
int n = 1;
while (n < a.size() + b.size())
n <<= 1;
a.resize(n);
b.resize(n);
ntt(a, false);
ntt(b, false);
for (int i = 0; i < n; i++)
a[i] = (a[i] * b[i]) % MOD;
ntt(a, true);
return a;
}
int main() {
// 示例:计算 (x+2)(3x+4) = 3x^2 + 10x + 8
vector<long long> A = {2, 1}; // 多项式A: x+2
vector<long long> B = {4, 3}; // 多项式B: 3x+4
vector<long long> C = poly_multiply(A, B);
cout << "Coefficients of product polynomial: ";
for (long long x : C) {
cout << x << " ";
}
// 输出: 8 10 3 (对应常数项、一次项、二次项系数)
// 因为:8 + 10x + 3x^2
return 0;
}
```
### 关键步骤解析
1. **位逆序置换**:
- 通过位操作将数组元素按分治所需顺序重新排列,避免递归。
- 时间复杂度$O(n)$。
2. **蝶形运算**:
- 在每一层循环中,将数组分成大小为`len`的块,对每块进行蝶形运算。
- 使用单位根`wlen`的幂次进行线性变换。
3. **逆变换处理**:
- 逆变换时,使用原根的逆元`ROOT_INV`作为单位根。
- 最后乘以$n^{-1}$($n$的模逆元)实现归一化。
4. **多项式乘法**:
- 通过三次NTT(两次正变换、一次逆变换)实现卷积。
- 复杂度$O(n \log n)$,远优于朴素$O(n^2)$。
### 注意事项
- 数组长度必须为2的幂次,不足时补零。
- 模数必须满足$p = c \cdot 2^k + 1$,且$2^k$大于等于多项式长度。
- 原根和原根逆元需要预先计算。
### 性能优化
1. 预处理单位根:避免重复计算单位根的幂次。
2. 使用迭代版避免递归开销。
3. 使用constexpr和inline加速模运算。
§§相关问题§§:
1. NTT与FFT在性能和适用场景上有何区别?
2. 如何实现任意模数下的多项式乘法(如非NTT友好模数)?
3. 请解释如何利用NTT计算大整数乘法?
4. 在密码学中,NTT有哪些具体应用场景?</think># 快速数论变换(NTT)详解
快速数论变换(NTT)是离散傅里叶变换(DFT)在有限域上的实现,特别适用于模数意义下的多项式乘法。与FFT相比,NTT使用整数运算而非浮点数,避免了精度问题,特别适合密码学、多项式运算等需要精确整数的场景。
## NTT数学原理
给定质数模数 $p$ 和原根 $g$,NTT定义在模 $p$ 的整数环上。对于长度为 $n$ 的多项式($n$ 是2的幂),NTT变换为:
$$X_k = \sum_{j=0}^{n-1} x_j \cdot g^{j k \frac{p-1}{n}} \mod p$$
逆变换(INTT)为:
$$x_j = n^{-1} \sum_{k=0}^{n-1} X_k \cdot g^{-j k \frac{p-1}{n}} \mod p$$
其中 $g$ 是模 $p$ 的原根,满足 $g^{(p-1)/n}$ 是 $n$ 次单位根。
## C++实现NTT
以下是完整的NTT实现,包含多项式乘法示例:
```cpp
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
// 常用NTT模数(998244353 = 119×2^23 + 1)
const long long MOD = 998244353;
const long long ROOT = 3; // 模数的原根
const long long ROOT_INV = 332748118; // 原根的逆元 (3^{-1} mod MOD)
// 快速幂取模
long long mod_pow(long long base, long long exp, long long mod = MOD) {
long long result = 1;
while (exp) {
if (exp & 1)
result = (result * base) % mod;
base = (base * base) % mod;
exp >>= 1;
}
return result;
}
// 求逆元
inline long long mod_inv(long long a, long long mod = MOD) {
return mod_pow(a, mod - 2, mod);
}
// 位逆序置换
void bit_reverse(vector<long long>& a) {
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
// 计算下一个反转索引
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j)
swap(a[i], a[j]);
}
}
// NTT核心函数(invert=false为正变换,invert=true为逆变换)
void ntt(vector<long long>& a, bool invert) {
int n = a.size();
bit_reverse(a);
for (int len = 2; len <= n; len <<= 1) {
// 计算当前长度的单位根
long long wlen = invert ? ROOT_INV : ROOT;
for (int i = len; i < (MOD-1)>>1; i <<= 1)
wlen = (wlen * wlen) % MOD;
// 蝶形变换
for (int i = 0; i < n; i += len) {
long long w = 1;
for (int j = 0; j < len / 2; j++) {
long long u = a[i+j];
long long v = (a[i+j+len/2] * w) % MOD;
a[i+j] = (u + v) % MOD;
a[i+j+len/2] = (u - v + MOD) % MOD;
w = (w * wlen) % MOD;
}
}
}
// 逆变换需要乘以n^{-1}
if (invert) {
long long inv_n = mod_inv(n, MOD);
for (long long& x : a)
x = (x * inv_n) % MOD;
}
}
// 多项式乘法(卷积)
vector<long long> poly_multiply(vector<long long> a, vector<long long> b) {
// 扩展为2的幂次
int n = 1;
while (n < a.size() + b.size())
n <<= 1;
a.resize(n);
b.resize(n);
// NTT变换
ntt(a, false);
ntt(b, false);
// 点值乘法
for (int i = 0; i < n; i++)
a[i] = (a[i] * b[i]) % MOD;
// 逆变换
ntt(a, true);
return a;
}
int main() {
// 示例1:计算 (x+2)(3x+4) = 3x^2 + 10x + 8
vector<long long> A = {2, 1}; // 多项式A: x+2
vector<long long> B = {4, 3}; // 多项式B: 3x+4
vector<long long> C = poly_multiply(A, B);
cout << "Coefficients of (x+2)(3x+4): ";
for (long long x : C) {
cout << x << " ";
}
// 输出: 8 10 3 (对应常数项、一次项、二次项系数)
// 示例2:计算 (x^2 + 2x + 3)(4x + 5)
vector<long long> D = {3, 2, 1}; // x^2 + 2x + 3
vector<long long> E = {5, 4}; // 4x + 5
vector<long long> F = poly_multiply(D, E);
cout << "\nCoefficients of (x^2+2x+3)(4x+5): ";
for (long long x : F) {
cout << x << " ";
}
// 输出: 15 22 14 4 (对应常数项、一次项、二次项、三次项系数)
// 即 4x^3 + 14x^2 + 22x + 15
return 0;
}
```
## 关键组件解析
### 1. 位逆序置换 (bit_reverse)
```cpp
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j)
swap(a[i], a[j]);
}
```
- 将数组元素按分治需要的顺序重新排列
- 时间复杂度 $O(n)$,是迭代实现NTT的关键
### 2. 蝶形变换 (Butterfly Operation)
```cpp
for (int i = 0; i < n; i += len) {
long long w = 1;
for (int j = 0; j < len / 2; j++) {
long long u = a[i+j];
long long v = (a[i+j+len/2] * w) % MOD;
a[i+j] = (u + v) % MOD;
a[i+j+len/2] = (u - v + MOD) % MOD;
w = (w * wlen) % MOD;
}
}
```
- 核心计算单元,处理长度为`len`的分组
- 使用单位根的幂次`w`进行线性组合
- 每次迭代处理长度加倍的分组
### 3. 逆变换处理
```cpp
if (invert) {
long long inv_n = mod_inv(n, MOD);
for (long long& x : a)
x = (x * inv_n) % MOD;
}
```
- 逆变换后需要乘以$n^{-1}$进行归一化
- 使用费马小定理计算模逆元
## NTT模数选择
常用NTT友好质数:
1. $998244353 = 119×2^{23} + 1$(原根=3)
2. $1004535809 = 479×2^{21} + 1$(原根=3)
3. $469762049 = 7×2^{26} + 1$(原根=3)
这些质数满足$p = k×2^m + 1$的形式,支持足够大的变换长度。
## 性能优化技巧
1. **预处理单位根**:避免重复计算$g^{(p-1)/len}$
2. **使用constexpr计算**:编译时计算常量表达式
3. **循环展开**:手动展开内层循环
4. **SIMD指令**:使用AVX2等指令集并行化计算
5. **缓存友好访问**:优化内存访问模式
## 应用场景
1. 多项式乘法(卷积)
2. 大整数乘法(将整数视为多项式)
3. 密码学(格密码、同态加密)
4. 组合计数(生成函数计算)
5. 信号处理(有限域上的滤波)
## 时间复杂度
| 操作 | 时间复杂度 |
|------|------------|
| 单次NTT/INTT | $O(n \log n)$ |
| 多项式乘法 | $O(n \log n)$ |
| 大整数乘法 | $O(n \log n)$ |
相比朴素$O(n^2)$算法,NTT在$n>100$时具有显著优势。
阅读全文
相关推荐


















