首页 > 代码库 > POJ1990 MooFest 树状数组(Binary Indexed Tree,BIT)

POJ1990 MooFest 树状数组(Binary Indexed Tree,BIT)

         N头牛排成一列,每头牛的听力是Vi,每头牛的位置Pi,任意两头牛i,j相互交流时两头牛都至少需要发出声音的大小为max(Vi,Vj) * |Pi-Pj|,求这N头牛两两交流总共发出的声音大小是多少。N,V,P都是1-20000的范围。

        这题首先对Vi从小到大进行排序,排序过后就可以依次计算i,将所有比Vi小的牛到i之间的距离之和乘以Vi得到Ri,然后累加Ri就是最终结果。问题是Ri具体该如何求。

假设听力比Vi小的牛并且位置也比Pi小的牛的个数为Ci,并且这些距离之和为Si,听力比Vi小的所有牛的距离之和为Ti。则Ri = Vi *(Ci* Pi - Si+ Ti -Si - (i - Ci ) * Pi )。Ci,Si我们用树状数组来求,而Ti则可以直接求解。

#include <stdlib.h>
#include <stdio.h>
#include <vector>
#include <math.h>
#include <string.h>
#include <string>
#include <iostream>
#include <queue>
#include <list>
#include <algorithm>
#include <stack>
#include <map>

#include<iostream>  
#include<cstdio>  
using namespace std;

long long Total[20000];
int Count[20000];
long long S[20000];
long long BIT[20001];

struct MyStruct
{
	int pos;
	int v;
};

MyStruct ms[20000];

int compp(const void* a1, const void* a2)
{
	return ((MyStruct*)a1)->v - ((MyStruct*)a2)->v;
}

template <class TYPE>
void BITAdd(TYPE array[], int i, TYPE addvalue, int n)
{
	while (i <= n)
	{
		array[i] += addvalue;
		i += i & -i;
	}
}

template <class TYPE>
TYPE BITGet(TYPE array[], int i)
{
	TYPE ss = 0;
	while (i > 0)
	{
		ss += array[i];
		i -= i & -i;
	}
	return ss;
}

int main()
{
	int N;
#ifdef _DEBUG
	freopen("d:\\in.txt", "r", stdin);
#endif
	scanf("%d", &N);
	for (int i = 0; i < N; i++)
	{
		scanf("%d %d", &ms[i].v, &ms[i].pos);
	}
	qsort(ms, N, sizeof(MyStruct), compp);
	memset(Count, 0, sizeof(Count));
	memset(Total, 0, sizeof(Total));
	memset(BIT, 0, sizeof(BIT));
	memset(S, 0, sizeof(S));
	long long sum = 0;
	for (int i = 0; i < N;i++)
	{
		Count[i] = BITGet<long long>(BIT, ms[i].pos);
		//BITAdd<int>(Count, ms[i].pos, 1, N);
		BITAdd<long long>(BIT, ms[i].pos, 1, 20000);
		if (i != 0)
		{
			Total[i] = sum;
		}
		sum += ms[i].pos;
	}
	memset(BIT, 0, sizeof(BIT));
	for (int i = 0; i < N; i++)
	{
		S[i] = BITGet<long long>(BIT, ms[i].pos);
		BITAdd<long long>(BIT, ms[i].pos, ms[i].pos, 20000);
	}
	long long res = 0;
	for (int i = 1; i < N;i++)
	{
		res += (Count[i] * ms[i].pos - S[i] + Total[i] - S[i] - (i - Count[i]) * ms[i].pos) * ms[i].v;
	}
	printf("%I64d\n", res);
	return 0;
}