zoukankan      html  css  js  c++  java
  • 洛谷 P5293 [HNOI2019]白兔之舞

    有一张顶点数为 ((L+1) imes n) 的有向图。这张图的每个顶点由一个二元组((u,v))表示((0le ule L,1le vle n))
    这张图不是简单图,对于任意两个顶点 ((u_1,v_1)(u_2,v_2)),如果 (u_1<u_2),则从 ((u_1,v_1))((u_2,v_2)) 一共有 (w[v_1][v_2]) 条不同的边,如果 (u_1ge u_2) 则没有边。

    白兔将在这张图上上演一支舞曲。白兔初始时位于该有向图的顶点 ((0,x))

    白兔将会跳若干步。每一步,白兔会从当前顶点沿任意一条出边跳到下一个顶点。白兔可以在任意时候停止跳舞(也可以没有跳就直接结束)。当到达第一维为 (L) 的顶点就不得不停止,因为该顶点没有出边。

    假设白兔停止时,跳了 (m) 步,白兔会把这只舞曲给记录下来成为一个序列。序列的第 (i) 个元素为它第 (i) 步经过的边。

    问题来了:给定正整数 (k)(y)(1le yle n)),对于每个 (t)(0le t<k)),求有多少种舞曲(假设其长度为 (m))满足 (m mod k=t),且白兔最后停在了坐标第二维为 (y) 的顶点?

    两支舞曲不同定义为它们的长度((m))不同或者存在某一步它们所走的边不同。

    输出的结果对 (p) 取模。

    对于全部数据,(p) 为一个质数,(10^8<p<2^{30})(1le nle 3)(1le xle n)(1le yle n)(0le w(i,j)<p)(1le kle 65536)(k)(p-1) 的约数,(1le Lle 10^8)


    首先可以考虑dp,设 (f_{i,j}) 表示走了 (i) 步,最后第二维停在了 (j) 上的方案数。

    这个东西不好dp,再设一个 (g_{i,j}) 表示挨着走了 (i) 个格子,最后第二维走到了 (j) 上的方案数,这个就可以dp了,有:

    [g_{i,j}=sum_{k=1}^ng_{i-1,k}w_{k,j} ]

    这个是非常可以矩阵优化的,那么就可以写成:

    [G_i=G_0W^i ]

    然后 (f_{i,j}) 可以看作从 (L) 个格子中选了了 (i) 个落脚点,就有:

    [f_{i,j}={Lchoose i}sum_{k=1}^ng_{i,k} ]

    最后我们的答案也就变成了:

    [ans_t=sum_{m=0}^L[m\%k==t]f_{m,y} ]

    这东西熟啊,直接单位根反演,就有:

    [egin{aligned}&sum_{m=0}^Lsum_{d=0}^{k-1}frac{1}{k}omega_k^{(m-t)d}f_{m,y}\=&frac{1}{k}sum_{d=0}^{k-1}omega_k^{-td}sum_{m=0}^Lomega_k^{md}f_{m,y}end{aligned} ]

    (f_{m,y}) 替换成 (G_m) 的系数的形式

    [egin{aligned}ans_t&=frac{1}{k}sum_{d=0}^{k-1}omega_k^{-td}sum_{m=0}omega_k^{md}{Lchoose m}[y]G_0W^m\&=frac{1}{k}sum_{d=0}^{k-1}omega_k^{-td}[y]G_0sum_{m=0}^L{Lchoose m}(omega_k^dW)^m\&=frac{1}{k}sum_{d=0}^{k-1}omega_k^{-td}[y]G_0(omega_k^dW+I)^Lend{aligned} ]

    其中 (I) 是单位矩阵,设 (F_d=[y]G_0(omega_k^dW+I)^L) ,那么这个东西可以通过矩阵快速幂预处理出来,然后就有:

    [ans_t=frac{1}{k}sum_{d=0}^{k-1}omega_k^{-td}F_d ]

    这个是个循环卷积,我们用 ({i+jchoose 2}-{ichoose2}-{jchoose 2}) 来替换 (omega_k) 的系数就有:

    [egin{aligned}ans_t&=frac{1}{k}sum_{d=0}^{k-1}omega_k^{-({t+dchoose 2}-{tchoose2}-{dchoose 2})}F_d\&=frac{1}{k}omega_k^{tchoose2}sum_{d=0}^{k-1}omega_k^{-{t+dchoose2}}omega_k^{{dchoose2}}F_dend{aligned} ]

    发现这是个差卷积,然后模数不固定,还要做mtt。

    Code

    #include <iostream>
    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #include <cmath>
    const int N = 65536;
    const int M = 1e6;
    const long double Pi = acos(-1.0);
    using namespace std;
    int n,k,L,x,y,p,g,rev[M + 5],maxn,lg,prime[N + 5],pcnt,ow[N + 5],A[M + 5],B[M + 5],C[M + 5];
    struct Matrix
    {
        int a[4][4];
    }G,W,nw,I;
    Matrix operator *(Matrix a,Matrix b)
    {
        Matrix c;
        for (int i = 1;i <= n;i++)
            for (int j = 1;j <= n;j++)
                c.a[i][j] = 0;
        for (int i = 1;i <= n;i++)
            for (int j = 1;j <= n;j++)
                for (int k = 1;k <= n;k++)
                    c.a[i][j] += 1ll * a.a[i][k] * b.a[k][j] % p,c.a[i][j] %= p;
        return c;
    }
    struct node
    {
        double x,y;
        node conj(){return (node){x,-y};}
    }w[M + 5],c[M + 5],d[M + 5],x1[M + 5],x2[M + 5],x3[M + 5],aa[M + 5],bb[M + 5],cc[M + 5],dd[M + 5],conj;
    node operator +(node a,node b){return (node){a.x + b.x,a.y + b.y};}
    node operator -(node a,node b){return (node){a.x - b.x,a.y - b.y};}
    node operator *(node a,node b){return (node){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};}
    int mypow(int a,int x,int p){int s = 1;for (;x;x & 1 ? s = 1ll * s * a % p : 0,a = 1ll * a * a % p,x >>= 1);return s;}
    void prework(int n)
    {
        maxn = 1;lg = 0;
        while (maxn <= n)
            maxn <<= 1,lg++;
        for (int i = 0;i < maxn;i++)
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
    }
    int getphi(int n)
    {
        int ans = n;
        for (int i = 2;i * i <= n;i++)
            if (n % i == 0)
            {
                while (n % i == 0)
                {
                    n /= i;
                    ans = ans / i * (i - 1);
                }
            }
        if (n != 1)
            ans = ans / n * (n - 1);
        return ans;
    }
    void getprime(int n)
    {
        for (int i = 2;i * i <= n;i++)
            if (n % i == 0)
            {
                prime[++pcnt] = i;
                while (n % i == 0)
                    n /= i;
            }
        if (n != 1)
            prime[++pcnt] = n;
    }
    int get(int n)
    {
        int phi = getphi(n);
        getprime(phi);
        for (int i = 1;i < n;i++)
            if (mypow(i,phi,n) == 1)
            {
                int fl = 1;
                for (int j = 1;j <= pcnt;j++)
                    if (mypow(i,phi / prime[j],n) == 1)
                    {
                        fl = 0;
                        break;
                    }
                if (fl)
                    return i;
            }
    }
    void fft(node *a,int typ)
    {
        for (int i = 0;i < maxn;i++)
            if (i < rev[i])
                swap(a[i],a[rev[i]]);
        for (int i = 1;i < maxn;i <<= 1)
            for (int j = 0;j < maxn;j += i << 1)
                for (int k = 0;k < i;k++)
                {
                    node x = a[j + k],t = (node){w[k + i].x,w[k + i].y * typ} * a[j + k + i];
                    a[j + k] = x + t;
                    a[j + k + i] = x - t;
                }
        if (typ == -1)
            for (int i = 0;i < maxn;i++)
                a[i].x /= maxn,a[i].y /= maxn;
    }
    Matrix mpow(Matrix a,int x)
    {
        Matrix s = I;
        while (x)
        {
            if (x & 1)
                s = s * a;
            a = a * a;
            x >>= 1;
        }
        return s;
    }
    void getmatrix(int d)
    {
        memset(G.a,0,sizeof(G.a));
        G.a[1][x] = 1;
        memset(nw.a,0,sizeof(nw.a));
        for (int i = 1;i <= n;i++)
            for (int j = 1;j <= n;j++)
                nw.a[i][j] = 1ll * ow[d] * W.a[i][j] % p + I.a[i][j];
        G = G * mpow(nw,L);
    }
    int C2(int n){return 1ll * n * (n - 1) / 2 % k;}
    void mtt(int *a,int *b,int n,int p)
    {
        for (int i = 0;i < n;i++)
            a[i] %= p,b[i] %= p;
        int bs = 32768;
        for (int i = 0;i < n;i++)
            c[i] = (node){a[i] / bs,a[i] % bs};
        fft(c,1);
        d[0] = c[0].conj();
        for (int i = 1;i < maxn;i++)
            d[i] = c[maxn - i].conj();
        for (int i = 0;i < maxn;i++)
            aa[i] = (c[i] + d[i]) * (node){0.5,0},bb[i] = (c[i] - d[i]) * (node){0,-0.5};
        for (int i = 0;i < maxn;i++)
            c[i] = d[i] = (node){0,0};
        for (int i = 0;i < n;i++)
            c[i] = (node){b[i] / bs,b[i] % bs};
        fft(c,1);
        d[0] = c[0].conj();
        for (int i = 1;i < maxn;i++)
            d[i] = c[maxn - i].conj();
        for (int i = 0;i < maxn;i++)
            cc[i] = (c[i] + d[i]) * (node){0.5,0},dd[i] = (c[i] - d[i]) * (node){0,-0.5};
        for (int i = 0;i < maxn;i++)
            x1[i] = aa[i] * cc[i],x2[i] = aa[i] * dd[i] + cc[i] * bb[i],x3[i] = bb[i] * dd[i];
        for (int i = 0;i < maxn;i++)
            x1[i] = x1[i] + x3[i] * (node){0,1};
        fft(x1,-1);
        fft(x2,-1);
        for (int i = 0;i < n;i++)
            a[i] = ((1ll * ((long long)(x1[i].x + 0.1)) % p * bs % p * bs % p + 1ll * ((long long)(x2[i].x + 0.1) % p) * bs % p) % p + ((long long)(x1[i].y + 0.1)) % p) % p;
    }
    int main()
    {
        scanf("%d%d%d%d%d%d",&n,&k,&L,&x,&y,&p);
        for (int i = 1;i <= n;i++)
            for (int j = 1;j <= n;j++)
                scanf("%d",&W.a[i][j]);
        g = get(p);
        ow[1] = mypow(g,(p - 1) / k,p);
        for (int i = 2;i <= k;i++)
            ow[i % k] = 1ll * ow[i - 1] * ow[1] % p;
        prework(k * 4);
        for (int i = 1;i < maxn;i <<= 1)
            for (int j = 0;j < i;j++)
                w[i + j] = (node){cos(Pi * j / i),sin(Pi * j / i)};
        for (int i = 1;i <= 3;i++)
            I.a[i][i] = 1;
        int lim = k * 2;
        for (int d = 0;d < lim;d++)
            A[d] = ow[(k - C2(d)) % k];
        for (int d = 0;d < k;d++)
        {
            getmatrix(d);
            B[d] = 1ll * ow[C2(d)] * G.a[1][y] % p; 
        }
        reverse(A,A + lim);
        mtt(A,B,lim,p);
        reverse(A,A + lim);
        int ik = mypow(k,p - 2,p);
        for (int d = 0;d < k;d++)
        {
            A[d] = 1ll * ik * ow[C2(d)] % p * A[d] % p;
            printf("%d
    ",(A[d] + p) % p);
        }
        return 0;
    }
    
    
  • 相关阅读:
    SGC强制最低128位加密,公钥支持ECC加密算法的SSL证书
    python学习笔记(一)
    eclipse中启动 Eclipse 弹出“Failed to load the JNI shared library jvm.dll”错误
    外键建立失败
    scala函数式编程(一)
    idea环境下建立maven工程并运行scala程序
    scala中option、None、some对象
    Java与mysql数据库编程中遇见“Before start of result set at com.mysql.jdbc.SQLError.createSQLException” 的解决办法
    hive表的存储路径查找以及表的大小
    red hat7 系统可以ping通ip地址但是不能ping通域名
  • 原文地址:https://www.cnblogs.com/sdlang/p/14328501.html
Copyright © 2011-2022 走看看