网络升级 「Adera 6」杯省选模拟赛
总时限 16 s $ quad $ 总内存 256 MiB
题目描述
Rainbow所在学校的网络中有 $ n $ 台计算机,由 $ n-1 $ 条电缆相连(即构成树形)。
其中第i条电缆连接 $ a_i、b_i $ 两台计算机,传输时间为ti。
当然,网络中任意两台计算机 $ a、b $ 传输数据所需时间就是a到b的路径上所有电缆的传输时间之和。
网络的效率关键在于传输时间最长的两台计算机之间传输数据所需要的时间,记为 $ μ $ 。
现在Rainbow所在的学校要对网络进行升级,升级的目标就是减小 $ μ $ 的值。
对于第i条电缆,可以花费pi的价钱把它升级为光缆,光缆依然连接 $ a_i $ 和 $ b_i $ ,不过传输时间快到可以忽略不计!
现在学校要选择一些电缆进行升级,使得升级之后μ的值减小的前提下,花费的价钱最少。
输入格式
第一行一个整数 $ n $ 。
接下来n-1行每行四个整数 $ a_i、b_i、t_i、p_i $ 。
输出格式
输出升级之后 $ μ $ 的值减小的前提下,花费的最少价钱。
样例输入1
4
1 2 3 3
1 3 8 33
1 4 3 7
样例输出1
10
样例输入2
4
1 2 3 5
2 3 5 2
3 4 5 4
样例输出2
2
数据范围与约定
对于10%的数据,$ 1 le n le 10 $ 。
对于40%的数据,$ 1 le n le 1000 $ 。
对于100% 的数据,$ 1 le a_i,b_i le n le 100000,1 le t_i,p_i le 10000 $ ,计算机和电缆的编号均从 $ 1 $ 开始。
題解
-
找到直徑中點 $ p $ ,設直徑兩端到 $ p $ 的距離分別爲 $ d_{far},d_{near} $ 。
-
僅考慮所有可能成爲直徑的點和邊構成的樹
-
對於 $ p $ 的子節點 $ s_1,s_2, dots , s_k $ ,樹形動規求切斷 $ s_i $ 與子樹中所有葉子節點聯係的代價
-
若 $ d_{far} = d_{near} $ ,則只留一個 $ s_i $ ,其餘都要切斷
-
若 $ d_{far} > d_{near} $ ,則要麽切斷唯一的一個 $ d_{far} $ ,要麽切斷所有的 $ d_{near} $
代碼
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
using namespace std;
const int u=100010;
int head[u],ver[2*u],Next[2*u],edge[2*u],cost[2*u],dis[u],pre[u],f[u],d[u],a[u],b[u],v[u],c[u],fa[u];
int n,tot,m,t,i,x,y,z,w,ans,ct,temp,now,cnt;
void add(int x,int y,int z,int w)
{
ver[++tot]=y,Next[tot]=head[x],edge[tot]=z,cost[tot]=w,head[x]=tot;
}
int bfs(int s)
{
memset(dis,-1,sizeof(dis));
memset(v,0,sizeof(v));
queue<int> q;
int i,x,y;
q.push(s); dis[s]=pre[s]=0; v[s]=1;
while(q.size())
{
x=q.front(); q.pop();
for(i=head[x];i;i=Next[i])
if(dis[y=ver[i]]<dis[x]+edge[i]&&!v[y])
{
dis[y]=dis[x]+edge[i];
v[y]=1,q.push(y),pre[y]=i;
}
}
for(i=x=1;i<=n;i++) if(dis[i]>dis[x]) x=i;
return x;
}
void center()
{
x=bfs(1);
y=bfs(x);
for(i=pre[y];i;i=pre[ver[i^1]]) a[++t]=i^1;
for(i=1;i<=t;i++)
if(2*dis[ver[a[i]]]<dis[y]&&2*dis[ver[a[i]^1]]>=dis[y]) ct=ver[a[i]^1];
memset(v,0,sizeof(v)); t=0;
}
void dfs(int x)
{
v[x]=1;
for(int i=head[x];i;i=Next[i])
if(!v[ver[i]])
{
dfs(ver[i]),fa[ver[i]]=i;
d[x]=max(d[x],d[ver[i]]+edge[i]);
}
v[x]=0;
}
int dp(int x)
{
v[x]=1;
for(int i=head[x];i;i=Next[i])
if(!v[ver[i]]&&d[ver[i]]+edge[i]==d[x]) f[x]+=dp(ver[i]);
v[x]=0;
if(!f[x]) f[x]=1<<30;
return min(f[x],cost[fa[x]]);
}
void print(int x)
{
v[x]=1;
if(cost[fa[x]]<f[x]) c[++cnt]=fa[x]>>1;
else for(int i=head[x];i;i=Next[i])
if(!v[ver[i]]&&d[ver[i]]+edge[i]==d[x]) print(ver[i]);
}
void solve()
{
dfs(ct);
for(i=head[ct];i;i=Next[i])
if(d[ver[i]]+edge[i]==dis[ct]) a[++m]=ver[i];
else if(d[ver[i]]+edge[i]==dis[y]-dis[ct]) b[++t]=ver[i];
v[ct]=1;
for(i=1;i<=m;i++)
{
now=dp(a[i]);
ans+=now;
if(now>temp) temp=now,x=i;
}
for(i=1,now=0;i<=t;i++) now+=dp(b[i]);
if(dis[ct]!=dis[y]&&temp>now)
{
ans-=temp-now;
for(i=1;i<=m;i++)
if(i!=x) print(a[i]);
for(i=1;i<=t;i++) print(b[i]);
}
else for(i=1;i<=m;i++) print(a[i]);
}
int main()
{
cin>>n;
for(tot=i=1;i<n;i++)
{
scanf("%d%d%d%d",&x,&y,&z,&w);
add(x,y,z,w),add(y,x,z,w);
}
center();
solve();
cout<<ans<<endl;
/*cout<<cnt<<endl;
for(i=1;i<=cnt;i++) printf("%d ",c[i]);*/
return 0;
}
/*
用时
60 ms
占用内存
356 KiB
*/