Sol
yzhx点分治讲课的一道简单练习题
这道题就是要求出树上有多少条路径的长度是3的倍数
引理:一个有$n$个点的树一共有$n×(n-1)$条路径
注意到每次选的两个点可以是同一个点,所以长度$3$的$0$倍也算作合法的路径
那么一共就有$n*n=n^2$条路径
点分治统计答案的时候可以直接把路径长模3,把余数分别为$1,2,0$的路径条数分别记作$sum[0],sum[1],sum[2]$
- 在同一棵子树内余数为$0$的路径显然可以直接两两合并成一条新的余数仍未的$0$路径,所以$sum[0]$对答案的实际贡献是$sum[0] * sum[0]$
- 长度对3取余为$1$和$2$的两条路径可以拼在一起组成一条新的长度为$3$的倍数的路径,那么无论是先选余数为$2$的还是为$1$的都合法,这两种路径之间可以互相两两选择,所以它们对答案的贡献是:$sum[1] * sum[2] * 2$
然后就做完了
Code
#include <bits/stdc++.h>
using namespace std;
const int SIZE = 2e4 + 5;
int n, num, tsiz, root, ans;
int head[SIZE], son[SIZE], siz[SIZE], vis[SIZE], sum[SIZE];
struct node {
int to, v, nxt;
} edge[SIZE << 1];
inline int read()
{
char ch = getchar();
int f = 1, x = 0;
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + ch - '0'; ch = getchar(); }
return x * f;
}
inline void addEdge(int u, int v, int d)
{
edge[++ num].to = v, edge[num].v = d, edge[num].nxt = head[u];
head[u] = num;
}
inline int gcd(int a, int b)
{
if (!b) return a;
return gcd(b, a % b);
}
inline void findRoot(int u, int fa)
{
son[u] = 0, siz[u] = 1;
for (int i = head[u]; i; i = edge[i].nxt) {
int v = edge[i].to;
if (v == fa || vis[v]) continue;
findRoot(v, u);
siz[u] += siz[v], son[u] = std::max(son[u], siz[v]);
}
son[u] = std::max(son[u], tsiz - siz[u]);
if (son[root] > son[u]) root = u;
}
inline void dfs(int u, int fa, int dis)
{
++ sum[(dis % 3)];
for (register int i = head[u]; i; i = edge[i].nxt) {
int v = edge[i].to;
if (vis[v] || v == fa) continue;
dfs(v, u, dis + edge[i].v);
}
}
inline int calc(int u, int val)
{
sum[0] = sum[1] = sum[2] = 0;
dfs(u, 0, val);
return sum[1] * sum[2] * 2 + sum[0] * sum[0];
}
inline void merge(int u)
{
ans += calc(u, 0);
vis[u] = 1;
for (register int i = head[u], v; i; i = edge[i].nxt) {
v = edge[i].to;
if (vis[v]) continue;
ans -= calc(v, edge[i].v);
root = 0, son[0] = n, tsiz = siz[v];
findRoot(v, u);
merge(root);
}
}
int main()
{
// freopen("std.in", "r", stdin);
// freopen("my.out", "w", stdout);
n = read();
for (int i = 1, u, v, d; i < n; ++ i) {
u = read(), v = read(), d = read();
addEdge(u, v, d); addEdge(v, u, d);
}
son[0] = tsiz = n, root = 0;
findRoot(1, 0);
merge(root);
int com = gcd(n * n, ans);
// printf("com: %d ans:%d", com, ans);
printf("%d/%d\n", ans / com, n * n / com);
return 0;
}