首页 > 代码库 > POJ 1436 Horizontally Visible Segments(线段树建图+枚举)

POJ 1436 Horizontally Visible Segments(线段树建图+枚举)

题目连接:http://poj.org/problem?id=1436

题意:给一些线段,每个线段有三个值y1, y2, x代表起点为(x, y1),终点为(x, y2)的线段。当从一个线段可以作水平线到另一个线段并且不穿过其他线段时,就称这两个线段时水平可见的。当三个线段可以两两水平可见,就称为形成一个线段三角。问:在这些线段中有多少个这样的线段三角?

分析:可以把每条线段看做是一个点,如果它和其他线段是水平可见的,就将这两个点相连,由于是无向图,就是你能看到我,我也能看到你,所以需要连接两次(map[a][b] = 1; map[b][a] = 1;)。然后问题就是如何生成存放线段关系的图,可以用线段树来求。假如现在遍历到第i条线段,那么第i条线段所能看到的是前面区间最新的线段,如图:


线段5只能看到1、2、4,不能看到3,所以可以用线段树区间修改的方法来生成这个树。

初始化set数组(存放线段的标号)全为-1,当遍历到某条线段时,对线段的那段区间进行查询,如果查询到的区间有值,说明可以看到线段set[o],然后即相连,如果查询到的区间一直是-1,就一直深搜,直到到达叶子节点返回。查询结束之后将当前线段插入到线段树中,即把区间更新。因为都是整数,所以某两条线段可见可能是在0.5的时候可见,所以需要预处理一下线段的区间,将区间扩大至2倍即可。

建完图之后暴力枚举每三个点即可。

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 8000;
struct segment{
	int a, b, x;
	bool operator < (const segment& temp) const {
		return x < temp.x; 
	}
}a[maxn+10];
int set[8*maxn+10];
bool map[maxn+10][maxn+10];
void push_down(int o){
	if (set[o] != -1){
		//将其左孩子和右孩子赋为和父亲一样的标号 
		set[o<<1] = set[o];
		set[o<<1|1] = set[o];
		set[o] = -1;//因为该区间已经不是一条线段全部覆盖了,所以标记重新初始化 
	}
}
void query(int o, int L, int R, int i){
	if (set[o] != -1){//区间有值,说明可以看到set[o],将其连接 
		map[i][set[o]] = 1;
		map[set[o]][i] = 1;
		return;
	}
	if (L == R){//没有值,还是初始状态,到叶子节点返回 
		return;
	}
	int m = (L+R)>>1;
	if (a[i].a <= m) query(o<<1, L, m, i);
	if (a[i].b > m) query(o<<1|1, m+1, R, i);
}
void insert(int o, int L, int R, int i){
	if (a[i].a <= L && a[i].b >= R){
		set[o] = i;//找到区间,将节点的值赋为i,表示这段区间可以看到第i条线段 
		return;
	}
	push_down(o);//节点信息向下更新 
	int m = (L+R)>>1;
	if (a[i].a <= m) insert(o<<1, L, m, i);
	if (a[i].b > m) insert(o<<1|1, m+1, R, i);
}
int main(){
	int T, n;
	scanf("%d", &T);
	while (T--){
		scanf("%d", &n);
		memset(set, -1, sizeof(set));//初始化为-1 
		memset(map, false, sizeof(map));//初始化为不相连 
		int L = (1<<31)-1, R = -1;
		for (int i=0; i<n; i++){
			scanf("%d%d%d", &a[i].a, &a[i].b, &a[i].x);
			a[i].a*=2;//预处理 
			a[i].b*=2;
			L = min(L, a[i].a);
			R = max(R, a[i].b);
		}
		sort(a, a+n);//将线段按照x从小到大排序 
		for (int i=0; i<n; i++){
			query(1, L, R, i);
			insert(1, L, R, i);
		}
		int ans = 0;
		//枚举三个线段,符合条件结果+1 
		for (int i=0; i<n; i++){
			for (int j=i+1; j<n; j++){
				if (map[i][j]){
					for (int k=j+1; k<n; k++){
						if (map[i][k] && map[j][k]){
							ans++;
						}
					}
				}
			}
		}
		printf("%d\n", ans);
	}
	return 0;
}