题解:
会了消耗战之后,这题的难点就只在统计ans了。
我原来求最长(短)链只会保留次优值,然后开三个数组写得特别麻烦。。。
今天学习了,orz POPOQQQ
inline void dfs(int x) { f[x]=v[x];g[x]=0; mi[x]=v[x]?0:inf; mx[x]=v[x]?0:-inf; for4(i,x) { dfs(y); ans+=(g[x]+f[x]*e[i].w)*f[y]+g[y]*f[x]; f[x]+=f[y]; g[x]+=g[y]+(ll)e[i].w*f[y]; ans1=min(ans1,mi[x]+mi[y]+e[i].w); ans2=max(ans2,mx[x]+mx[y]+e[i].w); mi[x]=min(mi[x],mi[y]+e[i].w); mx[x]=max(mx[x],mx[y]+e[i].w); } head[x]=0; }
改为直接枚举最长链的端点的lca,并且充分利用mi,mx数组的“前缀和”性质。orzzzzzzzzzzzzz
然后总和的统计也很简单,分两部分算贡献即可。
代码:
1 #include<cstdio> 2 #include<cstdlib> 3 #include<cmath> 4 #include<cstring> 5 #include<algorithm> 6 #include<iostream> 7 #include<vector> 8 #include<map> 9 #include<set> 10 #include<queue> 11 #include<string> 12 #define inf 1000000000 13 #define maxn 1000000+5 14 #define maxm 100000+5 15 #define eps 1e-10 16 #define ll long long 17 #define pa pair<int,int> 18 #define for0(i,n) for(int i=0;i<=(n);i++) 19 #define for1(i,n) for(int i=1;i<=(n);i++) 20 #define for2(i,x,y) for(int i=(x);i<=(y);i++) 21 #define for3(i,x,y) for(int i=(x);i>=(y);i--) 22 #define for4(i,x) for(int i=head[x],y=e[i].go;i;i=e[i].next,y=e[i].go) 23 #define mod 1000000007 24 using namespace std; 25 inline int read() 26 { 27 int x=0,f=1;char ch=getchar(); 28 while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} 29 while(ch>='0'&&ch<='9'){x=10*x+ch-'0';ch=getchar();} 30 return x*f; 31 } 32 int n,m,id[maxn],cnt,dep[maxn],fa[maxn][21],f[maxn],mi[maxn],mx[maxn]; 33 ll g[maxn]; 34 bool v[maxn]; 35 int ans1,ans2; 36 ll ans; 37 inline int lca(int x,int y) 38 { 39 if(dep[x]<dep[y])swap(x,y); 40 int t=dep[x]-dep[y]; 41 for0(i,20)if(t&(1<<i))x=fa[x][i]; 42 if(x==y)return x; 43 for3(i,20,0)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i]; 44 return fa[x][0]; 45 } 46 struct graph 47 { 48 int head[maxn],tot; 49 struct edge{int go,next,w;}e[2*maxn]; 50 inline void add(int x,int y) 51 { 52 e[++tot]=(edge){y,head[x],0};head[x]=tot; 53 e[++tot]=(edge){x,head[y],0};head[y]=tot; 54 } 55 inline void addd(int x,int y) 56 { 57 e[++tot]=(edge){y,head[x],dep[y]-dep[x]};head[x]=tot; 58 } 59 inline void dfs(int x,int f) 60 { 61 id[x]=++cnt; 62 for1(i,20) 63 if(dep[x]>=1<<i)fa[x][i]=fa[fa[x][i-1]][i-1]; 64 else break; 65 for4(i,x)if(y!=f) 66 { 67 dep[y]=dep[x]+1;fa[y][0]=x; 68 dfs(y,x); 69 } 70 } 71 inline void dfs(int x) 72 { 73 f[x]=v[x];g[x]=0; 74 mi[x]=v[x]?0:inf; 75 mx[x]=v[x]?0:-inf; 76 for4(i,x) 77 { 78 dfs(y); 79 ans+=(g[x]+f[x]*e[i].w)*f[y]+g[y]*f[x]; 80 f[x]+=f[y]; 81 g[x]+=g[y]+(ll)e[i].w*f[y]; 82 ans1=min(ans1,mi[x]+mi[y]+e[i].w); 83 ans2=max(ans2,mx[x]+mx[y]+e[i].w); 84 mi[x]=min(mi[x],mi[y]+e[i].w); 85 mx[x]=max(mx[x],mx[y]+e[i].w); 86 } 87 head[x]=0; 88 } 89 }G1,G2; 90 int a[maxn],sta[maxn],top; 91 inline bool cmp(int x,int y){return id[x]<id[y];} 92 int main() 93 { 94 freopen("input.txt","r",stdin); 95 freopen("output.txt","w",stdout); 96 n=read(); 97 for1(i,n-1)G1.add(read(),read()); 98 G1.dfs(1,0); 99 int T=read(); 100 while(T--) 101 { 102 m=read();ans=0;ans1=inf;ans2=-inf; 103 for1(i,m)a[i]=read(); 104 sort(a+1,a+m+1,cmp); 105 for1(i,m)v[a[i]]=1; 106 sta[top=1]=1;G2.tot=0; 107 for1(i,m) 108 { 109 int x=a[i],f=lca(x,sta[top]); 110 while(dep[f]<dep[sta[top]]) 111 { 112 if(dep[f]>=dep[sta[top-1]]) 113 { 114 G2.addd(f,sta[top--]); 115 if(sta[top]!=f)sta[++top]=f; 116 break; 117 } 118 G2.addd(sta[top-1],sta[top]);top--; 119 } 120 if(sta[top]!=x)sta[++top]=x; 121 } 122 while(--top)G2.addd(sta[top],sta[top+1]); 123 G2.dfs(1); 124 printf("%lld %d %d ",ans,ans1,ans2); 125 for1(i,m)v[a[i]]=0; 126 } 127 return 0; 128 }