首页 > 代码库 > bzoj4503 两个串

bzoj4503 两个串

传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=4503

【题解】

我们设匹配函数f = (a[i]-b[i])^2*b[i]

那么展开f,做卷积就能得出f的值了

对于t[i]==‘?‘,b[i]=0,显然当f=0表示匹配,那么直接FFT即可。

我记得有道类似的题,两串都有通配符匹配,就把f改成(a[i]-b[i])^2a[i]b[i]就ok了。

技术分享
# include <math.h>
# include <stdio.h>
# include <string.h>
# include <algorithm>
// # include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int M = 5e5 + 10;
const int mod = 1e9+7;
const double pi = acos(-1.0); 

# define RG register
# define ST static

struct cp {
    double x, y;
    cp() {}
    cp(double x, double y) : x(x), y(y) {}
    friend cp operator +(cp a, cp b) {
        return cp(a.x+b.x, a.y+b.y);
    }
    friend cp operator -(cp a, cp b) {
        return cp(a.x-b.x, a.y-b.y);
    }
    friend cp operator *(cp a, cp b) {
        return cp(a.x*b.x-a.y*b.y, a.x*b.y+a.y*b.x);
    }
};

char str[M]; 
int S[M], T[M];
int n1, n2;

namespace FFT {
    const int M = 6e5 + 10;
    cp w[2][M]; 
    int n, lst[M];
    inline void init(int _n) {
        n = 1;
        while(n < _n) n <<= 1;
        for (int i=0; i<n; ++i) w[0][i] = cp(cos(pi*2/n*i), sin(pi*2/n*i)), w[1][i] = cp(w[0][i].x, -w[0][i].y); 
        int len = 1;
        while((1<<len) < n) len ++; 
        for (int i=0; i<n; ++i) {
            int t = 0;
            for (int j=0; j<len; ++j) if(i&(1<<j)) t |= (1<<(len-j-1));
            lst[i] = t;
        }
    }
    inline void DFT(cp *a, int op) {
        cp *o = w[op];
        for (int i=0; i<n; ++i) if(i < lst[i]) swap(a[i], a[lst[i]]);
        for (int len=2; len<=n; len<<=1) {
            int m = len>>1;
            for (cp *p = a; p != a+n; p+=len) {
                 for (int k=0; k<m; ++k) {
                      cp t = o[n/len*k] * p[k+m];
                     p[k+m] = p[k] - t;
                     p[k] = p[k] + t; 
                      
                }
            }
        }
        if(op == 1) {
            for (int i=0; i<n; ++i) a[i].x = a[i].x/(double)n, a[i].y = a[i].y/(double)n;
        }
    }
}

cp a[M], b[M], c[M]; 
double cnt = 0;
int ans[M], ansn = 0; 

int main() {
    scanf("%s", str); 
    n1 = strlen(str); 
    for (int i=0; i<n1; ++i) S[i] = str[i] - a + 1;
    scanf("%s", str);
    n2 = strlen(str);
    for (int i=0; i<n2; ++i) {
        if(str[i] == ?) T[i] = 0;
        else T[i] = str[i] - a + 1;
    }
    int m = n1 + n2 - 1;
    FFT::init(m);
    for (int i=0; i<FFT::n; ++i) a[i].x = a[i].y = b[i].x = b[i].y = 0; 
    for (int i=0; i<n1; ++i) a[i].x = S[i]*S[i];
    for (int i=0; i<n2; ++i) {
        int pos = n2 - 1 - i; 
        b[pos].x = T[i]; 
    }
    FFT::DFT(a, 0); FFT::DFT(b, 0);
    cnt = 0;
    for (int i=0; i<FFT::n; ++i) c[i] = a[i] * b[i];
    for (int i=0; i<FFT::n; ++i) a[i].x = a[i].y = b[i].x = b[i].y = 0; 
    for (int i=0; i<n1; ++i) a[i].x = S[i]*2;
    for (int i=0; i<n2; ++i) {
        int pos = n2 - 1 - i; 
        b[pos].x = T[i]*T[i]; 
        cnt = cnt + T[i]*T[i]*T[i]; 
    }
    FFT::DFT(a, 0); FFT::DFT(b, 0);
    for (int i=0; i<FFT::n; ++i) c[i] = c[i] - a[i] * b[i];
    FFT::DFT(c, 1);
    for (int i=0; i<=n1-n2; ++i)
        if(int(c[n2-1+i].x + cnt + 0.5) == 0) ans[++ansn] = i; 
    printf("%d\n", ansn);
    for (int i=1; i<=ansn; ++i) printf("%d\n", ans[i]);
    return 0;
}
View Code

 

bzoj4503 两个串