sol

  • 说在前头

考虑直接维护每个叶子作为当前子树的根的概率.

设当前叶子u作为子树根的值的概率为f(u)

发现a, b 2个子树合并时, 子树a的某一叶子u为根的概率:

f'(u) = \sum_{v\in \text{subtree}_b} [V_v < V_u] p_x\cdot f(v)f(u) +\sum_{v\in \text{subtree}_b} [V_v > V_u] (1-p_x)\cdot f(v)f(u)

那么这个转移就非常类似于启发式合并或线段树合并.

  • 但是启发式合并不能够直接维护, 每次合并需要遍历两个子树的叶子, 复杂度退化为O(n^2)

故, 考虑线段树合并.

如何合并?

个人感觉类似于cdq分治

当合并u, v时递归二者的左子树的时候

只需要计算rs(u)\to ls(v), \ rs(v) \to ls(u)即可.

a \to b” 表示a子树对b子树的影响!

举个例子, 对于rs(v)ls(u)做贡献时

由于rs(v)维护的每个叶子的权值都大于ls(u)

那么这个影响根据已有的式子:

\sum_{u\in \text{subtree}_{rs(v)}} f(u)\cdot (1-p_x)

我们要做的, 就是”区间乘”, 乘上这个式子.

但是注意, 这样不是完全的区间乘

而是对于所有上式(对叶子的贡献)先累加, 再相乘.(考虑开头给出的合并时计算u的概率的公式).

Code

调了有点久, 还是不够熟练

而且一定要注意, 由于线段树合并的节点是共用的(类似主席树).

那么对一个点做修改或各种操作一定要加新点!!

%:pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;

const int N = 3e5 + 50;
const int mod = 998244353;
# define pb push_back

int ans, n, x, a[N], d[N], vis[N];
int Ecnt, first[N], nex[N * 2], arr[N * 2];
vector<int> vec;

namespace {
    inline void read(int &x) {
        x = 0; char c = getchar();
        for(; !isdigit(c); c = getchar());
        for(;  isdigit(c); c = getchar())
            x = x * 10 + c - '0';
    }
    inline int mul(int a, int b) {
        return (LL) a * b % mod;
    }
    inline int add(int a, int b) {
        return a + b < mod ? a + b : a + b - mod;
    }
}

inline void Add(int u, int v) {
    nex[++Ecnt] = first[u], first[u] = Ecnt, arr[Ecnt] = v;
}
namespace Seg {
    const int N = :: N * 50;

    int sz, tot, px, rt[N], ls[N], rs[N], mt[N], sum[N];

    void init() { sum[0] = 1; }
// 设为 sum[0] = 1 是方便sum[rt] = mul(sum[ls[rt]], sum[rs[rt]]);不为0
    inline int new_node() {
        return mt[++tot] = 1, tot;
    }
    inline void pushdown(int u) {
        if(mt[u] == 1) return ;
        int x = mt[u], l = ls[u], r = rs[u];
        sum[l] = mul(sum[l], x), mt[l] = mul(mt[l], x);
        sum[r] = mul(sum[r], x), mt[r] = mul(mt[r], x);
        mt[u] = 1;
    }
    void modify(int &rt, int l, int r, int pos) {
        if(!rt) rt = new_node();
        if(l == r)
            return (void) (sum[rt] = 1);
        int mid = l + r >> 1;
        if(pos <= mid) modify(ls[rt], l, mid, pos);
        else modify(rs[rt], mid + 1, r, pos);
        sum[rt] = mul(sum[ls[rt]], sum[rs[rt]]);
    }
    int merge(int u, int v, int lp = 0, int rp = 0) {
        if(!u && !v) return 0;
        int rt = new_node();
        if(!v || !u) {
            sum[rt] = mul(sum[u | v], !v ? lp : rp);
            mt[rt] = mul(mt[u | v], !v ? lp : rp);
            ls[rt] = ls[u | v]; // 注意 ls, rs也要复制
            rs[rt] = rs[u | v];
            return rt;
        }
        pushdown(u), pushdown(v);
        ls[rt] = merge(ls[u], ls[v], add(lp, mul(sum[rs[v]], mod + 1 - px)), add(rp, mul(sum[rs[u]], mod + 1 - px)));
        rs[rt] = merge(rs[u], rs[v], add(lp, mul(sum[ls[v]], px)), add(rp, mul(sum[ls[u]], px)));
        sum[rt] = add(sum[ls[rt]], sum[rs[rt]]);
        return rt;
    }
}
using namespace Seg;

inline int sqr(int x) {
    return mul(x, x);
}
void dfs(int u, int fa = 0) {
    for(int i = first[u], v; i; i = nex[i]) {
        v = arr[i];
        if(v == fa) continue;
        dfs(v, u);
        if(!rt[u]) rt[u] = rt[v];
        else {
            px = a[u];
            rt[u] = merge(rt[u], rt[v]);
        }
    }
}
void Dfs(int u, int l = 1, int r = sz) {
    if(!u) return ;
    if(l == r) {
        int res = mul(l, mul(sqr(sum[u]), vec[l]));
        ans = add(ans, res);
        return ;
    }
    int mid = l + r >> 1;
    pushdown(u);
    Dfs(ls[u], l, mid);
    Dfs(rs[u], mid + 1, r);
}

int main() {
    read(n);
    for(int i = 1; i <= n; ++i) {
        read(x);
        ++d[i], ++d[x];
        if(x) Add(i, x), Add(x, i);
    }
    for(int i = 2; i <= n; ++i)
        if(d[i] == 1) vis[i] = 1;
    for(int i = 1; i <= n; ++i)
        read(a[i]);
    vec.pb(-1);
    // 离散化
    for(int i = 1; i <= n; ++i) {
        if(vis[i]) vec.pb(a[i]);
        else a[i] = mul(a[i], 796898467);
    }
    sort(vec.begin(), vec.end());
    vec.erase(unique(vec.begin(), vec.end()), vec.end());
    sz = vec.size() - 1;
    init();
    // 线段树合并
    for(int i = 1; i <= n; ++i) {
        if(!vis[i]) continue;
        int loc = lower_bound(vec.begin(), vec.end(), a[i]) - vec.begin();
        modify(rt[i], 1, sz, loc);
    }
    sum[0] = 0;
// 变成0表示sum[0] = 0, 即该点代表区间的概率值之和为0
    dfs(1);
    Dfs(rt[1]);
    printf("%d\n", ans);
    return 0;
}