https://loj.ac/p/575
弱化版:AT4541 Permutation
前缀和优化的DP没有什么前途,我们考虑容斥
先忽略所有的 “ > ” “>” “>”(全部强制满足),把剩下的 “ < " “<" “<"是否满足条件看作是 0 / 1 0/1 0/1
那么要求的就是 111..111 111..111 111..111
容斥一下就是
 
    
     
      
       
        111...111
       
       
        =
       
       
        111....11
       
       
        ?
       
       
        −
       
       
        111...110
       
       
       
        =
       
       
        111....11
       
       
        ?
       
       
        −
       
       
        111...1
       
       
        ?
       
       
        0
       
       
        +
       
       
        111...100
       
      
      
       111...111=111....11?-111...110\\=111....11?-111...1?0+111...100
      
     
    111...111=111....11?−111...110=111....11?−111...1?0+111...100
 以此类推
设 d p [ i ] dp[i] dp[i]表示钦定长度为 i i i的前缀合法的排列数
d p [ i ] = ∑ j = 0 i = 1 [ s t [ j ] = = ′ < ′ ] d p [ j ] × ( − 1 ) c n t [ i − 1 ] − c n t [ j ] × ( i i − j ) dp[i]=\sum_{j=0}^{i=1}[st[j]=='<'] \ \ dp[j]\times (-1)^{cnt[i-1]-cnt[j]} \times \binom{i}{i-j} dp[i]=j=0∑i=1[st[j]==′<′] dp[j]×(−1)cnt[i−1]−cnt[j]×(i−ji)
c n t [ i ] cnt[i] cnt[i]表示前 i i i个有几个是 “ < ” “<” “<”
显然可以轻易用分治NTT优化成 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)
code:
#include<bits/stdc++.h>
#define N 400050
#define poly vector<int>
#define mod 998244353
using namespace std;
int add(int x, int y) { x += y;
    if(x >= mod) x -= mod;
    return x;
}
int sub(int x, int y) { x -= y;
    if(x < 0) x += mod;
    return x;
}
int mul(int x, int y) {
    return 1ll * x * y % mod;
}
int qpow(int x, int y) {
    int ret = 1;
    for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ret = mul(ret, x);
    return ret;
}
const int G = 3;
const int Ginv = qpow(G, mod - 2);
int rev[N << 1];
void ntt(int *a, int n, int o) {
    for(int i = 1; i < n; i ++) rev[i] = (rev[i >> 1] >> 1) | ((n >> 1) * (i & 1));
    for(int i = 1; i < n; i ++) if(i > rev[i]) swap(a[i], a[rev[i]]);
    //for(int i = 0; i < n; i ++) printf("%d ", a[i]); printf("\n");
    for(int len = 2; len <= n; len <<= 1) {
        int w0 = qpow((o == 1)? G : Ginv, (mod - 1) / len);
        for(int j = 0; j < n; j += len) {
            int wn = 1;
            for(int k = j; k < j + (len >> 1); k ++, wn = mul(wn, w0)) {
                int X = a[k], Y = mul(wn, a[k + (len >> 1)]);
                a[k] = add(X, Y), a[k + (len >> 1)] = sub(X, Y);
            }
        }
    }
    int ninv = qpow(n, mod - 2);
    if(o == -1)
        for(int i = 0; i < n; i ++) a[i] = mul(a[i], ninv);
    //for(int i = 0; i < n; i ++) printf("%d ", a[i]); printf("\n\n");
}
#define sz(A) ((int)A.size()) 
int a[N << 1], b[N << 1];
poly operator * (const poly& A, const poly& B) {
    for(int i = 0; i < sz(A); i ++) a[i] = A[i];
    for(int i = 0; i < sz(B); i ++) b[i] = B[i];
    poly C; C.resize(sz(A) + sz(B) - 1);
    int len = 1;
    for(; len <= sz(A) + sz(B) - 1 ; ) len <<= 1;
    
    //  for(int i = 0; i < sz(A); i ++) printf("%d ", a[i]); printf("\n");
    //  for(int i = 0; i < sz(B); i ++) printf("%d ", b[i]); printf("\n");
    
    ntt(a, len, 1), ntt(b, len, 1);
    // for(int i = 0; i < len; i ++) printf("%d ", a[i]); printf("\n");
    // for(int i = 0; i < len; i ++) printf("%d ", b[i]); printf("\n");
    // for(int i = 0; i < len; i ++) printf("%lld ", 1ll * a[i] * b[i] % mod); printf("\n");
    for(int i = 0; i < len; i ++) a[i] = mul(a[i], b[i]);
   ntt(a, len, -1);
     //for(int i = 0; i < len; i ++) printf(" %d ", a[i]); printf("\n\n");
    for(int i = 0; i < sz(A) + sz(B) - 1; i ++) C[i] = a[i];
    for(int i = 0; i <= len; i ++) a[i] = b[i] = 0;
    return C;
}
int fac[N], ifac[N];
void init(int n) {
    fac[0] = 1;
    for(int i = 1; i <= n; i ++) fac[i] = mul(fac[i - 1], i);
    ifac[n] = qpow(fac[n], mod - 2);
    for(int i = n - 1; i >= 0; i --) ifac[i] = mul(ifac[i + 1], i + 1);
}
char st[N];
int n, f[N], cnt[N];
void cdq(int l, int r) {
    if(l == r) {
        if(l) {
            if(cnt[l - 1] & 1) f[l] = (mod - f[l]) % mod;
        }
        return ;
    }
    int mid = (l + r) >> 1;
    cdq(l, mid);
    poly a, b;
    for(int i = l; i <= mid; i ++) {
        int o = f[i];
        if(cnt[i] & 1) o = (mod - o) % mod;
        if(st[i] == '>') o = 0;
        a.push_back(o);
    }
    
    for(int i = l; i <= r; i ++) b.push_back(ifac[i - l]);
  //  printf("--- %d %d %d\n", l, mid, r);
    // for(int i = 0; i < sz(a); i ++) printf("%d ", a[i]); printf("\n");
    // for(int i = 0; i < sz(a); i ++) printf("%d ", b[i]); printf("\n");
    poly c = a * b;
   // for(int i = 0; i < sz(c); i ++) printf("%d ", c[i]); printf("\n\n");
    for(int i = mid + 1; i <= r; i ++) f[i] = add(f[i], c[i - l]);
    cdq(mid + 1, r);
}
int main() {
    // freopen("a.in","r",stdin);
    // freopen("a.out","w",stdout);
    scanf("%s", st + 1);
    n = strlen(st + 1) + 1; init(n);
    for(int i = 1; i < n; i ++) cnt[i] = cnt[i - 1] + (st[i] == '<');
    f[0] = 1;
    cdq(0, n);
    
   // if(!(cnt[n - 1] & 1)) f[n] = (mod - f[n]) % mod;
    f[n]= mul(f[n], fac[n]);
    printf("%d", f[n]);
    return 0;
}










