magyarsort/simd-sort/avx2-nate-quicksort.cpp

437 lines
14 KiB
C++

#pragma once
#include <x86intrin.h>
#include <cstdint>
namespace nate {
const int AVX2_REGISTER_SIZE = 8; // in 32-bit words
/*
Algorithm of vectorized partition by Nathan Kurz.
Everything else was borrowed from Daniel's code (avx2-altquicksort.h)
*/
static uint32_t reverseshufflemask[256 * 8] __attribute__((aligned(0x100))) = {
0, 1, 2, 3, 4, 5, 6, 7, /* 0*/
1, 2, 3, 4, 5, 6, 7, 0, /* 1*/
0, 2, 3, 4, 5, 6, 7, 1, /* 2*/
2, 3, 4, 5, 6, 7, 0, 1, /* 3*/
0, 1, 3, 4, 5, 6, 7, 2, /* 4*/
1, 3, 4, 5, 6, 7, 0, 2, /* 5*/
0, 3, 4, 5, 6, 7, 1, 2, /* 6*/
3, 4, 5, 6, 7, 0, 1, 2, /* 7*/
0, 1, 2, 4, 5, 6, 7, 3, /* 8*/
1, 2, 4, 5, 6, 7, 0, 3, /* 9*/
0, 2, 4, 5, 6, 7, 1, 3, /* 10*/
2, 4, 5, 6, 7, 0, 1, 3, /* 11*/
0, 1, 4, 5, 6, 7, 2, 3, /* 12*/
1, 4, 5, 6, 7, 0, 2, 3, /* 13*/
0, 4, 5, 6, 7, 1, 2, 3, /* 14*/
4, 5, 6, 7, 0, 1, 2, 3, /* 15*/
0, 1, 2, 3, 5, 6, 7, 4, /* 16*/
1, 2, 3, 5, 6, 7, 0, 4, /* 17*/
0, 2, 3, 5, 6, 7, 1, 4, /* 18*/
2, 3, 5, 6, 7, 0, 1, 4, /* 19*/
0, 1, 3, 5, 6, 7, 2, 4, /* 20*/
1, 3, 5, 6, 7, 0, 2, 4, /* 21*/
0, 3, 5, 6, 7, 1, 2, 4, /* 22*/
3, 5, 6, 7, 0, 1, 2, 4, /* 23*/
0, 1, 2, 5, 6, 7, 3, 4, /* 24*/
1, 2, 5, 6, 7, 0, 3, 4, /* 25*/
0, 2, 5, 6, 7, 1, 3, 4, /* 26*/
2, 5, 6, 7, 0, 1, 3, 4, /* 27*/
0, 1, 5, 6, 7, 2, 3, 4, /* 28*/
1, 5, 6, 7, 0, 2, 3, 4, /* 29*/
0, 5, 6, 7, 1, 2, 3, 4, /* 30*/
5, 6, 7, 0, 1, 2, 3, 4, /* 31*/
0, 1, 2, 3, 4, 6, 7, 5, /* 32*/
1, 2, 3, 4, 6, 7, 0, 5, /* 33*/
0, 2, 3, 4, 6, 7, 1, 5, /* 34*/
2, 3, 4, 6, 7, 0, 1, 5, /* 35*/
0, 1, 3, 4, 6, 7, 2, 5, /* 36*/
1, 3, 4, 6, 7, 0, 2, 5, /* 37*/
0, 3, 4, 6, 7, 1, 2, 5, /* 38*/
3, 4, 6, 7, 0, 1, 2, 5, /* 39*/
0, 1, 2, 4, 6, 7, 3, 5, /* 40*/
1, 2, 4, 6, 7, 0, 3, 5, /* 41*/
0, 2, 4, 6, 7, 1, 3, 5, /* 42*/
2, 4, 6, 7, 0, 1, 3, 5, /* 43*/
0, 1, 4, 6, 7, 2, 3, 5, /* 44*/
1, 4, 6, 7, 0, 2, 3, 5, /* 45*/
0, 4, 6, 7, 1, 2, 3, 5, /* 46*/
4, 6, 7, 0, 1, 2, 3, 5, /* 47*/
0, 1, 2, 3, 6, 7, 4, 5, /* 48*/
1, 2, 3, 6, 7, 0, 4, 5, /* 49*/
0, 2, 3, 6, 7, 1, 4, 5, /* 50*/
2, 3, 6, 7, 0, 1, 4, 5, /* 51*/
0, 1, 3, 6, 7, 2, 4, 5, /* 52*/
1, 3, 6, 7, 0, 2, 4, 5, /* 53*/
0, 3, 6, 7, 1, 2, 4, 5, /* 54*/
3, 6, 7, 0, 1, 2, 4, 5, /* 55*/
0, 1, 2, 6, 7, 3, 4, 5, /* 56*/
1, 2, 6, 7, 0, 3, 4, 5, /* 57*/
0, 2, 6, 7, 1, 3, 4, 5, /* 58*/
2, 6, 7, 0, 1, 3, 4, 5, /* 59*/
0, 1, 6, 7, 2, 3, 4, 5, /* 60*/
1, 6, 7, 0, 2, 3, 4, 5, /* 61*/
0, 6, 7, 1, 2, 3, 4, 5, /* 62*/
6, 7, 0, 1, 2, 3, 4, 5, /* 63*/
0, 1, 2, 3, 4, 5, 7, 6, /* 64*/
1, 2, 3, 4, 5, 7, 0, 6, /* 65*/
0, 2, 3, 4, 5, 7, 1, 6, /* 66*/
2, 3, 4, 5, 7, 0, 1, 6, /* 67*/
0, 1, 3, 4, 5, 7, 2, 6, /* 68*/
1, 3, 4, 5, 7, 0, 2, 6, /* 69*/
0, 3, 4, 5, 7, 1, 2, 6, /* 70*/
3, 4, 5, 7, 0, 1, 2, 6, /* 71*/
0, 1, 2, 4, 5, 7, 3, 6, /* 72*/
1, 2, 4, 5, 7, 0, 3, 6, /* 73*/
0, 2, 4, 5, 7, 1, 3, 6, /* 74*/
2, 4, 5, 7, 0, 1, 3, 6, /* 75*/
0, 1, 4, 5, 7, 2, 3, 6, /* 76*/
1, 4, 5, 7, 0, 2, 3, 6, /* 77*/
0, 4, 5, 7, 1, 2, 3, 6, /* 78*/
4, 5, 7, 0, 1, 2, 3, 6, /* 79*/
0, 1, 2, 3, 5, 7, 4, 6, /* 80*/
1, 2, 3, 5, 7, 0, 4, 6, /* 81*/
0, 2, 3, 5, 7, 1, 4, 6, /* 82*/
2, 3, 5, 7, 0, 1, 4, 6, /* 83*/
0, 1, 3, 5, 7, 2, 4, 6, /* 84*/
1, 3, 5, 7, 0, 2, 4, 6, /* 85*/
0, 3, 5, 7, 1, 2, 4, 6, /* 86*/
3, 5, 7, 0, 1, 2, 4, 6, /* 87*/
0, 1, 2, 5, 7, 3, 4, 6, /* 88*/
1, 2, 5, 7, 0, 3, 4, 6, /* 89*/
0, 2, 5, 7, 1, 3, 4, 6, /* 90*/
2, 5, 7, 0, 1, 3, 4, 6, /* 91*/
0, 1, 5, 7, 2, 3, 4, 6, /* 92*/
1, 5, 7, 0, 2, 3, 4, 6, /* 93*/
0, 5, 7, 1, 2, 3, 4, 6, /* 94*/
5, 7, 0, 1, 2, 3, 4, 6, /* 95*/
0, 1, 2, 3, 4, 7, 5, 6, /* 96*/
1, 2, 3, 4, 7, 0, 5, 6, /* 97*/
0, 2, 3, 4, 7, 1, 5, 6, /* 98*/
2, 3, 4, 7, 0, 1, 5, 6, /* 99*/
0, 1, 3, 4, 7, 2, 5, 6, /* 100*/
1, 3, 4, 7, 0, 2, 5, 6, /* 101*/
0, 3, 4, 7, 1, 2, 5, 6, /* 102*/
3, 4, 7, 0, 1, 2, 5, 6, /* 103*/
0, 1, 2, 4, 7, 3, 5, 6, /* 104*/
1, 2, 4, 7, 0, 3, 5, 6, /* 105*/
0, 2, 4, 7, 1, 3, 5, 6, /* 106*/
2, 4, 7, 0, 1, 3, 5, 6, /* 107*/
0, 1, 4, 7, 2, 3, 5, 6, /* 108*/
1, 4, 7, 0, 2, 3, 5, 6, /* 109*/
0, 4, 7, 1, 2, 3, 5, 6, /* 110*/
4, 7, 0, 1, 2, 3, 5, 6, /* 111*/
0, 1, 2, 3, 7, 4, 5, 6, /* 112*/
1, 2, 3, 7, 0, 4, 5, 6, /* 113*/
0, 2, 3, 7, 1, 4, 5, 6, /* 114*/
2, 3, 7, 0, 1, 4, 5, 6, /* 115*/
0, 1, 3, 7, 2, 4, 5, 6, /* 116*/
1, 3, 7, 0, 2, 4, 5, 6, /* 117*/
0, 3, 7, 1, 2, 4, 5, 6, /* 118*/
3, 7, 0, 1, 2, 4, 5, 6, /* 119*/
0, 1, 2, 7, 3, 4, 5, 6, /* 120*/
1, 2, 7, 0, 3, 4, 5, 6, /* 121*/
0, 2, 7, 1, 3, 4, 5, 6, /* 122*/
2, 7, 0, 1, 3, 4, 5, 6, /* 123*/
0, 1, 7, 2, 3, 4, 5, 6, /* 124*/
1, 7, 0, 2, 3, 4, 5, 6, /* 125*/
0, 7, 1, 2, 3, 4, 5, 6, /* 126*/
7, 0, 1, 2, 3, 4, 5, 6, /* 127*/
0, 1, 2, 3, 4, 5, 6, 7, /* 128*/
1, 2, 3, 4, 5, 6, 0, 7, /* 129*/
0, 2, 3, 4, 5, 6, 1, 7, /* 130*/
2, 3, 4, 5, 6, 0, 1, 7, /* 131*/
0, 1, 3, 4, 5, 6, 2, 7, /* 132*/
1, 3, 4, 5, 6, 0, 2, 7, /* 133*/
0, 3, 4, 5, 6, 1, 2, 7, /* 134*/
3, 4, 5, 6, 0, 1, 2, 7, /* 135*/
0, 1, 2, 4, 5, 6, 3, 7, /* 136*/
1, 2, 4, 5, 6, 0, 3, 7, /* 137*/
0, 2, 4, 5, 6, 1, 3, 7, /* 138*/
2, 4, 5, 6, 0, 1, 3, 7, /* 139*/
0, 1, 4, 5, 6, 2, 3, 7, /* 140*/
1, 4, 5, 6, 0, 2, 3, 7, /* 141*/
0, 4, 5, 6, 1, 2, 3, 7, /* 142*/
4, 5, 6, 0, 1, 2, 3, 7, /* 143*/
0, 1, 2, 3, 5, 6, 4, 7, /* 144*/
1, 2, 3, 5, 6, 0, 4, 7, /* 145*/
0, 2, 3, 5, 6, 1, 4, 7, /* 146*/
2, 3, 5, 6, 0, 1, 4, 7, /* 147*/
0, 1, 3, 5, 6, 2, 4, 7, /* 148*/
1, 3, 5, 6, 0, 2, 4, 7, /* 149*/
0, 3, 5, 6, 1, 2, 4, 7, /* 150*/
3, 5, 6, 0, 1, 2, 4, 7, /* 151*/
0, 1, 2, 5, 6, 3, 4, 7, /* 152*/
1, 2, 5, 6, 0, 3, 4, 7, /* 153*/
0, 2, 5, 6, 1, 3, 4, 7, /* 154*/
2, 5, 6, 0, 1, 3, 4, 7, /* 155*/
0, 1, 5, 6, 2, 3, 4, 7, /* 156*/
1, 5, 6, 0, 2, 3, 4, 7, /* 157*/
0, 5, 6, 1, 2, 3, 4, 7, /* 158*/
5, 6, 0, 1, 2, 3, 4, 7, /* 159*/
0, 1, 2, 3, 4, 6, 5, 7, /* 160*/
1, 2, 3, 4, 6, 0, 5, 7, /* 161*/
0, 2, 3, 4, 6, 1, 5, 7, /* 162*/
2, 3, 4, 6, 0, 1, 5, 7, /* 163*/
0, 1, 3, 4, 6, 2, 5, 7, /* 164*/
1, 3, 4, 6, 0, 2, 5, 7, /* 165*/
0, 3, 4, 6, 1, 2, 5, 7, /* 166*/
3, 4, 6, 0, 1, 2, 5, 7, /* 167*/
0, 1, 2, 4, 6, 3, 5, 7, /* 168*/
1, 2, 4, 6, 0, 3, 5, 7, /* 169*/
0, 2, 4, 6, 1, 3, 5, 7, /* 170*/
2, 4, 6, 0, 1, 3, 5, 7, /* 171*/
0, 1, 4, 6, 2, 3, 5, 7, /* 172*/
1, 4, 6, 0, 2, 3, 5, 7, /* 173*/
0, 4, 6, 1, 2, 3, 5, 7, /* 174*/
4, 6, 0, 1, 2, 3, 5, 7, /* 175*/
0, 1, 2, 3, 6, 4, 5, 7, /* 176*/
1, 2, 3, 6, 0, 4, 5, 7, /* 177*/
0, 2, 3, 6, 1, 4, 5, 7, /* 178*/
2, 3, 6, 0, 1, 4, 5, 7, /* 179*/
0, 1, 3, 6, 2, 4, 5, 7, /* 180*/
1, 3, 6, 0, 2, 4, 5, 7, /* 181*/
0, 3, 6, 1, 2, 4, 5, 7, /* 182*/
3, 6, 0, 1, 2, 4, 5, 7, /* 183*/
0, 1, 2, 6, 3, 4, 5, 7, /* 184*/
1, 2, 6, 0, 3, 4, 5, 7, /* 185*/
0, 2, 6, 1, 3, 4, 5, 7, /* 186*/
2, 6, 0, 1, 3, 4, 5, 7, /* 187*/
0, 1, 6, 2, 3, 4, 5, 7, /* 188*/
1, 6, 0, 2, 3, 4, 5, 7, /* 189*/
0, 6, 1, 2, 3, 4, 5, 7, /* 190*/
6, 0, 1, 2, 3, 4, 5, 7, /* 191*/
0, 1, 2, 3, 4, 5, 6, 7, /* 192*/
1, 2, 3, 4, 5, 0, 6, 7, /* 193*/
0, 2, 3, 4, 5, 1, 6, 7, /* 194*/
2, 3, 4, 5, 0, 1, 6, 7, /* 195*/
0, 1, 3, 4, 5, 2, 6, 7, /* 196*/
1, 3, 4, 5, 0, 2, 6, 7, /* 197*/
0, 3, 4, 5, 1, 2, 6, 7, /* 198*/
3, 4, 5, 0, 1, 2, 6, 7, /* 199*/
0, 1, 2, 4, 5, 3, 6, 7, /* 200*/
1, 2, 4, 5, 0, 3, 6, 7, /* 201*/
0, 2, 4, 5, 1, 3, 6, 7, /* 202*/
2, 4, 5, 0, 1, 3, 6, 7, /* 203*/
0, 1, 4, 5, 2, 3, 6, 7, /* 204*/
1, 4, 5, 0, 2, 3, 6, 7, /* 205*/
0, 4, 5, 1, 2, 3, 6, 7, /* 206*/
4, 5, 0, 1, 2, 3, 6, 7, /* 207*/
0, 1, 2, 3, 5, 4, 6, 7, /* 208*/
1, 2, 3, 5, 0, 4, 6, 7, /* 209*/
0, 2, 3, 5, 1, 4, 6, 7, /* 210*/
2, 3, 5, 0, 1, 4, 6, 7, /* 211*/
0, 1, 3, 5, 2, 4, 6, 7, /* 212*/
1, 3, 5, 0, 2, 4, 6, 7, /* 213*/
0, 3, 5, 1, 2, 4, 6, 7, /* 214*/
3, 5, 0, 1, 2, 4, 6, 7, /* 215*/
0, 1, 2, 5, 3, 4, 6, 7, /* 216*/
1, 2, 5, 0, 3, 4, 6, 7, /* 217*/
0, 2, 5, 1, 3, 4, 6, 7, /* 218*/
2, 5, 0, 1, 3, 4, 6, 7, /* 219*/
0, 1, 5, 2, 3, 4, 6, 7, /* 220*/
1, 5, 0, 2, 3, 4, 6, 7, /* 221*/
0, 5, 1, 2, 3, 4, 6, 7, /* 222*/
5, 0, 1, 2, 3, 4, 6, 7, /* 223*/
0, 1, 2, 3, 4, 5, 6, 7, /* 224*/
1, 2, 3, 4, 0, 5, 6, 7, /* 225*/
0, 2, 3, 4, 1, 5, 6, 7, /* 226*/
2, 3, 4, 0, 1, 5, 6, 7, /* 227*/
0, 1, 3, 4, 2, 5, 6, 7, /* 228*/
1, 3, 4, 0, 2, 5, 6, 7, /* 229*/
0, 3, 4, 1, 2, 5, 6, 7, /* 230*/
3, 4, 0, 1, 2, 5, 6, 7, /* 231*/
0, 1, 2, 4, 3, 5, 6, 7, /* 232*/
1, 2, 4, 0, 3, 5, 6, 7, /* 233*/
0, 2, 4, 1, 3, 5, 6, 7, /* 234*/
2, 4, 0, 1, 3, 5, 6, 7, /* 235*/
0, 1, 4, 2, 3, 5, 6, 7, /* 236*/
1, 4, 0, 2, 3, 5, 6, 7, /* 237*/
0, 4, 1, 2, 3, 5, 6, 7, /* 238*/
4, 0, 1, 2, 3, 5, 6, 7, /* 239*/
0, 1, 2, 3, 4, 5, 6, 7, /* 240*/
1, 2, 3, 0, 4, 5, 6, 7, /* 241*/
0, 2, 3, 1, 4, 5, 6, 7, /* 242*/
2, 3, 0, 1, 4, 5, 6, 7, /* 243*/
0, 1, 3, 2, 4, 5, 6, 7, /* 244*/
1, 3, 0, 2, 4, 5, 6, 7, /* 245*/
0, 3, 1, 2, 4, 5, 6, 7, /* 246*/
3, 0, 1, 2, 4, 5, 6, 7, /* 247*/
0, 1, 2, 3, 4, 5, 6, 7, /* 248*/
1, 2, 0, 3, 4, 5, 6, 7, /* 249*/
0, 2, 1, 3, 4, 5, 6, 7, /* 250*/
2, 0, 1, 3, 4, 5, 6, 7, /* 251*/
0, 1, 2, 3, 4, 5, 6, 7, /* 252*/
1, 0, 2, 3, 4, 5, 6, 7, /* 253*/
0, 1, 2, 3, 4, 5, 6, 7, /* 254*/
0, 1, 2, 3, 4, 5, 6, 7, /* 255*/
};
static uint32_t avx_pivot_on_last_value(int32_t *array, size_t length) {
if (length <= 1)
return 1;
{ // we exchange the last value for the middle value for a better pivot
int32_t ival = array[length / 2];
int32_t bval = array[length - 1];
array[length / 2] = bval;
array[length - 1] = ival;
}
int wh, rh;
int wt, rt;
rh = 0;
rt = length - 1;
wh = 0;
wt = length - 1;
__m256i h0 = _mm256_loadu_si256((__m256i*)(array + rh));
__m256i h1 = _mm256_loadu_si256((__m256i*)(array + rh + AVX2_REGISTER_SIZE));
rh += 2*AVX2_REGISTER_SIZE;
__m256i t0 = _mm256_loadu_si256((__m256i*)(array + rt - AVX2_REGISTER_SIZE));
__m256i t1 = _mm256_loadu_si256((__m256i*)(array + rt - 2*AVX2_REGISTER_SIZE));
rt -= 3*AVX2_REGISTER_SIZE;
__m256i current = _mm256_loadu_si256((__m256i*)(array + rh));
__m256i next;
rh += AVX2_REGISTER_SIZE;
const int32_t pivot = array[length - 1];
const __m256i P = _mm256_set1_epi32(pivot);
// 1. the main loop
while (wh - wt < 4*AVX2_REGISTER_SIZE) {
const bool which = (rh - wh) < (wt - rt);
// I believe that a compiler will emit branchless code
__m256i* next_ptr = (__m256i*)(which ? array + rh : array + rt);
const int adv_rh = which ? AVX2_REGISTER_SIZE : 0;
const int adv_rt = which ? 0 : AVX2_REGISTER_SIZE;
// My faith ends here
next = _mm256_loadu_si256(next_ptr);
const int pvbyte = _mm256_movemask_ps((__m256)_mm256_cmpgt_epi32(current, P));
const uint32_t cnt = 8 - _mm_popcnt_u32(pvbyte);
const __m256i shuf = _mm256_load_si256((__m256i *)(reverseshufflemask + 8 * pvbyte));
const __m256i ordered = _mm256_permutevar8x32_epi32(current, shuf);
_mm256_storeu_si256((__m256i*)(array + wh), ordered);
_mm256_storeu_si256((__m256i*)(array + wt - 8), ordered);
rh += adv_rh;
rt -= adv_rt;
wh += cnt;
wt -= 8 - cnt;
current = next;
}
// 2. partition remaining part
while (wh - wt > 4*AVX2_REGISTER_SIZE) {
const int32_t v = array[rh++];
if (v < pivot) {
array[wh++] = v;
} else {
array[wt--] = v;
}
}
// 3. partition 4 registers loaded in the beginning
static uint32_t tmp[4*AVX2_REGISTER_SIZE];
_mm256_storeu_si256((__m256i*)(tmp + 0*AVX2_REGISTER_SIZE), h0);
_mm256_storeu_si256((__m256i*)(tmp + 1*AVX2_REGISTER_SIZE), h1);
_mm256_storeu_si256((__m256i*)(tmp + 2*AVX2_REGISTER_SIZE), t0);
_mm256_storeu_si256((__m256i*)(tmp + 3*AVX2_REGISTER_SIZE), t1);
for (int i=0; i < 4*AVX2_REGISTER_SIZE; i++) {
const int32_t v = array[i];
if (v < pivot) {
array[wh++] = v;
} else {
array[wt--] = v;
}
}
{
const int32_t a = array[length - 1];
const int32_t b = array[wh];
array[length - 1] = b;
array[wh] = a;
}
return wh;
}
// for fallback
void scalar_partition(int32_t* array, const int32_t pivot, int& left, int& right) {
while (left <= right) {
while (array[left] < pivot) {
left += 1;
}
while (array[right] > pivot) {
right -= 1;
}
if (left <= right) {
const uint32_t t = array[left];
array[left] = array[right];
array[right] = t;
left += 1;
right -= 1;
}
}
}
//fallback
void scalar_quicksort(int32_t* array, int left, int right) {
#ifdef WITH_RUNTIME_STATS
statistics.scalar__partition_calls += 1;
statistics.scalar__items_processed += right - left + 1;
#endif
int i = left;
int j = right;
const int32_t pivot = array[(i + j)/2];
scalar_partition(array, pivot, i, j);
if (left < j) {
scalar_quicksort(array, left, j);
}
if (i < right) {
scalar_quicksort(array, i, right);
}
}
void avx2_pivotonlast_sort(int32_t *array, const uint32_t length) {
uint32_t sep;
if (length > 8*AVX2_REGISTER_SIZE) {
sep = avx_pivot_on_last_value(array, length);
} else {
sep = lomuto_partition_epi32((uint32_t*)array, 0, length - 1);
}
if(sep == length) {
// we have an ineffective pivot. Let us give up.
if(length > 1) scalar_quicksort(array,0,length - 1);
} else {
if (sep > 2) {
avx2_pivotonlast_sort(array, sep - 1);
}
if (sep + 1 < length) {
avx2_pivotonlast_sort(array + sep, length - sep);
}
}
}
void wrapped_avx2_pivotonlast_sort(uint32_t *array, int left, int right) {
avx2_pivotonlast_sort((int32_t *)array + left, right - left + 1);
}
} // namespace nate