POJ 2817 WordStack 解题报告
题目链接:http://acm.pku.edu.cn/JudgeOnline/problem?id=2817
题目主要用到的算法:状态空间dp
题目描述:
这个题的意思是第一行给出case数N (1 <= N <= 10),然后给出N个单词,每个一行,当输入不是正整数的时候结束。每个单词最多10个字母。
Sample Input
5
abc
bcd
cde
aaa
bfcde
0
要求的是,按任意顺序排列这些单词,可以在单词前面加任意个空格,使得相邻的单词上下对应的字母数目最多,并输出最多是多少。
Sample Output
8
比如sample里面的8,是这样得来的:
aaa
abc
bcd
cde
bfcde
注意只有相邻单词的字母上下对应才算对应。
我AC的代码:

AC的代码
1
2
#include <cstdio>
3
#include <cmath>
4
#include <cstring>
5
6
char word[11][15];
7
int mm[11][11];
8
int len[11];
9
int n;
10
int dp[2000][11];
11
12
void Cal( int a, int b ) //word[a]和word[b]有多少公共字母,存在mm[a][b]里,test OK
13

{
14
int max = 0;
15
for ( int i = 0; i < len[a]; ++i )
16
{
17
int result = 0;
18
for ( int j = 0; j < len[b]; ++j )
19
{
20
if ( i+j<len[a] && word[a][i+j] == word[b][j] ) ++result;
21
}
22
max >?= result;
23
}
24
for ( int i = 1; i < len[b]; ++i )
25
{
26
int result = 0;
27
for ( int j = 0; j < len[a]; ++j )
28
{
29
if ( i+j<len[b] && word[b][i+j] == word[a][j] ) ++result;
30
}
31
max >?= result;
32
}
33
mm[a][b] = mm[b][a] = max;
34
}
35
36
int Fun( int state, int last )
37

{
38
int max = 0;
39
int state_t = state;
40
state_t &= (~(1<<(last-1)));
41
if ( state_t == 0 ) return 0;
42
if ( dp[state][last] ) return dp[state][last];
43
44
for ( int i = 0 ; i < n; ++i )
45
{
46
int result = 0;
47
int tmp = state_t&(~(1<<i));
48
if ( tmp != state_t ) result = Fun( state_t, i+1 )+mm[i+1][last];
49
max >?= result;
50
}
51
52
dp[state][last] = max;
53
return max;
54
}
55
56
int main()
57

{
58
while ( scanf( "%d", &n ), n>0 )
59
{
60
int max = 0;
61
int power = (int)pow( 2, n );
62
63
for ( int i = 0; i < power; ++i )
64
{
65
for ( int j = 0; j <= n; ++j )
66
{
67
dp[i][j] = 0;
68
}
69
}
70
71
for ( int i = 1; i <= n; ++i )
72
{
73
scanf( "%s", word[i] );
74
len[i] = strlen( word[i] );
75
}
76
77
for ( int i = 1; i <= n; ++i )
78
{
79
for ( int j = i+1; j <= n; ++j )
80
{
81
Cal( i, j ); //Cal函数算word两两之间的最大公共字母数存到mm里
82
}
83
}
84
85
for ( int i = 1; i <= n; ++i )
86
{
87
int tt = Fun( power-1, i );//Fun函数用来求dp[][]的值,也就是全部单词都放满,最后一行是i的情况下总的最大公共字母数
88
max >?= tt;
89
}
90
printf( "%d\n", max );
91
}
92
return 0;
93
}
思路:分析这个题目,我们应该很容易就看出,结果只与单词的排列顺序有关,至于在每个单词前面加几个空格对结果是不影响的。所以我们可以首先统计这n个单词两两之间最多能有多少个对应的字母,放到数组mm[11][11]中(因为题目说n最大是10)。数组名取得不太好哈:-P.比如sample里面的五个单词计算出的mm数组应该是这样的
0 0 0 0 0 0
0 0 2 1 1 1
0 2 0 2 0 2
0 1 2 0 0 3
0 1 0 0 0 0
0 1 2 3 0 0
mm[i][j] 是 word[i]和word[j]前面加任意个空格能对应的最多字母数。我的代码里是用一个Cal函数实现的这个功能。其中mm数组是全局变量。
1
void Cal( int a, int b ) //word[a]和word[b]有多少公共字母,存在mm[a][b]里,test OK
2

