Link

中文题面链接

Sol

先来考虑一个简单版本的问题:我们只有一个北国银行。答案是:

$$ \sum\limits_{(x,fa)\in E} w \cdot min(size_{x},S - size_{x}) $$

证明:

不妨设树根为$1$。在上面的表达式中,$size_{x}$ 表示的是在 $x$ 子树中居民的数量,$S$表示树中居民的总数量。

考虑如何计算所有居民接收到一枚钟离盾的总时间:对于每一条边,它的贡献为 $w$ $*$ 通过这条边的居民的数量的总和,即对于每一条边 $(x, fa, w)$,我们有两种选择:一是选择子树内的所有居民通过该边(此时北国银行在子树外),故此部分贡献为 $w \cdot size_{x}$;二是选择 $x$ 的子树外的所有点上的居民通过该边进入 $x$ 的子树领取钟离盾(此时北国银行在子树内)。两种方式取 $min$ 即可。

考虑最终北国银行放置的位置,会有一条没有居民过的边。那么可以枚举该边并将此边断开,使原树分成两个树,再通过上面的方法对两棵树的答案独立计算。时间复杂度为 $O(n^2)$。

定义 $in_v$ 为 DFS 时进入$v$ 点的时间戳,$out_v$ 为 DFS 时离开 $v$ 点的时间戳。
假设只放置一个北国银行,我们就应该找到对于所有边 $e(v, to, w)$ :$\sum\limits_{e\in E} w \cdot min(size_{to}, S - size_{to})$ 的总和。考虑 $size_{to}$ 和 $S - size_{to}$ 什么时候会算到贡献里:$w \cdot min(size_{to}, S - size_{to})$ 为 $w \cdot size_{to}$ 当且仅当 $size_{to} \leq \frac{S}{2}$, 为 $w \cdot (S - size_{to})$ 当且仅当 $size_{to} > \frac{S}{2}$。

回到 $O(n^2)$ 的解法。我们在 DFS 时在树中移去边 $(v, to, w)$ 。现在有两棵子树,大小分别为:$X = size_{to}, Y = size_{1} - size_{to}$。 考虑分别在 $X, Y$ 两棵树上解决问题。

第一部分先计算 $to$ 子树中的点的答案: $w \cdot min(size_{to}, X - size_{to})$ 的值的和。我们可以计算出 $w \cdot size_{to} (\forall size_{to} \leq \frac{X}{2}) + w \cdot X (\forall size_{to} > \frac{X}{2}) - w \cdot size_{to} (\forall size_{to} > \frac{X}{2})$。

剩余的部分就是要求出区间 $[l,r]$ 内所有数字 $\leq K$ 的和。可以把区间 $[in_{to}, out_{to}]$ 内的点的 $w$ 和 $w \cdot size_{to}$ 丢进两个树状数组,分别统计两类的答案即可。

$to$ 的子树以外的部分做法也很类似。注意要将这些操作在区间 $[1, in_{to} - 1]$ 和 $[out_{to} + 1, n]$ 和子树 $Y$ 上完成。除了从根节点开始的链到 $v$ ( 包括 $to$ 的子树中的节点)。在这条链上 $size_u$ 单调递减,所以用双指针并且记录前缀和。可以减去我们要为 $\leq \frac{Y}{2}$ 计算的部分。同样对 $size_{1} - size_{u}$ 做这样的步骤。最后加上我们需要的 $> \frac{Y}{2}$ 的部分,这部分在这条链上递增。

时间复杂度和空间复杂度均为 $O(n \log n)$ 。

Code

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
const int SIZE = 5e4 + 5, SIZE_N = 2.5e5 + 5;
const LL inf = 9e18;

int n, rt, num, tim, tot;
int head[SIZE], a[SIZE], dep[SIZE], f[SIZE][18], dfn[SIZE], id[SIZE], lg[SIZE], fx[SIZE], s[SIZE], b[SIZE];
LL S, siz[SIZE], y[SIZE], z[SIZE], ans[SIZE], c[SIZE], C[SIZE];

struct node {
    int to, v, nxt;
} edge[SIZE << 1];

struct ask {
    int l, r, d; LL k, s;
    inline bool operator< (const ask &a) const {
        return k < a.k;
    }
} q[SIZE * 3];

void addEdge(int u, int v, int d) {
    edge[++ num] = (node) {v, d, head[u]}, head[u] = num;
}

namespace GTR {
    const int bufl = 1 << 15;
    char buf[bufl], *s = buf, *t = buf;
    inline int fetch() {
        if (s == t) { t = (s = buf) + fread(buf, 1, bufl, stdin); if (s == t) return EOF; }
        return *s++;
    }
    inline int read() {
        int a = 0, b = 1, c = fetch();
        while (c < 48 || c > 57) b ^= c == '-', c = fetch();
        while (c >= 48 && c <= 57) a = (a << 1) + (a << 3) + c - 48, c = fetch();
        return b ? a : -a;
    }
} using GTR::read;

