[Algorithm]DP解决钢条切割问题

DP解决钢条切割问题


(原题见算法导论·动态规划)

对长度为n的钢条进行切割,对应的切割长度和价格对应如下:

int cost[] = {0, 1, 5, 8, 9, 10, 17, 17, 20, 24, 30};

比如1对应价值1,10对应价值30。即相应的下标和值的对应。现求切割所得最大效益mx。

递归算法

1
2
3
4
5
6
7
8
9
int cut_rod(int *cost,int n)
{
if(n == 0) return 0;
int limit = MIN(n,10); //分割第一条的上限
int mx = -1;
for(int i = 1;i <= limit; ++i)
mx = maxnum(mx,cost[i]+cut_rod(cost,n-i)); //取当前值于递归值的最大值
return mx;
}

由于对相同子问题的重复求解,T(n) = 2^n

递归标记数组算法(自顶而下)(DFS)

1
2
3
4
5
6
7
8
9
10
11
12
int mem_cut_rod(int *cost,int n,int *mem)   //mem数组长度为n,所有元素须在其他函数中初始化为-1
{
int mx;
if (mem[n] >= 0) return mem[n]; //对于求过的问题,直接返回存储的值
if (n == 0) mx = 0;
else mx = -1;
int limit = MIN(n,10);
for(int i = 1;i <= limit; ++i)
mx = maxnum(mx,cost[i]+mem_cut_rod(cost,n-i,mem)); //后面的内容和递归型是一样的
mem[n] = mx; //储存计算出的新值
return mx;
}

逆拓扑序DP(自底向上)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

int bottom_cut_rod(int *cost,int n)
{
int mem[MEM_LEN+1]; //MEM_LEN = n,设置标记数组
mem[0] = 0; //i,j将从1开始,这里收益是0
for(int i = 1; i <= n; ++i) //从第一个问题开始求解
{
int mx = -1;
int limit = MIN(i,10);
for(int j = 1;j <= limit; ++j)
mx = maxnum(mx,cost[j] + mem[i-j]); //求解最小的问题
mem[i] = mx;
}
return mem[n];
}

我们可以看到,2,3 的解法复杂度均为O(n^2)。

带解决方案的DFS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
typedef struct {
string path; //方案路径
bool memoried;
int value;
} MEMORY;

MEMORY *mem_pool;
string num_to_str(int num) {
char buf[120];
sprintf(buf, "%d", num);
return string(buf);
}

MEMORY DFS(int remain) {
int select, limit = MIN(remain, COST_LEN), mx = -1, cur_cost;
string cur_path, mx_path;
if (mem_pool[remain - 1].memoried) {
return mem_pool[remain - 1];
}
for (select = 1; select <= limit; ++select) {
if (select == remain) {
cur_cost = cost[remain];
cur_path = num_to_str(remain);
} else {
MEMORY upper = DFS(remain - select);
cur_cost = cost[select] + upper.value;
cur_path = num_to_str(select) + ", " + upper.path;
}
if (cur_cost > mx) {
mx = cur_cost;
mx_path = cur_path;
}
}
mem_pool[remain - 1].memoried = true;
mem_pool[remain - 1].value = mx;
mem_pool[remain - 1].path = mx_path;
return mem_pool[remain - 1];
}

int main() {
int n, i;
cin >> n;
mem_pool = new MEMORY[n];
if (!mem_pool) {
return 1;
}
for (i = 0; i < n; ++i) {
mem_pool[i].memoried = false;
}
MEMORY result = DFS(n);
cout << result.value << endl;
cout << result.path << endl;
delete[] mem_pool;
return 0;
}

其他代码

最后我们附上一份c实现的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
//2015.6.2
//copyright XJSoft

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>

typedef struct {
bool memoried;
int value;
} MEMORY;

int cost[] = {0, 1, 5, 8, 9, 10, 17, 17, 20, 24, 30};
MEMORY *mem_pool;

#define COST_LEN 10
#define MIN(a,b) ((a)<(b)?(a):(b))
int maxnum(const int v1,const int v2)
{
if (v1 > v2) return v1;
else return v2;
}

int DFS(int remain) {
int select, limit = MIN(remain, COST_LEN), mx = -1, cur_cost;
if (mem_pool[remain - 1].memoried) {
return mem_pool[remain - 1].value;
}
for (select = 1; select <= limit; ++select) {
if (select == remain) {
cur_cost = cost[remain];
} else {
cur_cost = cost[select] + DFS(remain - select);
}
if (cur_cost > mx) {
mx = cur_cost;
}
}
mem_pool[remain - 1].memoried = true;
mem_pool[remain - 1].value = mx;
return mx;
}

int DP(int n) {
int remain, select, limit, mx, cur_cost;
for (remain = 1; remain <= n; ++remain) {
mx = -1;
limit = MIN(remain, COST_LEN);
for (select = 1; select <= limit; ++select) {
if (select == remain) {
cur_cost = cost[select];
} else {
cur_cost = cost[select] + mem_pool[remain - select - 1].value;
}
if (cur_cost > mx) {
mx = cur_cost;
}
}
mem_pool[remain - 1].value = mx;
}
return mem_pool[n - 1].value;
}

int main() {
int n, i;
scanf("%d", &n);
mem_pool = (MEMORY*)malloc(n * sizeof(MEMORY));
if (!mem_pool) {
printf("Mem error!\n");
return 1;
}
for (i = 0; i < n; ++i) {
mem_pool[i].memoried = false;
}
printf("%d\n", DP(n));
free(mem_pool);
return 0;
}