看到题目肯定首先想到要求LCA(其实是我菜),可乍一看,n与q的规模为5W,
求LCA的复杂度为(O(logN)),那么总时间复杂度为(O(nq log n))。
怎么搞呢?
会树上差分的都知道,要对一条链进行操作,比如说链上的节点权值(+p),就要对两个端点分别(+p),然后对(LCA)及其父亲分别(-p)。
和这个思想差不多,设两个点为(u, v),那么求(dep(LCA(u,v)))只需要把((root,u))之前的路径所有点权值(+1),然后统计((root,v))路径上的权值和就行了,具体请自行脑补我觉得很好理解啊,这个过程显然可以用树剖来完成。
然后就是差分,注意不是树上差分,(sum_{i=l}^{r}LCA(i,z)=sum_{i=1}^{r}LCA(i,z)-sum_{i=1}^{l-1}LCA(i,z))
考虑离线,设每个询问为((l,r,z)),定义一个结构体,包含(x)(端点),(k)(-1为左端点,1为右端点),(z)(要求LCA的点),(id)(是哪个询问),用2个空间分别存储左端点(l)和右端点(r),然后按端点编号排序。此时存放询问的数组应该是有(2q)个的。定义一个指针(cur),表示当前已经把前(cur)个点到根的路径(+1)了。枚举(i(1~2q)),如果是左端点,把(cur)更新到(x-1),((while(cur<x-1) Updata(1,++cur);)),把(ans[id])减去(z)到根的权值和,相当于(-sum_{i=1}^{l-1}LCA(i,z))。如果是右端点,同理,把(cur)更新到(x),然后(ans[id])加上...这题就愉快地切掉了。时间复杂度(O(n log^2 n))
手写树剖一遍过,好激动
#include <cstdio>
#include <algorithm>
using std::swap;
using std::sort;
#define left (now << 1)
#define right (now << 1 | 1)
const int MAXN = 50010;
const int MOD = 201314;
struct Edge{
int next, to;
}e[MAXN];
int head[MAXN], num, n, q, a, b, c;
int fa[MAXN], maxson[MAXN], size[MAXN], top[MAXN], dfsx[MAXN], ID, pos[MAXN], dep[MAXN], ans[MAXN];
inline void Add(int from, int to){
e[++num].to = to;
e[num].next = head[from];
head[from] = num;
}
int sum[MAXN << 2], lazy[MAXN << 2];
inline void pushdown(int now, int len){
if(lazy[now]){
sum[left] += lazy[now] * (len - (len >> 1));
sum[right] += lazy[now] * (len >> 1);
lazy[left] += lazy[now];
lazy[right] += lazy[now];
lazy[now] = 0;
}
}
inline void pushup(int now){
sum[now] = sum[left] + sum[right];
}
/*void Build(int now, int l, int r){
if(l != r){
int mid = (l + r) >> 1;
Build(left, l, mid);
Build(right, mid + 1, r);
pushup(now);
}
else sum[now] = a[l];
}*/
void Change(int now, int l, int r, int wl, int wr, int p){
if(l > wr || r < wl) return ;
if(l >= wl && r <= wr){
sum[now] += p * (r - l + 1); lazy[now] += p; return ;
}
pushdown(now, r - l + 1);
int mid = (l + r) >> 1;
Change(left, l, mid, wl, wr, p);
Change(right, mid + 1, r, wl, wr, p);
pushup(now);
}
int Query(int now, int l, int r, int wl, int wr){
if(l > wr || r < wl) return 0;
if(l >= wl && r <= wr) return sum[now];
pushdown(now, r - l + 1);
int ans = 0;
int mid = (l + r) >> 1;
ans += Query(left, l, mid, wl, wr);
ans += Query(right, mid + 1, r, wl, wr);
return ans;
}
struct ASK{
int x, z, k, id;
bool operator < (const ASK &A) const{ //按照端点排序
return x == A.x ? k < A.k : x < A.x;
}
}Ans[MAXN << 1];
void dfs1(int u){
dep[u] = dep[fa[u]] + 1;
size[u] = 1;
for(int i = head[u]; i; i = e[i].next){
dfs1(e[i].to);
size[u] += size[e[i].to];
if(size[e[i].to] > size[maxson[u]])
maxson[u] = e[i].to;
}
}
void dfs2(int u, int t){
dfsx[++ID] = u;
top[u] = t;
if(maxson[u]) dfs2(maxson[u], t);
for(int i = head[u]; i; i = e[i].next){
if(e[i].to != maxson[u])
dfs2(e[i].to, e[i].to);
}
}
void Updata(int u, int v, int p){
int x = top[u], y = top[v];
while(x != y){
if(dep[x] > dep[y]) swap(x, y), swap(u, v);
Change(1, 1, n, pos[y], pos[v], p);
u = fa[x]; x = top[u];
v = fa[y]; y = top[v];
}
if(dep[u] > dep[v]) swap(u, v);
Change(1, 1, n, pos[u], pos[v], p);
}
int Solve(int u, int v){
int ans = 0;
int x = top[u], y = top[v];
while(x != y){
if(dep[x] > dep[y]) swap(x, y), swap(u, v);
ans += Query(1, 1, n, pos[y], pos[v]);
u = fa[x]; x = top[u];
v = fa[y]; y = top[v];
}
if(dep[u] > dep[v]) swap(u, v);
ans += Query(1, 1, n, pos[u], pos[v]);
return ans;
}
int cur;
int main(){
scanf("%d%d", &n, &q);
for(int i = 2; i <= n; ++i)
scanf("%d", &fa[i]), Add(fa[i] += 1, i);
for(int i = 1; i <= q; ++i){
scanf("%d%d%d", &a, &b, &c);
Ans[i].x = a + 1;
Ans[i + q].x = b + 1;
Ans[i].z = Ans[i + q].z = c + 1;
Ans[i].k = -1;
Ans[i + q].k = 1;
Ans[i].id = Ans[i + q].id = i;
}
dfs1(1);
dfs2(1, 1);
for(int i = 1; i <= n; ++i)
pos[dfsx[i]] = i;
sort(Ans + 1, Ans + q * 2 + 1);
for(int i = 1; i <= 2 * q; ++i){
if(Ans[i].k < 0) //是左端点
while(cur < Ans[i].x - 1)
Updata(1, ++cur, 1);
else //是右端点
while(cur < Ans[i].x)
Updata(1, ++cur, 1);
ans[Ans[i].id] += Ans[i].k * Solve(1, Ans[i].z);
}
for(int i = 1; i <= q; ++i)
printf("%d
", (ans[i] + MOD) % MOD);
return 0;
}