{
3
int max = 0;
4
//以word[a]作为基准,依次在word[b]前面添加1~len[a]-1这么多空格,看最多能有多少个字母匹配
5
for ( int i = 0; i < len[a]; ++i )
6
{
7
int result = 0;
8
for ( int j = 0; j < len[b]; ++j )
9
{
10
if ( i+j<len[a] && word[a][i+j] == word[b][j] ) ++result;
11
}
12
max >?= result;//">?="这个操作符的意思是二者中取最大的赋给max,可能在vc6.0里不支持
13
}
14
//和上面的代码类似,word[b]不动,在word[a]前面添加空格。
15
for ( int i = 1; i < len[b]; ++i )
16
{
17
int result = 0;
18
for ( int j = 0; j < len[a]; ++j )
19
{
20
if ( i+j<len[b] && word[b][i+j] == word[a][j] ) ++result;
21
}
22
max >?= result;
23
}
24
mm[a][b] = mm[b][a] = max;//这个矩阵是对称的
25
}
对这个题目,不知大家是不是想到了全排列的方法呢?就是把这n个单词全部的可能放置的顺序都排列出来,把上下相邻的公共字母数相加,哪种排列得到的结果最大,就是最终结果。
首先肯定这个做法是绝对正确的,而且也可以用这个做法AC这个题目(见附代码)。但是,这个做法效率比较低,用了600ms多才跑完数据。而且这种做法没有学到我们要学习的内容,状态空间dp。
接下来给出的是最关键的部分,状态空间dp部分的代码,只有短短的14行,用心理解便不困难

Fun函数
1
int dp[2000][11];//其实2^10=1024就够用了,不需要2000这么多,呵呵
2
int Fun( int state, int last )
3

