dp-and-backtracking

DP

将递归算法重新写成非递归算法,让后者把那些子问题的答案系统地记录在一个表内。利用这种方法的一种技巧就叫做动态规划(dynamic programming)。

1
2
3
4
5
6
7
int Fib(int N) {
if (N <= 1) {
return 1;
}

return Fib(n - 1) + Fib(N - 2);
}

01 背包问题: 每件物品共有两种情况,挑选和不挑选

最朴素的解法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
int rec(int i, int j) {
int res;
if (i == n) {
// 已经没有剩余物品了
res = 0;
} else if (j < w[i]) {
// 无法挑选这个物品
res = rec(i + 1, j);
} else {
// 挑选和不挑选两种情况
res = max(rec(i + 1, j), rec(i + 1, j - w[i]) + v[i]);
}

return res;
}

void solve() {
rec(0, w);
}

我们尝试把第一次计算时的结果记录下来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
int dp[MAX_N + 1][MAX_N + 1];

int rec(int i, int j) {
if (dp[i][j] >= 0) {
// 已经计算过的话直接使用之前的结果
return dp[i][j];
}

int res;
if (i == n) {
// 已经没有剩余物品了
res = 0;
} else if (j < w[i]) {
// 无法挑选这个物品
res = rec(i + 1, j);
} else {
// 挑选和不挑选两种情况
res = max(rec(i + 1, j), rec(i + 1, j - w[i]) + v[i]);
}

// 将结果记录在数组中
return dp[i][j] = res;
}

void solve() {
memset(dp, -1, sizeof(dp));
rec(0, w);
}

其实,如果不用写递归函数,直接利用递归式将各项的值计算出来,简单地用二重循环也可以解决这一问题:

代码:

1
2
3
4
5
6
7
8
9
10
11
12
void solve() {
for (int i=0; i<n; i++) {
for (int j=0; j<=w; j++) {
// 还允许拿 j 千克,这件物品重 w[i] 千克
if (j < w[i]) {
dp[i + 1][j] = dp[i][j];
} else {
dp[i + 1][j] = max(dp[i][j], dp[i][j - w[i]] + v[i]);
}
}
}
}

类似背包问题的有:

对于,每一个数字,我们都可以选择 pick it or not. 假设:

1
dp[i][j] := 和等于 j 的几个数字是否可以从前 i 个数得到

那么,就有如下定义:

1
2
dp[i][j] = dp[i-1][j]; // dont't pick the number nums[i]
dp[i][j] = dp[i-1][j-nums[i]] // we pick the number nums[i]

那么最后的方程为:

1
2
dp[0][0] = 0;
dp[i][j] = dp[i-1][j] || dp[i-1][j-nums[i]];

不限次数挑选:

① 朴素解法

1
2
3
4
5
6
7
8
9
10
11
12
13
public int combinationSum4(int[] nums, int target) {
if (target == 0) {
return 1;
}

int res = 0;
for (int i = 0; i < nums.length; i++) {
if (target >= nums[i]) {
res += combinationSum4(nums, target - nums[i]);
}
}
return res;
}

② 记忆化搜索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
private int[] dp;

public int combinationSum4(int[] nums, int target) {
dp = new int[target + 1];

Arrays.fill(dp, -1);
dp[0] = 1;

return helper(nums, target);
}

private int helper(int[] nums, int target) {
if (dp[target] != -1) {
return dp[target];
}

int res = 0;
for (int i = 0; i < nums.length; i++) {
if (target >= nums[i]) {
res += helper(nums, target - nums[i]);
}
}
dp[target] = res;
return res;
}

③ 递推式:

1
2
3
4
5
6
7
8
9
10
11
12
public int combinationSum4(int[] nums, int target) {
int[] comb = new int[target + 1];
comb[0] = 1;
for (int i = 1; i < comb.length; i++) {
for (int j = 0; j < nums.length; j++) {
if (i - nums[j] >= 0) {
comb[i] += comb[i - nums[j]];
}
}
}
return comb[target];
}

