Saturday, June 11, 2016

Count of Range Sum -- LeetCode 327

[Question]

Given an integer array nums, return the number of range sums that lie in [lower, upper] inclusive.
Range sum S(i, j) is defined as the sum of the elements in nums between indices i and j (i ≤ j), inclusive.

Note:
A naive algorithm of O(n2) is trivial. You MUST do better than that.

Example:
Given nums = [-2, 5, -1], lower = -2, upper = 2,

Return 3.
The three ranges are : [0, 0], [2, 2], [0, 2] and their respective sums are: -2, -1, 2.

[Analysis]
This is another Binary Index Tree problem (check this blog post).

This problem is not straightforward until we can see Range Sum S(i, j) = S(0,j) - S(0,i-1). Let S[i] be the sum of nums[0] ... nums[i], the question become, find i,j in array nums, where lower<= S[j]-S[i-1]<=upper, where i<=j, which is equivalent to say, S[i-1] + lower <=  S[j]  <=  S[i-1]+ upper.

Suppose S[j] is already placed into Binary Index Tree, before place S[i-1], we can check how many S[j] has been placed between S[i-1]+lower and S[i-1]+upper. Note: (1) the value space represented by Binary Index Tree should be S[i], S[i]+lower, S[i]+upper, for each i; (2) the values placed in Binary Tree should be only S[i]s.

The overall time complexity is O(NLogN).

Update #1: Using enhanced BST can also resolve this problem in O(NLogN) time. The enhanced BST node records the numbers that are smaller than its value (i.e. the number of nodes in its left sub-tree). Then this enhanced BST can provide the position for lower_bound() and upper_bound() for a given value.

Update #2: Using Merge Sort is also a good solution.

[Solution]
typedef long long LL;

class BITree {
    vector<int> nodes;
    int lowbit(int x) { return x & -x; };
public:
    BITree(int n) : nodes(n+1, 0) {};
    
    void add(int i, int val) {
        while (i<nodes.size()) {
            nodes[i] +=val;
            i += lowbit(i);
        }
    }
    
    int sum (int i) {
        int sum=0;
        while (i>0) {
            sum +=nodes[i];
            i -= lowbit(i);
        }
        return sum;
    }
};

class Solution {
public:
    int countRangeSum(vector<int>& nums, int lower, int upper) {
        
        vector<LL> sums(nums.size()*3, 0);
        LL sum=0;
        for (int i=0; i< nums.size(); i++) {
            sum += nums[i];
            sums[3*i] = sum;
            sums[3*i+1] = sum+lower-1;
            sums[3*i+2] = sum+upper;
        }
        sums.push_back(upper);
sums.push_back(lower - 1);
        sort(sums.begin(), sums.end());
        
        // get value distribution and the inverted index        
        auto end = unique(sums.begin(), sums.end() );
        auto it = sums.begin();
        unordered_map<LL, int> index;
        for (int i=1; it!=end; i++, it++ ) 
            index[*it] = i;
        
        // use BIT to sort it out  
        BITree tree(index.size());
        int rslt=0;
        for (int i= nums.size()-1; i>=0; i--) {
            tree.add(index[sum],1);
            sum -= nums[i];
            rslt += tree.sum(index[upper+sum]) - tree.sum(index[lower+sum-1]);
        }

        return rslt;
    }
};

//
// -- Using BST --
//
class Solution {
    struct Node {
        long long val, smaller;
        Node *left, *right;
        Node(long long v, long long s):val(v),smaller(s),left(NULL),right(NULL){}
    };
 
    void insert(Node*& root, long long val) {
        if (root==NULL)
            root=new Node(val,0);
        else
            if (val < root->val) {
                root->smaller++;
                insert(root->left, val);
            }
            else insert(root->right, val);
    }
 
    int lower_bound(Node *root, long long val) {
        if (root==NULL) return 0;
        if (val < root->val) return lower_bound(root->left, val);
        else return lower_bound(root->right, val) + root->smaller +(val==root->val?0:1);
    }
 
    int upper_bound(Node *root, long long val) {
        if (root==NULL) return 0;
        if (val < root->val) return upper_bound(root->left, val);
        else return upper_bound(root->right, val) + root->smaller +1;
    }
 
public:
    int countRangeSum(vector<int>& nums, int lower, int upper) {
        if (nums.empty()) return 0;
     
        int res=0;
        Node *root=NULL;
        vector<long> firstsum(nums.size()+1, 0);
        for (int i=0; i<nums.size(); i++)
            firstsum[i+1] = firstsum[i]+ nums[i];
       
        for (int i=firstsum.size()-1; i>=0; i--) {
            res+= upper_bound(root, firstsum[i]+upper) - lower_bound(root, firstsum[i]+lower);
            insert(root, firstsum[i]);
        }
        return res;
    }
};

//
//-- Using Merge Sort ---
//
class Solution {
public:
    int countRangeSum(vector<int>& nums, int lower, int upper) {
        vector<long> subsum(nums.size()+1, 0);
        for (int i=0; i<nums.size(); i++)
            subsum[i+1]=subsum[i]+nums[i];
     
        function<int(int,int)> msort=[lower, upper, &subsum, &msort](int lo, int hi)->int {
            if (lo+1>=hi) return 0;
            int mid = (lo+hi)/2;
         
            //count.....
            int cnt= msort(lo, mid) + msort(mid, hi);
            int i=mid, j=mid;
            for (int k=lo; k<mid; k++) {
                while (i<hi && subsum[i] < subsum[k]+lower) i++;
                while (j<hi && subsum[j] <=subsum[k]+upper) j++;
                cnt+=j-i;
            }
         
            // merge sort.....
            vector<int> tmp;
            i=lo; j=mid;
            while (i<mid && j<hi)
                if (subsum[i]<subsum[j]) tmp.push_back(subsum[i++]);
                else tmp.push_back(subsum[j++]);
            while (i<mid) tmp.push_back(subsum[i++]);
            while (j<hi)  tmp.push_back(subsum[j++]);
             
            for (auto& n: tmp) subsum[lo++] = n;
         
            return cnt;
        };
     
        return msort(0, subsum.size());
    }
};

No comments:

Post a Comment