母函数
算法简述
普通母函数:对于序列构造一函数称函数G(x)是序列的母函数。
指数型母函数:对于序列函数称为序列的指数型母函数。这样对于一个多重集,其中重复了次,重复了次,重复了次,如果从n个元素中取r个元素排列,不同的排列数所对应的指数型母函数为。
主要运用场合及思路
母函数可以帮助我们有效的优化算法,解决问题,往往需要我们有良好的分析问题的能力,能将问题转化为母函数的模型上,而且往往需要进行数学运算求解。
模板
普通母函数(hdu2082):
#include <stdio.h>
#include <string.h>
int a[55],b[55]; //a数组保存最后的结果
int main(){
int n;
scanf("%d",&n);
while(n --){
memset(a,0,sizeof(a));
memset(b,0,sizeof(b));
a[0] = 1;
int num;
for(int i = 1;i <= 26;i ++){
scanf("%d",&num);
for(int j = 0;j <= 50;j ++) //j表示前面i个表达式累乘得到的表达式里第j个变量
for(int k = 0;k <= num && k*i+j <= 50;k ++)//k表示的是此时计算的多项式的第k个指数
b[k*i+j] += a[j];
for(int j = 0;j <= 50;j ++){
a[j] = b[j];
b[j] = 0;
}
}
int ans = 0;
for(int i = 1;i <= 50;i ++)
ans += a[i];
printf("%d\n",ans);
}
return 0;
}
指数型母函数(hdu1521):
#include <stdio.h>
#include <algorithm>
#include <stdio.h>
using namespace std;
double fac[] = {1,1,2,6,24,120,720,5040,40320,362880,3628800};
int main(){
double num1[11],num2[11],a[11];
int n,m;
while(~scanf("%d%d",&n,&m)){
for(int i = 0;i < n;i ++)
scanf("%lf",&a[i]);
for(int i = 0;i <= m;i ++)
num2[i] = num1[i] = 0.0;
for(int i = 0;i <= a[0];i ++)
num1[i] = 1.0 / fac[i];
for(int i = 1;i < n;i ++){
for(int j = 0;j <= m;j ++)
for(int k = 0;k <= a[i] && k + j <= m;k ++)
num2[j+k] += num1[j]/fac[k];
for(int j = 0;j <= m;j ++){
num1[j] = num2[j];
num2[j] = 0;
}
}
printf("%.0lf\n",num1[m] *1.0* fac[m]);
}
return 0;
}
例题
hdu5616
题意:有n个质量已知的砝码和一个天平,砝码可以放在天平的左端或者右端,求能否称出某个质量。
思路:此题有多种解法,可以构造母函数:,注意到此时可以把砝码放在天平的左端或者右端,所以的指数的系数可正可负,最后若求能否称出,只需查看的系数是否为0就可以了。
参考代码:
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
int a[2010],b[2010];
int num[25];
int main(){
int t;
scanf("%d",&t);
while(t --){
int n,sum = 0;
scanf("%d",&n);
for(int i = 1;i <= n;i ++){
scanf("%d",&num[i]);
sum += num[i];
}
memset(a,0,sizeof(a));
memset(b,0,sizeof(b));
a[0] = 1;
for(int i = 1;i <= n;i ++){
for(int j = 0;j <= sum;j ++)
for(int k = 0;k <= 1 && k*num[i]+j <= sum;k ++){
b[k*num[i]+j] += a[j];
b[abs(k*num[i]-j)] += a[j];
}
for(int j = 0;j <= sum;j ++){
a[j] = b[j];
b[j] = 0;
}
}
int m;
scanf("%d",&m);
while(m --){
int w;
scanf("%d",&w);
if(w > sum) {
printf("NO\n");
continue;
}
if(a[w]) printf("YES\n");
else printf("NO\n");
}
}
return 0;
}
poj3734
题意:有一排砖,数量为n,有红蓝绿黄4种颜色,其中染成红和绿颜色的砖块的数量必须为偶数个,求可有多少种染色方案。
思路:根据题意构造一个指数型母函数:,根据泰勒展开,所以,所以最终答案为。
参考代码:
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
#define mod 10007
int quick(int a,int b){
int ans = 1;
while(b){
if(b&1) ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}
int main(){
int t;
scanf("%d",&t);
while(t --){
int n;
scanf("%d",&n);
printf("%d\n",(quick(4,n-1)+quick(2,n-1)) % mod);
}
return 0;
}
poj 1322 Chocolate
题意:一个口袋中装有巧克力,巧克力的颜色有c种。现从口袋中取出一个巧克力,若取出的巧克力与桌上已有巧克力颜色相同,则将两个巧克力都取走,否则将取出的巧克力放在桌上。设从口袋中取出每种颜色的巧克力的概率均等。求取出n个巧克力后桌面上剩余m个巧克力的概率。
思路:由于从口袋中取n个巧克力对应有种情况,这种情况会考虑到不同颜色的巧克力之间的排列关系,所以适用于指数型母函数来解决。 桌上剩余的m个巧克力颜色一定互不相同,所以此题可以转化为有m种巧克力取出了奇数次,c-m中巧克力取出了偶数次,与上一题类似,因此我们可以构造母函数:。最终的情况数就是G(x)中的系数,n!是指数型母函数所需要乘的,表示从c种颜色中取m种颜色取了奇数次。因此最终的结果就是。
参考代码:
#include <stdio.h>
double po[111], ne[111], pp[111], nn[111];
double powmod(double x, int n) {
double ret = 1;
while(n) {
if(n&1) ret *= x;
x *= x;
n /= 2;
}
return ret;
}
int c, n, m;
// 由于C(n, k)可能会很大,不能直接预处理出组合数
double cal(double ret, int n, int k) {
if(n-k < k) k = n-k;
for(int i = n;i > n-k; i--)
ret *= i;
for(int i = 1;i <= k; i++)
ret /= i;
return ret;
}
void solve() {
int i, j;
for(i = 0;i <= c; i++) {
po[i] = ne[i] = pp[i] = nn[i] = 0;
}
double chu = powmod(1.0/2, m);
for(i = 0;i <= m; i++) {
int now = i-m+i;
int flag = 1;
if((m-i)&1) flag = -1;
if(now >= 0) po[now] += cal(chu*flag, m, i); // 保存e^(kx)的系数
else ne[-now] += cal(chu*flag, m, i); // 保存e^(-kx)的系数
}
chu = powmod(1.0/2, c-m);
for(i = 0;i <= c-m; i++) {
double cur = cal(chu, c-m, i);
for(j = 0;j <= m; j++) {
int now = j+i-(c-m)+i;
if(now >= 0) pp[now] += po[j]*cur; // 直接合并系数
else nn[-now] += po[j]*cur;
}
for(j = 0;j <= m; j++) {
int now = -j + i-(c-m)+i;
if(now >= 0) pp[now] += ne[j]*cur;
else nn[-now] += ne[j]*cur;
}
}
double ans = 0;
for(i = 1;i <= c; i++) {
ans += cal( pp[i]*powmod((double)i/c, n), c, m);
}
for(i = 1;i <= c; i++) {
if(n&1) nn[i] = -nn[i];
ans += cal( nn[i]*powmod((double)i/c, n), c, m);
}
printf("%.3lf\n", ans);
}
int main() {
while(scanf("%d", &c) != -1 && c) {
scanf("%d%d", &n, &m);
if(m > n || m > c || (n-m)%2==1) {
puts("0.000"); continue;
}
// 尤其要注意n等于0 && m等于0 要特判
if(n == 0 && m == 0) {
puts("1.000"); continue;
}
solve();
}
return 0;
}