题面
给定 (a{n}) 和 (b{m}) 和质数 (p)。对于每个 (a_i),生成一个集合,刚开始只有 (1),每次从集合中选一个元素 (c),对于所有的 (j),如果满足 (c imes a_i^{b_j}mod p) 不在当前集合,就把它加入集合。求 (n) 个集合并的大小。
数据范围:(1le nle 10^4),(1le mle 10^5),(2le ple 10^9)。
题解
之前 VP
场上没做出来后来学了 zhouyuyang
的做法,然后发现学了个假做法。
Hack:2 1 13 3 5 1
;答案:6
;错误输出:7
。
设 (varphi=p-1),(g=gcd(varphi,gcd_{i=1}^m b_i)),(G) 是 (p) 任意一个原根。
很明显每个集合元素就是 ({a_i^{gk}mod p|kin N})。
(u^{x}equiv u^{xmod varphi}pmod p)。(xin[0,varphi)) 与 (f(x)=G^{x}in [1,p)) 一一对应。
所以搞出所有 (varphi) 的因数,可以得到对于所有的 (i),最小的 (c_i) 满足 ({a_i}^{gc_i}equiv 1pmod p)。那么这个集合有 (c_i) 个元素。
考虑是哪 (c_i) 个元素:({a^{g},a^{2g},...,a^{(c_i-1)g}})((mod p),下同)。
或者说:({G^{0},G^{varphi/c_i},G^{2varphi/c_i},...,G^{(c_i-1)varphi/c_i}})。这两个集合是等价的。
设 (s_i=frac{varphi}{c_i}),求 (n) 个集合的并集,问题转化为:有多少个 (1le xle varphi) 满足 (exists s_i|x)。
很多人这里写错了,可是 Codeforces
的数据不强,误导了好多人。
设 (d|varphi:f(d)) 表示 (d) 的倍数中只被 (d) 标记不被 (d) 的其他是 (varphi) 的因数的倍数标记的数的个数。
如果 (exists s_i|d),(f(d)=frac{varphi}{d}-sum_{dk|varphi}f(dk));否则 (f(d)=0)。
然后答案是 (sum_{d|varphi} f(d)),时间复杂度 (Theta((n+m)log p+d(varphi)(n+d(varphi))))。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define x first
#define y second
#define bg begin()
#define ed end()
#define pb push_back
#define mp make_pair
#define sz(a) int((a).size())
#define R(i,n) for(int i(0);i<(n);++i)
#define L(i,n) for(int i((n)-1);i>=0;--i)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
//Data
const int N=1e4,M=1e5,d_N=4e4;
int n,m,p,a[N],b[M],c[N];
int d_n,pd[d_N],f[d_N];
//Math
int gcd(int a,int b){return a?gcd(b%a,a):b;}
int mypow(int a,int x=p-2,int res=1){
for(;x;x>>=1,a=1ll*a*a%p)
(x&1)&&(res=1ll*res*a%p);
return res;
}
//Main
int main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n>>m>>p;
int phi=p-1,g=phi;
R(i,n) cin>>a[i];
R(i,m) cin>>b[i],g=gcd(g,b[i]);
for(int d=1;d*d<=phi;++d)if(phi%d==0)
pd[d_n++]=d,d*d<phi?pd[d_n++]=phi/d:0;
sort(pd,pd+d_n);
R(i,n){
int t=mypow(a[i],g);
R(j,d_n)if(mypow(t,pd[j])==1)
{c[i]=phi/pd[j]; break;}
}
int ns=0;
L(i,d_n){
bool mark=false;
R(j,n)if(pd[i]%c[j]==0){mark=true; break;}
if(mark){
f[i]=phi/pd[i];
for(int j=i+1;j<d_n;++j)
if(pd[j]%pd[i]==0) f[i]-=f[j];
ns+=f[i];
}
}
cout<<ns<<'
';
return 0;
}
祝大家学习愉快!