连续最大子数组问题:

1
L[i] = max(L[i-1] + A[i], A[i])

从后向前做乘法会超时:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution {
public int maxProduct(int[] nums) {
if (nums.length == 0) {
return 0;
} else if (nums.length == 1) {
return nums[0];
}

int[] dp = new int[nums.length];
dp[0] = nums[0];

for (int i=1; i<nums.length; i++) {
if (dp[i - 1] != 0) {
dp[i] = dp[i - 1] * nums[i];
} else {
dp[i] = nums[i];
}
}

int max = dp[0];
for (int i=1; i<nums.length; i++) {
if (dp[i] > max) {
max = dp[i];
}

for (int j=i - 1; j>=0; j--) {
if (dp[j] == 0) {
break;
}

int product = dp[i] / dp[j];
if (product > max) {
max = product;
}
}

}

return max;
}
}

以下是网友给出的一个时间复杂度为 O(n) 的算法:

整体的 DP 公式为:

1
imax = max(result, max(A[i], imax * A[i]));

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
int maxProduct(int A[], int n) {
// store the result that is the max we have found so far
int r = A[0];

// imax/imin stores the max/min product of
// subarray that ends with the current number A[i]
for (int i = 1, imax = r, imin = r; i < n; i++) {
// multiplied by a negative makes big number smaller, small number bigger
// so we redefine the extremums by swapping them
// example:
// min: 3, max: 6, nums[i]:-2
// min: -12, max: -6
//
// min: -6, max: -3, nums[i]:-2
// min: 6, max:12
//
// min:-3, max:6, nums[i]:-2
// min:-12, max:6
//
// No matter which case, it always should swap the min and max
if (A[i] < 0)
swap(imax, imin);

// max/min product for the current number is either the current number itself
// or the max/min by the previous number times the current one
imax = max(A[i], imax * A[i]);
imin = min(A[i], imin * A[i]);

// the newly computed max value is a candidate for our global result
r = max(r, imax);
}
return r;
}

引入之后如何解决问题,参考网友的答案我们可以打破环:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
public int rob(int[] nums) {
// 1 -> 2 -> 3 -> 4 -> 5 -> 1
//
// two cases can break the circle:
//
// 1 -> 2 -> 3 -> 4
// 2 -> 3 -> 4 -> 5
//
if (nums == null || nums.length == 0) {
return 0;
} else if (nums.length == 1) {
return nums[0];
} else if (nums.length == 2) {
return Math.max(nums[0], nums[1]);
}

return Math.max(rob(nums, 0, nums.length - 2), rob(nums, 1, nums.length - 1));
}

private int rob(int[] nums, int left, int right) {
int n = right - left + 1;
int[] dp = new int[n];

dp[0] = nums[left];
dp[1] = Math.max(nums[left], nums[left + 1]);

for (int i=2; i<n; i++) {
int robMoney = dp[i - 2] + nums[left + i];
int notRobMoney = dp[i - 1];
dp[i] = Math.max(robMoney, notRobMoney);
}

return dp[n - 1];
}
}

什么是 DP:

Wikipedia definition: “method for solving complex problems by breaking them down into simpler subproblems”


步骤:

  1. Define subproblems
  2. Write down the recurrence that relates subproblems
  3. Recognize and solve the base cases

1-dimensional DP:

Problem: given n, find the number of different ways to write n as the sum of 1, 3, 4.
Example: for n = 5, the answer is 6

1
2
3
4
5
6
5 = 1 + 1 + 1 + 1 + 1
= 1 + 1 + 3
= 1 + 3 + 1
= 3 + 1 + 1
= 1 + 4
= 4 + 1

  • 定义子问题: Let Dn be the number of ways to write n as the sum of 1, 3, 4

解决全排列问题的一类通用方法

A general approach to backtracking questions in Java (Subsets, Permutations, Combination Sum, Palindrome Partioning)

