AVL Trees
题目连接:
(http://codeforces.com/gym/100341)
题意
avl树是每棵子树的左右子树高度之差小于等于1,给你节点个数和树高,问有多少种树
题解:
很轻松地我们写出dp:
dp[h][n]表示树高h+1,n个节点的答案
(dp[h][n]=sum_{i=0}^{n-1}dp[h-1][i]*(dp[h-1][n-1-i]+2*dp[h-2][n-1-i]))
一眼看出这是n次ntt,最暴力3h次变换,优化一下2h次,前h/2由于有效的dp值很小,可以暴力算,那么就是h次;
假设我们学过信号与系统,我们知道(DP[h]=DP[h-1] imes(DP[h-1]+2*DP[h-2]) imesdelta[1]),那么我们只要3次变换就可以完成;
我们观察dp[0],只有dp[0][0]=1,那么(dp[0]=delta[0],DP[0]={1,1,1,1,..})
同理只有dp[1][1]=1,那么(dp[1]=delta[1],DP[1]={{g}^{0},{g}^{1},{g}^{2},..})
前两次变换就可以直接赋值,我们观察最后一次反变换,由于我们只要求x[n]的值,我们不需要把整个序列变换回来
根据变换式$$x[n]=frac{1}{N}sum_{k=0}{N-1}X[k]*g{-kn}$$
我们可以O(n)求出x[n]
那么这道题我们可以用O(hn)的算法完美解决,不用ntt变换
代码
//#include <bits/stdc++.h>
#include <stdio.h>
#include <iostream>
#include <string.h>
#include <math.h>
#include <stdlib.h>
#include <limits.h>
#include <algorithm>
#include <queue>
#include <vector>
#include <set>
#include <map>
#include <stack>
#include <bitset>
#include <string>
#include <time.h>
using namespace std;
long double esp=1e-11;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define fi first
#define se second
#define all(a) (a).begin(),(a).end()
#define cle(a) while(!a.empty())a.pop()
#define mem(p,c) memset(p,c,sizeof(p))
#define mp(A, B) make_pair(A, B)
#define pb push_back
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
typedef long long int LL;
const long double PI = acos((long double)-1);
const LL INF=0x3f3f3f3fll;
const int MOD =1000000007ll;
const int maxn=100100;
const int NUM = 1<<17;
int wn[NUM];
// p | deg | g 长度为2^k,且N|(p-1),p=c*(1<<k)+1,g为原根,g^phi(p)=1 %p的最小g
// 469762049 26 3
// 998244353 23 3
// 1004535809 21 3
// 1107296257 24 10
// 10000093151233 26 5
// 1000000523862017 26 3
// 1000000000949747713 26 2
LL mu(LL a,LL b,LL P)
{
LL ans=1;
while(b)
{
if(b&1)
ans=ans*a%P;
a=a*a%P;
b>>=1;
}
return ans;
}
void GetWn(int G,int P,int len)
{
wn[0] = 1, wn[1] = mu(G, (P - 1) / len, P);
for(int i = 2; i < len; i++)
wn[i] = 1LL * wn[i - 1] * wn[1] % P;
}
int dp[17][1<<17];
int main()
{
//freopen("in.txt", "r", stdin);
freopen("avl.in", "r", stdin);
freopen("avl.out", "w", stdout);
//::iterator iter; %I64d
//for(int x=1;x<=n;x++)
//for(int y=1;y<=n;y++)
//scanf("%d",&a);
//printf("%d
",ans);
int n,h;
scanf("%d%d",&n,&h);
if(n>=1<<(h+1))
{
printf("%d
",0);
return 0;
}
LL N=1ll<<(h+1),g=10,P=786433ll;
GetWn(g,P,N);
for(int x=0;x<N;x++)
dp[0][x]=1,dp[1][x]=wn[x];
for(int x=2;x<=h+1;x++)
for(int y=0;y<N;y++)
dp[x][y]=1ll*dp[x-1][y]*(dp[x-1][y]+2ll*dp[x-2][y])%P*wn[y]%P;
int ans=0,w=mu(wn[n],P-2,P);
int t=1;
for(int x=0;x<N;x++)
ans=(ans+1ll*t*dp[h+1][x])%P,t=1ll*t*w%P;
ans=1ll*ans*mu(N,P-2,P)%P;
printf("%d
",ans);
return 0;
}