ABC248C - Dice Sum

考え方(包除原理)

dpで解いたが,組み合わせ+包除原理でも解けそうと思いつつ実装できなかったのでメモ.
Editorial - UNIQUE VISION Programming Contest 2022(AtCoder Beginner Contest 248)

$\sum_{i=1}^{N} A_{i} = x$の場合の答え$\mathrm{ans}(x)$を求める.最終的な答えは,$\sum_{x=N}^{K} \mathrm{ans}(x)$となる.

全集合を

\begin{aligned}
S = \{ A \mid 1 \leq A_{i} \,(i = 1,..., N) , \sum_{i=1}^{N} A_{i} = x\}
\end{aligned}
とする(以下,$\sum_{i=1}^{N} A_{i} = x$を略する).

\begin{aligned}
X_{i} = \{ A=(A_{1},...,A_{N}) \mid 1 \leq A_{i} \leq M \}
\end{aligned}
とすると,
\begin{aligned}
& (X_{i})^{c} = S \setminus X_{i} \\
&= \{ A=(A_{1},...,A_{N}) \mid M + 1 \leq A_{i} \}
\end{aligned}
であり,求めたいのは
\begin{aligned}
& \Biggl| \bigcap_{i = 1}^{N} X_{i} \Biggr|
= \Biggl| \Biggl(\bigcup_{i = 1}^{N} (X_{i})^{c} \Biggr)^{c} \Biggr|
= \Biggl| S \setminus \bigcup_{i = 1}^{N} (X_{i})^{c} \Biggr|\\
&= |S| - \Biggl| \bigcup_{i = 1}^{N} (X_{i})^{c} \Biggr|
\end{aligned}
となる.ここで,
\begin{aligned}
& \Biggl| \bigcup_{i = 1}^{N} (X_{i})^{c} \Biggr| \\
&=\sum_{k = 1}^{N}
(-1)^{k + 1}
\Biggl(\sum_{1 \leq i_{1} < \cdots < i_{k} \leq N} |(X_{i_{1}})^{c} \cap \cdots \cap (X_{i_{k}})^{c}| \Biggr)
\end{aligned}
である(Inclusion–exclusion principle - Wikipedia).


よって,「$M + 1 \leq A_{i}$を満たす$i$の個数が$L$」であるような$S$の部分集合を$S_{L}$とすれば,答えは

\begin{aligned}
& \mathrm{ans}(x) = S_{0} = \Biggl| \bigcap_{i = 1}^{N} X_{i} \Biggr|
= |S| + \sum_{L=1}^{N} (-1)^{L} |S_{L}|
\end{aligned}
となる.ただし,
\begin{aligned}
&|S|
= \binom{K - 1}{N - 1} \\
&|S_{L}|
=
\begin{cases}
\, \binom{N}{L} \binom{K - ML - 1}{N - 1} & (K - ML \geq N)\\
\, 0 & (\text{otherwise})
\end{cases}
\end{aligned}
である(例えば$|S|$は,$K$個の○の間の$K-1$の隙間に$N - 1$個の棒を挟む方法の数).

回答例

from functools import reduce
N, M, K = map(int, input().split())
mod = 998244353

def ncr(n, r):
    r = min(r, n - r)
    numer = reduce(lambda x, y: x * y % mod, range(n, n - r, -1), 1)
    denom = reduce(lambda x, y: x * y % mod, range(1, r + 1), 1)
    return numer * pow(denom, mod - 2, mod) % mod

ans = 0
for x in range(N, K + 1):
    for L in range(N + 1):
        if x - M * L < N:
            break
        SL = ncr(N, L) * ncr(x - M * L - 1, N - 1)
        SL %= mod
        if L % 2:
            ans -= SL
        else:
            ans += SL
        ans %= mod

print(ans)