题目大意为计算 SSS 串的子串 sss 串和 TTT 中的前缀 ttt 串拼起来是回文串的种数的总贡献,要求 sss 串长度大于 ttt,s+ts+ts+t 是回文串,sss 和 ttt 串非空,拼起来的串视作不同种当且仅当 sss 或 ttt 串的位置不同。
我们把 sss 串分成 s1+s2s_1+s_2s1+s2 两个部分,其中 s1=reverse(t)s_1=reverse(t)s1=reverse(t),s2=palindromes_2=palindromes2=palindrome。显而易见的,我们可以把整个 SSS 串倒过来,计算它的每个后缀与 TTT 串的最长公共前缀,也就是扩展 KMPKMPKMP 做的事(我这里将 SSS 串和 TTT 串按 T+′∣′+ST+'|'+ST+′∣′+S 连接,这样求出 exexex 数组,就求出了 SSS 串的每个后缀在 TTT 串中的最长公共前缀,而 ′∣′'|'′∣′ 则用来避免他们计算错误)。
这样我们就得到了每个 s1s_1s1 串,此时的 sss 串变为 s2+s1s_2+s_1s2+s1,s1=ts_1=ts1=t,s2=palindromes_2=palindromes2=palindrome。我们只要再求出倒置后的 SSS 串中的回文串,然后差分贡献即可(每个回文串对它的右半径都有贡献,用差分才能在线性时间内求出)。
接下来就是求 SSS 串中的回文串,然后添加贡献,再遍历 SSS 串。累加差分值,当前字符为字母时计算对于当前的 SSS 串的后缀的贡献(为什么当前字符为字母才算,因为当前字符为回文串的右边界时,才计算贡献,而在实际的 SSS 串中,右边界只会是字母)。
toltoltol 为以当前字符为右边界的回文串总数,exexex 为 SSS 串的每个后缀与 TTT 串的最长公共前缀,显而易见的 exexex 的每个长度都能产生一个贡献,所以总贡献就是 ex[i+1]∗tolex[i+1]*tolex[i+1]∗tol(为什么是 i+1i+1i+1 ,因为是当前 iii 视作回文串的一部分,所以起始位置为 i+1i+1i+1 的后缀才用来当 s1s_1s1 串)。
挺好的一道题,ICPC的题目质量还是不错的
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define IOS ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
typedef long long LL;
const int maxn = 2e6 + 5;
char s[maxn], t[maxn];
int r[maxn], ex[maxn], sum[maxn];
int mid, rb, slen, tlen, len, cnt;
inline void init(){
tlen = strlen(t);
cnt = 1;
s[0] = '~';
s[1] = '|';
for(int i = tlen - 1; ~i; i--){
s[++cnt] = t[i];
s[++cnt] = '|';
}
for(int t = 1; t <= cnt; t++){
if(t < rb) r[t] = min(r[(mid<<1) - t], rb - t);
else r[t] = 1;
while(s[t - r[t]] == s[t + r[t]]) r[t]++;
if(r[t] + t > rb) rb = r[t] + t, mid = t;
sum[t]++, sum[t + r[t]]--;
}
}
inline void getex(){
int n = len;
for(int i = 1, l = 0, r = 0; i < n; i++){
if(i <= r) ex[i] = min(r - i + 1, ex[i - l]);
while(i + ex[i] < n && s[ex[i]] == s[i + ex[i]]) ex[i]++;
if(i + ex[i] - 1 > r) l = i, r = i + ex[i] - 1;
}
}
int main(){
IOS;
cin >> t;
init();
cin >> s;
slen = strlen(s);
s[slen++] = '|';
len = slen + tlen;
for(int i = 0; slen < len; i++, slen++) s[slen] = t[tlen - i - 1];
getex();
slen = len - tlen - 1;
LL tol = 0, ans = 0;
for(int i = 1; i <= cnt; i++){
tol += sum[i];
if(i % 2) continue;
ans += ex[slen + 1 + i / 2] * tol;
}
cout << ans << endl;
}