This documentation is automatically generated by online-judge-tools/verification-helper
$N$ 頂点からなる木について $\sum_{i = 1}^{N - 1} \sum_{j = i + 1}^{N} \text{dist}(i, j)$ を求める方法。
木dpを考える。頂点 $v$ を根とした部分木について以下の三つの情報を持つ。
$K_{v}$ はOK
$v$ に $x$ をマージするとき、 $A_{v}$ への寄与は $B_{v} \times K_{x} + B_{v} \times (K_{x} + B_{x})$ 。 $B_{v}$ への寄与は $B_{x} + K_{x}$
この問題では木を圧縮しているので、辺の長さが $1$ とは限らず、いくらかの係数をかけ合わせた値が寄与しているところがある。
#define PROBLEM "https://atcoder.jp/contests/abc359/tasks/abc359_g"
#include "../../Src/Template/IOSetting.hpp"
#include "../../Src/Graph/Tree/AuxiliaryTree.hpp"
#include <iostream>
#include <vector>
using namespace zawa;
int main() {
SetFastIO();
int N;
std::cin >> N;
std::vector<std::vector<int>> g(N);
for (int i{1} ; i < N ; i++) {
int u, v;
std::cin >> u >> v;
u--; v--;
g[u].push_back(v);
g[v].push_back(u);
}
std::vector<int> A(N);
std::vector<std::vector<int>> B(N);
for (int i{} ; i < N ; i++) {
std::cin >> A[i];
A[i]--;
B[A[i]].push_back(i);
}
AuxiliaryTree AT(g);
long long ans{};
std::vector<int> size(N);
std::vector<long long> dp(N);
for (int c{} ; c < N ; c++) {
if (B[c].empty()) {
continue;
}
int r{(int)AT.construct(B[c])};
auto dfs{[&](auto dfs, int v, int p) -> long long {
if (A[v] == c) size[v]++;
long long res{};
for (auto x : AT[v]) {
if ((int)x == p) continue;
res += dfs(dfs, x, v);
res += dp[v] * size[x] + size[v] * (size[x] * AT.parentEdgeLength(x) + dp[x]);
dp[v] += dp[x] + AT.parentEdgeLength(x) * size[x];
size[v] += size[x];
}
return res;
}};
long long val{dfs(dfs, r, -1)};
ans += val;
for (auto v : AT.current()) {
size[v] = 0;
dp[v] = 0;
}
}
std::cout << ans << '\n';
}
#line 1 "Test/AtCoder/abc359_g.test.cpp"
#define PROBLEM "https://atcoder.jp/contests/abc359/tasks/abc359_g"
#line 2 "Src/Template/IOSetting.hpp"
#line 2 "Src/Template/TypeAlias.hpp"
#include <cstdint>
#include <cstddef>
namespace zawa {
using i16 = std::int16_t;
using i32 = std::int32_t;
using i64 = std::int64_t;
using i128 = __int128_t;
using u8 = std::uint8_t;
using u16 = std::uint16_t;
using u32 = std::uint32_t;
using u64 = std::uint64_t;
using usize = std::size_t;
} // namespace zawa
#line 4 "Src/Template/IOSetting.hpp"
#include <iostream>
#include <iomanip>
namespace zawa {
void SetFastIO() {
std::cin.tie(nullptr)->sync_with_stdio(false);
}
void SetPrecision(u32 dig) {
std::cout << std::fixed << std::setprecision(dig);
}
} // namespace zawa
#line 2 "Src/Graph/Tree/AuxiliaryTree.hpp"
#line 2 "Src/Graph/Tree/LowestCommonAncestor.hpp"
#line 2 "Src/Algebra/Monoid/ChminMonoid.hpp"
#line 4 "Src/Algebra/Monoid/ChminMonoid.hpp"
#include <algorithm>
#include <optional>
namespace zawa {
template <class T, class U>
class ChminMonoidData {
private:
std::optional<T> priority_{};
U value_{};
public:
ChminMonoidData() = default;
ChminMonoidData(const U& value)
: priority_{std::nullopt}, value_{value} {}
ChminMonoidData(const T& priority, const U& value)
: priority_{priority}, value_{value} {}
constexpr bool infty() const noexcept {
return !priority_.has_value();
}
constexpr const T& priority() const noexcept {
return priority_.value();
}
constexpr const U& value() const noexcept {
return value_;
}
friend constexpr bool operator<(const ChminMonoidData& l, const ChminMonoidData& r) {
if (l.infty()) return false;
else if (r.infty()) return true;
else return l.priority() < r.priority();
}
};
template <class T, class U>
struct ChminMonoid {
using Element = ChminMonoidData<T, U>;
static Element identity() noexcept {
return Element{};
}
// タイブレークはl側を優先するようになっている。
static Element operation(const Element& l, const Element& r) noexcept {
return (r < l ? r : l);
}
};
} // namespace zawa
#line 2 "Src/DataStructure/SparseTable/SparseTable.hpp"
#line 4 "Src/DataStructure/SparseTable/SparseTable.hpp"
#include <vector>
#include <cassert>
#include <ostream>
namespace zawa {
template <class Structure>
class SparseTable {
private:
using Value = typename Structure::Element;
std::vector<u32> L;
std::vector<std::vector<Value>> dat;
public:
SparseTable() : L{}, dat{} {}
SparseTable(const std::vector<Value>& a) : L(a.size() + 1), dat{} {
for (u32 i{1} ; i < L.size() ; i++) {
L[i] = L[i - 1] + (i >> (L[i - 1] + 1));
}
dat.resize(L.back() + 1);
dat[0] = a;
for (u32 i{1}, len{2} ; i < dat.size() ; i++, len <<= 1) {
dat[i] = dat[i - 1];
for (u32 j{} ; j + len - 1 < dat[i].size() ; j++) {
dat[i][j] = Structure::operation(dat[i - 1][j], dat[i - 1][j + (len >> 1)]);
}
}
}
Value product(u32 l, u32 r) const {
assert(l <= r);
assert(l < dat[0].size());
assert(r <= dat[0].size());
u32 now{L[r - l]};
return Structure::operation(dat[now][l], dat[now][r - (1 << now)]);
}
friend std::ostream& operator<<(std::ostream& os, const SparseTable<Structure>& spt) {
for (u32 i{}, len{1} ; i < spt.dat.size() ; i++, len <<= 1) {
os << "length = " << len << '\n';
for (u32 j{} ; j + len - 1 < spt.dat[i].size() ; j++) {
os << spt.dat[i][j] << (j + len == spt.dat[i].size() ? '\n' : ' ');
}
}
return os;
}
};
} // namespace zawa
#line 6 "Src/Graph/Tree/LowestCommonAncestor.hpp"
#line 9 "Src/Graph/Tree/LowestCommonAncestor.hpp"
namespace zawa {
template <class V>
class LowestCommonAncestor {
private:
using Monoid = ChminMonoid<u32, V>;
public:
LowestCommonAncestor() = default;
LowestCommonAncestor(const std::vector<std::vector<V>>& tree, V r = V{})
: n_{tree.size()}, depth_(tree.size()), L_(tree.size()), R_(tree.size()), st_{} {
std::vector<typename Monoid::Element> init;
init.reserve(2 * size());
auto dfs{[&](auto dfs, V v, V p) -> void {
depth_[v] = (p == INVALID ? 0u : depth_[p] + 1);
L_[v] = (u32)init.size();
for (auto x : tree[v]) {
if (x == p) {
continue;
}
init.emplace_back(depth_[v], v);
dfs(dfs, x, v);
}
R_[v] = (u32)init.size();
}};
dfs(dfs, r, INVALID);
st_ = SparseTable<Monoid>(init);
}
V operator()(V u, V v) const {
assert(verify(u));
assert(verify(v));
if (L_[u] > L_[v]) {
std::swap(u, v);
}
return u == v ? u : st_.product(L_[u], R_[v]).value();
}
V lca(V u, V v) const {
return (*this)(u, v);
}
inline u32 depth(V v) const noexcept {
assert(verify(v));
return depth_[v];
}
u32 distance(V u, V v) const {
assert(verify(u));
assert(verify(v));
return depth(u) + depth(v) - 2u * depth((*this)(u, v));
}
bool isAncestor(V p, V v) const {
assert(verify(p));
assert(verify(v));
return L_[p] <= L_[v] and R_[v] <= R_[p];
}
protected:
u32 left(V v) const noexcept {
return L_[v];
}
inline usize size() const {
return n_;
}
inline bool verify(V v) const {
return v < (V)size();
}
private:
static constexpr V INVALID{static_cast<V>(-1)};
usize n_{};
std::vector<u32> depth_, L_, R_;
SparseTable<Monoid> st_;
};
} // namespace zawa
#line 4 "Src/Graph/Tree/AuxiliaryTree.hpp"
#line 6 "Src/Graph/Tree/AuxiliaryTree.hpp"
namespace zawa {
template <class V>
class AuxiliaryTree : public LowestCommonAncestor<V> {
public:
using Super = LowestCommonAncestor<V>;
AuxiliaryTree() = default;
AuxiliaryTree(const std::vector<std::vector<V>>& T, V r = 0u)
: Super{ T, r }, T_(T.size()), dist_(T.size()), used_(T.size()) {}
V construct(const std::vector<V>& vs) {
assert(vs.size());
clear();
vs_ = vs;
return build();
}
const std::vector<V>& operator[](V v) const {
assert(Super::verify(v));
return T_[v];
}
inline bool contains(V v) const {
assert(Super::verify(v));
return used_[v];
}
inline u32 parentEdgeLength(V v) const {
assert(contains(v));
return dist_[v];
}
std::vector<V> current() const {
return vs_;
}
private:
std::vector<std::vector<V>> T_{};
std::vector<V> vs_{};
std::vector<u32> dist_{};
std::vector<bool> used_{};
void addEdge(V p, V v) {
assert(Super::depth(p) < Super::depth(v));
T_[p].push_back(v);
T_[v].push_back(p);
dist_[v] = Super::depth(v) - Super::depth(p);
}
V build() {
std::sort(vs_.begin(), vs_.end(), [&](V u, V v) -> bool {
return Super::left(u) < Super::left(v);
});
vs_.erase(std::unique(vs_.begin(), vs_.end()), vs_.end());
usize k{vs_.size()};
std::vector<V> stack;
stack.reserve(2u * vs_.size());
stack.emplace_back(vs_[0]);
for (usize i{} ; i + 1 < k ; i++) {
if (!Super::isAncestor(vs_[i], vs_[i + 1])) {
V w{Super::lca(vs_[i], vs_[i + 1])};
V l{stack.back()};
stack.pop_back();
while (stack.size() and LowestCommonAncestor<V>::depth(w) < LowestCommonAncestor<V>::depth(stack.back())) {
addEdge(stack.back(), l);
l = stack.back();
stack.pop_back();
}
if (stack.empty() or stack.back() != w) {
stack.emplace_back(w);
vs_.emplace_back(w);
}
addEdge(w, l);
}
stack.emplace_back(vs_[i + 1]);
}
while (stack.size() > 1u) {
V l{stack.back()};
stack.pop_back();
addEdge(stack.back(), l);
}
for (V v : vs_) {
used_[v] = true;
}
return stack.back();
}
void clear() {
for (V v : vs_) {
T_[v].clear();
used_[v] = false;
dist_[v] = 0u;
}
vs_.clear();
}
};
} // namespace zawa
#line 5 "Test/AtCoder/abc359_g.test.cpp"
#line 8 "Test/AtCoder/abc359_g.test.cpp"
using namespace zawa;
int main() {
SetFastIO();
int N;
std::cin >> N;
std::vector<std::vector<int>> g(N);
for (int i{1} ; i < N ; i++) {
int u, v;
std::cin >> u >> v;
u--; v--;
g[u].push_back(v);
g[v].push_back(u);
}
std::vector<int> A(N);
std::vector<std::vector<int>> B(N);
for (int i{} ; i < N ; i++) {
std::cin >> A[i];
A[i]--;
B[A[i]].push_back(i);
}
AuxiliaryTree AT(g);
long long ans{};
std::vector<int> size(N);
std::vector<long long> dp(N);
for (int c{} ; c < N ; c++) {
if (B[c].empty()) {
continue;
}
int r{(int)AT.construct(B[c])};
auto dfs{[&](auto dfs, int v, int p) -> long long {
if (A[v] == c) size[v]++;
long long res{};
for (auto x : AT[v]) {
if ((int)x == p) continue;
res += dfs(dfs, x, v);
res += dp[v] * size[x] + size[v] * (size[x] * AT.parentEdgeLength(x) + dp[x]);
dp[v] += dp[x] + AT.parentEdgeLength(x) * size[x];
size[v] += size[x];
}
return res;
}};
long long val{dfs(dfs, r, -1)};
ans += val;
for (auto v : AT.current()) {
size[v] = 0;
dp[v] = 0;
}
}
std::cout << ans << '\n';
}