diff --git a/makefile b/makefile index c8b24a0..5b1832d 100644 --- a/makefile +++ b/makefile @@ -4,5 +4,14 @@ debug: test.cpp magyarsort.h release: test.cpp magyarsort.h g++ test.cpp -std=c++14 -O2 -o test.out +release3: test.cpp magyarsort.h + g++ test.cpp -std=c++14 -O3 -o test.out + +clang_release: test.cpp magyarsort.h + clang++ test.cpp -std=c++14 -O2 -o test.out + +clang_release3: test.cpp magyarsort.h + clang++ test.cpp -std=c++14 -O3 -o test.out + clean: test.out rm test.out diff --git a/ska_sort.hpp b/ska_sort.hpp new file mode 100644 index 0000000..81a9ef2 --- /dev/null +++ b/ska_sort.hpp @@ -0,0 +1,1445 @@ +// Copyright Malte Skarupke 2016. +// Distributed under the Boost Software License, Version 1.0. +// (See http://www.boost.org/LICENSE_1_0.txt) + +#pragma once + +#include +#include +#include +#include +#include + +namespace detail +{ +template +void counting_sort_impl(It begin, It end, OutIt out_begin, ExtractKey && extract_key) +{ + count_type counts[256] = {}; + for (It it = begin; it != end; ++it) + { + ++counts[extract_key(*it)]; + } + count_type total = 0; + for (count_type & count : counts) + { + count_type old_count = count; + count = total; + total += old_count; + } + for (; begin != end; ++begin) + { + std::uint8_t key = extract_key(*begin); + out_begin[counts[key]++] = std::move(*begin); + } +} +template +void counting_sort_impl(It begin, It end, OutIt out_begin, ExtractKey && extract_key) +{ + counting_sort_impl(begin, end, out_begin, extract_key); +} +inline bool to_unsigned_or_bool(bool b) +{ + return b; +} +inline unsigned char to_unsigned_or_bool(unsigned char c) +{ + return c; +} +inline unsigned char to_unsigned_or_bool(signed char c) +{ + return static_cast(c) + 128; +} +inline unsigned char to_unsigned_or_bool(char c) +{ + return static_cast(c); +} +inline std::uint16_t to_unsigned_or_bool(char16_t c) +{ + return static_cast(c); +} +inline std::uint32_t to_unsigned_or_bool(char32_t c) +{ + return static_cast(c); +} +inline std::uint32_t to_unsigned_or_bool(wchar_t c) +{ + return static_cast(c); +} +inline unsigned short to_unsigned_or_bool(short i) +{ + return static_cast(i) + static_cast(1 << (sizeof(short) * 8 - 1)); +} +inline unsigned short to_unsigned_or_bool(unsigned short i) +{ + return i; +} +inline unsigned int to_unsigned_or_bool(int i) +{ + return static_cast(i) + static_cast(1 << (sizeof(int) * 8 - 1)); +} +inline unsigned int to_unsigned_or_bool(unsigned int i) +{ + return i; +} +inline unsigned long to_unsigned_or_bool(long l) +{ + return static_cast(l) + static_cast(1l << (sizeof(long) * 8 - 1)); +} +inline unsigned long to_unsigned_or_bool(unsigned long l) +{ + return l; +} +inline unsigned long long to_unsigned_or_bool(long long l) +{ + return static_cast(l) + static_cast(1ll << (sizeof(long long) * 8 - 1)); +} +inline unsigned long long to_unsigned_or_bool(unsigned long long l) +{ + return l; +} +inline std::uint32_t to_unsigned_or_bool(float f) +{ + union + { + float f; + std::uint32_t u; + } as_union = { f }; + std::uint32_t sign_bit = -std::int32_t(as_union.u >> 31); + return as_union.u ^ (sign_bit | 0x80000000); +} +inline std::uint64_t to_unsigned_or_bool(double f) +{ + union + { + double d; + std::uint64_t u; + } as_union = { f }; + std::uint64_t sign_bit = -std::int64_t(as_union.u >> 63); + return as_union.u ^ (sign_bit | 0x8000000000000000); +} +template +inline size_t to_unsigned_or_bool(T * ptr) +{ + return reinterpret_cast(ptr); +} + +template +struct SizedRadixSorter; + +template<> +struct SizedRadixSorter<1> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + counting_sort_impl(begin, end, buffer_begin, [&](auto && o) + { + return to_unsigned_or_bool(extract_key(o)); + }); + return true; + } + + static constexpr size_t pass_count = 2; +}; +template<> +struct SizedRadixSorter<2> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + std::ptrdiff_t num_elements = end - begin; + if (num_elements <= (1ll << 32)) + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + else + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + } + + template + static bool sort_inline(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + count_type counts0[256] = {}; + count_type counts1[256] = {}; + + for (It it = begin; it != end; ++it) + { + uint16_t key = to_unsigned_or_bool(extract_key(*it)); + ++counts0[key & 0xff]; + ++counts1[(key >> 8) & 0xff]; + } + count_type total0 = 0; + count_type total1 = 0; + for (int i = 0; i < 256; ++i) + { + count_type old_count0 = counts0[i]; + count_type old_count1 = counts1[i]; + counts0[i] = total0; + counts1[i] = total1; + total0 += old_count0; + total1 += old_count1; + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)); + out_begin[counts0[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 8; + begin[counts1[key]++] = std::move(*it); + } + return false; + } + + static constexpr size_t pass_count = 3; +}; +template<> +struct SizedRadixSorter<4> +{ + + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + std::ptrdiff_t num_elements = end - begin; + if (num_elements <= (1ll << 32)) + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + else + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + } + template + static bool sort_inline(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + count_type counts0[256] = {}; + count_type counts1[256] = {}; + count_type counts2[256] = {}; + count_type counts3[256] = {}; + + for (It it = begin; it != end; ++it) + { + uint32_t key = to_unsigned_or_bool(extract_key(*it)); + ++counts0[key & 0xff]; + ++counts1[(key >> 8) & 0xff]; + ++counts2[(key >> 16) & 0xff]; + ++counts3[(key >> 24) & 0xff]; + } + count_type total0 = 0; + count_type total1 = 0; + count_type total2 = 0; + count_type total3 = 0; + for (int i = 0; i < 256; ++i) + { + count_type old_count0 = counts0[i]; + count_type old_count1 = counts1[i]; + count_type old_count2 = counts2[i]; + count_type old_count3 = counts3[i]; + counts0[i] = total0; + counts1[i] = total1; + counts2[i] = total2; + counts3[i] = total3; + total0 += old_count0; + total1 += old_count1; + total2 += old_count2; + total3 += old_count3; + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)); + out_begin[counts0[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 8; + begin[counts1[key]++] = std::move(*it); + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 16; + out_begin[counts2[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 24; + begin[counts3[key]++] = std::move(*it); + } + return false; + } + + static constexpr size_t pass_count = 5; +}; +template<> +struct SizedRadixSorter<8> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + std::ptrdiff_t num_elements = end - begin; + if (num_elements <= (1ll << 32)) + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + else + return sort_inline(begin, end, buffer_begin, buffer_begin + num_elements, extract_key); + } + template + static bool sort_inline(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + count_type counts0[256] = {}; + count_type counts1[256] = {}; + count_type counts2[256] = {}; + count_type counts3[256] = {}; + count_type counts4[256] = {}; + count_type counts5[256] = {}; + count_type counts6[256] = {}; + count_type counts7[256] = {}; + + for (It it = begin; it != end; ++it) + { + uint64_t key = to_unsigned_or_bool(extract_key(*it)); + ++counts0[key & 0xff]; + ++counts1[(key >> 8) & 0xff]; + ++counts2[(key >> 16) & 0xff]; + ++counts3[(key >> 24) & 0xff]; + ++counts4[(key >> 32) & 0xff]; + ++counts5[(key >> 40) & 0xff]; + ++counts6[(key >> 48) & 0xff]; + ++counts7[(key >> 56) & 0xff]; + } + count_type total0 = 0; + count_type total1 = 0; + count_type total2 = 0; + count_type total3 = 0; + count_type total4 = 0; + count_type total5 = 0; + count_type total6 = 0; + count_type total7 = 0; + for (int i = 0; i < 256; ++i) + { + count_type old_count0 = counts0[i]; + count_type old_count1 = counts1[i]; + count_type old_count2 = counts2[i]; + count_type old_count3 = counts3[i]; + count_type old_count4 = counts4[i]; + count_type old_count5 = counts5[i]; + count_type old_count6 = counts6[i]; + count_type old_count7 = counts7[i]; + counts0[i] = total0; + counts1[i] = total1; + counts2[i] = total2; + counts3[i] = total3; + counts4[i] = total4; + counts5[i] = total5; + counts6[i] = total6; + counts7[i] = total7; + total0 += old_count0; + total1 += old_count1; + total2 += old_count2; + total3 += old_count3; + total4 += old_count4; + total5 += old_count5; + total6 += old_count6; + total7 += old_count7; + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)); + out_begin[counts0[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 8; + begin[counts1[key]++] = std::move(*it); + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 16; + out_begin[counts2[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 24; + begin[counts3[key]++] = std::move(*it); + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 32; + out_begin[counts4[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 40; + begin[counts5[key]++] = std::move(*it); + } + for (It it = begin; it != end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 48; + out_begin[counts6[key]++] = std::move(*it); + } + for (OutIt it = out_begin; it != out_end; ++it) + { + std::uint8_t key = to_unsigned_or_bool(extract_key(*it)) >> 56; + begin[counts7[key]++] = std::move(*it); + } + return false; + } + + static constexpr size_t pass_count = 9; +}; + +template +struct RadixSorter; +template<> +struct RadixSorter +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + size_t false_count = 0; + for (It it = begin; it != end; ++it) + { + if (!extract_key(*it)) + ++false_count; + } + size_t true_position = false_count; + false_count = 0; + for (; begin != end; ++begin) + { + if (extract_key(*begin)) + buffer_begin[true_position++] = std::move(*begin); + else + buffer_begin[false_count++] = std::move(*begin); + } + return true; + } + + static constexpr size_t pass_count = 2; +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template<> +struct RadixSorter : SizedRadixSorter +{ +}; +template +struct RadixSorter> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + bool first_result = RadixSorter::sort(begin, end, buffer_begin, [&](auto && o) + { + return extract_key(o).second; + }); + auto extract_first = [&](auto && o) + { + return extract_key(o).first; + }; + + if (first_result) + { + return !RadixSorter::sort(buffer_begin, buffer_begin + (end - begin), begin, extract_first); + } + else + { + return RadixSorter::sort(begin, end, buffer_begin, extract_first); + } + } + + static constexpr size_t pass_count = RadixSorter::pass_count + RadixSorter::pass_count; +}; +template +struct RadixSorter &> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + bool first_result = RadixSorter::sort(begin, end, buffer_begin, [&](auto && o) -> const V & + { + return extract_key(o).second; + }); + auto extract_first = [&](auto && o) -> const K & + { + return extract_key(o).first; + }; + + if (first_result) + { + return !RadixSorter::sort(buffer_begin, buffer_begin + (end - begin), begin, extract_first); + } + else + { + return RadixSorter::sort(begin, end, buffer_begin, extract_first); + } + } + + static constexpr size_t pass_count = RadixSorter::pass_count + RadixSorter::pass_count; +}; +template +struct TupleRadixSorter +{ + using NextSorter = TupleRadixSorter; + using ThisSorter = RadixSorter::type>; + + template + static bool sort(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + bool which = NextSorter::sort(begin, end, out_begin, out_end, extract_key); + auto extract_i = [&](auto && o) + { + return std::get(extract_key(o)); + }; + if (which) + return !ThisSorter::sort(out_begin, out_end, begin, extract_i); + else + return ThisSorter::sort(begin, end, out_begin, extract_i); + } + + static constexpr size_t pass_count = ThisSorter::pass_count + NextSorter::pass_count; +}; +template +struct TupleRadixSorter +{ + using NextSorter = TupleRadixSorter; + using ThisSorter = RadixSorter::type>; + + template + static bool sort(It begin, It end, OutIt out_begin, OutIt out_end, ExtractKey && extract_key) + { + bool which = NextSorter::sort(begin, end, out_begin, out_end, extract_key); + auto extract_i = [&](auto && o) -> decltype(auto) + { + return std::get(extract_key(o)); + }; + if (which) + return !ThisSorter::sort(out_begin, out_end, begin, extract_i); + else + return ThisSorter::sort(begin, end, out_begin, extract_i); + } + + static constexpr size_t pass_count = ThisSorter::pass_count + NextSorter::pass_count; +}; +template +struct TupleRadixSorter +{ + template + static bool sort(It, It, OutIt, OutIt, ExtractKey &&) + { + return false; + } + + static constexpr size_t pass_count = 0; +}; +template +struct TupleRadixSorter +{ + template + static bool sort(It, It, OutIt, OutIt, ExtractKey &&) + { + return false; + } + + static constexpr size_t pass_count = 0; +}; + +template +struct RadixSorter> +{ + using SorterImpl = TupleRadixSorter<0, sizeof...(Args), std::tuple>; + + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + return SorterImpl::sort(begin, end, buffer_begin, buffer_begin + (end - begin), extract_key); + } + + static constexpr size_t pass_count = SorterImpl::pass_count; +}; + +template +struct RadixSorter &> +{ + using SorterImpl = TupleRadixSorter<0, sizeof...(Args), const std::tuple &>; + + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + return SorterImpl::sort(begin, end, buffer_begin, buffer_begin + (end - begin), extract_key); + } + + static constexpr size_t pass_count = SorterImpl::pass_count; +}; + +template +struct RadixSorter> +{ + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + auto buffer_end = buffer_begin + (end - begin); + bool which = false; + for (size_t i = S; i > 0; --i) + { + auto extract_i = [&, i = i - 1](auto && o) + { + return extract_key(o)[i]; + }; + if (which) + which = !RadixSorter::sort(buffer_begin, buffer_end, begin, extract_i); + else + which = RadixSorter::sort(begin, end, buffer_begin, extract_i); + } + return which; + } + + static constexpr size_t pass_count = RadixSorter::pass_count * S; +}; + +template +struct RadixSorter : RadixSorter +{ +}; +template +struct RadixSorter : RadixSorter +{ +}; +template +struct RadixSorter : RadixSorter +{ +}; +template +struct RadixSorter : RadixSorter +{ +}; +template +struct RadixSorter : RadixSorter +{ +}; +// these structs serve two purposes +// 1. they serve as illustration for how to implement the to_radix_sort_key function +// 2. they help produce better error messages. with these overloads you get the +// error message "no matching function for call to to_radix_sort(your_type)" +// without these examples, you'd get the error message "to_radix_sort_key was +// not declared in this scope" which is a much less useful error message +struct ExampleStructA { int i; }; +struct ExampleStructB { float f; }; +inline int to_radix_sort_key(ExampleStructA a) { return a.i; } +inline float to_radix_sort_key(ExampleStructB b) { return b.f; } +template +struct FallbackRadixSorter : RadixSorter()))> +{ + using base = RadixSorter()))>; + + template + static bool sort(It begin, It end, OutIt buffer_begin, ExtractKey && extract_key) + { + return base::sort(begin, end, buffer_begin, [&](auto && a) -> decltype(auto) + { + return to_radix_sort_key(extract_key(a)); + }); + } +}; + +template +struct nested_void +{ + using type = void; +}; + +template +using void_t = typename nested_void::type; + +template +struct has_subscript_operator_impl +{ + template()[0])> + static std::true_type test(int); + template + static std::false_type test(...); + + using type = decltype(test(0)); +}; + +template +using has_subscript_operator = typename has_subscript_operator_impl::type; + + +template +struct FallbackRadixSorter()))>> + : RadixSorter()))> +{ +}; + +template +struct RadixSorter : FallbackRadixSorter +{ +}; + +template +size_t radix_sort_pass_count = RadixSorter::pass_count; + +template +inline void unroll_loop_four_times(It begin, size_t iteration_count, Func && to_call) +{ + size_t loop_count = iteration_count / 4; + size_t remainder_count = iteration_count - loop_count * 4; + for (; loop_count > 0; --loop_count) + { + to_call(begin); + ++begin; + to_call(begin); + ++begin; + to_call(begin); + ++begin; + to_call(begin); + ++begin; + } + switch(remainder_count) + { + case 3: + to_call(begin); + ++begin; + case 2: + to_call(begin); + ++begin; + case 1: + to_call(begin); + } +} + +template +inline It custom_std_partition(It begin, It end, F && func) +{ + for (;; ++begin) + { + if (begin == end) + return end; + if (!func(*begin)) + break; + } + It it = begin; + for(++it; it != end; ++it) + { + if (!func(*it)) + continue; + + std::iter_swap(begin, it); + ++begin; + } + return begin; +} + +struct PartitionInfo +{ + PartitionInfo() + : count(0) + { + } + + union + { + size_t count; + size_t offset; + }; + size_t next_offset; +}; + +template +struct UnsignedForSize; +template<> +struct UnsignedForSize<1> +{ + typedef uint8_t type; +}; +template<> +struct UnsignedForSize<2> +{ + typedef uint16_t type; +}; +template<> +struct UnsignedForSize<4> +{ + typedef uint32_t type; +}; +template<> +struct UnsignedForSize<8> +{ + typedef uint64_t type; +}; +template +struct SubKey; +template +struct SizedSubKey +{ + template + static auto sub_key(T && value, void *) + { + return to_unsigned_or_bool(value); + } + + typedef SubKey next; + + using sub_key_type = typename UnsignedForSize::type; +}; +template +struct SubKey : SubKey +{ +}; +template +struct SubKey : SubKey +{ +}; +template +struct SubKey : SubKey +{ +}; +template +struct SubKey : SubKey +{ +}; +template +struct SubKey : SubKey +{ +}; +template +struct FallbackSubKey + : SubKey()))> +{ + using base = SubKey()))>; + + template + static decltype(auto) sub_key(U && value, void * data) + { + return base::sub_key(to_radix_sort_key(value), data); + } +}; +template +struct FallbackSubKey()))>> + : SubKey()))> +{ +}; +template +struct SubKey : FallbackSubKey +{ +}; +template<> +struct SubKey +{ + template + static bool sub_key(T && value, void *) + { + return value; + } + + typedef SubKey next; + + using sub_key_type = bool; +}; +template<> +struct SubKey; +template<> +struct SubKey : SizedSubKey +{ +}; +template<> +struct SubKey : SizedSubKey +{ +}; +template<> +struct SubKey : SizedSubKey +{ +}; +template<> +struct SubKey : SizedSubKey +{ +}; +template<> +struct SubKey : SizedSubKey +{ +}; +template +struct SubKey : SizedSubKey +{ +}; +template +struct PairSecondSubKey : Current +{ + static decltype(auto) sub_key(const std::pair & value, void * sort_data) + { + return Current::sub_key(value.second, sort_data); + } + + using next = typename std::conditional, typename Current::next>::value, SubKey, PairSecondSubKey>::type; +}; +template +struct PairFirstSubKey : Current +{ + static decltype(auto) sub_key(const std::pair & value, void * sort_data) + { + return Current::sub_key(value.first, sort_data); + } + + using next = typename std::conditional, typename Current::next>::value, PairSecondSubKey>, PairFirstSubKey>::type; +}; +template +struct SubKey> : PairFirstSubKey> +{ +}; +template +struct TypeAt : TypeAt +{ +}; +template +struct TypeAt<0, First, More...> +{ + typedef First type; +}; + +template +struct TupleSubKey; + +template +struct NextTupleSubKey +{ + using type = TupleSubKey; +}; +template +struct NextTupleSubKey, First, Second, More...> +{ + using type = TupleSubKey, Second, More...>; +}; +template +struct NextTupleSubKey, First> +{ + using type = SubKey; +}; + +template +struct TupleSubKey : Current +{ + template + static decltype(auto) sub_key(const Tuple & value, void * sort_data) + { + return Current::sub_key(std::get(value), sort_data); + } + + using next = typename NextTupleSubKey::type; +}; +template +struct TupleSubKey : Current +{ + template + static decltype(auto) sub_key(const Tuple & value, void * sort_data) + { + return Current::sub_key(std::get(value), sort_data); + } + + using next = typename NextTupleSubKey::type; +}; +template +struct SubKey> : TupleSubKey<0, SubKey, First, More...> +{ +}; + +struct BaseListSortData +{ + size_t current_index; + size_t recursion_limit; + void * next_sort_data; +}; +template +struct ListSortData : BaseListSortData +{ + void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *); +}; + +template +struct ListElementSubKey : SubKey()[0])>::type> +{ + using base = SubKey()[0])>::type>; + + using next = ListElementSubKey; + + template + static decltype(auto) sub_key(U && value, void * sort_data) + { + BaseListSortData * list_sort_data = static_cast(sort_data); + const T & list = CurrentSubKey::sub_key(value, list_sort_data->next_sort_data); + return base::sub_key(list[list_sort_data->current_index], list_sort_data->next_sort_data); + } +}; + +template +struct ListSubKey +{ + using next = SubKey; + + using sub_key_type = T; + + static const T & sub_key(const T & value, void *) + { + return value; + } +}; + +template +struct FallbackSubKey::value>::type> : ListSubKey +{ +}; + +template +inline void StdSortFallback(It begin, It end, ExtractKey & extract_key) +{ + std::sort(begin, end, [&](auto && l, auto && r){ return extract_key(l) < extract_key(r); }); +} + +template +inline bool StdSortIfLessThanThreshold(It begin, It end, std::ptrdiff_t num_elements, ExtractKey & extract_key) +{ + if (num_elements <= 1) + return true; + if (num_elements >= StdSortThreshold) + return false; + StdSortFallback(begin, end, extract_key); + return true; +} + +template +struct InplaceSorter; + +template +struct UnsignedInplaceSorter +{ + static constexpr size_t ShiftAmount = (((NumBytes - 1) - Offset) * 8); + template + inline static uint8_t current_byte(T && elem, void * sort_data) + { + return CurrentSubKey::sub_key(elem, sort_data) >> ShiftAmount; + } + template + static void sort(It begin, It end, std::ptrdiff_t num_elements, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * sort_data) + { + if (num_elements < AmericanFlagSortThreshold) + american_flag_sort(begin, end, extract_key, next_sort, sort_data); + else + ska_byte_sort(begin, end, extract_key, next_sort, sort_data); + } + + template + static void american_flag_sort(It begin, It end, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * sort_data) + { + PartitionInfo partitions[256]; + for (It it = begin; it != end; ++it) + { + ++partitions[current_byte(extract_key(*it), sort_data)].count; + } + size_t total = 0; + uint8_t remaining_partitions[256]; + int num_partitions = 0; + for (int i = 0; i < 256; ++i) + { + size_t count = partitions[i].count; + if (!count) + continue; + partitions[i].offset = total; + total += count; + partitions[i].next_offset = total; + remaining_partitions[num_partitions] = i; + ++num_partitions; + } + if (num_partitions > 1) + { + uint8_t * current_block_ptr = remaining_partitions; + PartitionInfo * current_block = partitions + *current_block_ptr; + uint8_t * last_block = remaining_partitions + num_partitions - 1; + It it = begin; + It block_end = begin + current_block->next_offset; + It last_element = end - 1; + for (;;) + { + PartitionInfo * block = partitions + current_byte(extract_key(*it), sort_data); + if (block == current_block) + { + ++it; + if (it == last_element) + break; + else if (it == block_end) + { + for (;;) + { + ++current_block_ptr; + if (current_block_ptr == last_block) + goto recurse; + current_block = partitions + *current_block_ptr; + if (current_block->offset != current_block->next_offset) + break; + } + + it = begin + current_block->offset; + block_end = begin + current_block->next_offset; + } + } + else + { + size_t offset = block->offset++; + std::iter_swap(it, begin + offset); + } + } + } + recurse: + if (Offset + 1 != NumBytes || next_sort) + { + size_t start_offset = 0; + It partition_begin = begin; + for (uint8_t * it = remaining_partitions, * end = remaining_partitions + num_partitions; it != end; ++it) + { + size_t end_offset = partitions[*it].next_offset; + It partition_end = begin + end_offset; + std::ptrdiff_t num_elements = end_offset - start_offset; + if (!StdSortIfLessThanThreshold(partition_begin, partition_end, num_elements, extract_key)) + { + UnsignedInplaceSorter::sort(partition_begin, partition_end, num_elements, extract_key, next_sort, sort_data); + } + start_offset = end_offset; + partition_begin = partition_end; + } + } + } + + template + static void ska_byte_sort(It begin, It end, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * sort_data) + { + PartitionInfo partitions[256]; + for (It it = begin; it != end; ++it) + { + ++partitions[current_byte(extract_key(*it), sort_data)].count; + } + uint8_t remaining_partitions[256]; + size_t total = 0; + int num_partitions = 0; + for (int i = 0; i < 256; ++i) + { + size_t count = partitions[i].count; + if (count) + { + partitions[i].offset = total; + total += count; + remaining_partitions[num_partitions] = i; + ++num_partitions; + } + partitions[i].next_offset = total; + } + for (uint8_t * last_remaining = remaining_partitions + num_partitions, * end_partition = remaining_partitions + 1; last_remaining > end_partition;) + { + last_remaining = custom_std_partition(remaining_partitions, last_remaining, [&](uint8_t partition) + { + size_t & begin_offset = partitions[partition].offset; + size_t & end_offset = partitions[partition].next_offset; + if (begin_offset == end_offset) + return false; + + unroll_loop_four_times(begin + begin_offset, end_offset - begin_offset, [partitions = partitions, begin, &extract_key, sort_data](It it) + { + uint8_t this_partition = current_byte(extract_key(*it), sort_data); + size_t offset = partitions[this_partition].offset++; + std::iter_swap(it, begin + offset); + }); + return begin_offset != end_offset; + }); + } + if (Offset + 1 != NumBytes || next_sort) + { + for (uint8_t * it = remaining_partitions + num_partitions; it != remaining_partitions; --it) + { + uint8_t partition = it[-1]; + size_t start_offset = (partition == 0 ? 0 : partitions[partition - 1].next_offset); + size_t end_offset = partitions[partition].next_offset; + It partition_begin = begin + start_offset; + It partition_end = begin + end_offset; + std::ptrdiff_t num_elements = end_offset - start_offset; + if (!StdSortIfLessThanThreshold(partition_begin, partition_end, num_elements, extract_key)) + { + UnsignedInplaceSorter::sort(partition_begin, partition_end, num_elements, extract_key, next_sort, sort_data); + } + } + } + } +}; + +template +struct UnsignedInplaceSorter +{ + template + inline static void sort(It begin, It end, std::ptrdiff_t num_elements, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * next_sort_data) + { + next_sort(begin, end, num_elements, extract_key, next_sort_data); + } +}; + +template +size_t CommonPrefix(It begin, It end, size_t start_index, ExtractKey && extract_key, ElementKey && element_key) +{ + const auto & largest_match_list = extract_key(*begin); + size_t largest_match = largest_match_list.size(); + if (largest_match == start_index) + return start_index; + for (++begin; begin != end; ++begin) + { + const auto & current_list = extract_key(*begin); + size_t current_size = current_list.size(); + if (current_size < largest_match) + { + largest_match = current_size; + if (largest_match == start_index) + return start_index; + } + if (element_key(largest_match_list[start_index]) != element_key(current_list[start_index])) + return start_index; + for (size_t i = start_index + 1; i < largest_match; ++i) + { + if (element_key(largest_match_list[i]) != element_key(current_list[i])) + { + largest_match = i; + break; + } + } + } + return largest_match; +} + +template +struct ListInplaceSorter +{ + using ElementSubKey = ListElementSubKey; + template + static void sort(It begin, It end, ExtractKey & extract_key, ListSortData * sort_data) + { + size_t current_index = sort_data->current_index; + void * next_sort_data = sort_data->next_sort_data; + auto current_key = [&](auto && elem) -> decltype(auto) + { + return CurrentSubKey::sub_key(extract_key(elem), next_sort_data); + }; + auto element_key = [&](auto && elem) -> decltype(auto) + { + return ElementSubKey::base::sub_key(elem, sort_data); + }; + sort_data->current_index = current_index = CommonPrefix(begin, end, current_index, current_key, element_key); + It end_of_shorter_ones = std::partition(begin, end, [&](auto && elem) + { + return current_key(elem).size() <= current_index; + }); + std::ptrdiff_t num_shorter_ones = end_of_shorter_ones - begin; + if (sort_data->next_sort && !StdSortIfLessThanThreshold(begin, end_of_shorter_ones, num_shorter_ones, extract_key)) + { + sort_data->next_sort(begin, end_of_shorter_ones, num_shorter_ones, extract_key, next_sort_data); + } + std::ptrdiff_t num_elements = end - end_of_shorter_ones; + if (!StdSortIfLessThanThreshold(end_of_shorter_ones, end, num_elements, extract_key)) + { + void (*sort_next_element)(It, It, std::ptrdiff_t, ExtractKey &, void *) = static_cast(&sort_from_recursion); + InplaceSorter::sort(end_of_shorter_ones, end, num_elements, extract_key, sort_next_element, sort_data); + } + } + + template + static void sort_from_recursion(It begin, It end, std::ptrdiff_t, ExtractKey & extract_key, void * next_sort_data) + { + ListSortData offset = *static_cast *>(next_sort_data); + ++offset.current_index; + --offset.recursion_limit; + if (offset.recursion_limit == 0) + { + StdSortFallback(begin, end, extract_key); + } + else + { + sort(begin, end, extract_key, &offset); + } + } + + + template + static void sort(It begin, It end, std::ptrdiff_t, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * next_sort_data) + { + ListSortData offset; + offset.current_index = 0; + offset.recursion_limit = 16; + offset.next_sort = next_sort; + offset.next_sort_data = next_sort_data; + sort(begin, end, extract_key, &offset); + } +}; + +template +struct InplaceSorter +{ + template + static void sort(It begin, It end, std::ptrdiff_t, ExtractKey & extract_key, void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *), void * sort_data) + { + It middle = std::partition(begin, end, [&](auto && a){ return !CurrentSubKey::sub_key(extract_key(a), sort_data); }); + if (next_sort) + { + next_sort(begin, middle, middle - begin, extract_key, sort_data); + next_sort(middle, end, end - middle, extract_key, sort_data); + } + } +}; + +template +struct InplaceSorter : UnsignedInplaceSorter +{ +}; +template +struct InplaceSorter : UnsignedInplaceSorter +{ +}; +template +struct InplaceSorter : UnsignedInplaceSorter +{ +}; +template +struct InplaceSorter : UnsignedInplaceSorter +{ +}; +template +struct FallbackInplaceSorter; + +template +struct InplaceSorter : FallbackInplaceSorter +{ +}; + +template +struct FallbackInplaceSorter::value>::type> + : ListInplaceSorter +{ +}; + +template +struct SortStarter; +template +struct SortStarter> +{ + template + static void sort(It, It, std::ptrdiff_t, ExtractKey &, void *) + { + } +}; + +template +struct SortStarter +{ + template + static void sort(It begin, It end, std::ptrdiff_t num_elements, ExtractKey & extract_key, void * next_sort_data = nullptr) + { + if (StdSortIfLessThanThreshold(begin, end, num_elements, extract_key)) + return; + + void (*next_sort)(It, It, std::ptrdiff_t, ExtractKey &, void *) = static_cast(&SortStarter::sort); + if (next_sort == static_cast(&SortStarter>::sort)) + next_sort = nullptr; + InplaceSorter::sort(begin, end, num_elements, extract_key, next_sort, next_sort_data); + } +}; + +template +void inplace_radix_sort(It begin, It end, ExtractKey & extract_key) +{ + using SubKey = SubKey; + SortStarter::sort(begin, end, end - begin, extract_key); +} + +struct IdentityFunctor +{ + template + decltype(auto) operator()(T && i) const + { + return std::forward(i); + } +}; +} + +template +static void ska_sort(It begin, It end, ExtractKey && extract_key) +{ + detail::inplace_radix_sort<128, 1024>(begin, end, extract_key); +} + +template +static void ska_sort(It begin, It end) +{ + ska_sort(begin, end, detail::IdentityFunctor()); +} + +template +bool ska_sort_copy(It begin, It end, OutIt buffer_begin, ExtractKey && key) +{ + std::ptrdiff_t num_elements = end - begin; + if (num_elements < 128 || detail::radix_sort_pass_count::type> >= 8) + { + ska_sort(begin, end, key); + return false; + } + else + return detail::RadixSorter::type>::sort(begin, end, buffer_begin, key); +} +template +bool ska_sort_copy(It begin, It end, OutIt buffer_begin) +{ + return ska_sort_copy(begin, end, buffer_begin, detail::IdentityFunctor()); +} diff --git a/test.cpp b/test.cpp index 84ad84d..c1880a9 100644 --- a/test.cpp +++ b/test.cpp @@ -6,6 +6,7 @@ // #define CREEL // Overwrites TEST_LEN to 16 and sets MAGYAR_SORT_NIBBLE! // Number of input elements to generate - unused when CREEL is defined! +//#define SORT_WIDTH 200000000 #define SORT_WIDTH 40000 // Uncomment this to use nibbles as digits and not bytes - CREEL defines this anyways //#define MAGYAR_SORT_NIBBLE @@ -13,6 +14,8 @@ // Uncomment if you want to see output before / after sorts (debugging for example) //#define PRINT_OUTPUT +//#define SKA_SORT + /* Includes */ #include @@ -24,6 +27,10 @@ #include // std::sort #include "magyarsort.h" +#ifdef SKA_SORT +#include "ska_sort.hpp" +#endif // SKA_SORT + /* Input generation and prerequisites */ #ifdef CREEL @@ -97,9 +104,16 @@ int main() { MagyarSort::debugArr(arr1, in1.size()); #endif // PRINT_OUTPUT - /* std::sort */ auto stdBegin = std::chrono::high_resolution_clock::now(); +#ifndef SKA_SORT + /* std::sort */ std::sort(std::begin(in2), std::end(in2)); +#else // SKA_SORT + /* Ska-sort */ + //ska_sort(std::begin(in2), std::end(in2)); + std::vector buffer(in2.size()); + if (ska_sort_copy(std::begin(in2), std::end(in2), std::begin(buffer))) in2.swap(buffer); +#endif // SKA_SORT auto stdEnd = std::chrono::high_resolution_clock::now(); #ifdef PRINT_OUTPUT