https://loj.ac/p/575
弱化版:AT4541 Permutation
前缀和优化的DP没有什么前途,我们考虑容斥
先忽略所有的 “ > ” “>” “>”(全部强制满足),把剩下的 “ < " “<" “<"是否满足条件看作是 0 / 1 0/1 0/1
那么要求的就是 111..111 111..111 111..111
容斥一下就是
111...111
=
111....11
?
−
111...110
=
111....11
?
−
111...1
?
0
+
111...100
111...111=111....11?-111...110\\=111....11?-111...1?0+111...100
111...111=111....11?−111...110=111....11?−111...1?0+111...100
以此类推
设 d p [ i ] dp[i] dp[i]表示钦定长度为 i i i的前缀合法的排列数
d p [ i ] = ∑ j = 0 i = 1 [ s t [ j ] = = ′ < ′ ] d p [ j ] × ( − 1 ) c n t [ i − 1 ] − c n t [ j ] × ( i i − j ) dp[i]=\sum_{j=0}^{i=1}[st[j]=='<'] \ \ dp[j]\times (-1)^{cnt[i-1]-cnt[j]} \times \binom{i}{i-j} dp[i]=j=0∑i=1[st[j]==′<′] dp[j]×(−1)cnt[i−1]−cnt[j]×(i−ji)
c n t [ i ] cnt[i] cnt[i]表示前 i i i个有几个是 “ < ” “<” “<”
显然可以轻易用分治NTT优化成 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)
code:
#include<bits/stdc++.h>
#define N 400050
#define poly vector<int>
#define mod 998244353
using namespace std;
int add(int x, int y) { x += y;
if(x >= mod) x -= mod;
return x;
}
int sub(int x, int y) { x -= y;
if(x < 0) x += mod;
return x;
}
int mul(int x, int y) {
return 1ll * x * y % mod;
}
int qpow(int x, int y) {
int ret = 1;
for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ret = mul(ret, x);
return ret;
}
const int G = 3;
const int Ginv = qpow(G, mod - 2);
int rev[N << 1];
void ntt(int *a, int n, int o) {
for(int i = 1; i < n; i ++) rev[i] = (rev[i >> 1] >> 1) | ((n >> 1) * (i & 1));
for(int i = 1; i < n; i ++) if(i > rev[i]) swap(a[i], a[rev[i]]);
//for(int i = 0; i < n; i ++) printf("%d ", a[i]); printf("\n");
for(int len = 2; len <= n; len <<= 1) {
int w0 = qpow((o == 1)? G : Ginv, (mod - 1) / len);
for(int j = 0; j < n; j += len) {
int wn = 1;
for(int k = j; k < j + (len >> 1); k ++, wn = mul(wn, w0)) {
int X = a[k], Y = mul(wn, a[k + (len >> 1)]);
a[k] = add(X, Y), a[k + (len >> 1)] = sub(X, Y);
}
}
}
int ninv = qpow(n, mod - 2);
if(o == -1)
for(int i = 0; i < n; i ++) a[i] = mul(a[i], ninv);
//for(int i = 0; i < n; i ++) printf("%d ", a[i]); printf("\n\n");
}
#define sz(A) ((int)A.size())
int a[N << 1], b[N << 1];
poly operator * (const poly& A, const poly& B) {
for(int i = 0; i < sz(A); i ++) a[i] = A[i];
for(int i = 0; i < sz(B); i ++) b[i] = B[i];
poly C; C.resize(sz(A) + sz(B) - 1);
int len = 1;
for(; len <= sz(A) + sz(B) - 1 ; ) len <<= 1;
// for(int i = 0; i < sz(A); i ++) printf("%d ", a[i]); printf("\n");
// for(int i = 0; i < sz(B); i ++) printf("%d ", b[i]); printf("\n");
ntt(a, len, 1), ntt(b, len, 1);
// for(int i = 0; i < len; i ++) printf("%d ", a[i]); printf("\n");
// for(int i = 0; i < len; i ++) printf("%d ", b[i]); printf("\n");
// for(int i = 0; i < len; i ++) printf("%lld ", 1ll * a[i] * b[i] % mod); printf("\n");
for(int i = 0; i < len; i ++) a[i] = mul(a[i], b[i]);
ntt(a, len, -1);
//for(int i = 0; i < len; i ++) printf(" %d ", a[i]); printf("\n\n");
for(int i = 0; i < sz(A) + sz(B) - 1; i ++) C[i] = a[i];
for(int i = 0; i <= len; i ++) a[i] = b[i] = 0;
return C;
}
int fac[N], ifac[N];
void init(int n) {
fac[0] = 1;
for(int i = 1; i <= n; i ++) fac[i] = mul(fac[i - 1], i);
ifac[n] = qpow(fac[n], mod - 2);
for(int i = n - 1; i >= 0; i --) ifac[i] = mul(ifac[i + 1], i + 1);
}
char st[N];
int n, f[N], cnt[N];
void cdq(int l, int r) {
if(l == r) {
if(l) {
if(cnt[l - 1] & 1) f[l] = (mod - f[l]) % mod;
}
return ;
}
int mid = (l + r) >> 1;
cdq(l, mid);
poly a, b;
for(int i = l; i <= mid; i ++) {
int o = f[i];
if(cnt[i] & 1) o = (mod - o) % mod;
if(st[i] == '>') o = 0;
a.push_back(o);
}
for(int i = l; i <= r; i ++) b.push_back(ifac[i - l]);
// printf("--- %d %d %d\n", l, mid, r);
// for(int i = 0; i < sz(a); i ++) printf("%d ", a[i]); printf("\n");
// for(int i = 0; i < sz(a); i ++) printf("%d ", b[i]); printf("\n");
poly c = a * b;
// for(int i = 0; i < sz(c); i ++) printf("%d ", c[i]); printf("\n\n");
for(int i = mid + 1; i <= r; i ++) f[i] = add(f[i], c[i - l]);
cdq(mid + 1, r);
}
int main() {
// freopen("a.in","r",stdin);
// freopen("a.out","w",stdout);
scanf("%s", st + 1);
n = strlen(st + 1) + 1; init(n);
for(int i = 1; i < n; i ++) cnt[i] = cnt[i - 1] + (st[i] == '<');
f[0] = 1;
cdq(0, n);
// if(!(cnt[n - 1] & 1)) f[n] = (mod - f[n]) % mod;
f[n]= mul(f[n], fac[n]);
printf("%d", f[n]);
return 0;
}