题目大意:
给定一个有N个点的树,问其中有多少条路径满足他们的边权连成的数对M取余为0。
其中gcd(M,10)=1。
题解:
很亲民的点分治题目,对每一层点分治,预处理每个点到当前根的数字并对m取余,和当前根到每个点的数字取余。
我们得到公式:x1*10dep2+x2=0 mod(m)
dep2指终点的深度,x2指从根到终点数字,通过exgcd即可算出,若想使结果mod m=0,则出发点到根的数字必须是x1。
从公式还可以看出,若枚举出发点,我还需要知道终点的dep,不好做。
所以解法就出来了,先将所有出发点的值丢进一个map(或hash),然后枚举每个子树,先将该子树的出发点值去掉,然后根据终点的值和map计算答案,最后再将该子树出发点的值加进去。
这样可以保证计算答案的不重不漏。还有就是注意根节点的处理。
时间cf上1s2,如果map改hash,应该能更快一些。
#include<iostream> #include<cstdio> #include<cstdlib> #include<string> #include<cstring> #include<cmath> #include<ctime> #include<algorithm> #include<iomanip> #include<map> #include<queue> using namespace std; #define mem1(i,j) memset(i,j,sizeof(i)) #define mem2(i,j) memcpy(i,j,sizeof(i)) #define LL long long #define up(i,j,n) for(int i=(j);i<=(n);++i) #define down(i,n,j) for(int i=n;i>=j;--i) #define Auto(i,x) for(int i=linkk[x];i;i=e[i].next) #define FILE "dealing" #define poi vec #define db double #define eps 1e-10 #define mid ((l+r)>>1) const int maxn=101000,inf=1000000000; int read(){ int x=0,f=1,ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch<='9'&&ch>='0'){x=(x<<1)+(x<<3)+ch-'0',ch=getchar();} return f*x; } inline bool cmax(int& a,int b){return a<b?a=b,true:false;} inline bool cmin(int& a,int b){return a>b?a=b,true:false;} LL n,m,fac[maxn],ans=0; void exgcd(LL a,LL b,LL& d,LL& x,LL& y){ if(b==0){x=1,y=0,d=a;return;} exgcd(b,a%b,d,x,y); LL t=x;x=y;y=t-a/b*x; } struct node{int y,next,v;}e[maxn<<1];int len=0,linkk[maxn<<1]; void insert(int x,int y,int v){e[++len].y=y;e[len].v=v;;e[len].next=linkk[x];linkk[x]=len;} int siz[maxn],fa[maxn],root,q[maxn],vis[maxn],Max[maxn],Min=0,dep[maxn],head,tail; bool b[maxn];LL w[maxn],r[maxn]; void getsize(int x){ head=tail=0;q[++tail]=x;Min=inf; while(++head<=tail){ int x=q[head];b[x]=1;Max[x]=0; siz[x]=1; for(int i=linkk[x];i;i=e[i].next) if(!b[e[i].y]&&!vis[e[i].y])fa[e[i].y]=x,q[++tail]=e[i].y; } down(j,tail,1){ int x=q[j]; for(int i=linkk[x];i;i=e[i].next){ if(e[i].y==fa[x]||vis[e[i].y])continue; siz[x]+=siz[e[i].y]; cmax(Max[x],siz[e[i].y]); } cmax(Max[x],tail-siz[x]); if(cmin(Min,Max[x]))root=x;//求出根节点为root } up(i,1,tail)b[q[i]]=0; } map<int,int> t;//一只桶 void getans(LL x1,LL x2,LL c){ LL d,x,y; exgcd(x1,x2,d,x,y); x=(x*c/d%m+m)%m; if(t.count(x))ans+=t[x]; } void add(int x,int num){ head=tail=0;q[++tail]=x;r[x]=w[x]=num;dep[x]=1; while(++head<=tail){ int x=q[head];b[x]=1; if(t.count(r[x]))t[r[x]]++; else t[r[x]]=1; for(int i=linkk[x];i;i=e[i].next){ if(b[e[i].y]||vis[e[i].y])continue; dep[e[i].y]=dep[x]+1; w[e[i].y]=(w[x]*10+e[i].v)%m;//接收端 r[e[i].y]=(r[x]+fac[dep[x]]*e[i].v)%m;//发出端 q[++tail]=e[i].y; } } up(i,1,tail)b[q[i]]=0; } void cal(int x,int num){ head=tail=0;q[++tail]=x;r[x]=w[x]=num;dep[x]=1; while(++head<=tail){ int x=q[head];b[x]=1;t[r[x]]--; for(int i=linkk[x];i;i=e[i].next){ if(b[e[i].y]||vis[e[i].y])continue; q[++tail]=e[i].y; } } up(i,1,tail)b[q[i]]=0,getans(-fac[dep[q[i]]],m,w[q[i]]); up(i,1,tail)t[r[q[i]]]++; } void solve(int rt){ getsize(rt); t.clear();vis[root]=1;b[root]=1; t[0]=1; for(int i=linkk[root];i;i=e[i].next){ if(vis[e[i].y])continue; add(e[i].y,e[i].v%m); } for(int i=linkk[root];i;i=e[i].next){ if(vis[e[i].y])continue; cal(e[i].y,e[i].v%m); } ans+=t[0]-1; for(int i=linkk[root];i;i=e[i].next){ if(vis[e[i].y])continue; solve(e[i].y); } } int main(){ freopen(FILE".in","r",stdin); freopen(FILE".out","w",stdout); n=read(),m=read(); fac[0]=1;up(i,1,n)fac[i]=fac[i-1]*10%m; up(i,1,n-1){ int x=read()+1,y=read()+1,v=read(); insert(x,y,v);insert(y,x,v); } solve(1); cout<<ans<<endl; return 0; }