ABC220F - Distance Sums 2

考え方

AtCoder Beginner Contest 220 - YouTube
頂点1からの答えがわかっていれば,頂点1に隣接する頂点$v$の答えは「(頂点1の答え) - ($v$の部分木の頂点数) + (N - $v$の部分木の頂点数)」で求められる.これは,
  • $v$から$v$の部分木の各頂点への距離は,頂点1から$v$の部分木の各頂点への距離よりも1小さく,
  • $v$から$v$の部分木「以外」の各頂点への距離は,頂点1から$v$の部分木「以外」の各頂点への距離よりも1大きく
なるからである.

つまり,ある頂点の答えがわかっていれば,その子ノードの答えを計算できる.木を深さ方向に見ていけば良いので,DFSで計算できる.

よって,

  • 部分木の頂点数
  • 頂点1の答え
を求める問題に帰着された.

ある頂点を根とする部分木の頂点数は,自身とその子ノードの部分木の頂点数をすべて足し合わせたものだから,再帰関数で求められる.

頂点1の答えはBFSで求められる.あるいは,あるノードに関する答えがわかっていれば,その親ノードの答えは,「(子ノードの答え) + (子ノードを根とする部分木の頂点数)」となる.これは,親ノードから子ノードを根とする部分木の各点への距離は,子ノードからの距離に比べて+1されるからである.


回答例(再帰関数)

import sys
sys.setrecursionlimit(10 ** 6)

N = int(input())
G = [[] for _ in range(N)]
for _ in range(N - 1):
    u, v = map(lambda x: int(x) - 1, input().split())
    G[u].append(v)
    G[v].append(u)

def dfs_calc(cur, par):
    subtree_size[cur] = 1
    dist_sum = 0
    for chi in G[cur]:
        if chi == par:
            continue
        dist_sum += dfs_calc(chi, cur)
        dist_sum += subtree_size[chi]
        subtree_size[cur] += subtree_size[chi]
    return dist_sum

subtree_size = [0] * N
ans = [0] * N
ans[0] = dfs_calc(0, -1)

def dfs_solve(cur, par):
    for chi in G[cur]:
        if chi == par:
            continue
        ans[chi] = ans[cur] - subtree_size[chi] + (N - subtree_size[chi])
        dfs_solve(chi, cur)
        
dfs_solve(0, -1)
print(*ans, sep = '\n')

回答例(スタック)

Pythonで再帰は遅い.スタックでDFSする方が早い(コードは長くなるが).
【参考】非再帰 Euler Tour を Python でやる - Qiita

スタックで頂点をめぐるとDFS(オイラーツアー順)になる(※キューで頂点をめぐるとBFSになる).再帰と同じことをするには,帰りがけに親頂点に対する処理を行えばよい.

1つ目のDFSで処理済みフラグ(seen[cur] = True)の位置に注意.~vにフラグを立ててはいけない.

N = int(input())
G = [[] for _ in range(N)]
for _ in range(N - 1):
    u, v = map(lambda x: int(x) - 1, input().split())
    G[u].append(v)
    G[v].append(u)

dist_sum = 0
seen = [False] * N
par = [-1] * N
subtree_size = [1] * N
stack = [~0, 0]
while stack:
    cur = stack.pop()
    if cur >= 0:
        seen[cur] = True
        for chi in G[cur]:
            if seen[chi]:
                continue
            stack.append(~chi)
            stack.append(chi)
            par[chi] = cur
    else:
        if ~cur == 0:
            break
        subtree_size[par[~cur]] += subtree_size[~cur]
        dist_sum += subtree_size[~cur]

ans = [0] * N
ans[0] = dist_sum
seen = [False] * N
stack = [0]
while stack:
    cur = stack.pop()
    seen[cur] = True
    for chi in G[cur]:
        if seen[chi]:
            continue
        stack.append(chi)
        ans[chi] = ans[cur] - subtree_size[chi] + (N - subtree_size[chi])

print(*ans, sep = '\n')