namespace fenwick {
#define lowbit(x) (x & (-x))
    void modify(LL *bit, int pos, LL val) {
        for ( ; pos <= n; pos += lowbit(pos)) bit[pos] += val;
    }
    LL query(LL *bit, int pos) {
        LL ans = 0;
        for ( ; pos; pos -= lowbit(pos)) ans += bit[pos];
        return ans;
    }
    LL query(LL *bit, int l, int r) {
        return l > r ? 0ll : query(bit, r) - query(bit, l - 1);
    }
} ;

void dfs(int u, int fa) {
    siz[u] = a[u], s[u] = 1, dfn[u] = ++ tim;
    for (int i = head[u], v; i; i = edge[i].nxt) {
        v = edge[i].to;
        if (v == fa) continue;
        dep[v] = dep[u] + 1, f[v][0] = u, b[v] = edge[i].v;
        for (int j = 1; j < 18; ++ j) f[v][j] = f[f[v][j - 1]][j - 1];
        dfs(v, u);
        siz[u] += siz[v], s[u] += s[v];
    }
}

void dfs1(int u, int fa) {
    for (int i = head[u], v; i; i = edge[i].nxt) {
        v = edge[i].to;
        if (v == fa) continue;
        y[v] = y[u] + b[v], z[v] = z[u] + 1ll * b[v] * siz[v];
        dfs1(v, u); id[v] = ++ tot;
        q[tot] = (ask) {dfn[v] + 1, dfn[v] + s[v] - 1, tot, siz[v], 0ll};
        q[++ tot] = (ask) {1, dfn[v] - 1, tot, S - siz[v], 0ll};
        q[++ tot] = (ask) {dfn[v] + s[v], n, tot, S - siz[v], 0ll};
        int x = v; LL t = (S - siz[v]) >> 1;
        for (int j = lg[dep[v] - 1]; ~j; -- j) {
            if (f[x][j] && siz[f[x][j]] <= t) x = f[x][j];
        }
        ans[v] = (z[f[x][0]] << 1) - z[u] - y[f[x][0]] * (S - siz[v]);
        x = v; t = ((S - siz[v]) >> 1) + siz[v];
        for (int j = lg[dep[v] - 1]; ~j; -- j) {
            if (f[x][j] && siz[f[x][j]] <= t) x = f[x][j];
        }
        int fa = f[x][0];
        ans[v] += (z[u] - y[u] * siz[v]) - ((z[fa] - y[fa] * siz[v]) << 1) + y[fa] * (S - siz[v]);
    } 
}

int main() {
    int ti = read();
    while (ti --) {
        n = read(), rt = (n + 1) >> 1, S = num = tim = tot = 0, dep[rt] = 1;
        for (int i = 1; i <= n; ++ i) {
            a[i] = read(), fx[i] = i, S += a[i];
            head[i] = b[i] = y[i] = z[i] = c[i] = C[i] = 0;
        }
        for (int i = 2, u, v, d; i <= n; ++ i) {
            u = read(), v = read(), d = read();
            addEdge(u, v, d), addEdge(v, u, d);
            lg[i] = lg[i >> 1] + 1;
        }
        dfs(rt, 0), dfs1(rt, 0);
        std::sort(q + 1, q + tot + 1);
        std::sort(fx + 1, fx + n + 1, [] (int a, int b) { return siz[a] < siz[b]; });
        int t = 1;
        for (int i = 1; i <= tot; ++ i) {
            while (t <= n && siz[fx[t]] <= (q[i].k >> 1)) {
                fenwick::modify(c, dfn[fx[t]], 1ll * b[fx[t]] * siz[fx[t]]);
                ++ t;
            }
            q[i].s = fenwick::query(c, q[i].l, q[i].r);
        }
        t = n;
        for (int i = 1; i <= n; ++ i) c[i] = 0;
        for (int i = tot; i; -- i) {
            while (t && siz[fx[t]] > (q[i].k >> 1)) {
                fenwick::modify(c, dfn[fx[t]], 1ll * b[fx[t]] * siz[fx[t]]);
                fenwick::modify(C, dfn[fx[t]], 1ll * b[fx[t]]);
                -- t;
            }
            q[i].s += fenwick::query(C, q[i].l, q[i].r) * q[i].k - fenwick::query(c, q[i].l, q[i].r);
        }
        std::sort(q + 1, q + tot + 1, [] (ask a, ask b) { return a.d < b.d; });
        LL res = inf;
        for (int i = 1; i <= n; ++ i) {
            ans[i] += q[id[i]].s + q[id[i] + 1].s + q[id[i] + 2].s;
            if (i != rt && ans[i] < res) res = ans[i];
        }
        printf("%lld\n", res);
        for (int i = 1; i <= n; ++ i)
            for (int j = 0; 1 << j < dep[i]; ++ j) f[i][j] = 0;
    }
    return 0;
}
最后修改:2021 年 09 月 16 日
如果觉得我的文章对你有用,请随意赞赏