zoukankan      html  css  js  c++  java
  • 2019杭电多校三 C. Yukikaze and Demons (点分治)

    大意: 给定树, 每个点有一个十进制数位, 求有多少条路径组成的十进制数被$k$整除.

    点分治, 可以参考CF715C, 转化为求$10^a x+bequiv 0(modspace k)$的$x$的个数.

    要注意

    • $tmp$不要设成全局!!
    • 如果$ ext{y%z==0}$的话, 那么$ ext{x%y%z==x%z}$
    #include <iostream>
    #include <algorithm>
    #include <cstdio>
    #include <math.h>
    #include <set>
    #include <map>
    #include <queue>
    #include <string>
    #include <string.h>
    #include <bitset>
    #define REP(i,a,n) for(int i=a;i<=n;++i)
    #define PER(i,a,n) for(int i=n;i>=a;--i)
    #define hr putchar(10)
    #define pb push_back
    #define lc (o<<1)
    #define rc (lc|1)
    #define mid ((l+r)>>1)
    #define ls lc,l,mid
    #define rs rc,mid+1,r
    #define x first
    #define y second
    #define io std::ios::sync_with_stdio(false)
    #define endl '
    '
    using namespace std;
    typedef long long ll;
    typedef pair<int,int> pii;
    
    
    
    
    const int N = 1e5+10;
    int sum, n, rt, m, p10[N];
    int sz[N], mx[N], vis[N], b[N];
    char s[N];
    vector<int> g[N];
    ll ans, ans1, Phi;
    
    int gcd(int a, int b) {return b?gcd(b,a%b):a;}
    int exgcd(int a, int b, int &x, int &y) {
        int d;
        if (b) d=exgcd(b,a%b,y,x), y-=a/b*x;
        else d=a,x=1,y=0;
        return d;
    }
    bool chk(int &a, int &b, int &p) {
        //ax=b(mod p)是否有解
        int x, k, d = exgcd(a,p,x,k);
        if (b%d==0) a=1,p/=d,b=(b/d*x%p+p)%p;
        return a==1;
    }
    
    void getrt(int x, int fa) {
        mx[x]=0, sz[x]=1;
        for (int y:g[x]) if (!vis[y]&&y!=fa) {
            getrt(y,x),sz[x]+=sz[y];
            mx[x]=max(mx[x],sz[y]);
        }
        mx[x]=max(mx[x],sum-sz[x]);
        if (mx[rt]>mx[x]) rt=x;
    }
    
    int ID(int x) {
        return lower_bound(b+1,b+1+*b,x)-b;
    }
    
    map<int,int> mp[40];
    
    //mp[i][j] 记录10^h*x=y(mod m)的y的个数, 其中y = j (mod b[i]), b[i] = m/gcd(10^h,m)
    void dfs1(int x, int fa, int dep, int down) {
        //求10^dep*x=(m-down)%m
        int a = p10[dep], b = (m-down)%m, p = m;
        if (chk(a,b,p)) { 
            auto &u = mp[ID(p)];
            if (u.count(b)) ans += u[b];
        }
        for (int y:g[x]) if (!vis[y]&&y!=fa) {
            dfs1(y,x,dep+1,((ll)down*10ll+s[y])%m);
        }
    }
    int up[40];
    void dfs2(int x, int fa, int dep) {
        REP(i,1,*b) { 
            ++mp[i][up[i]];
        }
    	int tmp[40];
        for (int y:g[x]) if (!vis[y]&&y!=fa) {
            REP(i,1,*b) tmp[i]=up[i],up[i]=((ll)s[y]*p10[dep]+up[i])%b[i];
            dfs2(y,x,dep+1);
            REP(i,1,*b) up[i]=tmp[i];
        }
    }
    void dfs3(int x, int fa, int down, int dep, int up) {
        ans1 += !up+!down;
        for (int y:g[x]) if (!vis[y]&&y!=fa) {
            dfs3(y,x,((ll)down*10+s[y])%m,dep+1,((ll)s[y]*p10[dep]+up)%m);
        }
    }
    
    vector<int> q;
    void calc(int x) {
        REP(i,1,*b) mp[i].clear();
        if (s[x]%m==0) ++ans1;
        for (int y:q) {
            dfs1(y,x,1,s[y]%m);
            REP(i,1,*b) up[i] = (s[x]+10ll*s[y])%b[i];
            dfs2(y,x,2);
            dfs3(y,x,(10ll*s[x]+s[y])%m,2,(s[x]+10ll*s[y])%m);
        }
    }
    
    void solve(int x) {
        vis[x] = 1;
        q.clear();
        for (int y:g[x]) if (!vis[y]) q.pb(y);
        calc(x);
        reverse(q.begin(),q.end());
        calc(x);
        for (int y:g[x]) if (!vis[y]) {
            mx[rt=0]=n,sum=sz[y];
            getrt(y,0), solve(rt);
        }
    }
    
    
    void work() {
        scanf("%d%d%s", &n, &m, s+1);
        REP(i,1,n) p10[i]=p10[i-1]*10ll%m;
        REP(i,1,n) s[i]-='0';
        ans = ans1 = 0;
        REP(i,1,n) g[i].clear(),vis[i]=0;
        REP(i,2,n) {
            int u, v;
            scanf("%d%d", &u, &v);
            g[u].pb(v);
            g[v].pb(u);
        }
        if (m==1) return printf("%lld
    ", (ll)n*n),void();
        *b = 0;
        REP(i,0,min(n,30)) b[++*b]=m/gcd(p10[i],m);
        sort(b+1,b+1+*b),*b=unique(b+1,b+1+*b)-b-1;
        sum=mx[rt=0]=n,getrt(1,0),solve(rt);
        printf("%lld
    ", ans+ans1/2);
    }
    
    int main() {
        p10[0]=1;
        int t;
        scanf("%d", &t);
        while (t--) work();
    }
    
  • 相关阅读:
    Python网络编程 —— 粘包问题及解决方法
    Python网络编程 —— socket(套接字)及通信
    Python网络编程 —— 网络基础知识
    Python
    MySQL 之 数据的导出与导入
    MySQL 之 慢查询优化及慢日志管理
    MySQL 之 索引进阶
    MySQL 之 索引
    MySQL 之 事务
    MySQL 之 表的存储引擎
  • 原文地址:https://www.cnblogs.com/uid001/p/11266021.html
Copyright © 2011-2022 走看看