zoukankan      html  css  js  c++  java
  • HDU4436 str2int 子串之和

    vjudge传送


    这题还真挺难,没想出来。


    首先我们要做的是,为了判断所有的重复串,要把这些串放在一个SAM里,但要用特殊字符分隔开来(这里用数字10,以缩小字符集大小)。
    接下来我想的是在后缀链接树上dfs,但复杂度是(O(sum len[i])(i in 叶子节点)),这个复杂度上限是(O(n^2))的。


    题解的做法是在转移边上拓扑排序(转移边构成了一个DAG)。
    (sum[u])表示到节点(u)上的数字和,(num[u])表示到(u)的路径条数,于是对于一条转移边((uoverset{i}{ ightarrow} v))
    (num[v] += num[u]),
    (sum[v] += sum[u]*10+i*num[u]).
    刚开始我没想明白为什么要维护(num[u]),后来手动模拟后才知道,(num[u])表示的是到(u)的路径条数,对应的就是以(u)结尾不同数字个数,所以当这些数字后面都加了一个数码(i)后,对于(v)的影响就是加上了(i*num[u]).
    时间复杂度(O(2nK)(K)为字符集大小())

    #include<cstdio>
    #include<iostream>
    #include<cmath>
    #include<algorithm>
    #include<cstring>
    #include<cstdlib>
    #include<cctype>
    #include<vector>
    #include<queue>
    #include<assert.h>
    #include<ctime>
    using namespace std;
    #define enter puts("") 
    #define space putchar(' ')
    #define Mem(a, x) memset(a, x, sizeof(a))
    #define In inline
    #define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
    typedef long long ll;
    typedef double db;
    const int INF = 0x3f3f3f3f;
    const db eps = 1e-8;
    const int maxn = 1e5 + 5;
    const int mod = 2012;
    const int maxs = 12;
    In ll read()
    {
    	ll ans = 0;
    	char ch = getchar(), las = ' ';
    	while(!isdigit(ch)) las = ch, ch = getchar();
    	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
    	if(las == '-') ans = -ans;
    	return ans;
    }
    In void write(ll x)
    {
    	if(x < 0) x = -x, putchar('-');
    	if(x >= 10) write(x / 10);
    	putchar(x % 10 + '0');
    }
    
    int n, len, p[maxn];
    char s2[maxn], s[maxn];
    struct Sam
    {
    	int tra[maxn << 1][maxs], link[maxn << 1], len[maxn << 1], cnt, las;
    	In void init() 
    	{
    		link[cnt = las = 0] = -1;
    		Mem(tra[0], 0), Mem(buc, 0);
    		Mem(sum, 0), Mem(num, 0);
    	}
    	In void insert(int c, int id)
    	{
    		int now = ++cnt, p = las; Mem(tra[now], 0);
    		len[now] = len[las] + 1;
    		while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
    		if(p == -1) link[now] = 0;
    		else
    		{
    			int q = tra[p][c];
    			if(len[q] == len[p] + 1) link[now] = q;
    			else
    			{
    				int clo = ++cnt; 
    				len[clo] = len[p] + 1;
    				memcpy(tra[clo], tra[q], sizeof(tra[q]));
    				link[clo] = link[q]; link[q] = link[now] = clo;
    				while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
    			}
    
    		}
    		las = now;		
    	}
    	int buc[maxn << 1], pos[maxn << 1], sum[maxn << 1], num[maxn << 1];
    	In int solve()
    	{
    		for(int i = 1; i <= cnt; ++i) buc[len[i]]++;
    		for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
    		for(int i = 0; i <= cnt; ++i) pos[buc[len[i]]--] = i;
    		num[0] = 1;
    		for(int i = 0; i <= cnt; ++i)
    		{
    			int now = pos[i];
    			for(int j = 0; j < 10; ++j)
    			{
    				if((!now && !j) || !tra[now][j]) continue;
    				int v = tra[now][j];
    				int tp = (sum[now] * 10 + j * num[now]) % mod;
    				sum[v] = (sum[v] + tp) % mod;
    				num[v] = (num[v] + num[now]) % mod;			
    			}
    		}
    		int ans = 0;
    		for(int i = 1; i <= cnt; ++i) ans = (ans + sum[i]) % mod;
    		return ans;
    	}
    }S;
    
    int main()
    {
    	while(scanf("%d", &n) != EOF)
    	{
    		len = 0, S.init();
    		for(int i = 1; i <= n; ++i)
    		{
    			scanf("%s", s2);
    			int l = strlen(s2); S.insert(10, len);
    			for(int j = 0; j < l; ++j) S.insert(s2[j] - '0', len);
    		}
    		write(S.solve()), enter;
    	}
    	return 0;
    }
    
  • 相关阅读:
    Python-文件阅读(open函数)
    列表推导式练习
    Python-集合(set)
    Python-元组(tuple)
    Python-函数-聚合和打散
    Python-列表-非count的计数方法
    Python-字典(dict)
    Python-列表(list)
    Python-字符串
    求三个元素的最大值,和最小值。
  • 原文地址:https://www.cnblogs.com/mrclr/p/14556995.html
Copyright © 2011-2022 走看看