Codeforces 961G. Partitions(斯特林数)

Author Avatar
Sakits 5月 13, 2018

题解

  套路题,虽然题解的做法不是套路,但是有一种很套路的做法…
  首先很显然可以看出答案为以下式子:

ans=cnt×i=1nwians=cnt\times \sum_{i=1}^nw_i

  其中cntcnt就是一个元素在每种方案中被分到的集合大小之和
  显然可以枚举集合大小,一开始忘记乘上集合大小列的式子和题解不一样懵逼了好久= =…
  推导如下:

cnt=i=1ni(n1i1){nik1}=i=1ni(n1i1)j=0k1(1)jj!×(kj1)ni(kj1)!=j=0k1(1)jj!(kj1)!i=1ni(n1i1)(kj1)ni\begin{aligned} cnt&=\sum_{i=1}^ni\binom{n-1}{i-1} \begin{Bmatrix} n-i\\ k-1 \end{Bmatrix}\\ &=\sum_{i=1}^ni\binom{n-1}{i-1}\sum_{j=0}^{k-1}\frac{(-1)^j}{j!}\times \frac{(k-j-1)^{n-i}}{(k-j-1)!}\\ &=\sum_{j=0}^{k-1}\frac{(-1)^j}{j!(k-j-1)!}\sum_{i=1}^{n}i\binom{n-1}{i-1}(k-j-1)^{n-i}\\ \end{aligned}

  考虑后面的组合数部分:

i=1ni(n1i1)(kj1)ni=i=1n(n1i1)(kj1)ni+i=1n(i1)(n1i1)(kj1)ni=i=0n1(n1i)(kj)ni1+i=0n1i(n1i)(kj1)ni1=(kj)n1+(n1)i=1n1(n2i1)(kj1)ni1=(kj)n1+(n1)i=0n2(n2i)(kj1)ni2=(kj)n1+(n1)(kj)n2\begin{aligned} &\sum_{i=1}^{n}i\binom{n-1}{i-1}(k-j-1)^{n-i}\\ =&\sum_{i=1}^{n}\binom{n-1}{i-1}(k-j-1)^{n-i}+\sum_{i=1}^{n}(i-1)\binom{n-1}{i-1}(k-j-1)^{n-i}\\ =&\sum_{i=0}^{n-1}\binom{n-1}{i}(k-j)^{n-i-1}+\sum_{i=0}^{n-1}i\binom{n-1}{i}(k-j-1)^{n-i-1}\\ =&(k-j)^{n-1}+(n-1)\sum_{i=1}^{n-1}\binom{n-2}{i-1}(k-j-1)^{n-i-1}\\ =&(k-j)^{n-1}+(n-1)\sum_{i=0}^{n-2}\binom{n-2}{i}(k-j-1)^{n-i-2}\\ =&(k-j)^{n-1}+(n-1)(k-j)^{n-2}\\ \end{aligned}

  然后代回去就完了。推导2小时,代码5分钟
  套路在于把斯特林数的容斥系数拿去和组合数配对用二项式定理优化成数幂。

代码

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
#define MOD(x) ((x)>=mod?(x)-mod:(x))
#define ll long long
using namespace std;
const int maxn=500010, inf=1e9, mod=1e9+7;
int n, k, sum, ans, x;
int fac[maxn], inv[maxn];

inline void read(int &k)
{
	int f=1; k=0; char c=getchar();
	while(c<'0' || c>'9') c=='-'&&(f=-1), c=getchar();
	while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
	k*=f;	
}

inline int power(int a, int b)
{
	if(b<0) return 0; int ans=1;
	for(;b;b>>=1, a=1ll*a*a%mod)
	if(b&1) ans=1ll*ans*a%mod;
	return ans;
}
 
int main()
{
	read(n); read(k);
	for(int i=1;i<=n;i++) read(x), sum=MOD(sum+x);
	fac[0]=1; for(int i=1;i<=k;i++) fac[i]=1ll*fac[i-1]*i%mod;
	inv[k]=power(fac[k], mod-2); for(int i=k;i;i--) inv[i-1]=1ll*inv[i]*i%mod;
	for(int j=0;j<k;j++)
	ans=(ans+1ll*power(mod-1, j)*inv[j]%mod*inv[k-j-1]%mod*(power(k-j+mod, n-1)+1ll*(n-1)*power(k-j+mod, n-2)%mod))%mod;
	printf("%lld\n", 1ll*sum*ans%mod);
}