Issue 127054
Summary x86/avx512 vperm optimizations can produce incorrect permutations
Labels new issue
Assignees
Reporter jlainema
    // this code should not report mismatches when run on a cpu with AVX512F support:

#include <cstdint>
#include <cstdlib>
#include <cstdio>
#include <immintrin.h>

alignas(64) static const __m512i swizzles_512[10] = {
  _mm512_set_epi32( 0, 1, 2, 3, 4, 5, 6, 7,16,17,18,19,20,21,22,23),
  _mm512_set_epi32( 8, 9,10,11,12,13,14,15,24,25,26,27,28,29,30,31),
  _mm512_set_epi32( 0, 1, 2, 3,16,17,18,19, 8, 9,10,11,24,25,26,27),
  _mm512_set_epi32( 4, 5, 6, 7,20,21,22,23,12,13,14,15,28,29,30,31),
  _mm512_set_epi32( 0, 1,16,17, 4, 5,20,21, 8, 9,24,25,12,13,28,29),
  _mm512_set_epi32( 2, 3,18,19, 6, 7,22,23,10,11,26,27,14,15,30,31),
  _mm512_set_epi32( 0,16, 2,18, 4,20, 6,22, 8,24,10,26,12,28,14,30),
  _mm512_set_epi32( 1,17, 3,19, 5,21, 7,23, 9,25,11,27,13,29,15,31),
  _mm512_set_epi32( 0,16, 1,17, 2,18, 3,19, 4,20, 5,21, 6,22, 7,23),
  _mm512_set_epi32( 8,24, 9,25,10,26,11,27,12,28,13,29,14,30,15,31)
};

alignas(64) uint32_t swizzles[10][16] = {
  { 0, 1, 2, 3, 4, 5, 6, 7,16,17,18,19,20,21,22,23},{ 8, 9,10,11,12,13,14,15,24,25,26,27,28,29,30,31},
  { 0, 1, 2, 3,16,17,18,19, 8, 9,10,11,24,25,26,27},{ 4, 5, 6, 7,20,21,22,23,12,13,14,15,28,29,30,31},
 { 0, 1,16,17, 4, 5,20,21, 8, 9,24,25,12,13,28,29},{ 2, 3,18,19, 6, 7,22,23,10,11,26,27,14,15,30,31},
  { 0,16, 2,18, 4,20, 6,22, 8,24,10,26,12,28,14,30},{ 1,17, 3,19, 5,21, 7,23, 9,25,11,27,13,29,15,31},
 { 0,16, 1,17, 2,18, 3,19, 4,20, 5,21, 6,22, 7,23},{ 8,24, 9,25,10,26,11,27,12,28,13,29,14,30,15,31}};

void reorder_gather512(uint32_t *_a) {
  __m512i a,b;
  a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[0]), _a, 4);  
  b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[1]), _a, 4);
 _mm512_store_epi32(_a,_mm512_min_epi32(a,b));
 _mm512_store_epi32(_a+16,_mm512_max_epi32(a,b));
  a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[2]), _a, 4);  
  b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[3]), _a, 4);
 _mm512_store_epi32(_a,_mm512_min_epi32(a,b));
 _mm512_store_epi32(_a+16,_mm512_max_epi32(a,b));
  a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[4]), _a, 4);  
  b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[5]), _a, 4);
 _mm512_store_epi32(_a,_mm512_min_epi32(a,b));
 _mm512_store_epi32(_a+16,_mm512_max_epi32(a,b));
  a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[6]), _a, 4);  
  b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[7]), _a, 4);
 _mm512_store_epi32(_a,_mm512_min_epi32(a,b));
 _mm512_store_epi32(_a+16,_mm512_max_epi32(a,b));
  a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[8]), _a, 4);  
  b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[9]), _a, 4);
 _mm512_store_epi32(_a,a);
  _mm512_store_epi32(_a+16,b);
}

#define PERM_ACCESS(n) swizzles_512[n]
// #define PERM_ACCESS(n) _mm512_load_epi32(swizzles[n])

void reorder_perm512(uint32_t *_a) {
 __m512i a = _mm512_load_epi32(_a), b = _mm512_load_epi32(_a+16);
  __m512i x,y;
  x = _mm512_permutex2var_epi32(a,PERM_ACCESS(0),b);
  y = _mm512_permutex2var_epi32(a,PERM_ACCESS(1),b);
  a = _mm512_min_epi32(x,y);
  b = _mm512_max_epi32(x,y);
  x = _mm512_permutex2var_epi32(a,PERM_ACCESS(2),b);
  y = _mm512_permutex2var_epi32(a,PERM_ACCESS(3),b);
  a = _mm512_min_epi32(x,y);
  b = _mm512_max_epi32(x,y);
  x = _mm512_permutex2var_epi32(a,PERM_ACCESS(4),b);
  y = _mm512_permutex2var_epi32(a,PERM_ACCESS(5),b);
  a = _mm512_min_epi32(x,y);
  b = _mm512_max_epi32(x,y);
  x = _mm512_permutex2var_epi32(a,PERM_ACCESS(6),b);
  y = _mm512_permutex2var_epi32(a,PERM_ACCESS(7),b);
  a = _mm512_min_epi32(x,y);
  b = _mm512_max_epi32(x,y);
 _mm512_store_epi32(_a,_mm512_permutex2var_epi32(a,PERM_ACCESS(8),b));
 _mm512_store_epi32(_a+16,_mm512_permutex2var_epi32(a,PERM_ACCESS(9),b));
}

int main() {
  alignas(64) uint32_t aval[32], bval[32];

  for(unsigned i=0; i<32; i++) aval[i] = bval[i] = rand();

  reorder_perm512(aval);
 reorder_gather512(bval);

  for (unsigned i=0; i<10; i++) {
    __m512i dx = _mm512_sub_epi32(swizzles_512[i], _mm512_load_epi32(swizzles[i]));
    if (_mm512_reduce_add_epi32(dx))
      printf("index mismatch for register %d\n", i);
  }

  for (unsigned i=0; i<32; i++) {
    if (aval[i] != bval[i]) {
      printf("gather/permute mismatch %08x != %08x (@%d)\n", aval[i], bval[i], i);
    }
  }

  return 0;
}
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to