这道题看了题解后才会的。。果然是国家集训队的神仙题,思维独特。
首先若方程$ sum_{i=1}^{n}a_ix_i=k $有非负整数解,那么显然对于每一个$ a_i $方程$ sum_{i=1}^{n}a_ix_i=k $都必有非负整数解。于是若取$ Min=min(a_i) $,那么对于任意$ j in [0,min) $,若对于自然数数$ k $,$ sum_{i=1}^{n}a_ix_i=k (k equiv j (mod Min)) $有解,则对于一切自然数$ B>k(B equiv j (mod Min)) $,方程$ sum_{i=1}^{n}a_ix_i=k $都必有解。因此,我们只需对每一个$ j $求出对应的$ k $值。(我代码中实际求的是$ k/Min $,取$ Min $是为了让状态数尽可能小)
但是,这个值怎么求呢?根据上文,若$ sum_{i=1}^{n}a_ix_i=k $,则我们可以尝试将$ k $分别加上每一个$ a_i $得到新的合法数值。这其实相当于对于每一个$ i in [1,n] $,从$ k $向$ (k+a_i)mod Min $连一条长度为$ a_i $的边。同时,我们要使$ k $尽可能小,所以要在原图中跑最短路求解。
代码:
dijkstra+堆(这个如果dijkstra写得不好在洛谷上会tle)
#include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<ctime> #include<algorithm> #include<queue> #include<vector> #include<map> #define ll long long #define ull unsigned long long #define max(a,b) (a>b?a:b) #define min(a,b) (a<b?a:b) #define lowbit(x) (x& -x) #define mod 1000000000 #define inf 0x3f3f3f3f #define eps 1e-18 #define maxn 5010 inline ll read(){ll tmp=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())tmp=(tmp<<3)+(tmp<<1)+c-'0'; return tmp*f;} inline ll power(ll a,ll b){ll ans=1; for(;b;b>>=1){if(b&1)ans=ans*a%mod; a=a*a%mod;} return ans;} inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;} inline void swap(int &a,int &b){int tmp=a; a=b; b=tmp;} struct data{ int id,dis; friend bool operator < (data a,data b){ return a.dis>b.dis; } }; std::priority_queue<data>heap; struct edge{ int to,nxt,d; }e[6000010]; int fir[500010],mark[500010],dist[500010]; int a[20]; int n,m,tot=0; ll l,r; void add(int x,int y,int z){e[tot].to=y; e[tot].d=z; e[tot].nxt=fir[x]; fir[x]=tot++;} void dij(int S) { memset(dist,0x3f,sizeof(dist)); memset(mark,0,sizeof(mark)); data now; now.id=S; now.dis=0; heap.push(now); while(!heap.empty()){ now=heap.top(); heap.pop(); if(mark[now.id])continue; mark[now.id]=1; dist[now.id]=now.dis; for(int i=fir[now.id];~i;i=e[i].nxt) if(!mark[e[i].to]){ data tmp; tmp.id=e[i].to; tmp.dis=dist[now.id]+e[i].d; heap.push(tmp); } // printf("A "); } } int main() { memset(fir,255,sizeof(fir)); n=read(); l=read(); r=read(); int mn=inf; for(int i=1;i<=n;i++) a[i]=read(),mn=min(mn,a[i]); for(int i=0;i<mn;i++) for(int j=1;j<=n;j++){ int nxt=i+a[j]; add(i,nxt%mn,nxt/mn); } dij(0); // for(int i=0;i<mn;i++)printf("%d %d ",i,dist[i]); ll ans=0; for(int i=0;i<mn;i++) if(1ll*dist[i]*mn+i<=r){ ll L=max(dist[i]*mn+i,l),R=r; if(L/mn<R/mn){ ans+=R/mn-L/mn-1; if(L<=L/mn*mn+i&&L/mn*mn+i<=R)++ans; if(L<=R/mn*mn+i&&R/mn*mn+i<=R)++ans; } else if(L<=L/mn*mn+i&&L/mn*mn+i<=R)++ans; // printf("*** %d %lld %lld %lld ",i,ans,L,R); } printf("%lld ",ans); }
spfa(这个就很清真了,跑得飞快)
#include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<ctime> #include<algorithm> #include<queue> #include<vector> #include<map> #define ll long long #define ull unsigned long long #define max(a,b) (a>b?a:b) #define min(a,b) (a<b?a:b) #define lowbit(x) (x& -x) #define mod 1000000000 #define inf 0x3f3f3f3f #define eps 1e-18 #define maxn 5010 inline ll read(){ll tmp=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())tmp=(tmp<<3)+(tmp<<1)+c-'0'; return tmp*f;} inline ll power(ll a,ll b){ll ans=1; for(;b;b>>=1){if(b&1)ans=ans*a%mod; a=a*a%mod;} return ans;} inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;} inline void swap(int &a,int &b){int tmp=a; a=b; b=tmp;} using namespace std; struct edge{ int to,nxt,d; }e[6000010]; int fir[500010],inq[500010],dist[500010]; int q[10000010]; int a[20]; int n,m,tot=0; ll l,r; void add(int x,int y,int z){e[tot].to=y; e[tot].d=z; e[tot].nxt=fir[x]; fir[x]=tot++;} void spfa(int S) { memset(dist,0x3f,sizeof(dist)); memset(inq,0,sizeof(inq)); int h=1,t=1; q[1]=S; inq[S]=1; dist[S]=0; while(h<=t){ for(int i=fir[q[h]];~i;i=e[i].nxt) if(dist[q[h]]+e[i].d<dist[e[i].to]){ dist[e[i].to]=dist[q[h]]+e[i].d; if(!inq[e[i].to]){ q[++t]=e[i].to; inq[e[i].to]=1; } } inq[q[h++]]=0; } } int main() { memset(fir,255,sizeof(fir)); n=read(); l=read(); r=read(); int mn=inf; for(int i=1;i<=n;i++) a[i]=read(),mn=min(mn,a[i]); for(int i=0;i<mn;i++) for(int j=1;j<=n;j++){ int nxt=i+a[j]; add(i,nxt%mn,nxt/mn); } spfa(0); ll ans=0; for(int i=0;i<mn;i++) if(1ll*dist[i]*mn+i<=r){ ll L=max(dist[i]*mn+i,l),R=r; if(L/mn<R/mn){ ans+=R/mn-L/mn-1; if(L<=L/mn*mn+i&&L/mn*mn+i<=R)++ans; if(L<=R/mn*mn+i&&R/mn*mn+i<=R)++ans; } else if(L<=L/mn*mn+i&&L/mn*mn+i<=R)++ans; } printf("%lld ",ans); }