首页 > 代码库 > Codeforces Round #129 (Div. 1)E. Little Elephant and Strings

Codeforces Round #129 (Div. 1)E. Little Elephant and Strings

Codeforces Round #129 (Div. 1)E. Little Elephant and Strings

题意:给出n个字符串,问每个字符串有多少个子串(不同位置,相同的子串视作不同)至少出现在这n个字符串中的k个当中。

解法:这题学到了一个SAM的新技能,对于这多个串,建SAM的时候,不是把它们连在一起,建立SAM,而是先给它们建立Trie树,然后广搜这棵Trie树,对于Trie树上的V节点,在建SAM的时候,它应该接在Trie树上他的父亲节点后面,我们用TtoM[U]表示Trie树上的U节点映射到SAM上的标号。这样建立SAM的优点是,我找任何一个字符串的任何一个前缀,它匹配的的SAM上的节点的代表串必然是这个前缀。我们先记住这个东西,怎么用等会儿看。我们要求的是每个字符串的所有子串至少出现在K个字符串中,那么我们先看看所有的子串中,有哪些子串是出现在了k个字符串中,表达在SAM上就是有哪些节点被K个字符串匹配到过。我们用cnt[u]表示u节点被几个字符串匹配过。我们每次拿出一个字符串,它能给一些节点的cnt[]值贡献1,这些节点,就是这个字符串的每个前缀在sam中的节点到根的链的并集,这个用LCA求就好了。统计完cnt[]之后,看每个节点的cnt值是否大于等于k,是的话,这个节点u上就有val[u]-val[fa[u]]个子串是被k个以上字符串包含的,用add[u]表示这个值。最后,算每个字符串的答案的时候,就是这个字符串的每个前缀映射到SAM上的节点到根的链上的add之和。

代码:

#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<vector>
#include<queue>
#define ll __int64
using namespace std ;

const int maxn = 111111 ;
const int N = maxn << 1;
struct Edge {
    int to , next ;
} edge[N] ;
int head[N] , tot , f[N<<1] ;
void new_edge ( int a , int b ) {
    edge[tot].to = b ;
    edge[tot].next = head[a] ;
    head[a] = tot ++ ;
}

