本学习笔记主要是用来总结一些多项式基础技术,包括 FFT、NTT、MTT。
0. 引入
有两个多项式 f ( x ) = ∑ k = 0 n − 1 f k x k f(x)=\sum_{k=0}^{n-1} f_kx^k f ( x ) = ∑ k = 0 n − 1 f k x k 和 g ( x ) = ∑ k = 0 n − 1 g k x k g(x)=\sum_{k=0}^{n-1} g_kx^k g ( x ) = ∑ k = 0 n − 1 g k x k (把 x = 10 x=10 x = 1 0 带进去,这就是一个高精度整数的一个形式),求 ( f × g ) ( x ) (f\times g)(x) ( f × g ) ( x ) 。
朴素的乘法就是逐位相乘,时间是 O ( n 2 ) O(n^2) O ( n 2 ) 。如果 n n n 级别达到 1 0 5 10^5 1 0 5 呢?
我们考虑用 n n n 个数值 a i a_i a i 代进 f f f 求值,得到 f i ′ f'_i f i ′ ,以此表示 f f f 。也就是,我们用 n n n 个点 ( a i , f i ′ ) (a_i,f'_i) ( a i , f i ′ ) 表示了一个 n − 1 n-1 n − 1 次多项式,显然这是唯一对应的。这种表示多项式的方式,我们称为点值表示法 。
我们用 O ( n ) O(n) O ( n ) 的时间,将 f i ′ f'_i f i ′ 和 g i ′ g'_i g i ′ 相乘,得到的点值也就是 ( f × g ) i ′ (f\times g)'_i ( f × g ) i ′ ,再将其变换回 ( f × g ) ( x ) (f\times g)(x) ( f × g ) ( x ) 即可。
总结一下,我们从多项式的系数表示法 ,变换 为点值表示法 ,进行乘法,再变换回系数表示法 。乘法是 O ( n ) O(n) O ( n ) 的,用朴素的方式求 n n n 个点的值却还是 O ( n 2 ) O(n^2) O ( n 2 ) 的时间。如果有一些特性使得变换的时间优于 O ( n 2 ) O(n^2) O ( n 2 ) ,算法就得到了优化。
下文默认 n n n 可以被表示为 2 k 2^k 2 k (如果不够补全即可)(原因是便于下文进行大量的 n 2 \frac{n}{2} 2 n 计算)。
1. FFT
FFT(Fast Fourier Transform,快速傅立叶变换)是可以在 O ( n log n ) O(n\log n) O ( n log n ) 时间内借助复数完成的变换。
DFT 概念
DFT(Discrete Fourier Transform,离散傅里叶变换)选用的 n n n 个参数为 a k = ω n k = cos ( 2 π k n ) + i sin ( 2 π k n ) a_k=\omega_n^k=\cos(\frac{2\pi k}{n})+i\sin(\frac{2\pi k}{n}) a k = ω n k = cos ( n 2 π k ) + i sin ( n 2 π k ) 。容易发现这就是复平面上单位圆的 n n n 平分点,以逆时针顺序,从 ( 1 , 0 ) (1,0) ( 1 , 0 ) 开始。
它有一些性质(可以用其几何意义解释):
ω n 0 = ω n n = 1 \omega_n^0=\omega_n^n=1 ω n 0 = ω n n = 1
ω n n 2 = − 1 \omega_n^{\frac{n}{2}}=-1 ω n 2 n = − 1
ω n k = ω m n m k \omega_n^k=\omega_{mn}^{mk} ω n k = ω m n m k
ω n a ω n b = ω n a + b \omega_n^a\omega_n^b=\omega_n^{a+b} ω n a ω n b = ω n a + b
( ω n k ) m = ω n m k (\omega_n^k)^m=\omega_n^{mk} ( ω n k ) m = ω n m k
概括地说,你可以把它按照 1 1 1 的 k n \frac{k}{n} n k 次方来运算。
从 DFT 到 FFT
首先处理 n = 1 n=1 n = 1 的情况,ω 1 0 = 1 \omega_1^0=1 ω 1 0 = 1 ,f ′ ( 1 ) = f 0 f'(1)=f_0 f ′ ( 1 ) = f 0 。
我们把下标按奇偶分开,即
f 0 ′ ( x ) = ∑ k = 0 n 2 − 1 f 2 k x k f'_{0}(x)=\sum_{k=0}^{\frac{n}{2}-1}f_{2k}x^{k}
f 0 ′ ( x ) = k = 0 ∑ 2 n − 1 f 2 k x k
f 1 ′ ( x ) = ∑ k = 0 n 2 − 1 f 2 k + 1 x k f'_{1}(x)=\sum_{k=0}^{\frac{n}{2}-1}f_{2k+1}x^{k}
f 1 ′ ( x ) = k = 0 ∑ 2 n − 1 f 2 k + 1 x k
f ′ ( x ) = f 0 ′ ( x 2 ) + x f 1 ′ ( x 2 ) f'(x)=f'_0(x^2)+xf'_1(x^2)
f ′ ( x ) = f 0 ′ ( x 2 ) + x f 1 ′ ( x 2 )
更详细地,设 k < n 2 k<\frac{n}{2} k < 2 n ,
f ′ ( ω n k ) = f 0 ′ ( ω n 2 k ) + ω n k f 1 ′ ( ω n 2 k ) = f 0 ′ ( ω n 2 k ) + ω n k f 1 ′ ( ω n 2 k ) \begin{aligned} f'(\omega_n^k)&=f'_0(\omega_n^{2k})+\omega_n^kf'_1(\omega_n^{2k}) \\
&=f'_0(\omega_{\frac{n}{2}}^{k})+\omega_n^kf'_1(\omega_{\frac{n}{2}}^{k}) \end{aligned} f ′ ( ω n k ) = f 0 ′ ( ω n 2 k ) + ω n k f 1 ′ ( ω n 2 k ) = f 0 ′ ( ω 2 n k ) + ω n k f 1 ′ ( ω 2 n k )
f ′ ( ω n k + n 2 ) = f 0 ′ ( ω n 2 k + n ) + ω n k + n 2 f 1 ′ ( ω n 2 k + n ) = f 0 ′ ( ω n 2 k ) − ω n k f 1 ′ ( ω n 2 k ) \begin{aligned} f'(\omega_n^{k+\frac{n}{2}})&=f'_0(\omega_n^{2k+n})+\omega_n^{k+\frac{n}{2}}f'_1(\omega_n^{2k+n}) \\
&=f'_0(\omega_{\frac{n}{2}}^{k})-\omega_n^kf'_1(\omega_{\frac{n}{2}}^{k}) \end{aligned} f ′ ( ω n k + 2 n ) = f 0 ′ ( ω n 2 k + n ) + ω n k + 2 n f 1 ′ ( ω n 2 k + n ) = f 0 ′ ( ω 2 n k ) − ω n k f 1 ′ ( ω 2 n k )
因此,每次计算都可以把 f ′ f' f ′ 分为 f 0 ′ f'_0 f 0 ′ 和 f 1 ′ f'_1 f 1 ′ 两个大小为 n 2 \frac{n}{2} 2 n 的问题来解决,然后 O ( n ) O(n) O ( n ) 计算 f ′ f' f ′ 。时间为 T ( n ) = O ( n ) + 2 T ( n 2 ) = O ( n log n ) T(n)=O(n)+2T(\frac{n}{2})=O(n\log n) T ( n ) = O ( n ) + 2 T ( 2 n ) = O ( n log n ) 。
从 DFT 到 IDFT
IDFT(Inverse Discrete Fourier Transform,离散傅立叶逆变换)就是要完成从点值表示法转换回系数表示法。
我们直接考虑对原数组再做一遍傅里叶变换,但是参数变为 ω n − k \omega_n^{-k} ω n − k ,看看结果如何:
f ′ ′ ( x ) = ∑ a = 0 n − 1 f ′ ( − ω n a ) x a = ∑ a = 0 n − 1 x a ∑ b = 0 n − 1 f ( ω n b ) ω n − a b = ∑ a = 0 n − 1 x a ∑ b = 0 n − 1 ω n − a b ∑ c = 0 n − 1 f c ω n b c = ∑ a = 0 n − 1 x a ∑ c = 0 n − 1 f c ∑ b = 0 n − 1 ω n ( c − a ) b \begin{aligned} f''(x)&=\sum_{a=0}^{n-1}f'(-\omega_{n}^{a})x^a \\
&=\sum_{a=0}^{n-1}x^a\sum_{b=0}^{n-1} f(\omega_n^b)\omega_{n}^{-ab} \\
&= \sum_{a=0}^{n-1}x^a\sum_{b=0}^{n-1}\omega_{n}^{-ab}\sum_{c=0}^{n-1}f_c\omega_{n}^{bc}\\
&= \sum_{a=0}^{n-1}x^a\sum_{c=0}^{n-1}f_c\sum_{b=0}^{n-1} \omega_{n}^{(c-a)b} \end{aligned} f ′ ′ ( x ) = a = 0 ∑ n − 1 f ′ ( − ω n a ) x a = a = 0 ∑ n − 1 x a b = 0 ∑ n − 1 f ( ω n b ) ω n − a b = a = 0 ∑ n − 1 x a b = 0 ∑ n − 1 ω n − a b c = 0 ∑ n − 1 f c ω n b c = a = 0 ∑ n − 1 x a c = 0 ∑ n − 1 f c b = 0 ∑ n − 1 ω n ( c − a ) b
c − a ≡ 0 ( m o d n ) c-a\equiv 0\pmod n c − a ≡ 0 ( m o d n ) 即 a = c a=c a = c 时,∑ b = 0 n − 1 ω n ( c − a ) b = ∑ b = 0 n − 1 1 = n \sum_{b=0}^{n-1} \omega_{n}^{(c-a)b}=\sum_{b=0}^{n-1} 1=n ∑ b = 0 n − 1 ω n ( c − a ) b = ∑ b = 0 n − 1 1 = n 。
否则,
∑ b = 0 n − 1 ω n ( c − a ) b = ∑ b = 0 n − 1 ( ω n c − a ) b = 1 − ω n n ( c − a ) 1 − ω n c − a = 1 − 1 1 − ω n c − a = 0 \begin{aligned} \sum_{b=0}^{n-1} \omega_{n}^{(c-a)b}&=\sum_{b=0}^{n-1} (\omega_{n}^{c-a})^b \\
&=\frac{1-\omega_{n}^{n(c-a)}}{1-\omega_{n}^{c-a}} \\
&=\frac{1-1}{1-\omega_{n}^{c-a}}=0 \end{aligned} b = 0 ∑ n − 1 ω n ( c − a ) b = b = 0 ∑ n − 1 ( ω n c − a ) b = 1 − ω n c − a 1 − ω n n ( c − a ) = 1 − ω n c − a 1 − 1 = 0
所以,
f ′ ′ ( x ) = ∑ a = 0 n − 1 x a × n f a f''(x)= \sum_{a=0}^{n-1}x^a\times nf_a
f ′ ′ ( x ) = a = 0 ∑ n − 1 x a × n f a
也就是说,在再做一遍参数为 ω n − k \omega_n^{-k} ω n − k 的傅立叶变换,再将每个数 × 1 n \times \frac{1}{n} × n 1 ,就得回原式。
具体实现
笔者没写过递归写法,但是显而易见地,递归常数相当大。
所以我们考虑倍增,只过回推的过程,让 n n n 从 1 1 1 开始不断翻倍。
顺推的时候,把偶数下标的值顺次移到较低的 n 2 \frac{n}{2} 2 n 位,把奇数下标的值顺次移到较高的 n 2 \frac{n}{2} 2 n 位。这样子的话,设 f ′ ′ f'' f ′ ′ 为更深一层递归所得的 f ′ f' f ′ ,则 f 0 , i ′ = f i ′ ′ f'_{0,i}=f''_i f 0 , i ′ = f i ′ ′ ,f 1 , i ′ = f i + n 2 ′ ′ f'_{1,i}=f''_{i+\frac{n}{2}} f 1 , i ′ = f i + 2 n ′ ′ ,。
模拟 f f f 的下推过程,以下以 n = 8 n=8 n = 8 为例:
1 2 3 4 0 1 2 3 4 5 6 7 0 2 4 6|1 3 5 7 0 4|2 6|1 5|3 7 0|4|2|6|1|5|3|7
用二进制表示:
1 2 3 4 000 001 010 011 100 101 110 111 000 010 100 110|001 011 101 111 000 100|010 100|001 101|011 111 000|100|010|100|001|101|011|111
发现目标位置的下标是当前下标在二进制意义下的翻转。
于是整个过程就是:
将每位和下标在二进制意义下翻转的位交换。
最外层从小到大枚举当前层每单位长度,下一层枚举每连续两个单位,并且依照公式递推。
如果是逆变换,就将每个数 × 1 n \times \frac{1}{n} × n 1 。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 void fft (Cplx* a, int tp) { for (int i (0 ); i != L; ++i) if (i < tgt[i]) swap (a[i], a[tgt[i]]); for (int i (1 ); i != L; i <<= 1 ) { for (int j (0 ); j != L; j += i * 2 ) { for (int k (0 ); k != i; ++k) { Cplx t ((tp == 1 ? angle[i + k] : conj(angle[i + k])) * a[j + i + k]) ; a[j + i + k] = a[j + k] - t; a[j + k] = a[j + k] + t; } } } if (tp == -1 ) for (int i (0 ); i != L; ++i) a[i] = a[i] * Cplx (1.0 / L, 0 ); }
预处理:
1 2 3 4 5 6 for (L = 1 ; L <= N + M; L <<= 1 ) ;for (int i (0 ); i != L; ++i) tgt[i] = (tgt[i >> 1 ] >> 1 ) | ((i & 1 ) * (L >> 1 ));for (int i (1 ); i != L; i <<= 1 ) for (int j (0 ); j != i; ++j) angle[i + j] = Cplx (cos (PI / i * j), sin (PI / i * j));
独属于 FFT 的乘法优化
我们可以充分利用复数,将多项式乘法优化至只做 2 2 2 遍 FFT。
考虑对 h ( x ) = f ( x ) + i g ( x ) h(x)=f(x)+ig(x) h ( x ) = f ( x ) + i g ( x ) 进行平方,得到 h 2 ( x ) = ( f 2 ( x ) − g 2 ( x ) ) + 2 i f ( x ) g ( x ) h^2(x)=(f^2(x)-g^2(x))+2if(x)g(x) h 2 ( x ) = ( f 2 ( x ) − g 2 ( x ) ) + 2 i f ( x ) g ( x ) 。虚部乘 − i 2 \frac{-i}{2} 2 − i 就是乘积。
平方时直接将点值本身平方即可,总共 1 1 1 遍 DFT,1 1 1 遍 IDFT。
修订于 20240414:也可以用下方 MTT 第二种办法中的原理,同时求出两个 FFT 结果;但是这两种方法都会带来不可忽视的精度损失。
2. NTT
NTT(Number Theory Transform,数论变换)是系数在模意义下的离散傅里叶变换。
原根
g g g 是模 p p p 的原根,当且仅当满足 g m ≡ 1 ( m o d p ) g^m\equiv 1\pmod p g m ≡ 1 ( m o d p ) 的最小正整数 m = φ ( p ) m=\varphi(p) m = φ ( p ) 。当 p p p 是素数时,有 g m m o d p ( 0 < m ≤ p − 1 ) g^m \bmod p(0<m\le p-1) g m m o d p ( 0 < m ≤ p − 1 ) 互不相等。
NTT 常用的三个模数:
998244353 = 7 × 17 × 2 23 + 1 1004535809 = 479 × 2 21 + 1 469762049 = 7 × 2 26 + 1 \begin{aligned}998244353&=7\times 17\times 2^{23}+1 \\
1004535809&=479\times 2^{21}+1 \\
469762049&=7\times 2^{26}+1\end{aligned} 9 9 8 2 4 4 3 5 3 1 0 0 4 5 3 5 8 0 9 4 6 9 7 6 2 0 4 9 = 7 × 1 7 × 2 2 3 + 1 = 4 7 9 × 2 2 1 + 1 = 7 × 2 2 6 + 1
它们都形如 p = n × 2 m + 1 p=n\times 2^m+1 p = n × 2 m + 1 (这意味着 φ ( p ) = p − 1 \varphi(p)=p-1 φ ( p ) = p − 1 是 2 m 2^m 2 m 的倍数,而且这些 m m m 足够大),都有原根 g = 3 g=3 g = 3 。
原根的性质
原根满足 FFT 处所述单位元的许多性质。
令 ω n k = g k ( φ ( p ) − 1 ) n \omega_n^k=g^{\frac{k(\varphi(p)-1)}{n}} ω n k = g n k ( φ ( p ) − 1 ) ,容易发现 n n n 形如 2 k 2^k 2 k 时,指数为整数。
ω n 0 = ω n n = g 0 = 1 \omega_n^0=\omega_n^n=g^0=1 ω n 0 = ω n n = g 0 = 1
( ω n n 2 ) 2 = 1 (\omega_n^{\frac{n}{2}})^2=1 ( ω n 2 n ) 2 = 1 ,而且由原根的定义,ω n n 2 = g p − 1 2 ≠ g p − 1 = 1 \omega_n^{\frac{n}{2}}=g^{\frac{p-1}{2}}\neq g^{p-1}=1 ω n 2 n = g 2 p − 1 = g p − 1 = 1 ,所以ω n n 2 = − 1 \omega_n^{\frac{n}{2}}=-1 ω n 2 n = − 1
其余的都直接是分数的基本性质和幂运算的规则可以解释的。
应用
事实上,直接用 g g g 的幂替换单位元就可以了。g − 1 g^{-1} g − 1 也就是求逆元。
NTT 在对系数取模的多项式乘法中很常用。一些时候,即使不需要取模,如果模数大于系数可能的最大值,同样可以用 NTT 来保证精度(例如高精度乘法)。
代码不放了,差不多。
3. MTT
MTT 要实现的是任意模数多项式乘法。n n n 为 1 0 5 10^5 1 0 5 级别,p p p 为 1 0 9 10^9 1 0 9 级别。
NTT 对模数有限制,不能直接使用。而系数最大值为 n p 2 np^2 n p 2 ,可达 1 0 23 10^{23} 1 0 2 3 。过程中使用浮点数而不能取模的 FFT 精度又不高。都需要进行改造。
NTT 结合 exCRT
我们做 3 3 3 轮,分别求出 ( f × g ) ( x ) (f\times g)(x) ( f × g ) ( x ) 对 3 3 3 个不同模数取模的结果,然后用扩展中国剩余定理的思路合并。
具体而言,我们知道 h k ≡ ( f × g ) ( x ) ( m o d p k ) h_k\equiv (f\times g)(x) \pmod {p_k} h k ≡ ( f × g ) ( x ) ( m o d p k ) ,求h = ( f × g ) ( x ) h=(f\times g)(x) h = ( f × g ) ( x ) 。
那么
h 1 + x p 1 ≡ h 2 ( m o d p 2 ) h_1+xp_1\equiv h_2\pmod {p_2}
h 1 + x p 1 ≡ h 2 ( m o d p 2 )
x ≡ ( h 2 − h 1 ) × p 1 − 1 ( m o d p 2 ) x\equiv (h_2-h_1)\times p_1^{-1}\pmod {p_2}
x ≡ ( h 2 − h 1 ) × p 1 − 1 ( m o d p 2 )
求得 x x x 以后,令 h ′ = h 1 + x p 1 h'=h_1+xp_1 h ′ = h 1 + x p 1 ,则 h ≡ h ′ ( m o d p 1 p 2 ) h\equiv h'\pmod {p_1p_2} h ≡ h ′ ( m o d p 1 p 2 ) 。
h ′ + y p 1 p 2 = h 3 ( m o d p 3 ) h'+yp_1p_2=h_3\pmod {p_3}
h ′ + y p 1 p 2 = h 3 ( m o d p 3 )
y = ( h 3 − h ′ ) × ( p 1 p 2 ) − 1 ( m o d p 3 ) y=(h_3-h')\times (p_1p_2)^{-1}\pmod {p_3}
y = ( h 3 − h ′ ) × ( p 1 p 2 ) − 1 ( m o d p 3 )
求得 y y y 以后,令 h ′ ′ = h ′ + y p 1 p 2 h''=h'+yp_1p_2 h ′ ′ = h ′ + y p 1 p 2 ,则 h ≡ h ′ ′ ( m o d p 1 p 2 p 3 ) h\equiv h''\pmod {p_1p_2p_3} h ≡ h ′ ′ ( m o d p 1 p 2 p 3 ) 。
我们可以选取 NTT 章节所提出的三个模数作为 p 1 p_1 p 1 、p 2 p_2 p 2 、p 3 p_3 p 3 ,那么 p 1 p 2 p 3 > 4 × 1 0 26 p_1p_2p_3>4\times 10^{26} p 1 p 2 p 3 > 4 × 1 0 2 6 ,远超系数最大值,所以 h = h ′ ′ h=h'' h = h ′ ′ 。也正是因此,我们不能直接用 CRT,可能会爆精度。
在计算 h ′ h' h ′ 、h ′ ′ h'' h ′ ′ 的过程中对模数取模即可。
总共要跑 9 次 NTT(正向 6 次,逆向 3 次),常数较大。但是有一种写法,用结构体分别存下三个余数,同时跑 NTT,就相当于是 3 遍 NTT 但是乘上 3 的常数,速度得到了提高。
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 const int TP = 3 ;const int DVS[] = {998244353 , 1004535809 , 469762049 };const int INV0 = 669690699 , INV01 = 354521948 ;const int G = 3 ;int N, M, P, L, tgt[MXN];struct Num { int n[TP]; } A[MXN], B[MXN];Num getnum (int a = 0 , int b = 0 , int c = 0 ) { Num n; n.n[0 ] = a, n.n[1 ] = b, n.n[2 ] = c; return n; } Num operator +(Num a, Num b) { Num n; for (int i (0 ); i != TP; ++i) n.n[i] = (a.n[i] + b.n[i]) % DVS[i]; return n; } Num operator -(Num a, Num b) { Num n; for (int i (0 ); i != TP; ++i) n.n[i] = (a.n[i] - b.n[i] + DVS[i]) % DVS[i]; return n; } Num operator *(Num a, Num b) { Num n; for (int i (0 ); i != TP; ++i) n.n[i] = a.n[i] * 1LL * b.n[i] % DVS[i]; return n; }int pw (int b, int p, int m) { int s (1 ) ; while (p) { if (p & 1 ) s = s * 1LL * b % m; b = b * 1LL * b % m; p >>= 1 ; } return s; }void ntt (Num* a, bool tp) { for (int i (0 ); i != L; ++i) if (i < tgt[i]) swap (a[i], a[tgt[i]]); for (int i (1 ); i != L; i <<= 1 ) { Num bas (getnum(pw(G, (DVS[0 ] - 1 ) / (i * 2 ), DVS[0 ]), pw(G, (DVS[1 ] - 1 ) / (i * 2 ), DVS[1 ]), pw(G, (DVS[2 ] - 1 ) / (i * 2 ), DVS[2 ]))) ; if (tp) for (int j (0 ); j != TP; ++j) bas.n[j] = pw (bas.n[j], DVS[j] - 2 , DVS[j]); for (int j (0 ); j != L; j += i * 2 ) { Num e (getnum(1 , 1 , 1 )) ; for (int k (0 ); k != i; ++k, e = e * bas) { Num w (e * a[j + i + k]) ; a[j + i + k] = a[j + k] - w; a[j + k] = a[j + k] + w; } } } if (tp) { Num bas (getnum(pw(L, DVS[0 ] - 2 , DVS[0 ]), pw(L, DVS[1 ] - 2 , DVS[1 ]), pw(L, DVS[2 ] - 2 , DVS[2 ]))) ; for (int i (0 ); i != L; ++i) a[i] = a[i] * bas; } }int main () { cin >> N >> M >> P; for (int i (0 ), x (0 ); i <= N; ++i) { cin >> x; for (int j (0 ); j != TP; ++j) A[i].n[j] = x % DVS[j]; } for (int i (0 ), x (0 ); i <= M; ++i) { cin >> x; for (int j (0 ); j != TP; ++j) B[i].n[j] = x % DVS[j]; } for (L = 1 ; L <= N + M; L <<= 1 ) ; for (int i (1 ); i != L; ++i) tgt[i] = (tgt[i >> 1 ] >> 1 ) + (i & 1 ) * (L >> 1 ); ntt (A, 0 ), ntt (B, 0 ); for (int i (0 ); i != L; ++i) A[i] = A[i] * B[i]; ntt (A, 1 ); for (int i (0 ); i <= N + M; ++i) { int K1 ((A[i].n[1 ] - A[i].n[0 ] + DVS[1 ]) % DVS[1 ] * 1LL * INV0 % DVS[1 ]) ; long long x (A[i].n[0 ] + K1 * 1LL * DVS[0 ]) ; int K2 (((A[i].n[2 ] - x) * 1LL % DVS[2 ] + DVS[2 ]) % DVS[2 ] * INV01 % DVS[2 ]) ; cout << (x % P + K2 * 1LL * DVS[0 ] % P * DVS[1 ] % P) % P << ' ' ; } cout << endl; return 0 ; }
拆系数 FFT
我们把 f k f_k f k 表示为 b f 0 , k + f 1 , k ( f 1 , k < b ) bf_{0,k}+f_{1,k}(f_{1,k}<b) b f 0 , k + f 1 , k ( f 1 , k < b ) 的形式。则 f k g k = b 2 f 0 , k g 0 , k + b ( f 0 , k g 1 , k + f 1 , k g 0 , k ) + f 1 , k g 1 , k f_kg_k=b^2f_{0,k}g_{0,k}+b(f_{0,k}g_{1,k}+f_{1,k}g_{0,k})+f_{1,k}g_{1,k} f k g k = b 2 f 0 , k g 0 , k + b ( f 0 , k g 1 , k + f 1 , k g 0 , k ) + f 1 , k g 1 , k 。分别求出 f 0 f_0 f 0 、f 1 f_1 f 1 和 g 0 g_0 g 0 、g 1 g_1 g 1 两两之间的乘积即可。FFT 中系数最大值为 n b 2 nb^2 n b 2 、n p np n p 、n × p 2 b 2 n\times \frac{p^2}{b^2} n × b 2 p 2 的最大值。b = p b=\sqrt{p} b = p 的时候值域在 n p np n p 的级别,大约是 1 0 14 10^{14} 1 0 1 4 ,可以承受。
要跑 8 次 FFT(正向 4 次,逆向 4 次)。考虑优化。
还是充分发挥复数的优势,构造 A ( x ) = f 0 ( x ) g 0 ( x ) + i f 0 ( x ) g 1 ( x ) A(x)=f_0(x)g_0(x)+if_0(x)g_1(x) A ( x ) = f 0 ( x ) g 0 ( x ) + i f 0 ( x ) g 1 ( x ) 、B ( x ) = f 1 ( x ) g 0 ( x ) + i f 1 ( x ) g 1 ( x ) B(x)=f_1(x)g_0(x)+if_1(x)g_1(x) B ( x ) = f 1 ( x ) g 0 ( x ) + i f 1 ( x ) g 1 ( x ) ,那么逆向的变换就只要做 2 2 2 次。
考虑正向的变换还有没有优化的余地。我们要用更少的次数,同时求出 f 0 f_0 f 0 、 f 1 f_1 f 1 、 g 0 g_0 g 0 、 g 1 g_1 g 1 的傅里叶变换。
考虑这样的两个共轭的多项式 p ( x ) = s ( x ) + i t ( x ) p(x)=s(x)+it(x) p ( x ) = s ( x ) + i t ( x ) ,q ( x ) = s − i t ( x ) q(x)=s-it(x) q ( x ) = s − i t ( x ) 。
对其进行傅里叶变换,
p ′ ( ω n k ) = ∑ k = 0 n − 1 ( cos ( 2 π k n ) + i sin ( 2 π k n ) ) ( s k + i t k ) p'(\omega_n^k)=\sum_{k=0}^{n-1}(\cos(\frac{2\pi k}{n})+i\sin(\frac{2\pi k}{n}))(s_k+it_k)
p ′ ( ω n k ) = k = 0 ∑ n − 1 ( cos ( n 2 π k ) + i sin ( n 2 π k ) ) ( s k + i t k )
令 X = 2 π k n X=\frac{2\pi k}{n} X = n 2 π k ,简化式子,
p ′ ( ω n k ) = ∑ k = 0 n − 1 ( cos ( X ) + i sin ( X ) ) ( s k + i t k ) = ∑ k = 0 n − 1 ( cos ( X ) + i sin ( X ) ) ( s k + i t k ) = ∑ k = 0 n − 1 ( cos ( X ) s k − sin ( X ) t k ) + i ( sin ( X ) s k + cos ( X ) t k ) \begin{aligned} p'(\omega_n^k)&=\sum_{k=0}^{n-1}(\cos(X)+i\sin(X))(s_k+it_k) \\
&=\sum_{k=0}^{n-1}(\cos(X)+i\sin(X))(s_k+it_k) \\
&=\sum_{k=0}^{n-1}(\cos(X)s_k-\sin(X)t_k)+i(\sin(X)s_k+\cos(X)t_k) \end{aligned} p ′ ( ω n k ) = k = 0 ∑ n − 1 ( cos ( X ) + i sin ( X ) ) ( s k + i t k ) = k = 0 ∑ n − 1 ( cos ( X ) + i sin ( X ) ) ( s k + i t k ) = k = 0 ∑ n − 1 ( cos ( X ) s k − sin ( X ) t k ) + i ( sin ( X ) s k + cos ( X ) t k )
q ′ ( ω n k ) = ∑ k = 0 n − 1 ( cos ( X ) + i sin ( X ) ) ( s k − i t k ) = ∑ k = 0 n − 1 ( cos ( − X ) − i sin ( − X ) ) ( s k − i t k ) = ∑ k = 0 n − 1 ( cos ( − X ) s k − sin ( − X ) t k ) − i ( sin ( − X ) s k + cos ( − X ) t k ) = p ′ ( ω n ( n − k ) m o d n ) ‾ \begin{aligned} q'(\omega_n^k)&=\sum_{k=0}^{n-1}(\cos(X)+i\sin(X))(s_k-it_k) \\
&=\sum_{k=0}^{n-1}(\cos(-X)-i\sin(-X))(s_k-it_k) \\
&=\sum_{k=0}^{n-1}(\cos(-X)s_k-\sin(-X)t_k)-i(\sin(-X)s_k+\cos(-X)t_k) \\
&=\overline{p'(\omega_n^{(n-k)\bmod n})} \end{aligned} q ′ ( ω n k ) = k = 0 ∑ n − 1 ( cos ( X ) + i sin ( X ) ) ( s k − i t k ) = k = 0 ∑ n − 1 ( cos ( − X ) − i sin ( − X ) ) ( s k − i t k ) = k = 0 ∑ n − 1 ( cos ( − X ) s k − sin ( − X ) t k ) − i ( sin ( − X ) s k + cos ( − X ) t k ) = p ′ ( ω n ( n − k ) m o d n )
也就是说,我们可以用 p ′ p' p ′ 求出 q ′ q' q ′ 。那么我们就可以把两个数列合起来 DFT。
令 p ( x ) = f 0 ( x ) + i f 1 ( x ) p(x)=f_0(x)+if_1(x) p ( x ) = f 0 ( x ) + i f 1 ( x ) ,q ( x ) = f 0 ( x ) − i f 1 ( x ) q(x)=f_0(x)-if_1(x) q ( x ) = f 0 ( x ) − i f 1 ( x ) 。
f 0 , k ′ = p k ′ + q k ′ 2 = p k ′ + p ( n − k ) m o d n ′ ‾ 2 \begin{aligned}f'_{0,k}&=\frac{p'_k+q'_k}{2} \\
&=\frac{p'_k+\overline{p'_{(n-k)\bmod n}}}{2}\end{aligned} f 0 , k ′ = 2 p k ′ + q k ′ = 2 p k ′ + p ( n − k ) m o d n ′
f 1 , k ′ = p k ′ − q k ′ 2 i = − i × p k ′ − p ( n − k ) m o d n ′ ‾ 2 \begin{aligned}f'_{1,k}&=\frac{p'_k-q'_k}{2i} \\
&=-i\times \frac{p'_k-\overline{p'_{(n-k)\bmod n}}}{2}\end{aligned} f 1 , k ′ = 2 i p k ′ − q k ′ = − i × 2 p k ′ − p ( n − k ) m o d n ′
所以我们只需要做一次 FFT 求出 p ′ ( x ) p'(x) p ′ ( x ) 就可以求出 f 0 ′ ( x ) f'_0(x) f 0 ′ ( x ) 和 f 1 ′ ( x ) f'_1(x) f 1 ′ ( x ) 。 g 0 ′ ( x ) g'_0(x) g 0 ′ ( x ) 和 g 1 ′ ( x ) g'_1(x) g 1 ′ ( x ) 同理。然后再通过这构造 A ( x ) A(x) A ( x ) 和 B ( x ) B(x) B ( x ) 即可。总共 2 次 DFT,2 次 IDFT。
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 const int RDX = 31623 ;int N, M, L, P, tgt[MXN];void fft (Cplx* a, int tp) { for (int i (0 ); i != L; ++i) if (i < tgt[i]) swap (a[i], a[tgt[i]]); for (int i (1 ); i != L; i <<= 1 ) { for (int j (0 ); j != L; j += i * 2 ) { for (int k (0 ); k != i; ++k) { Cplx t ((tp == 1 ? angle[i + k] : conj(angle[i + k])) * a[j + i + k]) ; a[j + i + k] = a[j + k] - t; a[j + k] = a[j + k] + t; } } } if (tp == -1 ) for (int i (0 ); i != L; ++i) a[i] = a[i] * Cplx (1.0 / L, 0 ); }int main () { cin >> N >> M >> P; for (int i (0 ), x (0 ); i <= N; ++i) { cin >> x; A[i].r = x / RDX, A[i].i = x % RDX; } for (int i (0 ), x (0 ); i <= M; ++i) { cin >> x; B[i].r = x / RDX, B[i].i = x % RDX; } for (L = 1 ; L <= N + M; L <<= 1 ) ; for (int i (0 ); i != L; ++i) tgt[i] = (tgt[i >> 1 ] >> 1 ) | ((i & 1 ) * (L >> 1 )); for (int i (1 ); i != L; i <<= 1 ) for (int j (0 ); j != i; ++j) angle[i + j] = Cplx (cos (PI / i * j), sin (PI / i * j)); fft (A, 1 ), fft (B, 1 ); for (int i (0 ); i != L; ++i) { Cplx a0 ((conj(A[(L - i) % L]) + A[i]) * Cplx(0.5 , 0 )) , a1 ((conj(A[(L - i) % L]) - A[i]) * Cplx(0 , 0.5 )) , b0 ((conj(B[(L - i) % L]) + B[i]) * Cplx(0.5 , 0 )) , b1 ((conj(B[(L - i) % L]) - B[i]) * Cplx(0 , 0.5 )) ; C[i] = a0 * b0 + Cplx (0 , 1 ) * a0 * b1; D[i] = a1 * b0 + Cplx (0 , 1 ) * a1 * b1; } fft (C, -1 ), fft (D, -1 ); for (int i (0 ); i <= N + M; ++i) { long long a0b0 (C[i].r + 0.5 ) , a0b1 (C[i].i + 0.5 ) , a1b0 (D[i].r + 0.5 ) , a1b1 (D[i].i + 0.5 ) ; cout << (a0b0 % P * RDX % P * RDX % P + ((a0b1 + a1b0) % P * 1LL * RDX % P + a1b1 % P) % P) % P << ' ' ; } cout << endl; return 0 ; }