高斯费马,树上开花

前言

在昨天(2019/1/23)晚上的数论 & 图论爆炸后,作者开始接受线段树的洗礼……


初始约定:

在本文中,很多时候会出现[begin , end]

这表示一段数组上的区间:从begin 到 end

其中,begin为开头,end为结尾。

就像下图:

5jEGTp.png

但值得一提的是,begin在这个区间内,但end不在


何为线段树?

老惯例,先看百度百科:

线段树(Segment Tree)是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
——百度百科

算了……

不看了……

还是看下通俗点的解释吧。

线段树就是一种有点特殊的二叉树(废话)

说它特殊,是因为他的每一个结点不只代表一个数,而是代表一段区间。

就像下图:

kZgrhd.png

在上图中,树上的根节点代表整个序列之和(ps. [0 , 12])

树上每一个结点下面的[begin , end]代表下面的黄色点中相应区间的和。

每一层向下分裂时,都把本层的区间拆为两部分,直到拆到不能再拆为止,不能再拆的部分就成了树的叶子结点,叶子结点代表单个元素。

(ps.这就是为什么树上的叶子结点与下面的黄色点一一对应。)

同样,线段树的每一颗子树也是一颗线段树。


线段树的建立

在使用线段树之前,我们需要先种一颗线段树。

但……

我们如何把像下图一样的数列搬到树上呐?

eLADCd.jpg

很简单

首先我们需要一个结构体,来表示线段树的每一个节点

struct Node{
int num;
int add_tag;
……
}

其中,$num$代表的是节点的值,而$tag$指的是该节点的标记。

标记的类型有很多种,比如add_tag , mul_tag,每种标记都有不同的作用。

现在仔细观察一下,我们可以很容易的发现最底层的叶子节点的值就是原数列上的值,而叶子节点的$begin = end$。

所以我们可以直接修改叶子节点啊!

就像下图:

eLA5Cj.jpg

之后我们进行同样的操作,建立$4$节点的另外一颗子树,而$4$号节点就是这两颗子树的和

就像这样:

eLEpx1.jpg

继续循环建立子树,直到种好一整颗线段树为止

eLEBZT.jpg

这样,我们就建好了一整颗线段树啦OvO

Code

inline int lson(x){     //快速求左子节点
return x << 2;
}

inline int rson(x){ //快速求右子节点
return x << 2 | 1;
}

void pushup(int x){ //父节点的值等于两个子节点的和
node[x].num = node[lson(x)].num + node[rson(x)].num;
}

void build(int x , int l , int r){
if(l == r){
node[x].nun = a[l] //a为原数组
return ;
}
int mid = (l + r) >> 1;
build(lson(x) , l , mid);
build(rson(x) , mid + 1 , r);
pushup(x);
}

线段树的基本操作:

跟其他数据结构一样,线段树也有它的基本操作。下面,我们就来看看。

区间查询

在前面我们已经说过,线段树的每一个节点代表的不仅仅是一个数,而是一段区间

那么……

我们如何去查找一段区间的和呐?

或者说……

我们是不是不用遍历区间内的每一个叶子节点就能查找到区间之和呐?

答案是肯定的。

因为线段树中一个节点代表的是一段区间

因此……

如果两个节点都包含在所要搜索的区间之内,仅仅搜到它们的父节点即可得出答案。

之后把所有搜到的节点的值相加,即可得出答案

可以证明,拆分出区间的数量为$O(log n)$

来看这个图:

kZgrhd.png

下面我们要查找[1 , 6]的区间范围和

首先,我们从根节点开始遍历,在遍历[1 , 1]时,需要遍历到叶子结点。

keXl8I.png

但之后在遍历[2 , 3]时,就没那么麻烦了,直接到父节点遍历即可

keX4R1.png

[4 , 6]更加快捷,直接遍历到第三层即可

kmpys1.png

之后把每一次的结果加起来,就是最后的结果了。

最后提醒一句,查询之前一定要下放标记,即pushdown。下放的解释见代码注释

Code

//  pushdown为标记下放函数
void pushdown(int x , int l , int r){
if(!node[x].add_tag){ //没有标记就直接return
return ;
}
int mid = (l + r) >> 1; //求出左右子树分界,即中间点。
node[lson(x)].add_tag += node[x].add_tag; //为左子节点下放标记
node[lson(x)].num += node[x].add_tag * (mid - l + 1); //更改左子节点的值
node[rson(x)].add_tag += node[x].add_tag; //同上
node[rson(x)].num += node[x].add_tag * (r - mid);
node[x].add_tag = 0 //不要忘记清空标记
}

long long range_query(int x, int l , int r , int ql , int qr){
if(ql <= l && r <= qr){ //如果要查询的区间已经包含这个节点,直接return
return node[x].num;
}
pushdown(x , l , r); //否则,下放标记
int mid = (l + r) >> 1; //左右子树分界
long long ans = 0;
if(ql <= mid){ //如果需要查询的区间在左子树中
ans += range_query(lson(x) , l , mid , ql , qr);
//就在左子树中进行查询
}
if(qr >= mid + 1){ //如果需要查询的区间在右子树中
ans += range_query(rson(x) , mid + 1 , r , ql , qr);
//就在右子树中进行查询
}
return ans;
}

