Well, this problem has a O(n^3) solution similar to 3Sum. That is, fix two elements nums[i] and nums[j] (i < j) and search in the remaining array for two elements that sum to the target - nums[i] - nums[j]. Since i and j both have O(n) possible values and searching in the remaining array for two elements (just like 3Sum that fixes one and search for two other) has O(n) complexity using two pointers (left and right), the total time complexity is O(n^3).
The code is as follows, which should explain itself.
1 vector<vector<int> > fourSum(vector<int>& nums, int target) { 2 sort(nums.begin(), nums.end()); 3 vector<vector<int> > res; 4 for (int i = 0; i < (int)nums.size() - 3; i++) { 5 for (int j = i + 1; j < (int)nums.size() - 2; j++) { 6 int left = j + 1, right = nums.size() - 1; 7 while (left < right) { 8 int temp = nums[i] + nums[j] + nums[left] + nums[right]; 9 if (temp == target) { 10 vector<int> sol(4); 11 sol[0] = nums[i]; 12 sol[1] = nums[j]; 13 sol[2] = nums[left]; 14 sol[3] = nums[right]; 15 res.push_back(sol); 16 while (left < right && nums[left] == sol[2]) left++; 17 while (left < right && nums[right] == sol[3]) right--; 18 } 19 else if (temp < target) left++; 20 else right--; 21 } 22 while (j + 1 < (int)nums.size() - 2 && nums[j + 1] == nums[j]) j++; 23 } 24 while (i + 1 < (int)nums.size() - 3 && nums[i + 1] == nums[i]) i++; 25 } 26 return res; 27 }
In fact, there is also an O(n^2logn) solution by storing all the possible sum of a pair of elements in nums first. There are such solutions in solution 1, solution 2 and solution 3. You may refer to them. One thing to mention is that all these solutions are slower than the above O(n^3) solution on the OJ :)
Personally I like solution 3 and rewrite it below.
1 vector<vector<int> > fourSum(vector<int>& nums, int target) { 2 sort(nums.begin(), nums.end()); 3 unordered_map<int, vector<pair<int, int> > > mp; 4 for (int i = 0; i < nums.size(); i++) 5 for (int j = i + 1; j < nums.size(); j++) 6 mp[nums[i] + nums[j]].push_back(make_pair(i, j)); 7 vector<vector<int> > res; 8 for (int i = 0; i < (int)nums.size() - 3; i++) { 9 if (i && nums[i] == nums[i - 1]) continue; 10 for (int j = i + 1; j < (int)nums.size() - 2; j++) { 11 if (j > i + 1 && nums[j] == nums[j - 1]) continue; 12 int remain = target - nums[i] - nums[j]; 13 if (mp.find(remain) != mp.end()) { 14 for (auto itr = mp[remain].begin(); itr != mp[remain].end(); itr++) { 15 int k = (*itr).first, l = (*itr).second; 16 if (k > j) { 17 vector<int> ans(4); 18 ans[0] = nums[i]; 19 ans[1] = nums[j]; 20 ans[2] = nums[k]; 21 ans[3] = nums[l]; 22 if (res.empty() || ans != res.back()) 23 res.push_back(ans); 24 } 25 } 26 } 27 } 28 } 29 return res; 30 }