Issue |
127054
|
Summary |
x86/avx512 vperm optimizations can produce incorrect permutations
|
Labels |
new issue
|
Assignees |
|
Reporter |
jlainema
|
// this code should not report mismatches when run on a cpu with AVX512F support:
#include <cstdint>
#include <cstdlib>
#include <cstdio>
#include <immintrin.h>
alignas(64) static const __m512i swizzles_512[10] = {
_mm512_set_epi32( 0, 1, 2, 3, 4, 5, 6, 7,16,17,18,19,20,21,22,23),
_mm512_set_epi32( 8, 9,10,11,12,13,14,15,24,25,26,27,28,29,30,31),
_mm512_set_epi32( 0, 1, 2, 3,16,17,18,19, 8, 9,10,11,24,25,26,27),
_mm512_set_epi32( 4, 5, 6, 7,20,21,22,23,12,13,14,15,28,29,30,31),
_mm512_set_epi32( 0, 1,16,17, 4, 5,20,21, 8, 9,24,25,12,13,28,29),
_mm512_set_epi32( 2, 3,18,19, 6, 7,22,23,10,11,26,27,14,15,30,31),
_mm512_set_epi32( 0,16, 2,18, 4,20, 6,22, 8,24,10,26,12,28,14,30),
_mm512_set_epi32( 1,17, 3,19, 5,21, 7,23, 9,25,11,27,13,29,15,31),
_mm512_set_epi32( 0,16, 1,17, 2,18, 3,19, 4,20, 5,21, 6,22, 7,23),
_mm512_set_epi32( 8,24, 9,25,10,26,11,27,12,28,13,29,14,30,15,31)
};
alignas(64) uint32_t swizzles[10][16] = {
{ 0, 1, 2, 3, 4, 5, 6, 7,16,17,18,19,20,21,22,23},{ 8, 9,10,11,12,13,14,15,24,25,26,27,28,29,30,31},
{ 0, 1, 2, 3,16,17,18,19, 8, 9,10,11,24,25,26,27},{ 4, 5, 6, 7,20,21,22,23,12,13,14,15,28,29,30,31},
{ 0, 1,16,17, 4, 5,20,21, 8, 9,24,25,12,13,28,29},{ 2, 3,18,19, 6, 7,22,23,10,11,26,27,14,15,30,31},
{ 0,16, 2,18, 4,20, 6,22, 8,24,10,26,12,28,14,30},{ 1,17, 3,19, 5,21, 7,23, 9,25,11,27,13,29,15,31},
{ 0,16, 1,17, 2,18, 3,19, 4,20, 5,21, 6,22, 7,23},{ 8,24, 9,25,10,26,11,27,12,28,13,29,14,30,15,31}};
void reorder_gather512(uint32_t *_a) {
__m512i a,b;
a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[0]), _a, 4);
b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[1]), _a, 4);
_mm512_store_epi32(_a,_mm512_min_epi32(a,b));
_mm512_store_epi32(_a+16,_mm512_max_epi32(a,b));
a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[2]), _a, 4);
b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[3]), _a, 4);
_mm512_store_epi32(_a,_mm512_min_epi32(a,b));
_mm512_store_epi32(_a+16,_mm512_max_epi32(a,b));
a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[4]), _a, 4);
b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[5]), _a, 4);
_mm512_store_epi32(_a,_mm512_min_epi32(a,b));
_mm512_store_epi32(_a+16,_mm512_max_epi32(a,b));
a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[6]), _a, 4);
b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[7]), _a, 4);
_mm512_store_epi32(_a,_mm512_min_epi32(a,b));
_mm512_store_epi32(_a+16,_mm512_max_epi32(a,b));
a = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[8]), _a, 4);
b = _mm512_i32gather_epi32(_mm512_load_epi32(swizzles[9]), _a, 4);
_mm512_store_epi32(_a,a);
_mm512_store_epi32(_a+16,b);
}
#define PERM_ACCESS(n) swizzles_512[n]
// #define PERM_ACCESS(n) _mm512_load_epi32(swizzles[n])
void reorder_perm512(uint32_t *_a) {
__m512i a = _mm512_load_epi32(_a), b = _mm512_load_epi32(_a+16);
__m512i x,y;
x = _mm512_permutex2var_epi32(a,PERM_ACCESS(0),b);
y = _mm512_permutex2var_epi32(a,PERM_ACCESS(1),b);
a = _mm512_min_epi32(x,y);
b = _mm512_max_epi32(x,y);
x = _mm512_permutex2var_epi32(a,PERM_ACCESS(2),b);
y = _mm512_permutex2var_epi32(a,PERM_ACCESS(3),b);
a = _mm512_min_epi32(x,y);
b = _mm512_max_epi32(x,y);
x = _mm512_permutex2var_epi32(a,PERM_ACCESS(4),b);
y = _mm512_permutex2var_epi32(a,PERM_ACCESS(5),b);
a = _mm512_min_epi32(x,y);
b = _mm512_max_epi32(x,y);
x = _mm512_permutex2var_epi32(a,PERM_ACCESS(6),b);
y = _mm512_permutex2var_epi32(a,PERM_ACCESS(7),b);
a = _mm512_min_epi32(x,y);
b = _mm512_max_epi32(x,y);
_mm512_store_epi32(_a,_mm512_permutex2var_epi32(a,PERM_ACCESS(8),b));
_mm512_store_epi32(_a+16,_mm512_permutex2var_epi32(a,PERM_ACCESS(9),b));
}
int main() {
alignas(64) uint32_t aval[32], bval[32];
for(unsigned i=0; i<32; i++) aval[i] = bval[i] = rand();
reorder_perm512(aval);
reorder_gather512(bval);
for (unsigned i=0; i<10; i++) {
__m512i dx = _mm512_sub_epi32(swizzles_512[i], _mm512_load_epi32(swizzles[i]));
if (_mm512_reduce_add_epi32(dx))
printf("index mismatch for register %d\n", i);
}
for (unsigned i=0; i<32; i++) {
if (aval[i] != bval[i]) {
printf("gather/permute mismatch %08x != %08x (@%d)\n", aval[i], bval[i], i);
}
}
return 0;
}
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs