Wednesday, July 25, 2018

Least significant bit set and log2()

For my skiplist implementation I had code like this:

#include <cstdint>
#include <random>

class SkipList::RandomGenerator{
public:
 bool operator()(){
  return distr_(gen_);
 }

private:
 std::mt19937   gen_{ (uint32_t) time(nullptr) };
 std::bernoulli_distribution distr_{ 0.5 };
};

SkipList::RandomGenerator SkipList::rand_;

//...

auto SkipList::getRandomHeight_() -> int{
 int h = 1;
 while( h < height_ && rand_() )
  h++;

 return h;
}

The problem here is that rand_ is called in the loop. I did measurements and it turn out that this can be much faster if I do it with single rand_() call.

In order to do that I need to find "least significant bit set" or to "count leading zeroes".

Both functions are very similar. For example let suppose we have value of 100:

Binary:  0 1 1 0 0 1 0 0
Pos:     7 6 5 4 3 2 1 0

In this case, because third bit is set, the function should return "2".

Naive implementation


#include <cstdint>

template<typename T>
int lsb(T const x){
 for(uint8_t i = 0; i < sizeof(T) * 8; ++i)
  if (x & (1 << i) )
   return i;

 return 0;
}

It is naive, but note this is portable and works for all unsigned types.

However performance of this function is O(n). We sure can do better. Can we? 

Pre-computed table

We could do pre-computed lookup table and check from there.

For 32bit integers table must be 4GB in size. Not very practical. Because of that we can do it byte by byte:

#include <cstdint>

struct Lookup8{
 constexpr Lookup8(){
  for(uint8_t i = 0; i < 0xFF; ++i)
   data[i] = calc(i);
 }

 constexpr uint8_t operator[](uint8_t const x) const{
  return data[x];
 }

private:
 constexpr static uint8_t calc(uint8_t const x){
  for(uint8_t i = 0; i < 8; ++i)
   if (x & (1 << i) )
    return i;

  return 0;
 }

private:
 uint8_t data[0xFF] = {};
};

constexpr Lookup8 l8;

template<typename T>
int lsb(T const x){
 int j = 0;
 for(uint8_t i = 0; i < sizeof(T); ++i){
  uint8_t const byte = ( x >> j ) & 0xFF;

  if (byte)
   return l8[byte] + j;

  j += 8;
 }

 return 0;
}

This is great, but there should be better solution? ...and after some google-ing it is this...

Log2 + bitmask

#include <cstdint>

int log2(uint64_t const x);

constexpr uint64_t lsb_(uint64_t const x){
 return x & ( ~x + 1 );
}

constexpr int lsb(uint64_t const x){
 return log2(lsb_(x));
}

Wow! That's really short.

...But is log2() O(1) ?

It turns out it is a way to do it in O(1). Algorithm uses something called deBruijn sequence.

De Bruijn sequence

#include <cstdint>

constexpr int log2(uint64_t n){
 constexpr uint8_t tab64[64] = {
   0, 58,  1, 59, 47, 53,  2, 60,
  39, 48, 27, 54, 33, 42,  3, 61,
  51, 37, 40, 49, 18, 28, 20, 55,
  30, 34, 11, 43, 14, 22,  4, 62,
  57, 46, 52, 38, 26, 32, 41, 50,
  36, 17, 19, 29, 10, 13, 21, 56,
  45, 25, 31, 35, 16,  9, 12, 44,
  24, 15,  8, 23,  7,  6,  5, 63
 };

 n |= n >>  1;
 n |= n >>  2;
 n |= n >>  4;
 n |= n >>  8;
 n |= n >> 16;
 n |= n >> 32;

 auto const pos = (n  * 0x03F6EAF2CD271461) >> 58;

 return tab64[pos];
}


constexpr int log2(uint32_t n){
 constexpr uint8_t tab32[32] = {
   0,  9,  1, 10, 13, 21,  2, 29,
  11, 14, 16, 18, 22, 25,  3, 30,
   8, 12, 20, 28, 15, 17, 24,  7,
  19, 27, 23,  6, 26,  5,  4, 31
 };

 n |= n >>  1;
 n |= n >>  2;
 n |= n >>  4;
 n |= n >>  8;
 n |= n >> 16;

 auto const pos = (n * 0x07C4ACDD ) >> 27;

 return tab32[pos];
}

template<typename T>
constexpr T lsb_(T const x){
 return x & ( ~x + 1 );
}

template<typename T>
constexpr int lsb(T const x){
 return log2(lsb_(x));
}
Wait? The code complexity is not O(1) !

If you look at "n |= n >> 1", this is nothing else but unrolled loop.

Complexity is at best O(log(n)).

Time for testing

For testing I used something like this code:

#include <cstdint>

int main(){
 int x = 0;

 for(uint64_t i = 0; i <= 100000000; ++i)
  x += lsb(i);

 std::cout << x << '\n';
}
Then I start it with time. Here are the results:
Naive     | 0.452 sec
Lookup    | 0.550 sec
De Bruijn | 2.394 sec
De Bruijn solution is slowest.
Naive solution is fastest.

Bonus

  • GCC have non standard function ffs() defined in "string.h" / "cstring".
  • Linux have similar function defined somewhere too.