首页 > 代码库 > 求逆序对(线段树版)

求逆序对(线段树版)

一个序列a1,a2,a3...aN,求出满足:ai > aj 且 i < j 的个数。

一个最容易想到的方法就是枚举所有的i,j看看是否满足,显然是O(n^2)的复杂度。不够好。

可以这样考虑,开一个数组保存这n个数出现的位置和对应的次数,这个数组要开到a数组里最大的那个数MAX,也就是hash,初始状态数组里没有元素,每个数对应的个数都是0.

如果考虑第i个数,找到比它大的所有的数 的个数,查找的范围即 ai+1~MAX,这就是到i这个位置的逆序对的总和,接着把a[i]这个数添加到数组里,也就是a[i]这个位置的数量加1。一直进行到n结束,逆序对就求了出来。

这样做得复杂度依然是O(n^2),但查找和增加的操作可用线段树解决,这样复杂度就降到了O(nlogn)。

还有一个问题,如果a[i]可以达到10^9甚至更大,数组都开不下,即便开的下,时间上也不能承受,这样就要用到离散化,将n个数映射到1~n的范围内,这个操作排序加二分可轻松解决。所有数控制在n 的范围内,线段树解决是非常理想的。


以POJ2299为例 : 题目就是要求逆序对。详见代码:


#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <set>
#include <stack>
#include <cctype>
#include <algorithm>
#define lson o<<1, l, m
#define rson o<<1|1, m+1, r
using namespace std;
typedef long long LL;
const int maxn = 500500;
const int MAX = 0x3f3f3f3f;
int n, a, b, in[maxn], tt[maxn], fu[maxn], f[maxn];
LL num[maxn<<2];
int bs(int v, int x, int y) {
    while(x < y) {
        int m = (x+y) >> 1;
        if(fu[m] >= v) y = m;
        else x = m+1;
    }
    return x;
}
void up(int o) {
    num[o] = num[o<<1] + num[o<<1|1];
}
void build(int o, int l, int r) {
    num[o] = 0;
    if(l == r) return ;
    int m = (l+r) >> 1;
    build(lson);
    build(rson);
}
void update(int o, int l, int r) {
    if(l == r) {
        num[o]++;
        return ;
    }
    int m = (l+r) >> 1;
    if(a <= m) update(lson);
    else update(rson);
    up(o);
}
LL query(int o, int l, int r) {
    if(a <= l && r <= b) return num[o];
    int m = (l+r) >> 1;
    LL ans = 0;
    if(a <= m) ans += query(lson);
    if(m < b ) ans += query(rson);
    return ans ;
}
int main()
{
    while(cin >> n, n) {
        for(int i = 0; i < n; i++) {
            scanf("%d", &in[i]);
            tt[i] = in[i];  //tt记录原序列
        }
        sort(in, in+n);
        int k = 0;
        fu[k++] = in[0];   //fu为辅助数组
        for(int i = 1; i < n; i++)
            if(in[i] != in[i-1]) fu[k++] = in[i];
        b = 0;
        for(int i = 0; i < n; i++) {  //离散过程,二分
            f[i] = bs(tt[i], 0, k-1);
            b = max(b, f[i]);
        }
        LL ans = 0;
        build(1, 0, b);
        for(int i = 0; i < n; i++) {
            a = f[i] + 1;    // 查询f[i]+1~n的个数,个数就是f[i]当前的逆序对总数
            ans += query(1, 0, b);
            a = f[i];  // 将f[i]添加到数组中
            update(1, 0, b);
        }
        cout << ans << endl;
    }
    return 0;
}