Is `__shfl_sync` broken for 64-bit?

1

I have implemented a warp-wide and block-wide reduction using shuffle instructions. Everything works fine when I use 32-bit types, but for 64-bit I always get 0 as a result. To my knowledge, shuffling supports 64-bit arguments. What am I missing?

#include <stdio.h>

template<typename T>
inline __device__ T warpRegSumTest(T val) {
  T result = val;
  static constexpr unsigned mask = 0xffffffff;
#pragma unroll
  for (int delta = 16; delta > 0; delta /= 2) {
    result = result + __shfl_down_sync(mask, result, delta);
  }
  return result;
}

template<int numWarpsInBlock, typename T>
inline __device__ T blockRegSumTest(T val) {
  __shared__ T part[numWarpsInBlock];
  T warppart = warpRegSumTest(val);
  if (threadIdx.x % 32 == 0) {
    part[threadIdx.x / 32] = warppart;
  }
  __syncthreads();
  if (threadIdx.x < 32) {
    int tid = threadIdx.x;
    T solution = warpRegSumTest(tid < numWarpsInBlock ? part[tid] : T(0));
    __syncwarp();
    part[0] = solution;
  }

  __syncthreads();
  T result = part[0];
  __syncthreads();
  return result;
}

__global__ void testKernel() {
  float float_result = blockRegSumTest<256 / 32>(float(threadIdx.x));
  if (threadIdx.x == 0) {
    printf("Float sum: %f\n", float_result);
  }
  double double_result = blockRegSumTest<256 / 32>(double(threadIdx.x));
  if (threadIdx.x == 0) {
    printf("Double sum: %f\n", double_result);
  }
  int int_result = blockRegSumTest<256 / 32>(int(threadIdx.x));
  if (threadIdx.x == 0) {
    printf("Int sum: %d\n", int_result);
  }
  long long longlong_result = blockRegSumTest<256 / 32>(long long(threadIdx.x));
  if (threadIdx.x == 0) {
    printf("Long long sum: %lld\n", longlong_result);
  }
}

int main()
{
  testKernel << <1, 256 >> > ();
}

I am compiling this with compute_70,sm_70 and running on GTX 2070 SUPER. It outputs:

Float sum: 32640.000000
Double sum: 0.000000
Int sum: 32640
Long long sum: 0

I expected to see 32640 (the sum 0+1+2+...+255) in all 4 cases.

cuda
asked on Stack Overflow Feb 18, 2020 by CygnusX1

1 Answer

3

You've got an error here:

part[0] = solution;

it should be:

if (!threadIdx.x) part[0] = solution;

You only want thread 0 to execute that line.

answered on Stack Overflow Feb 18, 2020 by Robert Crovella

User contributions licensed under CC BY-SA 3.0