Friday, September 26, 2014

Fastest bit counting

The best bit counting algorithm as far as I know is the one invented by folks at Stanford University, which is always O(1).

int bitcount(int n)
{
    int cnt = 0;

    n = n - ((n >> 1) & 0x55555555);
    n = (n & 0x33333333) + ((n >> 2) & 0x33333333);

    return (((n + (n >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
}

For example, after I compile it to x86 Assembly:

    .cfi_startproc
    movl    4(%esp), %eax
    movl    %eax, %edx
    sarl    %edx
    andl    $1431655765, %edx
    subl    %edx, %eax
    movl    %eax, %edx
    sarl    $2, %eax
    andl    $858993459, %edx
    andl    $858993459, %eax
    addl    %edx, %eax
    movl    %eax, %edx
    sarl    $4, %edx
    addl    %edx, %eax
    andl    $252645135, %eax
    imull   $16843009, %eax, %eax
    sarl    $24, %eax
    ret
    .cfi_endproc


But, how does actually the algorithm work?

Ok, let try a simple one.  Assume n is a 4-bit number, and the bits can be represented as a set such that n= {a,b,c,d}., where a,b,c.. can only have either 0 or 1.  The decimal value of the n is: 8*a + 4*b + 2*c + d.  Total number of bit one is: a + b + c +d.  

For example, for n=0b1101, a=1, b=1, c=0, d=1 (which in decimal is 8*1 + 4*1 + 2*0 + 1 = 13), and total number of bit one is a+b+c+d = 1 + 1 + 0 + 1 = 3.  So far so good?

Now, we know that to count the 1-bits is as simple as: a+b+c+d.  But, wait.... n itself is not a+b+c+d, but 8*a + 4*b + 2*c + d.  Ok, let's conquer this step-by-step.  

If we shift n one bit to the right the LSbit is gone and other numbers just divided by two, so n becomes 4*a + 2*b + c, right? Now substract this to the original n.
  
n      = 8*a + 4*b + 2*c + d
n>>1   = 0 +   4*a + 2*b + c
----------------------------- -
       = 0   + 4*a + 2*b + c + d

That's a good direction!  Now if  (n>>1)  is written in the 4-bit nibble it is actually 0 + 4a + 2b + c.  If I just want to subtract 4a and c, we have to mask out 2*b.  so we mask it with binary 0101 (0x5), so we get only 4a + c:

n              = 8*a + 4*b + 2*c + d
(n>>1)&0x5     = 4*a +   0 + c
---------------------------------- -
n -(n>>1)&0x05 = 4*a + 4*b + c + d

Now store this result back to n, so from now on n is 4*a + 4*b + c + d
To get c + d only, we mask n with 0b11 or n & 0x03
if we shift n above once, we get 0 + 2a + 2b + 0, but if shift it again it becomes a + b!
To make sure we only get the lowest two bits (a + b), we mask it to 0x03 or:

n & 0x03       = c + d
(n >> 2)&0x03) = a + b
------------------------ +

Nice! now we have this expression: a + b + c +d .

so, for 4-bit bit counting, we can use the expression:

n = n - (n>>1) & 0x05
bits = (n & 0x3) + (n>>2) & 0x3

Proof: as example above, n=13 (0b1101).  so a=1,b=1, c=0, d=1

n = 0b1101 - (0b0110) & 0b0101 = 0b1101 - 0b100 = 13 - 4 = 9 = 0b1001
then the next step:
bits =(0b1001 & 0b0011)  +  (0b1001>>2) & 0b0011, or bits = 0b0001 + (0b0010) & 0b0011
 or bits = 1 + 2 = 3 !!

For 32-bit, it is based on the same idea, except we have to do more.

Say n  has set of coefficients {a[0], a[1], ...., a[31]} to represent the number, so n = a[0]*2^31 + a[1]*2^30 + .... + a[15]*2^0

The mask should be 32-bit, so instead of 0x5, we use 0x55555555 = 0b0101010101...0101

n                   = 2^31*a[0] + ... + 2^1*a[30] + 2^0*a[31]
(n>>1)&0b0101..0101 = 2^30*a[0] + ... + 2^2*a[28] + 2^0*a[30]
-------------------------------------------------------------- -
n -(n>>1)&0x055555555 = 2^31*a[0] - 2^28*a[0] + (2^30*a[1] - 2^26*a[1]) + ... + 4*a[29] - 2*a[30] + a[31]


Let's review binary arithmetics.

23 - 22 = 8 - 4 = 4
216 - 215 = 65536 - 32768 = 32768
or: 2(a+1) - 2a = 2a

2(a+2) - 2a = 2a * 22 - 2a = 2a (4 - 1) = 3*2a

So the result is:

a[0] * (2^31 - 2^28) + a[1] * (2^30 - 2^26) + ..... + a[30] (4 -2) + a[31]
= 3*2^28*a[0] + .3*2^26*a[1].. + 2*a[30] + a[31]

n - n(>>1) & 0x055555555 = 3*2^28*a[0] + .3*2^26*a[1].. + 2*a[30] + a[31]

stored this as a new n.

n>>2 = 2^24*a[0] + 2^22*a[1] + ...2*a[28] + a[29]

The rest is actually manipulation to count this a[0] + a[1] + ...  + a[31]

A variant, but this one is invented by folks at MIT:

int bitcount(unsigned int n)
{
    int cnt = 0;
    register unsigned int tmp;
                     
    tmp = n - ((n >> 1) & 033333333333) - ((n >> 2) &    011111111111);

    return ((tmp + (tmp >>3)) & 0030707070707) % 63;
}

It uses the similar method. but with different approach (notice the number 01..., 0333... and 00307... are in octal.  We could use Hex version but then it is harder to remember)





No comments:

Post a Comment