题目链接:https://acm.hdu.edu.cn/showproblem.php?pid=6966
题目大意:计算(sum_{p=1}^nsum_{k=1}^{+infty }d_{p,k}*c_{p}^k),其中(d_{p,k}=sum_{k=ioplus j}a_i*b_j(1leq ileq frac{n}{p},1leq jleq frac{n}{p})),(oplus)运算为三进制下按位取(gcd),即(k_{t}=gcd(i_{t},j_{t}))
题目思路:对于每个(frac{n}{p})暴力FWT计算(d_{p,k})的时间复杂度为(sum_{p=1}^nO({nover p}log{nover p})=O(nlog^2 n))
难点在于构造三维矩阵
[egin{pmatrix}
c(0,0) &c(0,1) &c(0,2) \
c(1,0) &c(1,1) &c(2,2) \
c(2,0) &c(2,1) &c(2,2)
end{pmatrix}
]
使其中(c(x,y)c(x,z) = c(x,yoplus z) = c (x,gcd(y,z)))
我们假设(x = 0),因为(c(0,0)c(0,0)=c(0,0)),所以(c(0,0)=1)或(0),除了自己乘自己,其他可能的情况有
[left{egin{matrix}
c(0,0)c(0,1)=c(0,1)\
c(0,0)c(0,2)=c(0,2)\
c(0,1)c(0,2)=c(0,1)
end{matrix}
ight.
]
直接暴力枚举每个值发现满足条件的有四组((0 ,0 ,0)(1, 0 ,0)(1 ,0 ,1)(1 ,1 ,1) ),因为要保证矩阵有逆,所以取矩阵
[egin{pmatrix}
1&0&0\
1&1&1\
1&0&1
end{pmatrix}
]
及其逆
[egin{pmatrix}
1&0&0\
0&1&-1\
-1&0&1
end{pmatrix}
]
AC代码:
#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 4e6 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 1e9 + 7;
const double eps = 1e-12;
const double PI = acos(-1.0);
int a[N], b[N], c[N];
int aa[N], bb[N], cc[N];
void FWT(int a[], int n, int flag)
{
for (int len = 1; len < n; len *= 3)
for (int i = 0; i < n; i += len * 3)
for (int j = 0; j < len; ++j)
{
int x = a[i + j], y = a[i + j + len], z = a[i + j + 2 * len];
if (flag == 1)
{
a[i + j] = x;
a[i + j + len] = ((x + y) % mod + z) % mod;
a[i + j + 2 * len] = (x + z) % mod;
}
else
{
a[i + j] = x;
a[i + j + len] = (y - z + mod) % mod;
a[i + j + 2 * len] = (z - x + mod) % mod;
}
}
}
int cal(int n, int c)
{
int tot = 1;
while (tot <= n)
tot *= 3;
for (int i = 0; i <= n; ++i)
{
aa[i] = a[i];
bb[i] = b[i];
}
for (int i = n + 1; i < tot; ++i)
{
aa[i] = 0;
bb[i] = 0;
}
FWT(aa, tot, 1), FWT(bb, tot, 1);
for (int i = 0; i < tot; ++i)
cc[i] = (ll)aa[i] * bb[i] % mod;
FWT(cc, tot, -1);
int res = 0, ck = c;
for (int k = 1; k < tot; ++k, ck = (ll)ck * c % mod)
res = (res + (ll)cc[k] * ck) % mod;
return res;
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", &a[i]);
for (int i = 1; i <= n; ++i)
scanf("%d", &b[i]);
for (int i = 1; i <= n; ++i)
scanf("%d", &c[i]);
int ans = 0;
for (int p = 1; p <= n; ++p)
ans = (ans + cal(n / p, c[p])) % mod;
printf("%d
", ans);
return 0;
}