https://www.luogu.com.cn/problem/P6190
矩阵优化的最短路
其实这题里的矩阵运算并不是乘法,由此说明在转移的时候并不一定要往矩阵乘法的原本形式上靠,只要符合结合律能快速“幂”就行
题目描述
C 国由 (n) 座城市与 (m) 条 有向 道路组成,城市与道路都从 (1) 开始编号,经过 (i) 号道路需要 (t_i) 的费用。
现在你要从 (1) 号城市出发去 (n) 号城市,你可以施展最多 (k) 次魔法,使得通过下一条道路时,需要的费用变为原来的相反数,即费用从 (t_i) 变为 (-t_i)。
请你算一算,你至少要花费多少费用才能完成这次旅程。
注意:使用魔法只是改变一次的花费,而不改变一条道路自身的 (t_i)
最终的费用可以为负,并且一个城市可以经过多次(包括 (n) 号城市)。
输入格式
输入的第一行有三个整数,分别代表城市数 (n),道路数 (m) 和魔法次数限制 (k)。
第 (2) 到第 ((m+1)) 行,每行三个整数。第 ((i+1)) 行的整数 (u_i,v_i,t_i)
输出格式
输出一行一个整数表示最小的花费。
先写出dp的方程,(f(p,i,j)) 表示的是用了 (p) 次魔法,从 (i) 到 (j) 的最短路长度
- (p=0),直接 floyd
- (p=1),枚举每个边 (k),则 (f(1,i,j)=min(f(1,i,j),f(0,i,u_k)+f(0,v_k,j)-t_k)),就是在边 (k) 使用了魔法
- (p>1),这时 (k) 表示的是枚举的点,(f(x+1,i,j)=min(f(x+1,i,j),f(x,i,k)+f(1,k,j)))
然后前两种可以直接按式子跑出来,第三种要用矩阵加速
第三种里,这个取 (min) 相当于一般矩阵乘法里的求和,然后 (f(x,i,k)+f(1,k,j)) 里的加号相当于矩阵乘法里的乘号
所以乘一个 (f(1)) 的矩阵,也就能从上一个矩阵转移到下一个矩阵了
然后 很容易猜出(其实证明不难),这个是有结合律的,可以快速幂,给 (f(1)) 矩阵做 (k) “次方”
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<map>
#include<iomanip>
#include<cstring>
#define reg register
#define EN std::puts("")
#define LL long long
inline int read(){
register int x=0;register int y=1;
register char c=std::getchar();
while(c<'0'||c>'9'){if(c=='-') y=0;c=std::getchar();}
while(c>='0'&&c<='9'){x=x*10+(c^48);c=std::getchar();}
return y?x:-x;
}
struct data{
LL a[105][105];
}ans,a;
LL dis[106][106],t[2506];
int u[2506],v[2506];
int n,m;
inline void pre(int k_){
for(reg int k=1;k<=n;k++)
for(reg int i=1;i<=n;i++)
for(reg int j=1;j<=n;j++)
dis[i][j]=std::min(dis[i][j],dis[i][k]+dis[k][j]);
if(!k_) std::printf("%lld",dis[1][n]),std::exit(0);
for(reg int i=1;i<=n;i++)
for(reg int j=1;j<=n;j++) a.a[i][j]=dis[i][j];
for(reg int k=1;k<=m;k++)
for(reg int i=1;i<=n;i++)
for(reg int j=1;j<=n;j++)
a.a[i][j]=std::min(a.a[i][j],dis[i][u[k]]+dis[v[k]][j]-t[k]);
if(k_==1) std::printf("%lld",a.a[1][n]),std::exit(0);
for(reg int i=1;i<=n;i++)
for(reg int j=1;j<=n;j++) ans.a[i][j]=a.a[i][j];
}
inline data mul(data a,data b){
data c;
for(reg int i=1;i<=n;i++)
for(reg int j=1;j<=n;j++) c.a[i][j]=1e18;
for(reg int i=1;i<=n;i++)
for(reg int j=1;j<=n;j++)
for(reg int k=1;k<=n;k++)
c.a[i][j]=std::min(c.a[i][j],a.a[i][k]+b.a[k][j]);
return c;
}
inline void power(int b){
while(b){
if(b&1) ans=mul(ans,a);
b>>=1;a=mul(a,a);
}
}
int main(){
n=read();m=read();int k=read();
std::memset(dis,0x3f,sizeof dis);std::memset(a.a,0x3f,sizeof a.a);
for(reg int i=1;i<=n;i++) dis[i][i]=a.a[i][i]=0;
for(reg int i=1;i<=m;i++){
u[i]=read();v[i]=read();dis[u[i]][v[i]]=t[i]=read();
}
pre(k);
power(k-1);
std::printf("%lld",ans.a[1][n]);
return 0;
}