描述
两个序列 (x, y),可以将一个序列每个值同时加非负整数 (c),其中一个序列可以循环移位,要求最小化:
[sum_{i = 1}^{n}(x_i - y_i) ^ 2
]
题解
循环移位 (Leftrightarrow) 断环成链。显然那个序列循环移位不影响,而且强制加值在 (x) 上, (c) 可以为负整数(可以理解为如果是负数,则把这个的绝对值加到 (y) 上,差保持不变),不妨让 y 移位,将 y 数组复制一倍到末尾,由于循环移位,所以 (sum_{i=1}^{n} y_{j + i} = sum_{i=1}^{n} y_i)。
[ans = min{ sum_{i = 1}^{n} (x_i + c - y_{j + i}) ^ 2 }
]
把里面这个东西拿出来:
[sum_{i = 1}^{n} (x_i + c - y_{j + i}) ^ 2 = sum_{i=1}^{n}x_i ^ 2 + sum_{i=1}^{n}y_i^2 + nc^2 + 2csum(x_i - y_i) - 2 sum x_i y_{i + j}
]
要让这个式子尽量小:
- 前两项是定值
- 第 (3, 4) 项是一个关于 (c) 的开口向上二次函数,由于要求取整数,所以算对称轴,算一下最近的两个整数取最优值即可。
比较棘手的是最后一项 (sum x_i y_{i + j}) (可以忽略系数),感觉可以转化成卷积的形式,用套路性的反转序列试试看:
新建一个数组 (z),令 (z_{2n - i + 1} = y_i)
(sum x_i y_{i + j} = sum x_i z_{2n + 1 - i - j}),很显然的一个卷积,即从 (j) 位开始的答案记录在了 (2n - j + 1) 位的系数上。
Tips
- C++ 如果是负数整除会上取整,注意特判
时间复杂度
(O(Nlog_2N))
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long LL;
const int N = 2e5 + 5;
const double PI = acos(-1);
int n, m, c, lim = 1, len, rev[N], x[N], y[N];
LL ans = 0;
struct CP{
double x, y;
CP operator + (const CP &b) const { return (CP){ x + b.x, y + b.y }; }
CP operator - (const CP &b) const { return (CP){ x - b.x, y - b.y }; }
CP operator * (const CP &b) const { return (CP){ x * b.x - y * b.y, x * b.y + y * b.x }; }
} F[N], G[N];
void FFT(CP a[], int opt) {
for (int i = 0; i < lim; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int m = 1; m <= lim; m <<= 1) {
CP wn = (CP){ cos(2 * PI / m), opt * sin(2 * PI / m) };
for (int i = 0; i < lim; i += m) {
CP w = (CP){ 1, 0 };
for (int j = 0; j < (m >> 1); j++) {
CP u = a[i + j], t = w * a[i + j + (m >> 1)];
a[i + j] = u + t, a[i + j + (m >> 1)] = u - t;
w = w * wn;
}
}
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", x + i), ans += x[i] * x[i], c += x[i];
for (int i = 1; i <= n; i++) scanf("%d", y + i), y[i + n] = y[i], ans += y[i] * y[i], c -= y[i];
// x = l 是对称轴
int l = -c / n;
if (c > 0) l--;
ans += min(l * l * n + 2 * l * c, (l + 1) * (l + 1) * n + 2 * (l + 1) * c);
LL v = 0;
for (int i = 1; i <= n; i++) F[i].x = x[i];
for (int i = 1; i <= 2 * n; i++) G[i].x = y[2 * n - i + 1];
while (lim <= 2 * n) lim <<= 1, ++len;
for (int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
FFT(F, 1); FFT(G, 1);
for (int i = 0; i < lim; i++) F[i] = F[i] * G[i];
FFT(F, -1);
for (int i = n + 1; i <= 2 * n; i++) v = max(v, (LL)(F[i].x / lim + 0.5));
printf("%lld
", ans - 2 * v);
return 0;
}