首页 > 代码库 > 初涉A*剪枝

初涉A*剪枝

挖坑防忘,天亮补题解。

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <queue>
#include <cmath>
#include <stack>
#include <map>

#pragma comment(linker, "/STACK:1024000000");
#define EPS (1e-8)
#define LL long long
#define ULL unsigned long long LL
#define _LL __int64
#define _INF 0x3f3f3f3f
#define Mod 1000000007

using namespace std;

const int MAXN = 100100;

struct N
{
    int u,v,w,next,ty;
}edge[2*MAXN];

int head[MAXN];

int Top;

void Link(int u,int v,int w,int ty)
{
    edge[Top].ty = ty;
    edge[Top].v = v;
    edge[Top].w = w;
    edge[Top].next = head[u];
    head[u] = Top++;
}

int ans[1010];

int w[1010][1010];

int H[1010];

struct Q
{
    int v,g,h,f;
    bool operator < (const Q &a) const
    {
        return a.f < f;
    }
};

void Init_H(int t)
{
    Q s,f;
    s.f = 0;
    s.v = t;

    memset(H,-1,sizeof(H));
    H[t] = 0;

    priority_queue<Q> q;
    q.push(s);

    while(q.empty() == false)
    {
        f = q.top();
        q.pop();

        H[f.v] = (H[f.v] == -1 ? f.f : H[f.v]);

        for(int p = head[f.v]; p != -1; p = edge[p].next)
        {
            if(edge[p].ty == 2)
            {
                if(H[edge[p].v] == -1)
                {
                    s.v = edge[p].v;
                    s.f = f.f + edge[p].w;
                    q.push(s);
                }
            }
        }
    }
}

void bfs(int h,int e,int k)
{
    Q f,s;
    priority_queue<Q> q;
    f.g = 0;
    f.h = H[h];
    f.f = f.h + f.g;
    f.v = h;

    q.push(f);

    while(q.empty() == false)
    {
        f = q.top();
        q.pop();

        if(f.g != 0)
        {
            w[f.v][++ans[f.v]] = f.g;

            if(f.v == e && ans[f.v] == k)
                return ;
        }


        for(int p = head[f.v]; p != -1; p = edge[p].next)
        {
            if(edge[p].ty == 1)
            {
                s.v = edge[p].v;
                s.h = H[edge[p].v];
                s.g = f.g + edge[p].w;
                s.f = s.h + s.g;
                q.push(s);
            }
        }

    }

}

int main()
{

    int n,m;

    int u,v,ww,i;

    while(scanf("%d %d",&n,&m) != EOF)
    {
        memset(head,-1,sizeof(head));

        Top = 0;

        for(i = 0;i < m; ++i)
        {
            scanf("%d %d %d",&u,&v,&ww);
            Link(u,v,ww,1);
            Link(v,u,ww,2);
        }

        int s,t,k;

        scanf("%d %d %d",&s,&t,&k);

        Init_H(t);

        memset(ans,0,sizeof(ans));

        bfs(s,t,k);

        if(ans[t] < k)
        {
            printf("-1\n");
        }
        else
        {
            printf("%d\n",w[t][k]);
        }

    }
    return 0;
}