线段树
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
对于线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a,(a+b)/2],右儿子表示的区间为[(a+b)/2+1,b]。因此线段树是平衡二叉树,最后的子节点数目为N,即整个线段区间的长度。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,因此有时需要离散化让空间压缩。
线段树有很多模板,而且基本上每道题都是稍稍改动或者根本不需改动就可以直接使用线段树的模板,这里提供几个相对简洁的模板:
//单点替换、单点增减、区间求和、区间最值
#include <cstdio>
#include <algorithm>
using namespace std;
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
const int maxn = 222222;
int MAX[maxn<<2];
int MIN[maxn<<2];
int SUM[maxn<<2];
int max(int a,int b){if(a>b)return a;else return b;}
int min(int a,int b){if(a<b)return a;else return b;}
void PushUP(int rt)
{
MAX[rt] = max(MAX[rt<<1] , MAX[rt<<1|1]);
MIN[rt] = min(MIN[rt<<1] , MIN[rt<<1|1]);
SUM[rt] = SUM[rt<<1] + SUM[rt<<1|1];
}
void build(int l,int r,int rt) {
if (l == r)
{
scanf("%d",&MAX[rt]);
MIN[rt] = MAX[rt];
SUM[rt] = MAX[rt];
//printf("mi = %d\n",MIN[rt]);
// printf("ma = %d\n",MAX[rt]);
return ;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
PushUP(rt);
}
void update(int p,int tihuan,int l,int r,int rt)
{
if (l == r) {
MAX[rt] = tihuan;
MIN[rt] = tihuan;
SUM[rt] = tihuan;
return ;
}
int m = (l + r) >> 1;
if (p <= m) update(p , tihuan ,lson);
else update(p , tihuan , rson);
PushUP(rt);
}
void update1(int p,int add,int l,int r,int rt)
{
if (l == r) {
SUM[rt] = SUM[rt] + add;
return ;
}
int m = (l + r) >> 1;
if (p <= m) update1(p , add ,lson);
else update1(p , add , rson);
PushUP(rt);
}
int query(int L,int R,int l,int r,int rt)
{
if (L <= l && r <= R)
{
return MAX[rt];
}
int m = (l + r) >> 1;
int ret = -1;
if (L <= m) ret = max(ret , query(L , R , lson));
if (R > m) ret = max(ret , query(L , R , rson));
return ret;
}
int query1(int L,int R,int l,int r,int rt)
{
if (L <= l && r <= R)
{
return MIN[rt];
}
int m = (l + r) >> 1;
int ret = 99999;
if (L <= m) ret = min(ret , query1(L , R , lson));
if (R > m) ret = min(ret , query1(L , R , rson));
return ret;
}
int queryhe(int L,int R,int l,int r,int rt)
{
if (L <= l && r <= R)
{
return SUM[rt];
}
int m = (l + r) >> 1;
int ret = 0;
if (L <= m) ret += queryhe(L , R , lson);
if (R > m) ret += queryhe(L , R , rson);
return ret;
}
int main()
{
int n , m;
while (~scanf("%d%d",&n,&m))
{
build(1 , n , 1);
while (m --) {
char op[2];
int a , b;
scanf("%s%d%d",op,&a,&b);
if (op[0] == 'Q') //区间求最大
{
/* for(int i = 1;i<=10;i++)
printf("%d ",MAX[i]);
puts("");*/
printf("%d\n",query(a , b , 1 , n , 1));
}
else if(op[0]=='U') //单点替换
update(a , b , 1 , n , 1);
else if(op[0]=='M')//区间求最小
{
/*for(int i = 1;i<=10;i++)
printf("%d ",MIN[i]);
puts("");*/
printf("%d\n",query1(a , b , 1 , n , 1));
}
else if(op[0]=='H')//区间求和
{
printf("%d\n",queryhe(a , b , 1 , n , 1));
}
else if(op[0]=='S')//单点增加
{
scanf("%d%d",&a,&b);
update1(a , b , 1 , n , 1);
}
else if(op[0]=='E')//单点减少
{
scanf("%d%d",&a,&b);
update1(a , -b , 1 , n , 1);
}
}
}
return 0;
}
//区间替换
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#define max(a,b) (a>b)?a:b
#define min(a,b) (a>b)?b:a
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
#define LL long long
const int maxn = 100100;
using namespace std;
int lazy[maxn<<2];
int sum[maxn<<2];
void PushUp(int rt)//由左孩子、右孩子向上更新父节点
{
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}
void PushDown(int rt,int m) //向下更新
{
if (lazy[rt]) //懒惰标记
{
lazy[rt<<1] = lazy[rt<<1|1] = lazy[rt];
sum[rt<<1] = (m - (m >> 1)) * lazy[rt];
sum[rt<<1|1] = ((m >> 1)) * lazy[rt];
lazy[rt] = 0;
}
}
void build(int l,int r,int rt)//建树
{
lazy[rt] = 0;
if (l== r)
{
scanf("%d",&sum[rt]);
return ;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
PushUp(rt);
}
void update(int L,int R,int c,int l,int r,int rt)//更新
{
//if(L>l||R>r) return;
if (L <= l && r <= R)
{
lazy[rt] = c;
sum[rt] = c * (r - l + 1);
//printf("%d %d %d %d %d\n", rt, sum[rt], c, l, r);
return ;
}
PushDown(rt , r - l + 1);
int m = (l + r) >> 1;
if (L <= m) update(L , R , c , lson);
if (R > m) update(L , R , c , rson);
PushUp(rt);
}
LL query(int L,int R,int l,int r,int rt)
{
if (L <= l && r <= R)
{
//printf("%d\n", sum[rt]);
return sum[rt];
}
PushDown(rt , r - l + 1);
int m = (l + r) >> 1;
LL ret = 0;
if (L <= m) ret += query(L , R , lson);
if (m < R) ret += query(L , R , rson);
return ret;
}
int main()
{
int n , m;
char str[5];
while(scanf("%d%d",&n,&m))
{
build(1 , n , 1);
while (m--)
{
scanf("%s",str);
int a , b , c;
if(str[0]=='T')
{
scanf("%d%d%d",&a,&b,&c);
update(a , b , c , 1 , n , 1);
}
else if(str[0]=='Q')
{
scanf("%d%d",&a,&b);
cout<<query(a,b,1,n,1)<<endl;
}
}
}
return 0;
}
//区间增减
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#define max(a,b) (a>b)?a:b
#define min(a,b) (a>b)?b:a
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
#define LL __int64
const int maxn = 100100;
using namespace std;
LL lazy[maxn<<2];
LL sum[maxn<<2];
void putup(int rt)
{
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}
void putdown(int rt,int m)
{
if (lazy[rt])
{
lazy[rt<<1] += lazy[rt];
lazy[rt<<1|1] += lazy[rt];
sum[rt<<1] += lazy[rt] * (m - (m >> 1));
sum[rt<<1|1] += lazy[rt] * (m >> 1);
lazy[rt] = 0;
}
}
void build(int l,int r,int rt) {
lazy[rt] = 0;
if (l == r)
{
scanf("%I64d",&sum[rt]);
return ;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
putup(rt);
}
void update(int L,int R,int c,int l,int r,int rt)
{
if (L <= l && r <= R)
{
lazy[rt] += c;
sum[rt] += (LL)c * (r - l + 1);
return ;
}
putdown(rt , r - l + 1);
int m = (l + r) >> 1;
if (L <= m) update(L , R , c , lson);
if (m < R) update(L , R , c , rson);
putup(rt);
}
LL query(int L,int R,int l,int r,int rt)
{
if (L <= l && r <= R)
{
return sum[rt];
}
putdown(rt , r - l + 1);
int m = (l + r) >> 1;
LL ret = 0;
if (L <= m) ret += query(L , R , lson);
if (m < R) ret += query(L , R , rson);
return ret;
}
int main()
{
int n , m;int a , b , c;
char str[5];
scanf("%d%d",&n,&m);
build(1 , n , 1);
while (m--)
{
scanf("%s",str);
if (str[0] == 'Q')
{
scanf("%d%d",&a,&b);
printf("%I64d\n",query(a , b , 1 , n , 1));
}
else if(str[0]=='C')
{
scanf("%d%d%d",&a,&b,&c);
update(a , b , c , 1 , n , 1);
}
}
return 0;
}
最简单的应用就是记录线段是否被覆盖,并随时查询当前被覆盖线段的总长度。那么此时可以在结点结构中加入一个变量int count;代表当前结点代表的子树中被覆盖的线段长度和。这样就要在插入(删除)当中维护这个count值,于是当前的覆盖总值就是根节点的count值了。
另外也可以将count换成bool cover;支持查找一个结点或线段是否被覆盖。
实际上,通过在结点上记录不同的数据,线段树还可以完成很多不同的任务。例如,如果每次插入操作是在一条线段上每个位置均加k,而查询操作是计算一条线段上的总和,那么在结点上需要记录的值为sum。
这里会遇到一个问题:为了使所有sum值都保持正确,每一次插入操作可能要更新O(N)个sum值,从而使时间复杂度退化为O(N)。
解决方案是Lazy思想:对整个结点进行的操作,先在结点上做标记,而并非真正执行,直到根据查询操作的需要分成两部分。
根据Lazy思想,我们可以在不代表原线段的结点上增加一个值toadd,即为对这个结点,留待以后执行的插入操作k值的总和。对整个结点插入时,只更新sum和toadd值而不向下进行,这样时间复杂度可证明为O(logN)。
对一个toadd值为0的结点整个进行查询时,直接返回存储在其中的sum值;而若对toadd不为0的一部分进行查询,则要更新其左右子结点的sum值,然后把toadd值传递下去,再对这个查询本身,左右子结点分别递归下去。时间复杂度也是O(nlogN)。
例题
- hdu 1754 I Hate It
题意:给出一个学生成绩的序列,有两个操作:1.修改一个学生的成绩,2.查询学生A到学生B之间所有学生的最高分。 典型的单点更新求区间最值问题,直接用线段树模板即可。
代码如下:
#include<stdio.h>
#include<string.h>
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define max 200010
int N,M;
int sum[max<<2];
int Max(int a,int b)
{
return a>b?a:b;
}
void PushUp(int rt)
{
sum[rt]=Max(sum[rt<<1],sum[rt<<1|1]);
}
void build(int l,int r,int rt)
{
if(l==r)
{
scanf("%d",&sum[rt]);
return ;
}
int m=(l+r)>>1;
build(lson);
build(rson);
PushUp(rt);
}
int query(int L,int R,int l,int r,int rt)
{
if(l>=L&&r<=R)
{
return sum[rt];
}
int maxn=0;
int m=(l+r)>>1;
if(L<=m)
maxn=Max(maxn,query(L,R,lson));
if(R>m)
maxn=Max(maxn,query(L,R,rson));
return maxn;
}
void update(int L,int num,int l,int r,int rt)
{
if(r==l)
{
sum[rt]=num;
return ;
}
int m=(l+r)>>1;
if(L<=m)
update(L,num,lson);
else
update(L,num,rson);
PushUp(rt);
}
int main()
{
char op[2];
int A,B;
while(scanf("%d%d",&N,&M)!=EOF)
{
memset(sum,0,sizeof(sum));
build(1,N,1);
for(int i=0;i<M;i++)
{
scanf("%s%d%d",op,&A,&B);
if(op[0]=='Q')
{
printf("%d\n",query(A,B,1,N,1));
}
else
{
update(A,B,1,N,1);
}
}
}
return 0;
}
题意参见树状数组例题部分。
需要用到线段树的,update:成段增减,query:区间求和
介绍Lazy思想:lazy-tag思想,记录每一个线段树节点的变化值,当这部分线段的一致性被破坏我们就将这个变化值传递给子区间,大大增加了线段树的效率。
在此通俗的解释Lazy意思,比如现在需要对[a,b]区间值进行加c操作,那么就从根节点[1,n]开始调用update函数进行操作,如果刚好执行到一个子节点,它的节点标记为rt,这时tree[rt].l == a && tree[rt].r == b 这时我们可以一步更新此时rt节点的sum[rt]的值,sum[rt] += c * (tree[rt].r - tree[rt].l + 1),注意关键的时刻来了,如果此时按照常规的线段树的update操作,这时候还应该更新rt子节点的sum[]值,而Lazy思想恰恰是暂时不更新rt子节点的sum[]值,到此就return,直到下次需要用到rt子节点的值的时候才去更新,这样避免许多可能无用的操作,从而节省时间 。
代码如下:
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 100005;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
__int64 sum[N<<2],add[N<<2];
struct Node
{
int l,r;
int mid()
{
return (l+r)>>1;
}
} tree[N<<2];
void PushUp(int rt)
{
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}
void PushDown(int rt,int m)
{
if(add[rt])
{
add[rt<<1] += add[rt];
add[rt<<1|1] += add[rt];
sum[rt<<1] += add[rt] * (m - (m>>1));
sum[rt<<1|1] += add[rt] * (m>>1);
add[rt] = 0;
}
}
void build(int l,int r,int rt)
{
tree[rt].l = l;
tree[rt].r = r;
add[rt] = 0;
if(l == r)
{
scanf("%I64d",&sum[rt]);
return ;
}
int m = tree[rt].mid();
build(lson);
build(rson);
PushUp(rt);
}
void update(int c,int l,int r,int rt)
{
if(tree[rt].l == l && r == tree[rt].r)
{
add[rt] += c;
sum[rt] += (__int64)c * (r-l+1);
return;
}
if(tree[rt].l == tree[rt].r) return;
PushDown(rt,tree[rt].r - tree[rt].l + 1);
int m = tree[rt].mid();
if(r <= m) update(c,l,r,rt<<1);
else if(l > m) update(c,l,r,rt<<1|1);
else
{
update(c,l,m,rt<<1);
update(c,m+1,r,rt<<1|1);
}
PushUp(rt);
}
__int64 query(int l,int r,int rt)
{
if(l == tree[rt].l && r == tree[rt].r)
{
return sum[rt];
}
PushDown(rt,tree[rt].r - tree[rt].l + 1);
int m = tree[rt].mid();
__int64 res = 0;
if(r <= m) res += query(l,r,rt<<1);
else if(l > m) res += query(l,r,rt<<1|1);
else
{
res += query(l,m,rt<<1);
res += query(m+1,r,rt<<1|1);
}
return res;
}
int main()
{
int n,m;
while(~scanf("%d %d",&n,&m))
{
build(1,n,1);
while(m--)
{
char ch[2];
scanf("%s",ch);
int a,b,c;
if(ch[0] == 'Q')
{
scanf("%d %d", &a,&b);
printf("%I64d\n",query(a,b,1));
}
else
{
scanf("%d %d %d",&a,&b,&c);
update(c,a,b,1);
}
}
}
return 0;
}
题意:求矩形的面积并
题解: 求矩形的并,由于矩形的位置可以多变,因此矩形的面积一下子不好求,这个时候,可以采用“分割”的思想,即把整块的矩形面积分割成几个小矩形的面积,然后求和就行了。
这里我们可以这样做,把每个矩形投影到y坐标轴上来,然后我们可以枚举矩形的 x 坐标,然后检测当前相邻x坐标上y方向的合法长度,两种相乘就是面积。然后关键就是如何用线段树来维护那个 “合法长度”
线段树的节点这样定义
struct node {
int left,right,cov;
double len;
}
cov 表示当前节点区间是否被覆盖,len 是当前区间的合法长度
然后我们通过“扫描线”的方法来进行扫描,枚举 x 的竖边,矩形的左边那条竖边就是入边,右边那条就是出边了。然后把所有这些竖边按照 x 坐标递增排序,每次进行插入操作,由于坐标不一定为整数,因此需要进行离散化处理。每次插入时如果当前区间被完全覆盖,那么就要对 cov 域进行更新。入边 +1 出边 -1,更新完毕后判断当前节点的 cov 域是否大于 0
,如果大于 0,那么当前节点的 len 域就是节点所覆盖的区间。否则,如果是叶子节点,则 len=0。
如果内部节点,则 len=左右儿子的 len 之和。
代码如下:
#include <algorithm>
#include <iostream>
using namespace std;
#define L(x) ( x << 1 )
#define R(x) ( x << 1 | 1 )
double y[1000];
struct Line
{
double x, y1, y2;
int flag;
} line[300];
struct Node
{
int l, r, cover;
double lf, rf, len;
} node[1000];
bool cmp ( Line a, Line b )
{
return a.x < b.x;
}
void length ( int u )
{
if ( node[u].cover > 0 )
{
node[u].len = node[u].rf - node[u].lf;
return;
}
else if ( node[u].l + 1 == node[u].r )
node[u].len = 0; /* 叶子节点,len 为 0 */
else
node[u].len = node[L(u)].len + node[R(u)].len;
}
void build ( int u, int l, int r )
{
node[u].l = l; node[u].r = r;
node[u].lf = y[l]; node[u].rf = y[r];
node[u].len = node[u].cover = 0;
if ( l + 1 == r ) return;
int mid = ( l + r ) / 2;
build ( L(u), l, mid );
build ( R(u), mid, r );
}
void update ( int u, Line e )
{
if ( e.y1 == node[u].lf && e.y2 == node[u].rf )
{
node[u].cover += e.flag;
length ( u );
return;
}
if ( e.y1 >= node[R(u)].lf )
update ( R(u), e );
else if ( e.y2 <= node[L(u)].rf )
update ( L(u), e );
else
{
Line temp = e;
temp.y2 = node[L(u)].rf;
update ( L(u), temp );
temp = e;
temp.y1 = node[R(u)].lf;
update ( R(u), temp );
}
length ( u );
}
int main()
{
//freopen("a.txt","r",stdin);
int n, t, i, Case = 0;
double x1, y1, x2, y2, ans;
while ( scanf("%d",&n) && n )
{
for ( i = t = 1; i <= n; i++, t++ )
{
scanf("%lf%lf%lf%lf",&x1, &y1, &x2, &y2 );
line[t].x = x1;
line[t].y1 = y1;
line[t].y2 = y2;
line[t].flag = 1;
y[t] = y1;
t++;
line[t].x = x2;
line[t].y1 = y1;
line[t].y2 = y2;
line[t].flag = -1;
y[t] = y2;
}
sort ( line + 1, line + t, cmp );
sort ( y + 1, y + t );
build ( 1, 1, t-1 );
update ( 1, line[1] );
ans = 0;
for ( i = 2; i < t; i++ )
{
ans += node[1].len * ( line[i].x - line[i-1].x );
update ( 1, line[i] );
}
printf ( "Test case #%d\n", ++Case );
printf ( "Total explored area: %.2lf\n\n", ans );
}
return 0;
}
二维线段树
与一维线段树类似,把线段树的每一个区间端点想象为一棵新的线段树。我们可以用树套树的方式实现,即每个外层线段树的节点对应于一颗内层线段树。如果外层线段树根对应的区间是x方向的[1,n],那么内层线段树根节点对应的区间是y方向的[1,m],那么整个线段是可以存在一个n行m列的二维数组中。
也可以用一个外层线段树节点力存一颗内层线段树的方式来实现。
所有性质与线段树类似,插入,删除,查找等时间复杂度为O(logn * logm)
我们用一道例题来详细解释二维线段树的用法:
题意:每次操作可以是编辑某个矩形区域,这个区域的0改为1,1改为0,每次查询只查询某一个点的值是0还是1.
使用二维线段树,在修改的时候只需要先找到第一维的对应区间,在在这个区间的弟二维中查找对应区间,再做修改即可。而查找的时候,由于不同的第一维区间可能会有包含关系,所以需要对每个目标所在第一维区间查找第二维区间。
比如线段树的区间大小是3×3,那么在查找第一维区间是[1,2],第二维区间是[1,2]时,就需要在线段树第一维的[1,3]和[1,2]两个区间对第二维进行查找,因为修改操作的时候可能修改了第一维的[1,3]区间,同时也修改了[1,2]区间,这样的话就不能仅仅只查找某一个第一维的区间。
至于本题的解法,我们可以在修改时标记某一个节点,那么这个节点以下的区间就都是要修改的,当我们在查找的时候,只需要统计查找到这个点时,一路上有多少个被修改的区间,是偶数说明呗修改回来了,是奇数那就是被修改了。
代码如下:
#include <stdio.h>
#include <string.h>
#define xlson kx<<1, xl, mid
#define xrson kx<<1|1, mid+1, xr
#define ylson ky<<1, yl, mid
#define yrson ky<<1|1, mid+1, yr
#define MAXN 1005
#define mem(a) memset(a, 0, sizeof(a))
bool tree[MAXN<<2][MAXN<<2];
int X, N, T;
int num, X1, X2, Y1, Y2;
char ch;
void editY(int kx,int ky,int yl,int yr)
{
if(Y1<=yl && yr<=Y2)
{
tree[kx][ky] = !tree[kx][ky];
return ;
}
int mid = (yl+yr)>>1;
if(Y1 <= mid) editY(kx,ylson);
if(Y2 > mid) editY(kx,yrson);
}
void editX(int kx,int xl,int xr)
{
if(X1<=xl && xr<=X2)
{
editY(kx,1,1,N);
return ;
}
int mid = (xl+xr)>>1;
if(X1 <= mid) editX(xlson);
if(X2 > mid) editX(xrson);
}
void queryY(int kx,int ky,int yl,int yr)
{
if(tree[kx][ky]) num ++;
if(yl==yr) return ;
int mid = (yl+yr)>>1;
if(Y1 <= mid) queryY(kx,ylson);
else queryY(kx,yrson);
}
void queryX(int kx,int xl,int xr)
{
queryY(kx,1,1,N);
if(xl==xr) return ;
int mid = (xl+xr)>>1;
if(X1 <= mid)queryX(xlson);
else queryX(xrson);
}
int main()
{
while(~scanf("%d", &X))while(X--)
{
mem(tree);
scanf("%d %d%*c", &N,&T);
for(int i=0;i<T;i++)
{
scanf("%c %d %d%*c",&ch,&X1,&Y1);
if(ch == 'C')
{
scanf("%d %d%*c", &X2, &Y2);
editX(1,1,N);
}
else
{
num = 0;
queryX(1,1,N);
if(num & 1)printf("1\n");
else printf("0\n");
}
}
if(X) printf("\n");
}
return 0;
}
其实使用二维数组也可以解这道题,只是把线段树部分换为树状数组。代码更为短小简洁。
代码如下:
#include<cstdio>
#include<string>
#include<iostream>
#define N 1005
int c[N][N],n;
int bit(int n)
{
return n&(-n);
}
int sum(int x,int y)
{
int ans=0;
for(int i=x;i>0;i-=bit(i))
for(int j=y;j>0;j-=bit(j))
{
ans+=c[i][j];
}
return ans;
}
void up(int x,int y,int k)
{
for(int i=x;i<=n;i+=bit(i))
for(int j=y;j<=n;j+=bit(j))
{
c[i][j]+=k;
}
}
int main()
{
int cc,t,x1,x2,y1,y2,a,b;
char ch;
scanf("%d",&cc);
for(int i=0;i<cc;i++)
{
memset(c,0,sizeof(c));
scanf("%d%d",&n,&t);
getchar();
for(int j=0;j<t;j++)
{
scanf("%c",&ch);
if(ch=='C')
{
scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
getchar();
up(x1,y1,1);
up(x1,y2+1,1);
up(x2+1,y1,1);
up(x2+1,y2+1,1);
}
else
{
scanf("%d%d",&a,&b);
getchar();
printf("%d\n",sum(a,b)%2);
}
}
printf("\n");
}
return 0;
}