题意
Description
再过三个多月就是圣诞节了,小R 想送小Y 一棵圣诞树作为节日礼物。因为他想让这棵圣诞树越大越好,所以当然是买不到能够让他满意的树的,因此他打算自己把这棵树拼出来。
现在,小R 开始画这棵树的设计图纸了。因为这棵树实在太大,所以他采用了一种比较方便的方法。首先他定义了m+ 1 棵树T0 到Tm。最开始他只画好了T0 的图纸:就只有一个点,编号为0。
接着,对于每一棵树Ti,他在第Tai 棵树的第ci 个点和第Tbi 棵树的第di 个点之间连上了一条长度为li 的边。在Ti 中,他保持Tai 中的所有节点编号不变,然后如果Tai 中有s 个节点,他会把Tbi 中的所有节点的编号加上s。
终于,他画好了所有的树。现在他定义一颗大小为n 的树的美观度为树中任意两个点对的距离和,其中d(i; j) 为这棵树中i 到j 的最短距离。
为了方便小R 选择等究竟拼哪一棵树,你可以分别告诉他T1 到Tm 的美观度吗?答案可能很大,请对10^9 + 7 取模后输出。
Input
第一行输入一个正整数T 表示数据组数。每组数据的第一行是一个整数m,接下来m 行每行五个整数ai, bi, ci, di, li,保证0 <= ai, bi < i, 0<= li<= 10^9,ci, di 存在。
Output
对于每组询问输出m 行。第i 行输出Ti 的权值
Sample Input
1
2
0 0 0 0 2
1 1 0 0 4
Sample Output
2
28
Data Constraint
对于30% 的数据,m <= 8
对于60% 的数据,m <= 16
对于100% 的数据,1 <= m<= 60,T<= 100
Solution
这题首先你得把题读懂。。。
读懂以后会发现树(T_i) 是(T_{ai})的第(ci)个点与(T_{bi})的第(di)个点连一条长度为(li)的边形成的树。树的结点个数是指数级增长的,所以暴力肯定不行。。。
设(f_i)表示(T_i)的答案,(getdis(u,x,y))表示(T_u)中第(x)个结点与第(y)个结点的距离,(getall(u,x))表示(T_u)中第(x)个结点到其他所有点的距离之和,(size_u)表示(T_u)结点个数,容易得到下面的式子:
- (size_0=1,size_i=size_{ai}+size_{bi} (i > 0))
- (f_i=f_{ai}+f_{bi}+getall(a_i,c_i)*size_{bi}+getall(b_i,d_i)*size_{ai}+size_{ai}*size_{bi}*l_i)
还有(getall)和(getdis),这两个函数也可以通过分类讨论递归定义,按照递归的定义搜索求出函数值,就能get到不错的成绩。
正解其实就是加上个记忆化,把搜过的(getdis)和(getall)记录下来,下一次要调用时直接使用,时间复杂度(O()玄())?
记忆化可以用map或者哈希,这里我图方便用map了,其实hash应该更快一点。
由于m范围很小,所以就能过了。
Code
#include <map>
#include <cstdio>
#include <cstring>
using namespace std;
typedef long long ll;
const int N = 67, mo = 1e9 + 7;
const ll P = 1e9 + 7;
int T, n;
ll a[N], b[N], c[N], d[N], l[N], f[N], size[N], sz[N];
struct note1 { int a, b, c; };
int operator<(note1 x, note1 y) { return x.a == y.a ? (x.b == y.b ? x.c < y.c : x.b < y.b) : x.a < y.a; }
struct note2 { int a, b; };
int operator<(note2 x, note2 y) { return x.a == y.a ? x.b < y.b : x.a < y.a; }
map<note1, ll> h1;
map<note2, ll> h2;
ll getdis(ll x, ll u, ll v)
{
if (!x || u == v) return 0;
ll ret = h1[(note1){x, u, v}];
if (ret > 0) return ret;
if (u < size[a[x]] && v < size[a[x]]) ret = getdis(a[x], u, v);
else if (u >= size[a[x]] && v >= size[a[x]]) ret = getdis(b[x], u - size[a[x]], v - size[a[x]]);
else if (u < size[a[x]] && v >= size[a[x]]) ret = (getdis(a[x], u, c[x]) + getdis(b[x], v - size[a[x]], d[x]) + l[x]) % P;
else ret = (getdis(b[x], u - size[a[x]], d[x]) + getdis(a[x], v, c[x]) + l[x]) % P;
h1[(note1){x, u, v}] = ret;
return ret;
}
ll getall(ll x, ll u)
{
if (!x) return 0;
ll ret = h2[(note2){x, u}];
if (ret > 0) return ret;
if (u < size[a[x]]) ret = (getall(a[x], u) + (getdis(a[x], u, c[x]) + l[x]) * sz[b[x]] % P + getall(b[x], d[x])) % P;
else ret = (getall(b[x], u - size[a[x]]) + (getdis(b[x], u - size[a[x]], d[x]) + l[x]) * sz[a[x]] % P + getall(a[x], c[x])) % P;
h2[(note2){x, u}] = ret;
return ret;
}
int main()
{
scanf("%d", &T);
while (T--)
{
h1.clear(), h2.clear();
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%lld%lld%lld%lld%lld", a + i, b + i, c + i, d + i, l + i);
size[0] = sz[0] = 1;
for (int i = 1; i <= n; i++) size[i] = size[a[i]] + size[b[i]], sz[i] = size[i] % P;
for (int i = 1; i <= n; i++) f[i] = (f[a[i]] + f[b[i]] + getall(a[i], c[i]) * sz[b[i]] % P + getall(b[i], d[i]) * sz[a[i]] % P + sz[a[i]] * sz[b[i]] % P * l[i] % P) % P, printf("%lld
", f[i]);
}
return 0;
}