0
点赞
收藏
分享

微信扫一扫

loj #575. 「LibreOJ NOI Round #2」不等关系

吃面多放酱 2022-02-14 阅读 50

https://loj.ac/p/575

弱化版:AT4541 Permutation

前缀和优化的DP没有什么前途,我们考虑容斥

先忽略所有的 “ > ” “>” >(全部强制满足),把剩下的 “ < " “<" <"是否满足条件看作是 0 / 1 0/1 0/1

那么要求的就是 111..111 111..111 111..111

容斥一下就是
111...111 = 111....11 ? − 111...110 = 111....11 ? − 111...1 ? 0 + 111...100 111...111=111....11?-111...110\\=111....11?-111...1?0+111...100 111...111=111....11?111...110=111....11?111...1?0+111...100
以此类推

d p [ i ] dp[i] dp[i]表示钦定长度为 i i i的前缀合法的排列数

d p [ i ] = ∑ j = 0 i = 1 [ s t [ j ] = = ′ < ′ ]    d p [ j ] × ( − 1 ) c n t [ i − 1 ] − c n t [ j ] × ( i i − j ) dp[i]=\sum_{j=0}^{i=1}[st[j]=='<'] \ \ dp[j]\times (-1)^{cnt[i-1]-cnt[j]} \times \binom{i}{i-j} dp[i]=j=0i=1[st[j]==<]  dp[j]×(1)cnt[i1]cnt[j]×(iji)

c n t [ i ] cnt[i] cnt[i]表示前 i i i个有几个是 “ < ” “<” <

显然可以轻易用分治NTT优化成 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)

code:

#include<bits/stdc++.h>
#define N 400050
#define poly vector<int>
#define mod 998244353
using namespace std;
int add(int x, int y) { x += y;
    if(x >= mod) x -= mod;
    return x;
}
int sub(int x, int y) { x -= y;
    if(x < 0) x += mod;
    return x;
}
int mul(int x, int y) {
    return 1ll * x * y % mod;
}
int qpow(int x, int y) {
    int ret = 1;
    for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ret = mul(ret, x);
    return ret;
}

const int G = 3;
const int Ginv = qpow(G, mod - 2);
int rev[N << 1];
void ntt(int *a, int n, int o) {
    for(int i = 1; i < n; i ++) rev[i] = (rev[i >> 1] >> 1) | ((n >> 1) * (i & 1));
    for(int i = 1; i < n; i ++) if(i > rev[i]) swap(a[i], a[rev[i]]);

    //for(int i = 0; i < n; i ++) printf("%d ", a[i]); printf("\n");

    for(int len = 2; len <= n; len <<= 1) {
        int w0 = qpow((o == 1)? G : Ginv, (mod - 1) / len);
        for(int j = 0; j < n; j += len) {
            int wn = 1;
            for(int k = j; k < j + (len >> 1); k ++, wn = mul(wn, w0)) {
                int X = a[k], Y = mul(wn, a[k + (len >> 1)]);
                a[k] = add(X, Y), a[k + (len >> 1)] = sub(X, Y);
            }
        }
    }
    int ninv = qpow(n, mod - 2);
    if(o == -1)
        for(int i = 0; i < n; i ++) a[i] = mul(a[i], ninv);

    //for(int i = 0; i < n; i ++) printf("%d ", a[i]); printf("\n\n");
}
#define sz(A) ((int)A.size()) 
int a[N << 1], b[N << 1];
poly operator * (const poly& A, const poly& B) {
    for(int i = 0; i < sz(A); i ++) a[i] = A[i];
    for(int i = 0; i < sz(B); i ++) b[i] = B[i];

    poly C; C.resize(sz(A) + sz(B) - 1);
    int len = 1;
    for(; len <= sz(A) + sz(B) - 1 ; ) len <<= 1;
    
    //  for(int i = 0; i < sz(A); i ++) printf("%d ", a[i]); printf("\n");
    //  for(int i = 0; i < sz(B); i ++) printf("%d ", b[i]); printf("\n");
    
    ntt(a, len, 1), ntt(b, len, 1);
    // for(int i = 0; i < len; i ++) printf("%d ", a[i]); printf("\n");
    // for(int i = 0; i < len; i ++) printf("%d ", b[i]); printf("\n");
    // for(int i = 0; i < len; i ++) printf("%lld ", 1ll * a[i] * b[i] % mod); printf("\n");
    for(int i = 0; i < len; i ++) a[i] = mul(a[i], b[i]);
   ntt(a, len, -1);
     //for(int i = 0; i < len; i ++) printf(" %d ", a[i]); printf("\n\n");

    for(int i = 0; i < sz(A) + sz(B) - 1; i ++) C[i] = a[i];
    for(int i = 0; i <= len; i ++) a[i] = b[i] = 0;
    return C;
}

int fac[N], ifac[N];
void init(int n) {
    fac[0] = 1;
    for(int i = 1; i <= n; i ++) fac[i] = mul(fac[i - 1], i);
    ifac[n] = qpow(fac[n], mod - 2);
    for(int i = n - 1; i >= 0; i --) ifac[i] = mul(ifac[i + 1], i + 1);
}
char st[N];
int n, f[N], cnt[N];
void cdq(int l, int r) {
    if(l == r) {
        if(l) {
            if(cnt[l - 1] & 1) f[l] = (mod - f[l]) % mod;
        }
        return ;
    }

    int mid = (l + r) >> 1;
    cdq(l, mid);

    poly a, b;
    for(int i = l; i <= mid; i ++) {
        int o = f[i];
        if(cnt[i] & 1) o = (mod - o) % mod;
        if(st[i] == '>') o = 0;
        a.push_back(o);
    }
    
    for(int i = l; i <= r; i ++) b.push_back(ifac[i - l]);
  //  printf("--- %d %d %d\n", l, mid, r);
    // for(int i = 0; i < sz(a); i ++) printf("%d ", a[i]); printf("\n");
    // for(int i = 0; i < sz(a); i ++) printf("%d ", b[i]); printf("\n");
    poly c = a * b;
   // for(int i = 0; i < sz(c); i ++) printf("%d ", c[i]); printf("\n\n");
    for(int i = mid + 1; i <= r; i ++) f[i] = add(f[i], c[i - l]);

    cdq(mid + 1, r);
}
int main() {
    // freopen("a.in","r",stdin);
    // freopen("a.out","w",stdout);
    scanf("%s", st + 1);
    n = strlen(st + 1) + 1; init(n);
    for(int i = 1; i < n; i ++) cnt[i] = cnt[i - 1] + (st[i] == '<');

    f[0] = 1;
    cdq(0, n);
    
   // if(!(cnt[n - 1] & 1)) f[n] = (mod - f[n]) % mod;
    f[n]= mul(f[n], fac[n]);

    printf("%d", f[n]);
    return 0;
}
举报

相关推荐

0 条评论