考え方:公式解説の方法
公式解説の方法(Editorial - Exawizards Programming Contest 2021(AtCoder Beginner Contest 222))を考える.まず,各辺を何回通るか求める.頂点$A_{i}$と頂点$A_{i+1}$を結ぶ辺を通る回数を$C_{i}$とする.
すると,
\begin{aligned}
& R - B = K \\
&
\begin{cases}
\, \displaystyle R = \sum_{i\in S_{R}} C_{i} \\[1.5em]
\, \displaystyle B = \sum_{i\in S_{B}} C_{i} & (S_{B} = \{1,2,...,N-1\} \setminus S_{R})
\end{cases}
\end{aligned}
を満たす,$S_{R} \subset \{1,2,...,N-1\}$の選び方が何通りか求める問題になる.& R - B = K \\
&
\begin{cases}
\, \displaystyle R = \sum_{i\in S_{R}} C_{i} \\[1.5em]
\, \displaystyle B = \sum_{i\in S_{B}} C_{i} & (S_{B} = \{1,2,...,N-1\} \setminus S_{R})
\end{cases}
\end{aligned}
さらに,
\begin{aligned}
R - B = 2R - \sum_{i=1}^{N-1} C_{i}
\end{aligned}
であるから,R - B = 2R - \sum_{i=1}^{N-1} C_{i}
\end{aligned}
\begin{aligned}
R = \frac{\displaystyle K + \sum_{i=1}^{N-1} C_{i}}{2}
\end{aligned}
を満たす,$S_{R} \subset \{1,2,...,N-1\}$の選び方が何通りか求める問題になる.R = \frac{\displaystyle K + \sum_{i=1}^{N-1} C_{i}}{2}
\end{aligned}
ただし,$R$は非負整数であることに注意する.
ここで,dp[i][j]=
「$\displaystyle \sum_{k\in S_{R}} C_{k} = j$となる$S_{R}\subset \{1,2,..., i \}$の選び方」とすれば,
\begin{aligned}
\mathrm{dp}[i + 1][j]
&=\mathrm{dp}[i][j] + \mathrm{dp}[i][j - 2C_{i+1}]
\end{aligned}
となる.\mathrm{dp}[i + 1][j]
&=\mathrm{dp}[i][j] + \mathrm{dp}[i][j - 2C_{i+1}]
\end{aligned}
回答例
さらに,配列再利用をすると,実行時間が1/3くらいになる.また,
MOD
を定義せずに,M
を上書きして使うと,メチャクチャ遅くなる(なんで?)・
from collections import deque N, M, K = map(int, input().split()) A = list(map(int, input().split())) MOD = 998244353 INF = 1 << 60 G = [[] for _ in range(N)] for i in range(N - 1): U, V = map(int, input().split()) U -= 1 V -= 1 G[U].append((V, i)) G[V].append((U, i)) C = [0] * (N - 1) for a1, a2 in zip(A, A[1:]): a1 -= 1 a2 -= 1 dist = [INF] * N dist[a1] = 0 prev = [[] for _ in range(N)] que = deque([a1]) while que: cur = que.popleft() for chi, i in G[cur]: if dist[cur] + 1 < dist[chi]: dist[chi] = dist[cur] + 1 prev[chi] = (cur, i) que.append(chi) cur = a2 while cur != a1: par, i = prev[cur] C[i] += 1 cur = par tmp = K + sum(C) if tmp % 2 == 1 or tmp < 0: exit(print(0)) R = tmp // 2 dp = [[0] * (R + 1) for _ in range(N)] dp[0][0] = 1 for i in range(N - 1): for j in range(R + 1): dp[i + 1][j] = dp[i][j] if j - C[i] >= 0: dp[i + 1][j] += dp[i][j - C[i]] dp[i + 1][j] %= MOD print(dp[N - 1][R])
【参考】Submission #26450962 - Exawizards Programming Contest 2021(AtCoder Beginner Contest 222)