4SUM implementation

Let’s try to solve the 4SUM problem. We will also take a look at 3SUM and 2SUM.

I will propose two solutions, both in O(n 2).

1. This solution uses a temporary dictionary and therefore allocates O(n 2) memory.

  • Create a map structure that will hold the two-sum of all the elements in the array. O(n 2)
  • Then iterate through the list of sorted sums from left and right to identify the terms that sum up to zero. (O(n))
  • Then increment the number of elements that sum to zero with the product of the number of pairs that have those sums.

The code follows:

public int count(final int[] a) {
    Map<Integer, List<Pair<Integer, Integer>>> dictionary = new HashMap<Integer, List<Pair<Integer, Integer>>>();
    int count = 0;

    for (int i = 0; i < a.length - 1; i++) {
        for (int j = i + 1; j < a.length; j++) {
            int sum = a[i] + a[j];
            if (dictionary.containsKey(sum)) {
                List<Pair<Integer, Integer>> pairs = dictionary.get(sum);
                pairs.add(new Pair<Integer, Integer>(i, j));
                dictionary.replace(sum, pairs);
            } else {
                List<Pair<Integer, Integer>> pairs = new LinkedList<Pair<Integer, Integer>>();
                pairs.add(new Pair<Integer, Integer>(i, j));
                dictionary.putIfAbsent(sum, pairs);
            }
        }
    } // O(n^2)

    Integer[] twoSums = dictionary.keySet().toArray(new Integer[]{});

    Arrays.sort(twoSums); // O(n log n) here be dragons

    for (int i = 0; i < twoSums.length - 1; i++) {
        for (int j = twoSums.length - 1; j > 0; j--) {
            int sum = twoSums[i] + twoSums[j];
            if (sum < 0) {
                i++;
            } else if (sum > 0) {
                j--;
            } else { //==0
                count += dictionary.get(twoSums[i]).size() * dictionary.get(twoSums[j]).size();
            }
        }
    }

    return count;
}

The second step is basically solving the 2SUM problem in O(n) time.

2. The second solution is fit for the case of distinct values and uses 4 pointers i, j, k, l, that pass through the array. i and j start from the left side, while k and l start from the right side. The algorithm executes as long as these pointers do not overlap.

public int count(final int[] a) {
    int count = 0;
    int i = 0;
    int j = i + 1;
    int k = a.length - 2;
    int l = a.length - 1;

    while (i < j && j < k && k < l) {
        int sum = a[i] + a[j] + a[k] + a[l];
        if (sum < 0) {
            if (j < k - 1) {
                j++;
            } else if (i < j - 1) {
                i++;
                j = i + 1;
            } else {
                break;
            }
        } else if (sum > 0) {
            if (k > j + 1) {
                k--;
            } else if (l > k + 1) {
                l--;
                k = l - 1;
            } else {
                break;
            }
        } else {
            count++;

            if (j < k - 1) {
                j++;
            } else {
                i++;
                j = i + 1;
            }
        }
    }

    return count;
}

As discussed previously, 2SUM can be solved in O(n) by using a pointer from left and another one from right, with the condition they don’t overlap.

3SUM can be solved in O(n 2) by using 3 pointers, 2 as above and another one that swings between them.

Advertisements