char s[maxn] , s1[maxn] ; int l[maxn] , len ;
int TtoM[maxn<<1] ;
struct LCA {
    int dp[22][N<<1] ;
    int to[N] , tim[N] ;
    int tot , n ;
    int MIN ( int a , int b ) {
        return tim[a] < tim[b] ? a : b ;
    }
    void init () {
        tot = 0 ;
        n = 0 ;
    }
    void dfs ( int u , int fa ) {
        tim[u] = ++ tot ;
        for ( int i = head[u] ; i != -1 ; i = edge[i].next ) {
            int v = edge[i].to ;
            if ( v == fa ) continue ;
            dfs ( v , u ) ;
            dp[0][++n] = u ;
        }
        dp[0][++n] = u ;
        to[u] = n ;
    }
    void rmq () {
        for ( int i = 1 ; i <= 20 ; i ++ ) {
            for ( int j = 1 ; j + (1<<i) - 1 <= n ; j ++ ) {
                dp[i][j] = MIN ( dp[i-1][j] , dp[i-1][j+(1<<i-1)] ) ;
            }
        }
    }
    int query ( int a , int b ) {
        a = to[a] , b = to[b] ;
        if ( a > b ) swap ( a , b ) ;
        int k = b - a + 1 ;
        return MIN ( dp[f[k]][a] , dp[f[k]][b-(1<<f[k])+1] ) ;
    }
} lca ;
namespace SAM  {
    int fa[N] , val[N] , c[26][N] ;
    int cnt[N] ; int tot , last ;
    int ws[N] , wv[N] ;
    ll add[N] ;
    vector<int> vec[N] ;
    void init () ;
    void solve ( int , int ) ;
    inline int new_node ( int _val ) {
        val[++tot] = _val ;
        for ( int i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ;
        cnt[tot] = fa[tot] = add[tot] = 0 ;
        vec[tot].clear () ;
        return tot ;
    }
    int ADD ( int k , int p ) {
        int i ;
        int np = new_node ( val[p] + 1 ) ;
        while ( p && !c[k][p] ) c[k][p] = np , p = fa[p] ;
        if ( !p ) fa[np] = 1 ;
        else {
            int q = c[k][p] ;
            if ( val[q] == val[p] + 1 ) fa[np] = q ;
            else {
                int nq = new_node ( val[p] + 1 ) ;
                for ( i = 0 ; i < 26 ; i ++ )
                    c[i][nq] = c[i][q] ;
                fa[nq] = fa[q] ;
                fa[q] = fa[np] = nq ;
                while ( p && c[k][p] == q ) c[k][p] = nq , p = fa[p] ;
            }
        }
        return np ;
    }
    void SORT () {
        for ( int i = 0 ; i < maxn ; i ++ ) wv[i] = 0 ;
        for ( int i = 1 ; i <= tot ; i ++ ) wv[val[i]] ++ ;
        for ( int i = 1 ; i < maxn ; i ++ ) wv[i] += wv[i-1] ;
        for ( int i = 1 ; i <= tot ; i ++ ) ws[wv[val[i]]--] = i ;
    }
}
namespace Trie {
    int c[26][maxn] , tot ;
    int new_node () {
        for ( int i = 0 ; i < 26 ; i ++ )
            c[i][tot] = 0 ;
        return tot ++ ;
    }
    void init () {
        tot = 0 ;
        new_node () ;
    }
    void insert ( int n ) {
        for ( int i = 1 ; i <= n ; i ++ ) {
            int now = 0 ;
            for ( int j = l[i] ; j < l[i+1] ; j ++ ) {
                int k = s[j] - 'a' ;
                if ( !c[k][now] ) c[k][now] = new_node () ;
                now = c[k][now] ;
            }
        }
    }
}
queue<int> Q ;
void SAM::init () {
    tot = 0 ;
    TtoM[0] = new_node ( 0 ) ;
    Q.push ( 0 ) ;
#define v Trie::c[k][u]
    while ( !Q.empty () ) {
        int u = Q.front () ; Q.pop () ;
        for ( int k = 0 ; k < 26 ; k ++ )
            if ( v ){
                TtoM[v]=ADD(k,TtoM[u]) ;
                Q.push ( v ) ;
            }
    }
}

int cmp ( int a , int b ) {
    return lca.tim[a] < lca.tim[b] ;
}
int sta[maxn] ;
void SAM::solve ( int n , int k ) {
    SORT () ;
    for ( int i = 2 ; i <= tot ; i ++ ) {
        new_edge ( fa[i] , i ) ;
    }
    lca.dfs ( 1 , 0 ) ; lca.rmq () ;
    for ( int i = 1 ; i <= n ; i ++ ) {
        int u = 0 ;
        int top = 0 ;
        for ( int j = l[i] ; j < l[i+1] ; j ++ ) {
            int k = s[j] - 'a' ;
            u = v ;
            sta[++top] = TtoM[u];
            cnt[TtoM[u]] ++ ;
        }
        sort ( sta + 1 , sta + top + 1 , cmp ) ;
        for ( int j = 2 ; j <= top ; j ++ ) {
            int w = lca.query ( sta[j-1] , sta[j] ) ;
            cnt[w] -- ;
        }
    }
    for ( int i = tot ; i >= 1 ; i -- ) {
        int p = ws[i] ;
        cnt[fa[p]] += cnt[p] ;
        if ( cnt[p] >= k ) add[p] = val[p] - val[fa[p]] ;
    }
    for ( int i = 1 ; i <= tot ; i ++ ) {
        int u = ws[i] ;
        for ( int j = head[u] ; j != -1 ; j = edge[j].next ) {
            int to = edge[j].to ;
            add[to] += add[u] ;
        }
    }
    for ( int i = 1 ; i <= n ; i ++ ) {
        int u = 0 ; ll ans = 0 ;
        for ( int j = l[i] ; j < l[i+1] ; j ++ ) {
            int k = s[j] - 'a' ;
            u = v ;
            ans += add[TtoM[u]] ;
        }
        printf ( "%I64d " , ans ) ;
    }
    puts ( "" ) ;
}
#undef v

void init () {
    tot = 0 ;
    memset ( head , -1 , sizeof ( head ) ) ;
    lca.init () ;
    Trie::init () ;
}

int main () {
    f[0] = -1 ;
    for ( int i = 1 ; i < maxn << 2 ; i ++ )
        f[i] = f[i>>1] + 1 ;
    int n , k ;
    while ( scanf ( "%d%d" , &n , &k ) != EOF ) {
        init () ;
        len = 0 ;
        for ( int i = 1 ; i <= n ; i ++ ) {
            scanf ( "%s" , s1 ) ;
            int k = strlen ( s1 ) ;
            l[i] = len ;
            for ( int j = 0 ; j < k ; j ++ )
                s[len++] = s1[j] ;
        }
        l[n+1] = len ;
        Trie::insert (n) ;
        SAM::init () ;
        SAM::solve ( n , k ) ;
    }
    return 0 ;
}
/*
3 2
abc
bc
ab
3 2
abc
ac
ab
2 2
abc
bc
1 1
bc
2 2
ab
b
4 4
abab
baba
aaabbbababa
abababababa
2 2
abab
baba
2 2
aba
bab
2 2
ab
ba
*/