Link

Sol

yzhx点分治讲课的一道简单练习题

这道题就是要求出树上有多少条路径的长度是3的倍数

引理:一个有$n$个点的树一共有$n×(n-1)$条路径

注意到每次选的两个点可以是同一个点,所以长度$3$的$0$倍也算作合法的路径

那么一共就有$n*n=n^2$条路径

点分治统计答案的时候可以直接把路径长模3,把余数分别为$1,2,0$的路径条数分别记作$sum[0],sum[1],sum[2]$

  1. 在同一棵子树内余数为$0$的路径显然可以直接两两合并成一条新的余数仍未的$0$路径,所以$sum[0]$对答案的实际贡献是$sum[0] * sum[0]$
  2. 长度对3取余为$1$和$2$的两条路径可以拼在一起组成一条新的长度为$3$的倍数的路径,那么无论是先选余数为$2$的还是为$1$的都合法,这两种路径之间可以互相两两选择,所以它们对答案的贡献是:$sum[1] * sum[2] * 2$

然后就做完了

Code

#include <bits/stdc++.h>
using namespace std;

const int SIZE = 2e4 + 5;

int n, num, tsiz, root, ans;
int head[SIZE], son[SIZE], siz[SIZE], vis[SIZE], sum[SIZE];

struct node {
    int to, v, nxt;
} edge[SIZE << 1];

inline int read()
{
    char ch = getchar();
    int f = 1, x = 0;
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + ch - '0'; ch = getchar(); }
    return x * f;
}

inline void addEdge(int u, int v, int d)
{
    edge[++ num].to = v, edge[num].v = d, edge[num].nxt = head[u];
    head[u] = num;
}

inline int gcd(int a, int b) 
{
    if (!b) return a;
    return gcd(b, a % b);
}

inline void findRoot(int u, int fa)
{
    son[u] = 0, siz[u] = 1;
    for (int i = head[u]; i; i = edge[i].nxt) {
        int v = edge[i].to;
        if (v == fa || vis[v]) continue;
        findRoot(v, u);
        siz[u] += siz[v], son[u] = std::max(son[u], siz[v]);
    }
    son[u] = std::max(son[u], tsiz - siz[u]);
    if (son[root] > son[u]) root = u;
}

inline void dfs(int u, int fa, int dis)
{
    ++ sum[(dis % 3)];
    for (register int i = head[u]; i; i = edge[i].nxt) {
        int v = edge[i].to;
        if (vis[v] || v == fa) continue;
        dfs(v, u, dis + edge[i].v);
    }
}

inline int calc(int u, int val)
{
    sum[0] = sum[1] = sum[2] = 0;
    dfs(u, 0, val);
    return sum[1] * sum[2] * 2 + sum[0] * sum[0];
}

inline void merge(int u)
{
    ans += calc(u, 0);
    vis[u] = 1;
    for (register int i = head[u], v; i; i = edge[i].nxt) {
        v = edge[i].to;
        if (vis[v]) continue;
        ans -= calc(v, edge[i].v);
        root = 0, son[0] = n, tsiz = siz[v];
        findRoot(v, u);
        merge(root);
    }
}

int main()
{
    // freopen("std.in", "r", stdin);
    // freopen("my.out", "w", stdout);
    n = read();
    for (int i = 1, u, v, d; i < n; ++ i) {
        u = read(), v = read(), d = read();
        addEdge(u, v, d); addEdge(v, u, d);
    }
    son[0] = tsiz = n, root = 0;
    findRoot(1, 0);
    merge(root);
    int com = gcd(n * n, ans);
    // printf("com: %d ans:%d", com, ans);
    printf("%d/%d\n", ans / com, n * n / com);
    return 0;
}
最后修改:2021 年 09 月 04 日
如果觉得我的文章对你有用,请随意赞赏