VII.[HNOI2015]开店
首先,第一种方法便是动态点分治。
我们先考虑忽略年龄限制的情形。
我们考虑正常求一个点到另一个点的距离应该怎么求——
一般来说,我们会用\(dis(i,j)=dep_i+dep_j-2*dep_{lca(i,j)}\)对吧?
这个东西相当于将路径划分成两个部分,其中每个部分的长度都易于求出。上面我们采取了\(lca(i,j)\)作为分割点。
那如果我们在点分树上求路径长度,又该如何呢?
我们或许还是能自然地想到\(dep_i+dep_j-2*dep_{lca(i,j)}\),其中\(dep_i\)是\(i\)到点分树的根的距离,而\(lca\)是点分树上的最近公共祖先。
但是很明显,这是错的——点分树上的父子关系极松,这意味着不一定有\(dis(i,j)=dep_j-dep_i\),其中\(i\)是\(j\)的祖先。
那怎么办呢?
这父子关系再松,有一点也是满足的——两点在点分树上的lca,一定在原树上两点间路径上。
换言之,必有
\(dis(i,j)=dis\Big(i,lca(i,j)\Big)+dis\Big(j,lca(i,j)\Big)\),其中\(lca(i,j)\)为点分树上lca,而\(dis\)为原树上距离。
原树上的距离,我们可以直接用ST表在\(O(1)\)时间内求出。故这个东西可以转到点分树上求出。
我们要求
即
如果我们换成枚举\(lca(i,x)\),则有
注意到我们这里出现了两个东西:\(cnt\)和\(sum\)。其中,\(cnt\)意为子树中所有 \(lca(x,j)=i\) 的\(j\)的数量,而\(sum\)意为子树中所有 \(lca(x,j)=i\) 的\(dis(i,j)\)之和。
显然,如果\(lca(x,j)=i\),它们只要满足来自\(i\)在点分树中的不同子树即可。即:\(i\)在点分树上的子树,挖去\(x\)所在的那颗子树即可。
我们可以预处理出\(i\)的所有儿子的数量,以及它们到\(i\)的距离和。我们用一个\(vecsf\)维护。然后,在每个节点处,维护子树中所有节点数量以及它们到父亲的距离和(方便在父亲处减掉这些东西),记为\(vecfa\)。
则我们只需要不断跳父亲,然后从\(vecsf\)与\(vecfa\)中对应加减即可。
现在有了年龄限制,那有如何呢?
好办。之前的\(vecsf\)与\(vecfa\)可以只用一个值表示即可,我们现在把它换成vector
。在vector
内部按照颜色排序,并作前缀和。最终只要在vector
中二分即可回答询问。
但因为实现的不好,它MLE了。
代码:
#pragma GCC optimize(3)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,lim,val[100100],dep[100100],mn[200100][20],in[100100],LG[200100],tot,fa[100100],admin;
ll las;
namespace Tree{
int sz[100100],SZ,msz[100100],ROOT,head[100100],cnt;
struct node{
int to,next,val;
}edge[200100];
void ae(int u,int v,int w){
edge[cnt].next=head[u],edge[cnt].to=v,edge[cnt].val=w,head[u]=cnt++;
edge[cnt].next=head[v],edge[cnt].to=u,edge[cnt].val=w,head[v]=cnt++;
}
bool vis[100100];
void getsz(int x,int fa){
sz[x]=1;
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getsz(edge[i].to,x),sz[x]+=sz[edge[i].to];
}
void getroot(int x,int fa){
sz[x]=1,msz[x]=0;
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getroot(edge[i].to,x),sz[x]+=sz[edge[i].to],msz[x]=max(msz[x],sz[edge[i].to]);
msz[x]=max(msz[x],SZ-sz[x]);
if(msz[x]<msz[ROOT])ROOT=x;
}
void solve(int x){
getsz(x,0);
vis[x]=true;
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])ROOT=0,SZ=sz[edge[i].to],getroot(edge[i].to,0),fa[ROOT]=x,solve(ROOT);
}
void getural(int x,int fa){
mn[++tot][0]=x,in[x]=tot;
for(int i=head[x];i!=-1;i=edge[i].next)if(edge[i].to!=fa)dep[edge[i].to]=dep[x]+edge[i].val,getural(edge[i].to,x),mn[++tot][0]=x;
}
}
int MIN(int i,int j){
return dep[i]<dep[j]?i:j;
}
int LCA(int i,int j){
i=in[i],j=in[j];
if(i>j)swap(i,j);
int k=LG[j-i+1];
return MIN(mn[i][k],mn[j-(1<<k)+1][k]);
}
int DIS(int i,int j){
return dep[i]+dep[j]-dep[LCA(i,j)]*2;
}
namespace cdt{
vector<int>v[100100];
vector<pair<int,ll> >vecfa[100100],vecsf[100100];
void prepvec(int x,int z){
if(fa[z])vecfa[z].push_back(make_pair(val[x],DIS(x,fa[z])));
vecsf[z].push_back(make_pair(val[x],DIS(x,z)));
for(auto y:v[x])prepvec(y,z);
}
ll calc(int x,int L,int R){
ll res=0;
int u=x;
while(x){
int l=lower_bound(vecsf[x].begin(),vecsf[x].end(),make_pair(L,-1ll))-vecsf[x].begin()-1;
int r=upper_bound(vecsf[x].begin(),vecsf[x].end(),make_pair(R,0x3f3f3f3f3f3f3f3fll))-vecsf[x].begin()-1;
res+=vecsf[x][r].second-vecsf[x][l].second;
res+=1ll*DIS(u,x)*(r-l);
if(!fa[x])break;
l=lower_bound(vecfa[x].begin(),vecfa[x].end(),make_pair(L,-1ll))-vecfa[x].begin()-1;
r=upper_bound(vecfa[x].begin(),vecfa[x].end(),make_pair(R,0x3f3f3f3f3f3f3f3fll))-vecfa[x].begin()-1;
res-=vecfa[x][r].second-vecfa[x][l].second;
res-=1ll*DIS(u,fa[x])*(r-l);
x=fa[x];
}
return res;
}
void prepare(){
for(int i=1;i<=n;i++)if(fa[i])v[fa[i]].push_back(i);
for(int i=1;i<=n;i++){
prepvec(i,i);
vecfa[i].push_back(make_pair(-1,0)),vecsf[i].push_back(make_pair(-1,0));
sort(vecfa[i].begin(),vecfa[i].end());
sort(vecsf[i].begin(),vecsf[i].end());
for(int j=1;j<vecfa[i].size();j++)vecfa[i][j].second+=vecfa[i][j-1].second;
for(int j=1;j<vecsf[i].size();j++)vecsf[i][j].second+=vecsf[i][j-1].second;
}
}
}
void read(int &x){
x=0;
char c=getchar();
while(c>'9'||c<'0')c=getchar();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
}
int main(){
read(n),read(m),read(lim),memset(Tree::head,-1,sizeof(Tree::head));
for(int i=1;i<=n;i++)read(val[i]);
for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),Tree::ae(x,y,z);
Tree::msz[0]=n+1,Tree::SZ=n,Tree::getroot(1,0),admin=Tree::ROOT,Tree::solve(Tree::ROOT);
Tree::getural(1,0);
for(int i=2;i<=tot;i++)LG[i]=LG[i>>1]+1;
for(int j=1;j<=LG[tot];j++)for(int i=1;i+(1<<j)-1<=tot;i++)mn[i][j]=MIN(mn[i][j-1],mn[i+(1<<(j-1))][j-1]);
cdt::prepare();
for(int i=1,x,l,r;i<=m;i++){
read(x),read(l),read(r),l=(las+l)%lim,r=(las+r)%lim;
if(l>r)swap(l,r);
printf("%lld\n",las=cdt::calc(x,l,r));
}
return 0;
}
然后,第二种方法便是主席树+树剖。
仍然先忽略年龄限制,我们直接使用\(dis(i,j)=dep_i+dep_j-2*dep_{lca(i,j)}\),得到
化简得
前两个可以很轻松预处理出来,但是最后一部分呢?
考虑我们每个节点向上走一直走到根,在每条边上维护一个计数器,然后每次被经过了的边上的计数器就加\(1\)。然后询问的时候,从询问的点出发向上走,对于每条经过的边,答案加上(计数器大小*边权)即可。
这个很好理解,因为\(dep_{lca(i,x)}\),就等于所有既是\(i\)的祖先,又是\(x\)的祖先的边的边权和。刚才的操作的意义,就是求这些边的边权和。
如果我们用树剖的话,复杂度就是\(O(n\log^2 n)\)的。
但是,问题来了,加上边权限制怎么办?
回忆一下,在之前动态点分治的方法中,我们用前缀和实现了这一效果;而在这里,我们也可以用前缀和——只不过是对树剖时建的线段树作前缀和。线段树的前缀和,就是主席树咯。
具体而言,我们在建树的时候,将所有节点按照年龄排序后插入,这样询问时就可以直接在对应的主席树上减以下即可。
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mid ((l+r)>>1)
int n,m,lim,a[150010],dis[150010],fa[150010],son[150010],dfn[150010],rev[150010],sz[150010],top[150010],head[150010],tot,cnt,cpt,rt[150010],num[150010],val[150010];
ll sum[150010],ans,add[150010];
struct Edge{
int to,next,val;
}edge[300100];
void ae(int u,int v,int w){
edge[cnt].next=head[u],edge[cnt].to=v,edge[cnt].val=w,head[u]=cnt++;
edge[cnt].next=head[v],edge[cnt].to=u,edge[cnt].val=w,head[v]=cnt++;
}
vector<int>ds,v[150010];
void dfs1(int x){
sz[x]=1;
for(int i=head[x],y;i!=-1;i=edge[i].next){
if((y=edge[i].to)==fa[x])continue;
fa[y]=x,dis[y]=dis[x]+edge[i].val,val[y]=edge[i].val;
dfs1(y);
sz[x]+=sz[y];
if(sz[y]>sz[son[x]])son[x]=y;
}
}
void dfs2(int x){
if(son[x])dfn[++tot]=son[x],rev[son[x]]=tot,top[son[x]]=top[x],dfs2(son[x]);
for(int i=head[x],y;i!=-1;i=edge[i].next){
y=edge[i].to;
if(y==fa[x]||y==son[x])continue;
dfn[++tot]=y,rev[y]=tot,top[y]=y,dfs2(edge[i].to);
}
}
struct SegTree{
int lson,rson,era,tag;
ll sum;
}seg[10001000];
void build(int &x,int l,int r){
x=++cpt;
if(l==r)return;
build(seg[x].lson,l,mid),build(seg[x].rson,mid+1,r);
}
void pushup(int x,int l,int r){
seg[x].sum=seg[seg[x].lson].sum+(sum[mid]-sum[l-1])*seg[seg[x].lson].tag+seg[seg[x].rson].sum+(sum[r]-sum[mid])*seg[seg[x].rson].tag;
}
void modify(int pre,int &x,int l,int r,int L,int R,int tim){
if(l>R||r<L)return;
if(seg[x].era!=tim)x=++cpt,seg[x]=seg[pre],seg[x].era=tim;
// printf("%d:(%d,%d):(%d,%d):%d\n",x,l,r,L,R,tim);
if(L<=l&&r<=R){seg[x].tag++;return;}
modify(seg[pre].lson,seg[x].lson,l,mid,L,R,tim),modify(seg[pre].rson,seg[x].rson,mid+1,r,L,R,tim),pushup(x,l,r);
}
ll query(int x,int l,int r,int L,int R,int tag){
if(!x||l>R||r<L)return 0;
tag+=seg[x].tag;
// printf("%d:(%d,%d):(%d,%d):%d\n",x,l,r,L,R,tag);
if(L<=l&&r<=R)return (sum[r]-sum[l-1])*tag+seg[x].sum;
return query(seg[x].lson,l,mid,L,R,tag)+query(seg[x].rson,mid+1,r,L,R,tag);
}
void initjump(int x){
int col=a[x];
while(x){
modify(rt[col-1],rt[col],1,n,rev[top[x]],rev[x],col);
x=fa[top[x]];
}
}
ll queryjump(int x,int L,int R){
ll res=0;
while(x){
res+=query(rt[R],1,n,rev[top[x]],rev[x],0)-query(rt[L-1],1,n,rev[top[x]],rev[x],0);
x=fa[top[x]];
}
// printf("%d\n",res);
return res;
}
int main(){
scanf("%d%d%d",&n,&m,&lim),memset(head,-1,sizeof(head));
for(int i=1;i<=n;i++)scanf("%d",&a[i]),ds.push_back(a[i]);
for(int i=1,x,y,z;i<n;i++)scanf("%d%d%d",&x,&y,&z),ae(x,y,z);
dfs1(1),top[1]=rev[1]=dfn[1]=tot=1,dfs2(1);
// for(int x=1;x<=n;x++)printf("%d::FA:%d SN:%d SZ:%d DN:%d RV:%d DS:%d TP:%d\n",x,fa[x],son[x],sz[x],dfn[x],rev[x],dis[x],top[x]);
sort(ds.begin(),ds.end()),ds.resize(unique(ds.begin(),ds.end())-ds.begin());
for(int i=1;i<=n;i++)a[i]=lower_bound(ds.begin(),ds.end(),a[i])-ds.begin()+1,v[a[i]].push_back(i),add[a[i]]+=dis[i],num[a[i]]++;
// for(int i=1;i<=n;i++)printf("(%d:%d)",i,a[i]);puts("");
for(int i=1;i<=n;i++)sum[i]=sum[i-1]+val[dfn[i]],add[i]+=add[i-1],num[i]+=num[i-1];
build(rt[0],1,n);
for(int i=1;i<=ds.size();i++){
rt[i]=++cpt,seg[rt[i]]=seg[rt[i-1]],seg[rt[i]].era=i;
for(auto x:v[i])initjump(x);
}
for(int i=1,x,l,r;i<=m;i++){
scanf("%d%d%d",&x,&l,&r);
l=(ans+l)%lim,r=(ans+r)%lim;
if(l>r)swap(l,r);
l=lower_bound(ds.begin(),ds.end(),l)-ds.begin()+1;
r=upper_bound(ds.begin(),ds.end(),r)-ds.begin();
// printf("%d %d %d\n",x,l,r);
if(l>r){printf("%lld\n",ans=0);continue;}
// printf("%d,%d\n",add[r]-add[l-1],num[r]-num[l-1]);
ans=(add[r]-add[l-1])+1ll*dis[x]*(num[r]-num[l-1])-2ll*queryjump(x,l,r);
printf("%lld\n",ans);
}
return 0;
}