是不是很简单啊(逃)


单点修改

线段树每个节点表示的是一段区间,因此,像普通二叉树那样直接修改节点是不行的。

但是……

线段树的每一个叶子节点仅代表了一个数啊QvQ

所以,我们可以从最简单的叶子节点入手,修改后更新它的每一个父节点,再继续更新父节点的父节点,再继续更新他的父节点的父节点的父节点 …… 直到更新到根节点为止。

每一次修改,更新的节点数最多为树高。

很显然,时间复杂度为$O(log n)$

来看这个图:

kZgrhd.png

现在我们来把[3 , 3]的值改成5。

首先, 我们先遍历到[3 , 3],并把它改为5

km8RTf.png

之后回溯,到它的父节点[2 , 3],并把父节点的值改为71

km820P.png

继续回溯,直到改完根节点为止。下面是成品图:

km8gmt.png

代码的话……就先不写了

直接把下面区间修改的代码搬来用吧。

好像也不难吧(逃 X2)


区间修改

现在,我们已经知道如何修改线段树中单个叶子节点的值了

那……

如果给你一个begin和一个end,要求将[begin , end]中的每个元素都+2,你怎样修改这整个区间的值呐?(ps.一个个修改会TLE哦)

其实……

我们可以把要修改的区间拆分成少量的节点,把修改的数直接修改在这些节点上就可以了。

就像下图:

kZgrhd.png

现在,我们把区间[0 , 5]的值全部改为 1

首先,在第二层,我们把代表[0 , 3]区间的点变成 1

kmtzUx.png

之后,在第三层,我们把代表[4 , 5]区间的点变成 1

kmtXr9.png

最后我们向上回溯,更新上面的节点,直到回溯至根节点。

kmtxV1.png

这样,我们就完成了对线段树区间的修改。

同样,在修改时也一定不要忘了下放与上传标记,即pushdown

Code

//区间加
void range_add(int x , int l , int r , int ql , int qr , long long add){
if(ql <= l && r <= ql){ //如果查询区间直接包含该节点
node[x].add_tag += add; //直接叠加标记
node[x].num = add * (r - l + 1); //加上值
return ;
}
pushdown(x , l , r); //否则,下放标记
int mid = (l + r) >> 1; //求出左右子树的分界
if(ql <= mid){ //如果需要加的区间在左子树中
range_add(lson(x) , l , mid , ql , qr , add); //就加
}
if(qr >= mid + 1){//同上
range_add(rson(x) , mid + 1 , r , ql , qr , add);
}
pushup(x);//上传标记
}


应该不算多难吧(逃 X3)


附件[Segment Tree]

线段树 1

代码描述:见LuoguP3372

Code

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<deque>
#include<algorithm>
#include<vector>
#include<stack>
#include<cmath>
#include<cstdlib>
#include<iomanip>

#define N 500001
#define I_int inline int
#define ll long long

using namespace std;

int n , m;
ll a[N];

namespace Segtree{
struct Node{
ll num;
ll add_tag;
}node[(N << 2) + 1];

I_int lson(int x){
return x << 1;
}

I_int rson(int x){
return x << 1 | 1;
}

void pushup(int x){
node[x].num = node[lson(x)].num + node[rson(x)].num;
}

void pushdown(int x , int l , int r){
if(!node[x].add_tag){
return ;
}
int mid = (l + r) >> 1;
node[lson(x)].add_tag += node[x].add_tag;
node[rson(x)].add_tag += node[x].add_tag;
node[lson(x)].num += node[x].add_tag * (mid - l + 1);
node[rson(x)].num += node[x].add_tag * (r - mid);
node[x].add_tag = 0;
}

void build(int x , int l , int r){
if(l == r){
node[x].num = a[r];
return ;
}
int mid = (l + r) >> 1;
build(lson(x) , l , mid);
build(rson(x) , mid + 1 , r);
pushup(x);
}

void range_add(int x , int l , int r , int ql , int qr , ll v){
if(ql <= l && r <= qr){
node[x].add_tag += v;
node[x].num += v * (r - l + 1);
return ;
}
pushdown(x , l , r);
int mid = (l + r) >> 1;
if(ql <= mid){
range_add(lson(x) , l , mid , ql , qr , v);
}
if(qr >= mid + 1){
range_add(rson(x) , mid + 1 , r , ql , qr , v);
}
pushup(x);
}

ll range_query(int x , int l , int r , int ql , int qr){
if(ql <= l && r <= qr){
return node[x].num;
}
pushdown(x , l , r);
int mid = (l + r) >> 1;
ll ans = 0;
if(ql <= mid){
ans += range_query(lson(x) , l , mid , ql , qr);
}
if(qr >= mid + 1){
ans += range_query(rson(x) , mid + 1 , r , ql , qr);
}
return ans;
}

}

using namespace Segtree;

