前言
又到了愉快的数学时间了!
由于蒟蒻作者数学特别烂,贴个友链然后直接上代码吧
本博客由某谷搬运过来,目的只是存板子。
(update 2021.2.2) 更新了部分代码,已经基本学懂,懒得更新讲解了。u1s1,到了高中后有了一定数学基础,就是比初中傻白甜的时候学得快。
(update 2021.2.3) 学懂了蝴蝶变换之后又更新了一波板子,不得不说,迭代版本真的快得离谱。
(update 2021.7.27) 折叠了代码,增加文章可读性。
FFT
友链
练习
代码
板题代码
$Mine ext{(递归)}$
//12252024832524
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define TT template<typename T>
using namespace std;
typedef long long LL;
const int MAXN = 1 << 21 | 5;
const double PI = acos(-1);
int lena,lenb;
LL Read()
{
LL x = 0,f = 1;char c = getchar();
while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
return x * f;
}
TT void Put1(T x)
{
if(x > 9) Put1(x/10);
putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
if(x < 0) putchar('-'),x = -x;
Put1(x);
if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}
struct cp//complex
{
double x,y;
cp(){}
cp(double x1,double y1){
x = x1;
y = y1;
}
cp operator + (const cp &A) const {return cp(x+A.x,y+A.y);}
cp operator - (const cp &A) const {return cp(x-A.x,y-A.y);}
cp operator * (const cp &A) const {return cp(x*A.x-y*A.y,x*A.y+y*A.x);}
}a[MAXN],b[MAXN];
void FFT(int len,cp * a,int f)
{
if(len == 1) return;
cp a1[len >> 1],a2[len >> 1];
for(int i = 0;i < len;i += 2) a1[i >> 1] = a[i],a2[i >> 1] = a[i+1];
FFT(len>>1,a1,f);
FFT(len>>1,a2,f);
cp w = cp(cos(2*PI/len),f*sin(2*PI/len)),k = cp(1,0);
len >>= 1;
for(int i = 0;i < len;++ i,k = k * w)
{
a[i] = a1[i] + k * a2[i];
a[i+len] = a1[i] - k * a2[i];
}
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".in","w",stdout);
lena = Read(); lenb = Read();
for(int i = 0;i <= lena;++ i) a[i].x = Read();
for(int i = 0;i <= lenb;++ i) b[i].x = Read();
int len = 1;
while(len <= lena + lenb) len <<= 1;
FFT(len,a,1);
FFT(len,b,1);
for(int i = 0;i <= len;++ i) a[i] = a[i] * b[i];
FFT(len,a,-1);
for(int i = 0;i <= lena+lenb;++ i) Put((int)(a[i].x/len + 0.5),' ');
return 0;
}
$Mine ext{(迭代)}$
//12252024832524
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define TT template<typename T>
using namespace std;
typedef long long LL;
const int MAXN = 1 << 21 | 5;
const double PI = acos(-1);
int lena,lenb,len = 1,l = -1;
int rev[MAXN];
LL Read()
{
LL x = 0,f = 1;char c = getchar();
while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
return x * f;
}
TT void Put1(T x)
{
if(x > 9) Put1(x/10);
putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
if(x < 0) putchar('-'),x = -x;
Put1(x);
if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}
struct cp
{
double x,y;
cp(){}
cp(double x1,double y1){
x = x1;
y = y1;
}
cp operator + (const cp &A)const{return cp(x+A.x,y+A.y);}
cp operator - (const cp &A)const{return cp(x-A.x,y-A.y);}
cp operator * (const cp &A)const{return cp(x*A.x-y*A.y,x*A.y+y*A.x);}
}a[MAXN],b[MAXN];
void FFT(cp *a,int opt)
{
for(int i = 0;i < len;++ i) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i < len;i <<= 1)
{
cp w = cp(cos(PI/i),opt*sin(PI/i));
for(int j = 0,p = i << 1;j < len;j += p)
{
cp s = cp(1,0);
for(int k = 0;k < i;++ k,s = s * w)
{
cp X = a[j+k],Y = s * a[i+j+k];
a[j+k] = X + Y;
a[i+j+k] = X - Y;
}
}
}
if(opt == -1) for(int i = 0;i < len;++ i) a[i].x /= len;
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
lena = Read(); lenb = Read();
for(int i = 0;i <= lena;++ i) a[i].x = Read();
for(int i = 0;i <= lenb;++ i) b[i].x = Read();
while(len <= lena + lenb) len <<= 1,l++;
for(int i = 0;i < len;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
FFT(a,1);
FFT(b,1);
for(int i = 0;i < len;++ i) a[i] = a[i] * b[i];
FFT(a,-1);
for(int i = 0;i <= lena+lenb;++ i) Put((int)(a[i].x + 0.5),' ');
return 0;
}
$French'sspace code ext{(迭代)}$
#include<iostream>
#include<algorithm>
#include<cmath>
using namespace std;
#define maxn 10000005
#define x first
#define y second
const double pi=acos(-1.0);
int n,m;
int limit=1;
pair<double,double> a[maxn];
pair<double,double> b[maxn];
int l;
int r[maxn];
pair<double,double> operator + (pair<double,double> a,pair<double,double> b)
{
return make_pair(a.x+b.x,a.y+b.y);
}
pair<double,double> operator - (pair<double,double> a,pair<double,double> b)
{
return make_pair(a.x-b.x,a.y-b.y);
}
pair<double,double> operator * (pair<double,double> a,pair<double,double> b)
{
return make_pair(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
#undef x
#undef y
void fft(pair<double,double> *a,int t)
{
for(int i=0;i<limit;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
pair<double,double> wn=make_pair(cos(pi/mid),t*sin(pi/mid));
for(int r=mid<<1,j=0;j<limit;j+=r)
{
pair<double,double> w=make_pair(1,0);
for(int k=0;k<mid;k++,w=w*wn)
{
pair<double,double> x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;
a[j+mid+k]=x-y;
}
}
}
}
void work()
{
fft(a,1);
fft(b,1);
for(int i=0;i<=limit;i++)
a[i]=a[i]*b[i];
fft(a,-1);
for(int i=0;i<=n+m;i++)
cout<<(int)(a[i].first/limit+0.5)<<' ';
}
void perpare()
{
while(limit<=n+m)
{
limit<<=1;
l++;
}
for(int i=0;i<limit;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
int main()
{
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=0;i<=n;i++)
cin>>a[i].first;
for(int i=0;i<=m;i++)
cin>>b[i].first;
perpare();
work();
return 0;
}
NTT
正题
难得一见的讲解:
由于 (FFT) 会有精度问题,而且不能取模,所以 (NTT) 就诞生了。
我们只需将 (FFT) 中的 (omega) 换成 (NTT) 中的模数的原根 (g) 就好了。
如果我们不需要取模,只需要找一个很大的模数就好了,这样取模就相当于没有取模。
当然最后除 (len) 的时候改为乘逆元就好了。
练习
其实你可以用 (NTT) 过所有 (FFT) 的题。 好像并不是
代码
板题代码
$Mine ext{(递归)}$
//12252024832524
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define TT template<typename T>
using namespace std;
typedef long long LL;
const int MAXN = 1 << 21 | 5;
const int MOD = 998244353;
const int PHI = 998244352;
const int GINV = 332748118;
const int G = 3;
int lena,lenb;
int a[MAXN],b[MAXN];
LL Read()
{
LL x = 0,f = 1;char c = getchar();
while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
return x * f;
}
TT void Put1(T x)
{
if(x > 9) Put1(x/10);
putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
if(x < 0) putchar('-'),x = -x;
Put1(x);
if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}
int qpow(int x,int y)
{
int ret = 1;
while(y){if(y & 1) ret = 1ll * ret * x % MOD;x = 1ll * x * x % MOD;y >>= 1;}
return ret;
}
void NTT(int len,int *a,int f)
{
if(len == 1) return;
int a1[len >> 1],a2[len >> 1];
for(int i = 0;i < len;i += 2) a1[i >> 1] = a[i],a2[i >> 1] = a[i+1];
NTT(len>>1,a1,f);
NTT(len>>1,a2,f);
int w = qpow(f == 1 ? G : GINV,PHI/len),k = 1;
len >>= 1;
for(int i = 0;i < len;++ i,k = 1ll * k * w % MOD)
{
a[i] = (a1[i] + 1ll * k * a2[i]) % MOD;
a[i+len] = (a1[i] - 1ll * k * a2[i]) % MOD;
}
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".in","w",stdout);
lena = Read(); lenb = Read();
for(int i = 0;i <= lena;++ i) a[i] = Read();
for(int i = 0;i <= lenb;++ i) b[i] = Read();
int len = 1;
while(len <= lena + lenb) len <<= 1;
NTT(len,a,1);
NTT(len,b,1);
for(int i = 0;i <= len;++ i) a[i] = 1ll * a[i] * b[i] % MOD;
NTT(len,a,-1);
const int invlen = qpow(len,MOD-2);
for(int i = 0;i <= lena+lenb;++ i) Put((1ll * a[i] * invlen % MOD + MOD) % MOD,' ');
return 0;
}
$Mine ext{(迭代)}$
//12252024832524
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define TT template<typename T>
using namespace std;
typedef long long LL;
const int MAXN = 1 << 21 | 5;
const int MOD = 998244353;
const int PHI = 998244352;
const int GINV = 332748118;
const int G = 3;
int lena,lenb,len = 1,l = -1;
int a[MAXN],b[MAXN],rev[MAXN];
LL Read()
{
LL x = 0,f = 1;char c = getchar();
while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
return x * f;
}
TT void Put1(T x)
{
if(x > 9) Put1(x/10);
putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
if(x < 0) putchar('-'),x = -x;
Put1(x);
if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}
int qpow(int x,int y)
{
int ret = 1;
while(y){if(y & 1) ret = 1ll * ret * x % MOD;x = 1ll * x * x % MOD;y >>= 1;}
return ret;
}
void NTT(int *a,int opt)
{
for(int i = 0;i < len;++ i) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i < len;i <<= 1)
{
int w = qpow(opt == 1 ? G : GINV,PHI / (i << 1));
for(int j = 0,p = i << 1;j < len;j += p)
{
int mi = 1;
for(int k = 0;k < i;++ k,mi = 1ll * mi * w % MOD)
{
int X = a[j+k],Y = 1ll * mi * a[i+j+k] % MOD;
a[j+k] = (X + Y) % MOD;
a[i+j+k] = (X - Y + MOD) % MOD;
}
}
}
int invlen = qpow(len,MOD-2);
if(opt == -1) for(int i = 0;i < len;++ i) a[i] = 1ll * a[i] * invlen % MOD;
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
lena = Read(); lenb = Read();
for(int i = 0;i <= lena;++ i) a[i] = Read();
for(int i = 0;i <= lenb;++ i) b[i] = Read();
while(len <= lena + lenb) len <<= 1,l++;
for(int i = 0;i < len;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
NTT(a,1);
NTT(b,1);
for(int i = 0;i < len;++ i) a[i] = 1ll * a[i] * b[i] % MOD;
NTT(a,-1);
for(int i = 0;i <= lena+lenb;++ i) Put(a[i],' ');
return 0;
}