首页 > 代码库 > poj 3013 Big Christmas Tree (dij+优先队列优化 求最短路)

poj 3013 Big Christmas Tree (dij+优先队列优化 求最短路)

模板

题意:给你一个图,1总是为根,每个边有单位价值,每个点有权重。

每条边的价值 = sum(后继节点权重)*边的单位价值。

求树的最小价值,即构成一棵树的n-1条边的最小价值。


算法:

1、因为每个边的价值都要乘以后来访问的节点的权重,而走到后来访问的点必经过这条边。

实际上总价值就是  到每个点的最短路径*这个点的权重。

2、但是这个题 数据量真的太大了,50000个点,50000条边。

写普通的dij算法tle。

必须加优先队列优化- -

据说spfa也能过,但是spfa算法不稳定- -,一般没有负权,则用优先队列或堆优化的dijkstra算法

应该能解决问题。

3、坑点:点为0或者1时,价值为0,要特判,否则也会tle。


#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#define maxn 50010

const __int64 INF = 10000000000;

using namespace std;

struct node
{
    int to,next,val;
}edge[maxn*2];

int v,head[maxn],c[maxn],cnt;
long long dis[maxn];
bool vis[maxn];
typedef pair<long long,int> PII;
priority_queue<PII, vector<PII> ,greater<PII> > q;

void add(int x,int y,int z)
{
    edge[cnt].to = y;
    edge[cnt].val = z;
    edge[cnt].next = head[x];
    head[x] = cnt++;
}

long long dij()
{
    for(int i=2;i<=v;i++)
        dis[i] = INF;
    while(!q.empty())
        q.pop();
    int sum = 0;
    long long ret = 0;
    long long x;
    int y;
    dis[1] = 0;
    q.push(make_pair(dis[1],1));
    while(!q.empty())
    {
        PII cur = q.top();
        q.pop();
        x = cur.first;
        y = cur.second;
        if(vis[y]) continue;
        vis[y] = true;
        sum++;
        ret += x*c[y];
        for(int i=head[y];i!=-1;i=edge[i].next)
        {
            int u = edge[i].to,p = edge[i].val;
            if(dis[u]>dis[y]+p)
            {
                dis[u] = dis[y]+p;
                q.push(make_pair(dis[u],u));
            }
        }
    }
    if(sum<v) return -1;
    else return ret;
}

int main()
{
    int T,w,a,b,cost;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%d",&v,&w);
        memset(head,-1,sizeof(head));
        cnt = 0;
        for(int i=1;i<=v;i++)
            scanf("%d",&c[i]);
        for(int i=0;i<w;i++)
        {
           scanf("%d%d%d",&a,&b,&cost);
           add(a,b,cost);
           add(b,a,cost);
        }
        if(v<=1)
        {
            printf("0\n");
            continue;
        }

        memset(vis,0,sizeof(vis));
        long long ans = dij();
        if(ans == -1) printf("No Answer\n");
        else printf("%I64d\n",ans);
    }
    return 0;
}