int main(){
int opt;
ll x , y , k;
cin >> n >> m;
for(int i = 1; i <= n; i ++){
cin >> a[i];
}
build(1 , 1 , n);
for(int i = 1; i <= m; i ++){
cin >> opt;
if(opt == 1){
cin >> x >> y >> k;
range_add(1 , 1 , n , x , y , k);
}
else if(opt == 2){
cin >> x >> y;
cout << range_query(1 , 1 , n , x , y) << endl;
}
}
return 0;
}

线段树 2

代码描述:见LuoguP3373

Code

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<queue>
#include<deque>
#include<stack>
#include<vector>
#include<iomanip>
#include<cstdlib>
#include<cmath>

#define N 500001
#define ll long long
#define I_int inline int

using namespace std;

int n , m , p;
int a[N];

namespace Segtree{
struct Node{
ll num;
ll add_tag;
ll mul_tag;
}node[(N << 2) + 1];

int lson(int x){
return x << 1;
}

int rson(int x){
return x << 1 | 1;
}

void pushup(int x){
node[x].num = node[lson(x)].num + node[rson(x)].num;
node[x].num %= p;
}

void pushdown(int x , int l , int r){
if(node[x].mul_tag != 1){
node[lson(x)].mul_tag *= node[x].mul_tag;
node[lson(x)].mul_tag %= p;
node[lson(x)].add_tag *= node[x].mul_tag;
node[lson(x)].add_tag %= p;
node[lson(x)].num *= node[x].mul_tag;
node[lson(x)].num %= p;

node[rson(x)].mul_tag *= node[x].mul_tag;
node[rson(x)].mul_tag %= p;
node[rson(x)].add_tag *= node[x].mul_tag;
node[rson(x)].add_tag %= p;
node[rson(x)].num *= node[x].mul_tag;
node[rson(x)].num %= p;
node[x].mul_tag = 1;
}
if(node[x].add_tag){
int mid = (l + r) >> 1;
node[lson(x)].add_tag += node[x].add_tag;
node[lson(x)].add_tag %= p;
node[lson(x)].num += node[x].add_tag * (mid - l + 1);
node[lson(x)].num %= p;

node[rson(x)].add_tag += node[x].add_tag;
node[rson(x)].add_tag %= p;
node[rson(x)].num += node[x].add_tag * (r - mid);
node[rson(x)].num %= p;
node[x].add_tag = 0;
}
}

void build(int x , int l , int r){
node[x].mul_tag = 1;
if(l == r){
node[x].num = a[r];
return ;
}
int mid = (l + r) >> 1;
build(lson(x) , l , mid);
build(rson(x) , mid + 1 , r);
pushup(x);
}

void range_add(int x , int l , int r , int ql , int qr , ll v){
if(ql <= l && r <= qr){
node[x].add_tag += v;
node[x].add_tag %= p;
node[x].num += v * (r - l + 1);
node[x].num %= p;
return;
}
pushdown(x , l , r);
int mid = (l + r) >> 1;
if(ql <= mid){
range_add(lson(x) , l , mid , ql , qr , v);
}
if(qr >= mid + 1){
range_add(rson(x) , mid + 1 , r , ql , qr , v);
}
pushup(x);
}

void range_mul(int x , int l , int r , int ql , int qr , ll v){
if(ql <= l && r <= qr){
node[x].mul_tag *= v;
node[x].mul_tag %= p;
node[x].add_tag *= v;
node[x].add_tag %= p;
node[x].num *= v;
node[x].num %= p;
return ;
}
pushdown(x , l , r);
int mid = (l + r) >> 1;
if(ql <= mid){
range_mul(lson(x) , l , mid , ql , qr , v);
}
if(qr >= mid + 1){
range_mul(rson(x) , mid + 1 , r , ql , qr , v);
}
pushup(x);
}
ll range_query(int x , int l , int r , int ql , int qr){
if(ql <= l && qr >= r){
return node[x].num % p;
}
pushdown(x , l , r);
int mid = (l + r) >> 1;
ll ans = 0;
if(ql <= mid){
ans += range_query(lson(x) , l , mid , ql , qr);
ans %= p;
}
if(qr >= mid + 1){
ans += range_query(rson(x) , mid + 1 , r , ql , qr);
ans %= p;
}
return ans % p;
}
}

using namespace Segtree;

int main(){
int opt;
ll x , y , k;
cin >> n >> m >> p;
for(int i = 1; i <= n; i ++){
cin >> a[i];
}
build(1 , 1 , n);
for(int i = 1; i <= m; i ++){
cin >> opt;
if(opt == 1){
cin >> x >> y >> k;
range_mul(1 , 1 , n , x , y , k % p);
}
else if(opt == 2){
cin >> x >> y >> k;
range_add(1 , 1 , n , x , y , k % p);
}
else if(opt == 3){
cin >> x >> y;
cout << range_query(1 , 1 , n , x , y) << endl;
}
}
return 0;
}

参考资料

  1. Menci’s Blog & 课件

鸣谢Menci在线段树学习中给予我的帮助OvO


THE END