题目传送门(内部题21)
输入格式
第一行一个字符串$str$,表示数据类型。
第二行一个正整数$k$,表示集合$K$的大小,保证$k>1$。
接下来$k$行每行$k$个数,第$i$行第$j$个数表示$ioplus j$的值,数据保证满足上文所提到的运算律。
接下来一行一个正整数$n$,表示城市的总数。
接下来$n-1$行,每行三个数$u,v,c$表示在城市$u$和城市$v$之间有一条边权为$c$的道路。
接下来一行$n$个数,第$i$个数表示城市$i$的敌对城市$x_i$,若$x_i=0$表示第$i$座城市非常热爱和平,否则保证有$x_{x_i}=i$。
输出格式
输出共一行,表示所有路径的旅者最终手上的数之和。
样例
样例输入:
F
2
0 1
1 0
4
1 2 1
2 3 1
2 4 0
0 0 4 3
样例输出:
6
数据范围与提示
题解
从$8$月$23$号打到$11$月$6$号,一有空就打,终于$AC$了(同桌疯狂吐槽我菜我也没办法……
怎么改出来的我已经不清楚了,根挤牙膏似的……
先转化一下题意:
在一棵$n$个节点的树上有一些坏点对,树上的边有边权,定义一种交换群运算,求所有不含坏点对的路径上的边权依次进行运算后的结果的和。
保证坏点对在以$1$为根的树上无父子关系。
由于题目保证了互相影响的点对一定没有父子关系,因此我们如果有了一条路径,只要$dfs$其中的一边判断另一边是否有这个点就行了。
因此我们对每个点$dfs$遍历每个轻儿子,因为求出了$dfs$序所以每个子树是一段连续的区间。我们$ban$掉一个点相当于$ban$掉了它的整个子树,因此我们访问到一个点就用线段树将这个区间的答案置$0$,因为逆元存在,对于每个点统计到所有可达的在已经访问的子树中的点到根的运算值即可。
原本的做法需要使用区间覆盖,考虑树的子树区间肯定只会互相包含或互不相交,实际上每次被删去的都是一些完整的区间。
考虑能否统计每个区间被减去了多少次。假设当前统计到节点$x$,与其对应的点为$u$,那么我们从$u$开始$dfs$,如果走到一个对应$x$的祖先的点就直接不访问,这样我们就能遍历到所有删去整个$x$的子树区间的点了,当然也可能$x$和它对应的点根本不会被整个删去,这意味着$x$到$lca(x,u)$的路径上存在一个对应$u$的祖先的点,这个可以随便维护一下。
我们用总的贡献减去$ban$掉的点对对答案的贡献,我们只要能够对于每个$x$求出这样的点对数目,最后计算的过程也可以用前面的$Theta(nk)$的做法优化。当然每对点只要选一边统计就行了。
问题可以转化为给两棵$n$个节点的树,每个标号$x$有一个$k$维向量的权$w_x$。对于所有的$x$,求出$sum w_y$,满足$x$在第一棵树上是$y$的祖先且$y$在第二棵树上是$x$的祖先,同时在$x$和$y$之间没有合法的$y$。
在第二棵树上$dfs$,只要在第一棵树上求出相同标号点子树的贡献即可。
对于新加入的点,贡献减去其子树中所有贡献,再将最近的父亲减去其贡献即可,可以树状数组维护。
注意要用树链剖分求$LCA$,而且读入要使用$AE86$。
不然就会$downarrow$
时间复杂度:$Theta(nklog n)$。
期望得分:$100$分。
实际得分:$96sim 100$分(看卡常了)。
代码时刻
#include<bits/stdc++.h>
using namespace std;
struct rec{int nxt,to,w;}e[500000];
int head[200001],cnt;
char opt[5];
int n,k;
int Map[129][129],pos[129][129],f[200001][129],g[200001][129];
int a[200001];
int dp[200001],dfn[200001],end[200001],depth[200001],t[200001],fa[200001],val[200001],wzc[129],qf[129],qg[129],tim;
int top[200001],size[200001],son[200001];
int tr[200001][129],trx[200001];
bool vis[200001];
long long ans;
namespace ae86{
const int bufl=1<<20;
char buf[bufl],*s=buf,*t=buf;
inline int fetch(){
if(s==t){t=(s=buf)+fread(buf,1,bufl,stdin);if(s==t)return EOF;}
return*s++;
}
inline int read(){
int a=0,b=1,c=fetch();
while(!isdigit(c))b^=c=='-',c=fetch();
while(isdigit(c))a=a*10+c-48,c=fetch();
return b?a:-a;
}
}
using ae86::read;
void add(int x,int y,int w)
{
e[++cnt].nxt=head[x];
e[cnt].to=y;
e[cnt].w=w;
head[x]=cnt;
}
void dfs(int x)
{
dfn[x]=++tim;
t[tim]=x;
size[x]=1;
for(int i=head[x];i;i=e[i].nxt)
{
if(depth[e[i].to])continue;
fa[e[i].to]=x;
dp[e[i].to]=Map[dp[x]][e[i].w];
depth[e[i].to]=depth[x]+1;
dfs(e[i].to);
for(int j=0;j<=k;j++)f[x][Map[e[i].w][j]]+=f[e[i].to][j];
f[x][e[i].w]++;val[e[i].to]=e[i].w;
size[x]+=size[e[i].to];
if(size[e[i].to]>size[son[x]])son[x]=e[i].to;
}
end[x]=tim;
}
int LCA(int x,int y)
{
while(top[x]!=top[y])
depth[top[x]]>depth[top[y]]?x=fa[top[x]]:y=fa[top[y]];
return depth[x]<depth[y]?x:y;
}
void dfs(int x,int topp)
{
top[x]=topp;
if(son[x])dfs(son[x],topp);
for(int i=head[x];i;i=e[i].nxt)if(!top[e[i].to])dfs(e[i].to,e[i].to);
}
void add(int x,int res)
{
for(int i=x;i<=n;i+=i&-i)
for(int j=0;j<=k;j++)
tr[i][j]+=g[res][j];
}
void del(int x,int res)
{
for(int i=x;i<=n;i+=i&-i)
for(int j=0;j<=k;j++)
tr[i][j]-=g[res][j];
}
void add(int x){for(int i=x;i<=n;i+=i&-i)trx[i]++;}
void del(int x){for(int i=x;i<=n;i+=i&-i)trx[i]--;}
int ask(int x){int res=0;for(int i=x;i;i-=i&-i)res+=trx[i];return res;}
void askf(int x)
{
for(int i=x;i;i-=i&-i)
for(int j=0;j<=k;j++)
wzc[j]+=tr[i][j];
}
void askr(int x)
{
for(int i=x;i;i-=i&-i)
for(int j=0;j<=k;j++)
wzc[j]-=tr[i][j];
}
void get(int res1,int res2,int x)
{
int fail1=0,fail2=0;
int res=Map[dp[x]][dp[x]];
for(int i=0;i<=k;i++)
{
if(f[res1][i])qf[++fail1]=i;
if(g[res2][i])qg[++fail2]=i;
}
for(int i=1;i<=fail1;i++)
for(int j=1;j<=fail2;j++)
ans-=1LL*f[res1][qf[i]]*g[res2][qg[j]]*pos[Map[qf[i]][qg[j]]][res];
}
void get(int x)
{
int y=a[x];
if(y)
{
if(ask(dfn[y]))vis[x]=1;
else
{
memset(wzc,0,sizeof(wzc));
askf(end[y]);
askr(dfn[y]-1);
for(int i=0;i<=k;i++)g[y][i]=f[y][i]-wzc[i];
get(x,y,LCA(x,y));
add(dfn[y],y);
del(end[y]+1);
add(dfn[y]);
}
}
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=fa[x])get(e[i].to);
if(!vis[x]&&a[x])
{
del(dfn[a[x]],a[x]);
add(end[a[x]]+1);
del(dfn[a[x]]);
}
}
int main()
{
scanf("%s",opt+1);
k=read()-1;
int now=-1;
for(int i=0;i<=k;i++)
for(int j=0;j<=k;j++)
{
Map[i][j]=read();
pos[Map[i][j]][i]=j;
}
for(int i=0;i<=k;i++)
{
now=i;
for(int j=0;j<=k;j++)
if(Map[i][j]!=j){now=-1;break;}
if(now>=0)break;
}
n=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read(),c=read();
add(u,v,c);
add(v,u,c);
}
for(int i=1;i<=n;i++)a[i]=read();
dp[1]=now;
depth[1]=1;
dfs(1);dfs(1,1);
for(int i=0;i<=k;i++)
{
g[1][i]=f[1][i];
ans+=1LL*g[1][i]*i;
}
for(int i=2;i<=n;i++)
{
int x=t[i];
int fat=fa[x];
int v=val[x];
for(int j=0;j<=k;j++)g[fat][Map[v][j]]-=f[x][j];
g[fat][v]--;
for(int j=0;j<=k;j++)g[x][j]=f[x][j]+g[fat][pos[j][v]];
g[x][v]++;
for(int j=0;j<=k;j++){g[fat][Map[v][j]]+=f[x][j];ans+=1LL*g[x][j]*j;}
g[fat][v]++;
}
for(int i=2;i<=n;i++)
{
f[i][now]++;
for(int j=0;j<=k;j++)g[i][Map[dp[i]][j]]=f[i][j];
for(int j=0;j<=k;j++)f[i][j]=g[i][j];
}
get(1);
printf("%lld",ans);
return 0;
}
rp++