magyarsort/simd-sort/avx2-partition.cpp

198 lines
6.0 KiB
C++

#include <x86intrin.h>
#include <cstdint>
namespace qs {
namespace avx2 {
__m256i FORCE_INLINE bitmask_to_bytemask_epi32(uint8_t bm) {
const __m256i mask = _mm256_set1_epi32(bm);
const __m256i bits = _mm256_setr_epi32(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80);
const __m256i tmp = _mm256_and_si256(mask, bits);
return _mm256_cmpeq_epi32(tmp, bits);
}
void FORCE_INLINE align_masks(uint8_t& a, uint8_t& b, uint8_t& rem_a, uint8_t& rem_b, __m256i& shuffle_a, __m256i& shuffle_b) {
assert(a != 0);
assert(b != 0);
uint8_t tmpA = a;
uint8_t tmpB = b;
uint32_t __attribute__((__aligned__(32))) tmpshufa[8];
uint32_t __attribute__((__aligned__(32))) tmpshufb[8];
while (tmpA != 0 && tmpB != 0) {
int idx_a = __builtin_ctz(tmpA);
int idx_b = __builtin_ctz(tmpB);
tmpA = tmpA & (tmpA - 1);
tmpB = tmpB & (tmpB - 1);
tmpshufb[idx_a] = idx_b;
tmpshufa[idx_b] = idx_a;
}
a = a ^ tmpA;
b = b ^ tmpB;
assert(a != 0);
assert(b != 0);
assert(_mm_popcnt_u64(a) == _mm_popcnt_u64(b));
rem_a = tmpA;
rem_b = tmpB;
shuffle_a = _mm256_load_si256((__m256i*)tmpshufa);
shuffle_b = _mm256_load_si256((__m256i*)tmpshufb);
}
__m256i FORCE_INLINE merge(const __m256i mask, const __m256i a, const __m256i b) {
return _mm256_or_si256(
_mm256_and_si256(mask, a),
_mm256_andnot_si256(mask, b)
);
}
void FORCE_INLINE swap_epi32(
__m256i& a, __m256i& b,
uint8_t mask_a, const __m256i shuffle_a,
uint8_t mask_b, const __m256i shuffle_b) {
const __m256i to_swap_b = _mm256_permutevar8x32_epi32(a, shuffle_a);
const __m256i to_swap_a = _mm256_permutevar8x32_epi32(b, shuffle_b);
const __m256i ma = bitmask_to_bytemask_epi32(mask_a);
const __m256i mb = bitmask_to_bytemask_epi32(mask_b);
a = merge(ma, to_swap_a, a);
b = merge(mb, to_swap_b, b);
}
#define _mm256_iszero(vec) (_mm256_testz_si256(vec, vec) != 0)
void FORCE_INLINE partition_epi32(uint32_t* array, uint32_t pv, int& left, int& right) {
const int N = 8; // the number of items in a register (256/32)
__m256i L;
__m256i R;
uint8_t maskL = 0;
uint8_t maskR = 0;
const __m256i pivot = _mm256_set1_epi32(pv);
int origL = left;
int origR = right;
while (true) {
if (maskL == 0) {
while (true) {
if (right - (left + N) + 1 < 2*N) {
goto end;
}
L = _mm256_loadu_si256((__m256i*)(array + left));
const __m256i bytemask = _mm256_cmpgt_epi32(pivot, L);
if (_mm256_testc_ps((__m256)bytemask, (__m256)_mm256_set1_epi32(-1))) {
left += N;
} else {
maskL = ~_mm256_movemask_ps((__m256)bytemask);
break;
}
}
}
if (maskR == 0) {
while (true) {
if ((right - N) - left + 1 < 2*N) {
goto end;
}
R = _mm256_loadu_si256((__m256i*)(array + right - N + 1));
const __m256i bytemask = _mm256_cmpgt_epi32(pivot, R);
if (_mm256_iszero(bytemask)) {
right -= N;
} else {
maskR = _mm256_movemask_ps((__m256)bytemask);
break;
}
}
}
assert(left <= right);
assert(maskL != 0);
assert(maskR != 0);
uint8_t mL;
uint8_t mR;
__m256i shuffleL;
__m256i shuffleR;
align_masks(maskL, maskR, mL, mR, shuffleL, shuffleR);
swap_epi32(L, R, maskL, shuffleL, maskR, shuffleR);
maskL = mL;
maskR = mR;
if (maskL == 0) {
_mm256_storeu_si256((__m256i*)(array + left), L);
left += N;
}
if (maskR == 0) {
_mm256_storeu_si256((__m256i*)(array + right - N + 1), R);
right -= N;
}
} // while
end:
assert(!(maskL != 0 && maskR != 0));
if (maskL != 0) {
_mm256_storeu_si256((__m256i*)(array + left), L);
} else if (maskR != 0) {
_mm256_storeu_si256((__m256i*)(array + right - N + 1), R);
}
if (left < right) {
int less = 0;
int greater = 0;
const int all = right - left + 1;
for (int i=left; i <= right; i++) {
less += int(array[i] < pv);
greater += int(array[i] > pv);
}
if (all == less) {
// all elements in range [left, right] less than pivot
scalar_partition_epi32(array, pv, origL, left);
} else if (all == greater) {
// all elements in range [left, right] greater than pivot
scalar_partition_epi32(array, pv, left, origR);
} else {
scalar_partition_epi32(array, pv, left, right);
}
}
}
} // namespace avx2
} // namespace qs