Convert __m256i register to uint64_t bitmask such that that value at each byte's value is a set bit in the output

4

Basically I have an __m256i variable where each byte represents a position that needs to be set in a uint64_t. Note all byte values will be < 64.

I'm at somewhat of a loss for how to do this even remotely efficiently.

One option I was considering is in certain circumstances there are a lot of duplicates between the bytes so something along the lines of:

__m256i indexes = foo();

uint64_t result         = 0;
uint32_t aggregate_mask = ~0;
do {
    uint32_t idx = _mm256_extract_epi8(indexes, __tzcnt_u32(aggregate_mask));

    uint32_t idx_mask =
        _mm256_movemask_epi8(_mm256_cmpeq_epi(indexes, _mm256_set1_epi8(idx)));
    aggregate_mask ^= idx_mask;
    result |= ((1UL) << idx);
} while (aggregate_mask);

With enough duplicates I believe this could be somewhat efficient but I can't guarantee that there will always be sufficient duplicates to make this faster than just iterating through the bytes and setting sequentially.

My goal is to find something this will ALWAYS be faster than what feels like the worst case:

__m256i indexes = foo();
uint8_t index_arr[32];
_mm256_store_si256((__m256i *)index_arr, indexes);

uint64_t result = 0;
for (uint32_t i = 0; i < 32; ++i) {
    result |= ((1UL) << index_arr[i];
}

If possible I am looking for a solution that can run on skylake (w.o AVX512). If AVX512 is necessary (I was thinking there might be something semi efficient with grouping then using _mm256_shldv_epi16) something is better than nothing :)

This is what I am thinking. Going from epi32:

    // 32 bit
    __m256i lo_shifts = _mm256_sllv_epi32(_mm256_set1_epi32(1), indexes);
    __m256i t0 = _mm256_sub_epi32(indexes, _mm256_set1_epi32(1));
    __m256i hi_shifts = _mm256_sllv_epi32(_mm256_set1_epi32(1), t0);
    __m256i lo_shifts_lo = _mm256_shuffle_epi32(lo_shifts, 0x5555);
    __m256i hi_shifts_lo = _mm256_shuffle_epi32(hi_shifts, 0x5555);
    
    __m256i hi_shifts_hi0 = _mm256_slli_epi64(hi_shifts, 32);
    __m256i hi_shifts_hi1 = _mm256_slli_epi64(hi_shifts_lo, 32);
    __m256i all_hi_shifts = _mm256_or_epi64(hi_shifts_hi0, hi_shifts_hi1);
    
    __m256i all_lo_shifts_garbage = _mm256_or_epi64(lo_shifts_lo, lo_shifts);
    __m256i all_lo_shifts = _mm256_and_epi64(all_lo_shifts_garbage, _mm256_set1_epi64(0xffffffff));

    __m256i all_shifts = _mm256_or_epi64(all_lo_shifts, all_hi_shifts);

or going from epi64 bit:

    // 64 bit
    __m256i indexes0 = _m256_and_epi64(indexes, _mm256_set1_epi64(0xffffffff));
    __m256i indexes1 = _m256_shuffle_epi32(indexes, 0x5555);

    __m256i shifts0 = _m256_sllv_epi64(_mm256_set1_epi64x(1), indexes0);
    __m256i shifts1 = _m256_sllv_epi64(_mm256_set1_epi64x(1), indexes1);

    __m256i all_shifts = _m256_or_epi64(shifts0, shifts1);

My guess is from epi64 is faster.

c++
simd
avx
micro-optimization
avx2
asked on Stack Overflow Sep 5, 2020 by Noah • edited Sep 5, 2020 by Noah

1 Answer

5

The key ingredient is _mm256_sllv_epi64 to shift bits within 64-bit lanes, using runtime-variable shift distances.

The code requires C++/17, only tested in VC++ 2019.

Not sure it’s going to be significantly faster than scalar code though, majority of instructions are 1-cycle latency but it’s too many of them to my taste, VC++ produced about 35 of them on the critical path.

// Move a single bit within 64-bit lanes
template<int index>
inline __m256i moveBit( __m256i position )
{
    static_assert( index >= 0 && index < 8 );

    // Extract index-th byte from the operand
    if constexpr( 7 == index )
    {
        // Most significant byte only needs 1 instruction to shift into position
        position = _mm256_srli_epi64( position, 64 - 8 );
    }
    else
    {
        if constexpr( index > 0 )
        {
            // Shift the operand by `index` bytes to the right.
            // On many CPUs, _mm256_srli_si256 is slightly faster than _mm256_srli_epi64
            position = _mm256_srli_si256( position, index );
        }
        const __m256i lowByte = _mm256_set1_epi64x( 0xFF );
        position = _mm256_and_si256( position, lowByte );
    }
    const __m256i one = _mm256_set1_epi64x( 1 );
    return _mm256_sllv_epi64( one, position );
}

inline uint64_t setBitsAvx2( __m256i positions )
{
    // Process each of the 8 bytes within 64-bit lanes
    const __m256i r0 = moveBit<0>( positions );
    const __m256i r1 = moveBit<1>( positions );
    const __m256i r2 = moveBit<2>( positions );
    const __m256i r3 = moveBit<3>( positions );
    const __m256i r4 = moveBit<4>( positions );
    const __m256i r5 = moveBit<5>( positions );
    const __m256i r6 = moveBit<6>( positions );
    const __m256i r7 = moveBit<7>( positions );
    // vpor instruction is very fast with 1 cycle latency,
    // however modern CPUs can issue and dispatch multiple instructions per cycle,
    // it still makes sense to try reducing dependencies.
    const __m256i r01 = _mm256_or_si256( r0, r1 );
    const __m256i r23 = _mm256_or_si256( r2, r3 );
    const __m256i r45 = _mm256_or_si256( r4, r5 );
    const __m256i r67 = _mm256_or_si256( r6, r7 );
    const __m256i r0123 = _mm256_or_si256( r01, r23 );
    const __m256i r4567 = _mm256_or_si256( r45, r67 );
    const __m256i result = _mm256_or_si256( r0123, r4567 );

    // Reduce 4 8-byte values to scalar
    const __m128i res16 = _mm_or_si128( _mm256_castsi256_si128( result ), _mm256_extracti128_si256( result, 1 ) );
    const __m128i res8 = _mm_or_si128( res16, _mm_unpackhi_epi64( res16, res16 ) );
    return (uint64_t)_mm_cvtsi128_si64( res8 );
};

inline uint64_t setBitsScalar( __m256i positions )
{
    alignas( 32 ) std::array<uint8_t, 32> index_arr;
    _mm256_store_si256( ( __m256i * )index_arr.data(), positions );

    uint64_t result = 0;
    for( uint32_t i = 0; i < 32; i++ )
        result |= ( ( 1ull ) << index_arr[ i ] );
    return result;
}

static void testShuffleBits()
{
    const __m128i src16 = _mm_setr_epi8( 0, 0, 0, 0, 1, 4, 5, 10, 11, 12, 13, 14, 15, 16, 17, 31 );
    const __m256i src32 = _mm256_setr_m128i( src16, _mm_setzero_si128() );
    printf( "AVX2: %" PRIx64 "\n", setBitsAvx2( src32 ) );
    printf( "Scalar: %" PRIx64 "\n", setBitsScalar( src32 ) );
}
answered on Stack Overflow Sep 6, 2020 by Soonts • edited Sep 6, 2020 by Soonts

User contributions licensed under CC BY-SA 3.0