This structure might apply to many other backtracking questions, but here I am just going to demonstrate Subsets, Permutations, and Combination Sum.

Subsets :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public List<List<Integer>> subsets(int[] nums) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, 0);
return list;
}

private void backtrack(List<List<Integer>> list , List<Integer> tempList, int [] nums, int start){
list.add(new ArrayList<>(tempList));
for (int i = start; i < nums.length; i++) {
tempList.add(nums[i]);
backtrack(list, tempList, nums, i + 1);
tempList.remove(tempList.size() - 1);
}
}

Subsets II (contains duplicates):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public List<List<Integer>> subsetsWithDup(int[] nums) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, 0);
return list;
}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums, int start){
list.add(new ArrayList<>(tempList));
for(int i = start; i < nums.length; i++){
if(i > start && nums[i] == nums[i-1]) continue; // skip duplicates
tempList.add(nums[i]);
backtrack(list, tempList, nums, i + 1);
tempList.remove(tempList.size() - 1);
}
}

Permutations (全排列):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public List<List<Integer>> permute(int[] nums) {
List<List<Integer>> list = new ArrayList<>();
// Arrays.sort(nums); // not necessary
backtrack(list, new ArrayList<>(), nums);
return list;
}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums){
if (tempList.size() == nums.length) {
list.add(new ArrayList<>(tempList));
} else {
for(int i = 0; i < nums.length; i++){
if(tempList.contains(nums[i])) continue; // element already exists, skip
tempList.add(nums[i]);
backtrack(list, tempList, nums);
tempList.remove(tempList.size() - 1);
}
}
}

Permutations II (contains duplicates):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public List<List<Integer>> permuteUnique(int[] nums) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, new boolean[nums.length]);
return list;
}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums, boolean [] used){
if(tempList.size() == nums.length){
list.add(new ArrayList<>(tempList));
} else{
for(int i = 0; i < nums.length; i++){
if(used[i] || i > 0 && nums[i] == nums[i-1] && !used[i - 1]) continue;
used[i] = true;
tempList.add(nums[i]);
backtrack(list, tempList, nums, used);
used[i] = false;
tempList.remove(tempList.size() - 1);
}
}
}

Combination Sum:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public List<List<Integer>> combinationSum(int[] nums, int target) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, target, 0);
return list;
}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums, int remain, int start){
if(remain < 0) return;
else if(remain == 0) list.add(new ArrayList<>(tempList));
else{
for(int i = start; i < nums.length; i++){
tempList.add(nums[i]);
backtrack(list, tempList, nums, remain - nums[i], i); // not i + 1 because we can reuse same elements
tempList.remove(tempList.size() - 1);
}
}
}

Combination Sum II (can’t reuse same element):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public List<List<Integer>> combinationSum2(int[] nums, int target) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, target, 0);
return list;

}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums, int remain, int start){
if(remain < 0) return;
else if(remain == 0) list.add(new ArrayList<>(tempList));
else{
for(int i = start; i < nums.length; i++){
if(i > start && nums[i] == nums[i-1]) continue; // skip duplicates
tempList.add(nums[i]);
backtrack(list, tempList, nums, remain - nums[i], i + 1);
tempList.remove(tempList.size() - 1);
}
}
}

Palindrome Partitioning:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public List<List<String>> partition(String s) {
List<List<String>> list = new ArrayList<>();
backtrack(list, new ArrayList<>(), s, 0);
return list;
}

public void backtrack(List<List<String>> list, List<String> tempList, String s, int start){
if(start == s.length())
list.add(new ArrayList<>(tempList));
else{
for(int i = start; i < s.length(); i++){
if(isPalindrome(s, start, i)){
tempList.add(s.substring(start, i + 1));
backtrack(list, tempList, s, i + 1);
tempList.remove(tempList.size() - 1);
}
}
}
}

public boolean isPalindrome(String s, int low, int high){
while(low < high)
if(s.charAt(low++) != s.charAt(high--)) return false;
return true;
}

推荐文章