bzoj4785:[Zjoi2017]树状数组(树套树)

Author Avatar
Sakits 4月 10, 2018

题解

  这题当作存一下树套树板子吧…
  首先可以一眼看出其实算的是后缀和。
  若l>1l>1,要查一个区间[l,r][l,r]的时候,查的实际上是[l1,r1][l-1,r-1],所以只需要看l1,rl-1,r这两个位置修改的次数奇偶性是否相同即可,那么我们可以用平面上的一个点(l1,r)(l-1,r)来表示这个概率,用一个树套树来维护即可。
  两个点被修改的概率分别为p1,p2p_1,p_2,那么修改次数奇偶性相同的概率就是p1p2+(1p1)(1p2)p_1p_2+(1-p_1)(1-p_2)
  若l=1l=1,实际上就是查rr的前缀后缀修改次数奇偶性相同的概率,这个也是可以维护的。
  修改的时候,[1,l1],[r+1,n][1,l-1],[r+1,n][l,r][l,r]相同概率是pp[l,r][l,r]内部相同概率是1p21-p*2,合并的时候用上式合并即可。维护前缀后缀修改次数奇偶性相等的话,只有rr被修改的时候才可能相等,所以[l,r][l,r]内概率为pp[1,l1],[r+1,n][1,l-1],[r+1,n]概率为00,合并同理。

代码

#include<iostream> 
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<cmath> 
#include<algorithm> 
#define MOD(x) ((x)>=mod?(x)-mod:(x))
#define lt (tree[x].ls)
#define rt (tree[x].rs)
using namespace std;
const int maxn=100010, inf=1e9, mod=998244353;
struct poi{int sum, ls, rs;}tree[maxn*280]; 
int n, m, ty, l, r, ans, tott;
int root[1<<20];

char buf[40000010],*ptr=buf-1;
inline void read(int &k)
{
	char c=*++ptr; k=0;
	while(c<48 || c>57) c=*++ptr;
	while(c>=48 && c<=57) k=k*10+c-'0', c=*++ptr;
}

inline int power(int a, int b)
{
	int ans=1;
	for(;b;b>>=1, a=1ll*a*a%mod)
	if(b&1) ans=1ll*ans*a%mod;
	return ans;
}

inline void addone(int &x, int y){x=(1ll*x*y+1ll*(1-x+mod)*(1-y+mod))%mod;}

void update2(int &x, int l, int r, int cl, int cr, int delta)
{
	if(!x) x=++tott, tree[x].sum=1;
	if(cl<=l && r<=cr){addone(tree[x].sum, delta); return;}
	int mid=(l+r)>>1;
	if(cl<=mid) update2(lt, l, mid, cl, cr, delta);
	if(cr>mid) update2(rt, mid+1, r, cl, cr, delta);
}

void query2(int x, int l, int r, int cx)
{
	if(!x) return; 
	addone(ans, tree[x].sum); if(l==r) return;
	int mid=(l+r)>>1;
	if(cx<=mid) query2(lt, l, mid, cx);
	else query2(rt, mid+1, r, cx);
}

void update1(int x, int l, int r, int cl, int cr, int cl2, int cr2, int delta)
{
	if(cl<=l && r<=cr){update2(root[x], 1, n, cl2, cr2, delta); return;}
	int mid=(l+r)>>1;
	if(cl<=mid) update1(x<<1, l, mid, cl, cr, cl2, cr2, delta);
	if(cr>mid) update1(x<<1|1, mid+1, r, cl, cr, cl2, cr2, delta); 
}

void query1(int x, int l, int r, int cx, int cx2)
{
	query2(root[x], 1, n, cx2);
	if(l==r) return;
	int mid=(l+r)>>1;
	if(cx<=mid) query1(x<<1, l, mid, cx, cx2);
	else query1(x<<1|1, mid+1, r, cx, cx2); 
}

int main()
{
	fread(buf,1,sizeof(buf),stdin); read(n); read(m);
	for(int i=1;i<=m;i++)
	{
		read(ty); read(l); read(r);
		if(ty==1)
		{
			int p=power(r-l+1, mod-2);
			if(l>1) 
			{
				update1(1, 1, n, 1, l-1, l, r, MOD(1-p+mod));
				update2(root[0], 1, n, 1, l-1, 0);
			}
			if(r<n)
			{
				update1(1, 1, n, l, r, r+1, n, MOD(1-p+mod));
				update2(root[0], 1, n, r+1, n, 0);
			}
			if(l!=r) update1(1, 1, n, l, r, l, r, MOD(MOD(1-p+mod)-p+mod));
			update2(root[0], 1, n, l, r, p);
		}
		else
		{
			ans=1; 
			if(l==1) query2(root[0], 1, n, r);
			else query1(1, 1, n, l-1, r);
			printf("%d\n", ans);
		}
	}
}