首页 > 代码库 > poj 3111 K Best ,二分,牛顿迭代
poj 3111 K Best ,二分,牛顿迭代
poj 3111 K Best
有n个物品的重量和价值分别是wi和vi。从中选出k个物品使得单位重量的价值最大。
题解:
1、二分做法
2、牛顿迭代
效率比较:
二分做法:
转换成判断是否存在选取K个物品的集合S满足下面的条件:
sigma(vi) / sigma(wi) >= x {vi∈S, wi∈S}
--> simga(vi - x*wi) >= 0
这样我们对 yi= vi - x*wi {1<=i<=n}从大到小排序,计算sum(yi) {1<=i<=k}
如果sum(yi){1<=i<=k}>=0 ,则说明 sigma(vi) / sigma(wi) >= x, 成立。
那么我们只要二分x {注意精度},就能找到单位重量价值最大的k个物品。
#include<cstdio> #include<algorithm> using namespace std; const int maxn = 100000 + 10; const int INF = 1e7; int v[maxn], w[maxn]; struct node { double val; int id; bool operator < (const node& rhs) const { return val > rhs.val; } }; node f[maxn]; int ans[maxn]; int n, k; //sigma(vi)/sigma(wi) >= x //-> sigma(vi - x*wi) >= 0 int ok(double x) { for(int i=0; i<n; ++i) { f[i].val = v[i] - x*w[i]; f[i].id = i+1; } sort(f, f+n); double sum = 0; for(int i=0; i<k; ++i) { sum += f[i].val; ans[i] = f[i].id; } return sum >= 0; } int main() { scanf("%d%d", &n, &k); for(int i=0; i<n; ++i) { scanf("%d%d", &v[i], &w[i]); } double l = 0, r = INF; //for(int i=0; i<50; ++i) { while(r-l>1e-8){ double mid = (l+r)/2; if(ok(mid)) l = mid; else r = mid; } for(int i=0; i<k; ++i) { printf("%d", f[i].id); if(i<k-1) printf(" "); else printf("\n"); } return 0; }
牛顿迭代
关于牛顿迭代法详见:点击打开链接
先取前k个元素算出S0 =∑(vi/wi) 作为初始值
然后对每一个元素(n个)求yi=vi-s0*wi
对yi从大到小排序,取前k个元素算出S,
重复上面的运算(每次循环后把S的值赋给S0,然后新一轮循环时S有通过S0计算出来),直到fabs(S-S0)<=eps,满足精度要求。
正确性证明:
证明其正确性,只要证明每次迭代的S都比上一次的大即可,也即迭代过程中S是单调递增的,因为给定的是有限集,故可以肯定,S必存在最大值,即该迭代过程是收敛的。下面证明单调性:
假设上轮得到的S1,则在n个元素中必存在k个元素使S1=∑(vi/wi),变形可得到∑vi-S1*∑wi=0,
现对每个元素求yi=vi-S1*wi,可知必存在k个元素使∑yi=∑vi-s1*∑wi=0, 所以当我们按y排序并取前k个元素作为求其∑y时,其∑y>=0,
然后对和式变形即可得到S1=((∑v-∑y)/∑w)<=(∑v/∑w)=s2,即此迭代过程是∑y是收敛的,当等号成立时,此S即为最大值。
#include<cstdio> #include<cmath> #include<algorithm> using namespace std; const int maxn = 100000 + 10; const double eps = 1e-8; int v[maxn], w[maxn]; struct node { double val; int v, w; int idx; bool operator < (const node& rhs) const { return val > rhs.val; } }; node f[maxn]; int n, k; double Get() { double sumv = 0, sumw = 0; for(int i=0; i<k; ++i) { sumv += f[i].v; sumw += f[i].w; } return sumv/sumw; } int main() { scanf("%d%d", &n, &k); for(int i=0; i<n; ++i) { scanf("%d%d", &f[i].v, &f[i].w); f[i].idx = i+1; } double s1, s2 = Get(); do { s1 = s2; for(int i=0; i<n; ++i) { f[i].val = f[i].v - s1*f[i].w; } sort(f, f+n); s2 = Get(); } while(fabs(s2-s1)>=eps); for(int i=0; i<k; ++i) { printf("%d%c", f[i].idx, " \n"[i==n-1] ); } return 0; }