zoukankan      html  css  js  c++  java
  • codeforces 852B

    http://codeforces.com/contest/852/problem/B

    题意:有一幅有向图,除了源点和汇点有 L 层,每层 n 个点。 第 i+1 层的每个点到 第 i+2 层的每个点都有一条边,边的权值为有向边终点的权值。求源点到汇点的路径长度能被 m 整除的个数。

    题解:快速幂。a[i] 表示从第 1 层到第 a 层总路径长度为 i (i % m) 的数目,b[j] 表示从第 a+1层到 第 a+1 层(也就是自己层)总路径长度为 j (j % m) 的数目,那么第 a+2 层的 a[(i+j)%m] = a[i]*b[j]。

       暴力做法,从第一层开始,一层一层的乘上去,这样显然会超时。

       仔细看一下,从第 2 层到第 L-1 层,每次乘的操作是相同的,可以用快速幂先把第 2 层到第 L-1 层乘起来。

    #include<iostream>
    #include<cstring>
    #include<algorithm>
    #include<cstdio> 
    #define mod 1000000007
    using namespace std;
    const int MAXN = 100000+10;
    int a[1000010];
    int n, l, m;
    struct node
    {
        long long num[110];
        node()
        {
            memset(num, 0x0000, sizeof(num));
        }
    };
    node Begin, End, mid;
    node mul(node la, node lb)
    {
        node aa = node();
        for(int i = 0; i < m; i++)
        {
            for(int j = 0; j < m; j++)
            {
                int k = (i+j)%m;
                aa.num[k] += la.num[i] * lb.num[j] % mod;
                aa.num[k] %= mod;
            }
        }
        return aa;
    }
    node fast(node nod, int k)
    {
        node sum = nod;
        k--;
        while(k)
        {
            if(k&1)
            {
                sum = mul(sum, nod);
            }
            k >>= 1;
            nod = mul(nod, nod);
        }
        return sum;
    }
    
    int main (void)
    {
        ios::sync_with_stdio(false);
        Begin = node();
        End = node();
        mid = node();
        cin >> n >> l >> m;
        for(int i = 1; i <= n; i++)
        {
            int x; cin >> x;
            Begin.num[x%m]++;
        }
        for(int i = 1; i <= n; i++)
        {
            int x; cin >> x;
            a[i] = x;
            mid.num[x%m]++;
        }
        for(int i = 1; i <= n; i++)
        {
            int x; cin >> x;
            End.num[(x+a[i])%m]++;
        }
        
        node nod;
        if( l-2 > 0 )
        {
            nod = fast(mid, l-2);    
            nod = mul(nod, Begin);
            nod = mul(nod, End); 
        }
        else
        {
            nod = mul(Begin, End);
        }
        
        long long ans = 0;
        for(int j = 0; j <= 100; j++)
        {
            if(j%m==0)
            {
                ans += nod.num[j];
                ans %= mod;
            }
        }
        cout << ans;
    }
  • 相关阅读:
    简单理解OOP——面向对象编程
    SpringMVC拦截器
    Vue简洁及基本用法
    springMVC实现文件上传下载
    Python笔记⑤爬虫
    Python笔记4
    Python笔记3
    Python基础语法笔记2
    Python基础入门语法1
    Navicat连接mysql时候出现1251错误代码
  • 原文地址:https://www.cnblogs.com/lkcc/p/7471852.html
Copyright © 2011-2022 走看看