题解

给定一个序列,要求实现区间取$min$(即区间内$a_i$更新为$min(a_i,val)$),区间求$max$,区间求$sum$。

做法:

维护区间最大值$mx$,区间严格次大值$sx$,区间最大值出现的次数$cx$,然后进行分类讨论:

  1、$val>=mx$,明显对区间无影响,退出;

  2、$sx<val<mx$,此时$mx$会被更改成$val$,而$sx,cx$则不变,对区间和$sum$的贡献为$(mx-val)*cx$;(此处不能使$sx=val$,否则在更新后$mx=cx$,此时$cx$也需要改变)

  3、$val<=sx$,跳到当前节点的两个子节点去处理这种情况再$pushup$回来。

在以上算法中,区间取$min$操作实际上被转化为了对区间的最大值集体减小一个数的操作,且这次操作不会让最大值比严格次大值更小。

另外,我们发现$mx$本身就可以作为取$min$的标记,所以我们无需另外打标记。

时间复杂度为$O(log^2n)$。

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6+5;
int Case;
int n,m;
int a[N];

struct STB{
int tot;
int ls[N<<1],rs[N<<1];
int mx[N<<1];//区间最大值
int sx[N<<1];//区间严格次大值
int cx[N<<1];//区间最大值出现的次数
ll sum[N<<1];

void up(int x){
sum[x]=sum[ls[x]]+sum[rs[x]];
mx[x]=max(mx[ls[x]],mx[rs[x]]);
sx[x]=max(sx[ls[x]],sx[rs[x]]);
if(mx[ls[x]]!=mx[rs[x]])sx[x]=max(sx[x],min(mx[ls[x]],mx[rs[x]]));
cx[x]=0;
if(mx[x]==mx[ls[x]])cx[x]+=cx[ls[x]];
if(mx[x]==mx[rs[x]])cx[x]+=cx[rs[x]];
}

int build(int l,int r){
int x=++tot;
if(l==r){
mx[x]=sum[x]=a[l];
sx[x]=-1e9;
cx[x]=1;
return x;
}
int mid=l+r>>1;
ls[x]=build(l,mid);
rs[x]=build(mid+1,r);
up(x);
return x;
}

void downmin(int x,int val){
sum[x]-=1ll*(mx[x]-val)*cx[x];
mx[x]=val;
}

void down(int x){
if(mx[x]<mx[ls[x]])downmin(ls[x],mx[x]);
if(mx[x]<mx[rs[x]])downmin(rs[x],mx[x]);
}

void changemin(int x,int l,int r,int L,int R,int val){
if(val>=mx[x])return;
if(L<=l&&r<=R&&val>sx[x]){
downmin(x,val);
return;
}
down(x);
int mid=l+r>>1;
if(L<=mid)changemin(ls[x],l,mid,L,R,val);
if(R>mid)changemin(rs[x],mid+1,r,L,R,val);
up(x);
}

ll getsum(int x,int l,int r,int L,int R){
if(L<=l&&r<=R)return sum[x];
down(x);
int mid=l+r>>1;
ll sum=0;
if(L<=mid)sum+=getsum(ls[x],l,mid,L,R);
if(R>mid)sum+=getsum(rs[x],mid+1,r,L,R);
up(x);
return sum;
}

int getmax(int x,int l,int r,int L,int R){
if(L<=l&&r<=R)return mx[x];
down(x);
int mid=l+r>>1;
int mx=0;
if(L<=mid)mx=max(mx,getmax(ls[x],l,mid,L,R));
if(R>mid)mx=max(mx,getmax(rs[x],mid+1,r,L,R));
up(x);
return mx;
}

void clear(){
for(int i=1;i<=tot;++i){
ls[i]=rs[i]=0;
mx[i]=sx[i]=cx[i]=0;
sum[i]=0;
}
tot=0;
}
}T;

int main(){
scanf("%d",&Case);
while(Case--){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i)scanf("%d",&a[i]);
T.clear();
T.build(1,n);
for(int i=1,o,x,y,t;i<=m;++i){
scanf("%d%d%d",&o,&x,&y);
if(o==0){
scanf("%d",&t);
T.changemin(1,1,n,x,y,t);
}
else if(o==1){
printf("%d\n",T.getmax(1,1,n,x,y));
}
else{
printf("%lld\n",T.getsum(1,1,n,x,y));
}
}
}
}