题目:https://www.luogu.org/problemnew/show/P4178
点分治。如果把每次的 dis 和 K-dis 都离散化,用树状数组找,是O(n*logn*logn),会T7个点。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=4e4+5; int n,hd[N],xnt,to[N<<1],nxt[N<<1],w[N<<1],f[N<<1],siz[N],ans,mn,rt; ll dis[N],tis[N],tp[N<<1],tnt,K; bool vis[N],sj[N]; void add(int x,int y,ll z) { to[++xnt]=y;nxt[xnt]=hd[x];w[xnt]=z;hd[x]=xnt; to[++xnt]=x;nxt[xnt]=hd[y];w[xnt]=z;hd[y]=xnt; } void getrt(int cr,int fa,int s) { siz[cr]=1;int mx=0; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa) { getrt(v,cr,s);siz[cr]+=siz[v];mx=max(mx,siz[v]); } mx=max(mx,s-siz[cr]); if(mx<mn)mn=mx,rt=cr; } void add(int x){for(;x<=tnt;x+=(x&-x))f[x]++;} int query(int x){int ret=0;for(;x;x-=(x&-x))ret+=f[x];return ret;} void dfs(int cr,int fa,ll lj) { dis[cr]=lj;sj[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa) dfs(v,cr,lj+w[i]); } int calc(int cr,ll w) { memset(sj,0,sizeof sj);tnt=0;dfs(cr,0,w); for(int i=1;i<=n;i++) if(sj[i]&&dis[i]<=K) { tis[i]=K-dis[i];tp[++tnt]=dis[i];tp[++tnt]=tis[i]; // printf("dis[%d]=%lld tis[%d]=%lld ",i,dis[i],i,tis[i]); } sort(tp+1,tp+tnt+1);tnt=unique(tp+1,tp+tnt+1)-tp-1; int ret=0; for(int i=1;i<=n;i++) if(sj[i]&&dis[i]<=K) { dis[i]=lower_bound(tp+1,tp+tnt+1,dis[i])-tp; tis[i]=lower_bound(tp+1,tp+tnt+1,tis[i])-tp; // printf("dis[%d]=%lld tis[%d]=%lld ",i,dis[i],i,tis[i]); ret+=query(tis[i]);add(dis[i]); } memset(f,0,sizeof f); return ret; } void solve(int cr,int s) { // printf("rt=%d ",cr); vis[cr]=1; ans+=calc(cr,0); // printf("cr=%d ans=%d ",cr,ans); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]) { ans-=calc(v,w[i]); int ts=(siz[cr]>siz[v]?siz[v]:s-siz[cr]);//-siz[cr]!!! mn=N;getrt(v,0,ts);solve(rt,ts); } } int main() { scanf("%d",&n);int x,y;ll z; for(int i=1;i<n;i++) { scanf("%d%d%lld",&x,&y,&z);add(x,y,z); } scanf("%lld",&K); mn=N;getrt(1,0,n);solve(rt,n); printf("%d ",ans); return 0; }
应当排序后枚举两个指针。(代码中两种方法时间一样)
如果把 ts=s-siz[cr] 写成 ts=s-siz[v] ,就会T7个点(?)!!!
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=4e4+5; int n,hd[N],xnt,to[N<<1],nxt[N<<1],w[N<<1],siz[N],mn,rt,sta[N],top,K,ans; bool vis[N]; void add(int x,int y,int z) { to[++xnt]=y;nxt[xnt]=hd[x];w[xnt]=z;hd[x]=xnt; to[++xnt]=x;nxt[xnt]=hd[y];w[xnt]=z;hd[y]=xnt; } void getrt(int cr,int fa,int s) { siz[cr]=1;int mx=0; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa) { getrt(v,cr,s);siz[cr]+=siz[v];mx=max(mx,siz[v]); } mx=max(mx,s-siz[cr]); if(mx<mn)mn=mx,rt=cr; } void dfs(int cr,int fa,int lj) { sta[++top]=lj; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa) dfs(v,cr,lj+w[i]); } int calc(int cr,int w) { int ret=0;dfs(cr,0,w); // l=1;r=0; // sort(sta+l,sta+r+1); // while(l<=r) // if(sta[l]+sta[r]<=K)ret+=r-l,l++; // else r--; sort(sta+1,sta+top+1);int p=top; for(int i=1;i<=top;i++) { while(sta[p]+sta[i]>K&&p)p--;if(!p)break; ret+=p-(p>=i); } top=0; // printf("cr=%d ret=%d ",cr,ret); return ret>>1; } void solve(int cr,int s) { // printf("rt=%d ",cr); vis[cr]=1; ans+=calc(cr,0); // printf("cr=%d ans=%d ",cr,ans); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]) { ans-=calc(v,w[i]); int ts=(siz[cr]>siz[v]?siz[v]:s-siz[cr]);//s-siz[cr]!!! mn=N;getrt(v,0,ts);solve(rt,ts); } } int main() { scanf("%d",&n); for(int i=1,x,y,z;i<n;i++) { scanf("%d%d%d",&x,&y,&z);add(x,y,z); } scanf("%d",&K); mn=N;getrt(1,0,n);solve(rt,n); printf("%d ",ans); return 0; }