同余最短路&转圈背包算法学习笔记(超详细)
一、问题引入
当你想要解决一个完全背包计数问题,但是 \(M\) 的范围太大,那么你就可以使用同余最短路。
二、算法推导过程
首先对于一个完全背包计数问题,我们要知道如果 \(x\) 这个数能凑出来,那么 \(x+a_i,x+2a_i,x+3a_i,\dots\) 一定都能凑出来,所以说,我们随便找一个 \(a_i\),找到对于 \([0,a_i-1]\) 中的每个数 \(k\) 的最小的能凑出来的数并且这个数模 \(a_i\) 等于 \(k\),这个时候从最小的可以想到最短路,于是我们就有了一个大胆的想法,首先为了节省时间复杂度,把模数设定成最小的 \(a_i\),然后进行最短路,最短路的过程就是不断尝试增加其它的 \(a_i\),然后再取模,你会发现,这是对的!哦对,刚刚只是找到了对于 \([0,a_i-1]\) 中的每个数 \(k\) 的最小的能凑出来的数并且这个数模 \(a_i\) 等于 \(k\),统计的话还要统计能加多少次 \(a_i\)。
时间复杂度:\(O(E \log V) = O(a_i(n-1) \log a_i)\)。
三、同余最短路模板
#include<bits/stdc++.h>
using namespace std;
const int N = ;//数据范围
int a[N];
int f[N];
int vis[N];
struct node
{
int x;
int w;
bool operator<(const node&a)const
{
return w>a.w;
}
};
signed main()
{
int n;
long long m;
scanf("%d %lld",&n,&m);
for(int i = 1;i<=n;i++)
{
scanf("%d",&a[i]);
}
sort(a+1,a+n+1);
memset(f,0x3f,sizeof(f));
f[0] = 0;
priority_queue<node>q;
q.push({0,0});
while(q.size())
{
node x = q.top();
q.pop();
if(vis[x.x])
{
continue;
}
vis[x.x] = 1;
for(int i = 2;i<=n;i++)
{
int v = (x.x+a[i])%a[1];
if(f[v]>f[x.x]+a[i])
{
f[v] = f[x.x]+a[i];
q.push({v,f[v]});
}
}
}
long long ans = 0;
for(int i = 0;i<a[1];i++)
{
if(f[i]<=m)
{
ans+=(m-f[i])/a[1]+1;
}
}
printf("%lld",ans-1);
return 0;
}
注意:这只是一个板子,应用时请随机应变。
四、同余最短路例题
U553673 硬币问题
同余最短路模板题,代码放上供参考:
#include<bits/stdc++.h>
using namespace std;
const int N = 1e6+5;
int a[N];
long long f[N];
int vis[N];
struct node
{
int x;
int w;
bool operator<(const node&a)const
{
return w>a.w;
}
};
signed main()
{
int n;
long long m;
scanf("%d %lld",&n,&m);
for(int i = 1;i<=n;i++)
{
scanf("%d",&a[i]);
}
sort(a+1,a+n+1);
memset(f,0x3f,sizeof(f));
f[0] = 0;
priority_queue<node>q;
q.push({0,0});
while(q.size())
{
node x = q.top();
q.pop();
if(vis[x.x])
{
continue;
}
vis[x.x] = 1;
for(int i = 2;i<=n;i++)
{
int v = (x.x+a[i])%a[1];
if(f[v]>f[x.x]+a[i])
{
f[v] = f[x.x]+a[i];
q.push({v,f[v]});
}
}
}
long long ans = 0;
for(int i = 0;i<a[1];i++)
{
if(f[i]<=m)
{
ans+=(m-f[i])/a[1]+1;
}
}
printf("%lld",ans-1);
return 0;
}
结果稀里糊涂地拿到了最优解(不要看最快的提交,因为数据很水然后我没开 long long 结果过了,所以看第二快的提交)……
P3403 跳楼机
只需要把 \(m\) 减一(因为题目是从第 \(1\) 层开始,然而我们的代码是从第 \(0\) 层),然后正常套模板就行了。
代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+5;
int a[N];
long long f[N];
int vis[N];
struct node
{
int x;
long long w;
bool operator<(const node&a)const
{
return w>a.w;
}
};
signed main()
{
int n = 3;
long long m;
scanf("%lld",&m);
m--;
for(int i = 1;i<=n;i++)
{
scanf("%d",&a[i]);
}
sort(a+1,a+n+1);
memset(f,0x3f,sizeof(f));
f[0] = 0;
priority_queue<node>q;
q.push({0,0});
while(q.size())
{
node x = q.top();
q.pop();
if(vis[x.x])
{
continue;
}
vis[x.x] = 1;
for(int i = 2;i<=n;i++)
{
int v = (x.x+a[i])%a[1];
if(f[v]>f[x.x]+a[i])
{
f[v] = f[x.x]+a[i];
q.push({v,f[v]});
}
}
}
long long ans = 0;
for(int i = 0;i<a[1];i++)
{
if(f[i]<=m)
{
ans+=(m-f[i])/a[1]+1;
}
}
printf("%lld",ans);
return 0;
}
由于作者水平不行,所以只能做这些题,后面还会有更多例题,敬请期待!
转圈背包由于太难,作者理解之后再更新!
但是似乎我发现转圈背包虽然理论上比同余最短路快,但是实际上效率远不如同余最短路,特别是在大数据,这是因为迪杰斯特拉在不被卡的情况下时间复杂度比 \(O(E \log V)\) 快得多。
五、转圈背包
转圈背包是同余最短路的另一种写法,这种写法时间复杂度更优,但是理解起来比较有难度。
首先我们重新定义 \(dp\) 值表示的是同余最短路中的最短路数组,那么在模 \(w\) 的意义下,面值为 \(w_b\) 的物品最多被用几次就没有意义了?你可能认为最多被用 \(w-1\) 次,虽然对了,但是这个上限不够好,其实最好的上限应该是 \(\frac{w}{\gcd(w,w_b)}-1\),也就是说用 \(\frac{w}{\gcd(w,w_b)}\) 次 \(w_b\) 就会出现 \(w\) 的倍数,至于为什么,见下:
然后假定 \(g = \gcd(w,w_b)\),那么:
因为 \(\gcd(n,m) = 1\),那么根据数论小知识:
那么 \(k\) 的最小取值就是 \(n\),也就是 \(\frac{w}{\gcd(w,w_b)}\)。
于是,我们就可以设计转移了(转移比较复杂,这里给图):
你会发现时间复杂度还是太高,考虑优化。
我们把所有这个物品用的次数中相同余数的点缩成一个点,就变成了这样:
我们只是转化了一下,时间复杂度并没有变,那么该如何优化呢?由于我们从任意一个余数出发原本只能走 \(\gcd(w,w_b)-1\) 次就停了下来,这个时候转圈背包的神奇之处就出现了,我们从每个环中最小的余数开始走,走 \(\gcd(w,w_b)-1\) 次之后不用停下来,继续走,然后每个点的状态就是这样的:
这个时候你可能会认为每个点对应的状态和刚刚的原始转移不一样,但是每个点统计的答案都起码大于等于 \(3\) 了,由于用超过 \(3\) 次就没用了,所以说这个转移方法转移次数不仅更优,而且也不会影响到答案。
这就是——大名鼎鼎的转圈背包!至于为什么叫转圈背包,那是因为这个转移很像在转圈。
哦对,对于每一个物品,它对应的余数环并不止一个,而是有 \(\gcd(w,w_b)\) 个,因为他总共有 \(w\) 个余数,然后每一个环的长度是 \(\frac{w}{\gcd(w,w_b)}\),所以环的数量就是 \(w \div \frac{w}{\gcd(w,w_b)} = \gcd(w,w_b)\)。
时间复杂度是 \(O((n-1)w)\)。
六、转圈背包模板
#include<bits/stdc++.h>
using namespace std;
const int N = ;//数据范围
int a[N];
long long f[N];
signed main()
{
int n;
long long m;
scanf("%d %lld",&n,&m);
for(int i = 1;i<=n;i++)
{
scanf("%d",&a[i]);
}
sort(a+1,a+n+1);//小优化1,取最小的点,这样余数就会变少
sort(a+2,a+n+1,greater<int>());//小优化2,我不知道原因
memset(f,0x3f,sizeof(f));
f[0] = 0;
for(int i = 2;i<=n;i++)
{
for(int j = 0,quan = __gcd(a[i],a[1])-1;j<=quan;j++)
{
for(int k = j,num = 0;num<2;num+=(k == j))
{
int o = (k+a[i])%a[1];
f[o] = min(f[o],f[k]+a[i]);
k = o;
}
}
}
long long ans = 0;
for(int i = 0;i<a[1];i++)
{
if(f[i]<=m)
{
ans+=(m-f[i])/a[1]+1;
}
}
printf("%lld",ans-1);
return 0;
}
注意:这只是一个板子,应用时请随机应变。