bzoj3516:国王奇遇记加强版(数论)

Author Avatar
Sakits 4月 02, 2018

题解

  这题题解的推法看是看懂了,但是考场上遇到了真不知道怎么推,当套路用又太局限了。于是我尝试自己推了推,得到的式子和题解的截然不同,最后居然也可以AC…
  推错了一万个地方,谢谢Blue233333帮忙查错,清清太强啦!!!
  开推
  要求

i=1nimmi\sum_{i=1}^n i^m\cdot m^i

  因为nn太大,尝试化成只与mm有关的递推式。
  设

Sk=i=1nikmiS_k=\sum_{i=1}^n i^k\cdot m^i

  则有:

Sk+(n+1)kmn+1=i=1n+1ikmi=m+i=1n(i+1)kmi+1=m+i=1nj=0k(kj)ijmi+1=m+j=0k(kj)i=1nijmi+1=m+j=0k(kj)Sjm\begin{aligned} S_k+(n+1)^k\cdot m^{n+1}&=\sum_{i=1}^{n+1}i^k\cdot m^i\\ &=m+\sum_{i=1}^n (i+1)^k\cdot m^{i+1}\\ &=m+\sum_{i=1}^n\sum_{j=0}^k\binom{k}{j}i^j\cdot m^{i+1}\\ &=m+\sum_{j=0}^k\binom{k}{j}\sum_{i=1}^ni^j\cdot m^{i+1}\\ &=m+\sum_{j=0}^k\binom{k}{j}S_j\cdot m \end{aligned}

  由上式得:

Sk+(n+1)kmn+1=m+j=0k(kj)SjmSk+(n+1)kmn+1=m+j=0k1(kj)Sjm+Skm(m1)Sk=(n+1)kmn+1j=0k1(kj)SjmmSk=(n+1)kmn+1j=0k1(kj)Sjmmm1\begin{gathered} S_k+(n+1)^k\cdot m^{n+1}=m+\sum_{j=0}^k\binom{k}{j}S_j\cdot m\\ S_k+(n+1)^k\cdot m^{n+1}=m+\sum_{j=0}^{k-1}\binom{k}{j}S_j\cdot m+S_k\cdot m\\ (m-1)S_k=(n+1)^k\cdot m^{n+1}-\sum_{j=0}^{k-1}\binom{k}{j}S_j\cdot m-m\\ S_k=\frac{(n+1)^k\cdot m^{n+1}-\sum_{j=0}^{k-1}\binom{k}{j}S_j\cdot m-m}{m-1} \end{gathered}

  推导结束,由此我们可以以O(m2)O(m^2)的时间复杂度求出SmS_m

代码

#include<iostream> 
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<cmath> 
#include<algorithm> 
#define MOD(x) ((x)>=mod?(x)-mod:(x))
using namespace std;
const int maxn=1010, inf=1e9, mod=1e9+7;
int n, m;
int c[maxn][maxn], f[maxn];
inline void read(int &k)
{
	int f=1; k=0; char c=getchar();
	while(c<'0' || c>'9') c=='-'&&(f=-1), c=getchar();
	while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
	k*=f;
}
inline int power(int a, int b)
{
	int ans=1;
	for(;b;b>>=1, a=1ll*a*a%mod)
	if(b&1) ans=1ll*ans*a%mod;
	return ans;	
} 
int main()
{
	read(n); read(m);
	if(m==1) return printf("%lld\n", (1ll*(n+1)*n>>1)%mod), 0;
	c[0][0]=1; 
	for(int i=1;i<=m;i++) 
	{
		c[i][0]=1;
		for(int j=1;j<=i;j++) c[i][j]=MOD(c[i-1][j]+c[i-1][j-1]);
	}
	int sum=power(m, n+1), fm=power(m-1, mod-2);
	f[0]=1ll*(sum-m+mod)*fm%mod;
	for(int i=1;i<=m;i++)
	{
		sum=1ll*sum*(n+1)%mod; f[i]=MOD(sum-m+mod);
		for(int j=0;j<i;j++)
			f[i]=(f[i]-1ll*c[i][j]*f[j]%mod*m%mod+mod)%mod;
		f[i]=1ll*f[i]*fm%mod;
	}
	printf("%d\n", f[m]);
}