先ORZ(Owen)一发。感觉是个很套路的题,这里给一个蒟蒻的需要特判数据的伪(nlog^2 n)算法,真正的两只(log)的还是去看标算吧(但这个好想好写跑不满啊)
首先这种树上路径统计的问题我们先套一个点分治上去就是了,然后求出分治中心连出去的每一条路径
我们还是套路地分成不同子树求解,如果我们对一条路径设一个二元组((dep,val))表示到分治中心的深度以及路径上的最大值,那么就考虑怎么把一棵子树内的路径和其它子树的合并了
考虑如果只有一个端点限制怎么做,很容易想到它们的和取值都满足单调性
那么我们直接把两边都按(dep)排序,然后直接( ext{two points})扫出一个端点的范围即可
接下来考虑两个端点怎么做,我们发现此时可以选的点形成了一个区间,既然这个区间两个端点都满足单调性,所以这个区间移动也是单调的
那么就很简单了,我们扫出每个点的合法区间之后考虑怎么统计答案
首先所有比它小的数最后的贡献都是它的权值,然后所有大于它的数的贡献即是它们本身
我们直接拿两个树状数组来维护,一个记录比它小的数的个数,另一个记录比它大的数的权值和,直接跟着指针一起添删点即可
很容易发现,这样的复杂度大部分消耗在排序上,如果每次暴力排序的话复杂度最坏会被卡到(n^2log^2 n),但是很多时候一个点的度数没有那么多,所以跑起来非常快,大致比两个(log)多一些常数吧
所以我们写完之后特判一下菊花图即可,这个更加简单,因为路径长度只有(1/2),暴力排序后枚举一下即可
CODE
#include<cstdio>
#include<cctype>
#include<algorithm>
#define RI int
#define CI const int&
#define Tp template <typename T>
using namespace std;
typedef long long LL;
const int N=100005,INF=1e9;
struct edge
{
int to,nxt,v;
}e[N<<1]; int n,head[N],rst[N],num,cnt,L,R,x,y,z; LL ans; bool flag;
class FileInputOutput
{
private:
static const int S=1<<21;
#define tc() (A==B&&(B=(A=Fin)+fread(Fin,1,S,stdin),A==B)?EOF:*A++)
char Fin[S],*A,*B;
public:
Tp inline void read(T& x)
{
x=0; char ch; while (!isdigit(ch=tc()));
while (x=(x<<3)+(x<<1)+(ch&15),isdigit(ch=tc()));
}
#undef tc
}F;
inline int min(CI a,CI b)
{
return a<b?a:b;
}
inline void addedge(CI x,CI y,CI z)
{
e[++cnt]=(edge){y,head[x],z}; head[x]=cnt;
e[++cnt]=(edge){x,head[y],z}; head[y]=cnt;
}
namespace Case1 //Point Division Solver
{
class Tree_Array
{
private:
int size[N]; LL sum[N]; bool vis[N];
public:
#define lowbit(x) x&-x
inline void add(CI x,CI y,CI opt,RI i=0)
{
for (i=x;i<=num;i+=lowbit(i)) size[i]+=opt;
for (i=x;i;i-=lowbit(i)) sum[i]+=opt*y;
}
inline int query_rk(RI x,int ret=0)
{
for (;x;x-=lowbit(x)) ret+=size[x]; return ret;
}
inline LL query_sum(RI x,LL ret=0)
{
for (;x<=num;x+=lowbit(x)) ret+=sum[x]; return ret;
}
#undef lowbit
}BIT;
#define to e[i].to
class Point_Division_Solver
{
private:
struct data
{
int dep,mxv;
friend inline bool operator < (const data& A,const data& B)
{
return A.dep<B.dep;
}
}s[N]; int size[N],tot,pl[N],pr[N]; bool vis[N];
inline int max(CI a,CI b)
{
return a>b?a:b;
}
inline void maxer(int& x,CI y)
{
if (y>x) x=y;
}
inline void DFS(CI now,CI fa,CI d,CI mx)
{
s[++tot]=(data){d,mx}; if (L<=d&&d<=R) ans+=rst[mx];
for (RI i=head[now];i;i=e[i].nxt)
if (to!=fa&&!vis[to]) DFS(to,now,d+1,max(mx,e[i].v));
}
public:
int mx[N],ots,rt;
inline void getrt(CI now,CI fa=1)
{
size[now]=1; mx[now]=0;
for (RI i=head[now];i;i=e[i].nxt) if (to!=fa&&!vis[to])
getrt(to,now),size[now]+=size[to],maxer(mx[now],size[to]);
if (maxer(mx[now],ots-size[now]),mx[now]<mx[rt]) rt=now;
}
inline void solve(CI now)
{
RI i,j; vis[now]=1; tot=0; int lst=0;
for (i=head[now];i;i=e[i].nxt) if (!vis[to])
{
DFS(to,now,1,e[i].v); sort(s+lst+2,s+tot+1); if (lst)
{
int p1=lst; for (j=lst+1;j<=tot;++j)
{
while (p1&&s[p1].dep+s[j].dep>R) --p1; pr[j]=p1;
}
int p2=1; for (j=tot;j>lst;--j)
{
while (p2<=lst&&s[p2].dep+s[j].dep<L) ++p2; pl[j]=p2;
}
p1=lst+1; p2=lst; for (j=lst+1;j<=tot;++j)
{
while (p1>pl[j]) --p1,BIT.add(s[p1].mxv,rst[s[p1].mxv],1);
while (p2>pr[j]) BIT.add(s[p2].mxv,rst[s[p2].mxv],-1),--p2;
ans+=1LL*BIT.query_rk(s[j].mxv)*rst[s[j].mxv]+BIT.query_sum(s[j].mxv+1);
}
for (j=p1;j<=p2;++j) BIT.add(s[j].mxv,rst[s[j].mxv],-1);
}
lst=tot; sort(s+1,s+lst+1);
}
for (i=head[now];i;i=e[i].nxt) if (!vis[to])
mx[rt=0]=INF,ots=size[to],getrt(to,now),solve(rt);
}
}PD;
#undef to
inline void solve(void)
{
sort(rst+1,rst+n); num=unique(rst+1,rst+n)-rst-1;
for (RI i=1;i<=cnt;++i) e[i].v=lower_bound(rst+1,rst+num+1,e[i].v)-rst;
PD.mx[PD.rt=0]=INF; PD.ots=n; PD.getrt(1); PD.solve(PD.rt);
}
};
namespace Case2 //Flower Graph Solver
{
inline void solve(void)
{
RI i; sort(rst+1,rst+n); if (L<=1&&1<=R)
for (i=1;i<n;++i) ans+=rst[i]; if (L<=2&&2<=R)
for (i=1;i<n;++i) ans+=1LL*rst[i]*(i-1);
}
};
int main()
{
//freopen("CODE.in","r",stdin); freopen("CODE.out","w",stdout);
RI i; for (F.read(n),F.read(L),F.read(R),flag=i=1;i<n;++i)
{
F.read(x); F.read(y); F.read(z);
addedge(x,y,z); rst[i]=z; if (min(x,y)!=1) flag=0;
}
if (flag) Case2::solve(); else Case1::solve();
return printf("%lld",ans<<1LL),0;
}