首页 > 代码库 > Codeforces Round #244 (Div. 2)D (后缀自动机)

Codeforces Round #244 (Div. 2)D (后缀自动机)

Codeforces Round #244 (Div. 2)D (后缀自动机)

(标号为0的节点一定是null节点,无论如何都不能拿来用,切记切记,以后不能再错了)

这题用后缀自动机的话,对后缀自动机的很多性质有足够深刻的理解。没想过后缀数组怎么做,因为不高兴敲。。。。

题意:给出两个长度均不超过5000的字符串s1,s2,求这两个串中,都只出现一次的最短公共子串。

解题思路:求的是公共子串,然后对出现的次数又有限制,第一想法就是后缀自动机啊,后缀自动机处理子串出现次数再合适不过了。做法是这样的,先建立s1的sam,用拓扑dp,求出每个节点的代表串出现的次数。目的是什么呢?其实我是想求ok[i][j],表示s1[i] ~ s1[j]的这个子串是否只出现了一次。现在我们求出了代表串的出现次数了,怎么求这个ok[i][j]呢?拿s1在建立好的自动机上匹配,当前匹配到了s1[i],记录temp表示当前匹配的最长长度,now表示当前匹配在哪个节点。这里有一个跟AC自动机很相似的性质,匹配到了now,则一定能匹配fa[now]。那么就顺着now往上走,一直找到第一个出现次数大于1的节点p,那么以i为结尾,长度为val[p]+1到temp的子串在s1里面肯定都只出现一次了。把这个记录到ok数组里。    第二步是对s2处理了,还是一样的过程,建立sam,求出每个点的代表串出现的次数,即cnt[]数组。   第三步就要拿s1在s2的sam上进行匹配了,匹配过程类似于前面处理s1的ok数组,找出当前匹配的最长长度temp,匹配到的节点now,顺着now往上,找到第一个cnt大于1的节点p,在s2里面,以当前匹配上的子串的结尾为结尾的长度为val[p] + 1到temp的子,串必然只在s2里出现过一次。然后就枚举j,从val[p] + 1到temp,如果在s1里面,以i为结尾,长度为j的子串只出现1次(即ok[i-j+1][i] == 1),那么这个j就有可能成为答案,用其更新ans即可。

代码:

#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std ;

const int maxn = 5001 ;
bool ok[maxn][maxn] ;
int ans = 111111 ;

struct SAM  {
    int fa[maxn<<1] , val[maxn<<1] , c[26][maxn<<1] ;
    int cnt[maxn<<1] ; int tot , last ;
    int ws[maxn<<1] , wv[maxn<<1] ;

    inline int new_node ( int _val ) {
        val[++tot] = _val ;
        for ( int i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ;
        cnt[tot] = fa[tot] = 0 ;
        return tot ;
    }

    void add ( int k ) {
        int p = last , i ;
        int np = new_node ( val[p] + 1 ) ;
        while ( p && !c[k][p] ) c[k][p] = np , p = fa[p] ;
        if ( !p ) fa[np] = 1 ;
        else {
            int q = c[k][p] ;
            if ( val[q] == val[p] + 1 ) fa[np] = q ;
            else {
                int nq = new_node ( val[p] + 1 ) ;
                for ( i = 0 ; i < 26 ; i ++ )
                    c[i][nq] = c[i][q] ;
                fa[nq] = fa[q] ;
                fa[q] = fa[np] = nq ;
                while ( p && c[k][p] == q ) c[k][p] = nq , p = fa[p] ;
            }
        }
        last = np ;
    }

    void init () {
        tot = 0 ;
        last = new_node ( 0 ) ;
    }

    void SORT () {
        for ( int i = 0 ; i < maxn ; i ++ ) wv[i] = 0 ;
        for ( int i = 1 ; i <= tot ; i ++ ) wv[val[i]] ++ ;
        for ( int i = 1 ; i < maxn ; i ++ ) wv[i] += wv[i-1] ;
        for ( int i = 1 ; i <= tot ; i ++ ) ws[wv[val[i]]--] = i ;
    }

    void get_cnt ( char *s , int n ) {
        SORT () ;
        int now = 1 , i ;
        memset ( cnt , 0 , sizeof ( cnt ) ) ;
        for ( i = 1 ; i <= n ; i ++ ) {
            int k = s[i] - ‘a‘ ;
            now = c[k][now] ;
            cnt[now] ++ ;
        }
        for ( i = tot ; i >= 1 ; i -- ) {
            now = ws[i] ;
            cnt[fa[now]] += cnt[now] ;
        }
    }

    void gao ( char *s , int n ) {
        get_cnt ( s , n ) ;
        int now = 1 , i , j ;
        for ( i = 1 ; i <= n ; i ++ ) {
            int k = s[i] - ‘a‘ ;
            now = c[k][now] ;
            int p = now ;
            while ( fa[p] && cnt[p] == 1 ) p = fa[p] ;
            for ( j = 1 ; j <= i - val[p] ; j ++ )
                ok[j][i] = 1 ;
        }
    }

    void work ( char *s , int n ) {
        int temp = 0 , now = 1 , i , j ;
        for ( i = 1 ; i <= n ; i ++ ) {
            int k = s[i] - ‘a‘ ;
            if ( c[k][now] ) {
                temp ++ ; now = c[k][now] ;
                int p = now ;
                while ( fa[p] && cnt[p] == 1 ) p = fa[p] ;
                for ( j = val[p] + 1 ; j <= temp ; j ++ )
                    if ( ok[i-j+1][i] ) {
                        ans = min ( ans , j ) ;
                        break ;
                    }
            }
            else {
                while ( now && !c[k][now] ) now = fa[now] ;
                if ( !now ) now = 1 , temp = 0 ;
                else {
                    temp = val[now] + 1 ;
                    now = c[k][now] ;
                    int p = now ;
                    while ( fa[p] && cnt[p] == 1 ) p = fa[p] ;
                    for ( j = val[p] + 1 ; j <= temp ; j ++ )
                        if ( ok[i-j+1][i] ) {
                            ans = min ( ans , j ) ;
                            break ;
                        }
                }
            }
        }
    }

} ac ;
char s1[maxn] , s2[maxn] ;

int main () {
    scanf ( "%s" , s1 + 1 ) ;
    ac.init () ;
    int n = strlen ( s1 + 1 ) , i , j ;
    for ( i = 1 ; i <= n ; i ++ )
        ac.add ( s1[i] - ‘a‘ ) ;
    ac.gao ( s1 , n ) ;
    scanf ( "%s" , s2 + 1 ) ;
    ac.init () ;
    int m= strlen ( s2 + 1 ) ;
    for ( i = 1 ; i <= m ; i ++ )
        ac.add ( s2[i] - ‘a‘ ) ;
    ac.get_cnt ( s2 , m ) ;
    ac.work ( s1 , n ) ;
    if ( ans == 111111 ) puts ( "-1" ) ;
    else printf ( "%d\n" , ans ) ;
    return 0 ;
}