题意:
给你n m q,表示在这一组数据中所有的01串长度均为n,然后给你一个含有m个元素的multiset,之后有q次询问。每次询问会给你一个01串t和一个给定常数k,让你输出串t和multiset里面多少个元素的“Wu”值不超过k。对于“Wu”值的定义:如果两个01串s和t在位置i上满足s[i]==t[i],那么加上w[i],处理完s和t的所有n位之后的结果即为这两个01串的“Wu”值。
n<12,k<100,m<5e5
思路:
n很小,k也很小,所以串的状态最多2^12次,预处理出sum[i][j]为串x(x转化为二进制i)与multiset里的wu值为j的数量
预处理复杂度O($2^n*2^n*n$)
询问的时候也可以与处理一下sum,不过这题k很小
代码:
#include<iostream> #include<iomanip> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #include<string> #include<stack> #include<queue> #include<deque> #include<set> #include<vector> #include<map> #include<functional> #include<list> #define fst first #define sc second #define pb push_back #define mp(a,b) make_pair(a,b) #define mem(a,b) memset(a,b,sizeof(a)) #define lson l,mid,root<<1 #define rson mid+1,r,root<<1|1 #define lc root<<1 #define rc root<<1|1 #define lowbit(x) ((x)&(-x)) #pragma Gcc optimize(2) using namespace std; typedef double db; typedef long double ldb; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> PI; typedef pair<ll,ll> PLL; const int maxn = 5e5 + 100; const int maxm = 5e3 + 100; const double eps = 1e-10; const int inf = 0x3f3f3f3f; const double pi = acos(-1.0); int scan(){ int res=0,ch,flag=0; if((ch=getchar())=='-') flag=1; else if(ch>='0'&&ch<='9') res=ch-'0'; while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0'; return flag?-res:res; } int w[maxn]; int cnt[4096 + 100]; int sum[4096 + 100][100 + 10]; int main(){ int n, m, q; scanf("%d %d %d", &n, &m, &q); mem(cnt, 0); mem(sum, 0); for(int i = 1; i <= n; i++){ scanf("%d", &w[i]); } for(int i = 1; i <= m; i++){ char s[20]; scanf("%s", s); int x = 0; for(int j = 0; j < n; j++){ if(s[j]=='1')x += 1<<(n-j-1); } cnt[x]++; } for(int i = 0; i <= (1<<12); i++){ for(int j = 0; j <= (1<<12); j++){ if(!cnt[j])continue; int tmp = 0; for(int k = 0; k <= 12; k++){ if((i&(1<<k))==(j&(1<<k)))tmp+=w[n-k]; if(tmp > 100) break; } if(tmp <= 100) sum[i][tmp] += cnt[j]; } } for(int i = 1; i <= q; i++){ char s[20];int c; scanf("%s %d", s, &c); int x = 0; for(int j = 0 ; j < n; j++) if(s[j]=='1')x += 1<<(n-j-1); int ans = 0; for(int j = 0; j <= c; j++)ans+=sum[x][j]; printf("%d ", ans); } return 0; } /* */