ABC221E - LEQ

各$A_{i}$について,$i < j$かつ$A_{i} \leq A_{j}$を満たす$j$の個数を求めようとして沼った($O(N^{2})$になる).

考え方

 $i < j, A_{i} \leq A_{j}$を満たす$i,j$を固定したとき,$i$と$j$の間にある$j - i - 1$コの要素を入れるか入れないかで$2^{j - i - 1}$のパターンがある.よって,
\begin{aligned}
\mathrm{ans}
&= \sum_{\{(i,j) \mid i < j, A_{i} \leq A_{j}\}} 2^{j - i - 1} \\
&=\sum_{j} 2^{j}
\textcolor{red}{\sum_{ \substack{i < j \\ (A_{i} \leq A_{j}) }} 2^{-(i + 1)}}
\end{aligned}
を計算すればよい.以下,$j$でループさせることにして,赤字の部分をどう実現するか考える.

 セグメント木を使えば「ある値以下の和」は計算できる.つまり,

\begin{aligned}
\sum_{A_{i} \leq A_{j} } f(A_{i})
= \sum_{x\in [0, A_{j} + 1)} f(x)
\end{aligned}
の部分はセグメント木$\{f(x)\}_{x}$で実現できる.

 また,$A_{j}$を$j$の小さい方から見て行き,上の和をとったあとでセグメント木の$A_{j}$の要素$f(A_{i})$に$2^{-(j + 1)}$を加算することにすれば,自動的に

\begin{aligned}
\sum_{x\in [0, A_{j} + 1)} f(x)
= \sum_{\{i \mid \textcolor{red}{i < j}, A_{i} \leq A_{j}\}} f(A_{i})
\end{aligned}
が満たされる(まだ見ていないものの寄与は加算されていないため).

 $A$の取り得る値の範囲$[1,10^{9}]$と大きいので,座標圧縮をして以上を行う.

【参考】

回答例

セグメント木は以下を使わせていただきました:

# https://judge.yosupo.jp/submission/7795
class SegTree:
    X_unit = 0
    
    def my_sum(self, x, y):
      return (x + y) % M
    X_f = my_sum

    def __init__(self, N):
        self.N = N
        self.X = [self.X_unit] * (N + N)

    def build(self, seq):
        for i, x in enumerate(seq, self.N):
            self.X[i] = x
        for i in range(self.N - 1, 0, -1):
            self.X[i] = self.X_f(self.X[i << 1], self.X[i << 1 | 1])

    def set_val(self, i, x):
        i += self.N
        self.X[i] = x
        while i > 1:
            i >>= 1
            self.X[i] = self.X_f(self.X[i << 1], self.X[i << 1 | 1])

    def fold(self, L, R):
        L += self.N
        R += self.N
        vL = self.X_unit
        vR = self.X_unit
        while L < R:
            if L & 1:
                vL = self.X_f(vL, self.X[L])
                L += 1
            if R & 1:
                R -= 1
                vR = self.X_f(self.X[R], vR)
            L >>= 1
            R >>= 1
        return self.X_f(vL, vR)
# ---
N = int(input())
A = list(map(int, input().split()))
M = 998244353
inv2 = pow(2, M - 2, M)

D = {x:i for i, x in enumerate(sorted(set(A)))}
L = len(D)
seg = SegTree(L)
seg.build([0] * L)

ans = 0
for j, a in enumerate(A):
  ans += pow(2, j, M) * seg.fold(0, D[a] + 1)
  ans %= M
  
  new = seg.X[seg.N + D[a]] + pow(inv2, j + 1, M)
  new %= M
  seg.set_val(D[a], new)
  
print(ans)