题意:
给定多个字符串,是按照字典顺序排列的。一个字符串如果能够可以通过改变一个字母,删除一个字母,增加一个字母变成后面的某一个字符串,
那么称这两个字符串之间存在一个阶梯,问最多有多少个阶梯。
思路:
http://www.cnblogs.com/staginner/archive/2011/11/30/2269222.html
这一题的难点在于题目数据量很大,O(n^2)的算法铁定会超时。另类的解题思路是:构造一个hash表,把给定的字符串都存进去。
然后试图对每个字符串进行删除,插入,修改等操作,查询经过变换后的字符串是否在hash表中,并且这个变换后的字符串要在源字符串的后面。
然后记忆化搜索得出最长的阶梯。
#include <cstdio> #include <cstdlib> #include <cstring> const int MAXN = 25010; const int HASH = 1000010; int n, head[HASH], next[MAXN], f[MAXN]; char b[MAXN][20], temp[20]; int hash(const char *s) { int v = 0, seed = 131; while (*s) v = v * seed + *(s++); return (v & 0x7fffffff) % HASH; } void insert(int s) { int h = hash(b[s]); next[s] = head[h]; head[h] = s; } int search(const char *s) { int i, h = hash(s); for (i = head[h]; i != -1; i = next[i]) if (!strcmp(b[i], s)) break; return i; } void add(const char *s, int p, int d) { int i = 0, j = 0; while (i < p) temp[j++] = s[i++]; temp[j++] = 'a' + d; while (s[i]) temp[j++] = s[i++]; temp[j] = '\0'; } void del(const char *s, int p) { int i = 0, j = 0; while (i < p) temp[j++] = s[i++]; ++i; while (s[i]) temp[j++] = s[i++]; temp[j] = '\0'; } void change(const char *s, int p, int d) { strcpy(temp, s); temp[p] = 'a' + d; } int dp(int s) { if (f[s] != -1) return f[s]; int ans = 0; int len = strlen(b[s]); for (int p = 0; p <= len; ++p) { for (int d = 0; d < 26; ++d) { add(b[s], p, d); int v = search(temp); if (v != -1 && strcmp(b[s], temp) < 0) { int t = dp(v); if (ans < t + 1) ans = t + 1; } } } for (int p = 0; p < len; ++p) { del(b[s], p); int v = search(temp); if (v != -1 && strcmp(b[s], temp) < 0) { int t = dp(v); if (ans < t + 1) ans = t + 1; } } for (int p = 0; p < len; ++p) { for (int d = 0; d < 26; ++d) { change(b[s], p, d); int v = search(temp); if (v != -1 && strcmp(b[s], temp) < 0) { int t = dp(v); if (ans < t + 1) ans = t + 1; } } } return f[s] = ans; } void solve() { memset(f, -1, sizeof(f)); int ans = 0; for (int i = 0; i < n; ++i) { int t = dp(i); if (ans < t) ans = t; } printf("%d\n", ans + 1); } void init() { n = 0; memset(head, -1, sizeof(head)); while (scanf("%s", b[n]) == 1) { insert(n), ++n; } } int main() { init(); solve(); return 0; }