•参考资料
•题意
给出一棵树,根节点为1。
每条边有一个权值,树上有红色结点 m 个,其花费为 0 ,其余为黑色;
每个黑色结点的花费为其到最近红色祖先的经过的路径权值之和。
有 q 次询问,每次给出一个点集;
问将树上任意一个结点涂成红色结点后,点集中所有点的花费的最大值的最小是多少。
•题解
相关变量解释:
sum : 每次询问中询问的点集个数
a[ ] : 存储每次询问到的点集
costR[i] : 结点 i 距其最近红色祖先的花费
预处理每个点到根的距离cost、到最近红色祖先的距离 costR 和 ST 表。
对于每次询问,将a[ ] 按 costR 从大到小排序,在 0~costR[a[0]] 范围内二分答案;
对所有大于答案的点求它们的公共祖先(利用ST表可以O(1)求两点的公共祖先),将其涂红;
之后计算每个大于答案的点的新花费是否小于答案。
•Code
View Code1 #include<iostream> 2 #include<vector> 3 #include<cstdio> 4 #include<cmath> 5 #include<algorithm> 6 #include<cstring> 7 using namespace std; 8 #define pb push_back 9 #define ll long long 10 #define mem(a,b) (memset(a,b,sizeof a)) 11 const int maxn=1e5+50; 12 13 int n,m,q; 14 //===============Restore Graph============ 15 struct Node 16 { 17 int to; 18 ll w; 19 Node(int to,int w):to(to),w(w){} 20 }; 21 vector<Node >G[maxn]; 22 void addEdge(int u,int v,int w) 23 { 24 G[u].pb(Node(v,w)); 25 G[v].pb(Node(u,w)); 26 } 27 //========================================= 28 int vs[2*maxn];//欧拉序列,范围区间为 [1,total] 29 int depth[2*maxn];//欧拉序列对应的深度序列 30 int pos[maxn];//pos[i] : 结点 i 再欧拉序列中第一次出现的位置 31 ll cost[maxn];//cost[i] : 结点 i 距根据点的距离 32 ll costR[maxn];//costR[i] : 结点 i 距最近红色祖先结点的距离,初始化为 -1 33 int total;//欧拉序列的大小 34 void dfs(int u,int f,int dep,ll dis) 35 { 36 vs[++total]=u; 37 depth[total]=dep; 38 pos[u]=total; 39 cost[u]=dis; 40 for(int i=0;i < G[u].size();++i) 41 { 42 Node e=G[u][i]; 43 if (e.to == f) 44 continue; 45 costR[e.to]=(costR[e.to] == 0 ? 0:costR[u]+e.w); 46 dfs(e.to,u,dep+1,dis+e.w); 47 vs[++total]=u; 48 depth[total]=dep; 49 } 50 } 51 //==================RMQ====================== 52 struct Node2 53 { 54 int mm[2 * maxn]; 55 int dp[2 * maxn][20]; 56 void ST() 57 { 58 int n=total; 59 mm[0] = -1; 60 for (int i = 1; i <= n; i++) 61 { 62 mm[i]=((i&(i-1))==0) ? mm[i - 1] + 1:mm[i - 1]; 63 dp[i][0]=i; 64 } 65 for (int j=1;j <= mm[n];j++) 66 for (int i=1;i+(1<<j)-1 <= n;i++) 67 if(depth[dp[i][j - 1]] < depth[dp[i+(1<<(j-1))][j-1]]) 68 dp[i][j]=dp[i][j-1]; 69 else 70 dp[i][j]=dp[i+(1<<(j-1))][j-1]; 71 } 72 int Lca(int u, int v) 73 { 74 u=pos[u],v=pos[v]; 75 if (u > v) 76 swap(u, v); 77 int k = mm[v-u+1]; 78 if(depth[dp[u][k]] <= depth[dp[v-(1<<k)+1][k]]) 79 return vs[dp[u][k]]; 80 return vs[dp[v-(1<<k)+1][k]]; 81 } 82 }_rmq; 83 //========================================== 84 int a[maxn]; 85 int sum; 86 bool cmp(int a, int b) 87 { 88 return costR[a] > costR[b]; 89 } 90 bool Check(ll x) 91 { 92 if(costR[a[0]] <= x) 93 return true; 94 int lca=a[0]; 95 for(int i=1;i < sum;i++) 96 { 97 if(costR[a[i]] <= x) 98 break; 99 lca=_rmq.Lca(lca,a[i]); 100 } 101 for(int i = 0;i < sum;i++) 102 { 103 if(costR[a[i]] <= x) 104 return true; 105 if(cost[a[i]]-cost[lca] > x) 106 return false; 107 } 108 return true; 109 } 110 void Solve() 111 { 112 dfs(1,-1,0,0); 113 _rmq.ST(); 114 while(q--) 115 { 116 scanf("%d",&sum); 117 for (int i=0;i < sum; i++) 118 scanf("%d",&a[i]); 119 sort(a,a+sum,cmp); 120 ll l=0,r=costR[a[0]]; 121 while(l < r) 122 { 123 ll mid=(l+r)/2; 124 if(Check(mid)) 125 r=mid; 126 else 127 l=mid + 1; 128 } 129 printf("%lld ",l); 130 } 131 } 132 void init() 133 { 134 mem(costR,-1); 135 total=0; 136 for(int i=0;i < maxn;++i) 137 G[i].clear(); 138 } 139 int main() 140 { 141 int t; 142 scanf("%d", &t); 143 while(t--) 144 { 145 init(); 146 scanf("%d%d%d",&n,&m,&q); 147 while(m--) 148 { 149 int red; 150 scanf("%d",&red); 151 costR[red]=0; 152 } 153 costR[1]=0; 154 for(int i=1;i<n;i++) 155 { 156 int u,v,w; 157 scanf("%d%d%d",&u,&v,&w); 158 addEdge(u,v,w); 159 } 160 Solve(); 161 } 162 return 0; 163 }
•出现的问题
1、用 vector 存储图比用 链式前向星存储图要慢
(1)vector :
(2)链式前向星:
2、平常一直在用的RMQ会超时
TLE1 //=====================RMQ=================== 2 struct Node1 3 { 4 int dp[20][2*maxn]; 5 void Preset() 6 { 7 for(int i=0;i < 2*maxn;++i) 8 dp[0][i]=i; 9 } 10 void ST() 11 { 12 int k=log(total)/log(2); 13 for(int i=1;i <= k;++i) 14 for(int j=1;j <= (total-(1<<i)+1);++j) 15 if(depth[dp[i-1][j]] > depth[dp[i-1][j+(1<<(i-1))]]) 16 dp[i][j]=dp[i-1][j+(1<<(i-1))]; 17 else 18 dp[i][j]=dp[i-1][j]; 19 } 20 int Lca(int u,int v) 21 { 22 u=pos[u],v=pos[v]; 23 if(u > v) 24 swap(u,v); 25 int k=log(v-u+1)/log(2); 26 if(depth[dp[k][u]] > depth[dp[k][v-(1<<k)+1]]) 27 return vs[dp[k][v-(1<<k)+1]]; 28 return vs[dp[k][u]]; 29 } 30 }_rmq; 31 //===========================================AC1 //==================RMQ====================== 2 struct Node2 3 { 4 int mm[2 * maxn]; 5 int dp[2 * maxn][20]; 6 void ST() 7 { 8 int n=total; 9 mm[0] = -1; 10 for (int i = 1; i <= n; i++) 11 { 12 mm[i]=((i&(i-1))==0) ? mm[i - 1] + 1:mm[i - 1]; 13 dp[i][0]=i; 14 } 15 for (int j=1;j <= mm[n];j++) 16 for (int i=1;i+(1<<j)-1 <= n;i++) 17 if(depth[dp[i][j - 1]] < depth[dp[i+(1<<(j-1))][j-1]]) 18 dp[i][j]=dp[i][j-1]; 19 else 20 dp[i][j]=dp[i+(1<<(j-1))][j-1]; 21 } 22 int Lca(int u, int v) 23 { 24 u=pos[u],v=pos[v]; 25 if (u > v) 26 swap(u, v); 27 int k = mm[v-u+1]; 28 if(depth[dp[u][k]] <= depth[dp[v-(1<<k)+1][k]]) 29 return vs[dp[u][k]]; 30 return vs[dp[v-(1<<k)+1][k]]; 31 } 32 }_rmq; 33 //==========================================3、cost[ ] 很有用,如果 Check( ) 中不加
if(cost[a[i]]-cost[lca] > x)
return false;会返回 WA,具体为什么,明天再好好想想%%%%%%%%%
分割线:2019.5.8
中石油的这场重现赛又让我回想起了这道题留下的疑惑;
现在再想想这道题,思路清晰了些许;
一些不理解的地方瞬间顿悟了;
ST表处理RMQ中,会多次求解 log2(x),这种算式是比较耗时的,我们预处理出所需的log2(x);
logTwo[i]=log2(i);如何预处理呢?
首先想一下,三位数的二进制数的最大值为 111(2),四位数的二进制数的最小值为 1000(2);
两者的关系是 (111)&(1000) = 0 , 而对于任意三位二进制数 x,y ,(x&y) != 0;
有了这个关系后,就可以这么预处理了:
logTwo[0]=-1; for(int i=1;i <= n;++i) logTwo[i]=(i&(i-1)) == 0 ? logTwo[i-1]+1:logTwo[i-1];这就是之前一直不理解的ST表加速的地方;
•Code
View Code1 #include<bits/stdc++.h> 2 using namespace std; 3 #define ll long long 4 #define mem(a,b) memset(a,b,sizeof(a)) 5 #define INFll 0x3f3f3f3f3f3f3f3f 6 const int maxn=1e5+50; 7 8 int n,m,q; 9 ll C[maxn];///C[i]:节点i到根节点1的花费 10 ll CR[maxn];///CR[i]:节点i到其最近的红色祖先节点的花费 11 int num; 12 int head[maxn]; 13 struct Edge 14 { 15 int to; 16 ll w; 17 int next; 18 }G[maxn<<1]; 19 void addEdge(int u,int v,ll w) 20 { 21 G[num]={v,w,head[u]}; 22 head[u]=num++; 23 } 24 struct LCA 25 { 26 int vs[maxn<<1];///欧拉序列 27 int dep[maxn<<1];///欧拉序列中的节点对应的深度序列 28 int pos[maxn<<1];///pos[i]:节点i在欧拉序列中第一次出现的位置 29 int cnt; 30 int logTwo[maxn<<1];///logTwo[i]:log2(i) 31 int dp[maxn<<1][20];///dp[i][j]:[i,i+2^j-1]深度最小的点的下标(欧拉序列中的下标) 32 void DFS(int u,int f,int depth,ll dist) 33 { 34 vs[++cnt]=u; 35 dep[cnt]=depth; 36 pos[u]=cnt; 37 C[u]=dist; 38 for(int i=head[u];~i;i=G[i].next) 39 { 40 int v=G[i].to; 41 ll w=G[i].w; 42 if(v == f) 43 continue; 44 CR[v]=min(CR[v],CR[u]+w); 45 DFS(v,u,depth+1,dist+w); 46 vs[++cnt]=u; 47 dep[cnt]=depth; 48 } 49 } 50 void ST() 51 { 52 logTwo[0]=-1; 53 for(int i=1;i <= cnt;++i) 54 { 55 dp[i][0]=i; 56 ///:后的语句写错了,刚开始写成了logTwo[i],debug了好一会 57 logTwo[i]=(i&(i-1)) == 0 ? logTwo[i-1]+1:logTwo[i-1]; 58 } 59 for(int k=1;k <= logTwo[cnt];++k) 60 for(int i=1;i+(1<<k)-1 <= cnt;++i) 61 if(dep[dp[i][k-1]] > dep[dp[i+(1<<(k-1))][k-1]]) 62 dp[i][k]=dp[i+(1<<(k-1))][k-1]; 63 else 64 dp[i][k]=dp[i][k-1]; 65 } 66 void lcaInit(int root) 67 { 68 cnt=0; 69 DFS(root,root,0,0); 70 ST(); 71 } 72 int lca(int u,int v)///返回节点u,v的LCA 73 { 74 u=pos[u]; 75 v=pos[v]; 76 77 if(u > v) 78 swap(u,v); 79 80 int k=logTwo[v-u+1]; 81 if(dep[dp[u][k]] > dep[dp[v-(1<<k)+1][k]]) 82 return vs[dp[v-(1<<k)+1][k]]; 83 else 84 return vs[dp[u][k]]; 85 } 86 }_lca; 87 88 int qCnt; 89 int query[maxn<<1]; 90 91 bool Check(ll mid) 92 { 93 int lca=0;///不满足条件的点的LCA 94 for(int i=1;i <= qCnt;++i) 95 { 96 if(CR[query[i]] <= mid) 97 continue; 98 if(lca == 0) 99 lca=query[i]; 100 else/// > mid的点LCA 101 lca=_lca.lca(lca,query[i]); 102 } 103 104 for(int i=1;i <= qCnt;++i) 105 { 106 if(CR[query[i]] <= mid) 107 continue; 108 109 ///如果将lca点涂红后还不能使其 <= mid,返回false 110 if(C[query[i]]-C[lca] > mid) 111 return false; 112 } 113 return true; 114 } 115 void Solve() 116 { 117 _lca.lcaInit(1); 118 119 for(int i=1;i <= q;++i) 120 { 121 scanf("%d",&qCnt); 122 123 ll l=-1,r=0; 124 for(int j=1;j <= qCnt;++j) 125 { 126 scanf("%d",query+j); 127 r=max(r,CR[query[j]]); 128 } 129 130 while(r-l > 1) 131 { 132 ll mid=l+((r-l)>>1); 133 if(Check(mid)) 134 r=mid; 135 else 136 l=mid; 137 } 138 printf("%lld ",r); 139 } 140 } 141 void Init() 142 { 143 num=0; 144 mem(head,-1); 145 mem(CR,INFll);///初始化为最大值 146 } 147 int main() 148 { 149 // freopen("C:\Users\hyacinthLJP\Desktop\in&&out\contest","r",stdin); 150 int test; 151 scanf("%d",&test); 152 while(test--) 153 { 154 Init(); 155 scanf("%d%d%d",&n,&m,&q); 156 for(int i=1;i <= m;++i) 157 { 158 int red; 159 scanf("%d",&red); 160 CR[red]=0; 161 } 162 CR[1]=0; 163 for(int i=1;i < n;++i) 164 { 165 int u,v,w; 166 scanf("%d%d%d",&u,&v,&w); 167 addEdge(u,v,w); 168 addEdge(v,u,w); 169 } 170 Solve(); 171 } 172 return 0; 173 }