0
点赞
收藏
分享

微信扫一扫

Codeforces Round #742 (Div. 2)E. Non-Decreasing Dilemma(线段树统计上升序列个数)


题解:线段树内每个节点维护区间内上升序列个数sum,以左端点为起点的最长上升子序列长度(记为ll)和以右端点为起点的最长上升子序列长度(记为rr)。
关于维护信息:pushup合并左右儿子区间到父节点时,如果左儿子右端点大于右儿子左端点则父节点的sum即左右儿子的sum之和,否则还要加上左儿子右端跟右儿子左端组合的情况(即ll*rr)。
关于ll 和 rr 的维护这里就不赘述了,详见代码。
AC代码:

#include <bits/stdc++.h>
#define int long long
#define x first
#define y second
using namespace std;
typedef pair<int,int> PII;
const int N=1e6+5,INF=1e16;
int n,m;
int w[N];

struct node{
int l,r;
int sum;//sum表示当前区间内上升序列的个数
int ll,rr;//ll(rr)分别表示以左端点开头(右端点结尾)的最长上升序列的长度
}tr[N*2];

int get_len(node t){
return t.r-t.l+1;
}

void pushup(int u){
auto l=tr[u<<1],r=tr[u<<1|1];

//更新sum
tr[u].sum=l.sum+r.sum;
if(w[l.r]<=w[r.l])tr[u].sum+=l.rr*r.ll;

//更新ll
if(l.ll==get_len(l)){
tr[u].ll=l.ll;
if(w[l.r]<=w[r.l])tr[u].ll+=r.ll;
}
else tr[u].ll=l.ll;

//更新rr
if(r.rr==get_len(r)){
tr[u].rr=r.rr;
if(w[l.r]<=w[r.l])tr[u].rr+=l.rr;
}
else tr[u].rr=r.rr;
}

void build(int u,int l,int r){
if(l==r){
tr[u]={l,r,1,1,1};
return;
}

tr[u]={l,r,0,0,0};

int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);

}


void modify(int u,int x,int v){
if(tr[u].l==tr[u].r){
if(tr[u].l==x)w[x]=v;
//pushup(u);
return;
}

int mid=tr[u].l+tr[u].r>>1;
if(x<=mid) modify(u<<1,x,v);
if(x>mid) modify(u<<1|1,x,v);
pushup(u);

//cout<<tr[u].l<<" "<<tr[u].r<<"****"<<tr[u].ll<<" "<<tr[u].rr<<"*****"<<tr[u].sum<<" "<<w[tr[u].l]<<endl;
}

int query(int u,int l,int r){
if(tr[u].l>=l&&tr[u].r<=r)return tr[u].sum;

int mid=tr[u].l+tr[u].r>>1;
if(r<=mid) return query(u<<1,l,r);
else if(l>mid) return query(u<<1|1,l,r);
else{
int res=query(u<<1,l,r)+query(u<<1|1,l,r);

auto L=tr[u<<1],R=tr[u<<1|1];
int k=0;
if(w[L.r]<=w[R.l])k=min(L.r-l+1,L.rr)*min(r-R.l+1,R.ll);

return res+k;
}
}


main(){
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>w[i];

build(1,1,n);

while(m--){
int op,a,b;
cin>>op>>a>>b;
if(op==1) modify(1,a,b);
else cout<<query(1,a,b)<<endl;
}
}


举报

相关推荐

0 条评论