{
4
int max = 0;
5
int state_t = state;
6
state_t &= (~(1<<(last-1)));
7
if ( state_t == 0 ) return 0;
8
if ( dp[state][last] ) return dp[state][last];
9
for ( int i = 0 ; i < n; ++i )
10
{
11
if ( state_t&(~(1<<i)) != state_t ) max >?= Fun( state_t, i+1 )+mm[i+1][last];
12
}
13
dp[state][last] = max;
14
return max;
15
}
我先来解释思路再解释这段代码。
其中主函数对Fun函数的调用部分代码如下:
1 int max = 0;
2 int power = (int)pow( 2, n );
3
4 for ( int i = 1; i <= n; ++i )
5 {
6 int tt = Fun( power-1, i );//Fun函数用来求dp[][]的值,也就是全部单词都放满,最后一行是i的情况下总的最大公共字母数
7 max >?= tt;
8 }
9 printf( "%d\n", max );
int Fun( int state, int last )函数的作用如注释所说,Fun函数的第一个参数state表示单词填入的状态,在这里,power-1表示的是所有的单词都已经放到排列里面的状态(我在后面会解释为什么)。第二个参数last表示最后添入的单词是哪个。这段代码的意思是:全部单词都填入,且最后填入的单词是word[i]的时候,总的最多的公共字母数是多少,当i取遍1~n的时候,就包括了全部的情况。取最大,就是我们要的结果。比如sample中,power-1 = 31(31的二进制码是11111)假设i=3,Fun(31,3)就表示下面的情况
word[x]
word[y]
word[z]
word[a]
word[3]
也就是说,x y z a是1 2 4 5的任意排列,或者说全部排列,然后最后一行必须是word[i],也就是word[3],取总公共字母数最多的值作为Fun函数的返回值。
现在我们回头看Fun函数是如何实现的。
1
2 int Fun( int state, int last ) //第一个参数state表示状态,也就是各个单词的取用情况,是用二进制的位运算解决的。因为每个单词只有两种状态,被用了和没被用,那我们就用1表示取了这个单词,0表示没取。每一位表示一个单词的取用情况,第i位表示第i-1个单词是否被取。比如10011表示word[5]word[2]word[1]被取,而word[4]word[3]未被取。
3 {
4 int max = 0;
5 int state_t = (state & (~(1<<(last-1))); //state_t是state去掉last的状态。
6 if ( state_t == 0 ) return 0; //如果state去掉last以后是0,也就是说,前面一个单词都没取,那last就是第一个单词,没有和它相邻的单词,返回0
7 if ( dp[state][last] ) return dp[state][last]; //记忆化搜索
8
9 for ( int i = 0 ; i < n; ++i ) //让n个单词都和last相邻一次,取最大值
10 {
11 if ( state_t&(~(1<<i)) != state_t ) max >?= Fun( state_t, i+1 )+mm[i+1][last];//如果state_t的i这位是1,也就是说,state_t包括word[i+1],那么递归调用Fun函数得到state_t包括的所有单词都作为末一行的最大公共字母数,此时word[i+1]作为word[last]的相邻行,所以再加上mm[i+1][last],取最大值存起来就可以了
12 }
13
14 dp[state][last] = max;
15 return max;
16 }
我觉得状态dp首先是个搜索,然后加上dp提高效率。顾名思义这个dp保存的是这种状态的最优结果。这是我的第一道状态dp题,这类题也仅仅做过两道而已,有什么不对和可以优化的地方请大侠们批评指正。
附:全排列的慢速代码
1
//枚举水过去的代码
2
#include <iostream>
3
#include <vector>
4
#include <algorithm>
5
using std::vector;
6
using std::string;
7
8
int dp[10][10] =
{0};
9
10
void F( string a, string b, int aa, int bb )
11

{
12
int max = 0;
13
string tmp;
14
for ( int i = 0; i != a.size(); ++i )
15
{
16
int result = 0;
17
for ( int j = 0; j != b.size() ; ++j )
18
{
19
if ( i+j < a.size() && b[j] == a[i+j] ) ++result;
20
}
21
if ( result > max ) max = result;
22
}
23
tmp = a;
24
a = b;
25
b = tmp;
26
for ( int i = 0; i != a.size(); ++i )
27
{
28
int result = 0;
29
for ( int j = 0; j != b.size(); ++j )
30
{
31
if ( i+j < a.size() && b[j] == a[i+j] ) ++result;
32
}
33
if ( result > max ) max = result;
34
}
35
dp[aa][bb] = dp[bb][aa] = max;
36
}
37
38
vector<string> word;
39
vector<int> id;
40
41
int main()
42

{
43
int n;
44
while ( scanf( "%d", &n )!=EOF && n > 0 )
45
{
46
int max = 0;
47
int result = 0;
48
for ( int i = 0; i < n; ++i )
49
{
50
string tmp;
51
std::cin>>tmp;
52
word.push_back(tmp);
53
id.push_back( i );
54
}
55
sort( word.begin(), word.end() );
56
57
for ( int i = 0; i < n; ++i )
58
{
59
for ( int j = i+1; j < n; ++j )
60
{
61
F( word[i], word[j], i, j );
62
}
63
}
64
65
for ( int i = 1; i < n; ++i )
66
{
67
result += dp[i][i-1];
68
}
69
if ( result > max ) max = result;
70
71
while ( next_permutation( id.begin(), id.end() ) )
72
{
73
result = 0;
74
for ( int i = 1; i < n; ++i )
75
{
76
result += dp[id[i]][id[i-1]];
77
}
78
if ( result > max ) max = result;
79
}
80
printf( "%d\n",max );
81
word.clear();
82
id.clear();
83
}
84
85
return 0;
86
}