diff --git a/media/libaom/config/generic/config/aom_config.asm b/media/libaom/config/generic/config/aom_config.asm index a28ad482eafb..f8db77d95e4e 100644 --- a/media/libaom/config/generic/config/aom_config.asm +++ b/media/libaom/config/generic/config/aom_config.asm @@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0 CONFIG_GCC equ 1 CONFIG_GCOV equ 0 CONFIG_GPROF equ 0 +CONFIG_HIGHWAY equ 0 CONFIG_INSPECTION equ 0 CONFIG_INTERNAL_STATS equ 0 CONFIG_INTER_STATS_ONLY equ 0 diff --git a/media/libaom/config/generic/config/aom_config.h b/media/libaom/config/generic/config/aom_config.h index 61b49dbc66a6..c2c0e39faf57 100644 --- a/media/libaom/config/generic/config/aom_config.h +++ b/media/libaom/config/generic/config/aom_config.h @@ -42,6 +42,7 @@ #define CONFIG_GCC 1 #define CONFIG_GCOV 0 #define CONFIG_GPROF 0 +#define CONFIG_HIGHWAY 0 #define CONFIG_INSPECTION 0 #define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTER_STATS_ONLY 0 diff --git a/media/libaom/config/linux/arm/config/aom_config.asm b/media/libaom/config/linux/arm/config/aom_config.asm index 9fd3159dba6d..fff68325f78a 100644 --- a/media/libaom/config/linux/arm/config/aom_config.asm +++ b/media/libaom/config/linux/arm/config/aom_config.asm @@ -40,6 +40,7 @@ .equ CONFIG_GCC, 1 .equ CONFIG_GCOV, 0 .equ CONFIG_GPROF, 0 +.equ CONFIG_HIGHWAY, 0 .equ CONFIG_INSPECTION, 0 .equ CONFIG_INTERNAL_STATS, 0 .equ CONFIG_INTER_STATS_ONLY, 0 diff --git a/media/libaom/config/linux/arm/config/aom_config.h b/media/libaom/config/linux/arm/config/aom_config.h index 15350b976cfb..478f88991d2b 100644 --- a/media/libaom/config/linux/arm/config/aom_config.h +++ b/media/libaom/config/linux/arm/config/aom_config.h @@ -42,6 +42,7 @@ #define CONFIG_GCC 1 #define CONFIG_GCOV 0 #define CONFIG_GPROF 0 +#define CONFIG_HIGHWAY 0 #define CONFIG_INSPECTION 0 #define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTER_STATS_ONLY 0 diff --git a/media/libaom/config/linux/ia32/config/aom_config.asm b/media/libaom/config/linux/ia32/config/aom_config.asm index 0f2be2761ba4..f81176b9cd4e 100644 --- a/media/libaom/config/linux/ia32/config/aom_config.asm +++ b/media/libaom/config/linux/ia32/config/aom_config.asm @@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0 CONFIG_GCC equ 1 CONFIG_GCOV equ 0 CONFIG_GPROF equ 0 +CONFIG_HIGHWAY equ 0 CONFIG_INSPECTION equ 0 CONFIG_INTERNAL_STATS equ 0 CONFIG_INTER_STATS_ONLY equ 0 diff --git a/media/libaom/config/linux/ia32/config/aom_config.h b/media/libaom/config/linux/ia32/config/aom_config.h index 89ff5324d067..ff1abe142382 100644 --- a/media/libaom/config/linux/ia32/config/aom_config.h +++ b/media/libaom/config/linux/ia32/config/aom_config.h @@ -42,6 +42,7 @@ #define CONFIG_GCC 1 #define CONFIG_GCOV 0 #define CONFIG_GPROF 0 +#define CONFIG_HIGHWAY 0 #define CONFIG_INSPECTION 0 #define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTER_STATS_ONLY 0 diff --git a/media/libaom/config/linux/x64/config/aom_config.asm b/media/libaom/config/linux/x64/config/aom_config.asm index 3091f2ae3233..f71cb89c5f53 100644 --- a/media/libaom/config/linux/x64/config/aom_config.asm +++ b/media/libaom/config/linux/x64/config/aom_config.asm @@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0 CONFIG_GCC equ 1 CONFIG_GCOV equ 0 CONFIG_GPROF equ 0 +CONFIG_HIGHWAY equ 0 CONFIG_INSPECTION equ 0 CONFIG_INTERNAL_STATS equ 0 CONFIG_INTER_STATS_ONLY equ 0 diff --git a/media/libaom/config/linux/x64/config/aom_config.h b/media/libaom/config/linux/x64/config/aom_config.h index a86fe4a4eac4..fa583c15eaf4 100644 --- a/media/libaom/config/linux/x64/config/aom_config.h +++ b/media/libaom/config/linux/x64/config/aom_config.h @@ -42,6 +42,7 @@ #define CONFIG_GCC 1 #define CONFIG_GCOV 0 #define CONFIG_GPROF 0 +#define CONFIG_HIGHWAY 0 #define CONFIG_INSPECTION 0 #define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTER_STATS_ONLY 0 diff --git a/media/libaom/config/mac/arm64/config/aom_config.asm b/media/libaom/config/mac/arm64/config/aom_config.asm index c8f503e63601..479a7552008f 100644 --- a/media/libaom/config/mac/arm64/config/aom_config.asm +++ b/media/libaom/config/mac/arm64/config/aom_config.asm @@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0 CONFIG_GCC equ 1 CONFIG_GCOV equ 0 CONFIG_GPROF equ 0 +CONFIG_HIGHWAY equ 0 CONFIG_INSPECTION equ 0 CONFIG_INTERNAL_STATS equ 0 CONFIG_INTER_STATS_ONLY equ 0 diff --git a/media/libaom/config/mac/arm64/config/aom_config.h b/media/libaom/config/mac/arm64/config/aom_config.h index c501b7ef2586..5a0dd6246a55 100644 --- a/media/libaom/config/mac/arm64/config/aom_config.h +++ b/media/libaom/config/mac/arm64/config/aom_config.h @@ -42,6 +42,7 @@ #define CONFIG_GCC 1 #define CONFIG_GCOV 0 #define CONFIG_GPROF 0 +#define CONFIG_HIGHWAY 0 #define CONFIG_INSPECTION 0 #define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTER_STATS_ONLY 0 diff --git a/media/libaom/config/mac/x64/config/aom_config.asm b/media/libaom/config/mac/x64/config/aom_config.asm index 3091f2ae3233..f71cb89c5f53 100644 --- a/media/libaom/config/mac/x64/config/aom_config.asm +++ b/media/libaom/config/mac/x64/config/aom_config.asm @@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0 CONFIG_GCC equ 1 CONFIG_GCOV equ 0 CONFIG_GPROF equ 0 +CONFIG_HIGHWAY equ 0 CONFIG_INSPECTION equ 0 CONFIG_INTERNAL_STATS equ 0 CONFIG_INTER_STATS_ONLY equ 0 diff --git a/media/libaom/config/mac/x64/config/aom_config.h b/media/libaom/config/mac/x64/config/aom_config.h index a86fe4a4eac4..fa583c15eaf4 100644 --- a/media/libaom/config/mac/x64/config/aom_config.h +++ b/media/libaom/config/mac/x64/config/aom_config.h @@ -42,6 +42,7 @@ #define CONFIG_GCC 1 #define CONFIG_GCOV 0 #define CONFIG_GPROF 0 +#define CONFIG_HIGHWAY 0 #define CONFIG_INSPECTION 0 #define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTER_STATS_ONLY 0 diff --git a/media/libaom/config/win/ia32/config/aom_config.asm b/media/libaom/config/win/ia32/config/aom_config.asm index c30f034e5502..e6c9eaa37c5f 100644 --- a/media/libaom/config/win/ia32/config/aom_config.asm +++ b/media/libaom/config/win/ia32/config/aom_config.asm @@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0 CONFIG_GCC equ 1 CONFIG_GCOV equ 0 CONFIG_GPROF equ 0 +CONFIG_HIGHWAY equ 0 CONFIG_INSPECTION equ 0 CONFIG_INTERNAL_STATS equ 0 CONFIG_INTER_STATS_ONLY equ 0 diff --git a/media/libaom/config/win/ia32/config/aom_config.h b/media/libaom/config/win/ia32/config/aom_config.h index bb73fdd79fc6..5ced843f1f31 100644 --- a/media/libaom/config/win/ia32/config/aom_config.h +++ b/media/libaom/config/win/ia32/config/aom_config.h @@ -42,6 +42,7 @@ #define CONFIG_GCC 1 #define CONFIG_GCOV 0 #define CONFIG_GPROF 0 +#define CONFIG_HIGHWAY 0 #define CONFIG_INSPECTION 0 #define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTER_STATS_ONLY 0 diff --git a/media/libaom/config/win/x64/config/aom_config.asm b/media/libaom/config/win/x64/config/aom_config.asm index 3091f2ae3233..f71cb89c5f53 100644 --- a/media/libaom/config/win/x64/config/aom_config.asm +++ b/media/libaom/config/win/x64/config/aom_config.asm @@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0 CONFIG_GCC equ 1 CONFIG_GCOV equ 0 CONFIG_GPROF equ 0 +CONFIG_HIGHWAY equ 0 CONFIG_INSPECTION equ 0 CONFIG_INTERNAL_STATS equ 0 CONFIG_INTER_STATS_ONLY equ 0 diff --git a/media/libaom/config/win/x64/config/aom_config.h b/media/libaom/config/win/x64/config/aom_config.h index a86fe4a4eac4..fa583c15eaf4 100644 --- a/media/libaom/config/win/x64/config/aom_config.h +++ b/media/libaom/config/win/x64/config/aom_config.h @@ -42,6 +42,7 @@ #define CONFIG_GCC 1 #define CONFIG_GCOV 0 #define CONFIG_GPROF 0 +#define CONFIG_HIGHWAY 0 #define CONFIG_INSPECTION 0 #define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTER_STATS_ONLY 0 diff --git a/media/libaom/moz.yaml b/media/libaom/moz.yaml index e6beabca41fa..db10744bed94 100644 --- a/media/libaom/moz.yaml +++ b/media/libaom/moz.yaml @@ -20,11 +20,11 @@ origin: # Human-readable identifier for this version/release # Generally "version NNN", "tag SSS", "bookmark SSS" - release: 4e3595a426bacb022e8152540a32753c43822f54 (Thu Mar 27 13:41:21 2025 -0700). + release: 719f60edc51b6141a2434bf1b5110c2fb075b246 (Fri Apr 25 19:13:37 2025 -0700). # Revision to pull in # Must be a long or short commit SHA (long preferred) - revision: 4e3595a426bacb022e8152540a32753c43822f54 + revision: 719f60edc51b6141a2434bf1b5110c2fb075b246 # The package's license, where possible using the mnemonic from # https://spdx.org/licenses/ diff --git a/third_party/aom/AUTHORS b/third_party/aom/AUTHORS index 84c63b2f1766..a12d6b2554c7 100644 --- a/third_party/aom/AUTHORS +++ b/third_party/aom/AUTHORS @@ -32,6 +32,7 @@ Arild Fuldseth Aron Rosenberg Arpad Panyik Arun Singh Negi +Athulya Raj Raji Mohini Attila Nagy Balaji Anandapadmanaban Bohan Li diff --git a/third_party/aom/CHANGELOG b/third_party/aom/CHANGELOG index fce8dc94ac9a..76d870632d6b 100644 --- a/third_party/aom/CHANGELOG +++ b/third_party/aom/CHANGELOG @@ -1,3 +1,19 @@ +2025-04-11 v3.12.1 + This release includes several bug fixes. This release is ABI + compatible with the last release. See + https://aomedia.googlesource.com/aom/+log/v3.12.0..v3.12.1 for all the + commits in this release. + + - Bug Fixes + * b:396169342: Assertion + `av1_is_subpelmv_in_range(&ms_params.mv_limits, start_mv)' failed. + * b:401671154: typo in void init_src_params(...) + * Coverity defect 323670: Uninitialized scalar variable in + encode_with_and_without_superres() + * cmake: bump minimum version to 3.16 + * cfl_ppc: fix subtract_average_vsx + * Fix an incorrect index in av1_highbd_pixel_proj_error_neon + 2025-02-10 v3.12.0 This release includes new codec interfaces, compression efficiency and perceptual improvements, speedup and memory optimizations, and bug diff --git a/third_party/aom/CMakeLists.txt b/third_party/aom/CMakeLists.txt index 047b258a27c3..bd96b609740a 100644 --- a/third_party/aom/CMakeLists.txt +++ b/third_party/aom/CMakeLists.txt @@ -59,7 +59,7 @@ endif() # # The VERSION number in project() should be updated when these variables are. set(LT_CURRENT 15) -set(LT_REVISION 0) +set(LT_REVISION 1) set(LT_AGE 12) math(EXPR SO_VERSION "${LT_CURRENT} - ${LT_AGE}") set(SO_FILE_VERSION "${SO_VERSION}.${LT_AGE}.${LT_REVISION}") @@ -270,6 +270,43 @@ add_rtcd_build_step("${AOM_ROOT}/av1/common/av1_rtcd_defs.pl" add_library(aom_rtcd OBJECT ${AOM_RTCD_SOURCES}) add_dependencies(aom_rtcd aom_version) +if(CONFIG_HIGHWAY) + list(APPEND AOM_HIGHWAY_SOURCES + "${AOM_ROOT}/third_party/highway/hwy/abort.h" + "${AOM_ROOT}/third_party/highway/hwy/aligned_allocator.h" + "${AOM_ROOT}/third_party/highway/hwy/base.h" + "${AOM_ROOT}/third_party/highway/hwy/cache_control.h" + "${AOM_ROOT}/third_party/highway/hwy/per_target.h" + "${AOM_ROOT}/third_party/highway/hwy/print.h" + "${AOM_ROOT}/third_party/highway/hwy/foreach_target.h" + "${AOM_ROOT}/third_party/highway/hwy/highway_export.h" + "${AOM_ROOT}/third_party/highway/hwy/highway.h" + "${AOM_ROOT}/third_party/highway/hwy/print-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/timer-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/detect_compiler_arch.h" + "${AOM_ROOT}/third_party/highway/hwy/detect_targets.h" + "${AOM_ROOT}/third_party/highway/hwy/targets.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/arm_neon-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/arm_sve-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/emu128-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/generic_ops-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/scalar-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/set_macros-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/shared-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/x86_128-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/x86_256-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/x86_512-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/ops/x86_avx3-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/contrib/algo/copy-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/contrib/algo/find-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/contrib/algo/transform-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/contrib/dot/dot-inl.h" + "${AOM_ROOT}/third_party/highway/hwy/contrib/image/image.h" + "${AOM_ROOT}/third_party/highway/hwy/contrib/math/math-inl.h") + add_library(aom_hwy OBJECT ${AOM_HIGHWAY_SOURCES}) + set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_hwy) +endif() + if(ENABLE_EXAMPLES) add_library(aom_encoder_stats OBJECT ${AOM_ENCODER_STATS_SOURCES}) set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_encoder_stats) diff --git a/third_party/aom/aom_dsp/aom_dsp.cmake b/third_party/aom/aom_dsp/aom_dsp.cmake index 33d97135374d..9ceb10990d0b 100644 --- a/third_party/aom/aom_dsp/aom_dsp.cmake +++ b/third_party/aom/aom_dsp/aom_dsp.cmake @@ -181,6 +181,11 @@ if(CONFIG_AV1_ENCODER) "${AOM_ROOT}/aom_dsp/variance.c" "${AOM_ROOT}/aom_dsp/variance.h") + if(CONFIG_HIGHWAY) + list(APPEND AOM_DSP_ENCODER_SOURCES "${AOM_ROOT}/aom_dsp/reduce_sum_hwy.h" + "${AOM_ROOT}/aom_dsp/sad_hwy.h") + endif() + # Flow estimation library and grain/noise table/model. if(NOT CONFIG_REALTIME_ONLY) list(APPEND AOM_DSP_ENCODER_SOURCES @@ -259,6 +264,11 @@ if(CONFIG_AV1_ENCODER) "${AOM_ROOT}/aom_dsp/x86/blk_sse_sum_avx2.c" "${AOM_ROOT}/aom_dsp/x86/sum_squares_avx2.c") + if(CONFIG_HIGHWAY) + list(APPEND AOM_DSP_ENCODER_INTRIN_AVX2 + "${AOM_ROOT}/aom_dsp/x86/sad_hwy_avx2.cc") + endif() + list(APPEND AOM_DSP_ENCODER_INTRIN_AVX "${AOM_ROOT}/aom_dsp/x86/aom_quantize_avx.c") diff --git a/third_party/aom/aom_dsp/reduce_sum_hwy.h b/third_party/aom/aom_dsp/reduce_sum_hwy.h new file mode 100644 index 000000000000..9f4c00545c0f --- /dev/null +++ b/third_party/aom/aom_dsp/reduce_sum_hwy.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2025, Alliance for Open Media. All rights reserved. + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ +#ifndef AOM_AOM_DSP_REDUCE_SUM_HWY_H_ +#define AOM_AOM_DSP_REDUCE_SUM_HWY_H_ + +#include +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); + +namespace { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +template +struct BlockReduceTraits; + +template <> +struct BlockReduceTraits<1> { + template + HWY_ATTR HWY_INLINE static hn::VFromD ReduceSum(D d, hn::VFromD v) { + (void)d; + return v; + } +}; + +template +struct BlockReduceTraits { + static_assert(NumBlocks > 1, + "Primary template BlockReduceTraits assumes NumBlocks > 1"); + static_assert((NumBlocks & (NumBlocks - 1)) == 0, + "BlockReduceTraits requires NumBlocks to be a power of 2."); + + template + HWY_ATTR HWY_INLINE static hn::VFromD> ReduceSum( + D d, hn::VFromD v) { + (void)d; + constexpr hn::Half half_d; + auto v_half = hn::Add(hn::LowerHalf(half_d, v), hn::UpperHalf(half_d, v)); + return BlockReduceTraits::ReduceSum(half_d, v_half); + } +}; + +// ReduceSum across blocks. +// For example, with a 4-block vector with 16 lanes of uint32_t: +// [a3 b3 c3 d3 a2 b2 c2 d2 a1 b1 c1 d1 a0 b0 c0 d0] +// returns a vector with 4 lanes: +// [a3+a2+a1+a0 b3+b2+b1+b0 c3+c2+c1+c0 d3+d2+d1+d0] +template +HWY_ATTR HWY_INLINE hn::Vec> BlockReduceSum( + D int_tag, hn::VFromD v) { + return BlockReduceTraits::ReduceSum(int_tag, v); +} + +} // namespace HWY_NAMESPACE +} // namespace + +HWY_AFTER_NAMESPACE(); + +#endif // AOM_AOM_DSP_REDUCE_SUM_HWY_H_ diff --git a/third_party/aom/aom_dsp/sad_hwy.h b/third_party/aom/aom_dsp/sad_hwy.h new file mode 100644 index 000000000000..b14242591bd5 --- /dev/null +++ b/third_party/aom/aom_dsp/sad_hwy.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2025, Alliance for Open Media. All rights reserved. + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ +#ifndef AOM_AOM_DSP_SAD_HWY_H_ +#define AOM_AOM_DSP_SAD_HWY_H_ + +#include "aom_dsp/reduce_sum_hwy.h" +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); + +namespace { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +template +HWY_MAYBE_UNUSED unsigned int SumOfAbsoluteDiff( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred = nullptr) { + constexpr hn::CappedTag pixel_tag; + constexpr hn::Repartition intermediate_sum_tag; + const int vw = hn::Lanes(pixel_tag); + auto sum_sad = hn::Zero(intermediate_sum_tag); + const bool is_sad_avg = second_pred != nullptr; + for (int i = 0; i < h; ++i) { + for (int j = 0; j < BlockWidth; j += vw) { + auto src_vec = hn::LoadU(pixel_tag, &src_ptr[j]); + auto ref_vec = hn::LoadU(pixel_tag, &ref_ptr[j]); + if (is_sad_avg) { + auto sec_pred_vec = hn::LoadU(pixel_tag, &second_pred[j]); + ref_vec = hn::AverageRound(ref_vec, sec_pred_vec); + } + auto sad = hn::SumsOf8AbsDiff(src_vec, ref_vec); + sum_sad = hn::Add(sum_sad, sad); + } + src_ptr += src_stride; + ref_ptr += ref_stride; + if (is_sad_avg) { + second_pred += BlockWidth; + } + } + return static_cast( + hn::ReduceSum(intermediate_sum_tag, sum_sad)); +} + +} // namespace HWY_NAMESPACE +} // namespace + +#define FSAD(w, h, suffix) \ + extern "C" unsigned int aom_sad##w##x##h##_##suffix( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride); \ + HWY_ATTR unsigned int aom_sad##w##x##h##_##suffix( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride) { \ + return HWY_NAMESPACE::SumOfAbsoluteDiff(src_ptr, src_stride, ref_ptr, \ + ref_stride, h); \ + } + +#define FOR_EACH_SAD_BLOCK_SIZE(X, suffix) \ + X(128, 128, suffix) \ + X(128, 64, suffix) \ + X(64, 128, suffix) \ + X(64, 64, suffix) \ + X(64, 32, suffix) + +HWY_AFTER_NAMESPACE(); + +#endif // AOM_AOM_DSP_SAD_HWY_H_ diff --git a/third_party/aom/aom_dsp/x86/sad_avx2.c b/third_party/aom/aom_dsp/x86/sad_avx2.c index f19ff0577442..ee0238038354 100644 --- a/third_party/aom/aom_dsp/x86/sad_avx2.c +++ b/third_party/aom/aom_dsp/x86/sad_avx2.c @@ -101,11 +101,17 @@ static inline unsigned int sad32xh_avx2(const uint8_t *src_ptr, int src_stride, h / 2); \ } +#if CONFIG_HIGHWAY +#define FSAD64 \ + FSADS64_H(64) \ + FSADS64_H(32) +#else #define FSAD64 \ FSAD64_H(64) \ FSAD64_H(32) \ FSADS64_H(64) \ FSADS64_H(32) +#endif #define FSAD32 \ FSAD32_H(64) \ diff --git a/third_party/aom/aom_dsp/x86/sad_hwy_avx2.cc b/third_party/aom/aom_dsp/x86/sad_hwy_avx2.cc new file mode 100644 index 000000000000..2df2646f9d35 --- /dev/null +++ b/third_party/aom/aom_dsp/x86/sad_hwy_avx2.cc @@ -0,0 +1,17 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved. + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#define HWY_BASELINE_TARGETS HWY_AVX2 +#define HWY_BROKEN_32BIT 0 + +#include "aom_dsp/sad_hwy.h" + +FOR_EACH_SAD_BLOCK_SIZE(FSAD, avx2) diff --git a/third_party/aom/aom_dsp/x86/sad_impl_avx2.c b/third_party/aom/aom_dsp/x86/sad_impl_avx2.c index 0d1b5ab8765e..2c6fa2450b33 100644 --- a/third_party/aom/aom_dsp/x86/sad_impl_avx2.c +++ b/third_party/aom/aom_dsp/x86/sad_impl_avx2.c @@ -56,6 +56,7 @@ static unsigned int sad64x64(const uint8_t *src_ptr, int src_stride, return sum; } +#if !CONFIG_HIGHWAY unsigned int aom_sad128x64_avx2(const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride) { unsigned int half_width = 64; @@ -83,6 +84,7 @@ unsigned int aom_sad128x128_avx2(const uint8_t *src_ptr, int src_stride, sum += aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride); return sum; } +#endif unsigned int aom_sad_skip_128x64_avx2(const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride) { diff --git a/third_party/aom/aom_dsp/x86/synonyms.h b/third_party/aom/aom_dsp/x86/synonyms.h index bbaa0a0c4818..0f829821a996 100644 --- a/third_party/aom/aom_dsp/x86/synonyms.h +++ b/third_party/aom/aom_dsp/x86/synonyms.h @@ -46,16 +46,6 @@ static inline __m128i xx_loadu_128(const void *a) { return _mm_loadu_si128((const __m128i *)a); } -// _mm_loadu_si64 has been introduced in GCC 9, reimplement the function -// manually on older compilers. -#if !defined(__clang__) && __GNUC_MAJOR__ < 9 -static inline __m128i xx_loadu_2x64(const void *hi, const void *lo) { - __m64 hi_, lo_; - memcpy(&hi_, hi, sizeof(hi_)); - memcpy(&lo_, lo, sizeof(lo_)); - return _mm_set_epi64(hi_, lo_); -} -#else // Load 64 bits from each of hi and low, and pack into an SSE register // Since directly loading as `int64_t`s and using _mm_set_epi64 may violate // the strict aliasing rule, this takes a different approach @@ -63,7 +53,6 @@ static inline __m128i xx_loadu_2x64(const void *hi, const void *lo) { return _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)lo), _mm_loadl_epi64((const __m128i *)hi)); } -#endif static inline void xx_storel_32(void *const a, const __m128i v) { const int val = _mm_cvtsi128_si32(v); diff --git a/third_party/aom/aom_dsp/x86/synonyms_avx2.h b/third_party/aom/aom_dsp/x86/synonyms_avx2.h index 5b8a79f8c445..20e6a4b23a04 100644 --- a/third_party/aom/aom_dsp/x86/synonyms_avx2.h +++ b/third_party/aom/aom_dsp/x86/synonyms_avx2.h @@ -76,26 +76,11 @@ static inline __m256i yy_loadu_4x64(const void *e3, const void *e2, return yy_set_m128i(_mm_castpd_si128(v23), _mm_castpd_si128(v01)); } -#define GCC_VERSION (__GNUC__ * 10000 \ - + __GNUC_MINOR__ * 100 \ - + __GNUC_PATCHLEVEL__) - -// _mm256_loadu2_m128i has been introduced in GCC 10.1 -#if !defined(__clang__) && GCC_VERSION < 101000 -static inline __m256i yy_loadu2_128(const void *hi, const void *lo) { - __m128i mhi = _mm_loadu_si128((const __m128i *)(hi)); - __m128i mlo = _mm_loadu_si128((const __m128i *)(lo)); - return _mm256_set_m128i(mhi, mlo); -} -#else static inline __m256i yy_loadu2_128(const void *hi, const void *lo) { __m128i mhi = _mm_loadu_si128((const __m128i *)(hi)); __m128i mlo = _mm_loadu_si128((const __m128i *)(lo)); return yy_set_m128i(mhi, mlo); } -#endif - -#undef GCC_VERSION static inline void yy_storeu2_128(void *hi, void *lo, const __m256i a) { _mm_storeu_si128((__m128i *)hi, _mm256_extracti128_si256(a, 1)); diff --git a/third_party/aom/aom_ports/x86.h b/third_party/aom/aom_ports/x86.h index 3d27a2e83ab6..742c8f369a33 100644 --- a/third_party/aom/aom_ports/x86.h +++ b/third_party/aom/aom_ports/x86.h @@ -171,6 +171,19 @@ static inline uint64_t xgetbv(void) { #define BIT(n) (1u << (n)) #endif +#define MMX_BITS BIT(23) +#define SSE_BITS BIT(25) +#define SSE2_BITS BIT(26) +#define SSE3_BITS BIT(0) +#define SSSE3_BITS BIT(9) +#define SSE4_1_BITS BIT(19) +// Bits 27 (OSXSAVE) & 28 (256-bit AVX) +#define AVX_BITS (BIT(27) | BIT(28)) +#define AVX2_BITS BIT(5) + +#define FEATURE_SET(reg, feature) \ + (((reg) & (feature##_BITS)) == (feature##_BITS)) + static inline int x86_simd_caps(void) { unsigned int flags = 0; unsigned int mask = ~0u; @@ -179,11 +192,9 @@ static inline int x86_simd_caps(void) { /* See if the CPU capabilities are being overridden by the environment */ env = getenv("AOM_SIMD_CAPS"); - if (env && *env) return (int)strtol(env, NULL, 0); env = getenv("AOM_SIMD_CAPS_MASK"); - if (env && *env) mask = (unsigned int)strtoul(env, NULL, 0); /* Ensure that the CPUID instruction supports extended features */ @@ -194,37 +205,26 @@ static inline int x86_simd_caps(void) { /* Get the standard feature flags */ cpuid(1, 0, reg_eax, reg_ebx, reg_ecx, reg_edx); - if (reg_edx & BIT(23)) flags |= HAS_MMX; - - if (reg_edx & BIT(25)) flags |= HAS_SSE; /* aka xmm */ - - if (reg_edx & BIT(26)) flags |= HAS_SSE2; /* aka wmt */ - - if (reg_ecx & BIT(0)) flags |= HAS_SSE3; - - if (reg_ecx & BIT(9)) flags |= HAS_SSSE3; - - if (reg_ecx & BIT(19)) flags |= HAS_SSE4_1; - - if (reg_ecx & BIT(20)) flags |= HAS_SSE4_2; + flags |= FEATURE_SET(reg_edx, MMX) ? HAS_MMX : 0; + flags |= FEATURE_SET(reg_edx, SSE) ? HAS_SSE : 0; + flags |= FEATURE_SET(reg_edx, SSE2) ? HAS_SSE2 : 0; + flags |= FEATURE_SET(reg_ecx, SSE3) ? HAS_SSE3 : 0; + flags |= FEATURE_SET(reg_ecx, SSSE3) ? HAS_SSSE3 : 0; + flags |= FEATURE_SET(reg_ecx, SSE4_1) ? HAS_SSE4_1 : 0; // bits 27 (OSXSAVE) & 28 (256-bit AVX) - if ((reg_ecx & (BIT(27) | BIT(28))) == (BIT(27) | BIT(28))) { + if (FEATURE_SET(reg_ecx, AVX)) { // Check for OS-support of YMM state. Necessary for AVX and AVX2. if ((xgetbv() & 0x6) == 0x6) { flags |= HAS_AVX; - if (max_cpuid_val >= 7) { /* Get the leaf 7 feature flags. Needed to check for AVX2 support */ cpuid(7, 0, reg_eax, reg_ebx, reg_ecx, reg_edx); - - if (reg_ebx & BIT(5)) flags |= HAS_AVX2; + flags |= FEATURE_SET(reg_ebx, AVX2) ? HAS_AVX2 : 0; } } } - (void)reg_eax; // Avoid compiler warning on unused-but-set variable. - return flags & mask; } diff --git a/third_party/aom/av1/av1_cx_iface.c b/third_party/aom/av1/av1_cx_iface.c index c412e50180b7..9a423dc4fa9c 100644 --- a/third_party/aom/av1/av1_cx_iface.c +++ b/third_party/aom/av1/av1_cx_iface.c @@ -3883,6 +3883,38 @@ static aom_codec_err_t ctrl_set_svc_params(aom_codec_alg_priv_t *ctx, ppi->number_temporal_layers = params->number_temporal_layers; cpi->svc.number_spatial_layers = params->number_spatial_layers; cpi->svc.number_temporal_layers = params->number_temporal_layers; + // Sequence parameters (operating_points_cnt_minus_1, operating_point_idc[]) + // need to be updated if the number of layers have changed. + // Force a keyframe here and update the two relevant sequence parameters. + if (cpi->svc.prev_number_temporal_layers && + cpi->svc.prev_number_spatial_layers && + (cpi->svc.number_temporal_layers != + cpi->svc.prev_number_temporal_layers || + cpi->svc.number_spatial_layers != cpi->svc.prev_number_spatial_layers)) { + SequenceHeader *const seq_params = &ppi->seq_params; + seq_params->operating_points_cnt_minus_1 = + ppi->number_spatial_layers * ppi->number_temporal_layers - 1; + ctx->next_frame_flags |= AOM_EFLAG_FORCE_KF; + av1_set_svc_seq_params(ppi); + // Check for valid values for the spatial/temporal_layer_id here, since + // there has been a dynamic change in the number_spatial/temporal_layers, + // and if the ctrl_set_layer_id is not used after this call, the + // previous (last_encoded) values of spatial/temporal_layer_id will be used, + // which may be invalid. + cpi->svc.spatial_layer_id = AOMMAX( + 0, + AOMMIN(cpi->svc.spatial_layer_id, cpi->svc.number_spatial_layers - 1)); + cpi->svc.temporal_layer_id = + AOMMAX(0, AOMMIN(cpi->svc.temporal_layer_id, + cpi->svc.number_temporal_layers - 1)); + cpi->common.spatial_layer_id = + AOMMAX(0, AOMMIN(cpi->common.spatial_layer_id, + cpi->svc.number_spatial_layers - 1)); + cpi->common.temporal_layer_id = + AOMMAX(0, AOMMIN(cpi->common.temporal_layer_id, + cpi->svc.number_temporal_layers - 1)); + } + if (ppi->number_spatial_layers > 1 || ppi->number_temporal_layers > 1) { unsigned int sl, tl; ctx->ppi->use_svc = 1; diff --git a/third_party/aom/av1/common/ppc/cfl_ppc.c b/third_party/aom/av1/common/ppc/cfl_ppc.c index 36defe04ec0b..c2a25c9298f8 100644 --- a/third_party/aom/av1/common/ppc/cfl_ppc.c +++ b/third_party/aom/av1/common/ppc/cfl_ppc.c @@ -19,7 +19,6 @@ #define OFF_1 16 #define OFF_2 32 #define OFF_3 48 -#define CFL_BUF_LINE_BYTES 64 #define CFL_LINE_1 64 #define CFL_LINE_2 128 #define CFL_LINE_3 192 @@ -35,8 +34,6 @@ typedef vector unsigned long long uint64x2_t; // NOLINT(runtime/int) static inline void subtract_average_vsx(const uint16_t *src_ptr, int16_t *dst, int width, int height, int round_offset, int num_pel_log2) { - // int16_t *dst = dst_ptr; - const int16_t *dst_end = dst + height * CFL_BUF_LINE; const int16_t *sum_buf = (const int16_t *)src_ptr; const int16_t *end = sum_buf + height * CFL_BUF_LINE; const uint32x4_t div_shift = vec_splats((uint32_t)num_pel_log2); @@ -63,7 +60,8 @@ static inline void subtract_average_vsx(const uint16_t *src_ptr, int16_t *dst, sum_32x4_1 = vec_sum4s(vec_vsx_ld(OFF_3 + CFL_LINE_1, sum_buf), sum_32x4_1); } - } while ((sum_buf += (CFL_BUF_LINE * 2)) < end); + sum_buf += CFL_BUF_LINE * 2; + } while (sum_buf < end); int32x4_t sum_32x4 = vec_add(sum_32x4_0, sum_32x4_1); const int32x4_t perm_64 = vec_perm(sum_32x4, sum_32x4, mask_64); @@ -72,41 +70,44 @@ static inline void subtract_average_vsx(const uint16_t *src_ptr, int16_t *dst, sum_32x4 = vec_add(sum_32x4, perm_32); const int32x4_t avg = vec_sr(sum_32x4, div_shift); const int16x8_t vec_avg = vec_pack(avg, avg); + const int16_t *src = (const int16_t *)src_ptr; do { - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0, dst), vec_avg), OFF_0, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_1, dst), vec_avg), - OFF_0 + CFL_BUF_LINE_BYTES, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_2, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0, src), vec_avg), OFF_0, dst); + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_1, src), vec_avg), + OFF_0 + CFL_LINE_1, dst); + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_2, src), vec_avg), OFF_0 + CFL_LINE_2, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_3, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_3, src), vec_avg), OFF_0 + CFL_LINE_3, dst); if (width >= 16) { - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1, dst), vec_avg), OFF_1, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_1, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1, src), vec_avg), OFF_1, dst); + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_1, src), vec_avg), OFF_1 + CFL_LINE_1, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_2, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_2, src), vec_avg), OFF_1 + CFL_LINE_2, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_3, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_3, src), vec_avg), OFF_1 + CFL_LINE_3, dst); } if (width == 32) { - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2, dst), vec_avg), OFF_2, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_1, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2, src), vec_avg), OFF_2, dst); + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_1, src), vec_avg), OFF_2 + CFL_LINE_1, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_2, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_2, src), vec_avg), OFF_2 + CFL_LINE_2, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_3, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_3, src), vec_avg), OFF_2 + CFL_LINE_3, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3, dst), vec_avg), OFF_3, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_1, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3, src), vec_avg), OFF_3, dst); + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_1, src), vec_avg), OFF_3 + CFL_LINE_1, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_2, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_2, src), vec_avg), OFF_3 + CFL_LINE_2, dst); - vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_3, dst), vec_avg), + vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_3, src), vec_avg), OFF_3 + CFL_LINE_3, dst); } - } while ((dst += CFL_BUF_LINE * 4) < dst_end); + src += CFL_BUF_LINE * 4; + dst += CFL_BUF_LINE * 4; + } while (src < end); } // Declare wrappers for VSX sizes diff --git a/third_party/aom/av1/encoder/compound_type.c b/third_party/aom/av1/encoder/compound_type.c index 0b33ab5ba15a..95c5c235d6d5 100644 --- a/third_party/aom/av1/encoder/compound_type.c +++ b/third_party/aom/av1/encoder/compound_type.c @@ -235,12 +235,6 @@ static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x, model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N, &rate, &dist); - // int rate2; - // int64_t dist2; - // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2); - // printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n", - // sse, rate, dist, rate2, dist2); dist = dist2; - // rate = rate2; rate += x->mode_costs.wedge_idx_cost[bsize][wedge_index]; rd = RDCOST(x->rdmult, rate, dist); diff --git a/third_party/aom/av1/encoder/encodeframe.c b/third_party/aom/av1/encoder/encodeframe.c index 64968ab89024..6213864c5cee 100644 --- a/third_party/aom/av1/encoder/encodeframe.c +++ b/third_party/aom/av1/encoder/encodeframe.c @@ -2060,6 +2060,8 @@ static inline void encode_frame_internal(AV1_COMP *cpi) { start_timing(cpi, av1_setup_motion_field_time); #endif av1_calculate_ref_frame_side(cm); + + features->allow_ref_frame_mvs &= !cpi->sf.hl_sf.disable_ref_frame_mvs; if (features->allow_ref_frame_mvs) av1_setup_motion_field(cm); #if CONFIG_COLLECT_COMPONENT_TIMING end_timing(cpi, av1_setup_motion_field_time); diff --git a/third_party/aom/av1/encoder/encoder.c b/third_party/aom/av1/encoder/encoder.c index 574e48358b8e..ca1f32dc7732 100644 --- a/third_party/aom/av1/encoder/encoder.c +++ b/third_party/aom/av1/encoder/encoder.c @@ -487,6 +487,32 @@ static void set_bitstream_level_tier(AV1_PRIMARY *const ppi, int width, } } +void av1_set_svc_seq_params(AV1_PRIMARY *const ppi) { + SequenceHeader *const seq = &ppi->seq_params; + if (seq->operating_points_cnt_minus_1 == 0) { + seq->operating_point_idc[0] = 0; + seq->has_nonzero_operating_point_idc = false; + } else { + // Set operating_point_idc[] such that the i=0 point corresponds to the + // highest quality operating point (all layers), and subsequent + // operarting points (i > 0) are lower quality corresponding to + // skip decoding enhancement layers (temporal first). + int i = 0; + assert(seq->operating_points_cnt_minus_1 == + (int)(ppi->number_spatial_layers * ppi->number_temporal_layers - 1)); + for (unsigned int sl = 0; sl < ppi->number_spatial_layers; sl++) { + for (unsigned int tl = 0; tl < ppi->number_temporal_layers; tl++) { + seq->operating_point_idc[i] = + (~(~0u << (ppi->number_spatial_layers - sl)) << 8) | + ~(~0u << (ppi->number_temporal_layers - tl)); + assert(seq->operating_point_idc[i] != 0); + i++; + } + } + seq->has_nonzero_operating_point_idc = true; + } +} + static void init_seq_coding_tools(AV1_PRIMARY *const ppi, const AV1EncoderConfig *oxcf, int disable_frame_id_numbers) { @@ -551,29 +577,7 @@ static void init_seq_coding_tools(AV1_PRIMARY *const ppi, set_bitstream_level_tier(ppi, frm_dim_cfg->width, frm_dim_cfg->height, oxcf->input_cfg.init_framerate); - - if (seq->operating_points_cnt_minus_1 == 0) { - seq->operating_point_idc[0] = 0; - seq->has_nonzero_operating_point_idc = false; - } else { - // Set operating_point_idc[] such that the i=0 point corresponds to the - // highest quality operating point (all layers), and subsequent - // operarting points (i > 0) are lower quality corresponding to - // skip decoding enhancement layers (temporal first). - int i = 0; - assert(seq->operating_points_cnt_minus_1 == - (int)(ppi->number_spatial_layers * ppi->number_temporal_layers - 1)); - for (unsigned int sl = 0; sl < ppi->number_spatial_layers; sl++) { - for (unsigned int tl = 0; tl < ppi->number_temporal_layers; tl++) { - seq->operating_point_idc[i] = - (~(~0u << (ppi->number_spatial_layers - sl)) << 8) | - ~(~0u << (ppi->number_temporal_layers - tl)); - assert(seq->operating_point_idc[i] != 0); - i++; - } - } - seq->has_nonzero_operating_point_idc = true; - } + av1_set_svc_seq_params(ppi); } static void init_config_sequence(struct AV1_PRIMARY *ppi, @@ -770,6 +774,11 @@ void av1_change_config_seq(struct AV1_PRIMARY *ppi, // Init sequence level coding tools // This should not be called after the first key frame. + // Note that for SVC encoding the sequence parameters + // (operating_points_cnt_minus_1, operating_point_idc[], + // has_nonzero_operating_point_idc) should be updated whenever the + // number of layers is changed. This is done in the + // ctrl_set_svc_params(). if (!ppi->seq_params_locked) { seq_params->operating_points_cnt_minus_1 = (ppi->number_spatial_layers > 1 || ppi->number_temporal_layers > 1) @@ -2276,7 +2285,12 @@ void av1_set_frame_size(AV1_COMP *cpi, int width, int height) { if (av1_is_scaled(sf)) aom_extend_frame_borders(&buf->buf, num_planes); } } - if (!frame_is_intra_only(cm) && !has_valid_ref_frame) { + // For 1 pass CBR mode: we can skip this check for spatial enhancement + // layer if the target_bandwidth is zero, since it will be dropped. + const bool dropped_frame = + has_no_stats_stage(cpi) && cpi->oxcf.rc_cfg.mode == AOM_CBR && + cpi->svc.spatial_layer_id > 0 && cpi->oxcf.rc_cfg.target_bandwidth == 0; + if (!frame_is_intra_only(cm) && !has_valid_ref_frame && !dropped_frame) { aom_internal_error( cm->error, AOM_CODEC_CORRUPT_FRAME, "Can't find at least one reference frame with valid size"); @@ -2990,10 +3004,6 @@ static int encode_with_recode_loop(AV1_COMP *cpi, size_t *size, uint8_t *dest, av1_set_variance_partition_thresholds(cpi, q, 0); - // printf("Frame %d/%d: q = %d, frame_type = %d superres_denom = %d\n", - // cm->current_frame.frame_number, cm->show_frame, q, - // cm->current_frame.frame_type, cm->superres_scale_denominator); - if (loop_count == 0) { av1_setup_frame(cpi); } else if (get_primary_ref_frame_buf(cm) == NULL) { @@ -4010,8 +4020,10 @@ static int encode_frame_to_data_rate(AV1_COMP *cpi, size_t *size, uint8_t *dest, cpi->frames_since_last_update = 1; } - if (cpi->svc.spatial_layer_id == cpi->svc.number_spatial_layers - 1) + if (cpi->svc.spatial_layer_id == cpi->svc.number_spatial_layers - 1) { cpi->svc.prev_number_spatial_layers = cpi->svc.number_spatial_layers; + } + cpi->svc.prev_number_temporal_layers = cpi->svc.number_temporal_layers; // Clear the one shot update flags for segmentation map and mode/ref loop // filter deltas. diff --git a/third_party/aom/av1/encoder/encoder.h b/third_party/aom/av1/encoder/encoder.h index 5c20b5549167..c8c50ed6d211 100644 --- a/third_party/aom/av1/encoder/encoder.h +++ b/third_party/aom/av1/encoder/encoder.h @@ -2675,7 +2675,11 @@ typedef struct AV1_PRIMARY { /*! * Sequence parameters have been transmitted already and locked * or not. Once locked av1_change_config cannot change the seq - * parameters. + * parameters. Note that for SVC encoding the sequence parameters + * (operating_points_cnt_minus_1, operating_point_idc[], + * has_nonzero_operating_point_idc) should be updated whenever the + * number of layers is changed. This is done in the + * ctrl_set_svc_params(). */ int seq_params_locked; @@ -3905,6 +3909,8 @@ void av1_set_screen_content_options(struct AV1_COMP *cpi, void av1_update_frame_size(AV1_COMP *cpi); +void av1_set_svc_seq_params(AV1_PRIMARY *const ppi); + typedef struct { int pyr_level; int disp_order; diff --git a/third_party/aom/av1/encoder/encoder_utils.c b/third_party/aom/av1/encoder/encoder_utils.c index 0f807b670c32..d6ae4e8961f4 100644 --- a/third_party/aom/av1/encoder/encoder_utils.c +++ b/third_party/aom/av1/encoder/encoder_utils.c @@ -518,8 +518,6 @@ static void process_tpl_stats_frame(AV1_COMP *cpi) { const int gfu_boost = get_gfu_boost_from_r0_lap( min_boost_factor, MAX_GFUBOOST_FACTOR, cpi->rd.r0, cpi->ppi->p_rc.num_stats_required_for_gfu_boost); - // printf("old boost %d new boost %d\n", cpi->rc.gfu_boost, - // gfu_boost); cpi->ppi->p_rc.gfu_boost = combine_prior_with_tpl_boost( min_boost_factor, MAX_BOOST_COMBINE_FACTOR, cpi->ppi->p_rc.gfu_boost, gfu_boost, @@ -840,10 +838,12 @@ BLOCK_SIZE av1_select_sb_size(const AV1EncoderConfig *const oxcf, int width, } assert(oxcf->tool_cfg.superblock_size == AOM_SUPERBLOCK_SIZE_DYNAMIC); - if (number_spatial_layers > 1 || - oxcf->resize_cfg.resize_mode != RESIZE_NONE) { - // Use the configured size (top resolution) for spatial layers or - // on resize. + if (number_spatial_layers > 1) { + // For spatial layers better selection may be done given the resolutions + // used across the layers, but for now use 64x64 for spatial layers. + return BLOCK_64X64; + } else if (oxcf->resize_cfg.resize_mode != RESIZE_NONE) { + // Use the configured size (top resolution) for resize. return AOMMIN(oxcf->frm_dim_cfg.width, oxcf->frm_dim_cfg.height) > 720 ? BLOCK_128X128 : BLOCK_64X64; diff --git a/third_party/aom/av1/encoder/global_motion.c b/third_party/aom/av1/encoder/global_motion.c index 18ea46fa2362..bf38b32bef17 100644 --- a/third_party/aom/av1/encoder/global_motion.c +++ b/third_party/aom/av1/encoder/global_motion.c @@ -30,8 +30,9 @@ // Border over which to compute the global motion #define ERRORADV_BORDER 0 -int av1_is_enough_erroradvantage(double best_erroradvantage, int params_cost) { - return best_erroradvantage < erroradv_tr && +int av1_is_enough_erroradvantage(double best_erroradvantage, int params_cost, + double gm_erroradv_tr) { + return best_erroradvantage < gm_erroradv_tr && best_erroradvantage * params_cost < erroradv_prod_tr; } @@ -364,7 +365,8 @@ int64_t av1_refine_integerized_param( WarpedMotionParams *wm, TransformationType wmtype, int use_hbd, int bd, uint8_t *ref, int r_width, int r_height, int r_stride, uint8_t *dst, int d_width, int d_height, int d_stride, int n_refinements, - int64_t ref_frame_error, uint8_t *segment_map, int segment_map_stride) { + int64_t ref_frame_error, uint8_t *segment_map, int segment_map_stride, + double gm_erroradv_tr) { static const int max_trans_model_params[TRANS_TYPES] = { 0, 2, 4, 6 }; const int border = ERRORADV_BORDER; int i = 0, p; @@ -383,7 +385,8 @@ int64_t av1_refine_integerized_param( // Compute the maximum error value that will be accepted, so that // get_warp_error can terminate early if it proves the model will not // be accepted. - int64_t selection_threshold = (int64_t)lrint(ref_frame_error * erroradv_tr); + int64_t selection_threshold = + (int64_t)lrint(ref_frame_error * gm_erroradv_tr); return get_warp_error(wm, use_hbd, bd, ref, r_width, r_height, r_stride, dst + border * d_stride + border, d_stride, border, border, d_width - 2 * border, d_height - 2 * border, diff --git a/third_party/aom/av1/encoder/global_motion.h b/third_party/aom/av1/encoder/global_motion.h index 4d8c8481987c..97feff41f757 100644 --- a/third_party/aom/av1/encoder/global_motion.h +++ b/third_party/aom/av1/encoder/global_motion.h @@ -77,7 +77,7 @@ void av1_convert_model_to_params(const double *params, WarpedMotionParams *model); // Criteria for accepting a global motion model -static const double erroradv_tr = 0.65; +static const double erroradv_tr[2] = { 0.65, 0.2 }; static const double erroradv_prod_tr = 20000; // Early exit threshold for global motion refinement @@ -91,7 +91,8 @@ static const double erroradv_prod_tr = 20000; // threshold even if the model is initially above the threshold static const double erroradv_early_tr = 0.70; -int av1_is_enough_erroradvantage(double best_erroradvantage, int params_cost); +int av1_is_enough_erroradvantage(double best_erroradvantage, int params_cost, + double gm_erroradv_tr); void av1_compute_feature_segmentation_map(uint8_t *segment_map, int width, int height, int *inliers, @@ -109,7 +110,8 @@ int64_t av1_refine_integerized_param( WarpedMotionParams *wm, TransformationType wmtype, int use_hbd, int bd, uint8_t *ref, int r_width, int r_height, int r_stride, uint8_t *dst, int d_width, int d_height, int d_stride, int n_refinements, - int64_t ref_frame_error, uint8_t *segment_map, int segment_map_stride); + int64_t ref_frame_error, uint8_t *segment_map, int segment_map_stride, + double gm_erroradv_tr); #ifdef __cplusplus } // extern "C" diff --git a/third_party/aom/av1/encoder/global_motion_facade.c b/third_party/aom/av1/encoder/global_motion_facade.c index 73a4e3c17fc2..df625c0fa1ee 100644 --- a/third_party/aom/av1/encoder/global_motion_facade.c +++ b/third_party/aom/av1/encoder/global_motion_facade.c @@ -91,13 +91,15 @@ static inline void compute_global_motion_for_ref_frame( GlobalMotionMethod global_motion_method = default_global_motion_method; int downsample_level = cpi->sf.gm_sf.downsample_level; int num_refinements = cpi->sf.gm_sf.num_refinement_steps; + int gm_erroradv_tr_level = cpi->sf.gm_sf.gm_erroradv_tr_level; bool mem_alloc_failed = false; + assert(gm_erroradv_tr_level < 2); // Select the best model based on fractional error reduction. // By initializing this to erroradv_tr, the same logic which is used to // select the best model will automatically filter out any model which // doesn't meet the required quality threshold - double best_erroradv = erroradv_tr; + double best_erroradv = erroradv_tr[gm_erroradv_tr_level]; for (TransformationType model = FIRST_GLOBAL_TRANS_TYPE; model <= LAST_GLOBAL_TRANS_TYPE; ++model) { if (!aom_compute_global_motion(model, cpi->source, ref_buf[frame], @@ -148,7 +150,8 @@ static inline void compute_global_motion_for_ref_frame( ref_buf[frame]->y_buffer, ref_buf[frame]->y_crop_width, ref_buf[frame]->y_crop_height, ref_buf[frame]->y_stride, cpi->source->y_buffer, src_width, src_height, src_stride, - num_refinements, ref_frame_error, segment_map, segment_map_w); + num_refinements, ref_frame_error, segment_map, segment_map_w, + erroradv_tr[gm_erroradv_tr_level]); // av1_refine_integerized_param() can return a simpler model type than // its input, so re-check model type here @@ -160,7 +163,8 @@ static inline void compute_global_motion_for_ref_frame( if (!av1_is_enough_erroradvantage( erroradvantage, gm_get_params_cost(&tmp_wm_params, ref_params, - cm->features.allow_high_precision_mv))) { + cm->features.allow_high_precision_mv), + erroradv_tr[gm_erroradv_tr_level])) { continue; } diff --git a/third_party/aom/av1/encoder/gop_structure.c b/third_party/aom/av1/encoder/gop_structure.c index c2395025a034..308290fee520 100644 --- a/third_party/aom/av1/encoder/gop_structure.c +++ b/third_party/aom/av1/encoder/gop_structure.c @@ -642,7 +642,7 @@ static int construct_multi_layer_gf_structure( : gf_group->is_sframe_due ? S_FRAME : INTER_FRAME; gf_group->is_sframe_due = - sframe_enabled && !(gf_group->frame_type[frame_index] == S_FRAME); + sframe_enabled && gf_group->frame_type[frame_index] != S_FRAME; gf_group->refbuf_state[frame_index] = REFBUF_UPDATE; gf_group->max_layer_depth = 1; gf_group->arf_index = frame_index; diff --git a/third_party/aom/av1/encoder/mcomp.c b/third_party/aom/av1/encoder/mcomp.c index 9819dd58d585..fdbef7cae042 100644 --- a/third_party/aom/av1/encoder/mcomp.c +++ b/third_party/aom/av1/encoder/mcomp.c @@ -102,21 +102,18 @@ void av1_make_default_fullpel_ms_params( ms_params->mv_limits = x->mv_limits; av1_set_mv_search_range(&ms_params->mv_limits, ref_mv); - if (cpi->oxcf.algo_cfg.sharpness) { + if (cpi->oxcf.algo_cfg.sharpness == 3) { int top_margin = x->e_mbd.mi_row * MI_SIZE + 8; int left_margin = x->e_mbd.mi_col * MI_SIZE + 8; - int bottom_margin = cpi->common.cur_frame->height - - mi_size_high[bsize] * MI_SIZE - top_margin + 16; - int right_margin = cpi->common.cur_frame->width - - mi_size_wide[bsize] * MI_SIZE - left_margin + 16; - if (ms_params->mv_limits.row_min < -top_margin) - ms_params->mv_limits.row_min = -top_margin; - if (ms_params->mv_limits.row_max > bottom_margin) - ms_params->mv_limits.row_max = bottom_margin; - if (ms_params->mv_limits.col_min < -left_margin) - ms_params->mv_limits.col_min = -left_margin; - if (ms_params->mv_limits.col_max > right_margin) - ms_params->mv_limits.col_max = right_margin; + int bottom_margin = + cpi->common.height - mi_size_high[bsize] * MI_SIZE - top_margin + 16; + int right_margin = + cpi->common.width - mi_size_wide[bsize] * MI_SIZE - left_margin + 16; + FullMvLimits *mv_limits = &ms_params->mv_limits; + mv_limits->row_min = AOMMAX(mv_limits->row_min, -top_margin); + mv_limits->row_max = AOMMIN(mv_limits->row_max, bottom_margin); + mv_limits->col_min = AOMMAX(mv_limits->col_min, -left_margin); + mv_limits->col_max = AOMMIN(mv_limits->col_max, right_margin); } // Mvcost params @@ -193,6 +190,22 @@ void av1_make_default_subpel_ms_params(SUBPEL_MOTION_SEARCH_PARAMS *ms_params, av1_set_subpel_mv_search_range(&ms_params->mv_limits, &x->mv_limits, ref_mv); + if (cpi->oxcf.algo_cfg.sharpness == 3) { + int top_margin = GET_MV_SUBPEL(x->e_mbd.mi_row * MI_SIZE + 8); + int left_margin = GET_MV_SUBPEL(x->e_mbd.mi_col * MI_SIZE + 8); + int bottom_margin = + GET_MV_SUBPEL(cpi->common.height - mi_size_high[bsize] * MI_SIZE - + x->e_mbd.mi_row * MI_SIZE + 8); + int right_margin = + GET_MV_SUBPEL(cpi->common.width - mi_size_wide[bsize] * MI_SIZE - + x->e_mbd.mi_col * MI_SIZE + 8); + SubpelMvLimits *mv_limits = &ms_params->mv_limits; + mv_limits->row_min = AOMMAX(mv_limits->row_min, -top_margin); + mv_limits->row_max = AOMMIN(mv_limits->row_max, bottom_margin); + mv_limits->col_min = AOMMAX(mv_limits->col_min, -left_margin); + mv_limits->col_max = AOMMIN(mv_limits->col_max, right_margin); + } + // Mvcost params init_mv_cost_params(&ms_params->mv_cost_params, x->mv_costs, ref_mv, x->errorperbit, x->sadperbit); @@ -230,10 +243,10 @@ void av1_set_mv_search_range(FullMvLimits *mv_limits, const MV *mv) { // Get intersection of UMV window and valid MV window to reduce # of checks // in diamond search. - if (mv_limits->col_min < col_min) mv_limits->col_min = col_min; - if (mv_limits->col_max > col_max) mv_limits->col_max = col_max; - if (mv_limits->row_min < row_min) mv_limits->row_min = row_min; - if (mv_limits->row_max > row_max) mv_limits->row_max = row_max; + mv_limits->col_min = AOMMAX(mv_limits->col_min, col_min); + mv_limits->col_max = AOMMIN(mv_limits->col_max, col_max); + mv_limits->row_min = AOMMAX(mv_limits->row_min, row_min); + mv_limits->row_max = AOMMIN(mv_limits->row_max, row_max); mv_limits->col_max = AOMMAX(mv_limits->col_min, mv_limits->col_max); mv_limits->row_max = AOMMAX(mv_limits->row_min, mv_limits->row_max); diff --git a/third_party/aom/av1/encoder/pass2_strategy.c b/third_party/aom/av1/encoder/pass2_strategy.c index bb3a7232803f..1ca1d503ed64 100644 --- a/third_party/aom/av1/encoder/pass2_strategy.c +++ b/third_party/aom/av1/encoder/pass2_strategy.c @@ -3408,9 +3408,6 @@ static void find_next_key_frame(AV1_COMP *cpi, FIRSTPASS_STATS *this_frame) { kf_bits = calculate_boost_bits( AOMMIN(rc->frames_to_key, frames_to_key_clipped) - 1, p_rc->kf_boost, AOMMIN(twopass->kf_group_bits, kf_group_bits_clipped)); - // printf("kf boost = %d kf_bits = %d kf_zeromotion_pct = %d\n", - // p_rc->kf_boost, - // kf_bits, twopass->kf_zeromotion_pct); kf_bits = adjust_boost_bits_for_target_level(cpi, rc, kf_bits, twopass->kf_group_bits, 0); diff --git a/third_party/aom/av1/encoder/pickrst.c b/third_party/aom/av1/encoder/pickrst.c index 3b6d8e41a51d..4e7030cc7673 100644 --- a/third_party/aom/av1/encoder/pickrst.c +++ b/third_party/aom/av1/encoder/pickrst.c @@ -1507,7 +1507,6 @@ static int64_t finer_search_wiener(const RestSearchCtxt *rsc, WienerInfo *plane_wiener = &rui->wiener_info; - // printf("err pre = %"PRId64"\n", err); const int start_step = 4; for (int s = start_step; s >= 1; s >>= 1) { for (int p = plane_off; p < WIENER_HALFWIN; ++p) { @@ -1593,7 +1592,6 @@ static int64_t finer_search_wiener(const RestSearchCtxt *rsc, } while (1); } } - // printf("err post = %"PRId64"\n", err); return err; } @@ -2052,6 +2050,8 @@ void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi) { min_lr_unit_size = AOMMAX(min_lr_unit_size, block_size_wide[cm->seq_params->sb_size]); + max_lr_unit_size = AOMMAX(min_lr_unit_size, max_lr_unit_size); + for (int plane = 0; plane < num_planes; ++plane) { cpi->pick_lr_ctxt.rusi[plane] = allocate_search_structs( cm, &cm->rst_info[plane], plane > 0, min_lr_unit_size); diff --git a/third_party/aom/av1/encoder/ratectrl.c b/third_party/aom/av1/encoder/ratectrl.c index 750f5ae7bc93..0f0316075c25 100644 --- a/third_party/aom/av1/encoder/ratectrl.c +++ b/third_party/aom/av1/encoder/ratectrl.c @@ -1448,8 +1448,6 @@ static int get_active_cq_level(const RATE_CONTROL *rc, static const double cq_adjust_threshold = 0.1; int active_cq_level = rc_cfg->cq_level; if (rc_cfg->mode == AOM_CQ || rc_cfg->mode == AOM_Q) { - // printf("Superres %d %d %d = %d\n", superres_denom, intra_only, - // rc->frames_to_key, !(intra_only && rc->frames_to_key <= 1)); if ((superres_mode == AOM_SUPERRES_QTHRESH || superres_mode == AOM_SUPERRES_AUTO) && superres_denom != SCALE_NUMERATOR) { @@ -2497,6 +2495,7 @@ void av1_rc_postencode_update_drop_frame(AV1_COMP *cpi) { if (cpi->svc.spatial_layer_id == cpi->svc.number_spatial_layers - 1) { cpi->svc.prev_number_spatial_layers = cpi->svc.number_spatial_layers; } + cpi->svc.prev_number_temporal_layers = cpi->svc.number_temporal_layers; } int av1_find_qindex(double desired_q, aom_bit_depth_t bit_depth, diff --git a/third_party/aom/av1/encoder/rd.c b/third_party/aom/av1/encoder/rd.c index ce9571a8bb1b..43e6a78fb616 100644 --- a/third_party/aom/av1/encoder/rd.c +++ b/third_party/aom/av1/encoder/rd.c @@ -649,10 +649,6 @@ void av1_fill_coeff_costs(CoeffCosts *coeff_costs, FRAME_CONTEXT *fc, av1_cost_tokens_from_cdf( br_rate, fc->coeff_br_cdf[AOMMIN(tx_size, TX_32X32)][plane][ctx], NULL); - // printf("br_rate: "); - // for(j = 0; j < BR_CDF_SIZE; j++) - // printf("%4d ", br_rate[j]); - // printf("\n"); for (i = 0; i < COEFF_BASE_RANGE; i += BR_CDF_SIZE - 1) { for (j = 0; j < BR_CDF_SIZE - 1; j++) { pcost->lps_cost[ctx][i + j] = prev_cost + br_rate[j]; @@ -660,10 +656,6 @@ void av1_fill_coeff_costs(CoeffCosts *coeff_costs, FRAME_CONTEXT *fc, prev_cost += br_rate[j]; } pcost->lps_cost[ctx][i] = prev_cost; - // printf("lps_cost: %d %d %2d : ", tx_size, plane, ctx); - // for (i = 0; i <= COEFF_BASE_RANGE; i++) - // printf("%5d ", pcost->lps_cost[ctx][i]); - // printf("\n"); } for (int ctx = 0; ctx < LEVEL_CONTEXTS; ++ctx) { pcost->lps_cost[ctx][0 + COEFF_BASE_RANGE + 1] = diff --git a/third_party/aom/av1/encoder/rdopt.c b/third_party/aom/av1/encoder/rdopt.c index 2744d29358d6..fb8ee1f05718 100644 --- a/third_party/aom/av1/encoder/rdopt.c +++ b/third_party/aom/av1/encoder/rdopt.c @@ -607,6 +607,77 @@ void av1_get_horver_correlation_full_c(const int16_t *diff, int stride, } } +static void get_variance_stats(const AV1_COMP *cpi, const MACROBLOCK *x, + int num_planes, int64_t *src_var, + int64_t *rec_var) { + const MACROBLOCKD *xd = &x->e_mbd; + const MB_MODE_INFO *mbmi = xd->mi[0]; + + DECLARE_ALIGNED(16, uint8_t, dclevel[MAX_SB_SQUARE]); + memset(dclevel, 128, sizeof(dclevel)); + int dclevel_stride = block_size_wide[mbmi->bsize]; + + *src_var = 0; + *rec_var = 0; + + for (int plane = 0; plane < num_planes; ++plane) { + if (plane && !xd->is_chroma_ref) break; + + const struct macroblock_plane *const p = &x->plane[plane]; + const struct macroblockd_plane *const pd = &xd->plane[plane]; + const BLOCK_SIZE bs = + get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y); + unsigned int sse; + + int64_t var = cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, dclevel, + dclevel_stride, &sse); + + *src_var += var; + + var = cpi->ppi->fn_ptr[bs].vf(pd->dst.buf, pd->dst.stride, dclevel, + dclevel_stride, &sse); + + *rec_var += var; + } + + *src_var <<= 4; + *rec_var <<= 4; +} + +static void adjust_rdcost(const AV1_COMP *cpi, const MACROBLOCK *x, + RD_STATS *rd_cost) { + if (cpi->oxcf.algo_cfg.sharpness != 3) return; + + if (frame_is_kf_gf_arf(cpi)) return; + + int64_t src_var, rec_var; + get_variance_stats(cpi, x, 1, &src_var, &rec_var); + + if (src_var <= rec_var) return; + + int64_t var_offset = src_var - rec_var; + + rd_cost->dist += var_offset; + + rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist); +} + +static void adjust_cost(const AV1_COMP *cpi, const MACROBLOCK *x, + int64_t *rd_cost) { + if (cpi->oxcf.algo_cfg.sharpness != 3) return; + + if (frame_is_kf_gf_arf(cpi)) return; + + int64_t src_var, rec_var; + get_variance_stats(cpi, x, 1, &src_var, &rec_var); + + if (src_var <= rec_var) return; + + int64_t var_offset = src_var - rec_var; + + *rd_cost += RDCOST(x->rdmult, 0, var_offset); +} + static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x, int64_t *sse_y) { const AV1_COMMON *cm = &cpi->common; @@ -5501,6 +5572,11 @@ static inline void search_intra_modes_in_interframe( intra_search_state, cpi, x, bsize, intra_ref_frame_cost, ctx, &intra_rd_stats_y, search_state->best_rd, &mode_cost_y, &intra_rd_y, &best_model_rd, top_intra_model_rd); + + if (intra_rd_y < INT64_MAX) { + adjust_cost(cpi, x, &intra_rd_y); + } + if (is_luma_result_valid && intra_rd_y < yrd_threshold) { is_best_y_mode_intra = 1; if (intra_rd_y < best_rd_y) { @@ -5586,6 +5662,8 @@ static inline void search_intra_modes_in_interframe( intra_rd_stats.rdcost = this_rd; + adjust_rdcost(cpi, x, &intra_rd_stats); + // Collect mode stats for multiwinner mode processing const int txfm_search_done = 1; store_winner_mode_stats( @@ -6019,6 +6097,12 @@ void av1_rd_pick_inter_mode(struct AV1_COMP *cpi, struct TileDataEnc *tile_data, args.best_pred_sse = search_state.best_pred_sse; args.skip_ifs = skip_interp_filter_search(cpi, is_single_pred); + if (!frame_is_kf_gf_arf(cpi) && cpi->oxcf.algo_cfg.sharpness == 3) { + if (ref_frame != ALTREF_FRAME && ref_frame != GOLDEN_FRAME && + ref_frame != INTRA_FRAME) + continue; + } + int64_t skip_rd[2] = { search_state.best_skip_rd[0], search_state.best_skip_rd[1] }; int64_t this_yrd = INT64_MAX; @@ -6057,6 +6141,9 @@ void av1_rd_pick_inter_mode(struct AV1_COMP *cpi, struct TileDataEnc *tile_data, ref_frame_rd[ref_frame] = this_rd; } + adjust_cost(cpi, x, &this_rd); + adjust_rdcost(cpi, x, &rd_stats); + // Did this mode help, i.e., is it the new best mode if (this_rd < search_state.best_rd) { assert(IMPLIES(comp_pred, diff --git a/third_party/aom/av1/encoder/speed_features.c b/third_party/aom/av1/encoder/speed_features.c index e7ba5f46aa92..cddedffc699f 100644 --- a/third_party/aom/av1/encoder/speed_features.c +++ b/third_party/aom/av1/encoder/speed_features.c @@ -614,6 +614,10 @@ static void set_good_speed_features_lc_dec_framesize_dependent( (update_type == LF_UPDATE || update_type == OVERLAY_UPDATE || update_type == INTNL_OVERLAY_UPDATE); if (leaf_and_overlay_frames) sf->gm_sf.gm_search_type = GM_DISABLE_SEARCH; + + sf->hl_sf.disable_ref_frame_mvs = 1; + } else if (is_608p_or_larger) { + sf->gm_sf.gm_erroradv_tr_level = 1; } } @@ -2028,6 +2032,7 @@ static inline void init_hl_sf(HIGH_LEVEL_SPEED_FEATURES *hl_sf) { hl_sf->accurate_bit_estimate = 0; hl_sf->weight_calc_level_in_tf = 0; hl_sf->allow_sub_blk_me_in_tf = 0; + hl_sf->disable_ref_frame_mvs = 0; } static inline void init_fp_sf(FIRST_PASS_SPEED_FEATURES *fp_sf) { @@ -2059,6 +2064,7 @@ static inline void init_gm_sf(GLOBAL_MOTION_SPEED_FEATURES *gm_sf) { gm_sf->disable_gm_search_based_on_stats = 0; gm_sf->downsample_level = 0; gm_sf->num_refinement_steps = GM_MAX_REFINEMENT_STEPS; + gm_sf->gm_erroradv_tr_level = 0; } static inline void init_part_sf(PARTITION_SPEED_FEATURES *part_sf) { @@ -2613,6 +2619,29 @@ void av1_set_speed_features_framesize_independent(AV1_COMP *cpi, int speed) { sf->rt_sf.gf_refresh_based_on_qp = 0; } +// Override some speed features for low complexity decode based on qindex. +static void set_speed_features_lc_dec_qindex_dependent( + const AV1_COMP *const cpi, SPEED_FEATURES *const sf, int speed) { + if (speed < 1 || speed > 3) return; + + const AV1_COMMON *const cm = &cpi->common; + const int short_dimension = AOMMIN(cm->width, cm->height); + const int is_720p_or_larger = AOMMIN(cm->width, cm->height) >= 720; + const FRAME_UPDATE_TYPE update_type = + get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index); + const int leaf_and_overlay_frames = + (update_type == LF_UPDATE || update_type == OVERLAY_UPDATE || + update_type == INTNL_OVERLAY_UPDATE); + + if (short_dimension > 480 && short_dimension < 720) { + sf->lpf_sf.min_lr_unit_size = RESTORATION_UNITSIZE_MAX >> 1; + sf->lpf_sf.max_lr_unit_size = RESTORATION_UNITSIZE_MAX >> 1; + } else if (is_720p_or_larger && speed <= 2 && leaf_and_overlay_frames) { + sf->lpf_sf.min_lr_unit_size = RESTORATION_UNITSIZE_MAX >> 1; + sf->lpf_sf.max_lr_unit_size = RESTORATION_UNITSIZE_MAX >> 1; + } +} + // Override some speed features based on qindex void av1_set_speed_features_qindex_dependent(AV1_COMP *cpi, int speed) { AV1_COMMON *const cm = &cpi->common; @@ -2815,4 +2844,7 @@ void av1_set_speed_features_qindex_dependent(AV1_COMP *cpi, int speed) { set_subpel_search_method(&cpi->mv_search_params, cpi->oxcf.unit_test_cfg.motion_vector_unit_test, sf->mv_sf.subpel_search_method); + + if (cpi->oxcf.enable_low_complexity_decode) + set_speed_features_lc_dec_qindex_dependent(cpi, sf, speed); } diff --git a/third_party/aom/av1/encoder/speed_features.h b/third_party/aom/av1/encoder/speed_features.h index 4929076130ac..e7ac791abe2d 100644 --- a/third_party/aom/av1/encoder/speed_features.h +++ b/third_party/aom/av1/encoder/speed_features.h @@ -482,6 +482,11 @@ typedef struct HIGH_LEVEL_SPEED_FEATURES { * 1: Conditionally allow motion estimation based on 4x4 sub-blocks variance. */ int allow_sub_blk_me_in_tf; + + /*! + * Enable/disable temporal mv prediction. + */ + int disable_ref_frame_mvs; } HIGH_LEVEL_SPEED_FEATURES; /*! @@ -592,6 +597,10 @@ typedef struct GLOBAL_MOTION_SPEED_FEATURES { // Number of refinement steps to apply after initial model generation int num_refinement_steps; + + // Error advantage threshold level used to determine whether global motion + // compensation should be enabled + int gm_erroradv_tr_level; } GLOBAL_MOTION_SPEED_FEATURES; typedef struct PARTITION_SPEED_FEATURES { diff --git a/third_party/aom/av1/encoder/svc_layercontext.h b/third_party/aom/av1/encoder/svc_layercontext.h index cbe4304a125c..fdf17480ea37 100644 --- a/third_party/aom/av1/encoder/svc_layercontext.h +++ b/third_party/aom/av1/encoder/svc_layercontext.h @@ -93,6 +93,7 @@ typedef struct SVC { int number_spatial_layers; int number_temporal_layers; int prev_number_spatial_layers; + int prev_number_temporal_layers; int use_flexible_mode; int ksvc_fixed_mode; /*!\endcond */ diff --git a/third_party/aom/av1/encoder/temporal_filter.c b/third_party/aom/av1/encoder/temporal_filter.c index 599b0495f125..7a6c6efee307 100644 --- a/third_party/aom/av1/encoder/temporal_filter.c +++ b/third_party/aom/av1/encoder/temporal_filter.c @@ -184,6 +184,10 @@ static void tf_motion_search(AV1_COMP *cpi, MACROBLOCK *mb, mbd->plane[0].pre[0].buf = ref_frame->y_buffer + y_offset; mbd->plane[0].pre[0].stride = y_stride; mbd->plane[0].pre[0].width = ref_width; + mbd->mi_row = + mb_row * (block_size_high[block_size] / block_size_high[BLOCK_4X4]); + mbd->mi_col = + mb_col * (block_size_wide[block_size] / block_size_wide[BLOCK_4X4]); *is_dc_diff_large = 0; const SEARCH_METHODS search_method = NSTEP; diff --git a/third_party/aom/build/cmake/aom_config_defaults.cmake b/third_party/aom/build/cmake/aom_config_defaults.cmake index b78c9ec98fba..33c19ba488a4 100644 --- a/third_party/aom/build/cmake/aom_config_defaults.cmake +++ b/third_party/aom/build/cmake/aom_config_defaults.cmake @@ -182,6 +182,8 @@ set_aom_config_var(CONFIG_CWG_E050 0 set_aom_config_var(CONFIG_LIBVMAF_PSNR_PEAK 1 "Use libvmaf PSNR peak for 10- and 12-bit") +set_aom_config_var(CONFIG_HIGHWAY 0 "Use Highway for SIMD.") + # # Variables in this section control optional features of the build system. # diff --git a/third_party/aom/test/acm_random.h b/third_party/aom/test/acm_random.h index 6fb6d566ae32..67f7aa862eaf 100644 --- a/third_party/aom/test/acm_random.h +++ b/third_party/aom/test/acm_random.h @@ -40,7 +40,7 @@ class ACMRandom { int16_t Rand16Signed() { return static_cast(Rand16()); } - int16_t Rand15() { + uint16_t Rand15() { const uint32_t value = random_.Generate(testing::internal::Random::kMaxRange); // There's a bit more entropy in the upper bits of this implementation. diff --git a/third_party/aom/test/av1_fwd_txfm2d_test.cc b/third_party/aom/test/av1_fwd_txfm2d_test.cc index 95b2151727e8..8831ca8aedaf 100644 --- a/third_party/aom/test/av1_fwd_txfm2d_test.cc +++ b/third_party/aom/test/av1_fwd_txfm2d_test.cc @@ -246,7 +246,6 @@ void AV1FwdTxfm2dMatchTest(TX_SIZE tx_size, lowbd_fwd_txfm_func target_func) { memset(¶m, 0, sizeof(param)); const int rows = tx_size_high[tx_size]; const int cols = tx_size_wide[tx_size]; - // printf("%d x %d\n", cols, rows); for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) { if (libaom_test::IsTxSizeTypeValid( tx_size, static_cast(tx_type)) == false) { diff --git a/third_party/aom/test/cfl_test.cc b/third_party/aom/test/cfl_test.cc index e093c4e3541b..3f93305005b0 100644 --- a/third_party/aom/test/cfl_test.cc +++ b/third_party/aom/test/cfl_test.cc @@ -175,7 +175,7 @@ class CFLTestWithAlignedData : public CFLTest { typedef cfl_subtract_average_fn (*sub_avg_fn)(TX_SIZE tx_size); typedef std::tuple sub_avg_param; class CFLSubAvgTest : public ::testing::TestWithParam, - public CFLTestWithData { + public CFLTestWithData { public: void SetUp() override { CFLTest::init(std::get<0>(this->GetParam())); @@ -191,27 +191,31 @@ class CFLSubAvgTest : public ::testing::TestWithParam, GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CFLSubAvgTest); TEST_P(CFLSubAvgTest, SubAvgTest) { + int16_t dst[CFL_BUF_SQUARE]; + int16_t dst_ref[CFL_BUF_SQUARE]; for (int it = 0; it < NUM_ITERATIONS; it++) { randData(&ACMRandom::Rand15); - sub_avg((uint16_t *)data, data); - sub_avg_ref((uint16_t *)data_ref, data_ref); - assert_eq(data, data_ref, width, height); + sub_avg(data, dst); + sub_avg_ref(data_ref, dst_ref); + assert_eq(dst, dst_ref, width, height); } } TEST_P(CFLSubAvgTest, DISABLED_SubAvgSpeedTest) { + int16_t dst[CFL_BUF_SQUARE]; + int16_t dst_ref[CFL_BUF_SQUARE]; aom_usec_timer ref_timer; aom_usec_timer timer; randData(&ACMRandom::Rand15); aom_usec_timer_start(&ref_timer); for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { - sub_avg_ref((uint16_t *)data_ref, data_ref); + sub_avg_ref(data_ref, dst_ref); } aom_usec_timer_mark(&ref_timer); int ref_elapsed_time = (int)aom_usec_timer_elapsed(&ref_timer); aom_usec_timer_start(&timer); for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { - sub_avg((uint16_t *)data, data); + sub_avg(data, dst); } aom_usec_timer_mark(&timer); int elapsed_time = (int)aom_usec_timer_elapsed(&timer); @@ -261,13 +265,13 @@ class CFLSubsampleTest : public ::testing::TestWithParam, CFLTestWithData::randData(random); aom_usec_timer_start(&ref_timer); for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { - fun_ref(this->data_ref, CFL_BUF_LINE, sub_luma_pels); + fun_ref(this->data_ref, CFL_BUF_LINE, sub_luma_pels_ref); } aom_usec_timer_mark(&ref_timer); int ref_elapsed_time = (int)aom_usec_timer_elapsed(&ref_timer); aom_usec_timer_start(&timer); for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { - fun(this->data, CFL_BUF_LINE, sub_luma_pels_ref); + fun(this->data, CFL_BUF_LINE, sub_luma_pels); } aom_usec_timer_mark(&timer); int elapsed_time = (int)aom_usec_timer_elapsed(&timer); diff --git a/third_party/aom/test/ratectrl_rtc_test.cc b/third_party/aom/test/ratectrl_rtc_test.cc index 18d46ae947e2..8d7e64258e0f 100644 --- a/third_party/aom/test/ratectrl_rtc_test.cc +++ b/third_party/aom/test/ratectrl_rtc_test.cc @@ -98,15 +98,21 @@ class RcInterfaceTest : public ::libaom_test::EncoderTest, // Go down to 2 temporal layers. SetConfigSvc(3, 2); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); + frame_flags_ = AOM_EFLAG_FORCE_KF; + frame_params_.frame_type = aom::kKeyFrame; ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); } else if (superframe_cnt_ == 200 && layer_id_.spatial_layer_id == 0) { // Go down to 1 temporal layer. SetConfigSvc(3, 1); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); + frame_flags_ = AOM_EFLAG_FORCE_KF; + frame_params_.frame_type = aom::kKeyFrame; ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); } else if (superframe_cnt_ == 300 && layer_id_.spatial_layer_id == 0) { // Go back up to 3 temporal layers. SetConfigSvc(3, 3); + frame_flags_ = AOM_EFLAG_FORCE_KF; + frame_params_.frame_type = aom::kKeyFrame; encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); } @@ -117,11 +123,15 @@ class RcInterfaceTest : public ::libaom_test::EncoderTest, // Change to 2 spatial layers (240p, 480p). SetConfigSvc(2, 3); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); + frame_flags_ = AOM_EFLAG_FORCE_KF; + frame_params_.frame_type = aom::kKeyFrame; ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); } else if (superframe_cnt_ == 200 && layer_id_.spatial_layer_id == 0) { // Change to 1 spatial layer (480p). SetConfigSvc(1, 3); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); + frame_flags_ = AOM_EFLAG_FORCE_KF; + frame_params_.frame_type = aom::kKeyFrame; ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); } else if (superframe_cnt_ == 300 && layer_id_.spatial_layer_id == 0) { // Go back to 3 spatial layers (120p, 240p, 480p). @@ -148,6 +158,10 @@ class RcInterfaceTest : public ::libaom_test::EncoderTest, if (encoder_exit_) { return; } + int num_operating_points; + encoder->Control(AV1E_GET_NUM_OPERATING_POINTS, &num_operating_points); + ASSERT_EQ(num_operating_points, + rc_cfg_.ss_number_layers * rc_cfg_.ts_number_layers); layer_frame_cnt_++; frame_cnt_++; if (layer_id_.spatial_layer_id == rc_cfg_.ss_number_layers - 1) diff --git a/third_party/aom/test/svc_datarate_test.cc b/third_party/aom/test/svc_datarate_test.cc index ab3ffc6d26f5..8732db9c70bf 100644 --- a/third_party/aom/test/svc_datarate_test.cc +++ b/third_party/aom/test/svc_datarate_test.cc @@ -79,9 +79,12 @@ void ScaleForFrameNumber(unsigned int frame, unsigned int initial_w, class ResizingVideoSource : public ::libaom_test::DummyVideoSource { public: - explicit ResizingVideoSource(int external_resize_pattern) { + explicit ResizingVideoSource(int external_resize_pattern, int width, + int height) { external_resize_pattern_ = external_resize_pattern; - SetSize(1280, 720); + top_width_ = width; + top_height_ = height; + SetSize(top_width_, top_height_); limit_ = 300; } ~ResizingVideoSource() override = default; @@ -92,7 +95,7 @@ class ResizingVideoSource : public ::libaom_test::DummyVideoSource { unsigned int width = 0; unsigned int height = 0; libaom_test::ACMRandom rnd(libaom_test::ACMRandom::DeterministicSeed()); - ScaleForFrameNumber(frame_, 1280, 720, &width, &height, + ScaleForFrameNumber(frame_, top_width_, top_height_, &width, &height, external_resize_pattern_); SetSize(width, height); FillFrame(); @@ -104,6 +107,9 @@ class ResizingVideoSource : public ::libaom_test::DummyVideoSource { private: int external_resize_pattern_; + // top_width_/height_ is the configured resolution when codec is created. + int top_width_; + int top_height_; }; class DatarateTestSVC @@ -172,6 +178,7 @@ class DatarateTestSVC use_last_as_scaled_single_ref_ = false; external_resize_dynamic_drop_layer_ = false; external_resize_pattern_ = 0; + dynamic_tl_ = false; } void PreEncodeFrameHook(::libaom_test::VideoSource *video, @@ -309,9 +316,6 @@ class DatarateTestSVC } if (layer_id_.spatial_layer_id == 0 && (video->frame() == 1 || video->frame() == 150)) { - // Set the new top width/height for external resize. - top_sl_width_ = video->img()->d_w; - top_sl_height_ = video->img()->d_h; for (int i = 0; i < 9; ++i) { bitrate_layer_[i] = svc_params_.layer_target_bitrate[i]; } @@ -345,8 +349,6 @@ class DatarateTestSVC encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); } else if (layer_id_.spatial_layer_id == 0 && (video->frame() == 50 || video->frame() == 200)) { - top_sl_width_ = video->img()->d_w; - top_sl_height_ = video->img()->d_h; if (external_resize_pattern_ == 1) { // Input size is 1/2. Change layer bitrates to set top layer to 0. // This will trigger skip encoding/dropping of top spatial layer. @@ -377,8 +379,6 @@ class DatarateTestSVC encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); } else if (layer_id_.spatial_layer_id == 0 && (video->frame() == 100 || video->frame() == 250)) { - top_sl_width_ = video->img()->d_w; - top_sl_height_ = video->img()->d_h; // Input is original size. Change layer bitrates to nonzero for all // layers. cfg_.rc_target_bitrate = @@ -395,6 +395,26 @@ class DatarateTestSVC encoder->Config(&cfg_); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); } + } else if (dynamic_tl_) { + if (video->frame() == 100) { + // Enable 3 temporal layers. + svc_params_.number_temporal_layers = 3; + number_temporal_layers_ = 3; + svc_params_.layer_target_bitrate[0] = 60 * cfg_.rc_target_bitrate / 100; + svc_params_.layer_target_bitrate[1] = 80 * cfg_.rc_target_bitrate / 100; + svc_params_.layer_target_bitrate[2] = cfg_.rc_target_bitrate; + svc_params_.framerate_factor[0] = 4; + svc_params_.framerate_factor[1] = 2; + svc_params_.framerate_factor[2] = 1; + encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); + } else if (video->frame() == 200) { + // Go back to 1 temporal layer. + svc_params_.number_temporal_layers = 1; + number_temporal_layers_ = 1; + svc_params_.layer_target_bitrate[0] = cfg_.rc_target_bitrate; + svc_params_.framerate_factor[0] = 1; + encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); + } } layer_frame_cnt_++; DatarateTest::PreEncodeFrameHook(video, encoder); @@ -2853,9 +2873,47 @@ class DatarateTestSVC cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)]; cfg_.g_w = 1280; cfg_.g_h = 720; - top_sl_width_ = 1280; - top_sl_height_ = 720; - ResizingVideoSource video(1); + ResizingVideoSource video(1, 1280, 720); + ResetModel(); + external_resize_dynamic_drop_layer_ = true; + external_resize_pattern_ = 1; + number_temporal_layers_ = 3; + number_spatial_layers_ = 3; + // SL0 + const int bitrate_sl0 = 1 * cfg_.rc_target_bitrate / 8; + target_layer_bitrate_[0] = 50 * bitrate_sl0 / 100; + target_layer_bitrate_[1] = 70 * bitrate_sl0 / 100; + target_layer_bitrate_[2] = bitrate_sl0; + // SL1 + const int bitrate_sl1 = 3 * cfg_.rc_target_bitrate / 8; + target_layer_bitrate_[3] = 50 * bitrate_sl1 / 100; + target_layer_bitrate_[4] = 70 * bitrate_sl1 / 100; + target_layer_bitrate_[5] = bitrate_sl1; + // SL2 + const int bitrate_sl2 = 4 * cfg_.rc_target_bitrate / 8; + target_layer_bitrate_[6] = 50 * bitrate_sl2 / 100; + target_layer_bitrate_[7] = 70 * bitrate_sl2 / 100; + target_layer_bitrate_[8] = bitrate_sl2; + ASSERT_NO_FATAL_FAILURE(RunLoop(&video)); + } + + virtual void BasicRateTargetingSVC3TL3SLExternalResizePattern1HighResTest() { + cfg_.rc_buf_initial_sz = 500; + cfg_.rc_buf_optimal_sz = 500; + cfg_.rc_buf_sz = 1000; + cfg_.rc_dropframe_thresh = 0; + cfg_.rc_min_quantizer = 0; + cfg_.rc_max_quantizer = 63; + cfg_.rc_end_usage = AOM_CBR; + cfg_.g_lag_in_frames = 0; + cfg_.g_error_resilient = 0; + const int bitrate_array[2] = { 600, 1200 }; + cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)]; + cfg_.g_w = 1850; + cfg_.g_h = 1110; + cfg_.g_forced_max_frame_width = 1850; + cfg_.g_forced_max_frame_height = 1110; + ResizingVideoSource video(1, 1850, 1110); ResetModel(); external_resize_dynamic_drop_layer_ = true; external_resize_pattern_ = 1; @@ -2893,9 +2951,7 @@ class DatarateTestSVC cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)]; cfg_.g_w = 1280; cfg_.g_h = 720; - top_sl_width_ = 1280; - top_sl_height_ = 720; - ResizingVideoSource video(2); + ResizingVideoSource video(2, 1280, 720); ResetModel(); external_resize_dynamic_drop_layer_ = true; external_resize_pattern_ = 2; @@ -2919,6 +2975,70 @@ class DatarateTestSVC ASSERT_NO_FATAL_FAILURE(RunLoop(&video)); } + virtual void BasicRateTargetingSVC3TL3SLExternalResizePattern2HighResTest() { + cfg_.rc_buf_initial_sz = 500; + cfg_.rc_buf_optimal_sz = 500; + cfg_.rc_buf_sz = 1000; + cfg_.rc_dropframe_thresh = 0; + cfg_.rc_min_quantizer = 0; + cfg_.rc_max_quantizer = 63; + cfg_.rc_end_usage = AOM_CBR; + cfg_.g_lag_in_frames = 0; + cfg_.g_error_resilient = 0; + const int bitrate_array[2] = { 600, 1200 }; + cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)]; + cfg_.g_w = 1850; + cfg_.g_h = 1110; + cfg_.g_forced_max_frame_width = 1850; + cfg_.g_forced_max_frame_height = 1110; + ResizingVideoSource video(2, 1850, 1110); + ResetModel(); + external_resize_dynamic_drop_layer_ = true; + external_resize_pattern_ = 2; + number_temporal_layers_ = 3; + number_spatial_layers_ = 3; + // SL0 + const int bitrate_sl0 = 1 * cfg_.rc_target_bitrate / 8; + target_layer_bitrate_[0] = 50 * bitrate_sl0 / 100; + target_layer_bitrate_[1] = 70 * bitrate_sl0 / 100; + target_layer_bitrate_[2] = bitrate_sl0; + // SL1 + const int bitrate_sl1 = 3 * cfg_.rc_target_bitrate / 8; + target_layer_bitrate_[3] = 50 * bitrate_sl1 / 100; + target_layer_bitrate_[4] = 70 * bitrate_sl1 / 100; + target_layer_bitrate_[5] = bitrate_sl1; + // SL2 + const int bitrate_sl2 = 4 * cfg_.rc_target_bitrate / 8; + target_layer_bitrate_[6] = 50 * bitrate_sl2 / 100; + target_layer_bitrate_[7] = 70 * bitrate_sl2 / 100; + target_layer_bitrate_[8] = bitrate_sl2; + ASSERT_NO_FATAL_FAILURE(RunLoop(&video)); + } + + virtual void BasicRateTargetingSVC3TL1SLDynamicTLTest() { + cfg_.rc_buf_initial_sz = 500; + cfg_.rc_buf_optimal_sz = 500; + cfg_.rc_buf_sz = 1000; + cfg_.rc_dropframe_thresh = 0; + cfg_.rc_min_quantizer = 0; + cfg_.rc_max_quantizer = 63; + cfg_.rc_end_usage = AOM_CBR; + cfg_.g_lag_in_frames = 0; + cfg_.g_error_resilient = 0; + ::libaom_test::I420VideoSource video("niklas_640_480_30.yuv", 640, 480, 30, + 1, 0, 400); + const int bitrate_array[2] = { 600, 1200 }; + cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)]; + target_layer_bitrate_[0] = cfg_.rc_target_bitrate; + cfg_.g_w = 640; + cfg_.g_h = 480; + ResetModel(); + number_temporal_layers_ = 1; + number_spatial_layers_ = 1; + dynamic_tl_ = true; + ASSERT_NO_FATAL_FAILURE(RunLoop(&video)); + } + int layer_frame_cnt_; int superframe_cnt_; int number_temporal_layers_; @@ -2961,8 +3081,7 @@ class DatarateTestSVC bool external_resize_dynamic_drop_layer_; int bitrate_layer_[9]; int external_resize_pattern_; - int top_sl_width_; - int top_sl_height_; + bool dynamic_tl_; }; // Check basic rate targeting for CBR, for 3 temporal layers, 1 spatial. @@ -3259,7 +3378,7 @@ TEST_P(DatarateTestSVC, BasicRateTargetingRPS1TL1SLDropFrames) { // and denoiser enabled. The external resizer will resize down and back up, // setting 0/nonzero bitrate on spatial enhancement layers to disable/enable // layers. Resizing starts on first frame and the pattern is: -// 1/4 -> 1/2 -> 1 -> 1/4 -> 1/2. +// 1/4 -> 1/2 -> 1 -> 1/4 -> 1/2. Configured resolution is 1280x720. TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL3SLExternalResizePattern1) { BasicRateTargetingSVC3TL3SLExternalResizePattern1Test(); } @@ -3268,11 +3387,38 @@ TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL3SLExternalResizePattern1) { // and denoiser enabled. The external resizer will resize down and back up, // setting 0/nonzero bitrate on spatial enhancement layers to disable/enable // layers. Resizing starts on first frame and the pattern is: -// 1/2 -> 1/4 -> 1 -> 1/2 -> 1/4. +// 1/4 -> 1/2 -> 1 -> 1/4 -> 1/2. Configured resolution is 1850x1110. +TEST_P(DatarateTestSVC, + BasicRateTargetingSVC3TL3SLExternalResizePattern1HighRes) { + BasicRateTargetingSVC3TL3SLExternalResizePattern1HighResTest(); +} + +// For 1 pass CBR SVC with 3 spatial and 3 temporal layers with external resize +// and denoiser enabled. The external resizer will resize down and back up, +// setting 0/nonzero bitrate on spatial enhancement layers to disable/enable +// layers. Resizing starts on first frame and the pattern is: +// 1/2 -> 1/4 -> 1 -> 1/2 -> 1/4. Configured resolution is 1280x720. TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL3SLExternalResizePattern2) { BasicRateTargetingSVC3TL3SLExternalResizePattern2Test(); } +// For 1 pass CBR SVC with 3 spatial and 3 temporal layers with external resize +// and denoiser enabled. The external resizer will resize down and back up, +// setting 0/nonzero bitrate on spatial enhancement layers to disable/enable +// layers. Resizing starts on first frame and the pattern is: +// 1/2 -> 1/4 -> 1 -> 1/2 -> 1/4. Configured resolution is 1850x1110. +TEST_P(DatarateTestSVC, + BasicRateTargetingSVC3TL3SLExternalResizePattern2HighRes) { + BasicRateTargetingSVC3TL3SLExternalResizePattern2HighResTest(); +} + +// For 1 pass CBR SVC with 1 spatial and dynamic temporal layers. +// Start/initialize with 1 temporal layer and then enable 3 temporal layers +// during the sequence, and then back to 1. +TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL1SLDynamicTL) { + BasicRateTargetingSVC3TL1SLDynamicTLTest(); +} + TEST(SvcParams, BitrateOverflow) { uint8_t buf[6] = { 0 }; aom_image_t img; diff --git a/third_party/aom/third_party/highway/LICENSE-BSD3 b/third_party/aom/third_party/highway/LICENSE-BSD3 new file mode 100644 index 000000000000..51d1bd4b412d --- /dev/null +++ b/third_party/aom/third_party/highway/LICENSE-BSD3 @@ -0,0 +1,26 @@ +Copyright (c) The Highway Project Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/third_party/aom/third_party/highway/README.libaom b/third_party/aom/third_party/highway/README.libaom new file mode 100644 index 000000000000..28fdd580faf6 --- /dev/null +++ b/third_party/aom/third_party/highway/README.libaom @@ -0,0 +1,11 @@ +URL: https://github.com/google/highway + +Version: e92c12750d18c372867809b882dd3ec6874ecc73 +License: BSD-3-clause clear +License File: LICENSE-BSD3 + +Description: +Highway is a C++ library that provides portable SIMD/vector intrinsics. + +Local Changes: +Remove everything except hwy/ and LICENSE-BSD3 diff --git a/third_party/aom/third_party/highway/hwy/abort.h b/third_party/aom/third_party/highway/hwy/abort.h new file mode 100644 index 000000000000..931e9780504b --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/abort.h @@ -0,0 +1,11 @@ +// Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef HIGHWAY_HWY_ABORT_H_ +#define HIGHWAY_HWY_ABORT_H_ + +// Empty header for compatibility. +// All Abort/Warn functionalities are in base.h. + +#endif // HIGHWAY_HWY_ABORT_H_ diff --git a/third_party/aom/third_party/highway/hwy/aligned_allocator.h b/third_party/aom/third_party/highway/hwy/aligned_allocator.h new file mode 100644 index 000000000000..149f18e65c6d --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/aligned_allocator.h @@ -0,0 +1,426 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ +#define HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ + +// Memory allocator with support for alignment and offsets. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/per_target.h" + +namespace hwy { + +// Minimum alignment of allocated memory for use in HWY_ASSUME_ALIGNED, which +// requires a literal. To prevent false sharing, this should be at least the +// L1 cache line size, usually 64 bytes. However, Intel's L2 prefetchers may +// access pairs of lines, and M1 L2 and POWER8 lines are also 128 bytes. +#define HWY_ALIGNMENT 128 + +template +HWY_API constexpr bool IsAligned(T* ptr, size_t align = HWY_ALIGNMENT) { + return reinterpret_cast(ptr) % align == 0; +} + +// Pointers to functions equivalent to malloc/free with an opaque void* passed +// to them. +using AllocPtr = void* (*)(void* opaque, size_t bytes); +using FreePtr = void (*)(void* opaque, void* memory); + +// Returns null or a pointer to at least `payload_size` (which can be zero) +// bytes of newly allocated memory, aligned to the larger of HWY_ALIGNMENT and +// the vector size. Calls `alloc` with the passed `opaque` pointer to obtain +// memory or malloc() if it is null. +HWY_DLLEXPORT void* AllocateAlignedBytes(size_t payload_size, + AllocPtr alloc_ptr = nullptr, + void* opaque_ptr = nullptr); + +// Frees all memory. No effect if `aligned_pointer` == nullptr, otherwise it +// must have been returned from a previous call to `AllocateAlignedBytes`. +// Calls `free_ptr` with the passed `opaque_ptr` pointer to free the memory; if +// `free_ptr` function is null, uses the default free(). +HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer, + FreePtr free_ptr, void* opaque_ptr); + +// Class that deletes the aligned pointer passed to operator() calling the +// destructor before freeing the pointer. This is equivalent to the +// std::default_delete but for aligned objects. For a similar deleter equivalent +// to free() for aligned memory see AlignedFreer(). +class AlignedDeleter { + public: + AlignedDeleter() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedDeleter(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template + void operator()(T* aligned_pointer) const { + return DeleteAlignedArray(aligned_pointer, free_, opaque_ptr_, + TypedArrayDeleter); + } + + private: + template + static void TypedArrayDeleter(void* ptr, size_t size_in_bytes) { + size_t elems = size_in_bytes / sizeof(T); + for (size_t i = 0; i < elems; i++) { + // Explicitly call the destructor on each element. + (static_cast(ptr) + i)->~T(); + } + } + + // Function prototype that calls the destructor for each element in a typed + // array. TypeArrayDeleter would match this prototype. + using ArrayDeleter = void (*)(void* t_ptr, size_t t_size); + + HWY_DLLEXPORT static void DeleteAlignedArray(void* aligned_pointer, + FreePtr free_ptr, + void* opaque_ptr, + ArrayDeleter deleter); + + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to T with custom aligned deleter. This can be a single +// element U or an array of element if T is a U[]. The custom aligned deleter +// will call the destructor on U or each element of a U[] in the array case. +template +using AlignedUniquePtr = std::unique_ptr; + +// Aligned memory equivalent of make_unique using the custom allocators +// alloc/free with the passed `opaque` pointer. This function calls the +// constructor with the passed Args... and calls the destructor of the object +// when the AlignedUniquePtr is destroyed. +template +AlignedUniquePtr MakeUniqueAlignedWithAlloc(AllocPtr alloc, FreePtr free, + void* opaque, Args&&... args) { + T* ptr = static_cast(AllocateAlignedBytes(sizeof(T), alloc, opaque)); + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter(free, opaque)); +} + +// Similar to MakeUniqueAlignedWithAlloc but using the default alloc/free +// functions. +template +AlignedUniquePtr MakeUniqueAligned(Args&&... args) { + T* ptr = static_cast(AllocateAlignedBytes(sizeof(T))); + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter()); +} + +template +struct AlignedAllocator { + using value_type = T; + + AlignedAllocator() = default; + + template + explicit AlignedAllocator(const AlignedAllocator&) noexcept {} + + template + value_type* allocate(V n) { + static_assert(std::is_integral::value, + "AlignedAllocator only supports integer types"); + static_assert(sizeof(V) <= sizeof(std::size_t), + "V n must be smaller or equal size_t to avoid overflow"); + return static_cast( + AllocateAlignedBytes(static_cast(n) * sizeof(value_type))); + } + + template + void deallocate(value_type* p, HWY_MAYBE_UNUSED V n) { + return FreeAlignedBytes(p, nullptr, nullptr); + } +}; + +template +constexpr bool operator==(const AlignedAllocator&, + const AlignedAllocator&) noexcept { + return true; +} + +template +constexpr bool operator!=(const AlignedAllocator&, + const AlignedAllocator&) noexcept { + return false; +} + +template +using AlignedVector = std::vector>; + +// Helpers for array allocators (avoids overflow) +namespace detail { + +// Returns x such that 1u << x == n (if n is a power of two). +static inline constexpr size_t ShiftCount(size_t n) { + return (n <= 1) ? 0 : 1 + ShiftCount(n / 2); +} + +template +T* AllocateAlignedItems(size_t items, AllocPtr alloc_ptr, void* opaque_ptr) { + constexpr size_t kSize = sizeof(T); + + constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0; + constexpr size_t kBits = ShiftCount(kSize); + static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug"); + + const size_t bytes = kIsPow2 ? items << kBits : items * kSize; + const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize; + if (check != items) { + return nullptr; // overflowed + } + return static_cast(AllocateAlignedBytes(bytes, alloc_ptr, opaque_ptr)); +} + +} // namespace detail + +// Aligned memory equivalent of make_unique for array types using the +// custom allocators alloc/free. This function calls the constructor with the +// passed Args... on every created item. The destructor of each element will be +// called when the AlignedUniquePtr is destroyed. +template +AlignedUniquePtr MakeUniqueAlignedArrayWithAlloc( + size_t items, AllocPtr alloc, FreePtr free, void* opaque, Args&&... args) { + T* ptr = detail::AllocateAlignedItems(items, alloc, opaque); + if (ptr != nullptr) { + for (size_t i = 0; i < items; i++) { + new (ptr + i) T(std::forward(args)...); + } + } + return AlignedUniquePtr(ptr, AlignedDeleter(free, opaque)); +} + +template +AlignedUniquePtr MakeUniqueAlignedArray(size_t items, Args&&... args) { + return MakeUniqueAlignedArrayWithAlloc( + items, nullptr, nullptr, nullptr, std::forward(args)...); +} + +// Custom deleter for std::unique_ptr equivalent to using free() as a deleter +// but for aligned memory. +class AlignedFreer { + public: + // Pass address of this to ctor to skip deleting externally-owned memory. + static void DoNothing(void* /*opaque*/, void* /*aligned_pointer*/) {} + + AlignedFreer() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedFreer(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template + void operator()(T* aligned_pointer) const { + FreeAlignedBytes(aligned_pointer, free_, opaque_ptr_); + } + + private: + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to single POD, or (if T is U[]) an array of POD. For non POD +// data use AlignedUniquePtr. +template +using AlignedFreeUniquePtr = std::unique_ptr; + +// Allocate an aligned and uninitialized array of POD values as a unique_ptr. +// Upon destruction of the unique_ptr the aligned array will be freed. +template +AlignedFreeUniquePtr AllocateAligned(const size_t items, AllocPtr alloc, + FreePtr free, void* opaque) { + static_assert(std::is_trivially_copyable::value, + "AllocateAligned: requires trivially copyable T"); + static_assert(std::is_trivially_destructible::value, + "AllocateAligned: requires trivially destructible T"); + return AlignedFreeUniquePtr( + detail::AllocateAlignedItems(items, alloc, opaque), + AlignedFreer(free, opaque)); +} + +// Same as previous AllocateAligned(), using default allocate/free functions. +template +AlignedFreeUniquePtr AllocateAligned(const size_t items) { + return AllocateAligned(items, nullptr, nullptr, nullptr); +} + +// A simple span containing data and size of data. +template +class Span { + public: + Span() = default; + Span(T* data, size_t size) : size_(size), data_(data) {} + template + Span(U u) : Span(u.data(), u.size()) {} + Span(std::initializer_list v) : Span(v.begin(), v.size()) {} + + // Copies the contents of the initializer list to the span. + Span& operator=(std::initializer_list v) { + HWY_DASSERT(size_ == v.size()); + CopyBytes(v.begin(), data_, sizeof(T) * std::min(size_, v.size())); + return *this; + } + + // Returns the size of the contained data. + size_t size() const { return size_; } + + // Returns a pointer to the contained data. + T* data() { return data_; } + T* data() const { return data_; } + + // Returns the element at index. + T& operator[](size_t index) const { return data_[index]; } + + // Returns an iterator pointing to the first element of this span. + T* begin() { return data_; } + + // Returns a const iterator pointing to the first element of this span. + constexpr const T* cbegin() const { return data_; } + + // Returns an iterator pointing just beyond the last element at the + // end of this span. + T* end() { return data_ + size_; } + + // Returns a const iterator pointing just beyond the last element at the + // end of this span. + constexpr const T* cend() const { return data_ + size_; } + + private: + size_t size_ = 0; + T* data_ = nullptr; +}; + +// A multi dimensional array containing an aligned buffer. +// +// To maintain alignment, the innermost dimension will be padded to ensure all +// innermost arrays are aligned. +template +class AlignedNDArray { + static_assert(std::is_trivial::value, + "AlignedNDArray can only contain trivial types"); + + public: + AlignedNDArray(AlignedNDArray&& other) = default; + AlignedNDArray& operator=(AlignedNDArray&& other) = default; + + // Constructs an array of the provided shape and fills it with zeros. + explicit AlignedNDArray(std::array shape) : shape_(shape) { + sizes_ = ComputeSizes(shape_); + memory_shape_ = shape_; + // Round the innermost dimension up to the number of bytes available for + // SIMD operations on this architecture to make sure that each innermost + // array is aligned from the first element. + memory_shape_[axes - 1] = RoundUpTo(memory_shape_[axes - 1], VectorBytes()); + memory_sizes_ = ComputeSizes(memory_shape_); + buffer_ = hwy::AllocateAligned(memory_size()); + hwy::ZeroBytes(buffer_.get(), memory_size() * sizeof(T)); + } + + // Returns a span containing the innermost array at the provided indices. + Span operator[](std::array indices) { + return Span(buffer_.get() + Offset(indices), sizes_[indices.size()]); + } + + // Returns a const span containing the innermost array at the provided + // indices. + Span operator[](std::array indices) const { + return Span(buffer_.get() + Offset(indices), + sizes_[indices.size()]); + } + + // Returns the shape of the array, which might be smaller than the allocated + // buffer after padding the last axis to alignment. + const std::array& shape() const { return shape_; } + + // Returns the shape of the allocated buffer, which might be larger than the + // used size of the array after padding to alignment. + const std::array& memory_shape() const { return memory_shape_; } + + // Returns the size of the array, which might be smaller than the allocated + // buffer after padding the last axis to alignment. + size_t size() const { return sizes_[0]; } + + // Returns the size of the allocated buffer, which might be larger than the + // used size of the array after padding to alignment. + size_t memory_size() const { return memory_sizes_[0]; } + + // Returns a pointer to the allocated buffer. + T* data() { return buffer_.get(); } + + // Returns a const pointer to the buffer. + const T* data() const { return buffer_.get(); } + + // Truncates the array by updating its shape. + // + // The new shape must be equal to or less than the old shape in all axes. + // + // Doesn't modify underlying memory. + void truncate(const std::array& new_shape) { +#if HWY_IS_DEBUG_BUILD + for (size_t axis_index = 0; axis_index < axes; ++axis_index) { + HWY_ASSERT(new_shape[axis_index] <= shape_[axis_index]); + } +#endif + shape_ = new_shape; + sizes_ = ComputeSizes(shape_); + } + + private: + std::array shape_; + std::array memory_shape_; + std::array sizes_; + std::array memory_sizes_; + hwy::AlignedFreeUniquePtr buffer_; + + // Computes offset in the buffer based on the provided indices. + size_t Offset(std::array indices) const { + size_t offset = 0; + size_t shape_index = 0; + for (const size_t axis_index : indices) { + offset += memory_sizes_[shape_index + 1] * axis_index; + shape_index++; + } + return offset; + } + + // Computes the sizes of all sub arrays based on the sizes of each axis. + // + // Does this by multiplying the size of each axis with the previous one in + // reverse order, starting with the conceptual axis of size 1 containing the + // actual elements in the array. + static std::array ComputeSizes( + std::array shape) { + std::array sizes; + size_t axis = shape.size(); + sizes[axis] = 1; + while (axis > 0) { + --axis; + sizes[axis] = sizes[axis + 1] * shape[axis]; + } + return sizes; + } +}; + +} // namespace hwy +#endif // HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ diff --git a/third_party/aom/third_party/highway/hwy/auto_tune.h b/third_party/aom/third_party/highway/hwy/auto_tune.h new file mode 100644 index 000000000000..c94b5eb451b8 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/auto_tune.h @@ -0,0 +1,504 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_AUTO_TUNE_H_ +#define HIGHWAY_HWY_AUTO_TUNE_H_ + +#include +#include +#include // memmove + +#include +#include + +#include "third_party/highway/hwy/aligned_allocator.h" // Span +#include "third_party/highway/hwy/base.h" // HWY_MIN +#include "third_party/highway/hwy/contrib/sort/vqsort.h" + +// Infrastructure for auto-tuning (choosing optimal parameters at runtime). + +namespace hwy { + +// O(1) storage to estimate the central tendency of hundreds of independent +// distributions (one per configuration). The number of samples per distribution +// (`kMinSamples`) varies from few to dozens. We support both by first storing +// values in a buffer, and when full, switching to online variance estimation. +// Modified from `hwy/stats.h`. +class CostDistribution { + public: + static constexpr size_t kMaxValues = 14; // for total size of 128 bytes + + void Notify(const double x) { + if (HWY_UNLIKELY(x < 0.0)) { + HWY_WARN("Ignoring negative cost %f.", x); + return; + } + + // Online phase after filling and warm-up. + if (HWY_LIKELY(IsOnline())) return OnlineNotify(x); + + // Fill phase: store up to `kMaxValues` values. + values_[num_values_++] = x; + HWY_DASSERT(num_values_ <= kMaxValues); + if (HWY_UNLIKELY(num_values_ == kMaxValues)) { + WarmUpOnline(); + HWY_DASSERT(IsOnline()); + } + } + + // Returns an estimate of the true cost, mitigating the impact of noise. + // + // Background and observations from time measurements in `thread_pool.h`: + // - We aim for O(1) storage because there may be hundreds of instances. + // - The mean is biased upwards by mostly additive noise: particularly + // interruptions such as context switches, but also contention. + // - The minimum is not a robust estimator because there are also "lucky + // shots" (1.2-1.6x lower values) where interruptions or contention happen + // to be low. + // - We want to preserve information about contention and a configuration's + // sensitivity to it. Otherwise, we are optimizing for the best-case, not + // the common case. + // - It is still important to minimize the influence of outliers, such as page + // faults, which can cause multiple times larger measurements. + // - Detecting outliers based only on the initial variance is too brittle. If + // the sample is narrow, measurements will fluctuate across runs because + // too many measurements are considered outliers. This would cause the + // 'best' configuration to vary. + // + // Approach: + // - Use Winsorization to reduce the impact of outliers, while preserving + // information on the central tendency. + // - Continually update the thresholds based on the online variance, with + // exponential smoothing for stability. + // - Trim the initial sample via MAD or skewness for a robust estimate of the + // variance. + double EstimateCost() { + if (!IsOnline()) { + WarmUpOnline(); + HWY_DASSERT(IsOnline()); + } + return Mean(); + } + + // Multiplex online state into values_ to allow higher `kMaxValues`. + // Public for inspection in tests. Do not use directly. + double& M1() { return values_[0]; } // Moments for variance. + double& M2() { return values_[1]; } + double& Mean() { return values_[2]; } // Exponential smoothing. + double& Stddev() { return values_[3]; } + double& Lower() { return values_[4]; } + double& Upper() { return values_[5]; } + + private: + static double Median(double* to_sort, size_t n) { + HWY_DASSERT(n >= 2); +// F64 is supported everywhere except Armv7. +#if !HWY_ARCH_ARM_V7 + VQSort(to_sort, n, SortAscending()); +#else + // Values are known to be finite and non-negative, hence sorting as U64 is + // equivalent. + VQSort(reinterpret_cast(to_sort), n, SortAscending()); +#endif + if (n & 1) return to_sort[n / 2]; + // Even length: average of two middle elements. + return (to_sort[n / 2] + to_sort[n / 2 - 1]) * 0.5; + } + + static double MAD(const double* values, size_t n, const double median) { + double abs_dev[kMaxValues]; + for (size_t i = 0; i < n; ++i) { + abs_dev[i] = ScalarAbs(values[i] - median); + } + return Median(abs_dev, n); + } + + // If `num_values_` is large enough, sorts and discards outliers: either via + // MAD, or if too many values are equal, by trimming according to skewness. + void RemoveOutliers() { + if (num_values_ < 3) return; // Not enough to discard two. + HWY_DASSERT(num_values_ <= kMaxValues); + + // Given the noise level in `auto_tune_test`, it can happen that 1/4 of the + // sample is an outlier *in either direction*. Use median absolute + // deviation, which is robust to almost half of the sample being outliers. + const double median = Median(values_, num_values_); // sorts in-place. + const double mad = MAD(values_, num_values_, median); + // At least half the sample is equal. + if (mad == 0.0) { + // Estimate skewness to decide which side to trim more. + const double skewness = + (values_[num_values_ - 1] - median) - (median - values_[0]); + + const size_t trim = HWY_MAX(num_values_ / 2, size_t{2}); + const size_t left = + HWY_MAX(skewness < 0.0 ? trim * 3 / 4 : trim / 4, size_t{1}); + num_values_ -= trim; + HWY_DASSERT(num_values_ >= 1); + memmove(values_, values_ + left, num_values_ * sizeof(values_[0])); + return; + } + + const double upper = median + 5.0 * mad; + const double lower = median - 5.0 * mad; + size_t right = num_values_ - 1; + while (values_[right] > upper) --right; + // Nonzero MAD implies no more than half are equal, so we did not advance + // beyond the median. + HWY_DASSERT(right >= num_values_ / 2); + + size_t left = 0; + while (left < right && values_[left] < lower) ++left; + HWY_DASSERT(left <= num_values_ / 2); + num_values_ = right - left + 1; + memmove(values_, values_ + left, num_values_ * sizeof(values_[0])); + } + + double SampleMean() const { + // Only called in non-online phase, but buffer might not be full. + HWY_DASSERT(!IsOnline() && 0 != num_values_ && num_values_ <= kMaxValues); + double sum = 0.0; + for (size_t i = 0; i < num_values_; ++i) { + sum += values_[i]; + } + return sum / static_cast(num_values_); + } + + // Unbiased estimator for population variance even for small `num_values_`. + double SampleVariance(double sample_mean) const { + HWY_DASSERT(sample_mean >= 0.0); // we checked costs are non-negative. + // Only called in non-online phase, but buffer might not be full. + HWY_DASSERT(!IsOnline() && 0 != num_values_ && num_values_ <= kMaxValues); + if (HWY_UNLIKELY(num_values_ == 1)) return 0.0; // prevent divide-by-zero. + double sum2 = 0.0; + for (size_t i = 0; i < num_values_; ++i) { + const double d = values_[i] - sample_mean; + sum2 += d * d; + } + return sum2 / static_cast(num_values_ - 1); + } + + bool IsOnline() const { return online_n_ > 0.0; } + + void OnlineNotify(double x) { + // Winsorize. + x = HWY_MIN(HWY_MAX(Lower(), x), Upper()); + + // Welford's online variance estimator. + // https://media.thinkbrg.com/wp-content/uploads/2020/06/19094655/720_720_McCrary_ImplementingAlgorithms_Whitepaper_20151119_WEB.pdf#page=7.09 + const double n_minus_1 = online_n_; + online_n_ += 1.0; + const double d = x - M1(); + const double d_div_n = d / online_n_; + M1() += d_div_n; + HWY_DASSERT(M1() >= Lower()); + M2() += d * n_minus_1 * d_div_n; // d^2 * (N-1)/N + // HWY_MAX avoids divide-by-zero. + const double stddev = std::sqrt(M2() / HWY_MAX(1.0, n_minus_1)); + + // Exponential smoothing. + constexpr double kNew = 0.2; // relatively fast update + constexpr double kOld = 1.0 - kNew; + Mean() = M1() * kNew + Mean() * kOld; + Stddev() = stddev * kNew + Stddev() * kOld; + + // Update thresholds from smoothed mean and stddev to enable recovering from + // a too narrow initial range due to excessive trimming. + Lower() = Mean() - 3.5 * Stddev(); + Upper() = Mean() + 3.5 * Stddev(); + } + + void WarmUpOnline() { + RemoveOutliers(); + + // Compute and copy before writing to `M1`, which overwrites `values_`! + const double sample_mean = SampleMean(); + const double sample_variance = SampleVariance(sample_mean); + double copy[kMaxValues]; + hwy::CopyBytes(values_, copy, num_values_ * sizeof(values_[0])); + + M1() = M2() = 0.0; + Mean() = sample_mean; + Stddev() = std::sqrt(sample_variance); + // For single-value or all-equal sample, widen the range, else we will only + // accept the same value. + if (Stddev() == 0.0) Stddev() = Mean() / 2; + + // High tolerance because the distribution is not actually Gaussian, and + // we trimmed up to *half*, and do not want to reject too many values in + // the online phase. + Lower() = Mean() - 4.0 * Stddev(); + Upper() = Mean() + 4.0 * Stddev(); + // Feed copied values into online estimator. + for (size_t i = 0; i < num_values_; ++i) { + OnlineNotify(copy[i]); + } + HWY_DASSERT(IsOnline()); + +#if SIZE_MAX == 0xFFFFFFFFu + (void)padding_; +#endif + } + + size_t num_values_ = 0; // size of `values_` <= `kMaxValues` +#if SIZE_MAX == 0xFFFFFFFFu + uint32_t padding_ = 0; +#endif + + double online_n_ = 0.0; // number of calls to `OnlineNotify`. + + double values_[kMaxValues]; +}; +static_assert(sizeof(CostDistribution) == 128, ""); + +// Implements a counter with wrap-around, plus the ability to skip values. +// O(1) time, O(N) space via doubly-linked list of indices. +class NextWithSkip { + public: + NextWithSkip() {} + explicit NextWithSkip(size_t num) { + links_.reserve(num); + for (size_t i = 0; i < num; ++i) { + links_.emplace_back(i, num); + } + } + + size_t Next(size_t pos) { + HWY_DASSERT(pos < links_.size()); + HWY_DASSERT(!links_[pos].IsRemoved()); + return links_[pos].Next(); + } + + // Must not be called for an already skipped position. Ignores an attempt to + // skip the last remaining position. + void Skip(size_t pos) { + HWY_DASSERT(!links_[pos].IsRemoved()); // not already skipped. + const size_t prev = links_[pos].Prev(); + const size_t next = links_[pos].Next(); + if (prev == pos || next == pos) return; // last remaining position. + links_[next].SetPrev(prev); + links_[prev].SetNext(next); + links_[pos].Remove(); + } + + private: + // Combine prev/next into one array to improve locality/reduce allocations. + class Link { + // Bit-shifts avoid potentially expensive 16-bit loads. Store `next` at the + // top and `prev` at the bottom for extraction with a single shift/AND. + // There may be hundreds of configurations, so 8 bits are not enough. + static constexpr size_t kBits = 14; + static constexpr size_t kShift = 32 - kBits; + static constexpr uint32_t kMaxNum = 1u << kBits; + + public: + Link(size_t pos, size_t num) { + HWY_DASSERT(num < kMaxNum); + const size_t prev = pos == 0 ? num - 1 : pos - 1; + const size_t next = pos == num - 1 ? 0 : pos + 1; + bits_ = + (static_cast(next) << kShift) | static_cast(prev); + HWY_DASSERT(Next() == next && Prev() == prev); + HWY_DASSERT(!IsRemoved()); + } + + bool IsRemoved() const { return (bits_ & kMaxNum) != 0; } + void Remove() { bits_ |= kMaxNum; } + + size_t Next() const { return bits_ >> kShift; } + size_t Prev() const { return bits_ & (kMaxNum - 1); } + + void SetNext(size_t next) { + HWY_DASSERT(next < kMaxNum); + bits_ &= (~0u >> kBits); // clear old next + bits_ |= static_cast(next) << kShift; + HWY_DASSERT(Next() == next); + HWY_DASSERT(!IsRemoved()); + } + void SetPrev(size_t prev) { + HWY_DASSERT(prev < kMaxNum); + bits_ &= ~(kMaxNum - 1); // clear old prev + bits_ |= static_cast(prev); + HWY_DASSERT(Prev() == prev); + HWY_DASSERT(!IsRemoved()); + } + + private: + uint32_t bits_; + }; + std::vector links_; +}; + +// State machine for choosing at runtime the lowest-cost `Config`, which is +// typically a struct containing multiple parameters. For an introduction, see +// "Auto-Tuning and Performance Portability on Heterogeneous Hardware". +// +// **Which parameters** +// Note that simple parameters such as the L2 cache size can be directly queried +// via `hwy/contrib/thread_pool/topology.h`. Difficult to predict parameters +// such as task granularity are more appropriate for auto-tuning. We also +// suggest that at least some parameters should also be 'algorithm variants' +// such as parallel vs. serial, or 2D tiling vs. 1D striping. +// +// **Search strategy** +// To guarantee the optimal result, we use exhaustive search, which is suitable +// for around 10 parameters and a few hundred combinations of 'candidate' +// configurations. +// +// **How to generate candidates** +// To keep this framework simple and generic, applications enumerate the search +// space and pass the list of all feasible candidates to `SetCandidates` before +// the first call to `NextConfig`. Applications should prune the space as much +// as possible, e.g. by upper-bounding parameters based on the known cache +// sizes, and applying constraints such as one being a multiple of another. +// +// **Usage** +// Applications typically conditionally branch to the code implementing the +// configuration returned by `NextConfig`. They measure the cost of running it +// and pass that to `NotifyCost`. Branching avoids the complexity and +// opaqueness of a JIT. The number of branches can be reduced (at the cost of +// code size) by inlining low-level decisions into larger code regions, e.g. by +// hoisting them outside hot loops. +// +// **What is cost** +// Cost is an arbitrary `uint64_t`, with lower values being better. Most +// applications will use the elapsed time. If the tasks being tuned are short, +// it is important to use a high-resolution timer such as `hwy/timer.h`. Energy +// may also be useful [https://www.osti.gov/servlets/purl/1361296]. +// +// **Online vs. offline** +// Although applications can auto-tune once, offline, it may be difficult to +// ensure the stored configuration still applies to the current circumstances. +// Thus we recommend online auto-tuning, re-discovering the configuration on +// each run. We assume the overhead of bookkeeping and measuring cost is +// negligible relative to the actual work. The cost of auto-tuning is then that +// of running sub-optimal configurations. Assuming the best configuration is +// better than baseline, and the work is performed many thousands of times, the +// cost is outweighed by the benefits. +// +// **kMinSamples** +// To further reduce overhead, after `kMinSamples` rounds (= measurements of +// each configuration) we start excluding configurations from further +// measurements if they are sufficiently worse than the current best. +// `kMinSamples` can be several dozen when the tasks being tuned take a few +// microseconds. Even for longer tasks, it should be at least 2 for some noise +// tolerance. After this, there are another `kMinSamples / 2 + 1` rounds before +// declaring the winner. +template +class AutoTune { + public: + // Returns non-null best configuration if auto-tuning has already finished. + // Otherwise, callers continue calling `NextConfig` and `NotifyCost`. + // Points into `Candidates()`. + const Config* Best() const { return best_; } + + // If false, caller must call `SetCandidates` before `NextConfig`. + bool HasCandidates() const { + HWY_DASSERT(!Best()); + return !candidates_.empty(); + } + // WARNING: invalidates `Best()`, do not call if that is non-null. + void SetCandidates(std::vector candidates) { + HWY_DASSERT(!Best() && !HasCandidates()); + candidates_.swap(candidates); + HWY_DASSERT(HasCandidates()); + costs_.resize(candidates_.size()); + list_ = NextWithSkip(candidates_.size()); + } + + // Typically called after Best() is non-null to compare all candidates' costs. + Span Candidates() const { + HWY_DASSERT(HasCandidates()); + return Span(candidates_.data(), candidates_.size()); + } + Span Costs() { + return Span(costs_.data(), costs_.size()); + } + + // Returns the current `Config` to measure. + const Config& NextConfig() const { + HWY_DASSERT(!Best() && HasCandidates()); + return candidates_[config_idx_]; + } + + // O(1) except at the end of each round, which is O(N). + void NotifyCost(uint64_t cost) { + HWY_DASSERT(!Best() && HasCandidates()); + + costs_[config_idx_].Notify(static_cast(cost)); + // Save now before we update `config_idx_`. + const size_t my_idx = config_idx_; + // Only retrieve once we have enough samples, otherwise, we switch to + // online variance before the buffer is populated. + const double my_cost = rounds_complete_ >= kMinSamples + ? costs_[config_idx_].EstimateCost() + : 0.0; + + // Advance to next non-skipped config with wrap-around. This decorrelates + // measurements by not immediately re-measuring the same config. + config_idx_ = list_.Next(config_idx_); + // Might still equal `my_idx` if this is the only non-skipped config. + + // Disqualify from future `NextConfig` if cost was too far beyond the + // current best. This reduces the number of measurements, while tolerating + // noise in the first few measurements. Must happen after advancing. + if (my_cost > skip_if_above_) { + list_.Skip(my_idx); + } + + // Wrap-around indicates the round is complete. + if (HWY_UNLIKELY(config_idx_ <= my_idx)) { + ++rounds_complete_; + + // Enough samples for stable estimates: update the thresholds. + if (rounds_complete_ >= kMinSamples) { + double best_cost = HighestValue(); + size_t idx_min = 0; + for (size_t i = 0; i < candidates_.size(); ++i) { + const double estimate = costs_[i].EstimateCost(); + if (estimate < best_cost) { + best_cost = estimate; + idx_min = i; + } + } + skip_if_above_ = best_cost * 1.25; + + // After sufficient rounds, declare the winner. + if (HWY_UNLIKELY(rounds_complete_ == 3 * kMinSamples / 2 + 1)) { + best_ = &candidates_[idx_min]; + HWY_DASSERT(Best()); + } + } + } + } + + // Avoid printing during the first few rounds, because those might be noisy + // and not yet skipped. + bool ShouldPrint() { return rounds_complete_ > kMinSamples; } + + private: + const Config* best_ = nullptr; + std::vector candidates_; + std::vector costs_; // one per candidate + size_t config_idx_ = 0; // [0, candidates_.size()) + NextWithSkip list_; + size_t rounds_complete_ = 0; + + double skip_if_above_ = 0.0; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_AUTO_TUNE_H_ diff --git a/third_party/aom/third_party/highway/hwy/base.h b/third_party/aom/third_party/highway/hwy/base.h new file mode 100644 index 000000000000..54b71c7e123b --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/base.h @@ -0,0 +1,3218 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_BASE_H_ +#define HIGHWAY_HWY_BASE_H_ + +// Target-independent definitions. + +// IWYU pragma: begin_exports +#include +#include +#if defined(HWY_HEADER_ONLY) +#include +#include +#endif + +#if !defined(HWY_NO_LIBCXX) +#include +#endif + +#include "third_party/highway/hwy/detect_compiler_arch.h" +#include "third_party/highway/hwy/highway_export.h" + +// API version (https://semver.org/); keep in sync with CMakeLists.txt. +#define HWY_MAJOR 1 +#define HWY_MINOR 2 +#define HWY_PATCH 0 + +// True if the Highway version >= major.minor.0. Added in 1.2.0. +#define HWY_VERSION_GE(major, minor) \ + (HWY_MAJOR > (major) || (HWY_MAJOR == (major) && HWY_MINOR >= (minor))) +// True if the Highway version < major.minor.0. Added in 1.2.0. +#define HWY_VERSION_LT(major, minor) \ + (HWY_MAJOR < (major) || (HWY_MAJOR == (major) && HWY_MINOR < (minor))) + +// "IWYU pragma: keep" does not work for these includes, so hide from the IDE. +#if !HWY_IDE + +#if !defined(HWY_NO_LIBCXX) +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include +#endif + +#if (HWY_ARCH_X86 && !defined(HWY_NO_LIBCXX)) || HWY_COMPILER_MSVC +#include +#endif + +#endif // !HWY_IDE + +#ifndef HWY_HAVE_COMPARE_HEADER // allow override +#define HWY_HAVE_COMPARE_HEADER 0 +#if defined(__has_include) // note: wrapper macro fails on Clang ~17 +#if __has_include() +#undef HWY_HAVE_COMPARE_HEADER +#define HWY_HAVE_COMPARE_HEADER 1 +#endif // __has_include +#endif // defined(__has_include) +#endif // HWY_HAVE_COMPARE_HEADER + +#ifndef HWY_HAVE_CXX20_THREE_WAY_COMPARE // allow override +#if !defined(HWY_NO_LIBCXX) && defined(__cpp_impl_three_way_comparison) && \ + __cpp_impl_three_way_comparison >= 201907L && HWY_HAVE_COMPARE_HEADER +#include +#define HWY_HAVE_CXX20_THREE_WAY_COMPARE 1 +#else +#define HWY_HAVE_CXX20_THREE_WAY_COMPARE 0 +#endif +#endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE + +// IWYU pragma: end_exports + +#if HWY_COMPILER_MSVC +#include // memcpy +#endif + +//------------------------------------------------------------------------------ +// Compiler-specific definitions + +#define HWY_STR_IMPL(macro) #macro +#define HWY_STR(macro) HWY_STR_IMPL(macro) + +#if HWY_COMPILER_MSVC + +#include + +#define HWY_FUNCTION __FUNCSIG__ // function name + template args +#define HWY_RESTRICT __restrict +#define HWY_INLINE __forceinline +#define HWY_NOINLINE __declspec(noinline) +#define HWY_FLATTEN +#define HWY_NORETURN __declspec(noreturn) +#define HWY_LIKELY(expr) (expr) +#define HWY_UNLIKELY(expr) (expr) +#define HWY_UNREACHABLE __assume(false) +#define HWY_PRAGMA(tokens) __pragma(tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(warning(tokens)) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(msc) +#define HWY_MAYBE_UNUSED +#define HWY_HAS_ASSUME_ALIGNED 0 +#if (_MSC_VER >= 1700) +#define HWY_MUST_USE_RESULT _Check_return_ +#else +#define HWY_MUST_USE_RESULT +#endif + +#else + +#define HWY_FUNCTION __PRETTY_FUNCTION__ // function name + template args +#define HWY_RESTRICT __restrict__ +// force inlining without optimization enabled creates very inefficient code +// that can cause compiler timeout +#ifdef __OPTIMIZE__ +#define HWY_INLINE inline __attribute__((always_inline)) +#else +#define HWY_INLINE inline +#endif +#define HWY_NOINLINE __attribute__((noinline)) +#define HWY_FLATTEN __attribute__((flatten)) +#define HWY_NORETURN __attribute__((noreturn)) +#define HWY_LIKELY(expr) __builtin_expect(!!(expr), 1) +#define HWY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#if HWY_COMPILER_GCC || HWY_HAS_BUILTIN(__builtin_unreachable) +#define HWY_UNREACHABLE __builtin_unreachable() +#else +#define HWY_UNREACHABLE +#endif +#define HWY_PRAGMA(tokens) _Pragma(#tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(GCC diagnostic tokens) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(gcc) +// Encountered "attribute list cannot appear here" when using the C++17 +// [[maybe_unused]], so only use the old style attribute for now. +#define HWY_MAYBE_UNUSED __attribute__((unused)) +#define HWY_MUST_USE_RESULT __attribute__((warn_unused_result)) + +#endif // !HWY_COMPILER_MSVC + +//------------------------------------------------------------------------------ +// Builtin/attributes (no more #include after this point due to namespace!) + +namespace hwy { + +// Enables error-checking of format strings. +#if HWY_HAS_ATTRIBUTE(__format__) +#define HWY_FORMAT(idx_fmt, idx_arg) \ + __attribute__((__format__(__printf__, idx_fmt, idx_arg))) +#else +#define HWY_FORMAT(idx_fmt, idx_arg) +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* HWY_RESTRICT aligned = (float*)HWY_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if HWY_HAS_BUILTIN(__builtin_assume_aligned) +#define HWY_ASSUME_ALIGNED(ptr, align) __builtin_assume_aligned((ptr), (align)) +#else +#define HWY_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ +#endif + +// Returns a pointer whose type is `type` (T*), while allowing the compiler to +// assume that the untyped pointer `ptr` is aligned to a multiple of sizeof(T). +#define HWY_RCAST_ALIGNED(type, ptr) \ + reinterpret_cast( \ + HWY_ASSUME_ALIGNED((ptr), alignof(hwy::RemovePtr))) + +// Clang and GCC require attributes on each function into which SIMD intrinsics +// are inlined. Support both per-function annotation (HWY_ATTR) for lambdas and +// automatic annotation via pragmas. +#if HWY_COMPILER_ICC +// As of ICC 2021.{1-9} the pragma is neither implemented nor required. +#define HWY_PUSH_ATTRIBUTES(targets_str) +#define HWY_POP_ATTRIBUTES +#elif HWY_COMPILER_CLANG +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(clang attribute push(__attribute__((target(targets_str))), \ + apply_to = function)) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(clang attribute pop) +#elif HWY_COMPILER_GCC_ACTUAL +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(GCC push_options) HWY_PRAGMA(GCC target targets_str) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(GCC pop_options) +#else +#define HWY_PUSH_ATTRIBUTES(targets_str) +#define HWY_POP_ATTRIBUTES +#endif + +//------------------------------------------------------------------------------ +// Macros + +#define HWY_API static HWY_INLINE HWY_FLATTEN HWY_MAYBE_UNUSED + +#define HWY_CONCAT_IMPL(a, b) a##b +#define HWY_CONCAT(a, b) HWY_CONCAT_IMPL(a, b) + +#define HWY_MIN(a, b) ((a) < (b) ? (a) : (b)) +#define HWY_MAX(a, b) ((a) > (b) ? (a) : (b)) + +#if HWY_COMPILER_GCC_ACTUAL +// nielskm: GCC does not support '#pragma GCC unroll' without the factor. +#define HWY_UNROLL(factor) HWY_PRAGMA(GCC unroll factor) +#define HWY_DEFAULT_UNROLL HWY_UNROLL(4) +#elif HWY_COMPILER_CLANG || HWY_COMPILER_ICC || HWY_COMPILER_ICX +#define HWY_UNROLL(factor) HWY_PRAGMA(unroll factor) +#define HWY_DEFAULT_UNROLL HWY_UNROLL() +#else +#define HWY_UNROLL(factor) +#define HWY_DEFAULT_UNROLL +#endif + +// Tell a compiler that the expression always evaluates to true. +// The expression should be free from any side effects. +// Some older compilers may have trouble with complex expressions, therefore +// it is advisable to split multiple conditions into separate assume statements, +// and manually check the generated code. +// OK but could fail: +// HWY_ASSUME(x == 2 && y == 3); +// Better: +// HWY_ASSUME(x == 2); +// HWY_ASSUME(y == 3); +#if (HWY_CXX_LANG >= 202302L) && HWY_HAS_CPP_ATTRIBUTE(assume) +#define HWY_ASSUME(expr) [[assume(expr)]] +#elif HWY_COMPILER_MSVC || HWY_COMPILER_ICC +#define HWY_ASSUME(expr) __assume(expr) +// __builtin_assume() was added in clang 3.6. +#elif HWY_COMPILER_CLANG && HWY_HAS_BUILTIN(__builtin_assume) +#define HWY_ASSUME(expr) __builtin_assume(expr) +// __builtin_unreachable() was added in GCC 4.5, but __has_builtin() was added +// later, so check for the compiler version directly. +#elif HWY_COMPILER_GCC_ACTUAL >= 405 +#define HWY_ASSUME(expr) \ + ((expr) ? static_cast(0) : __builtin_unreachable()) +#else +#define HWY_ASSUME(expr) static_cast(0) +#endif + +// Compile-time fence to prevent undesirable code reordering. On Clang x86, the +// typical asm volatile("" : : : "memory") has no effect, whereas atomic fence +// does, without generating code. +#if HWY_ARCH_X86 && !defined(HWY_NO_LIBCXX) +#define HWY_FENCE std::atomic_thread_fence(std::memory_order_acq_rel) +#else +// TODO(janwas): investigate alternatives. On Arm, the above generates barriers. +#define HWY_FENCE +#endif + +// 4 instances of a given literal value, useful as input to LoadDup128. +#define HWY_REP4(literal) literal, literal, literal, literal + +//------------------------------------------------------------------------------ +// Abort / Warn + +#if defined(HWY_HEADER_ONLY) +HWY_DLLEXPORT inline void HWY_FORMAT(3, 4) + Warn(const char* file, int line, const char* format, ...) { + char buf[800]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + fprintf(stderr, "Warn at %s:%d: %s\n", file, line, buf); +} + +HWY_DLLEXPORT HWY_NORETURN inline void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...) { + char buf[800]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + fprintf(stderr, "Abort at %s:%d: %s\n", file, line, buf); + + fflush(stderr); + +// Now terminate the program: +#if HWY_ARCH_RISCV + exit(1); // trap/abort just freeze Spike. +#else + abort(); // Compile error without this due to HWY_NORETURN. +#endif +} +#else // !HWY_HEADER_ONLY +// Interfaces for custom Warn/Abort handlers. +typedef void (*WarnFunc)(const char* file, int line, const char* message); + +typedef void (*AbortFunc)(const char* file, int line, const char* message); + +// Returns current Warn() handler, or nullptr if no handler was yet registered, +// indicating Highway should print to stderr. +// DEPRECATED because this is thread-hostile and prone to misuse (modifying the +// underlying pointer through the reference). +HWY_DLLEXPORT WarnFunc& GetWarnFunc(); + +// Returns current Abort() handler, or nullptr if no handler was yet registered, +// indicating Highway should print to stderr and abort. +// DEPRECATED because this is thread-hostile and prone to misuse (modifying the +// underlying pointer through the reference). +HWY_DLLEXPORT AbortFunc& GetAbortFunc(); + +// Sets a new Warn() handler and returns the previous handler, which is nullptr +// if no previous handler was registered, and should otherwise be called from +// the new handler. Thread-safe. +HWY_DLLEXPORT WarnFunc SetWarnFunc(WarnFunc func); + +// Sets a new Abort() handler and returns the previous handler, which is nullptr +// if no previous handler was registered, and should otherwise be called from +// the new handler. If all handlers return, then Highway will terminate the app. +// Thread-safe. +HWY_DLLEXPORT AbortFunc SetAbortFunc(AbortFunc func); + +HWY_DLLEXPORT void HWY_FORMAT(3, 4) + Warn(const char* file, int line, const char* format, ...); + +HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...); + +#endif // HWY_HEADER_ONLY + +#define HWY_WARN(format, ...) \ + ::hwy::Warn(__FILE__, __LINE__, format, ##__VA_ARGS__) + +#define HWY_ABORT(format, ...) \ + ::hwy::Abort(__FILE__, __LINE__, format, ##__VA_ARGS__) + +// Always enabled. +#define HWY_ASSERT_M(condition, msg) \ + do { \ + if (!(condition)) { \ + HWY_ABORT("Assert %s: %s", #condition, msg); \ + } \ + } while (0) +#define HWY_ASSERT(condition) HWY_ASSERT_M(condition, "") + +#if HWY_HAS_FEATURE(memory_sanitizer) || defined(MEMORY_SANITIZER) || \ + defined(__SANITIZE_MEMORY__) +#define HWY_IS_MSAN 1 +#else +#define HWY_IS_MSAN 0 +#endif + +#if HWY_HAS_FEATURE(address_sanitizer) || defined(ADDRESS_SANITIZER) || \ + defined(__SANITIZE_ADDRESS__) +#define HWY_IS_ASAN 1 +#else +#define HWY_IS_ASAN 0 +#endif + +#if HWY_HAS_FEATURE(hwaddress_sanitizer) || defined(HWADDRESS_SANITIZER) || \ + defined(__SANITIZE_HWADDRESS__) +#define HWY_IS_HWASAN 1 +#else +#define HWY_IS_HWASAN 0 +#endif + +#if HWY_HAS_FEATURE(thread_sanitizer) || defined(THREAD_SANITIZER) || \ + defined(__SANITIZE_THREAD__) +#define HWY_IS_TSAN 1 +#else +#define HWY_IS_TSAN 0 +#endif + +#if HWY_HAS_FEATURE(undefined_behavior_sanitizer) || \ + defined(UNDEFINED_BEHAVIOR_SANITIZER) +#define HWY_IS_UBSAN 1 +#else +#define HWY_IS_UBSAN 0 +#endif + +// MSAN may cause lengthy build times or false positives e.g. in AVX3 DemoteTo. +// You can disable MSAN by adding this attribute to the function that fails. +#if HWY_IS_MSAN +#define HWY_ATTR_NO_MSAN __attribute__((no_sanitize_memory)) +#else +#define HWY_ATTR_NO_MSAN +#endif + +#if HWY_IS_ASAN || HWY_IS_HWASAN || HWY_IS_MSAN || HWY_IS_TSAN || HWY_IS_UBSAN +#define HWY_IS_SANITIZER 1 +#else +#define HWY_IS_SANITIZER 0 +#endif + +// For enabling HWY_DASSERT and shortening tests in slower debug builds +#if !defined(HWY_IS_DEBUG_BUILD) +// Clang does not define NDEBUG, but it and GCC define __OPTIMIZE__, and recent +// MSVC defines NDEBUG (if not, could instead check _DEBUG). +#if (!defined(__OPTIMIZE__) && !defined(NDEBUG)) || HWY_IS_SANITIZER || \ + defined(__clang_analyzer__) +#define HWY_IS_DEBUG_BUILD 1 +#else +#define HWY_IS_DEBUG_BUILD 0 +#endif +#endif // HWY_IS_DEBUG_BUILD + +#if HWY_IS_DEBUG_BUILD +#define HWY_DASSERT_M(condition, msg) HWY_ASSERT_M(condition, msg) +#define HWY_DASSERT(condition) HWY_ASSERT_M(condition, "") +#else +#define HWY_DASSERT_M(condition, msg) \ + do { \ + } while (0) +#define HWY_DASSERT(condition) \ + do { \ + } while (0) +#endif + +//------------------------------------------------------------------------------ +// CopyBytes / ZeroBytes + +#if HWY_COMPILER_MSVC +#pragma intrinsic(memcpy) +#pragma intrinsic(memset) +#endif + +template +HWY_API void CopyBytes(const From* HWY_RESTRICT from, To* HWY_RESTRICT to) { +#if HWY_COMPILER_MSVC + memcpy(to, from, kBytes); +#else + __builtin_memcpy(to, from, kBytes); +#endif +} + +HWY_API void CopyBytes(const void* HWY_RESTRICT from, void* HWY_RESTRICT to, + size_t num_of_bytes_to_copy) { +#if HWY_COMPILER_MSVC + memcpy(to, from, num_of_bytes_to_copy); +#else + __builtin_memcpy(to, from, num_of_bytes_to_copy); +#endif +} + +// Same as CopyBytes, but for same-sized objects; avoids a size argument. +template +HWY_API void CopySameSize(const From* HWY_RESTRICT from, To* HWY_RESTRICT to) { + static_assert(sizeof(From) == sizeof(To), ""); + CopyBytes(from, to); +} + +template +HWY_API void ZeroBytes(To* to) { +#if HWY_COMPILER_MSVC + memset(to, 0, kBytes); +#else + __builtin_memset(to, 0, kBytes); +#endif +} + +HWY_API void ZeroBytes(void* to, size_t num_bytes) { +#if HWY_COMPILER_MSVC + memset(to, 0, num_bytes); +#else + __builtin_memset(to, 0, num_bytes); +#endif +} + +//------------------------------------------------------------------------------ +// kMaxVectorSize (undocumented, pending removal) + +#if HWY_ARCH_X86 +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 64; // AVX-512 +#elif HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ + __riscv_v_intrinsic >= 11000 +// Not actually an upper bound on the size. +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 4096; +#else +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; +#endif + +//------------------------------------------------------------------------------ +// Alignment + +// Potentially useful for LoadDup128 and capped vectors. In other cases, arrays +// should be allocated dynamically via aligned_allocator.h because Lanes() may +// exceed the stack size. +#if HWY_ARCH_X86 +#define HWY_ALIGN_MAX alignas(64) +#elif HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ + __riscv_v_intrinsic >= 11000 +#define HWY_ALIGN_MAX alignas(8) // only elements need be aligned +#else +#define HWY_ALIGN_MAX alignas(16) +#endif + +//------------------------------------------------------------------------------ +// Lane types + +// hwy::float16_t and hwy::bfloat16_t are forward declared here to allow +// BitCastScalar to be implemented before the implementations of the +// hwy::float16_t and hwy::bfloat16_t types +struct float16_t; +struct bfloat16_t; + +using float32_t = float; +using float64_t = double; + +#pragma pack(push, 1) + +// Aligned 128-bit type. Cannot use __int128 because clang doesn't yet align it: +// https://reviews.llvm.org/D86310 +struct alignas(16) uint128_t { + uint64_t lo; // little-endian layout + uint64_t hi; +}; + +// 64 bit key plus 64 bit value. Faster than using uint128_t when only the key +// field is to be compared (Lt128Upper instead of Lt128). +struct alignas(16) K64V64 { + uint64_t value; // little-endian layout + uint64_t key; +}; + +// 32 bit key plus 32 bit value. Allows vqsort recursions to terminate earlier +// than when considering both to be a 64-bit key. +struct alignas(8) K32V32 { + uint32_t value; // little-endian layout + uint32_t key; +}; + +#pragma pack(pop) + +static inline HWY_MAYBE_UNUSED bool operator<(const uint128_t& a, + const uint128_t& b) { + return (a.hi == b.hi) ? a.lo < b.lo : a.hi < b.hi; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const uint128_t& a, + const uint128_t& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const uint128_t& a, + const uint128_t& b) { + return a.lo == b.lo && a.hi == b.hi; +} + +#if !defined(HWY_NO_LIBCXX) +static inline HWY_MAYBE_UNUSED std::ostream& operator<<(std::ostream& os, + const uint128_t& n) { + return os << "[hi=" << n.hi << ",lo=" << n.lo << "]"; +} +#endif + +static inline HWY_MAYBE_UNUSED bool operator<(const K64V64& a, + const K64V64& b) { + return a.key < b.key; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const K64V64& a, + const K64V64& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const K64V64& a, + const K64V64& b) { + return a.key == b.key; +} + +#if !defined(HWY_NO_LIBCXX) +static inline HWY_MAYBE_UNUSED std::ostream& operator<<(std::ostream& os, + const K64V64& n) { + return os << "[k=" << n.key << ",v=" << n.value << "]"; +} +#endif + +static inline HWY_MAYBE_UNUSED bool operator<(const K32V32& a, + const K32V32& b) { + return a.key < b.key; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const K32V32& a, + const K32V32& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const K32V32& a, + const K32V32& b) { + return a.key == b.key; +} + +#if !defined(HWY_NO_LIBCXX) +static inline HWY_MAYBE_UNUSED std::ostream& operator<<(std::ostream& os, + const K32V32& n) { + return os << "[k=" << n.key << ",v=" << n.value << "]"; +} +#endif + +//------------------------------------------------------------------------------ +// Controlling overload resolution (SFINAE) + +template +struct EnableIfT {}; +template <> +struct EnableIfT { + using type = void; +}; + +template +using EnableIf = typename EnableIfT::type; + +template +struct IsSameT { + enum { value = 0 }; +}; + +template +struct IsSameT { + enum { value = 1 }; +}; + +template +HWY_API constexpr bool IsSame() { + return IsSameT::value; +} + +// Returns whether T matches either of U1 or U2 +template +HWY_API constexpr bool IsSameEither() { + return IsSameT::value || IsSameT::value; +} + +template +struct IfT { + using type = Then; +}; + +template +struct IfT { + using type = Else; +}; + +template +using If = typename IfT::type; + +template +struct IsConstT { + enum { value = 0 }; +}; + +template +struct IsConstT { + enum { value = 1 }; +}; + +template +HWY_API constexpr bool IsConst() { + return IsConstT::value; +} + +template +struct RemoveConstT { + using type = T; +}; +template +struct RemoveConstT { + using type = T; +}; + +template +using RemoveConst = typename RemoveConstT::type; + +template +struct RemoveVolatileT { + using type = T; +}; +template +struct RemoveVolatileT { + using type = T; +}; + +template +using RemoveVolatile = typename RemoveVolatileT::type; + +template +struct RemoveRefT { + using type = T; +}; +template +struct RemoveRefT { + using type = T; +}; +template +struct RemoveRefT { + using type = T; +}; + +template +using RemoveRef = typename RemoveRefT::type; + +template +using RemoveCvRef = RemoveConst>>; + +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; + +template +using RemovePtr = typename RemovePtrT::type; + +// Insert into template/function arguments to enable this overload only for +// vectors of exactly, at most (LE), or more than (GT) this many bytes. +// +// As an example, checking for a total size of 16 bytes will match both +// Simd and Simd. +#define HWY_IF_V_SIZE(T, kN, bytes) \ + hwy::EnableIf* = nullptr +#define HWY_IF_V_SIZE_LE(T, kN, bytes) \ + hwy::EnableIf* = nullptr +#define HWY_IF_V_SIZE_GT(T, kN, bytes) \ + hwy::EnableIf<(kN * sizeof(T) > bytes)>* = nullptr + +#define HWY_IF_LANES(kN, lanes) hwy::EnableIf<(kN == lanes)>* = nullptr +#define HWY_IF_LANES_LE(kN, lanes) hwy::EnableIf<(kN <= lanes)>* = nullptr +#define HWY_IF_LANES_GT(kN, lanes) hwy::EnableIf<(kN > lanes)>* = nullptr + +#define HWY_IF_UNSIGNED(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_UNSIGNED(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_SIGNED(T) \ + hwy::EnableIf() && !hwy::IsFloat() && \ + !hwy::IsSpecialFloat()>* = nullptr +#define HWY_IF_FLOAT(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_FLOAT(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_FLOAT3264(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_FLOAT3264(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_SPECIAL_FLOAT(T) \ + hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_SPECIAL_FLOAT(T) \ + hwy::EnableIf()>* = nullptr +#define HWY_IF_FLOAT_OR_SPECIAL(T) \ + hwy::EnableIf() || hwy::IsSpecialFloat()>* = nullptr +#define HWY_IF_NOT_FLOAT_NOR_SPECIAL(T) \ + hwy::EnableIf() && !hwy::IsSpecialFloat()>* = nullptr +#define HWY_IF_INTEGER(T) hwy::EnableIf()>* = nullptr + +#define HWY_IF_T_SIZE(T, bytes) hwy::EnableIf* = nullptr +#define HWY_IF_NOT_T_SIZE(T, bytes) \ + hwy::EnableIf* = nullptr +// bit_array = 0x102 means 1 or 8 bytes. There is no NONE_OF because it sounds +// too similar. If you want the opposite of this (2 or 4 bytes), ask for those +// bits explicitly (0x14) instead of attempting to 'negate' 0x102. +#define HWY_IF_T_SIZE_ONE_OF(T, bit_array) \ + hwy::EnableIf<((size_t{1} << sizeof(T)) & (bit_array)) != 0>* = nullptr +#define HWY_IF_T_SIZE_LE(T, bytes) \ + hwy::EnableIf<(sizeof(T) <= (bytes))>* = nullptr +#define HWY_IF_T_SIZE_GT(T, bytes) \ + hwy::EnableIf<(sizeof(T) > (bytes))>* = nullptr + +#define HWY_IF_SAME(T, expected) \ + hwy::EnableIf, expected>()>* = nullptr +#define HWY_IF_NOT_SAME(T, expected) \ + hwy::EnableIf, expected>()>* = nullptr + +// One of two expected types +#define HWY_IF_SAME2(T, expected1, expected2) \ + hwy::EnableIf< \ + hwy::IsSameEither, expected1, expected2>()>* = \ + nullptr + +#define HWY_IF_U8(T) HWY_IF_SAME(T, uint8_t) +#define HWY_IF_U16(T) HWY_IF_SAME(T, uint16_t) +#define HWY_IF_U32(T) HWY_IF_SAME(T, uint32_t) +#define HWY_IF_U64(T) HWY_IF_SAME(T, uint64_t) + +#define HWY_IF_I8(T) HWY_IF_SAME(T, int8_t) +#define HWY_IF_I16(T) HWY_IF_SAME(T, int16_t) +#define HWY_IF_I32(T) HWY_IF_SAME(T, int32_t) +#define HWY_IF_I64(T) HWY_IF_SAME(T, int64_t) + +#define HWY_IF_BF16(T) HWY_IF_SAME(T, hwy::bfloat16_t) +#define HWY_IF_NOT_BF16(T) HWY_IF_NOT_SAME(T, hwy::bfloat16_t) + +#define HWY_IF_F16(T) HWY_IF_SAME(T, hwy::float16_t) +#define HWY_IF_NOT_F16(T) HWY_IF_NOT_SAME(T, hwy::float16_t) + +#define HWY_IF_F32(T) HWY_IF_SAME(T, float) +#define HWY_IF_F64(T) HWY_IF_SAME(T, double) + +// Use instead of HWY_IF_T_SIZE to avoid ambiguity with float16_t/float/double +// overloads. +#define HWY_IF_UI8(T) HWY_IF_SAME2(T, uint8_t, int8_t) +#define HWY_IF_UI16(T) HWY_IF_SAME2(T, uint16_t, int16_t) +#define HWY_IF_UI32(T) HWY_IF_SAME2(T, uint32_t, int32_t) +#define HWY_IF_UI64(T) HWY_IF_SAME2(T, uint64_t, int64_t) + +#define HWY_IF_LANES_PER_BLOCK(T, N, LANES) \ + hwy::EnableIf* = nullptr + +// Empty struct used as a size tag type. +template +struct SizeTag {}; + +template +class DeclValT { + private: + template + static URef TryAddRValRef(int); + template + static U TryAddRValRef(Arg); + + public: + using type = decltype(TryAddRValRef(0)); + enum { kDisableDeclValEvaluation = 1 }; +}; + +// hwy::DeclVal() can only be used in unevaluated contexts such as within an +// expression of a decltype specifier. + +// hwy::DeclVal() does not require that T have a public default constructor +template +HWY_API typename DeclValT::type DeclVal() noexcept { + static_assert(!DeclValT::kDisableDeclValEvaluation, + "DeclVal() cannot be used in an evaluated context"); +} + +template +struct IsArrayT { + enum { value = 0 }; +}; + +template +struct IsArrayT { + enum { value = 1 }; +}; + +template +struct IsArrayT { + enum { value = 1 }; +}; + +template +static constexpr bool IsArray() { + return IsArrayT::value; +} + +#if HWY_COMPILER_MSVC +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4180, ignored "-Wignored-qualifiers") +#endif + +template +class IsConvertibleT { + private: + template + static hwy::SizeTag<1> TestFuncWithToArg(T); + + template + static decltype(IsConvertibleT::template TestFuncWithToArg( + DeclVal())) + TryConvTest(int); + + template + static hwy::SizeTag<0> TryConvTest(Arg); + + public: + enum { + value = (IsSame>, void>() && + IsSame>, void>()) || + (!IsArray() && + (IsSame())>() || + !IsSame, RemoveConst>()) && + IsSame(0)), hwy::SizeTag<1>>()) + }; +}; + +#if HWY_COMPILER_MSVC +HWY_DIAGNOSTICS(pop) +#endif + +template +HWY_API constexpr bool IsConvertible() { + return IsConvertibleT::value; +} + +template +class IsStaticCastableT { + private: + template (DeclVal()))> + static hwy::SizeTag<1> TryStaticCastTest(int); + + template + static hwy::SizeTag<0> TryStaticCastTest(Arg); + + public: + enum { + value = IsSame(0)), hwy::SizeTag<1>>() + }; +}; + +template +static constexpr bool IsStaticCastable() { + return IsStaticCastableT::value; +} + +#define HWY_IF_CASTABLE(From, To) \ + hwy::EnableIf()>* = nullptr + +#define HWY_IF_OP_CASTABLE(op, T, Native) \ + HWY_IF_CASTABLE(decltype(DeclVal() op DeclVal()), Native) + +template +class IsAssignableT { + private: + template () = DeclVal())> + static hwy::SizeTag<1> TryAssignTest(int); + + template + static hwy::SizeTag<0> TryAssignTest(Arg); + + public: + enum { + value = IsSame(0)), hwy::SizeTag<1>>() + }; +}; + +template +static constexpr bool IsAssignable() { + return IsAssignableT::value; +} + +#define HWY_IF_ASSIGNABLE(T, From) \ + hwy::EnableIf()>* = nullptr + +// ---------------------------------------------------------------------------- +// IsSpecialFloat + +// These types are often special-cased and not supported in all ops. +template +HWY_API constexpr bool IsSpecialFloat() { + return IsSameEither, hwy::float16_t, hwy::bfloat16_t>(); +} + +// ----------------------------------------------------------------------------- +// IsIntegerLaneType and IsInteger + +template +HWY_API constexpr bool IsIntegerLaneType() { + return false; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} + +namespace detail { + +template +static HWY_INLINE constexpr bool IsNonCvInteger() { + // NOTE: Do not add a IsNonCvInteger() specialization below as it is + // possible for IsSame() to be true when compiled with MSVC + // with the /Zc:wchar_t- option. + return IsIntegerLaneType() || IsSame() || + IsSameEither() || + IsSameEither(); +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +#if defined(__cpp_char8_t) && __cpp_char8_t >= 201811L +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +#endif +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} + +} // namespace detail + +template +HWY_API constexpr bool IsInteger() { + return detail::IsNonCvInteger>(); +} + +// ----------------------------------------------------------------------------- +// BitCastScalar + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +#define HWY_BITCASTSCALAR_CONSTEXPR constexpr +#else +#define HWY_BITCASTSCALAR_CONSTEXPR +#endif + +#if __cpp_constexpr >= 201304L +#define HWY_BITCASTSCALAR_CXX14_CONSTEXPR HWY_BITCASTSCALAR_CONSTEXPR +#else +#define HWY_BITCASTSCALAR_CXX14_CONSTEXPR +#endif + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +namespace detail { + +template +struct BitCastScalarSrcCastHelper { + static HWY_INLINE constexpr const From& CastSrcValRef(const From& val) { + return val; + } +}; + +#if HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 +// Workaround for Clang 9 constexpr __builtin_bit_cast bug +template >() && + hwy::IsInteger>()>* = nullptr> +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR To +BuiltinBitCastScalar(const From& val) { + static_assert(sizeof(To) == sizeof(From), + "sizeof(To) == sizeof(From) must be true"); + return static_cast(val); +} + +template >() && + hwy::IsInteger>())>* = nullptr> +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR To +BuiltinBitCastScalar(const From& val) { + return __builtin_bit_cast(To, val); +} +#endif // HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { + // If From is hwy::float16_t or hwy::bfloat16_t, first cast val to either + // const typename From::Native& or const uint16_t& using + // detail::BitCastScalarSrcCastHelper>::CastSrcValRef to + // allow BitCastScalar from hwy::float16_t or hwy::bfloat16_t to be constexpr + // if To is not a pointer type, union type, or a struct/class containing a + // pointer, union, or reference subobject +#if HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 + return detail::BuiltinBitCastScalar( + detail::BitCastScalarSrcCastHelper>::CastSrcValRef( + val)); +#else + return __builtin_bit_cast( + To, detail::BitCastScalarSrcCastHelper>::CastSrcValRef( + val)); +#endif +} +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { + // If To is hwy::float16_t or hwy::bfloat16_t, first do a BitCastScalar of val + // to uint16_t, and then bit cast the uint16_t value to To using To::FromBits + // as hwy::float16_t::FromBits and hwy::bfloat16_t::FromBits are guaranteed to + // be constexpr if the __builtin_bit_cast intrinsic is available. + return To::FromBits(BitCastScalar(val)); +} +#else +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { + To result; + CopySameSize(&val, &result); + return result; +} +#endif + +//------------------------------------------------------------------------------ +// F16 lane type + +#pragma pack(push, 1) + +// Compiler supports __fp16 and load/store/conversion NEON intrinsics, which are +// included in Armv8 and VFPv4 (except with MSVC). On Armv7 Clang requires +// __ARM_FP & 2 whereas Armv7 GCC requires -mfp16-format=ieee. +#if (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) || \ + (HWY_COMPILER_CLANG && defined(__ARM_FP) && (__ARM_FP & 2)) || \ + (HWY_COMPILER_GCC_ACTUAL && defined(__ARM_FP16_FORMAT_IEEE)) +#define HWY_NEON_HAVE_F16C 1 +#else +#define HWY_NEON_HAVE_F16C 0 +#endif + +// RVV with f16 extension supports _Float16 and f16 vector ops. If set, implies +// HWY_HAVE_FLOAT16. +#if HWY_ARCH_RISCV && defined(__riscv_zvfh) && HWY_COMPILER_CLANG >= 1600 +#define HWY_RVV_HAVE_F16_VEC 1 +#else +#define HWY_RVV_HAVE_F16_VEC 0 +#endif + +// x86 compiler supports _Float16, not necessarily with operators. +// Avoid clang-cl because it lacks __extendhfsf2. +#if HWY_ARCH_X86 && defined(__SSE2__) && defined(__FLT16_MAX__) && \ + ((HWY_COMPILER_CLANG >= 1500 && !HWY_COMPILER_CLANGCL) || \ + HWY_COMPILER_GCC_ACTUAL >= 1200) +#define HWY_SSE2_HAVE_F16_TYPE 1 +#else +#define HWY_SSE2_HAVE_F16_TYPE 0 +#endif + +#ifndef HWY_HAVE_SCALAR_F16_TYPE +// Compiler supports _Float16, not necessarily with operators. +#if HWY_NEON_HAVE_F16C || HWY_RVV_HAVE_F16_VEC || HWY_SSE2_HAVE_F16_TYPE || \ + __SPIRV_DEVICE__ +#define HWY_HAVE_SCALAR_F16_TYPE 1 +#else +#define HWY_HAVE_SCALAR_F16_TYPE 0 +#endif +#endif // HWY_HAVE_SCALAR_F16_TYPE + +#ifndef HWY_HAVE_SCALAR_F16_OPERATORS +// Recent enough compiler also has operators. +#if HWY_HAVE_SCALAR_F16_TYPE && \ + (HWY_COMPILER_CLANG >= 1800 || HWY_COMPILER_GCC_ACTUAL >= 1200 || \ + (HWY_COMPILER_CLANG >= 1500 && !HWY_COMPILER_CLANGCL && \ + !defined(_WIN32)) || \ + (HWY_ARCH_ARM && \ + (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800))) +#define HWY_HAVE_SCALAR_F16_OPERATORS 1 +#else +#define HWY_HAVE_SCALAR_F16_OPERATORS 0 +#endif +#endif // HWY_HAVE_SCALAR_F16_OPERATORS + +namespace detail { + +template , bool = IsSpecialFloat()> +struct SpecialFloatUnwrapArithOpOperandT {}; + +template +struct SpecialFloatUnwrapArithOpOperandT { + using type = T; +}; + +template +using SpecialFloatUnwrapArithOpOperand = + typename SpecialFloatUnwrapArithOpOperandT::type; + +template > +struct NativeSpecialFloatToWrapperT { + using type = T; +}; + +template +using NativeSpecialFloatToWrapper = + typename NativeSpecialFloatToWrapperT::type; + +} // namespace detail + +// Match [u]int##_t naming scheme so rvv-inl.h macros can obtain the type name +// by concatenating base type and bits. We use a wrapper class instead of a +// typedef to the native type to ensure that the same symbols, e.g. for VQSort, +// are generated regardless of F16 support; see #1684. +struct alignas(2) float16_t { +#if HWY_HAVE_SCALAR_F16_TYPE +#if HWY_RVV_HAVE_F16_VEC || HWY_SSE2_HAVE_F16_TYPE || __SPIRV_DEVICE__ + using Native = _Float16; +#elif HWY_NEON_HAVE_F16C + using Native = __fp16; +#else +#error "Logic error: condition should be 'all but NEON_HAVE_F16C'" +#endif +#elif HWY_IDE + using Native = uint16_t; +#endif // HWY_HAVE_SCALAR_F16_TYPE + + union { +#if HWY_HAVE_SCALAR_F16_TYPE || HWY_IDE + // Accessed via NativeLaneType, and used directly if + // HWY_HAVE_SCALAR_F16_OPERATORS. + Native native; +#endif + // Only accessed via NativeLaneType or U16LaneType. + uint16_t bits; + }; + + // Default init and copying. + float16_t() noexcept = default; + constexpr float16_t(const float16_t&) noexcept = default; + constexpr float16_t(float16_t&&) noexcept = default; + float16_t& operator=(const float16_t&) noexcept = default; + float16_t& operator=(float16_t&&) noexcept = default; + +#if HWY_HAVE_SCALAR_F16_TYPE + // NEON vget/set_lane intrinsics and SVE `svaddv` could use explicit + // float16_t(intrinsic()), but user code expects implicit conversions. + constexpr float16_t(Native arg) noexcept : native(arg) {} + constexpr operator Native() const noexcept { return native; } +#endif + +#if HWY_HAVE_SCALAR_F16_TYPE + static HWY_BITCASTSCALAR_CONSTEXPR float16_t FromBits(uint16_t bits) { + return float16_t(BitCastScalar(bits)); + } +#else + + private: + struct F16FromU16BitsTag {}; + constexpr float16_t(F16FromU16BitsTag /*tag*/, uint16_t u16_bits) + : bits(u16_bits) {} + + public: + static constexpr float16_t FromBits(uint16_t bits) { + return float16_t(F16FromU16BitsTag(), bits); + } +#endif + + // When backed by a native type, ensure the wrapper behaves like the native + // type by forwarding all operators. Unfortunately it seems difficult to reuse + // this code in a base class, so we repeat it in float16_t. +#if HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE + template , float16_t>() && + IsConvertible()>* = nullptr> + constexpr float16_t(T&& arg) noexcept + : native(static_cast(static_cast(arg))) {} + + template , float16_t>() && + !IsConvertible() && + IsStaticCastable()>* = nullptr> + explicit constexpr float16_t(T&& arg) noexcept + : native(static_cast(static_cast(arg))) {} + + // pre-decrement operator (--x) + HWY_CXX14_CONSTEXPR float16_t& operator--() noexcept { + native = static_cast(native - Native{1}); + return *this; + } + + // post-decrement operator (x--) + HWY_CXX14_CONSTEXPR float16_t operator--(int) noexcept { + float16_t result = *this; + native = static_cast(native - Native{1}); + return result; + } + + // pre-increment operator (++x) + HWY_CXX14_CONSTEXPR float16_t& operator++() noexcept { + native = static_cast(native + Native{1}); + return *this; + } + + // post-increment operator (x++) + HWY_CXX14_CONSTEXPR float16_t operator++(int) noexcept { + float16_t result = *this; + native = static_cast(native + Native{1}); + return result; + } + + constexpr float16_t operator-() const noexcept { + return float16_t(static_cast(-native)); + } + constexpr float16_t operator+() const noexcept { return *this; } + + // Reduce clutter by generating `operator+` and `operator+=` etc. Note that + // we cannot token-paste `operator` and `+`, so pass it in as `op_func`. +#define HWY_FLOAT16_BINARY_OP(op, op_func, assign_func) \ + constexpr float16_t op_func(const float16_t& rhs) const noexcept { \ + return float16_t(static_cast(native op rhs.native)); \ + } \ + template , \ + typename RawResultT = \ + decltype(DeclVal() op DeclVal()), \ + typename ResultT = \ + detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + constexpr ResultT op_func(const T& rhs) const noexcept(noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + return static_cast(native op static_cast(rhs)); \ + } \ + HWY_CXX14_CONSTEXPR hwy::float16_t& assign_func( \ + const hwy::float16_t& rhs) noexcept { \ + native = static_cast(native op rhs.native); \ + return *this; \ + } \ + template () op DeclVal()))> \ + HWY_CXX14_CONSTEXPR hwy::float16_t& assign_func(const T& rhs) noexcept( \ + noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + native = static_cast(native op rhs); \ + return *this; \ + } + + HWY_FLOAT16_BINARY_OP(+, operator+, operator+=) + HWY_FLOAT16_BINARY_OP(-, operator-, operator-=) + HWY_FLOAT16_BINARY_OP(*, operator*, operator*=) + HWY_FLOAT16_BINARY_OP(/, operator/, operator/=) +#undef HWY_FLOAT16_BINARY_OP + +#endif // HWY_HAVE_SCALAR_F16_OPERATORS +}; +static_assert(sizeof(hwy::float16_t) == 2, "Wrong size of float16_t"); + +#if HWY_HAVE_SCALAR_F16_TYPE +namespace detail { + +#if HWY_HAVE_SCALAR_F16_OPERATORS +template +struct SpecialFloatUnwrapArithOpOperandT { + using type = hwy::float16_t::Native; +}; +#endif + +template +struct NativeSpecialFloatToWrapperT { + using type = hwy::float16_t; +}; + +} // namespace detail +#endif // HWY_HAVE_SCALAR_F16_TYPE + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +namespace detail { + +template <> +struct BitCastScalarSrcCastHelper { +#if HWY_HAVE_SCALAR_F16_TYPE + static HWY_INLINE constexpr const hwy::float16_t::Native& CastSrcValRef( + const hwy::float16_t& val) { + return val.native; + } +#else + static HWY_INLINE constexpr const uint16_t& CastSrcValRef( + const hwy::float16_t& val) { + return val.bits; + } +#endif +}; + +} // namespace detail +#endif // HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 + +#if HWY_HAVE_SCALAR_F16_OPERATORS +#define HWY_F16_CONSTEXPR constexpr +#else +#define HWY_F16_CONSTEXPR HWY_BITCASTSCALAR_CXX14_CONSTEXPR +#endif // HWY_HAVE_SCALAR_F16_OPERATORS + +HWY_API HWY_F16_CONSTEXPR float F32FromF16(float16_t f16) { +#if HWY_HAVE_SCALAR_F16_OPERATORS && !HWY_IDE + return static_cast(f16); +#endif +#if !HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE + const uint16_t bits16 = BitCastScalar(f16); + const uint32_t sign = static_cast(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); + return sign ? -subnormal : subnormal; + } + + // Normalized, infinity or NaN: convert the representation directly + // (faster than ldexp/tables). + const uint32_t biased_exp32 = + biased_exp == 31 ? 0xFF : biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + + return BitCastScalar(bits32); +#endif // !HWY_HAVE_SCALAR_F16_OPERATORS +} + +#if HWY_IS_DEBUG_BUILD && \ + (HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926) +#if defined(__cpp_if_consteval) && __cpp_if_consteval >= 202106L +// If C++23 if !consteval support is available, only execute +// HWY_DASSERT(condition) if F16FromF32 is not called from a constant-evaluated +// context to avoid compilation errors. +#define HWY_F16_FROM_F32_DASSERT(condition) \ + do { \ + if !consteval { \ + HWY_DASSERT(condition); \ + } \ + } while (0) +#elif HWY_HAS_BUILTIN(__builtin_is_constant_evaluated) || \ + HWY_COMPILER_MSVC >= 1926 +// If the __builtin_is_constant_evaluated() intrinsic is available, +// only do HWY_DASSERT(condition) if __builtin_is_constant_evaluated() returns +// false to avoid compilation errors if F16FromF32 is called from a +// constant-evaluated context. +#define HWY_F16_FROM_F32_DASSERT(condition) \ + do { \ + if (!__builtin_is_constant_evaluated()) { \ + HWY_DASSERT(condition); \ + } \ + } while (0) +#else +// If C++23 if !consteval support is not available, +// the __builtin_is_constant_evaluated() intrinsic is not available, +// HWY_IS_DEBUG_BUILD is 1, and the __builtin_bit_cast intrinsic is available, +// do not do a HWY_DASSERT to avoid compilation errors if F16FromF32 is +// called from a constant-evaluated context. +#define HWY_F16_FROM_F32_DASSERT(condition) \ + do { \ + } while (0) +#endif // defined(__cpp_if_consteval) && __cpp_if_consteval >= 202106L +#else +// If HWY_IS_DEBUG_BUILD is 0 or the __builtin_bit_cast intrinsic is not +// available, define HWY_F16_FROM_F32_DASSERT(condition) as +// HWY_DASSERT(condition) +#define HWY_F16_FROM_F32_DASSERT(condition) HWY_DASSERT(condition) +#endif // HWY_IS_DEBUG_BUILD && (HWY_HAS_BUILTIN(__builtin_bit_cast) || + // HWY_COMPILER_MSVC >= 1926) + +HWY_API HWY_F16_CONSTEXPR float16_t F16FromF32(float f32) { +#if HWY_HAVE_SCALAR_F16_OPERATORS && !HWY_IDE + return float16_t(static_cast(f32)); +#endif +#if !HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE + const uint32_t bits32 = BitCastScalar(f32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + constexpr uint32_t kMantissaMask = 0x7FFFFF; + const uint32_t mantissa32 = bits32 & kMantissaMask; + + // Before shifting (truncation), round to nearest even to reduce bias. If + // the lowest remaining mantissa bit is odd, increase the offset. Example + // with the lowest remaining bit (left) and next lower two bits; the + // latter, plus two more, will be truncated. + // 0[00] + 1 = 0[01] + // 0[01] + 1 = 0[10] + // 0[10] + 1 = 0[11] (round down toward even) + // 0[11] + 1 = 1[00] (round up) + // 1[00] + 10 = 1[10] + // 1[01] + 10 = 1[11] + // 1[10] + 10 = C0[00] (round up toward even with C=1 carry out) + // 1[11] + 10 = C0[01] (round up toward even with C=1 carry out) + + // If |f32| >= 2^-24, f16_ulp_bit_idx is the index of the F32 mantissa bit + // that will be shifted down into the ULP bit of the rounded down F16 result + + // The biased F32 exponent of 2^-14 (the smallest positive normal F16 value) + // is 113, and bit 13 of the F32 mantissa will be shifted down to into the ULP + // bit of the rounded down F16 result if |f32| >= 2^14 + + // If |f32| < 2^-24, f16_ulp_bit_idx is equal to 24 as there are 24 mantissa + // bits (including the implied 1 bit) in the mantissa of a normal F32 value + // and as we want to round up the mantissa if |f32| > 2^-25 && |f32| < 2^-24 + const int32_t f16_ulp_bit_idx = + HWY_MIN(HWY_MAX(126 - static_cast(biased_exp32), 13), 24); + const uint32_t odd_bit = ((mantissa32 | 0x800000u) >> f16_ulp_bit_idx) & 1; + const uint32_t rounded = + mantissa32 + odd_bit + (uint32_t{1} << (f16_ulp_bit_idx - 1)) - 1u; + const bool carry = rounded >= (1u << 23); + + const int32_t exp = static_cast(biased_exp32) - 127 + carry; + + // Tiny or zero => zero. + if (exp < -24) { + // restore original sign + return float16_t::FromBits(static_cast(sign << 15)); + } + + // If biased_exp16 would be >= 31, first check whether the input was NaN so we + // can set the mantissa to nonzero. + const bool is_nan = (biased_exp32 == 255) && mantissa32 != 0; + const bool overflowed = exp >= 16; + const uint32_t biased_exp16 = + static_cast(HWY_MIN(HWY_MAX(0, exp + 15), 31)); + // exp = [-24, -15] => subnormal, shift the mantissa. + const uint32_t sub_exp = static_cast(HWY_MAX(-14 - exp, 0)); + HWY_F16_FROM_F32_DASSERT(sub_exp < 11); + const uint32_t shifted_mantissa = + (rounded & kMantissaMask) >> (23 - 10 + sub_exp); + const uint32_t leading = sub_exp == 0u ? 0u : (1024u >> sub_exp); + const uint32_t mantissa16 = is_nan ? 0x3FF + : overflowed ? 0u + : (leading + shifted_mantissa); + +#if HWY_IS_DEBUG_BUILD + if (exp < -14) { + HWY_F16_FROM_F32_DASSERT(biased_exp16 == 0); + HWY_F16_FROM_F32_DASSERT(sub_exp >= 1); + } else if (exp <= 15) { + HWY_F16_FROM_F32_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + HWY_F16_FROM_F32_DASSERT(sub_exp == 0); + } +#endif + + HWY_F16_FROM_F32_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_F16_FROM_F32_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast(bits16); // big-endian safe + return float16_t::FromBits(narrowed); +#endif // !HWY_HAVE_SCALAR_F16_OPERATORS +} + +HWY_API HWY_F16_CONSTEXPR float16_t F16FromF64(double f64) { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return float16_t(static_cast(f64)); +#else + // The mantissa bits of f64 are first rounded using round-to-odd rounding + // to the nearest f64 value that has the lower 29 bits zeroed out to + // ensure that the result is correctly rounded to a F16. + + // The F64 round-to-odd operation below will round a normal F64 value + // (using round-to-odd rounding) to a F64 value that has 24 bits of precision. + + // It is okay if the magnitude of a denormal F64 value is rounded up in the + // F64 round-to-odd step below as the magnitude of a denormal F64 value is + // much smaller than 2^(-24) (the smallest positive denormal F16 value). + + // It is also okay if bit 29 of a NaN F64 value is changed by the F64 + // round-to-odd step below as the lower 13 bits of a F32 NaN value are usually + // discarded or ignored by the conversion of a F32 NaN value to a F16. + + // If f64 is a NaN value, the result of the F64 round-to-odd step will be a + // NaN value as the result of the F64 round-to-odd step will have at least one + // mantissa bit if f64 is a NaN value. + + // The F64 round-to-odd step will ensure that the F64 to F32 conversion is + // exact if the magnitude of the rounded F64 value (using round-to-odd + // rounding) is between 2^(-126) (the smallest normal F32 value) and + // HighestValue() (the largest finite F32 value) + + // It is okay if the F64 to F32 conversion is inexact for F64 values that have + // a magnitude that is less than 2^(-126) as the magnitude of a denormal F32 + // value is much smaller than 2^(-24) (the smallest positive denormal F16 + // value). + + return F16FromF32( + static_cast(BitCastScalar(static_cast( + (BitCastScalar(f64) & 0xFFFFFFFFE0000000ULL) | + ((BitCastScalar(f64) + 0x000000001FFFFFFFULL) & + 0x0000000020000000ULL))))); +#endif +} + +// More convenient to define outside float16_t because these may use +// F32FromF16, which is defined after the struct. +HWY_F16_CONSTEXPR inline bool operator==(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native == rhs.native; +#else + return F32FromF16(lhs) == F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator!=(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native != rhs.native; +#else + return F32FromF16(lhs) != F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator<(float16_t lhs, float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native < rhs.native; +#else + return F32FromF16(lhs) < F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator<=(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native <= rhs.native; +#else + return F32FromF16(lhs) <= F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator>(float16_t lhs, float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native > rhs.native; +#else + return F32FromF16(lhs) > F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator>=(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native >= rhs.native; +#else + return F32FromF16(lhs) >= F32FromF16(rhs); +#endif +} +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_F16_CONSTEXPR inline std::partial_ordering operator<=>( + float16_t lhs, float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native <=> rhs.native; +#else + return F32FromF16(lhs) <=> F32FromF16(rhs); +#endif +} +#endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE + +//------------------------------------------------------------------------------ +// BF16 lane type + +// Compiler supports ACLE __bf16, not necessarily with operators. + +// Disable the __bf16 type on AArch64 with GCC 13 or earlier as there is a bug +// in GCC 13 and earlier that sometimes causes BF16 constant values to be +// incorrectly loaded on AArch64, and this GCC bug on AArch64 is +// described at https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111867. + +#if HWY_ARCH_ARM_A64 && \ + (HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400) +#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 1 +#else +#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 0 +#endif + +// x86 compiler supports __bf16, not necessarily with operators. +#ifndef HWY_SSE2_HAVE_SCALAR_BF16_TYPE +#if HWY_ARCH_X86 && defined(__SSE2__) && \ + ((HWY_COMPILER_CLANG >= 1700 && !HWY_COMPILER_CLANGCL) || \ + HWY_COMPILER_GCC_ACTUAL >= 1300) +#define HWY_SSE2_HAVE_SCALAR_BF16_TYPE 1 +#else +#define HWY_SSE2_HAVE_SCALAR_BF16_TYPE 0 +#endif +#endif // HWY_SSE2_HAVE_SCALAR_BF16_TYPE + +// Compiler supports __bf16, not necessarily with operators. +#if HWY_ARM_HAVE_SCALAR_BF16_TYPE || HWY_SSE2_HAVE_SCALAR_BF16_TYPE +#define HWY_HAVE_SCALAR_BF16_TYPE 1 +#else +#define HWY_HAVE_SCALAR_BF16_TYPE 0 +#endif + +#ifndef HWY_HAVE_SCALAR_BF16_OPERATORS +// Recent enough compiler also has operators. aarch64 clang 18 hits internal +// compiler errors on bf16 ToString, hence only enable on GCC for now. +#if HWY_HAVE_SCALAR_BF16_TYPE && (HWY_COMPILER_GCC_ACTUAL >= 1300) +#define HWY_HAVE_SCALAR_BF16_OPERATORS 1 +#else +#define HWY_HAVE_SCALAR_BF16_OPERATORS 0 +#endif +#endif // HWY_HAVE_SCALAR_BF16_OPERATORS + +#if HWY_HAVE_SCALAR_BF16_OPERATORS +#define HWY_BF16_CONSTEXPR constexpr +#else +#define HWY_BF16_CONSTEXPR HWY_BITCASTSCALAR_CONSTEXPR +#endif + +struct alignas(2) bfloat16_t { +#if HWY_HAVE_SCALAR_BF16_TYPE + using Native = __bf16; +#elif HWY_IDE + using Native = uint16_t; +#endif + + union { +#if HWY_HAVE_SCALAR_BF16_TYPE || HWY_IDE + // Accessed via NativeLaneType, and used directly if + // HWY_HAVE_SCALAR_BF16_OPERATORS. + Native native; +#endif + // Only accessed via NativeLaneType or U16LaneType. + uint16_t bits; + }; + + // Default init and copying + bfloat16_t() noexcept = default; + constexpr bfloat16_t(bfloat16_t&&) noexcept = default; + constexpr bfloat16_t(const bfloat16_t&) noexcept = default; + bfloat16_t& operator=(bfloat16_t&& arg) noexcept = default; + bfloat16_t& operator=(const bfloat16_t& arg) noexcept = default; + +// Only enable implicit conversions if we have a native type. +#if HWY_HAVE_SCALAR_BF16_TYPE || HWY_IDE + constexpr bfloat16_t(Native arg) noexcept : native(arg) {} + constexpr operator Native() const noexcept { return native; } +#endif + +#if HWY_HAVE_SCALAR_BF16_TYPE + static HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t FromBits(uint16_t bits) { + return bfloat16_t(BitCastScalar(bits)); + } +#else + + private: + struct BF16FromU16BitsTag {}; + constexpr bfloat16_t(BF16FromU16BitsTag /*tag*/, uint16_t u16_bits) + : bits(u16_bits) {} + + public: + static constexpr bfloat16_t FromBits(uint16_t bits) { + return bfloat16_t(BF16FromU16BitsTag(), bits); + } +#endif + + // When backed by a native type, ensure the wrapper behaves like the native + // type by forwarding all operators. Unfortunately it seems difficult to reuse + // this code in a base class, so we repeat it in float16_t. +#if HWY_HAVE_SCALAR_BF16_OPERATORS || HWY_IDE + template , Native>() && + !IsSame, bfloat16_t>() && + IsConvertible()>* = nullptr> + constexpr bfloat16_t(T&& arg) noexcept( + noexcept(static_cast(DeclVal()))) + : native(static_cast(static_cast(arg))) {} + + template , Native>() && + !IsSame, bfloat16_t>() && + !IsConvertible() && + IsStaticCastable()>* = nullptr> + explicit constexpr bfloat16_t(T&& arg) noexcept( + noexcept(static_cast(DeclVal()))) + : native(static_cast(static_cast(arg))) {} + + HWY_CXX14_CONSTEXPR bfloat16_t& operator=(Native arg) noexcept { + native = arg; + return *this; + } + + // pre-decrement operator (--x) + HWY_CXX14_CONSTEXPR bfloat16_t& operator--() noexcept { + native = static_cast(native - Native{1}); + return *this; + } + + // post-decrement operator (x--) + HWY_CXX14_CONSTEXPR bfloat16_t operator--(int) noexcept { + bfloat16_t result = *this; + native = static_cast(native - Native{1}); + return result; + } + + // pre-increment operator (++x) + HWY_CXX14_CONSTEXPR bfloat16_t& operator++() noexcept { + native = static_cast(native + Native{1}); + return *this; + } + + // post-increment operator (x++) + HWY_CXX14_CONSTEXPR bfloat16_t operator++(int) noexcept { + bfloat16_t result = *this; + native = static_cast(native + Native{1}); + return result; + } + + constexpr bfloat16_t operator-() const noexcept { + return bfloat16_t(static_cast(-native)); + } + constexpr bfloat16_t operator+() const noexcept { return *this; } + + // Reduce clutter by generating `operator+` and `operator+=` etc. Note that + // we cannot token-paste `operator` and `+`, so pass it in as `op_func`. +#define HWY_BFLOAT16_BINARY_OP(op, op_func, assign_func) \ + constexpr bfloat16_t op_func(const bfloat16_t& rhs) const noexcept { \ + return bfloat16_t(static_cast(native op rhs.native)); \ + } \ + template , \ + typename RawResultT = \ + decltype(DeclVal() op DeclVal()), \ + typename ResultT = \ + detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + constexpr ResultT op_func(const T& rhs) const noexcept(noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + return static_cast(native op static_cast(rhs)); \ + } \ + HWY_CXX14_CONSTEXPR hwy::bfloat16_t& assign_func( \ + const hwy::bfloat16_t& rhs) noexcept { \ + native = static_cast(native op rhs.native); \ + return *this; \ + } \ + template () op DeclVal()))> \ + HWY_CXX14_CONSTEXPR hwy::bfloat16_t& assign_func(const T& rhs) noexcept( \ + noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + native = static_cast(native op rhs); \ + return *this; \ + } + HWY_BFLOAT16_BINARY_OP(+, operator+, operator+=) + HWY_BFLOAT16_BINARY_OP(-, operator-, operator-=) + HWY_BFLOAT16_BINARY_OP(*, operator*, operator*=) + HWY_BFLOAT16_BINARY_OP(/, operator/, operator/=) +#undef HWY_BFLOAT16_BINARY_OP + +#endif // HWY_HAVE_SCALAR_BF16_OPERATORS +}; +static_assert(sizeof(hwy::bfloat16_t) == 2, "Wrong size of bfloat16_t"); + +#pragma pack(pop) + +#if HWY_HAVE_SCALAR_BF16_TYPE +namespace detail { + +#if HWY_HAVE_SCALAR_BF16_OPERATORS +template +struct SpecialFloatUnwrapArithOpOperandT { + using type = hwy::bfloat16_t::Native; +}; +#endif + +template +struct NativeSpecialFloatToWrapperT { + using type = hwy::bfloat16_t; +}; + +} // namespace detail +#endif // HWY_HAVE_SCALAR_BF16_TYPE + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +namespace detail { + +template <> +struct BitCastScalarSrcCastHelper { +#if HWY_HAVE_SCALAR_BF16_TYPE + static HWY_INLINE constexpr const hwy::bfloat16_t::Native& CastSrcValRef( + const hwy::bfloat16_t& val) { + return val.native; + } +#else + static HWY_INLINE constexpr const uint16_t& CastSrcValRef( + const hwy::bfloat16_t& val) { + return val.bits; + } +#endif +}; + +} // namespace detail +#endif // HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 + +HWY_API HWY_BF16_CONSTEXPR float F32FromBF16(bfloat16_t bf) { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return static_cast(bf); +#else + return BitCastScalar(static_cast( + static_cast(BitCastScalar(bf)) << 16)); +#endif +} + +namespace detail { + +// Returns the increment to add to the bits of a finite F32 value to round a +// finite F32 to the nearest BF16 value +static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint32_t F32BitsToBF16RoundIncr( + const uint32_t f32_bits) { + return static_cast(((f32_bits & 0x7FFFFFFFu) < 0x7F800000u) + ? (0x7FFFu + ((f32_bits >> 16) & 1u)) + : 0u); +} + +// Converts f32_bits (which is the bits of a F32 value) to BF16 bits, +// rounded to the nearest F16 value +static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint16_t F32BitsToBF16Bits( + const uint32_t f32_bits) { + // Round f32_bits to the nearest BF16 by first adding + // F32BitsToBF16RoundIncr(f32_bits) to f32_bits and then right shifting + // f32_bits + F32BitsToBF16RoundIncr(f32_bits) by 16 + + // If f32_bits is the bit representation of a NaN F32 value, make sure that + // bit 6 of the BF16 result is set to convert SNaN F32 values to QNaN BF16 + // values and to prevent NaN F32 values from being converted to an infinite + // BF16 value + return static_cast( + ((f32_bits + F32BitsToBF16RoundIncr(f32_bits)) >> 16) | + (static_cast((f32_bits & 0x7FFFFFFFu) > 0x7F800000u) << 6)); +} + +} // namespace detail + +HWY_API HWY_BF16_CONSTEXPR bfloat16_t BF16FromF32(float f) { + // The rounding mode is not specified in the C++ standard, so ignore + // `HWY_HAVE_SCALAR_BF16_OPERATORS` and only use our round to nearest. + return bfloat16_t::FromBits( + detail::F32BitsToBF16Bits(BitCastScalar(f))); +} + +HWY_API HWY_BF16_CONSTEXPR bfloat16_t BF16FromF64(double f64) { + // The mantissa bits of f64 are first rounded using round-to-odd rounding + // to the nearest f64 value that has the lower 38 bits zeroed out to + // ensure that the result is correctly rounded to a BF16. + + // The F64 round-to-odd operation below will round a normal F64 value + // (using round-to-odd rounding) to a F64 value that has 15 bits of precision. + + // It is okay if the magnitude of a denormal F64 value is rounded up in the + // F64 round-to-odd step below as the magnitude of a denormal F64 value is + // much smaller than 2^(-133) (the smallest positive denormal BF16 value). + + // It is also okay if bit 38 of a NaN F64 value is changed by the F64 + // round-to-odd step below as the lower 16 bits of a F32 NaN value are usually + // discarded or ignored by the conversion of a F32 NaN value to a BF16. + + // If f64 is a NaN value, the result of the F64 round-to-odd step will be a + // NaN value as the result of the F64 round-to-odd step will have at least one + // mantissa bit if f64 is a NaN value. + + // The F64 round-to-odd step below will ensure that the F64 to F32 conversion + // is exact if the magnitude of the rounded F64 value (using round-to-odd + // rounding) is between 2^(-135) (one-fourth of the smallest positive denormal + // BF16 value) and HighestValue() (the largest finite F32 value). + + // If |f64| is less than 2^(-135), the magnitude of the result of the F64 to + // F32 conversion is guaranteed to be less than or equal to 2^(-135), which + // ensures that the F32 to BF16 conversion is correctly rounded, even if the + // conversion of a rounded F64 value whose magnitude is less than 2^(-135) + // to a F32 is inexact. + + return BF16FromF32( + static_cast(BitCastScalar(static_cast( + (BitCastScalar(f64) & 0xFFFFFFC000000000ULL) | + ((BitCastScalar(f64) + 0x0000003FFFFFFFFFULL) & + 0x0000004000000000ULL))))); +} + +// More convenient to define outside bfloat16_t because these may use +// F32FromBF16, which is defined after the struct. + +HWY_BF16_CONSTEXPR inline bool operator==(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native == rhs.native; +#else + return F32FromBF16(lhs) == F32FromBF16(rhs); +#endif +} + +HWY_BF16_CONSTEXPR inline bool operator!=(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native != rhs.native; +#else + return F32FromBF16(lhs) != F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator<(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native < rhs.native; +#else + return F32FromBF16(lhs) < F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator<=(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native <= rhs.native; +#else + return F32FromBF16(lhs) <= F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator>(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native > rhs.native; +#else + return F32FromBF16(lhs) > F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator>=(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native >= rhs.native; +#else + return F32FromBF16(lhs) >= F32FromBF16(rhs); +#endif +} +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_BF16_CONSTEXPR inline std::partial_ordering operator<=>( + bfloat16_t lhs, bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native <=> rhs.native; +#else + return F32FromBF16(lhs) <=> F32FromBF16(rhs); +#endif +} +#endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE + +//------------------------------------------------------------------------------ +// Type relations + +namespace detail { + +template +struct Relations; +template <> +struct Relations { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = uint16_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = int16_t; + enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; + using Wide = uint32_t; + using Narrow = uint8_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; + using Wide = int32_t; + using Narrow = int8_t; + enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = uint64_t; + using Narrow = uint16_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = int64_t; + using Narrow = int16_t; + enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Wide = uint128_t; + using Narrow = uint32_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = int32_t; + enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint128_t; + using Narrow = uint64_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; + using Wide = float; + enum { is_signed = 1, is_float = 1, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = float; + enum { is_signed = 1, is_float = 1, is_bf16 = 1 }; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = double; + using Narrow = float16_t; + enum { is_signed = 1, is_float = 1, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = float; + enum { is_signed = 1, is_float = 1, is_bf16 = 0 }; +}; + +template +struct TypeFromSize; +template <> +struct TypeFromSize<1> { + using Unsigned = uint8_t; + using Signed = int8_t; +}; +template <> +struct TypeFromSize<2> { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; +}; +template <> +struct TypeFromSize<4> { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; +}; +template <> +struct TypeFromSize<8> { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; +}; +template <> +struct TypeFromSize<16> { + using Unsigned = uint128_t; +}; + +} // namespace detail + +// Aliases for types of a different category, but the same size. +template +using MakeUnsigned = typename detail::Relations::Unsigned; +template +using MakeSigned = typename detail::Relations::Signed; +template +using MakeFloat = typename detail::Relations::Float; + +// Aliases for types of the same category, but different size. +template +using MakeWide = typename detail::Relations::Wide; +template +using MakeNarrow = typename detail::Relations::Narrow; + +// Obtain type from its size [bytes]. +template +using UnsignedFromSize = typename detail::TypeFromSize::Unsigned; +template +using SignedFromSize = typename detail::TypeFromSize::Signed; +template +using FloatFromSize = typename detail::TypeFromSize::Float; + +// Avoid confusion with SizeTag where the parameter is a lane size. +using UnsignedTag = SizeTag<0>; +using SignedTag = SizeTag<0x100>; // integer +using FloatTag = SizeTag<0x200>; +using SpecialTag = SizeTag<0x300>; + +template > +constexpr auto TypeTag() + -> hwy::SizeTag<((R::is_signed + R::is_float + R::is_bf16) << 8)> { + return hwy::SizeTag<((R::is_signed + R::is_float + R::is_bf16) << 8)>(); +} + +// For when we only want to distinguish FloatTag from everything else. +using NonFloatTag = SizeTag<0x400>; + +template > +constexpr auto IsFloatTag() -> hwy::SizeTag<(R::is_float ? 0x200 : 0x400)> { + return hwy::SizeTag<(R::is_float ? 0x200 : 0x400)>(); +} + +//------------------------------------------------------------------------------ +// Type traits + +template +HWY_API constexpr bool IsFloat3264() { + return IsSameEither, float, double>(); +} + +template +HWY_API constexpr bool IsFloat() { + // Cannot use T(1.25) != T(1) for float16_t, which can only be converted to or + // from a float, not compared. Include float16_t in case HWY_HAVE_FLOAT16=1. + return IsSame, float16_t>() || IsFloat3264(); +} + +template +HWY_API constexpr bool IsSigned() { + return static_cast(0) > static_cast(-1); +} +template <> +constexpr bool IsSigned() { + return true; +} +template <> +constexpr bool IsSigned() { + return true; +} +template <> +constexpr bool IsSigned() { + return false; +} +template <> +constexpr bool IsSigned() { + return false; +} +template <> +constexpr bool IsSigned() { + return false; +} + +template () && !IsIntegerLaneType()> +struct MakeLaneTypeIfIntegerT { + using type = T; +}; + +template +struct MakeLaneTypeIfIntegerT { + using type = hwy::If(), SignedFromSize, + UnsignedFromSize>; +}; + +template +using MakeLaneTypeIfInteger = typename MakeLaneTypeIfIntegerT::type; + +// Largest/smallest representable integer values. +template +HWY_API constexpr T LimitsMax() { + static_assert(IsInteger(), "Only for integer types"); + using TU = UnsignedFromSize; + return static_cast(IsSigned() ? (static_cast(~TU(0)) >> 1) + : static_cast(~TU(0))); +} +template +HWY_API constexpr T LimitsMin() { + static_assert(IsInteger(), "Only for integer types"); + return IsSigned() ? static_cast(-1) - LimitsMax() + : static_cast(0); +} + +// Largest/smallest representable value (integer or float). This naming avoids +// confusion with numeric_limits::min() (the smallest positive value). +// Cannot be constexpr because we use CopySameSize for [b]float16_t. +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T LowestValue() { + return LimitsMin(); +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t LowestValue() { + return bfloat16_t::FromBits(uint16_t{0xFF7Fu}); // -1.1111111 x 2^127 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t LowestValue() { + return float16_t::FromBits(uint16_t{0xFBFFu}); // -1.1111111111 x 2^15 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float LowestValue() { + return -3.402823466e+38F; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double LowestValue() { + return -1.7976931348623158e+308; +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T HighestValue() { + return LimitsMax(); +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t HighestValue() { + return bfloat16_t::FromBits(uint16_t{0x7F7Fu}); // 1.1111111 x 2^127 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t HighestValue() { + return float16_t::FromBits(uint16_t{0x7BFFu}); // 1.1111111111 x 2^15 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float HighestValue() { + return 3.402823466e+38F; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double HighestValue() { + return 1.7976931348623158e+308; +} + +// Difference between 1.0 and the next representable value. Equal to +// 1 / (1ULL << MantissaBits()), but hard-coding ensures precision. +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T Epsilon() { + return 1; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t Epsilon() { + return bfloat16_t::FromBits(uint16_t{0x3C00u}); // 0.0078125 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t Epsilon() { + return float16_t::FromBits(uint16_t{0x1400u}); // 0.0009765625 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float Epsilon() { + return 1.192092896e-7f; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double Epsilon() { + return 2.2204460492503131e-16; +} + +// Returns width in bits of the mantissa field in IEEE binary16/32/64. +template +constexpr int MantissaBits() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr int MantissaBits() { + return 7; +} +template <> +constexpr int MantissaBits() { + return 10; +} +template <> +constexpr int MantissaBits() { + return 23; +} +template <> +constexpr int MantissaBits() { + return 52; +} + +// Returns the (left-shifted by one bit) IEEE binary16/32/64 representation with +// the largest possible (biased) exponent field. Used by IsInf. +template +constexpr MakeSigned MaxExponentTimes2() { + return -(MakeSigned{1} << (MantissaBits() + 1)); +} + +// Returns bitmask of the sign bit in IEEE binary16/32/64. +template +constexpr MakeUnsigned SignMask() { + return MakeUnsigned{1} << (sizeof(T) * 8 - 1); +} + +// Returns bitmask of the exponent field in IEEE binary16/32/64. +template +constexpr MakeUnsigned ExponentMask() { + return (~(MakeUnsigned{1} << MantissaBits()) + 1) & + static_cast>(~SignMask()); +} + +// Returns bitmask of the mantissa field in IEEE binary16/32/64. +template +constexpr MakeUnsigned MantissaMask() { + return (MakeUnsigned{1} << MantissaBits()) - 1; +} + +// Returns 1 << mantissa_bits as a floating-point number. All integers whose +// absolute value are less than this can be represented exactly. +template +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T MantissaEnd() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t MantissaEnd() { + return bfloat16_t::FromBits(uint16_t{0x4300u}); // 1.0 x 2^7 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t MantissaEnd() { + return float16_t::FromBits(uint16_t{0x6400u}); // 1.0 x 2^10 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float MantissaEnd() { + return 8388608.0f; // 1 << 23 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double MantissaEnd() { + // floating point literal with p52 requires C++17. + return 4503599627370496.0; // 1 << 52 +} + +// Returns width in bits of the exponent field in IEEE binary16/32/64. +template +constexpr int ExponentBits() { + // Exponent := remaining bits after deducting sign and mantissa. + return 8 * sizeof(T) - 1 - MantissaBits(); +} + +// Returns largest value of the biased exponent field in IEEE binary16/32/64, +// right-shifted so that the LSB is bit zero. Example: 0xFF for float. +// This is expressed as a signed integer for more efficient comparison. +template +constexpr MakeSigned MaxExponentField() { + return (MakeSigned{1} << ExponentBits()) - 1; +} + +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CONSTEXPR T +NegativeInfOrLowestValue(hwy::FloatTag /* tag */) { + return BitCastScalar( + static_cast>(SignMask() | ExponentMask())); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CONSTEXPR T +NegativeInfOrLowestValue(hwy::NonFloatTag /* tag */) { + return LowestValue(); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CONSTEXPR T +PositiveInfOrHighestValue(hwy::FloatTag /* tag */) { + return BitCastScalar(ExponentMask()); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CONSTEXPR T +PositiveInfOrHighestValue(hwy::NonFloatTag /* tag */) { + return HighestValue(); +} + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T NegativeInfOrLowestValue() { + return detail::NegativeInfOrLowestValue(IsFloatTag()); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T PositiveInfOrHighestValue() { + return detail::PositiveInfOrHighestValue(IsFloatTag()); +} + +//------------------------------------------------------------------------------ +// Additional F16/BF16 operators + +#if HWY_HAVE_SCALAR_F16_OPERATORS || HWY_HAVE_SCALAR_BF16_OPERATORS + +#define HWY_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T2) \ + template < \ + typename T1, \ + hwy::EnableIf>() || \ + hwy::IsFloat3264>()>* = nullptr, \ + typename RawResultT = decltype(DeclVal() op DeclVal()), \ + typename ResultT = detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + static HWY_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept { \ + return static_cast(a op b.native); \ + } + +#define HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(op, assign_op, T2) \ + template >() || \ + hwy::IsFloat3264>()>* = nullptr, \ + typename ResultT = \ + decltype(DeclVal() assign_op DeclVal())> \ + static HWY_INLINE constexpr ResultT operator assign_op(T1& a, \ + T2 b) noexcept { \ + return (a assign_op b.native); \ + } + +#define HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(op, op_func, T1) \ + HWY_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T1) \ + template < \ + typename T2, \ + hwy::EnableIf>() || \ + hwy::IsFloat3264>()>* = nullptr, \ + typename RawResultT = decltype(DeclVal() op DeclVal()), \ + typename ResultT = detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + static HWY_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept { \ + return static_cast(a.native op b); \ + } + +#if HWY_HAVE_SCALAR_F16_OPERATORS +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(+, operator+, float16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(-, operator-, float16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(*, operator*, float16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(/, operator/, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(+, +=, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(-, -=, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(*, *=, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(/, /=, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(==, operator==, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(!=, operator!=, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<, operator<, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=, operator<=, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>, operator>, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>=, operator>=, float16_t) +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=>, operator<=>, float16_t) +#endif +#endif // HWY_HAVE_SCALAR_F16_OPERATORS + +#if HWY_HAVE_SCALAR_BF16_OPERATORS +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(+, operator+, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(-, operator-, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(*, operator*, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(/, operator/, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(+, +=, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(-, -=, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(*, *=, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(/, /=, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(==, operator==, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(!=, operator!=, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<, operator<, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=, operator<=, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>, operator>, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>=, operator>=, bfloat16_t) +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=>, operator<=>, bfloat16_t) +#endif +#endif // HWY_HAVE_SCALAR_BF16_OPERATORS + +#undef HWY_RHS_SPECIAL_FLOAT_ARITH_OP +#undef HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP +#undef HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP + +#endif // HWY_HAVE_SCALAR_F16_OPERATORS || HWY_HAVE_SCALAR_BF16_OPERATORS + +//------------------------------------------------------------------------------ +// Type conversions (after IsSpecialFloat) + +HWY_API float F32FromF16Mem(const void* ptr) { + float16_t f16; + CopyBytes<2>(HWY_ASSUME_ALIGNED(ptr, 2), &f16); + return F32FromF16(f16); +} + +HWY_API float F32FromBF16Mem(const void* ptr) { + bfloat16_t bf; + CopyBytes<2>(HWY_ASSUME_ALIGNED(ptr, 2), &bf); + return F32FromBF16(bf); +} + +#if HWY_HAVE_SCALAR_F16_OPERATORS +#define HWY_BF16_TO_F16_CONSTEXPR HWY_BF16_CONSTEXPR +#else +#define HWY_BF16_TO_F16_CONSTEXPR HWY_F16_CONSTEXPR +#endif + +// For casting from TFrom to TTo +template +HWY_API constexpr TTo ConvertScalarTo(const TFrom in) { + return static_cast(in); +} +template +HWY_API constexpr TTo ConvertScalarTo(const TFrom in) { + return F16FromF32(static_cast(in)); +} +template +HWY_API HWY_BF16_TO_F16_CONSTEXPR TTo +ConvertScalarTo(const hwy::bfloat16_t in) { + return F16FromF32(F32FromBF16(in)); +} +template +HWY_API HWY_F16_CONSTEXPR TTo ConvertScalarTo(const double in) { + return F16FromF64(in); +} +template +HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(const TFrom in) { + return BF16FromF32(static_cast(in)); +} +template +HWY_API HWY_BF16_TO_F16_CONSTEXPR TTo ConvertScalarTo(const hwy::float16_t in) { + return BF16FromF32(F32FromF16(in)); +} +template +HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(const double in) { + return BF16FromF64(in); +} +template +HWY_API HWY_F16_CONSTEXPR TTo ConvertScalarTo(const TFrom in) { + return static_cast(F32FromF16(in)); +} +template +HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(TFrom in) { + return static_cast(F32FromBF16(in)); +} +// Same: return unchanged +template +HWY_API constexpr TTo ConvertScalarTo(TTo in) { + return in; +} + +//------------------------------------------------------------------------------ +// Helper functions + +template +constexpr inline T1 DivCeil(T1 a, T2 b) { +#if HWY_CXX_LANG >= 201703L + HWY_DASSERT(b != 0); +#endif + return (a + b - 1) / b; +} + +// Works for any non-zero `align`; if a power of two, compiler emits ADD+AND. +constexpr inline size_t RoundUpTo(size_t what, size_t align) { + return DivCeil(what, align) * align; +} + +// Works for any `align`; if a power of two, compiler emits AND. +constexpr inline size_t RoundDownTo(size_t what, size_t align) { + return what - (what % align); +} + +namespace detail { + +// T is unsigned or T is signed and (val >> shift_amt) is an arithmetic right +// shift +template +static HWY_INLINE constexpr T ScalarShr(hwy::UnsignedTag /*type_tag*/, T val, + int shift_amt) { + return static_cast(val >> shift_amt); +} + +// T is signed and (val >> shift_amt) is a non-arithmetic right shift +template +static HWY_INLINE constexpr T ScalarShr(hwy::SignedTag /*type_tag*/, T val, + int shift_amt) { + using TU = MakeUnsigned>; + return static_cast( + (val < 0) ? static_cast( + ~(static_cast(~static_cast(val)) >> shift_amt)) + : static_cast(static_cast(val) >> shift_amt)); +} + +} // namespace detail + +// If T is an signed integer type, ScalarShr is guaranteed to perform an +// arithmetic right shift + +// Otherwise, if T is an unsigned integer type, ScalarShr is guaranteed to +// perform a logical right shift +template )> +HWY_API constexpr RemoveCvRef ScalarShr(T val, int shift_amt) { + using NonCvRefT = RemoveCvRef; + return detail::ScalarShr( + hwy::SizeTag<((IsSigned() && + (LimitsMin() >> (sizeof(T) * 8 - 1)) != + static_cast(-1)) + ? 0x100 + : 0)>(), + static_cast(val), shift_amt); +} + +// Undefined results for x == 0. +HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x) { + HWY_DASSERT(x != 0); +#if HWY_COMPILER_MSVC + unsigned long index; // NOLINT + _BitScanForward(&index, x); + return index; +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_ctz(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t Num0BitsBelowLS1Bit_Nonzero64(const uint64_t x) { + HWY_DASSERT(x != 0); +#if HWY_COMPILER_MSVC +#if HWY_ARCH_X86_64 + unsigned long index; // NOLINT + _BitScanForward64(&index, x); + return index; +#else // HWY_ARCH_X86_64 + // _BitScanForward64 not available + uint32_t lsb = static_cast(x & 0xFFFFFFFF); + unsigned long index; // NOLINT + if (lsb == 0) { + uint32_t msb = static_cast(x >> 32u); + _BitScanForward(&index, msb); + return 32 + index; + } else { + _BitScanForward(&index, lsb); + return index; + } +#endif // HWY_ARCH_X86_64 +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_ctzll(x)); +#endif // HWY_COMPILER_MSVC +} + +// Undefined results for x == 0. +HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x) { + HWY_DASSERT(x != 0); +#if HWY_COMPILER_MSVC + unsigned long index; // NOLINT + _BitScanReverse(&index, x); + return 31 - index; +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_clz(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t Num0BitsAboveMS1Bit_Nonzero64(const uint64_t x) { + HWY_DASSERT(x != 0); +#if HWY_COMPILER_MSVC +#if HWY_ARCH_X86_64 + unsigned long index; // NOLINT + _BitScanReverse64(&index, x); + return 63 - index; +#else // HWY_ARCH_X86_64 + // _BitScanReverse64 not available + const uint32_t msb = static_cast(x >> 32u); + unsigned long index; // NOLINT + if (msb == 0) { + const uint32_t lsb = static_cast(x & 0xFFFFFFFF); + _BitScanReverse(&index, lsb); + return 63 - index; + } else { + _BitScanReverse(&index, msb); + return 31 - index; + } +#endif // HWY_ARCH_X86_64 +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_clzll(x)); +#endif // HWY_COMPILER_MSVC +} + +template ), + HWY_IF_T_SIZE_ONE_OF(RemoveCvRef, (1 << 1) | (1 << 2) | (1 << 4))> +HWY_API size_t PopCount(T x) { + uint32_t u32_x = static_cast( + static_cast)>>(x)); + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + return static_cast(__builtin_popcountl(u32_x)); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) + return static_cast(_mm_popcnt_u32(u32_x)); +#else + u32_x -= ((u32_x >> 1) & 0x55555555u); + u32_x = (((u32_x >> 2) & 0x33333333u) + (u32_x & 0x33333333u)); + u32_x = (((u32_x >> 4) + u32_x) & 0x0F0F0F0Fu); + u32_x += (u32_x >> 8); + u32_x += (u32_x >> 16); + return static_cast(u32_x & 0x3Fu); +#endif +} + +template ), + HWY_IF_T_SIZE(RemoveCvRef, 8)> +HWY_API size_t PopCount(T x) { + uint64_t u64_x = static_cast( + static_cast)>>(x)); + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + return static_cast(__builtin_popcountll(u64_x)); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 && defined(__AVX__) + return _mm_popcnt_u64(u64_x); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) + return _mm_popcnt_u32(static_cast(u64_x & 0xFFFFFFFFu)) + + _mm_popcnt_u32(static_cast(u64_x >> 32)); +#else + u64_x -= ((u64_x >> 1) & 0x5555555555555555ULL); + u64_x = (((u64_x >> 2) & 0x3333333333333333ULL) + + (u64_x & 0x3333333333333333ULL)); + u64_x = (((u64_x >> 4) + u64_x) & 0x0F0F0F0F0F0F0F0FULL); + u64_x += (u64_x >> 8); + u64_x += (u64_x >> 16); + u64_x += (u64_x >> 32); + return static_cast(u64_x & 0x7Fu); +#endif +} + +// Skip HWY_API due to GCC "function not considered for inlining". Previously +// such errors were caused by underlying type mismatches, but it's not clear +// what is still mismatched despite all the casts. +template +/*HWY_API*/ constexpr size_t FloorLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast(FloorLog2(static_cast(x >> 1)) + 1); +} + +template +/*HWY_API*/ constexpr size_t CeilLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast(FloorLog2(static_cast(x - 1)) + 1); +} + +template +HWY_INLINE constexpr T AddWithWraparound(T t, T2 increment) { + return t + static_cast(increment); +} + +template +HWY_INLINE constexpr T AddWithWraparound(T t, T2 increment) { + return ConvertScalarTo(ConvertScalarTo(t) + + ConvertScalarTo(increment)); +} + +template +HWY_INLINE constexpr T AddWithWraparound(T t, T2 n) { + using TU = MakeUnsigned; + // Sub-int types would promote to int, not unsigned, which would trigger + // warnings, so first promote to the largest unsigned type. Due to + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87519, which affected GCC 8 + // until fixed in 9.3, we use built-in types rather than uint64_t. + return static_cast(static_cast( + static_cast(static_cast(t) + + static_cast(n)) & + uint64_t{hwy::LimitsMax()})); +} + +#if HWY_COMPILER_MSVC && HWY_ARCH_X86_64 +#pragma intrinsic(_mul128) +#pragma intrinsic(_umul128) +#endif + +// 64 x 64 = 128 bit multiplication +HWY_API uint64_t Mul128(uint64_t a, uint64_t b, uint64_t* HWY_RESTRICT upper) { +#if defined(__SIZEOF_INT128__) + __uint128_t product = (__uint128_t)a * (__uint128_t)b; + *upper = (uint64_t)(product >> 64); + return (uint64_t)(product & 0xFFFFFFFFFFFFFFFFULL); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 + return _umul128(a, b, upper); +#else + constexpr uint64_t kLo32 = 0xFFFFFFFFU; + const uint64_t lo_lo = (a & kLo32) * (b & kLo32); + const uint64_t hi_lo = (a >> 32) * (b & kLo32); + const uint64_t lo_hi = (a & kLo32) * (b >> 32); + const uint64_t hi_hi = (a >> 32) * (b >> 32); + const uint64_t t = (lo_lo >> 32) + (hi_lo & kLo32) + lo_hi; + *upper = (hi_lo >> 32) + (t >> 32) + hi_hi; + return (t << 32) | (lo_lo & kLo32); +#endif +} + +HWY_API int64_t Mul128(int64_t a, int64_t b, int64_t* HWY_RESTRICT upper) { +#if defined(__SIZEOF_INT128__) + __int128_t product = (__int128_t)a * (__int128_t)b; + *upper = (int64_t)(product >> 64); + return (int64_t)(product & 0xFFFFFFFFFFFFFFFFULL); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 + return _mul128(a, b, upper); +#else + uint64_t unsigned_upper; + const int64_t lower = static_cast(Mul128( + static_cast(a), static_cast(b), &unsigned_upper)); + *upper = static_cast( + unsigned_upper - + (static_cast(ScalarShr(a, 63)) & static_cast(b)) - + (static_cast(ScalarShr(b, 63)) & static_cast(a))); + return lower; +#endif +} + +// Precomputation for fast n / divisor and n % divisor, where n is a variable +// and divisor is unchanging but unknown at compile-time. +class Divisor { + public: + explicit Divisor(uint32_t divisor) : divisor_(divisor) { + if (divisor <= 1) return; + + const uint32_t len = + static_cast(31 - Num0BitsAboveMS1Bit_Nonzero32(divisor - 1)); + const uint64_t u_hi = (2ULL << len) - divisor; + const uint32_t q = Truncate((u_hi << 32) / divisor); + + mul_ = q + 1; + shift1_ = 1; + shift2_ = len; + } + + uint32_t GetDivisor() const { return divisor_; } + + // Returns n / divisor_. + uint32_t Divide(uint32_t n) const { + const uint64_t mul = mul_; + const uint32_t t = Truncate((mul * n) >> 32); + return (t + ((n - t) >> shift1_)) >> shift2_; + } + + // Returns n % divisor_. + uint32_t Remainder(uint32_t n) const { return n - (Divide(n) * divisor_); } + + private: + static uint32_t Truncate(uint64_t x) { + return static_cast(x & 0xFFFFFFFFu); + } + + uint32_t divisor_; + uint32_t mul_ = 1; + uint32_t shift1_ = 0; + uint32_t shift2_ = 0; +}; + +#ifndef HWY_HAVE_DIV128 // allow override +// Exclude clang-cl because it calls __divti3 from clang_rt.builtins-x86_64, +// which is not linked in. +#if (HWY_COMPILER_MSVC >= 1920 && HWY_ARCH_X86_64) || \ + (defined(__SIZEOF_INT128__) && !HWY_COMPILER_CLANGCL) +#define HWY_HAVE_DIV128 1 +#else +#define HWY_HAVE_DIV128 0 +#endif +#endif // HWY_HAVE_DIV128 + +// Divisor64 can precompute the multiplicative inverse. +#if HWY_HAVE_DIV128 + +#if HWY_COMPILER_MSVC >= 1920 && HWY_ARCH_X86_64 +#pragma intrinsic(_udiv128) +#pragma intrinsic(__umulh) +#endif + +// As above, but for 64-bit divisors: more expensive to compute and initialize. +class Divisor64 { + public: + explicit Divisor64(uint64_t divisor) : divisor_(divisor) { + if (divisor <= 1) return; + + const uint64_t len = + static_cast(63 - Num0BitsAboveMS1Bit_Nonzero64(divisor - 1)); + const uint64_t u_hi = (2ULL << len) - divisor; + const uint64_t q = Div128(u_hi, divisor); + + mul_ = q + 1; + shift1_ = 1; + shift2_ = len; + } + + uint64_t GetDivisor() const { return divisor_; } + + // Returns n / divisor_. + uint64_t Divide(uint64_t n) const { + const uint64_t t = MulHigh(mul_, n); + return (t + ((n - t) >> shift1_)) >> shift2_; + } + + // Returns n % divisor_. + uint64_t Remainder(uint64_t n) const { return n - (Divide(n) * divisor_); } + + private: + uint64_t divisor_; + + static uint64_t Div128(uint64_t hi, uint64_t div) { +#if HWY_COMPILER_MSVC >= 1920 && HWY_ARCH_X86_64 + unsigned __int64 remainder; // unused + return _udiv128(hi, uint64_t{0}, div, &remainder); +#else + using u128 = unsigned __int128; + const u128 hi128 = static_cast(hi) << 64; + return static_cast(hi128 / static_cast(div)); +#endif + } + + static uint64_t MulHigh(uint64_t a, uint64_t b) { +#if HWY_COMPILER_MSVC >= 1920 && HWY_ARCH_X86_64 + return __umulh(a, b); +#else + using u128 = unsigned __int128; + const u128 a128 = static_cast(a); + const u128 b128 = static_cast(b); + return static_cast((a128 * b128) >> 64); +#endif + } + + uint64_t mul_ = 1; + uint64_t shift1_ = 0; + uint64_t shift2_ = 0; +}; +#else +// No Div128 available, use built-in 64-bit division on each call. +class Divisor64 { + public: + explicit Divisor64(uint64_t divisor) : divisor_(divisor) {} + + uint64_t GetDivisor() const { return divisor_; } + + uint64_t Divide(uint64_t n) const { return n / divisor_; } + uint64_t Remainder(uint64_t n) const { return n % divisor_; } + + private: + uint64_t divisor_; +}; +#endif // HWY_HAVE_DIV128 + +namespace detail { + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T ScalarAbs(hwy::FloatTag /*tag*/, + T val) { + using TU = MakeUnsigned; + return BitCastScalar( + static_cast(BitCastScalar(val) & (~SignMask()))); +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T +ScalarAbs(hwy::SpecialTag /*tag*/, T val) { + return ScalarAbs(hwy::FloatTag(), val); +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T +ScalarAbs(hwy::SignedTag /*tag*/, T val) { + using TU = MakeUnsigned; + return (val < T{0}) ? static_cast(TU{0} - static_cast(val)) : val; +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T +ScalarAbs(hwy::UnsignedTag /*tag*/, T val) { + return val; +} + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR RemoveCvRef ScalarAbs(T val) { + using TVal = MakeLaneTypeIfInteger< + detail::NativeSpecialFloatToWrapper>>; + return detail::ScalarAbs(hwy::TypeTag(), static_cast(val)); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsNaN(T val) { + using TF = detail::NativeSpecialFloatToWrapper>; + using TU = MakeUnsigned; + return (BitCastScalar(ScalarAbs(val)) > ExponentMask()); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsInf(T val) { + using TF = detail::NativeSpecialFloatToWrapper>; + using TU = MakeUnsigned; + return static_cast(BitCastScalar(static_cast(val)) << 1) == + static_cast(MaxExponentTimes2()); +} + +namespace detail { + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite( + hwy::FloatTag /*tag*/, T val) { + using TU = MakeUnsigned; + return (BitCastScalar(hwy::ScalarAbs(val)) < ExponentMask()); +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite( + hwy::NonFloatTag /*tag*/, T /*val*/) { + // Integer values are always finite + return true; +} + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite(T val) { + using TVal = MakeLaneTypeIfInteger< + detail::NativeSpecialFloatToWrapper>>; + return detail::ScalarIsFinite(hwy::IsFloatTag(), + static_cast(val)); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR RemoveCvRef ScalarCopySign(T magn, + T sign) { + using TF = RemoveCvRef>>; + using TU = MakeUnsigned; + return BitCastScalar(static_cast( + (BitCastScalar(static_cast(magn)) & (~SignMask())) | + (BitCastScalar(static_cast(sign)) & SignMask()))); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarSignBit(T val) { + using TVal = MakeLaneTypeIfInteger< + detail::NativeSpecialFloatToWrapper>>; + using TU = MakeUnsigned; + return ((BitCastScalar(static_cast(val)) & SignMask()) != 0); +} + +// Prevents the compiler from eliding the computations that led to "output". +#if HWY_ARCH_PPC && (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && \ + !defined(_SOFT_FLOAT) +// Workaround to avoid test failures on PPC if compiled with Clang +template +HWY_API void PreventElision(T&& output) { + asm volatile("" : "+f"(output)::"memory"); +} +template +HWY_API void PreventElision(T&& output) { + asm volatile("" : "+d"(output)::"memory"); +} +template +HWY_API void PreventElision(T&& output) { + asm volatile("" : "+r"(output)::"memory"); +} +#else +template +HWY_API void PreventElision(T&& output) { +#if HWY_COMPILER_MSVC + // MSVC does not support inline assembly anymore (and never supported GCC's + // RTL constraints). Self-assignment with #pragma optimize("off") might be + // expected to prevent elision, but it does not with MSVC 2015. Type-punning + // with volatile pointers generates inefficient code on MSVC 2017. + static std::atomic> sink; + sink.store(output, std::memory_order_relaxed); +#else + // Works by indicating to the compiler that "output" is being read and + // modified. The +r constraint avoids unnecessary writes to memory, but only + // works for built-in types (typically FuncOutput). + asm volatile("" : "+r"(output) : : "memory"); +#endif +} +#endif + +} // namespace hwy + +#endif // HIGHWAY_HWY_BASE_H_ diff --git a/third_party/aom/third_party/highway/hwy/bit_set.h b/third_party/aom/third_party/highway/hwy/bit_set.h new file mode 100644 index 000000000000..b74741652775 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/bit_set.h @@ -0,0 +1,158 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_BIT_SET_H_ +#define HIGHWAY_HWY_BIT_SET_H_ + +// BitSet with fast Foreach for up to 64 and 4096 members. + +#include + +#include "third_party/highway/hwy/base.h" + +namespace hwy { + +// 64-bit specialization of std::bitset, which lacks Foreach. +class BitSet64 { + public: + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < 64); + bits_ |= (1ULL << i); + HWY_DASSERT(Get(i)); + } + + // Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does + // not clear any existing bits. + void SetNonzeroBitsFrom64(uint64_t bits) { bits_ |= bits; } + + void Clear(size_t i) { + HWY_DASSERT(i < 64); + bits_ &= ~(1ULL << i); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < 64); + return (bits_ & (1ULL << i)) != 0; + } + + // Returns true if any Get(i) would return true for i in [0, 64). + bool Any() const { return bits_ != 0; } + + // Returns lowest i such that Get(i). Caller must ensure Any() beforehand! + size_t First() const { + HWY_DASSERT(Any()); + return Num0BitsBelowLS1Bit_Nonzero64(bits_); + } + + // Returns uint64_t(Get(i)) << i for i in [0, 64). + uint64_t Get64() const { return bits_; } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is unaffected. + template + void Foreach(const Func& func) const { + uint64_t remaining_bits = bits_; + while (remaining_bits != 0) { + const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits); + remaining_bits &= remaining_bits - 1; // clear LSB + func(i); + } + } + + size_t Count() const { return PopCount(bits_); } + + private: + uint64_t bits_ = 0; +}; + +// Two-level bitset for up to kMaxSize <= 4096 values. +template +class BitSet4096 { + public: + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Set(mod); + nonzero_.Set(idx); + HWY_DASSERT(Get(i)); + } + + // Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does + // not clear any existing bits. + void SetNonzeroBitsFrom64(uint64_t bits) { + bits_[0].SetNonzeroBitsFrom64(bits); + if (bits) nonzero_.Set(0); + } + + void Clear(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Clear(mod); + if (!bits_[idx].Any()) { + nonzero_.Clear(idx); + } + HWY_DASSERT(!Get(i)); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + return bits_[idx].Get(mod); + } + + // Returns true if any Get(i) would return true for i in [0, 64). + bool Any() const { return nonzero_.Any(); } + + // Returns lowest i such that Get(i). Caller must ensure Any() beforehand! + size_t First() const { + HWY_DASSERT(Any()); + const size_t idx = nonzero_.First(); + return idx * 64 + bits_[idx].First(); + } + + // Returns uint64_t(Get(i)) << i for i in [0, 64). + uint64_t Get64() const { return bits_[0].Get64(); } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is only affected if changing one of + // the not yet visited BitSet64 for which Any() is true. + template + void Foreach(const Func& func) const { + nonzero_.Foreach([&func, this](size_t idx) { + bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); }); + }); + } + + size_t Count() const { + size_t total = 0; + nonzero_.Foreach( + [&total, this](size_t idx) { total += bits_[idx].Count(); }); + return total; + } + + private: + static_assert(kMaxSize <= 64 * 64, "One BitSet64 insufficient"); + BitSet64 nonzero_; + BitSet64 bits_[kMaxSize / 64]; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_BIT_SET_H_ diff --git a/third_party/aom/third_party/highway/hwy/cache_control.h b/third_party/aom/third_party/highway/hwy/cache_control.h new file mode 100644 index 000000000000..b3bf5a8323e4 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/cache_control.h @@ -0,0 +1,126 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CACHE_CONTROL_H_ +#define HIGHWAY_HWY_CACHE_CONTROL_H_ + +#include "third_party/highway/hwy/base.h" + +// Requires SSE2; fails to compile on 32-bit Clang 7 (see +// https://github.com/gperftools/gperftools/issues/946). +#if !defined(__SSE2__) || (HWY_COMPILER_CLANG && HWY_ARCH_X86_32) +#undef HWY_DISABLE_CACHE_CONTROL +#define HWY_DISABLE_CACHE_CONTROL +#endif + +#ifndef HWY_DISABLE_CACHE_CONTROL +// intrin.h is sufficient on MSVC and already included by base.h. +#if HWY_ARCH_X86 && !HWY_COMPILER_MSVC +#include // SSE2 +#include // _mm_prefetch +#elif HWY_ARCH_ARM_A64 +#include +#endif +#endif // HWY_DISABLE_CACHE_CONTROL + +namespace hwy { + +// Even if N*sizeof(T) is smaller, Stream may write a multiple of this size. +#define HWY_STREAM_MULTIPLE 16 + +// The following functions may also require an attribute. +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC +#define HWY_ATTR_CACHE __attribute__((target("sse2"))) +#else +#define HWY_ATTR_CACHE +#endif + +// Windows.h #defines this, which causes infinite recursion. Temporarily +// undefine to avoid conflict with our function. +// TODO(janwas): remove when this function is removed. +#pragma push_macro("LoadFence") +#undef LoadFence + +// Delays subsequent loads until prior loads are visible. Beware of potentially +// differing behavior across architectures and vendors: on Intel but not +// AMD CPUs, also serves as a full fence (waits for all prior instructions to +// complete). +HWY_INLINE HWY_ATTR_CACHE void LoadFence() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_lfence(); +#endif +} + +// TODO(janwas): remove when this function is removed. (See above.) +#pragma pop_macro("LoadFence") + +// Ensures values written by previous `Stream` calls are visible on the current +// core. This is NOT sufficient for synchronizing across cores; when `Stream` +// outputs are to be consumed by other core(s), the producer must publish +// availability (e.g. via mutex or atomic_flag) after `FlushStream`. +HWY_INLINE HWY_ATTR_CACHE void FlushStream() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_sfence(); +#endif +} + +// Optionally begins loading the cache line containing "p" to reduce latency of +// subsequent actual loads. +template +HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T* p) { + (void)p; +#ifndef HWY_DISABLE_CACHE_CONTROL +#if HWY_ARCH_X86 + _mm_prefetch(reinterpret_cast(p), _MM_HINT_T0); +#elif HWY_COMPILER_GCC // includes clang + // Hint=0 (NTA) behavior differs, but skipping outer caches is probably not + // desirable, so use the default 3 (keep in caches). + __builtin_prefetch(p, /*write=*/0, /*hint=*/3); +#endif +#endif // HWY_DISABLE_CACHE_CONTROL +} + +// Invalidates and flushes the cache line containing "p", if possible. +HWY_INLINE HWY_ATTR_CACHE void FlushCacheline(const void* p) { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_clflush(p); +#else + (void)p; +#endif +} + +// Hints that we are inside a spin loop and potentially reduces power +// consumption and coherency traffic. For example, x86 avoids multiple +// outstanding load requests, which reduces the memory order violation penalty +// when exiting the loop. +HWY_INLINE HWY_ATTR_CACHE void Pause() { +#ifndef HWY_DISABLE_CACHE_CONTROL +#if HWY_ARCH_X86 + _mm_pause(); +#elif HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG + // This is documented in ACLE and the YIELD instruction is also available in + // Armv7, but the intrinsic is broken for Armv7 clang, hence A64 only. + __yield(); +#elif HWY_ARCH_ARM && HWY_COMPILER_GCC // includes clang + __asm__ volatile("yield" ::: "memory"); +#elif HWY_ARCH_PPC && HWY_COMPILER_GCC // includes clang + __asm__ volatile("or 27,27,27" ::: "memory"); +#endif +#endif // HWY_DISABLE_CACHE_CONTROL +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_CACHE_CONTROL_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/algo/copy-inl.h b/third_party/aom/third_party/highway/hwy/contrib/algo/copy-inl.h new file mode 100644 index 000000000000..a4411d8f62dc --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/algo/copy-inl.h @@ -0,0 +1,145 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#endif + +#include +#include + +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These functions avoid having to write a loop plus remainder handling in the +// (unfortunately still common) case where arrays are not aligned/padded. If the +// inputs are known to be aligned/padded, it is more efficient to write a single +// loop using Load(). We do not provide a CopyAlignedPadded because it +// would be more verbose than such a loop. + +// Fills `to`[0, `count`) with `value`. +template > +void Fill(D d, T value, size_t count, T* HWY_RESTRICT to) { + const size_t N = Lanes(d); + const Vec v = Set(d, value); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + StoreU(v, d, to + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + SafeFillN(remaining, value, d, to + idx); +} + +// Copies `from`[0, `count`) to `to`, which must not overlap `from`. +template > +void Copy(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, from + idx); + StoreU(v, d, to + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + SafeCopyN(remaining, d, from + idx, to + idx); +} + +// For idx in [0, count) in ascending order, appends `from[idx]` to `to` if the +// corresponding mask element of `func(d, v)` is true. Returns the STL-style end +// of the newly written elements in `to`. +// +// `func` is either a functor with a templated operator()(d, v) returning a +// mask, or a generic lambda if using C++14. Due to apparent limitations of +// Clang on Windows, it is currently necessary to add HWY_ATTR before the +// opening { of the lambda to avoid errors about "function .. requires target". +// +// NOTE: this is only supported for 16-, 32- or 64-bit types. +// NOTE: Func may be called a second time for elements it has already seen, but +// these elements will not be written to `to` again. +template > +T* CopyIf(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, from + idx); + to += CompressBlendedStore(v, func(d, v), d, to); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return to; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag d1; + for (; idx < count; ++idx) { + using V1 = Vec; + // Workaround for -Waggressive-loop-optimizations on GCC 8 + // (iteration 2305843009213693951 invokes undefined behavior for T=i64) + const uintptr_t addr = reinterpret_cast(from); + const T* HWY_RESTRICT from_idx = + reinterpret_cast(addr + (idx * sizeof(T))); + const V1 v = LoadU(d1, from_idx); + // Avoid storing to `to` unless we know it should be kept - otherwise, we + // might overrun the end if it was allocated for the exact count. + if (CountTrue(d1, func(d1, v)) == 0) continue; + StoreU(v, d1, to); + to += 1; + } +#else + // Start index of the last unaligned whole vector, ending at the array end. + const size_t last = count - N; + // Number of elements before `from` or already written. + const size_t invalid = idx - last; + HWY_DASSERT(0 != invalid && invalid < N); + const Mask mask = Not(FirstN(d, invalid)); + const Vec v = MaskedLoad(mask, d, from + last); + to += CompressBlendedStore(v, And(mask, func(d, v)), d, to); +#endif + return to; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/algo/find-inl.h b/third_party/aom/third_party/highway/hwy/contrib/algo/find-inl.h new file mode 100644 index 000000000000..e1c2c1f0f2f0 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/algo/find-inl.h @@ -0,0 +1,113 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#endif + +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns index of the first element equal to `value` in `in[0, count)`, or +// `count` if not found. +template > +size_t Find(D d, T value, const T* HWY_RESTRICT in, size_t count) { + const size_t N = Lanes(d); + const Vec broadcasted = Set(d, value); + + size_t i = 0; + if (count >= N) { + for (; i <= count - N; i += N) { + const intptr_t pos = FindFirstTrue(d, Eq(broadcasted, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast(pos); + } + } + + if (i != count) { +#if HWY_MEM_OPS_MIGHT_FAULT + // Scan single elements. + const CappedTag d1; + using V1 = Vec; + const V1 broadcasted1 = Set(d1, GetLane(broadcasted)); + for (; i < count; ++i) { + if (AllTrue(d1, Eq(broadcasted1, LoadU(d1, in + i)))) { + return i; + } + } +#else + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, in + i); + // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. + const intptr_t pos = FindFirstTrue(d, And(Eq(broadcasted, v), mask)); + if (pos >= 0) return i + static_cast(pos); +#endif // HWY_MEM_OPS_MIGHT_FAULT + } + + return count; // not found +} + +// Returns index of the first element in `in[0, count)` for which `func(d, vec)` +// returns true, otherwise `count`. +template > +size_t FindIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) { + const size_t N = Lanes(d); + + size_t i = 0; + if (count >= N) { + for (; i <= count - N; i += N) { + const intptr_t pos = FindFirstTrue(d, func(d, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast(pos); + } + } + + if (i != count) { +#if HWY_MEM_OPS_MIGHT_FAULT + // Scan single elements. + const CappedTag d1; + for (; i < count; ++i) { + if (AllTrue(d1, func(d1, LoadU(d1, in + i)))) { + return i; + } + } +#else + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, in + i); + // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. + const intptr_t pos = FindFirstTrue(d, And(func(d, v), mask)); + if (pos >= 0) return i + static_cast(pos); +#endif // HWY_MEM_OPS_MIGHT_FAULT + } + + return count; // not found +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/algo/transform-inl.h b/third_party/aom/third_party/highway/hwy/contrib/algo/transform-inl.h new file mode 100644 index 000000000000..9310a32c0ddf --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/algo/transform-inl.h @@ -0,0 +1,228 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#endif + +#include + +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These functions avoid having to write a loop plus remainder handling in the +// (unfortunately still common) case where arrays are not aligned/padded. If the +// inputs are known to be aligned/padded, it is more efficient to write a single +// loop using Load(). We do not provide a TransformAlignedPadded because it +// would be more verbose than such a loop. +// +// Func is either a functor with a templated operator()(d, v[, v1[, v2]]), or a +// generic lambda if using C++14. The d argument is the same as was passed to +// the Generate etc. functions. Due to apparent limitations of Clang, it is +// currently necessary to add HWY_ATTR before the opening { of the lambda to +// avoid errors about "always_inline function .. requires target". +// +// We do not check HWY_MEM_OPS_MIGHT_FAULT because LoadN/StoreN do not fault. + +// Fills `out[0, count)` with the vectors returned by `func(d, index_vec)`, +// where `index_vec` is `Vec>`. On the first call to `func`, +// the value of its lane i is i, and increases by `Lanes(d)` after every call. +// Note that some of these indices may be `>= count`, but the elements that +// `func` returns in those lanes will not be written to `out`. +template > +void Generate(D d, T* HWY_RESTRICT out, size_t count, const Func& func) { + const RebindToUnsigned du; + using TU = TFromD; + const size_t N = Lanes(d); + + size_t idx = 0; + Vec vidx = Iota(du, 0); + if (count >= N) { + for (; idx <= count - N; idx += N) { + StoreU(func(d, vidx), d, out + idx); + vidx = Add(vidx, Set(du, static_cast(N))); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + StoreN(func(d, vidx), d, out + idx, remaining); +} + +// Calls `func(d, v)` for each input vector; out of bound lanes with index i >= +// `count` are instead taken from `no[i % Lanes(d)]`. +template > +void Foreach(D d, const T* HWY_RESTRICT in, const size_t count, const Vec no, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, in + idx); + func(d, v); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadNOr(no, d, in + idx, remaining); + func(d, v); +} + +// Replaces `inout[idx]` with `func(d, inout[idx])`. Example usage: multiplying +// array elements by a constant. +template > +void Transform(D d, T* HWY_RESTRICT inout, size_t count, const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, inout + idx); + StoreU(func(d, v), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + StoreN(func(d, v), d, inout + idx, remaining); +} + +// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx])`. Example usage: +// multiplying array elements by those of another array. +template > +void Transform1(D d, T* HWY_RESTRICT inout, size_t count, + const T* HWY_RESTRICT in1, const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, inout + idx); + const Vec v1 = LoadU(d, in1 + idx); + StoreU(func(d, v, v1), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + const Vec v1 = LoadN(d, in1 + idx, remaining); + StoreN(func(d, v, v1), d, inout + idx, remaining); +} + +// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx], in2[idx])`. Example +// usage: FMA of elements from three arrays, stored into the first array. +template > +void Transform2(D d, T* HWY_RESTRICT inout, size_t count, + const T* HWY_RESTRICT in1, const T* HWY_RESTRICT in2, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, inout + idx); + const Vec v1 = LoadU(d, in1 + idx); + const Vec v2 = LoadU(d, in2 + idx); + StoreU(func(d, v, v1, v2), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + const Vec v1 = LoadN(d, in1 + idx, remaining); + const Vec v2 = LoadN(d, in2 + idx, remaining); + StoreN(func(d, v, v1, v2), d, inout + idx, remaining); +} + +template > +void Replace(D d, T* HWY_RESTRICT inout, size_t count, T new_t, T old_t) { + const size_t N = Lanes(d); + const Vec old_v = Set(d, old_t); + const Vec new_v = Set(d, new_t); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + Vec v = LoadU(d, inout + idx); + StoreU(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + StoreN(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx, remaining); +} + +template > +void ReplaceIf(D d, T* HWY_RESTRICT inout, size_t count, T new_t, + const Func& func) { + const size_t N = Lanes(d); + const Vec new_v = Set(d, new_t); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + Vec v = LoadU(d, inout + idx); + StoreU(IfThenElse(func(d, v), new_v, v), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + StoreN(IfThenElse(func(d, v), new_v, v), d, inout + idx, remaining); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h b/third_party/aom/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h new file mode 100644 index 000000000000..0b3902e0ec3a --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h @@ -0,0 +1,2851 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "third_party/highway/hwy/base.h" + +// Per-target include guard +// clang-format off +#if defined(HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +// clang-format on +#ifdef HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ +#endif + +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// The entry points are class templates specialized below for each number of +// bits. Each provides Pack and Unpack member functions which load (Pack) or +// store (Unpack) B raw vectors, and store (Pack) or load (Unpack) a number of +// packed vectors equal to kBits. B denotes the bits per lane: 8 for Pack8, 16 +// for Pack16, 32 for Pack32 which is also the upper bound for kBits. +template // <= 8 +struct Pack8 {}; +template // <= 16 +struct Pack16 {}; + +template <> +struct Pack8<1> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + // 16-bit shifts avoid masking (bits will not cross 8-bit lanes). + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 packed = + Xor3(Or(ShiftLeft<7>(raw7), ShiftLeft<6>(raw6)), + Xor3(ShiftLeft<5>(raw5), ShiftLeft<4>(raw4), ShiftLeft<3>(raw3)), + Xor3(ShiftLeft<2>(raw2), ShiftLeft<1>(raw1), raw0)); + StoreU(BitCast(d8, packed), d8, packed_out); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0101u); // LSB in each byte + + const VU16 packed = BitCast(d16, LoadU(d8, packed_in)); + + const VU16 raw0 = And(packed, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(ShiftRight<1>(packed), mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(ShiftRight<2>(packed), mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(ShiftRight<3>(packed), mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(ShiftRight<4>(packed), mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(ShiftRight<5>(packed), mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<6>(packed), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 raw7 = And(ShiftRight<7>(packed), mask); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<1> + +template <> +struct Pack8<2> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + // 16-bit shifts avoid masking (bits will not cross 8-bit lanes). + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 packed0 = Xor3(ShiftLeft<6>(raw6), ShiftLeft<4>(raw4), + Or(ShiftLeft<2>(raw2), raw0)); + const VU16 packed1 = Xor3(ShiftLeft<6>(raw7), ShiftLeft<4>(raw5), + Or(ShiftLeft<2>(raw3), raw1)); + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0303u); // Lowest 2 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(ShiftRight<2>(packed0), mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(ShiftRight<2>(packed1), mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(ShiftRight<4>(packed0), mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(ShiftRight<4>(packed1), mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<6>(packed0), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 raw7 = And(ShiftRight<6>(packed1), mask); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<2> + +template <> +struct Pack8<3> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + // The upper two bits of these three will be filled with packed3 (6 bits). + VU16 packed0 = Or(ShiftLeft<3>(raw4), raw0); + VU16 packed1 = Or(ShiftLeft<3>(raw5), raw1); + VU16 packed2 = Or(ShiftLeft<3>(raw6), raw2); + const VU16 packed3 = Or(ShiftLeft<3>(raw7), raw3); + + const VU16 hi2 = Set(d16, 0xC0C0u); + packed0 = OrAnd(packed0, ShiftLeft<2>(packed3), hi2); + packed1 = OrAnd(packed1, ShiftLeft<4>(packed3), hi2); + packed2 = OrAnd(packed2, ShiftLeft<6>(packed3), hi2); + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0707u); // Lowest 3 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw4 = And(ShiftRight<3>(packed0), mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(ShiftRight<3>(packed1), mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<3>(packed2), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + // raw73 is the concatenation of the upper two bits in packed0..2. + const VU16 hi2 = Set(d16, 0xC0C0u); + const VU16 raw73 = Xor3(ShiftRight<6>(And(packed2, hi2)), // + ShiftRight<4>(And(packed1, hi2)), + ShiftRight<2>(And(packed0, hi2))); + + const VU16 raw3 = And(mask, raw73); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw7 = And(mask, ShiftRight<3>(raw73)); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<3> + +template <> +struct Pack8<4> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + // 16-bit shifts avoid masking (bits will not cross 8-bit lanes). + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 packed0 = Or(ShiftLeft<4>(raw2), raw0); + const VU16 packed1 = Or(ShiftLeft<4>(raw3), raw1); + const VU16 packed2 = Or(ShiftLeft<4>(raw6), raw4); + const VU16 packed3 = Or(ShiftLeft<4>(raw7), raw5); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0F0Fu); // Lowest 4 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(ShiftRight<4>(packed0), mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(ShiftRight<4>(packed1), mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(packed2, mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(packed3, mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<4>(packed2), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 raw7 = And(ShiftRight<4>(packed3), mask); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<4> + +template <> +struct Pack8<5> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + // Fill upper three bits with upper bits from raw4..7. + const VU16 hi3 = Set(d16, 0xE0E0u); + const VU16 packed0 = OrAnd(raw0, ShiftLeft<3>(raw4), hi3); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<3>(raw5), hi3); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<3>(raw6), hi3); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<3>(raw7), hi3); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + + // Combine lower two bits of raw4..7 into packed4. + const VU16 lo2 = Set(d16, 0x0303u); + const VU16 packed4 = Or(And(raw4, lo2), Xor3(ShiftLeft<2>(And(raw5, lo2)), + ShiftLeft<4>(And(raw6, lo2)), + ShiftLeft<6>(And(raw7, lo2)))); + StoreU(BitCast(d8, packed4), d8, packed_out + 4 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + const VU16 packed4 = BitCast(d16, LoadU(d8, packed_in + 4 * N8)); + + const VU16 mask = Set(d16, 0x1F1Fu); // Lowest 5 bits per byte + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(packed3, mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + // The upper bits are the top 3 bits shifted right by three. + const VU16 top4 = ShiftRight<3>(AndNot(mask, packed0)); + const VU16 top5 = ShiftRight<3>(AndNot(mask, packed1)); + const VU16 top6 = ShiftRight<3>(AndNot(mask, packed2)); + const VU16 top7 = ShiftRight<3>(AndNot(mask, packed3)); + + // Insert the lower 2 bits, which were concatenated into a byte. + const VU16 lo2 = Set(d16, 0x0303u); + const VU16 raw4 = OrAnd(top4, lo2, packed4); + const VU16 raw5 = OrAnd(top5, lo2, ShiftRight<2>(packed4)); + const VU16 raw6 = OrAnd(top6, lo2, ShiftRight<4>(packed4)); + const VU16 raw7 = OrAnd(top7, lo2, ShiftRight<6>(packed4)); + + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<5> + +template <> +struct Pack8<6> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 hi2 = Set(d16, 0xC0C0u); + // Each triplet of these stores raw3/raw7 (6 bits) in the upper 2 bits. + const VU16 packed0 = OrAnd(raw0, ShiftLeft<2>(raw3), hi2); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<4>(raw3), hi2); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<6>(raw3), hi2); + const VU16 packed3 = OrAnd(raw4, ShiftLeft<2>(raw7), hi2); + const VU16 packed4 = OrAnd(raw5, ShiftLeft<4>(raw7), hi2); + const VU16 packed5 = OrAnd(raw6, ShiftLeft<6>(raw7), hi2); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + StoreU(BitCast(d8, packed4), d8, packed_out + 4 * N8); + StoreU(BitCast(d8, packed5), d8, packed_out + 5 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x3F3Fu); // Lowest 6 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + const VU16 packed4 = BitCast(d16, LoadU(d8, packed_in + 4 * N8)); + const VU16 packed5 = BitCast(d16, LoadU(d8, packed_in + 5 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw4 = And(packed3, mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(packed4, mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(packed5, mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + // raw3/7 are the concatenation of the upper two bits in packed0..2. + const VU16 raw3 = Xor3(ShiftRight<6>(AndNot(mask, packed2)), + ShiftRight<4>(AndNot(mask, packed1)), + ShiftRight<2>(AndNot(mask, packed0))); + const VU16 raw7 = Xor3(ShiftRight<6>(AndNot(mask, packed5)), + ShiftRight<4>(AndNot(mask, packed4)), + ShiftRight<2>(AndNot(mask, packed3))); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<6> + +template <> +struct Pack8<7> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + // Inserted into top bit of packed0..6. + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 hi1 = Set(d16, 0x8080u); + const VU16 packed0 = OrAnd(raw0, Add(raw7, raw7), hi1); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<2>(raw7), hi1); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<3>(raw7), hi1); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<4>(raw7), hi1); + const VU16 packed4 = OrAnd(raw4, ShiftLeft<5>(raw7), hi1); + const VU16 packed5 = OrAnd(raw5, ShiftLeft<6>(raw7), hi1); + const VU16 packed6 = OrAnd(raw6, ShiftLeft<7>(raw7), hi1); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + StoreU(BitCast(d8, packed4), d8, packed_out + 4 * N8); + StoreU(BitCast(d8, packed5), d8, packed_out + 5 * N8); + StoreU(BitCast(d8, packed6), d8, packed_out + 6 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + const VU16 packed4 = BitCast(d16, LoadU(d8, packed_in + 4 * N8)); + const VU16 packed5 = BitCast(d16, LoadU(d8, packed_in + 5 * N8)); + const VU16 packed6 = BitCast(d16, LoadU(d8, packed_in + 6 * N8)); + + const VU16 mask = Set(d16, 0x7F7Fu); // Lowest 7 bits per byte + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(packed3, mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(packed4, mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(packed5, mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(packed6, mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 p0 = Xor3(ShiftRight<7>(AndNot(mask, packed6)), + ShiftRight<6>(AndNot(mask, packed5)), + ShiftRight<5>(AndNot(mask, packed4))); + const VU16 p1 = Xor3(ShiftRight<4>(AndNot(mask, packed3)), + ShiftRight<3>(AndNot(mask, packed2)), + ShiftRight<2>(AndNot(mask, packed1))); + const VU16 raw7 = Xor3(ShiftRight<1>(AndNot(mask, packed0)), p0, p1); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<7> + +template <> +struct Pack8<8> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + using VU8 = Vec; + const size_t N8 = Lanes(d8); + const VU8 raw0 = LoadU(d8, raw + 0 * N8); + const VU8 raw1 = LoadU(d8, raw + 1 * N8); + const VU8 raw2 = LoadU(d8, raw + 2 * N8); + const VU8 raw3 = LoadU(d8, raw + 3 * N8); + const VU8 raw4 = LoadU(d8, raw + 4 * N8); + const VU8 raw5 = LoadU(d8, raw + 5 * N8); + const VU8 raw6 = LoadU(d8, raw + 6 * N8); + const VU8 raw7 = LoadU(d8, raw + 7 * N8); + + StoreU(raw0, d8, packed_out + 0 * N8); + StoreU(raw1, d8, packed_out + 1 * N8); + StoreU(raw2, d8, packed_out + 2 * N8); + StoreU(raw3, d8, packed_out + 3 * N8); + StoreU(raw4, d8, packed_out + 4 * N8); + StoreU(raw5, d8, packed_out + 5 * N8); + StoreU(raw6, d8, packed_out + 6 * N8); + StoreU(raw7, d8, packed_out + 7 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + using VU8 = Vec; + const size_t N8 = Lanes(d8); + const VU8 raw0 = LoadU(d8, packed_in + 0 * N8); + const VU8 raw1 = LoadU(d8, packed_in + 1 * N8); + const VU8 raw2 = LoadU(d8, packed_in + 2 * N8); + const VU8 raw3 = LoadU(d8, packed_in + 3 * N8); + const VU8 raw4 = LoadU(d8, packed_in + 4 * N8); + const VU8 raw5 = LoadU(d8, packed_in + 5 * N8); + const VU8 raw6 = LoadU(d8, packed_in + 6 * N8); + const VU8 raw7 = LoadU(d8, packed_in + 7 * N8); + + StoreU(raw0, d8, raw + 0 * N8); + StoreU(raw1, d8, raw + 1 * N8); + StoreU(raw2, d8, raw + 2 * N8); + StoreU(raw3, d8, raw + 3 * N8); + StoreU(raw4, d8, raw + 4 * N8); + StoreU(raw5, d8, raw + 5 * N8); + StoreU(raw6, d8, raw + 6 * N8); + StoreU(raw7, d8, raw + 7 * N8); + } +}; // Pack8<8> + +template <> +struct Pack16<1> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + const VU16 p0 = Xor3(ShiftLeft<2>(raw2), Add(raw1, raw1), raw0); + const VU16 p1 = + Xor3(ShiftLeft<5>(raw5), ShiftLeft<4>(raw4), ShiftLeft<3>(raw3)); + const VU16 p2 = + Xor3(ShiftLeft<8>(raw8), ShiftLeft<7>(raw7), ShiftLeft<6>(raw6)); + const VU16 p3 = + Xor3(ShiftLeft<0xB>(rawB), ShiftLeft<0xA>(rawA), ShiftLeft<9>(raw9)); + const VU16 p4 = + Xor3(ShiftLeft<0xE>(rawE), ShiftLeft<0xD>(rawD), ShiftLeft<0xC>(rawC)); + const VU16 packed = + Or(Xor3(ShiftLeft<0xF>(rawF), p0, p1), Xor3(p2, p3, p4)); + StoreU(packed, d, packed_out); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 1u); // Lowest bit + + const VU16 packed = LoadU(d, packed_in); + + const VU16 raw0 = And(packed, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(ShiftRight<1>(packed), mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(ShiftRight<2>(packed), mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(ShiftRight<3>(packed), mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(ShiftRight<4>(packed), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<5>(packed), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<6>(packed), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(ShiftRight<7>(packed), mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(ShiftRight<8>(packed), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<9>(packed), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<0xA>(packed), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<0xB>(packed), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<0xC>(packed), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<0xD>(packed), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<0xE>(packed), mask); + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<0xF>(packed); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<1> + +template <> +struct Pack16<2> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + VU16 packed0 = Xor3(ShiftLeft<4>(raw4), ShiftLeft<2>(raw2), raw0); + VU16 packed1 = Xor3(ShiftLeft<4>(raw5), ShiftLeft<2>(raw3), raw1); + packed0 = Xor3(packed0, ShiftLeft<8>(raw8), ShiftLeft<6>(raw6)); + packed1 = Xor3(packed1, ShiftLeft<8>(raw9), ShiftLeft<6>(raw7)); + + packed0 = Xor3(packed0, ShiftLeft<12>(rawC), ShiftLeft<10>(rawA)); + packed1 = Xor3(packed1, ShiftLeft<12>(rawD), ShiftLeft<10>(rawB)); + + packed0 = Or(packed0, ShiftLeft<14>(rawE)); + packed1 = Or(packed1, ShiftLeft<14>(rawF)); + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0x3u); // Lowest 2 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(ShiftRight<2>(packed0), mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(ShiftRight<2>(packed1), mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(ShiftRight<4>(packed0), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<4>(packed1), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<6>(packed0), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(ShiftRight<6>(packed1), mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(ShiftRight<8>(packed0), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<8>(packed1), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<0xA>(packed0), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<0xA>(packed1), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<0xC>(packed0), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<0xC>(packed1), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = ShiftRight<0xE>(packed0); + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<0xE>(packed1); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<2> + +template <> +struct Pack16<3> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // We can fit 15 raw vectors in three packed vectors (five each). + VU16 packed0 = Xor3(ShiftLeft<6>(raw6), ShiftLeft<3>(raw3), raw0); + VU16 packed1 = Xor3(ShiftLeft<6>(raw7), ShiftLeft<3>(raw4), raw1); + VU16 packed2 = Xor3(ShiftLeft<6>(raw8), ShiftLeft<3>(raw5), raw2); + + // rawF will be scattered into the upper bit of these three. + packed0 = Xor3(packed0, ShiftLeft<12>(rawC), ShiftLeft<9>(raw9)); + packed1 = Xor3(packed1, ShiftLeft<12>(rawD), ShiftLeft<9>(rawA)); + packed2 = Xor3(packed2, ShiftLeft<12>(rawE), ShiftLeft<9>(rawB)); + + const VU16 hi1 = Set(d, 0x8000u); + packed0 = Or(packed0, ShiftLeft<15>(rawF)); // MSB only, no mask + packed1 = OrAnd(packed1, ShiftLeft<14>(rawF), hi1); + packed2 = OrAnd(packed2, ShiftLeft<13>(rawF), hi1); + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0x7u); // Lowest 3 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + + const VU16 raw0 = And(mask, packed0); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(mask, packed1); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(mask, packed2); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(mask, ShiftRight<3>(packed0)); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(mask, ShiftRight<3>(packed1)); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(mask, ShiftRight<3>(packed2)); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(mask, ShiftRight<6>(packed0)); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(mask, ShiftRight<6>(packed1)); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(mask, ShiftRight<6>(packed2)); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(mask, ShiftRight<9>(packed0)); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(mask, ShiftRight<9>(packed1)); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(mask, ShiftRight<9>(packed2)); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(mask, ShiftRight<12>(packed0)); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(mask, ShiftRight<12>(packed1)); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(mask, ShiftRight<12>(packed2)); + StoreU(rawE, d, raw + 0xE * N); + + // rawF is the concatenation of the upper bit of packed0..2. + const VU16 down0 = ShiftRight<15>(packed0); + const VU16 down1 = ShiftRight<15>(packed1); + const VU16 down2 = ShiftRight<15>(packed2); + const VU16 rawF = Xor3(ShiftLeft<2>(down2), Add(down1, down1), down0); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<3> + +template <> +struct Pack16<4> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + VU16 packed0 = Xor3(ShiftLeft<8>(raw4), ShiftLeft<4>(raw2), raw0); + VU16 packed1 = Xor3(ShiftLeft<8>(raw5), ShiftLeft<4>(raw3), raw1); + packed0 = Or(packed0, ShiftLeft<12>(raw6)); + packed1 = Or(packed1, ShiftLeft<12>(raw7)); + VU16 packed2 = Xor3(ShiftLeft<8>(rawC), ShiftLeft<4>(rawA), raw8); + VU16 packed3 = Xor3(ShiftLeft<8>(rawD), ShiftLeft<4>(rawB), raw9); + packed2 = Or(packed2, ShiftLeft<12>(rawE)); + packed3 = Or(packed3, ShiftLeft<12>(rawF)); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0xFu); // Lowest 4 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + const VU16 packed3 = LoadU(d, packed_in + 3 * N); + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(ShiftRight<4>(packed0), mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(ShiftRight<4>(packed1), mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(ShiftRight<8>(packed0), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<8>(packed1), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = ShiftRight<12>(packed0); // no mask required + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = ShiftRight<12>(packed1); // no mask required + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed2, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed3, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<4>(packed2), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<4>(packed3), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<8>(packed2), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<8>(packed3), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = ShiftRight<12>(packed2); // no mask required + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<12>(packed3); // no mask required + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<4> + +template <> +struct Pack16<5> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // We can fit 15 raw vectors in five packed vectors (three each). + VU16 packed0 = Xor3(ShiftLeft<10>(rawA), ShiftLeft<5>(raw5), raw0); + VU16 packed1 = Xor3(ShiftLeft<10>(rawB), ShiftLeft<5>(raw6), raw1); + VU16 packed2 = Xor3(ShiftLeft<10>(rawC), ShiftLeft<5>(raw7), raw2); + VU16 packed3 = Xor3(ShiftLeft<10>(rawD), ShiftLeft<5>(raw8), raw3); + VU16 packed4 = Xor3(ShiftLeft<10>(rawE), ShiftLeft<5>(raw9), raw4); + + // rawF will be scattered into the upper bits of these five. + const VU16 hi1 = Set(d, 0x8000u); + packed0 = Or(packed0, ShiftLeft<15>(rawF)); // MSB only, no mask + packed1 = OrAnd(packed1, ShiftLeft<14>(rawF), hi1); + packed2 = OrAnd(packed2, ShiftLeft<13>(rawF), hi1); + packed3 = OrAnd(packed3, ShiftLeft<12>(rawF), hi1); + packed4 = OrAnd(packed4, ShiftLeft<11>(rawF), hi1); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + const VU16 packed3 = LoadU(d, packed_in + 3 * N); + const VU16 packed4 = LoadU(d, packed_in + 4 * N); + + const VU16 mask = Set(d, 0x1Fu); // Lowest 5 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<5>(packed0), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<5>(packed1), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(ShiftRight<5>(packed2), mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(ShiftRight<5>(packed3), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<5>(packed4), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<10>(packed0), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<10>(packed1), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<10>(packed2), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<10>(packed3), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<10>(packed4), mask); + StoreU(rawE, d, raw + 0xE * N); + + // rawF is the concatenation of the lower bit of packed0..4. + const VU16 down0 = ShiftRight<15>(packed0); + const VU16 down1 = ShiftRight<15>(packed1); + const VU16 hi1 = Set(d, 0x8000u); + const VU16 p0 = + Xor3(ShiftRight<13>(And(packed2, hi1)), Add(down1, down1), down0); + const VU16 rawF = Xor3(ShiftRight<11>(And(packed4, hi1)), + ShiftRight<12>(And(packed3, hi1)), p0); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<5> + +template <> +struct Pack16<6> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + const VU16 packed3 = Or(ShiftLeft<6>(raw7), raw3); + const VU16 packed7 = Or(ShiftLeft<6>(rawF), rawB); + // Three vectors, two 6-bit raw each; packed3 (12 bits) is spread over the + // four remainder bits at the top of each vector. + const VU16 packed0 = Xor3(ShiftLeft<12>(packed3), ShiftLeft<6>(raw4), raw0); + VU16 packed1 = Or(ShiftLeft<6>(raw5), raw1); + VU16 packed2 = Or(ShiftLeft<6>(raw6), raw2); + const VU16 packed4 = Xor3(ShiftLeft<12>(packed7), ShiftLeft<6>(rawC), raw8); + VU16 packed5 = Or(ShiftLeft<6>(rawD), raw9); + VU16 packed6 = Or(ShiftLeft<6>(rawE), rawA); + + const VU16 hi4 = Set(d, 0xF000u); + packed1 = OrAnd(packed1, ShiftLeft<8>(packed3), hi4); + packed2 = OrAnd(packed2, ShiftLeft<4>(packed3), hi4); + packed5 = OrAnd(packed5, ShiftLeft<8>(packed7), hi4); + packed6 = OrAnd(packed6, ShiftLeft<4>(packed7), hi4); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed4, d, packed_out + 3 * N); + StoreU(packed5, d, packed_out + 4 * N); + StoreU(packed6, d, packed_out + 5 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0x3Fu); // Lowest 6 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + const VU16 packed4 = LoadU(d, packed_in + 3 * N); + const VU16 packed5 = LoadU(d, packed_in + 4 * N); + const VU16 packed6 = LoadU(d, packed_in + 5 * N); + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw4 = And(ShiftRight<6>(packed0), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<6>(packed1), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<6>(packed2), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw8 = And(packed4, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed5, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(packed6, mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawC = And(ShiftRight<6>(packed4), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<6>(packed5), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<6>(packed6), mask); + StoreU(rawE, d, raw + 0xE * N); + + // packed3 is the concatenation of the four upper bits in packed0..2. + const VU16 down0 = ShiftRight<12>(packed0); + const VU16 down4 = ShiftRight<12>(packed4); + const VU16 hi4 = Set(d, 0xF000u); + const VU16 packed3 = Xor3(ShiftRight<4>(And(packed2, hi4)), + ShiftRight<8>(And(packed1, hi4)), down0); + const VU16 packed7 = Xor3(ShiftRight<4>(And(packed6, hi4)), + ShiftRight<8>(And(packed5, hi4)), down4); + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 rawB = And(packed7, mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 raw7 = ShiftRight<6>(packed3); // upper bits already zero + StoreU(raw7, d, raw + 7 * N); + + const VU16 rawF = ShiftRight<6>(packed7); // upper bits already zero + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<6> + +template <> +struct Pack16<7> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + const VU16 packed7 = Or(ShiftLeft<7>(rawF), raw7); + // Seven vectors, two 7-bit raw each; packed7 (14 bits) is spread over the + // two remainder bits at the top of each vector. + const VU16 packed0 = Xor3(ShiftLeft<14>(packed7), ShiftLeft<7>(raw8), raw0); + VU16 packed1 = Or(ShiftLeft<7>(raw9), raw1); + VU16 packed2 = Or(ShiftLeft<7>(rawA), raw2); + VU16 packed3 = Or(ShiftLeft<7>(rawB), raw3); + VU16 packed4 = Or(ShiftLeft<7>(rawC), raw4); + VU16 packed5 = Or(ShiftLeft<7>(rawD), raw5); + VU16 packed6 = Or(ShiftLeft<7>(rawE), raw6); + + const VU16 hi2 = Set(d, 0xC000u); + packed1 = OrAnd(packed1, ShiftLeft<12>(packed7), hi2); + packed2 = OrAnd(packed2, ShiftLeft<10>(packed7), hi2); + packed3 = OrAnd(packed3, ShiftLeft<8>(packed7), hi2); + packed4 = OrAnd(packed4, ShiftLeft<6>(packed7), hi2); + packed5 = OrAnd(packed5, ShiftLeft<4>(packed7), hi2); + packed6 = OrAnd(packed6, ShiftLeft<2>(packed7), hi2); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + + const VU16 mask = Set(d, 0x7Fu); // Lowest 7 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw8 = And(ShiftRight<7>(packed0), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<7>(packed1), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<7>(packed2), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<7>(packed3), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<7>(packed4), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<7>(packed5), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<7>(packed6), mask); + StoreU(rawE, d, raw + 0xE * N); + + // packed7 is the concatenation of the two upper bits in packed0..6. + const VU16 down0 = ShiftRight<14>(packed0); + const VU16 hi2 = Set(d, 0xC000u); + const VU16 p0 = Xor3(ShiftRight<12>(And(packed1, hi2)), + ShiftRight<10>(And(packed2, hi2)), down0); + const VU16 p1 = Xor3(ShiftRight<8>(And(packed3, hi2)), // + ShiftRight<6>(And(packed4, hi2)), + ShiftRight<4>(And(packed5, hi2))); + const VU16 packed7 = Xor3(ShiftRight<2>(And(packed6, hi2)), p1, p0); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 rawF = ShiftRight<7>(packed7); // upper bits already zero + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<7> + +template <> +struct Pack16<8> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // This is equivalent to ConcatEven with 8-bit lanes, but much more + // efficient on RVV and slightly less efficient on SVE2. + const VU16 packed0 = Or(ShiftLeft<8>(raw2), raw0); + const VU16 packed1 = Or(ShiftLeft<8>(raw3), raw1); + const VU16 packed2 = Or(ShiftLeft<8>(raw6), raw4); + const VU16 packed3 = Or(ShiftLeft<8>(raw7), raw5); + const VU16 packed4 = Or(ShiftLeft<8>(rawA), raw8); + const VU16 packed5 = Or(ShiftLeft<8>(rawB), raw9); + const VU16 packed6 = Or(ShiftLeft<8>(rawE), rawC); + const VU16 packed7 = Or(ShiftLeft<8>(rawF), rawD); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 mask = Set(d, 0xFFu); // Lowest 8 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = ShiftRight<8>(packed0); // upper bits already zero + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = ShiftRight<8>(packed1); // upper bits already zero + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed2, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed3, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = ShiftRight<8>(packed2); // upper bits already zero + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = ShiftRight<8>(packed3); // upper bits already zero + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed4, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed5, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = ShiftRight<8>(packed4); // upper bits already zero + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = ShiftRight<8>(packed5); // upper bits already zero + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(packed6, mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(packed7, mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = ShiftRight<8>(packed6); // upper bits already zero + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<8>(packed7); // upper bits already zero + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<8> + +template <> +struct Pack16<9> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + // 8 vectors, each with 9+7 bits; top 2 bits are concatenated into packed8. + const VU16 packed0 = Or(ShiftLeft<9>(raw8), raw0); + const VU16 packed1 = Or(ShiftLeft<9>(raw9), raw1); + const VU16 packed2 = Or(ShiftLeft<9>(rawA), raw2); + const VU16 packed3 = Or(ShiftLeft<9>(rawB), raw3); + const VU16 packed4 = Or(ShiftLeft<9>(rawC), raw4); + const VU16 packed5 = Or(ShiftLeft<9>(rawD), raw5); + const VU16 packed6 = Or(ShiftLeft<9>(rawE), raw6); + const VU16 packed7 = Or(ShiftLeft<9>(rawF), raw7); + + // We could shift down, OR and shift up, but two shifts are typically more + // expensive than AND, shift into position, and OR (which can be further + // reduced via Xor3). + const VU16 mid2 = Set(d, 0x180u); // top 2 in lower 9 + const VU16 part8 = ShiftRight<7>(And(raw8, mid2)); + const VU16 part9 = ShiftRight<5>(And(raw9, mid2)); + const VU16 partA = ShiftRight<3>(And(rawA, mid2)); + const VU16 partB = ShiftRight<1>(And(rawB, mid2)); + const VU16 partC = ShiftLeft<1>(And(rawC, mid2)); + const VU16 partD = ShiftLeft<3>(And(rawD, mid2)); + const VU16 partE = ShiftLeft<5>(And(rawE, mid2)); + const VU16 partF = ShiftLeft<7>(And(rawF, mid2)); + const VU16 packed8 = Xor3(Xor3(part8, part9, partA), + Xor3(partB, partC, partD), Or(partE, partF)); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + + const VU16 mask = Set(d, 0x1FFu); // Lowest 9 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 mid2 = Set(d, 0x180u); // top 2 in lower 9 + const VU16 raw8 = + OrAnd(ShiftRight<9>(packed0), ShiftLeft<7>(packed8), mid2); + const VU16 raw9 = + OrAnd(ShiftRight<9>(packed1), ShiftLeft<5>(packed8), mid2); + const VU16 rawA = + OrAnd(ShiftRight<9>(packed2), ShiftLeft<3>(packed8), mid2); + const VU16 rawB = + OrAnd(ShiftRight<9>(packed3), ShiftLeft<1>(packed8), mid2); + const VU16 rawC = + OrAnd(ShiftRight<9>(packed4), ShiftRight<1>(packed8), mid2); + const VU16 rawD = + OrAnd(ShiftRight<9>(packed5), ShiftRight<3>(packed8), mid2); + const VU16 rawE = + OrAnd(ShiftRight<9>(packed6), ShiftRight<5>(packed8), mid2); + const VU16 rawF = + OrAnd(ShiftRight<9>(packed7), ShiftRight<7>(packed8), mid2); + + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<9> + +template <> +struct Pack16<10> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 8 vectors, each with 10+6 bits; top 4 bits are concatenated into + // packed8 and packed9. + const VU16 packed0 = Or(ShiftLeft<10>(raw8), raw0); + const VU16 packed1 = Or(ShiftLeft<10>(raw9), raw1); + const VU16 packed2 = Or(ShiftLeft<10>(rawA), raw2); + const VU16 packed3 = Or(ShiftLeft<10>(rawB), raw3); + const VU16 packed4 = Or(ShiftLeft<10>(rawC), raw4); + const VU16 packed5 = Or(ShiftLeft<10>(rawD), raw5); + const VU16 packed6 = Or(ShiftLeft<10>(rawE), raw6); + const VU16 packed7 = Or(ShiftLeft<10>(rawF), raw7); + + // We could shift down, OR and shift up, but two shifts are typically more + // expensive than AND, shift into position, and OR (which can be further + // reduced via Xor3). + const VU16 mid4 = Set(d, 0x3C0u); // top 4 in lower 10 + const VU16 part8 = ShiftRight<6>(And(raw8, mid4)); + const VU16 part9 = ShiftRight<2>(And(raw9, mid4)); + const VU16 partA = ShiftLeft<2>(And(rawA, mid4)); + const VU16 partB = ShiftLeft<6>(And(rawB, mid4)); + const VU16 partC = ShiftRight<6>(And(rawC, mid4)); + const VU16 partD = ShiftRight<2>(And(rawD, mid4)); + const VU16 partE = ShiftLeft<2>(And(rawE, mid4)); + const VU16 partF = ShiftLeft<6>(And(rawF, mid4)); + const VU16 packed8 = Or(Xor3(part8, part9, partA), partB); + const VU16 packed9 = Or(Xor3(partC, partD, partE), partF); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + + const VU16 mask = Set(d, 0x3FFu); // Lowest 10 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 mid4 = Set(d, 0x3C0u); // top 4 in lower 10 + const VU16 raw8 = + OrAnd(ShiftRight<10>(packed0), ShiftLeft<6>(packed8), mid4); + const VU16 raw9 = + OrAnd(ShiftRight<10>(packed1), ShiftLeft<2>(packed8), mid4); + const VU16 rawA = + OrAnd(ShiftRight<10>(packed2), ShiftRight<2>(packed8), mid4); + const VU16 rawB = + OrAnd(ShiftRight<10>(packed3), ShiftRight<6>(packed8), mid4); + const VU16 rawC = + OrAnd(ShiftRight<10>(packed4), ShiftLeft<6>(packed9), mid4); + const VU16 rawD = + OrAnd(ShiftRight<10>(packed5), ShiftLeft<2>(packed9), mid4); + const VU16 rawE = + OrAnd(ShiftRight<10>(packed6), ShiftRight<2>(packed9), mid4); + const VU16 rawF = + OrAnd(ShiftRight<10>(packed7), ShiftRight<6>(packed9), mid4); + + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<10> + +template <> +struct Pack16<11> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // It is not obvious what the optimal partitioning looks like. To reduce the + // number of constants, we want to minimize the number of distinct bit + // lengths. 11+5 also requires 6-bit remnants with 4-bit leftovers. + // 8+3 seems better: it is easier to scatter 3 bits into the MSBs. + const VU16 lo8 = Set(d, 0xFFu); + + // Lower 8 bits of all raw + const VU16 packed0 = OrAnd(ShiftLeft<8>(raw1), raw0, lo8); + const VU16 packed1 = OrAnd(ShiftLeft<8>(raw3), raw2, lo8); + const VU16 packed2 = OrAnd(ShiftLeft<8>(raw5), raw4, lo8); + const VU16 packed3 = OrAnd(ShiftLeft<8>(raw7), raw6, lo8); + const VU16 packed4 = OrAnd(ShiftLeft<8>(raw9), raw8, lo8); + const VU16 packed5 = OrAnd(ShiftLeft<8>(rawB), rawA, lo8); + const VU16 packed6 = OrAnd(ShiftLeft<8>(rawD), rawC, lo8); + const VU16 packed7 = OrAnd(ShiftLeft<8>(rawF), rawE, lo8); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + + // Three vectors, five 3bit remnants each, plus one 3bit in their MSB. + const VU16 top0 = ShiftRight<8>(raw0); + const VU16 top1 = ShiftRight<8>(raw1); + const VU16 top2 = ShiftRight<8>(raw2); + // Insert top raw bits into 3-bit groups within packed8..A. Moving the + // mask along avoids masking each of raw0..E and enables OrAnd. + VU16 next = Set(d, 0x38u); // 0x7 << 3 + VU16 packed8 = OrAnd(top0, ShiftRight<5>(raw3), next); + VU16 packed9 = OrAnd(top1, ShiftRight<5>(raw4), next); + VU16 packedA = OrAnd(top2, ShiftRight<5>(raw5), next); + next = ShiftLeft<3>(next); + packed8 = OrAnd(packed8, ShiftRight<2>(raw6), next); + packed9 = OrAnd(packed9, ShiftRight<2>(raw7), next); + packedA = OrAnd(packedA, ShiftRight<2>(raw8), next); + next = ShiftLeft<3>(next); + packed8 = OrAnd(packed8, Add(raw9, raw9), next); + packed9 = OrAnd(packed9, Add(rawA, rawA), next); + packedA = OrAnd(packedA, Add(rawB, rawB), next); + next = ShiftLeft<3>(next); + packed8 = OrAnd(packed8, ShiftLeft<4>(rawC), next); + packed9 = OrAnd(packed9, ShiftLeft<4>(rawD), next); + packedA = OrAnd(packedA, ShiftLeft<4>(rawE), next); + + // Scatter upper 3 bits of rawF into the upper bits. + next = ShiftLeft<3>(next); // = 0x8000u + packed8 = OrAnd(packed8, ShiftLeft<7>(rawF), next); + packed9 = OrAnd(packed9, ShiftLeft<6>(rawF), next); + packedA = OrAnd(packedA, ShiftLeft<5>(rawF), next); + + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + + const VU16 mask = Set(d, 0xFFu); // Lowest 8 bits + + const VU16 down0 = And(packed0, mask); + const VU16 down1 = ShiftRight<8>(packed0); + const VU16 down2 = And(packed1, mask); + const VU16 down3 = ShiftRight<8>(packed1); + const VU16 down4 = And(packed2, mask); + const VU16 down5 = ShiftRight<8>(packed2); + const VU16 down6 = And(packed3, mask); + const VU16 down7 = ShiftRight<8>(packed3); + const VU16 down8 = And(packed4, mask); + const VU16 down9 = ShiftRight<8>(packed4); + const VU16 downA = And(packed5, mask); + const VU16 downB = ShiftRight<8>(packed5); + const VU16 downC = And(packed6, mask); + const VU16 downD = ShiftRight<8>(packed6); + const VU16 downE = And(packed7, mask); + const VU16 downF = ShiftRight<8>(packed7); + + // Three bits from packed8..A, eight bits from down0..F. + const VU16 hi3 = Set(d, 0x700u); + const VU16 raw0 = OrAnd(down0, ShiftLeft<8>(packed8), hi3); + const VU16 raw1 = OrAnd(down1, ShiftLeft<8>(packed9), hi3); + const VU16 raw2 = OrAnd(down2, ShiftLeft<8>(packedA), hi3); + + const VU16 raw3 = OrAnd(down3, ShiftLeft<5>(packed8), hi3); + const VU16 raw4 = OrAnd(down4, ShiftLeft<5>(packed9), hi3); + const VU16 raw5 = OrAnd(down5, ShiftLeft<5>(packedA), hi3); + + const VU16 raw6 = OrAnd(down6, ShiftLeft<2>(packed8), hi3); + const VU16 raw7 = OrAnd(down7, ShiftLeft<2>(packed9), hi3); + const VU16 raw8 = OrAnd(down8, ShiftLeft<2>(packedA), hi3); + + const VU16 raw9 = OrAnd(down9, ShiftRight<1>(packed8), hi3); + const VU16 rawA = OrAnd(downA, ShiftRight<1>(packed9), hi3); + const VU16 rawB = OrAnd(downB, ShiftRight<1>(packedA), hi3); + + const VU16 rawC = OrAnd(downC, ShiftRight<4>(packed8), hi3); + const VU16 rawD = OrAnd(downD, ShiftRight<4>(packed9), hi3); + const VU16 rawE = OrAnd(downE, ShiftRight<4>(packedA), hi3); + + // Shift MSB into the top 3-of-11 and mask. + const VU16 rawF = Or(downF, Xor3(And(ShiftRight<7>(packed8), hi3), + And(ShiftRight<6>(packed9), hi3), + And(ShiftRight<5>(packedA), hi3))); + + StoreU(raw0, d, raw + 0 * N); + StoreU(raw1, d, raw + 1 * N); + StoreU(raw2, d, raw + 2 * N); + StoreU(raw3, d, raw + 3 * N); + StoreU(raw4, d, raw + 4 * N); + StoreU(raw5, d, raw + 5 * N); + StoreU(raw6, d, raw + 6 * N); + StoreU(raw7, d, raw + 7 * N); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<11> + +template <> +struct Pack16<12> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 8 vectors, each with 12+4 bits; top 8 bits are concatenated into + // packed8 to packedB. + const VU16 packed0 = Or(ShiftLeft<12>(raw8), raw0); + const VU16 packed1 = Or(ShiftLeft<12>(raw9), raw1); + const VU16 packed2 = Or(ShiftLeft<12>(rawA), raw2); + const VU16 packed3 = Or(ShiftLeft<12>(rawB), raw3); + const VU16 packed4 = Or(ShiftLeft<12>(rawC), raw4); + const VU16 packed5 = Or(ShiftLeft<12>(rawD), raw5); + const VU16 packed6 = Or(ShiftLeft<12>(rawE), raw6); + const VU16 packed7 = Or(ShiftLeft<12>(rawF), raw7); + + // Masking after shifting left enables OrAnd. + const VU16 hi8 = Set(d, 0xFF00u); + const VU16 packed8 = OrAnd(ShiftRight<4>(raw8), ShiftLeft<4>(raw9), hi8); + const VU16 packed9 = OrAnd(ShiftRight<4>(rawA), ShiftLeft<4>(rawB), hi8); + const VU16 packedA = OrAnd(ShiftRight<4>(rawC), ShiftLeft<4>(rawD), hi8); + const VU16 packedB = OrAnd(ShiftRight<4>(rawE), ShiftLeft<4>(rawF), hi8); + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + + const VU16 mask = Set(d, 0xFFFu); // Lowest 12 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 mid8 = Set(d, 0xFF0u); // upper 8 in lower 12 + const VU16 raw8 = + OrAnd(ShiftRight<12>(packed0), ShiftLeft<4>(packed8), mid8); + const VU16 raw9 = + OrAnd(ShiftRight<12>(packed1), ShiftRight<4>(packed8), mid8); + const VU16 rawA = + OrAnd(ShiftRight<12>(packed2), ShiftLeft<4>(packed9), mid8); + const VU16 rawB = + OrAnd(ShiftRight<12>(packed3), ShiftRight<4>(packed9), mid8); + const VU16 rawC = + OrAnd(ShiftRight<12>(packed4), ShiftLeft<4>(packedA), mid8); + const VU16 rawD = + OrAnd(ShiftRight<12>(packed5), ShiftRight<4>(packedA), mid8); + const VU16 rawE = + OrAnd(ShiftRight<12>(packed6), ShiftLeft<4>(packedB), mid8); + const VU16 rawF = + OrAnd(ShiftRight<12>(packed7), ShiftRight<4>(packedB), mid8); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<12> + +template <> +struct Pack16<13> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // As with 11 bits, it is not obvious what the optimal partitioning looks + // like. We similarly go with an 8+5 split. + const VU16 lo8 = Set(d, 0xFFu); + + // Lower 8 bits of all raw + const VU16 packed0 = OrAnd(ShiftLeft<8>(raw1), raw0, lo8); + const VU16 packed1 = OrAnd(ShiftLeft<8>(raw3), raw2, lo8); + const VU16 packed2 = OrAnd(ShiftLeft<8>(raw5), raw4, lo8); + const VU16 packed3 = OrAnd(ShiftLeft<8>(raw7), raw6, lo8); + const VU16 packed4 = OrAnd(ShiftLeft<8>(raw9), raw8, lo8); + const VU16 packed5 = OrAnd(ShiftLeft<8>(rawB), rawA, lo8); + const VU16 packed6 = OrAnd(ShiftLeft<8>(rawD), rawC, lo8); + const VU16 packed7 = OrAnd(ShiftLeft<8>(rawF), rawE, lo8); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + + // Five vectors, three 5bit remnants each, plus one 5bit in their MSB. + const VU16 top0 = ShiftRight<8>(raw0); + const VU16 top1 = ShiftRight<8>(raw1); + const VU16 top2 = ShiftRight<8>(raw2); + const VU16 top3 = ShiftRight<8>(raw3); + const VU16 top4 = ShiftRight<8>(raw4); + + // Insert top raw bits into 5-bit groups within packed8..C. Moving the + // mask along avoids masking each of raw0..E and enables OrAnd. + VU16 next = Set(d, 0x3E0u); // 0x1F << 5 + VU16 packed8 = OrAnd(top0, ShiftRight<3>(raw5), next); + VU16 packed9 = OrAnd(top1, ShiftRight<3>(raw6), next); + VU16 packedA = OrAnd(top2, ShiftRight<3>(raw7), next); + VU16 packedB = OrAnd(top3, ShiftRight<3>(raw8), next); + VU16 packedC = OrAnd(top4, ShiftRight<3>(raw9), next); + next = ShiftLeft<5>(next); + packed8 = OrAnd(packed8, ShiftLeft<2>(rawA), next); + packed9 = OrAnd(packed9, ShiftLeft<2>(rawB), next); + packedA = OrAnd(packedA, ShiftLeft<2>(rawC), next); + packedB = OrAnd(packedB, ShiftLeft<2>(rawD), next); + packedC = OrAnd(packedC, ShiftLeft<2>(rawE), next); + + // Scatter upper 5 bits of rawF into the upper bits. + next = ShiftLeft<3>(next); // = 0x8000u + packed8 = OrAnd(packed8, ShiftLeft<7>(rawF), next); + packed9 = OrAnd(packed9, ShiftLeft<6>(rawF), next); + packedA = OrAnd(packedA, ShiftLeft<5>(rawF), next); + packedB = OrAnd(packedB, ShiftLeft<4>(rawF), next); + packedC = OrAnd(packedC, ShiftLeft<3>(rawF), next); + + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + StoreU(packedC, d, packed_out + 0xC * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 packedC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + + const VU16 mask = Set(d, 0xFFu); // Lowest 8 bits + + const VU16 down0 = And(packed0, mask); + const VU16 down1 = ShiftRight<8>(packed0); + const VU16 down2 = And(packed1, mask); + const VU16 down3 = ShiftRight<8>(packed1); + const VU16 down4 = And(packed2, mask); + const VU16 down5 = ShiftRight<8>(packed2); + const VU16 down6 = And(packed3, mask); + const VU16 down7 = ShiftRight<8>(packed3); + const VU16 down8 = And(packed4, mask); + const VU16 down9 = ShiftRight<8>(packed4); + const VU16 downA = And(packed5, mask); + const VU16 downB = ShiftRight<8>(packed5); + const VU16 downC = And(packed6, mask); + const VU16 downD = ShiftRight<8>(packed6); + const VU16 downE = And(packed7, mask); + const VU16 downF = ShiftRight<8>(packed7); + + // Upper five bits from packed8..C, eight bits from down0..F. + const VU16 hi5 = Set(d, 0x1F00u); + const VU16 raw0 = OrAnd(down0, ShiftLeft<8>(packed8), hi5); + const VU16 raw1 = OrAnd(down1, ShiftLeft<8>(packed9), hi5); + const VU16 raw2 = OrAnd(down2, ShiftLeft<8>(packedA), hi5); + const VU16 raw3 = OrAnd(down3, ShiftLeft<8>(packedB), hi5); + const VU16 raw4 = OrAnd(down4, ShiftLeft<8>(packedC), hi5); + + const VU16 raw5 = OrAnd(down5, ShiftLeft<3>(packed8), hi5); + const VU16 raw6 = OrAnd(down6, ShiftLeft<3>(packed9), hi5); + const VU16 raw7 = OrAnd(down7, ShiftLeft<3>(packedA), hi5); + const VU16 raw8 = OrAnd(down8, ShiftLeft<3>(packed9), hi5); + const VU16 raw9 = OrAnd(down9, ShiftLeft<3>(packedA), hi5); + + const VU16 rawA = OrAnd(downA, ShiftRight<2>(packed8), hi5); + const VU16 rawB = OrAnd(downB, ShiftRight<2>(packed9), hi5); + const VU16 rawC = OrAnd(downC, ShiftRight<2>(packedA), hi5); + const VU16 rawD = OrAnd(downD, ShiftRight<2>(packed9), hi5); + const VU16 rawE = OrAnd(downE, ShiftRight<2>(packedA), hi5); + + // Shift MSB into the top 5-of-11 and mask. + const VU16 p0 = Xor3(And(ShiftRight<7>(packed8), hi5), // + And(ShiftRight<6>(packed9), hi5), + And(ShiftRight<5>(packedA), hi5)); + const VU16 p1 = Xor3(And(ShiftRight<4>(packedB), hi5), + And(ShiftRight<3>(packedC), hi5), downF); + const VU16 rawF = Or(p0, p1); + + StoreU(raw0, d, raw + 0 * N); + StoreU(raw1, d, raw + 1 * N); + StoreU(raw2, d, raw + 2 * N); + StoreU(raw3, d, raw + 3 * N); + StoreU(raw4, d, raw + 4 * N); + StoreU(raw5, d, raw + 5 * N); + StoreU(raw6, d, raw + 6 * N); + StoreU(raw7, d, raw + 7 * N); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<13> + +template <> +struct Pack16<14> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 14 vectors, each with 14+2 bits; two raw vectors are scattered + // across the upper 2 bits. + const VU16 hi2 = Set(d, 0xC000u); + const VU16 packed0 = Or(raw0, ShiftLeft<14>(rawE)); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<12>(rawE), hi2); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<10>(rawE), hi2); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<8>(rawE), hi2); + const VU16 packed4 = OrAnd(raw4, ShiftLeft<6>(rawE), hi2); + const VU16 packed5 = OrAnd(raw5, ShiftLeft<4>(rawE), hi2); + const VU16 packed6 = OrAnd(raw6, ShiftLeft<2>(rawE), hi2); + const VU16 packed7 = Or(raw7, ShiftLeft<14>(rawF)); + const VU16 packed8 = OrAnd(raw8, ShiftLeft<12>(rawF), hi2); + const VU16 packed9 = OrAnd(raw9, ShiftLeft<10>(rawF), hi2); + const VU16 packedA = OrAnd(rawA, ShiftLeft<8>(rawF), hi2); + const VU16 packedB = OrAnd(rawB, ShiftLeft<6>(rawF), hi2); + const VU16 packedC = OrAnd(rawC, ShiftLeft<4>(rawF), hi2); + const VU16 packedD = OrAnd(rawD, ShiftLeft<2>(rawF), hi2); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + StoreU(packedC, d, packed_out + 0xC * N); + StoreU(packedD, d, packed_out + 0xD * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 packedC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + const VU16 packedD = BitCast(d, LoadU(d, packed_in + 0xD * N)); + + const VU16 mask = Set(d, 0x3FFFu); // Lowest 14 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed8, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed9, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(packedA, mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(packedB, mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(packedC, mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(packedD, mask); + StoreU(rawD, d, raw + 0xD * N); + + // rawE is the concatenation of the top two bits in packed0..6. + const VU16 E0 = Xor3(ShiftRight<14>(packed0), // + ShiftRight<12>(AndNot(mask, packed1)), + ShiftRight<10>(AndNot(mask, packed2))); + const VU16 E1 = Xor3(ShiftRight<8>(AndNot(mask, packed3)), + ShiftRight<6>(AndNot(mask, packed4)), + ShiftRight<4>(AndNot(mask, packed5))); + const VU16 rawE = Xor3(ShiftRight<2>(AndNot(mask, packed6)), E0, E1); + const VU16 F0 = Xor3(ShiftRight<14>(AndNot(mask, packed7)), + ShiftRight<12>(AndNot(mask, packed8)), + ShiftRight<10>(AndNot(mask, packed9))); + const VU16 F1 = Xor3(ShiftRight<8>(AndNot(mask, packedA)), + ShiftRight<6>(AndNot(mask, packedB)), + ShiftRight<4>(AndNot(mask, packedC))); + const VU16 rawF = Xor3(ShiftRight<2>(AndNot(mask, packedD)), F0, F1); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<14> + +template <> +struct Pack16<15> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 15 vectors, each with 15+1 bits; one packed vector is scattered + // across the upper bit. + const VU16 hi1 = Set(d, 0x8000u); + const VU16 packed0 = Or(raw0, ShiftLeft<15>(rawF)); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<14>(rawF), hi1); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<13>(rawF), hi1); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<12>(rawF), hi1); + const VU16 packed4 = OrAnd(raw4, ShiftLeft<11>(rawF), hi1); + const VU16 packed5 = OrAnd(raw5, ShiftLeft<10>(rawF), hi1); + const VU16 packed6 = OrAnd(raw6, ShiftLeft<9>(rawF), hi1); + const VU16 packed7 = OrAnd(raw7, ShiftLeft<8>(rawF), hi1); + const VU16 packed8 = OrAnd(raw8, ShiftLeft<7>(rawF), hi1); + const VU16 packed9 = OrAnd(raw9, ShiftLeft<6>(rawF), hi1); + const VU16 packedA = OrAnd(rawA, ShiftLeft<5>(rawF), hi1); + const VU16 packedB = OrAnd(rawB, ShiftLeft<4>(rawF), hi1); + const VU16 packedC = OrAnd(rawC, ShiftLeft<3>(rawF), hi1); + const VU16 packedD = OrAnd(rawD, ShiftLeft<2>(rawF), hi1); + const VU16 packedE = OrAnd(rawE, ShiftLeft<1>(rawF), hi1); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + StoreU(packedC, d, packed_out + 0xC * N); + StoreU(packedD, d, packed_out + 0xD * N); + StoreU(packedE, d, packed_out + 0xE * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 packedC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + const VU16 packedD = BitCast(d, LoadU(d, packed_in + 0xD * N)); + const VU16 packedE = BitCast(d, LoadU(d, packed_in + 0xE * N)); + + const VU16 mask = Set(d, 0x7FFFu); // Lowest 15 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed8, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed9, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(packedA, mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(packedB, mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(packedC, mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(packedD, mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(packedE, mask); + StoreU(rawE, d, raw + 0xE * N); + + // rawF is the concatenation of the top bit in packed0..E. + const VU16 F0 = Xor3(ShiftRight<15>(packed0), // + ShiftRight<14>(AndNot(mask, packed1)), + ShiftRight<13>(AndNot(mask, packed2))); + const VU16 F1 = Xor3(ShiftRight<12>(AndNot(mask, packed3)), + ShiftRight<11>(AndNot(mask, packed4)), + ShiftRight<10>(AndNot(mask, packed5))); + const VU16 F2 = Xor3(ShiftRight<9>(AndNot(mask, packed6)), + ShiftRight<8>(AndNot(mask, packed7)), + ShiftRight<7>(AndNot(mask, packed8))); + const VU16 F3 = Xor3(ShiftRight<6>(AndNot(mask, packed9)), + ShiftRight<5>(AndNot(mask, packedA)), + ShiftRight<4>(AndNot(mask, packedB))); + const VU16 F4 = Xor3(ShiftRight<3>(AndNot(mask, packedC)), + ShiftRight<2>(AndNot(mask, packedD)), + ShiftRight<1>(AndNot(mask, packedE))); + const VU16 rawF = Xor3(F0, F1, Xor3(F2, F3, F4)); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<15> + +template <> +struct Pack16<16> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + StoreU(raw0, d, packed_out + 0 * N); + StoreU(raw1, d, packed_out + 1 * N); + StoreU(raw2, d, packed_out + 2 * N); + StoreU(raw3, d, packed_out + 3 * N); + StoreU(raw4, d, packed_out + 4 * N); + StoreU(raw5, d, packed_out + 5 * N); + StoreU(raw6, d, packed_out + 6 * N); + StoreU(raw7, d, packed_out + 7 * N); + StoreU(raw8, d, packed_out + 8 * N); + StoreU(raw9, d, packed_out + 9 * N); + StoreU(rawA, d, packed_out + 0xA * N); + StoreU(rawB, d, packed_out + 0xB * N); + StoreU(rawC, d, packed_out + 0xC * N); + StoreU(rawD, d, packed_out + 0xD * N); + StoreU(rawE, d, packed_out + 0xE * N); + StoreU(rawF, d, packed_out + 0xF * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 raw0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 raw1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 raw2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 raw3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 raw4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 raw5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 raw6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 raw7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 raw8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 raw9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 rawA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 rawB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 rawC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + const VU16 rawD = BitCast(d, LoadU(d, packed_in + 0xD * N)); + const VU16 rawE = BitCast(d, LoadU(d, packed_in + 0xE * N)); + const VU16 rawF = BitCast(d, LoadU(d, packed_in + 0xF * N)); + + StoreU(raw0, d, raw + 0 * N); + StoreU(raw1, d, raw + 1 * N); + StoreU(raw2, d, raw + 2 * N); + StoreU(raw3, d, raw + 3 * N); + StoreU(raw4, d, raw + 4 * N); + StoreU(raw5, d, raw + 5 * N); + StoreU(raw6, d, raw + 6 * N); + StoreU(raw7, d, raw + 7 * N); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<16> + +// The supported packing types for 32/64 bits. +enum BlockPackingType { + // Simple fixed bit-packing. + kBitPacked, + // Bit packing after subtracting a `frame of reference` value from input. + kFoRBitPacked, +}; + +namespace detail { + +// Generates the implementation for bit-packing/un-packing `T` type numbers +// where each number takes `kBits` bits. +// `S` is the remainder bits left from the previous bit-packed block. +// `kLoadPos` is the offset from which the next vector block should be loaded. +// `kStorePos` is the offset into which the next vector block should be stored. +// `BlockPackingType` is the type of packing/unpacking for this block. +template +struct BitPackUnroller { + static constexpr size_t B = sizeof(T) * 8; + + template + static inline void Pack(D d, const T* HWY_RESTRICT raw, + T* HWY_RESTRICT packed_out, const V& mask, + const V& frame_of_reference, V& in, V& out) { + // Avoid compilation errors and unnecessary template instantiation if + // compiling in C++11 or C++14 mode + using NextUnroller = BitPackUnroller< + T, kBits, ((S <= B) ? (S + ((S < B) ? kBits : 0)) : (S % B)), + kLoadPos + static_cast(S < B), + kStorePos + static_cast(S > B), block_packing_type>; + + (void)raw; + (void)mask; + (void)in; + + const size_t N = Lanes(d); + HWY_IF_CONSTEXPR(S >= B) { + StoreU(out, d, packed_out + kStorePos * N); + HWY_IF_CONSTEXPR(S == B) { return; } + HWY_IF_CONSTEXPR(S != B) { + constexpr size_t shr_amount = (kBits - S % B) % B; + out = ShiftRight(in); + // NextUnroller is a typedef for + // Unroller if S > B is true + return NextUnroller::Pack(d, raw, packed_out, mask, frame_of_reference, + in, out); + } + } + HWY_IF_CONSTEXPR(S < B) { + HWY_IF_CONSTEXPR(block_packing_type == BlockPackingType::kBitPacked) { + in = LoadU(d, raw + kLoadPos * N); + } + HWY_IF_CONSTEXPR(block_packing_type == BlockPackingType::kFoRBitPacked) { + in = Sub(LoadU(d, raw + kLoadPos * N), frame_of_reference); + } + // Optimize for the case when `S` is zero. + // We can skip `Or` + ShiftLeft` to align `in`. + HWY_IF_CONSTEXPR(S == 0) { out = in; } + HWY_IF_CONSTEXPR(S != 0) { out = Or(out, ShiftLeft(in)); } + // NextUnroller is a typedef for + // Unroller if S < B is true + return NextUnroller::Pack(d, raw, packed_out, mask, frame_of_reference, + in, out); + } + } + + template + static inline void Unpack(D d, const T* HWY_RESTRICT packed_in, + T* HWY_RESTRICT raw, const V& mask, + const V& frame_of_reference, V& in, V& out) { + // Avoid compilation errors and unnecessary template instantiation if + // compiling in C++11 or C++14 mode + using NextUnroller = BitPackUnroller< + T, kBits, ((S <= B) ? (S + ((S < B) ? kBits : 0)) : (S % B)), + kLoadPos + static_cast(S > B), + kStorePos + static_cast(S < B), block_packing_type>; + + (void)packed_in; + (void)mask; + (void)in; + + const size_t N = Lanes(d); + HWY_IF_CONSTEXPR(S >= B) { + HWY_IF_CONSTEXPR(S == B) { + V bitpacked_output = out; + HWY_IF_CONSTEXPR(block_packing_type == + BlockPackingType::kFoRBitPacked) { + bitpacked_output = Add(bitpacked_output, frame_of_reference); + } + StoreU(bitpacked_output, d, raw + kStorePos * N); + return; + } + HWY_IF_CONSTEXPR(S != B) { + in = LoadU(d, packed_in + kLoadPos * N); + constexpr size_t shl_amount = (kBits - S % B) % B; + out = And(Or(out, ShiftLeft(in)), mask); + // NextUnroller is a typedef for + // Unroller if S > B is true + return NextUnroller::Unpack(d, packed_in, raw, mask, frame_of_reference, + in, out); + } + } + HWY_IF_CONSTEXPR(S < B) { + V bitpacked_output = out; + HWY_IF_CONSTEXPR(block_packing_type == BlockPackingType::kFoRBitPacked) { + bitpacked_output = Add(bitpacked_output, frame_of_reference); + } + StoreU(bitpacked_output, d, raw + kStorePos * N); + HWY_IF_CONSTEXPR(S + kBits < B) { + // Optimize for the case when `S` is zero. + // We can skip the `ShiftRight` to align `in`. + HWY_IF_CONSTEXPR(S == 0) { out = And(in, mask); } + HWY_IF_CONSTEXPR(S != 0) { out = And(ShiftRight(in), mask); } + } + HWY_IF_CONSTEXPR(S + kBits >= B) { out = ShiftRight(in); } + // NextUnroller is a typedef for + // Unroller if S < B is true + return NextUnroller::Unpack(d, packed_in, raw, mask, frame_of_reference, + in, out); + } + } +}; + +// Computes the highest power of two that divides `kBits`. +template +constexpr size_t NumLoops() { + return (kBits & ~(kBits - 1)); +} + +template +constexpr size_t PackedIncr() { + return kBits / NumLoops(); +} + +template +constexpr size_t UnpackedIncr() { + return (sizeof(T) * 8) / NumLoops(); +} + +template +constexpr uint32_t MaskBits32() { + return static_cast((1ull << kBits) - 1); +} + +template +constexpr uint64_t MaskBits64() { + return (uint64_t{1} << kBits) - 1; +} +template <> +constexpr uint64_t MaskBits64<64>() { + return ~uint64_t{0}; +} + +} // namespace detail + +template // <= 32 +struct Pack32 { + template + HWY_INLINE void Pack(D d, const uint32_t* HWY_RESTRICT raw, + uint32_t* HWY_RESTRICT packed_out, + const uint32_t frame_of_reference_value = 0) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits32()); + const V frame_of_reference = Set(d, frame_of_reference_value); + for (size_t i = 0; i < detail::NumLoops(); ++i) { + V in = Zero(d); + V out = Zero(d); + detail::BitPackUnroller::Pack(d, raw, packed_out, + mask, + frame_of_reference, in, + out); + raw += detail::UnpackedIncr() * Lanes(d); + packed_out += detail::PackedIncr() * Lanes(d); + } + } + + template + HWY_INLINE void Unpack(D d, const uint32_t* HWY_RESTRICT packed_in, + uint32_t* HWY_RESTRICT raw, + const uint32_t frame_of_reference_value = 0) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits32()); + const V frame_of_reference = Set(d, frame_of_reference_value); + for (size_t i = 0; i < detail::NumLoops(); ++i) { + V in = LoadU(d, packed_in + 0 * Lanes(d)); + V out = And(in, mask); + detail::BitPackUnroller::Unpack(d, packed_in, raw, + mask, + frame_of_reference, + in, out); + raw += detail::UnpackedIncr() * Lanes(d); + packed_in += detail::PackedIncr() * Lanes(d); + } + } +}; + +template // <= 64 +struct Pack64 { + template + HWY_INLINE void Pack(D d, const uint64_t* HWY_RESTRICT raw, + uint64_t* HWY_RESTRICT packed_out, + const uint64_t frame_of_reference_value = 0) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits64()); + const V frame_of_reference = Set(d, frame_of_reference_value); + for (size_t i = 0; i < detail::NumLoops(); ++i) { + V in = Zero(d); + V out = Zero(d); + detail::BitPackUnroller::Pack(d, raw, packed_out, + mask, + frame_of_reference, in, + out); + raw += detail::UnpackedIncr() * Lanes(d); + packed_out += detail::PackedIncr() * Lanes(d); + } + } + + template + HWY_INLINE void Unpack(D d, const uint64_t* HWY_RESTRICT packed_in, + uint64_t* HWY_RESTRICT raw, + const uint64_t frame_of_reference_value = 0) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits64()); + const V frame_of_reference = Set(d, frame_of_reference_value); + for (size_t i = 0; i < detail::NumLoops(); ++i) { + V in = LoadU(d, packed_in + 0 * Lanes(d)); + V out = And(in, mask); + detail::BitPackUnroller::Unpack(d, packed_in, raw, + mask, + frame_of_reference, + in, out); + raw += detail::UnpackedIncr() * Lanes(d); + packed_in += detail::PackedIncr() * Lanes(d); + } + } +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/dot/dot-inl.h b/third_party/aom/third_party/highway/hwy/contrib/dot/dot-inl.h new file mode 100644 index 000000000000..4349629f6b5c --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/dot/dot-inl.h @@ -0,0 +1,460 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +// clang-format on +#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#endif + +#include +#include + +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// NOTE: the D argument describes the inputs, not the output, because both +// f32/f32, bf16/bf16, and f32/bf16 inputs accumulate to f32. +struct Dot { + // Specify zero or more of these, ORed together, as the kAssumptions template + // argument to Compute. Each one may improve performance or reduce code size, + // at the cost of additional requirements on the arguments. + enum Assumptions { + // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T). + kAtLeastOneVector = 1, + // num_elements is divisible by N (a power of two, so this can be used if + // the problem size is known to be a power of two >= HWY_MAX_BYTES / + // sizeof(T)). + kMultipleOfVector = 2, + // RoundUpTo(num_elements, N) elements are accessible; their value does not + // matter (will be treated as if they were zero). + kPaddedToVector = 4, + }; + + // Returns sum{pa[i] * pb[i]} for floating-point inputs, including float16_t + // and double if HWY_HAVE_FLOAT16/64. Aligning the + // pointers to a multiple of N elements is helpful but not required. + template > + static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa, + const T* const HWY_RESTRICT pb, + const size_t num_elements) { + static_assert(IsFloat(), "MulAdd requires float type"); + using V = decltype(Zero(d)); + + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + // Only 2x unroll to avoid excessive code size. + T sum0 = ConvertScalarTo(0); + T sum1 = ConvertScalarTo(0); + for (; i + 2 <= num_elements; i += 2) { + // For reasons unknown, fp16 += does not compile on clang (Arm). + sum0 = ConvertScalarTo(sum0 + pa[i + 0] * pb[i + 0]); + sum1 = ConvertScalarTo(sum1 + pa[i + 1] * pb[i + 1]); + } + if (i < num_elements) { + sum1 = ConvertScalarTo(sum1 + pa[i] * pb[i]); + } + return ConvertScalarTo(sum0 + sum1); + } + + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive + // for unaligned inputs (each unaligned pointer halves the throughput + // because it occupies both L1 load ports for a cycle). We cannot have + // arrays of vectors on RVV/SVE, so always unroll 4x. + V sum0 = Zero(d); + V sum1 = Zero(d); + V sum2 = Zero(d); + V sum3 = Zero(d); + + // Main loop: unrolled + for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = MulAdd(a0, b0, sum0); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum1 = MulAdd(a1, b1, sum1); + const auto a2 = LoadU(d, pa + i); + const auto b2 = LoadU(d, pb + i); + i += N; + sum2 = MulAdd(a2, b2, sum2); + const auto a3 = LoadU(d, pa + i); + const auto b3 = LoadU(d, pb + i); + i += N; + sum3 = MulAdd(a3, b3, sum3); + } + + // Up to 3 iterations of whole vectors + for (; i + N <= num_elements; i += N) { + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum0 = MulAdd(a, b, sum0); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(d, remaining); + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1); + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(d, N - remaining); + const auto a = LoadU(d, pa + i); // always unaligned + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(d, sum0); + } + + // f32 * bf16 + template + static HWY_INLINE float Compute(const DF df, + const float* const HWY_RESTRICT pa, + const hwy::bfloat16_t* const HWY_RESTRICT pb, + const size_t num_elements) { +#if HWY_TARGET == HWY_SCALAR + const Rebind dbf; +#else + const Repartition dbf; + using VBF = decltype(Zero(dbf)); +#endif + const Half dbfh; + using VF = decltype(Zero(df)); + + HWY_LANES_CONSTEXPR size_t NF = Lanes(df); + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < NF)) { + // Only 2x unroll to avoid excessive code size. + float sum0 = 0.0f; + float sum1 = 0.0f; + size_t i = 0; + for (; i + 2 <= num_elements; i += 2) { + sum0 += pa[i + 0] * ConvertScalarTo(pb[i + 0]); + sum1 += pa[i + 1] * ConvertScalarTo(pb[i + 1]); + } + for (; i < num_elements; ++i) { + sum1 += pa[i] * ConvertScalarTo(pb[i]); + } + return sum0 + sum1; + } + + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive + // for unaligned inputs (each unaligned pointer halves the throughput + // because it occupies both L1 load ports for a cycle). We cannot have + // arrays of vectors on RVV/SVE, so always unroll 4x. + VF sum0 = Zero(df); + VF sum1 = Zero(df); + VF sum2 = Zero(df); + VF sum3 = Zero(df); + + size_t i = 0; + +#if HWY_TARGET != HWY_SCALAR // PromoteUpperTo supported + // Main loop: unrolled + for (; i + 4 * NF <= num_elements; /* i += 4 * N */) { // incr in loop + const VF a0 = LoadU(df, pa + i); + const VBF b0 = LoadU(dbf, pb + i); + i += NF; + sum0 = MulAdd(a0, PromoteLowerTo(df, b0), sum0); + const VF a1 = LoadU(df, pa + i); + i += NF; + sum1 = MulAdd(a1, PromoteUpperTo(df, b0), sum1); + const VF a2 = LoadU(df, pa + i); + const VBF b2 = LoadU(dbf, pb + i); + i += NF; + sum2 = MulAdd(a2, PromoteLowerTo(df, b2), sum2); + const VF a3 = LoadU(df, pa + i); + i += NF; + sum3 = MulAdd(a3, PromoteUpperTo(df, b2), sum3); + } +#endif // HWY_TARGET == HWY_SCALAR + + // Up to 3 iterations of whole vectors + for (; i + NF <= num_elements; i += NF) { + const VF a = LoadU(df, pa + i); + const VF b = PromoteTo(df, LoadU(dbfh, pb + i)); + sum0 = MulAdd(a, b, sum0); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(df, remaining); + const VF a = LoadU(df, pa + i); + const VF b = PromoteTo(df, LoadU(dbfh, pb + i)); + sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1); + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= NF); + i += remaining - NF; + const auto skip = FirstN(df, NF - remaining); + const VF a = LoadU(df, pa + i); // always unaligned + const VF b = PromoteTo(df, LoadU(dbfh, pb + i)); + sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(df, sum0); + } + + // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a + // multiple of N elements is helpful but not required. + template + static HWY_INLINE float Compute(const D d, + const bfloat16_t* const HWY_RESTRICT pa, + const bfloat16_t* const HWY_RESTRICT pb, + const size_t num_elements) { + const RebindToUnsigned du16; + const Repartition df32; + + using V = decltype(Zero(df32)); + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for.. + float sum1 = 0.0f; // this unlikely(?) case. + for (; i + 2 <= num_elements; i += 2) { + sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]); + sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]); + } + if (i < num_elements) { + sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); + } + return sum0 + sum1; + } + + // See comment in the other Compute() overload. Unroll 2x, but we need + // twice as many sums for ReorderWidenMulAccumulate. + V sum0 = Zero(df32); + V sum1 = Zero(df32); + V sum2 = Zero(df32); + V sum3 = Zero(df32); + + // Main loop: unrolled + for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3); + } + + // Possibly one more iteration of whole vectors + if (i + N <= num_elements) { + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(du16, remaining); + const auto va = LoadU(d, pa + i); + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(du16, N - remaining); + const auto va = LoadU(d, pa + i); // always unaligned + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(df32, sum0); + } + + // Returns sum{i32(pa[i]) * i32(pb[i])} for i16 inputs. Aligning the pointers + // to a multiple of N elements is helpful but not required. + template + static HWY_INLINE int32_t Compute(const D d, + const int16_t* const HWY_RESTRICT pa, + const int16_t* const HWY_RESTRICT pb, + const size_t num_elements) { + const RebindToUnsigned du16; + const RepartitionToWide di32; + + using VI32 = Vec; + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + int32_t sum0 = 0; // Only 2x unroll to avoid excessive code size for.. + int32_t sum1 = 0; // this unlikely(?) case. + for (; i + 2 <= num_elements; i += 2) { + sum0 += int32_t{pa[i + 0]} * int32_t{pb[i + 0]}; + sum1 += int32_t{pa[i + 1]} * int32_t{pb[i + 1]}; + } + if (i < num_elements) { + sum1 += int32_t{pa[i]} * int32_t{pb[i]}; + } + return sum0 + sum1; + } + + // See comment in the other Compute() overload. Unroll 2x, but we need + // twice as many sums for ReorderWidenMulAccumulate. + VI32 sum0 = Zero(di32); + VI32 sum1 = Zero(di32); + VI32 sum2 = Zero(di32); + VI32 sum3 = Zero(di32); + + // Main loop: unrolled + for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum2 = ReorderWidenMulAccumulate(di32, a1, b1, sum2, sum3); + } + + // Possibly one more iteration of whole vectors + if (i + N <= num_elements) { + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(du16, remaining); + const auto va = LoadU(d, pa + i); + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3); + + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(du16, N - remaining); + const auto va = LoadU(d, pa + i); // always unaligned + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(di32, sum0); + } +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/image/image.h b/third_party/aom/third_party/highway/hwy/contrib/image/image.h new file mode 100644 index 000000000000..7316c762de3e --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/image/image.h @@ -0,0 +1,467 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ +#define HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ + +// SIMD/multicore-friendly planar image representation with row accessors. + +#include + +#include // std::move + +#include "third_party/highway/hwy/aligned_allocator.h" +#include "third_party/highway/hwy/base.h" + +namespace hwy { + +// Type-independent parts of Image<> - reduces code duplication and facilitates +// moving member function implementations to cc file. +struct HWY_CONTRIB_DLLEXPORT ImageBase { + // Returns required alignment in bytes for externally allocated memory. + static size_t VectorSize(); + + // Returns distance [bytes] between the start of two consecutive rows, a + // multiple of VectorSize but NOT kAlias (see implementation). + static size_t BytesPerRow(size_t xsize, size_t sizeof_t); + + // No allocation (for output params or unused images) + ImageBase() + : xsize_(0), + ysize_(0), + bytes_per_row_(0), + bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) {} + + // Allocates memory (this is the common case) + ImageBase(size_t xsize, size_t ysize, size_t sizeof_t); + + // References but does not take ownership of external memory. Useful for + // interoperability with other libraries. `aligned` must be aligned to a + // multiple of VectorSize() and `bytes_per_row` must also be a multiple of + // VectorSize() or preferably equal to BytesPerRow(). + ImageBase(size_t xsize, size_t ysize, size_t bytes_per_row, void* aligned); + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo() instead. + ImageBase(const ImageBase& other) = delete; + ImageBase& operator=(const ImageBase& other) = delete; + + // Move constructor (required for returning Image from function) + ImageBase(ImageBase&& other) noexcept = default; + + // Move assignment (required for std::vector) + ImageBase& operator=(ImageBase&& other) noexcept = default; + + void Swap(ImageBase& other); + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. Caller is responsible + // for ensuring xsize/ysize are <= the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + xsize_ = static_cast(xsize); + ysize_ = static_cast(ysize); + // NOTE: we can't recompute bytes_per_row for more compact storage and + // better locality because that would invalidate the image contents. + } + + // How many pixels. + HWY_INLINE size_t xsize() const { return xsize_; } + HWY_INLINE size_t ysize() const { return ysize_; } + + // NOTE: do not use this for copying rows - the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return bytes_per_row_; } + + // Raw access to byte contents, for interfacing with other libraries. + // Unsigned char instead of char to avoid surprises (sign extension). + HWY_INLINE uint8_t* bytes() { + void* p = bytes_.get(); + return static_cast(HWY_ASSUME_ALIGNED(p, 64)); + } + HWY_INLINE const uint8_t* bytes() const { + const void* p = bytes_.get(); + return static_cast(HWY_ASSUME_ALIGNED(p, 64)); + } + + protected: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidRow(const size_t y) const { +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + if (y >= ysize_) { + HWY_ABORT("Row(%d) >= %u\n", static_cast(y), ysize_); + } +#endif + + void* row = bytes_.get() + y * bytes_per_row_; + return HWY_ASSUME_ALIGNED(row, 64); + } + + enum class Padding { + // Allow Load(d, row + x) for x = 0; x < xsize(); x += Lanes(d). Default. + kRoundUp, + // Allow LoadU(d, row + x) for x <= xsize() - 1. This requires an extra + // vector to be initialized. If done by default, this would suppress + // legitimate msan warnings. We therefore require users to explicitly call + // InitializePadding before using unaligned loads (e.g. convolution). + kUnaligned + }; + + // Initializes the minimum bytes required to suppress msan warnings from + // legitimate (according to Padding mode) vector loads/stores on the right + // border, where some lanes are uninitialized and assumed to be unused. + void InitializePadding(size_t sizeof_t, Padding padding); + + // (Members are non-const to enable assignment during move-assignment.) + uint32_t xsize_; // In valid pixels, not including any padding. + uint32_t ysize_; + size_t bytes_per_row_; // Includes padding. + AlignedFreeUniquePtr bytes_; +}; + +// Single channel, aligned rows separated by padding. T must be POD. +// +// 'Single channel' (one 2D array per channel) simplifies vectorization +// (repeating the same operation on multiple adjacent components) without the +// complexity of a hybrid layout (8 R, 8 G, 8 B, ...). In particular, clients +// can easily iterate over all components in a row and Image requires no +// knowledge of the pixel format beyond the component type "T". +// +// 'Aligned' means each row is aligned to the L1 cache line size. This prevents +// false sharing between two threads operating on adjacent rows. +// +// 'Padding' is still relevant because vectors could potentially be larger than +// a cache line. By rounding up row sizes to the vector size, we allow +// reading/writing ALIGNED vectors whose first lane is a valid sample. This +// avoids needing a separate loop to handle remaining unaligned lanes. +// +// This image layout could also be achieved with a vector and a row accessor +// function, but a class wrapper with support for "deleter" allows wrapping +// existing memory allocated by clients without copying the pixels. It also +// provides convenient accessors for xsize/ysize, which shortens function +// argument lists. Supports move-construction so it can be stored in containers. +template +class Image : public ImageBase { + public: + using T = ComponentType; + + Image() = default; + Image(const size_t xsize, const size_t ysize) + : ImageBase(xsize, ysize, sizeof(T)) {} + Image(const size_t xsize, const size_t ysize, size_t bytes_per_row, + void* aligned) + : ImageBase(xsize, ysize, bytes_per_row, aligned) {} + + void InitializePaddingForUnalignedAccesses() { + InitializePadding(sizeof(T), Padding::kUnaligned); + } + + HWY_INLINE const T* ConstRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + HWY_INLINE const T* ConstRow(const size_t y) { + return static_cast(VoidRow(y)); + } + + // Returns pointer to non-const. This allows passing const Image* parameters + // when the callee is only supposed to fill the pixels, as opposed to + // allocating or resizing the image. + HWY_INLINE T* MutableRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + HWY_INLINE T* MutableRow(const size_t y) { + return static_cast(VoidRow(y)); + } + + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must + // NOT be used to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { + return static_cast(bytes_per_row_ / sizeof(T)); + } +}; + +using ImageF = Image; + +// A bundle of 3 same-sized images. To fill an existing Image3 using +// single-channel producers, we also need access to each const Image*. Const +// prevents breaking the same-size invariant, while still allowing pixels to be +// changed via MutableRow. +template +class Image3 { + public: + using T = ComponentType; + using ImageT = Image; + static constexpr size_t kNumPlanes = 3; + + Image3() : planes_{ImageT(), ImageT(), ImageT()} {} + + Image3(const size_t xsize, const size_t ysize) + : planes_{ImageT(xsize, ysize), ImageT(xsize, ysize), + ImageT(xsize, ysize)} {} + + Image3(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + } + + Image3(ImageT&& plane0, ImageT&& plane1, ImageT&& plane2) { + if (!SameSize(plane0, plane1) || !SameSize(plane0, plane2)) { + HWY_ABORT( + "Not same size: %d x %d, %d x %d, %d x %d\n", + static_cast(plane0.xsize()), static_cast(plane0.ysize()), + static_cast(plane1.xsize()), static_cast(plane1.ysize()), + static_cast(plane2.xsize()), static_cast(plane2.ysize())); + } + planes_[0] = std::move(plane0); + planes_[1] = std::move(plane1); + planes_[2] = std::move(plane2); + } + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo instead. + Image3(const Image3& other) = delete; + Image3& operator=(const Image3& other) = delete; + + Image3& operator=(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + return *this; + } + + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) const { + return static_cast(VoidPlaneRow(c, y)); + } + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) { + return static_cast(VoidPlaneRow(c, y)); + } + + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) const { + return static_cast(VoidPlaneRow(c, y)); + } + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) { + return static_cast(VoidPlaneRow(c, y)); + } + + HWY_INLINE const ImageT& Plane(size_t idx) const { return planes_[idx]; } + + void Swap(Image3& other) { + for (size_t c = 0; c < 3; ++c) { + other.planes_[c].Swap(planes_[c]); + } + } + + void ShrinkTo(const size_t xsize, const size_t ysize) { + for (ImageT& plane : planes_) { + plane.ShrinkTo(xsize, ysize); + } + } + + // Sizes of all three images are guaranteed to be equal. + HWY_INLINE size_t xsize() const { return planes_[0].xsize(); } + HWY_INLINE size_t ysize() const { return planes_[0].ysize(); } + // Returns offset [bytes] from one row to the next row of the same plane. + // WARNING: this must NOT be used to determine xsize, nor for copying rows - + // the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return planes_[0].bytes_per_row(); } + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must NOT be used + // to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { return planes_[0].PixelsPerRow(); } + + private: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidPlaneRow(const size_t c, const size_t y) const { +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + if (c >= kNumPlanes || y >= ysize()) { + HWY_ABORT("PlaneRow(%d, %d) >= %d\n", static_cast(c), + static_cast(y), static_cast(ysize())); + } +#endif + // Use the first plane's stride because the compiler might not realize they + // are all equal. Thus we only need a single multiplication for all planes. + const size_t row_offset = y * planes_[0].bytes_per_row(); + const void* row = planes_[c].bytes() + row_offset; + return static_cast( + HWY_ASSUME_ALIGNED(row, HWY_ALIGNMENT)); + } + + private: + ImageT planes_[kNumPlanes]; +}; + +using Image3F = Image3; + +// Rectangular region in image(s). Factoring this out of Image instead of +// shifting the pointer by x0/y0 allows this to apply to multiple images with +// different resolutions. Can compare size via SameSize(rect1, rect2). +class Rect { + public: + // Most windows are xsize_max * ysize_max, except those on the borders where + // begin + size_max > end. + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max, size_t xend, size_t yend) + : x0_(xbegin), + y0_(ybegin), + xsize_(ClampedSize(xbegin, xsize_max, xend)), + ysize_(ClampedSize(ybegin, ysize_max, yend)) {} + + // Construct with origin and known size (typically from another Rect). + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize, size_t ysize) + : x0_(xbegin), y0_(ybegin), xsize_(xsize), ysize_(ysize) {} + + // Construct a rect that covers a whole image. + template + explicit Rect(const Image& image) + : Rect(0, 0, image.xsize(), image.ysize()) {} + + Rect() : Rect(0, 0, 0, 0) {} + + Rect(const Rect&) = default; + Rect& operator=(const Rect&) = default; + + Rect Subrect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max) { + return Rect(x0_ + xbegin, y0_ + ybegin, xsize_max, ysize_max, x0_ + xsize_, + y0_ + ysize_); + } + + template + const T* ConstRow(const Image* image, size_t y) const { + return image->ConstRow(y + y0_) + x0_; + } + + template + T* MutableRow(const Image* image, size_t y) const { + return image->MutableRow(y + y0_) + x0_; + } + + template + const T* ConstPlaneRow(const Image3& image, size_t c, size_t y) const { + return image.ConstPlaneRow(c, y + y0_) + x0_; + } + + template + T* MutablePlaneRow(Image3* image, const size_t c, size_t y) const { + return image->MutablePlaneRow(c, y + y0_) + x0_; + } + + // Returns true if this Rect fully resides in the given image. ImageT could be + // Image or Image3; however if ImageT is Rect, results are nonsensical. + template + bool IsInside(const ImageT& image) const { + return (x0_ + xsize_ <= image.xsize()) && (y0_ + ysize_ <= image.ysize()); + } + + size_t x0() const { return x0_; } + size_t y0() const { return y0_; } + size_t xsize() const { return xsize_; } + size_t ysize() const { return ysize_; } + + private: + // Returns size_max, or whatever is left in [begin, end). + static constexpr size_t ClampedSize(size_t begin, size_t size_max, + size_t end) { + return (begin + size_max <= end) ? size_max + : (end > begin ? end - begin : 0); + } + + size_t x0_; + size_t y0_; + + size_t xsize_; + size_t ysize_; +}; + +// Works for any image-like input type(s). +template +HWY_MAYBE_UNUSED bool SameSize(const Image1& image1, const Image2& image2) { + return image1.xsize() == image2.xsize() && image1.ysize() == image2.ysize(); +} + +// Mirrors out of bounds coordinates and returns valid coordinates unchanged. +// We assume the radius (distance outside the image) is small compared to the +// image size, otherwise this might not terminate. +// The mirror is outside the last column (border pixel is also replicated). +static HWY_INLINE HWY_MAYBE_UNUSED size_t Mirror(int64_t x, + const int64_t xsize) { + HWY_DASSERT(xsize != 0); + + // TODO(janwas): replace with branchless version + while (x < 0 || x >= xsize) { + if (x < 0) { + x = -x - 1; + } else { + x = 2 * xsize - 1 - x; + } + } + return static_cast(x); +} + +// Wrap modes for ensuring X/Y coordinates are in the valid range [0, size): + +// Mirrors (repeating the edge pixel once). Useful for convolutions. +struct WrapMirror { + HWY_INLINE size_t operator()(const int64_t coord, const size_t size) const { + return Mirror(coord, static_cast(size)); + } +}; + +// Returns the same coordinate, for when we know "coord" is already valid (e.g. +// interior of an image). +struct WrapUnchanged { + HWY_INLINE size_t operator()(const int64_t coord, size_t /*size*/) const { + return static_cast(coord); + } +}; + +// Similar to Wrap* but for row pointers (reduces Row() multiplications). + +class WrapRowMirror { + public: + template + WrapRowMirror(const View& image, size_t ysize) + : first_row_(image.ConstRow(0)), last_row_(image.ConstRow(ysize - 1)) {} + + const float* operator()(const float* const HWY_RESTRICT row, + const int64_t stride) const { + if (row < first_row_) { + const int64_t num_before = first_row_ - row; + // Mirrored; one row before => row 0, two before = row 1, ... + return first_row_ + num_before - stride; + } + if (row > last_row_) { + const int64_t num_after = row - last_row_; + // Mirrored; one row after => last row, two after = last - 1, ... + return last_row_ - num_after + stride; + } + return row; + } + + private: + const float* const HWY_RESTRICT first_row_; + const float* const HWY_RESTRICT last_row_; +}; + +struct WrapRowUnchanged { + HWY_INLINE const float* operator()(const float* const HWY_RESTRICT row, + int64_t /*stride*/) const { + return row; + } +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/math/math-inl.h b/third_party/aom/third_party/highway/hwy/contrib/math/math-inl.h new file mode 100644 index 000000000000..5bb536d9f827 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/math/math-inl.h @@ -0,0 +1,1752 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#if defined(HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#endif + +#include +#include + +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +/** + * Highway SIMD version of std::acos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc cosine of 'x' + */ +template +HWY_INLINE V Acos(D d, V x); +template +HWY_NOINLINE V CallAcos(const D d, VecArg x) { + return Acos(d, x); +} + +/** + * Highway SIMD version of std::acosh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[1, +FLT_MAX], float64[1, +DBL_MAX] + * @return hyperbolic arc cosine of 'x' + */ +template +HWY_INLINE V Acosh(D d, V x); +template +HWY_NOINLINE V CallAcosh(const D d, VecArg x) { + return Acosh(d, x); +} + +/** + * Highway SIMD version of std::asin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc sine of 'x' + */ +template +HWY_INLINE V Asin(D d, V x); +template +HWY_NOINLINE V CallAsin(const D d, VecArg x) { + return Asin(d, x); +} + +/** + * Highway SIMD version of std::asinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic arc sine of 'x' + */ +template +HWY_INLINE V Asinh(D d, V x); +template +HWY_NOINLINE V CallAsinh(const D d, VecArg x) { + return Asinh(d, x); +} + +/** + * Highway SIMD version of std::atan(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return arc tangent of 'x' + */ +template +HWY_INLINE V Atan(D d, V x); +template +HWY_NOINLINE V CallAtan(const D d, VecArg x) { + return Atan(d, x); +} + +/** + * Highway SIMD version of std::atanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: (-1, +1) + * @return hyperbolic arc tangent of 'x' + */ +template +HWY_INLINE V Atanh(D d, V x); +template +HWY_NOINLINE V CallAtanh(const D d, VecArg x) { + return Atanh(d, x); +} + +// Atan2 was added later and some users may be implementing it themselves, so +// notify them that this version of Highway defines it already. +#ifndef HWY_HAVE_ATAN2 +#define HWY_HAVE_ATAN2 1 +#endif + +/** + * Highway SIMD version of std::atan2(x). + * + * Valid Lane Types: float32, float64 + * Correctly handles negative zero, infinities, and NaN. + * @return atan2 of 'y', 'x' + */ +template , class M = MFromD, + typename T = TFromD> +HWY_INLINE V Atan2(const D d, V y, V x) { + const V kHalf = Set(d, static_cast(+0.5)); + const V kPi = Set(d, static_cast(+3.14159265358979323846264)); + const V kPi2 = Mul(kPi, kHalf); + + const V k0 = Zero(d); + const M y_0 = Eq(y, k0); + const M x_0 = Eq(x, k0); + const M x_neg = Lt(x, k0); + const M y_inf = IsInf(y); + const M x_inf = IsInf(x); + const M nan = Or(IsNaN(y), IsNaN(x)); + + const V if_xneg_pi = IfThenElseZero(x_neg, kPi); + // x= +inf: pi/4; -inf: 3*pi/4; else: pi/2 + const V if_yinf = Mul(kHalf, IfThenElse(x_inf, Add(kPi2, if_xneg_pi), kPi)); + + V t = Atan(d, Div(y, x)); + // Disambiguate between quadrants 1/3 and 2/4 by adding (Q2: Pi; Q3: -Pi). + t = Add(t, CopySignToAbs(if_xneg_pi, y)); + // Special cases for 0 and infinity: + t = IfThenElse(x_inf, if_xneg_pi, t); + t = IfThenElse(x_0, kPi2, t); + t = IfThenElse(y_inf, if_yinf, t); + t = IfThenElse(y_0, if_xneg_pi, t); + // Any input NaN => NaN, otherwise fix sign. + return IfThenElse(nan, NaN(d), CopySign(t, y)); +} +template +HWY_NOINLINE V CallAtan2(const D d, VecArg y, VecArg x) { + return Atan2(d, y, x); +} + +/** + * Highway SIMD version of std::cos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return cosine of 'x' + */ +template +HWY_INLINE V Cos(D d, V x); +template +HWY_NOINLINE V CallCos(const D d, VecArg x) { + return Cos(d, x); +} + +/** + * Highway SIMD version of std::exp(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 1 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x + */ +template +HWY_INLINE V Exp(D d, V x); +template +HWY_NOINLINE V CallExp(const D d, VecArg x) { + return Exp(d, x); +} + +/** + * Highway SIMD version of std::exp2(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32[-FLT_MAX, +128], float64[-DBL_MAX, +1024] + * @return 2^x + */ +template +HWY_INLINE V Exp2(D d, V x); +template +HWY_NOINLINE V CallExp2(const D d, VecArg x) { + return Exp2(d, x); +} + +/** + * Highway SIMD version of std::expm1(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x - 1 + */ +template +HWY_INLINE V Expm1(D d, V x); +template +HWY_NOINLINE V CallExpm1(const D d, VecArg x) { + return Expm1(d, x); +} + +/** + * Highway SIMD version of std::log(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return natural logarithm of 'x' + */ +template +HWY_INLINE V Log(D d, V x); +template +HWY_NOINLINE V CallLog(const D d, VecArg x) { + return Log(d, x); +} + +/** + * Highway SIMD version of std::log10(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 10 logarithm of 'x' + */ +template +HWY_INLINE V Log10(D d, V x); +template +HWY_NOINLINE V CallLog10(const D d, VecArg x) { + return Log10(d, x); +} + +/** + * Highway SIMD version of std::log1p(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32[0, +FLT_MAX], float64[0, +DBL_MAX] + * @return log(1 + x) + */ +template +HWY_INLINE V Log1p(D d, V x); +template +HWY_NOINLINE V CallLog1p(const D d, VecArg x) { + return Log1p(d, x); +} + +/** + * Highway SIMD version of std::log2(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 2 logarithm of 'x' + */ +template +HWY_INLINE V Log2(D d, V x); +template +HWY_NOINLINE V CallLog2(const D d, VecArg x) { + return Log2(d, x); +} + +/** + * Highway SIMD version of std::sin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return sine of 'x' + */ +template +HWY_INLINE V Sin(D d, V x); +template +HWY_NOINLINE V CallSin(const D d, VecArg x) { + return Sin(d, x); +} + +/** + * Highway SIMD version of std::sinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-88.7228, +88.7228], float64[-709, +709] + * @return hyperbolic sine of 'x' + */ +template +HWY_INLINE V Sinh(D d, V x); +template +HWY_NOINLINE V CallSinh(const D d, VecArg x) { + return Sinh(d, x); +} + +/** + * Highway SIMD version of std::tanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic tangent of 'x' + */ +template +HWY_INLINE V Tanh(D d, V x); +template +HWY_NOINLINE V CallTanh(const D d, VecArg x) { + return Tanh(d, x); +} + +/** + * Highway SIMD version of SinCos. + * Compute the sine and cosine at the same time + * The performance should be around the same as calling Sin. + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 1 + * Valid Range: [-39000, +39000] + * @return sine and cosine of 'x' + */ +template +HWY_INLINE void SinCos(D d, V x, V& s, V& c); +template +HWY_NOINLINE void CallSinCos(const D d, VecArg x, V& s, V& c) { + SinCos(d, x, s, c); +} + +/** + * Highway SIMD version of Hypot + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hypotenuse of a and b + */ +template +HWY_INLINE V Hypot(D d, V a, V b); +template +HWY_NOINLINE V CallHypot(const D d, VecArg a, VecArg b) { + return Hypot(d, a, b); +} + +//////////////////////////////////////////////////////////////////////////////// +// Implementation +//////////////////////////////////////////////////////////////////////////////// +namespace impl { + +// Estrin's Scheme is a faster method for evaluating large polynomials on +// super scalar architectures. It works by factoring the Horner's Method +// polynomial into power of two sub-trees that can be evaluated in parallel. +// Wikipedia Link: https://en.wikipedia.org/wiki/Estrin%27s_scheme +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1) { + return MulAdd(c1, x, c0); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2) { + T x2 = Mul(x, x); + return MulAdd(x2, c2, MulAdd(c1, x, c0)); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3) { + T x2 = Mul(x, x); + return MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, c4, MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(c5, x, c4), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(x2, c6, MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, c8, + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(c9, x, c8), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(x2, c10, MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd( + x8, MulAdd(x4, c12, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(c13, x, c12), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, c14, MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, c16, + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, MulAdd(c17, x, c16), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17, + T c18) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, MulAdd(x2, c18, MulAdd(c17, x, c16)), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} + +template +struct AsinImpl {}; +template +struct AtanImpl {}; +template +struct CosSinImpl {}; +template +struct ExpImpl {}; +template +struct LogImpl {}; +template +struct SinCosImpl {}; + +template <> +struct AsinImpl { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666677296f); + const auto k1 = Set(d, +0.07495029271f); + const auto k2 = Set(d, +0.04547423869f); + const auto k3 = Set(d, +0.02424046025f); + const auto k4 = Set(d, +0.04197454825f); + + return Estrin(x2, k0, k1, k2, k3, k4); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct AsinImpl { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666666666666497543); + const auto k1 = Set(d, +0.07500000000378581611); + const auto k2 = Set(d, +0.04464285681377102438); + const auto k3 = Set(d, +0.03038195928038132237); + const auto k4 = Set(d, +0.02237176181932048341); + const auto k5 = Set(d, +0.01735956991223614604); + const auto k6 = Set(d, +0.01388715184501609218); + const auto k7 = Set(d, +0.01215360525577377331); + const auto k8 = Set(d, +0.006606077476277170610); + const auto k9 = Set(d, +0.01929045477267910674); + const auto k10 = Set(d, -0.01581918243329996643); + const auto k11 = Set(d, +0.03161587650653934628); + + return Estrin(x2, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11); + } +}; + +#endif + +template <> +struct AtanImpl { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333331018686294555664062f); + const auto k1 = Set(d, +0.199926957488059997558594f); + const auto k2 = Set(d, -0.142027363181114196777344f); + const auto k3 = Set(d, +0.106347933411598205566406f); + const auto k4 = Set(d, -0.0748900920152664184570312f); + const auto k5 = Set(d, +0.0425049886107444763183594f); + const auto k6 = Set(d, -0.0159569028764963150024414f); + const auto k7 = Set(d, +0.00282363896258175373077393f); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7), Mul(y, x), x); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct AtanImpl { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333333333333311110369124); + const auto k1 = Set(d, +0.199999999996591265594148); + const auto k2 = Set(d, -0.14285714266771329383765); + const auto k3 = Set(d, +0.111111105648261418443745); + const auto k4 = Set(d, -0.090908995008245008229153); + const auto k5 = Set(d, +0.0769219538311769618355029); + const auto k6 = Set(d, -0.0666573579361080525984562); + const auto k7 = Set(d, +0.0587666392926673580854313); + const auto k8 = Set(d, -0.0523674852303482457616113); + const auto k9 = Set(d, +0.0466667150077840625632675); + const auto k10 = Set(d, -0.0407629191276836500001934); + const auto k11 = Set(d, +0.0337852580001353069993897); + const auto k12 = Set(d, -0.0254517624932312641616861); + const auto k13 = Set(d, +0.016599329773529201970117); + const auto k14 = Set(d, -0.00889896195887655491740809); + const auto k15 = Set(d, +0.00370026744188713119232403); + const auto k16 = Set(d, -0.00110611831486672482563471); + const auto k17 = Set(d, +0.000209850076645816976906797); + const auto k18 = Set(d, -1.88796008463073496563746e-5); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, + k12, k13, k14, k15, k16, k17, k18), + Mul(y, x), x); + } +}; + +#endif + +template <> +struct CosSinImpl { + // Rounds float toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind(), x); + } + + template + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -1.66666597127914428710938e-1f); + const auto k1 = Set(d, +8.33307858556509017944336e-3f); + const auto k2 = Set(d, -1.981069071916863322258e-4f); + const auto k3 = Set(d, +2.6083159809786593541503e-6f); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3), Mul(y, x), x); + } + + template + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0f + kHalfPiPart1f + kHalfPiPart2f + kHalfPiPart3f ~= -pi/2 + const V kHalfPiPart0f = Set(d, -0.5f * 3.140625f); + const V kHalfPiPart1f = Set(d, -0.5f * 0.0009670257568359375f); + const V kHalfPiPart2f = Set(d, -0.5f * 6.2771141529083251953e-7f); + const V kHalfPiPart3f = Set(d, -0.5f * 1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kHalfPiPart0f, x); + x = MulAdd(qf, kHalfPiPart1f, x); + x = MulAdd(qf, kHalfPiPart2f, x); + x = MulAdd(qf, kHalfPiPart3f, x); + return x; + } + + template + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0f + kPiPart1f + kPiPart2f + kPiPart3f ~= -pi + const V kPiPart0f = Set(d, -3.140625f); + const V kPiPart1f = Set(d, -0.0009670257568359375f); + const V kPiPart2f = Set(d, -6.2771141529083251953e-7f); + const V kPiPart3f = Set(d, -1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kPiPart0f, x); + x = MulAdd(qf, kPiPart1f, x); + x = MulAdd(qf, kPiPart2f, x); + x = MulAdd(qf, kPiPart3f, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template + HWY_INLINE Vec> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind(), 2); + return BitCast(d, ShiftLeft<30>(AndNot(q, kTwo))); + } + + // ((q & 1) ? -0.0 : +0.0) + template + HWY_INLINE Vec> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind(), 1); + return BitCast(d, ShiftLeft<31>(And(q, kOne))); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct CosSinImpl { + // Rounds double toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind(), x); + } + + template + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -0.166666666666666657414808); + const auto k1 = Set(d, +0.00833333333333332974823815); + const auto k2 = Set(d, -0.000198412698412696162806809); + const auto k3 = Set(d, +2.75573192239198747630416e-6); + const auto k4 = Set(d, -2.50521083763502045810755e-8); + const auto k5 = Set(d, +1.60590430605664501629054e-10); + const auto k6 = Set(d, -7.64712219118158833288484e-13); + const auto k7 = Set(d, +2.81009972710863200091251e-15); + const auto k8 = Set(d, -7.97255955009037868891952e-18); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8), Mul(y, x), x); + } + + template + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0d + kHalfPiPart1d + kHalfPiPart2d + kHalfPiPart3d ~= -pi/2 + const V kHalfPiPart0d = Set(d, -0.5 * 3.1415926218032836914); + const V kHalfPiPart1d = Set(d, -0.5 * 3.1786509424591713469e-8); + const V kHalfPiPart2d = Set(d, -0.5 * 1.2246467864107188502e-16); + const V kHalfPiPart3d = Set(d, -0.5 * 1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kHalfPiPart0d, x); + x = MulAdd(qf, kHalfPiPart1d, x); + x = MulAdd(qf, kHalfPiPart2d, x); + x = MulAdd(qf, kHalfPiPart3d, x); + return x; + } + + template + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0d + kPiPart1d + kPiPart2d + kPiPart3d ~= -pi + const V kPiPart0d = Set(d, -3.1415926218032836914); + const V kPiPart1d = Set(d, -3.1786509424591713469e-8); + const V kPiPart2d = Set(d, -1.2246467864107188502e-16); + const V kPiPart3d = Set(d, -1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kPiPart0d, x); + x = MulAdd(qf, kPiPart1d, x); + x = MulAdd(qf, kPiPart2d, x); + x = MulAdd(qf, kPiPart3d, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template + HWY_INLINE Vec> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind(), 2); + return BitCast( + d, ShiftLeft<62>(PromoteTo(Rebind(), AndNot(q, kTwo)))); + } + + // ((q & 1) ? -0.0 : +0.0) + template + HWY_INLINE Vec> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind(), 1); + return BitCast( + d, ShiftLeft<63>(PromoteTo(Rebind(), And(q, kOne)))); + } +}; + +#endif + +template <> +struct ExpImpl { + // Rounds float toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind(), x); + } + + // Rounds float to nearest int32_t + template + HWY_INLINE Vec> ToNearestInt32(D /*unused*/, V x) { + return NearestInt(x); + } + + template + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5f); + const auto k1 = Set(d, +0.166666671633720397949219f); + const auto k2 = Set(d, +0.0416664853692054748535156f); + const auto k3 = Set(d, +0.00833336077630519866943359f); + const auto k4 = Set(d, +0.00139304355252534151077271f); + const auto k5 = Set(d, +0.000198527617612853646278381f); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5), Mul(x, x), x); + } + + // Computes 2^x, where x is an integer. + template + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const VI32 kOffset = Set(di32, 0x7F); + return BitCast(d, ShiftLeft<23>(Add(x, kOffset))); + } + + // Sets the exponent of 'x' to 2^e. + template + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0f + kLn2Part1f ~= -ln(2) + const V kLn2Part0f = Set(d, -0.693145751953125f); + const V kLn2Part1f = Set(d, -1.428606765330187045e-6f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kLn2Part0f, x); + x = MulAdd(qf, kLn2Part1f, x); + return x; + } + + template + HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) { + const V x_frac = Sub(x, ConvertTo(d, q)); + return MulAdd(x_frac, Set(d, 0.193147182464599609375f), + Mul(x_frac, Set(d, 0.5f))); + } +}; + +template <> +struct LogImpl { + template + HWY_INLINE Vec> Log2p1NoSubnormal(D /*d*/, V x) { + const Rebind di32; + const Rebind du32; + const auto kBias = Set(di32, 0x7F); + return Sub(BitCast(di32, ShiftRight<23>(BitCast(du32, x))), kBias); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.66666662693f); + const V k1 = Set(d, 0.40000972152f); + const V k2 = Set(d, 0.28498786688f); + const V k3 = Set(d, 0.24279078841f); + + const V x2 = Mul(x, x); + const V x4 = Mul(x2, x2); + return MulAdd(MulAdd(k2, x4, k0), x2, Mul(MulAdd(k3, x4, k1), x4)); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template <> +struct ExpImpl { + // Rounds double toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind(), x); + } + + // Rounds double to nearest int32_t + template + HWY_INLINE Vec> ToNearestInt32(D /*unused*/, V x) { + return DemoteToNearestInt(Rebind(), x); + } + + template + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5); + const auto k1 = Set(d, +0.166666666666666851703837); + const auto k2 = Set(d, +0.0416666666666665047591422); + const auto k3 = Set(d, +0.00833333333331652721664984); + const auto k4 = Set(d, +0.00138888888889774492207962); + const auto k5 = Set(d, +0.000198412698960509205564975); + const auto k6 = Set(d, +2.4801587159235472998791e-5); + const auto k7 = Set(d, +2.75572362911928827629423e-6); + const auto k8 = Set(d, +2.75573911234900471893338e-7); + const auto k9 = Set(d, +2.51112930892876518610661e-8); + const auto k10 = Set(d, +2.08860621107283687536341e-9); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10), + Mul(x, x), x); + } + + // Computes 2^x, where x is an integer. + template + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const Rebind di64; + const VI32 kOffset = Set(di32, 0x3FF); + return BitCast(d, ShiftLeft<52>(PromoteTo(di64, Add(x, kOffset)))); + } + + // Sets the exponent of 'x' to 2^e. + template + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0d + kLn2Part1d ~= -ln(2) + const V kLn2Part0d = Set(d, -0.6931471805596629565116018); + const V kLn2Part1d = Set(d, -0.28235290563031577122588448175e-12); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kLn2Part0d, x); + x = MulAdd(qf, kLn2Part1d, x); + return x; + } + + template + HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) { + const V x_frac = Sub(x, PromoteTo(d, q)); + return MulAdd(x_frac, Set(d, 0.1931471805599453139823396), + Mul(x_frac, Set(d, 0.5))); + } +}; + +template <> +struct LogImpl { + template + HWY_INLINE Vec> Log2p1NoSubnormal(D /*d*/, V x) { + const Rebind di64; + const Rebind du64; + return Sub(BitCast(di64, ShiftRight<52>(BitCast(du64, x))), + Set(di64, 0x3FF)); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.6666666666666735130); + const V k1 = Set(d, 0.3999999999940941908); + const V k2 = Set(d, 0.2857142874366239149); + const V k3 = Set(d, 0.2222219843214978396); + const V k4 = Set(d, 0.1818357216161805012); + const V k5 = Set(d, 0.1531383769920937332); + const V k6 = Set(d, 0.1479819860511658591); + + const V x2 = Mul(x, x); + const V x4 = Mul(x2, x2); + return MulAdd(MulAdd(MulAdd(MulAdd(k6, x4, k4), x4, k2), x4, k0), x2, + (Mul(MulAdd(MulAdd(k5, x4, k3), x4, k1), x4))); + } +}; + +#endif + +template +HWY_INLINE V Log(const D d, V x) { + // http://git.musl-libc.org/cgit/musl/tree/src/math/log.c for more info. + using T = TFromD; + impl::LogImpl impl; + + constexpr bool kIsF32 = (sizeof(T) == 4); + + // Float Constants + const V kLn2Hi = Set(d, kIsF32 ? static_cast(0.69313812256f) + : static_cast(0.693147180369123816490)); + const V kLn2Lo = Set(d, kIsF32 ? static_cast(9.0580006145e-6f) + : static_cast(1.90821492927058770002e-10)); + const V kOne = Set(d, static_cast(+1.0)); + const V kMinNormal = Set(d, kIsF32 ? static_cast(1.175494351e-38f) + : static_cast(2.2250738585072014e-308)); + const V kScale = Set(d, kIsF32 ? static_cast(3.355443200e+7f) + : static_cast(1.8014398509481984e+16)); + + // Integer Constants + using TI = MakeSigned; + const Rebind di; + using VI = decltype(Zero(di)); + const VI kLowerBits = Set(di, kIsF32 ? static_cast(0x00000000L) + : static_cast(0xFFFFFFFFLL)); + const VI kMagic = Set(di, kIsF32 ? static_cast(0x3F3504F3L) + : static_cast(0x3FE6A09E00000000LL)); + const VI kExpMask = Set(di, kIsF32 ? static_cast(0x3F800000L) + : static_cast(0x3FF0000000000000LL)); + const VI kExpScale = + Set(di, kIsF32 ? static_cast(-25) : static_cast(-54)); + const VI kManMask = Set(di, kIsF32 ? static_cast(0x7FFFFFL) + : static_cast(0xFFFFF00000000LL)); + + // Scale up 'x' so that it is no longer denormalized. + VI exp_bits; + V exp; + if (kAllowSubnormals == true) { + const auto is_denormal = Lt(x, kMinNormal); + x = IfThenElse(is_denormal, Mul(x, kScale), x); + + // Compute the new exponent. + exp_bits = Add(BitCast(di, x), Sub(kExpMask, kMagic)); + const VI exp_scale = + BitCast(di, IfThenElseZero(is_denormal, BitCast(d, kExpScale))); + exp = ConvertTo( + d, Add(exp_scale, impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits)))); + } else { + // Compute the new exponent. + exp_bits = Add(BitCast(di, x), Sub(kExpMask, kMagic)); + exp = ConvertTo(d, impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits))); + } + + // Renormalize. + const V y = Or(And(x, BitCast(d, kLowerBits)), + BitCast(d, Add(And(exp_bits, kManMask), kMagic))); + + // Approximate and reconstruct. + const V ym1 = Sub(y, kOne); + const V z = Div(ym1, Add(y, kOne)); + + return MulSub( + exp, kLn2Hi, + Sub(MulSub(z, Sub(ym1, impl.LogPoly(d, z)), Mul(exp, kLn2Lo)), ym1)); +} + +// SinCos +// Based on "sse_mathfun.h", by Julien Pommier +// http://gruntthepeon.free.fr/ssemath/ + +// Third degree poly +template +HWY_INLINE void SinCos3(D d, TFromD dp1, TFromD dp2, TFromD dp3, V x, + V& s, V& c) { + using T = TFromD; + using TI = MakeSigned; + using DI = Rebind; + const DI di; + using VI = decltype(Zero(di)); + using M = Mask; + + static constexpr size_t bits = sizeof(TI) * 8; + const VI sign_mask = SignBit(di); + const VI ci_0 = Zero(di); + const VI ci_1 = Set(di, 1); + const VI ci_2 = Set(di, 2); + const VI ci_4 = Set(di, 4); + const V cos_p0 = Set(d, ConvertScalarTo(2.443315711809948E-005)); + const V cos_p1 = Set(d, ConvertScalarTo(-1.388731625493765E-003)); + const V cos_p2 = Set(d, ConvertScalarTo(4.166664568298827E-002)); + const V sin_p0 = Set(d, ConvertScalarTo(-1.9515295891E-4)); + const V sin_p1 = Set(d, ConvertScalarTo(8.3321608736E-3)); + const V sin_p2 = Set(d, ConvertScalarTo(-1.6666654611E-1)); + const V FOPI = Set(d, ConvertScalarTo(1.27323954473516)); // 4 / M_PI + const V DP1 = Set(d, dp1); + const V DP2 = Set(d, dp2); + const V DP3 = Set(d, dp3); + + V xmm1, xmm2, sign_bit_sin, y; + VI imm0, imm2, imm4; + + sign_bit_sin = x; + x = Abs(x); + + /* extract the sign bit (upper one) */ + sign_bit_sin = And(sign_bit_sin, BitCast(d, sign_mask)); + + /* scale by 4/Pi */ + y = Mul(x, FOPI); + + /* store the integer part of y in imm2 */ + imm2 = ConvertTo(di, y); + + /* j=(j+1) & (~1) (see the cephes sources) */ + imm2 = Add(imm2, ci_1); + imm2 = AndNot(ci_1, imm2); + + y = ConvertTo(d, imm2); + imm4 = imm2; + + /* get the swap sign flag for the sine */ + imm0 = And(imm2, ci_4); + imm0 = ShiftLeft(imm0); + + V swap_sign_bit_sin = BitCast(d, imm0); + + /* get the polynomial selection mask for the sine*/ + imm2 = And(imm2, ci_2); + M poly_mask = RebindMask(d, Eq(imm2, ci_0)); + + /* The magic pass: "Extended precision modular arithmetic" + x = ((x - y * DP1) - y * DP2) - y * DP3; */ + x = MulAdd(y, DP1, x); + x = MulAdd(y, DP2, x); + x = MulAdd(y, DP3, x); + + imm4 = Sub(imm4, ci_2); + imm4 = AndNot(imm4, ci_4); + imm4 = ShiftLeft(imm4); + + V sign_bit_cos = BitCast(d, imm4); + + sign_bit_sin = Xor(sign_bit_sin, swap_sign_bit_sin); + + /* Evaluate the first polynomial (0 <= x <= Pi/4) */ + V z = Mul(x, x); + + y = MulAdd(cos_p0, z, cos_p1); + y = MulAdd(y, z, cos_p2); + y = Mul(y, z); + y = Mul(y, z); + y = NegMulAdd(z, Set(d, 0.5f), y); + y = Add(y, Set(d, 1)); + + /* Evaluate the second polynomial (Pi/4 <= x <= 0) */ + V y2 = MulAdd(sin_p0, z, sin_p1); + y2 = MulAdd(y2, z, sin_p2); + y2 = Mul(y2, z); + y2 = MulAdd(y2, x, x); + + /* select the correct result from the two polynomials */ + xmm1 = IfThenElse(poly_mask, y2, y); + xmm2 = IfThenElse(poly_mask, y, y2); + + /* update the sign */ + s = Xor(xmm1, sign_bit_sin); + c = Xor(xmm2, sign_bit_cos); +} + +// Sixth degree poly +template +HWY_INLINE void SinCos6(D d, TFromD dp1, TFromD dp2, TFromD dp3, V x, + V& s, V& c) { + using T = TFromD; + using TI = MakeSigned; + using DI = Rebind; + const DI di; + using VI = decltype(Zero(di)); + using M = Mask; + + static constexpr size_t bits = sizeof(TI) * 8; + const VI sign_mask = SignBit(di); + const VI ci_0 = Zero(di); + const VI ci_1 = Set(di, 1); + const VI ci_2 = Set(di, 2); + const VI ci_4 = Set(di, 4); + const V cos_p0 = Set(d, ConvertScalarTo(-1.13585365213876817300E-11)); + const V cos_p1 = Set(d, ConvertScalarTo(2.08757008419747316778E-9)); + const V cos_p2 = Set(d, ConvertScalarTo(-2.75573141792967388112E-7)); + const V cos_p3 = Set(d, ConvertScalarTo(2.48015872888517045348E-5)); + const V cos_p4 = Set(d, ConvertScalarTo(-1.38888888888730564116E-3)); + const V cos_p5 = Set(d, ConvertScalarTo(4.16666666666665929218E-2)); + const V sin_p0 = Set(d, ConvertScalarTo(1.58962301576546568060E-10)); + const V sin_p1 = Set(d, ConvertScalarTo(-2.50507477628578072866E-8)); + const V sin_p2 = Set(d, ConvertScalarTo(2.75573136213857245213E-6)); + const V sin_p3 = Set(d, ConvertScalarTo(-1.98412698295895385996E-4)); + const V sin_p4 = Set(d, ConvertScalarTo(8.33333333332211858878E-3)); + const V sin_p5 = Set(d, ConvertScalarTo(-1.66666666666666307295E-1)); + const V FOPI = // 4 / M_PI + Set(d, ConvertScalarTo(1.2732395447351626861510701069801148)); + const V DP1 = Set(d, dp1); + const V DP2 = Set(d, dp2); + const V DP3 = Set(d, dp3); + + V xmm1, xmm2, sign_bit_sin, y; + VI imm0, imm2, imm4; + + sign_bit_sin = x; + x = Abs(x); + + /* extract the sign bit (upper one) */ + sign_bit_sin = And(sign_bit_sin, BitCast(d, sign_mask)); + + /* scale by 4/Pi */ + y = Mul(x, FOPI); + + /* store the integer part of y in imm2 */ + imm2 = ConvertTo(di, y); + + /* j=(j+1) & (~1) (see the cephes sources) */ + imm2 = Add(imm2, ci_1); + imm2 = AndNot(ci_1, imm2); + + y = ConvertTo(d, imm2); + imm4 = imm2; + + /* get the swap sign flag for the sine */ + imm0 = And(imm2, ci_4); + imm0 = ShiftLeft(imm0); + + V swap_sign_bit_sin = BitCast(d, imm0); + + /* get the polynomial selection mask for the sine*/ + imm2 = And(imm2, ci_2); + M poly_mask = RebindMask(d, Eq(imm2, ci_0)); + + /* The magic pass: "Extended precision modular arithmetic" + x = ((x - y * DP1) - y * DP2) - y * DP3; */ + x = MulAdd(y, DP1, x); + x = MulAdd(y, DP2, x); + x = MulAdd(y, DP3, x); + + imm4 = Sub(imm4, ci_2); + imm4 = AndNot(imm4, ci_4); + imm4 = ShiftLeft(imm4); + + V sign_bit_cos = BitCast(d, imm4); + sign_bit_sin = Xor(sign_bit_sin, swap_sign_bit_sin); + + /* Evaluate the first polynomial (0 <= x <= Pi/4) */ + V z = Mul(x, x); + + y = MulAdd(cos_p0, z, cos_p1); + y = MulAdd(y, z, cos_p2); + y = MulAdd(y, z, cos_p3); + y = MulAdd(y, z, cos_p4); + y = MulAdd(y, z, cos_p5); + y = Mul(y, z); + y = Mul(y, z); + y = NegMulAdd(z, Set(d, 0.5f), y); + y = Add(y, Set(d, 1.0f)); + + /* Evaluate the second polynomial (Pi/4 <= x <= 0) */ + V y2 = MulAdd(sin_p0, z, sin_p1); + y2 = MulAdd(y2, z, sin_p2); + y2 = MulAdd(y2, z, sin_p3); + y2 = MulAdd(y2, z, sin_p4); + y2 = MulAdd(y2, z, sin_p5); + y2 = Mul(y2, z); + y2 = MulAdd(y2, x, x); + + /* select the correct result from the two polynomials */ + xmm1 = IfThenElse(poly_mask, y2, y); + xmm2 = IfThenElse(poly_mask, y, y2); + + /* update the sign */ + s = Xor(xmm1, sign_bit_sin); + c = Xor(xmm2, sign_bit_cos); +} + +template <> +struct SinCosImpl { + template + HWY_INLINE void SinCos(D d, V x, V& s, V& c) { + SinCos3(d, -0.78515625f, -2.4187564849853515625e-4f, + -3.77489497744594108e-8f, x, s, c); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template <> +struct SinCosImpl { + template + HWY_INLINE void SinCos(D d, V x, V& s, V& c) { + SinCos6(d, -7.85398125648498535156E-1, -3.77489470793079817668E-8, + -2.69515142907905952645E-15, x, s, c); + } +}; +#endif + +} // namespace impl + +template +HWY_INLINE V Acos(const D d, V x) { + using T = TFromD; + + const V kZero = Zero(d); + const V kHalf = Set(d, static_cast(+0.5)); + const V kPi = Set(d, static_cast(+3.14159265358979323846264)); + const V kPiOverTwo = Set(d, static_cast(+1.57079632679489661923132169)); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = Lt(abs_x, kHalf); + const V yy = + IfThenElse(mask, Mul(abs_x, abs_x), NegMulAdd(abs_x, kHalf, kHalf)); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl impl; + const V t = Mul(impl.AsinPoly(d, yy, y), Mul(y, yy)); + + const V t_plus_y = Add(t, y); + const V z = + IfThenElse(mask, Sub(kPiOverTwo, Add(Xor(y, sign_x), Xor(t, sign_x))), + Add(t_plus_y, t_plus_y)); + return IfThenElse(Or(mask, Ge(x, kZero)), z, Sub(kPi, z)); +} + +template +HWY_INLINE V Acosh(const D d, V x) { + using T = TFromD; + + const V kLarge = Set(d, static_cast(268435456.0)); + const V kLog2 = Set(d, static_cast(0.693147180559945286227)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const auto is_x_large = Gt(x, kLarge); + const auto is_x_gt_2 = Gt(x, kTwo); + + const V x_minus_1 = Sub(x, kOne); + const V y0 = MulSub(kTwo, x, Div(kOne, Add(Sqrt(MulSub(x, x, kOne)), x))); + const V y1 = + Add(Sqrt(MulAdd(x_minus_1, kTwo, Mul(x_minus_1, x_minus_1))), x_minus_1); + const V y2 = + IfThenElse(is_x_gt_2, IfThenElse(is_x_large, x, y0), Add(y1, kOne)); + const V z = impl::Log(d, y2); + + const auto is_pole = Eq(y2, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y2), kOne); + return Add(IfThenElse(is_x_gt_2, z, + IfThenElse(is_pole, y1, Div(Mul(z, y1), divisor))), + IfThenElseZero(is_x_large, kLog2)); +} + +template +HWY_INLINE V Asin(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kTwo = Set(d, static_cast(+2.0)); + const V kPiOverTwo = Set(d, static_cast(+1.57079632679489661923132169)); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = Lt(abs_x, kHalf); + const V yy = + IfThenElse(mask, Mul(abs_x, abs_x), NegMulAdd(abs_x, kHalf, kHalf)); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl impl; + const V z0 = MulAdd(impl.AsinPoly(d, yy, y), Mul(yy, y), y); + const V z1 = NegMulAdd(z0, kTwo, kPiOverTwo); + return Or(IfThenElse(mask, z0, z1), sign_x); +} + +template +HWY_INLINE V Asinh(const D d, V x) { + using T = TFromD; + + const V kSmall = Set(d, static_cast(1.0 / 268435456.0)); + const V kLarge = Set(d, static_cast(268435456.0)); + const V kLog2 = Set(d, static_cast(0.693147180559945286227)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const V sign_x = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign_x); + + const auto is_x_large = Gt(abs_x, kLarge); + const auto is_x_lt_2 = Lt(abs_x, kTwo); + + const V x2 = Mul(x, x); + const V sqrt_x2_plus_1 = Sqrt(Add(x2, kOne)); + + const V y0 = MulAdd(abs_x, kTwo, Div(kOne, Add(sqrt_x2_plus_1, abs_x))); + const V y1 = Add(Div(x2, Add(sqrt_x2_plus_1, kOne)), abs_x); + const V y2 = + IfThenElse(is_x_lt_2, Add(y1, kOne), IfThenElse(is_x_large, abs_x, y0)); + const V z = impl::Log(d, y2); + + const auto is_pole = Eq(y2, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y2), kOne); + const auto large = IfThenElse(is_pole, y1, Div(Mul(z, y1), divisor)); + const V y = IfThenElse(Lt(abs_x, kSmall), x, large); + return Or(Add(IfThenElse(is_x_lt_2, y, z), IfThenElseZero(is_x_large, kLog2)), + sign_x); +} + +template +HWY_INLINE V Atan(const D d, V x) { + using T = TFromD; + + const V kOne = Set(d, static_cast(+1.0)); + const V kPiOverTwo = Set(d, static_cast(+1.57079632679489661923132169)); + + const V sign = And(SignBit(d), x); + const V abs_x = Xor(x, sign); + const auto mask = Gt(abs_x, kOne); + + impl::AtanImpl impl; + const auto divisor = IfThenElse(mask, abs_x, kOne); + const V y = impl.AtanPoly(d, IfThenElse(mask, Div(kOne, divisor), abs_x)); + return Or(IfThenElse(mask, Sub(kPiOverTwo, y), y), sign); +} + +template +HWY_INLINE V Atanh(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kOne = Set(d, static_cast(+1.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + return Mul(Log1p(d, Div(Add(abs_x, abs_x), Sub(kOne, abs_x))), + Xor(kHalf, sign)); +} + +template +HWY_INLINE V Cos(const D d, V x) { + using T = TFromD; + impl::CosSinImpl impl; + + // Float Constants + const V kOneOverPi = Set(d, static_cast(0.31830988618379067153)); + + // Integer Constants + const Rebind di32; + using VI32 = decltype(Zero(di32)); + const VI32 kOne = Set(di32, 1); + + const V y = Abs(x); // cos(x) == cos(|x|) + + // Compute the quadrant, q = int(|x| / pi) * 2 + 1 + const VI32 q = Add(ShiftLeft<1>(impl.ToInt32(d, Mul(y, kOneOverPi))), kOne); + + // Reduce range, apply sign, and approximate. + return impl.Poly( + d, Xor(impl.CosReduce(d, y, q), impl.CosSignFromQuadrant(d, q))); +} + +template +HWY_INLINE V Exp(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -104.0 : -1000.0))); + const V kNegZero = Set(d, static_cast(-0.0)); + const V kOne = Set(d, static_cast(+1.0)); + const V kOneOverLog2 = Set(d, static_cast(+1.442695040888963407359924681)); + + impl::ExpImpl impl; + + // q = static_cast((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.LoadExpShortRange( + d, Add(impl.ExpPoly(d, impl.ExpReduce(d, x, q)), kOne), q); + return IfThenElseZero(Ge(x, kLowerBound), y); +} + +template +HWY_INLINE V Exp2(const D d, V x) { + using T = TFromD; + + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -150.0 : -1075.0))); + const V kOne = Set(d, static_cast(+1.0)); + + impl::ExpImpl impl; + + // q = static_cast(std::lrint(x)) + const auto q = impl.ToNearestInt32(d, x); + + // Reduce, approximate, and then reconstruct. + const V y = impl.LoadExpShortRange( + d, Add(impl.ExpPoly(d, impl.Exp2Reduce(d, x, q)), kOne), q); + return IfThenElseZero(Ge(x, kLowerBound), y); +} + +template +HWY_INLINE V Expm1(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -104.0 : -1000.0))); + const V kLn2Over2 = Set(d, static_cast(+0.346573590279972654708616)); + const V kNegOne = Set(d, static_cast(-1.0)); + const V kNegZero = Set(d, static_cast(-0.0)); + const V kOne = Set(d, static_cast(+1.0)); + const V kOneOverLog2 = Set(d, static_cast(+1.442695040888963407359924681)); + + impl::ExpImpl impl; + + // q = static_cast((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.ExpPoly(d, impl.ExpReduce(d, x, q)); + const V z = IfThenElse(Lt(Abs(x), kLn2Over2), y, + Sub(impl.LoadExpShortRange(d, Add(y, kOne), q), kOne)); + return IfThenElse(Lt(x, kLowerBound), kNegOne, z); +} + +template +HWY_INLINE V Log(const D d, V x) { + return impl::Log(d, x); +} + +template +HWY_INLINE V Log10(const D d, V x) { + using T = TFromD; + return Mul(Log(d, x), Set(d, static_cast(0.4342944819032518276511))); +} + +template +HWY_INLINE V Log1p(const D d, V x) { + using T = TFromD; + const V kOne = Set(d, static_cast(+1.0)); + + const V y = Add(x, kOne); + const auto is_pole = Eq(y, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y), kOne); + const auto non_pole = + Mul(impl::Log(d, y), Div(x, divisor)); + return IfThenElse(is_pole, x, non_pole); +} + +template +HWY_INLINE V Log2(const D d, V x) { + using T = TFromD; + return Mul(Log(d, x), Set(d, static_cast(1.44269504088896340735992))); +} + +template +HWY_INLINE V Sin(const D d, V x) { + using T = TFromD; + impl::CosSinImpl impl; + + // Float Constants + const V kOneOverPi = Set(d, static_cast(0.31830988618379067153)); + const V kHalf = Set(d, static_cast(0.5)); + + // Integer Constants + const Rebind di32; + using VI32 = decltype(Zero(di32)); + + const V abs_x = Abs(x); + const V sign_x = Xor(abs_x, x); + + // Compute the quadrant, q = int((|x| / pi) + 0.5) + const VI32 q = impl.ToInt32(d, MulAdd(abs_x, kOneOverPi, kHalf)); + + // Reduce range, apply sign, and approximate. + return impl.Poly(d, Xor(impl.SinReduce(d, abs_x, q), + Xor(impl.SinSignFromQuadrant(d, q), sign_x))); +} + +template +HWY_INLINE V Sinh(const D d, V x) { + using T = TFromD; + const V kHalf = Set(d, static_cast(+0.5)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, abs_x); + const V z = Mul(Div(Add(y, kTwo), Add(y, kOne)), Mul(y, kHalf)); + return Xor(z, sign); // Reapply the sign bit +} + +template +HWY_INLINE V Tanh(const D d, V x) { + using T = TFromD; + const V kLimit = Set(d, static_cast(18.714973875)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, Mul(abs_x, kTwo)); + const V z = IfThenElse(Gt(abs_x, kLimit), kOne, Div(y, Add(y, kTwo))); + return Xor(z, sign); // Reapply the sign bit +} + +template +HWY_INLINE void SinCos(const D d, V x, V& s, V& c) { + using T = TFromD; + impl::SinCosImpl impl; + impl.SinCos(d, x, s, c); +} + +template +HWY_INLINE V Hypot(const D d, V a, V b) { + using T = TFromD; + using TI = MakeSigned; + const RebindToUnsigned du; + const RebindToSigned di; + using VI = VFromD; + + constexpr int kMaxBiasedExp = static_cast(MaxExponentField()); + static_assert(kMaxBiasedExp > 0, "kMaxBiasedExp > 0 must be true"); + + constexpr int kNumOfMantBits = MantissaBits(); + static_assert(kNumOfMantBits > 0, "kNumOfMantBits > 0 must be true"); + + constexpr int kExpBias = kMaxBiasedExp / 2; + + static_assert( + static_cast(kExpBias) + static_cast(kNumOfMantBits) < + static_cast(kMaxBiasedExp), + "kExpBias + kNumOfMantBits < kMaxBiasedExp must be true"); + + // kMinValToSquareBiasedExp is the smallest biased exponent such that + // pow(pow(2, kMinValToSquareBiasedExp - kExpBias) * x, 2) is either a normal + // floating-point value or infinity if x is a non-zero, non-NaN value + constexpr int kMinValToSquareBiasedExp = (kExpBias / 2) + kNumOfMantBits; + static_assert(kMinValToSquareBiasedExp < kExpBias, + "kMinValToSquareBiasedExp < kExpBias must be true"); + + // kMaxValToSquareBiasedExp is the largest biased exponent such that + // pow(pow(2, kMaxValToSquareBiasedExp - kExpBias) * x, 2) * 2 is guaranteed + // to be a finite value if x is a finite value + constexpr int kMaxValToSquareBiasedExp = kExpBias + ((kExpBias / 2) - 1); + static_assert(kMaxValToSquareBiasedExp > kExpBias, + "kMaxValToSquareBiasedExp > kExpBias must be true"); + static_assert(kMaxValToSquareBiasedExp < kMaxBiasedExp, + "kMaxValToSquareBiasedExp < kMaxBiasedExp must be true"); + +#if HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128 || \ + HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 + using TExpSatSub = MakeUnsigned; + using TExpMinMax = TI; +#else + using TExpSatSub = uint16_t; + using TExpMinMax = int16_t; +#endif + + const Repartition d_exp_sat_sub; + const Repartition d_exp_min_max; + + const V abs_a = Abs(a); + const V abs_b = Abs(b); + + const MFromD either_inf = Or(IsInf(a), IsInf(b)); + + const VI zero = Zero(di); + + // exp_a[i] is the biased exponent of abs_a[i] + const VI exp_a = BitCast(di, ShiftRight(BitCast(du, abs_a))); + + // exp_b[i] is the biased exponent of abs_b[i] + const VI exp_b = BitCast(di, ShiftRight(BitCast(du, abs_b))); + + // max_exp[i] is equal to HWY_MAX(exp_a[i], exp_b[i]) + + // If abs_a[i] and abs_b[i] are both NaN values, max_exp[i] will be equal to + // the biased exponent of the larger value. Otherwise, if either abs_a[i] or + // abs_b[i] is NaN, max_exp[i] will be equal to kMaxBiasedExp. + const VI max_exp = BitCast( + di, Max(BitCast(d_exp_min_max, exp_a), BitCast(d_exp_min_max, exp_b))); + + // If either abs_a[i] or abs_b[i] is zero, min_exp[i] is equal to max_exp[i]. + // Otherwise, if abs_a[i] and abs_b[i] are both nonzero, min_exp[i] is equal + // to HWY_MIN(exp_a[i], exp_b[i]). + const VI min_exp = IfThenElse( + Or(Eq(BitCast(di, abs_a), zero), Eq(BitCast(di, abs_b), zero)), max_exp, + BitCast(di, Min(BitCast(d_exp_min_max, exp_a), + BitCast(d_exp_min_max, exp_b)))); + + // scl_pow2[i] is the power of 2 to scale abs_a[i] and abs_b[i] by + + // abs_a[i] and abs_b[i] should be scaled by a factor that is greater than + // zero but less than or equal to + // pow(2, kMaxValToSquareBiasedExp - max_exp[i]) to ensure that that the + // multiplications or addition operations do not overflow if + // std::hypot(abs_a[i], abs_b[i]) is finite + + // If either abs_a[i] or abs_b[i] is a a positive value that is less than + // pow(2, kMinValToSquareBiasedExp - kExpBias), then scaling up abs_a[i] and + // abs_b[i] by pow(2, kMinValToSquareBiasedExp - min_exp[i]) will ensure that + // the multiplications and additions result in normal floating point values, + // infinities, or NaNs. + + // If HWY_MAX(kMinValToSquareBiasedExp - min_exp[i], 0) is greater than + // kMaxValToSquareBiasedExp - max_exp[i], scale abs_a[i] and abs_b[i] up by + // pow(2, kMaxValToSquareBiasedExp - max_exp[i]) to ensure that the + // multiplication and addition operations result in a finite result if + // std::hypot(abs_a[i], abs_b[i]) is finite. + + const VI scl_pow2 = BitCast( + di, + Min(BitCast(d_exp_min_max, + SaturatedSub(BitCast(d_exp_sat_sub, + Set(di, static_cast( + kMinValToSquareBiasedExp))), + BitCast(d_exp_sat_sub, min_exp))), + BitCast(d_exp_min_max, + Sub(Set(di, static_cast(kMaxValToSquareBiasedExp)), + max_exp)))); + + const VI exp_bias = Set(di, static_cast(kExpBias)); + + const V ab_scl_factor = + BitCast(d, ShiftLeft(Add(exp_bias, scl_pow2))); + const V hypot_scl_factor = + BitCast(d, ShiftLeft(Sub(exp_bias, scl_pow2))); + + const V scl_a = Mul(abs_a, ab_scl_factor); + const V scl_b = Mul(abs_b, ab_scl_factor); + + const V scl_hypot = Sqrt(MulAdd(scl_a, scl_a, Mul(scl_b, scl_b))); + // std::hypot returns inf if one input is +/- inf, even if the other is NaN. + return IfThenElse(either_inf, Inf(d), Mul(scl_hypot, hypot_scl_factor)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/matvec/matvec-inl.h b/third_party/aom/third_party/highway/hwy/contrib/matvec/matvec-inl.h new file mode 100644 index 000000000000..32cd8df1473f --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/matvec/matvec-inl.h @@ -0,0 +1,451 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#if defined(HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ +#endif + +#include + +#include "third_party/highway/hwy/cache_control.h" +#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h" +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +TA AddScalar(TA a, TB b) { + return ConvertScalarTo(ConvertScalarTo(a) + + ConvertScalarTo(b)); +} + +template +HWY_NOINLINE void MatVecAddImpl(const T* HWY_RESTRICT mat, + const T* HWY_RESTRICT vec, + const T* HWY_RESTRICT add, T* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + (void)add; + + // Process multiple rows at a time so that we write multiples of a cache line + // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little + // parallelization potential. + constexpr size_t kChunkSize = 64 / sizeof(T); + const uint64_t num_chunks = static_cast(kOuter / kChunkSize); + + const ScalableTag d; + const size_t N = Lanes(d); + // Required for Stream loop, otherwise we might have partial vectors. + HWY_DASSERT(kChunkSize >= N); + pool.Run(0, num_chunks, + [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { + // MSVC workaround: duplicate to ensure constexpr. + constexpr size_t kChunkSize = 64 / sizeof(T); + // Software write-combining to avoid cache pollution from out. + // Although `out` may be used later, keeping it out of the cache + // now and avoiding RFOs is a consistent 5% overall win. + HWY_ALIGN T buf[kChunkSize]; + + // Only handle entire chunks here because the Stream is not masked. + // Remaining rows are handled after the pool.Run. + const size_t begin = static_cast(chunk * kChunkSize); + for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) { + auto sum0 = Zero(d); + auto sum1 = Zero(d); + // 4x unrolling barely helps SKX but likely helps Arm V2. + auto sum2 = Zero(d); + auto sum3 = Zero(d); + + const T* HWY_RESTRICT row = &mat[(begin + idx_row) * kInner]; + size_t i = 0; + // No clear win from prefetching from the next 1..3 rows. + // clflush &row[i] is slow, clflushopt less so but not helping. + HWY_UNROLL(1) + for (; i + 4 * N <= kInner; i += 4 * N) { + const auto a0 = LoadU(d, row + i + 0 * N); + const auto v0 = LoadU(d, vec + i + 0 * N); + sum0 = MulAdd(a0, v0, sum0); + + const auto a1 = LoadU(d, row + i + 1 * N); + const auto v1 = LoadU(d, vec + i + 1 * N); + sum1 = MulAdd(a1, v1, sum1); + + const auto a2 = LoadU(d, row + i + 2 * N); + const auto v2 = LoadU(d, vec + i + 2 * N); + sum2 = MulAdd(a2, v2, sum2); + + const auto a3 = LoadU(d, row + i + 3 * N); + const auto v3 = LoadU(d, vec + i + 3 * N); + sum3 = MulAdd(a3, v3, sum3); + } + // Last entire vectors + for (; i + N <= kInner; i += N) { + const auto a0 = LoadU(d, row + i); + const auto v0 = LoadU(d, vec + i); + sum0 = MulAdd(a0, v0, sum0); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const auto a0 = LoadN(d, row + i, remainder); + const auto v0 = LoadN(d, vec + i, remainder); + sum1 = MulAdd(a0, v0, sum1); + } + // Reduction tree: sum of all accumulators, then their lanes + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum1); + sum0 = Add(sum0, sum2); + buf[idx_row] = ReduceSum(d, sum0); + HWY_IF_CONSTEXPR(kAdd) { + buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]); + } + } // idx_row + HWY_UNROLL(4) // 1..4 iterations + for (size_t i = 0; i != kChunkSize; i += N) { + Stream(Load(d, buf + i), d, out + begin + i); + } + }); + hwy::FlushStream(); + + // Handle remainder rows which are not a multiple of the chunk size. + for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) { + auto sum0 = Zero(d); + + const T* HWY_RESTRICT row = &mat[r * kInner]; + size_t i = 0; + HWY_UNROLL(1) + for (; i + N <= kInner; i += N) { + const auto a0 = LoadU(d, row + i); + const auto v0 = LoadU(d, vec + i); + sum0 = MulAdd(a0, v0, sum0); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const auto a0 = LoadN(d, row + i, remainder); + const auto v0 = LoadN(d, vec + i, remainder); + sum0 = MulAdd(a0, v0, sum0); + } + out[r] = ReduceSum(d, sum0); + HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); } + } // r +} + +// Multiplies mat with vec, adds add and puts the result in out. +// +// mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at +// index i * kInner + j. +// +// vec is a (kInner,)-shaped array. +// +// add is a (kOuter,)-shaped array. +// +// out is a (kOuter,)-shaped array that will set to mat @ vec + add. +template +HWY_NOINLINE void MatVecAdd(const T* HWY_RESTRICT mat, + const T* HWY_RESTRICT vec, + const T* HWY_RESTRICT add, T* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, add, out, pool); +} + +// Multiplies mat with vec and puts the result in out. +// +// mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at +// index i * kInner + j. +// +// vec is a (kInner,)-shaped array. +// +// out is a (kOuter,)-shaped array that will set to mat @ vec. +template +HWY_NOINLINE void MatVec(const T* HWY_RESTRICT mat, const T* HWY_RESTRICT vec, + T* HWY_RESTRICT out, hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, /*add=*/nullptr, out, pool); +} + +// This target lacks too many ops required in our implementation, use +// HWY_EMU128 instead. +#if HWY_TARGET != HWY_SCALAR + +// Specialization for bf16 matrix, which halves memory bandwidth requirements. +template +HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat, + const float* HWY_RESTRICT vec, + const float* HWY_RESTRICT add, + float* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + // Process multiple rows at a time so that we write multiples of a cache line + // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little + // parallelization potential. + constexpr size_t kChunkSize = 64 / sizeof(float); + const uint64_t num_chunks = static_cast(kOuter / kChunkSize); + + const ScalableTag d; + const Repartition d16; + // In the remainder loop, we only process a single f32 vector, so load half + // vectors of bf16 to avoid overrun. + const Half d16h; + using V = Vec; + using V16 = Vec; + using V16H = Vec; + const size_t N = Lanes(d); + // Required for Stream loop, otherwise we might have partial vectors. + HWY_DASSERT(kChunkSize >= N); + pool.Run(0, num_chunks, + [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { + // MSVC workaround: duplicate to ensure constexpr. + constexpr size_t kChunkSize = 64 / sizeof(float); + // Software write-combining to avoid cache pollution from out. + // Although `out` may be used later, keeping it out of the cache + // now and avoiding RFOs is a consistent 5% overall win. + HWY_ALIGN float buf[kChunkSize]; + + // Only handle entire chunks here because the Stream is not masked. + // Remaining rows are handled after the pool.Run. + const size_t begin = static_cast(chunk * kChunkSize); + for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) { + auto sum0 = Zero(d); + auto sum1 = Zero(d); + // 4x unrolling barely helps SKX but likely helps Arm V2. + auto sum2 = Zero(d); + auto sum3 = Zero(d); + + const hwy::bfloat16_t* HWY_RESTRICT row = + &mat[(begin + idx_row) * kInner]; + size_t i = 0; + // No clear win from prefetching from the next 1..3 rows. + // clflush &row[i] is slow, clflushopt less so but not helping. + HWY_UNROLL(1) + for (; i + 4 * N <= kInner; i += 4 * N) { + const V16 b0 = LoadU(d16, row + i + 0 * N); + const V a0 = PromoteLowerTo(d, b0); + const V a1 = PromoteUpperTo(d, b0); + + const V16 b1 = LoadU(d16, row + i + 2 * N); + const V a2 = PromoteLowerTo(d, b1); + const V a3 = PromoteUpperTo(d, b1); + + const V v0 = LoadU(d, vec + i + 0 * N); + sum0 = MulAdd(a0, v0, sum0); + + const V v1 = LoadU(d, vec + i + 1 * N); + sum1 = MulAdd(a1, v1, sum1); + + const V v2 = LoadU(d, vec + i + 2 * N); + sum2 = MulAdd(a2, v2, sum2); + + const V v3 = LoadU(d, vec + i + 3 * N); + sum3 = MulAdd(a3, v3, sum3); + } + // Last entire vectors + for (; i + N <= kInner; i += N) { + const V16H b0 = LoadU(d16h, row + i); + const V a0 = PromoteTo(d, b0); + const V v0 = LoadU(d, vec + i); + sum0 = MulAdd(a0, v0, sum0); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const V16H b0 = LoadN(d16h, row + i, remainder); + const V a0 = PromoteTo(d, b0); + const V v0 = LoadN(d, vec + i, remainder); + sum1 = MulAdd(a0, v0, sum1); + } + // Reduction tree: sum of all accumulators, then their lanes + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum1); + sum0 = Add(sum0, sum2); + buf[idx_row] = ReduceSum(d, sum0); + HWY_IF_CONSTEXPR(kAdd) { + buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]); + } + } // idx_row + HWY_UNROLL(4) // 1..4 iterations + for (size_t i = 0; i != kChunkSize; i += N) { + Stream(Load(d, buf + i), d, out + begin + i); + } + }); + hwy::FlushStream(); + + // Handle remainder rows which are not a multiple of the chunk size. + for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) { + auto sum0 = Zero(d); + + const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner]; + size_t i = 0; + HWY_UNROLL(1) + for (; i + N <= kInner; i += N) { + const V16H b0 = LoadU(d16h, row + i); + const V a0 = PromoteTo(d, b0); + const V v0 = LoadU(d, vec + i); + sum0 = MulAdd(a0, v0, sum0); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const V16H b0 = LoadN(d16h, row + i, remainder); + const V a0 = PromoteTo(d, b0); + const V v0 = LoadN(d, vec + i, remainder); + sum0 = MulAdd(a0, v0, sum0); + } + out[r] = ReduceSum(d, sum0); + HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); } + } // r +} + +template +HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat, + const float* HWY_RESTRICT vec, + const float* HWY_RESTRICT add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, add, out, pool); +} + +template +HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat, + const float* HWY_RESTRICT vec, float* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, /*add=*/nullptr, out, pool); +} + +// Both mat and vec are bf16. +template +HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat, + const hwy::bfloat16_t* HWY_RESTRICT vec, + const hwy::bfloat16_t* HWY_RESTRICT add, + float* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + // Process multiple rows at a time so that we write multiples of a cache line + // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little + // parallelization potential. + constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t); + const uint64_t num_chunks = static_cast(kOuter / kChunkSize); + + const ScalableTag df; + const Repartition d16; + using V16 = Vec; + const size_t N = Lanes(d16); + // Required for Stream loop, otherwise we might have partial vectors. + HWY_DASSERT(kChunkSize >= N); + pool.Run(0, num_chunks, + [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { + // MSVC workaround: duplicate to ensure constexpr. + constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t); + // Software write-combining to avoid cache pollution from out. + // Although `out` may be used later, keeping it out of the cache + // now and avoiding RFOs is a consistent 5% overall win. + HWY_ALIGN float buf[kChunkSize]; + + // Only handle entire chunks here because the Stream is not masked. + // Remaining rows are handled after the pool.Run. + const size_t begin = static_cast(chunk * kChunkSize); + for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) { + auto sum0 = Zero(df); + auto sum1 = Zero(df); + auto sum2 = Zero(df); + auto sum3 = Zero(df); + + const hwy::bfloat16_t* HWY_RESTRICT row = + &mat[(begin + idx_row) * kInner]; + size_t i = 0; + // No clear win from prefetching from the next 1..3 rows. + // clflush &row[i] is slow, clflushopt less so but not helping. + HWY_UNROLL(1) + for (; i + 2 * N <= kInner; i += 2 * N) { + const V16 b0 = LoadU(d16, row + i + 0 * N); + const V16 b1 = LoadU(d16, row + i + 1 * N); + const V16 v0 = LoadU(d16, vec + i + 0 * N); + const V16 v1 = LoadU(d16, vec + i + 1 * N); + sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); + sum2 = ReorderWidenMulAccumulate(df, b1, v1, sum2, sum3); + } + // Last entire vector + for (; i + N <= kInner; i += N) { + const V16 b0 = LoadU(d16, row + i); + const V16 v0 = LoadU(d16, vec + i); + sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const V16 b0 = LoadN(d16, row + i, remainder); + const V16 v0 = LoadN(d16, vec + i, remainder); + sum2 = ReorderWidenMulAccumulate(df, b0, v0, sum2, sum3); + } + // Reduction tree: sum of all accumulators, then their lanes + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + buf[idx_row] = ReduceSum(df, sum0); + HWY_IF_CONSTEXPR(kAdd) { + buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]); + } + } // idx_row + HWY_UNROLL(4) // 1..4 iterations + for (size_t i = 0; i != kChunkSize; i += N / 2) { + Stream(Load(df, buf + i), df, out + begin + i); + } + }); + hwy::FlushStream(); + + // Handle remainder rows which are not a multiple of the chunk size. + for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) { + auto sum0 = Zero(df); + auto sum1 = Zero(df); + + const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner]; + size_t i = 0; + HWY_UNROLL(1) + for (; i + N <= kInner; i += N) { + const V16 b0 = LoadU(d16, row + i); + const V16 v0 = LoadU(d16, vec + i); + sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const V16 b0 = LoadN(d16, row + i, remainder); + const V16 v0 = LoadN(d16, vec + i, remainder); + sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); + } + out[r] = ReduceSum(df, Add(sum0, sum1)); + HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); } + } // r +} + +template +HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat, + const hwy::bfloat16_t* HWY_RESTRICT vec, + const hwy::bfloat16_t* HWY_RESTRICT add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, add, out, pool); +} + +template +HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat, + const hwy::bfloat16_t* HWY_RESTRICT vec, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, /*add=*/nullptr, out, pool); +} + +#endif // HWY_TARGET != HWY_SCALAR + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/random/random-inl.h b/third_party/aom/third_party/highway/hwy/contrib/random/random-inl.h new file mode 100644 index 000000000000..b96ef8a4ecc8 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/random/random-inl.h @@ -0,0 +1,384 @@ +/* + * Original implementation written in 2019 + * by David Blackman and Sebastiano Vigna (vigna@acm.org) + * Available at https://prng.di.unimi.it/ with creative commons license: + * To the extent possible under law, the author has dedicated all copyright + * and related and neighboring rights to this software to the public domain + * worldwide. This software is distributed without any warranty. + * See . + * + * This implementation is a Vector port of the original implementation + * written by Marco Barbone (m.barbone19@imperial.ac.uk). + * I take no credit for the original implementation. + * The code is provided as is and the original license applies. + */ + +#if defined(HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_ +#undef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_ +#else +#define HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_ +#endif + +#include +#include +#include + +#include "third_party/highway/hwy/aligned_allocator.h" +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); // required if not using HWY_ATTR + +namespace hwy { + +namespace HWY_NAMESPACE { // required: unique per target +namespace internal { + +namespace { +#if HWY_HAVE_FLOAT64 +// C++ < 17 does not support hexfloat +#if __cpp_hex_float > 201603L +constexpr double kMulConst = 0x1.0p-53; +#else +constexpr double kMulConst = + 0.00000000000000011102230246251565404236316680908203125; +#endif // __cpp_hex_float + +#endif // HWY_HAVE_FLOAT64 + +constexpr std::uint64_t kJump[] = {0x180ec6d33cfd0aba, 0xd5a61266f0c9392c, + 0xa9582618e03fc9aa, 0x39abdc4529b1661c}; + +constexpr std::uint64_t kLongJump[] = {0x76e15d3efefdcbbf, 0xc5004e441c522fb3, + 0x77710069854ee241, 0x39109bb02acbe635}; +} // namespace + +class SplitMix64 { + public: + constexpr explicit SplitMix64(const std::uint64_t state) noexcept + : state_(state) {} + + HWY_CXX14_CONSTEXPR std::uint64_t operator()() { + std::uint64_t z = (state_ += 0x9e3779b97f4a7c15); + z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9; + z = (z ^ (z >> 27)) * 0x94d049bb133111eb; + return z ^ (z >> 31); + } + + private: + std::uint64_t state_; +}; + +class Xoshiro { + public: + HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed) noexcept + : state_{} { + SplitMix64 splitMix64{seed}; + for (auto &element : state_) { + element = splitMix64(); + } + } + + HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed, + const std::uint64_t thread_id) noexcept + : Xoshiro(seed) { + for (auto i = UINT64_C(0); i < thread_id; ++i) { + Jump(); + } + } + + HWY_CXX14_CONSTEXPR std::uint64_t operator()() noexcept { return Next(); } + +#if HWY_HAVE_FLOAT64 + HWY_CXX14_CONSTEXPR double Uniform() noexcept { + return static_cast(Next() >> 11) * kMulConst; + } +#endif + + HWY_CXX14_CONSTEXPR std::array GetState() const { + return {state_[0], state_[1], state_[2], state_[3]}; + } + + HWY_CXX17_CONSTEXPR void SetState( + std::array state) noexcept { + state_[0] = state[0]; + state_[1] = state[1]; + state_[2] = state[2]; + state_[3] = state[3]; + } + + static constexpr std::uint64_t StateSize() noexcept { return 4; } + + /* This is the jump function for the generator. It is equivalent to 2^128 + * calls to next(); it can be used to generate 2^128 non-overlapping + * subsequences for parallel computations. */ + HWY_CXX14_CONSTEXPR void Jump() noexcept { Jump(kJump); } + + /* This is the long-jump function for the generator. It is equivalent to 2^192 + * calls to next(); it can be used to generate 2^64 starting points, from each + * of which jump() will generate 2^64 non-overlapping subsequences for + * parallel distributed computations. */ + HWY_CXX14_CONSTEXPR void LongJump() noexcept { Jump(kLongJump); } + + private: + std::uint64_t state_[4]; + + static constexpr std::uint64_t Rotl(const std::uint64_t x, int k) noexcept { + return (x << k) | (x >> (64 - k)); + } + + HWY_CXX14_CONSTEXPR std::uint64_t Next() noexcept { + const std::uint64_t result = Rotl(state_[0] + state_[3], 23) + state_[0]; + const std::uint64_t t = state_[1] << 17; + + state_[2] ^= state_[0]; + state_[3] ^= state_[1]; + state_[1] ^= state_[2]; + state_[0] ^= state_[3]; + + state_[2] ^= t; + + state_[3] = Rotl(state_[3], 45); + + return result; + } + + HWY_CXX14_CONSTEXPR void Jump(const std::uint64_t (&jumpArray)[4]) noexcept { + std::uint64_t s0 = 0; + std::uint64_t s1 = 0; + std::uint64_t s2 = 0; + std::uint64_t s3 = 0; + + for (const std::uint64_t i : jumpArray) + for (std::uint_fast8_t b = 0; b < 64; b++) { + if (i & std::uint64_t{1UL} << b) { + s0 ^= state_[0]; + s1 ^= state_[1]; + s2 ^= state_[2]; + s3 ^= state_[3]; + } + Next(); + } + + state_[0] = s0; + state_[1] = s1; + state_[2] = s2; + state_[3] = s3; + } +}; + +} // namespace internal + +class VectorXoshiro { + private: + using VU64 = Vec>; + using StateType = AlignedNDArray; +#if HWY_HAVE_FLOAT64 + using VF64 = Vec>; +#endif + public: + explicit VectorXoshiro(const std::uint64_t seed, + const std::uint64_t threadNumber = 0) + : state_{{internal::Xoshiro::StateSize(), + Lanes(ScalableTag{})}}, + streams{state_.shape().back()} { + internal::Xoshiro xoshiro{seed}; + + for (std::uint64_t i = 0; i < threadNumber; ++i) { + xoshiro.LongJump(); + } + + for (size_t i = 0UL; i < streams; ++i) { + const auto state = xoshiro.GetState(); + for (size_t j = 0UL; j < internal::Xoshiro::StateSize(); ++j) { + state_[{j}][i] = state[j]; + } + xoshiro.Jump(); + } + } + + HWY_INLINE VU64 operator()() noexcept { return Next(); } + + AlignedVector operator()(const std::size_t n) { + AlignedVector result(n); + const ScalableTag tag{}; + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + for (std::uint64_t i = 0; i < n; i += Lanes(tag)) { + const auto next = Update(s0, s1, s2, s3); + Store(next, tag, result.data() + i); + } + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } + + template + std::array operator()() noexcept { + alignas(HWY_ALIGNMENT) std::array result; + const ScalableTag tag{}; + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + for (std::uint64_t i = 0; i < N; i += Lanes(tag)) { + const auto next = Update(s0, s1, s2, s3); + Store(next, tag, result.data() + i); + } + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } + + std::uint64_t StateSize() const noexcept { + return streams * internal::Xoshiro::StateSize(); + } + + const StateType &GetState() const { return state_; } + +#if HWY_HAVE_FLOAT64 + + HWY_INLINE VF64 Uniform() noexcept { + const ScalableTag real_tag{}; + const auto MUL_VALUE = Set(real_tag, internal::kMulConst); + const auto bits = ShiftRight<11>(Next()); + const auto real = ConvertTo(real_tag, bits); + return Mul(real, MUL_VALUE); + } + + AlignedVector Uniform(const std::size_t n) { + AlignedVector result(n); + const ScalableTag tag{}; + const ScalableTag real_tag{}; + const auto MUL_VALUE = Set(real_tag, internal::kMulConst); + + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + + for (std::uint64_t i = 0; i < n; i += Lanes(real_tag)) { + const auto next = Update(s0, s1, s2, s3); + const auto bits = ShiftRight<11>(next); + const auto real = ConvertTo(real_tag, bits); + const auto uniform = Mul(real, MUL_VALUE); + Store(uniform, real_tag, result.data() + i); + } + + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } + + template + std::array Uniform() noexcept { + alignas(HWY_ALIGNMENT) std::array result; + const ScalableTag tag{}; + const ScalableTag real_tag{}; + const auto MUL_VALUE = Set(real_tag, internal::kMulConst); + + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + + for (std::uint64_t i = 0; i < N; i += Lanes(real_tag)) { + const auto next = Update(s0, s1, s2, s3); + const auto bits = ShiftRight<11>(next); + const auto real = ConvertTo(real_tag, bits); + const auto uniform = Mul(real, MUL_VALUE); + Store(uniform, real_tag, result.data() + i); + } + + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } + +#endif + + private: + StateType state_; + const std::uint64_t streams; + + HWY_INLINE static VU64 Update(VU64 &s0, VU64 &s1, VU64 &s2, + VU64 &s3) noexcept { + const auto result = Add(RotateRight<41>(Add(s0, s3)), s0); + const auto t = ShiftLeft<17>(s1); + s2 = Xor(s2, s0); + s3 = Xor(s3, s1); + s1 = Xor(s1, s2); + s0 = Xor(s0, s3); + s2 = Xor(s2, t); + s3 = RotateRight<19>(s3); + return result; + } + + HWY_INLINE VU64 Next() noexcept { + const ScalableTag tag{}; + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + auto result = Update(s0, s1, s2, s3); + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } +}; + +template +class CachedXoshiro { + public: + using result_type = std::uint64_t; + + static constexpr result_type(min)() { + return (std::numeric_limits::min)(); + } + + static constexpr result_type(max)() { + return (std::numeric_limits::max)(); + } + + explicit CachedXoshiro(const result_type seed, + const result_type threadNumber = 0) + : generator_{seed, threadNumber}, + cache_{generator_.operator()()}, + index_{0} {} + + result_type operator()() noexcept { + if (HWY_UNLIKELY(index_ == size)) { + cache_ = std::move(generator_.operator()()); + index_ = 0; + } + return cache_[index_++]; + } + + private: + VectorXoshiro generator_; + alignas(HWY_ALIGNMENT) std::array cache_; + std::size_t index_; + + static_assert((size & (size - 1)) == 0 && size != 0, + "only power of 2 are supported"); +}; + +} // namespace HWY_NAMESPACE +} // namespace hwy + +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ \ No newline at end of file diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/BUILD b/third_party/aom/third_party/highway/hwy/contrib/sort/BUILD new file mode 100644 index 000000000000..9dd625f5a3e2 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/BUILD @@ -0,0 +1,265 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +# Unused on Bazel builds, where this is not defined/known; Copybara replaces +# usages with an empty list. +COMPAT = [ + "//buildenv/target:non_prod", # includes mobile/vendor. +] + +cc_library( + name = "intel", + # hdrs = select({ + # "//third_party/bazel_platforms/cpu:x86_64": [ + # "avx512-16bit-common.h", + # "avx512-16bit-qsort.hpp", + # "avx512-32bit-qsort.hpp", + # "avx512-64bit-common.h", + # "avx512-64bit-qsort.hpp", + # "avx512-common-qsort.h", + # ], + # "//conditions:default": [], + # }), + compatible_with = [], +) + +cc_library( + name = "vxsort", + srcs = [ + # "vxsort/isa_detection.cpp", + # "vxsort/isa_detection_msvc.cpp", + # "vxsort/isa_detection_sane.cpp", + # "vxsort/machine_traits.avx2.cpp", + # "vxsort/smallsort/avx2_load_mask_tables.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.double.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.float.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.double.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.float.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.cpp", + # "vxsort/vxsort_stats.cpp", + ], + hdrs = [ + # "vxsort/alignment.h", + # "vxsort/defs.h", + # "vxsort/isa_detection.h", + # "vxsort/machine_traits.avx2.h", + # "vxsort/machine_traits.avx512.h", + # "vxsort/machine_traits.h", + # "vxsort/packer.h", + # "vxsort/smallsort/bitonic_sort.AVX2.double.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.float.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.double.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.float.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.h", + # "vxsort/smallsort/bitonic_sort.h", + # "vxsort/vxsort.h", + # "vxsort/vxsort_stats.h", + ], + compatible_with = [], + textual_hdrs = [ + # "vxsort/vxsort_targets_disable.h", + # "vxsort/vxsort_targets_enable_avx2.h", + # "vxsort/vxsort_targets_enable_avx512.h", + ], +) + +VQSORT_SRCS = [ + "vqsort.cc", + # Split into separate files to reduce MSVC build time. + "vqsort_128a.cc", + "vqsort_128d.cc", + "vqsort_f16a.cc", + "vqsort_f16d.cc", + "vqsort_f32a.cc", + "vqsort_f32d.cc", + "vqsort_f64a.cc", + "vqsort_f64d.cc", + "vqsort_i16a.cc", + "vqsort_i16d.cc", + "vqsort_i32a.cc", + "vqsort_i32d.cc", + "vqsort_i64a.cc", + "vqsort_i64d.cc", + "vqsort_kv64a.cc", + "vqsort_kv64d.cc", + "vqsort_kv128a.cc", + "vqsort_kv128d.cc", + "vqsort_u16a.cc", + "vqsort_u16d.cc", + "vqsort_u32a.cc", + "vqsort_u32d.cc", + "vqsort_u64a.cc", + "vqsort_u64d.cc", +] + +VQSORT_TEXTUAL_HDRS = [ + "shared-inl.h", + "sorting_networks-inl.h", + "traits-inl.h", + "traits128-inl.h", + "vqsort-inl.h", + # Placeholder for internal instrumentation. Do not remove. +] + +cc_library( + name = "vqsort", + srcs = VQSORT_SRCS, + hdrs = [ + "order.h", # part of public interface, included by vqsort.h + "vqsort.h", # public interface + ], + compatible_with = [], + local_defines = ["hwy_contrib_EXPORTS"], + textual_hdrs = VQSORT_TEXTUAL_HDRS, + deps = [ + ":intel", # required if HAVE_INTEL + ":vxsort", # required if HAVE_VXSORT + "//:algo", + "//:hwy", + ], +) + +# ----------------------------------------------------------------------------- +# Internal-only targets + +# Same as vqsort, but add HWY_COMPILE_ALL_ATTAINABLE to ensure we cover all +# targets. Do not enable this in the main vqsort because it increases +# compile times. +cc_library( + name = "vqsort_for_test", + srcs = VQSORT_SRCS, + hdrs = [ + "order.h", # part of public interface, included by vqsort.h + "vqsort.h", # public interface + ], + compatible_with = [], + local_defines = [ + "hwy_contrib_EXPORTS", + # Build for all targets because sort_test will dynamic-dispatch to all. + "HWY_COMPILE_ALL_ATTAINABLE", + ], + textual_hdrs = VQSORT_TEXTUAL_HDRS, + deps = [ + "//:algo", + "//:hwy", + ], +) + +cc_library( + name = "helpers", + testonly = 1, + textual_hdrs = [ + "algo-inl.h", + "result-inl.h", + ], + deps = [ + ":vqsort", + "//:nanobenchmark", + # Required for HAVE_PDQSORT, but that is unused and this is + # unavailable to Bazel builds, hence commented out. + # "//third_party/boost/allowed", + # Avoid ips4o and thus TBB to work around hwloc build failure. + ], +) + +cc_binary( + name = "print_network", + testonly = 1, + srcs = ["print_network.cc"], + deps = [ + ":helpers", + ":vqsort", + "//:hwy", + ], +) + +TEST_MAIN = select({ + "//:compiler_msvc": [], + "//conditions:default": ["@com_google_googletest//:gtest_main"], +}) + +cc_test( + name = "sort_unit_test", + size = "small", + srcs = ["sort_unit_test.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":helpers", + ":vqsort_for_test", + "//:hwy", + "//:hwy_test_util", + ] + TEST_MAIN, +) + +cc_test( + name = "sort_test", + size = "medium", + timeout = "long", + srcs = ["sort_test.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":helpers", + ":vqsort_for_test", + "//:hwy", + "//:hwy_test_util", + "//:thread_pool", + "//:topology", + ] + TEST_MAIN, +) + +cc_test( + name = "bench_sort", + size = "medium", + srcs = ["bench_sort.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":helpers", + ":vqsort", + "//:hwy", + "//:hwy_test_util", + "//:nanobenchmark", + "//:thread_pool", + ] + TEST_MAIN, +) + +cc_binary( + name = "bench_parallel", + testonly = 1, + srcs = ["bench_parallel.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + deps = [ + ":helpers", + ":vqsort", + "//:hwy", + "//:hwy_test_util", + "//:nanobenchmark", + ] + TEST_MAIN, +) diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/algo-inl.h b/third_party/aom/third_party/highway/hwy/contrib/sort/algo-inl.h new file mode 100644 index 000000000000..1530c5c57d74 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/algo-inl.h @@ -0,0 +1,620 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +#include +#include + +#include // std::sort +#include // std::less, std::greater +#include + +#include "third_party/highway/hwy/contrib/sort/vqsort.h" +#include "third_party/highway/hwy/highway.h" +#include "third_party/highway/hwy/print.h" + +// Third-party algorithms +#define HAVE_AVX2SORT 0 +#define HAVE_IPS4O 0 +// When enabling, consider changing max_threads (required for Table 1a) +#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1) +#define HAVE_PDQSORT 0 +#define HAVE_SORT512 0 +#define HAVE_VXSORT 0 +#if HWY_ARCH_X86 +#define HAVE_INTEL 0 +#else +#define HAVE_INTEL 0 +#endif + +#if HAVE_PARALLEL_IPS4O +#include // NOLINT +#endif + +#if HAVE_AVX2SORT +HWY_PUSH_ATTRIBUTES("avx2,avx") +#include "avx2sort.h" //NOLINT +HWY_POP_ATTRIBUTES +#endif +#if HAVE_IPS4O || HAVE_PARALLEL_IPS4O +#include "third_party/ips4o/include/ips4o.hpp" +#include "third_party/ips4o/include/ips4o/thread_pool.hpp" +#endif +#if HAVE_PDQSORT +#include "third_party/boost/allowed/sort/sort.hpp" +#endif +#if HAVE_SORT512 +#include "sort512.h" //NOLINT +#endif + +// vxsort is difficult to compile for multiple targets because it also uses +// .cpp files, and we'd also have to #undef its include guards. Instead, compile +// only for AVX2 or AVX3 depending on this macro. +#define VXSORT_AVX3 1 +#if HAVE_VXSORT +// inlined from vxsort_targets_enable_avx512 (must close before end of header) +#ifdef __GNUC__ +#ifdef __clang__ +#if VXSORT_AVX3 +#pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \ + apply_to = any(function)) +#else +#pragma clang attribute push(__attribute__((target("avx2"))), \ + apply_to = any(function)) +#endif // VXSORT_AVX3 + +#else +#pragma GCC push_options +#if VXSORT_AVX3 +#pragma GCC target("avx512f,avx512dq") +#else +#pragma GCC target("avx2") +#endif // VXSORT_AVX3 +#endif +#endif + +#if VXSORT_AVX3 +#include "vxsort/machine_traits.avx512.h" +#else +#include "vxsort/machine_traits.avx2.h" +#endif // VXSORT_AVX3 +#include "vxsort/vxsort.h" +#ifdef __GNUC__ +#ifdef __clang__ +#pragma clang attribute pop +#else +#pragma GCC pop_options +#endif +#endif +#endif // HAVE_VXSORT + +namespace hwy { + +enum class Dist { kUniform8, kUniform16, kUniform32 }; + +static inline std::vector AllDist() { + // Also include lower-entropy distributions to test MaybePartitionTwoValue. + return {Dist::kUniform8, /*Dist::kUniform16,*/ Dist::kUniform32}; +} + +static inline const char* DistName(Dist dist) { + switch (dist) { + case Dist::kUniform8: + return "uniform8"; + case Dist::kUniform16: + return "uniform16"; + case Dist::kUniform32: + return "uniform32"; + } + return "unreachable"; +} + +template +class InputStats { + public: + void Notify(T value) { + min_ = HWY_MIN(min_, value); + max_ = HWY_MAX(max_, value); + // Converting to integer would truncate floats, multiplying to save digits + // risks overflow especially when casting, so instead take the sum of the + // bit representations as the checksum. + uint64_t bits = 0; + static_assert(sizeof(T) <= 8, "Expected a built-in type"); + CopyBytes(&value, &bits); // not same size + sum_ += bits; + count_ += 1; + } + + bool operator==(const InputStats& other) const { + char type_name[100]; + detail::TypeName(hwy::detail::MakeTypeInfo(), 1, type_name); + + if (count_ != other.count_) { + HWY_ABORT("Sort %s: count %d vs %d\n", type_name, + static_cast(count_), static_cast(other.count_)); + } + + if (min_ != other.min_ || max_ != other.max_) { + HWY_ABORT("Sort %s: minmax %f/%f vs %f/%f\n", type_name, + static_cast(min_), static_cast(max_), + static_cast(other.min_), + static_cast(other.max_)); + } + + // Sum helps detect duplicated/lost values + if (sum_ != other.sum_) { + HWY_ABORT("Sort %s: Sum mismatch %g %g; min %g max %g\n", type_name, + static_cast(sum_), static_cast(other.sum_), + static_cast(min_), static_cast(max_)); + } + + return true; + } + + private: + T min_ = hwy::HighestValue(); + T max_ = hwy::LowestValue(); + uint64_t sum_ = 0; + size_t count_ = 0; +}; + +enum class Algo { +#if HAVE_INTEL + kIntel, +#endif +#if HAVE_AVX2SORT + kSEA, +#endif +#if HAVE_IPS4O + kIPS4O, +#endif +#if HAVE_PARALLEL_IPS4O + kParallelIPS4O, +#endif +#if HAVE_PDQSORT + kPDQ, +#endif +#if HAVE_SORT512 + kSort512, +#endif +#if HAVE_VXSORT + kVXSort, +#endif + kStdSort, + kStdSelect, + kStdPartialSort, + kVQSort, + kVQPartialSort, + kVQSelect, + kHeapSort, + kHeapPartialSort, + kHeapSelect, +}; + +static inline bool IsVQ(Algo algo) { + switch (algo) { + case Algo::kVQSort: + case Algo::kVQPartialSort: + case Algo::kVQSelect: + return true; + default: + return false; + } +} + +static inline bool IsSelect(Algo algo) { + switch (algo) { + case Algo::kStdSelect: + case Algo::kVQSelect: + case Algo::kHeapSelect: + return true; + default: + return false; + } +} + +static inline bool IsPartialSort(Algo algo) { + switch (algo) { + case Algo::kStdPartialSort: + case Algo::kVQPartialSort: + case Algo::kHeapPartialSort: + return true; + default: + return false; + } +} + +static inline Algo ReferenceAlgoFor(Algo algo) { + if (IsPartialSort(algo)) return Algo::kStdPartialSort; +#if HAVE_PDQSORT + return Algo::kPDQ; +#else + return Algo::kStdSort; +#endif +} + +static inline const char* AlgoName(Algo algo) { + switch (algo) { +#if HAVE_INTEL + case Algo::kIntel: + return "intel"; +#endif +#if HAVE_AVX2SORT + case Algo::kSEA: + return "sea"; +#endif +#if HAVE_IPS4O + case Algo::kIPS4O: + return "ips4o"; +#endif +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + return "par_ips4o"; +#endif +#if HAVE_PDQSORT + case Algo::kPDQ: + return "pdq"; +#endif +#if HAVE_SORT512 + case Algo::kSort512: + return "sort512"; +#endif +#if HAVE_VXSORT + case Algo::kVXSort: + return "vxsort"; +#endif + case Algo::kStdSort: + return "std"; + case Algo::kStdPartialSort: + return "std_partial"; + case Algo::kStdSelect: + return "std_select"; + case Algo::kVQSort: + return "vq"; + case Algo::kVQPartialSort: + return "vq_partial"; + case Algo::kVQSelect: + return "vq_select"; + case Algo::kHeapSort: + return "heap"; + case Algo::kHeapPartialSort: + return "heap_partial"; + case Algo::kHeapSelect: + return "heap_select"; + } + return "unreachable"; +} + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +// Per-target +// clang-format off +#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#endif +// clang-format on + +#include "third_party/highway/hwy/aligned_allocator.h" +#include "third_party/highway/hwy/contrib/sort/traits-inl.h" +#include "third_party/highway/hwy/contrib/sort/traits128-inl.h" +#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h" // HeapSort + +HWY_BEFORE_NAMESPACE(); + +// Requires target pragma set by HWY_BEFORE_NAMESPACE +#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3 +// #include "avx512-16bit-qsort.hpp" // requires AVX512-VBMI2 +#include "avx512-32bit-qsort.hpp" +#include "avx512-64bit-qsort.hpp" +#endif + +namespace hwy { +namespace HWY_NAMESPACE { + +#if HAVE_INTEL || HAVE_VXSORT // only supports ascending order +template +using OtherOrder = detail::OrderAscending; +#else +template +using OtherOrder = detail::OrderDescending; +#endif + +class Xorshift128Plus { + static HWY_INLINE uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + public: + // Generates two vectors of 64-bit seeds via SplitMix64 and stores into + // `seeds`. Generating these afresh in each ChoosePivot is too expensive. + template + static void GenerateSeeds(DU64 du64, TFromD* HWY_RESTRICT seeds) { + seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull); + for (size_t i = 1; i < 2 * Lanes(du64); ++i) { + seeds[i] = SplitMix64(seeds[i - 1]); + } + } + + // Need to pass in the state because vector cannot be class members. + template + static VU64 RandomBits(VU64& state0, VU64& state1) { + VU64 s1 = state0; + VU64 s0 = state1; + const VU64 bits = Add(s1, s0); + state0 = s0; + s1 = Xor(s1, ShiftLeft<23>(s1)); + state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0)))); + return bits; + } +}; + +template +Vec RandomValues(D d, VU64& s0, VU64& s1, const VU64 mask) { + const VU64 bits = Xorshift128Plus::RandomBits(s0, s1); + return BitCast(d, And(bits, mask)); +} + +// It is important to avoid denormals, which are flushed to zero by SIMD but not +// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD. +template +Vec RandomValues(DF df, VU64& s0, VU64& s1, const VU64 mask) { + using TF = TFromD; + const RebindToUnsigned du; + using VU = Vec; + + const VU64 bits64 = And(Xorshift128Plus::RandomBits(s0, s1), mask); + +#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to smaller types + using TU = MakeUnsigned; + const VU bits = Set(du, static_cast(GetLane(bits64) & LimitsMax())); +#else + const VU bits = BitCast(du, bits64); +#endif + // Avoid NaN/denormal by only generating values in [1, 2), i.e. random + // mantissas with the exponent taken from the representation of 1.0. + const VU k1 = BitCast(du, Set(df, TF{1.0})); + const VU mantissa_mask = Set(du, MantissaMask()); + const VU representation = OrAnd(k1, bits, mantissa_mask); + return BitCast(df, representation); +} + +template +Vec MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) { + switch (sizeof_t) { + case 2: + return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull + : 0xFFFFFFFFFFFFFFFFull); + case 4: + return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull + : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull + : 0xFFFFFFFFFFFFFFFFull); + case 8: + return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull + : (dist == Dist::kUniform16) ? 0x000000000000FFFFull + : 0x00000000FFFFFFFFull); + default: + HWY_ABORT("Logic error"); + return Zero(du64); + } +} + +template +InputStats GenerateInput(const Dist dist, T* v, size_t num_lanes) { + SortTag du64; + using VU64 = Vec; + const size_t N64 = Lanes(du64); + auto seeds = hwy::AllocateAligned(2 * N64); + Xorshift128Plus::GenerateSeeds(du64, seeds.get()); + VU64 s0 = Load(du64, seeds.get()); + VU64 s1 = Load(du64, seeds.get() + N64); + +#if HWY_TARGET == HWY_SCALAR + const Sisd d; +#else + const Repartition d; +#endif + using V = Vec; + const size_t N = Lanes(d); + const VU64 mask = MaskForDist(du64, dist, sizeof(T)); + auto buf = hwy::AllocateAligned(N); + + size_t i = 0; + for (; i + N <= num_lanes; i += N) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, v + i); + } + if (i < num_lanes) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, buf.get()); + CopyBytes(buf.get(), v + i, (num_lanes - i) * sizeof(T)); + } + + InputStats input_stats; + for (size_t i = 0; i < num_lanes; ++i) { + input_stats.Notify(v[i]); + } + return input_stats; +} + +struct SharedState { +#if HAVE_PARALLEL_IPS4O + const unsigned max_threads = hwy::LimitsMax(); // 16 for Table 1a + ips4o::StdThreadPool pool{static_cast( + HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))}; +#endif +}; + +// Adapters from Run's num_keys to vqsort-inl.h num_lanes. +template +void CallHeapSort(KeyType* keys, const size_t num_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + return detail::HeapSort(st, reinterpret_cast(keys), + num_keys * st.LanesPerKey()); +} +template +void CallHeapPartialSort(KeyType* keys, const size_t num_keys, + const size_t k_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + detail::HeapPartialSort(st, reinterpret_cast(keys), + num_keys * st.LanesPerKey(), + k_keys * st.LanesPerKey()); +} +template +void CallHeapSelect(KeyType* keys, const size_t num_keys, const size_t k_keys, + Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + detail::HeapSelect(st, reinterpret_cast(keys), + num_keys * st.LanesPerKey(), k_keys * st.LanesPerKey()); +} + +template +void Run(Algo algo, KeyType* inout, size_t num_keys, SharedState& shared, + size_t /*thread*/, size_t k_keys, Order) { + const std::less less; + const std::greater greater; + + constexpr bool kAscending = Order::IsAscending(); + +#if !HAVE_PARALLEL_IPS4O + (void)shared; +#endif + + switch (algo) { +#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3 + case Algo::kIntel: + return avx512_qsort(inout, static_cast(num_keys)); +#endif + +#if HAVE_AVX2SORT + case Algo::kSEA: + return avx2::quicksort(inout, static_cast(num_keys)); +#endif + +#if HAVE_IPS4O + case Algo::kIPS4O: + if (kAscending) { + return ips4o::sort(inout, inout + num_keys, less); + } else { + return ips4o::sort(inout, inout + num_keys, greater); + } +#endif + +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + if (kAscending) { + return ips4o::parallel::sort(inout, inout + num_keys, less, + shared.pool); + } else { + return ips4o::parallel::sort(inout, inout + num_keys, greater, + shared.pool); + } +#endif + +#if HAVE_SORT512 + case Algo::kSort512: + HWY_ABORT("not supported"); + // return Sort512::Sort(inout, num_keys); +#endif + +#if HAVE_PDQSORT + case Algo::kPDQ: + if (kAscending) { + return boost::sort::pdqsort_branchless(inout, inout + num_keys, less); + } else { + return boost::sort::pdqsort_branchless(inout, inout + num_keys, + greater); + } +#endif + +#if HAVE_VXSORT + case Algo::kVXSort: { +#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \ + (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2) + HWY_WARN("Do not call for target %s\n", hwy::TargetName(HWY_TARGET)); + return; +#else +#if VXSORT_AVX3 + vxsort::vxsort vx; +#else + vxsort::vxsort vx; +#endif + if (kAscending) { + return vx.sort(inout, inout + num_keys - 1); + } else { + HWY_WARN("Skipping VX - does not support descending order\n"); + return; + } +#endif // enabled for this target + } +#endif // HAVE_VXSORT + + case Algo::kStdSort: + if (kAscending) { + return std::sort(inout, inout + num_keys, less); + } else { + return std::sort(inout, inout + num_keys, greater); + } + case Algo::kStdPartialSort: + if (kAscending) { + return std::partial_sort(inout, inout + k_keys, inout + num_keys, less); + } else { + return std::partial_sort(inout, inout + k_keys, inout + num_keys, + greater); + } + case Algo::kStdSelect: + if (kAscending) { + return std::nth_element(inout, inout + k_keys, inout + num_keys, less); + } else { + return std::nth_element(inout, inout + k_keys, inout + num_keys, + greater); + } + + case Algo::kVQSort: + return VQSort(inout, num_keys, Order()); + case Algo::kVQPartialSort: + return VQPartialSort(inout, num_keys, k_keys, Order()); + case Algo::kVQSelect: + return VQSelect(inout, num_keys, k_keys, Order()); + + case Algo::kHeapSort: + return CallHeapSort(inout, num_keys, Order()); + case Algo::kHeapPartialSort: + return CallHeapPartialSort(inout, num_keys, k_keys, Order()); + case Algo::kHeapSelect: + return CallHeapSelect(inout, num_keys, k_keys, Order()); + + default: + HWY_ABORT("Not implemented"); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/order.h b/third_party/aom/third_party/highway/hwy/contrib/sort/order.h new file mode 100644 index 000000000000..5aa60b5e77f2 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/order.h @@ -0,0 +1,34 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tag arguments that determine the sort order. Used by both vqsort.h and the +// VQSortStatic in vqsort-inl.h. Moved to a separate header so that the latter +// can be used without pulling in the dllimport statements in vqsort.h. + +#ifndef HIGHWAY_HWY_CONTRIB_SORT_ORDER_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_ORDER_H_ + +namespace hwy { + +struct SortAscending { + static constexpr bool IsAscending() { return true; } +}; +struct SortDescending { + static constexpr bool IsAscending() { return false; } +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_ORDER_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/result-inl.h b/third_party/aom/third_party/highway/hwy/contrib/sort/result-inl.h new file mode 100644 index 000000000000..5f6d2ca8ecb8 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/result-inl.h @@ -0,0 +1,291 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "third_party/highway/hwy/contrib/sort/algo-inl.h" + +// Normal include guard for non-SIMD parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +#include +#include +#include + +#include // std::sort +#include +#include + +#include "third_party/highway/hwy/aligned_allocator.h" +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/contrib/sort/order.h" +#include "third_party/highway/hwy/per_target.h" // DispatchedTarget +#include "third_party/highway/hwy/targets.h" // TargetName + +namespace hwy { + +// Returns trimmed mean (we don't want to run an out-of-L3-cache sort often +// enough for the mode to be reliable). +static inline double SummarizeMeasurements(std::vector& seconds) { + std::sort(seconds.begin(), seconds.end()); + double sum = 0; + int count = 0; + const size_t num = seconds.size(); + for (size_t i = num / 4; i < num / 2; ++i) { + sum += seconds[i]; + count += 1; + } + return sum / count; +} + +struct SortResult { + SortResult() {} + SortResult(const Algo algo, Dist dist, size_t num_keys, size_t num_threads, + double sec, size_t sizeof_key, const char* key_name) + : target(DispatchedTarget()), + algo(algo), + dist(dist), + num_keys(num_keys), + num_threads(num_threads), + sec(sec), + sizeof_key(sizeof_key), + key_name(key_name) {} + + void Print() const { + const double bytes = static_cast(num_keys) * + static_cast(num_threads) * + static_cast(sizeof_key); + printf("%10s: %12s: %7s: %9s: %05g %4.0f MB/s (%2zu threads)\n", + hwy::TargetName(target), AlgoName(algo), key_name.c_str(), + DistName(dist), static_cast(num_keys), bytes * 1E-6 / sec, + num_threads); + } + + int64_t target; + Algo algo; + Dist dist; + size_t num_keys = 0; + size_t num_threads = 0; + double sec = 0.0; + size_t sizeof_key = 0; + std::string key_name; +}; + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Copies the input, and compares results to that of a reference algorithm. +template +class ReferenceSortVerifier { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + using Order = typename Traits::Order; + static constexpr bool kAscending = Order::IsAscending(); + static constexpr size_t kLPK = Traits().LanesPerKey(); + + public: + ReferenceSortVerifier(const LaneType* in_lanes, size_t num_lanes) { + num_lanes_ = num_lanes; + num_keys_ = num_lanes / kLPK; + in_lanes_ = hwy::AllocateAligned(num_lanes); + HWY_ASSERT(in_lanes_); + CopyBytes(in_lanes, in_lanes_.get(), num_lanes * sizeof(LaneType)); + } + + // For full sorts, k_keys == num_keys. + void operator()(Algo algo, const LaneType* out_lanes, size_t k_keys) { + SharedState shared; + const Traits st; + const CappedTag d; + + HWY_ASSERT(hwy::IsAligned(in_lanes_.get(), sizeof(KeyType))); + KeyType* in_keys = HWY_RCAST_ALIGNED(KeyType*, in_lanes_.get()); + + char caption[10]; + const char* algo_type = IsPartialSort(algo) ? "PartialSort" : "Sort"; + + HWY_ASSERT(k_keys <= num_keys_); + Run(ReferenceAlgoFor(algo), in_keys, num_keys_, shared, /*thread=*/0, + k_keys, Order()); + + if (IsSelect(algo)) { + // Print lanes centered around k_keys. + if (VQSORT_PRINT >= 3) { + const size_t begin_lane = k_keys < 3 ? 0 : (k_keys - 3) * kLPK; + const size_t end_lane = HWY_MIN(num_lanes_, (k_keys + 3) * kLPK); + fprintf(stderr, "\nExpected:\n"); + for (size_t i = begin_lane; i < end_lane; i += kLPK) { + snprintf(caption, sizeof(caption), "%4zu ", i / kLPK); + Print(d, caption, st.SetKey(d, &in_lanes_[i])); + } + fprintf(stderr, "\n\nActual:\n"); + for (size_t i = begin_lane; i < end_lane; i += kLPK) { + snprintf(caption, sizeof(caption), "%4zu ", i / kLPK); + Print(d, caption, st.SetKey(d, &out_lanes[i])); + } + fprintf(stderr, "\n\n"); + } + + // At k_keys: should be equivalent, i.e. neither a < b nor b < a. + // SortOrderVerifier will also check the ordering of the rest of the keys. + const size_t k = k_keys * kLPK; + if (st.Compare1(&in_lanes_[k], &out_lanes[k]) || + st.Compare1(&out_lanes[k], &in_lanes_[k])) { + Print(d, "Expected", st.SetKey(d, &in_lanes_[k])); + Print(d, " Actual", st.SetKey(d, &out_lanes[k])); + HWY_ABORT("Select %s asc=%d: mismatch at k_keys=%zu, num_keys=%zu\n", + st.KeyString(), kAscending, k_keys, num_keys_); + } + } else { + if (VQSORT_PRINT >= 3) { + const size_t lanes_to_print = HWY_MIN(40, k_keys * kLPK); + fprintf(stderr, "\nExpected:\n"); + for (size_t i = 0; i < lanes_to_print; i += kLPK) { + snprintf(caption, sizeof(caption), "%4zu ", i / kLPK); + Print(d, caption, st.SetKey(d, &in_lanes_[i])); + } + fprintf(stderr, "\n\nActual:\n"); + for (size_t i = 0; i < lanes_to_print; i += kLPK) { + snprintf(caption, sizeof(caption), "%4zu ", i / kLPK); + Print(d, caption, st.SetKey(d, &out_lanes[i])); + } + fprintf(stderr, "\n\n"); + } + + // Full or partial sort: all elements up to k_keys are equivalent to the + // reference sort. SortOrderVerifier also checks the output's ordering. + for (size_t i = 0; i < k_keys * kLPK; i += kLPK) { + // All up to k_keys should be equivalent, i.e. neither a < b nor b < a. + if (st.Compare1(&in_lanes_[i], &out_lanes[i]) || + st.Compare1(&out_lanes[i], &in_lanes_[i])) { + Print(d, "Expected", st.SetKey(d, &in_lanes_[i])); + Print(d, " Actual", st.SetKey(d, &out_lanes[i])); + HWY_ABORT("%s %s asc=%d: mismatch at %zu, k_keys=%zu, num_keys=%zu\n", + algo_type, st.KeyString(), kAscending, i / kLPK, k_keys, + num_keys_); + } + } + } + } + + private: + hwy::AlignedFreeUniquePtr in_lanes_; + size_t num_lanes_; + size_t num_keys_; +}; + +// Faster than ReferenceSortVerifier, for use in bench_sort. Only verifies +// order, without running a slow reference sorter. This means it can't verify +// Select places the correct key at `k_keys`, nor that input and output keys are +// the same. +template +class SortOrderVerifier { + using LaneType = typename Traits::LaneType; + using Order = typename Traits::Order; + static constexpr bool kAscending = Order::IsAscending(); + static constexpr size_t kLPK = Traits().LanesPerKey(); + + public: + void operator()(Algo algo, const InputStats& input_stats, + const LaneType* output, size_t num_keys, size_t k_keys) { + if (IsSelect(algo)) { + CheckSelectOrder(input_stats, output, num_keys, k_keys); + } else { + CheckSortedOrder(algo, input_stats, output, num_keys, k_keys); + } + } + + private: + // For full or partial sorts: ensures keys are in sorted order. + void CheckSortedOrder(const Algo algo, + const InputStats& input_stats, + const LaneType* output, const size_t num_keys, + const size_t k_keys) { + const Traits st; + const CappedTag d; + const size_t num_lanes = num_keys * kLPK; + const size_t k = k_keys * kLPK; + const char* algo_type = IsPartialSort(algo) ? "PartialSort" : "Sort"; + + InputStats output_stats; + // Even for partial sorts, loop over all keys to verify none disappeared. + for (size_t i = 0; i < num_lanes - kLPK; i += kLPK) { + output_stats.Notify(output[i]); + if (kLPK == 2) output_stats.Notify(output[i + 1]); + + // Only check the first k_keys (== num_keys for a full sort). + // Reverse order instead of checking !Compare1 so we accept equal keys. + if (i < k - kLPK && st.Compare1(output + i + kLPK, output + i)) { + Print(d, " cur", st.SetKey(d, &output[i])); + Print(d, "next", st.SetKey(d, &output[i + kLPK])); + HWY_ABORT( + "%s %s asc=%d: wrong order at %zu, k_keys=%zu, num_keys=%zu\n", + algo_type, st.KeyString(), kAscending, i / kLPK, k_keys, num_keys); + } + } + output_stats.Notify(output[num_lanes - kLPK]); + if (kLPK == 2) output_stats.Notify(output[num_lanes - kLPK + 1]); + + HWY_ASSERT(input_stats == output_stats); + } + + // Ensures keys below index k_keys are less, and all above are greater. + void CheckSelectOrder(const InputStats& input_stats, + const LaneType* output, const size_t num_keys, + const size_t k_keys) { + const Traits st; + const CappedTag d; + const size_t num_lanes = num_keys * kLPK; + const size_t k = k_keys * kLPK; + + InputStats output_stats; + for (size_t i = 0; i < num_lanes - kLPK; i += kLPK) { + output_stats.Notify(output[i]); + if (kLPK == 2) output_stats.Notify(output[i + 1]); + // Reverse order instead of checking !Compare1 so we accept equal keys. + if (i < k ? st.Compare1(output + k, output + i) + : st.Compare1(output + i, output + k)) { + Print(d, "cur", st.SetKey(d, &output[i])); + Print(d, "kth", st.SetKey(d, &output[k])); + HWY_ABORT( + "Select %s asc=%d: wrong order at %zu, k_keys=%zu, num_keys=%zu\n", + st.KeyString(), kAscending, i / kLPK, k_keys, num_keys); + } + } + output_stats.Notify(output[num_lanes - kLPK]); + if (kLPK == 2) output_stats.Notify(output[num_lanes - kLPK + 1]); + + HWY_ASSERT(input_stats == output_stats); + } +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/shared-inl.h b/third_party/aom/third_party/highway/hwy/contrib/sort/shared-inl.h new file mode 100644 index 000000000000..e534d3ba9253 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/shared-inl.h @@ -0,0 +1,181 @@ +// Copyright 2021 Google LLC +// Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions shared between vqsort-inl and sorting_networks-inl. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +#include "third_party/highway/hwy/base.h" + +namespace hwy { + +// Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028 +static HWY_INLINE uint64_t RandomBits(uint64_t* HWY_RESTRICT state) { + const uint64_t a = state[0]; + const uint64_t b = state[1]; + const uint64_t w = state[2] + 1; + const uint64_t next = a ^ w; + state[0] = (b + (b << 3)) ^ (b >> 11); + const uint64_t rot = (b << 24) | (b >> 40); + state[1] = rot + next; + state[2] = w; + return next; +} + +// Internal constants - these are to avoid magic numbers/literals and cannot be +// changed without also changing the associated code. +struct SortConstants { + // SortingNetwork reshapes its input into a matrix. This is the maximum number + // of *lanes* per vector. Must be at least 8 because SortSamples assumes the + // sorting network can handle 128 bytes with 8 rows, so 16 bytes per vector, + // which means 8 lanes for 16-bit types. +#if HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD + static constexpr size_t kMaxCols = 8; // avoid build timeout/stack overflow +#else + static constexpr size_t kMaxCols = 16; // enough for u32 in 512-bit vector +#endif + + // 16 rows is a compromise between using the 32 AVX-512/SVE/RVV registers, + // fitting within 16 AVX2 registers with only a few spills, keeping BaseCase + // code size reasonable, and minimizing the extra logN factor for larger + // networks (for which only loose upper bounds on size are known). + static constexpr size_t kMaxRows = 16; + + // Template argument ensures there is no actual division instruction. + template + static constexpr HWY_INLINE size_t BaseCaseNumLanes(size_t N) { + // We use 8, 8x2, 8x4, and 16x{4..} networks, in units of keys. For N/kLPK + // < 4, we cannot use the 16-row networks. + return (((N / kLPK) >= 4) ? kMaxRows : 8) * HWY_MIN(N, kMaxCols); + } + + // Unrolling is important (pipelining and amortizing branch mispredictions); + // 2x is sufficient to reach full memory bandwidth on SKX in Partition, but + // somewhat slower for sorting than 4x. + // + // To change, must also update left + 3 * N etc. in the loop. + static constexpr size_t kPartitionUnroll = 4; + + // Chunk := group of keys loaded for sampling a pivot. Matches the typical + // cache line size of 64 bytes to get maximum benefit per L2 miss. Sort() + // ensures vectors are no larger than that, so this can be independent of the + // vector size and thus constexpr. + static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t) { + return 64 / sizeof_t; + } + + template + static constexpr HWY_INLINE size_t SampleLanes() { + return 2 * LanesPerChunk(sizeof(T)); // Stored samples + } + + static constexpr HWY_INLINE size_t PartitionBufNum(size_t N) { + // The main loop reads kPartitionUnroll vectors, and first loads from + // both left and right beforehand, so it requires 2 * kPartitionUnroll + // vectors. To handle amounts between that and BaseCaseNumLanes(), we + // partition up 3 * kPartitionUnroll + 1 vectors into a two-part buffer. + return 2 * (3 * kPartitionUnroll + 1) * N; + } + + // Max across the three buffer usages. + template + static constexpr HWY_INLINE size_t BufNum(size_t N) { + // BaseCase may write one padding vector, and SortSamples uses the space + // after samples as the buffer. + return HWY_MAX(SampleLanes() + BaseCaseNumLanes(N) + N, + PartitionBufNum(N)); + } + + // Translates vector_size to lanes and returns size in bytes. + template + static constexpr HWY_INLINE size_t BufBytes(size_t vector_size) { + return BufNum(vector_size / sizeof(T)) * sizeof(T); + } + + // Returns max for any type. + template + static constexpr HWY_INLINE size_t MaxBufBytes(size_t vector_size) { + // If 2 lanes per key, it's a 128-bit key with u64 lanes. + return kLPK == 2 ? BufBytes(vector_size) + : HWY_MAX((BufBytes(vector_size)), + HWY_MAX((BufBytes(vector_size)), + (BufBytes(vector_size)))); + } +}; + +static_assert(SortConstants::MaxBufBytes<1>(64) <= 1664, "Unexpectedly high"); +static_assert(SortConstants::MaxBufBytes<2>(64) <= 1664, "Unexpectedly high"); + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +// Per-target +// clang-format off +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT +// clang-format on +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#endif + +#include "third_party/highway/hwy/highway.h" + +// vqsort isn't available on HWY_SCALAR, and builds time out on MSVC opt and +// Armv7 debug, and Armv8 GCC 11 asan hits an internal compiler error likely +// due to https://gcc.gnu.org/bugzilla/show_bug.cgi?id=97696. Armv8 Clang +// hwasan/msan/tsan/asan also fail to build SVE (b/335157772). RVV currently +// has a compiler issue. +#undef VQSORT_ENABLED +#undef VQSORT_COMPILER_COMPATIBLE + +#if (HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD) || \ + (HWY_ARCH_ARM_V7 && HWY_IS_DEBUG_BUILD) || \ + (HWY_ARCH_ARM_A64 && HWY_COMPILER_GCC_ACTUAL && HWY_IS_ASAN) || \ + (HWY_ARCH_RISCV) +#define VQSORT_COMPILER_COMPATIBLE 0 +#else +#define VQSORT_COMPILER_COMPATIBLE 1 +#endif + +#if (HWY_TARGET == HWY_SCALAR) || !VQSORT_COMPILER_COMPATIBLE +#define VQSORT_ENABLED 0 +#else +#define VQSORT_ENABLED 1 +#endif + +namespace hwy { +namespace HWY_NAMESPACE { + +// Default tag / vector width selector. +#if HWY_TARGET == HWY_RVV +// Use LMUL = 1/2; for SEW=64 this ends up emulated via VSETVLI. +template +using SortTag = ScalableTag; +#else +template +using SortTag = ScalableTag; +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h b/third_party/aom/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h new file mode 100644 index 000000000000..2158e7ea993c --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h @@ -0,0 +1,902 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#endif + +#include "third_party/highway/hwy/contrib/sort/shared-inl.h" // SortConstants +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if VQSORT_ENABLED + +using Constants = hwy::SortConstants; + +// ------------------------------ SharedTraits + +// Code shared between all traits. It's unclear whether these can profitably be +// specialized for Lane vs Block, or optimized like SortPairsDistance1 using +// Compare/DupOdd. +template +struct SharedTraits : public Base { + using SharedTraitsForSortingNetwork = + SharedTraits; + + // Conditionally swaps lane 0 with 2, 1 with 3 etc. + template + HWY_INLINE Vec SortPairsDistance2(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->SwapAdjacentPairs(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template + HWY_INLINE Vec SortPairsReverse8(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys8(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template + HWY_INLINE Vec SortPairsReverse16(D d, Vec v) const { + const Base* base = static_cast(this); + static_assert(Constants::kMaxCols <= 16, "Need actual Reverse16"); + Vec swapped = base->ReverseKeys(d, v); + base->Sort2(d, v, swapped); + return ConcatUpperLower(d, swapped, v); // 8 = half of the vector + } +}; + +// ------------------------------ Sorting network + +// Sorting networks for independent columns in 2, 4 and 8 vectors from +// https://bertdobbelaere.github.io/sorting_networks.html. + +template > +HWY_INLINE void Sort2(D d, Traits st, V& v0, V& v1) { + st.Sort2(d, v0, v1); +} + +template > +HWY_INLINE void Sort4(D d, Traits st, V& v0, V& v1, V& v2, V& v3) { + st.Sort2(d, v0, v2); + st.Sort2(d, v1, v3); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v1, v2); +} + +template > +HWY_INLINE void Sort8(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7) { + st.Sort2(d, v0, v2); + st.Sort2(d, v1, v3); + st.Sort2(d, v4, v6); + st.Sort2(d, v5, v7); + + st.Sort2(d, v0, v4); + st.Sort2(d, v1, v5); + st.Sort2(d, v2, v6); + st.Sort2(d, v3, v7); + + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + + st.Sort2(d, v2, v4); + st.Sort2(d, v3, v5); + + st.Sort2(d, v1, v4); + st.Sort2(d, v3, v6); + + st.Sort2(d, v1, v2); + st.Sort2(d, v3, v4); + st.Sort2(d, v5, v6); +} + +// (Green's irregular) sorting network for independent columns in 16 vectors. +template > +HWY_INLINE void Sort16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + st.Sort2(d, v0, v2); + st.Sort2(d, v1, v3); + st.Sort2(d, v4, v6); + st.Sort2(d, v5, v7); + st.Sort2(d, v8, va); + st.Sort2(d, v9, vb); + st.Sort2(d, vc, ve); + st.Sort2(d, vd, vf); + st.Sort2(d, v0, v4); + st.Sort2(d, v1, v5); + st.Sort2(d, v2, v6); + st.Sort2(d, v3, v7); + st.Sort2(d, v8, vc); + st.Sort2(d, v9, vd); + st.Sort2(d, va, ve); + st.Sort2(d, vb, vf); + st.Sort2(d, v0, v8); + st.Sort2(d, v1, v9); + st.Sort2(d, v2, va); + st.Sort2(d, v3, vb); + st.Sort2(d, v4, vc); + st.Sort2(d, v5, vd); + st.Sort2(d, v6, ve); + st.Sort2(d, v7, vf); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v3, vc); + st.Sort2(d, v7, vb); + st.Sort2(d, vd, ve); + st.Sort2(d, v4, v8); + st.Sort2(d, v1, v2); + st.Sort2(d, v1, v4); + st.Sort2(d, v7, vd); + st.Sort2(d, v2, v8); + st.Sort2(d, vb, ve); + st.Sort2(d, v2, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vd); + st.Sort2(d, v3, v8); + st.Sort2(d, v7, vc); + st.Sort2(d, v3, v5); + st.Sort2(d, v6, v8); + st.Sort2(d, v7, v9); + st.Sort2(d, va, vc); + st.Sort2(d, v3, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v7, v8); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vc); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); +} + +// ------------------------------ Merging networks + +// Blacher's hybrid bitonic/odd-even networks, generated by print_network.cc. +// For acceptable performance, these must be inlined, otherwise vectors are +// loaded from the stack. The kKeysPerVector allows calling from generic code +// but skipping the functions when vectors have too few lanes for +// st.SortPairsDistance1 to compile. `if constexpr` in the caller would also +// work, but is not available in C++11. We write out the (unused) argument types +// rather than `...` because GCC 9 (but not 10) fails to compile with `...`. + +template +HWY_INLINE void Merge8x2(D, Traits, V, V, V, V, V, V, V, V) {} +template +HWY_INLINE void Merge8x4(D, Traits, V, V, V, V, V, V, V, V) {} + +template +HWY_INLINE void Merge16x2(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V, + V, V) {} +template +HWY_INLINE void Merge16x4(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V, + V, V) {} +template +HWY_INLINE void Merge16x8(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V, + V, V) {} +template +HWY_INLINE void Merge16x16(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V, + V, V) {} + +template , + HWY_IF_LANES_GT(kKeysPerVector, 1)> +HWY_INLINE void Merge8x2(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7) { + v7 = st.ReverseKeys2(d, v7); + v6 = st.ReverseKeys2(d, v6); + v5 = st.ReverseKeys2(d, v5); + v4 = st.ReverseKeys2(d, v4); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + + v3 = st.ReverseKeys2(d, v3); + v2 = st.ReverseKeys2(d, v2); + v7 = st.ReverseKeys2(d, v7); + v6 = st.ReverseKeys2(d, v6); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + + v1 = st.ReverseKeys2(d, v1); + v3 = st.ReverseKeys2(d, v3); + v5 = st.ReverseKeys2(d, v5); + v7 = st.ReverseKeys2(d, v7); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); +} + +template , + HWY_IF_LANES_GT(kKeysPerVector, 2)> +HWY_INLINE void Merge8x4(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7) { + v7 = st.ReverseKeys4(d, v7); + v6 = st.ReverseKeys4(d, v6); + v5 = st.ReverseKeys4(d, v5); + v4 = st.ReverseKeys4(d, v4); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + + v3 = st.ReverseKeys4(d, v3); + v2 = st.ReverseKeys4(d, v2); + v7 = st.ReverseKeys4(d, v7); + v6 = st.ReverseKeys4(d, v6); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + + v1 = st.ReverseKeys4(d, v1); + v3 = st.ReverseKeys4(d, v3); + v5 = st.ReverseKeys4(d, v5); + v7 = st.ReverseKeys4(d, v7); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + + v0 = st.SortPairsReverse4(d, v0); + v1 = st.SortPairsReverse4(d, v1); + v2 = st.SortPairsReverse4(d, v2); + v3 = st.SortPairsReverse4(d, v3); + v4 = st.SortPairsReverse4(d, v4); + v5 = st.SortPairsReverse4(d, v5); + v6 = st.SortPairsReverse4(d, v6); + v7 = st.SortPairsReverse4(d, v7); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); +} + +// Only used by the now-deprecated SortingNetwork(). +template , + HWY_IF_LANES_GT(kKeysPerVector, 1)> +HWY_INLINE void Merge16x2(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, + V& vc, V& vd, V& ve, V& vf) { + vf = st.ReverseKeys2(d, vf); + ve = st.ReverseKeys2(d, ve); + vd = st.ReverseKeys2(d, vd); + vc = st.ReverseKeys2(d, vc); + vb = st.ReverseKeys2(d, vb); + va = st.ReverseKeys2(d, va); + v9 = st.ReverseKeys2(d, v9); + v8 = st.ReverseKeys2(d, v8); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + + v7 = st.ReverseKeys2(d, v7); + v6 = st.ReverseKeys2(d, v6); + v5 = st.ReverseKeys2(d, v5); + v4 = st.ReverseKeys2(d, v4); + vf = st.ReverseKeys2(d, vf); + ve = st.ReverseKeys2(d, ve); + vd = st.ReverseKeys2(d, vd); + vc = st.ReverseKeys2(d, vc); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + st.Sort2(d, v8, vf); + st.Sort2(d, v9, ve); + st.Sort2(d, va, vd); + st.Sort2(d, vb, vc); + + v3 = st.ReverseKeys2(d, v3); + v2 = st.ReverseKeys2(d, v2); + v7 = st.ReverseKeys2(d, v7); + v6 = st.ReverseKeys2(d, v6); + vb = st.ReverseKeys2(d, vb); + va = st.ReverseKeys2(d, va); + vf = st.ReverseKeys2(d, vf); + ve = st.ReverseKeys2(d, ve); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + + v1 = st.ReverseKeys2(d, v1); + v3 = st.ReverseKeys2(d, v3); + v5 = st.ReverseKeys2(d, v5); + v7 = st.ReverseKeys2(d, v7); + v9 = st.ReverseKeys2(d, v9); + vb = st.ReverseKeys2(d, vb); + vd = st.ReverseKeys2(d, vd); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template , + HWY_IF_LANES_GT(kKeysPerVector, 2)> +HWY_INLINE void Merge16x4(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, + V& vc, V& vd, V& ve, V& vf) { + vf = st.ReverseKeys4(d, vf); + ve = st.ReverseKeys4(d, ve); + vd = st.ReverseKeys4(d, vd); + vc = st.ReverseKeys4(d, vc); + vb = st.ReverseKeys4(d, vb); + va = st.ReverseKeys4(d, va); + v9 = st.ReverseKeys4(d, v9); + v8 = st.ReverseKeys4(d, v8); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + + v7 = st.ReverseKeys4(d, v7); + v6 = st.ReverseKeys4(d, v6); + v5 = st.ReverseKeys4(d, v5); + v4 = st.ReverseKeys4(d, v4); + vf = st.ReverseKeys4(d, vf); + ve = st.ReverseKeys4(d, ve); + vd = st.ReverseKeys4(d, vd); + vc = st.ReverseKeys4(d, vc); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + st.Sort2(d, v8, vf); + st.Sort2(d, v9, ve); + st.Sort2(d, va, vd); + st.Sort2(d, vb, vc); + + v3 = st.ReverseKeys4(d, v3); + v2 = st.ReverseKeys4(d, v2); + v7 = st.ReverseKeys4(d, v7); + v6 = st.ReverseKeys4(d, v6); + vb = st.ReverseKeys4(d, vb); + va = st.ReverseKeys4(d, va); + vf = st.ReverseKeys4(d, vf); + ve = st.ReverseKeys4(d, ve); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + + v1 = st.ReverseKeys4(d, v1); + v3 = st.ReverseKeys4(d, v3); + v5 = st.ReverseKeys4(d, v5); + v7 = st.ReverseKeys4(d, v7); + v9 = st.ReverseKeys4(d, v9); + vb = st.ReverseKeys4(d, vb); + vd = st.ReverseKeys4(d, vd); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + + v0 = st.SortPairsReverse4(d, v0); + v1 = st.SortPairsReverse4(d, v1); + v2 = st.SortPairsReverse4(d, v2); + v3 = st.SortPairsReverse4(d, v3); + v4 = st.SortPairsReverse4(d, v4); + v5 = st.SortPairsReverse4(d, v5); + v6 = st.SortPairsReverse4(d, v6); + v7 = st.SortPairsReverse4(d, v7); + v8 = st.SortPairsReverse4(d, v8); + v9 = st.SortPairsReverse4(d, v9); + va = st.SortPairsReverse4(d, va); + vb = st.SortPairsReverse4(d, vb); + vc = st.SortPairsReverse4(d, vc); + vd = st.SortPairsReverse4(d, vd); + ve = st.SortPairsReverse4(d, ve); + vf = st.SortPairsReverse4(d, vf); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template , + HWY_IF_LANES_GT(kKeysPerVector, 4)> +HWY_INLINE void Merge16x8(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, + V& vc, V& vd, V& ve, V& vf) { + vf = st.ReverseKeys8(d, vf); + ve = st.ReverseKeys8(d, ve); + vd = st.ReverseKeys8(d, vd); + vc = st.ReverseKeys8(d, vc); + vb = st.ReverseKeys8(d, vb); + va = st.ReverseKeys8(d, va); + v9 = st.ReverseKeys8(d, v9); + v8 = st.ReverseKeys8(d, v8); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + + v7 = st.ReverseKeys8(d, v7); + v6 = st.ReverseKeys8(d, v6); + v5 = st.ReverseKeys8(d, v5); + v4 = st.ReverseKeys8(d, v4); + vf = st.ReverseKeys8(d, vf); + ve = st.ReverseKeys8(d, ve); + vd = st.ReverseKeys8(d, vd); + vc = st.ReverseKeys8(d, vc); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + st.Sort2(d, v8, vf); + st.Sort2(d, v9, ve); + st.Sort2(d, va, vd); + st.Sort2(d, vb, vc); + + v3 = st.ReverseKeys8(d, v3); + v2 = st.ReverseKeys8(d, v2); + v7 = st.ReverseKeys8(d, v7); + v6 = st.ReverseKeys8(d, v6); + vb = st.ReverseKeys8(d, vb); + va = st.ReverseKeys8(d, va); + vf = st.ReverseKeys8(d, vf); + ve = st.ReverseKeys8(d, ve); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + + v1 = st.ReverseKeys8(d, v1); + v3 = st.ReverseKeys8(d, v3); + v5 = st.ReverseKeys8(d, v5); + v7 = st.ReverseKeys8(d, v7); + v9 = st.ReverseKeys8(d, v9); + vb = st.ReverseKeys8(d, vb); + vd = st.ReverseKeys8(d, vd); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + + v0 = st.SortPairsReverse8(d, v0); + v1 = st.SortPairsReverse8(d, v1); + v2 = st.SortPairsReverse8(d, v2); + v3 = st.SortPairsReverse8(d, v3); + v4 = st.SortPairsReverse8(d, v4); + v5 = st.SortPairsReverse8(d, v5); + v6 = st.SortPairsReverse8(d, v6); + v7 = st.SortPairsReverse8(d, v7); + v8 = st.SortPairsReverse8(d, v8); + v9 = st.SortPairsReverse8(d, v9); + va = st.SortPairsReverse8(d, va); + vb = st.SortPairsReverse8(d, vb); + vc = st.SortPairsReverse8(d, vc); + vd = st.SortPairsReverse8(d, vd); + ve = st.SortPairsReverse8(d, ve); + vf = st.SortPairsReverse8(d, vf); + + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +// Unused on MSVC, see below +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + +template , + HWY_IF_LANES_GT(kKeysPerVector, 8)> +HWY_INLINE void Merge16x16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, + V& vc, V& vd, V& ve, V& vf) { + vf = st.ReverseKeys16(d, vf); + ve = st.ReverseKeys16(d, ve); + vd = st.ReverseKeys16(d, vd); + vc = st.ReverseKeys16(d, vc); + vb = st.ReverseKeys16(d, vb); + va = st.ReverseKeys16(d, va); + v9 = st.ReverseKeys16(d, v9); + v8 = st.ReverseKeys16(d, v8); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + + v7 = st.ReverseKeys16(d, v7); + v6 = st.ReverseKeys16(d, v6); + v5 = st.ReverseKeys16(d, v5); + v4 = st.ReverseKeys16(d, v4); + vf = st.ReverseKeys16(d, vf); + ve = st.ReverseKeys16(d, ve); + vd = st.ReverseKeys16(d, vd); + vc = st.ReverseKeys16(d, vc); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + st.Sort2(d, v8, vf); + st.Sort2(d, v9, ve); + st.Sort2(d, va, vd); + st.Sort2(d, vb, vc); + + v3 = st.ReverseKeys16(d, v3); + v2 = st.ReverseKeys16(d, v2); + v7 = st.ReverseKeys16(d, v7); + v6 = st.ReverseKeys16(d, v6); + vb = st.ReverseKeys16(d, vb); + va = st.ReverseKeys16(d, va); + vf = st.ReverseKeys16(d, vf); + ve = st.ReverseKeys16(d, ve); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + + v1 = st.ReverseKeys16(d, v1); + v3 = st.ReverseKeys16(d, v3); + v5 = st.ReverseKeys16(d, v5); + v7 = st.ReverseKeys16(d, v7); + v9 = st.ReverseKeys16(d, v9); + vb = st.ReverseKeys16(d, vb); + vd = st.ReverseKeys16(d, vd); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + + v0 = st.SortPairsReverse16(d, v0); + v1 = st.SortPairsReverse16(d, v1); + v2 = st.SortPairsReverse16(d, v2); + v3 = st.SortPairsReverse16(d, v3); + v4 = st.SortPairsReverse16(d, v4); + v5 = st.SortPairsReverse16(d, v5); + v6 = st.SortPairsReverse16(d, v6); + v7 = st.SortPairsReverse16(d, v7); + v8 = st.SortPairsReverse16(d, v8); + v9 = st.SortPairsReverse16(d, v9); + va = st.SortPairsReverse16(d, va); + vb = st.SortPairsReverse16(d, vb); + vc = st.SortPairsReverse16(d, vc); + vd = st.SortPairsReverse16(d, vd); + ve = st.SortPairsReverse16(d, ve); + vf = st.SortPairsReverse16(d, vf); + + v0 = st.SortPairsDistance4(d, v0); + v1 = st.SortPairsDistance4(d, v1); + v2 = st.SortPairsDistance4(d, v2); + v3 = st.SortPairsDistance4(d, v3); + v4 = st.SortPairsDistance4(d, v4); + v5 = st.SortPairsDistance4(d, v5); + v6 = st.SortPairsDistance4(d, v6); + v7 = st.SortPairsDistance4(d, v7); + v8 = st.SortPairsDistance4(d, v8); + v9 = st.SortPairsDistance4(d, v9); + va = st.SortPairsDistance4(d, va); + vb = st.SortPairsDistance4(d, vb); + vc = st.SortPairsDistance4(d, vc); + vd = st.SortPairsDistance4(d, vd); + ve = st.SortPairsDistance4(d, ve); + vf = st.SortPairsDistance4(d, vf); + + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +#endif // !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + +// Reshapes `buf` into a matrix, sorts columns independently, and then merges +// into a sorted 1D array without transposing. +// +// DEPRECATED, use BaseCase() instead. +template +HWY_INLINE void SortingNetwork(Traits st, size_t cols, V& v0, V& v1, V& v2, + V& v3, V& v4, V& v5, V& v6, V& v7, V& v8, V& v9, + V& va, V& vb, V& vc, V& vd, V& ve, V& vf) { + // traits*-inl assume 'full' vectors (but still capped to kMaxCols). + const CappedTag d; + + HWY_DASSERT(cols <= Constants::kMaxCols); + + // The network width depends on the number of keys, not lanes. + constexpr size_t kLanesPerKey = st.LanesPerKey(); + const size_t keys = cols / kLanesPerKey; + constexpr size_t kMaxKeys = MaxLanes(d) / kLanesPerKey; + + Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf); + + // Checking MaxLanes avoids generating HWY_ASSERT code for the unreachable + // code paths: if MaxLanes < 2, then keys <= cols < 2. + if (HWY_LIKELY(keys >= 2 && kMaxKeys >= 2)) { + Merge16x2(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); + + if (HWY_LIKELY(keys >= 4 && kMaxKeys >= 4)) { + Merge16x4(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); + + if (HWY_LIKELY(keys >= 8 && kMaxKeys >= 8)) { + Merge16x8(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, + vb, vc, vd, ve, vf); + + // Avoids build timeout. Must match #if condition in kMaxCols. +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + if (HWY_LIKELY(keys >= 16 && kMaxKeys >= 16)) { + Merge16x16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + va, vb, vc, vd, ve, vf); + + static_assert(Constants::kMaxCols <= 16, "Add more branches"); + } +#endif + } + } + } +} + +// As above, but loads from/stores to `buf`. This ensures full vectors are +// aligned, and enables loads/stores without bounds checks. +// +// DEPRECATED, use BaseCase() instead. +template +HWY_NOINLINE void SortingNetwork(Traits st, T* HWY_RESTRICT buf, size_t cols) { + // traits*-inl assume 'full' vectors (but still capped to kMaxCols). + // However, for smaller arrays and sub-maximal `cols` we have overlapping + // loads where only the lowest `cols` are valid, and we skip Merge16 etc. + const CappedTag d; + using V = decltype(Zero(d)); + + HWY_DASSERT(cols <= Constants::kMaxCols); + + // These are aligned iff cols == Lanes(d). We prefer unaligned/non-constexpr + // offsets to duplicating this code for every value of cols. + static_assert(Constants::kMaxRows == 16, "Update loads/stores/args"); + V v0 = LoadU(d, buf + 0x0 * cols); + V v1 = LoadU(d, buf + 0x1 * cols); + V v2 = LoadU(d, buf + 0x2 * cols); + V v3 = LoadU(d, buf + 0x3 * cols); + V v4 = LoadU(d, buf + 0x4 * cols); + V v5 = LoadU(d, buf + 0x5 * cols); + V v6 = LoadU(d, buf + 0x6 * cols); + V v7 = LoadU(d, buf + 0x7 * cols); + V v8 = LoadU(d, buf + 0x8 * cols); + V v9 = LoadU(d, buf + 0x9 * cols); + V va = LoadU(d, buf + 0xa * cols); + V vb = LoadU(d, buf + 0xb * cols); + V vc = LoadU(d, buf + 0xc * cols); + V vd = LoadU(d, buf + 0xd * cols); + V ve = LoadU(d, buf + 0xe * cols); + V vf = LoadU(d, buf + 0xf * cols); + + SortingNetwork(st, cols, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, + vd, ve, vf); + + StoreU(v0, d, buf + 0x0 * cols); + StoreU(v1, d, buf + 0x1 * cols); + StoreU(v2, d, buf + 0x2 * cols); + StoreU(v3, d, buf + 0x3 * cols); + StoreU(v4, d, buf + 0x4 * cols); + StoreU(v5, d, buf + 0x5 * cols); + StoreU(v6, d, buf + 0x6 * cols); + StoreU(v7, d, buf + 0x7 * cols); + StoreU(v8, d, buf + 0x8 * cols); + StoreU(v9, d, buf + 0x9 * cols); + StoreU(va, d, buf + 0xa * cols); + StoreU(vb, d, buf + 0xb * cols); + StoreU(vc, d, buf + 0xc * cols); + StoreU(vd, d, buf + 0xd * cols); + StoreU(ve, d, buf + 0xe * cols); + StoreU(vf, d, buf + 0xf * cols); +} + +#else +template +struct SharedTraits : public Base {}; +#endif // VQSORT_ENABLED + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/traits-inl.h b/third_party/aom/third_party/highway/hwy/contrib/sort/traits-inl.h new file mode 100644 index 000000000000..efa410c81d35 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/traits-inl.h @@ -0,0 +1,618 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#endif + +#include +#include + +#include "third_party/highway/hwy/contrib/sort/order.h" // SortDescending +#include "third_party/highway/hwy/contrib/sort/shared-inl.h" // SortConstants +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +// Base class of both KeyLane variants +template +struct KeyLaneBase { + static constexpr bool Is128() { return false; } + constexpr size_t LanesPerKey() const { return 1; } + + // What type bench_sort should allocate for generating inputs. + using LaneType = LaneTypeArg; + // What type to pass to VQSort. + using KeyType = KeyTypeArg; + + const char* KeyString() const { + return IsSame() ? "f16" + : IsSame() ? "f32" + : IsSame() ? "f64" + : IsSame() ? "i16" + : IsSame() ? "i32" + : IsSame() ? "i64" + : IsSame() ? "u32" + : IsSame() ? "u32" + : IsSame() ? "u64" + : IsSame() ? "k+v=64" + : "?"; + } +}; + +// Wrapper functions so we can specialize for floats - infinity trumps +// HighestValue (the normal value with the largest magnitude). Must be outside +// Order* classes to enable SFINAE. + +template +Vec LargestSortValue(D d) { + return Inf(d); +} +template +Vec LargestSortValue(D d) { + return Set(d, hwy::HighestValue>()); +} + +template +Vec SmallestSortValue(D d) { + return Neg(Inf(d)); +} +template +Vec SmallestSortValue(D d) { + return Set(d, hwy::LowestValue>()); +} + +// Returns the next distinct larger value unless already +inf. +template +Vec LargerSortValue(D d, Vec v) { + HWY_DASSERT(AllFalse(d, IsNaN(v))); // we replaced all NaN with LastValue. + using T = TFromD; + const RebindToUnsigned du; + using VU = Vec; + using TU = TFromD; + + const VU vu = BitCast(du, Abs(v)); + + // The direction depends on the original sign. Integer comparison is cheaper + // than float comparison and treats -0 as 0 (so we return +epsilon). + const Mask was_pos = Le(BitCast(du, v), SignBit(du)); + // If positive, add 1, else -1. + const VU add = IfThenElse(was_pos, Set(du, 1u), Set(du, LimitsMax())); + // Prev/next integer is the prev/next value, even if mantissa under/overflows. + v = BitCast(d, Add(vu, add)); + // But we may have overflowed into inf or NaN; replace with inf if positive, + // but the largest (later negated!) value if the input was -inf. + const Mask was_pos_f = RebindMask(d, was_pos); + v = IfThenElse(IsFinite(v), v, + IfThenElse(was_pos_f, Inf(d), Set(d, HighestValue()))); + // Restore the original sign - not via CopySignToAbs because we used a mask. + return IfThenElse(was_pos_f, v, Neg(v)); +} + +// Returns the next distinct smaller value unless already -inf. +template +Vec SmallerSortValue(D d, Vec v) { + HWY_DASSERT(AllFalse(d, IsNaN(v))); // we replaced all NaN with LastValue. + using T = TFromD; + const RebindToUnsigned du; + using VU = Vec; + using TU = TFromD; + + const VU vu = BitCast(du, Abs(v)); + + // The direction depends on the original sign. Float comparison because we + // want to treat 0 as -0 so we return -epsilon. + const Mask was_pos = Gt(v, Zero(d)); + // If positive, add -1, else 1. + const VU add = + IfThenElse(RebindMask(du, was_pos), Set(du, LimitsMax()), Set(du, 1)); + // Prev/next integer is the prev/next value, even if mantissa under/overflows. + v = BitCast(d, Add(vu, add)); + // But we may have overflowed into inf or NaN; replace with +inf (which will + // later be negated) if negative, but the largest value if the input was +inf. + v = IfThenElse(IsFinite(v), v, + IfThenElse(was_pos, Set(d, HighestValue()), Inf(d))); + // Restore the original sign - not via CopySignToAbs because we used a mask. + return IfThenElse(was_pos, v, Neg(v)); +} + +template +Vec LargerSortValue(D d, Vec v) { + return Add(v, Set(d, TFromD{1})); +} + +template +Vec SmallerSortValue(D d, Vec v) { + return Sub(v, Set(d, TFromD{1})); +} + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +template +struct KeyLane : public KeyLaneBase { + // For HeapSort + HWY_INLINE void Swap(LaneType* a, LaneType* b) const { + const LaneType temp = *a; + *a = *b; + *b = temp; + } + + template + HWY_INLINE V CompressKeys(V keys, M mask) const { + return CompressNot(keys, mask); + } + + // Broadcasts one key into a vector + template + HWY_INLINE Vec SetKey(D d, const LaneType* key) const { + return Set(d, *key); + } + + template + HWY_INLINE Mask EqualKeys(D /*tag*/, Vec a, Vec b) const { + return Eq(a, b); + } + + template + HWY_INLINE Mask NotEqualKeys(D /*tag*/, Vec a, Vec b) const { + return Ne(a, b); + } + + // For keys=lanes, any difference counts. + template + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned du; + return AllTrue(du, Eq(BitCast(du, diff), Zero(du))); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const { + return *a == *b; + } + + template + HWY_INLINE Vec ReverseKeys(D d, Vec v) const { + return Reverse(d, v); + } + + template + HWY_INLINE Vec ReverseKeys2(D d, Vec v) const { + return Reverse2(d, v); + } + + template + HWY_INLINE Vec ReverseKeys4(D d, Vec v) const { + return Reverse4(d, v); + } + + template + HWY_INLINE Vec ReverseKeys8(D d, Vec v) const { + return Reverse8(d, v); + } + + template + HWY_INLINE Vec ReverseKeys16(D d, Vec v) const { + static_assert(SortConstants::kMaxCols <= 16, "Assumes u32x16 = 512 bit"); + return ReverseKeys(d, v); + } + + template + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEven(odd, even); + } + + template + HWY_INLINE Vec SwapAdjacentPairs(D d, const Vec v) const { + const Repartition du32; + return BitCast(d, Shuffle2301(BitCast(du32, v))); + } + template + HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { + return Shuffle1032(v); + } + template + HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { + return SwapAdjacentBlocks(v); + } + + template + HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide> dw; +#endif + return BitCast(d, SwapAdjacentPairs(dw, BitCast(dw, v))); + } + template + HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { + // Assumes max vector size = 512 + return ConcatLowerUpper(d, v, v); + } + + template + HWY_INLINE Vec OddEvenPairs(D d, const Vec odd, + const Vec even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide> dw; +#endif + return BitCast(d, OddEven(BitCast(dw, odd), BitCast(dw, even))); + } + template + HWY_INLINE Vec OddEvenPairs(D /* tag */, Vec odd, Vec even) const { + return OddEvenBlocks(odd, even); + } + + template + HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide> dw; +#endif + return BitCast(d, OddEvenPairs(dw, BitCast(dw, odd), BitCast(dw, even))); + } + template + HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { + return ConcatUpperLower(d, odd, even); + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +template +struct OrderAscending : public KeyLane { + // False indicates the entire key (i.e. lane) should be compared. KV stands + // for key-value. + static constexpr bool IsKV() { return false; } + + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending; + + HWY_INLINE bool Compare1(const T* a, const T* b) const { return *a < *b; } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(a, b); + } + + // Two halves of Sort2, used in ScanMinMax. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return SmallestSortValue(d); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return LargestSortValue(d); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return SmallerSortValue(d, v); + } +}; + +template +struct OrderDescending : public KeyLane { + // False indicates the entire key (i.e. lane) should be compared. KV stands + // for key-value. + static constexpr bool IsKV() { return false; } + + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending; + + HWY_INLINE bool Compare1(const T* a, const T* b) const { return *b < *a; } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(b, a); + } + + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return LargestSortValue(d); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return SmallestSortValue(d); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return LargerSortValue(d, v); + } +}; + +struct KeyValue64 : public KeyLane { + // True indicates only part of the key (i.e. lane) should be compared. KV + // stands for key-value. + static constexpr bool IsKV() { return true; } + + template + HWY_INLINE Mask EqualKeys(D /*tag*/, Vec a, Vec b) const { + return Eq(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + template + HWY_INLINE Mask NotEqualKeys(D /*tag*/, Vec a, Vec b) const { + return Ne(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + HWY_INLINE bool Equal1(const uint64_t* a, const uint64_t* b) const { + return (*a >> 32) == (*b >> 32); + } + + // Only count differences in the actual key, not the value. + template + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned du; + const Vec zero = Zero(du); + const Vec keys = ShiftRight<32>(diff); // clear values + return AllTrue(du, Eq(BitCast(du, keys), zero)); + } +}; + +struct OrderAscendingKV64 : public KeyValue64 { + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return (*a >> 32) < (*b >> 32); + } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + // Not required to be stable (preserving the order of equivalent keys), so + // we can include the value in the comparison. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + // Same as for regular lanes. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue>()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue>()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return Sub(v, Set(d, uint64_t{1} << 32)); + } +}; + +struct OrderDescendingKV64 : public KeyValue64 { + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return (*b >> 32) < (*a >> 32); + } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(ShiftRight<32>(b), ShiftRight<32>(a)); + } + + // Not required to be stable (preserving the order of equivalent keys), so + // we can include the value in the comparison. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue>()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue>()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return Add(v, Set(d, uint64_t{1} << 32)); + } +}; + +// Shared code that depends on Order. +template +struct TraitsLane : public Base { + using TraitsForSortingNetwork = + TraitsLane; + + // For each lane i: replaces a[i] with the first and b[i] with the second + // according to Base. + // Corresponds to a conditional swap, which is one "node" of a sorting + // network. Min/Max are cheaper than compare + blend at least for integers. + template + HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { + const Base* base = static_cast(this); + + const Vec a_copy = a; + // Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4 + // instructions. We can reduce it to a compare + 2 IfThenElse. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + if (sizeof(TFromD) == 8) { + const Mask cmp = base->Compare(d, a, b); + a = IfThenElse(cmp, a, b); + b = IfThenElse(cmp, b, a_copy); + return; + } +#endif + a = base->First(d, a, b); + b = base->Last(d, a_copy, b); + } + + // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + // Further to the above optimization, Sort2+OddEvenKeys compile to four + // instructions; we can save one by combining two blends. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + const Vec cmp = VecFromMask(d, base->Compare(d, v, swapped)); + return IfVecThenElse(DupOdd(cmp), swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); +#endif + } + + // (See above - we use Sort2 for non-64-bit types.) + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 4 keys. + template + HWY_INLINE Vec SortPairsReverse4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys4(d, v); + Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template + HWY_INLINE Vec SortPairsDistance4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->SwapAdjacentQuads(d, v); + // Only used in Merge16, so this will not be used on AVX2 (which only has 4 + // u64 lanes), so skip the above optimization for 64-bit AVX2. + Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } +}; + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/traits128-inl.h b/third_party/aom/third_party/highway/hwy/contrib/sort/traits128-inl.h new file mode 100644 index 000000000000..404f9a936a2d --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/traits128-inl.h @@ -0,0 +1,549 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#endif + +#include +#include + +#include "third_party/highway/hwy/contrib/sort/order.h" // SortDescending +#include "third_party/highway/hwy/contrib/sort/shared-inl.h" +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +// Also used by HeapSort, so do not require VQSORT_ENABLED. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +struct KeyAny128 { + static constexpr bool Is128() { return true; } + constexpr size_t LanesPerKey() const { return 2; } + + // What type bench_sort should allocate for generating inputs. + using LaneType = uint64_t; + // KeyType and KeyString are defined by derived classes. + + HWY_INLINE void Swap(LaneType* a, LaneType* b) const { + const FixedTag d; + const auto temp = LoadU(d, a); + StoreU(LoadU(d, b), d, a); + StoreU(temp, d, b); + } + + template + HWY_INLINE V CompressKeys(V keys, M mask) const { + return CompressBlocksNot(keys, mask); + } + + template + HWY_INLINE Vec SetKey(D d, const TFromD* key) const { + return LoadDup128(d, key); + } + + template + HWY_INLINE Vec ReverseKeys(D d, Vec v) const { + return ReverseBlocks(d, v); + } + + template + HWY_INLINE Vec ReverseKeys2(D /* tag */, const Vec v) const { + HWY_DASSERT(Lanes(D()) >= 4); // at least 2 keys + return SwapAdjacentBlocks(v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template + HWY_INLINE Vec ReverseKeys4(D d, const Vec v) const { + HWY_DASSERT(Lanes(D()) == 8); // exactly 4 keys: the 512-bit limit + return ReverseKeys(d, v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template + HWY_INLINE Vec OddEvenPairs(D d, const Vec odd, + const Vec even) const { + HWY_DASSERT(Lanes(D()) == 8); // exactly 4 keys: the 512-bit limit + return ConcatUpperLower(d, odd, even); + } + + template + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEvenBlocks(odd, even); + } + + template + HWY_INLINE Vec ReverseKeys8(D, Vec) const { + HWY_ASSERT(0); // not supported: would require 1024-bit vectors + } + + template + HWY_INLINE Vec ReverseKeys16(D, Vec) const { + HWY_ASSERT(0); // not supported: would require 2048-bit vectors + } + + // This is only called for 8/16 col networks (not supported). + template + HWY_INLINE Vec SwapAdjacentPairs(D, Vec) const { + HWY_ASSERT(0); + } + + // This is only called for 16 col networks (not supported). + template + HWY_INLINE Vec SwapAdjacentQuads(D, Vec) const { + HWY_ASSERT(0); + } + + // This is only called for 8 col networks (not supported). + template + HWY_INLINE Vec OddEvenQuads(D, Vec, Vec) const { + HWY_ASSERT(0); + } +}; + +// Base class shared between OrderAscending128, OrderDescending128. +struct Key128 : public KeyAny128 { + // False indicates the entire key should be compared. KV means key-value. + static constexpr bool IsKV() { return false; } + + // What type to pass to VQSort. + using KeyType = hwy::uint128_t; + + const char* KeyString() const { return "U128"; } + + template + HWY_INLINE Mask EqualKeys(D d, Vec a, Vec b) const { + return Eq128(d, a, b); + } + + template + HWY_INLINE Mask NotEqualKeys(D d, Vec a, Vec b) const { + return Ne128(d, a, b); + } + + // For keys=entire 128 bits, any difference counts. + template + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned du; + return AllTrue(du, Eq(BitCast(du, diff), Zero(du))); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const { + return a[0] == b[0] && a[1] == b[1]; + } + + // Returns vector with only the top half of each block valid. This allows + // fusing the "replicate upper to lower half" step with a subsequent permute. + template + HWY_INLINE HWY_MAYBE_UNUSED Vec CompareTop(D d, Vec a, Vec b) const { + const Mask eqHL = Eq(a, b); + const Vec ltHL = VecFromMask(d, Order().CompareLanes(a, b)); +#if HWY_TARGET <= HWY_AVX2 // slightly faster + const Vec ltLX = ShiftLeftLanes<1>(ltHL); + return OrAnd(ltHL, VecFromMask(d, eqHL), ltLX); +#else + return IfThenElse(eqHL, DupEven(ltHL), ltHL); +#endif + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +struct OrderAscending128 : public Key128 { + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending128; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128(d, a, b); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Min128(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Max128(d, a, b); + } + + // FirstOfLanes/LastOfLanes are implemented in Traits128. + + // Same as for regular lanes because 128-bit keys are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k0 = Zero(d); + const Vec k1 = OddEven(k0, Set(d, uint64_t{1})); + const Mask borrow = Eq(v, k0); // don't-care, lo == 0 + // lo == 0? 1 : 0, 0 + const Vec adjust = ShiftLeftLanes<1>(IfThenElseZero(borrow, k1)); + return Sub(Sub(v, k1), adjust); + } + + // 'Private', used by base class Key128::CompareTop. + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(a, b); + } +}; + +struct OrderDescending128 : public Key128 { + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending128; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128(d, b, a); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Max128(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Min128(d, a, b); + } + + // FirstOfLanes/LastOfLanes are implemented in Traits128. + + // Same as for regular lanes because 128-bit keys are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k1 = OddEven(Zero(d), Set(d, uint64_t{1})); + const Vec added = Add(v, k1); + const Mask overflowed = Lt(added, v); // false, overflowed + // overflowed? 1 : 0, 0 + const Vec adjust = ShiftLeftLanes<1>(IfThenElseZero(overflowed, k1)); + return Add(added, adjust); + } + + // 'Private', used by base class Key128::CompareTop. + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(b, a); + } +}; + +// Base class shared between OrderAscendingKV128, OrderDescendingKV128. +struct KeyValue128 : public KeyAny128 { + // True indicates only part of the key (the more significant lane) should be + // compared. KV stands for key-value. + static constexpr bool IsKV() { return true; } + + // What type to pass to VQSort. + using KeyType = K64V64; + + const char* KeyString() const { return "k+v=128"; } + + template + HWY_INLINE Mask EqualKeys(D d, Vec a, Vec b) const { + return Eq128Upper(d, a, b); + } + + template + HWY_INLINE Mask NotEqualKeys(D d, Vec a, Vec b) const { + return Ne128Upper(d, a, b); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const { + return a[1] == b[1]; + } + + // Only count differences in the actual key, not the value. + template + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned du; + const Vec zero = Zero(du); + const Vec keys = OddEven(diff, zero); // clear values + return AllTrue(du, Eq(BitCast(du, keys), zero)); + } + + // Returns vector with only the top half of each block valid. This allows + // fusing the "replicate upper to lower half" step with a subsequent permute. + template + HWY_INLINE HWY_MAYBE_UNUSED Vec CompareTop(D d, Vec a, Vec b) const { + // Only the upper lane of each block is a key, and only that lane is + // required to be valid, so comparing all lanes is sufficient. + return VecFromMask(d, Order().CompareLanes(a, b)); + } +}; + +struct OrderAscendingKV128 : public KeyValue128 { + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending128; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return a[1] < b[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128Upper(d, a, b); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Min128Upper(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Max128Upper(d, a, b); + } + + // FirstOfLanes/LastOfLanes are implemented in Traits128. + + // Same as for regular lanes because 128-bit keys are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k1 = OddEven(Set(d, uint64_t{1}), Zero(d)); + return Sub(v, k1); + } + + // 'Private', used by base class KeyValue128::CompareTop. + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(a, b); + } +}; + +struct OrderDescendingKV128 : public KeyValue128 { + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending128; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return b[1] < a[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128Upper(d, b, a); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Max128Upper(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Min128Upper(d, a, b); + } + + // FirstOfLanes/LastOfLanes are implemented in Traits128. + + // Same as for regular lanes because 128-bit keys are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k1 = OddEven(Set(d, uint64_t{1}), Zero(d)); + return Add(v, k1); + } + + // 'Private', used by base class KeyValue128::CompareTop. + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(b, a); + } +}; + +// We want to swap 2 u128, i.e. 4 u64 lanes, based on the 0 or FF..FF mask in +// the most-significant of those lanes (the result of CompareTop), so +// replicate it 4x. Only called for >= 256-bit vectors. + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V ReplicateTop4x(V v) { + return V{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +} +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_TARGET <= HWY_AVX2 + +template +HWY_INLINE V ReplicateTop4x(V v) { + return V{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +} + +#else // HWY_TARGET > HWY_AVX2 + +template +HWY_INLINE V ReplicateTop4x(V v) { +#if HWY_TARGET == HWY_SVE_256 + return svdup_lane_u64(v, 3); +#else + const ScalableTag d; + HWY_DASSERT(Lanes(d) == 4 || Lanes(d) == 8); // for table below + HWY_ALIGN static constexpr uint64_t kIndices[8] = {3, 3, 3, 3, 7, 7, 7, 7}; + return TableLookupLanes(v, SetTableIndices(d, kIndices)); +#endif +} + +#endif // HWY_TARGET <= HWY_AVX2 + +// Shared code that depends on Order. +template +struct Traits128 : public Base { + using TraitsForSortingNetwork = + Traits128; + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const Base* base = static_cast(this); + const size_t N = Lanes(d); + Store(v, d, buf); + v = base->SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) { + v = base->First(d, v, base->SetKey(d, buf + i)); + } + return v; + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const Base* base = static_cast(this); + const size_t N = Lanes(d); + Store(v, d, buf); + v = base->SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) { + v = base->Last(d, v, base->SetKey(d, buf + i)); + } + return v; + } + + template + HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { + const Base* base = static_cast(this); + + const Vec a_copy = a; + const auto lt = base->Compare(d, a, b); + a = IfThenElse(lt, a, b); + b = IfThenElse(lt, b, a_copy); + } + + // Conditionally swaps even-numbered keys with their odd-numbered neighbor. + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + HWY_DASSERT(Lanes(d) >= 4); // required by ReplicateTop4x + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + const Vec cmpHx = base->template CompareTop(d, v, swapped); + return IfVecThenElse(ReplicateTop4x(cmpHx), swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of four 128-bit + // keys, which implies 512-bit vectors (we do not support more than that). + template + HWY_INLINE Vec SortPairsReverse4(D d, Vec v) const { + HWY_DASSERT(Lanes(d) == 8); // For TableLookupLanes below + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys4(d, v); + + const Vec cmpHx = base->template CompareTop(d, v, swapped); + // Similar to ReplicateTop4x, we want to gang together 2 comparison results + // (4 lanes). They are not contiguous, so use permute to replicate 4x. + HWY_ALIGN uint64_t kIndices[8] = {7, 7, 5, 5, 5, 5, 7, 7}; + const Vec select = TableLookupLanes(cmpHx, SetTableIndices(d, kIndices)); + return IfVecThenElse(select, swapped, v); + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template + HWY_INLINE Vec SortPairsDistance4(D, Vec) const { + // Only used by Merge16, which would require 2048 bit vectors (unsupported). + HWY_ASSERT(0); + } +}; + +#endif // HWY_TARGET != HWY_SCALAR + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/vqsort-inl.h b/third_party/aom/third_party/highway/hwy/contrib/sort/vqsort-inl.h new file mode 100644 index 000000000000..5eaf4d56f8fb --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/vqsort-inl.h @@ -0,0 +1,2210 @@ +// Copyright 2021 Google LLC +// Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +// unconditional #include so we can use if(VQSORT_PRINT), which unlike #if does +// not interfere with code-folding. +#include +#include // clock + +// IWYU pragma: begin_exports +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/contrib/sort/order.h" // SortAscending +// IWYU pragma: end_exports + +#include "third_party/highway/hwy/cache_control.h" // Prefetch +#include "third_party/highway/hwy/print.h" // unconditional, see above. + +// If 1, VQSortStatic can be called without including vqsort.h, and we avoid +// any DLLEXPORT. This simplifies integration into other build systems, but +// decreases the security of random seeds. +#ifndef VQSORT_ONLY_STATIC +#define VQSORT_ONLY_STATIC 0 +#endif + +// Verbosity: 0 for none, 1 for brief per-sort, 2+ for more details. +#ifndef VQSORT_PRINT +#define VQSORT_PRINT 0 +#endif + +#if !VQSORT_ONLY_STATIC +#include "third_party/highway/hwy/contrib/sort/vqsort.h" // Fill16BytesSecure +#endif + +namespace hwy { +namespace detail { + +HWY_INLINE void Fill16BytesStatic(void* bytes) { +#if !VQSORT_ONLY_STATIC + if (Fill16BytesSecure(bytes)) return; +#endif + + uint64_t* words = reinterpret_cast(bytes); + + // Static-only, or Fill16BytesSecure failed. Get some entropy from the + // stack/code location, and the clock() timer. + uint64_t** seed_stack = &words; + void (*seed_code)(void*) = &Fill16BytesStatic; + const uintptr_t bits_stack = reinterpret_cast(seed_stack); + const uintptr_t bits_code = reinterpret_cast(seed_code); + const uint64_t bits_time = static_cast(clock()); + words[0] = bits_stack ^ bits_time ^ 0xFEDCBA98; // "Nothing up my sleeve" + words[1] = bits_code ^ bits_time ^ 0x01234567; // constants. +} + +HWY_INLINE uint64_t* GetGeneratorStateStatic() { + thread_local uint64_t state[3] = {0}; + // This is a counter; zero indicates not yet initialized. + if (HWY_UNLIKELY(state[2] == 0)) { + Fill16BytesStatic(state); + state[2] = 1; + } + return state; +} + +} // namespace detail +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#endif + +#if VQSORT_PRINT +#include "third_party/highway/hwy/print-inl.h" +#endif + +#include "third_party/highway/hwy/contrib/algo/copy-inl.h" +#include "third_party/highway/hwy/contrib/sort/shared-inl.h" +#include "third_party/highway/hwy/contrib/sort/sorting_networks-inl.h" +#include "third_party/highway/hwy/contrib/sort/traits-inl.h" +#include "third_party/highway/hwy/contrib/sort/traits128-inl.h" +// Placeholder for internal instrumentation. Do not remove. +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +using Constants = hwy::SortConstants; + +// Wrapper avoids #if in user code (interferes with code folding) +template +HWY_INLINE void MaybePrintVector(D d, const char* label, Vec v, + size_t start = 0, size_t max_lanes = 16) { +#if VQSORT_PRINT >= 2 // Print is only defined #if + Print(d, label, v, start, max_lanes); +#else + (void)d; + (void)label; + (void)v; + (void)start; + (void)max_lanes; +#endif +} + +// ------------------------------ HeapSort + +template +void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, + size_t start) { + constexpr size_t N1 = st.LanesPerKey(); + const FixedTag d; + + while (start < num_lanes) { + const size_t left = 2 * start + N1; + const size_t right = 2 * start + 2 * N1; + if (left >= num_lanes) break; + size_t idx_larger = start; + const auto key_j = st.SetKey(d, lanes + start); + if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) { + idx_larger = left; + } + if (right < num_lanes && + AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger), + st.SetKey(d, lanes + right)))) { + idx_larger = right; + } + if (idx_larger == start) break; + st.Swap(lanes + start, lanes + idx_larger); + start = idx_larger; + } +} + +// Heapsort: O(1) space, O(N*logN) worst-case comparisons. +// Based on LLVM sanitizer_common.h, licensed under Apache-2.0. +template +void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) { + constexpr size_t N1 = st.LanesPerKey(); + HWY_DASSERT(num_lanes % N1 == 0); + if (num_lanes == N1) return; + + // Build heap. + for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) { + SiftDown(st, lanes, num_lanes, i); + } + + for (size_t i = num_lanes - N1; i != 0; i -= N1) { +// Workaround for -Waggressive-loop-optimizations warning that might be emitted +// by GCC +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4756, + ignored "-Waggressive-loop-optimizations") +#endif + // Swap root with last + st.Swap(lanes + 0, lanes + i); + +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(pop) +#endif + + // Sift down the new root. + SiftDown(st, lanes, i, 0); + } +} + +template +void HeapSelect(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, + const size_t k_lanes) { + constexpr size_t N1 = st.LanesPerKey(); + const size_t k = k_lanes + N1; + HWY_DASSERT(num_lanes % N1 == 0); + if (num_lanes == N1) return; + + const FixedTag d; + + // Build heap. + for (size_t i = ((k - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) { + SiftDown(st, lanes, k, i); + } + + for (size_t i = k; i <= num_lanes - N1; i += N1) { + if (AllTrue(d, st.Compare(d, st.SetKey(d, lanes + i), + st.SetKey(d, lanes + 0)))) { +// Workaround for -Waggressive-loop-optimizations warning that might be emitted +// by GCC +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4756, + ignored "-Waggressive-loop-optimizations") +#endif + + // Swap root with last + st.Swap(lanes + 0, lanes + i); + +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(pop) +#endif + + // Sift down the new root. + SiftDown(st, lanes, k, 0); + } + } + + st.Swap(lanes + 0, lanes + k - N1); +} + +template +void HeapPartialSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, + const size_t k_lanes) { + HeapSelect(st, lanes, num_lanes, k_lanes); + HeapSort(st, lanes, k_lanes); +} + +#if VQSORT_ENABLED || HWY_IDE + +// ------------------------------ BaseCase + +// Special cases where `num_lanes` is in the specified range (inclusive). +template +HWY_INLINE void Sort2To2(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, + T* HWY_RESTRICT /* buf */) { + constexpr size_t kLPK = st.LanesPerKey(); + const size_t num_keys = num_lanes / kLPK; + HWY_DASSERT(num_keys == 2); + HWY_ASSUME(num_keys == 2); + + // One key per vector, required to avoid reading past the end of `keys`. + const CappedTag d; + using V = Vec; + + V v0 = LoadU(d, keys + 0x0 * kLPK); + V v1 = LoadU(d, keys + 0x1 * kLPK); + + Sort2(d, st, v0, v1); + + StoreU(v0, d, keys + 0x0 * kLPK); + StoreU(v1, d, keys + 0x1 * kLPK); +} + +template +HWY_INLINE void Sort3To4(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, + T* HWY_RESTRICT buf) { + constexpr size_t kLPK = st.LanesPerKey(); + const size_t num_keys = num_lanes / kLPK; + HWY_DASSERT(3 <= num_keys && num_keys <= 4); + HWY_ASSUME(num_keys >= 3); + HWY_ASSUME(num_keys <= 4); // reduces branches + + // One key per vector, required to avoid reading past the end of `keys`. + const CappedTag d; + using V = Vec; + + // If num_keys == 3, initialize padding for the last sorting network element + // so that it does not influence the other elements. + Store(st.LastValue(d), d, buf); + + // Points to a valid key, or padding. This avoids special-casing + // HWY_MEM_OPS_MIGHT_FAULT because there is only a single key per vector. + T* in_out3 = num_keys == 3 ? buf : keys + 0x3 * kLPK; + + V v0 = LoadU(d, keys + 0x0 * kLPK); + V v1 = LoadU(d, keys + 0x1 * kLPK); + V v2 = LoadU(d, keys + 0x2 * kLPK); + V v3 = LoadU(d, in_out3); + + Sort4(d, st, v0, v1, v2, v3); + + StoreU(v0, d, keys + 0x0 * kLPK); + StoreU(v1, d, keys + 0x1 * kLPK); + StoreU(v2, d, keys + 0x2 * kLPK); + StoreU(v3, d, in_out3); +} + +#if HWY_MEM_OPS_MIGHT_FAULT + +template > +HWY_INLINE void CopyHalfToPaddedBuf(D d, Traits st, T* HWY_RESTRICT keys, + size_t num_lanes, T* HWY_RESTRICT buf) { + constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow; + // Must cap for correctness: we will load up to the last valid lane, so + // Lanes(dmax) must not exceed `num_lanes` (known to be at least kMinLanes). + const CappedTag dmax; + const size_t Nmax = Lanes(dmax); + HWY_DASSERT(Nmax < num_lanes); + HWY_ASSUME(Nmax <= kMinLanes); + + // Fill with padding - last in sort order, not copied to keys. + const Vec kPadding = st.LastValue(dmax); + + // Rounding down allows aligned stores, which are typically faster. + size_t i = num_lanes & ~(Nmax - 1); + HWY_ASSUME(i != 0); // because Nmax <= num_lanes; avoids branch + do { + Store(kPadding, dmax, buf + i); + i += Nmax; + // Initialize enough for the last vector even if Nmax > kLanesPerRow. + } while (i < (kRows - 1) * kLanesPerRow + Lanes(d)); + + // Ensure buf contains all we will read, and perhaps more before. + ptrdiff_t end = static_cast(num_lanes); + do { + end -= static_cast(Nmax); + StoreU(LoadU(dmax, keys + end), dmax, buf + end); + } while (end > static_cast(kRows / 2 * kLanesPerRow)); +} + +#endif // HWY_MEM_OPS_MIGHT_FAULT + +template +HWY_NOINLINE void Sort8Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, + T* HWY_RESTRICT buf) { + // kKeysPerRow <= 4 because 8 64-bit keys implies 512-bit vectors, which + // are likely slower than 16x4, so 8x4 is the largest we handle here. + static_assert(kKeysPerRow <= 4, ""); + + constexpr size_t kLPK = st.LanesPerKey(); + + // We reshape the 1D keys into kRows x kKeysPerRow. + constexpr size_t kRows = 8; + constexpr size_t kLanesPerRow = kKeysPerRow * kLPK; + constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow; + HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow); + + const CappedTag d; + using V = Vec; + V v4, v5, v6, v7; + + // At least half the kRows are valid, otherwise a different function would + // have been called to handle this num_lanes. + V v0 = LoadU(d, keys + 0x0 * kLanesPerRow); + V v1 = LoadU(d, keys + 0x1 * kLanesPerRow); + V v2 = LoadU(d, keys + 0x2 * kLanesPerRow); + V v3 = LoadU(d, keys + 0x3 * kLanesPerRow); +#if HWY_MEM_OPS_MIGHT_FAULT + CopyHalfToPaddedBuf(d, st, keys, num_lanes, buf); + v4 = LoadU(d, buf + 0x4 * kLanesPerRow); + v5 = LoadU(d, buf + 0x5 * kLanesPerRow); + v6 = LoadU(d, buf + 0x6 * kLanesPerRow); + v7 = LoadU(d, buf + 0x7 * kLanesPerRow); +#endif // HWY_MEM_OPS_MIGHT_FAULT +#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE + (void)buf; + const V vnum_lanes = Set(d, ConvertScalarTo(num_lanes)); + // First offset where not all vector are guaranteed valid. + const V kIota = Iota(d, static_cast(kMinLanes)); + const V k1 = Set(d, static_cast(kLanesPerRow)); + const V k2 = Add(k1, k1); + + using M = Mask; + const M m4 = Gt(vnum_lanes, kIota); + const M m5 = Gt(vnum_lanes, Add(kIota, k1)); + const M m6 = Gt(vnum_lanes, Add(kIota, k2)); + const M m7 = Gt(vnum_lanes, Add(kIota, Add(k2, k1))); + + const V kPadding = st.LastValue(d); // Not copied to keys. + v4 = MaskedLoadOr(kPadding, m4, d, keys + 0x4 * kLanesPerRow); + v5 = MaskedLoadOr(kPadding, m5, d, keys + 0x5 * kLanesPerRow); + v6 = MaskedLoadOr(kPadding, m6, d, keys + 0x6 * kLanesPerRow); + v7 = MaskedLoadOr(kPadding, m7, d, keys + 0x7 * kLanesPerRow); +#endif // !HWY_MEM_OPS_MIGHT_FAULT + + Sort8(d, st, v0, v1, v2, v3, v4, v5, v6, v7); + + // Merge8x2 is a no-op if kKeysPerRow < 2 etc. + Merge8x2(d, st, v0, v1, v2, v3, v4, v5, v6, v7); + Merge8x4(d, st, v0, v1, v2, v3, v4, v5, v6, v7); + + StoreU(v0, d, keys + 0x0 * kLanesPerRow); + StoreU(v1, d, keys + 0x1 * kLanesPerRow); + StoreU(v2, d, keys + 0x2 * kLanesPerRow); + StoreU(v3, d, keys + 0x3 * kLanesPerRow); + +#if HWY_MEM_OPS_MIGHT_FAULT + // Store remaining vectors into buf and safely copy them into keys. + StoreU(v4, d, buf + 0x4 * kLanesPerRow); + StoreU(v5, d, buf + 0x5 * kLanesPerRow); + StoreU(v6, d, buf + 0x6 * kLanesPerRow); + StoreU(v7, d, buf + 0x7 * kLanesPerRow); + + const ScalableTag dmax; + const size_t Nmax = Lanes(dmax); + + // The first half of vectors have already been stored unconditionally into + // `keys`, so we do not copy them. + size_t i = kMinLanes; + HWY_UNROLL(1) + for (; i + Nmax <= num_lanes; i += Nmax) { + StoreU(LoadU(dmax, buf + i), dmax, keys + i); + } + + // Last iteration: copy partial vector + const size_t remaining = num_lanes - i; + HWY_ASSUME(remaining < 256); // helps FirstN + SafeCopyN(remaining, dmax, buf + i, keys + i); +#endif // HWY_MEM_OPS_MIGHT_FAULT +#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE + BlendedStore(v4, m4, d, keys + 0x4 * kLanesPerRow); + BlendedStore(v5, m5, d, keys + 0x5 * kLanesPerRow); + BlendedStore(v6, m6, d, keys + 0x6 * kLanesPerRow); + BlendedStore(v7, m7, d, keys + 0x7 * kLanesPerRow); +#endif // !HWY_MEM_OPS_MIGHT_FAULT +} + +template +HWY_NOINLINE void Sort16Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, + T* HWY_RESTRICT buf) { + static_assert(kKeysPerRow <= SortConstants::kMaxCols, ""); + + constexpr size_t kLPK = st.LanesPerKey(); + + // We reshape the 1D keys into kRows x kKeysPerRow. + constexpr size_t kRows = 16; + constexpr size_t kLanesPerRow = kKeysPerRow * kLPK; + constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow; + HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow); + + const CappedTag d; + using V = Vec; + V v8, v9, va, vb, vc, vd, ve, vf; + + // At least half the kRows are valid, otherwise a different function would + // have been called to handle this num_lanes. + V v0 = LoadU(d, keys + 0x0 * kLanesPerRow); + V v1 = LoadU(d, keys + 0x1 * kLanesPerRow); + V v2 = LoadU(d, keys + 0x2 * kLanesPerRow); + V v3 = LoadU(d, keys + 0x3 * kLanesPerRow); + V v4 = LoadU(d, keys + 0x4 * kLanesPerRow); + V v5 = LoadU(d, keys + 0x5 * kLanesPerRow); + V v6 = LoadU(d, keys + 0x6 * kLanesPerRow); + V v7 = LoadU(d, keys + 0x7 * kLanesPerRow); +#if HWY_MEM_OPS_MIGHT_FAULT + CopyHalfToPaddedBuf(d, st, keys, num_lanes, buf); + v8 = LoadU(d, buf + 0x8 * kLanesPerRow); + v9 = LoadU(d, buf + 0x9 * kLanesPerRow); + va = LoadU(d, buf + 0xa * kLanesPerRow); + vb = LoadU(d, buf + 0xb * kLanesPerRow); + vc = LoadU(d, buf + 0xc * kLanesPerRow); + vd = LoadU(d, buf + 0xd * kLanesPerRow); + ve = LoadU(d, buf + 0xe * kLanesPerRow); + vf = LoadU(d, buf + 0xf * kLanesPerRow); +#endif // HWY_MEM_OPS_MIGHT_FAULT +#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE + (void)buf; + const V vnum_lanes = Set(d, ConvertScalarTo(num_lanes)); + // First offset where not all vector are guaranteed valid. + const V kIota = Iota(d, static_cast(kMinLanes)); + const V k1 = Set(d, static_cast(kLanesPerRow)); + const V k2 = Add(k1, k1); + const V k4 = Add(k2, k2); + const V k8 = Add(k4, k4); + + using M = Mask; + const M m8 = Gt(vnum_lanes, kIota); + const M m9 = Gt(vnum_lanes, Add(kIota, k1)); + const M ma = Gt(vnum_lanes, Add(kIota, k2)); + const M mb = Gt(vnum_lanes, Add(kIota, Sub(k4, k1))); + const M mc = Gt(vnum_lanes, Add(kIota, k4)); + const M md = Gt(vnum_lanes, Add(kIota, Add(k4, k1))); + const M me = Gt(vnum_lanes, Add(kIota, Add(k4, k2))); + const M mf = Gt(vnum_lanes, Add(kIota, Sub(k8, k1))); + + const V kPadding = st.LastValue(d); // Not copied to keys. + v8 = MaskedLoadOr(kPadding, m8, d, keys + 0x8 * kLanesPerRow); + v9 = MaskedLoadOr(kPadding, m9, d, keys + 0x9 * kLanesPerRow); + va = MaskedLoadOr(kPadding, ma, d, keys + 0xa * kLanesPerRow); + vb = MaskedLoadOr(kPadding, mb, d, keys + 0xb * kLanesPerRow); + vc = MaskedLoadOr(kPadding, mc, d, keys + 0xc * kLanesPerRow); + vd = MaskedLoadOr(kPadding, md, d, keys + 0xd * kLanesPerRow); + ve = MaskedLoadOr(kPadding, me, d, keys + 0xe * kLanesPerRow); + vf = MaskedLoadOr(kPadding, mf, d, keys + 0xf * kLanesPerRow); +#endif // !HWY_MEM_OPS_MIGHT_FAULT + + Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf); + + // Merge16x4 is a no-op if kKeysPerRow < 4 etc. + Merge16x2(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); + Merge16x4(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); + Merge16x8(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + Merge16x16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); +#endif + + StoreU(v0, d, keys + 0x0 * kLanesPerRow); + StoreU(v1, d, keys + 0x1 * kLanesPerRow); + StoreU(v2, d, keys + 0x2 * kLanesPerRow); + StoreU(v3, d, keys + 0x3 * kLanesPerRow); + StoreU(v4, d, keys + 0x4 * kLanesPerRow); + StoreU(v5, d, keys + 0x5 * kLanesPerRow); + StoreU(v6, d, keys + 0x6 * kLanesPerRow); + StoreU(v7, d, keys + 0x7 * kLanesPerRow); + +#if HWY_MEM_OPS_MIGHT_FAULT + // Store remaining vectors into buf and safely copy them into keys. + StoreU(v8, d, buf + 0x8 * kLanesPerRow); + StoreU(v9, d, buf + 0x9 * kLanesPerRow); + StoreU(va, d, buf + 0xa * kLanesPerRow); + StoreU(vb, d, buf + 0xb * kLanesPerRow); + StoreU(vc, d, buf + 0xc * kLanesPerRow); + StoreU(vd, d, buf + 0xd * kLanesPerRow); + StoreU(ve, d, buf + 0xe * kLanesPerRow); + StoreU(vf, d, buf + 0xf * kLanesPerRow); + + const ScalableTag dmax; + const size_t Nmax = Lanes(dmax); + + // The first half of vectors have already been stored unconditionally into + // `keys`, so we do not copy them. + size_t i = kMinLanes; + HWY_UNROLL(1) + for (; i + Nmax <= num_lanes; i += Nmax) { + StoreU(LoadU(dmax, buf + i), dmax, keys + i); + } + + // Last iteration: copy partial vector + const size_t remaining = num_lanes - i; + HWY_ASSUME(remaining < 256); // helps FirstN + SafeCopyN(remaining, dmax, buf + i, keys + i); +#endif // HWY_MEM_OPS_MIGHT_FAULT +#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE + BlendedStore(v8, m8, d, keys + 0x8 * kLanesPerRow); + BlendedStore(v9, m9, d, keys + 0x9 * kLanesPerRow); + BlendedStore(va, ma, d, keys + 0xa * kLanesPerRow); + BlendedStore(vb, mb, d, keys + 0xb * kLanesPerRow); + BlendedStore(vc, mc, d, keys + 0xc * kLanesPerRow); + BlendedStore(vd, md, d, keys + 0xd * kLanesPerRow); + BlendedStore(ve, me, d, keys + 0xe * kLanesPerRow); + BlendedStore(vf, mf, d, keys + 0xf * kLanesPerRow); +#endif // !HWY_MEM_OPS_MIGHT_FAULT +} + +// Sorts `keys` within the range [0, num_lanes) via sorting network. +// Reshapes into a matrix, sorts columns independently, and then merges +// into a sorted 1D array without transposing. +// +// `TraitsKV` is SharedTraits>. This abstraction layer bridges +// differences in sort order and single-lane vs 128-bit keys. For key-value +// types, items with the same key are not equivalent. Our sorting network +// does not preserve order, thus we prevent mixing padding into the items by +// comparing all the item bits, including the value (see *ForSortingNetwork). +// +// See M. Blacher's thesis: https://github.com/mark-blacher/masterthesis +template +HWY_NOINLINE void BaseCase(D d, TraitsKV, T* HWY_RESTRICT keys, + size_t num_lanes, T* buf) { + using Traits = typename TraitsKV::SharedTraitsForSortingNetwork; + Traits st; + constexpr size_t kLPK = st.LanesPerKey(); + HWY_DASSERT(num_lanes <= Constants::BaseCaseNumLanes(Lanes(d))); + const size_t num_keys = num_lanes / kLPK; + + // Can be zero when called through HandleSpecialCases, but also 1 (in which + // case the array is already sorted). Also ensures num_lanes - 1 != 0. + if (HWY_UNLIKELY(num_keys <= 1)) return; + + const size_t ceil_log2 = + 32 - Num0BitsAboveMS1Bit_Nonzero32(static_cast(num_keys - 1)); + + // Checking kMaxKeysPerVector avoids generating unreachable codepaths. + constexpr size_t kMaxKeysPerVector = MaxLanes(d) / kLPK; + + using FuncPtr = decltype(&Sort2To2); + const FuncPtr funcs[9] = { + /* <= 1 */ nullptr, // We ensured num_keys > 1. + /* <= 2 */ &Sort2To2, + /* <= 4 */ &Sort3To4, + /* <= 8 */ &Sort8Rows<1, Traits, T>, // 1 key per row + /* <= 16 */ kMaxKeysPerVector >= 2 ? &Sort8Rows<2, Traits, T> : nullptr, + /* <= 32 */ kMaxKeysPerVector >= 4 ? &Sort8Rows<4, Traits, T> : nullptr, + /* <= 64 */ kMaxKeysPerVector >= 4 ? &Sort16Rows<4, Traits, T> : nullptr, + /* <= 128 */ kMaxKeysPerVector >= 8 ? &Sort16Rows<8, Traits, T> : nullptr, +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + /* <= 256 */ kMaxKeysPerVector >= 16 ? &Sort16Rows<16, Traits, T> + : nullptr, +#endif + }; + funcs[ceil_log2](st, keys, num_lanes, buf); +} + +// ------------------------------ Partition + +// Partitions O(1) of the *rightmost* keys, at least `N`, until a multiple of +// kUnroll*N remains, or all keys if there are too few for that. +// +// Returns how many remain to partition at the *start* of `keys`, sets `bufL` to +// the number of keys for the left partition written to `buf`, and `writeR` to +// the start of the finished right partition at the end of `keys`. +template +HWY_INLINE size_t PartitionRightmost(D d, Traits st, T* const keys, + const size_t num, const Vec pivot, + size_t& bufL, size_t& writeR, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + HWY_DASSERT(num > 2 * N); // BaseCase handles smaller arrays + + constexpr size_t kUnroll = Constants::kPartitionUnroll; + size_t num_here; // how many to process here + size_t num_main; // how many for main Partition loop (return value) + { + // The main Partition loop increments by kUnroll * N, so at least handle + // the remainders here. + const size_t remainder = num & (kUnroll * N - 1); + // Ensure we handle at least one vector to prevent overruns (see below), but + // still leave a multiple of kUnroll * N. + const size_t min = remainder + (remainder < N ? kUnroll * N : 0); + // Do not exceed the input size. + num_here = HWY_MIN(min, num); + num_main = num - num_here; + // Before the main Partition loop we load two blocks; if not enough left for + // that, handle everything here. + if (num_main < 2 * kUnroll * N) { + num_here = num; + num_main = 0; + } + } + + // Note that `StoreLeftRight` uses `CompressBlendedStore`, which may load and + // store a whole vector starting at `writeR`, and thus overrun `keys`. To + // prevent this, we partition at least `N` of the rightmost `keys` so that + // `StoreLeftRight` will be able to safely blend into them. + HWY_DASSERT(num_here >= N); + + // We cannot use `CompressBlendedStore` for the same reason, so we instead + // write the right-of-partition keys into a buffer in ascending order. + // `min` may be up to (kUnroll + 1) * N, hence `num_here` could be as much as + // (3 * kUnroll + 1) * N, and they might all fall on one side of the pivot. + const size_t max_buf = (3 * kUnroll + 1) * N; + HWY_DASSERT(num_here <= max_buf); + + const T* pReadR = keys + num; // pre-decremented by N + + bufL = 0; + size_t bufR = max_buf; // starting position, not the actual count. + + size_t i = 0; + // For whole vectors, we can LoadU. + for (; i <= num_here - N; i += N) { + pReadR -= N; + HWY_DASSERT(pReadR >= keys); + const Vec v = LoadU(d, pReadR); + + const Mask comp = st.Compare(d, pivot, v); + const size_t numL = CompressStore(v, Not(comp), d, buf + bufL); + bufL += numL; + (void)CompressStore(v, comp, d, buf + bufR); + bufR += (N - numL); + } + + // Last iteration: avoid reading past the end. + const size_t remaining = num_here - i; + if (HWY_LIKELY(remaining != 0)) { + const Mask mask = FirstN(d, remaining); + pReadR -= remaining; + HWY_DASSERT(pReadR >= keys); + const Vec v = LoadN(d, pReadR, remaining); + + const Mask comp = st.Compare(d, pivot, v); + const size_t numL = CompressStore(v, AndNot(comp, mask), d, buf + bufL); + bufL += numL; + (void)CompressStore(v, comp, d, buf + bufR); + bufR += (remaining - numL); + } + + const size_t numWrittenR = bufR - max_buf; + // MSan seems not to understand CompressStore. + detail::MaybeUnpoison(buf, bufL); + detail::MaybeUnpoison(buf + max_buf, numWrittenR); + + // Overwrite already-read end of keys with bufR. + writeR = num - numWrittenR; + hwy::CopyBytes(buf + max_buf, keys + writeR, numWrittenR * sizeof(T)); + // Ensure we finished reading/writing all we wanted + HWY_DASSERT(pReadR == keys + num_main); + HWY_DASSERT(bufL + numWrittenR == num_here); + return num_main; +} + +// Note: we could track the OrXor of v and pivot to see if the entire left +// partition is equal, but that happens rarely and thus is a net loss. +template +HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec v, + const Vec pivot, T* HWY_RESTRICT keys, + size_t& writeL, size_t& remaining) { + const size_t N = Lanes(d); + + const Mask comp = st.Compare(d, pivot, v); + + // Otherwise StoreU/CompressStore overwrites right keys. + HWY_DASSERT(remaining >= 2 * N); + + remaining -= N; + if (hwy::HWY_NAMESPACE::CompressIsPartition::value || + (HWY_MAX_BYTES == 16 && st.Is128())) { + // Non-native Compress (e.g. AVX2): we are able to partition a vector using + // a single Compress+two StoreU instead of two Compress[Blended]Store. The + // latter are more expensive. Because we store entire vectors, the contents + // between the updated writeL and writeR are ignored and will be overwritten + // by subsequent calls. This works because writeL and writeR are at least + // two vectors apart. + const Vec lr = st.CompressKeys(v, comp); + const size_t num_left = N - CountTrue(d, comp); + StoreU(lr, d, keys + writeL); + // Now write the right-side elements (if any), such that the previous writeR + // is one past the end of the newly written right elements, then advance. + StoreU(lr, d, keys + remaining + writeL); + writeL += num_left; + } else { + // Native Compress[Store] (e.g. AVX3), which only keep the left or right + // side, not both, hence we require two calls. + const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL); + writeL += num_left; + + (void)CompressBlendedStore(v, comp, d, keys + remaining + writeL); + } +} + +template +HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec v0, + const Vec v1, const Vec v2, + const Vec v3, const Vec pivot, + T* HWY_RESTRICT keys, size_t& writeL, + size_t& remaining) { + StoreLeftRight(d, st, v0, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v1, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v2, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v3, pivot, keys, writeL, remaining); +} + +// For the last two vectors, we cannot use StoreLeftRight because it might +// overwrite prior right-side keys. Instead write R and append L into `buf`. +template +HWY_INLINE void StoreRightAndBuf(D d, Traits st, const Vec v, + const Vec pivot, T* HWY_RESTRICT keys, + size_t& writeR, T* HWY_RESTRICT buf, + size_t& bufL) { + const size_t N = Lanes(d); + const Mask comp = st.Compare(d, pivot, v); + const size_t numL = CompressStore(v, Not(comp), d, buf + bufL); + const size_t numR = N - numL; + bufL += numL; + writeR -= numR; + StoreN(Compress(v, comp), d, keys + writeR, numR); +} + +// Moves "<= pivot" keys to the front, and others to the back. pivot is +// broadcasted. Returns the index of the first key in the right partition. +// +// Time-critical, but aligned loads do not seem to be worthwhile because we +// are not bottlenecked by load ports. +template +HWY_INLINE size_t Partition(D d, Traits st, T* const keys, const size_t num, + const Vec pivot, T* HWY_RESTRICT buf) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + size_t bufL, writeR; + const size_t num_main = + PartitionRightmost(d, st, keys, num, pivot, bufL, writeR, buf); + HWY_DASSERT(num_main <= num && writeR <= num); + HWY_DASSERT(bufL <= Constants::PartitionBufNum(N)); + HWY_DASSERT(num_main + bufL == writeR); + + if (VQSORT_PRINT >= 3) { + fprintf(stderr, " num_main %zu bufL %zu writeR %zu\n", num_main, bufL, + writeR); + } + + constexpr size_t kUnroll = Constants::kPartitionUnroll; + + // Partition splits the vector into 3 sections, left to right: Elements + // smaller or equal to the pivot, unpartitioned elements and elements larger + // than the pivot. To write elements unconditionally on the loop body without + // overwriting existing data, we maintain two regions of the loop where all + // elements have been copied elsewhere (e.g. vector registers.). I call these + // bufferL and bufferR, for left and right respectively. + // + // These regions are tracked by the indices (writeL, writeR, left, right) as + // presented in the diagram below. + // + // writeL writeR + // \/ \/ + // | <= pivot | bufferL | unpartitioned | bufferR | > pivot | + // \/ \/ \/ + // readL readR num + // + // In the main loop body below we choose a side, load some elements out of the + // vector and move either `readL` or `readR`. Next we call into StoreLeftRight + // to partition the data, and the partitioned elements will be written either + // to writeR or writeL and the corresponding index will be moved accordingly. + // + // Note that writeR is not explicitly tracked as an optimization for platforms + // with conditional operations. Instead we track writeL and the number of + // not yet written elements (`remaining`). From the diagram above we can see + // that: + // writeR - writeL = remaining => writeR = remaining + writeL + // + // Tracking `remaining` is advantageous because each iteration reduces the + // number of unpartitioned elements by a fixed amount, so we can compute + // `remaining` without data dependencies. + size_t writeL = 0; + size_t remaining = writeR - writeL; + + const T* readL = keys; + const T* readR = keys + num_main; + // Cannot load if there were fewer than 2 * kUnroll * N. + if (HWY_LIKELY(num_main != 0)) { + HWY_DASSERT(num_main >= 2 * kUnroll * N); + HWY_DASSERT((num_main & (kUnroll * N - 1)) == 0); + + // Make space for writing in-place by reading from readL/readR. + const V vL0 = LoadU(d, readL + 0 * N); + const V vL1 = LoadU(d, readL + 1 * N); + const V vL2 = LoadU(d, readL + 2 * N); + const V vL3 = LoadU(d, readL + 3 * N); + readL += kUnroll * N; + readR -= kUnroll * N; + const V vR0 = LoadU(d, readR + 0 * N); + const V vR1 = LoadU(d, readR + 1 * N); + const V vR2 = LoadU(d, readR + 2 * N); + const V vR3 = LoadU(d, readR + 3 * N); + + // readL/readR changed above, so check again before the loop. + while (readL != readR) { + V v0, v1, v2, v3; + + // Data-dependent but branching is faster than forcing branch-free. + const size_t capacityL = + static_cast((readL - keys) - static_cast(writeL)); + HWY_DASSERT(capacityL <= num_main); // >= 0 + // Load data from the end of the vector with less data (front or back). + // The next paragraphs explain how this works. + // + // let block_size = (kUnroll * N) + // On the loop prelude we load block_size elements from the front of the + // vector and an additional block_size elements from the back. On each + // iteration k elements are written to the front of the vector and + // (block_size - k) to the back. + // + // This creates a loop invariant where the capacity on the front + // (capacityL) and on the back (capacityR) always add to 2 * block_size. + // In other words: + // capacityL + capacityR = 2 * block_size + // capacityR = 2 * block_size - capacityL + // + // This means that: + // capacityL > capacityR <=> + // capacityL > 2 * block_size - capacityL <=> + // 2 * capacityL > 2 * block_size <=> + // capacityL > block_size + if (capacityL > kUnroll * N) { // equivalent to capacityL > capacityR. + readR -= kUnroll * N; + v0 = LoadU(d, readR + 0 * N); + v1 = LoadU(d, readR + 1 * N); + v2 = LoadU(d, readR + 2 * N); + v3 = LoadU(d, readR + 3 * N); + hwy::Prefetch(readR - 3 * kUnroll * N); + } else { + v0 = LoadU(d, readL + 0 * N); + v1 = LoadU(d, readL + 1 * N); + v2 = LoadU(d, readL + 2 * N); + v3 = LoadU(d, readL + 3 * N); + readL += kUnroll * N; + hwy::Prefetch(readL + 3 * kUnroll * N); + } + + StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, remaining); + } + + // Now finish writing the saved vectors to the middle. + StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, remaining); + + StoreLeftRight(d, st, vR0, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, vR1, pivot, keys, writeL, remaining); + + // Switch back to updating writeR for clarity. The middle is missing vR2/3 + // and what is in the buffer. + HWY_DASSERT(remaining == bufL + 2 * N); + writeR = writeL + remaining; + // Switch to StoreRightAndBuf for the last two vectors because + // StoreLeftRight may overwrite prior keys. + StoreRightAndBuf(d, st, vR2, pivot, keys, writeR, buf, bufL); + StoreRightAndBuf(d, st, vR3, pivot, keys, writeR, buf, bufL); + HWY_DASSERT(writeR <= num); // >= 0 + HWY_DASSERT(bufL <= Constants::PartitionBufNum(N)); + } + + // We have partitioned [0, num) into [0, writeL) and [writeR, num). + // Now insert left keys from `buf` to empty space starting at writeL. + HWY_DASSERT(writeL + bufL == writeR); + CopyBytes(buf, keys + writeL, bufL * sizeof(T)); + + return writeL + bufL; +} + +// Returns true and partitions if [keys, keys + num) contains only {valueL, +// valueR}. Otherwise, sets third to the first differing value; keys may have +// been reordered and a regular Partition is still necessary. +// Called from two locations, hence NOINLINE. +template +HWY_NOINLINE bool MaybePartitionTwoValue(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, const Vec valueL, + const Vec valueR, Vec& third, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + // No guarantee that num >= N because this is called for subarrays! + + size_t i = 0; + size_t writeL = 0; + + // As long as all lanes are equal to L or R, we can overwrite with valueL. + // This is faster than first counting, then backtracking to fill L and R. + if (num >= N) { + for (; i <= num - N; i += N) { + const Vec v = LoadU(d, keys + i); + // It is not clear how to apply OrXor here - that can check if *both* + // comparisons are true, but here we want *either*. Comparing the unsigned + // min of differences to zero works, but is expensive for u64 prior to + // AVX3. + const Mask eqL = st.EqualKeys(d, v, valueL); + const Mask eqR = st.EqualKeys(d, v, valueR); + // At least one other value present; will require a regular partition. + // On AVX-512, Or + AllTrue are folded into a single kortest if we are + // careful with the FindKnownFirstTrue argument, see below. + if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { + // If we repeat Or(eqL, eqR) here, the compiler will hoist it into the + // loop, which is a pessimization because this if-true branch is cold. + // We can defeat this via Not(Xor), which is equivalent because eqL and + // eqR cannot be true at the same time. Can we elide the additional Not? + // FindFirstFalse instructions are generally unavailable, but we can + // fuse Not and Xor/Or into one ExclusiveNeither. + const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); + third = st.SetKey(d, keys + i + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at vec %zu; writeL %zu\n", i, + writeL); + } + // 'Undo' what we did by filling the remainder of what we read with R. + if (i >= N) { + for (; writeL <= i - N; writeL += N) { + StoreU(valueR, d, keys + writeL); + } + } + StoreN(valueR, d, keys + writeL, i - writeL); + return false; + } + StoreU(valueL, d, keys + writeL); + writeL += CountTrue(d, eqL); + } + } + + // Final vector, masked comparison (no effect if i == num) + const size_t remaining = num - i; + SafeCopyN(remaining, d, keys + i, buf); + const Vec v = Load(d, buf); + const Mask valid = FirstN(d, remaining); + const Mask eqL = And(st.EqualKeys(d, v, valueL), valid); + const Mask eqR = st.EqualKeys(d, v, valueR); + // Invalid lanes are considered equal. + const Mask eq = Or(Or(eqL, eqR), Not(valid)); + // At least one other value present; will require a regular partition. + if (HWY_UNLIKELY(!AllTrue(d, eq))) { + const size_t lane = FindKnownFirstTrue(d, Not(eq)); + third = st.SetKey(d, keys + i + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at partial vec %zu; writeL %zu\n", i, + writeL); + } + // 'Undo' what we did by filling the remainder of what we read with R. + if (i >= N) { + for (; writeL <= i - N; writeL += N) { + StoreU(valueR, d, keys + writeL); + } + } + StoreN(valueR, d, keys + writeL, i - writeL); + return false; + } + StoreN(valueL, d, keys + writeL, remaining); + writeL += CountTrue(d, eqL); + + // Fill right side + i = writeL; + if (num >= N) { + for (; i <= num - N; i += N) { + StoreU(valueR, d, keys + i); + } + } + StoreN(valueR, d, keys + i, num - i); + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Successful MaybePartitionTwoValue\n"); + } + return true; +} + +// Same as above, except that the pivot equals valueR, so scan right to left. +template +HWY_INLINE bool MaybePartitionTwoValueR(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, const Vec valueL, + const Vec valueR, Vec& third, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + + HWY_DASSERT(num >= N); + size_t pos = num - N; // current read/write position + size_t countR = 0; // number of valueR found + + // For whole vectors, in descending address order: as long as all lanes are + // equal to L or R, overwrite with valueR. This is faster than counting, then + // filling both L and R. Loop terminates after unsigned wraparound. + for (; pos < num; pos -= N) { + const Vec v = LoadU(d, keys + pos); + // It is not clear how to apply OrXor here - that can check if *both* + // comparisons are true, but here we want *either*. Comparing the unsigned + // min of differences to zero works, but is expensive for u64 prior to AVX3. + const Mask eqL = st.EqualKeys(d, v, valueL); + const Mask eqR = st.EqualKeys(d, v, valueR); + // If there is a third value, stop and undo what we've done. On AVX-512, + // Or + AllTrue are folded into a single kortest, but only if we are + // careful with the FindKnownFirstTrue argument - see prior comment on that. + if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { + const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); + third = st.SetKey(d, keys + pos + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at vec %zu; countR %zu\n", pos, + countR); + MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); + } + pos += N; // rewind: we haven't yet committed changes in this iteration. + // We have filled [pos, num) with R, but only countR of them should have + // been written. Rewrite [pos, num - countR) to L. + HWY_DASSERT(countR <= num - pos); + const size_t endL = num - countR; + if (endL >= N) { + for (; pos <= endL - N; pos += N) { + StoreU(valueL, d, keys + pos); + } + } + StoreN(valueL, d, keys + pos, endL - pos); + return false; + } + StoreU(valueR, d, keys + pos); + countR += CountTrue(d, eqR); + } + + // Final partial (or empty) vector, masked comparison. + const size_t remaining = pos + N; + HWY_DASSERT(remaining <= N); + const Vec v = LoadU(d, keys); // Safe because num >= N. + const Mask valid = FirstN(d, remaining); + const Mask eqL = st.EqualKeys(d, v, valueL); + const Mask eqR = And(st.EqualKeys(d, v, valueR), valid); + // Invalid lanes are considered equal. + const Mask eq = Or(Or(eqL, eqR), Not(valid)); + // At least one other value present; will require a regular partition. + if (HWY_UNLIKELY(!AllTrue(d, eq))) { + const size_t lane = FindKnownFirstTrue(d, Not(eq)); + third = st.SetKey(d, keys + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at partial vec %zu; writeR %zu\n", pos, + countR); + MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); + } + pos += N; // rewind: we haven't yet committed changes in this iteration. + // We have filled [pos, num) with R, but only countR of them should have + // been written. Rewrite [pos, num - countR) to L. + HWY_DASSERT(countR <= num - pos); + const size_t endL = num - countR; + if (endL >= N) { + for (; pos <= endL - N; pos += N) { + StoreU(valueL, d, keys + pos); + } + } + StoreN(valueL, d, keys + pos, endL - pos); + return false; + } + const size_t lastR = CountTrue(d, eqR); + countR += lastR; + + // First finish writing valueR - [0, N) lanes were not yet written. + StoreU(valueR, d, keys); // Safe because num >= N. + + // Fill left side (ascending order for clarity) + const size_t endL = num - countR; + size_t i = 0; + if (endL >= N) { + for (; i <= endL - N; i += N) { + StoreU(valueL, d, keys + i); + } + } + Store(valueL, d, buf); + SafeCopyN(endL - i, d, buf, keys + i); // avoids ASan overrun + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, + "MaybePartitionTwoValueR countR %zu pos %zu i %zu endL %zu\n", + countR, pos, i, endL); + } + + return true; +} + +// `idx_second` is `first_mismatch` from `AllEqual` and thus the index of the +// second key. This is the first path into `MaybePartitionTwoValue`, called +// when all samples are equal. Returns false if there are at least a third +// value and sets `third`. Otherwise, partitions the array and returns true. +template +HWY_INLINE bool PartitionIfTwoKeys(D d, Traits st, const Vec pivot, + T* HWY_RESTRICT keys, size_t num, + const size_t idx_second, const Vec second, + Vec& third, T* HWY_RESTRICT buf) { + // True if second comes before pivot. + const bool is_pivotR = AllFalse(d, st.Compare(d, pivot, second)); + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "Samples all equal, diff at %zu, isPivotR %d\n", idx_second, + is_pivotR); + } + HWY_DASSERT(AllFalse(d, st.EqualKeys(d, second, pivot))); + + // If pivot is R, we scan backwards over the entire array. Otherwise, + // we already scanned up to idx_second and can leave those in place. + return is_pivotR ? MaybePartitionTwoValueR(d, st, keys, num, second, pivot, + third, buf) + : MaybePartitionTwoValue(d, st, keys + idx_second, + num - idx_second, pivot, second, + third, buf); +} + +// Second path into `MaybePartitionTwoValue`, called when not all samples are +// equal. `samples` is sorted. +template +HWY_INLINE bool PartitionIfTwoSamples(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = Constants::SampleLanes(); + constexpr size_t N1 = st.LanesPerKey(); + const Vec valueL = st.SetKey(d, samples); + const Vec valueR = st.SetKey(d, samples + kSampleLanes - N1); + HWY_DASSERT(AllTrue(d, st.Compare(d, valueL, valueR))); + HWY_DASSERT(AllFalse(d, st.EqualKeys(d, valueL, valueR))); + const Vec prev = st.PrevValue(d, valueR); + // If the sample has more than two values, then the keys have at least that + // many, and thus this special case is inapplicable. + if (HWY_UNLIKELY(!AllTrue(d, st.EqualKeys(d, valueL, prev)))) { + return false; + } + + // Must not overwrite samples because if this returns false, caller wants to + // read the original samples again. + T* HWY_RESTRICT buf = samples + kSampleLanes; + Vec third; // unused + return MaybePartitionTwoValue(d, st, keys, num, valueL, valueR, third, buf); +} + +// ------------------------------ Pivot sampling + +template +HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) { + const DFromV d; + // Slightly faster for 128-bit, apparently because not serially dependent. + if (st.Is128()) { + // Median = XOR-sum 'minus' the first and last. Calling First twice is + // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR. + const V sum = Xor(Xor(v0, v1), v2); + const V first = st.First(d, st.First(d, v0, v1), v2); + const V last = st.Last(d, st.Last(d, v0, v1), v2); + return Xor(Xor(sum, first), last); + } + st.Sort2(d, v0, v2); + v1 = st.Last(d, v0, v1); + v1 = st.First(d, v1, v2); + return v1; +} + +// Returns slightly biased random index of a chunk in [0, num_chunks). +// See https://www.pcg-random.org/posts/bounded-rands.html. +HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) { + const uint64_t chunk_index = (static_cast(bits) * num_chunks) >> 32; + HWY_DASSERT(chunk_index < num_chunks); + return static_cast(chunk_index); +} + +// Writes samples from `keys[0, num)` into `buf`. +template +HWY_INLINE void DrawSamples(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf, uint64_t* HWY_RESTRICT state) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + // Power of two + constexpr size_t kLanesPerChunk = Constants::LanesPerChunk(sizeof(T)); + + // Align start of keys to chunks. We have at least 2 chunks (x 64 bytes) + // because the base case handles anything up to 8 vectors (x 16 bytes). + HWY_DASSERT(num >= Constants::SampleLanes()); + const size_t misalign = + (reinterpret_cast(keys) / sizeof(T)) & (kLanesPerChunk - 1); + if (misalign != 0) { + const size_t consume = kLanesPerChunk - misalign; + keys += consume; + num -= consume; + } + + // Generate enough random bits for 6 uint32 + uint32_t bits[6]; + for (size_t i = 0; i < 6; i += 2) { + const uint64_t bits64 = RandomBits(state); + CopyBytes<8>(&bits64, bits + i); + } + + const size_t num_chunks64 = num / kLanesPerChunk; + // Clamp to uint32 for RandomChunkIndex + const uint32_t num_chunks = + static_cast(HWY_MIN(num_chunks64, 0xFFFFFFFFull)); + + const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) * kLanesPerChunk; + const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) * kLanesPerChunk; + const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) * kLanesPerChunk; + const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) * kLanesPerChunk; + const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) * kLanesPerChunk; + const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) * kLanesPerChunk; + for (size_t i = 0; i < kLanesPerChunk; i += N) { + const V v0 = Load(d, keys + offset0 + i); + const V v1 = Load(d, keys + offset1 + i); + const V v2 = Load(d, keys + offset2 + i); + const V medians0 = MedianOf3(st, v0, v1, v2); + Store(medians0, d, buf + i); + + const V v3 = Load(d, keys + offset3 + i); + const V v4 = Load(d, keys + offset4 + i); + const V v5 = Load(d, keys + offset5 + i); + const V medians1 = MedianOf3(st, v3, v4, v5); + Store(medians1, d, buf + i + kLanesPerChunk); + } +} + +template +V OrXor(const V o, const V x1, const V x2) { + return Or(o, Xor(x1, x2)); // TERNLOG on AVX3 +} + +// For detecting inputs where (almost) all keys are equal. +template +HWY_INLINE bool UnsortedSampleEqual(D d, Traits st, + const TFromD* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = Constants::SampleLanes>(); + const size_t N = Lanes(d); + // Both are powers of two, so there will be no remainders. + HWY_DASSERT(N < kSampleLanes); + using V = Vec; + + const V first = st.SetKey(d, samples); + + if (!hwy::IsFloat>()) { + // OR of XOR-difference may be faster than comparison. + V diff = Zero(d); + for (size_t i = 0; i < kSampleLanes; i += N) { + const V v = Load(d, samples + i); + diff = OrXor(diff, first, v); + } + return st.NoKeyDifference(d, diff); + } else { + // Disable the OrXor optimization for floats because OrXor will not treat + // subnormals the same as actual comparisons, leading to logic errors for + // 2-value cases. + for (size_t i = 0; i < kSampleLanes; i += N) { + const V v = Load(d, samples + i); + if (!AllTrue(d, st.EqualKeys(d, v, first))) { + return false; + } + } + return true; + } +} + +template +HWY_INLINE void SortSamples(D d, Traits st, T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + constexpr size_t kSampleLanes = Constants::SampleLanes(); + // Network must be large enough to sort two chunks. + HWY_DASSERT(Constants::BaseCaseNumLanes(N) >= kSampleLanes); + + BaseCase(d, st, buf, kSampleLanes, buf + kSampleLanes); + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Samples:\n"); + for (size_t i = 0; i < kSampleLanes; i += N) { + MaybePrintVector(d, "", Load(d, buf + i), 0, N); + } + } +} + +// ------------------------------ Pivot selection + +enum class PivotResult { + kDone, // stop without partitioning (all equal, or two-value partition) + kNormal, // partition and recurse left and right + kIsFirst, // partition but skip left recursion + kWasLast, // partition but skip right recursion +}; + +HWY_INLINE const char* PivotResultString(PivotResult result) { + switch (result) { + case PivotResult::kDone: + return "done"; + case PivotResult::kNormal: + return "normal"; + case PivotResult::kIsFirst: + return "first"; + case PivotResult::kWasLast: + return "last"; + } + return "unknown"; +} + +// (Could vectorize, but only 0.2% of total time) +template +HWY_INLINE size_t PivotRank(Traits st, const T* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = Constants::SampleLanes(); + constexpr size_t N1 = st.LanesPerKey(); + + constexpr size_t kRankMid = kSampleLanes / 2; + static_assert(kRankMid % N1 == 0, "Mid is not an aligned key"); + + // Find the previous value not equal to the median. + size_t rank_prev = kRankMid - N1; + for (; st.Equal1(samples + rank_prev, samples + kRankMid); rank_prev -= N1) { + // All previous samples are equal to the median. + if (rank_prev == 0) return 0; + } + + size_t rank_next = rank_prev + N1; + for (; st.Equal1(samples + rank_next, samples + kRankMid); rank_next += N1) { + // The median is also the largest sample. If it is also the largest key, + // we'd end up with an empty right partition, so choose the previous key. + if (rank_next == kSampleLanes - N1) return rank_prev; + } + + // If we choose the median as pivot, the ratio of keys ending in the left + // partition will likely be rank_next/kSampleLanes (if the sample is + // representative). This is because equal-to-pivot values also land in the + // left - it's infeasible to do an in-place vectorized 3-way partition. + // Check whether prev would lead to a more balanced partition. + const size_t excess_if_median = rank_next - kRankMid; + const size_t excess_if_prev = kRankMid - rank_prev; + return excess_if_median < excess_if_prev ? kRankMid : rank_prev; +} + +// Returns pivot chosen from `samples`. It will never be the largest key +// (thus the right partition will never be empty). +template +HWY_INLINE Vec ChoosePivotByRank(D d, Traits st, + const T* HWY_RESTRICT samples) { + const size_t pivot_rank = PivotRank(st, samples); + const Vec pivot = st.SetKey(d, samples + pivot_rank); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, " Pivot rank %3zu\n", pivot_rank); + HWY_ALIGN T pivot_lanes[MaxLanes(d)]; + Store(pivot, d, pivot_lanes); + using Key = typename Traits::KeyType; + Key key; + CopyBytes(pivot_lanes, &key); + PrintValue(key); + } + // Verify pivot is not equal to the last sample. + constexpr size_t kSampleLanes = Constants::SampleLanes(); + constexpr size_t N1 = st.LanesPerKey(); + const Vec last = st.SetKey(d, samples + kSampleLanes - N1); + const bool all_neq = AllTrue(d, st.NotEqualKeys(d, pivot, last)); + (void)all_neq; + HWY_DASSERT(all_neq); + return pivot; +} + +// Returns true if all keys equal `pivot`, otherwise returns false and sets +// `*first_mismatch' to the index of the first differing key. +template +HWY_INLINE bool AllEqual(D d, Traits st, const Vec pivot, + const T* HWY_RESTRICT keys, size_t num, + size_t* HWY_RESTRICT first_mismatch) { + const size_t N = Lanes(d); + // Ensures we can use overlapping loads for the tail; see HandleSpecialCases. + HWY_DASSERT(num >= N); + const Vec zero = Zero(d); + + // Vector-align keys + i. + const size_t misalign = + (reinterpret_cast(keys) / sizeof(T)) & (N - 1); + HWY_DASSERT(misalign % st.LanesPerKey() == 0); + const size_t consume = N - misalign; + { + const Vec v = LoadU(d, keys); + // Only check masked lanes; consider others to be equal. + const Mask diff = And(FirstN(d, consume), st.NotEqualKeys(d, v, pivot)); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = lane; + return false; + } + } + size_t i = consume; + HWY_DASSERT(((reinterpret_cast(keys + i) / sizeof(T)) & (N - 1)) == + 0); + + // Disable the OrXor optimization for floats because OrXor will not treat + // subnormals the same as actual comparisons, leading to logic errors for + // 2-value cases. + if (!hwy::IsFloat()) { + // Sticky bits registering any difference between `keys` and the first key. + // We use vector XOR because it may be cheaper than comparisons, especially + // for 128-bit. 2x unrolled for more ILP. + Vec diff0 = zero; + Vec diff1 = zero; + + // We want to stop once a difference has been found, but without slowing + // down the loop by comparing during each iteration. The compromise is to + // compare after a 'group', which consists of kLoops times two vectors. + constexpr size_t kLoops = 8; + const size_t lanes_per_group = kLoops * 2 * N; + + if (num >= lanes_per_group) { + for (; i <= num - lanes_per_group; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec v0 = Load(d, keys + i + loop * 2 * N); + const Vec v1 = Load(d, keys + i + loop * 2 * N + N); + diff0 = OrXor(diff0, v0, pivot); + diff1 = OrXor(diff1, v1, pivot); + } + + // If there was a difference in the entire group: + if (HWY_UNLIKELY(!st.NoKeyDifference(d, Or(diff0, diff1)))) { + // .. then loop until the first one, with termination guarantee. + for (;; i += N) { + const Vec v = Load(d, keys + i); + const Mask diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + } + } + } + } + } // !hwy::IsFloat() + + // Whole vectors, no unrolling, compare directly + for (; i <= num - N; i += N) { + const Vec v = Load(d, keys + i); + const Mask diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + } + // Always re-check the last (unaligned) vector to reduce branching. + i = num - N; + const Vec v = LoadU(d, keys + i); + const Mask diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "All keys equal\n"); + } + return true; // all equal +} + +// Called from 'two locations', but only one is active (IsKV is constexpr). +template +HWY_INLINE bool ExistsAnyBefore(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, const Vec pivot) { + const size_t N = Lanes(d); + HWY_DASSERT(num >= N); // See HandleSpecialCases + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Scanning for before\n"); + } + + size_t i = 0; + + constexpr size_t kLoops = 16; + const size_t lanes_per_group = kLoops * N; + + Vec first = pivot; + + // Whole group, unrolled + if (num >= lanes_per_group) { + for (; i <= num - lanes_per_group; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec curr = LoadU(d, keys + i + loop * N); + first = st.First(d, first, curr); + } + + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, first, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at end of group %zu\n", + i + lanes_per_group); + } + return true; + } + } + } + // Whole vectors, no unrolling + for (; i <= num - N; i += N) { + const Vec curr = LoadU(d, keys + i); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at %zu\n", i); + } + return true; + } + } + // If there are remainders, re-check the last whole vector. + if (HWY_LIKELY(i != num)) { + const Vec curr = LoadU(d, keys + num - N); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at last %zu\n", num - N); + } + return true; + } + } + + return false; // pivot is the first +} + +// Called from 'two locations', but only one is active (IsKV is constexpr). +template +HWY_INLINE bool ExistsAnyAfter(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, const Vec pivot) { + const size_t N = Lanes(d); + HWY_DASSERT(num >= N); // See HandleSpecialCases + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Scanning for after\n"); + } + + size_t i = 0; + + constexpr size_t kLoops = 16; + const size_t lanes_per_group = kLoops * N; + + Vec last = pivot; + + // Whole group, unrolled + if (num >= lanes_per_group) { + for (; i + lanes_per_group <= num; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec curr = LoadU(d, keys + i + loop * N); + last = st.Last(d, last, curr); + } + + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, last)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at end of group %zu\n", + i + lanes_per_group); + } + return true; + } + } + } + // Whole vectors, no unrolling + for (; i <= num - N; i += N) { + const Vec curr = LoadU(d, keys + i); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at %zu\n", i); + } + return true; + } + } + // If there are remainders, re-check the last whole vector. + if (HWY_LIKELY(i != num)) { + const Vec curr = LoadU(d, keys + num - N); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at last %zu\n", num - N); + } + return true; + } + } + + return false; // pivot is the last +} + +// Returns pivot chosen from `keys[0, num)`. It will never be the largest key +// (thus the right partition will never be empty). +template +HWY_INLINE Vec ChoosePivotForEqualSamples(D d, Traits st, + T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT samples, + Vec second, Vec third, + PivotResult& result) { + const Vec pivot = st.SetKey(d, samples); // the single unique sample + + // Early out for mostly-0 arrays, where pivot is often FirstValue. + if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.FirstValue(d))))) { + result = PivotResult::kIsFirst; + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Pivot equals first possible value\n"); + } + return pivot; + } + if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.LastValue(d))))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Pivot equals last possible value\n"); + } + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // If key-value, we didn't run PartitionIfTwo* and thus `third` is unknown and + // cannot be used. + if (st.IsKV()) { + // If true, pivot is either middle or last. + const bool before = !AllFalse(d, st.Compare(d, second, pivot)); + if (HWY_UNLIKELY(before)) { + // Not last, so middle. + if (HWY_UNLIKELY(ExistsAnyAfter(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + + // We didn't find anything after pivot, so it is the last. Because keys + // equal to the pivot go to the left partition, the right partition would + // be empty and Partition will not have changed anything. Instead use the + // previous value in sort order, which is not necessarily an actual key. + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // Otherwise, pivot is first or middle. Rule out it being first: + if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + // It is first: fall through to shared code below. + } else { + // Check if pivot is between two known values. If so, it is not the first + // nor the last and we can avoid scanning. + st.Sort2(d, second, third); + HWY_DASSERT(AllTrue(d, st.Compare(d, second, third))); + const bool before = !AllFalse(d, st.Compare(d, second, pivot)); + const bool after = !AllFalse(d, st.Compare(d, pivot, third)); + // Only reached if there are three keys, which means pivot is either first, + // last, or in between. Thus there is another key that comes before or + // after. + HWY_DASSERT(before || after); + if (HWY_UNLIKELY(before)) { + // Neither first nor last. + if (HWY_UNLIKELY(after || ExistsAnyAfter(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + + // We didn't find anything after pivot, so it is the last. Because keys + // equal to the pivot go to the left partition, the right partition would + // be empty and Partition will not have changed anything. Instead use the + // previous value in sort order, which is not necessarily an actual key. + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // Has after, and we found one before: in the middle. + if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + } + + // Pivot is first. We could consider a special partition mode that only + // reads from and writes to the right side, and later fills in the left + // side, which we know is equal to the pivot. However, that leads to more + // cache misses if the array is large, and doesn't save much, hence is a + // net loss. + result = PivotResult::kIsFirst; + return pivot; +} + +// ------------------------------ Quicksort recursion + +enum class RecurseMode { + kSort, // Sort mode. + kSelect, // Select mode. + // The element pointed at by nth is changed to whatever element + // would occur in that position if [first, last) were sorted. All of + // the elements before this new nth element are less than or equal + // to the elements after the new nth element. + kLooseSelect, // Loose select mode. + // The first n elements will contain the n smallest elements in + // unspecified order +}; + +template +HWY_NOINLINE void PrintMinMax(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 2) { + const size_t N = Lanes(d); + if (num < N) return; + + Vec first = st.LastValue(d); + Vec last = st.FirstValue(d); + + size_t i = 0; + for (; i <= num - N; i += N) { + const Vec v = LoadU(d, keys + i); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + if (HWY_LIKELY(i != num)) { + HWY_DASSERT(num >= N); // See HandleSpecialCases + const Vec v = LoadU(d, keys + num - N); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + + first = st.FirstOfLanes(d, first, buf); + last = st.LastOfLanes(d, last, buf); + MaybePrintVector(d, "first", first, 0, st.LanesPerKey()); + MaybePrintVector(d, "last", last, 0, st.LanesPerKey()); + } +} + +template +HWY_NOINLINE void Recurse(D d, Traits st, T* HWY_RESTRICT keys, + const size_t num, T* HWY_RESTRICT buf, + uint64_t* HWY_RESTRICT state, + const size_t remaining_levels, const size_t k = 0) { + HWY_DASSERT(num != 0); + + const size_t N = Lanes(d); + constexpr size_t kLPK = st.LanesPerKey(); + if (HWY_UNLIKELY(num <= Constants::BaseCaseNumLanes(N))) { + BaseCase(d, st, keys, num, buf); + return; + } + + // Move after BaseCase so we skip printing for small subarrays. + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "\n\n=== Recurse depth=%zu len=%zu k=%zu\n", + remaining_levels, num, k); + PrintMinMax(d, st, keys, num, buf); + } + + DrawSamples(d, st, keys, num, buf, state); + + Vec pivot; + PivotResult result = PivotResult::kNormal; + if (HWY_UNLIKELY(UnsortedSampleEqual(d, st, buf))) { + pivot = st.SetKey(d, buf); + size_t idx_second = 0; + if (HWY_UNLIKELY(AllEqual(d, st, pivot, keys, num, &idx_second))) { + return; + } + HWY_DASSERT(idx_second % st.LanesPerKey() == 0); + // Must capture the value before PartitionIfTwoKeys may overwrite it. + const Vec second = st.SetKey(d, keys + idx_second); + MaybePrintVector(d, "pivot", pivot, 0, st.LanesPerKey()); + MaybePrintVector(d, "second", second, 0, st.LanesPerKey()); + + Vec third = Zero(d); + // Not supported for key-value types because two 'keys' may be equivalent + // but not interchangeable (their values may differ). + if (HWY_UNLIKELY(!st.IsKV() && + PartitionIfTwoKeys(d, st, pivot, keys, num, idx_second, + second, third, buf))) { + return; // Done, skip recursion because each side has all-equal keys. + } + + // We can no longer start scanning from idx_second because + // PartitionIfTwoKeys may have reordered keys. + pivot = ChoosePivotForEqualSamples(d, st, keys, num, buf, second, third, + result); + // If kNormal, `pivot` is very common but not the first/last. It is + // tempting to do a 3-way partition (to avoid moving the =pivot keys a + // second time), but that is a net loss due to the extra comparisons. + } else { + SortSamples(d, st, buf); + + // Not supported for key-value types because two 'keys' may be equivalent + // but not interchangeable (their values may differ). + if (HWY_UNLIKELY(!st.IsKV() && + PartitionIfTwoSamples(d, st, keys, num, buf))) { + return; + } + + pivot = ChoosePivotByRank(d, st, buf); + } + + // Too many recursions. This is unlikely to happen because we select pivots + // from large (though still O(1)) samples. + if (HWY_UNLIKELY(remaining_levels == 0)) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "HeapSort reached, size=%zu\n", num); + } + HeapSort(st, keys, num); // Slow but N*logN. + return; + } + + const size_t bound = Partition(d, st, keys, num, pivot, buf); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "bound %zu num %zu result %s\n", bound, num, + PivotResultString(result)); + } + // The left partition is not empty because the pivot is usually one of the + // keys. Exception: if kWasLast, we set pivot to PrevValue(pivot), but we + // still have at least one value <= pivot because AllEqual ruled out the case + // of only one unique value. Note that for floating-point, PrevValue can + // return the same value (for -inf inputs), but that would just mean the + // pivot is again one of the keys. + using Order = typename Traits::Order; + (void)Order::IsAscending(); + HWY_DASSERT_M(bound != 0, + (Order::IsAscending() ? "Ascending" : "Descending")); + // ChoosePivot* ensure pivot != last, so the right partition is never empty + // except in the rare case of the pivot matching the last-in-sort-order value, + // which implies we anyway skip the right partition due to kWasLast. + HWY_DASSERT(bound != num || result == PivotResult::kWasLast); + + HWY_IF_CONSTEXPR(mode == RecurseMode::kSelect) { + if (HWY_LIKELY(result != PivotResult::kIsFirst) && k < bound) { + Recurse(d, st, keys, bound, buf, state, + remaining_levels - 1, k); + } else if (HWY_LIKELY(result != PivotResult::kWasLast) && k >= bound) { + Recurse(d, st, keys + bound, num - bound, buf, + state, remaining_levels - 1, k - bound); + } + } + HWY_IF_CONSTEXPR(mode == RecurseMode::kSort) { + if (HWY_LIKELY(result != PivotResult::kIsFirst)) { + Recurse(d, st, keys, bound, buf, state, + remaining_levels - 1); + } + if (HWY_LIKELY(result != PivotResult::kWasLast)) { + Recurse(d, st, keys + bound, num - bound, buf, state, + remaining_levels - 1); + } + } +} + +// Returns true if sorting is finished. +template +HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + constexpr size_t kLPK = st.LanesPerKey(); + const size_t base_case_num = Constants::BaseCaseNumLanes(N); + + // Recurse will also check this, but doing so here first avoids setting up + // the random generator state. + if (HWY_UNLIKELY(num <= base_case_num)) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "Special-casing small, %zu lanes\n", num); + } + BaseCase(d, st, keys, num, buf); + return true; + } + + // 128-bit keys require vectors with at least two u64 lanes, which is always + // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the + // hardware vector width is less than 128bit / fraction. + const bool partial_128 = !IsFull(d) && N < 2 && st.Is128(); + // Partition assumes its input is at least two vectors. If vectors are huge, + // base_case_num may actually be smaller. If so, which is only possible on + // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of + // HWY_LANES to account for the largest possible LMUL. + constexpr bool kPotentiallyHuge = + HWY_MAX_BYTES / sizeof(T) > Constants::kMaxRows * Constants::kMaxCols; + const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num); + if (partial_128 || huge_vec) { + if (VQSORT_PRINT >= 1) { + HWY_WARN("using slow HeapSort: partial %d huge %d\n", partial_128, + huge_vec); + } + HeapSort(st, keys, num); + return true; + } + + // We could also check for already sorted/reverse/equal, but that's probably + // counterproductive if vqsort is used as a base case. + + return false; // not finished sorting +} + +#endif // VQSORT_ENABLED + +template +HWY_INLINE size_t CountAndReplaceNaN(D d, Traits st, T* HWY_RESTRICT keys, + size_t num) { + const size_t N = Lanes(d); + // Will be sorted to the back of the array. + const Vec sentinel = st.LastValue(d); + size_t num_nan = 0; + size_t i = 0; + if (num >= N) { + for (; i <= num - N; i += N) { + const Mask is_nan = IsNaN(LoadU(d, keys + i)); + BlendedStore(sentinel, is_nan, d, keys + i); + num_nan += CountTrue(d, is_nan); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < N); + const Vec v = LoadN(d, keys + i, remaining); + const Mask is_nan = IsNaN(v); + StoreN(IfThenElse(is_nan, sentinel, v), d, keys + i, remaining); + num_nan += CountTrue(d, is_nan); + return num_nan; +} + +// IsNaN is not implemented for non-float, so skip it. +template +HWY_INLINE size_t CountAndReplaceNaN(D, Traits, T* HWY_RESTRICT, size_t) { + return 0; +} + +} // namespace detail + +// Old interface with user-specified buffer, retained for compatibility. Called +// by the newer overload below. `buf` must be vector-aligned and hold at least +// SortConstants::BufBytes(HWY_MAX_BYTES, st.LanesPerKey()). +template +void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, + T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "=============== Sort %s num=%zu, vec bytes=%zu\n", + st.KeyString(), num, sizeof(T) * Lanes(d)); + } + +#if HWY_MAX_BYTES > 64 + // sorting_networks-inl and traits assume no more than 512 bit vectors. + if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { + return Sort(CappedTag(), st, keys, num, buf); + } +#endif // HWY_MAX_BYTES > 64 + + const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num); + +#if VQSORT_ENABLED || HWY_IDE + if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { + uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic(); + // Introspection: switch to worst-case N*logN heapsort after this many. + // Should never be reached, so computing log2 exactly does not help. + const size_t max_levels = 50; + detail::Recurse(d, st, keys, num, buf, state, + max_levels); + } +#else // !VQSORT_ENABLED + (void)d; + (void)buf; + if (VQSORT_PRINT >= 1) { + HWY_WARN("using slow HeapSort because vqsort disabled\n"); + } + detail::HeapSort(st, keys, num); +#endif // VQSORT_ENABLED + + if (num_nan != 0) { + Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan); + } +} + +template +void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, size_t num, size_t k, + T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, + "=============== PartialSort %s num=%zu, k=%zu vec bytes=%zu\n", + st.KeyString(), num, k, sizeof(T) * Lanes(d)); + } + HWY_DASSERT(k <= num); + +#if HWY_MAX_BYTES > 64 + // sorting_networks-inl and traits assume no more than 512 bit vectors. + if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { + return PartialSort(CappedTag(), st, keys, num, k, buf); + } +#endif // HWY_MAX_BYTES > 64 + + const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num); + +#if VQSORT_ENABLED || HWY_IDE + if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { // TODO + uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic(); + // Introspection: switch to worst-case N*logN heapsort after this many. + // Should never be reached, so computing log2 exactly does not help. + const size_t max_levels = 50; + // TODO: optimize to use kLooseSelect + detail::Recurse(d, st, keys, num, buf, state, + max_levels, k); + detail::Recurse(d, st, keys, k, buf, state, + max_levels); + } +#else // !VQSORT_ENABLED + (void)d; + (void)buf; + if (VQSORT_PRINT >= 1) { + HWY_WARN("using slow HeapSort because vqsort disabled\n"); + } + detail::HeapPartialSort(st, keys, num, k); +#endif // VQSORT_ENABLED + + if (num_nan != 0) { + Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan); + } +} + +template +void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, + const size_t k, T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "=============== Select %s num=%zu, k=%zu vec bytes=%zu\n", + st.KeyString(), num, k, sizeof(T) * Lanes(d)); + } + HWY_DASSERT(k < num); + +#if HWY_MAX_BYTES > 64 + // sorting_networks-inl and traits assume no more than 512 bit vectors. + if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { + return Select(CappedTag(), st, keys, num, k, buf); + } +#endif // HWY_MAX_BYTES > 64 + + const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num); + +#if VQSORT_ENABLED || HWY_IDE + if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { // TODO + uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic(); + // Introspection: switch to worst-case N*logN heapsort after this many. + // Should never be reached, so computing log2 exactly does not help. + const size_t max_levels = 50; + detail::Recurse(d, st, keys, num, buf, state, + max_levels, k); + } +#else // !VQSORT_ENABLED + (void)d; + (void)buf; + if (VQSORT_PRINT >= 1) { + HWY_WARN("using slow HeapSort because vqsort disabled\n"); + } + detail::HeapSelect(st, keys, num, k); +#endif // VQSORT_ENABLED + + if (num_nan != 0) { + Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan); + } +} + +// Sorts `keys[0..num-1]` according to the order defined by `st.Compare`. +// In-place i.e. O(1) additional storage. Worst-case N*logN comparisons. +// Non-stable (order of equal keys may change), except for the common case where +// the upper bits of T are the key, and the lower bits are a sequential or at +// least unique ID. Any NaN will be moved to the back of the array and replaced +// with the canonical NaN(d). +// There is no upper limit on `num`, but note that pivots may be chosen by +// sampling only from the first 256 GiB. +// +// `d` is typically SortTag (chooses between full and partial vectors). +// `st` is SharedTraits>. This abstraction layer bridges +// differences in sort order and single-lane vs 128-bit keys. +// `num` is in units of `T`, not keys! +template +HWY_API void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num) { + constexpr size_t kLPK = st.LanesPerKey(); + HWY_ALIGN T buf[SortConstants::BufBytes(HWY_MAX_BYTES) / sizeof(T)]; + Sort(d, st, keys, num, buf); +} + +// Rearranges elements such that the range [0, k) contains the sorted first `k` +// elements in the range [0, n) ordered by `st.Compare`. See also the comment +// for `Sort()`; note that `num` and `k` are in units of `T`, not keys! +template +HWY_API void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, + const size_t k) { + constexpr size_t kLPK = st.LanesPerKey(); + HWY_ALIGN T buf[SortConstants::BufBytes(HWY_MAX_BYTES) / sizeof(T)]; + PartialSort(d, st, keys, num, k, buf); +} + +// Reorders `keys[0..num-1]` such that `keys+k` is the k-th element if keys was +// sorted by `st.Compare`, and all of the elements before it are ordered +// by `st.Compare` relative to `keys[k]`. See also the comment for `Sort()`; +// note that `num` and `k` are in units of `T`, not keys! +template +HWY_API void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, + const size_t k) { + constexpr size_t kLPK = st.LanesPerKey(); + HWY_ALIGN T buf[SortConstants::BufBytes(HWY_MAX_BYTES) / sizeof(T)]; + Select(d, st, keys, num, k, buf); +} + +// Translates Key and Order (SortAscending or SortDescending) to SharedTraits. +namespace detail { + +// Primary template for built-in key types = lane type. +template +struct KeyAdapter { + template + using Traits = TraitsLane< + hwy::If, OrderDescending>>; +}; + +template <> +struct KeyAdapter { + template + using Traits = TraitsLane< + hwy::If>; +}; + +// 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + +template <> +struct KeyAdapter { + template + using Traits = Traits128< + hwy::If>; +}; + +template <> +struct KeyAdapter { + template + using Traits = Traits128< + hwy::If>; +}; + +#endif // HWY_TARGET != HWY_SCALAR + +template +using MakeTraits = + SharedTraits::template Traits>; + +} // namespace detail + +// Simpler interface matching VQSort(), but without dynamic dispatch. Uses the +// instructions available in the current target (HWY_NAMESPACE). Supported key +// types: 16-64 bit unsigned/signed/floating-point (but float16/64 only #if +// HWY_HAVE_FLOAT16/64), uint128_t, K64V64, K32V32. Note that `num`, and for +// VQPartialSortStatic/VQSelectStatic also `k`, are in units of *keys*, whereas +// for all functions above this point, they are in units of `T`. Order is either +// SortAscending or SortDescending. +template +void VQSortStatic(Key* HWY_RESTRICT keys, const size_t num_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + const SortTag d; + Sort(d, st, reinterpret_cast(keys), num_keys * st.LanesPerKey()); +} + +template +void VQPartialSortStatic(Key* HWY_RESTRICT keys, const size_t num_keys, + const size_t k_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + const SortTag d; + PartialSort(d, st, reinterpret_cast(keys), + num_keys * st.LanesPerKey(), k_keys * st.LanesPerKey()); +} + +template +void VQSelectStatic(Key* HWY_RESTRICT keys, const size_t num_keys, + const size_t k_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + const SortTag d; + Select(d, st, reinterpret_cast(keys), num_keys * st.LanesPerKey(), + k_keys * st.LanesPerKey()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE diff --git a/third_party/aom/third_party/highway/hwy/contrib/sort/vqsort.h b/third_party/aom/third_party/highway/hwy/contrib/sort/vqsort.h new file mode 100644 index 000000000000..2f0d0f616dfa --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/sort/vqsort.h @@ -0,0 +1,303 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Interface to vectorized quicksort with dynamic dispatch. For static dispatch +// without any DLLEXPORT, avoid including this header and instead define +// VQSORT_ONLY_STATIC, then call VQSortStatic* in vqsort-inl.h. +// +// Blog post: https://tinyurl.com/vqsort-blog +// Paper with measurements: https://arxiv.org/abs/2205.05982 +// +// To ensure the overhead of using wide vectors (e.g. AVX2 or AVX-512) is +// worthwhile, we recommend using this code for sorting arrays whose size is at +// least 100 KiB. See the README for details. + +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ + +// IWYU pragma: begin_exports +#include + +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/contrib/sort/order.h" // SortAscending +// IWYU pragma: end_exports + +namespace hwy { + +// Vectorized Quicksort: sorts keys[0, n). Does not preserve the ordering of +// equivalent keys (defined as: neither greater nor less than another). +// Dispatches to the best available instruction set. Does not allocate memory. +// Uses about 1.2 KiB stack plus an internal 3-word TLS cache for random state. +HWY_CONTRIB_DLLEXPORT void VQSort(uint16_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint16_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint32_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint32_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint64_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint64_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(int16_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(int16_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(int32_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(int32_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n, + SortDescending); + +// These two must only be called if hwy::HaveFloat16() is true. +HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n, + SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n, + SortDescending); + +// These two must only be called if hwy::HaveFloat64() is true. +HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n, + SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQSort(K32V32* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(K32V32* HWY_RESTRICT keys, size_t n, + SortDescending); + +// 128-bit types: `n` is still in units of the 128-bit keys. +HWY_CONTRIB_DLLEXPORT void VQSort(uint128_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint128_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(K64V64* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(K64V64* HWY_RESTRICT keys, size_t n, + SortDescending); + +// Vectorized partial Quicksort: +// Rearranges elements such that the range [0, k) contains the sorted first k +// elements in the range [0, n). Does not preserve the ordering of equivalent +// keys (defined as: neither greater nor less than another). +// Dispatches to the best available instruction set. Does not allocate memory. +// Uses about 1.2 KiB stack plus an internal 3-word TLS cache for random state. +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// These two must only be called if hwy::HaveFloat16() is true. +HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// These two must only be called if hwy::HaveFloat64() is true. +HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQPartialSort(K32V32* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(K32V32* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// 128-bit types: `n` and `k` are still in units of the 128-bit keys. +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint128_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint128_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(K64V64* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(K64V64* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// Vectorized Quickselect: +// rearranges elements in [0, n) such that: +// The element pointed at by kth is changed to whatever element would occur in +// that position if [0, n) were sorted. All of the elements before this new kth +// element are less than or equal to the elements after the new kth element. +HWY_CONTRIB_DLLEXPORT void VQSelect(uint16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// These two must only be called if hwy::HaveFloat16() is true. +HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// These two must only be called if hwy::HaveFloat64() is true. +HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQSelect(K32V32* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(K32V32* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// 128-bit types: `n` and `k` are still in units of the 128-bit keys. +HWY_CONTRIB_DLLEXPORT void VQSelect(uint128_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint128_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(K64V64* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(K64V64* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// User-level caching is no longer required, so this class is no longer +// beneficial. We recommend using the simpler VQSort() interface instead, and +// retain this class only for compatibility. It now just calls VQSort. +class HWY_CONTRIB_DLLEXPORT Sorter { + public: + Sorter(); + ~Sorter() { Delete(); } + + // Move-only + Sorter(const Sorter&) = delete; + Sorter& operator=(const Sorter&) = delete; + Sorter(Sorter&& /*other*/) {} + Sorter& operator=(Sorter&& /*other*/) { return *this; } + + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + // These two must only be called if hwy::HaveFloat16() is true. + void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(float* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(float* HWY_RESTRICT keys, size_t n, SortDescending) const; + + // These two must only be called if hwy::HaveFloat64() is true. + void operator()(double* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(double* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortDescending) const; + + // Unused + static void Fill24Bytes(const void*, size_t, void*); + static bool HaveFloat64(); // Can also use hwy::HaveFloat64 directly. + + private: + void Delete(); + + template + T* Get() const { + return unused_; + } + +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wunused-private-field") +#endif + void* unused_ = nullptr; +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS(pop) +#endif +}; + +// Used by vqsort-inl.h unless VQSORT_ONLY_STATIC. +HWY_CONTRIB_DLLEXPORT bool Fill16BytesSecure(void* bytes); + +// Unused, only provided for binary compatibility. +HWY_CONTRIB_DLLEXPORT uint64_t* GetGeneratorState(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/thread_pool/futex.h b/third_party/aom/third_party/highway/hwy/contrib/thread_pool/futex.h new file mode 100644 index 000000000000..740cbd23fc04 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/thread_pool/futex.h @@ -0,0 +1,247 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_FUTEX_H_ +#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_FUTEX_H_ + +// Keyed event (futex): kernel queue of blocked threads, identified by the +// address of an atomic u32 called `current` within the same process (do NOT +// use with shared-memory mappings). +// +// Futex equivalents: https://outerproduct.net/futex-dictionary.html; we +// support Linux/Emscripten/Apple/Windows and C++20 std::atomic::wait, plus a +// NanoSleep fallback. + +#include + +#include +#include // INT_MAX + +#include "third_party/highway/hwy/base.h" + +#if HWY_ARCH_WASM +#include +#include // INFINITY + +#elif HWY_OS_LINUX +#include // IWYU pragma: keep +#include // FUTEX_* +#include +#include // SYS_* +#include +// Android may not declare these: +#ifndef SYS_futex +#ifdef SYS_futex_time64 // 32-bit with 64-bit time_t +#define SYS_futex SYS_futex_time64 +#else +#define SYS_futex __NR_futex +#endif // SYS_futex_time64 +#endif // SYS_futex +#ifndef FUTEX_WAIT_PRIVATE +#define FUTEX_WAIT_PRIVATE (FUTEX_WAIT | 128) +#endif +#ifndef FUTEX_WAKE_PRIVATE +#define FUTEX_WAKE_PRIVATE (FUTEX_WAKE | 128) +#endif + +#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX) +// These are private APIs, so add an opt-out. +extern "C" { +int __ulock_wait(uint32_t op, void* address, uint64_t val, uint32_t max_us); +int __ulock_wake(uint32_t op, void* address, uint64_t zero); +} // extern "C" +#define UL_COMPARE_AND_WAIT 1 +#define ULF_WAKE_ALL 0x00000100 + +#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX) +// WakeByAddressAll requires Windows 8, so add an opt-out. +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#include +#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL +#pragma comment(lib, "synchronization.lib") +#endif + +#elif HWY_CXX_LANG < 202002L // NOT C++20, which has native support +#define HWY_FUTEX_SLEEP +#endif + +namespace hwy { + +// Attempts to pause for the specified nanoseconds, though the resolution is +// closer to 0.1 microseconds. Returns false if no wait happened. Thread-safe. +static inline bool NanoSleep(uint64_t ns) { +#if HWY_OS_WIN + static thread_local HANDLE hTimer = nullptr; + if (HWY_UNLIKELY(hTimer == nullptr)) { + // Must be manual reset: auto-reset would immediately signal after the next + // SetWaitableTimer. + hTimer = CreateWaitableTimer(nullptr, TRUE, nullptr); + if (hTimer == nullptr) return false; + } + + // Negative means relative, in units of 100 ns. + LARGE_INTEGER time; + time.QuadPart = -static_cast(ns / 100); + const LONG period = 0; // signal once + if (!SetWaitableTimer(hTimer, &time, period, nullptr, nullptr, FALSE)) { + return false; + } + + (void)WaitForSingleObject(hTimer, INFINITE); + return true; +#else + timespec duration; + duration.tv_sec = static_cast(ns / 1000000000); + duration.tv_nsec = static_cast(ns % 1000000000); + timespec remainder; + // Repeat if interrupted by a signal. Note that the remainder may be rounded + // up, which could cause an infinite loop if continually interrupted. Using + // clock_nanosleep would work, but we'd have to get the current time. We + // assume durations are short, and instead just cap the number of retries. + for (int rep = 0; rep < 3; ++rep) { + if (nanosleep(&duration, &remainder) == 0 || errno != EINTR) break; + duration = remainder; + } + return true; +#endif +} + +// Waits until `current != prev` and returns the new value. May return +// immediately if `current` already changed, or after blocking and waking. +static inline uint32_t BlockUntilDifferent( + const uint32_t prev, const std::atomic& current) { + const auto acq = std::memory_order_acquire; + +#if HWY_ARCH_WASM + // It is always safe to cast to void. + volatile void* address = + const_cast(static_cast(¤t)); + const double max_ms = INFINITY; + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + const int ret = emscripten_futex_wait(address, prev, max_ms); + HWY_DASSERT(ret >= 0); + (void)ret; + } + +#elif HWY_OS_LINUX + // Safe to cast because std::atomic is a standard layout type. + const uint32_t* address = reinterpret_cast(¤t); + // _PRIVATE requires this only be used in the same process, and avoids + // virtual->physical lookups and atomic reference counting. + const int op = FUTEX_WAIT_PRIVATE; + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + // timeout=null may prevent interrupts via signal. No lvalue because + // the timespec type is only standardized since C++17 or C11. + const auto ret = syscall(SYS_futex, address, op, prev, nullptr, nullptr, 0); + if (ret == -1) { + HWY_DASSERT(errno == EAGAIN); // otherwise an actual error + } + } + +#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX) + // It is always safe to cast to void. + volatile void* address = + const_cast(static_cast(¤t)); + // API is not const-correct, but only loads from the pointer. + PVOID pprev = const_cast(static_cast(&prev)); + const DWORD max_ms = INFINITE; + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + const BOOL ok = WaitOnAddress(address, pprev, sizeof(prev), max_ms); + HWY_DASSERT(ok); + (void)ok; + } + +#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX) + // It is always safe to cast to void. + void* address = const_cast(static_cast(¤t)); + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + __ulock_wait(UL_COMPARE_AND_WAIT, address, prev, 0); + } + +#elif defined(HWY_FUTEX_SLEEP) + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + NanoSleep(2000); + } + +#elif HWY_CXX_LANG >= 202002L + current.wait(prev, acq); // No spurious wakeup. + const uint32_t next = current.load(acq); + HWY_DASSERT(next != prev); + return next; + +#else +#error "Logic error, should have reached HWY_FUTEX_SLEEP" +#endif // HWY_OS_* +} // BlockUntilDifferent + +// Wakes all threads, if any, that are waiting because they called +// `BlockUntilDifferent` with the same `current`. +static inline void WakeAll(std::atomic& current) { +#if HWY_ARCH_WASM + // It is always safe to cast to void. + volatile void* address = static_cast(¤t); + const int max_to_wake = INT_MAX; // actually signed + const int ret = emscripten_futex_wake(address, max_to_wake); + HWY_DASSERT(ret >= 0); + (void)ret; + +#elif HWY_OS_LINUX + // Safe to cast because std::atomic is a standard layout type. + uint32_t* address = reinterpret_cast(¤t); + const int max_to_wake = INT_MAX; // actually signed + const auto ret = syscall(SYS_futex, address, FUTEX_WAKE_PRIVATE, max_to_wake, + nullptr, nullptr, 0); + HWY_DASSERT(ret >= 0); // number woken + (void)ret; + +#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX) + // It is always safe to cast to void. + void* address = static_cast(¤t); + WakeByAddressAll(address); + +#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX) + // It is always safe to cast to void. + void* address = static_cast(¤t); + __ulock_wake(UL_COMPARE_AND_WAIT | ULF_WAKE_ALL, address, 0); + +#elif defined(HWY_FUTEX_SLEEP) + // NanoSleep loop does not require wakeup. + (void)current; +#elif HWY_CXX_LANG >= 202002L + current.notify_all(); + +#else +#error "Logic error, should have reached HWY_FUTEX_SLEEP" +#endif +} // WakeAll + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_FUTEX_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/thread_pool/spin.h b/third_party/aom/third_party/highway/hwy/contrib/thread_pool/spin.h new file mode 100644 index 000000000000..57973a7610c6 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/thread_pool/spin.h @@ -0,0 +1,328 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_ +#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_ + +// Relatively power-efficient spin lock for low-latency synchronization. + +#include + +#include + +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/cache_control.h" // Pause + +#ifndef HWY_ENABLE_MONITORX // allow override +// Clang 3.9 suffices for mwaitx, but the target pragma requires 9.0. +#if HWY_ARCH_X86 && ((HWY_COMPILER_CLANG >= 900) || \ + (HWY_COMPILER_GCC_ACTUAL >= 502) || defined(__MWAITX__)) +#define HWY_ENABLE_MONITORX 1 +#else +#define HWY_ENABLE_MONITORX 0 +#endif +#endif // HWY_ENABLE_MONITORX + +#ifndef HWY_ENABLE_UMONITOR // allow override +#if HWY_ARCH_X86 && ((HWY_COMPILER_CLANG >= 900) || \ + (HWY_COMPILER_GCC_ACTUAL >= 901) || defined(__WAITPKG__)) +#define HWY_ENABLE_UMONITOR 1 +#else +#define HWY_ENABLE_UMONITOR 0 +#endif +#endif // HWY_ENABLE_UMONITOR + +// Inline assembly is preferred because it allows inlining of `UntilDifferent` +// etc, but we also support intrinsics for MSVC. +#ifndef HWY_ENABLE_SPIN_ASM // allow override +#if (HWY_COMPILER_CLANG || HWY_COMPILER_GCC) && HWY_ARCH_X86_64 +#define HWY_ENABLE_SPIN_ASM 1 +#else +#define HWY_ENABLE_SPIN_ASM 0 +#endif +#endif // HWY_ENABLE_SPIN_ASM + +#if HWY_ENABLE_MONITORX || HWY_ENABLE_UMONITOR +#if HWY_ENABLE_SPIN_ASM +#define HWY_INLINE_SPIN HWY_INLINE // can inline functions with inline assembly +#else +// Intrinsics require attributes, which prevent inlining. +#define HWY_INLINE_SPIN +#include +#endif // HWY_ENABLE_SPIN_ASM + +#include "third_party/highway/hwy/x86_cpuid.h" +#endif // HWY_ENABLE_MONITORX || HWY_ENABLE_UMONITOR + +namespace hwy { + +// Returned by `UntilDifferent` in a single register. +struct SpinResult { + // We also use u32 because that is all that futex.h supports. + uint32_t current; + // Number of retries before returning, useful for checking that the + // monitor/wait did not just return immediately. + uint32_t reps; +}; + +// User-space monitor/wait are supported on Zen2+ AMD and SPR+ Intel. Spin waits +// are rarely called from SIMD code, hence we do not integrate this into +// `HWY_TARGET` and its runtime dispatch mechanism. Returned by `Type()`, also +// used by callers to set the `disabled` argument for `DetectSpin`. +enum class SpinType : uint8_t { + kMonitorX = 1, // AMD + kUMonitor, // Intel + kPause, + kSentinel // for iterating over all enumerators. Must be last. +}; + +// For printing which is in use. +static inline const char* ToString(SpinType type) { + switch (type) { + case SpinType::kMonitorX: + return "MonitorX_C1"; + case SpinType::kUMonitor: + return "UMonitor_C0.2"; + case SpinType::kPause: + return "Pause"; + case SpinType::kSentinel: + return nullptr; + default: + HWY_UNREACHABLE; + } +} + +// Indirect function calls turn out to be too expensive because this is called +// multiple times per ThreadPool barrier. We will instead inline the spin and +// barrier using policy classes. This one is always available; use it as a +// reference for the interface. Note that Pause varies across CPUs: it can be +// a no-op, or wait 140 cycles. +struct SpinPause { + SpinType Type() const { return SpinType::kPause; } + + // Spins until `watched != prev` and returns the new value, similar to + // `BlockUntilDifferent` in `futex.h`. + HWY_INLINE SpinResult UntilDifferent( + const uint32_t prev, const std::atomic& watched) const { + for (uint32_t reps = 0;; ++reps) { + const uint32_t current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + hwy::Pause(); + } + } + + // Returns number of retries until `watched == expected`. + HWY_INLINE size_t UntilEqual(const uint32_t expected, + const std::atomic& watched) const { + for (size_t reps = 0;; ++reps) { + const uint32_t current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + hwy::Pause(); + } + } +}; + +#if HWY_ENABLE_MONITORX || HWY_IDE +#if !HWY_ENABLE_SPIN_ASM +HWY_PUSH_ATTRIBUTES("mwaitx") +#endif + +// AMD's user-mode monitor/wait (Zen2+). +class SpinMonitorX { + public: + SpinType Type() const { return SpinType::kMonitorX; } + + HWY_INLINE_SPIN SpinResult UntilDifferent( + const uint32_t prev, const std::atomic& watched) const { + for (uint32_t reps = 0;; ++reps) { + uint32_t current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + Monitor(&watched); + // Double-checked 'lock' to avoid missed events: + current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + Wait(); + } + } + + HWY_INLINE_SPIN size_t UntilEqual( + const uint32_t expected, const std::atomic& watched) const { + for (size_t reps = 0;; ++reps) { + uint32_t current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + Monitor(&watched); + // Double-checked 'lock' to avoid missed events: + current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + Wait(); + } + } + + private: + static HWY_INLINE void Monitor(const void* addr) { + // No extensions/hints currently defined. +#if HWY_ENABLE_SPIN_ASM + asm volatile("monitorx" ::"a"(addr), "c"(0), "d"(0)); +#else + _mm_monitorx(const_cast(addr), 0, 0); +#endif + } + + static HWY_INLINE void Wait() { +#if HWY_ENABLE_SPIN_ASM + // EBX=0 cycles means no timeout/infinite. + asm volatile("mwaitx" ::"a"(kHints), "b"(0), "c"(kExtensions)); +#else + _mm_mwaitx(kExtensions, kHints, /*cycles=*/0); +#endif + } + + // 0xF would be C0. Its wakeup latency is less than 0.1 us shorter, and + // package power is sometimes actually higher than with Pause. The + // difference in spurious wakeups is minor. + static constexpr unsigned kHints = 0x0; // C1: a bit deeper than C0 + // No timeout required, we assume the mwaitx does not miss stores, see + // https://www.usenix.org/system/files/usenixsecurity23-zhang-ruiyi.pdf.] + static constexpr unsigned kExtensions = 0; +}; + +#if !HWY_ENABLE_SPIN_ASM +HWY_POP_ATTRIBUTES +#endif +#endif // HWY_ENABLE_MONITORX + +#if HWY_ENABLE_UMONITOR || HWY_IDE +#if !HWY_ENABLE_SPIN_ASM +HWY_PUSH_ATTRIBUTES("waitpkg") +#endif + +// Intel's user-mode monitor/wait (SPR+). +class SpinUMonitor { + public: + SpinType Type() const { return SpinType::kUMonitor; } + + HWY_INLINE_SPIN SpinResult UntilDifferent( + const uint32_t prev, const std::atomic& watched) const { + for (uint32_t reps = 0;; ++reps) { + uint32_t current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + Monitor(&watched); + // Double-checked 'lock' to avoid missed events: + current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + Wait(); + } + } + + HWY_INLINE_SPIN size_t UntilEqual( + const uint32_t expected, const std::atomic& watched) const { + for (size_t reps = 0;; ++reps) { + uint32_t current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + Monitor(&watched); + // Double-checked 'lock' to avoid missed events: + current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + Wait(); + } + } + + private: + static HWY_INLINE void Monitor(const void* addr) { +#if HWY_ENABLE_SPIN_ASM + asm volatile("umonitor %%rcx" ::"c"(addr)); +#else + _umonitor(const_cast(addr)); +#endif + } + + static HWY_INLINE void Wait() { +#if HWY_ENABLE_SPIN_ASM + asm volatile("umwait %%ecx" ::"c"(kControl), "d"(kDeadline >> 32), + "a"(kDeadline & 0xFFFFFFFFu)); +#else + _umwait(kControl, kDeadline); +#endif + } + + // 1 would be C0.1. C0.2 has 20x fewer spurious wakeups and additional 4% + // package power savings vs Pause on SPR. It comes at the cost of + // 0.4-0.6us higher wake latency, but the total is comparable to Zen4. + static constexpr unsigned kControl = 0; // C0.2 for deeper sleep + static constexpr uint64_t kDeadline = ~uint64_t{0}; // no timeout, see above +}; + +#if !HWY_ENABLE_SPIN_ASM +HWY_POP_ATTRIBUTES +#endif +#endif // HWY_ENABLE_UMONITOR + +// TODO(janwas): add WFE on Arm. May wake at 10 kHz, but still worthwhile. + +// Returns the best-available type whose bit in `disabled` is not set. Example: +// to disable kUMonitor, pass `1 << static_cast(SpinType::kUMonitor)`. +// Ignores `disabled` for `kPause` if it is the only supported and enabled type. +// Somewhat expensive, typically called during initialization. +static inline SpinType DetectSpin(int disabled = 0) { + const auto HWY_MAYBE_UNUSED enabled = [disabled](SpinType type) { + return (disabled & (1 << static_cast(type))) == 0; + }; + +#if HWY_ENABLE_MONITORX + if (enabled(SpinType::kMonitorX) && x86::IsAMD()) { + uint32_t abcd[4]; + x86::Cpuid(0x80000001U, 0, abcd); + if (x86::IsBitSet(abcd[2], 29)) return SpinType::kMonitorX; + } +#endif // HWY_ENABLE_MONITORX + +#if HWY_ENABLE_UMONITOR + if (enabled(SpinType::kUMonitor) && x86::MaxLevel() >= 7) { + uint32_t abcd[4]; + x86::Cpuid(7, 0, abcd); + if (x86::IsBitSet(abcd[2], 5)) return SpinType::kUMonitor; + } +#endif // HWY_ENABLE_UMONITOR + + if (!enabled(SpinType::kPause)) { + HWY_WARN("Ignoring attempt to disable Pause, it is the only option left."); + } + return SpinType::kPause; +} + +// Calls `func(spin)` for the given `spin_type`. +template +HWY_INLINE void CallWithSpin(SpinType spin_type, Func&& func) { + switch (spin_type) { +#if HWY_ENABLE_MONITORX + case SpinType::kMonitorX: + func(SpinMonitorX()); + break; +#endif +#if HWY_ENABLE_UMONITOR + case SpinType::kUMonitor: + func(SpinUMonitor()); + break; +#endif + case SpinType::kPause: + default: + func(SpinPause()); + break; + } +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/thread_pool/thread_pool.h b/third_party/aom/third_party/highway/hwy/contrib/thread_pool/thread_pool.h new file mode 100644 index 000000000000..d3516174ae07 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/thread_pool/thread_pool.h @@ -0,0 +1,1287 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Modified from BSD-licensed code +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// See https://github.com/libjxl/libjxl/blob/main/LICENSE. + +#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_THREAD_POOL_H_ +#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_THREAD_POOL_H_ + +#include +#include +#include // snprintf + +#include +#include +#include +#include // NOLINT +#include + +#include "third_party/highway/hwy/detect_compiler_arch.h" +#if HWY_OS_FREEBSD +#include +#endif + +#include "third_party/highway/hwy/aligned_allocator.h" // HWY_ALIGNMENT +#include "third_party/highway/hwy/auto_tune.h" +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/cache_control.h" // Pause +#include "third_party/highway/hwy/contrib/thread_pool/futex.h" +#include "third_party/highway/hwy/contrib/thread_pool/spin.h" +#include "third_party/highway/hwy/contrib/thread_pool/topology.h" +#include "third_party/highway/hwy/stats.h" +#include "third_party/highway/hwy/timer.h" + +// Define to HWY_NOINLINE to see profiles of `WorkerRun*` and waits. +#define HWY_POOL_PROFILE + +namespace hwy { + +// Sets the name of the current thread to the format string `format`, which must +// include %d for `thread`. Currently only implemented for pthreads (*nix and +// OSX); Windows involves throwing an exception. +static inline void SetThreadName(const char* format, int thread) { + char buf[16] = {}; // Linux limit, including \0 + const int chars_written = snprintf(buf, sizeof(buf), format, thread); + HWY_ASSERT(0 < chars_written && + chars_written <= static_cast(sizeof(buf) - 1)); + +#if HWY_OS_LINUX && (!defined(__ANDROID__) || __ANDROID_API__ >= 19) + HWY_ASSERT(0 == pthread_setname_np(pthread_self(), buf)); +#elif HWY_OS_FREEBSD + HWY_ASSERT(0 == pthread_set_name_np(pthread_self(), buf)); +#elif HWY_OS_APPLE + // Different interface: single argument, current thread only. + HWY_ASSERT(0 == pthread_setname_np(buf)); +#endif +} + +// Whether workers should block or spin. +enum class PoolWaitMode : uint8_t { kBlock = 1, kSpin }; + +namespace pool { + +#ifndef HWY_POOL_VERBOSITY +#define HWY_POOL_VERBOSITY 0 +#endif + +static constexpr int kVerbosity = HWY_POOL_VERBOSITY; + +// Some CPUs already have more than this many threads, but rather than one +// large pool, we assume applications create multiple pools, ideally per +// cluster (cores sharing a cache), because this improves locality and barrier +// latency. In that case, this is a generous upper bound. +static constexpr size_t kMaxThreads = 63; + +// Generates a random permutation of [0, size). O(1) storage. +class ShuffledIota { + public: + ShuffledIota() : coprime_(1) {} // for Worker + explicit ShuffledIota(uint32_t coprime) : coprime_(coprime) {} + + // Returns the next after `current`, using an LCG-like generator. + uint32_t Next(uint32_t current, const Divisor64& divisor) const { + HWY_DASSERT(current < divisor.GetDivisor()); + // (coprime * i + current) % size, see https://lemire.me/blog/2017/09/18/. + return static_cast(divisor.Remainder(current + coprime_)); + } + + // Returns true if a and b have no common denominator except 1. Based on + // binary GCD. Assumes a and b are nonzero. Also used in tests. + static bool CoprimeNonzero(uint32_t a, uint32_t b) { + const size_t trailing_a = Num0BitsBelowLS1Bit_Nonzero32(a); + const size_t trailing_b = Num0BitsBelowLS1Bit_Nonzero32(b); + // If both have at least one trailing zero, they are both divisible by 2. + if (HWY_MIN(trailing_a, trailing_b) != 0) return false; + + // If one of them has a trailing zero, shift it out. + a >>= trailing_a; + b >>= trailing_b; + + for (;;) { + // Swap such that a >= b. + const uint32_t tmp_a = a; + a = HWY_MAX(tmp_a, b); + b = HWY_MIN(tmp_a, b); + + // When the smaller number is 1, they were coprime. + if (b == 1) return true; + + a -= b; + // a == b means there was a common factor, so not coprime. + if (a == 0) return false; + a >>= Num0BitsBelowLS1Bit_Nonzero32(a); + } + } + + // Returns another coprime >= `start`, or 1 for small `size`. + // Used to seed independent ShuffledIota instances. + static uint32_t FindAnotherCoprime(uint32_t size, uint32_t start) { + if (size <= 2) { + return 1; + } + + // Avoids even x for even sizes, which are sure to be rejected. + const uint32_t inc = (size & 1) ? 1 : 2; + + for (uint32_t x = start | 1; x < start + size * 16; x += inc) { + if (CoprimeNonzero(x, static_cast(size))) { + return x; + } + } + + HWY_UNREACHABLE; + } + + uint32_t coprime_; +}; + +// 'Policies' suitable for various worker counts and locality. To define a +// new class, add an enum and update `ToString` plus `FunctorAddWait`. The +// enumerators must be contiguous so we can iterate over them. +enum class WaitType : uint8_t { + kBlock, + kSpin1, + kSpinSeparate, + kSentinel // Must be last. +}; +enum class BarrierType : uint8_t { + kOrdered, + kCounter1, + kCounter2, + kCounter4, + kGroup2, + kGroup4, + kSentinel // Must be last. +}; + +// For printing which is in use. +static inline const char* ToString(WaitType type) { + switch (type) { + case WaitType::kBlock: + return "Block"; + case WaitType::kSpin1: + return "Single"; + case WaitType::kSpinSeparate: + return "Separate"; + case WaitType::kSentinel: + return nullptr; + default: + HWY_UNREACHABLE; + } +} + +static inline const char* ToString(BarrierType type) { + switch (type) { + case BarrierType::kOrdered: + return "Ordered"; + case BarrierType::kCounter1: + return "Counter1"; + case BarrierType::kCounter2: + return "Counter2"; + case BarrierType::kCounter4: + return "Counter4"; + case BarrierType::kGroup2: + return "Group2"; + case BarrierType::kGroup4: + return "Group4"; + case BarrierType::kSentinel: + return nullptr; + default: + HWY_UNREACHABLE; + } +} + +// We want predictable struct/class sizes so we can reason about cache lines. +#pragma pack(push, 1) + +// Parameters governing the main and worker thread behavior. Can be updated at +// runtime via `SetWaitMode`. Both have copies which are carefully synchronized +// (two-phase barrier). 64-bit allows adding fields (e.g. for load-balancing) +// without having to bit-pack members, and is fine because this is only moved +// with relaxed stores, hence we do not have to fit it in the 32 futex bits. +class Config { // 8 bytes + public: + static std::vector AllCandidates(PoolWaitMode wait_mode, + size_t num_threads) { + std::vector spin_types(size_t{1}, DetectSpin()); + // Monitor-based spin may be slower, so also try Pause. + if (spin_types[0] != SpinType::kPause) { + spin_types.push_back(SpinType::kPause); + } + + std::vector wait_types; + if (wait_mode == PoolWaitMode::kSpin) { + // All except `kBlock`. + for (size_t wait = 0;; ++wait) { + const WaitType wait_type = static_cast(wait); + if (wait_type == WaitType::kSentinel) break; + if (wait_type != WaitType::kBlock) wait_types.push_back(wait_type); + } + } else { + wait_types.push_back(WaitType::kBlock); + } + + std::vector barrier_types; + // Note that casting an integer is UB if there is no matching enumerator, + // but we define a sentinel to prevent this. + for (size_t barrier = 0;; ++barrier) { + const BarrierType barrier_type = static_cast(barrier); + if (barrier_type == BarrierType::kSentinel) break; + // If <= 2 workers, group size of 4 is the same as 2. + if (num_threads <= 1 && barrier_type == BarrierType::kCounter4) continue; + if (num_threads <= 1 && barrier_type == BarrierType::kGroup4) continue; + barrier_types.push_back(barrier_type); + } + + std::vector candidates; + candidates.reserve(50); + for (const SpinType spin_type : spin_types) { + for (const WaitType wait_type : wait_types) { + for (const BarrierType barrier_type : barrier_types) { + candidates.emplace_back(spin_type, wait_type, barrier_type); + } + } + } + return candidates; + } + + std::string ToString() const { + char buf[128]; + snprintf(buf, sizeof(buf), "%14s %9s %9s", hwy::ToString(spin_type), + pool::ToString(wait_type), pool::ToString(barrier_type)); + return buf; + } + + Config() {} + Config(SpinType spin_type, WaitType wait_type, BarrierType barrier_type) + : spin_type(spin_type), + wait_type(wait_type), + barrier_type(barrier_type), + exit(false) {} + + SpinType spin_type; + WaitType wait_type; + BarrierType barrier_type; + bool exit; + uint32_t reserved = 0; +}; +static_assert(sizeof(Config) == 8, ""); + +// Per-worker state used by both main and worker threads. `ThreadFunc` +// (threads) and `ThreadPool` (main) have a few additional members of their own. +class alignas(HWY_ALIGNMENT) Worker { // HWY_ALIGNMENT bytes + static constexpr size_t kMaxVictims = 4; + + static constexpr auto kAcq = std::memory_order_acquire; + static constexpr auto kRel = std::memory_order_release; + + public: + Worker(const size_t worker, const size_t num_threads, + const Divisor64& div_workers) + : worker_(worker), num_threads_(num_threads), workers_(this - worker) { + (void)padding_; + + HWY_DASSERT(IsAligned(this, HWY_ALIGNMENT)); + HWY_DASSERT(worker <= num_threads); + const size_t num_workers = static_cast(div_workers.GetDivisor()); + num_victims_ = static_cast(HWY_MIN(kMaxVictims, num_workers)); + + // Increase gap between coprimes to reduce collisions. + const uint32_t coprime = ShuffledIota::FindAnotherCoprime( + static_cast(num_workers), + static_cast((worker + 1) * 257 + worker * 13)); + const ShuffledIota shuffled_iota(coprime); + + // To simplify `WorkerRun`, this worker is the first to 'steal' from. + victims_[0] = static_cast(worker); + for (uint32_t i = 1; i < num_victims_; ++i) { + victims_[i] = shuffled_iota.Next(victims_[i - 1], div_workers); + HWY_DASSERT(victims_[i] != worker); + } + } + + // Placement-newed by `WorkerLifecycle`, we do not expect any copying. + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + size_t Index() const { return worker_; } + Worker* AllWorkers() { return workers_; } + const Worker* AllWorkers() const { return workers_; } + size_t NumThreads() const { return num_threads_; } + + // ------------------------ Per-worker storage for `SendConfig` + + Config LatchedConfig() const { return latched_; } + // For workers, but no harm if also called by main thread. + void LatchConfig(Config copy) { latched_ = copy; } + + // ------------------------ Task assignment + + // Called from the main thread. + void SetRange(const uint64_t begin, const uint64_t end) { + my_begin_.store(begin, kRel); + my_end_.store(end, kRel); + } + + uint64_t MyEnd() const { return my_end_.load(kAcq); } + + Span Victims() const { + return hwy::Span(victims_.data(), + static_cast(num_victims_)); + } + + // Returns the next task to execute. If >= MyEnd(), it must be skipped. + uint64_t WorkerReserveTask() { + // TODO(janwas): replace with cooperative work-stealing. + return my_begin_.fetch_add(1, std::memory_order_relaxed); + } + + // ------------------------ Waiter: Threads wait for tasks + + // WARNING: some `Wait*` do not set this for all Worker instances. For + // example, `WaitType::kBlock` only uses the first worker's `Waiter` because + // one futex can wake multiple waiters. Hence we never load this directly + // without going through `Wait*` policy classes, and must ensure all threads + // use the same wait mode. + + const std::atomic& Waiter() const { return wait_epoch_; } + std::atomic& MutableWaiter() { return wait_epoch_; } // futex + void StoreWaiter(uint32_t epoch) { wait_epoch_.store(epoch, kRel); } + + // ------------------------ Barrier: Main thread waits for workers + + const std::atomic& Barrier() const { return barrier_epoch_; } + std::atomic& MutableBarrier() { return barrier_epoch_; } + void StoreBarrier(uint32_t epoch) { barrier_epoch_.store(epoch, kRel); } + + private: + // Atomics first because arm7 clang otherwise makes them unaligned. + + // Set by `SetRange`: + alignas(8) std::atomic my_begin_; + alignas(8) std::atomic my_end_; + + // Use u32 to match futex.h. + alignas(4) std::atomic wait_epoch_{0}; + alignas(4) std::atomic barrier_epoch_{0}; // is reset + + uint32_t num_victims_; // <= kPoolMaxVictims + std::array victims_; + + Config latched_; + + const size_t worker_; + const size_t num_threads_; + Worker* const workers_; + + uint8_t padding_[HWY_ALIGNMENT - 64 - sizeof(victims_)]; +}; +static_assert(sizeof(Worker) == HWY_ALIGNMENT, ""); + +#pragma pack(pop) + +// Creates/destroys `Worker` using preallocated storage. See comment at +// `ThreadPool::worker_bytes_` for why we do not dynamically allocate. +class WorkerLifecycle { // 0 bytes + public: + // Placement new for `Worker` into `storage` because its ctor requires + // the worker index. Returns array of all workers. + static Worker* Init(uint8_t* storage, size_t num_threads, + const Divisor64& div_workers) { + Worker* workers = new (storage) Worker(0, num_threads, div_workers); + for (size_t worker = 1; worker <= num_threads; ++worker) { + new (Addr(storage, worker)) Worker(worker, num_threads, div_workers); + // Ensure pointer arithmetic is the same (will be used in Destroy). + HWY_DASSERT(reinterpret_cast(workers + worker) == + reinterpret_cast(Addr(storage, worker))); + } + + // Publish non-atomic stores in `workers`. + std::atomic_thread_fence(std::memory_order_release); + + return workers; + } + + static void Destroy(Worker* workers, size_t num_threads) { + for (size_t worker = 0; worker <= num_threads; ++worker) { + workers[worker].~Worker(); + } + } + + private: + static uint8_t* Addr(uint8_t* storage, size_t worker) { + return storage + worker * sizeof(Worker); + } +}; + +#pragma pack(push, 1) +// Stores arguments to `Run`: the function and range of task indices. Set by +// the main thread, read by workers including the main thread. +class alignas(8) Tasks { + static constexpr auto kAcq = std::memory_order_acquire; + + // Signature of the (internal) function called from workers(s) for each + // `task` in the [`begin`, `end`) passed to Run(). Closures (lambdas) do not + // receive the first argument, which points to the lambda object. + typedef void (*RunFunc)(const void* opaque, uint64_t task, size_t worker); + + public: + Tasks() { HWY_DASSERT(IsAligned(this, 8)); } + + template + void Set(uint64_t begin, uint64_t end, const Closure& closure) { + constexpr auto kRel = std::memory_order_release; + // `TestTasks` and `SetWaitMode` call this with `begin == end`. + HWY_DASSERT(begin <= end); + begin_.store(begin, kRel); + end_.store(end, kRel); + func_.store(static_cast(&CallClosure), kRel); + opaque_.store(reinterpret_cast(&closure), kRel); + } + + // Assigns workers their share of `[begin, end)`. Called from the main + // thread; workers are initializing or spinning for a command. + static void DivideRangeAmongWorkers(const uint64_t begin, const uint64_t end, + const Divisor64& div_workers, + Worker* workers) { + const size_t num_workers = static_cast(div_workers.GetDivisor()); + HWY_DASSERT(num_workers > 1); // Else Run() runs on the main thread. + HWY_DASSERT(begin <= end); + const size_t num_tasks = static_cast(end - begin); + + // Assigning all remainders to the last worker causes imbalance. We instead + // give one more to each worker whose index is less. This may be zero when + // called from `TestTasks`. + const size_t min_tasks = static_cast(div_workers.Divide(num_tasks)); + const size_t remainder = + static_cast(div_workers.Remainder(num_tasks)); + + uint64_t my_begin = begin; + for (size_t worker = 0; worker < num_workers; ++worker) { + const uint64_t my_end = my_begin + min_tasks + (worker < remainder); + workers[worker].SetRange(my_begin, my_end); + my_begin = my_end; + } + HWY_DASSERT(my_begin == end); + } + + // Runs the worker's assigned range of tasks, plus work stealing if needed. + HWY_POOL_PROFILE void WorkerRun(Worker* worker) const { + if (NumTasks() > worker->NumThreads() + 1) { + WorkerRunWithStealing(worker); + } else { + WorkerRunSingle(worker->Index()); + } + } + + private: + // Special case for <= 1 task per worker, where stealing is unnecessary. + void WorkerRunSingle(size_t worker) const { + const uint64_t begin = begin_.load(kAcq); + const uint64_t end = end_.load(kAcq); + HWY_DASSERT(begin <= end); + + const uint64_t task = begin + worker; + // We might still have more workers than tasks, so check first. + if (HWY_LIKELY(task < end)) { + const void* opaque = Opaque(); + const RunFunc func = Func(); + func(opaque, task, worker); + } + } + + // Must be called for each `worker` in [0, num_workers). + // + // A prior version of this code attempted to assign only as much work as a + // worker will actually use. As with OpenMP's 'guided' strategy, we assigned + // remaining/(k*num_threads) in each iteration. Although the worst-case + // imbalance is bounded, this required several rounds of work allocation, and + // the atomic counter did not scale to > 30 threads. + // + // We now use work stealing instead, where already-finished workers look for + // and perform work from others, as if they were that worker. This deals with + // imbalances as they arise, but care is required to reduce contention. We + // randomize the order in which threads choose victims to steal from. + HWY_POOL_PROFILE void WorkerRunWithStealing(Worker* worker) const { + Worker* workers = worker->AllWorkers(); + const size_t index = worker->Index(); + const RunFunc func = Func(); + const void* opaque = Opaque(); + + // For each worker in random order, starting with our own, attempt to do + // all their work. + for (uint32_t victim : worker->Victims()) { + Worker* other_worker = workers + victim; + + // Until all of other_worker's work is done: + const uint64_t other_end = other_worker->MyEnd(); + for (;;) { + // The worker that first sets `task` to `other_end` exits this loop. + // After that, `task` can be incremented up to `num_workers - 1` times, + // once per other worker. + const uint64_t task = other_worker->WorkerReserveTask(); + if (HWY_UNLIKELY(task >= other_end)) { + hwy::Pause(); // Reduce coherency traffic while stealing. + break; + } + // Pass the index we are actually running on; this is important + // because it is the TLS index for user code. + func(opaque, task, index); + } + } + } + + size_t NumTasks() const { + return static_cast(end_.load(kAcq) - begin_.load(kAcq)); + } + + const void* Opaque() const { return opaque_.load(kAcq); } + RunFunc Func() const { return func_.load(kAcq); } + + // Calls closure(task, worker). Signature must match `RunFunc`. + template + static void CallClosure(const void* opaque, uint64_t task, size_t worker) { + (*reinterpret_cast(opaque))(task, worker); + } + + std::atomic begin_; + std::atomic end_; + std::atomic func_; + std::atomic opaque_; +}; +static_assert(sizeof(Tasks) == 16 + 2 * sizeof(void*), ""); +#pragma pack(pop) + +// ------------------------------ Threads wait, main wakes them + +// Considerations: +// - uint32_t storage per `Worker` so we can use `futex.h`. +// - avoid atomic read-modify-write. These are implemented on x86 using a LOCK +// prefix, which interferes with other cores' cache-coherency transactions +// and drains our core's store buffer. We use only store-release and +// load-acquire. Although expressed using `std::atomic`, these are normal +// loads/stores in the strong x86 memory model. +// - prefer to avoid resetting the state. "Sense-reversing" (flipping a flag) +// would work, but we we prefer an 'epoch' counter because it is more useful +// and easier to understand/debug, and as fast. + +// Both the main thread and each worker maintain their own counter, which are +// implicitly synchronized by the barrier. To wake, the main thread does a +// store-release, and each worker does a load-acquire. The policy classes differ +// in whether they block or spin (with pause/monitor to reduce power), and +// whether workers check their own counter or a shared one. +// +// All methods are const because they only use storage in `Worker`, and we +// prefer to pass const-references to empty classes to enable type deduction. + +// Futex: blocking reduces apparent CPU usage, but has higher wake latency. +struct WaitBlock { + WaitType Type() const { return WaitType::kBlock; } + + // Wakes all workers by storing the current `epoch`. + void WakeWorkers(Worker* workers, const uint32_t epoch) const { + HWY_DASSERT(epoch != 0); + workers[0].StoreWaiter(epoch); + WakeAll(workers[0].MutableWaiter()); // futex: expensive syscall + } + + // Waits until `WakeWorkers(_, epoch)` has been called. + template + void UntilWoken(const Worker* worker, const Spin& /*spin*/, + const uint32_t epoch) const { + BlockUntilDifferent(epoch - 1, worker->AllWorkers()->Waiter()); + } +}; + +// Single u32: single store by the main thread. All worker threads poll this +// one cache line and thus have it in a shared state, which means the store +// will invalidate each of them, leading to more transactions than SpinSeparate. +struct WaitSpin1 { + WaitType Type() const { return WaitType::kSpin1; } + + void WakeWorkers(Worker* workers, const uint32_t epoch) const { + workers[0].StoreWaiter(epoch); + } + + template + void UntilWoken(const Worker* worker, const Spin& spin, + const uint32_t epoch) const { + (void)spin.UntilEqual(epoch, worker->AllWorkers()->Waiter()); + // TODO: store reps in stats. + } +}; + +// Separate u32 per thread: more stores for the main thread, but each worker +// only polls its own cache line, leading to fewer cache-coherency transactions. +struct WaitSpinSeparate { + WaitType Type() const { return WaitType::kSpinSeparate; } + + void WakeWorkers(Worker* workers, const uint32_t epoch) const { + for (size_t thread = 0; thread < workers->NumThreads(); ++thread) { + workers[thread].StoreWaiter(epoch); + } + } + + template + void UntilWoken(const Worker* worker, const Spin& spin, + const uint32_t epoch) const { + (void)spin.UntilEqual(epoch, worker->Waiter()); + // TODO: store reps in stats. + } +}; + +// ------------------------------ Barrier: Main thread waits for workers + +// Single atomic counter. TODO: remove if not competitive? +template +class BarrierCounter { + static_assert(kShards == 1 || kShards == 2 || kShards == 4, ""); // pow2 + + public: + BarrierType Type() const { + return kShards == 1 ? BarrierType::kCounter1 + : kShards == 2 ? BarrierType::kCounter2 + : BarrierType::kCounter4; + } + + void Reset(Worker* workers) const { + for (size_t i = 0; i < kShards; ++i) { + // Use last worker(s) to avoid contention with other stores to the Worker. + // Note that there are kMaxThreads + 1 workers, hence i == 0 is the last. + workers[kMaxThreads - i].StoreBarrier(0); + } + } + + template + void WorkerReached(Worker* worker, const Spin& /*spin*/, + uint32_t /*epoch*/) const { + const size_t shard = worker->Index() & (kShards - 1); + const auto kAcqRel = std::memory_order_acq_rel; + worker->AllWorkers()[kMaxThreads - shard].MutableBarrier().fetch_add( + 1, kAcqRel); + } + + template + void UntilReached(size_t num_threads, const Worker* workers, const Spin& spin, + uint32_t /*epoch*/) const { + HWY_IF_CONSTEXPR(kShards == 1) { + (void)spin.UntilEqual(static_cast(num_threads), + workers[kMaxThreads].Barrier()); + } + HWY_IF_CONSTEXPR(kShards == 2) { + const auto kAcq = std::memory_order_acquire; + for (;;) { + hwy::Pause(); + const uint64_t sum = workers[kMaxThreads - 0].Barrier().load(kAcq) + + workers[kMaxThreads - 1].Barrier().load(kAcq); + if (sum == num_threads) break; + } + } + HWY_IF_CONSTEXPR(kShards == 4) { + const auto kAcq = std::memory_order_acquire; + for (;;) { + hwy::Pause(); + const uint64_t sum = workers[kMaxThreads - 0].Barrier().load(kAcq) + + workers[kMaxThreads - 1].Barrier().load(kAcq) + + workers[kMaxThreads - 2].Barrier().load(kAcq) + + workers[kMaxThreads - 3].Barrier().load(kAcq); + if (sum == num_threads) break; + } + } + } +}; + +// As with the wait, a store-release of the same local epoch counter serves as a +// "have arrived" flag that does not require resetting. + +// Main thread loops over each worker. +class BarrierOrdered { + public: + BarrierType Type() const { return BarrierType::kOrdered; } + + void Reset(Worker* /*workers*/) const {} + + template + void WorkerReached(Worker* worker, const Spin&, uint32_t epoch) const { + worker->StoreBarrier(epoch); + } + + template + void UntilReached(size_t num_threads, const Worker* workers, const Spin& spin, + uint32_t epoch) const { + for (size_t i = 0; i < num_threads; ++i) { + (void)spin.UntilEqual(epoch, workers[i].Barrier()); + } + } +}; + +// Leader threads wait for others in the group, main thread loops over leaders. +template +class BarrierGroup { + public: + BarrierType Type() const { + return kGroupSize == 2 ? BarrierType::kGroup2 : BarrierType::kGroup4; + } + + void Reset(Worker* /*workers*/) const {} + + template + void WorkerReached(Worker* worker, const Spin& spin, uint32_t epoch) const { + const size_t thread = worker->Index(); + // Leaders wait for all others in their group before marking themselves. + if (thread % kGroupSize == 0) { + for (size_t i = thread + 1; + i < HWY_MIN(thread + kGroupSize, worker->NumThreads()); ++i) { + (void)spin.UntilEqual(epoch, worker->AllWorkers()[i].Barrier()); + } + } + worker->StoreBarrier(epoch); + } + + template + void UntilReached(size_t num_threads, const Worker* workers, const Spin& spin, + uint32_t epoch) const { + for (size_t i = 0; i < num_threads; i += kGroupSize) { + (void)spin.UntilEqual(epoch, workers[i].Barrier()); + } + } +}; + +// ------------------------------ Inlining policy classes + +// We want to inline the various spin/wait/barrier policy classes into larger +// code sections because both the main and worker threads use two or three of +// them at a time, and we do not want separate branches around each. +// +// We generate code for three combinations of the enums, hence implement +// composable adapters that 'add' `Wait` and `Barrier` arguments. `spin.h` +// provides a `CallWithSpin`, hence it is the outermost. C++11 lacks generic +// lambdas, so we implement these as classes. +template +class FunctorAddWait { + public: + FunctorAddWait(WaitType wait_type, Func&& func) + : func_(std::forward(func)), wait_type_(wait_type) {} + + template + HWY_INLINE void operator()(const Spin& spin) { + switch (wait_type_) { + case WaitType::kBlock: + return func_(spin, WaitBlock()); + case WaitType::kSpin1: + return func_(spin, WaitSpin1()); + case WaitType::kSpinSeparate: + return func_(spin, WaitSpinSeparate()); + default: + HWY_UNREACHABLE; + } + } + + private: + Func&& func_; + WaitType wait_type_; +}; + +template +class FunctorAddBarrier { + public: + FunctorAddBarrier(BarrierType barrier_type, Func&& func) + : func_(std::forward(func)), barrier_type_(barrier_type) {} + + template + HWY_INLINE void operator()(const Wait& wait) { + switch (barrier_type_) { + case BarrierType::kOrdered: + return func_(wait, BarrierOrdered()); + case BarrierType::kCounter1: + return func_(wait, BarrierCounter<1>()); + case BarrierType::kCounter2: + return func_(wait, BarrierCounter<2>()); + case BarrierType::kCounter4: + return func_(wait, BarrierCounter<4>()); + case BarrierType::kGroup2: + return func_(wait, BarrierGroup<2>()); + case BarrierType::kGroup4: + return func_(wait, BarrierGroup<4>()); + default: + HWY_UNREACHABLE; + } + } + template + HWY_INLINE void operator()(const Spin& spin, const Wait& wait) { + switch (barrier_type_) { + case BarrierType::kOrdered: + return func_(spin, wait, BarrierOrdered()); + case BarrierType::kCounter1: + return func_(spin, wait, BarrierCounter<1>()); + case BarrierType::kCounter2: + return func_(spin, wait, BarrierCounter<2>()); + case BarrierType::kCounter4: + return func_(spin, wait, BarrierCounter<4>()); + case BarrierType::kGroup2: + return func_(spin, wait, BarrierGroup<2>()); + case BarrierType::kGroup4: + return func_(spin, wait, BarrierGroup<4>()); + default: + HWY_UNREACHABLE; + } + } + + private: + Func&& func_; + BarrierType barrier_type_; +}; + +// Calls unrolled code selected by all 3 enums. +template +HWY_INLINE void CallWithConfig(const Config& config, Func&& func) { + CallWithSpin( + config.spin_type, + FunctorAddWait>( + config.wait_type, FunctorAddBarrier(config.barrier_type, + std::forward(func)))); +} + +// For `WorkerAdapter`, `Spin` and `Wait`. +template +HWY_INLINE void CallWithSpinWait(const Config& config, Func&& func) { + CallWithSpin( + config.spin_type, + FunctorAddWait(config.wait_type, std::forward(func))); +} + +// For `WorkerAdapter`, only `Spin` and `Barrier`. +template +HWY_INLINE void CallWithSpinBarrier(const Config& config, Func&& func) { + CallWithSpin( + config.spin_type, + FunctorAddBarrier(config.barrier_type, std::forward(func))); +} + +// ------------------------------ Adapters + +// Logic of the main and worker threads, again packaged as classes because +// C++11 lacks generic lambdas, called by `CallWith*`. + +class MainAdapter { + public: + MainAdapter(Worker* main, const Tasks* tasks) : main_(main), tasks_(tasks) {} + + void SetEpoch(uint32_t epoch) { epoch_ = epoch; } + + template + HWY_POOL_PROFILE void operator()(const Spin& spin, const Wait& wait, + const Barrier& barrier) const { + Worker* workers = main_->AllWorkers(); + const size_t num_threads = main_->NumThreads(); + barrier.Reset(workers); + + wait.WakeWorkers(workers, epoch_); + // Threads might still be starting up and wake up late, but we wait for + // them at the barrier below. + + // Also perform work on the main thread before the barrier. + tasks_->WorkerRun(main_); + + // Waits until all *threads* (not the main thread, because it already knows + // it is here) called `WorkerReached`. All `barrier` types use spinning. + + barrier.UntilReached(num_threads, workers, spin, epoch_); + + // Threads may already be waiting `UntilWoken`, which serves as the + // 'release' phase of the barrier. + } + + private: + Worker* const main_; + const Tasks* const tasks_; + uint32_t epoch_; +}; + +class WorkerAdapter { + public: + explicit WorkerAdapter(Worker* worker) : worker_(worker) {} + + void SetEpoch(uint32_t epoch) { epoch_ = epoch; } + + // Split into separate wait/barrier functions because `ThreadFunc` latches + // the config in between them. + template + void operator()(const Spin& spin, const Wait& wait) const { + wait.UntilWoken(worker_, spin, epoch_); + } + + template + void operator()(const Spin& spin, const Barrier& barrier) const { + barrier.WorkerReached(worker_, spin, epoch_); + } + + private: + Worker* const worker_; + uint32_t epoch_; +}; + +// Could also be a lambda in ThreadPool ctor, but this allows annotating with +// `HWY_POOL_PROFILE` so we can more easily inspect the generated code. +class ThreadFunc { + public: + ThreadFunc(Worker* worker, Tasks* tasks, Config config) + : worker_(worker), + tasks_(tasks), + config_(config), + worker_adapter_(worker_) { + worker->LatchConfig(config); + } + + HWY_POOL_PROFILE void operator()() { + SetThreadName("worker%03zu", static_cast(worker_->Index())); + + // Ensure main thread's writes are visible (synchronizes with fence in + // `WorkerLifecycle::Init`). + std::atomic_thread_fence(std::memory_order_acquire); + + // Initialization must match pre-increment in `MainAdapter::SetEpoch`. + // Loop termination is triggered by `~ThreadPool`. + for (uint32_t epoch = 1;; ++epoch) { + worker_adapter_.SetEpoch(epoch); + CallWithSpinWait(config_, worker_adapter_); + + // Must happen before `WorkerRun` because `SendConfig` writes it there. + config_ = worker_->LatchedConfig(); + + tasks_->WorkerRun(worker_); + + // Notify barrier after `WorkerRun`. + CallWithSpinBarrier(config_, worker_adapter_); + + // Check after notifying the barrier, otherwise the main thread deadlocks. + if (HWY_UNLIKELY(config_.exit)) break; + } + } + + private: + Worker* const worker_; + Tasks* const tasks_; + + Config config_; + WorkerAdapter worker_adapter_; +}; + +} // namespace pool + +// Highly efficient parallel-for, intended for workloads with thousands of +// fork-join regions which consist of calling tasks[t](i) for a few hundred i, +// using dozens of threads. +// +// To reduce scheduling overhead, we assume that tasks are statically known and +// that threads do not schedule new work themselves. This allows us to avoid +// queues and only store a counter plus the current task. The latter is a +// pointer to a lambda function, without the allocation/indirection required for +// std::function. +// +// To reduce fork/join latency, we choose an efficient barrier, optionally +// enable spin-waits via SetWaitMode, and avoid any mutex/lock. We largely even +// avoid atomic RMW operations (LOCK prefix): currently for the wait and +// barrier, in future hopefully also for work stealing. +// +// To eliminate false sharing and enable reasoning about cache line traffic, the +// class is aligned and holds all worker state. +// +// For load-balancing, we use work stealing in random order. +class alignas(HWY_ALIGNMENT) ThreadPool { + public: + // This typically includes hyperthreads, hence it is a loose upper bound. + // -1 because these are in addition to the main thread. + static size_t MaxThreads() { + LogicalProcessorSet lps; + // This is OS dependent, but more accurate if available because it takes + // into account restrictions set by cgroups or numactl/taskset. + if (GetThreadAffinity(lps)) { + return lps.Count() - 1; + } + return static_cast(std::thread::hardware_concurrency() - 1); + } + + // `num_threads` is the number of *additional* threads to spawn, which should + // not exceed `MaxThreads()`. Note that the main thread also performs work. + explicit ThreadPool(size_t num_threads) + : have_timer_stop_(platform::HaveTimerStop(cpu100_)), + num_threads_(ClampedNumThreads(num_threads)), + div_workers_(num_threads_ + 1), + workers_(pool::WorkerLifecycle::Init(worker_bytes_, num_threads_, + div_workers_)), + main_adapter_(workers_ + num_threads_, &tasks_) { + // Leaves the default wait mode as `kBlock`, which means futex, because + // spinning only makes sense when threads are pinned and wake latency is + // important, so it must explicitly be requested by calling `SetWaitMode`. + for (PoolWaitMode mode : {PoolWaitMode::kSpin, PoolWaitMode::kBlock}) { + wait_mode_ = mode; // for AutoTuner + AutoTuner().SetCandidates( + pool::Config::AllCandidates(mode, num_threads_)); + } + config_ = AutoTuner().Candidates()[0]; + + threads_.reserve(num_threads_); + for (size_t thread = 0; thread < num_threads_; ++thread) { + threads_.emplace_back( + pool::ThreadFunc(workers_ + thread, &tasks_, config_)); + } + + // No barrier is required here because wakeup works regardless of the + // relative order of wake and wait. + } + + // Waits for all threads to exit. + ~ThreadPool() { + // There is no portable way to request threads to exit like `ExitThread` on + // Windows, otherwise we could call that from `Run`. Instead, we must cause + // the thread to wake up and exit. We can use the same `SendConfig` + // mechanism as `SetWaitMode`. + pool::Config copy = config_; + copy.exit = true; + SendConfig(copy); + + for (std::thread& thread : threads_) { + HWY_DASSERT(thread.joinable()); + thread.join(); + } + + pool::WorkerLifecycle::Destroy(workers_, num_threads_); + } + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator&(const ThreadPool&) = delete; + + // Returns number of Worker, i.e., one more than the largest `worker` + // argument. Useful for callers that want to allocate thread-local storage. + size_t NumWorkers() const { + return static_cast(div_workers_.GetDivisor()); + } + + // `mode` defaults to `kBlock`, which means futex. Switching to `kSpin` + // reduces fork-join overhead especially when there are many calls to `Run`, + // but wastes power when waiting over long intervals. Must not be called + // concurrently with any `Run`, because this uses the same waiter/barrier. + void SetWaitMode(PoolWaitMode mode) { + wait_mode_ = mode; + SendConfig(AutoTuneComplete() ? *AutoTuner().Best() + : AutoTuner().NextConfig()); + } + + // For printing which are in use. + pool::Config config() const { return config_; } + + bool AutoTuneComplete() const { return AutoTuner().Best(); } + Span AutoTuneCosts() { return AutoTuner().Costs(); } + + // parallel-for: Runs `closure(task, worker)` on workers for every `task` in + // `[begin, end)`. Note that the unit of work should be large enough to + // amortize the function call overhead, but small enough that each worker + // processes a few tasks. Thus each `task` is usually a loop. + // + // Not thread-safe - concurrent parallel-for in the same `ThreadPool` are + // forbidden unless `NumWorkers() == 1` or `end <= begin + 1`. + template + void Run(uint64_t begin, uint64_t end, const Closure& closure) { + const size_t num_tasks = static_cast(end - begin); + const size_t num_workers = NumWorkers(); + + // If zero or one task, or no extra threads, run on the main thread without + // setting any member variables, because we may be re-entering Run. + if (HWY_UNLIKELY(num_tasks <= 1 || num_workers == 1)) { + for (uint64_t task = begin; task < end; ++task) { + closure(task, /*worker=*/0); + } + return; + } + + SetBusy(); + tasks_.Set(begin, end, closure); + + // More than one task per worker: use work stealing. + if (HWY_LIKELY(num_tasks > num_workers)) { + pool::Tasks::DivideRangeAmongWorkers(begin, end, div_workers_, workers_); + } + + main_adapter_.SetEpoch(++epoch_); + + AutoTuneT& auto_tuner = AutoTuner(); + if (HWY_LIKELY(auto_tuner.Best())) { + CallWithConfig(config_, main_adapter_); + ClearBusy(); + } else { + const uint64_t t0 = timer::Start(); + CallWithConfig(config_, main_adapter_); + const uint64_t t1 = have_timer_stop_ ? timer::Stop() : timer::Start(); + auto_tuner.NotifyCost(t1 - t0); + ClearBusy(); // before `SendConfig` + if (auto_tuner.Best()) { // just finished + HWY_IF_CONSTEXPR(pool::kVerbosity >= 1) { + const size_t idx_best = static_cast( + auto_tuner.Best() - auto_tuner.Candidates().data()); + HWY_DASSERT(idx_best < auto_tuner.Costs().size()); + auto& AT = auto_tuner.Costs()[idx_best]; + const double best_cost = AT.EstimateCost(); + HWY_DASSERT(best_cost > 0.0); // will divide by this below + + Stats s_ratio; + for (size_t i = 0; i < auto_tuner.Costs().size(); ++i) { + if (i == idx_best) continue; + const double cost = auto_tuner.Costs()[i].EstimateCost(); + s_ratio.Notify(static_cast(cost / best_cost)); + } + + fprintf(stderr, " %s %5.0f +/- %4.0f. Gain %.2fx [%.2fx, %.2fx]\n", + auto_tuner.Best()->ToString().c_str(), best_cost, AT.Stddev(), + s_ratio.GeometricMean(), s_ratio.Min(), s_ratio.Max()); + } + SendConfig(*auto_tuner.Best()); + } else { + HWY_IF_CONSTEXPR(pool::kVerbosity >= 2) { + fprintf(stderr, " %s %5lu\n", config_.ToString().c_str(), t1 - t0); + } + SendConfig(auto_tuner.NextConfig()); + } + } + } + + // Can pass this as init_closure when no initialization is needed. + // DEPRECATED, better to call the Run() overload without the init_closure arg. + static bool NoInit(size_t /*num_threads*/) { return true; } // DEPRECATED + + // DEPRECATED equivalent of NumWorkers. Note that this is not the same as the + // ctor argument because num_threads = 0 has the same effect as 1. + size_t NumThreads() const { return NumWorkers(); } // DEPRECATED + + // DEPRECATED prior interface with 32-bit tasks and first calling + // `init_closure(num_threads)`. Instead, perform any init before this, calling + // NumWorkers() for an upper bound on the worker index, then call the other + // overload of Run(). + template + bool Run(uint64_t begin, uint64_t end, const InitClosure& init_closure, + const RunClosure& run_closure) { + if (!init_closure(NumThreads())) return false; + Run(begin, end, run_closure); + return true; + } + + private: + // Used to initialize ThreadPool::num_threads_ from its ctor argument. + static size_t ClampedNumThreads(size_t num_threads) { + // Upper bound is required for `worker_bytes_`. + if (HWY_UNLIKELY(num_threads > pool::kMaxThreads)) { + HWY_WARN("ThreadPool: clamping num_threads %zu to %zu.", num_threads, + pool::kMaxThreads); + num_threads = pool::kMaxThreads; + } + return num_threads; + } + + // Debug-only re-entrancy detection. + void SetBusy() { HWY_DASSERT(!busy_.test_and_set()); } + void ClearBusy() { HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) busy_.clear(); } + + // Two-phase barrier protocol for sending `copy` to workers, similar to the + // 'quiescent state' used in RCU. + // + // Phase 1: + // - Main wakes threads using the old config. + // - Threads latch `copy` during `WorkerRun`. + // - Threads notify a barrier and wait for the next wake using the old config. + // + // Phase 2: + // - Main wakes threads still using the old config. + // - Threads switch their config to their latched `copy`. + // - Threads notify a barrier and wait, BOTH with the new config. + // - Main thread switches to `copy` for the next wake. + HWY_NOINLINE void SendConfig(pool::Config copy) { + if (NumWorkers() == 1) { + config_ = copy; + return; + } + + SetBusy(); + + const auto closure = [this, copy](uint64_t task, size_t worker) { + (void)task; + HWY_DASSERT(task == worker); // one task per worker + workers_[worker].LatchConfig(copy); + }; + tasks_.Set(0, NumWorkers(), closure); + // Same config as workers are *currently* using. + main_adapter_.SetEpoch(++epoch_); + CallWithConfig(config_, main_adapter_); + // All workers have latched `copy` and are waiting with the old config. + + // No-op task; will not be called because begin == end. + tasks_.Set(0, 0, [](uint64_t /*task*/, size_t /*worker*/) {}); + // Threads are waiting using the old config, but will switch after waking, + // which means we must already use the new barrier. + pool::Config new_barrier = config_; + new_barrier.barrier_type = copy.barrier_type; + main_adapter_.SetEpoch(++epoch_); + CallWithConfig(new_barrier, main_adapter_); + // All have woken and are, or will be, waiting per the *new* config. Now we + // can entirely switch the main thread's config for the next wake. + config_ = copy; + + ClearBusy(); + } + + using AutoTuneT = AutoTune; + AutoTuneT& AutoTuner() { + static_assert(static_cast(PoolWaitMode::kBlock) == 1, ""); + return auto_tune_[static_cast(wait_mode_) - 1]; + } + const AutoTuneT& AutoTuner() const { + return auto_tune_[static_cast(wait_mode_) - 1]; + } + + char cpu100_[100]; + const bool have_timer_stop_; + const size_t num_threads_; // not including main thread + const Divisor64 div_workers_; + pool::Worker* const workers_; // points into `worker_bytes_` + + pool::MainAdapter main_adapter_; + + // The only mutable state: + pool::Tasks tasks_; // written by `Run` and read by workers. + pool::Config config_; // for use by the next `Run`. Updated via `SendConfig`. + uint32_t epoch_ = 0; // passed to `MainAdapter`. + + // In debug builds, detects if functions are re-entered. + std::atomic_flag busy_ = ATOMIC_FLAG_INIT; + + // Unmodified after ctor, but cannot be const because we call thread::join(). + std::vector threads_; + + PoolWaitMode wait_mode_; + AutoTuneT auto_tune_[2]; // accessed via `AutoTuner` + + // Last because it is large. Store inside `ThreadPool` so that callers can + // bind it to the NUMA node's memory. Not stored inside `WorkerLifecycle` + // because that class would be initialized after `workers_`. + alignas(HWY_ALIGNMENT) uint8_t + worker_bytes_[sizeof(pool::Worker) * (pool::kMaxThreads + 1)]; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_THREAD_POOL_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/thread_pool/topology.h b/third_party/aom/third_party/highway/hwy/contrib/thread_pool/topology.h new file mode 100644 index 000000000000..dec8bbc998ea --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/thread_pool/topology.h @@ -0,0 +1,141 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_TOPOLOGY_H_ +#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_TOPOLOGY_H_ + +// OS-specific functions for processor topology and thread affinity. + +#include + +#include + +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/bit_set.h" + +namespace hwy { + +// Returns false if std::thread should not be used. +HWY_CONTRIB_DLLEXPORT bool HaveThreadingSupport(); + +// Upper bound on logical processors, including hyperthreads. +static constexpr size_t kMaxLogicalProcessors = 1024; // matches glibc + +// Set used by Get/SetThreadAffinity. +using LogicalProcessorSet = BitSet4096; + +// Returns false, or sets `lps` to all logical processors which are online and +// available to the current thread. +HWY_CONTRIB_DLLEXPORT bool GetThreadAffinity(LogicalProcessorSet& lps); + +// Ensures the current thread can only run on the logical processors in `lps`. +// Returns false if not supported (in particular on Apple), or if the +// intersection between `lps` and `GetThreadAffinity` is the empty set. +HWY_CONTRIB_DLLEXPORT bool SetThreadAffinity(const LogicalProcessorSet& lps); + +// Returns false, or ensures the current thread will only run on `lp`, which +// must not exceed `TotalLogicalProcessors`. Note that this merely calls +// `SetThreadAffinity`, see the comment there. +static inline bool PinThreadToLogicalProcessor(size_t lp) { + LogicalProcessorSet lps; + lps.Set(lp); + return SetThreadAffinity(lps); +} + +// Returns 1 if unknown, otherwise the total number of logical processors +// provided by the hardware clamped to `kMaxLogicalProcessors`. +// These processors are not necessarily all usable; you can determine which are +// via GetThreadAffinity(). +HWY_CONTRIB_DLLEXPORT size_t TotalLogicalProcessors(); + +struct Topology { + // Caller must check packages.empty(); if so, do not use any fields. + HWY_CONTRIB_DLLEXPORT Topology(); + + // Clique of cores with lower latency to each other. On Apple M1 these are + // four cores sharing an L2. On Zen4 these 'CCX' are up to eight cores sharing + // an L3 and a memory controller, or for Zen4c up to 16 and half the L3 size. + struct Cluster { + LogicalProcessorSet lps; + uint64_t private_kib = 0; // 0 if unknown + uint64_t shared_kib = 0; // 0 if unknown + uint64_t reserved1 = 0; + uint64_t reserved2 = 0; + uint64_t reserved3 = 0; + }; + + struct Core { + LogicalProcessorSet lps; + uint64_t reserved = 0; + }; + + struct Package { + std::vector clusters; + std::vector cores; + }; + + std::vector packages; + + // Several hundred instances, so prefer a compact representation. +#pragma pack(push, 1) + struct LP { + uint16_t cluster = 0; // < packages[package].clusters.size() + uint16_t core = 0; // < packages[package].cores.size() + uint8_t package = 0; // < packages.size() + uint8_t smt = 0; // < packages[package].cores[core].lps.Count() + uint8_t node = 0; + + uint8_t reserved = 0; + }; +#pragma pack(pop) + std::vector lps; // size() == TotalLogicalProcessors(). +}; + +#pragma pack(push, 1) +// Cache parameters. Note the overlap with `HWY_ALIGNMENT`, which is intended +// but not guaranteed to be an upper bound for L1/L2 line sizes, and +// `Topology::Cluster::private_kib/shared_kib`, which are intended but not +// guaranteed to be the L2/L3 sizes. Getting the exact parameters, including the +// ways of associativity, can be useful for modeling cache conflicts. +// +// Uses packed fields so the array of `Cache` fits in a typical cache line. +struct Cache { + // Arbitrary upper bound for sanity checking. + static constexpr uint16_t kMaxAssociativity = 128; + + // Zero if the level does not exist; *per-core* portion for shared caches. + uint32_t size_kib = 0; + // Also per-core portion, computed as number of lines / associativity. + uint32_t sets = 0; + uint16_t bytes_per_line = 0; + uint16_t associativity = 0; // number of ways + uint16_t cores_sharing = 0; // usually 1 for L1 + uint16_t reserved = 0; +}; +static_assert(sizeof(Cache) == 16, "Unexpected size"); +#pragma pack(pop) + +// Returns null if unknown, otherwise pointer to an array of `Cache` instances, +// where entry 0 is reserved, entry 1 describes the L1 data cache, entry 2 +// describes the (possibly unified or shared) L2, and entry 3 describes the L3 +// if its `size_kib != 0`. +// +// Initializes on-demand, which has some overhead for thread safety, hence +// callers should cache the result. +HWY_CONTRIB_DLLEXPORT const Cache* DataCaches(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_TOPOLOGY_H_ diff --git a/third_party/aom/third_party/highway/hwy/contrib/unroller/unroller-inl.h b/third_party/aom/third_party/highway/hwy/contrib/unroller/unroller-inl.h new file mode 100644 index 000000000000..7008e7ef4135 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/contrib/unroller/unroller-inl.h @@ -0,0 +1,473 @@ +// Copyright 2023 Matthew Kolbe +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ +#endif + +#include // std::abs + +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +template +struct UnrollerUnit { + static constexpr size_t kMaxTSize = HWY_MAX(sizeof(IN_T), sizeof(OUT_T)); + using LargerT = SignedFromSize; // only the size matters. + + DERIVED* me() { return static_cast(this); } + + static constexpr size_t MaxUnitLanes() { + return HWY_MAX_LANES_D(hn::ScalableTag); + } + static size_t ActualLanes() { return Lanes(hn::ScalableTag()); } + + using LargerD = hn::CappedTag; + using IT = hn::Rebind; + using OT = hn::Rebind; + IT d_in; + OT d_out; + using Y_VEC = hn::Vec; + using X_VEC = hn::Vec; + + Y_VEC Func(const ptrdiff_t idx, const X_VEC x, const Y_VEC y) { + return me()->Func(idx, x, y); + } + + X_VEC X0Init() { return me()->X0InitImpl(); } + + X_VEC X0InitImpl() { return hn::Zero(d_in); } + + Y_VEC YInit() { return me()->YInitImpl(); } + + Y_VEC YInitImpl() { return hn::Zero(d_out); } + + X_VEC Load(const ptrdiff_t idx, const IN_T* from) { + return me()->LoadImpl(idx, from); + } + + X_VEC LoadImpl(const ptrdiff_t idx, const IN_T* from) { + return hn::LoadU(d_in, from + idx); + } + + // MaskLoad can take in either a positive or negative number for `places`. if + // the number is positive, then it loads the top `places` values, and if it's + // negative, it loads the bottom |places| values. example: places = 3 + // | o | o | o | x | x | x | x | x | + // example places = -3 + // | x | x | x | x | x | o | o | o | + X_VEC MaskLoad(const ptrdiff_t idx, const IN_T* from, + const ptrdiff_t places) { + return me()->MaskLoadImpl(idx, from, places); + } + + X_VEC MaskLoadImpl(const ptrdiff_t idx, const IN_T* from, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_in, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_in, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + return hn::MaskedLoad(mask, d_in, from + idx); + } + + bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { + return me()->StoreAndShortCircuitImpl(idx, to, x); + } + + bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { + hn::StoreU(x, d_out, to + idx); + return true; + } + + ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, + ptrdiff_t const places) { + return me()->MaskStoreImpl(idx, to, x, places); + } + + ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_out, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_out, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + hn::BlendedStore(x, mask, d_out, to + idx); + return std::abs(places); + } + + ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); } + + ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) { + // default does nothing + (void)x; + (void)to; + return 0; + } + + void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { + me()->ReduceImpl(x0, x1, x2, y); + } + + void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { + // default does nothing + (void)x0; + (void)x1; + (void)x2; + (void)y; + } +}; + +template +struct UnrollerUnit2D { + DERIVED* me() { return static_cast(this); } + + static constexpr size_t kMaxTSize = + HWY_MAX(sizeof(IN0_T), HWY_MAX(sizeof(IN1_T), sizeof(OUT_T))); + using LargerT = SignedFromSize; // only the size matters. + + static constexpr size_t MaxUnitLanes() { + return HWY_MAX_LANES_D(hn::ScalableTag); + } + static size_t ActualLanes() { return Lanes(hn::ScalableTag()); } + + using LargerD = hn::CappedTag; + + using I0T = hn::Rebind; + using I1T = hn::Rebind; + using OT = hn::Rebind; + I0T d_in0; + I1T d_in1; + OT d_out; + using Y_VEC = hn::Vec; + using X0_VEC = hn::Vec; + using X1_VEC = hn::Vec; + + hn::Vec Func(const ptrdiff_t idx, const hn::Vec x0, + const hn::Vec x1, const Y_VEC y) { + return me()->Func(idx, x0, x1, y); + } + + X0_VEC X0Init() { return me()->X0InitImpl(); } + + X0_VEC X0InitImpl() { return hn::Zero(d_in0); } + + X1_VEC X1Init() { return me()->X1InitImpl(); } + + X1_VEC X1InitImpl() { return hn::Zero(d_in1); } + + Y_VEC YInit() { return me()->YInitImpl(); } + + Y_VEC YInitImpl() { return hn::Zero(d_out); } + + X0_VEC Load0(const ptrdiff_t idx, const IN0_T* from) { + return me()->Load0Impl(idx, from); + } + + X0_VEC Load0Impl(const ptrdiff_t idx, const IN0_T* from) { + return hn::LoadU(d_in0, from + idx); + } + + X1_VEC Load1(const ptrdiff_t idx, const IN1_T* from) { + return me()->Load1Impl(idx, from); + } + + X1_VEC Load1Impl(const ptrdiff_t idx, const IN1_T* from) { + return hn::LoadU(d_in1, from + idx); + } + + // maskload can take in either a positive or negative number for `places`. if + // the number is positive, then it loads the top `places` values, and if it's + // negative, it loads the bottom |places| values. example: places = 3 + // | o | o | o | x | x | x | x | x | + // example places = -3 + // | x | x | x | x | x | o | o | o | + X0_VEC MaskLoad0(const ptrdiff_t idx, const IN0_T* from, + const ptrdiff_t places) { + return me()->MaskLoad0Impl(idx, from, places); + } + + X0_VEC MaskLoad0Impl(const ptrdiff_t idx, const IN0_T* from, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_in0, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_in0, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + return hn::MaskedLoad(mask, d_in0, from + idx); + } + + hn::Vec MaskLoad1(const ptrdiff_t idx, const IN1_T* from, + const ptrdiff_t places) { + return me()->MaskLoad1Impl(idx, from, places); + } + + hn::Vec MaskLoad1Impl(const ptrdiff_t idx, const IN1_T* from, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_in1, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_in1, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + return hn::MaskedLoad(mask, d_in1, from + idx); + } + + // store returns a bool that is `false` when + bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { + return me()->StoreAndShortCircuitImpl(idx, to, x); + } + + bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { + hn::StoreU(x, d_out, to + idx); + return true; + } + + ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, + const ptrdiff_t places) { + return me()->MaskStoreImpl(idx, to, x, places); + } + + ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_out, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_out, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + hn::BlendedStore(x, mask, d_out, to + idx); + return std::abs(places); + } + + ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); } + + ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) { + // default does nothing + (void)x; + (void)to; + return 0; + } + + void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { + me()->ReduceImpl(x0, x1, x2, y); + } + + void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { + // default does nothing + (void)x0; + (void)x1; + (void)x2; + (void)y; + } +}; + +template +inline void Unroller(FUNC& f, const IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y, + const ptrdiff_t n) { + auto xx = f.X0Init(); + auto yy = f.YInit(); + ptrdiff_t i = 0; + +#if HWY_MEM_OPS_MIGHT_FAULT + constexpr auto lane_sz = + static_cast(RemoveRef::MaxUnitLanes()); + if (n < lane_sz) { + const DFromV d; + // this may not fit on the stack for HWY_RVV, but we do not reach this code + // there + HWY_ALIGN IN_T xtmp[static_cast(lane_sz)]; + HWY_ALIGN OUT_T ytmp[static_cast(lane_sz)]; + + CopyBytes(x, xtmp, static_cast(n) * sizeof(IN_T)); + xx = f.MaskLoad(0, xtmp, n); + yy = f.Func(0, xx, yy); + Store(Zero(d), d, ytmp); + i += f.MaskStore(0, ytmp, yy, n); + i += f.Reduce(yy, ytmp); + CopyBytes(ytmp, y, static_cast(i) * sizeof(OUT_T)); + return; + } +#endif + + const ptrdiff_t actual_lanes = + static_cast(RemoveRef::ActualLanes()); + if (n > 4 * actual_lanes) { + auto xx1 = f.X0Init(); + auto yy1 = f.YInit(); + auto xx2 = f.X0Init(); + auto yy2 = f.YInit(); + auto xx3 = f.X0Init(); + auto yy3 = f.YInit(); + + while (i + 4 * actual_lanes - 1 < n) { + xx = f.Load(i, x); + i += actual_lanes; + xx1 = f.Load(i, x); + i += actual_lanes; + xx2 = f.Load(i, x); + i += actual_lanes; + xx3 = f.Load(i, x); + i -= 3 * actual_lanes; + + yy = f.Func(i, xx, yy); + yy1 = f.Func(i + actual_lanes, xx1, yy1); + yy2 = f.Func(i + 2 * actual_lanes, xx2, yy2); + yy3 = f.Func(i + 3 * actual_lanes, xx3, yy3); + + if (!f.StoreAndShortCircuit(i, y, yy)) return; + i += actual_lanes; + if (!f.StoreAndShortCircuit(i, y, yy1)) return; + i += actual_lanes; + if (!f.StoreAndShortCircuit(i, y, yy2)) return; + i += actual_lanes; + if (!f.StoreAndShortCircuit(i, y, yy3)) return; + i += actual_lanes; + } + + f.Reduce(yy3, yy2, yy1, &yy); + } + + while (i + actual_lanes - 1 < n) { + xx = f.Load(i, x); + yy = f.Func(i, xx, yy); + if (!f.StoreAndShortCircuit(i, y, yy)) return; + i += actual_lanes; + } + + if (i != n) { + xx = f.MaskLoad(n - actual_lanes, x, i - n); + yy = f.Func(n - actual_lanes, xx, yy); + f.MaskStore(n - actual_lanes, y, yy, i - n); + } + + f.Reduce(yy, y); +} + +template +inline void Unroller(FUNC& HWY_RESTRICT f, IN0_T* HWY_RESTRICT x0, + IN1_T* HWY_RESTRICT x1, OUT_T* HWY_RESTRICT y, + const ptrdiff_t n) { + const ptrdiff_t lane_sz = + static_cast(RemoveRef::ActualLanes()); + + auto xx00 = f.X0Init(); + auto xx10 = f.X1Init(); + auto yy = f.YInit(); + + ptrdiff_t i = 0; + +#if HWY_MEM_OPS_MIGHT_FAULT + if (n < lane_sz) { + const DFromV d; + // this may not fit on the stack for HWY_RVV, but we do not reach this code + // there + constexpr auto max_lane_sz = + static_cast(RemoveRef::MaxUnitLanes()); + HWY_ALIGN IN0_T xtmp0[static_cast(max_lane_sz)]; + HWY_ALIGN IN1_T xtmp1[static_cast(max_lane_sz)]; + HWY_ALIGN OUT_T ytmp[static_cast(max_lane_sz)]; + + CopyBytes(x0, xtmp0, static_cast(n) * sizeof(IN0_T)); + CopyBytes(x1, xtmp1, static_cast(n) * sizeof(IN1_T)); + xx00 = f.MaskLoad0(0, xtmp0, n); + xx10 = f.MaskLoad1(0, xtmp1, n); + yy = f.Func(0, xx00, xx10, yy); + Store(Zero(d), d, ytmp); + i += f.MaskStore(0, ytmp, yy, n); + i += f.Reduce(yy, ytmp); + CopyBytes(ytmp, y, static_cast(i) * sizeof(OUT_T)); + return; + } +#endif + + if (n > 4 * lane_sz) { + auto xx01 = f.X0Init(); + auto xx11 = f.X1Init(); + auto yy1 = f.YInit(); + auto xx02 = f.X0Init(); + auto xx12 = f.X1Init(); + auto yy2 = f.YInit(); + auto xx03 = f.X0Init(); + auto xx13 = f.X1Init(); + auto yy3 = f.YInit(); + + while (i + 4 * lane_sz - 1 < n) { + xx00 = f.Load0(i, x0); + xx10 = f.Load1(i, x1); + i += lane_sz; + xx01 = f.Load0(i, x0); + xx11 = f.Load1(i, x1); + i += lane_sz; + xx02 = f.Load0(i, x0); + xx12 = f.Load1(i, x1); + i += lane_sz; + xx03 = f.Load0(i, x0); + xx13 = f.Load1(i, x1); + i -= 3 * lane_sz; + + yy = f.Func(i, xx00, xx10, yy); + yy1 = f.Func(i + lane_sz, xx01, xx11, yy1); + yy2 = f.Func(i + 2 * lane_sz, xx02, xx12, yy2); + yy3 = f.Func(i + 3 * lane_sz, xx03, xx13, yy3); + + if (!f.StoreAndShortCircuit(i, y, yy)) return; + i += lane_sz; + if (!f.StoreAndShortCircuit(i, y, yy1)) return; + i += lane_sz; + if (!f.StoreAndShortCircuit(i, y, yy2)) return; + i += lane_sz; + if (!f.StoreAndShortCircuit(i, y, yy3)) return; + i += lane_sz; + } + + f.Reduce(yy3, yy2, yy1, &yy); + } + + while (i + lane_sz - 1 < n) { + xx00 = f.Load0(i, x0); + xx10 = f.Load1(i, x1); + yy = f.Func(i, xx00, xx10, yy); + if (!f.StoreAndShortCircuit(i, y, yy)) return; + i += lane_sz; + } + + if (i != n) { + xx00 = f.MaskLoad0(n - lane_sz, x0, i - n); + xx10 = f.MaskLoad1(n - lane_sz, x1, i - n); + yy = f.Func(n - lane_sz, xx00, xx10, yy); + f.MaskStore(n - lane_sz, y, yy, i - n); + } + + f.Reduce(yy, y); +} + +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ diff --git a/third_party/aom/third_party/highway/hwy/detect_compiler_arch.h b/third_party/aom/third_party/highway/hwy/detect_compiler_arch.h new file mode 100644 index 000000000000..9d4d56b0a053 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/detect_compiler_arch.h @@ -0,0 +1,395 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ +#define HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ + +// Detects compiler and arch from predefined macros. Zero dependencies for +// inclusion by foreach_target.h. + +// Add to #if conditions to prevent IDE from graying out code. +#if (defined __CDT_PARSER__) || (defined __INTELLISENSE__) || \ + (defined Q_CREATOR_RUN) || (defined __CLANGD__) || \ + (defined GROK_ELLIPSIS_BUILD) +#define HWY_IDE 1 +#else +#define HWY_IDE 0 +#endif + +//------------------------------------------------------------------------------ +// Compiler + +// Actual MSVC, not clang-cl, which defines _MSC_VER but doesn't behave like +// MSVC in other aspects (e.g. HWY_DIAGNOSTICS). +#if defined(_MSC_VER) && !defined(__clang__) +#define HWY_COMPILER_MSVC _MSC_VER +#else +#define HWY_COMPILER_MSVC 0 +#endif + +#if defined(_MSC_VER) && defined(__clang__) +#define HWY_COMPILER_CLANGCL _MSC_VER +#else +#define HWY_COMPILER_CLANGCL 0 +#endif + +#ifdef __INTEL_COMPILER +#define HWY_COMPILER_ICC __INTEL_COMPILER +#else +#define HWY_COMPILER_ICC 0 +#endif + +#ifdef __INTEL_LLVM_COMPILER +#define HWY_COMPILER_ICX __INTEL_LLVM_COMPILER +#else +#define HWY_COMPILER_ICX 0 +#endif + +// HWY_COMPILER_GCC is a generic macro for all compilers implementing the GNU +// compiler extensions (eg. Clang, Intel...) +#ifdef __GNUC__ +#define HWY_COMPILER_GCC (__GNUC__ * 100 + __GNUC_MINOR__) +#else +#define HWY_COMPILER_GCC 0 +#endif + +// Clang or clang-cl, not GCC. +#ifdef __clang__ +// In case of Apple LLVM (whose version number is unrelated to that of LLVM) or +// an invalid version number, deduce it from the presence of warnings. +// Originally based on +// https://github.com/simd-everywhere/simde/blob/47d6e603de9d04ee05cdfbc57cf282a02be1bf2a/simde/simde-detect-clang.h#L59. +// Please send updates below to them as well, thanks! +#if defined(__apple_build_version__) || __clang_major__ >= 999 +#if __has_warning("-Woverriding-option") +#define HWY_COMPILER_CLANG 1801 +// No new warnings in 17.0, and Apple LLVM 15.3, which should be 1600, already +// has the unsafe_buffer_usage attribute, so we instead check for new builtins. +#elif __has_builtin(__builtin_nondeterministic_value) +#define HWY_COMPILER_CLANG 1700 +#elif __has_attribute(nouwtable) // no new warnings in 16.0 +#define HWY_COMPILER_CLANG 1600 +#elif __has_warning("-Warray-parameter") +#define HWY_COMPILER_CLANG 1500 +#elif __has_warning("-Wbitwise-instead-of-logical") +#define HWY_COMPILER_CLANG 1400 +#elif __has_warning("-Wreserved-identifier") +#define HWY_COMPILER_CLANG 1300 +#elif __has_warning("-Wformat-insufficient-args") +#define HWY_COMPILER_CLANG 1200 +#elif __has_warning("-Wimplicit-const-int-float-conversion") +#define HWY_COMPILER_CLANG 1100 +#elif __has_warning("-Wmisleading-indentation") +#define HWY_COMPILER_CLANG 1000 +#elif defined(__FILE_NAME__) +#define HWY_COMPILER_CLANG 900 +#elif __has_warning("-Wextra-semi-stmt") || \ + __has_builtin(__builtin_rotateleft32) +#define HWY_COMPILER_CLANG 800 +// For reasons unknown, XCode 10.3 (Apple LLVM version 10.0.1) is apparently +// based on Clang 7, but does not support the warning we test. +// See https://en.wikipedia.org/wiki/Xcode#Toolchain_versions and +// https://trac.macports.org/wiki/XcodeVersionInfo. +#elif __has_warning("-Wc++98-compat-extra-semi") || \ + (defined(__apple_build_version__) && __apple_build_version__ >= 10010000) +#define HWY_COMPILER_CLANG 700 +#else // Anything older than 7.0 is not recommended for Highway. +#define HWY_COMPILER_CLANG 600 +#endif // __has_warning chain +#define HWY_COMPILER3_CLANG (HWY_COMPILER_CLANG * 100) +#else // use normal version +#define HWY_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__) +#define HWY_COMPILER3_CLANG \ + (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) +#endif +#else // Not clang +#define HWY_COMPILER_CLANG 0 +#define HWY_COMPILER3_CLANG 0 +#endif + +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG && !HWY_COMPILER_ICC && \ + !HWY_COMPILER_ICX +#define HWY_COMPILER_GCC_ACTUAL HWY_COMPILER_GCC +#else +#define HWY_COMPILER_GCC_ACTUAL 0 +#endif + +// More than one may be nonzero, but we want at least one. +#if 0 == (HWY_COMPILER_MSVC + HWY_COMPILER_CLANGCL + HWY_COMPILER_ICC + \ + HWY_COMPILER_ICX + HWY_COMPILER_GCC + HWY_COMPILER_CLANG) +#error "Unsupported compiler" +#endif + +// We should only detect one of these (only clang/clangcl/icx overlap) +#if 1 < (!!HWY_COMPILER_MSVC + (!!HWY_COMPILER_ICC & !HWY_COMPILER_ICX) + \ + !!HWY_COMPILER_GCC_ACTUAL + \ + !!(HWY_COMPILER_ICX | HWY_COMPILER_CLANGCL | HWY_COMPILER_CLANG)) +#error "Detected multiple compilers" +#endif + +//------------------------------------------------------------------------------ +// Compiler features and C++ version + +#ifdef __has_builtin +#define HWY_HAS_BUILTIN(name) __has_builtin(name) +#else +#define HWY_HAS_BUILTIN(name) 0 +#endif + +#ifdef __has_attribute +#define HWY_HAS_ATTRIBUTE(name) __has_attribute(name) +#else +#define HWY_HAS_ATTRIBUTE(name) 0 +#endif + +#ifdef __has_cpp_attribute +#define HWY_HAS_CPP_ATTRIBUTE(name) __has_cpp_attribute(name) +#else +#define HWY_HAS_CPP_ATTRIBUTE(name) 0 +#endif + +#ifdef __has_feature +#define HWY_HAS_FEATURE(name) __has_feature(name) +#else +#define HWY_HAS_FEATURE(name) 0 +#endif + +// NOTE: clang ~17 does not correctly handle wrapping __has_include in a macro. + +#if HWY_COMPILER_MSVC && defined(_MSVC_LANG) && _MSVC_LANG > __cplusplus +#define HWY_CXX_LANG _MSVC_LANG +#else +#define HWY_CXX_LANG __cplusplus +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 201603L +#define HWY_CXX17_CONSTEXPR constexpr +#else +#define HWY_CXX17_CONSTEXPR +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304L +#define HWY_CXX14_CONSTEXPR constexpr +#else +#define HWY_CXX14_CONSTEXPR +#endif + +#if HWY_CXX_LANG >= 201703L +#define HWY_IF_CONSTEXPR if constexpr +#else +#define HWY_IF_CONSTEXPR if +#endif + +//------------------------------------------------------------------------------ +// Architecture + +#if defined(__i386__) || defined(_M_IX86) +#define HWY_ARCH_X86_32 1 +#else +#define HWY_ARCH_X86_32 0 +#endif + +#if defined(__x86_64__) || defined(_M_X64) +#define HWY_ARCH_X86_64 1 +#else +#define HWY_ARCH_X86_64 0 +#endif + +#if HWY_ARCH_X86_32 && HWY_ARCH_X86_64 +#error "Cannot have both x86-32 and x86-64" +#endif + +#if HWY_ARCH_X86_32 || HWY_ARCH_X86_64 +#define HWY_ARCH_X86 1 +#else +#define HWY_ARCH_X86 0 +#endif + +#if defined(__powerpc64__) || defined(_M_PPC) || defined(__powerpc__) +#define HWY_ARCH_PPC 1 +#else +#define HWY_ARCH_PPC 0 +#endif + +#if defined(__powerpc64__) || (HWY_ARCH_PPC && defined(__64BIT__)) +#define HWY_ARCH_PPC_64 1 +#else +#define HWY_ARCH_PPC_64 0 +#endif + +// aarch32 is currently not supported; please raise an issue if you want it. +#if defined(__ARM_ARCH_ISA_A64) || defined(__aarch64__) || defined(_M_ARM64) +#define HWY_ARCH_ARM_A64 1 +#else +#define HWY_ARCH_ARM_A64 0 +#endif + +#if (defined(__ARM_ARCH) && __ARM_ARCH == 7) || (defined(_M_ARM) && _M_ARM == 7) +#define HWY_ARCH_ARM_V7 1 +#else +#define HWY_ARCH_ARM_V7 0 +#endif + +#if HWY_ARCH_ARM_A64 && HWY_ARCH_ARM_V7 +#error "Cannot have both A64 and V7" +#endif + +// Any *supported* version of Arm, i.e. 7 or later +#if HWY_ARCH_ARM_A64 || HWY_ARCH_ARM_V7 +#define HWY_ARCH_ARM 1 +#else +#define HWY_ARCH_ARM 0 +#endif + +// Older than Armv7 (e.g. armel aka Armv5) => we do not support SIMD. +#if (defined(__arm__) || defined(_M_ARM)) && !HWY_ARCH_ARM +#define HWY_ARCH_ARM_OLD 1 +#else +#define HWY_ARCH_ARM_OLD 0 +#endif + +#if defined(__EMSCRIPTEN__) || defined(__wasm__) || defined(__WASM__) +#define HWY_ARCH_WASM 1 +#else +#define HWY_ARCH_WASM 0 +#endif + +#ifdef __riscv +#define HWY_ARCH_RISCV 1 +#else +#define HWY_ARCH_RISCV 0 +#endif +// DEPRECATED names; please use HWY_ARCH_RISCV instead. +#define HWY_ARCH_RVV HWY_ARCH_RISCV + +#if HWY_ARCH_RISCV && defined(__riscv_xlen) + +#if __riscv_xlen == 32 +#define HWY_ARCH_RISCV_32 1 +#else +#define HWY_ARCH_RISCV_32 0 +#endif + +#if __riscv_xlen == 64 +#define HWY_ARCH_RISCV_64 1 +#else +#define HWY_ARCH_RISCV_64 0 +#endif + +#else // !HWY_ARCH_RISCV || !defined(__riscv_xlen) +#define HWY_ARCH_RISCV_32 0 +#define HWY_ARCH_RISCV_64 0 +#endif // HWY_ARCH_RISCV && defined(__riscv_xlen) + +#if HWY_ARCH_RISCV_32 && HWY_ARCH_RISCV_64 +#error "Cannot have both RISCV_32 and RISCV_64" +#endif + +#if defined(__s390x__) +#define HWY_ARCH_S390X 1 +#else +#define HWY_ARCH_S390X 0 +#endif + +#if defined(__loongarch64__) || defined(__loongarch64) || \ + (defined(__loongarch_grlen) && __loongarch_grlen == 64) +#define HWY_ARCH_LOONGARCH_64 1 +#else +#define HWY_ARCH_LOONGARCH_64 0 +#endif + +#if defined(__loongarch__) && !HWY_ARCH_LOONGARCH_64 +#define HWY_ARCH_LOONGARCH_32 1 +#else +#define HWY_ARCH_LOONGARCH_32 0 +#endif + +#if HWY_ARCH_LOONGARCH_64 || HWY_ARCH_LOONGARCH_32 +#define HWY_ARCH_LOONGARCH 1 +#else +#define HWY_ARCH_LOONGARCH 0 +#endif + +// It is an error to detect multiple architectures at the same time, but OK to +// detect none of the above. +#if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_ARM_OLD + \ + HWY_ARCH_WASM + HWY_ARCH_RISCV + HWY_ARCH_S390X + HWY_ARCH_LOONGARCH) > 1 +#error "Must not detect more than one architecture" +#endif + +//------------------------------------------------------------------------------ +// Operating system + +#if defined(_WIN32) || defined(_WIN64) +#define HWY_OS_WIN 1 +#else +#define HWY_OS_WIN 0 +#endif + +#if defined(linux) || defined(__linux__) +#define HWY_OS_LINUX 1 +#else +#define HWY_OS_LINUX 0 +#endif + +// iOS or Mac +#if defined(__APPLE__) +#define HWY_OS_APPLE 1 +#else +#define HWY_OS_APPLE 0 +#endif + +#if defined(__FreeBSD__) +#define HWY_OS_FREEBSD 1 +#else +#define HWY_OS_FREEBSD 0 +#endif + +// It is an error to detect multiple OSes at the same time, but OK to +// detect none of the above. +#if (HWY_OS_WIN + HWY_OS_LINUX + HWY_OS_APPLE + HWY_OS_FREEBSD) > 1 +#error "Must not detect more than one OS" +#endif + +//------------------------------------------------------------------------------ +// Endianness + +#if HWY_COMPILER_MSVC +#if HWY_ARCH_PPC && defined(_XBOX_VER) && _XBOX_VER >= 200 +// XBox 360 is big-endian +#define HWY_IS_LITTLE_ENDIAN 0 +#define HWY_IS_BIG_ENDIAN 1 +#else +// All other targets supported by MSVC are little-endian +#define HWY_IS_LITTLE_ENDIAN 1 +#define HWY_IS_BIG_ENDIAN 0 +#endif // HWY_ARCH_PPC && defined(_XBOX_VER) && _XBOX_VER >= 200 +#elif defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) && \ + __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define HWY_IS_LITTLE_ENDIAN 1 +#define HWY_IS_BIG_ENDIAN 0 +#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \ + __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define HWY_IS_LITTLE_ENDIAN 0 +#define HWY_IS_BIG_ENDIAN 1 +#else +#error "Unable to detect endianness or unsupported byte order" +#endif + +#if (HWY_IS_LITTLE_ENDIAN + HWY_IS_BIG_ENDIAN) != 1 +#error "Must only detect one byte order" +#endif + +#endif // HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ diff --git a/third_party/aom/third_party/highway/hwy/detect_targets.h b/third_party/aom/third_party/highway/hwy/detect_targets.h new file mode 100644 index 000000000000..491f3ee8f9dc --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/detect_targets.h @@ -0,0 +1,930 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_DETECT_TARGETS_H_ +#define HIGHWAY_HWY_DETECT_TARGETS_H_ + +// Defines targets and chooses which to enable. + +#include "third_party/highway/hwy/detect_compiler_arch.h" + +//------------------------------------------------------------------------------ +// Optional configuration + +// See g3doc/quick_reference.md for documentation of these macros. + +// Uncomment to override the default baseline determined from predefined macros: +// #define HWY_BASELINE_TARGETS (HWY_SSE4 | HWY_SCALAR) + +// Uncomment to override the default blocklist: +// #define HWY_BROKEN_TARGETS HWY_AVX3 + +// Uncomment to definitely avoid generating those target(s): +// #define HWY_DISABLED_TARGETS HWY_SSE4 + +// Uncomment to avoid emitting BMI/BMI2/FMA instructions (allows generating +// AVX2 target for VMs which support AVX2 but not the other instruction sets) +// #define HWY_DISABLE_BMI2_FMA + +// Uncomment to enable these on MSVC even if the predefined macros are not set. +// #define HWY_WANT_SSE2 1 +// #define HWY_WANT_SSSE3 1 +// #define HWY_WANT_SSE4 1 + +//------------------------------------------------------------------------------ +// Targets + +// Unique bit value for each target. A lower value is "better" (e.g. more lanes) +// than a higher value within the same group/platform - see HWY_STATIC_TARGET. +// +// All values are unconditionally defined so we can test HWY_TARGETS without +// first checking the HWY_ARCH_*. +// +// The C99 preprocessor evaluates #if expressions using intmax_t types. This +// holds at least 64 bits in practice (verified 2022-07-18 via Godbolt on +// 32-bit clang/GCC/MSVC compilers for x86/Arm7/AArch32/RISC-V/WASM). We now +// avoid overflow when computing HWY_TARGETS (subtracting one instead of +// left-shifting 2^62), but still do not use bit 63 because it is the sign bit. + +// --------------------------- x86: 15 targets (+ one fallback) +// Bits 0..2 reserved (3 targets) +#define HWY_AVX10_2_512 (1LL << 3) // AVX10.2 with 512-bit vectors +#define HWY_AVX3_SPR (1LL << 4) +#define HWY_AVX10_2 (1LL << 5) // AVX10.2 with 256-bit vectors +// Currently `HWY_AVX3_DL` plus `AVX512BF16` and a special case for +// `CompressStore` (10x as fast, still useful on Zen5). We may later also use +// `VPCONFLICT`. Note that `VP2INTERSECT` is available in Zen5. +#define HWY_AVX3_ZEN4 (1LL << 6) // see HWY_WANT_AVX3_ZEN4 below + +// Currently satisfiable by Ice Lake (`VNNI`, `VPCLMULQDQ`, `VPOPCNTDQ`, +// `VBMI`, `VBMI2`, `VAES`, `BITALG`, `GFNI`). +#define HWY_AVX3_DL (1LL << 7) +#define HWY_AVX3 (1LL << 8) // HWY_AVX2 plus AVX-512F/BW/CD/DQ/VL +#define HWY_AVX2 (1LL << 9) // HWY_SSE4 plus BMI2 + F16 + FMA +// Bit 10: reserved +#define HWY_SSE4 (1LL << 11) // SSE4.2 plus AES + CLMUL +#define HWY_SSSE3 (1LL << 12) // S-SSE3 +// Bit 13: reserved for SSE3 +#define HWY_SSE2 (1LL << 14) +// The highest bit in the HWY_TARGETS mask that a x86 target can have. Used for +// dynamic dispatch. All x86 target bits must be lower or equal to +// (1 << HWY_HIGHEST_TARGET_BIT_X86) and they can only use +// HWY_MAX_DYNAMIC_TARGETS in total. +#define HWY_HIGHEST_TARGET_BIT_X86 14 + +// --------------------------- Arm: 15 targets (+ one fallback) +// Bits 15..17 reserved (3 targets) +#define HWY_SVE2_128 (1LL << 18) // specialized (e.g. Neoverse V2/N2/N3) +#define HWY_SVE_256 (1LL << 19) // specialized (Neoverse V1) +// Bits 20-22 reserved for later SVE (3 targets) +#define HWY_SVE2 (1LL << 23) +#define HWY_SVE (1LL << 24) +// Bit 25 reserved for NEON +#define HWY_NEON_BF16 (1LL << 26) // fp16/dot/bf16 (e.g. Neoverse V2/N2/N3) +// Bit 27 reserved for NEON +#define HWY_NEON (1LL << 28) // Implies support for AES +#define HWY_NEON_WITHOUT_AES (1LL << 29) +#define HWY_HIGHEST_TARGET_BIT_ARM 29 + +#define HWY_ALL_NEON (HWY_NEON_WITHOUT_AES | HWY_NEON | HWY_NEON_BF16) +#define HWY_ALL_SVE (HWY_SVE | HWY_SVE2 | HWY_SVE_256 | HWY_SVE2_128) + +// --------------------------- RISC-V: 9 targets (+ one fallback) +// Bits 30..36 reserved (7 targets) +#define HWY_RVV (1LL << 37) +// Bit 38 reserved +#define HWY_HIGHEST_TARGET_BIT_RVV 38 + +// --------------------------- LoongArch: 3 targets (+ one fallback) +// Bits 39 reserved (1 target) +#define HWY_LASX (1LL << 40) +#define HWY_LSX (1LL << 41) +#define HWY_HIGHEST_TARGET_BIT_LOONGARCH 41 + +// --------------------------- Future expansion: 1 target +// Bits 42 reserved + +// --------------------------- IBM Power/ZSeries: 9 targets (+ one fallback) +// Bits 43..46 reserved (4 targets) +#define HWY_PPC10 (1LL << 47) // v3.1 +#define HWY_PPC9 (1LL << 48) // v3.0 +#define HWY_PPC8 (1LL << 49) // v2.07 +#define HWY_Z15 (1LL << 50) // Z15 +#define HWY_Z14 (1LL << 51) // Z14 +#define HWY_HIGHEST_TARGET_BIT_PPC 51 + +#define HWY_ALL_PPC (HWY_PPC8 | HWY_PPC9 | HWY_PPC10) + +// --------------------------- WebAssembly: 9 targets (+ one fallback) +// Bits 52..57 reserved (6 targets) +#define HWY_WASM_EMU256 (1LL << 58) // Experimental +#define HWY_WASM (1LL << 59) +// Bits 60 reserved +#define HWY_HIGHEST_TARGET_BIT_WASM 60 + +// --------------------------- Emulation: 2 targets + +#define HWY_EMU128 (1LL << 61) +// We do not add/left-shift, so this will not overflow to a negative number. +#define HWY_SCALAR (1LL << 62) +#define HWY_HIGHEST_TARGET_BIT_SCALAR 62 + +// Do not use bit 63 - would be confusing to have negative numbers. + +//------------------------------------------------------------------------------ +// Set default blocklists + +// Disabled means excluded from enabled at user's request. A separate config +// macro allows disabling without deactivating the blocklist below. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS 0 +#endif + +// Broken means excluded from enabled due to known compiler issues. We define +// separate HWY_BROKEN_* and then OR them together (more than one might apply). + +#ifndef HWY_BROKEN_CLANG6 // allow override +// x86 clang-6: we saw multiple AVX2/3 compile errors and in one case invalid +// SSE4 codegen (possibly only for msan), so disable all those targets. +#if HWY_ARCH_X86 && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_BROKEN_CLANG6 (HWY_SSE4 | (HWY_SSE4 - 1)) +// This entails a major speed reduction, so warn unless the user explicitly +// opts in to scalar-only. +#if !defined(HWY_COMPILE_ONLY_SCALAR) +#pragma message("x86 Clang <= 6: define HWY_COMPILE_ONLY_SCALAR or upgrade.") +#endif + +#else +#define HWY_BROKEN_CLANG6 0 +#endif +#endif // HWY_BROKEN_CLANG6 + +#ifndef HWY_BROKEN_32BIT // allow override +// 32-bit may fail to compile AVX2/3. +#if HWY_ARCH_X86_32 +// GCC-13 is ok with AVX2: +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1300) +#define HWY_BROKEN_32BIT (HWY_AVX3 | (HWY_AVX3 - 1)) +#else +#define HWY_BROKEN_32BIT (HWY_AVX2 | (HWY_AVX2 - 1)) +#endif +#else +#define HWY_BROKEN_32BIT 0 +#endif +#endif // HWY_BROKEN_32BIT + +#ifndef HWY_BROKEN_MSVC // allow override +// MSVC AVX3 support is buggy: https://github.com/Mysticial/Flops/issues/16 +#if HWY_COMPILER_MSVC != 0 +#define HWY_BROKEN_MSVC (HWY_AVX3 | (HWY_AVX3 - 1)) +#else +#define HWY_BROKEN_MSVC 0 +#endif +#endif // HWY_BROKEN_MSVC + +#ifndef HWY_BROKEN_AVX3_DL_ZEN4 // allow override +// AVX3_DL and AVX3_ZEN4 require clang >= 7 (ensured above), gcc >= 8.1 or ICC +// 2021. +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 801) || \ + (HWY_COMPILER_ICC && HWY_COMPILER_ICC < 2021) +#define HWY_BROKEN_AVX3_DL_ZEN4 (HWY_AVX3_DL | HWY_AVX3_ZEN4) +#else +#define HWY_BROKEN_AVX3_DL_ZEN4 0 +#endif +#endif // HWY_BROKEN_AVX3_DL_ZEN4 + +#ifndef HWY_BROKEN_AVX3_SPR // allow override +// AVX3_SPR requires clang >= 14, gcc >= 12, or ICC 2021. +#if (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1400) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200) || \ + (HWY_COMPILER_ICC && HWY_COMPILER_ICC < 2021) +#define HWY_BROKEN_AVX3_SPR (HWY_AVX3_SPR) +#else +#define HWY_BROKEN_AVX3_SPR 0 +#endif +#endif // HWY_BROKEN_AVX3_SPR + +#ifndef HWY_BROKEN_ARM7_BIG_ENDIAN // allow override +// armv7be has not been tested and is not yet supported. +#if HWY_ARCH_ARM_V7 && HWY_IS_BIG_ENDIAN +#define HWY_BROKEN_ARM7_BIG_ENDIAN HWY_ALL_NEON +#else +#define HWY_BROKEN_ARM7_BIG_ENDIAN 0 +#endif +#endif // HWY_BROKEN_ARM7_BIG_ENDIAN + +#ifdef __ARM_NEON_FP +#define HWY_HAVE_NEON_FP __ARM_NEON_FP +#else +#define HWY_HAVE_NEON_FP 0 +#endif + +#ifndef HWY_BROKEN_ARM7_WITHOUT_VFP4 // allow override +// armv7-a without a detected vfpv4 is not supported +// (for example Cortex-A8, Cortex-A9) +// vfpv4 always have neon half-float _and_ FMA. +#if HWY_ARCH_ARM_V7 && (__ARM_ARCH_PROFILE == 'A') && \ + !defined(__ARM_VFPV4__) && \ + !((HWY_HAVE_NEON_FP & 0x2 /* half-float */) && (__ARM_FEATURE_FMA == 1)) +#define HWY_BROKEN_ARM7_WITHOUT_VFP4 HWY_ALL_NEON +#else +#define HWY_BROKEN_ARM7_WITHOUT_VFP4 0 +#endif +#endif // HWY_BROKEN_ARM7_WITHOUT_VFP4 + +#ifndef HWY_BROKEN_NEON_BF16 // allow override +// HWY_NEON_BF16 requires recent compilers. +#if (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1700) || \ + (HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 1302) +#define HWY_BROKEN_NEON_BF16 (HWY_NEON_BF16) +#else +#define HWY_BROKEN_NEON_BF16 0 +#endif +#endif // HWY_BROKEN_NEON_BF16 + +// SVE[2] require recent clang or gcc versions. + +#ifndef HWY_BROKEN_SVE // allow override +// GCC 10+. Clang 19 still has many test failures for SVE. No Apple CPU (at +// least up to and including M4 and A18) has SVE. +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2000) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \ + HWY_OS_APPLE +#define HWY_BROKEN_SVE (HWY_SVE | HWY_SVE_256) +#else +#define HWY_BROKEN_SVE 0 +#endif +#endif // HWY_BROKEN_SVE + +#ifndef HWY_BROKEN_SVE2 // allow override +// Clang 19 still has many test failures for SVE2. +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2000) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \ + HWY_OS_APPLE +#define HWY_BROKEN_SVE2 (HWY_SVE2 | HWY_SVE2_128) +#else +#define HWY_BROKEN_SVE2 0 +#endif +#endif // HWY_BROKEN_SVE2 + +#ifndef HWY_BROKEN_PPC10 // allow override +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1100) +// GCC 10 supports the -mcpu=power10 option but does not support the PPC10 +// vector intrinsics +#define HWY_BROKEN_PPC10 (HWY_PPC10) +#elif HWY_ARCH_PPC && HWY_IS_BIG_ENDIAN && \ + ((HWY_COMPILER3_CLANG && HWY_COMPILER3_CLANG < 160001) || \ + (HWY_COMPILER_GCC_ACTUAL >= 1200 && HWY_COMPILER_GCC_ACTUAL <= 1203) || \ + (HWY_COMPILER_GCC_ACTUAL >= 1300 && HWY_COMPILER_GCC_ACTUAL <= 1301)) +// GCC 12.0 through 12.3 and GCC 13.0 through 13.1 have a compiler bug where the +// vsldoi instruction is sometimes incorrectly optimized out (and this causes +// some of the Highway unit tests to fail on big-endian PPC10). Details about +// this compiler bug can be found at +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109069, and this bug will be +// fixed in the upcoming GCC 12.4 and 13.2 releases. + +// Clang 16.0.0 and earlier (but not Clang 16.0.1 and later) have a compiler +// bug in the LLVM DAGCombiner that causes a zero-extend followed by an +// element insert into a vector, followed by a vector shuffle to be incorrectly +// optimized on big-endian PPC (and which caused some of the Highway unit tests +// to fail on big-endian PPC10). + +// Details about this bug, which has already been fixed in Clang 16.0.1 and +// later, can be found at https://github.com/llvm/llvm-project/issues/61315. +#define HWY_BROKEN_PPC10 (HWY_PPC10) +#else +#define HWY_BROKEN_PPC10 0 +#endif +#endif // HWY_BROKEN_PPC10 + +#ifndef HWY_BROKEN_PPC_32BIT // allow override +// PPC8/PPC9/PPC10 targets may fail to compile on 32-bit PowerPC +#if HWY_ARCH_PPC && !HWY_ARCH_PPC_64 +#define HWY_BROKEN_PPC_32BIT (HWY_PPC8 | HWY_PPC9 | HWY_PPC10) +#else +#define HWY_BROKEN_PPC_32BIT 0 +#endif +#endif // HWY_BROKEN_PPC_32BIT + +#ifndef HWY_BROKEN_RVV // allow override +// HWY_RVV fails to compile with GCC < 13 or Clang < 16. +#if HWY_ARCH_RISCV && \ + ((HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1300)) +#define HWY_BROKEN_RVV (HWY_RVV) +#else +#define HWY_BROKEN_RVV 0 +#endif +#endif // HWY_BROKEN_RVV + +#ifndef HWY_BROKEN_LOONGARCH // allow override +// HWY_LSX/HWY_LASX require GCC 14 or Clang 18. +#if HWY_ARCH_LOONGARCH && \ + ((HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1800) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400)) +#define HWY_BROKEN_LOONGARCH (HWY_LSX | HWY_LASX) +#else +#define HWY_BROKEN_LOONGARCH 0 +#endif +#endif // HWY_BROKEN_LOONGARCH + +#ifndef HWY_BROKEN_Z14 // allow override +#if HWY_ARCH_S390X +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1900 +// Clang 18 and earlier have bugs with some ZVector intrinsics +#define HWY_BROKEN_Z14 (HWY_Z14 | HWY_Z15) +#elif HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 +// Z15 target requires GCC 9 or later +#define HWY_BROKEN_Z14 (HWY_Z15) +#else +#define HWY_BROKEN_Z14 0 +#endif +#else // !HWY_ARCH_S390X +#define HWY_BROKEN_Z14 0 +#endif // HWY_ARCH_S390X +#endif // HWY_BROKEN_Z14 + +// Allow the user to override this without any guarantee of success. +#ifndef HWY_BROKEN_TARGETS + +#define HWY_BROKEN_TARGETS \ + (HWY_BROKEN_CLANG6 | HWY_BROKEN_32BIT | HWY_BROKEN_MSVC | \ + HWY_BROKEN_AVX3_DL_ZEN4 | HWY_BROKEN_AVX3_SPR | \ + HWY_BROKEN_ARM7_BIG_ENDIAN | HWY_BROKEN_ARM7_WITHOUT_VFP4 | \ + HWY_BROKEN_NEON_BF16 | HWY_BROKEN_SVE | HWY_BROKEN_SVE2 | \ + HWY_BROKEN_PPC10 | HWY_BROKEN_PPC_32BIT | HWY_BROKEN_RVV | \ + HWY_BROKEN_LOONGARCH | HWY_BROKEN_Z14) + +#endif // HWY_BROKEN_TARGETS + +// Enabled means not disabled nor blocklisted. +#define HWY_ENABLED(targets) \ + ((targets) & ~((HWY_DISABLED_TARGETS) | (HWY_BROKEN_TARGETS))) + +// Opt-out for EMU128 (affected by a GCC bug on multiple arches, fixed in 12.3: +// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=106322). An issue still +// remains with 13.2, see #1683. This is separate from HWY_BROKEN_TARGETS +// because it affects the fallback target, which must always be enabled. If 1, +// we instead choose HWY_SCALAR even without HWY_COMPILE_ONLY_SCALAR being set. +#if !defined(HWY_BROKEN_EMU128) // allow overriding +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400) || \ + defined(HWY_NO_LIBCXX) +#define HWY_BROKEN_EMU128 1 +#else +#define HWY_BROKEN_EMU128 0 +#endif +#endif // HWY_BROKEN_EMU128 + +//------------------------------------------------------------------------------ +// Detect baseline targets using predefined macros + +// Baseline means the targets for which the compiler is allowed to generate +// instructions, implying the target CPU would have to support them. This does +// not take the blocklist into account. + +#if defined(HWY_COMPILE_ONLY_SCALAR) || HWY_BROKEN_EMU128 +#define HWY_BASELINE_SCALAR HWY_SCALAR +#else +#define HWY_BASELINE_SCALAR HWY_EMU128 +#endif + +// Also check HWY_ARCH to ensure that simulating unknown platforms ends up with +// HWY_TARGET == HWY_BASELINE_SCALAR. + +#if HWY_ARCH_WASM && defined(__wasm_simd128__) +#if defined(HWY_WANT_WASM2) +#define HWY_BASELINE_WASM HWY_WASM_EMU256 +#else +#define HWY_BASELINE_WASM HWY_WASM +#endif // HWY_WANT_WASM2 +#else +#define HWY_BASELINE_WASM 0 +#endif + +// GCC or Clang. +#if HWY_ARCH_PPC && HWY_COMPILER_GCC && defined(__ALTIVEC__) && \ + defined(__VSX__) && defined(__POWER8_VECTOR__) && \ + (defined(__CRYPTO__) || defined(HWY_DISABLE_PPC8_CRYPTO)) +#define HWY_BASELINE_PPC8 HWY_PPC8 +#else +#define HWY_BASELINE_PPC8 0 +#endif + +#if HWY_BASELINE_PPC8 != 0 && defined(__POWER9_VECTOR__) +#define HWY_BASELINE_PPC9 HWY_PPC9 +#else +#define HWY_BASELINE_PPC9 0 +#endif + +#if HWY_BASELINE_PPC9 != 0 && \ + (defined(_ARCH_PWR10) || defined(__POWER10_VECTOR__)) +#define HWY_BASELINE_PPC10 HWY_PPC10 +#else +#define HWY_BASELINE_PPC10 0 +#endif + +#if HWY_ARCH_S390X && defined(__VEC__) && defined(__ARCH__) && __ARCH__ >= 12 +#define HWY_BASELINE_Z14 HWY_Z14 +#else +#define HWY_BASELINE_Z14 0 +#endif + +#if HWY_BASELINE_Z14 && __ARCH__ >= 13 +#define HWY_BASELINE_Z15 HWY_Z15 +#else +#define HWY_BASELINE_Z15 0 +#endif + +#define HWY_BASELINE_SVE2 0 +#define HWY_BASELINE_SVE 0 +#define HWY_BASELINE_NEON 0 + +#if HWY_ARCH_ARM + +// Also check compiler version as done for HWY_ATTAINABLE_SVE2 because the +// static target (influenced here) must be one of the attainable targets. +#if defined(__ARM_FEATURE_SVE2) && \ + (HWY_COMPILER_CLANG >= 1400 || HWY_COMPILER_GCC_ACTUAL >= 1200) +#undef HWY_BASELINE_SVE2 // was 0, will be re-defined +// If user specified -msve-vector-bits=128, they assert the vector length is +// 128 bits and we should use the HWY_SVE2_128 (more efficient for some ops). +#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 128 +#define HWY_BASELINE_SVE2 HWY_SVE2_128 +// Otherwise we're not sure what the vector length will be. The baseline must be +// unconditionally valid, so we can only assume HWY_SVE2. However, when running +// on a CPU with 128-bit vectors, user code that supports dynamic dispatch will +// still benefit from HWY_SVE2_128 because we add it to HWY_ATTAINABLE_TARGETS. +#else +#define HWY_BASELINE_SVE2 HWY_SVE2 +#endif // __ARM_FEATURE_SVE_BITS +#endif // __ARM_FEATURE_SVE2 + +#if defined(__ARM_FEATURE_SVE) && \ + (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800) +#undef HWY_BASELINE_SVE // was 0, will be re-defined +// See above. If user-specified vector length matches our optimization, use it. +#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 256 +#define HWY_BASELINE_SVE HWY_SVE_256 +#else +#define HWY_BASELINE_SVE HWY_SVE +#endif // __ARM_FEATURE_SVE_BITS +#endif // __ARM_FEATURE_SVE + +// GCC 4.5.4 only defines __ARM_NEON__; 5.4 defines both. +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#undef HWY_BASELINE_NEON +#if defined(__ARM_FEATURE_AES) && \ + defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \ + defined(__ARM_FEATURE_DOTPROD) && \ + defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#define HWY_BASELINE_NEON HWY_ALL_NEON +#elif defined(__ARM_FEATURE_AES) +#define HWY_BASELINE_NEON (HWY_NEON_WITHOUT_AES | HWY_NEON) +#else +#define HWY_BASELINE_NEON (HWY_NEON_WITHOUT_AES) +#endif // __ARM_FEATURE* +#endif // __ARM_NEON + +#endif // HWY_ARCH_ARM + +// Special handling for MSVC because it has fewer predefined macros: +#if HWY_COMPILER_MSVC + +#if HWY_ARCH_X86_32 +#if _M_IX86_FP >= 2 +#define HWY_CHECK_SSE2 1 +#else +#define HWY_CHECK_SSE2 0 +#endif +#elif HWY_ARCH_X86_64 +#define HWY_CHECK_SSE2 1 +#else +#define HWY_CHECK_SSE2 0 +#endif + +// 1) We can only be sure SSSE3/SSE4 are enabled if AVX is: +// https://stackoverflow.com/questions/18563978/. +#if defined(__AVX__) +#define HWY_CHECK_SSSE3 1 +#define HWY_CHECK_SSE4 1 +#else +#define HWY_CHECK_SSSE3 0 +#define HWY_CHECK_SSE4 0 +#endif + +// 2) Cannot check for PCLMUL/AES and BMI2/FMA/F16C individually; we assume +// PCLMUL/AES are available if SSE4 is, and BMI2/FMA/F16C if AVX2 is. +#define HWY_CHECK_PCLMUL_AES 1 +#define HWY_CHECK_BMI2_FMA 1 +#define HWY_CHECK_F16C 1 + +#else // non-MSVC + +#if defined(__SSE2__) +#define HWY_CHECK_SSE2 1 +#else +#define HWY_CHECK_SSE2 0 +#endif + +#if defined(__SSSE3__) +#define HWY_CHECK_SSSE3 1 +#else +#define HWY_CHECK_SSSE3 0 +#endif + +#if defined(__SSE4_1__) && defined(__SSE4_2__) +#define HWY_CHECK_SSE4 1 +#else +#define HWY_CHECK_SSE4 0 +#endif + +// If these are disabled, they should not gate the availability of SSE4/AVX2. +#if defined(HWY_DISABLE_PCLMUL_AES) || (defined(__PCLMUL__) && defined(__AES__)) +#define HWY_CHECK_PCLMUL_AES 1 +#else +#define HWY_CHECK_PCLMUL_AES 0 +#endif + +#if defined(HWY_DISABLE_BMI2_FMA) || (defined(__BMI2__) && defined(__FMA__)) +#define HWY_CHECK_BMI2_FMA 1 +#else +#define HWY_CHECK_BMI2_FMA 0 +#endif + +#if defined(HWY_DISABLE_F16C) || defined(__F16C__) +#define HWY_CHECK_F16C 1 +#else +#define HWY_CHECK_F16C 0 +#endif + +#endif // non-MSVC + +#if HWY_ARCH_X86 && \ + ((defined(HWY_WANT_SSE2) && HWY_WANT_SSE2) || HWY_CHECK_SSE2) +#define HWY_BASELINE_SSE2 HWY_SSE2 +#else +#define HWY_BASELINE_SSE2 0 +#endif + +#if HWY_ARCH_X86 && \ + ((defined(HWY_WANT_SSSE3) && HWY_WANT_SSSE3) || HWY_CHECK_SSSE3) +#define HWY_BASELINE_SSSE3 HWY_SSSE3 +#else +#define HWY_BASELINE_SSSE3 0 +#endif + +#if HWY_ARCH_X86 && ((defined(HWY_WANT_SSE4) && HWY_WANT_SSE4) || \ + (HWY_CHECK_SSE4 && HWY_CHECK_PCLMUL_AES)) +#define HWY_BASELINE_SSE4 HWY_SSE4 +#else +#define HWY_BASELINE_SSE4 0 +#endif + +#if HWY_BASELINE_SSE4 != 0 && HWY_CHECK_BMI2_FMA && HWY_CHECK_F16C && \ + defined(__AVX2__) +#define HWY_BASELINE_AVX2 HWY_AVX2 +#else +#define HWY_BASELINE_AVX2 0 +#endif + +// Require everything in AVX2 plus AVX-512 flags (also set by MSVC) +#if HWY_BASELINE_AVX2 != 0 && defined(__AVX512F__) && defined(__AVX512BW__) && \ + defined(__AVX512DQ__) && defined(__AVX512VL__) && \ + ((!HWY_COMPILER_GCC_ACTUAL && !HWY_COMPILER_CLANG) || \ + HWY_COMPILER_GCC_ACTUAL < 1400 || HWY_COMPILER_CLANG < 1800 || \ + defined(__EVEX512__)) +#define HWY_BASELINE_AVX3 HWY_AVX3 +#else +#define HWY_BASELINE_AVX3 0 +#endif + +// TODO(janwas): not yet known whether these will be set by MSVC +#if HWY_BASELINE_AVX3 != 0 && defined(__AVX512VNNI__) && defined(__VAES__) && \ + defined(__VPCLMULQDQ__) && defined(__AVX512VBMI__) && \ + defined(__AVX512VBMI2__) && defined(__AVX512VPOPCNTDQ__) && \ + defined(__AVX512BITALG__) +#define HWY_BASELINE_AVX3_DL HWY_AVX3_DL +#else +#define HWY_BASELINE_AVX3_DL 0 +#endif + +// The ZEN4-optimized AVX3 target is numerically lower than AVX3_DL and is thus +// considered better. Do not enable it unless the user explicitly requests it - +// we do not want to choose the ZEN4 path on Intel because it could be slower. +#if defined(HWY_WANT_AVX3_ZEN4) && HWY_BASELINE_AVX3_DL != 0 +#define HWY_BASELINE_AVX3_ZEN4 HWY_AVX3_ZEN4 +#else +#define HWY_BASELINE_AVX3_ZEN4 0 +#endif + +#if HWY_BASELINE_AVX2 != 0 && defined(__AVX10_2__) +#define HWY_BASELINE_AVX10_2 HWY_AVX10_2 +#else +#define HWY_BASELINE_AVX10_2 0 +#endif + +#if HWY_BASELINE_AVX3_DL != 0 && defined(__AVX512BF16__) && \ + defined(__AVX512FP16__) +#define HWY_BASELINE_AVX3_SPR HWY_AVX3_SPR +#else +#define HWY_BASELINE_AVX3_SPR 0 +#endif + +#if HWY_BASELINE_AVX3_SPR != 0 && defined(__AVX10_2_512__) +#define HWY_BASELINE_AVX10_2_512 HWY_AVX10_2_512 +#else +#define HWY_BASELINE_AVX10_2_512 0 +#endif + +// RVV requires intrinsics 0.11 or later, see #1156. +#if HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ + __riscv_v_intrinsic >= 11000 +#define HWY_BASELINE_RVV HWY_RVV +#else +#define HWY_BASELINE_RVV 0 +#endif + +#if HWY_ARCH_LOONGARCH && defined(__loongarch_sx) && defined(__loongarch_asx) +#define HWY_BASELINE_LOONGARCH (HWY_LSX | HWY_LASX) +#elif HWY_ARCH_LOONGARCH && defined(__loongarch_sx) +#define HWY_BASELINE_LOONGARCH (HWY_LSX) +#else +#define HWY_BASELINE_LOONGARCH 0 +#endif + +// Allow the user to override this without any guarantee of success. +#ifndef HWY_BASELINE_TARGETS +#define HWY_BASELINE_TARGETS \ + (HWY_BASELINE_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | \ + HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10 | HWY_BASELINE_Z14 | \ + HWY_BASELINE_Z15 | HWY_BASELINE_SVE2 | HWY_BASELINE_SVE | \ + HWY_BASELINE_NEON | HWY_BASELINE_SSE2 | HWY_BASELINE_SSSE3 | \ + HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | HWY_BASELINE_AVX3 | \ + HWY_BASELINE_AVX3_DL | HWY_BASELINE_AVX3_ZEN4 | HWY_BASELINE_AVX10_2 | \ + HWY_BASELINE_AVX3_SPR | HWY_BASELINE_AVX10_2_512 | HWY_BASELINE_RVV | \ + HWY_BASELINE_LOONGARCH) +#endif // HWY_BASELINE_TARGETS + +//------------------------------------------------------------------------------ +// Choose target for static dispatch + +#define HWY_ENABLED_BASELINE HWY_ENABLED(HWY_BASELINE_TARGETS) +#if HWY_ENABLED_BASELINE == 0 +#error "At least one baseline target must be defined and enabled" +#endif + +// Best baseline, used for static dispatch. This is the least-significant 1-bit +// within HWY_ENABLED_BASELINE and lower bit values imply "better". +#define HWY_STATIC_TARGET (HWY_ENABLED_BASELINE & -HWY_ENABLED_BASELINE) + +// Start by assuming static dispatch. If we later use dynamic dispatch, this +// will be defined to other targets during the multiple-inclusion, and finally +// return to the initial value. Defining this outside begin/end_target ensures +// inl headers successfully compile by themselves (required by Bazel). +#define HWY_TARGET HWY_STATIC_TARGET + +//------------------------------------------------------------------------------ +// Choose targets for dynamic dispatch according to one of four policies + +// TODO: remove once HWY_LSX is actually supported +#if HWY_ARCH_LOONGARCH && !defined(HWY_COMPILE_ONLY_SCALAR) && \ + !defined(HWY_COMPILE_ONLY_EMU128) +#undef HWY_COMPILE_ONLY_STATIC +#define HWY_COMPILE_ONLY_EMU128 +#endif + +#if 1 < (defined(HWY_COMPILE_ONLY_SCALAR) + defined(HWY_COMPILE_ONLY_EMU128) + \ + defined(HWY_COMPILE_ONLY_STATIC)) +#error "Can only define one of HWY_COMPILE_ONLY_{SCALAR|EMU128|STATIC} - bug?" +#endif +// Defining one of HWY_COMPILE_ONLY_* will trump HWY_COMPILE_ALL_ATTAINABLE. + +#ifndef HWY_HAVE_ASM_HWCAP // allow override +#ifdef TOOLCHAIN_MISS_ASM_HWCAP_H +#define HWY_HAVE_ASM_HWCAP 0 // CMake failed to find the header +#elif defined(__has_include) // note: wrapper macro fails on Clang ~17 +// clang-format off +#if __has_include() +// clang-format on +#define HWY_HAVE_ASM_HWCAP 1 // header present +#else +#define HWY_HAVE_ASM_HWCAP 0 // header not present +#endif // __has_include +#else // compiler lacks __has_include +#define HWY_HAVE_ASM_HWCAP 0 +#endif +#endif // HWY_HAVE_ASM_HWCAP + +#ifndef HWY_HAVE_AUXV // allow override +#ifdef TOOLCHAIN_MISS_SYS_AUXV_H +#define HWY_HAVE_AUXV 0 // CMake failed to find the header +// glibc 2.16 added auxv, but checking for that requires features.h, and we do +// not want to include system headers here. Instead check for the header +// directly, which has been supported at least since GCC 5.4 and Clang 3. +#elif defined(__has_include) // note: wrapper macro fails on Clang ~17 +// clang-format off +#if __has_include() +// clang-format on +#define HWY_HAVE_AUXV 1 // header present +#else +#define HWY_HAVE_AUXV 0 // header not present +#endif // __has_include +#else // compiler lacks __has_include +#define HWY_HAVE_AUXV 0 +#endif +#endif // HWY_HAVE_AUXV + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_RVV // allow override +// The riscv_vector.h in Clang 16-18 requires compiler flags, and 19 still has +// some missing intrinsics, see +// https://github.com/llvm/llvm-project/issues/56592. GCC 13.3 also has an +// #error check, whereas 14.1 fails with "argument type 'vuint16m8_t' requires +// the V ISA extension": https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115325. +#if HWY_ARCH_RISCV && HWY_COMPILER_CLANG >= 1900 && 0 +#define HWY_HAVE_RUNTIME_DISPATCH_RVV 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_RVV 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_RVV + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_APPLE // allow override +#if HWY_ARCH_ARM_A64 && HWY_OS_APPLE && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1700) +#define HWY_HAVE_RUNTIME_DISPATCH_APPLE 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_APPLE 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_APPLE + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_LINUX // allow override +#if (HWY_ARCH_ARM || HWY_ARCH_PPC || HWY_ARCH_S390X) && HWY_OS_LINUX && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1700) && HWY_HAVE_AUXV +#define HWY_HAVE_RUNTIME_DISPATCH_LINUX 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_LINUX 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_LINUX + +// Allow opting out, and without a guarantee of success, opting-in. +#ifndef HWY_HAVE_RUNTIME_DISPATCH +// Clang, GCC and MSVC allow OS-independent runtime dispatch on x86. +#if HWY_ARCH_X86 || HWY_HAVE_RUNTIME_DISPATCH_RVV || \ + HWY_HAVE_RUNTIME_DISPATCH_APPLE || HWY_HAVE_RUNTIME_DISPATCH_LINUX +#define HWY_HAVE_RUNTIME_DISPATCH 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH + +#if HWY_ARCH_ARM_A64 && HWY_HAVE_RUNTIME_DISPATCH +#define HWY_ATTAINABLE_NEON HWY_ALL_NEON +#elif HWY_ARCH_ARM // static dispatch, or HWY_ARCH_ARM_V7 +#define HWY_ATTAINABLE_NEON (HWY_BASELINE_NEON) +#else +#define HWY_ATTAINABLE_NEON 0 +#endif + +#if HWY_ARCH_ARM_A64 && \ + (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800) && \ + (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE | HWY_SVE_256))) +#define HWY_ATTAINABLE_SVE (HWY_SVE | HWY_SVE_256) +#else +#define HWY_ATTAINABLE_SVE 0 +#endif + +#if HWY_ARCH_ARM_A64 && \ + (HWY_COMPILER_CLANG >= 1400 || HWY_COMPILER_GCC_ACTUAL >= 1200) && \ + (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE2 | HWY_SVE2_128))) +#define HWY_ATTAINABLE_SVE2 (HWY_SVE2 | HWY_SVE2_128) +#else +#define HWY_ATTAINABLE_SVE2 0 +#endif + +#if HWY_ARCH_PPC && defined(__ALTIVEC__) && \ + (!HWY_COMPILER_CLANG || HWY_BASELINE_PPC8 != 0) + +#if (HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10) && \ + !defined(HWY_SKIP_NON_BEST_BASELINE) +// On POWER with -m flags, we get compile errors (#1707) for targets older than +// the baseline specified via -m, so only generate the static target and better. +// Note that some Linux distros actually do set POWER9 as the baseline. +// This works by skipping case 3 below, so case 4 is reached. +#define HWY_SKIP_NON_BEST_BASELINE +#endif + +#define HWY_ATTAINABLE_PPC (HWY_PPC8 | HWY_PPC9 | HWY_PPC10) + +#else +#define HWY_ATTAINABLE_PPC 0 +#endif + +#if HWY_ARCH_S390X && HWY_BASELINE_Z14 != 0 +#define HWY_ATTAINABLE_S390X (HWY_Z14 | HWY_Z15) +#else +#define HWY_ATTAINABLE_S390X 0 +#endif + +#if HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH +#define HWY_ATTAINABLE_RISCV HWY_RVV +#else +#define HWY_ATTAINABLE_RISCV HWY_BASELINE_RVV +#endif + +#if HWY_ARCH_LOONGARCH && HWY_HAVE_RUNTIME_DISPATCH +#define HWY_ATTAINABLE_LOONGARCH (HWY_LSX | HWY_LASX) +#else +#define HWY_ATTAINABLE_LOONGARCH HWY_BASELINE_LOONGARCH +#endif + +#ifndef HWY_ATTAINABLE_TARGETS_X86 // allow override +#if HWY_COMPILER_MSVC && defined(HWY_SLOW_MSVC) +// Fewer targets for faster builds. +#define HWY_ATTAINABLE_TARGETS_X86 \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_STATIC_TARGET | HWY_AVX2) +#else // !HWY_COMPILER_MSVC +#define HWY_ATTAINABLE_TARGETS_X86 \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | \ + HWY_AVX2 | HWY_AVX3 | HWY_AVX3_DL | HWY_AVX3_ZEN4 | \ + HWY_AVX3_SPR) +#endif // !HWY_COMPILER_MSVC +#endif // HWY_ATTAINABLE_TARGETS_X86 + +// Attainable means enabled and the compiler allows intrinsics (even when not +// allowed to auto-vectorize). Used in 3 and 4. +#if HWY_ARCH_X86 +#define HWY_ATTAINABLE_TARGETS HWY_ATTAINABLE_TARGETS_X86 +#elif HWY_ARCH_ARM +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_NEON | HWY_ATTAINABLE_SVE | \ + HWY_ATTAINABLE_SVE2) +#elif HWY_ARCH_PPC +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_PPC) +#elif HWY_ARCH_S390X +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_S390X) +#elif HWY_ARCH_RISCV +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_RISCV) +#elif HWY_ARCH_LOONGARCH +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_LOONGARCH) +#else +#define HWY_ATTAINABLE_TARGETS (HWY_ENABLED_BASELINE) +#endif // HWY_ARCH_* + +// 1) For older compilers: avoid SIMD intrinsics, but still support all ops. +#if defined(HWY_COMPILE_ONLY_EMU128) && !HWY_BROKEN_EMU128 +#undef HWY_STATIC_TARGET +#define HWY_STATIC_TARGET HWY_EMU128 // override baseline +#define HWY_TARGETS HWY_EMU128 + +// 1b) HWY_SCALAR is less capable than HWY_EMU128 (which supports all ops), but +// we currently still support it for backwards compatibility. +#elif defined(HWY_COMPILE_ONLY_SCALAR) || \ + (defined(HWY_COMPILE_ONLY_EMU128) && HWY_BROKEN_EMU128) +#undef HWY_STATIC_TARGET +#define HWY_STATIC_TARGET HWY_SCALAR // override baseline +#define HWY_TARGETS HWY_SCALAR + +// 2) For forcing static dispatch without code changes (removing HWY_EXPORT) +#elif defined(HWY_COMPILE_ONLY_STATIC) +#define HWY_TARGETS HWY_STATIC_TARGET + +// 3) For tests: include all attainable targets (in particular: scalar) +#elif (defined(HWY_COMPILE_ALL_ATTAINABLE) || defined(HWY_IS_TEST)) && \ + !defined(HWY_SKIP_NON_BEST_BASELINE) +#define HWY_TARGETS HWY_ATTAINABLE_TARGETS + +// 4) Default: attainable WITHOUT non-best baseline. This reduces code size by +// excluding superseded targets, in particular scalar. Note: HWY_STATIC_TARGET +// may be 2^62 (HWY_SCALAR), so we must not left-shift/add it. Subtracting one +// sets all lower bits (better targets), then we also include the static target. +#else +#define HWY_TARGETS \ + (HWY_ATTAINABLE_TARGETS & ((HWY_STATIC_TARGET - 1LL) | HWY_STATIC_TARGET)) + +#endif // target policy + +// HWY_ONCE and the multiple-inclusion mechanism rely on HWY_STATIC_TARGET being +// one of the dynamic targets. This also implies HWY_TARGETS != 0 and +// (HWY_TARGETS & HWY_ENABLED_BASELINE) != 0. +#if (HWY_TARGETS & HWY_STATIC_TARGET) == 0 +#error "Logic error: best baseline should be included in dynamic targets" +#endif + +#endif // HIGHWAY_HWY_DETECT_TARGETS_H_ diff --git a/third_party/aom/third_party/highway/hwy/examples/skeleton-inl.h b/third_party/aom/third_party/highway/hwy/examples/skeleton-inl.h new file mode 100644 index 000000000000..227ef462e52b --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/examples/skeleton-inl.h @@ -0,0 +1,64 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Demo of functions that might be called from multiple SIMD modules (either +// other -inl.h files, or a .cc file between begin/end_target-inl). This is +// optional - all SIMD code can reside in .cc files. However, this allows +// splitting code into different files while still inlining instead of requiring +// calling through function pointers. + +// Per-target include guard. This is only required when using dynamic dispatch, +// i.e. including foreach_target.h. For static dispatch, a normal include +// guard would be fine because the header is only compiled once. +#if defined(HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#undef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#else +#define HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#endif + +// It is fine to #include normal or *-inl headers. +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace skeleton { +namespace HWY_NAMESPACE { + +// Highway ops reside here; ADL does not find templates nor builtins. +namespace hn = hwy::HWY_NAMESPACE; + +// Example of a type-agnostic (caller-specified lane type) and width-agnostic +// (uses best available instruction set) function in a header. +// +// Computes x[i] = mul_array[i] * x_array[i] + add_array[i] for i < size. +template +HWY_MAYBE_UNUSED void MulAddLoop(const D d, const T* HWY_RESTRICT mul_array, + const T* HWY_RESTRICT add_array, + const size_t size, T* HWY_RESTRICT x_array) { + for (size_t i = 0; i < size; i += hn::Lanes(d)) { + const auto mul = hn::Load(d, mul_array + i); + const auto add = hn::Load(d, add_array + i); + auto x = hn::Load(d, x_array + i); + x = hn::MulAdd(mul, x, add); + hn::Store(x, d, x_array + i); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +#endif // include guard diff --git a/third_party/aom/third_party/highway/hwy/examples/skeleton.h b/third_party/aom/third_party/highway/hwy/examples/skeleton.h new file mode 100644 index 000000000000..55e15a49dc35 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/examples/skeleton.h @@ -0,0 +1,38 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Demo interface to target-specific code in skeleton.cc + +// Normal header with include guard and namespace. +#ifndef HIGHWAY_HWY_EXAMPLES_SKELETON_H_ +#define HIGHWAY_HWY_EXAMPLES_SKELETON_H_ + +// Platform-specific definitions used for declaring an interface, independent of +// the SIMD instruction set. +#include "third_party/highway/hwy/base.h" // HWY_RESTRICT + +namespace skeleton { + +// Computes base-2 logarithm by converting to float. Supports dynamic dispatch. +HWY_DLLEXPORT void CallFloorLog2(const uint8_t* HWY_RESTRICT in, size_t count, + uint8_t* HWY_RESTRICT out); + +// Same, but uses HWY_DYNAMIC_POINTER to save a function pointer and call it. +HWY_DLLEXPORT void SavedCallFloorLog2(const uint8_t* HWY_RESTRICT in, + size_t count, uint8_t* HWY_RESTRICT out); + +} // namespace skeleton + +#endif // HIGHWAY_HWY_EXAMPLES_SKELETON_H_ diff --git a/third_party/aom/third_party/highway/hwy/foreach_target.h b/third_party/aom/third_party/highway/hwy/foreach_target.h new file mode 100644 index 000000000000..33faf8507d2b --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/foreach_target.h @@ -0,0 +1,421 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_FOREACH_TARGET_H_ +#define HIGHWAY_HWY_FOREACH_TARGET_H_ + +// Re-includes the translation unit zero or more times to compile for any +// targets except HWY_STATIC_TARGET. Defines unique HWY_TARGET each time so that +// highway.h defines the corresponding macro/namespace. + +#include "third_party/highway/hwy/detect_targets.h" + +// *_inl.h may include other headers, which requires include guards to prevent +// repeated inclusion. The guards must be reset after compiling each target, so +// the header is again visible. This is done by flipping HWY_TARGET_TOGGLE, +// defining it if undefined and vice versa. This macro is initially undefined +// so that IDEs don't gray out the contents of each header. +#ifdef HWY_TARGET_TOGGLE +#error "This macro must not be defined outside foreach_target.h" +#endif + +#ifdef HWY_HIGHWAY_INCLUDED // highway.h include guard +// Trigger fixup at the bottom of this header. +#define HWY_ALREADY_INCLUDED + +// The next highway.h must re-include set_macros-inl.h because the first +// highway.h chose the static target instead of what we will set below. +#undef HWY_SET_MACROS_PER_TARGET +#endif + +// Disable HWY_EXPORT in user code until we have generated all targets. Note +// that a subsequent highway.h will not override this definition. +#undef HWY_ONCE +#define HWY_ONCE (0 || HWY_IDE) + +// Avoid warnings on #include HWY_TARGET_INCLUDE by hiding them from the IDE; +// also skip if only 1 target defined (no re-inclusion will be necessary). +#if !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +#if !defined(HWY_TARGET_INCLUDE) +#error ">1 target enabled => define HWY_TARGET_INCLUDE before foreach_target.h" +#endif + +// ------------------------------ HWY_ARCH_X86 + +#if (HWY_TARGETS & HWY_SSE2) && (HWY_STATIC_TARGET != HWY_SSE2) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSE2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SSSE3) && (HWY_STATIC_TARGET != HWY_SSSE3) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSSE3 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SSE4) && (HWY_STATIC_TARGET != HWY_SSE4) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSE4 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX2) && (HWY_STATIC_TARGET != HWY_AVX2) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3) && (HWY_STATIC_TARGET != HWY_AVX3) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3_DL) && (HWY_STATIC_TARGET != HWY_AVX3_DL) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3_DL +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3_ZEN4) && (HWY_STATIC_TARGET != HWY_AVX3_ZEN4) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3_ZEN4 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3_SPR) && (HWY_STATIC_TARGET != HWY_AVX3_SPR) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3_SPR +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX10_2) && (HWY_STATIC_TARGET != HWY_AVX10_2) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX10_2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX10_2_512) && (HWY_STATIC_TARGET != HWY_AVX10_2_512) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX10_2_512 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_ARM + +#if (HWY_TARGETS & HWY_NEON_WITHOUT_AES) && \ + (HWY_STATIC_TARGET != HWY_NEON_WITHOUT_AES) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON_WITHOUT_AES +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_NEON) && (HWY_STATIC_TARGET != HWY_NEON) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_NEON_BF16) && (HWY_STATIC_TARGET != HWY_NEON_BF16) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON_BF16 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE) && (HWY_STATIC_TARGET != HWY_SVE) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE2) && (HWY_STATIC_TARGET != HWY_SVE2) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE_256) && (HWY_STATIC_TARGET != HWY_SVE_256) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE_256 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE2_128) && (HWY_STATIC_TARGET != HWY_SVE2_128) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE2_128 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_WASM + +#if (HWY_TARGETS & HWY_WASM_EMU256) && (HWY_STATIC_TARGET != HWY_WASM_EMU256) +#undef HWY_TARGET +#define HWY_TARGET HWY_WASM_EMU256 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_WASM) && (HWY_STATIC_TARGET != HWY_WASM) +#undef HWY_TARGET +#define HWY_TARGET HWY_WASM +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_PPC + +#if (HWY_TARGETS & HWY_PPC8) && (HWY_STATIC_TARGET != HWY_PPC8) +#undef HWY_TARGET +#define HWY_TARGET HWY_PPC8 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_PPC9) && (HWY_STATIC_TARGET != HWY_PPC9) +#undef HWY_TARGET +#define HWY_TARGET HWY_PPC9 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_PPC10) && (HWY_STATIC_TARGET != HWY_PPC10) +#undef HWY_TARGET +#define HWY_TARGET HWY_PPC10 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_S390X + +#if (HWY_TARGETS & HWY_Z14) && (HWY_STATIC_TARGET != HWY_Z14) +#undef HWY_TARGET +#define HWY_TARGET HWY_Z14 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_Z15) && (HWY_STATIC_TARGET != HWY_Z15) +#undef HWY_TARGET +#define HWY_TARGET HWY_Z15 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_RISCV + +#if (HWY_TARGETS & HWY_RVV) && (HWY_STATIC_TARGET != HWY_RVV) +#undef HWY_TARGET +#define HWY_TARGET HWY_RVV +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_LOONGARCH + +#if (HWY_TARGETS & HWY_LSX) && (HWY_STATIC_TARGET != HWY_LSX) +#undef HWY_TARGET +#define HWY_TARGET HWY_LSX +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_LASX) && (HWY_STATIC_TARGET != HWY_LASX) +#undef HWY_TARGET +#define HWY_TARGET HWY_LASX +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ Scalar + +#if (HWY_TARGETS & HWY_EMU128) && (HWY_STATIC_TARGET != HWY_EMU128) +#undef HWY_TARGET +#define HWY_TARGET HWY_EMU128 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SCALAR) && (HWY_STATIC_TARGET != HWY_SCALAR) +#undef HWY_TARGET +#define HWY_TARGET HWY_SCALAR +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#endif // !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +// Now that all but the static target have been generated, re-enable HWY_EXPORT. +#undef HWY_ONCE +#define HWY_ONCE 1 + +// If we re-include once per enabled target, the translation unit's +// implementation would have to be skipped via #if to avoid redefining symbols. +// We instead skip the re-include for HWY_STATIC_TARGET, and generate its +// implementation when resuming compilation of the translation unit. +#undef HWY_TARGET +#define HWY_TARGET HWY_STATIC_TARGET + +#ifdef HWY_ALREADY_INCLUDED +// Revert the previous toggle to prevent redefinitions for the static target. +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif + +// Force re-inclusion of set_macros-inl.h now that HWY_TARGET is restored. +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif +#endif + +#endif // HIGHWAY_HWY_FOREACH_TARGET_H_ diff --git a/third_party/aom/third_party/highway/hwy/highway.h b/third_party/aom/third_party/highway/hwy/highway.h new file mode 100644 index 000000000000..a50d9a271f9d --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/highway.h @@ -0,0 +1,642 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Main header required before using vector types. + +// IWYU pragma: begin_exports +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/detect_compiler_arch.h" +#include "third_party/highway/hwy/detect_targets.h" +#include "third_party/highway/hwy/highway_export.h" +#include "third_party/highway/hwy/targets.h" +// IWYU pragma: end_exports + +#if HWY_CXX_LANG < 201703L +#define HWY_DISPATCH_MAP 1 +#else +#define HWY_DISPATCH_MAP 0 +#endif + +// This include guard is checked by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. NOTE: ops/*-inl.h are included +// after/outside this include guard. +#ifndef HWY_HIGHWAY_INCLUDED +#define HWY_HIGHWAY_INCLUDED + +namespace hwy { + +//------------------------------------------------------------------------------ +// Shorthand for tags (defined in shared-inl.h) used to select overloads. +// Note that ScalableTag is preferred over HWY_FULL, and CappedTag over +// HWY_CAPPED(T, N). + +// HWY_FULL(T[,LMUL=1]) is a native vector/group. LMUL is the number of +// registers in the group, and is ignored on targets that do not support groups. +#define HWY_FULL1(T) hwy::HWY_NAMESPACE::ScalableTag +#define HWY_FULL2(T, LMUL) \ + hwy::HWY_NAMESPACE::ScalableTag +#define HWY_3TH_ARG(arg1, arg2, arg3, ...) arg3 +// Workaround for MSVC grouping __VA_ARGS__ into a single argument +#define HWY_FULL_RECOMPOSER(args_with_paren) HWY_3TH_ARG args_with_paren +// Trailing comma avoids -pedantic false alarm +#define HWY_CHOOSE_FULL(...) \ + HWY_FULL_RECOMPOSER((__VA_ARGS__, HWY_FULL2, HWY_FULL1, )) +#define HWY_FULL(...) HWY_CHOOSE_FULL(__VA_ARGS__())(__VA_ARGS__) + +// Vector of up to MAX_N lanes. It's better to use full vectors where possible. +#define HWY_CAPPED(T, MAX_N) hwy::HWY_NAMESPACE::CappedTag + +//------------------------------------------------------------------------------ +// Export user functions for static/dynamic dispatch + +// Evaluates to 0 inside a translation unit if it is generating anything but the +// static target (the last one if multiple targets are enabled). Used to prevent +// redefinitions of HWY_EXPORT. Unless foreach_target.h is included, we only +// compile once anyway, so this is 1 unless it is or has been included. +#ifndef HWY_ONCE +#define HWY_ONCE 1 +#endif + +// HWY_STATIC_DISPATCH(FUNC_NAME) is the namespace-qualified FUNC_NAME for +// HWY_STATIC_TARGET (the only defined namespace unless HWY_TARGET_INCLUDE is +// defined), and can be used to deduce the return type of Choose*. +#if HWY_STATIC_TARGET == HWY_SCALAR +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SCALAR::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_EMU128 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_EMU128::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_WASM +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_WASM_EMU256 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM_EMU256::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_Z14 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_Z14::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_Z15 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_Z15::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_PPC8 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC8::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_PPC9 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC9::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_PPC10 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC10::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_LSX +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_LSX::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_LASX +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_LASX::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_RVV +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_RVV::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_NEON_WITHOUT_AES +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON_WITHOUT_AES::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_NEON +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_NEON_BF16 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON_BF16::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE2 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE_256 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE_256::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE2_128 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2_128::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SSE2 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE2::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SSSE3 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSSE3::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SSE4 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE4::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX2 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX2::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX3 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX3_DL +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_DL::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX3_ZEN4 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_ZEN4::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX10_2 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX10_2::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX3_SPR +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_SPR::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX10_2_512 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX10_2_512::FUNC_NAME +#endif + +// HWY_CHOOSE_*(FUNC_NAME) expands to the function pointer for that target or +// nullptr is that target was not compiled. +#if HWY_TARGETS & HWY_EMU128 +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_EMU128::FUNC_NAME +#elif HWY_TARGETS & HWY_SCALAR +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_SCALAR::FUNC_NAME +#else +// When HWY_SCALAR/HWY_EMU128 are not present and other targets were disabled at +// runtime, fall back to the baseline with HWY_STATIC_DISPATCH(). +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME) +#endif + +#if HWY_TARGETS & HWY_WASM +#define HWY_CHOOSE_WASM(FUNC_NAME) &N_WASM::FUNC_NAME +#else +#define HWY_CHOOSE_WASM(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_WASM_EMU256 +#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) &N_WASM_EMU256::FUNC_NAME +#else +#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_Z14 +#define HWY_CHOOSE_Z14(FUNC_NAME) &N_Z14::FUNC_NAME +#else +#define HWY_CHOOSE_Z14(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_Z15 +#define HWY_CHOOSE_Z15(FUNC_NAME) &N_Z15::FUNC_NAME +#else +#define HWY_CHOOSE_Z15(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_PPC8 +#define HWY_CHOOSE_PPC8(FUNC_NAME) &N_PPC8::FUNC_NAME +#else +#define HWY_CHOOSE_PPC8(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_PPC9 +#define HWY_CHOOSE_PPC9(FUNC_NAME) &N_PPC9::FUNC_NAME +#else +#define HWY_CHOOSE_PPC9(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_LSX +#define HWY_CHOOSE_LSX(FUNC_NAME) &N_LSX::FUNC_NAME +#else +#define HWY_CHOOSE_LSX(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_LASX +#define HWY_CHOOSE_LASX(FUNC_NAME) &N_LASX::FUNC_NAME +#else +#define HWY_CHOOSE_LASX(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_PPC10 +#define HWY_CHOOSE_PPC10(FUNC_NAME) &N_PPC10::FUNC_NAME +#else +#define HWY_CHOOSE_PPC10(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_RVV +#define HWY_CHOOSE_RVV(FUNC_NAME) &N_RVV::FUNC_NAME +#else +#define HWY_CHOOSE_RVV(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_NEON_WITHOUT_AES +#define HWY_CHOOSE_NEON_WITHOUT_AES(FUNC_NAME) &N_NEON_WITHOUT_AES::FUNC_NAME +#else +#define HWY_CHOOSE_NEON_WITHOUT_AES(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_NEON +#define HWY_CHOOSE_NEON(FUNC_NAME) &N_NEON::FUNC_NAME +#else +#define HWY_CHOOSE_NEON(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_NEON_BF16 +#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) &N_NEON_BF16::FUNC_NAME +#else +#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE +#define HWY_CHOOSE_SVE(FUNC_NAME) &N_SVE::FUNC_NAME +#else +#define HWY_CHOOSE_SVE(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE2 +#define HWY_CHOOSE_SVE2(FUNC_NAME) &N_SVE2::FUNC_NAME +#else +#define HWY_CHOOSE_SVE2(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE_256 +#define HWY_CHOOSE_SVE_256(FUNC_NAME) &N_SVE_256::FUNC_NAME +#else +#define HWY_CHOOSE_SVE_256(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE2_128 +#define HWY_CHOOSE_SVE2_128(FUNC_NAME) &N_SVE2_128::FUNC_NAME +#else +#define HWY_CHOOSE_SVE2_128(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SSE2 +#define HWY_CHOOSE_SSE2(FUNC_NAME) &N_SSE2::FUNC_NAME +#else +#define HWY_CHOOSE_SSE2(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SSSE3 +#define HWY_CHOOSE_SSSE3(FUNC_NAME) &N_SSSE3::FUNC_NAME +#else +#define HWY_CHOOSE_SSSE3(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SSE4 +#define HWY_CHOOSE_SSE4(FUNC_NAME) &N_SSE4::FUNC_NAME +#else +#define HWY_CHOOSE_SSE4(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX2 +#define HWY_CHOOSE_AVX2(FUNC_NAME) &N_AVX2::FUNC_NAME +#else +#define HWY_CHOOSE_AVX2(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX3 +#define HWY_CHOOSE_AVX3(FUNC_NAME) &N_AVX3::FUNC_NAME +#else +#define HWY_CHOOSE_AVX3(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX3_DL +#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) &N_AVX3_DL::FUNC_NAME +#else +#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX3_ZEN4 +#define HWY_CHOOSE_AVX3_ZEN4(FUNC_NAME) &N_AVX3_ZEN4::FUNC_NAME +#else +#define HWY_CHOOSE_AVX3_ZEN4(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX10_2 +#define HWY_CHOOSE_AVX10_2(FUNC_NAME) &N_AVX10_2::FUNC_NAME +#else +#define HWY_CHOOSE_AVX10_2(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX3_SPR +#define HWY_CHOOSE_AVX3_SPR(FUNC_NAME) &N_AVX3_SPR::FUNC_NAME +#else +#define HWY_CHOOSE_AVX3_SPR(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX10_2_512 +#define HWY_CHOOSE_AVX10_2_512(FUNC_NAME) &N_AVX10_2_512::FUNC_NAME +#else +#define HWY_CHOOSE_AVX10_2_512(FUNC_NAME) nullptr +#endif + +// MSVC 2017 workaround: the non-type template parameter to ChooseAndCall +// apparently cannot be an array. Use a function pointer instead, which has the +// disadvantage that we call the static (not best) target on the first call to +// any HWY_DYNAMIC_DISPATCH. +#if (HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1915) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700) +#define HWY_DISPATCH_WORKAROUND 1 +#else +#define HWY_DISPATCH_WORKAROUND 0 +#endif + +#if HWY_DISPATCH_MAP +struct AllExports { + template + static const FuncPtr*& GetRefToExportsPtr() { + static const FuncPtr* s_exports = nullptr; + return s_exports; + } +}; +#endif + +// Provides a static member function which is what is called during the first +// HWY_DYNAMIC_DISPATCH, where GetIndex is still zero, and instantiations of +// this function are the first entry in the tables created by HWY_EXPORT[_T]. +template +struct FunctionCache { + public: + typedef RetType(FuncType)(Args...); + using FuncPtr = FuncType*; + + // A template function that when instantiated has the same signature as the + // function being called. This function initializes the bit array of targets + // supported by the current CPU and then calls the appropriate entry within + // the HWY_EXPORT table. Subsequent calls via HWY_DYNAMIC_DISPATCH to any + // exported functions, even those defined by different translation units, + // will dispatch directly to the best available target. +#if HWY_DISPATCH_MAP + template + static RetType ChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + + const FuncPtr* table = AllExports::template GetRefToExportsPtr< + FuncPtr, RemoveCvRef, kHash>(); + HWY_ASSERT(table); + + return (table[chosen_target.GetIndex()])(args...); + } + +#if !HWY_DISPATCH_WORKAROUND + template + static RetType TableChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + return (table[chosen_target.GetIndex()])(args...); + } +#endif // !HWY_DISPATCH_WORKAROUND + +#else // !HWY_DISPATCH_MAP: zero-overhead, but requires C++17 + template + static RetType ChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + return (table[chosen_target.GetIndex()])(args...); + } +#endif // HWY_DISPATCH_MAP +}; + +// Used to deduce the template parameters RetType and Args from a function. +template +FunctionCache DeduceFunctionCache(RetType (*)(Args...)) { + return FunctionCache(); +} + +#define HWY_DISPATCH_TABLE(FUNC_NAME) \ + HWY_CONCAT(FUNC_NAME, HighwayDispatchTable) + +// HWY_EXPORT(FUNC_NAME); expands to a static array that is used by +// HWY_DYNAMIC_DISPATCH() to call the appropriate function at runtime. +// After being exported, it can be called from other parts of the same source +// file using HWY_DYNAMIC_DISPATCH(), in particular from a function wrapper +// like in the following example: +// +// #include "third_party/highway/hwy/highway.h" +// HWY_BEFORE_NAMESPACE(); +// namespace skeleton { +// namespace HWY_NAMESPACE { +// +// void MyFunction(int a, char b, const char* c) { ... } +// +// // NOLINTNEXTLINE(google-readability-namespace-comments) +// } // namespace HWY_NAMESPACE +// } // namespace skeleton +// HWY_AFTER_NAMESPACE(); +// +// namespace skeleton { +// HWY_EXPORT(MyFunction); // Defines the dispatch table in this scope. +// +// void MyFunction(int a, char b, const char* c) { +// return HWY_DYNAMIC_DISPATCH(MyFunction)(a, b, c); +// } +// } // namespace skeleton +// +// For templated code with a single type parameter, instead use HWY_EXPORT_T and +// its HWY_DYNAMIC_DISPATCH_T counterpart: +// +// template +// void MyFunctionCaller(T ...) { +// // First argument to both HWY_EXPORT_T and HWY_DYNAMIC_DISPATCH_T is an +// // arbitrary table name; you must provide the same name for each call. +// // It is fine to have multiple HWY_EXPORT_T in a function, but a 64-bit +// // FNV hash collision among *any* table names will trigger HWY_ABORT. +// HWY_EXPORT_T(Table1, MyFunction) +// HWY_DYNAMIC_DISPATCH_T(Table1)(a, b, c); +// } +// +// Note that HWY_EXPORT_T must be invoked inside a template (in the above +// example: `MyFunctionCaller`), so that a separate table will be created for +// each template instantiation. For convenience, we also provide a macro that +// combines both steps and avoids the need to pick a table name: +// +// template +// void MyFunctionCaller(T ...) { +// // Table name is automatically chosen. Note that this variant must be +// // called in statement context; it is not a valid expression. +// HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(MyFunction)(a, b, c); +// } + +// Simplified version for IDE or the dynamic dispatch case with only one target. +#if HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +// We use a table to provide the same compile error conditions as with the +// non-simplified case, but the table only has a single entry. +#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \ + HWY_MAYBE_UNUSED static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const \ + HWY_DISPATCH_TABLE(TABLE_NAME)[1] = {&HWY_STATIC_DISPATCH(FUNC_NAME)} + +// Use the table, not just STATIC_DISPATCH as in DYNAMIC_DISPATCH, because +// TABLE_NAME might not match the function name. +#define HWY_DYNAMIC_POINTER_T(TABLE_NAME) (HWY_DISPATCH_TABLE(TABLE_NAME)[0]) +#define HWY_DYNAMIC_DISPATCH_T(TABLE_NAME) \ + (*(HWY_DYNAMIC_POINTER_T(TABLE_NAME))) + +#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME) +#define HWY_DYNAMIC_POINTER(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME) +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) HWY_STATIC_DISPATCH(FUNC_NAME) + +#else // not simplified: full table + +// Pre-C++17 workaround: non-type template arguments must have linkage, which +// means we cannot pass &table as a template argument to ChooseAndCall. +// ChooseAndCall must find a way to access the table in order to dispatch to the +// chosen target: +// 0) Skipping this by dispatching to the static target would be surprising to +// users and may have serious performance implications. +// 1) An extra function parameter would be unacceptable because it changes the +// user-visible function signature. +// 2) Declaring a table, then defining a pointer to it would work, but requires +// an additional DECLARE step outside the function so that the pointer has +// linkage, which breaks existing code. +// 3) We instead associate the function with the table using an instance of an +// unnamed struct and the hash of the table name as the key. Because +// ChooseAndCall has the type information, it can then cast to the function +// pointer type. However, we cannot simply pass the name as a template +// argument to ChooseAndCall because this requires char*, which hits the same +// linkage problem. We instead hash the table name, which assumes the +// function names do not have collisions. +#if HWY_DISPATCH_MAP + +static constexpr uint64_t FNV(const char* name) { + return *name ? static_cast(static_cast(*name)) ^ + (0x100000001b3ULL * FNV(name + 1)) + : 0xcbf29ce484222325ULL; +} + +template +struct AddExport { + template + AddExport(ExportsKey /*exports_key*/, const char* table_name, + const FuncPtr* table) { + using FuncCache = decltype(DeduceFunctionCache(hwy::DeclVal())); + static_assert( + hwy::IsSame, typename FuncCache::FuncPtr>(), + "FuncPtr should be same type as FuncCache::FuncPtr"); + + const FuncPtr*& exports_ptr = AllExports::template GetRefToExportsPtr< + RemoveCvRef, RemoveCvRef, kHash>(); + if (exports_ptr && exports_ptr != table) { + HWY_ABORT("Hash collision for %s, rename the function\n", table_name); + } else { + exports_ptr = table; + } + } +}; + +// Dynamic dispatch: defines table of function pointers. This must be invoked +// from inside the function template that calls the template we are exporting. +// TABLE_NAME must match the one passed to HWY_DYNAMIC_DISPATCH_T. This +// argument allows multiple exports within one function. +#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \ + static const struct { \ + } HWY_CONCAT(TABLE_NAME, HighwayDispatchExportsKey) = {}; \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + TABLE_NAME)[static_cast(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \ + template ChooseAndCall, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + }; \ + HWY_MAYBE_UNUSED static hwy::AddExport HWY_CONCAT( \ + HighwayAddTable, __LINE__)( \ + HWY_CONCAT(TABLE_NAME, HighwayDispatchExportsKey), #TABLE_NAME, \ + HWY_DISPATCH_TABLE(TABLE_NAME)) + +// For non-template functions. Not necessarily invoked within a function, hence +// we derive the string and variable names from FUNC_NAME, not HWY_FUNCTION. +#if HWY_DISPATCH_WORKAROUND +#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME) +#else +#define HWY_EXPORT(FUNC_NAME) \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + FUNC_NAME)[static_cast(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \ + template TableChooseAndCall, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + } +#endif // HWY_DISPATCH_WORKAROUND + +#else // !HWY_DISPATCH_MAP + +// Zero-overhead, but requires C++17 for non-type template arguments without +// linkage, because HWY_EXPORT_T tables are local static variables. +#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + TABLE_NAME)[static_cast(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \ + template ChooseAndCall, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + } + +#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME) + +#endif // HWY_DISPATCH_MAP + +// HWY_DISPATCH_MAP only affects how tables are created, not their usage. + +// Evaluates to the function pointer for the chosen target. +#define HWY_DYNAMIC_POINTER(FUNC_NAME) \ + (HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::GetChosenTarget().GetIndex()]) + +// Calls the function pointer for the chosen target. +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) (*(HWY_DYNAMIC_POINTER(FUNC_NAME))) + +// Same as DISPATCH, but provide a different arg name to clarify usage. +#define HWY_DYNAMIC_DISPATCH_T(TABLE_NAME) HWY_DYNAMIC_DISPATCH(TABLE_NAME) +#define HWY_DYNAMIC_POINTER_T(TABLE_NAME) HWY_DYNAMIC_POINTER(TABLE_NAME) + +#endif // HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +// Returns the name of an anonymous dispatch table that is only shared with +// macro invocations coming from the same source line. +#define HWY_DISPATCH_TABLE_T() HWY_CONCAT(HighwayDispatchTableT, __LINE__) + +// For templated code, combines export and dispatch using an anonymous table. +#define HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC_NAME) \ + HWY_EXPORT_T(HWY_DISPATCH_TABLE_T(), FUNC_NAME); \ + HWY_DYNAMIC_DISPATCH_T(HWY_DISPATCH_TABLE_T()) + +// DEPRECATED names; please use HWY_HAVE_* instead. +#define HWY_CAP_INTEGER64 HWY_HAVE_INTEGER64 +#define HWY_CAP_FLOAT16 HWY_HAVE_FLOAT16 +#define HWY_CAP_FLOAT64 HWY_HAVE_FLOAT64 + +} // namespace hwy + +#endif // HWY_HIGHWAY_INCLUDED + +//------------------------------------------------------------------------------ + +// NOTE: the following definitions and ops/*.h depend on HWY_TARGET, so we want +// to include them once per target, which is ensured by the toggle check. +// Because ops/*.h are included under it, they do not need their own guard. +#if defined(HWY_HIGHWAY_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_HIGHWAY_PER_TARGET +#undef HWY_HIGHWAY_PER_TARGET +#else +#define HWY_HIGHWAY_PER_TARGET +#endif + +// These define ops inside namespace hwy::HWY_NAMESPACE. +#if HWY_TARGET == HWY_SSE2 || HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 +#include "third_party/highway/hwy/ops/x86_128-inl.h" +#elif HWY_TARGET == HWY_AVX2 +#include "third_party/highway/hwy/ops/x86_256-inl.h" +#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL || \ + HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX10_2 || \ + HWY_TARGET == HWY_AVX3_SPR || HWY_TARGET == HWY_AVX10_2_512 +#include "third_party/highway/hwy/ops/x86_avx3-inl.h" +#elif HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 || \ + (HWY_TARGET & HWY_ALL_PPC) +#include "third_party/highway/hwy/ops/ppc_vsx-inl.h" +#elif HWY_TARGET & HWY_ALL_NEON +#include "third_party/highway/hwy/ops/arm_neon-inl.h" +#elif HWY_TARGET & HWY_ALL_SVE +#include "third_party/highway/hwy/ops/arm_sve-inl.h" +#elif HWY_TARGET == HWY_WASM_EMU256 +#include "third_party/highway/hwy/ops/wasm_256-inl.h" +#elif HWY_TARGET == HWY_WASM +#include "third_party/highway/hwy/ops/wasm_128-inl.h" +#elif HWY_TARGET == HWY_RVV +#include "third_party/highway/hwy/ops/rvv-inl.h" +#elif HWY_TARGET == HWY_EMU128 +#include "third_party/highway/hwy/ops/emu128-inl.h" +#elif HWY_TARGET == HWY_SCALAR +#include "third_party/highway/hwy/ops/scalar-inl.h" +#elif HWY_TARGET == HWY_LSX || HWY_TARGET == HWY_LASX +#include "third_party/highway/hwy/ops/loongarch_lsx-inl.h" +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +#include "third_party/highway/hwy/ops/generic_ops-inl.h" + +#endif // HWY_HIGHWAY_PER_TARGET diff --git a/third_party/aom/third_party/highway/hwy/highway_export.h b/third_party/aom/third_party/highway/hwy/highway_export.h new file mode 100644 index 000000000000..30edc17d0132 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/highway_export.h @@ -0,0 +1,74 @@ +// Pseudo-generated file to handle both cmake & bazel build system. + +// Initial generation done using cmake code: +// include(GenerateExportHeader) +// generate_export_header(hwy EXPORT_MACRO_NAME HWY_DLLEXPORT EXPORT_FILE_NAME +// hwy/highway_export.h) +// code reformatted using clang-format --style=Google + +#ifndef HWY_DLLEXPORT_H +#define HWY_DLLEXPORT_H + +#if !defined(HWY_SHARED_DEFINE) +#define HWY_DLLEXPORT +#define HWY_CONTRIB_DLLEXPORT +#define HWY_TEST_DLLEXPORT +#else // !HWY_SHARED_DEFINE + +#ifndef HWY_DLLEXPORT +#if defined(hwy_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllexport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllimport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_EXPORTS) +#endif // HWY_DLLEXPORT + +#ifndef HWY_CONTRIB_DLLEXPORT +#if defined(hwy_contrib_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllexport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_contrib_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllimport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_contrib_EXPORTS) +#endif // HWY_CONTRIB_DLLEXPORT + +#ifndef HWY_TEST_DLLEXPORT +#if defined(hwy_test_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllexport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_test_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllimport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_test_EXPORTS) +#endif // HWY_TEST_DLLEXPORT + +#endif // !HWY_SHARED_DEFINE + +#endif /* HWY_DLLEXPORT_H */ diff --git a/third_party/aom/third_party/highway/hwy/hwy.version b/third_party/aom/third_party/highway/hwy/hwy.version new file mode 100644 index 000000000000..9ff6be6a2d72 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/hwy.version @@ -0,0 +1,19 @@ +HWY_0 { + global: + extern "C++" { + *hwy::*; + }; + + local: + # Hide all the std namespace symbols. std namespace is explicitly marked + # as visibility(default) and header-only functions or methods (such as those + # from templates) should be exposed in shared libraries as weak symbols but + # this is only needed when we expose those types in the shared library API + # in any way. We don't use C++ std types in the API and we also don't + # support exceptions in the library. + # See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=36022 for a discussion + # about this. + extern "C++" { + *std::*; + }; +}; diff --git a/third_party/aom/third_party/highway/hwy/nanobenchmark.h b/third_party/aom/third_party/highway/hwy/nanobenchmark.h new file mode 100644 index 000000000000..9001767a51b5 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/nanobenchmark.h @@ -0,0 +1,153 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_NANOBENCHMARK_H_ +#define HIGHWAY_HWY_NANOBENCHMARK_H_ + +// Benchmarks functions of a single integer argument with realistic branch +// prediction hit rates. Uses a robust estimator to summarize the measurements. +// The precision is about 0.2%. +// +// Examples: see nanobenchmark_test.cc. +// +// Background: Microbenchmarks such as http://github.com/google/benchmark +// can measure elapsed times on the order of a microsecond. Shorter functions +// are typically measured by repeating them thousands of times and dividing +// the total elapsed time by this count. Unfortunately, repetition (especially +// with the same input parameter!) influences the runtime. In time-critical +// code, it is reasonable to expect warm instruction/data caches and TLBs, +// but a perfect record of which branches will be taken is unrealistic. +// Unless the application also repeatedly invokes the measured function with +// the same parameter, the benchmark is measuring something very different - +// a best-case result, almost as if the parameter were made a compile-time +// constant. This may lead to erroneous conclusions about branch-heavy +// algorithms outperforming branch-free alternatives. +// +// Our approach differs in three ways. Adding fences to the timer functions +// reduces variability due to instruction reordering, improving the timer +// resolution to about 40 CPU cycles. However, shorter functions must still +// be invoked repeatedly. For more realistic branch prediction performance, +// we vary the input parameter according to a user-specified distribution. +// Thus, instead of VaryInputs(Measure(Repeat(func))), we change the +// loop nesting to Measure(Repeat(VaryInputs(func))). We also estimate the +// central tendency of the measurement samples with the "half sample mode", +// which is more robust to outliers and skewed data than the mean or median. + +#include +#include + +#include "third_party/highway/hwy/highway_export.h" +#include "third_party/highway/hwy/timer.h" // IWYU pragma: export + +namespace hwy { + +// Returns 1, but without the compiler knowing what the value is. This prevents +// optimizing out code. +HWY_DLLEXPORT int Unpredictable1(); + +// Input influencing the function being measured (e.g. number of bytes to copy). +using FuncInput = size_t; + +// "Proof of work" returned by Func to ensure the compiler does not elide it. +using FuncOutput = uint64_t; + +// Function to measure: either 1) a captureless lambda or function with two +// arguments or 2) a lambda with capture, in which case the first argument +// is reserved for use by MeasureClosure. +using Func = FuncOutput (*)(const void*, FuncInput); + +// Internal parameters that determine precision/resolution/measuring time. +struct Params { + // Best-case precision, expressed as a divisor of the timer resolution. + // Larger => more calls to Func and higher precision. + size_t precision_divisor = 1024; + + // Ratio between full and subset input distribution sizes. Cannot be less + // than 2; larger values increase measurement time but more faithfully + // model the given input distribution. + size_t subset_ratio = 2; + + // Together with the estimated Func duration, determines how many times to + // call Func before checking the sample variability. Larger values increase + // measurement time, memory/cache use and precision. + double seconds_per_eval = 4E-3; + + // The minimum number of samples before estimating the central tendency. + size_t min_samples_per_eval = 7; + + // The mode is better than median for estimating the central tendency of + // skewed/fat-tailed distributions, but it requires sufficient samples + // relative to the width of half-ranges. + size_t min_mode_samples = 64; + + // Maximum permissible variability (= median absolute deviation / center). + double target_rel_mad = 0.002; + + // Abort after this many evals without reaching target_rel_mad. This + // prevents infinite loops. + size_t max_evals = 9; + + // Whether to print additional statistics to stdout. + bool verbose = true; +}; + +// Measurement result for each unique input. +struct Result { + FuncInput input; + + // Robust estimate (mode or median) of duration. + float ticks; + + // Measure of variability (median absolute deviation relative to "ticks"). + float variability; +}; + +// Precisely measures the number of ticks elapsed when calling "func" with the +// given inputs, shuffled to ensure realistic branch prediction hit rates. +// +// "func" returns a 'proof of work' to ensure its computations are not elided. +// "arg" is passed to Func, or reserved for internal use by MeasureClosure. +// "inputs" is an array of "num_inputs" (not necessarily unique) arguments to +// "func". The values should be chosen to maximize coverage of "func". This +// represents a distribution, so a value's frequency should reflect its +// probability in the real application. Order does not matter; for example, a +// uniform distribution over [0, 4) could be represented as {3,0,2,1}. +// Returns how many Result were written to "results": one per unique input, or +// zero if the measurement failed (an error message goes to stderr). +HWY_DLLEXPORT size_t Measure(Func func, const uint8_t* arg, + const FuncInput* inputs, size_t num_inputs, + Result* results, const Params& p = Params()); + +// Calls operator() of the given closure (lambda function). +template +static FuncOutput CallClosure(const Closure* f, const FuncInput input) { + return (*f)(input); +} + +// Same as Measure, except "closure" is typically a lambda function of +// FuncInput -> FuncOutput with a capture list. +template +static inline size_t MeasureClosure(const Closure& closure, + const FuncInput* inputs, + const size_t num_inputs, Result* results, + const Params& p = Params()) { + return Measure(reinterpret_cast(&CallClosure), + reinterpret_cast(&closure), inputs, num_inputs, + results, p); +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_NANOBENCHMARK_H_ diff --git a/third_party/aom/third_party/highway/hwy/ops/arm_neon-inl.h b/third_party/aom/third_party/highway/hwy/ops/arm_neon-inl.h new file mode 100644 index 000000000000..f7e587eb3f4d --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/arm_neon-inl.h @@ -0,0 +1,10469 @@ +// Copyright 2019 Google LLC +// Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit Arm NEON vectors and operations. +// External include guard in highway.h - see comment there. + +// Arm NEON intrinsics are documented at: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon] + +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/ops/shared-inl.h" + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +#include // NOLINT(build/include_order) +HWY_DIAGNOSTICS(pop) + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { // for code folding and Raw128 + +// Macros used to define single and double function calls for multiple types +// for full and half vectors. These macros are undefined at the end of the file. + +// HWY_NEON_BUILD_TPL_* is the template<...> prefix to the function. +#define HWY_NEON_BUILD_TPL_1 +#define HWY_NEON_BUILD_TPL_2 +#define HWY_NEON_BUILD_TPL_3 + +// HWY_NEON_BUILD_RET_* is return type; type arg is without _t suffix so we can +// extend it to int32x4x2_t packs. +#define HWY_NEON_BUILD_RET_1(type, size) Vec128 +#define HWY_NEON_BUILD_RET_2(type, size) Vec128 +#define HWY_NEON_BUILD_RET_3(type, size) Vec128 + +// HWY_NEON_BUILD_PARAM_* is the list of parameters the function receives. +#define HWY_NEON_BUILD_PARAM_1(type, size) const Vec128 a +#define HWY_NEON_BUILD_PARAM_2(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_PARAM_3(type, size) \ + const Vec128 a, const Vec128 b, \ + const Vec128 c + +// HWY_NEON_BUILD_ARG_* is the list of arguments passed to the underlying +// function. +#define HWY_NEON_BUILD_ARG_1 a.raw +#define HWY_NEON_BUILD_ARG_2 a.raw, b.raw +#define HWY_NEON_BUILD_ARG_3 a.raw, b.raw, c.raw + +// We use HWY_NEON_EVAL(func, ...) to delay the evaluation of func until after +// the __VA_ARGS__ have been expanded. This allows "func" to be a macro on +// itself like with some of the library "functions" such as vshlq_u8. For +// example, HWY_NEON_EVAL(vshlq_u8, MY_PARAMS) where MY_PARAMS is defined as +// "a, b" (without the quotes) will end up expanding "vshlq_u8(a, b)" if needed. +// Directly writing vshlq_u8(MY_PARAMS) would fail since vshlq_u8() macro +// expects two arguments. +#define HWY_NEON_EVAL(func, ...) func(__VA_ARGS__) + +// Main macro definition that defines a single function for the given type and +// size of vector, using the underlying (prefix##infix##suffix) function and +// the template, return type, parameters and arguments defined by the "args" +// parameters passed here (see HWY_NEON_BUILD_* macros defined before). +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_CONCAT(HWY_NEON_BUILD_TPL_, args) \ + HWY_API HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size) \ + name(HWY_CONCAT(HWY_NEON_BUILD_PARAM_, args)(type, size)) { \ + return HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size)( \ + HWY_NEON_EVAL(prefix##infix##suffix, HWY_NEON_BUILD_ARG_##args)); \ + } + +// The HWY_NEON_DEF_FUNCTION_* macros define all the variants of a function +// called "name" using the set of neon functions starting with the given +// "prefix" for all the variants of certain types, as specified next to each +// macro. For example, the prefix "vsub" can be used to define the operator- +// using args=2. + +// uint8_t +#define HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 8, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 4, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 2, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 1, name, prefix, infix, u8, args) + +// int8_t +#define HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 8, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 4, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 2, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 1, name, prefix, infix, s8, args) + +// uint16_t +#define HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 4, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 2, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 1, name, prefix, infix, u16, args) + +// int16_t +#define HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 4, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 2, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 1, name, prefix, infix, s16, args) + +// uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 2, name, prefix, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 1, name, prefix, infix, u32, args) + +// int32_t +#define HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 2, name, prefix, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 1, name, prefix, infix, s32, args) + +// uint64_t +#define HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) + +// int64_t +#define HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) + +// Clang 17 crashes with bf16, see github.com/llvm/llvm-project/issues/64179. +#undef HWY_NEON_HAVE_BFLOAT16 +#if HWY_HAVE_SCALAR_BF16_TYPE && \ + ((HWY_TARGET == HWY_NEON_BF16 && \ + (!HWY_COMPILER_CLANG || HWY_COMPILER_CLANG >= 1800)) || \ + defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)) +#define HWY_NEON_HAVE_BFLOAT16 1 +#else +#define HWY_NEON_HAVE_BFLOAT16 0 +#endif + +// HWY_NEON_HAVE_F32_TO_BF16C is defined if NEON vcvt_bf16_f32 and +// vbfdot_f32 are available, even if the __bf16 type is disabled due to +// GCC/Clang bugs. +#undef HWY_NEON_HAVE_F32_TO_BF16C +#if HWY_NEON_HAVE_BFLOAT16 || HWY_TARGET == HWY_NEON_BF16 || \ + (defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \ + (HWY_COMPILER_GCC_ACTUAL >= 1000 || HWY_COMPILER_CLANG >= 1100)) +#define HWY_NEON_HAVE_F32_TO_BF16C 1 +#else +#define HWY_NEON_HAVE_F32_TO_BF16C 0 +#endif + +// bfloat16_t +#if HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(bfloat16, 8, name, prefix##q, infix, bf16, args) \ + HWY_NEON_DEF_FUNCTION(bfloat16, 4, name, prefix, infix, bf16, args) \ + HWY_NEON_DEF_FUNCTION(bfloat16, 2, name, prefix, infix, bf16, args) \ + HWY_NEON_DEF_FUNCTION(bfloat16, 1, name, prefix, infix, bf16, args) +#else +#define HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) +#endif + +// Used for conversion instructions if HWY_NEON_HAVE_F16C. +#define HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(name, prefix, infix, \ + args) \ + HWY_NEON_DEF_FUNCTION(float16, 8, name, prefix##q, infix, f16, args) \ + HWY_NEON_DEF_FUNCTION(float16, 4, name, prefix, infix, f16, args) \ + HWY_NEON_DEF_FUNCTION(float16, 2, name, prefix, infix, f16, args) \ + HWY_NEON_DEF_FUNCTION(float16, 1, name, prefix, infix, f16, args) + +// float16_t +#if HWY_HAVE_FLOAT16 +#define HWY_NEON_DEF_FUNCTION_FLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(name, prefix, infix, args) +#else +#define HWY_NEON_DEF_FUNCTION_FLOAT_16(name, prefix, infix, args) +#endif + +// Enable generic functions for whichever of (f16, bf16) are not supported. +#if !HWY_HAVE_FLOAT16 && !HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_NEON_IF_NOT_EMULATED_D(D) HWY_IF_NOT_SPECIAL_FLOAT_D(D) +#elif !HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_IF_EMULATED_D(D) HWY_IF_F16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_F16_D(D) +#define HWY_NEON_IF_NOT_EMULATED_D(D) HWY_IF_NOT_F16_D(D) +#elif HWY_HAVE_FLOAT16 && !HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_NEON_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) +#elif HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_BFLOAT16 +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the D template +// argument +#define HWY_NEON_IF_EMULATED_D(D) hwy::EnableIf()>* = nullptr +#define HWY_GENERIC_IF_EMULATED_D(D) \ + hwy::EnableIf()>* = nullptr +#define HWY_NEON_IF_NOT_EMULATED_D(D) hwy::EnableIf* = nullptr +#else +#error "Logic error, handled all four cases" +#endif + +// float +#define HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float32, 4, name, prefix##q, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 2, name, prefix, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 1, name, prefix, infix, f32, args) + +// double +#if HWY_HAVE_FLOAT64 +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float64, 2, name, prefix##q, infix, f64, args) \ + HWY_NEON_DEF_FUNCTION(float64, 1, name, prefix, infix, f64, args) +#else +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) +#endif + +// Helper macros to define for more than one type. +// uint8_t, uint16_t and uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) + +// int8_t, int16_t and int32_t +#define HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) + +// uint8_t, uint16_t, uint32_t and uint64_t +#define HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) + +// int8_t, int16_t, int32_t and int64_t +#define HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) + +// All int*_t and uint*_t up to 64 +#define HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) + +// All previous types. +#define HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_UI_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_UIF_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UI_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_UIF_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) + +// For vzip1/2 +#define HWY_NEON_DEF_FUNCTION_FULL_UI_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) +#define HWY_NEON_DEF_FUNCTION_FULL_UIF_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FULL_UI_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float64, 2, name, prefix##q, infix, f64, args) + +// For eor3q, which is only defined for full vectors. +#define HWY_NEON_DEF_FUNCTION_FULL_UI(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION_FULL_UI_64(name, prefix, infix, args) +// Emulation of some intrinsics on armv7. +#if HWY_ARCH_ARM_V7 +#define vuzp1_s8(x, y) vuzp_s8(x, y).val[0] +#define vuzp1_u8(x, y) vuzp_u8(x, y).val[0] +#define vuzp1_s16(x, y) vuzp_s16(x, y).val[0] +#define vuzp1_u16(x, y) vuzp_u16(x, y).val[0] +#define vuzp1_s32(x, y) vuzp_s32(x, y).val[0] +#define vuzp1_u32(x, y) vuzp_u32(x, y).val[0] +#define vuzp1_f32(x, y) vuzp_f32(x, y).val[0] +#define vuzp1q_s8(x, y) vuzpq_s8(x, y).val[0] +#define vuzp1q_u8(x, y) vuzpq_u8(x, y).val[0] +#define vuzp1q_s16(x, y) vuzpq_s16(x, y).val[0] +#define vuzp1q_u16(x, y) vuzpq_u16(x, y).val[0] +#define vuzp1q_s32(x, y) vuzpq_s32(x, y).val[0] +#define vuzp1q_u32(x, y) vuzpq_u32(x, y).val[0] +#define vuzp1q_f32(x, y) vuzpq_f32(x, y).val[0] +#define vuzp2_s8(x, y) vuzp_s8(x, y).val[1] +#define vuzp2_u8(x, y) vuzp_u8(x, y).val[1] +#define vuzp2_s16(x, y) vuzp_s16(x, y).val[1] +#define vuzp2_u16(x, y) vuzp_u16(x, y).val[1] +#define vuzp2_s32(x, y) vuzp_s32(x, y).val[1] +#define vuzp2_u32(x, y) vuzp_u32(x, y).val[1] +#define vuzp2_f32(x, y) vuzp_f32(x, y).val[1] +#define vuzp2q_s8(x, y) vuzpq_s8(x, y).val[1] +#define vuzp2q_u8(x, y) vuzpq_u8(x, y).val[1] +#define vuzp2q_s16(x, y) vuzpq_s16(x, y).val[1] +#define vuzp2q_u16(x, y) vuzpq_u16(x, y).val[1] +#define vuzp2q_s32(x, y) vuzpq_s32(x, y).val[1] +#define vuzp2q_u32(x, y) vuzpq_u32(x, y).val[1] +#define vuzp2q_f32(x, y) vuzpq_f32(x, y).val[1] +#define vzip1_s8(x, y) vzip_s8(x, y).val[0] +#define vzip1_u8(x, y) vzip_u8(x, y).val[0] +#define vzip1_s16(x, y) vzip_s16(x, y).val[0] +#define vzip1_u16(x, y) vzip_u16(x, y).val[0] +#define vzip1_f32(x, y) vzip_f32(x, y).val[0] +#define vzip1_u32(x, y) vzip_u32(x, y).val[0] +#define vzip1_s32(x, y) vzip_s32(x, y).val[0] +#define vzip1q_s8(x, y) vzipq_s8(x, y).val[0] +#define vzip1q_u8(x, y) vzipq_u8(x, y).val[0] +#define vzip1q_s16(x, y) vzipq_s16(x, y).val[0] +#define vzip1q_u16(x, y) vzipq_u16(x, y).val[0] +#define vzip1q_s32(x, y) vzipq_s32(x, y).val[0] +#define vzip1q_u32(x, y) vzipq_u32(x, y).val[0] +#define vzip1q_f32(x, y) vzipq_f32(x, y).val[0] +#define vzip2_s8(x, y) vzip_s8(x, y).val[1] +#define vzip2_u8(x, y) vzip_u8(x, y).val[1] +#define vzip2_s16(x, y) vzip_s16(x, y).val[1] +#define vzip2_u16(x, y) vzip_u16(x, y).val[1] +#define vzip2_s32(x, y) vzip_s32(x, y).val[1] +#define vzip2_u32(x, y) vzip_u32(x, y).val[1] +#define vzip2_f32(x, y) vzip_f32(x, y).val[1] +#define vzip2q_s8(x, y) vzipq_s8(x, y).val[1] +#define vzip2q_u8(x, y) vzipq_u8(x, y).val[1] +#define vzip2q_s16(x, y) vzipq_s16(x, y).val[1] +#define vzip2q_u16(x, y) vzipq_u16(x, y).val[1] +#define vzip2q_s32(x, y) vzipq_s32(x, y).val[1] +#define vzip2q_u32(x, y) vzipq_u32(x, y).val[1] +#define vzip2q_f32(x, y) vzipq_f32(x, y).val[1] +#endif + +// Wrappers over uint8x16x2_t etc. so we can define StoreInterleaved2 +// overloads for all vector types, even those (bfloat16_t) where the +// underlying vector is the same as others (uint16_t). +template +struct Tuple2; +template +struct Tuple3; +template +struct Tuple4; + +template <> +struct Tuple2 { + uint8x16x2_t raw; +}; +template +struct Tuple2 { + uint8x8x2_t raw; +}; +template <> +struct Tuple2 { + int8x16x2_t raw; +}; +template +struct Tuple2 { + int8x8x2_t raw; +}; +template <> +struct Tuple2 { + uint16x8x2_t raw; +}; +template +struct Tuple2 { + uint16x4x2_t raw; +}; +template <> +struct Tuple2 { + int16x8x2_t raw; +}; +template +struct Tuple2 { + int16x4x2_t raw; +}; +template <> +struct Tuple2 { + uint32x4x2_t raw; +}; +template +struct Tuple2 { + uint32x2x2_t raw; +}; +template <> +struct Tuple2 { + int32x4x2_t raw; +}; +template +struct Tuple2 { + int32x2x2_t raw; +}; +template <> +struct Tuple2 { + uint64x2x2_t raw; +}; +template +struct Tuple2 { + uint64x1x2_t raw; +}; +template <> +struct Tuple2 { + int64x2x2_t raw; +}; +template +struct Tuple2 { + int64x1x2_t raw; +}; + +template <> +struct Tuple2 { + float32x4x2_t raw; +}; +template +struct Tuple2 { + float32x2x2_t raw; +}; +#if HWY_HAVE_FLOAT64 +template <> +struct Tuple2 { + float64x2x2_t raw; +}; +template +struct Tuple2 { + float64x1x2_t raw; +}; +#endif // HWY_HAVE_FLOAT64 + +template <> +struct Tuple3 { + uint8x16x3_t raw; +}; +template +struct Tuple3 { + uint8x8x3_t raw; +}; +template <> +struct Tuple3 { + int8x16x3_t raw; +}; +template +struct Tuple3 { + int8x8x3_t raw; +}; +template <> +struct Tuple3 { + uint16x8x3_t raw; +}; +template +struct Tuple3 { + uint16x4x3_t raw; +}; +template <> +struct Tuple3 { + int16x8x3_t raw; +}; +template +struct Tuple3 { + int16x4x3_t raw; +}; +template <> +struct Tuple3 { + uint32x4x3_t raw; +}; +template +struct Tuple3 { + uint32x2x3_t raw; +}; +template <> +struct Tuple3 { + int32x4x3_t raw; +}; +template +struct Tuple3 { + int32x2x3_t raw; +}; +template <> +struct Tuple3 { + uint64x2x3_t raw; +}; +template +struct Tuple3 { + uint64x1x3_t raw; +}; +template <> +struct Tuple3 { + int64x2x3_t raw; +}; +template +struct Tuple3 { + int64x1x3_t raw; +}; + +template <> +struct Tuple3 { + float32x4x3_t raw; +}; +template +struct Tuple3 { + float32x2x3_t raw; +}; +#if HWY_HAVE_FLOAT64 +template <> +struct Tuple3 { + float64x2x3_t raw; +}; +template +struct Tuple3 { + float64x1x3_t raw; +}; +#endif // HWY_HAVE_FLOAT64 + +template <> +struct Tuple4 { + uint8x16x4_t raw; +}; +template +struct Tuple4 { + uint8x8x4_t raw; +}; +template <> +struct Tuple4 { + int8x16x4_t raw; +}; +template +struct Tuple4 { + int8x8x4_t raw; +}; +template <> +struct Tuple4 { + uint16x8x4_t raw; +}; +template +struct Tuple4 { + uint16x4x4_t raw; +}; +template <> +struct Tuple4 { + int16x8x4_t raw; +}; +template +struct Tuple4 { + int16x4x4_t raw; +}; +template <> +struct Tuple4 { + uint32x4x4_t raw; +}; +template +struct Tuple4 { + uint32x2x4_t raw; +}; +template <> +struct Tuple4 { + int32x4x4_t raw; +}; +template +struct Tuple4 { + int32x2x4_t raw; +}; +template <> +struct Tuple4 { + uint64x2x4_t raw; +}; +template +struct Tuple4 { + uint64x1x4_t raw; +}; +template <> +struct Tuple4 { + int64x2x4_t raw; +}; +template +struct Tuple4 { + int64x1x4_t raw; +}; + +template <> +struct Tuple4 { + float32x4x4_t raw; +}; +template +struct Tuple4 { + float32x2x4_t raw; +}; +#if HWY_HAVE_FLOAT64 +template <> +struct Tuple4 { + float64x2x4_t raw; +}; +template +struct Tuple4 { + float64x1x4_t raw; +}; +#endif // HWY_HAVE_FLOAT64 + +template +struct Raw128; + +template <> +struct Raw128 { + using type = uint8x16_t; +}; +template +struct Raw128 { + using type = uint8x8_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; +template +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = uint32x4_t; +}; +template +struct Raw128 { + using type = uint32x2_t; +}; + +template <> +struct Raw128 { + using type = uint64x2_t; +}; +template <> +struct Raw128 { + using type = uint64x1_t; +}; + +template <> +struct Raw128 { + using type = int8x16_t; +}; +template +struct Raw128 { + using type = int8x8_t; +}; + +template <> +struct Raw128 { + using type = int16x8_t; +}; +template +struct Raw128 { + using type = int16x4_t; +}; + +template <> +struct Raw128 { + using type = int32x4_t; +}; +template +struct Raw128 { + using type = int32x2_t; +}; + +template <> +struct Raw128 { + using type = int64x2_t; +}; +template <> +struct Raw128 { + using type = int64x1_t; +}; + +template <> +struct Raw128 { + using type = float32x4_t; +}; +template +struct Raw128 { + using type = float32x2_t; +}; + +#if HWY_HAVE_FLOAT64 +template <> +struct Raw128 { + using type = float64x2_t; +}; +template <> +struct Raw128 { + using type = float64x1_t; +}; +#endif // HWY_HAVE_FLOAT64 + +#if HWY_NEON_HAVE_F16C + +template <> +struct Tuple2 { + float16x8x2_t raw; +}; +template +struct Tuple2 { + float16x4x2_t raw; +}; + +template <> +struct Tuple3 { + float16x8x3_t raw; +}; +template +struct Tuple3 { + float16x4x3_t raw; +}; + +template <> +struct Tuple4 { + float16x8x4_t raw; +}; +template +struct Tuple4 { + float16x4x4_t raw; +}; + +template <> +struct Raw128 { + using type = float16x8_t; +}; +template +struct Raw128 { + using type = float16x4_t; +}; + +#else // !HWY_NEON_HAVE_F16C + +template +struct Tuple2 : public Tuple2 {}; +template +struct Tuple3 : public Tuple3 {}; +template +struct Tuple4 : public Tuple4 {}; +template +struct Raw128 : public Raw128 {}; + +#endif // HWY_NEON_HAVE_F16C + +#if HWY_NEON_HAVE_BFLOAT16 + +template <> +struct Tuple2 { + bfloat16x8x2_t raw; +}; +template +struct Tuple2 { + bfloat16x4x2_t raw; +}; + +template <> +struct Tuple3 { + bfloat16x8x3_t raw; +}; +template +struct Tuple3 { + bfloat16x4x3_t raw; +}; + +template <> +struct Tuple4 { + bfloat16x8x4_t raw; +}; +template +struct Tuple4 { + bfloat16x4x4_t raw; +}; + +template <> +struct Raw128 { + using type = bfloat16x8_t; +}; +template +struct Raw128 { + using type = bfloat16x4_t; +}; + +#else // !HWY_NEON_HAVE_BFLOAT16 + +template +struct Tuple2 : public Tuple2 {}; +template +struct Tuple3 : public Tuple3 {}; +template +struct Tuple4 : public Tuple4 {}; +template +struct Raw128 : public Raw128 {}; + +#endif // HWY_NEON_HAVE_BFLOAT16 + +} // namespace detail + +template +class Vec128 { + public: + using Raw = typename detail::Raw128::type; + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + HWY_INLINE Vec128() {} + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + HWY_INLINE explicit Vec128(const Raw raw) : raw(raw) {} + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +// FF..FF or 0. +template +class Mask128 { + public: + // Arm C Language Extensions return and expect unsigned type. + using Raw = typename detail::Raw128, N>::type; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM + + HWY_INLINE Mask128() {} + Mask128(const Mask128&) = default; + Mask128& operator=(const Mask128&) = default; + HWY_INLINE explicit Mask128(const Raw raw) : raw(raw) {} + + Raw raw; +}; + +template +using Mask64 = Mask128; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ Set + +namespace detail { +// We want to route any combination of N/kPow2 to the intrinsics depending on +// whether the requested size is <= 64 bits or 128. HWY_NEON_BUILD_TPL is +// unconditional and currently does not accept inputs (such as whether the +// vector is 64 or 128-bit). Thus we are not able to use HWY_IF_V_SIZE_D for +// SFINAE. We instead define a private NativeSet which receives a Simd<> whose +// kPow2 has already been folded into its N. +#define HWY_NEON_BUILD_TPL_HWY_SET +#define HWY_NEON_BUILD_RET_HWY_SET(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_SET(type, size) \ + Simd /* tag */, type##_t t +#define HWY_NEON_BUILD_ARG_HWY_SET t + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(NativeSet, vdup, _n_, HWY_SET) +#if !HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_F16C +HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(NativeSet, vdup, _n_, HWY_SET) +#endif +HWY_NEON_DEF_FUNCTION_BFLOAT_16(NativeSet, vdup, _n_, HWY_SET) + +template +HWY_API Vec128, MaxLanes(D())> NativeSet(D d, TFromD t) { + const uint16_t tu = BitCastScalar(t); + return Vec128, d.MaxLanes()>(Set(RebindToUnsigned(), tu).raw); +} + +#undef HWY_NEON_BUILD_TPL_HWY_SET +#undef HWY_NEON_BUILD_RET_HWY_SET +#undef HWY_NEON_BUILD_PARAM_HWY_SET +#undef HWY_NEON_BUILD_ARG_HWY_SET + +} // namespace detail + +// Full vector. Cannot yet use VFromD because that is defined in terms of Set. +// Do not use a typename T = TFromD argument because T will be deduced from +// the actual argument type, which can differ from TFromD. +template +HWY_INLINE Vec128> Set(D /* tag */, T t) { + return detail::NativeSet(Full128>(), static_cast>(t)); +} + +// Partial vector: create 64-bit and return wrapper. +template +HWY_API Vec128, MaxLanes(D())> Set(D /* tag */, T t) { + const Full64> dfull; + return Vec128, MaxLanes(D())>( + detail::NativeSet(dfull, static_cast>(t)).raw); +} + +template +using VFromD = decltype(Set(D(), TFromD())); + +template +HWY_API VFromD Zero(D d) { + // Default ctor also works for bfloat16_t and float16_t. + return Set(d, TFromD{}); +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +template +HWY_API VFromD Undefined(D /*tag*/) { +#if HWY_HAS_BUILTIN(__builtin_nondeterministic_value) + return VFromD{__builtin_nondeterministic_value(Zero(D()).raw)}; +#else + VFromD v; + return v; +#endif +} + +HWY_DIAGNOSTICS(pop) + +#if !HWY_COMPILER_GCC && !HWY_COMPILER_CLANGCL +namespace detail { + +#pragma pack(push, 1) + +template +struct alignas(8) Vec64ValsWrapper { + static_assert(sizeof(T) >= 1, "sizeof(T) >= 1 must be true"); + static_assert(sizeof(T) <= 8, "sizeof(T) <= 8 must be true"); + T vals[8 / sizeof(T)]; +}; + +#pragma pack(pop) + +} // namespace detail +#endif // !HWY_COMPILER_GCC && !HWY_COMPILER_CLANGCL + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD /*t8*/, TFromD /*t9*/, + TFromD /*t10*/, TFromD /*t11*/, + TFromD /*t12*/, TFromD /*t13*/, + TFromD /*t14*/, TFromD /*t15*/) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int8_t GccI8RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccI8RawVectType raw = { + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + return ResizeBitCast( + d, Set(Full64(), + BitCastScalar(detail::Vec64ValsWrapper>{ + {t0, t1, t2, t3, t4, t5, t6, t7}}))); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int16_t GccI16RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccI16RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + return ResizeBitCast( + d, Set(Full64(), + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1, t2, t3}}))); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD /*t2*/, TFromD /*t3*/) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int32_t GccI32RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccI32RawVectType raw = {static_cast(t0), + static_cast(t1)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + return ResizeBitCast(d, + Set(Full64(), + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1}}))); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD /*t2*/, TFromD /*t3*/) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef float GccF32RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccF32RawVectType raw = {t0, t1}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + return ResizeBitCast(d, + Set(Full64(), + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1}}))); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD /*t1*/) { + return Set(d, t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int8_t GccI8RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI8RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), + static_cast(t8), static_cast(t9), + static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), + static_cast(t14), static_cast(t15)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, + Dup128VecFromValues(dh, t8, t9, t10, t11, t12, t13, t14, t15, + t8, t9, t10, t11, t12, t13, t14, t15), + Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, + t2, t3, t4, t5, t6, t7)); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int16_t GccI16RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI16RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Dup128VecFromValues(dh, t4, t5, t6, t7, t4, t5, t6, t7), + Dup128VecFromValues(dh, t0, t1, t2, t3, t0, t1, t2, t3)); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int32_t GccI32RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI32RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Dup128VecFromValues(dh, t2, t3, t2, t3), + Dup128VecFromValues(dh, t0, t1, t0, t1)); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccF32RawVectType raw = {t0, t1, t2, t3}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Dup128VecFromValues(dh, t2, t3, t2, t3), + Dup128VecFromValues(dh, t0, t1, t0, t1)); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int64_t GccI64RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI64RawVectType raw = {static_cast(t0), + static_cast(t1)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Set(dh, t1), Set(dh, t0)); +#endif +} + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccF64RawVectType raw = {t0, t1}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Set(dh, t1), Set(dh, t0)); +#endif +} +#endif + +// Generic for all vector lengths +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +#if (HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL) && HWY_NEON_HAVE_F16C +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/) { + typedef __fp16 GccF16RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccF16RawVectType raw = { + static_cast<__fp16>(t0), static_cast<__fp16>(t1), static_cast<__fp16>(t2), + static_cast<__fp16>(t3)}; + return VFromD(reinterpret_cast::Raw>(raw)); +} +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + typedef __fp16 GccF16RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccF16RawVectType raw = { + static_cast<__fp16>(t0), static_cast<__fp16>(t1), static_cast<__fp16>(t2), + static_cast<__fp16>(t3), static_cast<__fp16>(t4), static_cast<__fp16>(t5), + static_cast<__fp16>(t6), static_cast<__fp16>(t7)}; + return VFromD(reinterpret_cast::Raw>(raw)); +} +#else +// Generic for all vector lengths if MSVC or !HWY_NEON_HAVE_F16C +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} +#endif // (HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL) && HWY_NEON_HAVE_F16C + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues( + d, TFromD{0}, TFromD{1}, TFromD{2}, TFromD{3}, TFromD{4}, + TFromD{5}, TFromD{6}, TFromD{7}, TFromD{8}, TFromD{9}, + TFromD{10}, TFromD{11}, TFromD{12}, TFromD{13}, TFromD{14}, + TFromD{15}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}, TFromD{2}, + TFromD{3}, TFromD{4}, TFromD{5}, + TFromD{6}, TFromD{7}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + const RebindToUnsigned du; + return BitCast(d, Dup128VecFromValues(du, uint16_t{0}, uint16_t{0x3C00}, + uint16_t{0x4000}, uint16_t{0x4200}, + uint16_t{0x4400}, uint16_t{0x4500}, + uint16_t{0x4600}, uint16_t{0x4700})); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}, TFromD{2}, + TFromD{3}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}); +} + +#if HWY_COMPILER_MSVC +template +static HWY_INLINE V MaskOutIota(V v) { + constexpr size_t kVecSizeInBytes = HWY_MAX_LANES_V(V) * sizeof(TFromV); + constexpr uint64_t kU64MaskOutMask = + hwy::LimitsMax>(); + + const DFromV d; + const Repartition du8; + using VU8 = VFromD; + const auto mask_out_mask = + BitCast(d, VU8(vreinterpret_u8_u64(vdup_n_u64(kU64MaskOutMask)))); + return v & mask_out_mask; +} +template +static HWY_INLINE V MaskOutIota(V v) { + return v; +} +#endif + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + const auto result_iota = + detail::Iota0(d) + Set(d, static_cast>(first)); +#if HWY_COMPILER_MSVC + return detail::MaskOutIota(result_iota); +#else + return result_iota; +#endif +} + +// ------------------------------ Combine + +// Full result +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u8(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u16(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u32(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u64(lo.raw, hi.raw)); +} + +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s8(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s16(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s32(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s64(lo.raw, hi.raw)); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Combine(D, Vec64 hi, Vec64 lo) { + return Vec128(vcombine_f16(lo.raw, hi.raw)); +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API VFromD Combine(D, Vec64 hi, Vec64 lo) { + return VFromD(vcombine_bf16(lo.raw, hi.raw)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +template , HWY_NEON_IF_EMULATED_D(D)> +HWY_API VFromD Combine(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, Combine(du, BitCast(duh, hi), BitCast(duh, lo))); +} + +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, Vec64 lo) { + return Vec128(vcombine_f32(lo.raw, hi.raw)); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_f64(lo.raw, hi.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ BitCast + +namespace detail { + +// Converts from Vec128 to Vec128 using the +// vreinterpret*_u8_*() set of functions. +#define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \ + Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128 v +#define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw + +// Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined. +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return v; +} + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_BFLOAT_16(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) + +HWY_NEON_DEF_FUNCTION_INTS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_16(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) + +#if !HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_F16C +HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) +#else +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return BitCastToByte(Vec128(v.raw)); +} +#endif // HWY_NEON_HAVE_F16C +#endif // !HWY_HAVE_FLOAT16 + +#if !HWY_NEON_HAVE_BFLOAT16 +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return BitCastToByte(Vec128(v.raw)); +} +#endif // !HWY_NEON_HAVE_BFLOAT16 + +#undef HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_RET_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, VFromD v) { + return v; +} + +// 64-bit or less: + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_s8_u8(v.raw)); +} +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_u16_u8(v.raw)); +} +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_s16_u8(v.raw)); +} +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_u32_u8(v.raw)); +} +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_s32_u8(v.raw)); +} + +template +HWY_INLINE Vec64 BitCastFromByte(D /* tag */, Vec64 v) { + return Vec64(vreinterpret_u64_u8(v.raw)); +} +template +HWY_INLINE Vec64 BitCastFromByte(D /* tag */, Vec64 v) { + return Vec64(vreinterpret_s64_u8(v.raw)); +} + +// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C. +template +HWY_INLINE VFromD BitCastFromByte(D, VFromD> v) { +#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C + return VFromD(vreinterpret_f16_u8(v.raw)); +#else + const RebindToUnsigned du; + return VFromD(BitCastFromByte(du, v).raw); +#endif +} + +template +HWY_INLINE VFromD BitCastFromByte(D, VFromD> v) { +#if HWY_NEON_HAVE_BFLOAT16 + return VFromD(vreinterpret_bf16_u8(v.raw)); +#else + const RebindToUnsigned du; + return VFromD(BitCastFromByte(du, v).raw); +#endif +} + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_f32_u8(v.raw)); +} + +#if HWY_HAVE_FLOAT64 +template +HWY_INLINE Vec64 BitCastFromByte(D /* tag */, Vec64 v) { + return Vec64(vreinterpret_f64_u8(v.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +// 128-bit full: + +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_s8_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_u16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_s16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_u32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_s32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_u64_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_s64_u8(v.raw)); +} + +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_f32_u8(v.raw)); +} + +#if HWY_HAVE_FLOAT64 +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_f64_u8(v.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C. +template +HWY_INLINE VFromD BitCastFromByte(D, Vec128 v) { +#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C + return VFromD(vreinterpretq_f16_u8(v.raw)); +#else + return VFromD(BitCastFromByte(RebindToUnsigned(), v).raw); +#endif +} + +template +HWY_INLINE VFromD BitCastFromByte(D, Vec128 v) { +#if HWY_NEON_HAVE_BFLOAT16 + return VFromD(vreinterpretq_bf16_u8(v.raw)); +#else + return VFromD(BitCastFromByte(RebindToUnsigned(), v).raw); +#endif +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, + Vec128().MaxLanes()> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ ResizeBitCast + +// <= 8 byte vector to <= 8 byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Repartition du8; + return BitCast(d, VFromD{detail::BitCastToByte(v).raw}); +} + +// 16-byte vector to 16-byte vector: same as BitCast +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, v); +} + +// 16-byte vector to <= 8-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const DFromV d_from; + const Half dh_from; + return ResizeBitCast(d, LowerHalf(dh_from, v)); +} + +// <= 8-bit vector to 16-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Full64> d_full64_from; + const Full128> d_full128_from; + return BitCast(d, Combine(d_full128_from, Zero(d_full64_from), + ResizeBitCast(d_full64_from, v))); +} + +// ------------------------------ GetLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_GET template +#define HWY_NEON_BUILD_RET_HWY_GET(type, size) type##_t +#define HWY_NEON_BUILD_PARAM_HWY_GET(type, size) Vec128 v +#define HWY_NEON_BUILD_ARG_HWY_GET v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(GetLane, vget, _lane_, HWY_GET) +HWY_NEON_DEF_FUNCTION_BFLOAT_16(GetLane, vget, _lane_, HWY_GET) + +template )> +static HWY_INLINE HWY_MAYBE_UNUSED TFromV GetLane(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCastScalar>(GetLane(BitCast(du, v))); +} + +#undef HWY_NEON_BUILD_TPL_HWY_GET +#undef HWY_NEON_BUILD_RET_HWY_GET +#undef HWY_NEON_BUILD_PARAM_HWY_GET +#undef HWY_NEON_BUILD_ARG_HWY_GET + +} // namespace detail + +template +HWY_API TFromV GetLane(const V v) { + return detail::GetLane<0>(v); +} + +// ------------------------------ ExtractLane + +// Requires one overload per vector length because GetLane<3> is a compile error +// if v is a uint32x2_t. +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return detail::GetLane<0>(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + case 8: + return detail::GetLane<8>(v); + case 9: + return detail::GetLane<9>(v); + case 10: + return detail::GetLane<10>(v); + case 11: + return detail::GetLane<11>(v); + case 12: + return detail::GetLane<12>(v); + case 13: + return detail::GetLane<13>(v); + case 14: + return detail::GetLane<14>(v); + case 15: + return detail::GetLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_INSERT template +#define HWY_NEON_BUILD_RET_HWY_INSERT(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_INSERT(type, size) \ + Vec128 v, type##_t t +#define HWY_NEON_BUILD_ARG_HWY_INSERT t, v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(InsertLane, vset, _lane_, HWY_INSERT) +HWY_NEON_DEF_FUNCTION_BFLOAT_16(InsertLane, vset, _lane_, HWY_INSERT) + +#undef HWY_NEON_BUILD_TPL_HWY_INSERT +#undef HWY_NEON_BUILD_RET_HWY_INSERT +#undef HWY_NEON_BUILD_PARAM_HWY_INSERT +#undef HWY_NEON_BUILD_ARG_HWY_INSERT + +template , HWY_NEON_IF_EMULATED_D(D)> +HWY_API V InsertLane(const V v, TFromD t) { + const D d; + const RebindToUnsigned du; + const uint16_t tu = BitCastScalar(t); + return BitCast(d, InsertLane(BitCast(du, v), tu)); +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator+, vadd, _, 2) + +// ------------------------------ Subtraction +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator-, vsub, _, 2) + +// ------------------------------ SumsOf8 + +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v.raw)))); +} +HWY_API Vec64 SumsOf8(const Vec64 v) { + return Vec64(vpaddl_u32(vpaddl_u16(vpaddl_u8(v.raw)))); +} +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128(vpaddlq_s32(vpaddlq_s16(vpaddlq_s8(v.raw)))); +} +HWY_API Vec64 SumsOf8(const Vec64 v) { + return Vec64(vpaddl_s32(vpaddl_s16(vpaddl_s8(v.raw)))); +} + +// ------------------------------ SumsOf2 +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_s8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_s8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_u8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_u8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_s16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_s16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_u16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_u16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_s32(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_s32(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_u32(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_u32(v.raw)); +} + +} // namespace detail + +// ------------------------------ SaturatedAdd + +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB +#endif + +// Returns a + b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(SaturatedAdd, vqadd, _, 2) + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(SaturatedSub, vqsub, _, 2) + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +HWY_NEON_DEF_FUNCTION_UI_8_16_32(AverageRound, vrhadd, _, 2) + +// ------------------------------ Neg + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Neg, vneg, _, 1) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Neg, vneg, _, 1) // i64 implemented below + +#if !HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Neg(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask()))); +} +#endif // !HWY_HAVE_FLOAT16 + +// There is no vneg for bf16, but we can cast to f16 (emulated or native). +template +HWY_API Vec128 Neg(const Vec128 v) { + const DFromV d; + const Rebind df16; + return BitCast(d, Neg(BitCast(df16, v))); +} + +HWY_API Vec64 Neg(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vneg_s64(v.raw)); +#else + return Zero(DFromV()) - v; +#endif +} + +HWY_API Vec128 Neg(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vnegq_s64(v.raw)); +#else + return Zero(DFromV()) - v; +#endif +} + +// ------------------------------ SaturatedNeg +#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 +#undef HWY_NATIVE_SATURATED_NEG_8_16_32 +#else +#define HWY_NATIVE_SATURATED_NEG_8_16_32 +#endif + +HWY_NEON_DEF_FUNCTION_INT_8_16_32(SaturatedNeg, vqneg, _, 1) + +#if HWY_ARCH_ARM_A64 +#ifdef HWY_NATIVE_SATURATED_NEG_64 +#undef HWY_NATIVE_SATURATED_NEG_64 +#else +#define HWY_NATIVE_SATURATED_NEG_64 +#endif + +HWY_API Vec64 SaturatedNeg(const Vec64 v) { + return Vec64(vqneg_s64(v.raw)); +} + +HWY_API Vec128 SaturatedNeg(const Vec128 v) { + return Vec128(vqnegq_s64(v.raw)); +} +#endif + +// ------------------------------ ShiftLeft + +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +// Customize HWY_NEON_DEF_FUNCTION to special-case count=0 (not supported). +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + template \ + HWY_API Vec128 name(const Vec128 v) { \ + return kBits == 0 ? v \ + : Vec128(HWY_NEON_EVAL( \ + prefix##infix##suffix, v.raw, HWY_MAX(1, kBits))); \ + } + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(ShiftLeft, vshl, _n_, ignored) + +HWY_NEON_DEF_FUNCTION_UINTS(ShiftRight, vshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_INTS(ShiftRight, vshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_UINTS(RoundingShiftRight, vrshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_INTS(RoundingShiftRight, vrshr, _n_, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// NOTE: vxarq_u64 can be applied to uint64_t, but we do not yet have a +// mechanism for checking for extensions to Armv8. + +// ------------------------------ Shl + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_u8(v.raw, vreinterpretq_s8_u8(bits.raw))); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_u8(v.raw, vreinterpret_s8_u8(bits.raw))); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_u16(v.raw, vreinterpretq_s16_u16(bits.raw))); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_u16(v.raw, vreinterpret_s16_u16(bits.raw))); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_u32(v.raw, vreinterpretq_s32_u32(bits.raw))); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_u32(v.raw, vreinterpret_s32_u32(bits.raw))); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_u64(v.raw, vreinterpretq_s64_u64(bits.raw))); +} +HWY_API Vec64 operator<<(Vec64 v, Vec64 bits) { + return Vec64(vshl_u64(v.raw, vreinterpret_s64_u64(bits.raw))); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s8(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s8(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s16(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s16(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s32(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s32(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s64(v.raw, bits.raw)); +} +HWY_API Vec64 operator<<(Vec64 v, Vec64 bits) { + return Vec64(vshl_s64(v.raw, bits.raw)); +} + +// ------------------------------ Shr (Neg) + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int8x16_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshlq_u8(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int8x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshl_u8(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int16x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshlq_u16(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int16x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshl_u16(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int32x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshlq_u32(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int32x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshl_u32(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int64x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshlq_u64(v.raw, neg_bits)); +} +HWY_API Vec64 operator>>(Vec64 v, Vec64 bits) { + const RebindToSigned> di; + const int64x1_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec64(vshl_u64(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s8(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s8(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s16(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s16(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s32(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s32(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s64(v.raw, Neg(bits).raw)); +} +HWY_API Vec64 operator>>(Vec64 v, Vec64 bits) { + return Vec64(vshl_s64(v.raw, Neg(bits).raw)); +} + +// ------------------------------ RoundingShr (Neg) + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int8x16_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u8(v.raw, neg_bits)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int8x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshl_u8(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int16x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u16(v.raw, neg_bits)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int16x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshl_u16(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int32x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u32(v.raw, neg_bits)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int32x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshl_u32(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int64x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u64(v.raw, neg_bits)); +} +HWY_API Vec64 RoundingShr(Vec64 v, Vec64 bits) { + const RebindToSigned> di; + const int64x1_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec64(vrshl_u64(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s8(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128(vrshl_s8(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s16(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128(vrshl_s16(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s32(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128(vrshl_s32(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s64(v.raw, Neg(bits).raw)); +} +HWY_API Vec64 RoundingShr(Vec64 v, Vec64 bits) { + return Vec64(vrshl_s64(v.raw, Neg(bits).raw)); +} + +// ------------------------------ ShiftLeftSame (Shl) + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, int bits) { + return v << Set(DFromV(), static_cast(bits)); +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, int bits) { + return v >> Set(DFromV(), static_cast(bits)); +} + +// ------------------------------ RoundingShiftRightSame (RoundingShr) + +template +HWY_API Vec128 RoundingShiftRightSame(const Vec128 v, int bits) { + return RoundingShr(v, Set(DFromV(), static_cast(bits))); +} + +// ------------------------------ Int/float multiplication + +// Per-target flag to prevent generic_ops-inl.h from defining 8-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif + +// All except ui64 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator*, vmul, _, 2) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator*, vmul, _, 2) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator*, vmul, _, 2) + +// ------------------------------ Integer multiplication + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int16x8_t rlo = vmull_s8(vget_low_s8(a.raw), vget_low_s8(b.raw)); +#if HWY_ARCH_ARM_A64 + int16x8_t rhi = vmull_high_s8(a.raw, b.raw); +#else + int16x8_t rhi = vmull_s8(vget_high_s8(a.raw), vget_high_s8(b.raw)); +#endif + return Vec128( + vuzp2q_s8(vreinterpretq_s8_s16(rlo), vreinterpretq_s8_s16(rhi))); +} +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint16x8_t rlo = vmull_u8(vget_low_u8(a.raw), vget_low_u8(b.raw)); +#if HWY_ARCH_ARM_A64 + uint16x8_t rhi = vmull_high_u8(a.raw, b.raw); +#else + uint16x8_t rhi = vmull_u8(vget_high_u8(a.raw), vget_high_u8(b.raw)); +#endif + return Vec128( + vuzp2q_u8(vreinterpretq_u8_u16(rlo), vreinterpretq_u8_u16(rhi))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int8x16_t hi_lo = vreinterpretq_s8_s16(vmull_s8(a.raw, b.raw)); + return Vec128(vget_low_s8(vuzp2q_s8(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint8x16_t hi_lo = vreinterpretq_u8_u16(vmull_u8(a.raw, b.raw)); + return Vec128(vget_low_u8(vuzp2q_u8(hi_lo, hi_lo))); +} + +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int32x4_t rlo = vmull_s16(vget_low_s16(a.raw), vget_low_s16(b.raw)); +#if HWY_ARCH_ARM_A64 + int32x4_t rhi = vmull_high_s16(a.raw, b.raw); +#else + int32x4_t rhi = vmull_s16(vget_high_s16(a.raw), vget_high_s16(b.raw)); +#endif + return Vec128( + vuzp2q_s16(vreinterpretq_s16_s32(rlo), vreinterpretq_s16_s32(rhi))); +} +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint32x4_t rlo = vmull_u16(vget_low_u16(a.raw), vget_low_u16(b.raw)); +#if HWY_ARCH_ARM_A64 + uint32x4_t rhi = vmull_high_u16(a.raw, b.raw); +#else + uint32x4_t rhi = vmull_u16(vget_high_u16(a.raw), vget_high_u16(b.raw)); +#endif + return Vec128( + vuzp2q_u16(vreinterpretq_u16_u32(rlo), vreinterpretq_u16_u32(rhi))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int16x8_t hi_lo = vreinterpretq_s16_s32(vmull_s16(a.raw, b.raw)); + return Vec128(vget_low_s16(vuzp2q_s16(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(Vec128 a, + Vec128 b) { + uint16x8_t hi_lo = vreinterpretq_u16_u32(vmull_u16(a.raw, b.raw)); + return Vec128(vget_low_u16(vuzp2q_u16(hi_lo, hi_lo))); +} + +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int64x2_t rlo = vmull_s32(vget_low_s32(a.raw), vget_low_s32(b.raw)); +#if HWY_ARCH_ARM_A64 + int64x2_t rhi = vmull_high_s32(a.raw, b.raw); +#else + int64x2_t rhi = vmull_s32(vget_high_s32(a.raw), vget_high_s32(b.raw)); +#endif + return Vec128( + vuzp2q_s32(vreinterpretq_s32_s64(rlo), vreinterpretq_s32_s64(rhi))); +} +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint64x2_t rlo = vmull_u32(vget_low_u32(a.raw), vget_low_u32(b.raw)); +#if HWY_ARCH_ARM_A64 + uint64x2_t rhi = vmull_high_u32(a.raw, b.raw); +#else + uint64x2_t rhi = vmull_u32(vget_high_u32(a.raw), vget_high_u32(b.raw)); +#endif + return Vec128( + vuzp2q_u32(vreinterpretq_u32_u64(rlo), vreinterpretq_u32_u64(rhi))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int32x4_t hi_lo = vreinterpretq_s32_s64(vmull_s32(a.raw, b.raw)); + return Vec128(vget_low_s32(vuzp2q_s32(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(Vec128 a, + Vec128 b) { + uint32x4_t hi_lo = vreinterpretq_u32_u64(vmull_u32(a.raw, b.raw)); + return Vec128(vget_low_u32(vuzp2q_u32(hi_lo, hi_lo))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi_0; + T hi_1; + + Mul128(GetLane(a), GetLane(b), &hi_0); + Mul128(detail::GetLane<1>(a), detail::GetLane<1>(b), &hi_1); + + return Dup128VecFromValues(Full128(), hi_0, hi_1); +} + +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Set(Full64(), hi); +} + +HWY_API Vec128 MulFixedPoint15(Vec128 a, Vec128 b) { + return Vec128(vqrdmulhq_s16(a.raw, b.raw)); +} +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + return Vec128(vqrdmulh_s16(a.raw, b.raw)); +} + +// ------------------------------ Floating-point division + +// Emulate missing intrinsic +#if HWY_HAVE_FLOAT64 && HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +HWY_INLINE float64x1_t vrecpe_f64(float64x1_t raw) { + const CappedTag d; + const Twice dt; + using VT = VFromD; + return LowerHalf(d, VT(vrecpeq_f64(Combine(dt, v, v).raw))).raw; +} +#endif + +// Approximate reciprocal +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(ApproximateReciprocal, vrecpe, _, 1) + +#if HWY_HAVE_FLOAT64 +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator/, vdiv, _, 2) +#else // !HWY_HAVE_FLOAT64 +namespace detail { +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(ReciprocalNewtonRaphsonStep, vrecps, _, 2) +} // namespace detail + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + auto x = ApproximateReciprocal(b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + return a * x; +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Absolute value of difference. + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(AbsDiff, vabd, _, 2) +HWY_NEON_DEF_FUNCTION_UI_8_16_32(AbsDiff, vabd, _, 2) // no UI64 + +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +// ------------------------------ Integer multiply-add + +// Per-target flag to prevent generic_ops-inl.h from defining int MulAdd. +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA +#else +#define HWY_NATIVE_INT_FMA +#endif + +// Wrappers for changing argument order to what intrinsics expect. +namespace detail { +// All except ui64 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(MulAdd, vmla, _, 3) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(MulAdd, vmla, _, 3) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(NegMulAdd, vmls, _, 3) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(NegMulAdd, vmls, _, 3) +} // namespace detail + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return detail::MulAdd(add, mul, x); +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return detail::NegMulAdd(add, mul, x); +} + +// 64-bit integer +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Add(Mul(mul, x), add); +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Sub(add, Mul(mul, x)); +} + +// ------------------------------ Floating-point multiply-add variants + +namespace detail { + +#if HWY_NATIVE_FMA +// Wrappers for changing argument order to what intrinsics expect. +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(MulAdd, vfma, _, 3) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(NegMulAdd, vfms, _, 3) +#else +// Emulate. Matches intrinsics arg order. +template +HWY_API Vec128 MulAdd(Vec128 add, Vec128 mul, + Vec128 x) { + return mul * x + add; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 add, Vec128 mul, + Vec128 x) { + return add - mul * x; +} + +#endif // HWY_NATIVE_FMA +} // namespace detail + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return detail::MulAdd(add, mul, x); +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return detail::NegMulAdd(add, mul, x); +} + +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return MulAdd(mul, x, Neg(sub)); +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Neg(MulAdd(mul, x, sub)); +} + +// ------------------------------ Floating-point square root (IfThenZeroElse) + +// Emulate missing intrinsic +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 490 +HWY_INLINE float64x1_t vrsqrte_f64(float64x1_t raw) { + const CappedTag d; + const Twice dt; + using VT = VFromD; + const VFromD v(raw); + return LowerHalf(d, VT(vrsqrteq_f64(Combine(dt, v, v).raw))).raw; +} +#endif + +// Approximate reciprocal square root +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(ApproximateReciprocalSqrt, vrsqrte, _, 1) + +#if HWY_HAVE_FLOAT64 +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +// Full precision square root +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Sqrt, vsqrt, _, 1) +#else // !HWY_HAVE_FLOAT64 +namespace detail { +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(ReciprocalSqrtStep, vrsqrts, _, 2) +} // namespace detail + +template +HWY_API Vec128 Sqrt(const Vec128 v) { + auto recip = ApproximateReciprocalSqrt(v); + + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + + const auto root = v * recip; + return IfThenZeroElse(v == Zero(Simd()), root); +} +#endif // HWY_HAVE_FLOAT64 + +// ================================================== LOGICAL + +// ------------------------------ Not + +// There is no 64-bit vmvn, so cast instead of using HWY_NEON_DEF_FUNCTION. +template +HWY_API Vec128 Not(const Vec128 v) { + const DFromV d; + const Repartition d8; + return BitCast(d, Vec128(vmvnq_u8(BitCast(d8, v).raw))); +} +template +HWY_API Vec128 Not(const Vec128 v) { + const DFromV d; + const Repartition d8; + using V8 = decltype(Zero(d8)); + return BitCast(d, V8(vmvn_u8(BitCast(d8, v).raw))); +} + +// ------------------------------ And +HWY_NEON_DEF_FUNCTION_INTS_UINTS(And, vand, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 And(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) & BitCast(du, b)); +} + +// ------------------------------ AndNot + +namespace detail { +// reversed_andnot returns a & ~b. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(reversed_andnot, vbic, _, 2) +} // namespace detail + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return detail::reversed_andnot(mask, not_mask); +} + +// Uses the u32/64 defined above. +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + const DFromV d; + const RebindToUnsigned du; + VFromD ret = + detail::reversed_andnot(BitCast(du, mask), BitCast(du, not_mask)); + return BitCast(d, ret); +} + +// ------------------------------ Or + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Or, vorr, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 Or(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) | BitCast(du, b)); +} + +// ------------------------------ Xor + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Xor, veor, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 Xor(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) ^ BitCast(du, b)); +} + +// ------------------------------ Xor3 +#if HWY_ARCH_ARM_A64 && defined(__ARM_FEATURE_SHA3) +HWY_NEON_DEF_FUNCTION_FULL_UI(Xor3, veor3, _, 3) + +// Half vectors are not natively supported. Two Xor are likely more efficient +// than Combine to 128-bit. +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { + return Xor(x1, Xor(x2, x3)); +} + +template +HWY_API Vec128 Xor3(const Vec128 x1, const Vec128 x2, + const Vec128 x3) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); +} + +#else +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { + return Xor(x1, Xor(x2, x3)); +} +#endif + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ BitwiseIfThenElse + +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return IfVecThenElse(mask, yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ I64/U64 AbsDiff + +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Max(a, b) - Min(a, b); +} + +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Or(SaturatedSub(a, b), SaturatedSub(b, a)); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, Vec128 v) { + const Full128 d8; + return Vec128(vcntq_u8(BitCast(d8, v).raw)); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + const Simd d8; + return Vec128(vcnt_u8(BitCast(d8, v).raw)); +} + +// NEON lacks popcount for lane sizes > 1, so take pairwise sums of the bytes. +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u8(bytes)); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u8(bytes)); +} + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u16(vpaddlq_u8(bytes))); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u16(vpaddl_u8(bytes))); +} + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(bytes)))); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u32(vpaddl_u16(vpaddl_u8(bytes)))); +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +// ================================================== SIGN + +// ------------------------------ Abs +// i64 is implemented after BroadcastSignBit. +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Abs, vabs, _, 1) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Abs, vabs, _, 1) + +// ------------------------------ SaturatedAbs +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +HWY_NEON_DEF_FUNCTION_INT_8_16_32(SaturatedAbs, vqabs, _, 1) + +// ------------------------------ CopySign +template +HWY_API Vec128 CopySign(Vec128 magn, Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} + +// ================================================== MASK + +// ------------------------------ To/from vector + +// Mask and Vec have the same representation (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const Simd, N, 0> du; + return Mask128(BitCast(du, v).raw); +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API VFromD VecFromMask(D d, const MFromD m) { + // Raw type of masks is unsigned. + const RebindToUnsigned du; + return BitCast(d, VFromD(m.raw)); +} + +// ------------------------------ RebindMask (MaskFromVec) + +template +HWY_API MFromD RebindMask(DTo /* tag */, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD(m.raw); +} + +// ------------------------------ IfThenElse + +#define HWY_NEON_BUILD_TPL_HWY_IF +#define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \ + const Mask128 mask, const Vec128 yes, \ + const Vec128 no +#define HWY_NEON_BUILD_ARG_HWY_IF mask.raw, yes.raw, no.raw + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) + +#if HWY_HAVE_FLOAT16 +#define HWY_NEON_IF_EMULATED_IF_THEN_ELSE(V) HWY_IF_BF16(TFromV) +#else +#define HWY_NEON_IF_EMULATED_IF_THEN_ELSE(V) HWY_IF_SPECIAL_FLOAT_V(V) +#endif + +template +HWY_API V IfThenElse(MFromD> mask, V yes, V no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +#undef HWY_NEON_IF_EMULATED_IF_THEN_ELSE +#undef HWY_NEON_BUILD_TPL_HWY_IF +#undef HWY_NEON_BUILD_RET_HWY_IF +#undef HWY_NEON_BUILD_PARAM_HWY_IF +#undef HWY_NEON_BUILD_ARG_HWY_IF + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + Mask128 m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(DFromM(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +// ------------------------------ Shuffle2301 (for i64 compares) + +// Swap 32-bit halves in 64-bits +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_u32(v.raw)); +} +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_s32(v.raw)); +} +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_f32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_u32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_s32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_f32(v.raw)); +} + +#define HWY_NEON_BUILD_TPL_HWY_COMPARE +#define HWY_NEON_BUILD_RET_HWY_COMPARE(type, size) Mask128 +#define HWY_NEON_BUILD_PARAM_HWY_COMPARE(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_ARG_HWY_COMPARE a.raw, b.raw + +// ------------------------------ Equality +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator==, vceq, _, HWY_COMPARE) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator==, vceq, _, HWY_COMPARE) +#else +// No 64-bit comparisons on armv7: emulate them below, after Shuffle2301. +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator==, vceq, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator==, vceq, _, HWY_COMPARE) +#endif + +// ------------------------------ Strict inequality (signed, float) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator<, vclt, _, HWY_COMPARE) +#else +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator<, vclt, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator<, vclt, _, HWY_COMPARE) +#endif +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<, vclt, _, HWY_COMPARE) + +// ------------------------------ Weak inequality (float) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator<=, vcle, _, HWY_COMPARE) +#else +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator<=, vcle, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator<=, vcle, _, HWY_COMPARE) +#endif +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<=, vcle, _, HWY_COMPARE) + +#undef HWY_NEON_BUILD_TPL_HWY_COMPARE +#undef HWY_NEON_BUILD_RET_HWY_COMPARE +#undef HWY_NEON_BUILD_PARAM_HWY_COMPARE +#undef HWY_NEON_BUILD_ARG_HWY_COMPARE + +// ------------------------------ Armv7 i64 compare (Shuffle2301, Eq) + +#if HWY_ARCH_ARM_V7 + +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + const int64x2_t sub = vqsubq_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec128(sub))); +} +HWY_API Mask128 operator<(const Vec64 a, + const Vec64 b) { + const int64x1_t sub = vqsub_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec64(sub))); +} + +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = AndNot(a, b) | AndNot(a ^ b, a - b); + return MaskFromVec(BitCast(du, BroadcastSignBit(BitCast(di, msb)))); +} + +template +HWY_API Mask128 operator<=(const Vec128 a, + const Vec128 b) { + return Not(b < a); +} + +template +HWY_API Mask128 operator<=(const Vec128 a, + const Vec128 b) { + return Not(b < a); +} + +#endif + +// ------------------------------ operator!= (operator==) + +// Customize HWY_NEON_DEF_FUNCTION to call 2 functions. +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +// This cannot have _any_ template argument (in x86_128 we can at least have N +// as an argument), otherwise it is not more specialized than rewritten +// operator== in C++20, leading to compile errors. +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_API Mask128 name(Vec128 a, \ + Vec128 b) { \ + return Not(a == b); \ + } + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator!=, ignored, ignored, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return operator<(b, a); +} +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return operator<=(b, a); +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API MFromD FirstN(D d, size_t num) { + const RebindToSigned di; // Signed comparisons are cheaper. + using TI = TFromD; + return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(num))); +} + +// ------------------------------ TestBit (Eq) + +#define HWY_NEON_BUILD_TPL_HWY_TESTBIT +#define HWY_NEON_BUILD_RET_HWY_TESTBIT(type, size) Mask128 +#define HWY_NEON_BUILD_PARAM_HWY_TESTBIT(type, size) \ + Vec128 v, Vec128 bit +#define HWY_NEON_BUILD_ARG_HWY_TESTBIT v.raw, bit.raw + +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(TestBit, vtst, _, HWY_TESTBIT) +#else +// No 64-bit versions on armv7 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) + +template +HWY_API Mask128 TestBit(Vec128 v, + Vec128 bit) { + return (v & bit) == bit; +} +template +HWY_API Mask128 TestBit(Vec128 v, + Vec128 bit) { + return (v & bit) == bit; +} + +#endif +#undef HWY_NEON_BUILD_TPL_HWY_TESTBIT +#undef HWY_NEON_BUILD_RET_HWY_TESTBIT +#undef HWY_NEON_BUILD_PARAM_HWY_TESTBIT +#undef HWY_NEON_BUILD_ARG_HWY_TESTBIT + +// ------------------------------ Abs i64 (IfThenElse, BroadcastSignBit) +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vabsq_s64(v.raw)); +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} +HWY_API Vec64 Abs(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vabs_s64(v.raw)); +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +HWY_API Vec128 SaturatedAbs(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vqabsq_s64(v.raw)); +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), SaturatedSub(zero, v), v); +#endif +} +HWY_API Vec64 SaturatedAbs(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vqabs_s64(v.raw)); +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), SaturatedSub(zero, v), v); +#endif +} + +// ------------------------------ Min (IfThenElse, BroadcastSignBit) + +// Unsigned +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Min, vmin, _, 2) + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const DFromV du; + const RebindToSigned di; + return BitCast(du, BitCast(di, a) - BitCast(di, SaturatedSub(a, b))); +#endif +} + +// Signed +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Min, vmin, _, 2) + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const Vec128 sign = SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), a, b); +#endif +} + +// Float: IEEE minimumNumber on v8 +#if HWY_ARCH_ARM_A64 + +HWY_NEON_DEF_FUNCTION_FLOAT_16_32(Min, vminnm, _, 2) + +// GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic, so define +// in terms of the 128-bit intrinsic. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +namespace detail { + +template +HWY_INLINE V F64Vec64Min(V a, V b) { + const DFromV d; + const Twice dt; + return LowerHalf(d, Min(ZeroExtendVector(dt, a), ZeroExtendVector(dt, b))); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + +HWY_API Vec64 Min(Vec64 a, Vec64 b) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return detail::F64Vec64Min(a, b); +#else + return Vec64(vminnm_f64(a.raw, b.raw)); +#endif +} + +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128(vminnmq_f64(a.raw, b.raw)); +} + +#else +// Armv7: NaN if any is NaN. +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vmin, _, 2) +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ Max (IfThenElse, BroadcastSignBit) + +// Unsigned (no u64) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Max, vmax, _, 2) + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const DFromV du; + const RebindToSigned di; + return BitCast(du, BitCast(di, b) + BitCast(di, SaturatedSub(a, b))); +#endif +} + +// Signed (no i64) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Max, vmax, _, 2) + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const Vec128 sign = SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), b, a); +#endif +} + +// Float: IEEE minimumNumber on v8 +#if HWY_ARCH_ARM_A64 + +HWY_NEON_DEF_FUNCTION_FLOAT_16_32(Max, vmaxnm, _, 2) + +// GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic, so define +// in terms of the 128-bit intrinsic. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +namespace detail { + +template +HWY_INLINE V F64Vec64Max(V a, V b) { + const DFromV d; + const Twice dt; + return LowerHalf(d, Max(ZeroExtendVector(dt, a), ZeroExtendVector(dt, b))); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + +HWY_API Vec64 Max(Vec64 a, Vec64 b) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return detail::F64Vec64Max(a, b); +#else + return Vec64(vmaxnm_f64(a.raw, b.raw)); +#endif +} + +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128(vmaxnmq_f64(a.raw, b.raw)); +} + +#else +// Armv7: NaN if any is NaN. +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmax, _, 2) +#endif // HWY_ARCH_ARM_A64 + +// ================================================== MEMORY + +// ------------------------------ Load 128 + +template +HWY_API Vec128 LoadU(D /* tag */, + const uint8_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u8(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const uint16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u16(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const uint32_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u32(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const uint64_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u64(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const int8_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s8(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const int16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s16(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const int32_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s32(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const int64_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s64(unaligned)); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 LoadU(D /* tag */, + const float16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f16(detail::NativeLanePointer(unaligned))); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 LoadU(D /* tag */, + const bfloat16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_bf16(detail::NativeLanePointer(unaligned))); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 LoadU(D /* tag */, const float* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f32(unaligned)); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API Vec128 LoadU(D /* tag */, + const double* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f64(unaligned)); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Load 64 + +template +HWY_API Vec64 LoadU(D /* tag */, const uint8_t* HWY_RESTRICT p) { + return Vec64(vld1_u8(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const uint16_t* HWY_RESTRICT p) { + return Vec64(vld1_u16(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const uint32_t* HWY_RESTRICT p) { + return Vec64(vld1_u32(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const uint64_t* HWY_RESTRICT p) { + return Vec64(vld1_u64(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const int8_t* HWY_RESTRICT p) { + return Vec64(vld1_s8(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const int16_t* HWY_RESTRICT p) { + return Vec64(vld1_s16(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const int32_t* HWY_RESTRICT p) { + return Vec64(vld1_s32(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const int64_t* HWY_RESTRICT p) { + return Vec64(vld1_s64(p)); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec64 LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return Vec64(vld1_f16(detail::NativeLanePointer(p))); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 LoadU(D /* tag */, const bfloat16_t* HWY_RESTRICT p) { + return Vec64(vld1_bf16(detail::NativeLanePointer(p))); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 LoadU(D /* tag */, const float* HWY_RESTRICT p) { + return Vec64(vld1_f32(p)); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API Vec64 LoadU(D /* tag */, const double* HWY_RESTRICT p) { + return Vec64(vld1_f64(p)); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Load 32 + +// Actual 32-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +template +HWY_API Vec32 LoadU(D /*tag*/, const uint32_t* HWY_RESTRICT p) { + return Vec32(vld1_dup_u32(p)); +} +template +HWY_API Vec32 LoadU(D /*tag*/, const int32_t* HWY_RESTRICT p) { + return Vec32(vld1_dup_s32(p)); +} +template +HWY_API Vec32 LoadU(D /*tag*/, const float* HWY_RESTRICT p) { + return Vec32(vld1_dup_f32(p)); +} + +// {u,i}{8,16} +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +// ------------------------------ Load 16 + +// Actual 16-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +template +HWY_API VFromD LoadU(D /* tag */, const uint16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_u16(p)); +} +template +HWY_API VFromD LoadU(D /* tag */, const int16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_s16(p)); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_f16(detail::NativeLanePointer(p))); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API VFromD LoadU(D /* tag */, const bfloat16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_bf16(detail::NativeLanePointer(p))); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +// 8-bit x2 +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d16; + uint16_t buf; + CopyBytes<2>(p, &buf); + return BitCast(d, LoadU(d16, &buf)); +} + +// ------------------------------ Load 8 +template +HWY_API VFromD LoadU(D /* tag */, const uint8_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_u8(p)); +} +template +HWY_API VFromD LoadU(D /* tag */, const int8_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_s8(p)); +} + +// ------------------------------ Load misc + +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadU(du, detail::U16LanePointer(p))); +} + +// On Arm, Load is the same as LoadU. +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT aligned) { + return IfThenElse(m, Load(d, aligned), v); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// ------------------------------ Store 128 + +template +HWY_API void StoreU(Vec128 v, D /* tag */, + uint8_t* HWY_RESTRICT unaligned) { + vst1q_u8(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + uint16_t* HWY_RESTRICT unaligned) { + vst1q_u16(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + uint32_t* HWY_RESTRICT unaligned) { + vst1q_u32(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + uint64_t* HWY_RESTRICT unaligned) { + vst1q_u64(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + int8_t* HWY_RESTRICT unaligned) { + vst1q_s8(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + int16_t* HWY_RESTRICT unaligned) { + vst1q_s16(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + int32_t* HWY_RESTRICT unaligned) { + vst1q_s32(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + int64_t* HWY_RESTRICT unaligned) { + vst1q_s64(unaligned, v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + float16_t* HWY_RESTRICT unaligned) { + vst1q_f16(detail::NativeLanePointer(unaligned), v.raw); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + bfloat16_t* HWY_RESTRICT unaligned) { + vst1q_bf16(detail::NativeLanePointer(unaligned), v.raw); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + float* HWY_RESTRICT unaligned) { + vst1q_f32(unaligned, v.raw); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + double* HWY_RESTRICT unaligned) { + vst1q_f64(unaligned, v.raw); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Store 64 + +template +HWY_API void StoreU(Vec64 v, D /* tag */, uint8_t* HWY_RESTRICT p) { + vst1_u8(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, uint16_t* HWY_RESTRICT p) { + vst1_u16(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, uint32_t* HWY_RESTRICT p) { + vst1_u32(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, uint64_t* HWY_RESTRICT p) { + vst1_u64(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, int8_t* HWY_RESTRICT p) { + vst1_s8(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, int16_t* HWY_RESTRICT p) { + vst1_s16(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, int32_t* HWY_RESTRICT p) { + vst1_s32(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, int64_t* HWY_RESTRICT p) { + vst1_s64(p, v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec64 v, D /* tag */, + float16_t* HWY_RESTRICT p) { + vst1_f16(detail::NativeLanePointer(p), v.raw); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec64 v, D /* tag */, + bfloat16_t* HWY_RESTRICT p) { + vst1_bf16(detail::NativeLanePointer(p), v.raw); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec64 v, D /* tag */, float* HWY_RESTRICT p) { + vst1_f32(p, v.raw); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API void StoreU(Vec64 v, D /* tag */, double* HWY_RESTRICT p) { + vst1_f64(p, v.raw); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Store 32 + +template +HWY_API void StoreU(Vec32 v, D, uint32_t* HWY_RESTRICT p) { + vst1_lane_u32(p, v.raw, 0); +} +template +HWY_API void StoreU(Vec32 v, D, int32_t* HWY_RESTRICT p) { + vst1_lane_s32(p, v.raw, 0); +} +template +HWY_API void StoreU(Vec32 v, D, float* HWY_RESTRICT p) { + vst1_lane_f32(p, v.raw, 0); +} + +// {u,i}{8,16} +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Repartition d32; + uint32_t buf = GetLane(BitCast(d32, v)); + CopyBytes<4>(&buf, p); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Repartition d32; + uint32_t buf = GetLane(BitCast(d32, v)); + CopyBytes<4>(&buf, p); +} +#endif +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Repartition d32; + uint32_t buf = GetLane(BitCast(d32, v)); + CopyBytes<4>(&buf, p); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +// ------------------------------ Store 16 + +template +HWY_API void StoreU(Vec16 v, D, uint16_t* HWY_RESTRICT p) { + vst1_lane_u16(p, v.raw, 0); +} +template +HWY_API void StoreU(Vec16 v, D, int16_t* HWY_RESTRICT p) { + vst1_lane_s16(p, v.raw, 0); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec16 v, D, float16_t* HWY_RESTRICT p) { + vst1_lane_f16(detail::NativeLanePointer(p), v.raw, 0); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec16 v, D, bfloat16_t* HWY_RESTRICT p) { + vst1_lane_bf16(detail::NativeLanePointer(p), v.raw, 0); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const Repartition d16; + const uint16_t buf = GetLane(BitCast(d16, v)); + CopyBytes<2>(&buf, p); +} + +// ------------------------------ Store 8 + +template +HWY_API void StoreU(Vec128 v, D, uint8_t* HWY_RESTRICT p) { + vst1_lane_u8(p, v.raw, 0); +} +template +HWY_API void StoreU(Vec128 v, D, int8_t* HWY_RESTRICT p) { + vst1_lane_s8(p, v.raw, 0); +} + +// ------------------------------ Store misc + +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return StoreU(BitCast(du, v), du, detail::U16LanePointer(p)); +} + +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +// On Arm, Store is the same as StoreU. +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + StoreU(v, d, aligned); +} + +HWY_DIAGNOSTICS(pop) + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + // Treat as unsigned so that we correctly support float16. + const RebindToUnsigned du; + const auto blended = + IfThenElse(RebindMask(du, m), BitCast(du, v), BitCast(du, LoadU(d, p))); + StoreU(BitCast(d, blended), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(const VFromD v, D d, TFromD* HWY_RESTRICT aligned) { +#if HWY_ARCH_ARM_A64 +#if HWY_COMPILER_GCC + __builtin_prefetch(aligned, 1, 0); +#elif HWY_COMPILER_MSVC + __prefetch2(aligned, 0x11); +#endif +#endif + Store(v, d, aligned); +} + +// ================================================== CONVERT + +// ------------------------------ ConvertTo + +#if HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 + +// TODO(janwas): use macro generator instead of handwritten +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f16_s16(v.raw)); +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f16_s16(v.raw)); +} + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f16_u16(v.raw)); +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f16_u16(v.raw)); +} + +#endif // HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f32_s32(v.raw)); +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f32_s32(v.raw)); +} + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f32_u32(v.raw)); +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f32_u32(v.raw)); +} + +#if HWY_HAVE_FLOAT64 + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f64_s64(v.raw)); +} +template +HWY_API Vec64 ConvertTo(D /* tag */, Vec64 v) { +// GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return Set(Full64(), static_cast(GetLane(v))); +#else + return Vec64(vcvt_f64_s64(v.raw)); +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +} + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f64_u64(v.raw)); +} +template +HWY_API Vec64 ConvertTo(D /* tag */, Vec64 v) { + // GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return Set(Full64(), static_cast(GetLane(v))); +#else + return Vec64(vcvt_f64_u64(v.raw)); +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +} + +#endif // HWY_HAVE_FLOAT64 + +namespace detail { +// Truncates (rounds toward zero). +template +HWY_INLINE Vec128 ConvertFToI(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an int32_t. + + int32x4_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzs %0.4s, %1.4s" +#else + "vcvt.s32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_s32_f32(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToI(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an int32_t. + + int32x2_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzs %0.2s, %1.2s" +#else + "vcvt.s32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_s32_f32(v.raw)); +#endif +} +template +HWY_INLINE Vec128 ConvertFToU(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an uint32_t. + + uint32x4_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzu %0.4s, %1.4s" +#else + "vcvt.u32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_u32_f32(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToU(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an uint32_t. + + uint32x2_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzu %0.2s, %1.2s" +#else + "vcvt.u32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_u32_f32(v.raw)); +#endif +} + +#if HWY_HAVE_FLOAT64 + +// Truncates (rounds toward zero). +template +HWY_INLINE Vec128 ConvertFToI(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int64_t. + int64x2_t raw_result; + __asm__("fcvtzs %0.2d, %1.2d" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_s64_f64(v.raw)); +#endif +} +template +HWY_INLINE Vec64 ConvertFToI(D /* tag */, Vec64 v) { +#if HWY_ARCH_ARM_A64 && \ + ((HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200)) + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int64_t. + // If compiling for AArch64 NEON with GCC 6 or earlier, use inline assembly to + // work around the missing vcvt_s64_f64 intrinsic. + int64x1_t raw_result; + __asm__("fcvtzs %d0, %d1" : "=w"(raw_result) : "w"(v.raw)); + return Vec64(raw_result); +#else + return Vec64(vcvt_s64_f64(v.raw)); +#endif +} +template +HWY_INLINE Vec128 ConvertFToU(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint64_t. + uint64x2_t raw_result; + __asm__("fcvtzu %0.2d, %1.2d" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_u64_f64(v.raw)); +#endif +} +template +HWY_INLINE Vec64 ConvertFToU(D /* tag */, Vec64 v) { +#if HWY_ARCH_ARM_A64 && \ + ((HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200)) + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint64_t. + + // Inline assembly is also used if compiling for AArch64 NEON with GCC 6 or + // earlier to work around the issue of the missing vcvt_u64_f64 intrinsic. + uint64x1_t raw_result; + __asm__("fcvtzu %d0, %d1" : "=w"(raw_result) : "w"(v.raw)); + return Vec64(raw_result); +#else + return Vec64(vcvt_u64_f64(v.raw)); +#endif +} + +#endif // HWY_HAVE_FLOAT64 + +#if HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 + +// Truncates (rounds toward zero). +template +HWY_INLINE Vec128 ConvertFToI(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int16_t. + int16x8_t raw_result; + __asm__("fcvtzs %0.8h, %1.8h" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_s16_f16(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToI(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int16_t. + int16x4_t raw_result; + __asm__("fcvtzs %0.4h, %1.4h" : "=w"(raw_result) : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_s16_f16(v.raw)); +#endif +} + +template +HWY_INLINE Vec128 ConvertFToU(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint16_t. + uint16x8_t raw_result; + __asm__("fcvtzu %0.8h, %1.8h" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_u16_f16(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToU(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint16_t. + uint16x4_t raw_result; + __asm__("fcvtzu %0.4h, %1.4h" : "=w"(raw_result) : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_u16_f16(v.raw)); +#endif +} + +#endif // HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 +} // namespace detail + +template +HWY_API VFromD ConvertTo(D di, VFromD> v) { + return detail::ConvertFToI(di, v); +} + +template +HWY_API VFromD ConvertTo(D du, VFromD> v) { + return detail::ConvertFToU(du, v); +} + +// ------------------------------ PromoteTo (ConvertTo) + +// Unsigned: zero-extend to full vector. +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_u8(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec32 v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128(vmovl_u16(vget_low_u16(a))); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_u16(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_u32(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D d, Vec64 v) { + return BitCast(d, Vec128(vmovl_u8(v.raw))); +} +template +HWY_API Vec128 PromoteTo(D d, Vec32 v) { + uint16x8_t a = vmovl_u8(v.raw); + return BitCast(d, Vec128(vmovl_u16(vget_low_u16(a)))); +} +template +HWY_API Vec128 PromoteTo(D d, Vec64 v) { + return BitCast(d, Vec128(vmovl_u16(v.raw))); +} +template +HWY_API Vec128 PromoteTo(D d, Vec64 v) { + return BitCast(d, Vec128(vmovl_u32(v.raw))); +} + +// Unsigned: zero-extend to half vector. +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_u16(vmovl_u8(v.raw))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(v.raw))))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_u32(vmovl_u16(v.raw))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_u64(vmovl_u32(v.raw))); +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + using VU16 = VFromD>; + return BitCast(d, VU16(vget_low_u16(vmovl_u8(v.raw)))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + const uint32x4_t u32 = vmovl_u16(vget_low_u16(vmovl_u8(v.raw))); + return VFromD(vget_low_s32(vreinterpretq_s32_u32(u32))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s32(vreinterpretq_s32_u32(vmovl_u16(v.raw)))); +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + using DU = RebindToUnsigned; + return BitCast(d, VFromD(vget_low_u64(vmovl_u32(v.raw)))); +} + +// U8/U16 to U64/I64: First, zero-extend to U32, and then zero-extend to +// TFromD +template +HWY_API VFromD PromoteTo(D d, V v) { + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +} + +// Signed: replicate sign bit to full vector. +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_s8(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec32 v) { + int16x8_t a = vmovl_s8(v.raw); + return Vec128(vmovl_s16(vget_low_s16(a))); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_s16(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_s32(v.raw)); +} + +// Signed: replicate sign bit to half vector. +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s16(vmovl_s8(v.raw))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(v.raw))))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s32(vmovl_s16(v.raw))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s64(vmovl_s32(v.raw))); +} + +// I8/I16 to I64: First, promote to I32, and then promote to I64 +template +HWY_API VFromD PromoteTo(D d, V v) { + const Rebind di32; + return PromoteTo(d, PromoteTo(di32, v)); +} + +#if HWY_NEON_HAVE_F16C + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vcvt_f32_f16(v.raw)); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_f32(vcvt_f32_f16(v.raw))); +} + +#endif // HWY_NEON_HAVE_F16C + +#if HWY_HAVE_FLOAT64 + +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vcvt_f64_f32(v.raw)); +} + +template +HWY_API Vec64 PromoteTo(D /* tag */, Vec32 v) { + return Vec64(vget_low_f64(vcvt_f64_f32(v.raw))); +} + +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + const int64x2_t i64 = vmovl_s32(v.raw); + return Vec128(vcvtq_f64_s64(i64)); +} + +template +HWY_API Vec64 PromoteTo(D d, Vec32 v) { + return ConvertTo(d, Vec64(vget_low_s64(vmovl_s32(v.raw)))); +} + +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + const uint64x2_t u64 = vmovl_u32(v.raw); + return Vec128(vcvtq_f64_u64(u64)); +} + +template +HWY_API Vec64 PromoteTo(D d, Vec32 v) { + return ConvertTo(d, Vec64(vget_low_u64(vmovl_u32(v.raw)))); +} + +template +HWY_API VFromD PromoteTo(D d64, VFromD> v) { + const RebindToFloat df64; + return ConvertTo(d64, PromoteTo(df64, v)); +} + +#else // !HWY_HAVE_FLOAT64 + +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{157}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + const auto f32_to_i32_result = ConvertTo(di32, adj_v); + const auto lo64_or_mask = PromoteTo( + di64, + BitCast(du32, VecFromMask(di32, Eq(f32_to_i32_result, + Set(di32, LimitsMax()))))); + + return Or(PromoteTo(di64, BitCast(di32, f32_to_i32_result)) + << PromoteTo(di64, exponent_adj), + lo64_or_mask); +} + +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { + const Rebind du32; + const RebindToFloat df32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{158}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + const auto f32_to_u32_result = ConvertTo(du32, adj_v); + const auto lo32_or_mask = PromoteTo( + du64, + VecFromMask(du32, f32_to_u32_result == Set(du32, LimitsMax()))); + + return Or(PromoteTo(du64, f32_to_u32_result) << PromoteTo(du64, exponent_adj), + lo32_or_mask); +} + +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD PromoteInRangeTo(D d64, VFromD> v) { + const Rebind>, decltype(d64)> d32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + constexpr uint32_t kExpAdjDecr = + 0xFFFFFF9Du + static_cast(!IsSigned>()); + + const auto exponent_adj = BitCast( + du32, SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, kExpAdjDecr)))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + return PromoteTo(d64, ConvertTo(d32, adj_v)) << PromoteTo(d64, exponent_adj); +} + +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "third_party/highway/hwy/ops/inside-inl.h" + +// ------------------------------ PromoteUpperTo + +#if HWY_ARCH_ARM_A64 + +// Per-target flag to prevent generic_ops-inl.h from defining PromoteUpperTo. +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// Unsigned: zero-extend to full vector. +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_u8(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_u16(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_u32(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D d, Vec128 v) { + return BitCast(d, Vec128(vmovl_high_u8(v.raw))); +} +template +HWY_API Vec128 PromoteUpperTo(D d, Vec128 v) { + return BitCast(d, Vec128(vmovl_high_u16(v.raw))); +} +template +HWY_API Vec128 PromoteUpperTo(D d, Vec128 v) { + return BitCast(d, Vec128(vmovl_high_u32(v.raw))); +} + +// Signed: replicate sign bit to full vector. +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_s8(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_s16(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_s32(v.raw)); +} + +#if HWY_NEON_HAVE_F16C + +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vcvt_high_f32_f16(v.raw)); +} + +#endif // HWY_NEON_HAVE_F16C + +template +HWY_API VFromD PromoteUpperTo(D df32, VFromD> v) { + const Repartition du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); +} + +#if HWY_HAVE_FLOAT64 + +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vcvt_high_f64_f32(v.raw)); +} + +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + const int64x2_t i64 = vmovl_high_s32(v.raw); + return Vec128(vcvtq_f64_s64(i64)); +} + +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + const uint64x2_t u64 = vmovl_high_u32(v.raw); + return Vec128(vcvtq_f64_u64(u64)); +} + +#endif // HWY_HAVE_FLOAT64 + +template +HWY_API VFromD PromoteUpperTo(D d64, Vec128 v) { +#if HWY_HAVE_FLOAT64 + const RebindToFloat df64; + return ConvertTo(d64, PromoteUpperTo(df64, v)); +#else + const Rebind dh; + return PromoteTo(d, UpperHalf(dh, v)); +#endif +} + +// Generic version for <=64 bit input/output (_high is only for full vectors). +template +HWY_API VFromD PromoteUpperTo(D d, V v) { + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ DemoteTo (ConvertTo) + +// From full vector to half or quarter +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovun_s32(v.raw)); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_s32(v.raw)); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec128 v) { + const uint16x4_t a = vqmovun_s32(v.raw); + return Vec32(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovun_s16(v.raw)); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec128 v) { + const int16x4_t a = vqmovn_s32(v.raw); + return Vec32(vqmovn_s16(vcombine_s16(a, a))); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_s16(v.raw)); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_u32(v.raw)); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec128 v) { + const uint16x4_t a = vqmovn_u32(v.raw); + return Vec32(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_u16(v.raw)); +} + +// From half vector to partial half +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovun_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovn_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const uint16x4_t a = vqmovun_s32(vcombine_s32(v.raw, v.raw)); + return VFromD(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovun_s16(vcombine_s16(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const int16x4_t a = vqmovn_s32(vcombine_s32(v.raw, v.raw)); + return VFromD(vqmovn_s16(vcombine_s16(a, a))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovn_s16(vcombine_s16(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovn_u32(vcombine_u32(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const uint16x4_t a = vqmovn_u32(vcombine_u32(v.raw, v.raw)); + return VFromD(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovn_u16(vcombine_u16(v.raw, v.raw))); +} + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_s64(v.raw)); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovun_s64(v.raw)); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_u64(v.raw)); +} +template +HWY_API VFromD DemoteTo(D d, Vec128 v) { + const Rebind di32; + return DemoteTo(d, DemoteTo(di32, v)); +} +template +HWY_API VFromD DemoteTo(D d, Vec128 v) { + const Rebind du32; + return DemoteTo(d, DemoteTo(du32, v)); +} +template +HWY_API VFromD DemoteTo(D d, Vec128 v) { + const Rebind du32; + return DemoteTo(d, DemoteTo(du32, v)); +} + +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { + return Vec32(vqmovn_s64(vcombine_s64(v.raw, v.raw))); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { + return Vec32(vqmovun_s64(vcombine_s64(v.raw, v.raw))); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { + return Vec32(vqmovn_u64(vcombine_u64(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D d, Vec64 v) { + const Rebind di32; + return DemoteTo(d, DemoteTo(di32, v)); +} +template +HWY_API VFromD DemoteTo(D d, Vec64 v) { + const Rebind du32; + return DemoteTo(d, DemoteTo(du32, v)); +} +template +HWY_API VFromD DemoteTo(D d, Vec64 v) { + const Rebind du32; + return DemoteTo(d, DemoteTo(du32, v)); +} + +#if HWY_NEON_HAVE_F16C + +// We already toggled HWY_NATIVE_F16C above. + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64{vcvt_f16_f32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f16_f32(vcombine_f32(v.raw, v.raw))); +} + +#endif // HWY_NEON_HAVE_F16C + +#if HWY_NEON_HAVE_F32_TO_BF16C +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +namespace detail { +#if HWY_NEON_HAVE_BFLOAT16 +// If HWY_NEON_HAVE_BFLOAT16 is true, detail::Vec128::type is +// bfloat16x4_t or bfloat16x8_t. +static HWY_INLINE bfloat16x4_t BitCastFromRawNeonBF16(bfloat16x4_t raw) { + return raw; +} +#else +// If HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true, +// detail::Vec128::type is uint16x4_t or uint16x8_t vector to +// work around compiler bugs that are there with GCC 13 or earlier or Clang 16 +// or earlier on AArch64. + +// The bfloat16x4_t vector returned by vcvt_bf16_f32 needs to be bitcasted to +// an uint16x4_t vector if HWY_NEON_HAVE_F32_TO_BF16C && +// !HWY_NEON_HAVE_BFLOAT16 is true. +static HWY_INLINE uint16x4_t BitCastFromRawNeonBF16(bfloat16x4_t raw) { + return vreinterpret_u16_bf16(raw); +} +#endif +} // namespace detail + +template +HWY_API VFromD DemoteTo(D /*dbf16*/, VFromD> v) { + return VFromD(detail::BitCastFromRawNeonBF16(vcvt_bf16_f32(v.raw))); +} +template +HWY_API VFromD DemoteTo(D /*dbf16*/, VFromD> v) { + return VFromD(detail::BitCastFromRawNeonBF16( + vcvt_bf16_f32(vcombine_f32(v.raw, v.raw)))); +} +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +#if HWY_HAVE_FLOAT64 + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vcvt_f32_f64(v.raw)); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { + return Vec32(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); +} + +template +HWY_API VFromD DemoteTo(D d32, VFromD> v) { + const Rebind>, D> d64; + return DemoteTo(d32, ConvertTo(d64, v)); +} + +#endif // HWY_HAVE_FLOAT64 + +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind di64; + const RebindToUnsigned du64; + +#if HWY_ARCH_ARM_A64 + const RebindToFloat df64; + + const auto k2p64_63 = Set(df64, 27670116110564327424.0); + const auto f64_hi52 = + Xor(BitCast(df64, ShiftRight<12>(BitCast(du64, v))), k2p64_63) - k2p64_63; + const auto f64_lo12 = + ConvertTo(df64, And(BitCast(du64, v), Set(du64, uint64_t{0x00000FFF}))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + const auto f64_bits_decrement = + And(ShiftRight<63>(BitCast(du64, Xor(f64_sum, f64_carry))), + f64_sum_is_inexact); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - f64_bits_decrement, f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +#else + const RebindToUnsigned du32; + const auto hi23 = TruncateTo(du32, ShiftRight<41>(BitCast(du64, v))); + const auto mid23 = And(TruncateTo(du32, ShiftRight<18>(BitCast(du64, v))), + Set(du32, uint32_t{0x007FFFFFu})); + const auto lo18 = + And(TruncateTo(du32, BitCast(du64, v)), Set(du32, uint32_t{0x0003FFFFu})); + + const auto k2p41_f32 = Set(df32, 2199023255552.0f); + const auto k2p64_63_f32 = Set(df32, 27670116110564327424.0f); + + const auto hi23_f32 = + BitCast(df32, Xor(hi23, BitCast(du32, k2p64_63_f32))) - k2p64_63_f32; + const auto mid23_f32 = + BitCast(df32, Or(mid23, BitCast(du32, k2p41_f32))) - k2p41_f32; + const auto lo18_f32 = ConvertTo(df32, lo18); + + const auto s_hi46 = hi23_f32 + mid23_f32; + const auto c_hi46 = (hi23_f32 - s_hi46) + mid23_f32; + + auto s_lo = c_hi46 + lo18_f32; + const auto c_lo = (c_hi46 - s_lo) + lo18_f32; + + const auto s_lo_inexact_mask = + VecFromMask(du32, RebindMask(du32, c_lo != Zero(df32))); + const auto s_lo_mag_adj = ShiftRight<31>( + And(s_lo_inexact_mask, Xor(BitCast(du32, s_lo), BitCast(du32, c_lo)))); + + s_lo = BitCast(df32, BitCast(du32, s_lo) - s_lo_mag_adj); + s_lo = + BitCast(df32, Or(BitCast(du32, s_lo), ShiftRight<31>(s_lo_inexact_mask))); + return s_hi46 + s_lo; +#endif +} + +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { +#if HWY_ARCH_ARM_A64 + const Rebind du64; + const RebindToFloat df64; + + const auto k2p64 = Set(df64, 18446744073709551616.0); + const auto f64_hi52 = Or(BitCast(df64, ShiftRight<12>(v)), k2p64) - k2p64; + const auto f64_lo12 = + ConvertTo(df64, And(v, Set(du64, uint64_t{0x00000FFF}))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - ShiftRight<63>(BitCast(du64, f64_carry)), + f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +#else + const RebindToUnsigned du32; + + const auto hi23 = TruncateTo(du32, ShiftRight<41>(v)); + const auto mid23 = And(TruncateTo(du32, ShiftRight<18>(v)), + Set(du32, uint32_t{0x007FFFFFu})); + const auto lo18 = And(TruncateTo(du32, v), Set(du32, uint32_t{0x0003FFFFu})); + + const auto k2p41_f32 = Set(df32, 2199023255552.0f); + const auto k2p64_f32 = Set(df32, 18446744073709551616.0f); + + const auto hi23_f32 = + BitCast(df32, Or(hi23, BitCast(du32, k2p64_f32))) - k2p64_f32; + const auto mid23_f32 = + BitCast(df32, Or(mid23, BitCast(du32, k2p41_f32))) - k2p41_f32; + const auto lo18_f32 = ConvertTo(df32, lo18); + + const auto s_hi46 = hi23_f32 + mid23_f32; + const auto c_hi46 = (hi23_f32 - s_hi46) + mid23_f32; + + auto s_lo = c_hi46 + lo18_f32; + const auto c_lo = (c_hi46 - s_lo) + lo18_f32; + + const auto s_lo_inexact_mask = + VecFromMask(du32, RebindMask(du32, c_lo != Zero(df32))); + const auto s_lo_mag_adj = ShiftRight<31>( + And(s_lo_inexact_mask, Xor(BitCast(du32, s_lo), BitCast(du32, c_lo)))); + + s_lo = BitCast(df32, BitCast(du32, s_lo) - s_lo_mag_adj); + s_lo = + BitCast(df32, Or(BitCast(du32, s_lo), ShiftRight<31>(s_lo_inexact_mask))); + return s_hi46 + s_lo; +#endif +} + +HWY_API Vec32 U8FromU32(Vec128 v) { + const uint8x16_t org_v = detail::BitCastToByte(v).raw; + const uint8x16_t w = vuzp1q_u8(org_v, org_v); + return Vec32(vget_low_u8(vuzp1q_u8(w, w))); +} +template +HWY_API Vec128 U8FromU32(Vec128 v) { + const uint8x8_t org_v = detail::BitCastToByte(v).raw; + const uint8x8_t w = vuzp1_u8(org_v, org_v); + return Vec128(vuzp1_u8(w, w)); +} + +// ------------------------------ Round (IfThenElse, mask, logical) + +#if HWY_ARCH_ARM_A64 +// Toward nearest integer +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Round, vrndn, _, 1) + +// Toward zero, aka truncate +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Trunc, vrnd, _, 1) + +// Toward +infinity, aka ceiling +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Ceil, vrndp, _, 1) + +// Toward -infinity, aka floor +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Floor, vrndm, _, 1) +#else + +// ------------------------------ Trunc + +// Armv7 only supports truncation to integer. We can either convert back to +// float (3 floating-point and 2 logic operations) or manipulate the binary32 +// representation, clearing the lowest 23-exp mantissa bits. This requires 9 +// integer operations and 3 constants, which is likely more expensive. + +namespace detail { + +// The original value is already the desired result if NaN or the magnitude is +// large (i.e. the value is already an integer). +template +HWY_INLINE Mask128 UseInt(const Vec128 v) { + return Abs(v) < Set(Simd(), MantissaEnd()); +} + +} // namespace detail + +template +HWY_API Vec128 Trunc(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), int_f, v); +} + +template +HWY_API Vec128 Round(const Vec128 v) { + const DFromV df; + + // Armv7 also lacks a native NearestInt, but we can instead rely on rounding + // (we assume the current mode is nearest-even) after addition with a large + // value such that no mantissa bits remain. We may need a compiler flag for + // precise floating-point to prevent this from being "optimized" out. + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +template +HWY_API Vec128 Ceil(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +template +HWY_API Vec128 Floor(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#endif + +// ------------------------------ CeilInt/FloorInt +#if HWY_ARCH_ARM_A64 + +#ifdef HWY_NATIVE_CEIL_FLOOR_INT +#undef HWY_NATIVE_CEIL_FLOOR_INT +#else +#define HWY_NATIVE_CEIL_FLOOR_INT +#endif + +#if HWY_HAVE_FLOAT16 +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtpq_s16_f16(v.raw)); +} + +template +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtp_s16_f16(v.raw)); +} + +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtmq_s16_f16(v.raw)); +} + +template +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtm_s16_f16(v.raw)); +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtpq_s32_f32(v.raw)); +} + +template +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtp_s32_f32(v.raw)); +} + +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtpq_s64_f64(v.raw)); +} + +template +HWY_API Vec128 CeilInt(const Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 610 + // Workaround for missing vcvtp_s64_f64 intrinsic + const DFromV d; + const RebindToSigned di; + const Twice dt; + return LowerHalf(di, CeilInt(Combine(dt, v, v))); +#else + return Vec128(vcvtp_s64_f64(v.raw)); +#endif +} + +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtmq_s32_f32(v.raw)); +} + +template +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtm_s32_f32(v.raw)); +} + +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtmq_s64_f64(v.raw)); +} + +template +HWY_API Vec128 FloorInt(const Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 610 + // Workaround for missing vcvtm_s64_f64 intrinsic + const DFromV d; + const RebindToSigned di; + const Twice dt; + return LowerHalf(di, FloorInt(Combine(dt, v, v))); +#else + return Vec128(vcvtm_s64_f64(v.raw)); +#endif +} + +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ NearestInt (Round) + +#if HWY_HAVE_FLOAT16 +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s16_f16(v.raw)); +} +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtn_s16_f16(v.raw)); +} +#endif + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s32_f32(v.raw)); +} +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtn_s32_f32(v.raw)); +} + +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s64_f64(v.raw)); +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 610 + // Workaround for missing vcvtn_s64_f64 intrinsic + const DFromV d; + const RebindToSigned di; + const Twice dt; + return LowerHalf(di, NearestInt(Combine(dt, v, v))); +#else + return Vec128(vcvtn_s64_f64(v.raw)); +#endif +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + return DemoteTo(di32, NearestInt(v)); +} + +#else + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + const RebindToSigned> di; + return ConvertTo(di, Round(v)); +} + +#endif + +// ------------------------------ Floating-point classification + +#if !HWY_COMPILER_CLANG || HWY_COMPILER_CLANG > 1801 || HWY_ARCH_ARM_V7 +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return v != v; +} +#else +// Clang up to 18.1 generates less efficient code than the expected FCMEQ, see +// https://github.com/numpy/numpy/issues/27313 and +// https://github.com/numpy/numpy/pull/22954/files and +// https://github.com/llvm/llvm-project/issues/59855 + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.8h, %1.8h, %1.8h" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.4h, %1.4h, %1.4h" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.4s, %1.4s, %1.4s" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.2s, %1.2s, %1.2s" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} + +#if HWY_HAVE_FLOAT64 +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.2d, %1.2d, %1.2d" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %d0, %d1, %d1" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_COMPILER_CLANG + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +// <= 64 bit: just return different type +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128(v.raw); +} + +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_u8(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_u16(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_u32(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_u64(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_s8(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_s16(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_s32(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_s64(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_f32(v.raw)); +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_f16(v.raw)); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_bf16(v.raw)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +#if HWY_HAVE_FLOAT64 +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_f64(v.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +template ), HWY_IF_V_SIZE_V(V, 16)> +HWY_API VFromD>> LowerHalf(V v) { + const Full128 du; + const Half> dh; + return BitCast(dh, LowerHalf(BitCast(du, v))); +} + +template +HWY_API VFromD LowerHalf(DH /* tag */, VFromD> v) { + return LowerHalf(v); +} + +// ------------------------------ CombineShiftRightBytes + +// 128-bit +template > +HWY_API Vec128 CombineShiftRightBytes(D d, Vec128 hi, Vec128 lo) { + static_assert(0 < kBytes && kBytes < 16, "kBytes must be in [1, 15]"); + const Repartition d8; + uint8x16_t v8 = vextq_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, Vec128(v8)); +} + +// 64-bit +template > +HWY_API Vec64 CombineShiftRightBytes(D d, Vec64 hi, Vec64 lo) { + static_assert(0 < kBytes && kBytes < 8, "kBytes must be in [1, 7]"); + const Repartition d8; + uint8x8_t v8 = vext_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, VFromD(v8)); +} + +// <= 32-bit defined after ShiftLeftBytes. + +// ------------------------------ Shift vector by constant #bytes + +namespace detail { + +// Partially specialize because kBytes = 0 and >= size are compile errors; +// callers replace the latter with 0xFF for easier specialization. +template +struct ShiftLeftBytesT { + // Full + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + const Full128 d; + return CombineShiftRightBytes<16 - kBytes>(d, v, Zero(d)); + } + + // Partial + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + // Expand to 64-bit so we only use the native EXT instruction. + const Full64 d64; + const auto zero64 = Zero(d64); + const decltype(zero64) v64(v.raw); + return Vec128( + CombineShiftRightBytes<8 - kBytes>(d64, v64, zero64).raw); + } +}; +template <> +struct ShiftLeftBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; +template <> +struct ShiftLeftBytesT<0xFF> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return Xor(v, v); + } +}; + +template +struct ShiftRightBytesT { + template + HWY_INLINE Vec128 operator()(Vec128 v) { + const DFromV d; + // For < 64-bit vectors, zero undefined lanes so we shift in zeros. + if (d.MaxBytes() < 8) { + constexpr size_t kReg = d.MaxBytes() == 16 ? 16 : 8; + const Simd dreg; + v = Vec128( + IfThenElseZero(FirstN(dreg, N), VFromD(v.raw)).raw); + } + return CombineShiftRightBytes(d, Zero(d), v); + } +}; +template <> +struct ShiftRightBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; +template <> +struct ShiftRightBytesT<0xFF> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return Xor(v, v); + } +}; + +} // namespace detail + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + return detail::ShiftLeftBytesT<(kBytes >= d.MaxBytes() ? 0xFF : kBytes)>()(v); +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +template +HWY_API VFromD ShiftLeftLanes(D d, VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes)>(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + return detail::ShiftRightBytesT<(kBytes >= d.MaxBytes() ? 0xFF : kBytes)>()( + v); +} + +template +HWY_API VFromD ShiftRightLanes(D d, VFromD v) { + const Repartition d8; + return BitCast( + d, ShiftRightBytes)>(d8, BitCast(d8, v))); +} + +// Calls ShiftLeftBytes +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + const Full64 d_full8; + const Repartition, decltype(d_full8)> d_full; + using V64 = VFromD; + const V64 hi64(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V64 lo64 = ShiftLeftBytes<8 - kSize>(V64(BitCast(d8, lo).raw)); + const V64 r = CombineShiftRightBytes<8 - kSize + kBytes>(d_full8, hi64, lo64); + // After casting to full 64-bit vector of correct type, shrink to 32-bit + return VFromD(BitCast(d_full, r).raw); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_u8(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_u16(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_u32(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_u64(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_s8(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_s16(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_s32(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_s64(v.raw)); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_f16(v.raw)); +} +#endif +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_bf16(v.raw)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_f32(v.raw)); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_f64(v.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +template +HWY_API VFromD UpperHalf(D dh, VFromD> v) { + const RebindToUnsigned> du; + const Half duh; + return BitCast(dh, UpperHalf(duh, BitCast(du, v))); +} + +// Partial +template +HWY_API VFromD UpperHalf(DH dh, VFromD> v) { + const Twice d; + const RebindToUnsigned du; + const VFromD upper = + ShiftRightBytes(du, BitCast(du, v)); + return VFromD(BitCast(d, upper).raw); +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(Vec128 v) { + return v; +} + +#if HWY_ARCH_ARM_A64 +// Unsigned +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return Vec128(vdupq_laneq_u8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_u64(v.raw, kLane)); +} + +// Signed +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return Vec128(vdupq_laneq_s8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_s64(v.raw, kLane)); +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_f16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f16(v.raw, kLane)); +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_bf16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_bf16(v.raw, kLane)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_f32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_f64(v.raw, kLane)); +} + +#else // !HWY_ARCH_ARM_A64 +// No vdupq_laneq_* on armv7: use vgetq_lane_* + vdupq_n_*. + +// Unsigned +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return Vec128(vdupq_n_u8(vgetq_lane_u8(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_u16(vgetq_lane_u16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_u32(vgetq_lane_u32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_u64(vgetq_lane_u64(v.raw, kLane))); +} + +// Signed +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return Vec128(vdupq_n_s8(vgetq_lane_s8(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_s16(vgetq_lane_s16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_s32(vgetq_lane_s32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_s64(vgetq_lane_s64(v.raw, kLane))); +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_f16(vgetq_lane_f16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f16(v.raw, kLane)); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_bf16(vgetq_lane_bf16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_bf16(v.raw, kLane)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_f32(vgetq_lane_f32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} + +#endif // HWY_ARCH_ARM_A64 + +template ), + HWY_IF_LANES_GT_D(DFromV, 1)> +HWY_API V Broadcast(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Broadcast(BitCast(du, v))); +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + typename detail::Raw128::type raw; +}; + +namespace detail { + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + return Iota(d8, 0); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + return Zero(d8); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}; + return Load(d8, kByteOffsets); +} + +} // namespace detail + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + (void)d; + return Indices128, MaxLanes(D())>{BitCast(d, vec).raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + const Repartition d8; + using V8 = VFromD; + + // Broadcast each lane index to all bytes of T and shift to bytes + const V8 lane_indices = TableLookupBytes( + BitCast(d8, vec), detail::IndicesFromVecBroadcastLaneBytes(d)); + constexpr int kIndexShiftAmt = static_cast(FloorLog2(sizeof(T))); + const V8 byte_indices = ShiftLeft(lane_indices); + const V8 sum = Add(byte_indices, detail::IndicesFromVecByteOffsets(d)); + return Indices128, MaxLanes(D())>{BitCast(d, sum).raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> SetTableIndices(D d, + const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + const RebindToSigned di; + return BitCast( + d, TableLookupBytes(BitCast(di, v), BitCast(di, Vec128{idx.raw}))); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Twice dt; +// TableLookupLanes currently requires table and index vectors to be the same +// size, though a half-length index vector would be sufficient here. +#if HWY_IS_MSAN + const Vec128 idx_vec{idx.raw}; + const Indices128 idx2{Combine(dt, idx_vec, idx_vec).raw}; +#else + // We only keep LowerHalf of the result, which is valid in idx. + const Indices128 idx2{idx.raw}; +#endif + return LowerHalf(d, TableLookupLanes(Combine(dt, b, a), idx2)); +} + +template +HWY_API Vec64 TwoTablesLookupLanes(Vec64 a, Vec64 b, + Indices128 idx) { + const DFromV d; + const Repartition du8; + const auto a_u8 = BitCast(du8, a); + const auto b_u8 = BitCast(du8, b); + const auto idx_u8 = BitCast(du8, Vec64{idx.raw}); + +#if HWY_ARCH_ARM_A64 + const Twice dt_u8; + return BitCast( + d, Vec64{vqtbl1_u8(Combine(dt_u8, b_u8, a_u8).raw, idx_u8.raw)}); +#else + detail::Tuple2 tup = {{{a_u8.raw, b_u8.raw}}}; + return BitCast(d, Vec64{vtbl2_u8(tup.raw, idx_u8.raw)}); +#endif +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Repartition du8; + const auto a_u8 = BitCast(du8, a); + const auto b_u8 = BitCast(du8, b); + const auto idx_u8 = BitCast(du8, Vec128{idx.raw}); + +#if HWY_ARCH_ARM_A64 + detail::Tuple2 tup = {{{a_u8.raw, b_u8.raw}}}; + return BitCast(d, Vec128{vqtbl2q_u8(tup.raw, idx_u8.raw)}); +#else + const Half dh; + const Repartition dh_u8; + const auto a_lo_u8 = LowerHalf(dh_u8, a_u8); + const auto a_hi_u8 = UpperHalf(dh_u8, a_u8); + const auto b_lo_u8 = LowerHalf(dh_u8, b_u8); + const auto b_hi_u8 = UpperHalf(dh_u8, b_u8); + const auto idx_lo_u8 = LowerHalf(dh_u8, idx_u8); + const auto idx_hi_u8 = UpperHalf(dh_u8, idx_u8); + + detail::Tuple4 tup = { + {{a_lo_u8.raw, a_hi_u8.raw, b_lo_u8.raw, b_hi_u8.raw}}}; + const auto lo_result = + BitCast(dh, Vec64{vtbl4_u8(tup.raw, idx_lo_u8.raw)}); + const auto hi_result = + BitCast(dh, Vec64{vtbl4_u8(tup.raw, idx_hi_u8.raw)}); + return Combine(d, hi_result, lo_result); +#endif +} + +// ------------------------------ Reverse2 (CombineShiftRightBytes) + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev16_u8(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 Reverse2(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev16q_u8(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev32_u16(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 Reverse2(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32q_u16(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev64_u32(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 Reverse2(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u32(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + return CombineShiftRightBytes<8>(d, v, v); +} + +// ------------------------------ Reverse4 (Reverse2) + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev32_u8(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 Reverse4(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32q_u8(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev64_u16(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 Reverse4(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u16(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + const RepartitionToWide> duw; + return BitCast(d, Reverse2(duw, BitCast(duw, Reverse2(d, v)))); +} + +template +HWY_API VFromD Reverse4(D /* tag */, VFromD) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 (Reverse2, Reverse4) + +template +HWY_API VFromD Reverse8(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev64_u8(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 Reverse8(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u8(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse8(D d, VFromD v) { + const Repartition du64; + return BitCast(d, Reverse2(du64, BitCast(du64, Reverse4(d, v)))); +} + +template +HWY_API VFromD Reverse8(D, VFromD) { + HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit +} + +// ------------------------------ Reverse (Reverse2, Reverse4, Reverse8) + +template , HWY_IF_LANES_D(D, 1)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return v; +} + +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + return Reverse2(d, v); +} + +template , HWY_IF_LANES_D(D, 4)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + return Reverse4(d, v); +} + +template , HWY_IF_LANES_D(D, 8)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + return Reverse8(d, v); +} + +template , HWY_IF_LANES_D(D, 16)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + const Repartition du64; + return BitCast(d, Reverse2(du64, BitCast(du64, Reverse8(d, v)))); +} + +// ------------------------------ ReverseBits + +#if HWY_ARCH_ARM_A64 + +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +HWY_NEON_DEF_FUNCTION_INT_8(ReverseBits, vrbit, _, 1) +HWY_NEON_DEF_FUNCTION_UINT_8(ReverseBits, vrbit, _, 1) + +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ Other shuffles (TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(Vec128 v) { + return CombineShiftRightBytes<8>(DFromV(), v, v); +} +template +HWY_API Vec128 Shuffle01(Vec128 v) { + return CombineShiftRightBytes<8>(DFromV(), v, v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(Vec128 v) { + return CombineShiftRightBytes<4>(DFromV(), v, v); +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(Vec128 v) { + return CombineShiftRightBytes<12>(DFromV(), v, v); +} + +// Reverse +template +HWY_API Vec128 Shuffle0123(Vec128 v) { + return Reverse4(DFromV(), v); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(InterleaveLower, vzip1, _, 2) +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_NEON_DEF_FUNCTION_FULL_UIF_64(InterleaveLower, vzip1, _, 2) +#else +// Emulated version for Armv7. +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + const DFromV d; + return CombineShiftRightBytes<8>(d, b, Shuffle01(a)); +} +#endif + +#if !HWY_HAVE_FLOAT16 +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, InterleaveLower(BitCast(du, a), BitCast(du, b))); +} +#endif // !HWY_HAVE_FLOAT16 + +// < 64 bit parts +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128(InterleaveLower(Vec64(a.raw), Vec64(b.raw)).raw); +} + +// Additional overload for the optional Simd<> tag. +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(InterleaveUpper, vzip2, _, 2) + +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_NEON_DEF_FUNCTION_FULL_UIF_64(InterleaveUpper, vzip2, _, 2) +#else +// Emulated version for Armv7. +template +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + const DFromV d; + return CombineShiftRightBytes<8>(d, Shuffle01(b), a); +} +#endif +} // namespace detail + +// Full register +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + const VFromD a2(UpperHalf(d2, a).raw); + const VFromD b2(UpperHalf(d2, b).raw); + return InterleaveLower(d, a2, b2); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ------------------------------ Per4LaneBlockShuffle +namespace detail { + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + +#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#else +#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#endif + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t /*x3*/, + const uint32_t /*x2*/, + const uint32_t x1, + const uint32_t x0) { + typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(8))); + const GccU32RawVectType raw = {x0, x1}; + return ResizeBitCast(d, Vec64(reinterpret_cast(raw))); +} + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16))); + const GccU32RawVectType raw = {x0, x1, x2, x3}; + return ResizeBitCast(d, Vec128(reinterpret_cast(raw))); +} +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG + +template , 4)> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x88> /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + const auto evens = BitCast(dw, ConcatEven(d, v, v)); + return BitCast(d, InterleaveLower(dw, evens, evens)); +} + +template , 4)> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xDD> /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + const auto odds = BitCast(dw, ConcatOdd(d, v, v)); + return BitCast(d, InterleaveLower(dw, odds, odds)); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xFA> /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag<8> /*vect_size_tag*/, V v) { + const DFromV d; + return InterleaveUpper(d, v, v); +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + using TU = UnsignedFromSize; + const Repartition du; + return BitCast(d, BitCast(du, v) << Set( + du, static_cast(amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Repartition du8; + const auto idx = + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromV))); + return BitCast(d, TableLookupBytesOr0(BitCast(du8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideUpLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + case 8: + return ShiftLeftLanes<8>(d, v); + case 9: + return ShiftLeftLanes<9>(d, v); + case 10: + return ShiftLeftLanes<10>(d, v); + case 11: + return ShiftLeftLanes<11>(d, v); + case 12: + return ShiftLeftLanes<12>(d, v); + case 13: + return ShiftLeftLanes<13>(d, v); + case 14: + return ShiftLeftLanes<14>(d, v); + case 15: + return ShiftLeftLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + using TU = UnsignedFromSize; + const Repartition du; + return BitCast(d, + BitCast(du, v) << Set( + du, static_cast(TU{0} - amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition di8; + auto idx = Iota(di8, static_cast(amt * sizeof(TFromV))); + idx = Or(idx, VecFromMask(di8, idx > Set(di8, int8_t{15}))); + return BitCast(d, TableLookupBytesOr0(BitCast(di8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideDownLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + case 8: + return ShiftRightLanes<8>(d, v); + case 9: + return ShiftRightLanes<9>(d, v); + case 10: + return ShiftRightLanes<10>(d, v); + case 11: + return ShiftRightLanes<11>(d, v); + case 12: + return ShiftRightLanes<12>(d, v); + case 13: + return ShiftRightLanes<13>(d, v); + case 14: + return ShiftRightLanes<14>(d, v); + case 15: + return ShiftRightLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +// ------------------------------- WidenHighMulAdd + +#ifdef HWY_NATIVE_WIDEN_HIGH_MUL_ADD +#undef HWY_NATIVE_WIDEN_HIGH_MUL_ADD +#else +#define HWY_NATIVE_WIDEN_HIGH_MUL_ADD +#endif + +namespace detail { + +template, + HWY_IF_LANES_GT_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_u32(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_u32(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_LE_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulResult = Vec128(vmull_u32(mul.raw, x.raw)); + return UpperHalf(d, mulResult) + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_s32(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_s32(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_LE_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulResult = Vec128(vmull_s32(mul.raw, x.raw)); + return UpperHalf(d, mulResult) + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_s16(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_s16(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s16(mul.raw, x.raw)); + Vec64 hi = UpperHalf(d, widen); + return hi + add; +} + +template, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s16(mul.raw, x.raw)); + Vec32 hi = UpperHalf(d, Vec64(vget_high_s32(widen.raw))); + return hi + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_u16(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_u16(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u16(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, widen); + return hi + add; +} + +template> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u16(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, Vec64(vget_high_u32(widen.raw))); + return hi + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_u8(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_u8(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u8(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, widen); + return hi + add; +} + +template), class DN = RepartitionToNarrow, + HWY_IF_LANES_LE_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u8(mul.raw, x.raw)); + const Twice d16F; + VFromD hi = UpperHalf(d, VFromD(vget_high_u16(widen.raw))); + return hi + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_s8(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_s8(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s8(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, widen); + return hi + add; +} + +template, + HWY_IF_LANES_LE_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s8(mul.raw, x.raw)); + const Twice d16F; + VFromD hi = UpperHalf(d, VFromD(vget_high_s16(widen.raw))); + return hi + add; +} + +#if 0 +#if HWY_HAVE_FLOAT16 +template> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vfmlalq_high_f16(add.raw, mul.raw, x.raw)); +} + +template> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec64(vfmlal_high_f16(add.raw, mul.raw, x.raw)); +} + +template> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + return MulAdd(add, PromoteUpperTo(d, mul), PromoteUpperTo(d, x)); +} +#endif +#endif + +} // namespace detail + +// ------------------------------- WidenMulAdd + +#ifdef HWY_NATIVE_WIDEN_MUL_ADD +#undef HWY_NATIVE_WIDEN_MUL_ADD +#else +#define HWY_NATIVE_WIDEN_MUL_ADD +#endif + +namespace detail { + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec128(vmlal_u8(add.raw, mul.raw, x.raw)); +} + +template >, D>> +HWY_API VFromD WidenMulAdd(D d, VFromD mul, VFromD x, + VFromD add) { + return MulAdd(add, PromoteTo(d, mul), PromoteTo(d, x)); +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vmlal_s8(add.raw, mul.raw, x.raw)); +} + +template >, D>> +HWY_API VFromD WidenMulAdd(D d, VFromD mul, VFromD x, + VFromD add) { + return MulAdd(add, PromoteTo(d, mul), PromoteTo(d, x)); +} + +template>, D>, + HWY_IF_LANES_GT_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec128(vmlal_s16(add.raw, mul.raw, x.raw)); +} + +template>, D>, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_s16(mul.raw, x.raw)); + const VFromD mul10 = LowerHalf(mulRs); + return add + mul10; +} + +template>, D>, + HWY_IF_LANES_D(D, 1)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec64 mulRs = LowerHalf(Vec128(vmull_s16(mul.raw, x.raw))); + const Vec32 mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec128(vmlal_u16(add.raw, mul.raw, x.raw)); +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_u16(mul.raw, x.raw)); + const Vec64 mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec64 mulRs = + LowerHalf(Vec128(vmull_u16(mul.raw, x.raw))); + const Vec32 mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vmlal_s32(add.raw, mul.raw, x.raw)); +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_s32(mul.raw, x.raw)); + const VFromD mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vmlal_u32(add.raw, mul.raw, x.raw)); +} + +template>, D>, + HWY_IF_LANES_D(DN, 1)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_u32(mul.raw, x.raw)); + const VFromD mul10(LowerHalf(mulRs)); + return add + mul10; +} + +#if 0 +#if HWY_HAVE_FLOAT16 +template, + HWY_IF_LANES_D(D, 4)> +HWY_API VFromD WidenLowMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vfmlalq_low_f16(add.raw, mul.raw, x.raw)); +} + +template, + HWY_IF_LANES_D(DN, 4)> +HWY_API VFromD WidenLowMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec64(vfmlal_low_f16(add.raw, mul.raw, x.raw)); +} + +template> +HWY_API VFromD WidenLowMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + return MulAdd(add, PromoteLowerTo(d, mul), PromoteLowerTo(d, x)); +} +#endif +#endif + +} // namespace detail + +// ------------------------------ WidenMulAccumulate + +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#endif + +template), class DN = RepartitionToNarrow> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = detail::WidenHighMulAdd(d, mul, x, high); + return detail::WidenMulAdd(d, LowerHalf(mul), LowerHalf(x), low); +} + +#if 0 +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#endif + +#if HWY_HAVE_FLOAT16 + +template> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = detail::WidenHighMulAdd(d, mul, x, high); + return detail::WidenLowMulAdd(d, mul, x, low); +} + +#endif +#endif + +// ------------------------------ SatWidenMulAccumFixedPoint + +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD(vqdmlal_s16(sum.raw, a.raw, b.raw)); +} + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Full128> di32_full; + const Rebind di16_full64; + return ResizeBitCast( + di32, SatWidenMulAccumFixedPoint(di32_full, ResizeBitCast(di16_full64, a), + ResizeBitCast(di16_full64, b), + ResizeBitCast(di32_full, sum))); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +#if HWY_NEON_HAVE_F32_TO_BF16C + +#ifdef HWY_NATIVE_MUL_EVEN_BF16 +#undef HWY_NATIVE_MUL_EVEN_BF16 +#else +#define HWY_NATIVE_MUL_EVEN_BF16 +#endif + +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +namespace detail { +#if HWY_NEON_HAVE_BFLOAT16 +// If HWY_NEON_HAVE_BFLOAT16 is true, detail::Vec128::type is +// bfloat16x4_t or bfloat16x8_t. +static HWY_INLINE bfloat16x4_t BitCastToRawNeonBF16(bfloat16x4_t raw) { + return raw; +} +static HWY_INLINE bfloat16x8_t BitCastToRawNeonBF16(bfloat16x8_t raw) { + return raw; +} +#else +// If HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true, +// detail::Vec128::type is uint16x4_t or uint16x8_t vector to +// work around compiler bugs that are there with GCC 13 or earlier or Clang 16 +// or earlier on AArch64. + +// The uint16x4_t or uint16x8_t vector neets to be bitcasted to a bfloat16x4_t +// or a bfloat16x8_t vector for the vbfdot_f32 and vbfdotq_f32 intrinsics if +// HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true +static HWY_INLINE bfloat16x4_t BitCastToRawNeonBF16(uint16x4_t raw) { + return vreinterpret_bf16_u16(raw); +} +static HWY_INLINE bfloat16x8_t BitCastToRawNeonBF16(uint16x8_t raw) { + return vreinterpretq_bf16_u16(raw); +} +#endif +} // namespace detail + +template +HWY_API Vec128 MulEvenAdd(D /*d32*/, Vec128 a, + Vec128 b, const Vec128 c) { + return Vec128(vbfmlalbq_f32(c.raw, detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +template +HWY_API Vec128 MulOddAdd(D /*d32*/, Vec128 a, + Vec128 b, const Vec128 c) { + return Vec128(vbfmlaltq_f32(c.raw, detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +template +HWY_API Vec128 ReorderWidenMulAccumulate(D /*d32*/, Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& /*sum1*/) { + return Vec128(vbfdotq_f32(sum0.raw, + detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +// There is no non-q version of these instructions. +template +HWY_API VFromD MulEvenAdd(D d32, VFromD> a, + VFromD> b, + const VFromD c) { + const Full128 d32f; + const Full128 d16f; + return ResizeBitCast( + d32, MulEvenAdd(d32f, ResizeBitCast(d16f, a), ResizeBitCast(d16f, b), + ResizeBitCast(d32f, c))); +} + +template +HWY_API VFromD MulOddAdd(D d32, VFromD> a, + VFromD> b, + const VFromD c) { + const Full128 d32f; + const Full128 d16f; + return ResizeBitCast( + d32, MulOddAdd(d32f, ResizeBitCast(d16f, a), ResizeBitCast(d16f, b), + ResizeBitCast(d32f, c))); +} + +template +HWY_API VFromD ReorderWidenMulAccumulate( + D /*d32*/, VFromD> a, + VFromD> b, const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD(vbfdot_f32(sum0.raw, detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +template +HWY_API Vec128 ReorderWidenMulAccumulate(D /*d32*/, Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmlal_high_s16(sum1.raw, a.raw, b.raw)); +#else + const Full64 dh; + sum1 = Vec128( + vmlal_s16(sum1.raw, UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + return Vec128( + vmlal_s16(sum0.raw, LowerHalf(a).raw, LowerHalf(b).raw)); +} + +template +HWY_API Vec64 ReorderWidenMulAccumulate(D d32, Vec64 a, + Vec64 b, + const Vec64 sum0, + Vec64& sum1) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_s16(a.raw, b.raw)); + const Vec64 mul_32 = UpperHalf(d32, mul_3210); + sum1 += mul_32; + return sum0 + LowerHalf(mul_3210); +} + +template +HWY_API Vec32 ReorderWidenMulAccumulate(D d32, Vec32 a, + Vec32 b, + const Vec32 sum0, + Vec32& sum1) { + const Vec128 mul_xx10(vmull_s16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(d32, mul_10); + const Vec32 mul1 = UpperHalf(d32, mul_10); + sum1 += mul1; + return sum0 + mul0; +} + +template +HWY_API Vec128 ReorderWidenMulAccumulate(D /*d32*/, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmlal_high_u16(sum1.raw, a.raw, b.raw)); +#else + const Full64 dh; + sum1 = Vec128( + vmlal_u16(sum1.raw, UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + return Vec128( + vmlal_u16(sum0.raw, LowerHalf(a).raw, LowerHalf(b).raw)); +} + +template +HWY_API Vec64 ReorderWidenMulAccumulate(D d32, Vec64 a, + Vec64 b, + const Vec64 sum0, + Vec64& sum1) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_u16(a.raw, b.raw)); + const Vec64 mul_32 = UpperHalf(d32, mul_3210); + sum1 += mul_32; + return sum0 + LowerHalf(mul_3210); +} + +template +HWY_API Vec32 ReorderWidenMulAccumulate(D du32, Vec32 a, + Vec32 b, + const Vec32 sum0, + Vec32& sum1) { + const Vec128 mul_xx10(vmull_u16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(du32, mul_10); + const Vec32 mul1 = UpperHalf(du32, mul_10); + sum1 += mul1; + return sum0 + mul0; +} + +// ------------------------------ Combine partial (InterleaveLower) +// < 64bit input, <= 64 bit result +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + // First double N (only lower halves will be used). + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + // Repartition to two unsigned lanes (each the size of the valid input). + const Simd, 2, 0> du; + return BitCast(d, InterleaveLower(BitCast(du, lo2), BitCast(du, hi2))); +} + +// ------------------------------ RearrangeToOddPlusEven (Combine) + +template +HWY_API Vec128 RearrangeToOddPlusEven(Vec128 sum0, + Vec128 sum1) { +#if HWY_NEON_HAVE_BFLOAT16 + (void)sum1; // unused by bf16 ReorderWidenMulAccumulate + return sum0; +#else + return Add(sum0, sum1); +#endif +} + +HWY_API Vec128 RearrangeToOddPlusEven(Vec128 sum0, + Vec128 sum1) { +// vmlal_s16 multiplied the lower half into sum0 and upper into sum1. +#if HWY_ARCH_ARM_A64 // pairwise sum is available and what we want + return Vec128(vpaddq_s32(sum0.raw, sum1.raw)); +#else + const Full128 d; + const Half d64; + const Vec64 hi( + vpadd_s32(LowerHalf(d64, sum1).raw, UpperHalf(d64, sum1).raw)); + const Vec64 lo( + vpadd_s32(LowerHalf(d64, sum0).raw, UpperHalf(d64, sum0).raw)); + return Combine(Full128(), hi, lo); +#endif +} + +HWY_API Vec64 RearrangeToOddPlusEven(Vec64 sum0, + Vec64 sum1) { + // vmlal_s16 multiplied the lower half into sum0 and upper into sum1. + return Vec64(vpadd_s32(sum0.raw, sum1.raw)); +} + +HWY_API Vec32 RearrangeToOddPlusEven(Vec32 sum0, + Vec32 sum1) { + // Only one widened sum per register, so add them for sum of odd and even. + return sum0 + sum1; +} + +HWY_API Vec128 RearrangeToOddPlusEven(Vec128 sum0, + Vec128 sum1) { +// vmlal_s16 multiplied the lower half into sum0 and upper into sum1. +#if HWY_ARCH_ARM_A64 // pairwise sum is available and what we want + return Vec128(vpaddq_u32(sum0.raw, sum1.raw)); +#else + const Full128 d; + const Half d64; + const Vec64 hi( + vpadd_u32(LowerHalf(d64, sum1).raw, UpperHalf(d64, sum1).raw)); + const Vec64 lo( + vpadd_u32(LowerHalf(d64, sum0).raw, UpperHalf(d64, sum0).raw)); + return Combine(Full128(), hi, lo); +#endif +} + +HWY_API Vec64 RearrangeToOddPlusEven(Vec64 sum0, + Vec64 sum1) { + // vmlal_u16 multiplied the lower half into sum0 and upper into sum1. + return Vec64(vpadd_u32(sum0.raw, sum1.raw)); +} + +HWY_API Vec32 RearrangeToOddPlusEven(Vec32 sum0, + Vec32 sum1) { + // Only one widened sum per register, so add them for sum of odd and even. + return sum0 + sum1; +} + +// ------------------------------ SumOfMulQuadAccumulate + +#if HWY_TARGET == HWY_NEON_BF16 + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD(vdot_s32(sum.raw, a.raw, b.raw)); +} + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD(vdotq_s32(sum.raw, a.raw, b.raw)); +} + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD(vdot_u32(sum.raw, a.raw, b.raw)); +} + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD(vdotq_u32(sum.raw, a.raw, b.raw)); +} + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 di32, VFromD> a_u, + VFromD> b_i, VFromD sum) { + // TODO: use vusdot[q]_s32 on NEON targets that require support for NEON I8MM + + const RebindToUnsigned du32; + const Repartition du8; + + const auto b_u = BitCast(du8, b_i); + const auto result_sum0 = + SumOfMulQuadAccumulate(du32, a_u, b_u, BitCast(du32, sum)); + const auto result_sum1 = ShiftLeft<8>( + SumOfMulQuadAccumulate(du32, a_u, ShiftRight<7>(b_u), Zero(du32))); + + return BitCast(di32, Sub(result_sum0, result_sum1)); +} + +#endif // HWY_TARGET == HWY_NEON_BF16 + +// ------------------------------ WidenMulPairwiseAdd + +#if HWY_NEON_HAVE_F32_TO_BF16C + +template +HWY_API Vec128 WidenMulPairwiseAdd(DF df, Vec128 a, + Vec128 b) { + return Vec128(vbfdotq_f32(Zero(df).raw, + detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +template +HWY_API VFromD WidenMulPairwiseAdd(DF df, + VFromD> a, + VFromD> b) { + return VFromD(vbfdot_f32(Zero(df).raw, + detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +#else +template +HWY_API VFromD WidenMulPairwiseAdd(DF df, + VFromD> a, + VFromD> b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +template +HWY_API Vec128 WidenMulPairwiseAdd(D /*d32*/, Vec128 a, + Vec128 b) { + Vec128 sum1; +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmull_high_s16(a.raw, b.raw)); +#else + const Full64 dh; + sum1 = Vec128(vmull_s16(UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + Vec128 sum0 = + Vec128(vmull_s16(LowerHalf(a).raw, LowerHalf(b).raw)); + return RearrangeToOddPlusEven(sum0, sum1); +} + +template +HWY_API Vec64 WidenMulPairwiseAdd(D d32, Vec64 a, + Vec64 b) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_s16(a.raw, b.raw)); + const Vec64 mul0 = LowerHalf(mul_3210); + const Vec64 mul1 = UpperHalf(d32, mul_3210); + return RearrangeToOddPlusEven(mul0, mul1); +} + +template +HWY_API Vec32 WidenMulPairwiseAdd(D d32, Vec32 a, + Vec32 b) { + const Vec128 mul_xx10(vmull_s16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(d32, mul_10); + const Vec32 mul1 = UpperHalf(d32, mul_10); + return RearrangeToOddPlusEven(mul0, mul1); +} + +template +HWY_API Vec128 WidenMulPairwiseAdd(D /*d32*/, Vec128 a, + Vec128 b) { + Vec128 sum1; +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmull_high_u16(a.raw, b.raw)); +#else + const Full64 dh; + sum1 = + Vec128(vmull_u16(UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + Vec128 sum0 = + Vec128(vmull_u16(LowerHalf(a).raw, LowerHalf(b).raw)); + return RearrangeToOddPlusEven(sum0, sum1); +} + +template +HWY_API Vec64 WidenMulPairwiseAdd(D d32, Vec64 a, + Vec64 b) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_u16(a.raw, b.raw)); + const Vec64 mul0 = LowerHalf(mul_3210); + const Vec64 mul1 = UpperHalf(d32, mul_3210); + return RearrangeToOddPlusEven(mul0, mul1); +} + +template +HWY_API Vec32 WidenMulPairwiseAdd(D d32, Vec32 a, + Vec32 b) { + const Vec128 mul_xx10(vmull_u16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(d32, mul_10); + const Vec32 mul1 = UpperHalf(d32, mul_10); + return RearrangeToOddPlusEven(mul0, mul1); +} + +// ------------------------------ ZeroExtendVector (Combine) + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + return Combine(d, Zero(Half()), lo); +} + +// ------------------------------ ConcatLowerLower + +// 64 or 128-bit input: just interleave +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition, decltype(d)> du; + return BitCast(d, InterleaveLower(BitCast(du, lo), BitCast(du, hi))); +} + +namespace detail { +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(InterleaveEven, vtrn1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(InterleaveOdd, vtrn2, _, 2) +#else + +// vtrn returns a struct with even and odd result. +#define HWY_NEON_BUILD_TPL_HWY_TRN +#define HWY_NEON_BUILD_RET_HWY_TRN(type, size) type##x##size##x2_t +// Pass raw args so we can accept uint16x2 args, for which there is no +// corresponding uint16x2x2 return type. +#define HWY_NEON_BUILD_PARAM_HWY_TRN(TYPE, size) \ + Raw128::type a, Raw128::type b +#define HWY_NEON_BUILD_ARG_HWY_TRN a, b + +// Cannot use UINT8 etc. type macros because the x2_t tuples are only defined +// for full and half vectors. +HWY_NEON_DEF_FUNCTION(uint8, 16, InterleaveEvenOdd, vtrnq, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint8, 8, InterleaveEvenOdd, vtrn, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 8, InterleaveEvenOdd, vtrnq, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 4, InterleaveEvenOdd, vtrn, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 4, InterleaveEvenOdd, vtrnq, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 2, InterleaveEvenOdd, vtrn, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 16, InterleaveEvenOdd, vtrnq, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 8, InterleaveEvenOdd, vtrn, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 8, InterleaveEvenOdd, vtrnq, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 4, InterleaveEvenOdd, vtrn, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 4, InterleaveEvenOdd, vtrnq, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 2, InterleaveEvenOdd, vtrn, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 4, InterleaveEvenOdd, vtrnq, _, f32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 2, InterleaveEvenOdd, vtrn, _, f32, HWY_TRN) + +#undef HWY_NEON_BUILD_TPL_HWY_TRN +#undef HWY_NEON_BUILD_RET_HWY_TRN +#undef HWY_NEON_BUILD_PARAM_HWY_TRN +#undef HWY_NEON_BUILD_ARG_HWY_TRN + +#endif // HWY_ARCH_ARM_A64 +} // namespace detail + +// <= 32-bit input/output +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveEven(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[0])); +#endif +} + +// ------------------------------ ConcatUpperUpper + +// 64 or 128-bit input: just interleave +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition, decltype(d)> du; + return BitCast(d, InterleaveUpper(du, BitCast(du, lo), BitCast(du, hi))); +} + +// <= 32-bit input/output +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveOdd(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[1])); +#endif +} + +// ------------------------------ ConcatLowerUpper (ShiftLeftBytes) + +// 64 or 128-bit input: extract from concatenated +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + return CombineShiftRightBytes(d, hi, lo); +} + +// <= 32-bit input/output +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + const Repartition d8; + const Full64 d8x8; + const Full64> d64; + using V8x8 = VFromD; + const V8x8 hi8x8(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V8x8 lo8x8 = ShiftLeftBytes<8 - kSize>(V8x8(BitCast(d8, lo).raw)); + const V8x8 r = CombineShiftRightBytes<8 - kSize / 2>(d8x8, hi8x8, lo8x8); + // Back to original lane type, then shrink N. + return VFromD(BitCast(d64, r).raw); +} + +// ------------------------------ ConcatUpperLower + +// Works for all N. +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatOdd (InterleaveUpper) + +namespace detail { +// There is no vuzpq_u64. +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(ConcatEven, vuzp1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(ConcatOdd, vuzp2, _, 2) + +#if !HWY_HAVE_FLOAT16 +template +HWY_INLINE Vec128 ConcatEven(Vec128 hi, + Vec128 lo) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ConcatEven(BitCast(du, hi), BitCast(du, lo))); +} +template +HWY_INLINE Vec128 ConcatOdd(Vec128 hi, + Vec128 lo) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ConcatOdd(BitCast(du, hi), BitCast(du, lo))); +} +#endif // !HWY_HAVE_FLOAT16 +} // namespace detail + +// Full/half vector +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + return detail::ConcatOdd(lo, hi); +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatOdd(D d, Vec32 hi, Vec32 lo) { + const Twice d2; + const Repartition dw2; + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + const VFromD Hx1Lx1 = BitCast(dw2, ConcatOdd(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec32(BitCast(d2, ConcatEven(dw2, Hx1Lx1, Hx1Lx1)).raw); +} + +// Any type x2 +template > +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// Full/half vector +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + return detail::ConcatEven(lo, hi); +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatEven(D d, Vec32 hi, Vec32 lo) { + const Twice d2; + const Repartition dw2; + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + const VFromD Hx0Lx0 = BitCast(dw2, ConcatEven(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec32(BitCast(d2, ConcatEven(dw2, Hx0Lx0, Hx0Lx0)).raw); +} + +// Any type x2 +template > +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveEven(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[0]); +#endif +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveOdd(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[1]); +#endif +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) static constexpr uint8_t kBytes[16] = { + ((0 / sizeof(T)) & 1) ? 0 : 0xFF, ((1 / sizeof(T)) & 1) ? 0 : 0xFF, + ((2 / sizeof(T)) & 1) ? 0 : 0xFF, ((3 / sizeof(T)) & 1) ? 0 : 0xFF, + ((4 / sizeof(T)) & 1) ? 0 : 0xFF, ((5 / sizeof(T)) & 1) ? 0 : 0xFF, + ((6 / sizeof(T)) & 1) ? 0 : 0xFF, ((7 / sizeof(T)) & 1) ? 0 : 0xFF, + ((8 / sizeof(T)) & 1) ? 0 : 0xFF, ((9 / sizeof(T)) & 1) ? 0 : 0xFF, + ((10 / sizeof(T)) & 1) ? 0 : 0xFF, ((11 / sizeof(T)) & 1) ? 0 : 0xFF, + ((12 / sizeof(T)) & 1) ? 0 : 0xFF, ((13 / sizeof(T)) & 1) ? 0 : 0xFF, + ((14 / sizeof(T)) & 1) ? 0 : 0xFF, ((15 / sizeof(T)) & 1) ? 0 : 0xFF, + }; + const auto vec = BitCast(d, Load(d8, kBytes)); + return IfThenElse(MaskFromVec(vec), b, a); +} + +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveEven(a, b); +#else + return VFromD(detail::InterleaveEvenOdd(a.raw, b.raw).val[0]); +#endif +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveOdd(a, b); +#else + return VFromD(detail::InterleaveEvenOdd(a.raw, b.raw).val[1]); +#endif +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template > +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template > +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ ReverseBlocks +// Single block: no change +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; +} + +// ------------------------------ ReorderDemote2To (OddEven) + +#if HWY_NEON_HAVE_F32_TO_BF16C +template +HWY_API VFromD ReorderDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const Half dh_bf16; + return Combine(dbf16, DemoteTo(dh_bf16, b), DemoteTo(dh_bf16, a)); +} +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +template +HWY_API Vec128 ReorderDemote2To(D d32, Vec128 a, + Vec128 b) { + const Vec64 a32(vqmovn_s64(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d32; + return Vec128(vqmovn_high_s64(a32.raw, b.raw)); +#else + const Vec64 b32(vqmovn_s64(b.raw)); + return Combine(d32, b32, a32); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d32, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d32, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d32, Vec128 a, + Vec128 b) { + const Vec64 a32(vqmovun_s64(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d32; + return Vec128(vqmovun_high_s64(a32.raw, b.raw)); +#else + const Vec64 b32(vqmovun_s64(b.raw)); + return Combine(d32, b32, a32); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d32, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d32, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d32, Vec128 a, + Vec128 b) { + const Vec64 a32(vqmovn_u64(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d32; + return Vec128(vqmovn_high_u64(a32.raw, b.raw)); +#else + const Vec64 b32(vqmovn_u64(b.raw)); + return Combine(d32, b32, a32); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d32, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d32, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d16, Vec128 a, + Vec128 b) { + const Vec64 a16(vqmovn_s32(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d16; + return Vec128(vqmovn_high_s32(a16.raw, b.raw)); +#else + const Vec64 b16(vqmovn_s32(b.raw)); + return Combine(d16, b16, a16); +#endif +} + +template +HWY_API Vec64 ReorderDemote2To(D /*d16*/, Vec64 a, + Vec64 b) { + const Full128 d32; + const Vec128 ab = Combine(d32, b, a); + return Vec64(vqmovn_s32(ab.raw)); +} + +template +HWY_API Vec32 ReorderDemote2To(D /*d16*/, Vec32 a, + Vec32 b) { + const Full128 d32; + const Vec64 ab(vzip1_s32(a.raw, b.raw)); + return Vec32(vqmovn_s32(Combine(d32, ab, ab).raw)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d16, Vec128 a, + Vec128 b) { + const Vec64 a16(vqmovun_s32(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d16; + return Vec128(vqmovun_high_s32(a16.raw, b.raw)); +#else + const Vec64 b16(vqmovun_s32(b.raw)); + return Combine(d16, b16, a16); +#endif +} + +template +HWY_API Vec64 ReorderDemote2To(D /*d16*/, Vec64 a, + Vec64 b) { + const Full128 d32; + const Vec128 ab = Combine(d32, b, a); + return Vec64(vqmovun_s32(ab.raw)); +} + +template +HWY_API Vec32 ReorderDemote2To(D /*d16*/, Vec32 a, + Vec32 b) { + const Full128 d32; + const Vec64 ab(vzip1_s32(a.raw, b.raw)); + return Vec32(vqmovun_s32(Combine(d32, ab, ab).raw)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d16, Vec128 a, + Vec128 b) { + const Vec64 a16(vqmovn_u32(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d16; + return Vec128(vqmovn_high_u32(a16.raw, b.raw)); +#else + const Vec64 b16(vqmovn_u32(b.raw)); + return Combine(d16, b16, a16); +#endif +} + +template +HWY_API Vec64 ReorderDemote2To(D /*d16*/, Vec64 a, + Vec64 b) { + const Full128 d32; + const Vec128 ab = Combine(d32, b, a); + return Vec64(vqmovn_u32(ab.raw)); +} + +template +HWY_API Vec32 ReorderDemote2To(D /*d16*/, Vec32 a, + Vec32 b) { + const Full128 d32; + const Vec64 ab(vzip1_u32(a.raw, b.raw)); + return Vec32(vqmovn_u32(Combine(d32, ab, ab).raw)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d8, Vec128 a, + Vec128 b) { + const Vec64 a8(vqmovn_s16(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d8; + return Vec128(vqmovn_high_s16(a8.raw, b.raw)); +#else + const Vec64 b8(vqmovn_s16(b.raw)); + return Combine(d8, b8, a8); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d8, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d8, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d8, Vec128 a, + Vec128 b) { + const Vec64 a8(vqmovun_s16(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d8; + return Vec128(vqmovun_high_s16(a8.raw, b.raw)); +#else + const Vec64 b8(vqmovun_s16(b.raw)); + return Combine(d8, b8, a8); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d8, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d8, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d8, Vec128 a, + Vec128 b) { + const Vec64 a8(vqmovn_u16(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d8; + return Vec128(vqmovn_high_u16(a8.raw, b.raw)); +#else + const Vec64 b8(vqmovn_u16(b.raw)); + return Combine(d8, b8, a8); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d8, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d8, Combine(dt, b, a)); +} + +template ), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} + +#if HWY_NEON_HAVE_F32_TO_BF16C +template +HWY_API VFromD OrderedDemote2To(D dbf16, VFromD> a, + VFromD> b) { + return ReorderDemote2To(dbf16, a, b); +} +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +// ================================================== CRYPTO + +// (aarch64 or Arm7) and (__ARM_FEATURE_AES or HWY_HAVE_RUNTIME_DISPATCH). +// Otherwise, rely on generic_ops-inl.h to emulate AESRound / CLMul*. +#if HWY_TARGET != HWY_NEON_WITHOUT_AES + +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + // NOTE: it is important that AESE and AESMC be consecutive instructions so + // they can be fused. AESE includes AddRoundKey, which is a different ordering + // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual + // round key (the compiler will hopefully optimize this for multiple rounds). + return Vec128(vaesmcq_u8(vaeseq_u8(state.raw, vdupq_n_u8(0)))) ^ + round_key; +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128(vaeseq_u8(state.raw, vdupq_n_u8(0))) ^ round_key; +} + +HWY_API Vec128 AESInvMixColumns(Vec128 state) { + return Vec128{vaesimcq_u8(state.raw)}; +} + +HWY_API Vec128 AESRoundInv(Vec128 state, + Vec128 round_key) { + // NOTE: it is important that AESD and AESIMC be consecutive instructions so + // they can be fused. AESD includes AddRoundKey, which is a different ordering + // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual + // round key (the compiler will hopefully optimize this for multiple rounds). + return Vec128(vaesimcq_u8(vaesdq_u8(state.raw, vdupq_n_u8(0)))) ^ + round_key; +} + +HWY_API Vec128 AESLastRoundInv(Vec128 state, + Vec128 round_key) { + return Vec128(vaesdq_u8(state.raw, vdupq_n_u8(0))) ^ round_key; +} + +HWY_API Vec128 CLMulLower(Vec128 a, Vec128 b) { + return Vec128((uint64x2_t)vmull_p64(GetLane(a), GetLane(b))); +} + +HWY_API Vec128 CLMulUpper(Vec128 a, Vec128 b) { + return Vec128( + (uint64x2_t)vmull_high_p64((poly64x2_t)a.raw, (poly64x2_t)b.raw)); +} + +#endif // HWY_TARGET != HWY_NEON_WITHOUT_AES + +// ================================================== MISC + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Truncations + +template , typename TFrom, + HWY_IF_UNSIGNED(TFrom), HWY_IF_UNSIGNED(TTo), + hwy::EnableIf<(sizeof(TTo) < sizeof(TFrom))>* = nullptr> +HWY_API Vec128 TruncateTo(DTo /* tag */, Vec128 v) { + const Repartition> d; + return Vec128{BitCast(d, v).raw}; +} + +template +HWY_API Vec16 TruncateTo(D /* tag */, Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + const auto v4 = detail::ConcatEven(v3, v3); + return LowerHalf(LowerHalf(LowerHalf(v4))); +} + +template +HWY_API Vec32 TruncateTo(D /* tag */, Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +template +HWY_API Vec64 TruncateTo(D /* tag */, Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +// ------------------------------ MulEven (ConcatEven) + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + int8x16_t a_packed = ConcatEven(d, a, a).raw; + int8x16_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_s8(vget_low_s8(a_packed), vget_low_s8(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + uint8x16_t a_packed = ConcatEven(d, a, a).raw; + uint8x16_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_u8(vget_low_u8(a_packed), vget_low_u8(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + int16x8_t a_packed = ConcatEven(d, a, a).raw; + int16x8_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_s16(vget_low_s16(a_packed), vget_low_s16(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + uint16x8_t a_packed = ConcatEven(d, a, a).raw; + uint16x8_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_u16(vget_low_u16(a_packed), vget_low_u16(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + int32x4_t a_packed = ConcatEven(d, a, a).raw; + int32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_s32(vget_low_s32(a_packed), vget_low_s32(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + uint32x4_t a_packed = ConcatEven(d, a, a).raw; + uint32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_u32(vget_low_u32(a_packed), vget_low_u32(b_packed))); +} + +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + int8x8_t a_packed = ConcatEven(d, a, a).raw; + int8x8_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_s16(vmull_s8(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + uint8x8_t a_packed = ConcatEven(d, a, a).raw; + uint8x8_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_u16(vmull_u8(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + int16x4_t a_packed = ConcatEven(d, a, a).raw; + int16x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_s32(vmull_s16(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + uint16x4_t a_packed = ConcatEven(d, a, a).raw; + uint16x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_u32(vmull_u16(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + int32x2_t a_packed = ConcatEven(d, a, a).raw; + int32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_s64(vmull_s32(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + uint32x2_t a_packed = ConcatEven(d, a, a).raw; + uint32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_u64(vmull_u32(a_packed, b_packed))); +} + +template +HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { + T hi; + T lo = Mul128(GetLane(a), GetLane(b), &hi); + return Dup128VecFromValues(Full128(), lo, hi); +} + +// Multiplies odd lanes (1, 3 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + int8x16_t a_packed = ConcatOdd(d, a, a).raw; + int8x16_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_s8(vget_low_s8(a_packed), vget_low_s8(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + uint8x16_t a_packed = ConcatOdd(d, a, a).raw; + uint8x16_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_u8(vget_low_u8(a_packed), vget_low_u8(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + int16x8_t a_packed = ConcatOdd(d, a, a).raw; + int16x8_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_s16(vget_low_s16(a_packed), vget_low_s16(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + uint16x8_t a_packed = ConcatOdd(d, a, a).raw; + uint16x8_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_u16(vget_low_u16(a_packed), vget_low_u16(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + int32x4_t a_packed = ConcatOdd(d, a, a).raw; + int32x4_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_s32(vget_low_s32(a_packed), vget_low_s32(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + uint32x4_t a_packed = ConcatOdd(d, a, a).raw; + uint32x4_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_u32(vget_low_u32(a_packed), vget_low_u32(b_packed))); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + int8x8_t a_packed = ConcatOdd(d, a, a).raw; + int8x8_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_s16(vmull_s8(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + uint8x8_t a_packed = ConcatOdd(d, a, a).raw; + uint8x8_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_u16(vmull_u8(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + int16x4_t a_packed = ConcatOdd(d, a, a).raw; + int16x4_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_s32(vmull_s16(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + uint16x4_t a_packed = ConcatOdd(d, a, a).raw; + uint16x4_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_u32(vmull_u16(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + int32x2_t a_packed = ConcatOdd(d, a, a).raw; + int32x2_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_s64(vmull_s32(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + uint32x2_t a_packed = ConcatOdd(d, a, a).raw; + uint32x2_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_u64(vmull_u32(a_packed, b_packed))); +} + +template +HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { + T hi; + T lo = Mul128(detail::GetLane<1>(a), detail::GetLane<1>(b), &hi); + return Dup128VecFromValues(Full128(), lo, hi); +} + +// ------------------------------ TableLookupBytes (Combine, LowerHalf) + +// Both full +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, Vec128 from) { + const DFromV d; + const Repartition d8; +#if HWY_ARCH_ARM_A64 + return BitCast(d, Vec128(vqtbl1q_u8(BitCast(d8, bytes).raw, + BitCast(d8, from).raw))); +#else + uint8x16_t table0 = BitCast(d8, bytes).raw; + uint8x8x2_t table; + table.val[0] = vget_low_u8(table0); + table.val[1] = vget_high_u8(table0); + uint8x16_t idx = BitCast(d8, from).raw; + uint8x8_t low = vtbl2_u8(table, vget_low_u8(idx)); + uint8x8_t hi = vtbl2_u8(table, vget_high_u8(idx)); + return BitCast(d, Vec128(vcombine_u8(low, hi))); +#endif +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, Vec128 from) { + const Full128 d_full; + const Vec64 from64(from.raw); + const auto idx_full = Combine(d_full, from64, from64); + const auto out_full = TableLookupBytes(bytes, idx_full); + return Vec128(LowerHalf(Half(), out_full).raw); +} + +// Partial table vector +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, Vec128 from) { + const Full128 d_full; + return TableLookupBytes(Combine(d_full, bytes, bytes), from); +} + +// Partial both +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, + Vec128 from) { + const DFromV d; + const Simd d_idx; + const Repartition d_idx8; + // uint8x8 + const auto bytes8 = BitCast(Repartition(), bytes); + const auto from8 = BitCast(d_idx8, from); + const VFromD v8(vtbl1_u8(bytes8.raw, from8.raw)); + return BitCast(d_idx, v8); +} + +// For all vector widths; Arm anyway zeroes if >= 0x10. +template +HWY_API VI TableLookupBytesOr0(V bytes, VI from) { + return TableLookupBytes(bytes, from); +} + +// ---------------------------- AESKeyGenAssist (AESLastRound, TableLookupBytes) + +#if HWY_TARGET != HWY_NEON_WITHOUT_AES +template +HWY_API Vec128 AESKeyGenAssist(Vec128 v) { + alignas(16) static constexpr uint8_t kRconXorMask[16] = { + 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0}; + alignas(16) static constexpr uint8_t kRotWordShuffle[16] = { + 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12}; + const DFromV d; + const Repartition du32; + const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); + const auto sub_word_result = AESLastRound(w13, Load(d, kRconXorMask)); + return TableLookupBytes(sub_word_result, Load(d, kRotWordShuffle)); +} +#endif // HWY_TARGET != HWY_NEON_WITHOUT_AES + +// ------------------------------ Scatter in generic_ops-inl.h +// ------------------------------ Gather in generic_ops-inl.h + +// ------------------------------ Reductions + +// On Armv8 we define ReduceSum and generic_ops defines SumOfLanes via Set. +#if HWY_ARCH_ARM_A64 + +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +// TODO(janwas): use normal HWY_NEON_DEF, then FULL type list. +#define HWY_NEON_DEF_REDUCTION(type, size, name, prefix, infix, suffix) \ + template \ + HWY_API type##_t name(D /* tag */, Vec128 v) { \ + return HWY_NEON_EVAL(prefix##infix##suffix, v.raw); \ + } + +// Excludes u64/s64 (missing minv/maxv) and f16 (missing addv). +#define HWY_NEON_DEF_REDUCTION_CORE_TYPES(name, prefix) \ + HWY_NEON_DEF_REDUCTION(uint8, 8, name, prefix, _, u8) \ + HWY_NEON_DEF_REDUCTION(uint8, 16, name, prefix##q, _, u8) \ + HWY_NEON_DEF_REDUCTION(uint16, 4, name, prefix, _, u16) \ + HWY_NEON_DEF_REDUCTION(uint16, 8, name, prefix##q, _, u16) \ + HWY_NEON_DEF_REDUCTION(uint32, 2, name, prefix, _, u32) \ + HWY_NEON_DEF_REDUCTION(uint32, 4, name, prefix##q, _, u32) \ + HWY_NEON_DEF_REDUCTION(int8, 8, name, prefix, _, s8) \ + HWY_NEON_DEF_REDUCTION(int8, 16, name, prefix##q, _, s8) \ + HWY_NEON_DEF_REDUCTION(int16, 4, name, prefix, _, s16) \ + HWY_NEON_DEF_REDUCTION(int16, 8, name, prefix##q, _, s16) \ + HWY_NEON_DEF_REDUCTION(int32, 2, name, prefix, _, s32) \ + HWY_NEON_DEF_REDUCTION(int32, 4, name, prefix##q, _, s32) \ + HWY_NEON_DEF_REDUCTION(float32, 2, name, prefix, _, f32) \ + HWY_NEON_DEF_REDUCTION(float32, 4, name, prefix##q, _, f32) \ + HWY_NEON_DEF_REDUCTION(float64, 2, name, prefix##q, _, f64) + +// Different interface than HWY_NEON_DEF_FUNCTION_FULL_UI_64. +#define HWY_NEON_DEF_REDUCTION_UI64(name, prefix) \ + HWY_NEON_DEF_REDUCTION(uint64, 2, name, prefix##q, _, u64) \ + HWY_NEON_DEF_REDUCTION(int64, 2, name, prefix##q, _, s64) + +#if HWY_HAVE_FLOAT16 +#define HWY_NEON_DEF_REDUCTION_F16(name, prefix) \ + HWY_NEON_DEF_REDUCTION(float16, 4, name, prefix, _, f16) \ + HWY_NEON_DEF_REDUCTION(float16, 8, name, prefix##q, _, f16) +#else +#define HWY_NEON_DEF_REDUCTION_F16(name, prefix) +#endif + +HWY_NEON_DEF_REDUCTION_CORE_TYPES(ReduceMin, vminv) +HWY_NEON_DEF_REDUCTION_CORE_TYPES(ReduceMax, vmaxv) +HWY_NEON_DEF_REDUCTION_F16(ReduceMin, vminv) +HWY_NEON_DEF_REDUCTION_F16(ReduceMax, vmaxv) + +HWY_NEON_DEF_REDUCTION_CORE_TYPES(ReduceSum, vaddv) +HWY_NEON_DEF_REDUCTION_UI64(ReduceSum, vaddv) + +// Emulate missing UI64 and partial N=2. +template +HWY_API TFromD ReduceSum(D /* tag */, VFromD v10) { + return GetLane(v10) + ExtractLane(v10, 1); +} + +template +HWY_API TFromD ReduceMin(D /* tag */, VFromD v10) { + return HWY_MIN(GetLane(v10), ExtractLane(v10, 1)); +} + +template +HWY_API TFromD ReduceMax(D /* tag */, VFromD v10) { + return HWY_MAX(GetLane(v10), ExtractLane(v10, 1)); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API float16_t ReduceMin(D d, VFromD v10) { + return GetLane(Min(v10, Reverse2(d, v10))); +} + +template +HWY_API float16_t ReduceMax(D d, VFromD v10) { + return GetLane(Max(v10, Reverse2(d, v10))); +} + +template +HWY_API float16_t ReduceSum(D /* tag */, VFromD v) { + const float16x4_t x2 = vpadd_f16(v.raw, v.raw); + return GetLane(VFromD(vpadd_f16(x2, x2))); +} +template +HWY_API float16_t ReduceSum(D d, VFromD v) { + const Half dh; + return ReduceSum(dh, LowerHalf(dh, VFromD(vpaddq_f16(v.raw, v.raw)))); +} +#endif // HWY_HAVE_FLOAT16 + +#undef HWY_NEON_DEF_REDUCTION_CORE_TYPES +#undef HWY_NEON_DEF_REDUCTION_F16 +#undef HWY_NEON_DEF_REDUCTION_UI64 +#undef HWY_NEON_DEF_REDUCTION + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); +} + +// On Armv7 we define SumOfLanes and generic_ops defines ReduceSum via GetLane. +#else // !HWY_ARCH_ARM_A64 + +// Armv7 lacks N=2 and 8-bit x4, so enable generic versions of those. +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) \ + hwy::EnableIf<(HWY_MAX_LANES_D(D) == 2) || \ + (sizeof(TFromD) == 1 && HWY_MAX_LANES_D(D) == 4)>* = \ + nullptr +#undef HWY_IF_MINMAX_OF_LANES_D +#define HWY_IF_MINMAX_OF_LANES_D(D) \ + hwy::EnableIf<(HWY_MAX_LANES_D(D) == 2) || \ + (sizeof(TFromD) == 1 && HWY_MAX_LANES_D(D) == 4)>* = \ + nullptr + +// For arm7, we implement reductions using a series of pairwise operations. This +// produces the full vector result, so we express Reduce* in terms of *OfLanes. +#define HWY_NEON_BUILD_TYPE_T(type, size) type##x##size##_t +#define HWY_NEON_DEF_PAIRWISE_REDUCTION(type, size, name, prefix, suffix) \ + template \ + HWY_API Vec128 name##OfLanes(D /* d */, \ + Vec128 v) { \ + HWY_NEON_BUILD_TYPE_T(type, size) tmp = prefix##_##suffix(v.raw, v.raw); \ + if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + return Vec128(tmp); \ + } + +// For the wide versions, the pairwise operations produce a half-length vector. +// We produce that `tmp` and then Combine. +#define HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(type, size, half, name, prefix, \ + suffix) \ + template \ + HWY_API Vec128 name##OfLanes(D /* d */, \ + Vec128 v) { \ + HWY_NEON_BUILD_TYPE_T(type, half) tmp; \ + tmp = prefix##_##suffix(vget_high_##suffix(v.raw), \ + vget_low_##suffix(v.raw)); \ + if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + if ((size / 8) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + return Vec128(vcombine_##suffix(tmp, tmp)); \ + } + +#define HWY_NEON_DEF_PAIRWISE_REDUCTIONS(name, prefix) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(uint32, 2, name, prefix, u32) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(uint16, 4, name, prefix, u16) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(uint8, 8, name, prefix, u8) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(int32, 2, name, prefix, s32) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(int16, 4, name, prefix, s16) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(int8, 8, name, prefix, s8) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(float32, 2, name, prefix, f32) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint32, 4, 2, name, prefix, u32) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint16, 8, 4, name, prefix, u16) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint8, 16, 8, name, prefix, u8) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int32, 4, 2, name, prefix, s32) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int16, 8, 4, name, prefix, s16) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int8, 16, 8, name, prefix, s8) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(float32, 4, 2, name, prefix, f32) + +HWY_NEON_DEF_PAIRWISE_REDUCTIONS(Sum, vpadd) +HWY_NEON_DEF_PAIRWISE_REDUCTIONS(Min, vpmin) +HWY_NEON_DEF_PAIRWISE_REDUCTIONS(Max, vpmax) + +#undef HWY_NEON_DEF_PAIRWISE_REDUCTIONS +#undef HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION +#undef HWY_NEON_DEF_PAIRWISE_REDUCTION +#undef HWY_NEON_BUILD_TYPE_T + +// GetLane(SumsOf4(v)) is more efficient on ArmV7 NEON than the default +// N=4 I8/U8 ReduceSum implementation in generic_ops-inl.h +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +template +HWY_API TFromD ReduceSum(D /*d*/, VFromD v) { + return static_cast>(GetLane(SumsOf4(v))); +} + +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// Helper function to set 64 bits and potentially return a smaller vector. The +// overload is required to call the q vs non-q intrinsics. Note that 8-bit +// LoadMaskBits only requires 16 bits, but 64 avoids casting. +template +HWY_INLINE VFromD Set64(D /* tag */, uint64_t mask_bits) { + const auto v64 = Vec64(vdup_n_u64(mask_bits)); + return VFromD(BitCast(Full64>(), v64).raw); +} +template +HWY_INLINE Vec128 Set64(Full128 d, uint64_t mask_bits) { + return BitCast(d, Vec128(vdupq_n_u64(mask_bits))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const auto vmask_bits = Set64(du, mask_bits); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) static constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vmask_bits, Load(du, kRep8)); + + alignas(16) static constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(d.MaxLanes() + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Returns mask[i]? 0xF : 0 in each nibble. This is more efficient than +// BitsFromMask for use in (partial) CountTrue, FindFirstTrue and AllFalse. +template +HWY_INLINE uint64_t NibblesFromMask(D d, MFromD mask) { + const Full128 du16; + const Vec128 vu16 = BitCast(du16, VecFromMask(d, mask)); + const Vec64 nib(vshrn_n_u16(vu16.raw, 4)); + return GetLane(BitCast(Full64(), nib)); +} + +template +HWY_INLINE uint64_t NibblesFromMask(D d, MFromD mask) { + // There is no vshrn_n_u16 for uint16x4, so zero-extend. + const Twice d2; + const VFromD v128 = ZeroExtendVector(d2, VecFromMask(d, mask)); + // No need to mask, upper half is zero thanks to ZeroExtendVector. + return NibblesFromMask(d2, MaskFromVec(v128)); +} + +template +HWY_INLINE uint64_t NibblesFromMask(D d, MFromD mask) { + const Mask64> mask64(mask.raw); + const uint64_t nib = NibblesFromMask(Full64>(), mask64); + // Clear nibbles from upper half of 64-bits + return nib & ((1ull << (d.MaxBytes() * 4)) - 1); +} + +// Returns the lowest N for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(D d, uint64_t bits) { + return (d.MaxBytes() >= 8) ? bits : (bits & ((1ull << d.MaxLanes()) - 1)); +} + +} // namespace detail + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + alignas(16) static constexpr uint8_t kSliceLanes[16] = { + 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, + }; + const RebindToUnsigned du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); + +#if HWY_ARCH_ARM_A64 + // Can't vaddv - we need two separate bytes (16 bits). + const uint8x8_t x2 = vget_low_u8(vpaddq_u8(values.raw, values.raw)); + const uint8x8_t x4 = vpadd_u8(x2, x2); + const uint8x8_t x8 = vpadd_u8(x4, x4); + return vget_lane_u64(vreinterpret_u64_u8(x8), 0) & 0xFFFF; +#else + // Don't have vpaddq, so keep doubling lane size. + const uint16x8_t x2 = vpaddlq_u8(values.raw); + const uint32x4_t x4 = vpaddlq_u16(x2); + const uint64x2_t x8 = vpaddlq_u32(x4); + return (vgetq_lane_u64(x8, 1) << 8) | vgetq_lane_u64(x8, 0); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) static constexpr uint8_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const RebindToUnsigned du; + using VU = VFromD; + const VU slice(Load(Full64(), kSliceLanes).raw); + const VU values = BitCast(du, VecFromMask(d, mask)) & slice; + +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddv_u8(values.raw)); +#else + const uint16x4_t x2 = vpaddl_u8(values.raw); + const uint32x2_t x4 = vpaddl_u16(x2); + const uint64x1_t x8 = vpaddl_u32(x4); + return detail::OnlyActive(d, vget_lane_u64(x8, 0)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + alignas(16) static constexpr uint16_t kSliceLanes[8] = { + 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80}; + const RebindToUnsigned du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddvq_u16(values.raw)); +#else + const uint32x4_t x2 = vpaddlq_u16(values.raw); + const uint64x2_t x4 = vpaddlq_u32(x2); + return detail::OnlyActive(d, vgetq_lane_u64(x4, 0) + vgetq_lane_u64(x4, 1)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) static constexpr uint16_t kSliceLanes[4] = {1, 2, 4, 8}; + const RebindToUnsigned du; + using VU = VFromD; + const VU slice(Load(Full64(), kSliceLanes).raw); + const VU values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddv_u16(values.raw)); +#else + const uint32x2_t x2 = vpaddl_u16(values.raw); + const uint64x1_t x4 = vpaddl_u32(x2); + return detail::OnlyActive(d, vget_lane_u64(x4, 0)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + alignas(16) static constexpr uint32_t kSliceLanes[4] = {1, 2, 4, 8}; + const RebindToUnsigned du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddvq_u32(values.raw)); +#else + const uint64x2_t x2 = vpaddlq_u32(values.raw); + return detail::OnlyActive(d, vgetq_lane_u64(x2, 0) + vgetq_lane_u64(x2, 1)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) static constexpr uint32_t kSliceLanes[2] = {1, 2}; + const RebindToUnsigned du; + using VU = VFromD; + const VU slice(Load(Full64(), kSliceLanes).raw); + const VU values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddv_u32(values.raw)); +#else + const uint64x1_t x2 = vpaddl_u32(values.raw); + return detail::OnlyActive(d, vget_lane_u64(x2, 0)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + alignas(16) static constexpr uint64_t kSliceLanes[2] = {1, 2}; + const RebindToUnsigned du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddvq_u64(values.raw)); +#else + return detail::OnlyActive( + d, vgetq_lane_u64(values.raw, 0) + vgetq_lane_u64(values.raw, 1)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned du; + const Vec64 values = BitCast(du, VecFromMask(d, mask)) & Set(du, 1); + return vget_lane_u64(values.raw, 0); +} + +namespace detail { + +// Returns number of lanes whose mask is set. +// +// Masks are either FF..FF or 0. Unfortunately there is no reduce-sub op +// ("vsubv"). ANDing with 1 would work but requires a constant. Negating also +// changes each lane to 1 (if mask set) or 0. +// NOTE: PopCount also operates on vectors, so we still have to do horizontal +// sums separately. We specialize CountTrue for full vectors (negating instead +// of PopCount because it avoids an extra shift), and use PopCount of +// NibblesFromMask for partial vectors. + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> /*tag*/, Mask128 mask) { + const Full128 di; + const int8x16_t ones = + vnegq_s8(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s8(ones)); +#else + const int16x8_t x2 = vpaddlq_s8(ones); + const int32x4_t x4 = vpaddlq_s16(x2); + const int64x2_t x8 = vpaddlq_s32(x4); + return static_cast(vgetq_lane_s64(x8, 0) + vgetq_lane_s64(x8, 1)); +#endif +} +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> /*tag*/, Mask128 mask) { + const Full128 di; + const int16x8_t ones = + vnegq_s16(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s16(ones)); +#else + const int32x4_t x2 = vpaddlq_s16(ones); + const int64x2_t x4 = vpaddlq_s32(x2); + return static_cast(vgetq_lane_s64(x4, 0) + vgetq_lane_s64(x4, 1)); +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, Mask128 mask) { + const Full128 di; + const int32x4_t ones = + vnegq_s32(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s32(ones)); +#else + const int64x2_t x2 = vpaddlq_s32(ones); + return static_cast(vgetq_lane_s64(x2, 0) + vgetq_lane_s64(x2, 1)); +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, Mask128 mask) { +#if HWY_ARCH_ARM_A64 + const Full128 di; + const int64x2_t ones = + vnegq_s64(BitCast(di, VecFromMask(Full128(), mask)).raw); + return static_cast(vaddvq_s64(ones)); +#else + const Full128 du; + const auto mask_u = VecFromMask(du, RebindMask(du, mask)); + const uint64x2_t ones = vshrq_n_u64(mask_u.raw, 63); + return static_cast(vgetq_lane_u64(ones, 0) + vgetq_lane_u64(ones, 1)); +#endif +} + +} // namespace detail + +// Full +template > +HWY_API size_t CountTrue(D /* tag */, Mask128 mask) { + return detail::CountTrue(hwy::SizeTag(), mask); +} + +// Partial +template +HWY_API size_t CountTrue(D d, MFromD mask) { + constexpr int kDiv = 4 * sizeof(TFromD); + return PopCount(detail::NibblesFromMask(d, mask)) / kDiv; +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + constexpr size_t kDiv = 4 * sizeof(TFromD); + return Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv; +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + if (nib == 0) return -1; + constexpr size_t kDiv = 4 * sizeof(TFromD); + return static_cast(Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv); +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + constexpr size_t kDiv = 4 * sizeof(TFromD); + return (63 - Num0BitsAboveMS1Bit_Nonzero64(nib)) / kDiv; +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + if (nib == 0) return -1; + constexpr size_t kDiv = 4 * sizeof(TFromD); + return static_cast((63 - Num0BitsAboveMS1Bit_Nonzero64(nib)) / + kDiv); +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + const uint64_t mask_bits = BitsFromMask(d, mask); + const size_t kNumBytes = (d.MaxLanes() + 7) / 8; + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API bool AllFalse(D d, MFromD m) { + return detail::NibblesFromMask(d, m) == 0; +} + +// Full +template > +HWY_API bool AllTrue(D d, Mask128 m) { + return detail::NibblesFromMask(d, m) == ~0ull; +} +// Partial +template +HWY_API bool AllTrue(D d, MFromD m) { + return detail::NibblesFromMask(d, m) == (1ull << (d.MaxBytes() * 4)) - 1; +} + +// ------------------------------ Compress + +template +struct CompressIsPartition { + enum { value = (sizeof(T) != 1) }; +}; + +namespace detail { + +// Load 8 bytes, replicate into upper half so ZipLower can use the lower half. +template +HWY_INLINE Vec128 Load8Bytes(D /*tag*/, const uint8_t* bytes) { + return Vec128(vreinterpretq_u8_u64( + vld1q_dup_u64(HWY_RCAST_ALIGNED(const uint64_t*, bytes)))); +} + +// Load 8 bytes and return half-reg with N <= 8 bytes. +template +HWY_INLINE VFromD Load8Bytes(D d, const uint8_t* bytes) { + return Load(d, bytes); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<2> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Repartition d8; + const Simd du; + + // NEON does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<2> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Repartition d8; + const Simd du; + + // NEON does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<4> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<4> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<8> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<8> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#endif + +// Helper function called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. +template +HWY_INLINE Vec128 Compress(Vec128 v, uint64_t mask_bits) { + const auto idx = + detail::IdxFromBits(hwy::SizeTag(), mask_bits); + using D = DFromV; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_INLINE Vec128 CompressNot(Vec128 v, uint64_t mask_bits) { + const auto idx = + detail::IdxFromNotBits(hwy::SizeTag(), mask_bits); + using D = DFromV; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return detail::Compress(v, BitsFromMask(d, mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, BitsFromMask(d, Not(mask))); + } + return detail::CompressNot(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits + +template +HWY_INLINE Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, mask); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; // so we can support fp16/bf16 + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + const MFromD store_mask = RebindMask(d, FirstN(du, count)); + const VFromD compressed = + detail::Compress(BitCast(du, v), mask_bits); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (d.MaxLanes() + 7) / 8; + CopyBytes(bits, &mask_bits); + if (d.MaxLanes() < 8) { + mask_bits &= (1ull << d.MaxLanes()) - 1; + } + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +namespace detail { + +#define HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#define HWY_NEON_BUILD_ARG_HWY_LOAD_INT from + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_LOAD_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D) +#define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64; also exclude any +// emulated types. +#define HWY_IF_LOAD_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D), \ + hwy::EnableIf<(HWY_MAX_LANES_D(D) == 1 || sizeof(TFromD) < 8)>* = \ + nullptr +#define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +// Must return raw tuple because Tuple2 lack a ctor, and we cannot use +// brace-initialization in HWY_NEON_DEF_FUNCTION because some functions return +// void. +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple2().raw) +// Tuple tag arg allows overloading (cannot just overload on return type) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const NativeLaneType*from, Tuple2 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved2, vld2, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple3().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const NativeLaneType*from, Tuple3 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved3, vld3, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple4().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const NativeLaneType*from, Tuple4 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved4, vld4, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#undef HWY_NEON_DEF_FUNCTION_LOAD_INT +#undef HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#undef HWY_NEON_BUILD_ARG_HWY_LOAD_INT + +} // namespace detail + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + auto raw = detail::LoadInterleaved2(detail::NativeLanePointer(unaligned), + detail::Tuple2()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); +} + +// <= 32 bits: avoid loading more than N bytes by copying to buffer +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + // The smallest vector registers are 64-bits and we want space for two. + alignas(16) T buf[2 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved2(detail::NativeLanePointer(buf), + detail::Tuple2()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved2(D d, T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1) { + const Half dh; + VFromD v00, v10, v01, v11; + LoadInterleaved2(dh, detail::NativeLanePointer(unaligned), v00, v10); + LoadInterleaved2(dh, detail::NativeLanePointer(unaligned + 2), v01, v11); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved3 + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + auto raw = detail::LoadInterleaved3(detail::NativeLanePointer(unaligned), + detail::Tuple3()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); + v2 = VFromD(raw.val[2]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + // The smallest vector registers are 64-bits and we want space for three. + alignas(16) T buf[3 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved3(detail::NativeLanePointer(buf), + detail::Tuple3()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); + v2 = VFromD(raw.val[2]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, Vec128& v2) { + const Half dh; + VFromD v00, v10, v20, v01, v11, v21; + LoadInterleaved3(dh, detail::NativeLanePointer(unaligned), v00, v10, v20); + LoadInterleaved3(dh, detail::NativeLanePointer(unaligned + 3), v01, v11, v21); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved4 + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + auto raw = detail::LoadInterleaved4(detail::NativeLanePointer(unaligned), + detail::Tuple4()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); + v2 = VFromD(raw.val[2]); + v3 = VFromD(raw.val[3]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + alignas(16) T buf[4 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved4(detail::NativeLanePointer(buf), + detail::Tuple4()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); + v2 = VFromD(raw.val[2]); + v3 = VFromD(raw.val[3]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, Vec128& v2, + Vec128& v3) { + const Half dh; + VFromD v00, v10, v20, v30, v01, v11, v21, v31; + LoadInterleaved4(dh, detail::NativeLanePointer(unaligned), v00, v10, v20, + v30); + LoadInterleaved4(dh, detail::NativeLanePointer(unaligned + 4), v01, v11, v21, + v31); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); + v3 = Combine(d, v31, v30); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_LOAD_INT + +// ------------------------------ StoreInterleaved2 + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_STORE_INT +#define HWY_NEON_BUILD_RET_HWY_STORE_INT(type, size) void +#define HWY_NEON_BUILD_ARG_HWY_STORE_INT to, tup.raw + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_STORE_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D) +#define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64; also exclude any +// emulated types. +#define HWY_IF_STORE_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D), \ + hwy::EnableIf<(HWY_MAX_LANES_D(D) == 1 || sizeof(TFromD) < 8)>* = \ + nullptr +#define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple2 tup, NativeLaneType*to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved2, vst2, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple3 tup, NativeLaneType*to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved3, vst3, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple4 tup, NativeLaneType*to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved4, vst4, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#undef HWY_NEON_DEF_FUNCTION_STORE_INT +#undef HWY_NEON_BUILD_TPL_HWY_STORE_INT +#undef HWY_NEON_BUILD_RET_HWY_STORE_INT +#undef HWY_NEON_BUILD_ARG_HWY_STORE_INT +} // namespace detail + +template > +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, detail::NativeLanePointer(unaligned)); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[2 * 8 / sizeof(T)]; + detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, detail::NativeLanePointer(buf)); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved2(Vec128 v0, Vec128 v1, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved2(LowerHalf(dh, v0), LowerHalf(dh, v1), dh, + detail::NativeLanePointer(unaligned)); + StoreInterleaved2(UpperHalf(dh, v0), UpperHalf(dh, v1), dh, + detail::NativeLanePointer(unaligned + 2)); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved3 + +template > +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + T* HWY_RESTRICT unaligned) { + detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, detail::NativeLanePointer(unaligned)); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[3 * 8 / sizeof(T)]; + detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, detail::NativeLanePointer(buf)); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved3(Vec128 v0, Vec128 v1, Vec128 v2, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved3(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), dh, + detail::NativeLanePointer(unaligned)); + StoreInterleaved3(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), dh, + detail::NativeLanePointer(unaligned + 3)); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved4 + +template > +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, T* HWY_RESTRICT unaligned) { + detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, detail::NativeLanePointer(unaligned)); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, T* HWY_RESTRICT unaligned) { + alignas(16) T buf[4 * 8 / sizeof(T)]; + detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, detail::NativeLanePointer(buf)); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved4(Vec128 v0, Vec128 v1, Vec128 v2, + Vec128 v3, D d, T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved4(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), + LowerHalf(dh, v3), dh, + detail::NativeLanePointer(unaligned)); + StoreInterleaved4(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), + UpperHalf(dh, v3), dh, + detail::NativeLanePointer(unaligned + 4)); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_STORE_INT + +// Fall back on generic Load/StoreInterleaved[234] for any emulated types. +// Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_NEON_IF_EMULATED_D. + +// ------------------------------ Additional mask logical operations +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const FixedTag d; + const auto vmask = VecFromMask(d, mask); + return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask))); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Simd d; + const auto vmask = VecFromMask(d, mask); + const auto neg_vmask = + ResizeBitCast(d, Neg(ResizeBitCast(Full64(), vmask))); + return MaskFromVec(Or(vmask, neg_vmask)); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Full128 d; + const Repartition di64; + + auto vmask = BitCast(di64, VecFromMask(d, mask)); + vmask = Or(vmask, Neg(vmask)); + + // Copy the sign bit of the first int64_t lane to the second int64_t lane + const auto vmask2 = BroadcastSignBit(InterleaveLower(Zero(di64), vmask)); + return MaskFromVec(BitCast(d, Or(vmask, vmask2))); +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const FixedTag d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto zero = Zero(di); + const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero); + return MaskFromVec(BitCast(d, And(vmask, vmask2))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Simd d; + const RebindToSigned di; + + const auto vmask = ResizeBitCast(Full64(), VecFromMask(d, mask)); + const auto only_first_vmask = + BitCast(d, Neg(ResizeBitCast(di, And(vmask, Neg(vmask))))); + return MaskFromVec(only_first_vmask); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Full128 d; + const RebindToSigned di; + const Repartition di64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + const auto vmask2 = VecFromMask(di64, InterleaveLower(zero, vmask) == zero); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 /*mask*/) { + const FixedTag d; + const RebindToSigned di; + using TI = MakeSigned; + + return RebindMask(d, MaskFromVec(Set(di, TI(-1)))); +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + const Simd d; + return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); +} + +// ------------------------------ Lt128 + +template +HWY_INLINE MFromD Lt128(D d, VFromD a, VFromD b) { + static_assert(IsSame, uint64_t>(), "T must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const MFromD eqHL = Eq(a, b); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const VFromD ltLx = DupEven(ltHL); + const VFromD outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template +HWY_INLINE MFromD Lt128Upper(D d, VFromD a, VFromD b) { + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template +HWY_INLINE MFromD Eq128(D d, VFromD a, VFromD b) { + static_assert(IsSame, uint64_t>(), "T must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template +HWY_INLINE MFromD Eq128Upper(D d, VFromD a, VFromD b) { + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template +HWY_INLINE MFromD Ne128(D d, VFromD a, VFromD b) { + static_assert(IsSame, uint64_t>(), "T must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template +HWY_INLINE MFromD Ne128Upper(D d, VFromD a, VFromD b) { + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_INLINE VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +HWY_NEON_DEF_FUNCTION_INT_8_16_32(LeadingZeroCount, vclz, _, 1) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(LeadingZeroCount, vclz, _, 1) + +template )> +HWY_API V LeadingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + const Repartition du32; + + const auto v_k32 = BitCast(du32, Set(du, 32)); + const auto v_u32_lzcnt = LeadingZeroCount(BitCast(du32, v)) + v_k32; + const auto v_u32_lo_lzcnt = + And(v_u32_lzcnt, BitCast(du32, Set(du, 0xFFFFFFFFu))); + const auto v_u32_hi_lzcnt = + BitCast(du32, ShiftRight<32>(BitCast(du, v_u32_lzcnt))); + + return BitCast( + d, IfThenElse(v_u32_hi_lzcnt == v_k32, v_u32_lo_lzcnt, v_u32_hi_lzcnt)); +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); +} + +template +HWY_API V TrailingZeroCount(V v) { + return LeadingZeroCount(ReverseBits(v)); +} + +template +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const Repartition du8; + return LeadingZeroCount( + ReverseLaneBytes(BitCast(d, ReverseBits(BitCast(du8, v))))); +} + +namespace detail { // for code folding +#if HWY_ARCH_ARM_V7 +#undef vuzp1_s8 +#undef vuzp1_u8 +#undef vuzp1_s16 +#undef vuzp1_u16 +#undef vuzp1_s32 +#undef vuzp1_u32 +#undef vuzp1_f32 +#undef vuzp1q_s8 +#undef vuzp1q_u8 +#undef vuzp1q_s16 +#undef vuzp1q_u16 +#undef vuzp1q_s32 +#undef vuzp1q_u32 +#undef vuzp1q_f32 +#undef vuzp2_s8 +#undef vuzp2_u8 +#undef vuzp2_s16 +#undef vuzp2_u16 +#undef vuzp2_s32 +#undef vuzp2_u32 +#undef vuzp2_f32 +#undef vuzp2q_s8 +#undef vuzp2q_u8 +#undef vuzp2q_s16 +#undef vuzp2q_u16 +#undef vuzp2q_s32 +#undef vuzp2q_u32 +#undef vuzp2q_f32 +#undef vzip1_s8 +#undef vzip1_u8 +#undef vzip1_s16 +#undef vzip1_u16 +#undef vzip1_s32 +#undef vzip1_u32 +#undef vzip1_f32 +#undef vzip1q_s8 +#undef vzip1q_u8 +#undef vzip1q_s16 +#undef vzip1q_u16 +#undef vzip1q_s32 +#undef vzip1q_u32 +#undef vzip1q_f32 +#undef vzip2_s8 +#undef vzip2_u8 +#undef vzip2_s16 +#undef vzip2_u16 +#undef vzip2_s32 +#undef vzip2_u32 +#undef vzip2_f32 +#undef vzip2q_s8 +#undef vzip2q_u8 +#undef vzip2q_s16 +#undef vzip2q_u16 +#undef vzip2q_s32 +#undef vzip2q_u32 +#undef vzip2q_f32 +#endif + +#undef HWY_NEON_BUILD_ARG_1 +#undef HWY_NEON_BUILD_ARG_2 +#undef HWY_NEON_BUILD_ARG_3 +#undef HWY_NEON_BUILD_PARAM_1 +#undef HWY_NEON_BUILD_PARAM_2 +#undef HWY_NEON_BUILD_PARAM_3 +#undef HWY_NEON_BUILD_RET_1 +#undef HWY_NEON_BUILD_RET_2 +#undef HWY_NEON_BUILD_RET_3 +#undef HWY_NEON_BUILD_TPL_1 +#undef HWY_NEON_BUILD_TPL_2 +#undef HWY_NEON_BUILD_TPL_3 +#undef HWY_NEON_DEF_FUNCTION +#undef HWY_NEON_DEF_FUNCTION_ALL_FLOATS +#undef HWY_NEON_DEF_FUNCTION_ALL_TYPES +#undef HWY_NEON_DEF_FUNCTION_BFLOAT_16 +#undef HWY_NEON_DEF_FUNCTION_FLOAT_16 +#undef HWY_NEON_DEF_FUNCTION_FLOAT_16_32 +#undef HWY_NEON_DEF_FUNCTION_FLOAT_32 +#undef HWY_NEON_DEF_FUNCTION_FLOAT_64 +#undef HWY_NEON_DEF_FUNCTION_FULL_UI +#undef HWY_NEON_DEF_FUNCTION_FULL_UI_64 +#undef HWY_NEON_DEF_FUNCTION_FULL_UIF_64 +#undef HWY_NEON_DEF_FUNCTION_INT_16 +#undef HWY_NEON_DEF_FUNCTION_INT_32 +#undef HWY_NEON_DEF_FUNCTION_INT_64 +#undef HWY_NEON_DEF_FUNCTION_INT_8 +#undef HWY_NEON_DEF_FUNCTION_INT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_INTS +#undef HWY_NEON_DEF_FUNCTION_INTS_UINTS +#undef HWY_NEON_DEF_FUNCTION_UI_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_UIF_64 +#undef HWY_NEON_DEF_FUNCTION_UIF_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_UINT_16 +#undef HWY_NEON_DEF_FUNCTION_UINT_32 +#undef HWY_NEON_DEF_FUNCTION_UINT_64 +#undef HWY_NEON_DEF_FUNCTION_UINT_8 +#undef HWY_NEON_DEF_FUNCTION_UINT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_UINTS +#undef HWY_NEON_EVAL +#undef HWY_NEON_IF_EMULATED_D +#undef HWY_NEON_IF_NOT_EMULATED_D +} // namespace detail + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/aom/third_party/highway/hwy/ops/arm_sve-inl.h b/third_party/aom/third_party/highway/hwy/ops/arm_sve-inl.h new file mode 100644 index 000000000000..87f0e4999693 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/arm_sve-inl.h @@ -0,0 +1,7009 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Arm SVE[2] vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include + +#include "third_party/highway/hwy/ops/shared-inl.h" + +// Arm C215 declares that SVE vector lengths will always be a power of two. +// We default to relying on this, which makes some operations more efficient. +// You can still opt into fixups by setting this to 0 (unsupported). +#ifndef HWY_SVE_IS_POW2 +#define HWY_SVE_IS_POW2 1 +#endif + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +#define HWY_SVE_HAVE_2 1 +#else +#define HWY_SVE_HAVE_2 0 +#endif + +// If 1, both __bf16 and a limited set of *_bf16 SVE intrinsics are available: +// create/get/set/dup, ld/st, sel, rev, trn, uzp, zip. +#if HWY_ARM_HAVE_SCALAR_BF16_TYPE && defined(__ARM_FEATURE_SVE_BF16) +#define HWY_SVE_HAVE_BF16_FEATURE 1 +#else +#define HWY_SVE_HAVE_BF16_FEATURE 0 +#endif + +// HWY_SVE_HAVE_BF16_VEC is defined to 1 if the SVE svbfloat16_t vector type +// is supported, even if HWY_SVE_HAVE_BF16_FEATURE (= intrinsics) is 0. +#if HWY_SVE_HAVE_BF16_FEATURE || \ + (HWY_COMPILER_CLANG >= 1200 && defined(__ARM_FEATURE_SVE_BF16)) || \ + HWY_COMPILER_GCC_ACTUAL >= 1000 +#define HWY_SVE_HAVE_BF16_VEC 1 +#else +#define HWY_SVE_HAVE_BF16_VEC 0 +#endif + +// HWY_SVE_HAVE_F32_TO_BF16C is defined to 1 if the SVE svcvt_bf16_f32_x +// and svcvtnt_bf16_f32_x intrinsics are available, even if the __bf16 type +// is disabled +#if HWY_SVE_HAVE_BF16_VEC && defined(__ARM_FEATURE_SVE_BF16) +#define HWY_SVE_HAVE_F32_TO_BF16C 1 +#else +#define HWY_SVE_HAVE_F32_TO_BF16C 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct DFromV_t {}; // specialized in macros +template +using DFromV = typename DFromV_t>::type; + +template +using TFromV = TFromD>; + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// Args: BASE, CHAR, BITS, HALF, NAME, OP + +// Unsigned: +#define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 64, 32, NAME, OP) + +// Signed: +#define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP) + +// Float: +#define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 16, 16, NAME, OP) +#define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 64, 32, NAME, OP) + +#define HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) \ + X_MACRO(bfloat, bf, 16, 16, NAME, OP) + +#if HWY_SVE_HAVE_BF16_FEATURE +#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) +// We have both f16 and bf16, so nothing is emulated. + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the D template +// argument +#define HWY_SVE_IF_EMULATED_D(D) hwy::EnableIf()>* = nullptr +#define HWY_GENERIC_IF_EMULATED_D(D) \ + hwy::EnableIf()>* = nullptr +#define HWY_SVE_IF_NOT_EMULATED_D(D) hwy::EnableIf* = nullptr +#else +#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) +#define HWY_SVE_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_SVE_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) +#endif // HWY_SVE_HAVE_BF16_FEATURE + +// For all element sizes: +#define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + +// HWY_SVE_FOREACH_F does not include HWY_SVE_FOREACH_BF16 because SVE lacks +// bf16 overloads for some intrinsics (especially less-common arithmetic). +// However, this does include f16 because SVE supports it unconditionally. +#define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) + +// Commonly used type categories for a given element size: +#define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) + +// Commonly used type categories: +#define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +// Assemble types for use in x-macros +#define HWY_SVE_T(BASE, BITS) BASE##BITS##_t +#define HWY_SVE_D(BASE, BITS, N, POW2) Simd +#define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t +#define HWY_SVE_TUPLE(BASE, BITS, MUL) sv##BASE##BITS##x##MUL##_t + +} // namespace detail + +#define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <> \ + struct DFromV_t { \ + using type = ScalableTag; \ + }; + +HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) +#endif +#undef HWY_SPECIALIZE + +// Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX +// instructions, and we anyway only use it when the predicate is ptrue. + +// vector = f(vector), e.g. Not +#define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } +#define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } +#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ + return sv##OP##_##CHAR##BITS##_m(no, m, a); \ + } +#define HWY_SVE_RETV_ARGMV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(m, v); \ + } +#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ + return sv##OP##_##CHAR##BITS##_z(m, a); \ + } + +// vector = f(vector, scalar), e.g. detail::AddN +#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } +// All-true mask +#define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +// User-specified mask. Mask=false value is undefined and must be set by caller +// because SVE instructions take it from one of the two inputs, whereas +// AVX-512, RVV and Highway allow a third argument. +#define HWY_SVE_RETV_ARGMVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(m, a, b); \ + } +// User-specified mask. Mask=false value is zero. +#define HWY_SVE_RETV_ARGMVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_z(m, a, b); \ + } + +#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS(a, b, c); \ + } +#define HWY_SVE_RETV_ARGMVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS##_x(m, a, b, c); \ + } +#define HWY_SVE_RETV_ARGMVVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ + HWY_SVE_V(BASE, BITS) add) { \ + return sv##OP##_##CHAR##BITS##_z(m, x, mul, add); \ + } + +// ------------------------------ Lanes + +namespace detail { + +// Returns actual lanes of a hardware vector without rounding to a power of two. +template +HWY_INLINE size_t AllHardwareLanes() { + return svcntb_pat(SV_ALL); +} +template +HWY_INLINE size_t AllHardwareLanes() { + return svcnth_pat(SV_ALL); +} +template +HWY_INLINE size_t AllHardwareLanes() { + return svcntw_pat(SV_ALL); +} +template +HWY_INLINE size_t AllHardwareLanes() { + return svcntd_pat(SV_ALL); +} + +// All-true mask from a macro + +#if HWY_SVE_IS_POW2 +#define HWY_SVE_ALL_PTRUE(BITS) svptrue_b##BITS() +#define HWY_SVE_PTRUE(BITS) svptrue_b##BITS() +#else +#define HWY_SVE_ALL_PTRUE(BITS) svptrue_pat_b##BITS(SV_ALL) +#define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2) +#endif // HWY_SVE_IS_POW2 + +} // namespace detail + +#if HWY_HAVE_SCALABLE + +// Returns actual number of lanes after capping by N and shifting. May return 0 +// (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8). +template +HWY_API size_t Lanes(Simd d) { + const size_t actual = detail::AllHardwareLanes(); + constexpr size_t kMaxLanes = MaxLanes(d); + constexpr int kClampedPow2 = HWY_MIN(kPow2, 0); + // Common case of full vectors: avoid any extra instructions. + if (detail::IsFull(d)) return actual; + return HWY_MIN(detail::ScaleByPower(actual, kClampedPow2), kMaxLanes); +} + +#endif // HWY_HAVE_SCALABLE + +// ================================================== MASK INIT + +// One mask bit per byte; only the one belonging to the lowest byte is valid. + +// ------------------------------ FirstN +#define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \ + const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \ + return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast(limit)); \ + } +HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_FIRSTN, FirstN, whilelt) +#endif + +template +svbool_t FirstN(D /* tag */, size_t count) { + return FirstN(RebindToUnsigned(), count); +} + +#undef HWY_SVE_FIRSTN + +template +using MFromD = svbool_t; + +namespace detail { + +#define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_PTRUE(BITS); \ + } \ + template \ + HWY_API svbool_t All##NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_ALL_PTRUE(BITS); \ + } + +HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) +#undef HWY_SVE_WRAP_PTRUE + +HWY_API svbool_t PFalse() { return svpfalse_b(); } + +// Returns all-true if d is HWY_FULL or FirstN(N) after capping N. +// +// This is used in functions that load/store memory; other functions (e.g. +// arithmetic) can ignore d and use PTrue instead. +template +svbool_t MakeMask(D d) { + return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d)); +} + +} // namespace detail + +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +template +HWY_API svbool_t MaskFalse(const D /*d*/) { + return detail::PFalse(); +} + +// ================================================== INIT + +// ------------------------------ Set +// vector = f(d, scalar), e.g. Set +#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) arg) { \ + return sv##OP##_##CHAR##BITS(arg); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n) +#if HWY_SVE_HAVE_BF16_FEATURE // for if-elif chain +HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, dup_n) +#elif HWY_SVE_HAVE_BF16_VEC +// Required for Zero and VFromD +template +HWY_API svbfloat16_t Set(D d, bfloat16_t arg) { + return svreinterpret_bf16_u16( + Set(RebindToUnsigned(), BitCastScalar(arg))); +} +#else // neither bf16 feature nor vector: emulate with u16 +// Required for Zero and VFromD +template +HWY_API svuint16_t Set(D d, bfloat16_t arg) { + const RebindToUnsigned du; + return Set(du, BitCastScalar(arg)); +} +#endif // HWY_SVE_HAVE_BF16_FEATURE +#undef HWY_SVE_SET + +template +using VFromD = decltype(Set(D(), TFromD())); + +using VBF16 = VFromD>; + +// ------------------------------ MaskedSetOr/MaskedSet + +#define HWY_SVE_MASKED_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ + return sv##OP##_##CHAR##BITS##_m(no, m, op); \ + } + +HWY_SVE_FOREACH(HWY_SVE_MASKED_SET_OR, MaskedSetOr, dup_n) +#undef HWY_SVE_MASKED_SET_OR + +#define HWY_SVE_MASKED_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ + return sv##OP##_##CHAR##BITS##_z(m, op); \ + } + +HWY_SVE_FOREACH(HWY_SVE_MASKED_SET, MaskedSet, dup_n) +#undef HWY_SVE_MASKED_SET + +// ------------------------------ Zero + +template +VFromD Zero(D d) { + // Cast to support bfloat16_t. + const RebindToUnsigned du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ BitCast + +namespace detail { + +// u8: no change +#define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } + +// All other types +#define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u8_##CHAR##BITS(v); \ + } \ + template \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \ + return sv##OP##_##CHAR##BITS##_u8(v); \ + } + +// U08 is special-cased, hence do not use FOREACH. +HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _) +HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret) + +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CAST, _, reinterpret) +#else // !(HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC) +template )> +HWY_INLINE svuint8_t BitCastToByte(V v) { + const RebindToUnsigned> du; + return BitCastToByte(BitCast(du, v)); +} + +template +HWY_INLINE VFromD BitCastFromByte(D d, svuint8_t v) { + const RebindToUnsigned du; + return BitCastFromByte(du, v); +} +#endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC + +#undef HWY_SVE_CAST_NOP +#undef HWY_SVE_CAST + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Undefined + +#define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return sv##OP##_##CHAR##BITS(); \ + } + +HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_UNDEFINED, Undefined, undef) +#endif + +template +VFromD Undefined(D d) { + const RebindToUnsigned du; + return BitCast(d, Undefined(du)); +} + +// ------------------------------ Tuple + +// tuples = f(d, v..), e.g. Create2 +#define HWY_SVE_CREATE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \ + NAME##2(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1) { \ + return sv##OP##2_##CHAR##BITS(v0, v1); \ + } \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) NAME##3( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v0, \ + HWY_SVE_V(BASE, BITS) v1, HWY_SVE_V(BASE, BITS) v2) { \ + return sv##OP##3_##CHAR##BITS(v0, v1, v2); \ + } \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \ + NAME##4(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3) { \ + return sv##OP##4_##CHAR##BITS(v0, v1, v2, v3); \ + } + +HWY_SVE_FOREACH(HWY_SVE_CREATE, Create, create) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CREATE, Create, create) +#endif +#undef HWY_SVE_CREATE + +template +using Vec2 = decltype(Create2(D(), Zero(D()), Zero(D()))); +template +using Vec3 = decltype(Create3(D(), Zero(D()), Zero(D()), Zero(D()))); +template +using Vec4 = decltype(Create4(D(), Zero(D()), Zero(D()), Zero(D()), Zero(D()))); + +#define HWY_SVE_GET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple) { \ + return sv##OP##2_##CHAR##BITS(tuple, kIndex); \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple) { \ + return sv##OP##3_##CHAR##BITS(tuple, kIndex); \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple) { \ + return sv##OP##4_##CHAR##BITS(tuple, kIndex); \ + } + +HWY_SVE_FOREACH(HWY_SVE_GET, Get, get) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_GET, Get, get) +#endif +#undef HWY_SVE_GET + +#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \ + NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(BASE, BITS) vec) { \ + return sv##OP##2_##CHAR##BITS(tuple, kIndex, vec); \ + } \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) \ + NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple, HWY_SVE_V(BASE, BITS) vec) { \ + return sv##OP##3_##CHAR##BITS(tuple, kIndex, vec); \ + } \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \ + NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple, HWY_SVE_V(BASE, BITS) vec) { \ + return sv##OP##4_##CHAR##BITS(tuple, kIndex, vec); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SET, Set, set) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_SET, Set, set) +#endif +#undef HWY_SVE_SET + +// ------------------------------ ResizeBitCast + +// Same as BitCast on SVE +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, v); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API svint8_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return svdupq_n_s8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, + t14, t15); +} + +template +HWY_API svuint8_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return svdupq_n_u8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, + t14, t15); +} + +template +HWY_API svint16_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return svdupq_n_s16(t0, t1, t2, t3, t4, t5, t6, t7); +} + +template +HWY_API svuint16_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return svdupq_n_u16(t0, t1, t2, t3, t4, t5, t6, t7); +} + +template +HWY_API svfloat16_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, + TFromD t4, TFromD t5, + TFromD t6, TFromD t7) { + return svdupq_n_f16(t0, t1, t2, t3, t4, t5, t6, t7); +} + +template +HWY_API VBF16 Dup128VecFromValues(D d, TFromD t0, TFromD t1, TFromD t2, + TFromD t3, TFromD t4, TFromD t5, + TFromD t6, TFromD t7) { +#if HWY_SVE_HAVE_BF16_FEATURE + (void)d; + return svdupq_n_bf16(t0, t1, t2, t3, t4, t5, t6, t7); +#else + const RebindToUnsigned du; + return BitCast( + d, Dup128VecFromValues( + du, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +#endif +} + +template +HWY_API svint32_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return svdupq_n_s32(t0, t1, t2, t3); +} + +template +HWY_API svuint32_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return svdupq_n_u32(t0, t1, t2, t3); +} + +template +HWY_API svfloat32_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return svdupq_n_f32(t0, t1, t2, t3); +} + +template +HWY_API svint64_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return svdupq_n_s64(t0, t1); +} + +template +HWY_API svuint64_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return svdupq_n_u64(t0, t1); +} + +template +HWY_API svfloat64_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return svdupq_n_f64(t0, t1); +} + +// ------------------------------ GetLane + +namespace detail { +#define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_T(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLaneM, lasta) +HWY_SVE_FOREACH(HWY_SVE_GET_LANE, ExtractLastMatchingLaneM, lastb) +#undef HWY_SVE_GET_LANE +} // namespace detail + +template +HWY_API TFromV GetLane(V v) { + return detail::GetLaneM(v, detail::PFalse()); +} + +// ================================================== LOGICAL + +// detail::*N() functions accept a scalar argument to avoid extra Set(). + +// ------------------------------ Not +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not ) // NOLINT + +// ------------------------------ And + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and) + +template +HWY_API V And(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, And(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Or + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, OrN, orr_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr) + +template +HWY_API V Or(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Or(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ MaskedOr +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedOr, orr) + +// ------------------------------ Xor + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor) + +template +HWY_API V Xor(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Xor(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ AndNot + +namespace detail { +#define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n) +#undef HWY_SVE_RETV_ARGPVN_SWAP +} // namespace detail + +#define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic) +#undef HWY_SVE_RETV_ARGPVV_SWAP + +template +HWY_API V AndNot(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Xor3 + +#if HWY_SVE_HAVE_2 + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV, Xor3, eor3) + +template +HWY_API V Xor3(const V x1, const V x2, const V x3) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); +} + +#else +template +HWY_API V Xor3(V x1, V x2, V x3) { + return Xor(x1, Xor(x2, x3)); +} +#endif + +// ------------------------------ Or3 +template +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +// Need to return original type instead of unsigned. +#define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return BitCast(DFromV(), \ + sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt) +#undef HWY_SVE_POPCNT + +// ================================================== SIGN + +// ------------------------------ Neg +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg) + +HWY_API VBF16 Neg(VBF16 v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask()))); +} + +// ------------------------------ SaturatedNeg +#if HWY_SVE_HAVE_2 +#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 +#undef HWY_NATIVE_SATURATED_NEG_8_16_32 +#else +#define HWY_NATIVE_SATURATED_NEG_8_16_32 +#endif + +#ifdef HWY_NATIVE_SATURATED_NEG_64 +#undef HWY_NATIVE_SATURATED_NEG_64 +#else +#define HWY_NATIVE_SATURATED_NEG_64 +#endif + +HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedNeg, qneg) +#endif // HWY_SVE_HAVE_2 + +// ================================================== ARITHMETIC + +// Per-target flags to prevent generic_ops-inl.h defining Add etc. +#ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS +#undef HWY_NATIVE_OPERATOR_REPLACEMENTS +#else +#define HWY_NATIVE_OPERATOR_REPLACEMENTS +#endif + +// ------------------------------ Add + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n) +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add) + +// ------------------------------ Sub + +namespace detail { +// Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg. +#define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_z(pg, a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n) +#undef HWY_SVE_RETV_ARGPVN_MASK +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub) + +// ------------------------------ SumsOf8 +HWY_API svuint64_t SumsOf8(const svuint8_t v) { + const ScalableTag du32; + const ScalableTag du64; + const svbool_t pg = detail::PTrue(du64); + + const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1); + // Compute pairwise sum of u32 and extend to u64. + +#if HWY_SVE_HAVE_2 + return svadalp_u64_x(pg, Zero(du64), sums_of_4); +#else + const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32); + // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended) + const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4)); + return Add(hi, lo); +#endif +} + +HWY_API svint64_t SumsOf8(const svint8_t v) { + const ScalableTag di32; + const ScalableTag di64; + const svbool_t pg = detail::PTrue(di64); + + const svint32_t sums_of_4 = svdot_n_s32(Zero(di32), v, 1); +#if HWY_SVE_HAVE_2 + return svadalp_s64_x(pg, Zero(di64), sums_of_4); +#else + const svint64_t hi = svasr_n_s64_x(pg, BitCast(di64, sums_of_4), 32); + // Isolate the lower 32 bits (to be added to the upper 32 and sign-extended) + const svint64_t lo = svextw_s64_x(pg, BitCast(di64, sums_of_4)); + return Add(hi, lo); +#endif +} + +// ------------------------------ SumsOf2 +#if HWY_SVE_HAVE_2 +namespace detail { + +HWY_INLINE svint16_t SumsOf2(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { + const ScalableTag di16; + const svbool_t pg = detail::PTrue(di16); + return svadalp_s16_x(pg, Zero(di16), v); +} + +HWY_INLINE svuint16_t SumsOf2(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { + const ScalableTag du16; + const svbool_t pg = detail::PTrue(du16); + return svadalp_u16_x(pg, Zero(du16), v); +} + +HWY_INLINE svint32_t SumsOf2(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { + const ScalableTag di32; + const svbool_t pg = detail::PTrue(di32); + return svadalp_s32_x(pg, Zero(di32), v); +} + +HWY_INLINE svuint32_t SumsOf2(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { + const ScalableTag du32; + const svbool_t pg = detail::PTrue(du32); + return svadalp_u32_x(pg, Zero(du32), v); +} + +HWY_INLINE svint64_t SumsOf2(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, svint32_t v) { + const ScalableTag di64; + const svbool_t pg = detail::PTrue(di64); + return svadalp_s64_x(pg, Zero(di64), v); +} + +HWY_INLINE svuint64_t SumsOf2(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, svuint32_t v) { + const ScalableTag du64; + const svbool_t pg = detail::PTrue(du64); + return svadalp_u64_x(pg, Zero(du64), v); +} + +} // namespace detail +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ SumsOf4 +namespace detail { + +HWY_INLINE svint32_t SumsOf4(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { + return svdot_n_s32(Zero(ScalableTag()), v, 1); +} + +HWY_INLINE svuint32_t SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { + return svdot_n_u32(Zero(ScalableTag()), v, 1); +} + +HWY_INLINE svint64_t SumsOf4(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { + return svdot_n_s64(Zero(ScalableTag()), v, 1); +} + +HWY_INLINE svuint64_t SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { + return svdot_n_u64(Zero(ScalableTag()), v, 1); +} + +} // namespace detail + +// ------------------------------ SaturatedAdd + +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB +#endif + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) + +// ------------------------------ SaturatedSub + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) + +// ------------------------------ AbsDiff +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, AbsDiff, abd) + +// ------------------------------ ShiftLeft[Same] + +#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ + } \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##Same(HWY_SVE_V(BASE, BITS) v, int bits) { \ + return sv##OP##_##CHAR##BITS##_x( \ + HWY_SVE_PTRUE(BITS), v, static_cast(bits)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n) + +// ------------------------------ ShiftRight[Same] + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) + +#undef HWY_SVE_SHIFT_N + +// ------------------------------ MaskedShift[Left/Right] + +#define HWY_SVE_SHIFT_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ + auto shifts = static_cast(kBits); \ + return sv##OP##_##CHAR##BITS##_z(m, v, shifts); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeft, lsl_n) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRight, asr_n) +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRight, lsr_n) + +#undef HWY_SVE_SHIFT_Z + +// ------------------------------ MaskedShiftRightOr + +#define HWY_SVE_SHIFT_OR(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ + auto shifts = static_cast(kBits); \ + return svsel##_##CHAR##BITS(m, sv##OP##_##CHAR##BITS##_z(m, v, shifts), \ + no); \ + } +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, asr_n) +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, lsr_n) + +#undef HWY_SVE_SHIFT_OR + +// ------------------------------ RotateRight + +#if HWY_SVE_HAVE_2 + +#define HWY_SVE_ROTATE_RIGHT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + if (kBits == 0) return v; \ + return sv##OP##_##CHAR##BITS(v, Zero(DFromV()), \ + HWY_MAX(kBits, 1)); \ + } + +HWY_SVE_FOREACH_U(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) +HWY_SVE_FOREACH_I(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) + +#undef HWY_SVE_ROTATE_RIGHT_N + +#else // !HWY_SVE_HAVE_2 +template +HWY_API V RotateRight(const V v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} +#endif + +// ------------------------------ Shl, Shr + +#define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ + const RebindToUnsigned> du; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ + BitCast(du, bits)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl) + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr) + +#undef HWY_SVE_SHIFT + +// ------------------------------ RoundingShiftLeft[Same]/RoundingShr + +#if HWY_SVE_HAVE_2 + +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +#define HWY_SVE_ROUNDING_SHR_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + HWY_IF_CONSTEXPR(kBits == 0) { return v; } \ + \ + return sv##OP##_##CHAR##BITS##_x( \ + HWY_SVE_PTRUE(BITS), v, static_cast(HWY_MAX(kBits, 1))); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR_N, RoundingShiftRight, rshr_n) + +#undef HWY_SVE_ROUNDING_SHR_N + +#define HWY_SVE_ROUNDING_SHR(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ + const RebindToSigned> di; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ + Neg(BitCast(di, bits))); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR, RoundingShr, rshl) + +#undef HWY_SVE_ROUNDING_SHR + +template +HWY_API V RoundingShiftRightSame(V v, int bits) { + const DFromV d; + using T = TFromD; + return RoundingShr(v, Set(d, static_cast(bits))); +} + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ BroadcastSignBit (ShiftRight) +template +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight) * 8 - 1>(v); +} + +// ------------------------------ Abs (ShiftRight, Add, Xor, AndN) + +// Workaround for incorrect results with `svabs`. +#if HWY_COMPILER_CLANG +template +HWY_API V Abs(V v) { + const V sign = BroadcastSignBit(v); + return Xor(Add(v, sign), sign); +} + +template +HWY_NOINLINE V Abs(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = MakeUnsigned>; + return BitCast( + d, detail::AndN(BitCast(du, v), static_cast(~SignMask()))); +} + +#else +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) +#endif + +// ------------------------------ SaturatedAbs +#if HWY_SVE_HAVE_2 +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs) +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ MaskedAbsOr +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_M, MaskedAbsOr, abs) + +// ------------------------------ MaskedAbs +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_Z, MaskedAbs, abs) + +// ------------------------------ Mul + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Mul, mul) + +// ------------------------------ MulHigh +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) + +// ------------------------------ MulFixedPoint15 +HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) { +#if HWY_SVE_HAVE_2 + return svqrdmulh_s16(a, b); +#else + const DFromV d; + const RebindToUnsigned du; + + const svuint16_t lo = BitCast(du, Mul(a, b)); + const svint16_t hi = MulHigh(a, b); + // We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must + // carry that into the result. Instead isolate the top two bits because only + // they can influence the result. + const svuint16_t lo_top2 = ShiftRight<14>(lo); + // Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0. + const svuint16_t rounding = ShiftRight<1>(detail::AddN(lo_top2, 1)); + return Add(Add(hi, hi), BitCast(d, rounding)); +#endif +} + +// ------------------------------ Div +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, Div, div) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPVV, Div, div) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div) + +// ------------------------------ ApproximateReciprocal +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe) + +// ------------------------------ Sqrt +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt) + +// ------------------------------ MaskedSqrt +#ifdef HWY_NATIVE_MASKED_SQRT +#undef HWY_NATIVE_MASKED_SQRT +#else +#define HWY_NATIVE_MASKED_SQRT +#endif + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrt, sqrt) + +// ------------------------------ ApproximateReciprocalSqrt +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte) + +// ------------------------------ MulAdd + +// Per-target flag to prevent generic_ops-inl.h from defining int MulAdd. +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA +#else +#define HWY_NATIVE_INT_FMA +#endif + +#define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ + HWY_SVE_V(BASE, BITS) add) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \ + } + +HWY_SVE_FOREACH(HWY_SVE_FMA, MulAdd, mad) + +// ------------------------------ NegMulAdd +HWY_SVE_FOREACH(HWY_SVE_FMA, NegMulAdd, msb) + +// ------------------------------ MulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb) + +// ------------------------------ NegMulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad) + +#undef HWY_SVE_FMA + +// ------------------------------ Round etc. + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz) + +// ================================================== MASK + +// ------------------------------ RebindMask +template +HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) { + return mask; +} + +// ------------------------------ Mask logical + +HWY_API svbool_t Not(svbool_t m) { + // We don't know the lane type, so assume 8-bit. For larger types, this will + // de-canonicalize the predicate, i.e. set bits to 1 even though they do not + // correspond to the lowest byte in the lane. Arm says such bits are ignored. + return svnot_b_z(HWY_SVE_PTRUE(8), m); +} +HWY_API svbool_t And(svbool_t a, svbool_t b) { + return svand_b_z(b, b, a); // same order as AndNot for consistency +} +HWY_API svbool_t AndNot(svbool_t a, svbool_t b) { + return svbic_b_z(b, b, a); // reversed order like NEON +} +HWY_API svbool_t Or(svbool_t a, svbool_t b) { + return svsel_b(a, a, b); // a ? true : b +} +HWY_API svbool_t Xor(svbool_t a, svbool_t b) { + return svsel_b(a, svnand_b_z(a, a, b), b); // a ? !(a & b) : b. +} + +HWY_API svbool_t ExclusiveNeither(svbool_t a, svbool_t b) { + return svnor_b_z(HWY_SVE_PTRUE(8), a, b); // !a && !b, undefined if a && b. +} + +// ------------------------------ CountTrue + +#define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \ + return sv##OP##_b##BITS(detail::MakeMask(d), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp) +#undef HWY_SVE_COUNT_TRUE + +// For 16-bit Compress: full vector, not limited to SV_POW2. +namespace detail { + +#define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \ + return sv##OP##_b##BITS(svptrue_b##BITS(), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp) +#undef HWY_SVE_COUNT_TRUE_FULL + +} // namespace detail + +// ------------------------------ AllFalse +template +HWY_API bool AllFalse(D d, svbool_t m) { + return !svptest_any(detail::MakeMask(d), m); +} + +// ------------------------------ AllTrue +template +HWY_API bool AllTrue(D d, svbool_t m) { + return CountTrue(d, m) == Lanes(d); +} + +// ------------------------------ FindFirstTrue +template +HWY_API intptr_t FindFirstTrue(D d, svbool_t m) { + return AllFalse(d, m) ? intptr_t{-1} + : static_cast( + CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m))); +} + +// ------------------------------ FindKnownFirstTrue +template +HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) { + return CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)); +} + +// ------------------------------ IfThenElse +#define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(m, yes, no); \ + } + +HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) +HWY_SVE_FOREACH_BF16(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) +#undef HWY_SVE_IF_THEN_ELSE + +template , HWY_SVE_IF_EMULATED_D(D)> +HWY_API V IfThenElse(const svbool_t mask, V yes, V no) { + const RebindToUnsigned du; + return BitCast( + D(), IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +// ------------------------------ IfThenElseZero + +template , HWY_SVE_IF_NOT_EMULATED_D(D)> +HWY_API V IfThenElseZero(const svbool_t mask, const V yes) { + return IfThenElse(mask, yes, Zero(D())); +} + +template , HWY_SVE_IF_EMULATED_D(D)> +HWY_API V IfThenElseZero(const svbool_t mask, V yes) { + const RebindToUnsigned du; + return BitCast(D(), IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); +} + +// ------------------------------ IfThenZeroElse + +template , HWY_SVE_IF_NOT_EMULATED_D(D)> +HWY_API V IfThenZeroElse(const svbool_t mask, const V no) { + return IfThenElse(mask, Zero(D()), no); +} + +template , HWY_SVE_IF_EMULATED_D(D)> +HWY_API V IfThenZeroElse(const svbool_t mask, V no) { + const RebindToUnsigned du; + return BitCast(D(), IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); +} + +// ------------------------------ Additional mask logical operations +HWY_API svbool_t SetBeforeFirst(svbool_t m) { + // We don't know the lane type, so assume 8-bit. For larger types, this will + // de-canonicalize the predicate, i.e. set bits to 1 even though they do not + // correspond to the lowest byte in the lane. Arm says such bits are ignored. + return svbrkb_b_z(HWY_SVE_PTRUE(8), m); +} + +HWY_API svbool_t SetAtOrBeforeFirst(svbool_t m) { + // We don't know the lane type, so assume 8-bit. For larger types, this will + // de-canonicalize the predicate, i.e. set bits to 1 even though they do not + // correspond to the lowest byte in the lane. Arm says such bits are ignored. + return svbrka_b_z(HWY_SVE_PTRUE(8), m); +} + +HWY_API svbool_t SetOnlyFirst(svbool_t m) { return svbrka_b_z(m, m); } + +HWY_API svbool_t SetAtOrAfterFirst(svbool_t m) { + return Not(SetBeforeFirst(m)); +} + +// ------------------------------ PromoteMaskTo + +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +template ) * 2)> +HWY_API svbool_t PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svunpklo_b(m); +} + +template ) * 2)> +HWY_API svbool_t PromoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { + using TFrom = TFromD; + using TWFrom = MakeWide>; + static_assert(sizeof(TWFrom) > sizeof(TFrom), + "sizeof(TWFrom) > sizeof(TFrom) must be true"); + + const Rebind dw_from; + return PromoteMaskTo(d_to, dw_from, PromoteMaskTo(dw_from, d_from, m)); +} + +// ------------------------------ DemoteMaskTo + +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +template +HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svuzp1_b8(m, m); +} + +template +HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svuzp1_b16(m, m); +} + +template +HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svuzp1_b32(m, m); +} + +template ) / 4)> +HWY_API svbool_t DemoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { + using TFrom = TFromD; + using TNFrom = MakeNarrow>; + static_assert(sizeof(TNFrom) < sizeof(TFrom), + "sizeof(TNFrom) < sizeof(TFrom) must be true"); + + const Rebind dn_from; + return DemoteMaskTo(d_to, dn_from, DemoteMaskTo(dn_from, d_from, m)); +} + +// ------------------------------ LowerHalfOfMask +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API svbool_t LowerHalfOfMask(D /*d*/, svbool_t m) { + return m; +} + +// ------------------------------ MaskedAddOr etc. (IfThenElse) + +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMin, min) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMax, max) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV, MaskedSqrt, sqrt) +#if HWY_SVE_HAVE_2 +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatAdd, qadd) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatSub, qsub) +#endif +} // namespace detail + +template +HWY_API V MaskedMinOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedMin(m, a, b), no); +} + +template +HWY_API V MaskedMaxOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedMax(m, a, b), no); +} + +template +HWY_API V MaskedAddOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedAdd(m, a, b), no); +} + +template +HWY_API V MaskedSubOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedSub(m, a, b), no); +} + +template +HWY_API V MaskedMulOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedMul(m, a, b), no); +} + +template , hwy::float16_t>() ? (1 << 2) : 0) | + (1 << 4) | (1 << 8))> +HWY_API V MaskedDivOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedDiv(m, a, b), no); +} + +// I8/U8/I16/U16 MaskedDivOr is implemented after I8/U8/I16/U16 Div + +#if HWY_SVE_HAVE_2 +template +HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedSatAdd(m, a, b), no); +} + +template +HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedSatSub(m, a, b), no); +} +#else +template +HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedAdd(a, b), no); +} + +template +HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedSub(a, b), no); +} +#endif + +// ------------------------------ MaskedMulAddOr +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad) +} + +// Per-target flag to prevent generic_ops-inl.h from defining int +// MaskedMulAddOr. +#ifdef HWY_NATIVE_MASKED_INT_FMA +#undef HWY_NATIVE_MASKED_INT_FMA +#else +#define HWY_NATIVE_MASKED_INT_FMA +#endif + +template +HWY_API V MaskedMulAddOr(V no, M m, V mul, V x, V add) { + return IfThenElse(m, detail::MaskedMulAdd(m, mul, x, add), no); +} + +template +HWY_API V MaskedSqrtOr(V no, M m, V v) { + return IfThenElse(m, detail::MaskedSqrt(m, v), no); +} + +// ================================================== REDUCE + +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +// These return T, suitable for ReduceSum. +namespace detail { +#define HWY_SVE_REDUCE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + /* The intrinsic returns [u]int64_t; truncate to T so we can broadcast. */ \ + using T = HWY_SVE_T(BASE, BITS); \ + using TU = MakeUnsigned; \ + constexpr uint64_t kMask = LimitsMax(); \ + return static_cast(static_cast( \ + static_cast(sv##OP##_##CHAR##BITS(pg, v)) & kMask)); \ + } + +#define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(pg, v); \ + } + +// TODO: Remove SumOfLanesM in favor of using MaskedReduceSum +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv) + +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanesM, minv) +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanesM, maxv) +// NaN if all are +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanesM, minnmv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanesM, maxnmv) + +#undef HWY_SVE_REDUCE +#undef HWY_SVE_REDUCE_ADD +} // namespace detail + +// detail::SumOfLanesM, detail::MinOfLanesM, and detail::MaxOfLanesM is more +// efficient for N=4 I8/U8 reductions on SVE than the default implementations +// of the N=4 I8/U8 ReduceSum/ReduceMin/ReduceMax operations in +// generic_ops-inl.h +#undef HWY_IF_REDUCE_D +#define HWY_IF_REDUCE_D(D) hwy::EnableIf* = nullptr + +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +#ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#else +#define HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#endif + +template +HWY_API TFromD ReduceSum(D d, VFromD v) { + return detail::SumOfLanesM(detail::MakeMask(d), v); +} + +template +HWY_API TFromD ReduceMin(D d, VFromD v) { + return detail::MinOfLanesM(detail::MakeMask(d), v); +} + +template +HWY_API TFromD ReduceMax(D d, VFromD v) { + return detail::MaxOfLanesM(detail::MakeMask(d), v); +} + +#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR +#undef HWY_NATIVE_MASKED_REDUCE_SCALAR +#else +#define HWY_NATIVE_MASKED_REDUCE_SCALAR +#endif + +template +HWY_API TFromD MaskedReduceSum(D /*d*/, M m, VFromD v) { + return detail::SumOfLanesM(m, v); +} +template +HWY_API TFromD MaskedReduceMin(D /*d*/, M m, VFromD v) { + return detail::MinOfLanesM(m, v); +} +template +HWY_API TFromD MaskedReduceMax(D /*d*/, M m, VFromD v) { + return detail::MaxOfLanesM(m, v); +} + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); +} + +// ------------------------------ MaskedAdd etc. (IfThenElse) + +#ifdef HWY_NATIVE_ZERO_MASKED_ARITH +#undef HWY_NATIVE_ZERO_MASKED_ARITH +#else +#define HWY_NATIVE_ZERO_MASKED_ARITH +#endif + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedMax, max) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedAdd, add) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedSub, sub) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedMul, mul) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV_Z, MaskedMulAdd, mad) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV_Z, MaskedNegMulAdd, msb) + +// I8/U8/I16/U16 MaskedDiv is implemented after I8/U8/I16/U16 Div + +#if HWY_SVE_HAVE_2 +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedSaturatedAdd, qadd) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedSaturatedSub, qsub) +#else +template +HWY_API V MaskedSaturatedAdd(M m, V a, V b) { + return IfThenElseZero(m, SaturatedAdd(a, b)); +} + +template +HWY_API V MaskedSaturatedSub(M m, V a, V b) { + return IfThenElseZero(m, SaturatedSub(a, b)); +} +#endif + +template , HWY_IF_I16_D(D)> +HWY_API V MaskedMulFixedPoint15(M m, V a, V b) { + return IfThenElseZero(m, MulFixedPoint15(a, b)); +} + +template >> +HWY_API VFromD MaskedWidenMulPairwiseAdd(D d32, M m, V16 a, V16 b) { + return IfThenElseZero(m, WidenMulPairwiseAdd(d32, a, b)); +} + +template +HWY_API VFromD MaskedWidenMulPairwiseAdd(DF df, M m, VBF a, VBF b) { + return IfThenElseZero(m, WidenMulPairwiseAdd(df, a, b)); +} + +// ================================================== COMPARE + +// mask = f(vector, vector) +#define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } + +// ------------------------------ Eq +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n) +} // namespace detail + +// ------------------------------ Ne +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n) +} // namespace detail + +// ------------------------------ Lt +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n) +} // namespace detail + +// ------------------------------ Le +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Le, cmple) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LeN, cmple_n) +} // namespace detail + +// ------------------------------ Gt/Ge (swapped order) +template +HWY_API svbool_t Gt(const V a, const V b) { + return Lt(b, a); +} +template +HWY_API svbool_t Ge(const V a, const V b) { + return Le(b, a); +} +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GeN, cmpge_n) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GtN, cmpgt_n) +} // namespace detail + +#undef HWY_SVE_COMPARE +#undef HWY_SVE_COMPARE_N + +// ------------------------------ TestBit +template +HWY_API svbool_t TestBit(const V a, const V bit) { + return detail::NeN(And(a, bit), 0); +} + +// ------------------------------ Min/Max (Lt, IfThenElse) + +HWY_SVE_FOREACH_U(HWY_SVE_RETV_ARGPVV, Min, min) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm) + +// Workaround for incorrect results with `svmin`. +#if HWY_COMPILER_CLANG +template +HWY_API V Min(V a, V b) { + return IfThenElse(Lt(a, b), a, b); +} +template +HWY_API V Min(V a, V b) { + return IfThenElse(Lt(a, b), a, b); +} + +#else +HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPVV, Min, min) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm) +#endif + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n) +} // namespace detail + +// ================================================== SWIZZLE + +// ------------------------------ ConcatEven/ConcatOdd + +// WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the +// full vector length, not rounded down to a power of two as we require). +namespace detail { + +#define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, uzp1) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, uzp2) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, + uzp1) +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, + uzp2) +#endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenBlocks, uzp1q) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, + ConcatEvenBlocks, uzp1q) +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, + uzp2q) +#endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +#endif // defined(__ARM_FEATURE_SVE_MATMUL_FP64) +#undef HWY_SVE_CONCAT_EVERY_SECOND + +// Used to slide up / shift whole register left; mask indicates which range +// to take from lo, and the rest is filled from hi starting at its lowest. +#define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice) +#if HWY_SVE_HAVE_BF16_FEATURE +HWY_SVE_FOREACH_BF16(HWY_SVE_SPLICE, Splice, splice) +#else +template )> +HWY_INLINE V Splice(V hi, V lo, svbool_t mask) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Splice(BitCast(du, hi), BitCast(du, lo), mask)); +} +#endif // HWY_SVE_HAVE_BF16_FEATURE +#undef HWY_SVE_SPLICE + +} // namespace detail + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_SVE_IS_POW2 + if (detail::IsFull(d)) return detail::ConcatOddFull(hi, lo); +#endif + const VFromD hi_odd = detail::ConcatOddFull(hi, hi); + const VFromD lo_odd = detail::ConcatOddFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_SVE_IS_POW2 + if (detail::IsFull(d)) return detail::ConcatEvenFull(hi, lo); +#endif + const VFromD hi_odd = detail::ConcatEvenFull(hi, hi); + const VFromD lo_odd = detail::ConcatEvenFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +} + +HWY_API svuint8_t U8FromU32(const svuint32_t v) { + const DFromV du32; + const RepartitionToNarrow du16; + const RepartitionToNarrow du8; + + const svuint16_t cast16 = BitCast(du16, v); + const svuint16_t x2 = svuzp1_u16(cast16, cast16); + const svuint8_t cast8 = BitCast(du8, x2); + return svuzp1_u8(cast8, cast8); +} + +// ================================================== MASK + +// ------------------------------ MaskFromVec (Ne) +template +HWY_API svbool_t MaskFromVec(const V v) { + using T = TFromV; + return detail::NeN(v, ConvertScalarTo(0)); +} + +// ------------------------------ VecFromMask +template +HWY_API VFromD VecFromMask(const D d, svbool_t mask) { + const RebindToSigned di; + // This generates MOV imm, whereas svdup_n_s8_z generates MOV scalar, which + // requires an extra instruction plus M0 pipeline. + return BitCast(d, IfThenElseZero(mask, Set(di, -1))); +} + +// ------------------------------ BitsFromMask (AndN, Shl, ReduceSum, GetLane +// ConcatEvenFull, U8FromU32) + +namespace detail { + +// For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes. +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return svdup_n_u8_z(m, 1); +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d8; + const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1)); + return detail::ConcatEvenFull(b16, b16); // lower half +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return U8FromU32(svdup_n_u32_z(m, 1)); +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d32; + const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1)); + return U8FromU32(detail::ConcatEvenFull(b64, b64)); // lower half +} + +// Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane. +HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) { + const ScalableTag d8; + const ScalableTag d16; + const ScalableTag d32; + const ScalableTag d64; + // TODO(janwas): could use SVE2 BDEP, but it's optional. + x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x)))); + x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x)))); + x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x)))); + return BitCast(d64, x); +} + +} // namespace detail + +// BitsFromMask is required if `HWY_MAX_BYTES <= 64`, which is true for the +// fixed-size SVE targets. +#if HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_SVE_256 +template +HWY_API uint64_t BitsFromMask(D d, svbool_t mask) { + const Repartition du64; + svuint64_t bits_in_u64 = detail::BitsFromBool(detail::BoolFromMask(mask)); + + constexpr size_t N = MaxLanes(d); + static_assert(N < 64, "SVE2_128 and SVE_256 are only 128 or 256 bits"); + const uint64_t valid = (1ull << N) - 1; + HWY_IF_CONSTEXPR(N <= 8) { + // Upper bits are undefined even if N == 8, hence mask. + return GetLane(bits_in_u64) & valid; + } + + // Up to 8 of the least-significant bits of each u64 lane are valid. + bits_in_u64 = detail::AndN(bits_in_u64, 0xFF); + + // 128-bit vector: only two u64, so avoid ReduceSum. + HWY_IF_CONSTEXPR(HWY_TARGET == HWY_SVE2_128) { + alignas(16) uint64_t lanes[2]; + Store(bits_in_u64, du64, lanes); + // lanes[0] is always valid because we know N > 8, but lanes[1] might + // not be - we may mask it out below. + const uint64_t result = lanes[0] + (lanes[1] << 8); + // 8-bit lanes, no further masking + HWY_IF_CONSTEXPR(N == 16) return result; + return result & valid; + } + + // Shift the 8-bit groups into place in each u64 lane. + alignas(32) uint64_t kShifts[4] = {0 * 8, 1 * 8, 2 * 8, 3 * 8}; + bits_in_u64 = Shl(bits_in_u64, Load(du64, kShifts)); + return ReduceSum(du64, bits_in_u64) & valid; +} + +#endif // HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_SVE_256 + +// ------------------------------ IsNegative (Lt) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +template +HWY_API svbool_t IsNegative(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + + return detail::LtN(BitCast(di, v), static_cast(0)); +} + +// ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse) + +#if HWY_SVE_HAVE_2 + +#define HWY_SVE_IF_VEC(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) yes, \ + HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(yes, no, mask); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IF_VEC, IfVecThenElse, bsl) +#undef HWY_SVE_IF_VEC + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, IfVecThenElse(BitCast(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +#else + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ BitwiseIfThenElse + +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return IfVecThenElse(mask, yes, no); +} + +// ------------------------------ CopySign (BitwiseIfThenElse) +template +HWY_API V CopySign(const V magn, const V sign) { + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API V CopySignToAbs(const V abs, const V sign) { +#if HWY_SVE_HAVE_2 // CopySign is more efficient than OrAnd + return CopySign(abs, sign); +#else + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +#endif +} + +// ------------------------------ Floating-point classification (Ne) + +template +HWY_API svbool_t IsNaN(const V v) { + return Ne(v, v); // could also use cmpuo +} + +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +// We use a fused Set/comparison for IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +template +HWY_API svbool_t IsInf(const V v) { + using T = TFromV; + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + + // 'Shift left' to clear the sign bit + const VFromD vu = BitCast(du, v); + const VFromD v2 = Add(vu, vu); + // Check for exponent=max and mantissa=0. + const VFromD max2 = Set(di, hwy::MaxExponentTimes2()); + return RebindMask(d, Eq(v2, BitCast(du, max2))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API svbool_t IsFinite(const V v) { + using T = TFromV; + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField())); +} + +// ------------------------------ MulByPow2/MulByFloorPow2 + +#define HWY_SVE_MUL_BY_POW2(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(int, BITS) exp) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, exp); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_MUL_BY_POW2, MulByPow2, scale) + +#undef HWY_SVE_MUL_BY_POW2 + +// ------------------------------ MaskedEq etc. +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +// mask = f(mask, vector, vector) +#define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ + HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(m, a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple) + +#undef HWY_SVE_COMPARE_Z + +template > +HWY_API MFromD MaskedGt(M m, V a, V b) { + // Swap args to reverse comparison + return MaskedLt(m, b, a); +} + +template > +HWY_API MFromD MaskedGe(M m, V a, V b) { + // Swap args to reverse comparison + return MaskedLe(m, b, a); +} + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return MaskedNe(m, v, v); +} + +// ================================================== MEMORY + +// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream + +#define HWY_SVE_MEM(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + LoadU(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return svld1_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(p)); \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + MaskedLoad(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return svld1_##CHAR##BITS(m, detail::NativeLanePointer(p)); \ + } \ + template \ + HWY_API void StoreU(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + svst1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), v); \ + } \ + template \ + HWY_API void Stream(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + svstnt1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), \ + v); \ + } \ + template \ + HWY_API void BlendedStore(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + svst1_##CHAR##BITS(m, detail::NativeLanePointer(p), v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_MEM, _, _) +HWY_SVE_FOREACH_BF16(HWY_SVE_MEM, _, _) + +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadU(du, detail::U16LanePointer(p))); +} + +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + StoreU(BitCast(du, v), du, detail::U16LanePointer(p)); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, + MaskedLoad(RebindMask(du, m), du, detail::U16LanePointer(p))); +} + +// MaskedLoadOr is generic and does not require emulation. + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + BlendedStore(BitCast(du, v), RebindMask(du, m), du, + detail::U16LanePointer(p)); +} + +#undef HWY_SVE_MEM + +#if HWY_TARGET != HWY_SVE2_128 +namespace detail { +#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + /* All-true predicate to load all 128 bits. */ \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), \ + detail::NativeLanePointer(p)); \ + } + +HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) + +template +HWY_API VFromD LoadDupFull128(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadDupFull128(du, detail::U16LanePointer(p))); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_SVE2_128 + +#if HWY_TARGET == HWY_SVE2_128 +// On the HWY_SVE2_128 target, LoadDup128 is the same as LoadU since vectors +// cannot exceed 16 bytes on the HWY_SVE2_128 target. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} +#else // HWY_TARGET != HWY_SVE2_128 +// If D().MaxBytes() <= 16 is true, simply do a LoadU operation. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// If D().MaxBytes() > 16 is true, need to load the vector using ld1rq +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return detail::LoadDupFull128(d, p); +} + +#endif // HWY_TARGET != HWY_SVE2_128 + +// Truncate to smaller size and store +#ifdef HWY_NATIVE_STORE_TRUNCATED +#undef HWY_NATIVE_STORE_TRUNCATED +#else +#define HWY_NATIVE_STORE_TRUNCATED +#endif + +#define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + const HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(detail::PTrue(d), detail::NativeLanePointer(p), v); \ + } + +#define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 8) +#define HWY_SVE_STORE_TRUNCATED_HALF(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 16) +#define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32) + +HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) +HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) +HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) +HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h) +HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h) +HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, TruncateStore, st1w) + +#undef HWY_SVE_STORE_TRUNCATED + +// ------------------------------ Load/Store + +// SVE only requires lane alignment, not natural alignment of the entire +// vector, so Load/Store are the same as LoadU/StoreU. +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +template +HWY_API void Store(const V v, D d, TFromD* HWY_RESTRICT p) { + StoreU(v, d, p); +} + +// ------------------------------ MaskedLoadOr + +// SVE MaskedLoad hard-codes zero, so this requires an extra blend. +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElse(m, MaskedLoad(m, d, p), v); +} + +// ------------------------------ ScatterOffset/Index + +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +#define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \ + v); \ + } + +#define HWY_SVE_MASKED_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ + HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) indices) { \ + sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices, v); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_SCATTER_INDEX, MaskedScatterIndex, + st1_scatter) +#undef HWY_SVE_SCATTER_OFFSET +#undef HWY_SVE_MASKED_SCATTER_INDEX + +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT p, + VFromD> indices) { + MaskedScatterIndex(v, detail::MakeMask(d), d, p, indices); +} + +// ------------------------------ GatherOffset/Index + +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +#define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \ + offset); \ + } +#define HWY_SVE_MASKED_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) indices) { \ + const RebindToSigned di; \ + (void)di; /* for HWY_DASSERT */ \ + HWY_DASSERT(AllFalse(di, Lt(indices, Zero(di)))); \ + return sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_GATHER_INDEX, MaskedGatherIndex, + ld1_gather) +#undef HWY_SVE_GATHER_OFFSET +#undef HWY_SVE_MASKED_GATHER_INDEX + +template +HWY_API VFromD MaskedGatherIndexOr(VFromD no, svbool_t m, D d, + const TFromD* HWY_RESTRICT p, + VFromD> indices) { + return IfThenElse(m, MaskedGatherIndex(m, d, p, indices), no); +} + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT p, + VFromD> indices) { + return MaskedGatherIndex(detail::MakeMask(d), d, p, indices); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +#define HWY_SVE_LOAD2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1) { \ + const HWY_SVE_TUPLE(BASE, BITS, 2) tuple = sv##OP##_##CHAR##BITS( \ + detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ + v0 = svget2(tuple, 0); \ + v1 = svget2(tuple, 1); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD2, LoadInterleaved2, ld2) + +#undef HWY_SVE_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_SVE_LOAD3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2) { \ + const HWY_SVE_TUPLE(BASE, BITS, 3) tuple = sv##OP##_##CHAR##BITS( \ + detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ + v0 = svget3(tuple, 0); \ + v1 = svget3(tuple, 1); \ + v2 = svget3(tuple, 2); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD3, LoadInterleaved3, ld3) + +#undef HWY_SVE_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_SVE_LOAD4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2, HWY_SVE_V(BASE, BITS) & v3) { \ + const HWY_SVE_TUPLE(BASE, BITS, 4) tuple = sv##OP##_##CHAR##BITS( \ + detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ + v0 = svget4(tuple, 0); \ + v1 = svget4(tuple, 1); \ + v2 = svget4(tuple, 2); \ + v3 = svget4(tuple, 3); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD4, LoadInterleaved4, ld4) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD4, LoadInterleaved4, ld4) + +#undef HWY_SVE_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(unaligned), \ + Create2(d, v0, v1)); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2) +HWY_SVE_FOREACH_BF16(HWY_SVE_STORE2, StoreInterleaved2, st2) + +#undef HWY_SVE_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(unaligned), \ + Create3(d, v0, v1, v2)); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3) +HWY_SVE_FOREACH_BF16(HWY_SVE_STORE3, StoreInterleaved3, st3) + +#undef HWY_SVE_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(unaligned), \ + Create4(d, v0, v1, v2, v3)); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE4, StoreInterleaved4, st4) +HWY_SVE_FOREACH_BF16(HWY_SVE_STORE4, StoreInterleaved4, st4) + +#undef HWY_SVE_STORE4 + +// Fall back on generic Load/StoreInterleaved[234] for any emulated types. +// Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_SVE_IF_EMULATED_D. + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// Same sign +#define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) + +// 2x +template +HWY_API svuint32_t PromoteTo(Simd dto, svuint8_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} +template +HWY_API svint32_t PromoteTo(Simd dto, svint8_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} +template +HWY_API svuint64_t PromoteTo(Simd dto, svuint16_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} +template +HWY_API svint64_t PromoteTo(Simd dto, svint16_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} + +// 3x +template +HWY_API svuint64_t PromoteTo(Simd dto, svuint8_t vfrom) { + const RepartitionToNarrow d4; + const RepartitionToNarrow d2; + return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom))); +} +template +HWY_API svint64_t PromoteTo(Simd dto, svint8_t vfrom) { + const RepartitionToNarrow d4; + const RepartitionToNarrow d2; + return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom))); +} + +// Sign change +template ), sizeof(TFromV))> +HWY_API VFromD PromoteTo(D di, V v) { + const RebindToUnsigned du; + return BitCast(di, PromoteTo(du, v)); +} + +// ------------------------------ PromoteTo F + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +// Unlike Highway's ZipLower, this returns the same type. +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLowerSame, zip1) +} // namespace detail + +template +HWY_API svfloat32_t PromoteTo(Simd /* d */, + const svfloat16_t v) { + // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so + // first replicate each lane once. + const svfloat16_t vv = detail::ZipLowerSame(v, v); + return svcvt_f32_f16_x(detail::PTrue(Simd()), vv); +} + +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svfloat16_t v) { + // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so + // first replicate each lane once. + const svfloat16_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_f16_x(detail::PTrue(Simd()), + detail::ZipLowerSame(vv, vv)); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_f32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svint32_t v) { + const svint32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_s32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svuint32_t v) { + const svuint32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_u32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svint64_t PromoteTo(Simd /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLowerSame(v, v); + return svcvt_s64_f32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svuint64_t PromoteTo(Simd /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLowerSame(v, v); + return svcvt_u64_f32_x(detail::PTrue(Simd()), vv); +} + +// ------------------------------ PromoteUpperTo + +namespace detail { +HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +#undef HWY_SVE_PROMOTE_TO +} // namespace detail + +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// Unsigned->Unsigned or Signed->Signed +template , typename TV = TFromV, + hwy::EnableIf() && IsInteger() && + (IsSigned() == IsSigned())>* = nullptr> +HWY_API VFromD PromoteUpperTo(D d, V v) { + if (detail::IsFull(d)) { + return detail::PromoteUpperTo(d, v); + } + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// Differing signs or either is float +template , typename TV = TFromV, + hwy::EnableIf() || !IsInteger() || + (IsSigned() != IsSigned())>* = nullptr> +HWY_API VFromD PromoteUpperTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// ------------------------------ DemoteTo U + +namespace detail { + +// Saturates unsigned vectors to half/quarter-width TN. +template +VU SaturateU(VU v) { + return detail::MinN(v, static_cast>(LimitsMax())); +} + +// Saturates unsigned vectors to half/quarter-width TN. +template +VI SaturateI(VI v) { + return detail::MinN(detail::MaxN(v, LimitsMin()), LimitsMax()); +} + +} // namespace detail + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint16_t v) { +#if HWY_SVE_HAVE_2 + const svuint8_t vn = BitCast(dn, svqxtunb_s16(v)); +#else + const DFromV di; + const RebindToUnsigned du; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint8_t vn = BitCast(dn, detail::SaturateU(clamped)); +#endif + return svuzp1_u8(vn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svint32_t v) { +#if HWY_SVE_HAVE_2 + const svuint16_t vn = BitCast(dn, svqxtunb_s32(v)); +#else + const DFromV di; + const RebindToUnsigned du; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint16_t vn = BitCast(dn, detail::SaturateU(clamped)); +#endif + return svuzp1_u16(vn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint32_t v) { + const DFromV di; + const RebindToUnsigned du; + const RepartitionToNarrow d2; +#if HWY_SVE_HAVE_2 + const svuint16_t cast16 = BitCast(d2, svqxtnb_u16(svqxtunb_s32(v))); +#else + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and quarter the width. + const svuint16_t cast16 = BitCast(d2, detail::SaturateU(clamped)); +#endif + const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16)); + return svuzp1_u8(x2, x2); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svuint16_t v) { +#if HWY_SVE_HAVE_2 + const svuint8_t vn = BitCast(dn, svqxtnb_u16(v)); +#else + using TN = TFromD; + const svuint8_t vn = BitCast(dn, detail::SaturateU(v)); +#endif + return svuzp1_u8(vn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svuint32_t v) { +#if HWY_SVE_HAVE_2 + const svuint16_t vn = BitCast(dn, svqxtnb_u32(v)); +#else + using TN = TFromD; + const svuint16_t vn = BitCast(dn, detail::SaturateU(v)); +#endif + return svuzp1_u16(vn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svuint32_t v) { + using TN = TFromD; + return U8FromU32(detail::SaturateU(v)); +} + +// ------------------------------ Truncations + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + const svuint8_t v3 = svuzp1_u8(v2, v2); + return svuzp1_u8(v3, v3); +} + +template +HWY_API svuint16_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint16_t v1 = BitCast(d, v); + const svuint16_t v2 = svuzp1_u16(v1, v1); + return svuzp1_u16(v2, v2); +} + +template +HWY_API svuint32_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint32_t v1 = BitCast(d, v); + return svuzp1_u32(v1, v1); +} + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint32_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + return svuzp1_u8(v2, v2); +} + +template +HWY_API svuint16_t TruncateTo(Simd /* tag */, + const svuint32_t v) { + const DFromV d; + const svuint16_t v1 = BitCast(d, v); + return svuzp1_u16(v1, v1); +} + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint16_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + return svuzp1_u8(v1, v1); +} + +// ------------------------------ DemoteTo I + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint16_t v) { +#if HWY_SVE_HAVE_2 + const svint8_t vn = BitCast(dn, svqxtnb_s16(v)); +#else + using TN = TFromD; + const svint8_t vn = BitCast(dn, detail::SaturateI(v)); +#endif + return svuzp1_s8(vn, vn); +} + +template +HWY_API svint16_t DemoteTo(Simd dn, const svint32_t v) { +#if HWY_SVE_HAVE_2 + const svint16_t vn = BitCast(dn, svqxtnb_s32(v)); +#else + using TN = TFromD; + const svint16_t vn = BitCast(dn, detail::SaturateI(v)); +#endif + return svuzp1_s16(vn, vn); +} + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint32_t v) { + const RepartitionToWide d2; +#if HWY_SVE_HAVE_2 + const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v))); +#else + using TN = TFromD; + const svint16_t cast16 = BitCast(d2, detail::SaturateI(v)); +#endif + const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16)); + return BitCast(dn, svuzp1_s8(v2, v2)); +} + +// ------------------------------ I64/U64 DemoteTo + +template +HWY_API svint32_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; + const RebindToUnsigned dn_u; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_s64(v)); +#else + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateI(v)); +#endif + return BitCast(dn, TruncateTo(dn_u, vn)); +} + +template +HWY_API svint16_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; + const RebindToUnsigned dn_u; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_s32(svqxtnb_s64(v))); +#else + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateI(v)); +#endif + return BitCast(dn, TruncateTo(dn_u, vn)); +} + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; + const RebindToUnsigned dn_u; + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateI(v)); + return BitCast(dn, TruncateTo(dn_u, vn)); +} + +template +HWY_API svuint32_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtunb_s64(v)); +#else + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); + // Saturate to unsigned-max + const svuint64_t vn = detail::SaturateU(clamped); +#endif + return TruncateTo(dn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtunb_s64(v))); +#else + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); + // Saturate to unsigned-max + const svuint64_t vn = detail::SaturateU(clamped); +#endif + return TruncateTo(dn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); + // Saturate to unsigned-max + const svuint64_t vn = detail::SaturateU(clamped); + return TruncateTo(dn, vn); +} + +template +HWY_API svuint32_t DemoteTo(Simd dn, const svuint64_t v) { + const Rebind du64; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_u64(v)); +#else + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateU(v)); +#endif + return TruncateTo(dn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svuint64_t v) { + const Rebind du64; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtnb_u64(v))); +#else + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateU(v)); +#endif + return TruncateTo(dn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svuint64_t v) { + const Rebind du64; + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateU(v)); + return TruncateTo(dn, vn); +} + +// ------------------------------ Unsigned to signed demotions + +// Disable the default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h on SVE/SVE2 as the SVE/SVE2 targets have +// target-specific implementations of the unsigned to signed DemoteTo and +// ReorderDemote2To ops + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + +template ) - 1)> +HWY_API VFromD DemoteTo(D dn, V v) { + const RebindToUnsigned dn_u; + return BitCast(dn, TruncateTo(dn_u, detail::SaturateU>(v))); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +// Signed to signed PromoteEvenTo: 1 instruction instead of 2 in generic-inl.h. +// Might as well also enable unsigned to unsigned, though it is just an And. +namespace detail { +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extb) +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, exth) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extw) +} // namespace detail + +#include "third_party/highway/hwy/ops/inside-inl.h" + +// ------------------------------ DemoteTo F + +// We already toggled HWY_NATIVE_F16C above. + +template +HWY_API svfloat16_t DemoteTo(Simd d, const svfloat32_t v) { + const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +template +HWY_API svfloat16_t DemoteTo(Simd d, const svfloat64_t v) { + const svfloat16_t in_lo16 = svcvt_f16_f64_x(detail::PTrue(d), v); + const svfloat16_t in_even = detail::ConcatEvenFull(in_lo16, in_lo16); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +#if !HWY_SVE_HAVE_F32_TO_BF16C +namespace detail { + +// Round a F32 value to the nearest BF16 value, with the result returned as the +// rounded F32 value bitcasted to an U32 + +// RoundF32ForDemoteToBF16 also converts NaN values to QNaN values to prevent +// NaN F32 values from being converted to an infinity +HWY_INLINE svuint32_t RoundF32ForDemoteToBF16(svfloat32_t v) { + const DFromV df32; + const RebindToUnsigned du32; + + const auto is_non_nan = Eq(v, v); + const auto bits32 = BitCast(du32, v); + + const auto round_incr = + detail::AddN(detail::AndN(ShiftRight<16>(bits32), 1u), 0x7FFFu); + return MaskedAddOr(detail::OrN(bits32, 0x00400000u), is_non_nan, bits32, + round_incr); +} + +} // namespace detail +#endif // !HWY_SVE_HAVE_F32_TO_BF16C + +template +HWY_API VBF16 DemoteTo(Simd dbf16, svfloat32_t v) { +#if HWY_SVE_HAVE_F32_TO_BF16C + const VBF16 in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), v); + return detail::ConcatEvenFull(in_even, in_even); +#else + const svuint16_t in_odd = + BitCast(ScalableTag(), detail::RoundF32ForDemoteToBF16(v)); + return BitCast(dbf16, detail::ConcatOddFull(in_odd, in_odd)); // lower half +#endif +} + +template +HWY_API svfloat32_t DemoteTo(Simd d, const svfloat64_t v) { + const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svint32_t DemoteTo(Simd d, const svfloat64_t v) { + const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svuint32_t DemoteTo(Simd d, const svfloat64_t v) { + const svuint32_t in_even = svcvt_u32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svfloat32_t DemoteTo(Simd d, const svint64_t v) { + const svfloat32_t in_even = svcvt_f32_s64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svfloat32_t DemoteTo(Simd d, const svuint64_t v) { + const svfloat32_t in_even = svcvt_f32_u64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +// ------------------------------ ConvertTo F + +#define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \ + /* Float from signed */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Float from unsigned */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(uint, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_u##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Signed from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(int, BITS) \ + NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Unsigned from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(uint, BITS) \ + NAME(HWY_SVE_D(uint, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) +#undef HWY_SVE_CONVERT + +// ------------------------------ MaskedConvertTo F + +#define HWY_SVE_MASKED_CONVERT_TO_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \ + /* Float from signed */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(int, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_s##BITS##_z(m, v); \ + } \ + /* Float from unsigned */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(uint, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_u##BITS##_z(m, v); \ + } \ + /* Signed from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(int, BITS) \ + NAME(svbool_t m, HWY_SVE_D(int, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_s##BITS##_##CHAR##BITS##_z(m, v); \ + } \ + /* Unsigned from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(uint, BITS) \ + NAME(svbool_t m, HWY_SVE_D(uint, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u##BITS##_##CHAR##BITS##_z(m, v); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_MASKED_CONVERT_TO_OR_ZERO, MaskedConvertTo, cvt) +#undef HWY_SVE_MASKED_CONVERT_TO_OR_ZERO + +// ------------------------------ NearestInt (Round, ConvertTo) +template >> +HWY_API VFromD NearestInt(VF v) { + // No single instruction, round then truncate. + return ConvertTo(DI(), Round(v)); +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + // No single instruction, round then demote. + return DemoteTo(di32, Round(v)); +} + +// ------------------------------ Iota (AddN, ConvertTo) + +#define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, T2 first) { \ + return sv##OP##_##CHAR##BITS( \ + ConvertScalarTo(first), 1); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index) +#undef HWY_SVE_IOTA + +template , typename T2, HWY_IF_FLOAT(T)> +HWY_API VFromD Iota(const D d, T2 first) { + const RebindToSigned di; + const T first_f = ConvertScalarTo(first); + const VFromD iota_f = ConvertTo(d, Iota(di, 0)); + return detail::AddN(iota_f, first_f); +} + +// ================================================== LANE ACCESS + +// ------------------------------ ExtractLane (GetLaneM, FirstN) +template +HWY_API TFromV ExtractLane(V v, size_t i) { + return detail::GetLaneM(v, FirstN(DFromV(), i)); +} + +// ------------------------------ InsertLane (IfThenElse, EqN) +template +HWY_API V InsertLane(const V v, size_t i, T t) { + static_assert(sizeof(TFromV) == sizeof(T), "Lane size mismatch"); + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + const svbool_t is_i = detail::EqN(Iota(di, 0), static_cast(i)); + // The actual type may be int16_t for special floats; copy, not cast. + TFromV t_bits; + hwy::CopySameSize(&t, &t_bits); + return IfThenElse(RebindMask(d, is_i), Set(d, t_bits), v); +} + +// ------------------------------ GetExponent + +#if HWY_SVE_HAVE_2 || HWY_IDE +#ifdef HWY_NATIVE_GET_EXPONENT +#undef HWY_NATIVE_GET_EXPONENT +#else +#define HWY_NATIVE_GET_EXPONENT +#endif + +namespace detail { +#define HWY_SVE_GET_EXP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(int, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } +HWY_SVE_FOREACH_F(HWY_SVE_GET_EXP, GetExponent, logb) +#undef HWY_SVE_GET_EXP +} // namespace detail + +template +HWY_API V GetExponent(V v) { + const DFromV d; + const RebindToSigned di; + const VFromD exponent_int = detail::GetExponent(v); + // convert integer to original type + return ConvertTo(d, exponent_int); +} +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ InterleaveLower + +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipLowerSame(a, b); +#else + // Move lower halves of blocks to lower half of vector. + const Repartition d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatEvenFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatEvenFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +template +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV(), a, b); +} + +// ------------------------------ InterleaveUpper + +// Only use zip2 if vector are a powers of two, otherwise getting the actual +// "upper half" requires MaskUpperHalf. +namespace detail { +// Unlike Highway's ZipUpper, this returns the same type. +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipUpperSame, zip2) +} // namespace detail + +// Full vector: guaranteed to have at least one block +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipUpperSame(a, b); +#else + // Move upper halves of blocks to lower half of vector. + const Repartition d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatOddFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatOddFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +// Capped/fraction: need runtime check +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { + // Less than one block: treat as capped + if (Lanes(d) * sizeof(TFromD) < 16) { + const Half d2; + return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); + } + return InterleaveUpper(DFromV(), a, b); +} + +// ------------------------------ InterleaveWholeLower +#ifdef HWY_NATIVE_INTERLEAVE_WHOLE +#undef HWY_NATIVE_INTERLEAVE_WHOLE +#else +#define HWY_NATIVE_INTERLEAVE_WHOLE +#endif + +template +HWY_API VFromD InterleaveWholeLower(D /*d*/, VFromD a, VFromD b) { + return detail::ZipLowerSame(a, b); +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + if (HWY_SVE_IS_POW2 && detail::IsFull(d)) { + return detail::ZipUpperSame(a, b); + } + + const Half d2; + return InterleaveWholeLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); +} + +// ------------------------------ Per4LaneBlockShuffle + +namespace detail { + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x88> /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + const auto evens = BitCast(dw, ConcatEvenFull(v, v)); + return BitCast(d, ZipLowerSame(evens, evens)); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xDD> /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + const auto odds = BitCast(dw, ConcatOddFull(v, v)); + return BitCast(d, ZipLowerSame(odds, odds)); +} + +} // namespace detail + +// ================================================== COMBINE + +namespace detail { + +#if (HWY_TARGET == HWY_SVE_256 && HWY_HAVE_CONSTEXPR_LANES) || HWY_IDE +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 32: + return svptrue_pat_b8(SV_VL16); + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + default: + return svptrue_pat_b8(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 16: + return svptrue_pat_b16(SV_VL8); + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + default: + return svptrue_pat_b16(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 8: + return svptrue_pat_b32(SV_VL4); + case 4: + return svptrue_pat_b32(SV_VL2); + default: + return svptrue_pat_b32(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 4: + return svptrue_pat_b64(SV_VL2); + default: + return svptrue_pat_b64(SV_VL1); + } +} +#endif +#if (HWY_TARGET == HWY_SVE2_128 && HWY_HAVE_CONSTEXPR_LANES) || HWY_IDE +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b8(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b16(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + return svptrue_pat_b32(MaxLanes(d) == 4 ? SV_VL2 : SV_VL1); +} +template +svbool_t MaskLowerHalf(D /*d*/) { + return svptrue_pat_b64(SV_VL1); +} +#endif // HWY_TARGET == HWY_SVE2_128 +#if (HWY_TARGET != HWY_SVE_256 && HWY_TARGET != HWY_SVE2_128) || \ + !HWY_HAVE_CONSTEXPR_LANES +template +svbool_t MaskLowerHalf(D d) { + return FirstN(d, Lanes(d) / 2); +} +#endif + +template +svbool_t MaskUpperHalf(D d) { + // TODO(janwas): WHILEGE on SVE2 + if (HWY_SVE_IS_POW2 && IsFull(d)) { + return Not(MaskLowerHalf(d)); + } + + // For Splice to work as intended, make sure bits above Lanes(d) are zero. + return AndNot(MaskLowerHalf(d), detail::MakeMask(d)); +} + +// Right-shift vector pair by constexpr; can be used to slide down (=N) or up +// (=Lanes()-N). +#define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi, kIndex); \ + } +HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext) +#undef HWY_SVE_EXT + +} // namespace detail + +// ------------------------------ ConcatUpperLower +template +HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) { + return IfThenElse(detail::MaskLowerHalf(d), lo, hi); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatEvenBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveLower(du64, lo64, BitCast(du64, hi))); +#endif + } + return detail::Splice(hi, lo, detail::MaskLowerHalf(d)); +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) { +#if HWY_HAVE_CONSTEXPR_LANES + if (detail::IsFull(d)) { + return detail::Ext(hi, lo); + } +#endif + return detail::Splice(hi, lo, detail::MaskUpperHalf(d)); +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatOddBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveUpper(du64, lo64, BitCast(du64, hi))); +#endif + } + const svbool_t mask_upper = detail::MaskUpperHalf(d); + const V lo_upper = detail::Splice(lo, lo, mask_upper); + return IfThenElse(mask_upper, hi, lo_upper); +} + +// ------------------------------ Combine +template +HWY_API VFromD Combine(const D d, const V2 hi, const V2 lo) { + return ConcatLowerLower(d, hi, lo); +} + +// ------------------------------ ZeroExtendVector +template +HWY_API V ZeroExtendVector(const D d, const V lo) { + return Combine(d, Zero(Half()), lo); +} + +// ------------------------------ Lower/UpperHalf + +template +HWY_API V LowerHalf(D2 /* tag */, const V v) { + return v; +} + +template +HWY_API V LowerHalf(const V v) { + return v; +} + +template +HWY_API V UpperHalf(const DH dh, const V v) { + const Twice d; + // Cast so that we support bfloat16_t. + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); +#if HWY_HAVE_CONSTEXPR_LANES + return BitCast(d, detail::Ext(vu, vu)); +#else + const MFromD mask = detail::MaskUpperHalf(du); + return BitCast(d, detail::Splice(vu, vu, mask)); +#endif +} + +// ================================================== SWIZZLE + +// ------------------------------ DupEven + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1) +} // namespace detail + +template +HWY_API V DupEven(const V v) { + return detail::InterleaveEven(v, v); +} + +// ------------------------------ DupOdd + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2) +} // namespace detail + +template +HWY_API V DupOdd(const V v) { + return detail::InterleaveOdd(v, v); +} + +// ------------------------------ OddEven + +#if HWY_SVE_HAVE_2 + +#define HWY_SVE_ODD_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) odd, HWY_SVE_V(BASE, BITS) even) { \ + return sv##OP##_##CHAR##BITS(even, odd, /*xor=*/0); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ODD_EVEN, OddEven, eortb_n) +#undef HWY_SVE_ODD_EVEN + +template +HWY_API V OddEven(const V odd, const V even) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, OddEven(BitCast(du, odd), BitCast(du, even))); +} + +#else + +template +HWY_API V OddEven(const V odd, const V even) { + const auto odd_in_even = detail::Ext<1>(odd, odd); + return detail::InterleaveEven(even, odd_in_even); +} + +#endif // HWY_TARGET + +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return detail::InterleaveEven(a, b); +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return detail::InterleaveOdd(a, b); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API V OddEvenBlocks(const V odd, const V even) { + const DFromV d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatUpperLower(d, odd, even); +#elif HWY_TARGET == HWY_SVE2_128 + (void)odd; + (void)d; + return even; +#else + const RebindToUnsigned du; + using TU = TFromD; + constexpr size_t kShift = CeilLog2(16 / sizeof(TU)); + const auto idx_block = ShiftRight(Iota(du, 0)); + const auto lsb = detail::AndN(idx_block, static_cast(1)); + const svbool_t is_even = detail::EqN(lsb, static_cast(0)); + return IfThenElse(is_even, even, odd); +#endif +} + +// ------------------------------ TableLookupLanes + +template +HWY_API VFromD> IndicesFromVec(D d, VI vec) { + using TI = TFromV; + static_assert(sizeof(TFromD) == sizeof(TI), "Index/lane size mismatch"); + const RebindToUnsigned du; + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + using TU = MakeUnsigned; + const size_t twice_max_lanes = Lanes(d) * 2; + HWY_DASSERT(AllTrue( + du, Eq(indices, + detail::AndN(indices, static_cast(twice_max_lanes - 1))))); +#else + (void)d; +#endif + return indices; +} + +template +HWY_API VFromD> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +#define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \ + return sv##OP##_##CHAR##BITS(v, idx); \ + } + +HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE, TableLookupLanes, tbl) +#endif +#undef HWY_SVE_TABLE + +#if HWY_SVE_HAVE_2 +namespace detail { +#define HWY_SVE_TABLE2(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(uint, BITS) idx) { \ + return sv##OP##_##CHAR##BITS(tuple, idx); \ + } + +HWY_SVE_FOREACH(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, tbl2) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, + tbl2) +#endif +#undef HWY_SVE_TABLE +} // namespace detail +#endif // HWY_SVE_HAVE_2 + +template +HWY_API VFromD TwoTablesLookupLanes(D d, VFromD a, VFromD b, + VFromD> idx) { + // SVE2 has an instruction for this, but it only works for full 2^n vectors. +#if HWY_SVE_HAVE_2 && HWY_SVE_IS_POW2 + if (detail::IsFull(d)) { + return detail::NativeTwoTableLookupLanes(Create2(d, a, b), idx); + } +#endif + const RebindToUnsigned du; + using TU = TFromD; + + const size_t num_of_lanes = Lanes(d); + const auto idx_mod = detail::AndN(idx, static_cast(num_of_lanes - 1)); + const auto sel_a_mask = Eq(idx, idx_mod); + + const auto a_lookup_result = TableLookupLanes(a, idx_mod); + const auto b_lookup_result = TableLookupLanes(b, idx_mod); + return IfThenElse(sel_a_mask, a_lookup_result, b_lookup_result); +} + +template +HWY_API V TwoTablesLookupLanes(V a, V b, + VFromD>> idx) { + const DFromV d; + return TwoTablesLookupLanes(d, a, b, idx); +} + +// ------------------------------ SlideUpLanes (FirstN) +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + return detail::Splice(v, Zero(d), FirstN(d, amt)); +} + +// ------------------------------ Slide1Up + +#ifdef HWY_NATIVE_SLIDE1_UP_DOWN +#undef HWY_NATIVE_SLIDE1_UP_DOWN +#else +#define HWY_NATIVE_SLIDE1_UP_DOWN +#endif + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + return SlideUpLanes(d, v, 1); +} + +// ------------------------------ SlideDownLanes (TableLookupLanes) +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + const RebindToUnsigned du; + using TU = TFromD; + const auto idx = Iota(du, static_cast(amt)); + return IfThenElseZero(FirstN(d, Lanes(d) - amt), TableLookupLanes(v, idx)); +} + +// ------------------------------ Slide1Down +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + return SlideDownLanes(d, v, 1); +} + +// ------------------------------ SwapAdjacentBlocks (TableLookupLanes) + +namespace detail { + +template +constexpr size_t LanesPerBlock(Simd d) { + // We might have a capped vector smaller than a block, so honor that. + return HWY_MIN(16 / sizeof(T), MaxLanes(d)); +} + +} // namespace detail + +template +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatLowerUpper(d, v, v); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#else + const RebindToUnsigned du; + constexpr auto kLanesPerBlock = + static_cast>(detail::LanesPerBlock(d)); + const VFromD idx = detail::XorN(Iota(du, 0), kLanesPerBlock); + return TableLookupLanes(v, idx); +#endif +} + +// ------------------------------ InterleaveEvenBlocks +// (ConcatLowerLower, SlideUpLanes, OddEvenBlocks) + +template > +HWY_API V InterleaveEvenBlocks(D d, V a, V b) { +#if HWY_TARGET == HWY_SVE_256 + return ConcatLowerLower(d, b, a); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + (void)b; + return a; +#else + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + return OddEvenBlocks(SlideUpLanes(d, b, kLanesPerBlock), a); +#endif +} + +// ------------------------------ InterleaveOddBlocks +// (ConcatUpperUpper, SlideDownLanes, OddEvenBlocks) + +template > +HWY_API V InterleaveOddBlocks(D d, V a, V b) { +#if HWY_TARGET == HWY_SVE_256 + return ConcatUpperUpper(d, b, a); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + (void)b; + return a; +#else + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + return OddEvenBlocks(b, SlideDownLanes(d, a, kLanesPerBlock)); +#endif +} + +// ------------------------------ Reverse + +namespace detail { + +#define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_REVERSE, ReverseFull, rev) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_REVERSE, ReverseFull, rev) +#endif +#undef HWY_SVE_REVERSE + +} // namespace detail + +template +HWY_API V Reverse(D d, V v) { + using T = TFromD; + const auto reversed = detail::ReverseFull(v); + if (HWY_SVE_IS_POW2 && detail::IsFull(d)) return reversed; + // Shift right to remove extra (non-pow2 and remainder) lanes. + // TODO(janwas): on SVE2, use WHILEGE. + // Avoids FirstN truncating to the return vector size. Must also avoid Not + // because that is limited to SV_POW2. + const ScalableTag dfull; + const svbool_t all_true = detail::AllPTrue(dfull); + const size_t all_lanes = detail::AllHardwareLanes(); + const size_t want_lanes = Lanes(d); + HWY_DASSERT(want_lanes <= all_lanes); + const svbool_t mask = + svnot_b_z(all_true, FirstN(dfull, all_lanes - want_lanes)); + return detail::Splice(reversed, reversed, mask); +} + +// ------------------------------ Reverse2 + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevb_u16_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { // 3210 +#if HWY_TARGET == HWY_SVE2_128 + if (detail::IsFull(d)) { + return detail::Ext<1>(v, v); + } +#endif + (void)d; + const auto odd_in_even = detail::Ext<1>(v, v); // x321 + return detail::InterleaveEven(odd_in_even, v); // 2301 +} + +// ------------------------------ Reverse4 (TableLookupLanes) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWideX2 du32; + return BitCast(d, svrevb_u32_x(detail::PTrue(d), BitCast(du32, v))); +} + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWideX2 du64; + return BitCast(d, svrevh_u64_x(detail::PTrue(d), BitCast(du64, v))); +} + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + if (HWY_TARGET == HWY_SVE2_128 && detail::IsFull(d)) { + return detail::ReverseFull(v); + } + // TODO(janwas): is this approach faster than Shuffle0123? + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 3); + return TableLookupLanes(v, idx); +} + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + if (HWY_TARGET == HWY_SVE_256 && detail::IsFull(d)) { + return detail::ReverseFull(v); + } + // TODO(janwas): is this approach faster than Shuffle0123? + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 3); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Reverse8 (TableLookupLanes) + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const Repartition du64; + return BitCast(d, svrevb_u64_x(detail::PTrue(d), BitCast(du64, v))); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 7); + return TableLookupLanes(v, idx); +} + +// ------------------------------- ReverseBits + +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +#ifdef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#undef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#else +#define HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#endif + +#define HWY_SVE_REVERSE_BITS(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + const DFromV d; \ + return sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_REVERSE_BITS, ReverseBits, rbit) +#undef HWY_SVE_REVERSE_BITS + +// ------------------------------ Block insert/extract/broadcast ops +#if HWY_TARGET != HWY_SVE2_128 + +#ifdef HWY_NATIVE_BLK_INSERT_EXTRACT +#undef HWY_NATIVE_BLK_INSERT_EXTRACT +#else +#define HWY_NATIVE_BLK_INSERT_EXTRACT +#endif + +template +HWY_API V InsertBlock(V v, V blk_to_insert) { + const DFromV d; + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + +#if HWY_TARGET == HWY_SVE_256 + return (kBlockIdx == 0) ? ConcatUpperLower(d, v, blk_to_insert) + : ConcatLowerLower(d, blk_to_insert, v); +#else + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + + constexpr size_t kBlockOffset = + static_cast(kBlockIdx) * kLanesPerBlock; + const auto splice_mask = FirstN(d, kBlockOffset); + const auto sel_lo_mask = FirstN(d, kBlockOffset + kLanesPerBlock); + + const auto splice_result = detail::Splice(blk_to_insert, v, splice_mask); + return IfThenElse(sel_lo_mask, splice_result, v); +#endif +} + +template +HWY_API V ExtractBlock(V v) { + const DFromV d; + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + + if (kBlockIdx == 0) return v; + +#if HWY_TARGET == HWY_SVE_256 + return UpperHalf(Half(), v); +#else + const RebindToUnsigned du; + using TU = TFromD; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + constexpr size_t kBlockOffset = + static_cast(kBlockIdx) * kLanesPerBlock; + const auto splice_mask = + RebindMask(d, detail::LtN(Iota(du, static_cast(0u - kBlockOffset)), + static_cast(kLanesPerBlock))); + return detail::Splice(v, v, splice_mask); +#endif +} + +template +HWY_API V BroadcastBlock(V v) { + const DFromV d; + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + + const RebindToUnsigned du; // for bfloat16_t + using VU = VFromD; + const VU vu = BitCast(du, v); + +#if HWY_TARGET == HWY_SVE_256 + return BitCast(d, (kBlockIdx == 0) ? ConcatLowerLower(du, vu, vu) + : ConcatUpperUpper(du, vu, vu)); +#else + using TU = TFromD; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + constexpr size_t kBlockOffset = + static_cast(kBlockIdx) * kLanesPerBlock; + + const VU idx = detail::AddN( + detail::AndN(Iota(du, TU{0}), static_cast(kLanesPerBlock - 1)), + static_cast(kBlockOffset)); + return BitCast(d, TableLookupLanes(vu, idx)); +#endif +} + +#endif // HWY_TARGET != HWY_SVE2_128 + +// ------------------------------ Compress (PromoteTo) + +template +struct CompressIsPartition { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + // Optimization for 64-bit lanes (could also be applied to 32-bit, but that + // requires a larger table). + enum { value = (sizeof(T) == 8) }; +#else + enum { value = 0 }; +#endif // HWY_TARGET == HWY_SVE_256 +}; + +#define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +HWY_SVE_FOREACH_UI32(HWY_SVE_COMPRESS, Compress, compact) +HWY_SVE_FOREACH_F32(HWY_SVE_COMPRESS, Compress, compact) +#else +HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact) +#endif +#undef HWY_SVE_COMPRESS + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +HWY_API V Compress(V v, svbool_t mask) { + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompress64x4Tables + 0, 1, 2, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 2, 0, 1, 3, 0, 2, + 1, 3, 1, 2, 0, 3, 0, 1, 2, 3, 3, 0, 1, 2, 0, 3, 1, 2, 1, 3, 0, 2, + 0, 1, 3, 2, 2, 3, 0, 1, 0, 2, 3, 1, 1, 2, 3, 0, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +} + +#endif // HWY_TARGET == HWY_SVE_256 +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE +template +HWY_API V Compress(V v, svbool_t mask) { + // If mask == 10: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot keep 10 + // unchanged and map everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(maskLL, mask)); +} + +#endif // HWY_TARGET == HWY_SVE2_128 + +template +HWY_API V Compress(V v, svbool_t mask16) { + static_assert(!IsSame(), "Must use overload"); + const DFromV d16; + + // Promote vector and mask to 32-bit + const RepartitionToWide dw; + const auto v32L = PromoteTo(dw, v); + const auto v32H = detail::PromoteUpperTo(dw, v); + const svbool_t mask32L = svunpklo_b(mask16); + const svbool_t mask32H = svunpkhi_b(mask16); + + const auto compressedL = Compress(v32L, mask32L); + const auto compressedH = Compress(v32H, mask32H); + + // Demote to 16-bit (already in range) - separately so we can splice + const V evenL = BitCast(d16, compressedL); + const V evenH = BitCast(d16, compressedH); + const V v16L = detail::ConcatEvenFull(evenL, evenL); // lower half + const V v16H = detail::ConcatEvenFull(evenH, evenH); + + // We need to combine two vectors of non-constexpr length, so the only option + // is Splice, which requires us to synthesize a mask. NOTE: this function uses + // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt. + const size_t countL = detail::CountTrueFull(dw, mask32L); + const auto compressed_maskL = FirstN(d16, countL); + return detail::Splice(v16H, v16L, compressed_maskL); +} + +// Must treat float16_t as integers so we can ConcatEven. +HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) { + const DFromV df; + const RebindToSigned di; + return BitCast(df, Compress(BitCast(di, v), mask16)); +} + +// ------------------------------ CompressNot + +// 2 or 4 bytes +template +HWY_API V CompressNot(V v, const svbool_t mask) { + return Compress(v, Not(mask)); +} + +template +HWY_API V CompressNot(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE + // If mask == 01: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot map + // 01 to 10, and everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(mask, maskLL)); +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompressNot64x4Tables + 0, 1, 2, 3, 1, 2, 3, 0, 0, 2, 3, 1, 2, 3, 0, 1, 0, 1, 3, 2, 1, 3, + 0, 2, 0, 3, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 0, 3, 0, 2, 1, 3, + 2, 0, 1, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif // HWY_TARGET == HWY_SVE_256 + + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API svuint64_t CompressBlocksNot(svuint64_t v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 + (void)mask; + return v; +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + uint64_t bits = 0; // predicate reg is 32-bit + CopyBytes<4>(&mask, &bits); // not same size - 64-bit more efficient + // Concatenate LSB for upper and lower blocks, pre-scale by 4 for table idx. + const size_t offset = ((bits & 1) ? 4u : 0u) + ((bits & 0x10000) ? 8u : 0u); + // See CompressIsPartition. Manually generated; flip halves if mask = [0, 1]. + alignas(16) static constexpr uint64_t table[4 * 4] = {0, 1, 2, 3, 2, 3, 0, 1, + 0, 1, 2, 3, 0, 1, 2, 3}; + const ScalableTag d; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif + + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const V v, const svbool_t mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const V v, const svbool_t mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + const svbool_t store_mask = FirstN(d, count); + BlendedStore(Compress(v, mask), store_mask, d, unaligned); + return count; +} + +// ================================================== MASK (2) + +// ------------------------------ FindKnownLastTrue +template +HWY_API size_t FindKnownLastTrue(D d, svbool_t m) { + const RebindToUnsigned du; + return static_cast(detail::ExtractLastMatchingLaneM( + Iota(du, 0), And(m, detail::MakeMask(d)))); +} + +// ------------------------------ FindLastTrue +template +HWY_API intptr_t FindLastTrue(D d, svbool_t m) { + return AllFalse(d, m) ? intptr_t{-1} + : static_cast(FindKnownLastTrue(d, m)); +} + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes + +// Prevent accidentally using these for 128-bit vectors - should not be +// necessary. +#if HWY_TARGET != HWY_SVE2_128 +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned>; + return detail::AndNotN(static_cast(LanesPerBlock(d) - 1), iota0); +} + +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint8_t idx_mod = + svdupq_n_u8(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock, 8 % kLanesPerBlock, + 9 % kLanesPerBlock, 10 % kLanesPerBlock, 11 % kLanesPerBlock, + 12 % kLanesPerBlock, 13 % kLanesPerBlock, 14 % kLanesPerBlock, + 15 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint16_t idx_mod = + svdupq_n_u16(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint32_t idx_mod = + svdupq_n_u32(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint64_t idx_mod = + svdupq_n_u64(0 % kLanesPerBlock, 1 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_SVE2_128 + +template > +HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) { + const Repartition d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, detail::Ext(hi8, lo8)); +#else + const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes)); + const auto lo_down = detail::Ext(lo8, lo8); + const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +#endif +} + +// ------------------------------ Shuffle2301 +template +HWY_API V Shuffle2301(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return Reverse2(d, v); +} + +// ------------------------------ Shuffle2103 +template +HWY_API V Shuffle2103(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0321 +template +HWY_API V Shuffle0321(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8)); +} + +// ------------------------------ Shuffle1032 +template +HWY_API V Shuffle1032(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle01 +template +HWY_API V Shuffle01(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 8, "Defined for 64-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0123 +template +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { +#if HWY_TARGET == HWY_SVE_256 + if (detail::IsFull(d)) { + return SwapAdjacentBlocks(v); + } else if (detail::IsFull(Twice())) { + return v; + } +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#endif + const Repartition du64; + return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v)))); +} + +// ------------------------------ TableLookupBytes + +template +HWY_API VI TableLookupBytes(const V v, const VI idx) { + const DFromV d; + const Repartition du8; +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, TableLookupLanes(BitCast(du8, v), BitCast(du8, idx))); +#else + const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0)); + const auto idx8 = Add(BitCast(du8, idx), offsets128); + return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8)); +#endif +} + +template +HWY_API VI TableLookupBytesOr0(const V v, const VI idx) { + const DFromV d; + // Mask size must match vector type, so cast everything to this type. + const Repartition di8; + + auto idx8 = BitCast(di8, idx); + const auto msb = detail::LtN(idx8, 0); + + const auto lookup = TableLookupBytes(BitCast(di8, v), idx8); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Broadcast + +#ifdef HWY_NATIVE_BROADCASTLANE +#undef HWY_NATIVE_BROADCASTLANE +#else +#define HWY_NATIVE_BROADCASTLANE +#endif + +namespace detail { +#define HWY_SVE_BROADCAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_INLINE HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v, kLane); \ + } + +HWY_SVE_FOREACH(HWY_SVE_BROADCAST, BroadcastLane, dup_lane) +#undef HWY_SVE_BROADCAST +} // namespace detail + +template +HWY_API V Broadcast(const V v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane"); +#if HWY_TARGET == HWY_SVE2_128 + return detail::BroadcastLane(v); +#else + auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0)); + if (kLane != 0) { + idx = detail::AddN(idx, kLane); + } + return TableLookupLanes(v, idx); +#endif +} + +template +HWY_API V BroadcastLane(const V v) { + static_assert(0 <= kLane && kLane < HWY_MAX_LANES_V(V), "Invalid lane"); + return detail::BroadcastLane(v); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API V ShiftLeftLanes(D d, const V v) { + const auto zero = Zero(d); + const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes)); +#if HWY_TARGET == HWY_SVE2_128 + return shifted; +#else + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + return IfThenElse(detail::FirstNPerBlock(d), zero, shifted); +#endif +} + +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightLanes +template > +HWY_API V ShiftRightLanes(D d, V v) { + // For capped/fractional vectors, clear upper lanes so we shift in zeros. + if (!detail::IsFull(d)) { + v = IfThenElseZero(detail::MakeMask(d), v); + } + +#if HWY_TARGET == HWY_SVE2_128 + return detail::Ext(Zero(d), v); +#else + const auto shifted = detail::Ext(v, v); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + const svbool_t mask = detail::FirstNPerBlock(d); + return IfThenElseZero(mask, shifted); +#endif +} + +// ------------------------------ ShiftLeftBytes + +template > +HWY_API V ShiftLeftBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftLeftLanes(BitCast(d8, v))); +} + +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftRightLanes(d8, BitCast(d8, v))); +} + +// ------------------------------ ZipLower + +template >> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(const V a, const V b) { + return BitCast(DW(), InterleaveLower(D(), a, b)); +} + +// ------------------------------ ZipUpper +template >> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== Ops with dependencies + +// ------------------------------ AddSub (Reverse2) + +// NOTE: svcadd_f*_x(HWY_SVE_PTRUE(BITS), a, b, 90) computes a[i] - b[i + 1] in +// the even lanes and a[i] + b[i - 1] in the odd lanes. + +#define HWY_SVE_ADDSUB_F(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + const DFromV d; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, Reverse2(d, b), \ + 90); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_ADDSUB_F, AddSub, cadd) + +#undef HWY_SVE_ADDSUB_F + +// NOTE: svcadd_s*(a, b, 90) and svcadd_u*(a, b, 90) compute a[i] - b[i + 1] in +// the even lanes and a[i] + b[i - 1] in the odd lanes. + +#if HWY_SVE_HAVE_2 +#define HWY_SVE_ADDSUB_UI(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + const DFromV d; \ + return sv##OP##_##CHAR##BITS(a, Reverse2(d, b), 90); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ADDSUB_UI, AddSub, cadd) + +#undef HWY_SVE_ADDSUB_UI + +// Disable the default implementation of AddSub in generic_ops-inl.h on SVE2 +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), \ + hwy::EnableIf()>* = nullptr + +#else // !HWY_SVE_HAVE_2 + +// Disable the default implementation of AddSub in generic_ops-inl.h for +// floating-point vectors on SVE, but enable the default implementation of +// AddSub in generic_ops-inl.h for integer vectors on SVE that do not support +// SVE2 +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ MulAddSub (AddSub) + +template , 1), HWY_IF_FLOAT_V(V)> +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + using T = TFromV; + + const DFromV d; + const T neg_zero = ConvertScalarTo(-0.0f); + + return MulAdd(mul, x, AddSub(Set(d, neg_zero), sub_or_add)); +} + +#if HWY_SVE_HAVE_2 + +// Disable the default implementation of MulAddSub in generic_ops-inl.h on SVE2 +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), \ + hwy::EnableIf()>* = nullptr + +template , 1), + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + const DFromV d; + return MulAdd(mul, x, AddSub(Zero(d), sub_or_add)); +} + +#else // !HWY_SVE_HAVE_2 + +// Disable the default implementation of MulAddSub in generic_ops-inl.h for +// floating-point vectors on SVE, but enable the default implementation of +// AddSub in generic_ops-inl.h for integer vectors on SVE targets that do not +// support SVE2 +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ PromoteTo bfloat16 (ZipLower) +template +HWY_API svfloat32_t PromoteTo(Simd df32, VBF16 v) { + const ScalableTag du16; + return BitCast(df32, detail::ZipLowerSame(svdup_n_u16(0), BitCast(du16, v))); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo (ConcatOddFull) + +namespace detail { + +// Signed to signed PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<2> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint8_t v) { + return svextb_s16_x(detail::PTrue(d_to), BitCast(d_to, v)); +} + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint16_t v) { + return svexth_s32_x(detail::PTrue(d_to), BitCast(d_to, v)); +} + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint32_t v) { + return svextw_s64_x(detail::PTrue(d_to), BitCast(d_to, v)); +} + +// F16->F32 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat16_t v) { + const Repartition d_from; + return svcvt_f32_f16_x(detail::PTrue(d_from), v); +} + +// F32->F64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat32_t v) { + const Repartition d_from; + return svcvt_f64_f32_x(detail::PTrue(d_from), v); +} + +// I32->F64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint32_t v) { + const Repartition d_from; + return svcvt_f64_s32_x(detail::PTrue(d_from), v); +} + +// U32->F64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, + svuint32_t v) { + const Repartition d_from; + return svcvt_f64_u32_x(detail::PTrue(d_from), v); +} + +// F32->I64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat32_t v) { + const Repartition d_from; + return svcvt_s64_f32_x(detail::PTrue(d_from), v); +} + +// F32->U64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat32_t v) { + const Repartition d_from; + return svcvt_u64_f32_x(detail::PTrue(d_from), v); +} + +// F16->F32 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag to_type_tag, + hwy::SizeTag<4> to_lane_size_tag, + hwy::FloatTag from_type_tag, D d_to, + svfloat16_t v) { + return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, + DupOdd(v)); +} + +// I32/U32/F32->F64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag to_type_tag, + hwy::SizeTag<8> to_lane_size_tag, + FromTypeTag from_type_tag, D d_to, V v) { + return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, + DupOdd(v)); +} + +// F32->I64/U64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(ToTypeTag to_type_tag, + hwy::SizeTag<8> to_lane_size_tag, + hwy::FloatTag from_type_tag, D d_to, + svfloat32_t v) { + return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, + DupOdd(v)); +} + +} // namespace detail + +// ------------------------------ ReorderDemote2To (OddEven) + +template +HWY_API VBF16 ReorderDemote2To(Simd dbf16, svfloat32_t a, + svfloat32_t b) { +#if HWY_SVE_HAVE_F32_TO_BF16C + const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); + return svcvtnt_bf16_f32_x(b_in_even, detail::PTrue(dbf16), a); +#else + (void)dbf16; + const auto a_in_odd = + BitCast(ScalableTag(), detail::RoundF32ForDemoteToBF16(a)); + const auto b_in_odd = + BitCast(ScalableTag(), detail::RoundF32ForDemoteToBF16(b)); + return BitCast(dbf16, detail::InterleaveOdd(b_in_odd, a_in_odd)); +#endif +} + +template +HWY_API svint16_t ReorderDemote2To(Simd d16, svint32_t a, + svint32_t b) { +#if HWY_SVE_HAVE_2 + (void)d16; + const svint16_t a_in_even = svqxtnb_s32(a); + return svqxtnt_s32(a_in_even, b); +#else + const svint16_t a16 = BitCast(d16, detail::SaturateI(a)); + const svint16_t b16 = BitCast(d16, detail::SaturateI(b)); + return detail::InterleaveEven(a16, b16); +#endif +} + +template +HWY_API svuint16_t ReorderDemote2To(Simd d16, svint32_t a, + svint32_t b) { +#if HWY_SVE_HAVE_2 + (void)d16; + const svuint16_t a_in_even = svqxtunb_s32(a); + return svqxtunt_s32(a_in_even, b); +#else + const Repartition du32; + const svuint32_t clamped_a = BitCast(du32, detail::MaxN(a, 0)); + const svuint32_t clamped_b = BitCast(du32, detail::MaxN(b, 0)); + const svuint16_t a16 = BitCast(d16, detail::SaturateU(clamped_a)); + const svuint16_t b16 = BitCast(d16, detail::SaturateU(clamped_b)); + return detail::InterleaveEven(a16, b16); +#endif +} + +template +HWY_API svuint16_t ReorderDemote2To(Simd d16, svuint32_t a, + svuint32_t b) { +#if HWY_SVE_HAVE_2 + (void)d16; + const svuint16_t a_in_even = svqxtnb_u32(a); + return svqxtnt_u32(a_in_even, b); +#else + const svuint16_t a16 = BitCast(d16, detail::SaturateU(a)); + const svuint16_t b16 = BitCast(d16, detail::SaturateU(b)); + return detail::InterleaveEven(a16, b16); +#endif +} + +template +HWY_API svint8_t ReorderDemote2To(Simd d8, svint16_t a, + svint16_t b) { +#if HWY_SVE_HAVE_2 + (void)d8; + const svint8_t a_in_even = svqxtnb_s16(a); + return svqxtnt_s16(a_in_even, b); +#else + const svint8_t a8 = BitCast(d8, detail::SaturateI(a)); + const svint8_t b8 = BitCast(d8, detail::SaturateI(b)); + return detail::InterleaveEven(a8, b8); +#endif +} + +template +HWY_API svuint8_t ReorderDemote2To(Simd d8, svint16_t a, + svint16_t b) { +#if HWY_SVE_HAVE_2 + (void)d8; + const svuint8_t a_in_even = svqxtunb_s16(a); + return svqxtunt_s16(a_in_even, b); +#else + const Repartition du16; + const svuint16_t clamped_a = BitCast(du16, detail::MaxN(a, 0)); + const svuint16_t clamped_b = BitCast(du16, detail::MaxN(b, 0)); + const svuint8_t a8 = BitCast(d8, detail::SaturateU(clamped_a)); + const svuint8_t b8 = BitCast(d8, detail::SaturateU(clamped_b)); + return detail::InterleaveEven(a8, b8); +#endif +} + +template +HWY_API svuint8_t ReorderDemote2To(Simd d8, svuint16_t a, + svuint16_t b) { +#if HWY_SVE_HAVE_2 + (void)d8; + const svuint8_t a_in_even = svqxtnb_u16(a); + return svqxtnt_u16(a_in_even, b); +#else + const svuint8_t a8 = BitCast(d8, detail::SaturateU(a)); + const svuint8_t b8 = BitCast(d8, detail::SaturateU(b)); + return detail::InterleaveEven(a8, b8); +#endif +} + +template +HWY_API svint32_t ReorderDemote2To(Simd d32, svint64_t a, + svint64_t b) { +#if HWY_SVE_HAVE_2 + (void)d32; + const svint32_t a_in_even = svqxtnb_s64(a); + return svqxtnt_s64(a_in_even, b); +#else + const svint32_t a32 = BitCast(d32, detail::SaturateI(a)); + const svint32_t b32 = BitCast(d32, detail::SaturateI(b)); + return detail::InterleaveEven(a32, b32); +#endif +} + +template +HWY_API svuint32_t ReorderDemote2To(Simd d32, svint64_t a, + svint64_t b) { +#if HWY_SVE_HAVE_2 + (void)d32; + const svuint32_t a_in_even = svqxtunb_s64(a); + return svqxtunt_s64(a_in_even, b); +#else + const Repartition du64; + const svuint64_t clamped_a = BitCast(du64, detail::MaxN(a, 0)); + const svuint64_t clamped_b = BitCast(du64, detail::MaxN(b, 0)); + const svuint32_t a32 = BitCast(d32, detail::SaturateU(clamped_a)); + const svuint32_t b32 = BitCast(d32, detail::SaturateU(clamped_b)); + return detail::InterleaveEven(a32, b32); +#endif +} + +template +HWY_API svuint32_t ReorderDemote2To(Simd d32, svuint64_t a, + svuint64_t b) { +#if HWY_SVE_HAVE_2 + (void)d32; + const svuint32_t a_in_even = svqxtnb_u64(a); + return svqxtnt_u64(a_in_even, b); +#else + const svuint32_t a32 = BitCast(d32, detail::SaturateU(a)); + const svuint32_t b32 = BitCast(d32, detail::SaturateU(b)); + return detail::InterleaveEven(a32, b32); +#endif +} + +template ) / 2)> +HWY_API VFromD ReorderDemote2To(D dn, V a, V b) { + const auto clamped_a = BitCast(dn, detail::SaturateU>(a)); + const auto clamped_b = BitCast(dn, detail::SaturateU>(b)); + return detail::InterleaveEven(clamped_a, clamped_b); +} + +template ), + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2)> +HWY_API VFromD OrderedDemote2To(D dn, V a, V b) { + const Half dnh; + const auto demoted_a = DemoteTo(dnh, a); + const auto demoted_b = DemoteTo(dnh, b); + return Combine(dn, demoted_b, demoted_a); +} + +template +HWY_API VBF16 OrderedDemote2To(Simd dbf16, svfloat32_t a, + svfloat32_t b) { +#if HWY_SVE_HAVE_F32_TO_BF16C + (void)dbf16; + const VBF16 a_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), a); + const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); + return ConcatEven(dbf16, b_in_even, a_in_even); +#else + const RebindToUnsigned du16; + const svuint16_t a_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(a)); + const svuint16_t b_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(b)); + return BitCast(dbf16, ConcatOdd(du16, b_in_odd, a_in_odd)); // lower half +#endif +} + +// ------------------------------ I8/U8/I16/U16 Div + +template +HWY_API V Div(V a, V b) { + const DFromV d; + const Half dh; + const RepartitionToWide dw; + + const auto q_lo = + Div(PromoteTo(dw, LowerHalf(dh, a)), PromoteTo(dw, LowerHalf(dh, b))); + const auto q_hi = Div(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b)); + + return OrderedDemote2To(d, q_lo, q_hi); +} + +// ------------------------------ I8/U8/I16/U16 MaskedDivOr +template +HWY_API V MaskedDivOr(V no, M m, V a, V b) { + return IfThenElse(m, Div(a, b), no); +} + +template +HWY_API V MaskedDiv(M m, V a, V b) { + return IfThenElseZero(m, Div(a, b)); +} + +// ------------------------------ Mod (Div, NegMulAdd) +template +HWY_API V Mod(V a, V b) { + return NegMulAdd(Div(a, b), b, a); +} + +// ------------------------------ MaskedModOr (Mod) +template +HWY_API V MaskedModOr(V no, M m, V a, V b) { + return IfThenElse(m, Mod(a, b), no); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + return IfThenElse(IsNegative(v), yes, no); +} +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif + +#define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \ + } + +HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg) + +#undef HWY_SVE_NEG_IF + +// ------------------------------ AverageRound (ShiftRight) + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +#if HWY_SVE_HAVE_2 +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) +#else +template +HWY_API V AverageRound(const V a, const V b) { + return Sub(Or(a, b), ShiftRight<1>(Xor(a, b))); +} +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { +#if HWY_COMPILER_CLANG >= 1901 || HWY_COMPILER_GCC_ACTUAL >= 1200 + typedef svbool_t UnalignedSveMaskT + __attribute__((__aligned__(1), __may_alias__)); + (void)d; + return *reinterpret_cast(bits); +#else + // TODO(janwas): with SVE2.1, load to vector, then PMOV + const RebindToUnsigned du; + const svuint8_t iota = Iota(du, 0); + + // Load correct number of bytes (bits/8) with 7 zeros after each. + const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits)); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota)); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(rep8, bit); +#endif +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const Repartition du8; + + // There may be up to 128 bits; avoid reading past the end. + const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits); + + // Replicate bytes 16x such that each lane contains the bit that governs it. + const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0))); + + const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(BitCast(du, rep16), bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const Repartition du8; + + // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable, + // so we can skip computing the actual length (Lanes(du)+7)/8. + const svuint8_t bytes = svld1(FirstN(du8, 8), bits); + + // Replicate bytes 32x such that each lane contains the bit that governs it. + const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0))); + + // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 .. + const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7)); + + return TestBit(BitCast(du, rep32), bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + + // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane. + // The "at least 8 byte" guarantee in quick_reference ensures this is safe. + uint32_t mask_bits; + CopyBytes<4>(bits, &mask_bits); // copy from bytes + const auto vbits = Set(du, mask_bits); + + // 2 ^ {0,1, .., 31}, will not have more lanes than that. + const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0)); + + return TestBit(vbits, bit); +} + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + + // Replicate the lower 8 bits of mask_bits to each u8 lane + const svuint8_t bytes = BitCast(du, Set(du, static_cast(mask_bits))); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(bytes, bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du16; + + // Replicate the lower 16 bits of mask_bits to each u16 lane of a u16 vector, + // and then bitcast the replicated mask_bits to a u8 vector + const svuint8_t bytes = + BitCast(du, Set(du16, static_cast(mask_bits))); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const svuint8_t rep8 = svtbl_u8(bytes, ShiftRight<3>(Iota(du, 0))); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(rep8, bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du8; + + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + + // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits + const svuint8_t bytes = Set(du8, static_cast(mask_bits)); + + const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(BitCast(du, bytes), bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du8; + + constexpr size_t kN = MaxLanes(d); + if (kN < 4) mask_bits &= (1u << kN) - 1; + + // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits + const svuint8_t bytes = Set(du8, static_cast(mask_bits)); + + const svuint32_t bit = svdupq_n_u32(1, 2, 4, 8); + return TestBit(BitCast(du, bytes), bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du8; + + if (MaxLanes(d) < 2) mask_bits &= 1u; + + // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits + const svuint8_t bytes = Set(du8, static_cast(mask_bits)); + + const svuint64_t bit = svdupq_n_u64(1, 2); + return TestBit(BitCast(du, bytes), bit); +} + +// ------------------------------ StoreMaskBits (BitsFromMask) + +// `p` points to at least 8 writable bytes. +// TODO(janwas): with SVE2.1, use PMOV to store to vector, then StoreU +template +HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + constexpr size_t N = MaxLanes(d); + const uint64_t bits64 = BitsFromMask(d, m); + HWY_IF_CONSTEXPR(N < 8) { + // BitsFromMask guarantees upper bits are zero, hence no masking. + bits[0] = static_cast(bits64); + } + else { + static_assert(N % 8 == 0, "N is pow2 >= 8, hence divisible"); + static_assert(HWY_IS_LITTLE_ENDIAN, ""); + hwy::CopyBytes(&bits64, bits); + } + constexpr size_t num_bytes = hwy::DivCeil(N, size_t{8}); + return num_bytes; +#else + svuint64_t bits_in_u64 = detail::BitsFromBool(detail::BoolFromMask(m)); + + const size_t num_bits = Lanes(d); + const size_t num_bytes = hwy::DivCeil(num_bits, size_t{8}); + + // Truncate each u64 to 8 bits and store to u8. + svst1b_u64(FirstN(ScalableTag(), num_bytes), bits, bits_in_u64); + + // Non-full byte, need to clear the undefined upper bits. Can happen for + // capped/fractional vectors or large T and small hardware vectors. + if (num_bits < 8) { + const int mask = static_cast((1ull << num_bits) - 1); + bits[0] = static_cast(bits[0] & mask); + } + // Else: we wrote full bytes because num_bits is a power of two >= 8. + + return num_bytes; +#endif // HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +} + +// ------------------------------ CompressBits (LoadMaskBits) +template +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ Expand (StoreMaskBits) + +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +namespace detail { + +HWY_INLINE svuint8_t IndicesForExpandFromBits(uint64_t mask_bits) { + const CappedTag du8; + alignas(16) static constexpr uint8_t table[8 * 256] = { + // PrintExpand8x8Tables + 128, 128, 128, 128, 128, 128, 128, 128, // + 0, 128, 128, 128, 128, 128, 128, 128, // + 128, 0, 128, 128, 128, 128, 128, 128, // + 0, 1, 128, 128, 128, 128, 128, 128, // + 128, 128, 0, 128, 128, 128, 128, 128, // + 0, 128, 1, 128, 128, 128, 128, 128, // + 128, 0, 1, 128, 128, 128, 128, 128, // + 0, 1, 2, 128, 128, 128, 128, 128, // + 128, 128, 128, 0, 128, 128, 128, 128, // + 0, 128, 128, 1, 128, 128, 128, 128, // + 128, 0, 128, 1, 128, 128, 128, 128, // + 0, 1, 128, 2, 128, 128, 128, 128, // + 128, 128, 0, 1, 128, 128, 128, 128, // + 0, 128, 1, 2, 128, 128, 128, 128, // + 128, 0, 1, 2, 128, 128, 128, 128, // + 0, 1, 2, 3, 128, 128, 128, 128, // + 128, 128, 128, 128, 0, 128, 128, 128, // + 0, 128, 128, 128, 1, 128, 128, 128, // + 128, 0, 128, 128, 1, 128, 128, 128, // + 0, 1, 128, 128, 2, 128, 128, 128, // + 128, 128, 0, 128, 1, 128, 128, 128, // + 0, 128, 1, 128, 2, 128, 128, 128, // + 128, 0, 1, 128, 2, 128, 128, 128, // + 0, 1, 2, 128, 3, 128, 128, 128, // + 128, 128, 128, 0, 1, 128, 128, 128, // + 0, 128, 128, 1, 2, 128, 128, 128, // + 128, 0, 128, 1, 2, 128, 128, 128, // + 0, 1, 128, 2, 3, 128, 128, 128, // + 128, 128, 0, 1, 2, 128, 128, 128, // + 0, 128, 1, 2, 3, 128, 128, 128, // + 128, 0, 1, 2, 3, 128, 128, 128, // + 0, 1, 2, 3, 4, 128, 128, 128, // + 128, 128, 128, 128, 128, 0, 128, 128, // + 0, 128, 128, 128, 128, 1, 128, 128, // + 128, 0, 128, 128, 128, 1, 128, 128, // + 0, 1, 128, 128, 128, 2, 128, 128, // + 128, 128, 0, 128, 128, 1, 128, 128, // + 0, 128, 1, 128, 128, 2, 128, 128, // + 128, 0, 1, 128, 128, 2, 128, 128, // + 0, 1, 2, 128, 128, 3, 128, 128, // + 128, 128, 128, 0, 128, 1, 128, 128, // + 0, 128, 128, 1, 128, 2, 128, 128, // + 128, 0, 128, 1, 128, 2, 128, 128, // + 0, 1, 128, 2, 128, 3, 128, 128, // + 128, 128, 0, 1, 128, 2, 128, 128, // + 0, 128, 1, 2, 128, 3, 128, 128, // + 128, 0, 1, 2, 128, 3, 128, 128, // + 0, 1, 2, 3, 128, 4, 128, 128, // + 128, 128, 128, 128, 0, 1, 128, 128, // + 0, 128, 128, 128, 1, 2, 128, 128, // + 128, 0, 128, 128, 1, 2, 128, 128, // + 0, 1, 128, 128, 2, 3, 128, 128, // + 128, 128, 0, 128, 1, 2, 128, 128, // + 0, 128, 1, 128, 2, 3, 128, 128, // + 128, 0, 1, 128, 2, 3, 128, 128, // + 0, 1, 2, 128, 3, 4, 128, 128, // + 128, 128, 128, 0, 1, 2, 128, 128, // + 0, 128, 128, 1, 2, 3, 128, 128, // + 128, 0, 128, 1, 2, 3, 128, 128, // + 0, 1, 128, 2, 3, 4, 128, 128, // + 128, 128, 0, 1, 2, 3, 128, 128, // + 0, 128, 1, 2, 3, 4, 128, 128, // + 128, 0, 1, 2, 3, 4, 128, 128, // + 0, 1, 2, 3, 4, 5, 128, 128, // + 128, 128, 128, 128, 128, 128, 0, 128, // + 0, 128, 128, 128, 128, 128, 1, 128, // + 128, 0, 128, 128, 128, 128, 1, 128, // + 0, 1, 128, 128, 128, 128, 2, 128, // + 128, 128, 0, 128, 128, 128, 1, 128, // + 0, 128, 1, 128, 128, 128, 2, 128, // + 128, 0, 1, 128, 128, 128, 2, 128, // + 0, 1, 2, 128, 128, 128, 3, 128, // + 128, 128, 128, 0, 128, 128, 1, 128, // + 0, 128, 128, 1, 128, 128, 2, 128, // + 128, 0, 128, 1, 128, 128, 2, 128, // + 0, 1, 128, 2, 128, 128, 3, 128, // + 128, 128, 0, 1, 128, 128, 2, 128, // + 0, 128, 1, 2, 128, 128, 3, 128, // + 128, 0, 1, 2, 128, 128, 3, 128, // + 0, 1, 2, 3, 128, 128, 4, 128, // + 128, 128, 128, 128, 0, 128, 1, 128, // + 0, 128, 128, 128, 1, 128, 2, 128, // + 128, 0, 128, 128, 1, 128, 2, 128, // + 0, 1, 128, 128, 2, 128, 3, 128, // + 128, 128, 0, 128, 1, 128, 2, 128, // + 0, 128, 1, 128, 2, 128, 3, 128, // + 128, 0, 1, 128, 2, 128, 3, 128, // + 0, 1, 2, 128, 3, 128, 4, 128, // + 128, 128, 128, 0, 1, 128, 2, 128, // + 0, 128, 128, 1, 2, 128, 3, 128, // + 128, 0, 128, 1, 2, 128, 3, 128, // + 0, 1, 128, 2, 3, 128, 4, 128, // + 128, 128, 0, 1, 2, 128, 3, 128, // + 0, 128, 1, 2, 3, 128, 4, 128, // + 128, 0, 1, 2, 3, 128, 4, 128, // + 0, 1, 2, 3, 4, 128, 5, 128, // + 128, 128, 128, 128, 128, 0, 1, 128, // + 0, 128, 128, 128, 128, 1, 2, 128, // + 128, 0, 128, 128, 128, 1, 2, 128, // + 0, 1, 128, 128, 128, 2, 3, 128, // + 128, 128, 0, 128, 128, 1, 2, 128, // + 0, 128, 1, 128, 128, 2, 3, 128, // + 128, 0, 1, 128, 128, 2, 3, 128, // + 0, 1, 2, 128, 128, 3, 4, 128, // + 128, 128, 128, 0, 128, 1, 2, 128, // + 0, 128, 128, 1, 128, 2, 3, 128, // + 128, 0, 128, 1, 128, 2, 3, 128, // + 0, 1, 128, 2, 128, 3, 4, 128, // + 128, 128, 0, 1, 128, 2, 3, 128, // + 0, 128, 1, 2, 128, 3, 4, 128, // + 128, 0, 1, 2, 128, 3, 4, 128, // + 0, 1, 2, 3, 128, 4, 5, 128, // + 128, 128, 128, 128, 0, 1, 2, 128, // + 0, 128, 128, 128, 1, 2, 3, 128, // + 128, 0, 128, 128, 1, 2, 3, 128, // + 0, 1, 128, 128, 2, 3, 4, 128, // + 128, 128, 0, 128, 1, 2, 3, 128, // + 0, 128, 1, 128, 2, 3, 4, 128, // + 128, 0, 1, 128, 2, 3, 4, 128, // + 0, 1, 2, 128, 3, 4, 5, 128, // + 128, 128, 128, 0, 1, 2, 3, 128, // + 0, 128, 128, 1, 2, 3, 4, 128, // + 128, 0, 128, 1, 2, 3, 4, 128, // + 0, 1, 128, 2, 3, 4, 5, 128, // + 128, 128, 0, 1, 2, 3, 4, 128, // + 0, 128, 1, 2, 3, 4, 5, 128, // + 128, 0, 1, 2, 3, 4, 5, 128, // + 0, 1, 2, 3, 4, 5, 6, 128, // + 128, 128, 128, 128, 128, 128, 128, 0, // + 0, 128, 128, 128, 128, 128, 128, 1, // + 128, 0, 128, 128, 128, 128, 128, 1, // + 0, 1, 128, 128, 128, 128, 128, 2, // + 128, 128, 0, 128, 128, 128, 128, 1, // + 0, 128, 1, 128, 128, 128, 128, 2, // + 128, 0, 1, 128, 128, 128, 128, 2, // + 0, 1, 2, 128, 128, 128, 128, 3, // + 128, 128, 128, 0, 128, 128, 128, 1, // + 0, 128, 128, 1, 128, 128, 128, 2, // + 128, 0, 128, 1, 128, 128, 128, 2, // + 0, 1, 128, 2, 128, 128, 128, 3, // + 128, 128, 0, 1, 128, 128, 128, 2, // + 0, 128, 1, 2, 128, 128, 128, 3, // + 128, 0, 1, 2, 128, 128, 128, 3, // + 0, 1, 2, 3, 128, 128, 128, 4, // + 128, 128, 128, 128, 0, 128, 128, 1, // + 0, 128, 128, 128, 1, 128, 128, 2, // + 128, 0, 128, 128, 1, 128, 128, 2, // + 0, 1, 128, 128, 2, 128, 128, 3, // + 128, 128, 0, 128, 1, 128, 128, 2, // + 0, 128, 1, 128, 2, 128, 128, 3, // + 128, 0, 1, 128, 2, 128, 128, 3, // + 0, 1, 2, 128, 3, 128, 128, 4, // + 128, 128, 128, 0, 1, 128, 128, 2, // + 0, 128, 128, 1, 2, 128, 128, 3, // + 128, 0, 128, 1, 2, 128, 128, 3, // + 0, 1, 128, 2, 3, 128, 128, 4, // + 128, 128, 0, 1, 2, 128, 128, 3, // + 0, 128, 1, 2, 3, 128, 128, 4, // + 128, 0, 1, 2, 3, 128, 128, 4, // + 0, 1, 2, 3, 4, 128, 128, 5, // + 128, 128, 128, 128, 128, 0, 128, 1, // + 0, 128, 128, 128, 128, 1, 128, 2, // + 128, 0, 128, 128, 128, 1, 128, 2, // + 0, 1, 128, 128, 128, 2, 128, 3, // + 128, 128, 0, 128, 128, 1, 128, 2, // + 0, 128, 1, 128, 128, 2, 128, 3, // + 128, 0, 1, 128, 128, 2, 128, 3, // + 0, 1, 2, 128, 128, 3, 128, 4, // + 128, 128, 128, 0, 128, 1, 128, 2, // + 0, 128, 128, 1, 128, 2, 128, 3, // + 128, 0, 128, 1, 128, 2, 128, 3, // + 0, 1, 128, 2, 128, 3, 128, 4, // + 128, 128, 0, 1, 128, 2, 128, 3, // + 0, 128, 1, 2, 128, 3, 128, 4, // + 128, 0, 1, 2, 128, 3, 128, 4, // + 0, 1, 2, 3, 128, 4, 128, 5, // + 128, 128, 128, 128, 0, 1, 128, 2, // + 0, 128, 128, 128, 1, 2, 128, 3, // + 128, 0, 128, 128, 1, 2, 128, 3, // + 0, 1, 128, 128, 2, 3, 128, 4, // + 128, 128, 0, 128, 1, 2, 128, 3, // + 0, 128, 1, 128, 2, 3, 128, 4, // + 128, 0, 1, 128, 2, 3, 128, 4, // + 0, 1, 2, 128, 3, 4, 128, 5, // + 128, 128, 128, 0, 1, 2, 128, 3, // + 0, 128, 128, 1, 2, 3, 128, 4, // + 128, 0, 128, 1, 2, 3, 128, 4, // + 0, 1, 128, 2, 3, 4, 128, 5, // + 128, 128, 0, 1, 2, 3, 128, 4, // + 0, 128, 1, 2, 3, 4, 128, 5, // + 128, 0, 1, 2, 3, 4, 128, 5, // + 0, 1, 2, 3, 4, 5, 128, 6, // + 128, 128, 128, 128, 128, 128, 0, 1, // + 0, 128, 128, 128, 128, 128, 1, 2, // + 128, 0, 128, 128, 128, 128, 1, 2, // + 0, 1, 128, 128, 128, 128, 2, 3, // + 128, 128, 0, 128, 128, 128, 1, 2, // + 0, 128, 1, 128, 128, 128, 2, 3, // + 128, 0, 1, 128, 128, 128, 2, 3, // + 0, 1, 2, 128, 128, 128, 3, 4, // + 128, 128, 128, 0, 128, 128, 1, 2, // + 0, 128, 128, 1, 128, 128, 2, 3, // + 128, 0, 128, 1, 128, 128, 2, 3, // + 0, 1, 128, 2, 128, 128, 3, 4, // + 128, 128, 0, 1, 128, 128, 2, 3, // + 0, 128, 1, 2, 128, 128, 3, 4, // + 128, 0, 1, 2, 128, 128, 3, 4, // + 0, 1, 2, 3, 128, 128, 4, 5, // + 128, 128, 128, 128, 0, 128, 1, 2, // + 0, 128, 128, 128, 1, 128, 2, 3, // + 128, 0, 128, 128, 1, 128, 2, 3, // + 0, 1, 128, 128, 2, 128, 3, 4, // + 128, 128, 0, 128, 1, 128, 2, 3, // + 0, 128, 1, 128, 2, 128, 3, 4, // + 128, 0, 1, 128, 2, 128, 3, 4, // + 0, 1, 2, 128, 3, 128, 4, 5, // + 128, 128, 128, 0, 1, 128, 2, 3, // + 0, 128, 128, 1, 2, 128, 3, 4, // + 128, 0, 128, 1, 2, 128, 3, 4, // + 0, 1, 128, 2, 3, 128, 4, 5, // + 128, 128, 0, 1, 2, 128, 3, 4, // + 0, 128, 1, 2, 3, 128, 4, 5, // + 128, 0, 1, 2, 3, 128, 4, 5, // + 0, 1, 2, 3, 4, 128, 5, 6, // + 128, 128, 128, 128, 128, 0, 1, 2, // + 0, 128, 128, 128, 128, 1, 2, 3, // + 128, 0, 128, 128, 128, 1, 2, 3, // + 0, 1, 128, 128, 128, 2, 3, 4, // + 128, 128, 0, 128, 128, 1, 2, 3, // + 0, 128, 1, 128, 128, 2, 3, 4, // + 128, 0, 1, 128, 128, 2, 3, 4, // + 0, 1, 2, 128, 128, 3, 4, 5, // + 128, 128, 128, 0, 128, 1, 2, 3, // + 0, 128, 128, 1, 128, 2, 3, 4, // + 128, 0, 128, 1, 128, 2, 3, 4, // + 0, 1, 128, 2, 128, 3, 4, 5, // + 128, 128, 0, 1, 128, 2, 3, 4, // + 0, 128, 1, 2, 128, 3, 4, 5, // + 128, 0, 1, 2, 128, 3, 4, 5, // + 0, 1, 2, 3, 128, 4, 5, 6, // + 128, 128, 128, 128, 0, 1, 2, 3, // + 0, 128, 128, 128, 1, 2, 3, 4, // + 128, 0, 128, 128, 1, 2, 3, 4, // + 0, 1, 128, 128, 2, 3, 4, 5, // + 128, 128, 0, 128, 1, 2, 3, 4, // + 0, 128, 1, 128, 2, 3, 4, 5, // + 128, 0, 1, 128, 2, 3, 4, 5, // + 0, 1, 2, 128, 3, 4, 5, 6, // + 128, 128, 128, 0, 1, 2, 3, 4, // + 0, 128, 128, 1, 2, 3, 4, 5, // + 128, 0, 128, 1, 2, 3, 4, 5, // + 0, 1, 128, 2, 3, 4, 5, 6, // + 128, 128, 0, 1, 2, 3, 4, 5, // + 0, 128, 1, 2, 3, 4, 5, 6, // + 128, 0, 1, 2, 3, 4, 5, 6, // + 0, 1, 2, 3, 4, 5, 6, 7}; + return Load(du8, table + mask_bits * 8); +} + +template +HWY_INLINE svuint8_t LaneIndicesFromByteIndices(D, svuint8_t idx) { + return idx; +} +template , HWY_IF_NOT_T_SIZE_D(D, 1)> +HWY_INLINE VFromD LaneIndicesFromByteIndices(D, svuint8_t idx) { + return PromoteTo(DU(), idx); +} + +// General case when we don't know the vector size, 8 elements at a time. +template +HWY_INLINE V ExpandLoop(V v, svbool_t mask) { + const DFromV d; + using T = TFromV; + uint8_t mask_bytes[256 / 8]; + StoreMaskBits(d, mask, mask_bytes); + + // ShiftLeftLanes is expensive, so we're probably better off storing to memory + // and loading the final result. + alignas(16) T out[2 * MaxLanes(d)]; + + svbool_t next = svpfalse_b(); + size_t input_consumed = 0; + const V iota = Iota(d, 0); + for (size_t i = 0; i < Lanes(d); i += 8) { + uint64_t mask_bits = mask_bytes[i / 8]; + + // We want to skip past the v lanes already consumed. There is no + // instruction for variable-shift-reg, but we can splice. + const V vH = detail::Splice(v, v, next); + input_consumed += PopCount(mask_bits); + next = detail::GeN(iota, ConvertScalarTo(input_consumed)); + + const auto idx = detail::LaneIndicesFromByteIndices( + d, detail::IndicesForExpandFromBits(mask_bits)); + const V expand = TableLookupLanes(vH, idx); + StoreU(expand, d, out + i); + } + return LoadU(d, out); +} + +} // namespace detail + +template +HWY_API V Expand(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE + const DFromV d; + uint8_t mask_bytes[256 / 8]; + StoreMaskBits(d, mask, mask_bytes); + const uint64_t maskL = mask_bytes[0]; + const uint64_t maskH = mask_bytes[1]; + + // We want to skip past the v bytes already consumed by expandL. There is no + // instruction for shift-reg by variable bytes, but we can splice. Instead of + // GeN, Not(FirstN()) would also work. + using T = TFromV; + const T countL = static_cast(PopCount(maskL)); + const V vH = detail::Splice(v, v, detail::GeN(Iota(d, 0), countL)); + + const svuint8_t idxL = detail::IndicesForExpandFromBits(maskL); + const svuint8_t idxH = detail::IndicesForExpandFromBits(maskH); + return Combine(d, TableLookupLanes(vH, idxH), TableLookupLanes(v, idxL)); +#else + return detail::ExpandLoop(v, mask); +#endif +} + +template +HWY_API V Expand(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE // 16x8 + const DFromV d; + const RebindToUnsigned du16; + const Rebind du8; + // Convert mask into bitfield via horizontal sum (faster than ORV) of 8 bits. + // Pre-multiply by N so we can use it as an offset for Load. + const svuint16_t bits = Shl(Set(du16, 1), Iota(du16, 3)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // Storing as 8-bit reduces table size from 4 KiB to 2 KiB. We cannot apply + // the nibble trick used below because not all indices fit within one lane. + alignas(16) static constexpr uint8_t table[8 * 256] = { + // PrintExpand16x8LaneTables + 255, 255, 255, 255, 255, 255, 255, 255, // + 0, 255, 255, 255, 255, 255, 255, 255, // + 255, 0, 255, 255, 255, 255, 255, 255, // + 0, 1, 255, 255, 255, 255, 255, 255, // + 255, 255, 0, 255, 255, 255, 255, 255, // + 0, 255, 1, 255, 255, 255, 255, 255, // + 255, 0, 1, 255, 255, 255, 255, 255, // + 0, 1, 2, 255, 255, 255, 255, 255, // + 255, 255, 255, 0, 255, 255, 255, 255, // + 0, 255, 255, 1, 255, 255, 255, 255, // + 255, 0, 255, 1, 255, 255, 255, 255, // + 0, 1, 255, 2, 255, 255, 255, 255, // + 255, 255, 0, 1, 255, 255, 255, 255, // + 0, 255, 1, 2, 255, 255, 255, 255, // + 255, 0, 1, 2, 255, 255, 255, 255, // + 0, 1, 2, 3, 255, 255, 255, 255, // + 255, 255, 255, 255, 0, 255, 255, 255, // + 0, 255, 255, 255, 1, 255, 255, 255, // + 255, 0, 255, 255, 1, 255, 255, 255, // + 0, 1, 255, 255, 2, 255, 255, 255, // + 255, 255, 0, 255, 1, 255, 255, 255, // + 0, 255, 1, 255, 2, 255, 255, 255, // + 255, 0, 1, 255, 2, 255, 255, 255, // + 0, 1, 2, 255, 3, 255, 255, 255, // + 255, 255, 255, 0, 1, 255, 255, 255, // + 0, 255, 255, 1, 2, 255, 255, 255, // + 255, 0, 255, 1, 2, 255, 255, 255, // + 0, 1, 255, 2, 3, 255, 255, 255, // + 255, 255, 0, 1, 2, 255, 255, 255, // + 0, 255, 1, 2, 3, 255, 255, 255, // + 255, 0, 1, 2, 3, 255, 255, 255, // + 0, 1, 2, 3, 4, 255, 255, 255, // + 255, 255, 255, 255, 255, 0, 255, 255, // + 0, 255, 255, 255, 255, 1, 255, 255, // + 255, 0, 255, 255, 255, 1, 255, 255, // + 0, 1, 255, 255, 255, 2, 255, 255, // + 255, 255, 0, 255, 255, 1, 255, 255, // + 0, 255, 1, 255, 255, 2, 255, 255, // + 255, 0, 1, 255, 255, 2, 255, 255, // + 0, 1, 2, 255, 255, 3, 255, 255, // + 255, 255, 255, 0, 255, 1, 255, 255, // + 0, 255, 255, 1, 255, 2, 255, 255, // + 255, 0, 255, 1, 255, 2, 255, 255, // + 0, 1, 255, 2, 255, 3, 255, 255, // + 255, 255, 0, 1, 255, 2, 255, 255, // + 0, 255, 1, 2, 255, 3, 255, 255, // + 255, 0, 1, 2, 255, 3, 255, 255, // + 0, 1, 2, 3, 255, 4, 255, 255, // + 255, 255, 255, 255, 0, 1, 255, 255, // + 0, 255, 255, 255, 1, 2, 255, 255, // + 255, 0, 255, 255, 1, 2, 255, 255, // + 0, 1, 255, 255, 2, 3, 255, 255, // + 255, 255, 0, 255, 1, 2, 255, 255, // + 0, 255, 1, 255, 2, 3, 255, 255, // + 255, 0, 1, 255, 2, 3, 255, 255, // + 0, 1, 2, 255, 3, 4, 255, 255, // + 255, 255, 255, 0, 1, 2, 255, 255, // + 0, 255, 255, 1, 2, 3, 255, 255, // + 255, 0, 255, 1, 2, 3, 255, 255, // + 0, 1, 255, 2, 3, 4, 255, 255, // + 255, 255, 0, 1, 2, 3, 255, 255, // + 0, 255, 1, 2, 3, 4, 255, 255, // + 255, 0, 1, 2, 3, 4, 255, 255, // + 0, 1, 2, 3, 4, 5, 255, 255, // + 255, 255, 255, 255, 255, 255, 0, 255, // + 0, 255, 255, 255, 255, 255, 1, 255, // + 255, 0, 255, 255, 255, 255, 1, 255, // + 0, 1, 255, 255, 255, 255, 2, 255, // + 255, 255, 0, 255, 255, 255, 1, 255, // + 0, 255, 1, 255, 255, 255, 2, 255, // + 255, 0, 1, 255, 255, 255, 2, 255, // + 0, 1, 2, 255, 255, 255, 3, 255, // + 255, 255, 255, 0, 255, 255, 1, 255, // + 0, 255, 255, 1, 255, 255, 2, 255, // + 255, 0, 255, 1, 255, 255, 2, 255, // + 0, 1, 255, 2, 255, 255, 3, 255, // + 255, 255, 0, 1, 255, 255, 2, 255, // + 0, 255, 1, 2, 255, 255, 3, 255, // + 255, 0, 1, 2, 255, 255, 3, 255, // + 0, 1, 2, 3, 255, 255, 4, 255, // + 255, 255, 255, 255, 0, 255, 1, 255, // + 0, 255, 255, 255, 1, 255, 2, 255, // + 255, 0, 255, 255, 1, 255, 2, 255, // + 0, 1, 255, 255, 2, 255, 3, 255, // + 255, 255, 0, 255, 1, 255, 2, 255, // + 0, 255, 1, 255, 2, 255, 3, 255, // + 255, 0, 1, 255, 2, 255, 3, 255, // + 0, 1, 2, 255, 3, 255, 4, 255, // + 255, 255, 255, 0, 1, 255, 2, 255, // + 0, 255, 255, 1, 2, 255, 3, 255, // + 255, 0, 255, 1, 2, 255, 3, 255, // + 0, 1, 255, 2, 3, 255, 4, 255, // + 255, 255, 0, 1, 2, 255, 3, 255, // + 0, 255, 1, 2, 3, 255, 4, 255, // + 255, 0, 1, 2, 3, 255, 4, 255, // + 0, 1, 2, 3, 4, 255, 5, 255, // + 255, 255, 255, 255, 255, 0, 1, 255, // + 0, 255, 255, 255, 255, 1, 2, 255, // + 255, 0, 255, 255, 255, 1, 2, 255, // + 0, 1, 255, 255, 255, 2, 3, 255, // + 255, 255, 0, 255, 255, 1, 2, 255, // + 0, 255, 1, 255, 255, 2, 3, 255, // + 255, 0, 1, 255, 255, 2, 3, 255, // + 0, 1, 2, 255, 255, 3, 4, 255, // + 255, 255, 255, 0, 255, 1, 2, 255, // + 0, 255, 255, 1, 255, 2, 3, 255, // + 255, 0, 255, 1, 255, 2, 3, 255, // + 0, 1, 255, 2, 255, 3, 4, 255, // + 255, 255, 0, 1, 255, 2, 3, 255, // + 0, 255, 1, 2, 255, 3, 4, 255, // + 255, 0, 1, 2, 255, 3, 4, 255, // + 0, 1, 2, 3, 255, 4, 5, 255, // + 255, 255, 255, 255, 0, 1, 2, 255, // + 0, 255, 255, 255, 1, 2, 3, 255, // + 255, 0, 255, 255, 1, 2, 3, 255, // + 0, 1, 255, 255, 2, 3, 4, 255, // + 255, 255, 0, 255, 1, 2, 3, 255, // + 0, 255, 1, 255, 2, 3, 4, 255, // + 255, 0, 1, 255, 2, 3, 4, 255, // + 0, 1, 2, 255, 3, 4, 5, 255, // + 255, 255, 255, 0, 1, 2, 3, 255, // + 0, 255, 255, 1, 2, 3, 4, 255, // + 255, 0, 255, 1, 2, 3, 4, 255, // + 0, 1, 255, 2, 3, 4, 5, 255, // + 255, 255, 0, 1, 2, 3, 4, 255, // + 0, 255, 1, 2, 3, 4, 5, 255, // + 255, 0, 1, 2, 3, 4, 5, 255, // + 0, 1, 2, 3, 4, 5, 6, 255, // + 255, 255, 255, 255, 255, 255, 255, 0, // + 0, 255, 255, 255, 255, 255, 255, 1, // + 255, 0, 255, 255, 255, 255, 255, 1, // + 0, 1, 255, 255, 255, 255, 255, 2, // + 255, 255, 0, 255, 255, 255, 255, 1, // + 0, 255, 1, 255, 255, 255, 255, 2, // + 255, 0, 1, 255, 255, 255, 255, 2, // + 0, 1, 2, 255, 255, 255, 255, 3, // + 255, 255, 255, 0, 255, 255, 255, 1, // + 0, 255, 255, 1, 255, 255, 255, 2, // + 255, 0, 255, 1, 255, 255, 255, 2, // + 0, 1, 255, 2, 255, 255, 255, 3, // + 255, 255, 0, 1, 255, 255, 255, 2, // + 0, 255, 1, 2, 255, 255, 255, 3, // + 255, 0, 1, 2, 255, 255, 255, 3, // + 0, 1, 2, 3, 255, 255, 255, 4, // + 255, 255, 255, 255, 0, 255, 255, 1, // + 0, 255, 255, 255, 1, 255, 255, 2, // + 255, 0, 255, 255, 1, 255, 255, 2, // + 0, 1, 255, 255, 2, 255, 255, 3, // + 255, 255, 0, 255, 1, 255, 255, 2, // + 0, 255, 1, 255, 2, 255, 255, 3, // + 255, 0, 1, 255, 2, 255, 255, 3, // + 0, 1, 2, 255, 3, 255, 255, 4, // + 255, 255, 255, 0, 1, 255, 255, 2, // + 0, 255, 255, 1, 2, 255, 255, 3, // + 255, 0, 255, 1, 2, 255, 255, 3, // + 0, 1, 255, 2, 3, 255, 255, 4, // + 255, 255, 0, 1, 2, 255, 255, 3, // + 0, 255, 1, 2, 3, 255, 255, 4, // + 255, 0, 1, 2, 3, 255, 255, 4, // + 0, 1, 2, 3, 4, 255, 255, 5, // + 255, 255, 255, 255, 255, 0, 255, 1, // + 0, 255, 255, 255, 255, 1, 255, 2, // + 255, 0, 255, 255, 255, 1, 255, 2, // + 0, 1, 255, 255, 255, 2, 255, 3, // + 255, 255, 0, 255, 255, 1, 255, 2, // + 0, 255, 1, 255, 255, 2, 255, 3, // + 255, 0, 1, 255, 255, 2, 255, 3, // + 0, 1, 2, 255, 255, 3, 255, 4, // + 255, 255, 255, 0, 255, 1, 255, 2, // + 0, 255, 255, 1, 255, 2, 255, 3, // + 255, 0, 255, 1, 255, 2, 255, 3, // + 0, 1, 255, 2, 255, 3, 255, 4, // + 255, 255, 0, 1, 255, 2, 255, 3, // + 0, 255, 1, 2, 255, 3, 255, 4, // + 255, 0, 1, 2, 255, 3, 255, 4, // + 0, 1, 2, 3, 255, 4, 255, 5, // + 255, 255, 255, 255, 0, 1, 255, 2, // + 0, 255, 255, 255, 1, 2, 255, 3, // + 255, 0, 255, 255, 1, 2, 255, 3, // + 0, 1, 255, 255, 2, 3, 255, 4, // + 255, 255, 0, 255, 1, 2, 255, 3, // + 0, 255, 1, 255, 2, 3, 255, 4, // + 255, 0, 1, 255, 2, 3, 255, 4, // + 0, 1, 2, 255, 3, 4, 255, 5, // + 255, 255, 255, 0, 1, 2, 255, 3, // + 0, 255, 255, 1, 2, 3, 255, 4, // + 255, 0, 255, 1, 2, 3, 255, 4, // + 0, 1, 255, 2, 3, 4, 255, 5, // + 255, 255, 0, 1, 2, 3, 255, 4, // + 0, 255, 1, 2, 3, 4, 255, 5, // + 255, 0, 1, 2, 3, 4, 255, 5, // + 0, 1, 2, 3, 4, 5, 255, 6, // + 255, 255, 255, 255, 255, 255, 0, 1, // + 0, 255, 255, 255, 255, 255, 1, 2, // + 255, 0, 255, 255, 255, 255, 1, 2, // + 0, 1, 255, 255, 255, 255, 2, 3, // + 255, 255, 0, 255, 255, 255, 1, 2, // + 0, 255, 1, 255, 255, 255, 2, 3, // + 255, 0, 1, 255, 255, 255, 2, 3, // + 0, 1, 2, 255, 255, 255, 3, 4, // + 255, 255, 255, 0, 255, 255, 1, 2, // + 0, 255, 255, 1, 255, 255, 2, 3, // + 255, 0, 255, 1, 255, 255, 2, 3, // + 0, 1, 255, 2, 255, 255, 3, 4, // + 255, 255, 0, 1, 255, 255, 2, 3, // + 0, 255, 1, 2, 255, 255, 3, 4, // + 255, 0, 1, 2, 255, 255, 3, 4, // + 0, 1, 2, 3, 255, 255, 4, 5, // + 255, 255, 255, 255, 0, 255, 1, 2, // + 0, 255, 255, 255, 1, 255, 2, 3, // + 255, 0, 255, 255, 1, 255, 2, 3, // + 0, 1, 255, 255, 2, 255, 3, 4, // + 255, 255, 0, 255, 1, 255, 2, 3, // + 0, 255, 1, 255, 2, 255, 3, 4, // + 255, 0, 1, 255, 2, 255, 3, 4, // + 0, 1, 2, 255, 3, 255, 4, 5, // + 255, 255, 255, 0, 1, 255, 2, 3, // + 0, 255, 255, 1, 2, 255, 3, 4, // + 255, 0, 255, 1, 2, 255, 3, 4, // + 0, 1, 255, 2, 3, 255, 4, 5, // + 255, 255, 0, 1, 2, 255, 3, 4, // + 0, 255, 1, 2, 3, 255, 4, 5, // + 255, 0, 1, 2, 3, 255, 4, 5, // + 0, 1, 2, 3, 4, 255, 5, 6, // + 255, 255, 255, 255, 255, 0, 1, 2, // + 0, 255, 255, 255, 255, 1, 2, 3, // + 255, 0, 255, 255, 255, 1, 2, 3, // + 0, 1, 255, 255, 255, 2, 3, 4, // + 255, 255, 0, 255, 255, 1, 2, 3, // + 0, 255, 1, 255, 255, 2, 3, 4, // + 255, 0, 1, 255, 255, 2, 3, 4, // + 0, 1, 2, 255, 255, 3, 4, 5, // + 255, 255, 255, 0, 255, 1, 2, 3, // + 0, 255, 255, 1, 255, 2, 3, 4, // + 255, 0, 255, 1, 255, 2, 3, 4, // + 0, 1, 255, 2, 255, 3, 4, 5, // + 255, 255, 0, 1, 255, 2, 3, 4, // + 0, 255, 1, 2, 255, 3, 4, 5, // + 255, 0, 1, 2, 255, 3, 4, 5, // + 0, 1, 2, 3, 255, 4, 5, 6, // + 255, 255, 255, 255, 0, 1, 2, 3, // + 0, 255, 255, 255, 1, 2, 3, 4, // + 255, 0, 255, 255, 1, 2, 3, 4, // + 0, 1, 255, 255, 2, 3, 4, 5, // + 255, 255, 0, 255, 1, 2, 3, 4, // + 0, 255, 1, 255, 2, 3, 4, 5, // + 255, 0, 1, 255, 2, 3, 4, 5, // + 0, 1, 2, 255, 3, 4, 5, 6, // + 255, 255, 255, 0, 1, 2, 3, 4, // + 0, 255, 255, 1, 2, 3, 4, 5, // + 255, 0, 255, 1, 2, 3, 4, 5, // + 0, 1, 255, 2, 3, 4, 5, 6, // + 255, 255, 0, 1, 2, 3, 4, 5, // + 0, 255, 1, 2, 3, 4, 5, 6, // + 255, 0, 1, 2, 3, 4, 5, 6, // + 0, 1, 2, 3, 4, 5, 6, 7}; + const svuint16_t indices = PromoteTo(du16, Load(du8, table + offset)); + return TableLookupLanes(v, indices); // already zeros mask=false lanes +#else + return detail::ExpandLoop(v, mask); +#endif +} + +template +HWY_API V Expand(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 32x8 + const DFromV d; + const RebindToUnsigned du32; + // Convert mask into bitfield via horizontal sum (faster than ORV). + const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0)); + const size_t code = detail::SumOfLanesM(mask, bits); + + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintExpand32x8. + 0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0, + 0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10, + 0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0, + 0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210, + 0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0, + 0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10, + 0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0, + 0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210, + 0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0, + 0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10, + 0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0, + 0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210, + 0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0, + 0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10, + 0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0, + 0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210, + 0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0, + 0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10, + 0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0, + 0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210, + 0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0, + 0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10, + 0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0, + 0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210, + 0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0, + 0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10, + 0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0, + 0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210, + 0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0, + 0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10, + 0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0, + 0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210, + 0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0, + 0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10, + 0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0, + 0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210, + 0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0, + 0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10, + 0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0, + 0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210, + 0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0, + 0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10, + 0x543210ff, 0x654321f0, 0x6543210f, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down and mask with 0xF because + // svtbl zeros outputs if the index is out of bounds. + const svuint32_t packed = Set(du32, packed_array[code]); + const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF); + return TableLookupLanes(v, indices); // already zeros mask=false lanes +#elif HWY_TARGET == HWY_SVE2_128 // 32x4 + const DFromV d; + const RebindToUnsigned du32; + // Convert mask into bitfield via horizontal sum (faster than ORV). + const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + alignas(16) constexpr uint32_t packed_array[16] = { + // PrintExpand64x4Nibble - same for 32x4. + 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, + 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, + 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down and mask with 0xF because + // svtbl zeros outputs if the index is out of bounds. + const svuint32_t packed = Set(du32, packed_array[offset]); + const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF); + return TableLookupLanes(v, indices); // already zeros mask=false lanes +#else + return detail::ExpandLoop(v, mask); +#endif +} + +template +HWY_API V Expand(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 64x4 + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintExpand64x4Tables - small enough to store uncompressed. + 255, 255, 255, 255, 0, 255, 255, 255, 255, 0, 255, 255, 0, 1, 255, 255, + 255, 255, 0, 255, 0, 255, 1, 255, 255, 0, 1, 255, 0, 1, 2, 255, + 255, 255, 255, 0, 0, 255, 255, 1, 255, 0, 255, 1, 0, 1, 255, 2, + 255, 255, 0, 1, 0, 255, 1, 2, 255, 0, 1, 2, 0, 1, 2, 3}; + // This already zeros mask=false lanes. + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#elif HWY_TARGET == HWY_SVE2_128 // 64x2 + // Same as Compress, just zero out the mask=false lanes. + return IfThenElseZero(mask, Compress(v, mask)); +#else + return detail::ExpandLoop(v, mask); +#endif +} + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +// ------------------------------ MulEven (InterleaveEven) + +#if HWY_SVE_HAVE_2 +namespace detail { +#define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) +HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) +HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) +HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulOddNative, mullt) +HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulOddNative, mullt) +HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulOddNative, mullt) +#undef HWY_SVE_MUL_EVEN +} // namespace detail +#endif + +template >, + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> +HWY_API VFromD MulEven(const V a, const V b) { +#if HWY_SVE_HAVE_2 + return BitCast(DW(), detail::MulEvenNative(a, b)); +#else + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), detail::InterleaveEven(lo, hi)); +#endif +} + +template >, + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> +HWY_API VFromD MulOdd(const V a, const V b) { +#if HWY_SVE_HAVE_2 + return BitCast(DW(), detail::MulOddNative(a, b)); +#else + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), detail::InterleaveOdd(lo, hi)); +#endif +} + +HWY_API svint64_t MulEven(const svint64_t a, const svint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveEven(lo, hi); +} + +HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveEven(lo, hi); +} + +HWY_API svint64_t MulOdd(const svint64_t a, const svint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveOdd(lo, hi); +} + +HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveOdd(lo, hi); +} + +// ------------------------------ PairwiseAdd/PairwiseSub +#if HWY_TARGET != HWY_SCALAR +#if HWY_SVE_HAVE_2 || HWY_IDE + +#ifdef HWY_NATIVE_PAIRWISE_ADD +#undef HWY_NATIVE_PAIRWISE_ADD +#else +#define HWY_NATIVE_PAIRWISE_ADD +#endif + +namespace detail { +#define HWY_SVE_SV_PAIRWISE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, HWY_SVE_V(BASE, BITS) a, \ + HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_m(HWY_SVE_PTRUE(BITS), a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SV_PAIRWISE_ADD, PairwiseAdd, addp) +#undef HWY_SVE_SV_PAIRWISE_ADD +} // namespace detail + +// Pairwise add returning interleaved output of a and b +template +HWY_API V PairwiseAdd(D d, V a, V b) { + return detail::PairwiseAdd(d, a, b); +} + +#endif // HWY_SVE_HAVE_2 +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ WidenMulPairwiseAdd + +template +HWY_API svfloat32_t WidenMulPairwiseAdd(Simd df, VBF16 a, + VBF16 b) { +#if HWY_SVE_HAVE_F32_TO_BF16C + const svfloat32_t even = svbfmlalb_f32(Zero(df), a, b); + return svbfmlalt_f32(even, a, b); +#else + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +#endif // HWY_SVE_HAVE_BF16_FEATURE +} + +template +HWY_API svint32_t WidenMulPairwiseAdd(Simd d32, svint16_t a, + svint16_t b) { +#if HWY_SVE_HAVE_2 + (void)d32; + return svmlalt_s32(svmullb_s32(a, b), a, b); +#else + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); +#endif +} + +template +HWY_API svuint32_t WidenMulPairwiseAdd(Simd d32, + svuint16_t a, svuint16_t b) { +#if HWY_SVE_HAVE_2 + (void)d32; + return svmlalt_u32(svmullb_u32(a, b), a, b); +#else + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); +#endif +} + +// ------------------------------ SatWidenMulPairwiseAccumulate +#if HWY_SVE_HAVE_2 +#define HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) dw, HWY_SVE_V(BASE, HALF) a, \ + HWY_SVE_V(BASE, HALF) b, HWY_SVE_V(BASE, BITS) sum) { \ + auto product = svmlalt_##CHAR##BITS(svmullb_##CHAR##BITS(a, b), a, b); \ + const auto mul_overflow = IfThenElseZero( \ + Eq(product, Set(dw, LimitsMin())), Set(dw, -1)); \ + return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), \ + Add(product, mul_overflow)); \ + } +HWY_SVE_FOREACH_UI16(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, + SatWidenMulPairwiseAccumulate, _) +HWY_SVE_FOREACH_UI32(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, + SatWidenMulPairwiseAccumulate, _) +HWY_SVE_FOREACH_UI64(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, + SatWidenMulPairwiseAccumulate, _) + +#undef HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2 +#endif + +// ------------------------------ SatWidenMulAccumFixedPoint + +#if HWY_SVE_HAVE_2 + +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return svqdmlalb_s32(sum, detail::ZipLowerSame(a, a), + detail::ZipLowerSame(b, b)); +} + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +#if HWY_SVE_HAVE_BF16_FEATURE + +// NOTE: we currently do not use SVE BFDOT for bf16 ReorderWidenMulAccumulate +// because, apparently unlike NEON, it uses round to odd unless the additional +// FEAT_EBF16 feature is available and enabled. +#ifdef HWY_NATIVE_MUL_EVEN_BF16 +#undef HWY_NATIVE_MUL_EVEN_BF16 +#else +#define HWY_NATIVE_MUL_EVEN_BF16 +#endif + +template +HWY_API svfloat32_t MulEvenAdd(Simd /* d */, VBF16 a, VBF16 b, + const svfloat32_t c) { + return svbfmlalb_f32(c, a, b); +} + +template +HWY_API svfloat32_t MulOddAdd(Simd /* d */, VBF16 a, VBF16 b, + const svfloat32_t c) { + return svbfmlalt_f32(c, a, b); +} + +#endif // HWY_SVE_HAVE_BF16_FEATURE + +template +HWY_API svint32_t ReorderWidenMulAccumulate(Simd d32, + svint16_t a, svint16_t b, + const svint32_t sum0, + svint32_t& sum1) { +#if HWY_SVE_HAVE_2 + (void)d32; + sum1 = svmlalt_s32(sum1, a, b); + return svmlalb_s32(sum0, a, b); +#else + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. + sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); +#endif +} + +template +HWY_API svuint32_t ReorderWidenMulAccumulate(Simd d32, + svuint16_t a, svuint16_t b, + const svuint32_t sum0, + svuint32_t& sum1) { +#if HWY_SVE_HAVE_2 + (void)d32; + sum1 = svmlalt_u32(sum1, a, b); + return svmlalb_u32(sum0, a, b); +#else + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. + sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); +#endif +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // sum0 is the sum of bottom/even lanes and sum1 of top/odd lanes. + return Add(sum0, sum1); +} + +// ------------------------------ SumOfMulQuadAccumulate + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, svint8_t a, + svint8_t b, svint32_t sum) { + return svdot_s32(sum, a, b); +} + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DU32 /*du32*/, svuint8_t a, + svuint8_t b, svuint32_t sum) { + return svdot_u32(sum, a, b); +} + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u, + svint8_t b_i, svint32_t sum) { + // TODO: use svusdot_u32 on SVE targets that require support for both SVE2 + // and SVE I8MM. + + const RebindToUnsigned du32; + const Repartition du8; + + const auto b_u = BitCast(du8, b_i); + const auto result_sum0 = svdot_u32(BitCast(du32, sum), a_u, b_u); + const auto result_sum1 = + ShiftLeft<8>(svdot_u32(Zero(du32), a_u, ShiftRight<7>(b_u))); + + return BitCast(di32, Sub(result_sum0, result_sum1)); +} + +#ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI64 /*di64*/, svint16_t a, + svint16_t b, svint64_t sum) { + return svdot_s64(sum, a, b); +} + +#ifdef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DU64 /*du64*/, svuint16_t a, + svuint16_t b, svuint64_t sum) { + return svdot_u64(sum, a, b); +} + +// ------------------------------ MulComplex* / MaskedMulComplex* + +// Per-target flag to prevent generic_ops-inl.h from defining MulComplex*. +#ifdef HWY_NATIVE_CPLX +#undef HWY_NATIVE_CPLX +#else +#define HWY_NATIVE_CPLX +#endif + +template )> +HWY_API V ComplexConj(V a) { + return OddEven(Neg(a), a); +} + +namespace detail { +#define HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, ROT) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##ROT(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b, c, ROT); \ + } \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##Z##ROT(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ + HWY_SVE_V(BASE, BITS) b, HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS##_z(m, a, b, c, ROT); \ + } + +#define HWY_SVE_CPLX_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 0) \ + HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 90) \ + HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 180) \ + HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 270) + +// Only SVE2 has complex multiply add for integer types +// and these do not include masked variants +HWY_SVE_FOREACH_F(HWY_SVE_CPLX_FMA, ComplexMulAdd, cmla) +#undef HWY_SVE_CPLX_FMA +#undef HWY_SVE_CPLX_FMA_ROT +} // namespace detail + +template +HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { + const V t = detail::ComplexMulAddZ0(mask, c, b, a); + return detail::ComplexMulAddZ270(mask, t, b, a); +} + +template +HWY_API V MaskedMulComplexConj(M mask, V a, V b) { + return MaskedMulComplexConjAdd(mask, a, b, Zero(DFromV())); +} + +template +HWY_API V MulComplexAdd(V a, V b, V c) { + return detail::ComplexMulAdd90(detail::ComplexMulAdd0(c, a, b), a, b); +} + +template +HWY_API V MulComplex(V a, V b) { + return MulComplexAdd(a, b, Zero(DFromV())); +} + +template +HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { + return IfThenElse(mask, MulComplex(a, b), no); +} + +template +HWY_API V MulComplexConjAdd(V a, V b, V c) { + return detail::ComplexMulAdd270(detail::ComplexMulAdd0(c, b, a), b, a); +} + +template +HWY_API V MulComplexConj(V a, V b) { + return MulComplexConjAdd(a, b, Zero(DFromV())); +} + +// TODO SVE2 does have intrinsics for integers but not masked variants +template +HWY_API V MulComplex(V a, V b) { + // a = u + iv, b = x + iy + const auto u = DupEven(a); + const auto v = DupOdd(a); + const auto x = DupEven(b); + const auto y = DupOdd(b); + + return OddEven(MulAdd(u, y, Mul(v, x)), Sub(Mul(u, x), Mul(v, y))); +} + +template +HWY_API V MulComplexConj(V a, V b) { + // a = u + iv, b = x + iy + const auto u = DupEven(a); + const auto v = DupOdd(a); + const auto x = DupEven(b); + const auto y = DupOdd(b); + + return OddEven(Sub(Mul(v, x), Mul(u, y)), MulAdd(u, x, Mul(v, y))); +} + +template +HWY_API V MulComplexAdd(V a, V b, V c) { + return Add(MulComplex(a, b), c); +} + +template +HWY_API V MulComplexConjAdd(V a, V b, V c) { + return Add(MulComplexConj(a, b), c); +} + +template +HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { + return IfThenElseZero(mask, MulComplexConjAdd(a, b, c)); +} + +template +HWY_API V MaskedMulComplexConj(M mask, V a, V b) { + return IfThenElseZero(mask, MulComplexConj(a, b)); +} + +template +HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { + return IfThenElse(mask, MulComplex(a, b), no); +} + +// ------------------------------ AESRound / CLMul + +// Static dispatch with -march=armv8-a+sve2+aes, or dynamic dispatch WITHOUT a +// baseline, in which case we check for AES support at runtime. +#if defined(__ARM_FEATURE_SVE2_AES) || \ + (HWY_SVE_HAVE_2 && HWY_HAVE_RUNTIME_DISPATCH && HWY_BASELINE_SVE2 == 0) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) { + // It is not clear whether E and MC fuse like they did on NEON. + return Xor(svaesmc_u8(svaese_u8(state, svdup_n_u8(0))), round_key); +} + +HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) { + return Xor(svaese_u8(state, svdup_n_u8(0)), round_key); +} + +HWY_API svuint8_t AESInvMixColumns(svuint8_t state) { + return svaesimc_u8(state); +} + +HWY_API svuint8_t AESRoundInv(svuint8_t state, svuint8_t round_key) { + return Xor(svaesimc_u8(svaesd_u8(state, svdup_n_u8(0))), round_key); +} + +HWY_API svuint8_t AESLastRoundInv(svuint8_t state, svuint8_t round_key) { + return Xor(svaesd_u8(state, svdup_n_u8(0)), round_key); +} + +template +HWY_API svuint8_t AESKeyGenAssist(svuint8_t v) { + alignas(16) static constexpr uint8_t kRconXorMask[16] = { + 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0}; + alignas(16) static constexpr uint8_t kRotWordShuffle[16] = { + 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12}; + const DFromV d; + const Repartition du32; + const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); + const auto sub_word_result = AESLastRound(w13, LoadDup128(d, kRconXorMask)); + return TableLookupBytes(sub_word_result, LoadDup128(d, kRotWordShuffle)); +} + +HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) { + return svpmullb_pair(a, b); +} + +HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) { + return svpmullt_pair(a, b); +} + +#endif // __ARM_FEATURE_SVE2_AES + +// ------------------------------ Lt128 + +namespace detail { +#define HWY_SVE_DUP(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, svbool_t m) { \ + return sv##OP##_b##BITS(m, m); \ + } + +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupEvenB, trn1) // actually for bool +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupOddB, trn2) // actually for bool +#undef HWY_SVE_DUP + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +HWY_INLINE svuint64_t Lt128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t ltHL = VecFromMask(d, Lt(a, b)); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + // Requires an extra IfThenElse because INSR, EXT, TRN2 are unpredicated. + const svuint64_t ltHx = IfThenElse(eqHx, DupEven(ltHL), ltHL); + // Duplicate upper lane into lower. + return DupOdd(ltHx); +} +#endif +} // namespace detail + +template +HWY_INLINE svbool_t Lt128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Lt128Vec(d, a, b)); +#else + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + const svbool_t ltHL = Lt(a, b); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + const svbool_t ltHx = svsel_b(eqHx, detail::DupEvenB(d, ltHL), ltHL); + // Duplicate upper lane into lower. + return detail::DupOddB(d, ltHx); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Lt128Upper + +template +HWY_INLINE svbool_t Lt128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t ltHL = Lt(a, b); + return detail::DupOddB(d, ltHL); +} + +// ------------------------------ Eq128, Ne128 + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +namespace detail { + +template +HWY_INLINE svuint64_t Eq128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t eqHL = VecFromMask(d, Eq(a, b)); + // Duplicate upper and lower. + const svuint64_t eqHH = DupOdd(eqHL); + const svuint64_t eqLL = DupEven(eqHL); + return And(eqLL, eqHH); +} + +template +HWY_INLINE svuint64_t Ne128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t neHL = VecFromMask(d, Ne(a, b)); + // Duplicate upper and lower. + const svuint64_t neHH = DupOdd(neHL); + const svuint64_t neLL = DupEven(neHL); + return Or(neLL, neHH); +} + +} // namespace detail +#endif + +template +HWY_INLINE svbool_t Eq128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Eq128Vec(d, a, b)); +#else + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t eqHL = Eq(a, b); + const svbool_t eqHH = detail::DupOddB(d, eqHL); + const svbool_t eqLL = detail::DupEvenB(d, eqHL); + return And(eqLL, eqHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +template +HWY_INLINE svbool_t Ne128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Ne128Vec(d, a, b)); +#else + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t neHL = Ne(a, b); + const svbool_t neHH = detail::DupOddB(d, neHL); + const svbool_t neLL = detail::DupEvenB(d, neHL); + return Or(neLL, neHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Eq128Upper, Ne128Upper + +template +HWY_INLINE svbool_t Eq128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t eqHL = Eq(a, b); + return detail::DupOddB(d, eqHL); +} + +template +HWY_INLINE svbool_t Ne128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t neHL = Ne(a, b); + return detail::DupOddB(d, neHL); +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +#else + return IfThenElse(Lt128(d, a, b), a, b); +#endif +} + +template +HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +#else + return IfThenElse(Lt128(d, b, a), a, b); +#endif +} + +template +HWY_INLINE svuint64_t Min128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE svuint64_t Max128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +#define HWY_SVE_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + const DFromV d; \ + return BitCast(d, sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_LEADING_ZERO_COUNT, LeadingZeroCount, clz) +#undef HWY_SVE_LEADING_ZERO_COUNT + +template +HWY_API V TrailingZeroCount(V v) { + return LeadingZeroCount(ReverseBits(v)); +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v))); +} + +#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#endif + +#define HWY_SVE_MASKED_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ + const DFromV d; \ + return BitCast(d, sv##OP##_##CHAR##BITS##_z(m, v)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, MaskedLeadingZeroCount, + clz) +#undef HWY_SVE_LEADING_ZERO_COUNT + +// ================================================== END MACROS +#undef HWY_SVE_ALL_PTRUE +#undef HWY_SVE_D +#undef HWY_SVE_FOREACH +#undef HWY_SVE_FOREACH_BF16 +#undef HWY_SVE_FOREACH_BF16_UNCONDITIONAL +#undef HWY_SVE_FOREACH_F +#undef HWY_SVE_FOREACH_F16 +#undef HWY_SVE_FOREACH_F32 +#undef HWY_SVE_FOREACH_F3264 +#undef HWY_SVE_FOREACH_F64 +#undef HWY_SVE_FOREACH_I +#undef HWY_SVE_FOREACH_I08 +#undef HWY_SVE_FOREACH_I16 +#undef HWY_SVE_FOREACH_I32 +#undef HWY_SVE_FOREACH_I64 +#undef HWY_SVE_FOREACH_IF +#undef HWY_SVE_FOREACH_U +#undef HWY_SVE_FOREACH_U08 +#undef HWY_SVE_FOREACH_U16 +#undef HWY_SVE_FOREACH_U32 +#undef HWY_SVE_FOREACH_U64 +#undef HWY_SVE_FOREACH_UI +#undef HWY_SVE_FOREACH_UI08 +#undef HWY_SVE_FOREACH_UI16 +#undef HWY_SVE_FOREACH_UI32 +#undef HWY_SVE_FOREACH_UI64 +#undef HWY_SVE_FOREACH_UIF3264 +#undef HWY_SVE_HAVE_2 +#undef HWY_SVE_IF_EMULATED_D +#undef HWY_SVE_IF_NOT_EMULATED_D +#undef HWY_SVE_PTRUE +#undef HWY_SVE_RETV_ARGMVV +#undef HWY_SVE_RETV_ARGMVV_Z +#undef HWY_SVE_RETV_ARGMV_Z +#undef HWY_SVE_RETV_ARGMV +#undef HWY_SVE_RETV_ARGMVV_Z +#undef HWY_SVE_RETV_ARGPV +#undef HWY_SVE_RETV_ARGPVN +#undef HWY_SVE_RETV_ARGPVV +#undef HWY_SVE_RETV_ARGV +#undef HWY_SVE_RETV_ARGVN +#undef HWY_SVE_RETV_ARGMV_M +#undef HWY_SVE_RETV_ARGVV +#undef HWY_SVE_RETV_ARGVVV +#undef HWY_SVE_RETV_ARGMVVV_Z +#undef HWY_SVE_RETV_ARGMVVV +#undef HWY_SVE_T +#undef HWY_SVE_UNDEFINED +#undef HWY_SVE_V + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/aom/third_party/highway/hwy/ops/emu128-inl.h b/third_party/aom/third_party/highway/hwy/ops/emu128-inl.h new file mode 100644 index 000000000000..7d54b79f5770 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/emu128-inl.h @@ -0,0 +1,2985 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include "third_party/highway/hwy/base.h" + +#ifndef HWY_NO_LIBCXX +#include // sqrtf +#endif + +#include "third_party/highway/hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +using Full128 = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec128 { + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + HWY_INLINE Vec128() = default; + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + // Behave like wasm128 (vectors can always hold 128 bits). generic_ops-inl.h + // relies on this for LoadInterleaved*. CAVEAT: this method of padding + // prevents using range for, especially in SumOfLanes, where it would be + // incorrect. Moving padding to another field would require handling the case + // where N = 16 / sizeof(T) (i.e. there is no padding), which is also awkward. + T raw[16 / sizeof(T)] = {}; +}; + +// 0 or FF..FF, same size as Vec128. +template +struct Mask128 { + using Raw = hwy::MakeUnsigned; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM + + static HWY_INLINE Raw FromBool(bool b) { + return b ? static_cast(~Raw{0}) : 0; + } + + // Must match the size of Vec128. + Raw bits[16 / sizeof(T)] = {}; +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ Zero + +// Use HWY_MAX_LANES_D here because VFromD is defined in terms of Zero. +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + Vec128, HWY_MAX_LANES_D(D)> v; // zero-initialized + return v; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ BitCast + +template +HWY_API VFromD BitCast(D /* tag */, VFrom v) { + VFromD to; + CopySameSize(&v.raw, &to.raw); + return to; +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D d, VFrom v) { + using DFrom = DFromV; + using TFrom = TFromD; + using TTo = TFromD; + + constexpr size_t kFromByteLen = sizeof(TFrom) * HWY_MAX_LANES_D(DFrom); + constexpr size_t kToByteLen = sizeof(TTo) * HWY_MAX_LANES_D(D); + constexpr size_t kCopyByteLen = HWY_MIN(kFromByteLen, kToByteLen); + + VFromD to = Zero(d); + CopyBytes(&v.raw, &to.raw); + return to; +} + +namespace detail { + +// ResizeBitCast on the HWY_EMU128 target has zero-extending semantics if +// VFromD is a larger vector than FromV +template +HWY_INLINE VFromD ZeroExtendResizeBitCast(FromSizeTag /* from_size_tag */, + ToSizeTag /* to_size_tag */, + DTo d_to, DFrom /* d_from */, + VFromD v) { + return ResizeBitCast(d_to, v); +} + +} // namespace detail + +// ------------------------------ Set +template +HWY_API VFromD Set(D d, const T2 t) { + VFromD v; + for (size_t i = 0; i < MaxLanes(d); ++i) { + v.raw[i] = ConvertScalarTo>(t); + } + return v; +} + +// ------------------------------ Undefined +template +HWY_API VFromD Undefined(D d) { + return Zero(d); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + result.raw[2] = t2; + result.raw[3] = t3; + result.raw[4] = t4; + result.raw[5] = t5; + result.raw[6] = t6; + result.raw[7] = t7; + result.raw[8] = t8; + result.raw[9] = t9; + result.raw[10] = t10; + result.raw[11] = t11; + result.raw[12] = t12; + result.raw[13] = t13; + result.raw[14] = t14; + result.raw[15] = t15; + return result; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + result.raw[2] = t2; + result.raw[3] = t3; + result.raw[4] = t4; + result.raw[5] = t5; + result.raw[6] = t6; + result.raw[7] = t7; + return result; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + result.raw[2] = t2; + result.raw[3] = t3; + return result; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + return result; +} + +// ------------------------------ Iota + +template , typename T2> +HWY_API VFromD Iota(D d, T2 first) { + VFromD v; + for (size_t i = 0; i < MaxLanes(d); ++i) { + v.raw[i] = AddWithWraparound(static_cast(first), i); + } + return v; +} + +// ================================================== LOGICAL + +// ------------------------------ Not +template +HWY_API Vec128 Not(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + VFromD vu = BitCast(du, v); + for (size_t i = 0; i < N; ++i) { + vu.raw[i] = static_cast(~vu.raw[i]); + } + return BitCast(d, vu); +} + +// ------------------------------ And +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] &= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator&(Vec128 a, Vec128 b) { + return And(a, b); +} + +// ------------------------------ AndNot +template +HWY_API Vec128 AndNot(Vec128 a, Vec128 b) { + return And(Not(a), b); +} + +// ------------------------------ Or +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] |= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator|(Vec128 a, Vec128 b) { + return Or(a, b); +} + +// ------------------------------ Xor +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] ^= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator^(Vec128 a, Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ Xor3 +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +// ------------------------------ CopySign +template +HWY_API Vec128 CopySign(Vec128 magn, Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API Vec128 BroadcastSignBit(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = ScalarShr(v.raw[i], sizeof(T) * 8 - 1); + } + return v; +} + +// ------------------------------ Mask + +// v must be 0 or FF..FF. +template +HWY_API Mask128 MaskFromVec(Vec128 v) { + Mask128 mask; + CopySameSize(&v.raw, &mask.bits); + return mask; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API MFromD RebindMask(DTo /* tag */, MFrom mask) { + MFromD to; + CopySameSize(&mask.bits, &to.bits); + return to; +} + +template +VFromD VecFromMask(D /* tag */, MFromD mask) { + VFromD v; + CopySameSize(&mask.bits, &v.raw); + return v; +} + +template +uint64_t BitsFromMask(D d, MFromD mask) { + uint64_t bits = 0; + for (size_t i = 0; i < Lanes(d); ++i) { + bits |= mask.bits[i] ? (1ull << i) : 0; + } + return bits; +} + +template +HWY_API MFromD FirstN(D d, size_t n) { + MFromD m; + for (size_t i = 0; i < MaxLanes(d); ++i) { + m.bits[i] = MFromD::FromBool(i < n); + } + return m; +} + +// Returns mask ? yes : no. +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const DFromV d; + return IfVecThenElse(VecFromMask(d, mask), yes, no); +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + const DFromV d; + return IfVecThenElse(VecFromMask(d, mask), yes, Zero(d)); +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + const DFromV d; + return IfVecThenElse(VecFromMask(d, mask), Zero(d), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + const DFromV d; + const RebindToSigned di; + const auto vi = BitCast(di, v); + + for (size_t i = 0; i < N; ++i) { + v.raw[i] = vi.raw[i] < 0 ? yes.raw[i] : no.raw[i]; + } + return v; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(Mask128 m) { + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template +HWY_API Vec128 ShiftLeft(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + using TU = hwy::MakeUnsigned; + for (size_t i = 0; i < N; ++i) { + const TU raw_u = static_cast(v.raw[i]); + const auto shifted = raw_u << kBits; // separate line to avoid MSVC warning + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 ShiftRight(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = ScalarShr(v.raw[i], kBits); + } + + return v; +} + +// ------------------------------ RotateRight (ShiftRight) +template +HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(Vec128 v, int bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) << bits; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, int bits) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = ScalarShr(v.raw[i], bits); + } + + return v; +} + +// ------------------------------ Shl + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) + << bits.raw[i]; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = ScalarShr(v.raw[i], static_cast(bits.raw[i])); + } + + return v; +} + +// ================================================== ARITHMETIC + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Add(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast(a.raw[i]); + const uint64_t b64 = static_cast(b.raw[i]); + a.raw[i] = static_cast((a64 + b64) & static_cast(~T(0))); + } + return a; +} +template +HWY_INLINE Vec128 Sub(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast(a.raw[i]); + const uint64_t b64 = static_cast(b.raw[i]); + a.raw[i] = static_cast((a64 - b64) & static_cast(~T(0))); + } + return a; +} + +template +HWY_INLINE Vec128 Add(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] += b.raw[i]; + } + return a; +} + +template +HWY_INLINE Vec128 Sub(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] -= b.raw[i]; + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 operator-(Vec128 a, Vec128 b) { + return detail::Sub(hwy::IsFloatTag(), a, b); +} +template +HWY_API Vec128 operator+(Vec128 a, Vec128 b) { + return detail::Add(hwy::IsFloatTag(), a, b); +} + +// ------------------------------ SumsOf8 + +template +HWY_API Vec128 SumsOf8(Vec128 v) { + Vec128 sums; + for (size_t i = 0; i < N; ++i) { + sums.raw[i / 8] += v.raw[i]; + } + return sums; +} + +template +HWY_API Vec128 SumsOf8(Vec128 v) { + Vec128 sums; + for (size_t i = 0; i < N; ++i) { + sums.raw[i / 8] += v.raw[i]; + } + return sums; +} + +// ------------------------------ SaturatedAdd +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + using TW = MakeSigned>; + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(HWY_MIN( + HWY_MAX(hwy::LowestValue(), static_cast(a.raw[i]) + b.raw[i]), + hwy::HighestValue())); + } + return a; +} + +// ------------------------------ SaturatedSub +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + using TW = MakeSigned>; + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(HWY_MIN( + HWY_MAX(hwy::LowestValue(), static_cast(a.raw[i]) - b.raw[i]), + hwy::HighestValue())); + } + return a; +} + +// ------------------------------ AverageRound + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +template +HWY_API Vec128 AverageRound(Vec128 a, Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const T a_val = a.raw[i]; + const T b_val = b.raw[i]; + a.raw[i] = static_cast((a_val | b_val) - ScalarShr(a_val ^ b_val, 1)); + } + return a; +} + +// ------------------------------ Abs + +template +HWY_API Vec128 Abs(Vec128 a) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = ScalarAbs(a.raw[i]); + } + return a; +} + +// ------------------------------ Min/Max + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Min(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + return a; +} +template +HWY_INLINE Vec128 Max(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + return a; +} + +template +HWY_INLINE Vec128 Min(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + if (ScalarIsNaN(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (ScalarIsNaN(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + } + return a; +} +template +HWY_INLINE Vec128 Max(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + if (ScalarIsNaN(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (ScalarIsNaN(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return detail::Min(hwy::IsFloatTag(), a, b); +} + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return detail::Max(hwy::IsFloatTag(), a, b); +} + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_API Vec128 Neg(hwy::NonFloatTag /*tag*/, Vec128 v) { + const DFromV d; + return Zero(d) - v; +} + +template +HWY_API Vec128 Neg(hwy::FloatTag /*tag*/, Vec128 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +template +HWY_API Vec128 Neg(hwy::SpecialTag /*tag*/, Vec128 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +} // namespace detail + +template +HWY_API Vec128 Neg(Vec128 v) { + return detail::Neg(hwy::IsFloatTag(), v); +} + +// ------------------------------ Mul/Div + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Mul(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] *= b.raw[i]; + } + return a; +} + +template +HWY_INLINE Vec128 Mul(SignedTag /*tag*/, Vec128 a, Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(static_cast(a.raw[i]) * + static_cast(b.raw[i])); + } + return a; +} + +template +HWY_INLINE Vec128 Mul(UnsignedTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(static_cast(a.raw[i]) * + static_cast(b.raw[i])); + } + return a; +} + +} // namespace detail + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return detail::Mul(hwy::TypeTag(), a, b); +} + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = (b.raw[i] == T{0}) ? 0 : a.raw[i] / b.raw[i]; + } + return a; +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + using TW = MakeWide; + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast( + (static_cast(a.raw[i]) * static_cast(b.raw[i])) >> + (sizeof(T) * 8)); + } + return a; +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Set(Full64(), hi); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi_0; + T hi_1; + + Mul128(GetLane(a), GetLane(b), &hi_0); + Mul128(ExtractLane(a, 1), ExtractLane(b, 1), &hi_1); + + return Dup128VecFromValues(Full128(), hi_0, hi_1); +} + +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast((a.raw[i] * b.raw[i] + 16384) >> 15); + } + return a; +} + +// Multiplies even lanes (0, 2, ..) and returns the double-wide result. +template +HWY_API Vec128, (N + 1) / 2> MulEven(Vec128 a, + Vec128 b) { + using TW = MakeWide; + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const TW a_wide = a.raw[i]; + mul.raw[i / 2] = static_cast(a_wide * b.raw[i]); + } + return mul; +} + +// Multiplies odd lanes (1, 3, ..) and returns the double-wide result. +template +HWY_API Vec128, (N + 1) / 2> MulOdd(Vec128 a, + Vec128 b) { + using TW = MakeWide; + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const TW a_wide = a.raw[i + 1]; + mul.raw[i / 2] = static_cast(a_wide * b.raw[i + 1]); + } + return mul; +} + +template +HWY_API Vec128 ApproximateReciprocal(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The result is arbitrary. + v.raw[i] = (ScalarAbs(v.raw[i]) == 0.0f) ? 0.0f : 1.0f / v.raw[i]; + } + return v; +} + +// generic_ops takes care of integer T. +template +HWY_API Vec128 AbsDiff(Vec128 a, Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return mul * x + add; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return add - mul * x; +} + +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return mul * x - sub; +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + const float half = v.raw[i] * 0.5f; + // Initial guess based on log2(f) + v.raw[i] = BitCastScalar(static_cast( + 0x5F3759DF - (BitCastScalar(v.raw[i]) >> 1))); + // One Newton-Raphson iteration + v.raw[i] = v.raw[i] * (1.5f - (half * v.raw[i] * v.raw[i])); + } + return v; +} + +namespace detail { + +static HWY_INLINE float ScalarSqrt(float v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return __builtin_sqrt(v); +#else + uint32_t bits = BitCastScalar(v); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1 << 29) + (bits >> 1) - (1 << 22); + return BitCastScalar(bits); +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return sqrtf(v); +#endif // !HWY_NO_LIBCXX +} +static HWY_INLINE double ScalarSqrt(double v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return __builtin_sqrt(v); +#else + uint64_t bits = BitCastScalar(v); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1ULL << 61) + (bits >> 1) - (1ULL << 51); + return BitCastScalar(bits); +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return sqrt(v); +#endif // HWY_NO_LIBCXX +} + +} // namespace detail + +template +HWY_API Vec128 Sqrt(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = detail::ScalarSqrt(v.raw[i]); + } + return v; +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec128 Round(Vec128 v) { + using TI = MakeSigned; + const T k0 = ConvertScalarTo(0); + const Vec128 a = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(a.raw[i] < MantissaEnd())) { // Huge or NaN + continue; + } + const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw[i] + bias); + if (rounded == 0) { + v.raw[i] = v.raw[i] < 0 ? ConvertScalarTo(-0) : k0; + continue; + } + const T rounded_f = ConvertScalarTo(rounded); + // Round to even + if ((rounded & 1) && + ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { + v.raw[i] = ConvertScalarTo(rounded - (v.raw[i] < k0 ? -1 : 1)); + continue; + } + v.raw[i] = rounded_f; + } + return v; +} + +// Round-to-nearest even. +template +HWY_API Vec128, N> NearestInt(Vec128 v) { + using TI = MakeSigned; + const T k0 = ConvertScalarTo(0); + + const Vec128 abs = Abs(v); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + const bool signbit = ScalarSignBit(v.raw[i]); + + if (!(abs.raw[i] < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs.raw[i] <= ConvertScalarTo(LimitsMax()))) { + ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); + continue; + } + ret.raw[i] = static_cast(v.raw[i]); + continue; + } + const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw[i] + bias); + if (rounded == 0) { + ret.raw[i] = 0; + continue; + } + const T rounded_f = ConvertScalarTo(rounded); + // Round to even + if ((rounded & 1) && + ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { + ret.raw[i] = rounded - (signbit ? -1 : 1); + continue; + } + ret.raw[i] = rounded; + } + return ret; +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 /*di32*/, + VFromD> v) { + using T = double; + using TI = int32_t; + const T k0 = ConvertScalarTo(0); + + constexpr size_t N = HWY_MAX_LANES_D(DI32); + + const VFromD> abs = Abs(v); + VFromD ret; + for (size_t i = 0; i < N; ++i) { + const bool signbit = ScalarSignBit(v.raw[i]); + + // Check if too large to cast or NaN + if (!(abs.raw[i] <= ConvertScalarTo(LimitsMax()))) { + ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); + continue; + } + + const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw[i] + bias); + if (rounded == 0) { + ret.raw[i] = 0; + continue; + } + const T rounded_f = ConvertScalarTo(rounded); + // Round to even + if ((rounded & 1) && + ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { + ret.raw[i] = rounded - (signbit ? -1 : 1); + continue; + } + ret.raw[i] = rounded; + } + return ret; +} + +template +HWY_API Vec128 Trunc(Vec128 v) { + using TI = MakeSigned; + const Vec128 abs = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(abs.raw[i] <= MantissaEnd())) { // Huge or NaN + continue; + } + const TI truncated = static_cast(v.raw[i]); + if (truncated == 0) { + v.raw[i] = v.raw[i] < 0 ? -T{0} : T{0}; + continue; + } + v.raw[i] = static_cast(truncated); + } + return v; +} + +// Toward +infinity, aka ceiling +template +Vec128 Ceil(Vec128 v) { + constexpr int kMantissaBits = MantissaBits(); + using Bits = MakeUnsigned; + const Bits kExponentMask = MaxExponentField(); + const Bits kMantissaMask = MantissaMask(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool positive = v.raw[i] > Float(0.0); + + Bits bits = BitCastScalar(v.raw[i]); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => 0 or 1. + if (exponent < 0) { + v.raw[i] = positive ? Float{1} : Float{-0.0}; + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + v.raw[i] = BitCastScalar(bits); + } + return v; +} + +// Toward -infinity, aka floor +template +Vec128 Floor(Vec128 v) { + constexpr int kMantissaBits = MantissaBits(); + using Bits = MakeUnsigned; + const Bits kExponentMask = MaxExponentField(); + const Bits kMantissaMask = MantissaMask(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool negative = v.raw[i] < Float(0.0); + + Bits bits = BitCastScalar(v.raw[i]); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => -1 or 0. + if (exponent < 0) { + v.raw[i] = negative ? Float(-1.0) : Float(0.0); + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + v.raw[i] = BitCastScalar(bits); + } + return v; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask128 IsNaN(Vec128 v) { + Mask128 ret; + for (size_t i = 0; i < N; ++i) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + ret.bits[i] = Mask128::FromBool(ScalarIsNaN(v.raw[i])); + } + return ret; +} + +// ================================================== COMPARE + +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] == b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] != b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] < b.raw[i]); + } + return m; +} +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] > b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] <= b.raw[i]); + } + return m; +} +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] >= b.raw[i]); + } + return m; +} + +// ------------------------------ Lt128 + +// Only makes sense for full vectors of u64. +template +HWY_API MFromD Lt128(D /* tag */, Vec128 a, Vec128 b) { + const bool lt = + (a.raw[1] < b.raw[1]) || (a.raw[1] == b.raw[1] && a.raw[0] < b.raw[0]); + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); + return ret; +} + +template +HWY_API MFromD Lt128Upper(D /* tag */, Vec128 a, + Vec128 b) { + const bool lt = a.raw[1] < b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); + return ret; +} + +// ------------------------------ Eq128 + +// Only makes sense for full vectors of u64. +template +HWY_API MFromD Eq128(D /* tag */, Vec128 a, Vec128 b) { + const bool eq = a.raw[1] == b.raw[1] && a.raw[0] == b.raw[0]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); + return ret; +} + +template +HWY_API Mask128 Ne128(D /* tag */, Vec128 a, + Vec128 b) { + const bool ne = a.raw[1] != b.raw[1] || a.raw[0] != b.raw[0]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); + return ret; +} + +template +HWY_API MFromD Eq128Upper(D /* tag */, Vec128 a, + Vec128 b) { + const bool eq = a.raw[1] == b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); + return ret; +} + +template +HWY_API MFromD Ne128Upper(D /* tag */, Vec128 a, + Vec128 b) { + const bool ne = a.raw[1] != b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); + return ret; +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_API VFromD Min128(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_API VFromD Max128(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_API VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_API VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { + VFromD v; + CopyBytes(aligned, v.raw); // copy from array + return v; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElse(m, LoadU(d, p), v); +} + +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t max_lanes_to_load) { + VFromD v = Zero(d); + const size_t N = Lanes(d); + const size_t num_of_lanes_to_load = HWY_MIN(max_lanes_to_load, N); + CopyBytes(p, v.raw, num_of_lanes_to_load * sizeof(TFromD)); + return v; +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t max_lanes_to_load) { + VFromD v = no; + const size_t N = Lanes(d); + const size_t num_of_lanes_to_load = HWY_MIN(max_lanes_to_load, N); + CopyBytes(p, v.raw, num_of_lanes_to_load * sizeof(TFromD)); + return v; +} + +// ------------------------------ Store + +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + CopyBytes(v.raw, aligned); // copy to array +} + +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (m.bits[i]) p[i] = v.raw[i]; + } +} + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template +HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t N = Lanes(d); + const size_t num_of_lanes_to_store = HWY_MIN(max_lanes_to_store, N); + CopyBytes(v.raw, p, num_of_lanes_to_store * sizeof(TFromD)); +} + +// ================================================== COMBINE + +template +HWY_API Vec128 LowerHalf(Vec128 v) { + Vec128 ret; + CopyBytes(v.raw, ret.raw); + return ret; +} + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return LowerHalf(v); +} + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + VFromD ret; + CopyBytes(&v.raw[MaxLanes(d)], ret.raw); + return ret; +} + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> v) { + const Half dh; + VFromD ret; // zero-initialized + CopyBytes(v.raw, ret.raw); + return ret; +} + +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + VFromD ret; + CopyBytes(lo_half.raw, &ret.raw[0]); + CopyBytes(hi_half.raw, &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); + CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[i] = lo.raw[2 * i]; + } + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i]; + } + return ret; +} + +// 2023-11-23: workaround for incorrect codegen (reduction_test fails for +// SumsOf2 because PromoteOddTo, which uses ConcatOdd, returns zero). +#if HWY_ARCH_RISCV && HWY_TARGET == HWY_EMU128 && HWY_COMPILER_CLANG +#define HWY_EMU128_CONCAT_INLINE HWY_NOINLINE +#else +#define HWY_EMU128_CONCAT_INLINE HWY_API +#endif + +template +HWY_EMU128_CONCAT_INLINE VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[i] = lo.raw[2 * i + 1]; + } + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i + 1]; + } + return ret; +} + +// ------------------------------ CombineShiftRightBytes +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + VFromD ret; + const uint8_t* HWY_RESTRICT lo8 = + reinterpret_cast(lo.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(lo8 + kBytes, ret8); + CopyBytes(hi.raw, ret8 + d.MaxBytes() - kBytes); + return ret; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + VFromD ret; + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + ZeroBytes(ret8); + CopyBytes(v.raw, ret8 + kBytes); + return ret; +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API VFromD ShiftLeftLanes(D d, VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + VFromD ret; + const uint8_t* HWY_RESTRICT v8 = + reinterpret_cast(v.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(v8 + kBytes, ret8); + ZeroBytes(ret8 + d.MaxBytes() - kBytes); + return ret; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API VFromD ShiftRightLanes(D d, VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ Tuples, PromoteEvenTo/PromoteOddTo +#include "third_party/highway/hwy/ops/inside-inl.h" + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +// We implement those here because scalar code is likely faster than emulation +// via shuffles. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +// Same for Load/StoreInterleaved of special floats. +#ifdef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#endif + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + alignas(16) T buf0[MaxLanes(d)]; + alignas(16) T buf1[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); +} + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + alignas(16) T buf0[MaxLanes(d)]; + alignas(16) T buf1[MaxLanes(d)]; + alignas(16) T buf2[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); +} + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + alignas(16) T buf0[MaxLanes(d)]; + alignas(16) T buf1[MaxLanes(d)]; + alignas(16) T buf2[MaxLanes(d)]; + alignas(16) T buf3[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + buf3[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); + v3 = Load(d, buf3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + TFromD* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + } +} + +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + } +} + +template +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, + TFromD* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + *unaligned++ = v3.raw[i]; + } +} + +// ------------------------------ Stream +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ Scatter in generic_ops-inl.h +// ------------------------------ Gather in generic_ops-inl.h + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +namespace detail { + +template +HWY_INLINE ToT CastValueForF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + using ToTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(static_cast(LimitsMax()) + + static_cast(ScalarSignBit(val))); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(ToTypeTag /* to_type_tag */, FromT val) { + return ConvertScalarTo(val); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(hwy::SignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(hwy::UnsignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} +// If val is within the range of ToT, CastValueForInRangeF2IConv(val) +// returns static_cast(val) +// +// Otherwise, CastValueForInRangeF2IConv(val) returns an +// implementation-defined result if val is not within the range of ToT. +template +HWY_INLINE ToT CastValueForInRangeF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(LimitsMin()); +} + +} // namespace detail + +template +HWY_API VFromD PromoteTo(DTo d, Vec128 from) { + static_assert(sizeof(TFromD) > sizeof(TFrom), "Not promoting"); + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); ++i) { + // For bits Y > X, floatX->floatY and intX->intY are always representable. + ret.raw[i] = detail::CastValueForPromoteTo>( + hwy::TypeTag>(), from.raw[i]); + } + return ret; +} + +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD PromoteInRangeTo(D64 d64, VFromD> v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d64); ++i) { + ret.raw[i] = detail::CastValueForInRangeF2IConv>(v.raw[i]); + } + return ret; +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(TFrom) is here, +// so we overload for TFrom=double and ToT={float,int32_t}. +template +HWY_API VFromD DemoteTo(D d, VFromD> from) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); ++i) { + // Prevent ubsan errors when converting float to narrower integer/float + if (ScalarIsInf(from.raw[i]) || + ScalarAbs(from.raw[i]) > static_cast(HighestValue())) { + ret.raw[i] = ScalarSignBit(from.raw[i]) ? LowestValue() + : HighestValue(); + continue; + } + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} +template +HWY_API VFromD DemoteTo(D d, VFromD> from) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); ++i) { + // Prevent ubsan errors when converting double to narrower integer/int32_t + ret.raw[i] = detail::CastValueForF2IConv>(from.raw[i]); + } + return ret; +} + +template )> +HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { + using TTo = TFromD; + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + VFromD ret; + for (size_t i = 0; i < N; ++i) { + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw[i] = + HWY_MIN(HWY_MAX(LimitsMin(), from.raw[i]), LimitsMax()); + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +// Disable the default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h on EMU128 as the EMU128 target has +// target-specific implementations of the unsigned to signed DemoteTo and +// ReorderDemote2To ops + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + +template +HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { + using TTo = TFromD; + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + const auto max = static_cast>(LimitsMax()); + + VFromD ret; + for (size_t i = 0; i < N; ++i) { + // Int to int: choose closest value in ToT to `from` (avoids UB) + ret.raw[i] = static_cast(HWY_MIN(from.raw[i], max)); + } + return ret; +} + +template +HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { + using TTo = TFromD; + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + VFromD ret; + for (size_t i = 0; i < N; ++i) { + // int64_t/uint64_t to float: okay to cast to float as an int64_t/uint64_t + // value is always within the range of a float + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +template +HWY_API VFromD ReorderDemote2To(DBF16 dbf16, VF32 a, VF32 b) { + const Repartition du32; + const VFromD b_in_lower = ShiftRight<16>(BitCast(du32, b)); + // Avoid OddEven - we want the upper half of `a` even on big-endian systems. + const VFromD a_mask = Set(du32, 0xFFFF0000); + return BitCast(dbf16, IfVecThenElse(a_mask, BitCast(du32, a), b_in_lower)); +} + +template ), class V, + HWY_IF_SIGNED_V(V), HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const RepartitionToWide dw; + const size_t NW = Lanes(dw); + using TN = TFromD; + const TN min = LimitsMin(); + const TN max = LimitsMax(); + VFromD ret; + for (size_t i = 0; i < NW; ++i) { + ret.raw[i] = static_cast(HWY_MIN(HWY_MAX(min, a.raw[i]), max)); + } + for (size_t i = 0; i < NW; ++i) { + ret.raw[NW + i] = static_cast(HWY_MIN(HWY_MAX(min, b.raw[i]), max)); + } + return ret; +} + +template ) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const RepartitionToWide dw; + const size_t NW = Lanes(dw); + using TN = TFromD; + using TN_U = MakeUnsigned; + const TN_U max = static_cast(LimitsMax()); + VFromD ret; + for (size_t i = 0; i < NW; ++i) { + ret.raw[i] = static_cast(HWY_MIN(a.raw[i], max)); + } + for (size_t i = 0; i < NW; ++i) { + ret.raw[NW + i] = static_cast(HWY_MIN(b.raw[i], max)); + } + return ret; +} + +template ), class V, + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + return ReorderDemote2To(dn, a, b); +} + +template ), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + const size_t NW = Lanes(dn) / 2; + using TN = TFromD; + VFromD ret; + for (size_t i = 0; i < NW; ++i) { + ret.raw[i] = ConvertScalarTo(a.raw[i]); + } + for (size_t i = 0; i < NW; ++i) { + ret.raw[NW + i] = ConvertScalarTo(b.raw[i]); + } + return ret; +} + +namespace detail { + +HWY_INLINE void StoreU16ToF16(const uint16_t val, + hwy::float16_t* HWY_RESTRICT to) { + CopySameSize(&val, to); +} + +HWY_INLINE uint16_t U16FromF16(const hwy::float16_t* HWY_RESTRICT from) { + uint16_t bits16; + CopySameSize(from, &bits16); + return bits16; +} + +} // namespace detail + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = F32FromBF16(v.raw[i]); + } + return ret; +} + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = BF16FromF32(v.raw[i]); + } + return ret; +} + +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD DemoteInRangeTo(D32 d32, VFromD> v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d32); ++i) { + ret.raw[i] = detail::CastValueForInRangeF2IConv>(v.raw[i]); + } + return ret; +} + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_API VFromD ConvertTo(hwy::FloatTag /*tag*/, DTo /*tag*/, + Vec128 from) { + using ToT = TFromD; + static_assert(sizeof(ToT) == sizeof(TFrom), "Should have same size"); + VFromD ret; + constexpr size_t N = HWY_MAX_LANES_D(DTo); + + for (size_t i = 0; i < N; ++i) { + // float## -> int##: return closest representable value + ret.raw[i] = CastValueForF2IConv(from.raw[i]); + } + return ret; +} + +template +HWY_API VFromD ConvertTo(hwy::NonFloatTag /*tag*/, DTo /* tag */, + Vec128 from) { + using ToT = TFromD; + static_assert(sizeof(ToT) == sizeof(TFrom), "Should have same size"); + VFromD ret; + constexpr size_t N = HWY_MAX_LANES_D(DTo); + for (size_t i = 0; i < N; ++i) { + // int## -> float##: no check needed + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +} // namespace detail + +template +HWY_API VFromD ConvertTo(DTo d, Vec128 from) { + return detail::ConvertTo(hwy::IsFloatTag(), d, from); +} + +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +template +HWY_API VFromD ConvertInRangeTo(DI di, VFromD> v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(di); i++) { + ret.raw[i] = detail::CastValueForInRangeF2IConv>(v.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 U8FromU32(Vec128 v) { + return DemoteTo(Simd(), v); +} + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFFFFFFu); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +#ifdef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#undef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#else +#define HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#endif + +template ) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedTruncate2To(DN dn, V a, V b) { + const RepartitionToWide dw; + const size_t NW = Lanes(dw); + using TW = TFromD; + using TN = TFromD; + VFromD ret; + constexpr TW max_val{LimitsMax()}; + + for (size_t i = 0; i < NW; ++i) { + ret.raw[i] = static_cast(a.raw[i] & max_val); + } + for (size_t i = 0; i < NW; ++i) { + ret.raw[NW + i] = static_cast(b.raw[i] & max_val); + } + return ret; +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(Vec128 v) { + return v.raw[0]; +} + +template +HWY_API Vec128 InsertLane(Vec128 v, size_t i, T t) { + v.raw[i] = t; + return v; +} + +template +HWY_API T ExtractLane(Vec128 v, size_t i) { + return v.raw[i]; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i + 1] = v.raw[i]; + } + return v; +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i] = v.raw[i + 1]; + } + return v; +} + +template +HWY_API Vec128 OddEven(Vec128 odd, Vec128 even) { + for (size_t i = 0; i < N; i += 2) { + odd.raw[i] = even.raw[i]; + } + return odd; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + constexpr size_t N = HWY_MAX_LANES_D(D); + for (size_t i = 1; i < N; i += 2) { + a.raw[i] = b.raw[i - 1]; + } + return a; +} + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + constexpr size_t N = HWY_MAX_LANES_D(D); + for (size_t i = 1; i < N; i += 2) { + b.raw[i - 1] = a.raw[i]; + } + return b; +} + +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template > +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template > +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + MakeSigned raw[N]; +}; + +template +HWY_API Indices128, N> IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index/lane size must match"); + Indices128, N> ret; + CopyBytes(vec.raw, ret.raw); + return ret; +} + +template +HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( + D d, const TI* idx) { + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = v.raw[idx.raw[i]]; + } + return ret; +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + using TI = MakeSigned; + Vec128 ret; + constexpr TI kVecLaneIdxMask = static_cast(N - 1); + for (size_t i = 0; i < N; ++i) { + const auto src_idx = idx.raw[i]; + const auto masked_src_lane_idx = src_idx & kVecLaneIdxMask; + ret.raw[i] = (src_idx < static_cast(N)) ? a.raw[masked_src_lane_idx] + : b.raw[masked_src_lane_idx]; + } + return ret; +} + +// ------------------------------ ReverseBlocks +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; // Single block: no change +} + +// ------------------------------ Reverse + +template +HWY_API VFromD Reverse(D d, VFromD v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); ++i) { + ret.raw[i] = v.raw[MaxLanes(d) - 1 - i]; + } + return ret; +} + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); i += 2) { + ret.raw[i + 0] = v.raw[i + 1]; + ret.raw[i + 1] = v.raw[i + 0]; + } + return ret; +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); i += 4) { + ret.raw[i + 0] = v.raw[i + 3]; + ret.raw[i + 1] = v.raw[i + 2]; + ret.raw[i + 2] = v.raw[i + 1]; + ret.raw[i + 3] = v.raw[i + 0]; + } + return ret; +} + +template +HWY_API VFromD Reverse8(D d, VFromD v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); i += 8) { + ret.raw[i + 0] = v.raw[i + 7]; + ret.raw[i + 1] = v.raw[i + 6]; + ret.raw[i + 2] = v.raw[i + 5]; + ret.raw[i + 3] = v.raw[i + 4]; + ret.raw[i + 4] = v.raw[i + 3]; + ret.raw[i + 5] = v.raw[i + 2]; + ret.raw[i + 6] = v.raw[i + 1]; + ret.raw[i + 7] = v.raw[i + 0]; + } + return ret; +} + +// ------------------------------ SlideUpLanes + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + VFromD ret = Zero(d); + constexpr size_t N = HWY_MAX_LANES_D(D); + const size_t clamped_amt = HWY_MIN(amt, N); + CopyBytes(v.raw, ret.raw + clamped_amt, + (N - clamped_amt) * sizeof(TFromD)); + return ret; +} + +// ------------------------------ SlideDownLanes + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + VFromD ret = Zero(d); + constexpr size_t N = HWY_MAX_LANES_D(D); + const size_t clamped_amt = HWY_MIN(amt, N); + CopyBytes(v.raw + clamped_amt, ret.raw, + (N - clamped_amt) * sizeof(TFromD)); + return ret; +} + +// ================================================== BLOCKWISE + +// ------------------------------ Shuffle* + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Reverse2(DFromV(), v); +} + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + Vec128 ret; + ret.raw[3] = v.raw[1]; + ret.raw[2] = v.raw[0]; + ret.raw[1] = v.raw[3]; + ret.raw[0] = v.raw[2]; + return ret; +} +template +HWY_API Vec128 Shuffle01(Vec128 v) { + static_assert(sizeof(T) == 8, "Only for 64-bit"); + return Reverse2(DFromV(), v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(Vec128 v) { + Vec128 ret; + ret.raw[3] = v.raw[0]; + ret.raw[2] = v.raw[3]; + ret.raw[1] = v.raw[2]; + ret.raw[0] = v.raw[1]; + return ret; +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(Vec128 v) { + Vec128 ret; + ret.raw[3] = v.raw[2]; + ret.raw[2] = v.raw[1]; + ret.raw[1] = v.raw[0]; + ret.raw[0] = v.raw[3]; + return ret; +} + +template +HWY_API Vec128 Shuffle0123(Vec128 v) { + return Reverse4(DFromV(), v); +} + +// ------------------------------ Broadcast +template +HWY_API Vec128 Broadcast(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[kLane]; + } + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template +HWY_API Vec128 TableLookupBytes(Vec128 v, + Vec128 indices) { + const uint8_t* HWY_RESTRICT v_bytes = + reinterpret_cast(v.raw); + const uint8_t* HWY_RESTRICT idx_bytes = + reinterpret_cast(indices.raw); + Vec128 ret; + uint8_t* HWY_RESTRICT ret_bytes = + reinterpret_cast(ret.raw); + for (size_t i = 0; i < NI * sizeof(TI); ++i) { + const size_t idx = idx_bytes[i]; + // Avoid out of bounds reads. + ret_bytes[i] = idx < sizeof(T) * N ? v_bytes[idx] : 0; + } + return ret; +} + +template +HWY_API Vec128 TableLookupBytesOr0(Vec128 v, + Vec128 indices) { + // Same as TableLookupBytes, which already returns 0 if out of bounds. + return TableLookupBytes(v, indices); +} + +// ------------------------------ InterleaveLower/InterleaveUpper + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[2 * i + 0] = a.raw[i]; + ret.raw[2 * i + 1] = b.raw[i]; + } + return ret; +} + +// Additional overload for the optional tag. +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half dh; + VFromD ret; + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[2 * i + 0] = a.raw[MaxLanes(dh) + i]; + ret.raw[2 * i + 1] = b.raw[MaxLanes(dh) + i]; + } + return ret; +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== MASK + +template +HWY_API bool AllFalse(D d, MFromD mask) { + typename MFromD::Raw or_sum = 0; + for (size_t i = 0; i < MaxLanes(d); ++i) { + or_sum |= mask.bits[i]; + } + return or_sum == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr uint64_t kAll = LimitsMax::Raw>(); + uint64_t and_sum = kAll; + for (size_t i = 0; i < MaxLanes(d); ++i) { + and_sum &= mask.bits[i]; + } + return and_sum == kAll; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + MFromD m; + for (size_t i = 0; i < MaxLanes(d); ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + m.bits[i] = MFromD::FromBool((bits[idx_byte] & bit) != 0); + } + return m; +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + MFromD m; + for (size_t i = 0; i < MaxLanes(d); ++i) { + m.bits[i] = MFromD::FromBool(((mask_bits >> i) & 1u) != 0); + } + return m; +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + bits[0] = 0; + if (MaxLanes(d) > 8) bits[1] = 0; // MaxLanes(d) <= 16, so max two bytes + for (size_t i = 0; i < MaxLanes(d); ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + if (mask.bits[i]) { + bits[idx_byte] = static_cast(bits[idx_byte] | bit); + } + } + return MaxLanes(d) > 8 ? 2 : 1; +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + size_t count = 0; + for (size_t i = 0; i < MaxLanes(d); ++i) { + count += mask.bits[i] != 0; + } + return count; +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask.bits[i] != 0) return i; + } + HWY_DASSERT(false); + return 0; +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask.bits[i] != 0) return static_cast(i); + } + return intptr_t{-1}; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + for (intptr_t i = static_cast(MaxLanes(d) - 1); i >= 0; i--) { + if (mask.bits[i] != 0) return static_cast(i); + } + HWY_DASSERT(false); + return 0; +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + for (intptr_t i = static_cast(MaxLanes(d) - 1); i >= 0; i--) { + if (mask.bits[i] != 0) return i; + } + return intptr_t{-1}; +} + +// ------------------------------ Compress + +template +struct CompressIsPartition { + enum { value = (sizeof(T) != 1) }; +}; + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + size_t count = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ Expand + +// Could also just allow generic_ops-inl.h to implement these, but use our +// simple implementation below to ensure the test is correct. +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +template +HWY_API Vec128 Expand(Vec128 v, const Mask128 mask) { + size_t in_pos = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[i] = v.raw[in_pos++]; + } else { + ret.raw[i] = ConvertScalarTo(0); + } + } + return ret; +} + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + size_t in_pos = 0; + VFromD ret; + for (size_t i = 0; i < Lanes(d); ++i) { + if (mask.bits[i]) { + ret.raw[i] = unaligned[in_pos++]; + } else { + ret.raw[i] = TFromD(); // zero, also works for float16_t + } + } + return ret; +} + +// ------------------------------ CompressNot +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + size_t count = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(Simd(), bits)); +} + +// ------------------------------ CompressStore + +// generic_ops-inl defines the 8-bit versions. +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + size_t count = 0; + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask.bits[i]) { + unaligned[count++] = v.raw[i]; + } + } + return count; +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, mask, d, unaligned); +} + +// ------------------------------ CompressBitsStore +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + const MFromD mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ Additional mask logical operations +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} + +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + using TU = hwy::MakeUnsigned; + + Mask128 result; + TU result_lane_mask{0}; + for (size_t i = 0; i < N; i++) { + result_lane_mask = static_cast(result_lane_mask | mask.bits[i]); + result.bits[i] = result_lane_mask; + } + return result; +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + using TU = hwy::MakeUnsigned; + using TI = hwy::MakeSigned; + + Mask128 result; + TU result_lane_mask = static_cast(~TU{0}); + for (size_t i = 0; i < N; i++) { + const auto curr_lane_mask_bits = mask.bits[i]; + result.bits[i] = static_cast(curr_lane_mask_bits & result_lane_mask); + result_lane_mask = + static_cast(result_lane_mask & + static_cast(-static_cast(mask.bits[i] == 0))); + } + return result; +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + using TU = hwy::MakeUnsigned; + using TI = hwy::MakeSigned; + + Mask128 result; + TU result_lane_mask = static_cast(~TU{0}); + for (size_t i = 0; i < N; i++) { + result.bits[i] = result_lane_mask; + result_lane_mask = + static_cast(result_lane_mask & + static_cast(-static_cast(mask.bits[i] == 0))); + } + return result; +} + +// ------------------------------ WidenMulPairwiseAdd + +template +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} + +template +HWY_API VFromD WidenMulPairwiseAdd(D d32, V16 a, V16 b) { + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API VFromD ReorderWidenMulAccumulate(D d32, V16 a, V16 b, + const VFromD sum0, + VFromD& sum1) { + sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API VW RearrangeToOddPlusEven(VW sum0, VW sum1) { + return Add(sum0, sum1); +} + +// ================================================== REDUCTIONS + +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceSum(D d, VFromD v) { + T sum = T{0}; + for (size_t i = 0; i < MaxLanes(d); ++i) { + sum += v.raw[i]; + } + return sum; +} + +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMin(D d, VFromD v) { + T min = PositiveInfOrHighestValue(); + for (size_t i = 0; i < MaxLanes(d); ++i) { + min = HWY_MIN(min, v.raw[i]); + } + return min; +} +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMax(D d, VFromD v) { + T max = NegativeInfOrLowestValue(); + for (size_t i = 0; i < MaxLanes(d); ++i) { + max = HWY_MAX(max, v.raw[i]); + } + return max; +} + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); +} + +// ================================================== OPS WITH DEPENDENCIES + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128(), mul); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + const Half> d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128(), mul); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/aom/third_party/highway/hwy/ops/generic_ops-inl.h b/third_party/aom/third_party/highway/hwy/ops/generic_ops-inl.h new file mode 100644 index 000000000000..d8bc111e3cc3 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/generic_ops-inl.h @@ -0,0 +1,8165 @@ +// Copyright 2021 Google LLC +// Copyright 2023,2024 Arm Limited and/or +// its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Target-independent types/functions defined after target-specific ops. + +// The "include guards" in this file that check HWY_TARGET_TOGGLE serve to skip +// the generic implementation here if native ops are already defined. + +#include "third_party/highway/hwy/base.h" + +// Define detail::Shuffle1230 etc, but only when viewing the current header; +// normally this is included via highway.h, which includes ops/*.h. +#if HWY_IDE && !defined(HWY_HIGHWAY_INCLUDED) +#include "third_party/highway/hwy/detect_targets.h" +#include "third_party/highway/hwy/ops/emu128-inl.h" +#endif // HWY_IDE + +// Relies on the external include guard in highway.h. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// The lane type of a vector type, e.g. float for Vec>. +template +using LaneType = decltype(GetLane(V())); + +// Vector type, e.g. Vec128 for CappedTag. Useful as the return +// type of functions that do not take a vector argument, or as an argument type +// if the function only has a template argument for D, or for explicit type +// names instead of auto. This may be a built-in type. +template +using Vec = decltype(Zero(D())); + +// Mask type. Useful as the return type of functions that do not take a mask +// argument, or as an argument type if the function only has a template argument +// for D, or for explicit type names instead of auto. +template +using Mask = decltype(MaskFromVec(Zero(D()))); + +// Returns the closest value to v within [lo, hi]. +template +HWY_API V Clamp(const V v, const V lo, const V hi) { + return Min(Max(lo, v), hi); +} + +// CombineShiftRightBytes (and -Lanes) are not available for the scalar target, +// and RVV has its own implementation of -Lanes. +#if (HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_RVV) || HWY_IDE + +template +HWY_API VFromD CombineShiftRightLanes(D d, VFromD hi, VFromD lo) { + constexpr size_t kBytes = kLanes * sizeof(TFromD); + static_assert(kBytes < 16, "Shift count is per-block"); + return CombineShiftRightBytes(d, hi, lo); +} + +#endif + +// Returns lanes with the most significant bit set and all other bits zero. +template +HWY_API Vec SignBit(D d) { + const RebindToUnsigned du; + return BitCast(d, Set(du, SignMask>())); +} + +// Returns quiet NaN. +template +HWY_API Vec NaN(D d) { + const RebindToSigned di; + // LimitsMax sets all exponent and mantissa bits to 1. The exponent plus + // mantissa MSB (to indicate quiet) would be sufficient. + return BitCast(d, Set(di, LimitsMax>())); +} + +// Returns positive infinity. +template +HWY_API Vec Inf(D d) { + const RebindToUnsigned du; + using T = TFromD; + using TU = TFromD; + const TU max_x2 = static_cast(MaxExponentTimes2()); + return BitCast(d, Set(du, max_x2 >> 1)); +} + +// ------------------------------ MaskedSetOr/MaskedSet + +template , typename D = DFromV, + typename M = MFromD> +HWY_API V MaskedSetOr(V no, M m, T a) { + D d; + return IfThenElse(m, Set(d, a), no); +} + +template , typename M = MFromD, + typename T = TFromD> +HWY_API V MaskedSet(D d, M m, T a) { + return IfThenElseZero(m, Set(d, a)); +} + +// ------------------------------ ZeroExtendResizeBitCast + +// The implementation of detail::ZeroExtendResizeBitCast for the HWY_EMU128 +// target is in emu128-inl.h, and the implementation of +// detail::ZeroExtendResizeBitCast for the HWY_SCALAR target is in scalar-inl.h +#if HWY_TARGET != HWY_EMU128 && HWY_TARGET != HWY_SCALAR +namespace detail { + +#if HWY_HAVE_SCALABLE +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag /* to_size_tag */, DTo d_to, DFrom d_from, + VFromD v) { + const Repartition d_to_u8; + const auto resized = ResizeBitCast(d_to_u8, v); + // Zero the upper bytes which were not present/valid in d_from. + const size_t num_bytes = Lanes(Repartition()); + return BitCast(d_to, IfThenElseZero(FirstN(d_to_u8, num_bytes), resized)); +} +#else // target that uses fixed-size vectors +// Truncating or same-size resizing cast: same as ResizeBitCast +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag /* to_size_tag */, DTo d_to, DFrom /*d_from*/, + VFromD v) { + return ResizeBitCast(d_to, v); +} + +// Resizing cast to vector that has twice the number of lanes of the source +// vector +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag /* to_size_tag */, DTo d_to, DFrom d_from, + VFromD v) { + const Twice dt_from; + return BitCast(d_to, ZeroExtendVector(dt_from, v)); +} + +// Resizing cast to vector that has more than twice the number of lanes of the +// source vector +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag /* to_size_tag */, DTo d_to, DFrom /*d_from*/, + VFromD v) { + using TFrom = TFromD; + constexpr size_t kNumOfFromLanes = kFromVectSize / sizeof(TFrom); + const Repartition d_resize_to; + return BitCast(d_to, IfThenElseZero(FirstN(d_resize_to, kNumOfFromLanes), + ResizeBitCast(d_resize_to, v))); +} +#endif // HWY_HAVE_SCALABLE + +} // namespace detail +#endif // HWY_TARGET != HWY_EMU128 && HWY_TARGET != HWY_SCALAR + +template +HWY_API VFromD ZeroExtendResizeBitCast(DTo d_to, DFrom d_from, + VFromD v) { + return detail::ZeroExtendResizeBitCast(hwy::SizeTag(), + hwy::SizeTag(), d_to, + d_from, v); +} + +// ------------------------------ SafeFillN + +template > +HWY_API void SafeFillN(const size_t num, const T value, D d, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = value; + } +#else + BlendedStore(Set(d, value), FirstN(d, num), d, to); +#endif +} + +// ------------------------------ SafeCopyN + +template > +HWY_API void SafeCopyN(const size_t num, D d, const T* HWY_RESTRICT from, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = from[i]; + } +#else + const Mask mask = FirstN(d, num); + BlendedStore(MaskedLoad(mask, d, from), mask, d, to); +#endif +} + +// ------------------------------ IsNegative +#if (defined(HWY_NATIVE_IS_NEGATIVE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +template +HWY_API Mask> IsNegative(V v) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MaskFromVec(BroadcastSignBit(BitCast(di, v)))); +} + +#endif // HWY_NATIVE_IS_NEGATIVE + +// ------------------------------ MaskFalse +#if (defined(HWY_NATIVE_MASK_FALSE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +template +HWY_API Mask MaskFalse(D d) { + return MaskFromVec(Zero(d)); +} + +#endif // HWY_NATIVE_MASK_FALSE + +// ------------------------------ IfNegativeThenElseZero +#if (defined(HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#undef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#else +#define HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#endif + +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + return IfThenElseZero(IsNegative(v), yes); +} + +#endif // HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO + +// ------------------------------ IfNegativeThenZeroElse +#if (defined(HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#undef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#else +#define HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#endif + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + return IfThenZeroElse(IsNegative(v), no); +} + +#endif // HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE + +// ------------------------------ ZeroIfNegative (IfNegativeThenZeroElse) + +// ZeroIfNegative is generic for all vector lengths +template +HWY_API V ZeroIfNegative(V v) { + return IfNegativeThenZeroElse(v, v); +} + +// ------------------------------ BitwiseIfThenElse +#if (defined(HWY_NATIVE_BITWISE_IF_THEN_ELSE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +#endif // HWY_NATIVE_BITWISE_IF_THEN_ELSE + +// ------------------------------ PromoteMaskTo + +#if (defined(HWY_NATIVE_PROMOTE_MASK_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +template +HWY_API Mask PromoteMaskTo(DTo d_to, DFrom d_from, Mask m) { + static_assert( + sizeof(TFromD) > sizeof(TFromD), + "sizeof(TFromD) must be greater than sizeof(TFromD)"); + static_assert( + IsSame, Mask, DTo>>>(), + "Mask must be the same type as Mask, DTo>>"); + + const RebindToSigned di_to; + const RebindToSigned di_from; + + return MaskFromVec(BitCast( + d_to, PromoteTo(di_to, BitCast(di_from, VecFromMask(d_from, m))))); +} + +#endif // HWY_NATIVE_PROMOTE_MASK_TO + +// ------------------------------ DemoteMaskTo + +#if (defined(HWY_NATIVE_DEMOTE_MASK_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +template +HWY_API Mask DemoteMaskTo(DTo d_to, DFrom d_from, Mask m) { + static_assert(sizeof(TFromD) < sizeof(TFromD), + "sizeof(TFromD) must be less than sizeof(TFromD)"); + static_assert( + IsSame, Mask, DTo>>>(), + "Mask must be the same type as Mask, DTo>>"); + + const RebindToSigned di_to; + const RebindToSigned di_from; + + return MaskFromVec( + BitCast(d_to, DemoteTo(di_to, BitCast(di_from, VecFromMask(d_from, m))))); +} + +#endif // HWY_NATIVE_DEMOTE_MASK_TO + +// ------------------------------ InsertIntoUpper +#if (defined(HWY_NATIVE_LOAD_HIGHER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOAD_HIGHER +#undef HWY_NATIVE_LOAD_HIGHER +#else +#define HWY_NATIVE_LOAD_HIGHER +#endif +template (), HWY_IF_LANES_GT_D(D, 1), + HWY_IF_POW2_GT_D(D, -3)> +HWY_API V InsertIntoUpper(D d, T* p, V a) { + Half dh; + const VFromD b = LoadU(dh, p); + return Combine(d, b, LowerHalf(a)); +} +#endif // HWY_NATIVE_LOAD_HIGHER + +// ------------------------------ CombineMasks + +#if (defined(HWY_NATIVE_COMBINE_MASKS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_COMBINE_MASKS +#undef HWY_NATIVE_COMBINE_MASKS +#else +#define HWY_NATIVE_COMBINE_MASKS +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API Mask CombineMasks(D d, Mask> hi, Mask> lo) { + const Half dh; + return MaskFromVec(Combine(d, VecFromMask(dh, hi), VecFromMask(dh, lo))); +} +#endif + +#endif // HWY_NATIVE_COMBINE_MASKS + +// ------------------------------ LowerHalfOfMask + +#if (defined(HWY_NATIVE_LOWER_HALF_OF_MASK) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API Mask LowerHalfOfMask(D d, Mask> m) { + const Twice dt; + return MaskFromVec(LowerHalf(d, VecFromMask(dt, m))); +} + +#endif // HWY_NATIVE_LOWER_HALF_OF_MASK + +// ------------------------------ UpperHalfOfMask + +#if (defined(HWY_NATIVE_UPPER_HALF_OF_MASK) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_UPPER_HALF_OF_MASK +#undef HWY_NATIVE_UPPER_HALF_OF_MASK +#else +#define HWY_NATIVE_UPPER_HALF_OF_MASK +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API Mask UpperHalfOfMask(D d, Mask> m) { + const Twice dt; + return MaskFromVec(UpperHalf(d, VecFromMask(dt, m))); +} +#endif + +#endif // HWY_NATIVE_UPPER_HALF_OF_MASK + +// ------------------------------ OrderedDemote2MasksTo + +#if (defined(HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#undef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#else +#define HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API Mask OrderedDemote2MasksTo(DTo d_to, DFrom d_from, Mask a, + Mask b) { + static_assert( + sizeof(TFromD) == sizeof(TFromD) / 2, + "sizeof(TFromD) must be equal to sizeof(TFromD) / 2"); + static_assert(IsSame, Mask, DFrom>>>(), + "Mask must be the same type as " + "Mask, DFrom>>>()"); + + const RebindToSigned di_from; + const RebindToSigned di_to; + + const auto va = BitCast(di_from, VecFromMask(d_from, a)); + const auto vb = BitCast(di_from, VecFromMask(d_from, b)); + return MaskFromVec(BitCast(d_to, OrderedDemote2To(di_to, va, vb))); +} +#endif + +#endif // HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO + +// ------------------------------ RotateLeft +template +HWY_API V RotateLeft(V v) { + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + + constexpr int kRotateRightAmt = + (kBits == 0) ? 0 : static_cast(kSizeInBits) - kBits; + return RotateRight(v); +} + +// ------------------------------ InterleaveWholeLower/InterleaveWholeUpper +#if (defined(HWY_NATIVE_INTERLEAVE_WHOLE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INTERLEAVE_WHOLE +#undef HWY_NATIVE_INTERLEAVE_WHOLE +#else +#define HWY_NATIVE_INTERLEAVE_WHOLE +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + // InterleaveWholeLower(d, a, b) is equivalent to InterleaveLower(a, b) if + // D().MaxBytes() <= 16 is true + return InterleaveLower(d, a, b); +} +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + // InterleaveWholeUpper(d, a, b) is equivalent to InterleaveUpper(a, b) if + // D().MaxBytes() <= 16 is true + return InterleaveUpper(d, a, b); +} + +// InterleaveWholeLower/InterleaveWholeUpper for 32-byte vectors on AVX2/AVX3 +// is implemented in x86_256-inl.h. + +// InterleaveWholeLower/InterleaveWholeUpper for 64-byte vectors on AVX3 is +// implemented in x86_512-inl.h. + +// InterleaveWholeLower/InterleaveWholeUpper for 32-byte vectors on WASM_EMU256 +// is implemented in wasm_256-inl.h. +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_INTERLEAVE_WHOLE + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +// The InterleaveWholeLower without the optional D parameter is generic for all +// vector lengths. +template +HWY_API V InterleaveWholeLower(V a, V b) { + return InterleaveWholeLower(DFromV(), a, b); +} +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ InterleaveEven + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +// InterleaveEven without the optional D parameter is generic for all vector +// lengths +template +HWY_API V InterleaveEven(V a, V b) { + return InterleaveEven(DFromV(), a, b); +} +#endif + +// ------------------------------ MinMagnitude/MaxMagnitude + +#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#undef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#else +#define HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#endif + +template +HWY_API V MinMagnitude(V a, V b) { + const V abs_a = Abs(a); + const V abs_b = Abs(b); + const V min = Min(IfThenElse(Eq(abs_a, abs_b), a, b), b); + return IfThenElse(Lt(abs_a, abs_b), a, min); +} + +template +HWY_API V MaxMagnitude(V a, V b) { + const V abs_a = Abs(a); + const V abs_b = Abs(b); + // This lvalue appears to be necessary to avoid a clang bug on SVE. + const V max = Max(IfThenElse(Eq(abs_a, abs_b), b, a), a); + return IfThenElse(Lt(abs_a, abs_b), b, max); +} + +#endif // HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE + +template +HWY_API V MinMagnitude(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const auto abs_a = BitCast(du, Abs(a)); + const auto abs_b = BitCast(du, Abs(b)); + return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), a, + Min(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), a, b), b)); +} + +template +HWY_API V MaxMagnitude(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const auto abs_a = BitCast(du, Abs(a)); + const auto abs_b = BitCast(du, Abs(b)); + return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), b, + Max(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), b, a), a)); +} + +template +HWY_API V MinMagnitude(V a, V b) { + return Min(a, b); +} + +template +HWY_API V MaxMagnitude(V a, V b) { + return Max(a, b); +} + +// ------------------------------ AddSub + +template , 1)> +HWY_API V AddSub(V a, V b) { + // AddSub(a, b) for a one-lane vector is equivalent to Sub(a, b) + return Sub(a, b); +} + +// AddSub for F32x2, F32x4, and F64x2 vectors is implemented in x86_128-inl.h on +// SSSE3/SSE4/AVX2/AVX3 + +// AddSub for F32x8 and F64x4 vectors is implemented in x86_256-inl.h on +// AVX2/AVX3 + +// AddSub for F16/F32/F64 vectors on SVE is implemented in arm_sve-inl.h + +// AddSub for integer vectors on SVE2 is implemented in arm_sve-inl.h +template +HWY_API V AddSub(V a, V b) { + using D = DFromV; + using T = TFromD; + using TNegate = If(), MakeSigned, T>; + + const D d; + const Rebind d_negate; + + // Negate the even lanes of b + const auto negated_even_b = OddEven(b, BitCast(d, Neg(BitCast(d_negate, b)))); + + return Add(a, negated_even_b); +} + +// ------------------------------ MaskedAddOr etc. +#if (defined(HWY_NATIVE_MASKED_ARITH) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +template +HWY_API V MaskedMinOr(V no, M m, V a, V b) { + return IfThenElse(m, Min(a, b), no); +} + +template +HWY_API V MaskedMaxOr(V no, M m, V a, V b) { + return IfThenElse(m, Max(a, b), no); +} + +template +HWY_API V MaskedAddOr(V no, M m, V a, V b) { + return IfThenElse(m, Add(a, b), no); +} + +template +HWY_API V MaskedSubOr(V no, M m, V a, V b) { + return IfThenElse(m, Sub(a, b), no); +} + +template +HWY_API V MaskedMulOr(V no, M m, V a, V b) { + return IfThenElse(m, Mul(a, b), no); +} + +template +HWY_API V MaskedDivOr(V no, M m, V a, V b) { + return IfThenElse(m, Div(a, b), no); +} + +template +HWY_API V MaskedModOr(V no, M m, V a, V b) { + return IfThenElse(m, Mod(a, b), no); +} + +template +HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedAdd(a, b), no); +} + +template +HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedSub(a, b), no); +} +#endif // HWY_NATIVE_MASKED_ARITH + +#if (defined(HWY_NATIVE_ZERO_MASKED_ARITH) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ZERO_MASKED_ARITH +#undef HWY_NATIVE_ZERO_MASKED_ARITH +#else +#define HWY_NATIVE_ZERO_MASKED_ARITH +#endif + +template +HWY_API V MaskedMax(M m, V a, V b) { + return IfThenElseZero(m, (Max(a, b))); +} + +template +HWY_API V MaskedAdd(M m, V a, V b) { + return IfThenElseZero(m, Add(a, b)); +} + +template +HWY_API V MaskedSub(M m, V a, V b) { + return IfThenElseZero(m, Sub(a, b)); +} + +template +HWY_API V MaskedMul(M m, V a, V b) { + return IfThenElseZero(m, Mul(a, b)); +} + +template +HWY_API V MaskedDiv(M m, V a, V b) { + return IfThenElseZero(m, Div(a, b)); +} + +template +HWY_API V MaskedSaturatedAdd(M m, V a, V b) { + return IfThenElseZero(m, SaturatedAdd(a, b)); +} + +template +HWY_API V MaskedSaturatedSub(M m, V a, V b) { + return IfThenElseZero(m, SaturatedSub(a, b)); +} + +template , HWY_IF_I16_D(D)> +HWY_API V MaskedMulFixedPoint15(M m, V a, V b) { + return IfThenElseZero(m, MulFixedPoint15(a, b)); +} + +template +HWY_API V MaskedMulAdd(M m, V mul, V x, V add) { + return IfThenElseZero(m, MulAdd(mul, x, add)); +} + +template +HWY_API V MaskedNegMulAdd(M m, V mul, V x, V add) { + return IfThenElseZero(m, NegMulAdd(mul, x, add)); +} + +template >> +HWY_API VFromD MaskedWidenMulPairwiseAdd(D d32, M m, V16 a, V16 b) { + return IfThenElseZero(m, WidenMulPairwiseAdd(d32, a, b)); +} + +template +HWY_API VFromD MaskedWidenMulPairwiseAdd(DF df, M m, VBF a, VBF b) { + return IfThenElseZero(m, WidenMulPairwiseAdd(df, a, b)); +} +#endif // HWY_NATIVE_ZERO_MASKED_ARITH + +// ------------------------------ MaskedShift +template +HWY_API V MaskedShiftLeft(M m, V a) { + return IfThenElseZero(m, ShiftLeft(a)); +} + +template +HWY_API V MaskedShiftRight(M m, V a) { + return IfThenElseZero(m, ShiftRight(a)); +} + +template +HWY_API V MaskedShiftRightOr(V no, M m, V a) { + return IfThenElse(m, ShiftRight(a), no); +} + +template +HWY_API V MaskedShrOr(V no, M m, V a, V shifts) { + return IfThenElse(m, Shr(a, shifts), no); +} + +// ------------------------------ MaskedEq etc. +#if (defined(HWY_NATIVE_MASKED_COMP) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +template +HWY_API auto MaskedEq(M m, V a, V b) -> decltype(a == b) { + return And(m, Eq(a, b)); +} + +template +HWY_API auto MaskedNe(M m, V a, V b) -> decltype(a == b) { + return And(m, Ne(a, b)); +} + +template +HWY_API auto MaskedLt(M m, V a, V b) -> decltype(a == b) { + return And(m, Lt(a, b)); +} + +template +HWY_API auto MaskedGt(M m, V a, V b) -> decltype(a == b) { + return And(m, Gt(a, b)); +} + +template +HWY_API auto MaskedLe(M m, V a, V b) -> decltype(a == b) { + return And(m, Le(a, b)); +} + +template +HWY_API auto MaskedGe(M m, V a, V b) -> decltype(a == b) { + return And(m, Ge(a, b)); +} + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return And(m, IsNaN(v)); +} +#endif // HWY_NATIVE_MASKED_COMP + +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif + +template +HWY_API V IfNegativeThenNegOrUndefIfZero(V mask, V v) { +#if HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE + // MaskedSubOr is more efficient than IfNegativeThenElse on RVV/SVE + const auto zero = Zero(DFromV()); + return MaskedSubOr(v, Lt(mask, zero), zero, v); +#else + return IfNegativeThenElse(mask, Neg(v), v); +#endif +} + +#endif // HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG + +template +HWY_API V IfNegativeThenNegOrUndefIfZero(V mask, V v) { + return CopySign(v, Xor(mask, v)); +} + +// ------------------------------ SaturatedNeg + +#if (defined(HWY_NATIVE_SATURATED_NEG_8_16_32) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 +#undef HWY_NATIVE_SATURATED_NEG_8_16_32 +#else +#define HWY_NATIVE_SATURATED_NEG_8_16_32 +#endif + +template +HWY_API V SaturatedNeg(V v) { + const DFromV d; + return SaturatedSub(Zero(d), v); +} + +template )> +HWY_API V SaturatedNeg(V v) { + const DFromV d; + +#if HWY_TARGET == HWY_RVV || HWY_TARGET_IS_PPC || HWY_TARGET_IS_SVE || \ + HWY_TARGET_IS_NEON + // RVV/PPC/SVE/NEON have native I32 SaturatedSub instructions + return SaturatedSub(Zero(d), v); +#else + // ~v[i] - ((v[i] > LimitsMin()) ? -1 : 0) is equivalent to + // (v[i] > LimitsMin) ? (-v[i]) : LimitsMax() since + // -v[i] == ~v[i] + 1 == ~v[i] - (-1) and + // ~LimitsMin() == LimitsMax(). + return Sub(Not(v), VecFromMask(d, Gt(v, Set(d, LimitsMin())))); +#endif +} +#endif // HWY_NATIVE_SATURATED_NEG_8_16_32 + +#if (defined(HWY_NATIVE_SATURATED_NEG_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SATURATED_NEG_64 +#undef HWY_NATIVE_SATURATED_NEG_64 +#else +#define HWY_NATIVE_SATURATED_NEG_64 +#endif + +template )> +HWY_API V SaturatedNeg(V v) { +#if HWY_TARGET == HWY_RVV || HWY_TARGET_IS_SVE || HWY_TARGET_IS_NEON + // RVV/SVE/NEON have native I64 SaturatedSub instructions + const DFromV d; + return SaturatedSub(Zero(d), v); +#else + const auto neg_v = Neg(v); + return Add(neg_v, BroadcastSignBit(And(v, neg_v))); +#endif +} +#endif // HWY_NATIVE_SATURATED_NEG_64 + +// ------------------------------ SaturatedAbs + +#if (defined(HWY_NATIVE_SATURATED_ABS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +template +HWY_API V SaturatedAbs(V v) { + return Max(v, SaturatedNeg(v)); +} + +#endif + +// ------------------------------ MaskedAbsOr +template +HWY_API V MaskedAbsOr(V no, M m, V v) { + return IfThenElse(m, Abs(v), no); +} + +// ------------------------------ MaskedAbs +template +HWY_API V MaskedAbs(M m, V v) { + return IfThenElseZero(m, Abs(v)); +} + +// ------------------------------ Reductions + +// Targets follow one of two strategies. If HWY_NATIVE_REDUCE_SCALAR is toggled, +// they (RVV/SVE/Armv8/Emu128) implement ReduceSum and SumOfLanes via Set. +// Otherwise, they (Armv7/PPC/scalar/WASM/x86) define zero to most of the +// SumOfLanes overloads. For the latter group, we here define the remaining +// overloads, plus ReduceSum which uses them plus GetLane. +#if (defined(HWY_NATIVE_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +namespace detail { + +// Allows reusing the same shuffle code for SumOfLanes/MinOfLanes/MaxOfLanes. +struct AddFunc { + template + V operator()(V a, V b) const { + return Add(a, b); + } +}; + +struct MinFunc { + template + V operator()(V a, V b) const { + return Min(a, b); + } +}; + +struct MaxFunc { + template + V operator()(V a, V b) const { + return Max(a, b); + } +}; + +// No-op for vectors of at most one block. +template +HWY_INLINE VFromD ReduceAcrossBlocks(D, Func, VFromD v) { + return v; +} + +// Reduces a lane with its counterpart in other block(s). Shared by AVX2 and +// WASM_EMU256. AVX3 has its own overload. +template +HWY_INLINE VFromD ReduceAcrossBlocks(D /*d*/, Func f, VFromD v) { + return f(v, SwapAdjacentBlocks(v)); +} + +// These return the reduction result broadcasted across all lanes. They assume +// the caller has already reduced across blocks. + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v10) { + return f(v10, Reverse2(d, v10)); +} + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v3210) { + const VFromD v0123 = Reverse4(d, v3210); + const VFromD v03_12_12_03 = f(v3210, v0123); + const VFromD v12_03_03_12 = Reverse2(d, v03_12_12_03); + return f(v03_12_12_03, v12_03_03_12); +} + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v76543210) { + // The upper half is reversed from the lower half; omit for brevity. + const VFromD v34_25_16_07 = f(v76543210, Reverse8(d, v76543210)); + const VFromD v0347_1625_1625_0347 = + f(v34_25_16_07, Reverse4(d, v34_25_16_07)); + return f(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); +} + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v) { + const RepartitionToWide dw; + using VW = VFromD; + const VW vw = BitCast(dw, v); + // f is commutative, so no need to adapt for HWY_IS_LITTLE_ENDIAN. + const VW even = And(vw, Set(dw, 0xFF)); + const VW odd = ShiftRight<8>(vw); + const VW reduced = ReduceWithinBlocks(dw, f, f(even, odd)); +#if HWY_IS_LITTLE_ENDIAN + return DupEven(BitCast(d, reduced)); +#else + return DupOdd(BitCast(d, reduced)); +#endif +} + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v) { + const RepartitionToWide dw; + using VW = VFromD; + const VW vw = BitCast(dw, v); + // Sign-extend + // f is commutative, so no need to adapt for HWY_IS_LITTLE_ENDIAN. + const VW even = ShiftRight<8>(ShiftLeft<8>(vw)); + const VW odd = ShiftRight<8>(vw); + const VW reduced = ReduceWithinBlocks(dw, f, f(even, odd)); +#if HWY_IS_LITTLE_ENDIAN + return DupEven(BitCast(d, reduced)); +#else + return DupOdd(BitCast(d, reduced)); +#endif +} + +} // namespace detail + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + const detail::AddFunc f; + v = detail::ReduceAcrossBlocks(d, f, v); + return detail::ReduceWithinBlocks(d, f, v); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + const detail::MinFunc f; + v = detail::ReduceAcrossBlocks(d, f, v); + return detail::ReduceWithinBlocks(d, f, v); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + const detail::MaxFunc f; + v = detail::ReduceAcrossBlocks(d, f, v); + return detail::ReduceWithinBlocks(d, f, v); +} + +template +HWY_API TFromD ReduceSum(D d, VFromD v) { + return GetLane(SumOfLanes(d, v)); +} +template +HWY_API TFromD ReduceMin(D d, VFromD v) { + return GetLane(MinOfLanes(d, v)); +} +template +HWY_API TFromD ReduceMax(D d, VFromD v) { + return GetLane(MaxOfLanes(d, v)); +} + +#endif // HWY_NATIVE_REDUCE_SCALAR + +// Corner cases for both generic and native implementations: +// N=1 (native covers N=2 e.g. for u64x2 and even u32x2 on Arm) +template +HWY_API TFromD ReduceSum(D /*d*/, VFromD v) { + return GetLane(v); +} +template +HWY_API TFromD ReduceMin(D /*d*/, VFromD v) { + return GetLane(v); +} +template +HWY_API TFromD ReduceMax(D /*d*/, VFromD v) { + return GetLane(v); +} + +template +HWY_API VFromD SumOfLanes(D /* tag */, VFromD v) { + return v; +} +template +HWY_API VFromD MinOfLanes(D /* tag */, VFromD v) { + return v; +} +template +HWY_API VFromD MaxOfLanes(D /* tag */, VFromD v) { + return v; +} + +// N=4 for 8-bit is still less than the minimum native size. + +// ARMv7 NEON/PPC/RVV/SVE have target-specific implementations of the N=4 I8/U8 +// ReduceSum operations +#if (defined(HWY_NATIVE_REDUCE_SUM_4_UI8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif +template +HWY_API TFromD ReduceSum(D d, VFromD v) { + const Twice> dw; + return static_cast>(ReduceSum(dw, PromoteTo(dw, v))); +} +#endif // HWY_NATIVE_REDUCE_SUM_4_UI8 + +// RVV/SVE have target-specific implementations of the N=4 I8/U8 +// ReduceMin/ReduceMax operations +#if (defined(HWY_NATIVE_REDUCE_MINMAX_4_UI8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#else +#define HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#endif +template +HWY_API TFromD ReduceMin(D d, VFromD v) { + const Twice> dw; + return static_cast>(ReduceMin(dw, PromoteTo(dw, v))); +} +template +HWY_API TFromD ReduceMax(D d, VFromD v) { + const Twice> dw; + return static_cast>(ReduceMax(dw, PromoteTo(dw, v))); +} +#endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8 + +#if (defined(HWY_NATIVE_MASKED_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR +#undef HWY_NATIVE_MASKED_REDUCE_SCALAR +#else +#define HWY_NATIVE_MASKED_REDUCE_SCALAR +#endif + +template +HWY_API TFromD MaskedReduceSum(D d, M m, VFromD v) { + return ReduceSum(d, IfThenElseZero(m, v)); +} +template +HWY_API TFromD MaskedReduceMin(D d, M m, VFromD v) { + return ReduceMin( + d, IfThenElse(m, v, Set(d, hwy::PositiveInfOrHighestValue>()))); +} +template +HWY_API TFromD MaskedReduceMax(D d, M m, VFromD v) { + return ReduceMax( + d, IfThenElse(m, v, Set(d, hwy::NegativeInfOrLowestValue>()))); +} + +#endif // HWY_NATIVE_MASKED_REDUCE_SCALAR + +// ------------------------------ IsEitherNaN +#if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IS_EITHER_NAN +#undef HWY_NATIVE_IS_EITHER_NAN +#else +#define HWY_NATIVE_IS_EITHER_NAN +#endif + +template +HWY_API MFromD> IsEitherNaN(V a, V b) { + return Or(IsNaN(a), IsNaN(b)); +} + +#endif // HWY_NATIVE_IS_EITHER_NAN + +// ------------------------------ IsInf, IsFinite + +// AVX3 has target-specific implementations of these. +#if (defined(HWY_NATIVE_ISINF) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +template > +HWY_API MFromD IsInf(const V v) { + using T = TFromD; + const D d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask( + d, + Eq(Add(vu, vu), + Set(du, static_cast>(hwy::MaxExponentTimes2())))); +} + +// Returns whether normal/subnormal/zero. +template > +HWY_API MFromD IsFinite(const V v) { + using T = TFromD; + const D d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); +// 'Shift left' to clear the sign bit. MSVC seems to generate incorrect code +// for AVX2 if we instead add vu + vu. +#if HWY_COMPILER_MSVC + const VFromD shl = ShiftLeft<1>(vu); +#else + const VFromD shl = Add(vu, vu); +#endif + + // Then shift right so we can compare with the max exponent (cannot compare + // with MaxExponentTimes2 directly because it is negative and non-negative + // floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(shl)); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +#endif // HWY_NATIVE_ISINF + +// ------------------------------ CeilInt/FloorInt +#if (defined(HWY_NATIVE_CEIL_FLOOR_INT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_CEIL_FLOOR_INT +#undef HWY_NATIVE_CEIL_FLOOR_INT +#else +#define HWY_NATIVE_CEIL_FLOOR_INT +#endif + +template +HWY_API VFromD>> CeilInt(V v) { + const DFromV d; + const RebindToSigned di; + return ConvertTo(di, Ceil(v)); +} + +template +HWY_API VFromD>> FloorInt(V v) { + const DFromV d; + const RebindToSigned di; + return ConvertTo(di, Floor(v)); +} + +#endif // HWY_NATIVE_CEIL_FLOOR_INT + +// ------------------------------ MulByPow2/MulByFloorPow2 + +#if (defined(HWY_NATIVE_MUL_BY_POW2) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MUL_BY_POW2 +#undef HWY_NATIVE_MUL_BY_POW2 +#else +#define HWY_NATIVE_MUL_BY_POW2 +#endif + +template +HWY_API V MulByPow2(V v, VFromD>> exp) { + const DFromV df; + const RebindToUnsigned du; + const RebindToSigned di; + + using TF = TFromD; + using TI = TFromD; + using TU = TFromD; + + using VF = VFromD; + using VI = VFromD; + + constexpr TI kMaxBiasedExp = MaxExponentField(); + static_assert(kMaxBiasedExp > 0, "kMaxBiasedExp > 0 must be true"); + + constexpr TI kExpBias = static_cast(kMaxBiasedExp >> 1); + static_assert(kExpBias > 0, "kExpBias > 0 must be true"); + static_assert(kExpBias <= LimitsMax() / 3, + "kExpBias <= LimitsMax() / 3 must be true"); + +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + using TExpMinMax = If<(sizeof(TI) <= 4), TI, int32_t>; +#elif (HWY_TARGET >= HWY_SSSE3 && HWY_TARGET <= HWY_SSE2) || \ + HWY_TARGET == HWY_WASM || HWY_TARGET == HWY_WASM_EMU256 + using TExpMinMax = int16_t; +#else + using TExpMinMax = TI; +#endif + +#if HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SCALAR + using TExpSatSub = TU; +#elif HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \ + HWY_TARGET == HWY_WASM_EMU256 + using TExpSatSub = If<(sizeof(TF) == 4), uint8_t, uint16_t>; +#elif HWY_TARGET_IS_PPC + using TExpSatSub = If<(sizeof(TF) >= 4), uint32_t, TU>; +#else + using TExpSatSub = If<(sizeof(TF) == 4), uint8_t, TU>; +#endif + + static_assert(kExpBias <= static_cast(LimitsMax() / 3), + "kExpBias <= LimitsMax() / 3 must be true"); + + const Repartition d_exp_min_max; + const Repartition d_sat_exp_sub; + + constexpr int kNumOfExpBits = ExponentBits(); + constexpr int kNumOfMantBits = MantissaBits(); + + // The sign bit of BitCastScalar(a[i]) >> kNumOfMantBits can be zeroed out + // using SaturatedSub if kZeroOutSignUsingSatSub is true. + + // If kZeroOutSignUsingSatSub is true, then val_for_exp_sub will be bitcasted + // to a vector that has a smaller lane size than TU for the SaturatedSub + // operation below. + constexpr bool kZeroOutSignUsingSatSub = + ((sizeof(TExpSatSub) * 8) == static_cast(kNumOfExpBits)); + + // If kZeroOutSignUsingSatSub is true, then the upper + // (sizeof(TU) - sizeof(TExpSatSub)) * 8 bits of kExpDecrBy1Bits will be all + // ones and the lower sizeof(TExpSatSub) * 8 bits of kExpDecrBy1Bits will be + // equal to 1. + + // Otherwise, if kZeroOutSignUsingSatSub is false, kExpDecrBy1Bits will be + // equal to 1. + constexpr TU kExpDecrBy1Bits = static_cast( + TU{1} - (static_cast(kZeroOutSignUsingSatSub) << kNumOfExpBits)); + + VF val_for_exp_sub = v; + HWY_IF_CONSTEXPR(!kZeroOutSignUsingSatSub) { + // If kZeroOutSignUsingSatSub is not true, zero out the sign bit of + // val_for_exp_sub[i] using Abs + val_for_exp_sub = Abs(val_for_exp_sub); + } + + // min_exp1_plus_min_exp2[i] is the smallest exponent such that + // min_exp1_plus_min_exp2[i] >= 2 - kExpBias * 2 and + // std::ldexp(v[i], min_exp1_plus_min_exp2[i]) is a normal floating-point + // number if v[i] is a normal number + const VI min_exp1_plus_min_exp2 = BitCast( + di, + Max(BitCast( + d_exp_min_max, + Neg(BitCast( + di, + SaturatedSub( + BitCast(d_sat_exp_sub, ShiftRight( + BitCast(du, val_for_exp_sub))), + BitCast(d_sat_exp_sub, Set(du, kExpDecrBy1Bits)))))), + BitCast(d_exp_min_max, + Set(di, static_cast(2 - kExpBias - kExpBias))))); + + const VI clamped_exp = + Max(Min(exp, Set(di, static_cast(kExpBias * 3))), + Add(min_exp1_plus_min_exp2, Set(di, static_cast(1 - kExpBias)))); + + const VI exp1_plus_exp2 = BitCast( + di, Max(Min(BitCast(d_exp_min_max, + Sub(clamped_exp, ShiftRight<2>(clamped_exp))), + BitCast(d_exp_min_max, + Set(di, static_cast(kExpBias + kExpBias)))), + BitCast(d_exp_min_max, min_exp1_plus_min_exp2))); + + const VI exp1 = ShiftRight<1>(exp1_plus_exp2); + const VI exp2 = Sub(exp1_plus_exp2, exp1); + const VI exp3 = Sub(clamped_exp, exp1_plus_exp2); + + const VI exp_bias = Set(di, kExpBias); + + const VF factor1 = + BitCast(df, ShiftLeft(Add(exp1, exp_bias))); + const VF factor2 = + BitCast(df, ShiftLeft(Add(exp2, exp_bias))); + const VF factor3 = + BitCast(df, ShiftLeft(Add(exp3, exp_bias))); + + return Mul(Mul(Mul(v, factor1), factor2), factor3); +} + +template +HWY_API V MulByFloorPow2(V v, V exp) { + const DFromV df; + + // MulByFloorPow2 special cases: + // MulByFloorPow2(v, NaN) => NaN + // MulByFloorPow2(0, inf) => NaN + // MulByFloorPow2(inf, -inf) => NaN + // MulByFloorPow2(-inf, -inf) => NaN + const auto is_special_case_with_nan_result = + Or(IsNaN(exp), + And(Eq(Abs(v), IfNegativeThenElseZero(exp, Inf(df))), IsInf(exp))); + + return IfThenElse(is_special_case_with_nan_result, NaN(df), + MulByPow2(v, FloorInt(exp))); +} + +#endif // HWY_NATIVE_MUL_BY_POW2 + +// ------------------------------ GetBiasedExponent +#if (defined(HWY_NATIVE_GET_BIASED_EXPONENT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_GET_BIASED_EXPONENT +#undef HWY_NATIVE_GET_BIASED_EXPONENT +#else +#define HWY_NATIVE_GET_BIASED_EXPONENT +#endif + +template +HWY_API VFromD>> GetBiasedExponent(V v) { + using T = TFromV; + + const DFromV d; + const RebindToUnsigned du; + + constexpr int kNumOfMantBits = MantissaBits(); + return ShiftRight(BitCast(du, Abs(v))); +} + +#endif + +// ------------------------------ GetExponent + +#if (defined(HWY_NATIVE_GET_EXPONENT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_GET_EXPONENT +#undef HWY_NATIVE_GET_EXPONENT +#else +#define HWY_NATIVE_GET_EXPONENT +#endif + +template +HWY_API V GetExponent(V v) { + const DFromV d; + using T = TFromV; + const RebindToSigned di; + + const auto exponent_offset = Set(di, MaxExponentField() >> 1); + + // extract exponent bits as integer + const auto encoded_exponent = GetBiasedExponent(v); + const auto exponent_int = Sub(BitCast(di, encoded_exponent), exponent_offset); + + // convert integer to original type + return ConvertTo(d, exponent_int); +} + +#endif // HWY_NATIVE_GET_EXPONENT +// ------------------------------ LoadInterleaved2 + +#if HWY_IDE || \ + (defined(HWY_NATIVE_LOAD_STORE_INTERLEAVED) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template +HWY_API void LoadInterleaved2(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + const VFromD A = LoadU(d, unaligned); // v1[1] v0[1] v1[0] v0[0] + const VFromD B = LoadU(d, unaligned + Lanes(d)); + v0 = ConcatEven(d, B, A); + v1 = ConcatOdd(d, B, A); +} + +template +HWY_API void LoadInterleaved2(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +// ------------------------------ LoadInterleaved3 (CombineShiftRightBytes) + +namespace detail { + +#if HWY_IDE +template +HWY_INLINE V ShuffleTwo1230(V a, V /* b */) { + return a; +} +template +HWY_INLINE V ShuffleTwo2301(V a, V /* b */) { + return a; +} +template +HWY_INLINE V ShuffleTwo3012(V a, V /* b */) { + return a; +} +#endif // HWY_IDE + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void LoadTransposedBlocks3(D d, + const TFromD* HWY_RESTRICT unaligned, + VFromD& A, VFromD& B, + VFromD& C) { + constexpr size_t kN = MaxLanes(d); + A = LoadU(d, unaligned + 0 * kN); + B = LoadU(d, unaligned + 1 * kN); + C = LoadU(d, unaligned + 2 * kN); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + using V = VFromD; + using VU = VFromD; + // Compact notation so these fit on one line: 12 := v1[2]. + V A; // 05 24 14 04 23 13 03 22 12 02 21 11 01 20 10 00 + V B; // 1a 0a 29 19 09 28 18 08 27 17 07 26 16 06 25 15 + V C; // 2f 1f 0f 2e 1e 0e 2d 1d 0d 2c 1c 0c 2b 1b 0b 2a + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + const VU idx_v0A = + Dup128VecFromValues(du, 0, 3, 6, 9, 12, 15, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU idx_v0B = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, 2, 5, 8, 11, 14, Z, Z, Z, Z, Z); + const VU idx_v0C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 1, 4, 7, 10, 13); + const VU idx_v1A = + Dup128VecFromValues(du, 1, 4, 7, 10, 13, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU idx_v1B = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 0, 3, 6, 9, 12, 15, Z, Z, Z, Z, Z); + const VU idx_v1C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 2, 5, 8, 11, 14); + const VU idx_v2A = + Dup128VecFromValues(du, 2, 5, 8, 11, 14, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU idx_v2B = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 1, 4, 7, 10, 13, Z, Z, Z, Z, Z, Z); + const VU idx_v2C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 0, 3, 6, 9, 12, 15); + const V v0L = BitCast(d, TableLookupBytesOr0(A, idx_v0A)); + const V v0M = BitCast(d, TableLookupBytesOr0(B, idx_v0B)); + const V v0U = BitCast(d, TableLookupBytesOr0(C, idx_v0C)); + const V v1L = BitCast(d, TableLookupBytesOr0(A, idx_v1A)); + const V v1M = BitCast(d, TableLookupBytesOr0(B, idx_v1B)); + const V v1U = BitCast(d, TableLookupBytesOr0(C, idx_v1C)); + const V v2L = BitCast(d, TableLookupBytesOr0(A, idx_v2A)); + const V v2M = BitCast(d, TableLookupBytesOr0(B, idx_v2B)); + const V v2U = BitCast(d, TableLookupBytesOr0(C, idx_v2C)); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 8-bit lanes x8 +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + using V = VFromD; + using VU = VFromD; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + const VU idx_v0A = + Dup128VecFromValues(du, 0, 3, 6, Z, Z, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v0B = + Dup128VecFromValues(du, Z, Z, Z, 1, 4, 7, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v0C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v1A = + Dup128VecFromValues(du, 1, 4, 7, Z, Z, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v1B = + Dup128VecFromValues(du, Z, Z, Z, 2, 5, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v1C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 0, 3, 6, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v2A = + Dup128VecFromValues(du, 2, 5, Z, Z, Z, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v2B = + Dup128VecFromValues(du, Z, Z, 0, 3, 6, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v2C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 1, 4, 7, 0, 0, 0, 0, 0, 0, 0, 0); + const V v0L = BitCast(d, TableLookupBytesOr0(A, idx_v0A)); + const V v0M = BitCast(d, TableLookupBytesOr0(B, idx_v0B)); + const V v0U = BitCast(d, TableLookupBytesOr0(C, idx_v0C)); + const V v1L = BitCast(d, TableLookupBytesOr0(A, idx_v1A)); + const V v1M = BitCast(d, TableLookupBytesOr0(B, idx_v1B)); + const V v1U = BitCast(d, TableLookupBytesOr0(C, idx_v1C)); + const V v2L = BitCast(d, TableLookupBytesOr0(A, idx_v2A)); + const V v2M = BitCast(d, TableLookupBytesOr0(B, idx_v2B)); + const V v2U = BitCast(d, TableLookupBytesOr0(C, idx_v2C)); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 16-bit lanes x8 +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + const Repartition du8; + using V = VFromD; + using VU8 = VFromD; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. Same as above, + // but each element of the array contains a byte index for a byte of a lane. + constexpr uint8_t Z = 0x80; + const VU8 idx_v0A = Dup128VecFromValues(du8, 0x00, 0x01, 0x06, 0x07, 0x0C, + 0x0D, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU8 idx_v0B = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, 0x02, 0x03, + 0x08, 0x09, 0x0E, 0x0F, Z, Z, Z, Z); + const VU8 idx_v0C = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, + Z, 0x04, 0x05, 0x0A, 0x0B); + const VU8 idx_v1A = Dup128VecFromValues(du8, 0x02, 0x03, 0x08, 0x09, 0x0E, + 0x0F, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU8 idx_v1B = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, 0x04, 0x05, + 0x0A, 0x0B, Z, Z, Z, Z, Z, Z); + const VU8 idx_v1C = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, + 0x00, 0x01, 0x06, 0x07, 0x0C, 0x0D); + const VU8 idx_v2A = Dup128VecFromValues(du8, 0x04, 0x05, 0x0A, 0x0B, Z, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU8 idx_v2B = Dup128VecFromValues(du8, Z, Z, Z, Z, 0x00, 0x01, 0x06, + 0x07, 0x0C, 0x0D, Z, Z, Z, Z, Z, Z); + const VU8 idx_v2C = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, + 0x02, 0x03, 0x08, 0x09, 0x0E, 0x0F); + const V v0L = TableLookupBytesOr0(A, BitCast(d, idx_v0A)); + const V v0M = TableLookupBytesOr0(B, BitCast(d, idx_v0B)); + const V v0U = TableLookupBytesOr0(C, BitCast(d, idx_v0C)); + const V v1L = TableLookupBytesOr0(A, BitCast(d, idx_v1A)); + const V v1M = TableLookupBytesOr0(B, BitCast(d, idx_v1B)); + const V v1U = TableLookupBytesOr0(C, BitCast(d, idx_v1C)); + const V v2L = TableLookupBytesOr0(A, BitCast(d, idx_v2A)); + const V v2M = TableLookupBytesOr0(B, BitCast(d, idx_v2B)); + const V v2U = TableLookupBytesOr0(C, BitCast(d, idx_v2C)); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + using V = VFromD; + V A; // v0[1] v2[0] v1[0] v0[0] + V B; // v1[2] v0[2] v2[1] v1[1] + V C; // v2[3] v1[3] v0[3] v2[2] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + + const V vxx_02_03_xx = OddEven(C, B); + v0 = detail::ShuffleTwo1230(A, vxx_02_03_xx); + + // Shuffle2301 takes the upper/lower halves of the output from one input, so + // we cannot just combine 13 and 10 with 12 and 11 (similar to v0/v2). Use + // OddEven because it may have higher throughput than Shuffle. + const V vxx_xx_10_11 = OddEven(A, B); + const V v12_13_xx_xx = OddEven(B, C); + v1 = detail::ShuffleTwo2301(vxx_xx_10_11, v12_13_xx_xx); + + const V vxx_20_21_xx = OddEven(B, A); + v2 = detail::ShuffleTwo3012(vxx_20_21_xx, C); +} + +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + VFromD A; // v1[0] v0[0] + VFromD B; // v0[1] v2[0] + VFromD C; // v2[1] v1[1] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + v0 = OddEven(B, A); + v1 = CombineShiftRightBytes)>(d, C, A); + v2 = OddEven(C, B); +} + +template , HWY_IF_LANES_D(D, 1)> +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +// ------------------------------ LoadInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void LoadTransposedBlocks4(D d, + const TFromD* HWY_RESTRICT unaligned, + VFromD& vA, VFromD& vB, + VFromD& vC, VFromD& vD) { + constexpr size_t kN = MaxLanes(d); + vA = LoadU(d, unaligned + 0 * kN); + vB = LoadU(d, unaligned + 1 * kN); + vC = LoadU(d, unaligned + 2 * kN); + vD = LoadU(d, unaligned + 3 * kN); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + const Repartition d64; + using V64 = VFromD; + using V = VFromD; + // 16 lanes per block; the lowest four blocks are at the bottom of vA..vD. + // Here int[i] means the four interleaved values of the i-th 4-tuple and + // int[3..0] indicates four consecutive 4-tuples (0 = least-significant). + V vA; // int[13..10] int[3..0] + V vB; // int[17..14] int[7..4] + V vC; // int[1b..18] int[b..8] + V vD; // int[1f..1c] int[f..c] + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + + // For brevity, the comments only list the lower block (upper = lower + 0x10) + const V v5140 = InterleaveLower(d, vA, vB); // int[5,1,4,0] + const V vd9c8 = InterleaveLower(d, vC, vD); // int[d,9,c,8] + const V v7362 = InterleaveUpper(d, vA, vB); // int[7,3,6,2] + const V vfbea = InterleaveUpper(d, vC, vD); // int[f,b,e,a] + + const V v6420 = InterleaveLower(d, v5140, v7362); // int[6,4,2,0] + const V veca8 = InterleaveLower(d, vd9c8, vfbea); // int[e,c,a,8] + const V v7531 = InterleaveUpper(d, v5140, v7362); // int[7,5,3,1] + const V vfdb9 = InterleaveUpper(d, vd9c8, vfbea); // int[f,d,b,9] + + const V64 v10L = BitCast(d64, InterleaveLower(d, v6420, v7531)); // v10[7..0] + const V64 v10U = BitCast(d64, InterleaveLower(d, veca8, vfdb9)); // v10[f..8] + const V64 v32L = BitCast(d64, InterleaveUpper(d, v6420, v7531)); // v32[7..0] + const V64 v32U = BitCast(d64, InterleaveUpper(d, veca8, vfdb9)); // v32[f..8] + + v0 = BitCast(d, InterleaveLower(d64, v10L, v10U)); + v1 = BitCast(d, InterleaveUpper(d64, v10L, v10U)); + v2 = BitCast(d, InterleaveLower(d64, v32L, v32U)); + v3 = BitCast(d, InterleaveUpper(d64, v32L, v32U)); +} + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + // In the last step, we interleave by half of the block size, which is usually + // 8 bytes but half that for 8-bit x8 vectors. + using TW = hwy::UnsignedFromSize; + const Repartition dw; + using VW = VFromD; + + // (Comments are for 256-bit vectors.) + // 8 lanes per block; the lowest four blocks are at the bottom of vA..vD. + VFromD vA; // v3210[9]v3210[8] v3210[1]v3210[0] + VFromD vB; // v3210[b]v3210[a] v3210[3]v3210[2] + VFromD vC; // v3210[d]v3210[c] v3210[5]v3210[4] + VFromD vD; // v3210[f]v3210[e] v3210[7]v3210[6] + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + + const VFromD va820 = InterleaveLower(d, vA, vB); // v3210[a,8] v3210[2,0] + const VFromD vec64 = InterleaveLower(d, vC, vD); // v3210[e,c] v3210[6,4] + const VFromD vb931 = InterleaveUpper(d, vA, vB); // v3210[b,9] v3210[3,1] + const VFromD vfd75 = InterleaveUpper(d, vC, vD); // v3210[f,d] v3210[7,5] + + const VW v10_b830 = // v10[b..8] v10[3..0] + BitCast(dw, InterleaveLower(d, va820, vb931)); + const VW v10_fc74 = // v10[f..c] v10[7..4] + BitCast(dw, InterleaveLower(d, vec64, vfd75)); + const VW v32_b830 = // v32[b..8] v32[3..0] + BitCast(dw, InterleaveUpper(d, va820, vb931)); + const VW v32_fc74 = // v32[f..c] v32[7..4] + BitCast(dw, InterleaveUpper(d, vec64, vfd75)); + + v0 = BitCast(d, InterleaveLower(dw, v10_b830, v10_fc74)); + v1 = BitCast(d, InterleaveUpper(dw, v10_b830, v10_fc74)); + v2 = BitCast(d, InterleaveLower(dw, v32_b830, v32_fc74)); + v3 = BitCast(d, InterleaveUpper(dw, v32_b830, v32_fc74)); +} + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + using V = VFromD; + V vA; // v3210[4] v3210[0] + V vB; // v3210[5] v3210[1] + V vC; // v3210[6] v3210[2] + V vD; // v3210[7] v3210[3] + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + const V v10e = InterleaveLower(d, vA, vC); // v1[6,4] v0[6,4] v1[2,0] v0[2,0] + const V v10o = InterleaveLower(d, vB, vD); // v1[7,5] v0[7,5] v1[3,1] v0[3,1] + const V v32e = InterleaveUpper(d, vA, vC); // v3[6,4] v2[6,4] v3[2,0] v2[2,0] + const V v32o = InterleaveUpper(d, vB, vD); // v3[7,5] v2[7,5] v3[3,1] v2[3,1] + + v0 = InterleaveLower(d, v10e, v10o); + v1 = InterleaveUpper(d, v10e, v10o); + v2 = InterleaveLower(d, v32e, v32o); + v3 = InterleaveUpper(d, v32e, v32o); +} + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + VFromD vA, vB, vC, vD; + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + v0 = InterleaveLower(d, vA, vC); + v1 = InterleaveUpper(d, vA, vC); + v2 = InterleaveLower(d, vB, vD); + v3 = InterleaveUpper(d, vB, vD); +} + +// Any T x1 +template , HWY_IF_LANES_D(D, 1)> +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void StoreTransposedBlocks2(VFromD A, VFromD B, D d, + TFromD* HWY_RESTRICT unaligned) { + constexpr size_t kN = MaxLanes(d); + StoreU(A, d, unaligned + 0 * kN); + StoreU(B, d, unaligned + 1 * kN); +} + +} // namespace detail + +// >= 128 bit vector +template +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + TFromD* HWY_RESTRICT unaligned) { + const auto v10L = InterleaveLower(d, v0, v1); // .. v1[0] v0[0] + const auto v10U = InterleaveUpper(d, v0, v1); // .. v1[kN/2] v0[kN/2] + detail::StoreTransposedBlocks2(v10L, v10U, d, unaligned); +} + +// <= 64 bits +template +HWY_API void StoreInterleaved2(V part0, V part1, D d, + TFromD* HWY_RESTRICT unaligned) { + const Twice d2; + const auto v0 = ZeroExtendVector(d2, part0); + const auto v1 = ZeroExtendVector(d2, part1); + const auto v10 = InterleaveLower(d2, v0, v1); + StoreU(v10, d2, unaligned); +} + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes) + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void StoreTransposedBlocks3(VFromD A, VFromD B, VFromD C, + D d, TFromD* HWY_RESTRICT unaligned) { + constexpr size_t kN = MaxLanes(d); + StoreU(A, d, unaligned + 0 * kN); + StoreU(B, d, unaligned + 1 * kN); + StoreU(C, d, unaligned + 2 * kN); +} + +} // namespace detail + +// >= 128-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + using TU = TFromD; + using VU = VFromD; + const VU k5 = Set(du, TU{5}); + const VU k6 = Set(du, TU{6}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v0[5], v2[4],v1[4],v0[4] .. v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + const VFromD shuf_A0 = + Dup128VecFromValues(du, 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, 3, + 0x80, 0x80, 4, 0x80, 0x80, 5); + // Cannot reuse shuf_A0 because it contains 5. + const VFromD shuf_A1 = + Dup128VecFromValues(du, 0x80, 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, + 3, 0x80, 0x80, 4, 0x80, 0x80); + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + // cannot reuse shuf_A0 (has 5) + const VU shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const VU vA0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const VU vA1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const VU vA2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const VFromD A = BitCast(d, vA0 | vA1 | vA2); + + // B: v1[10],v0[10], v2[9],v1[9],v0[9] .. , v2[6],v1[6],v0[6], v2[5],v1[5] + const VU shuf_B0 = shuf_A2 + k6; // .A..9..8..7..6.. + const VU shuf_B1 = shuf_A0 + k5; // A..9..8..7..6..5 + const VU shuf_B2 = shuf_A1 + k5; // ..9..8..7..6..5. + const VU vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B = BitCast(d, vB0 | vB1 | vB2); + + // C: v2[15],v1[15],v0[15], v2[11],v1[11],v0[11], v2[10] + const VU shuf_C0 = shuf_B2 + k6; // ..F..E..D..C..B. + const VU shuf_C1 = shuf_B0 + k5; // .F..E..D..C..B.. + const VU shuf_C2 = shuf_B1 + k5; // F..E..D..C..B..A + const VU vC0 = TableLookupBytesOr0(v0, shuf_C0); + const VU vC1 = TableLookupBytesOr0(v1, shuf_C1); + const VU vC2 = TableLookupBytesOr0(v2, shuf_C2); + const VFromD C = BitCast(d, vC0 | vC1 | vC2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const Repartition du8; + using VU8 = VFromD; + const VU8 k2 = Set(du8, uint8_t{2 * sizeof(TFromD)}); + const VU8 k3 = Set(du8, uint8_t{3 * sizeof(TFromD)}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. Note that these are byte + // indices for 16-bit lanes. + const VFromD shuf_A1 = + Dup128VecFromValues(du8, 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, 2, 3, + 0x80, 0x80, 0x80, 0x80, 4, 5); + const VFromD shuf_A2 = + Dup128VecFromValues(du8, 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, + 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80); + + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU8 shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + + const VU8 A0 = TableLookupBytesOr0(v0, shuf_A0); + const VU8 A1 = TableLookupBytesOr0(v1, shuf_A1); + const VU8 A2 = TableLookupBytesOr0(v2, shuf_A2); + const VFromD A = BitCast(d, A0 | A1 | A2); + + // B: v0[5] v2[4],v1[4],v0[4], v2[3],v1[3],v0[3], v2[2] + const VU8 shuf_B0 = shuf_A1 + k3; // 5..4..3. + const VU8 shuf_B1 = shuf_A2 + k3; // ..4..3.. + const VU8 shuf_B2 = shuf_A0 + k2; // .4..3..2 + const VU8 vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU8 vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU8 vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B = BitCast(d, vB0 | vB1 | vB2); + + // C: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const VU8 shuf_C0 = shuf_B1 + k3; // ..7..6.. + const VU8 shuf_C1 = shuf_B2 + k3; // .7..6..5 + const VU8 shuf_C2 = shuf_B0 + k2; // 7..6..5. + const VU8 vC0 = TableLookupBytesOr0(v0, shuf_C0); + const VU8 vC1 = TableLookupBytesOr0(v1, shuf_C1); + const VU8 vC2 = TableLookupBytesOr0(v2, shuf_C2); + const VFromD C = BitCast(d, vC0 | vC1 | vC2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + + const VFromD v10_v00 = InterleaveLower(d, v0, v1); + const VFromD v01_v20 = OddEven(v0, v2); + // A: v0[1], v2[0],v1[0],v0[0] (<- lane 0) + const VFromD A = BitCast( + d, InterleaveLower(dw, BitCast(dw, v10_v00), BitCast(dw, v01_v20))); + + const VFromD v1_321 = ShiftRightLanes<1>(d, v1); + const VFromD v0_32 = ShiftRightLanes<2>(d, v0); + const VFromD v21_v11 = OddEven(v2, v1_321); + const VFromD v12_v02 = OddEven(v1_321, v0_32); + // B: v1[2],v0[2], v2[1],v1[1] + const VFromD B = BitCast( + d, InterleaveLower(dw, BitCast(dw, v21_v11), BitCast(dw, v12_v02))); + + // Notation refers to the upper 2 lanes of the vector for InterleaveUpper. + const VFromD v23_v13 = OddEven(v2, v1_321); + const VFromD v03_v22 = OddEven(v0, v2); + // C: v2[3],v1[3],v0[3], v2[2] + const VFromD C = BitCast( + d, InterleaveUpper(dw, BitCast(dw, v03_v22), BitCast(dw, v23_v13))); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const VFromD A = InterleaveLower(d, v0, v1); + const VFromD B = OddEven(v0, v2); + const VFromD C = InterleaveUpper(d, v1, v2); + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// 64-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and first result. + constexpr size_t kFullN = 16 / sizeof(TFromD); + const Full128 du; + using VU = VFromD; + const Full128> d_full; + const VU k5 = Set(du, uint8_t{5}); + const VU k6 = Set(du, uint8_t{6}); + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU shuf_A0 = Load(du, tbl_v0); + const VU shuf_A1 = Load(du, tbl_v1); // cannot reuse shuf_A0 (5 in MSB) + const VU shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const VU A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const VU A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const VU A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + StoreU(A, d_full, unaligned + 0 * kFullN); + + // Second (HALF) vector: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const VU shuf_B0 = shuf_A2 + k6; // ..7..6.. + const VU shuf_B1 = shuf_A0 + k5; // .7..6..5 + const VU shuf_B2 = shuf_A1 + k5; // 7..6..5. + const VU vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B{BitCast(d_full, vB0 | vB1 | vB2).raw}; + StoreU(B, d, unaligned + 1 * kFullN); +} + +// 64-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D dh, + TFromD* HWY_RESTRICT unaligned) { + const Twice d_full; + const Full128 du8; + using VU8 = VFromD; + const VU8 k2 = Set(du8, uint8_t{2 * sizeof(TFromD)}); + const VU8 k3 = Set(du8, uint8_t{3 * sizeof(TFromD)}); + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave part (v0,v1,v2) to full (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, + 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + + // The interleaved vectors will be named A, B; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU8 shuf_A1 = Load(du8, tbl_v1); // 2..1..0. + // .2..1..0 + const VU8 shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + const VU8 shuf_A2 = Load(du8, tbl_v2); // ..1..0.. + + const VU8 A0 = TableLookupBytesOr0(v0, shuf_A0); + const VU8 A1 = TableLookupBytesOr0(v1, shuf_A1); + const VU8 A2 = TableLookupBytesOr0(v2, shuf_A2); + const VFromD A = BitCast(d_full, A0 | A1 | A2); + StoreU(A, d_full, unaligned); + + // Second (HALF) vector: v2[3],v1[3],v0[3], v2[2] + const VU8 shuf_B0 = shuf_A1 + k3; // ..3. + const VU8 shuf_B1 = shuf_A2 + k3; // .3.. + const VU8 shuf_B2 = shuf_A0 + k2; // 3..2 + const VU8 vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU8 vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU8 vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B = BitCast(d_full, vB0 | vB1 | vB2); + StoreU(VFromD{B.raw}, dh, unaligned + MaxLanes(d_full)); +} + +// 64-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + // (same code as 128-bit vector, 64-bit lanes) + const VFromD v10_v00 = InterleaveLower(d, v0, v1); + const VFromD v01_v20 = OddEven(v0, v2); + const VFromD v21_v11 = InterleaveUpper(d, v1, v2); + constexpr size_t kN = MaxLanes(d); + StoreU(v10_v00, d, unaligned + 0 * kN); + StoreU(v01_v20, d, unaligned + 1 * kN); + StoreU(v21_v11, d, unaligned + 2 * kN); +} + +// 64-bit lanes are handled by the N=1 case below. + +// <= 32-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 du; + using VU = VFromD; + const Full128> d_full; + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, + 0x80, 3, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU shuf_A0 = Load(du, tbl_v0); + const VU shuf_A1 = CombineShiftRightBytes<15>(du, shuf_A0, shuf_A0); + const VU shuf_A2 = CombineShiftRightBytes<14>(du, shuf_A0, shuf_A0); + const VU A0 = TableLookupBytesOr0(v0, shuf_A0); // ......3..2..1..0 + const VU A1 = TableLookupBytesOr0(v1, shuf_A1); // .....3..2..1..0. + const VU A2 = TableLookupBytesOr0(v2, shuf_A2); // ....3..2..1..0.. + const VFromD A = BitCast(d_full, A0 | A1 | A2); + alignas(16) TFromD buf[MaxLanes(d_full)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// 32-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 du8; + using VU8 = VFromD; + const Full128> d_full; + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU8 shuf_A2 = Load(du8, tbl_v2); // ..1..0.. + const VU8 shuf_A1 = + CombineShiftRightBytes<2>(du8, shuf_A2, shuf_A2); // ...1..0. + const VU8 shuf_A0 = + CombineShiftRightBytes<4>(du8, shuf_A2, shuf_A2); // ....1..0 + const VU8 A0 = TableLookupBytesOr0(v0, shuf_A0); // ..1..0 + const VU8 A1 = TableLookupBytesOr0(v1, shuf_A1); // .1..0. + const VU8 A2 = TableLookupBytesOr0(v2, shuf_A2); // 1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + alignas(16) TFromD buf[MaxLanes(d_full)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// Single-element vector, any lane size: just store directly +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +// ------------------------------ StoreInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void StoreTransposedBlocks4(VFromD vA, VFromD vB, VFromD vC, + VFromD vD, D d, + TFromD* HWY_RESTRICT unaligned) { + constexpr size_t kN = MaxLanes(d); + StoreU(vA, d, unaligned + 0 * kN); + StoreU(vB, d, unaligned + 1 * kN); + StoreU(vC, d, unaligned + 2 * kN); + StoreU(vD, d, unaligned + 3 * kN); +} + +} // namespace detail + +// >= 128-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, + TFromD* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + const auto v10L = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32L = ZipLower(dw, v2, v3); + const auto v10U = ZipUpper(dw, v0, v1); + const auto v32U = ZipUpper(dw, v2, v3); + // The interleaved vectors are vA, vB, vC, vD. + const VFromD vA = BitCast(d, InterleaveLower(dw, v10L, v32L)); // 3210 + const VFromD vB = BitCast(d, InterleaveUpper(dw, v10L, v32L)); + const VFromD vC = BitCast(d, InterleaveLower(dw, v10U, v32U)); + const VFromD vD = BitCast(d, InterleaveUpper(dw, v10U, v32U)); + detail::StoreTransposedBlocks4(vA, vB, vC, vD, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, + TFromD* HWY_RESTRICT unaligned) { + // The interleaved vectors are vA, vB, vC, vD. + const VFromD vA = InterleaveLower(d, v0, v1); // v1[0] v0[0] + const VFromD vB = InterleaveLower(d, v2, v3); + const VFromD vC = InterleaveUpper(d, v0, v1); + const VFromD vD = InterleaveUpper(d, v2, v3); + detail::StoreTransposedBlocks4(vA, vB, vC, vD, d, unaligned); +} + +// 64-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, + VFromD part2, VFromD part3, D /* tag */, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128> d_full; + const RepartitionToWide dw; + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + const VFromD v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto A = BitCast(d_full, InterleaveLower(dw, v10, v32)); + const auto B = BitCast(d_full, InterleaveUpper(dw, v10, v32)); + StoreU(A, d_full, unaligned); + StoreU(B, d_full, unaligned + MaxLanes(d_full)); +} + +// 64-bit vector, 64-bit lane +template +HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, + VFromD part2, VFromD part3, D /* tag */, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128> d_full; + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + const VFromD v3{part3.raw}; + const auto A = InterleaveLower(d_full, v0, v1); // v1[0] v0[0] + const auto B = InterleaveLower(d_full, v2, v3); + StoreU(A, d_full, unaligned); + StoreU(B, d_full, unaligned + MaxLanes(d_full)); +} + +// <= 32-bit vectors +template +HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, + VFromD part2, VFromD part3, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128> d_full; + const RepartitionToWide dw; + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + const VFromD v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto v3210 = BitCast(d_full, InterleaveLower(dw, v10, v32)); + alignas(16) TFromD buf[MaxLanes(d_full)]; + StoreU(v3210, d_full, buf); + CopyBytes(buf, unaligned); +} + +#endif // HWY_NATIVE_LOAD_STORE_INTERLEAVED + +// ------------------------------ PairwiseAdd/PairwiseSub +#if (defined(HWY_NATIVE_PAIRWISE_ADD) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PAIRWISE_ADD +#undef HWY_NATIVE_PAIRWISE_ADD +#else +#define HWY_NATIVE_PAIRWISE_ADD +#endif + +template (), HWY_IF_LANES_GT_D(D, 1)> +HWY_API V PairwiseAdd(D d, V a, V b) { + return Add(InterleaveEven(d, a, b), InterleaveOdd(d, a, b)); +} + +#endif + +#if (defined(HWY_NATIVE_PAIRWISE_SUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PAIRWISE_SUB +#undef HWY_NATIVE_PAIRWISE_SUB +#else +#define HWY_NATIVE_PAIRWISE_SUB +#endif + +template (), HWY_IF_LANES_GT_D(D, 1)> +HWY_API V PairwiseSub(D d, V a, V b) { + return Sub(InterleaveOdd(d, a, b), InterleaveEven(d, a, b)); +} + +#endif + +// Load/StoreInterleaved for special floats. Requires HWY_GENERIC_IF_EMULATED_D +// is defined such that it is true only for types that actually require these +// generic implementations. +#if HWY_IDE || (defined(HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED) == \ + defined(HWY_TARGET_TOGGLE) && \ + defined(HWY_GENERIC_IF_EMULATED_D)) +#ifdef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#endif +#if HWY_IDE +#define HWY_GENERIC_IF_EMULATED_D(D) int +#endif + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + const RebindToUnsigned du; + VFromD vu0, vu1; + LoadInterleaved2(du, detail::U16LanePointer(unaligned), vu0, vu1); + v0 = BitCast(d, vu0); + v1 = BitCast(d, vu1); +} + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + VFromD vu0, vu1, vu2; + LoadInterleaved3(du, detail::U16LanePointer(unaligned), vu0, vu1, vu2); + v0 = BitCast(d, vu0); + v1 = BitCast(d, vu1); + v2 = BitCast(d, vu2); +} + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + const RebindToUnsigned du; + VFromD vu0, vu1, vu2, vu3; + LoadInterleaved4(du, detail::U16LanePointer(unaligned), vu0, vu1, vu2, vu3); + v0 = BitCast(d, vu0); + v1 = BitCast(d, vu1); + v2 = BitCast(d, vu2); + v3 = BitCast(d, vu3); +} + +template > +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + StoreInterleaved2(BitCast(du, v0), BitCast(du, v1), du, + detail::U16LanePointer(unaligned)); +} + +template > +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + StoreInterleaved3(BitCast(du, v0), BitCast(du, v1), BitCast(du, v2), du, + detail::U16LanePointer(unaligned)); +} + +template > +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + StoreInterleaved4(BitCast(du, v0), BitCast(du, v1), BitCast(du, v2), + BitCast(du, v3), du, detail::U16LanePointer(unaligned)); +} + +#endif // HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED + +// ------------------------------ LoadN + +#if (defined(HWY_NATIVE_LOAD_N) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +#if HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE +namespace detail { + +template +HWY_INLINE VFromD LoadNResizeBitCast(DTo d_to, DFrom d_from, + VFromD v) { +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3/SSE4, the LoadU operation will zero out any lanes of v.raw + // past the first (lowest-index) Lanes(d_from) lanes of v.raw if + // sizeof(decltype(v.raw)) > d_from.MaxBytes() is true + (void)d_from; + return ResizeBitCast(d_to, v); +#else + // On other targets such as PPC/NEON, the contents of any lanes past the first + // (lowest-index) Lanes(d_from) lanes of v.raw might be non-zero if + // sizeof(decltype(v.raw)) > d_from.MaxBytes() is true. + return ZeroExtendResizeBitCast(d_to, d_from, v); +#endif +} + +} // namespace detail + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : Zero(d); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : no; +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 1> d1; + + if (num_lanes >= 2) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 1> d1; + + if (num_lanes >= 2) return LoadU(d, p); + if (num_lanes == 0) return no; + return InterleaveLower(ResizeBitCast(d, LoadU(d1, p)), no); +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 2> d2; + const Half d1; + + if (num_lanes >= 4) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + + // Two or three lanes. + const VFromD v_lo = detail::LoadNResizeBitCast(d, d2, LoadU(d2, p)); + return (num_lanes == 2) ? v_lo : InsertLane(v_lo, 2, p[2]); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 2> d2; + + if (num_lanes >= 4) return LoadU(d, p); + if (num_lanes == 0) return no; + if (num_lanes == 1) return InsertLane(no, 0, p[0]); + + // Two or three lanes. + const VFromD v_lo = + ConcatUpperLower(d, no, ResizeBitCast(d, LoadU(d2, p))); + return (num_lanes == 2) ? v_lo : InsertLane(v_lo, 2, p[2]); +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 4> d4; + const Half d2; + const Half d1; + + if (num_lanes >= 8) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + + const size_t leading_len = num_lanes & 4; + VFromD v_trailing = Zero(d4); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + detail::LoadNResizeBitCast(d2, d1, LoadU(d1, p + leading_len + 2)), + v_trailing_lo2); + } else { + v_trailing = detail::LoadNResizeBitCast(d4, d2, v_trailing_lo2); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = detail::LoadNResizeBitCast(d4, d1, LoadU(d1, p + leading_len)); + } + + if (leading_len != 0) { + return Combine(d, v_trailing, LoadU(d4, p)); + } else { + return detail::LoadNResizeBitCast(d, d4, v_trailing); + } +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 4> d4; + const Half d2; + const Half d1; + + if (num_lanes >= 8) return LoadU(d, p); + if (num_lanes == 0) return no; + if (num_lanes == 1) return InsertLane(no, 0, p[0]); + + const size_t leading_len = num_lanes & 4; + VFromD v_trailing = ResizeBitCast(d4, no); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + InterleaveLower(ResizeBitCast(d2, LoadU(d1, p + leading_len + 2)), + ResizeBitCast(d2, no)), + v_trailing_lo2); + } else { + v_trailing = ConcatUpperLower(d4, ResizeBitCast(d4, no), + ResizeBitCast(d4, v_trailing_lo2)); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = InsertLane(ResizeBitCast(d4, no), 0, p[leading_len]); + } + + if (leading_len != 0) { + return Combine(d, v_trailing, LoadU(d4, p)); + } else { + return ConcatUpperLower(d, no, ResizeBitCast(d, v_trailing)); + } +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 8> d8; + const Half d4; + const Half d2; + const Half d1; + + if (num_lanes >= 16) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + + const size_t leading_len = num_lanes & 12; + VFromD v_trailing = Zero(d4); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + detail::LoadNResizeBitCast(d2, d1, LoadU(d1, p + leading_len + 2)), + v_trailing_lo2); + } else { + v_trailing = detail::LoadNResizeBitCast(d4, d2, v_trailing_lo2); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = detail::LoadNResizeBitCast(d4, d1, LoadU(d1, p + leading_len)); + } + + if (leading_len != 0) { + if (leading_len >= 8) { + const VFromD v_hi7 = + ((leading_len & 4) != 0) + ? Combine(d8, v_trailing, LoadU(d4, p + 8)) + : detail::LoadNResizeBitCast(d8, d4, v_trailing); + return Combine(d, v_hi7, LoadU(d8, p)); + } else { + return detail::LoadNResizeBitCast(d, d8, + Combine(d8, v_trailing, LoadU(d4, p))); + } + } else { + return detail::LoadNResizeBitCast(d, d4, v_trailing); + } +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 8> d8; + const Half d4; + const Half d2; + const Half d1; + + if (num_lanes >= 16) return LoadU(d, p); + if (num_lanes == 0) return no; + if (num_lanes == 1) return InsertLane(no, 0, p[0]); + + const size_t leading_len = num_lanes & 12; + VFromD v_trailing = ResizeBitCast(d4, no); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + InterleaveLower(ResizeBitCast(d2, LoadU(d1, p + leading_len + 2)), + ResizeBitCast(d2, no)), + v_trailing_lo2); + } else { + v_trailing = ConcatUpperLower(d4, ResizeBitCast(d4, no), + ResizeBitCast(d4, v_trailing_lo2)); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = InsertLane(ResizeBitCast(d4, no), 0, p[leading_len]); + } + + if (leading_len != 0) { + if (leading_len >= 8) { + const VFromD v_hi7 = + ((leading_len & 4) != 0) + ? Combine(d8, v_trailing, LoadU(d4, p + 8)) + : ConcatUpperLower(d8, ResizeBitCast(d8, no), + ResizeBitCast(d8, v_trailing)); + return Combine(d, v_hi7, LoadU(d8, p)); + } else { + return ConcatUpperLower( + d, ResizeBitCast(d, no), + ResizeBitCast(d, Combine(d8, v_trailing, LoadU(d4, p)))); + } + } else { + const Repartition du32; + // lowest 4 bytes from v_trailing, next 4 from no. + const VFromD lo8 = + InterleaveLower(ResizeBitCast(du32, v_trailing), BitCast(du32, no)); + return ConcatUpperLower(d, ResizeBitCast(d, no), ResizeBitCast(d, lo8)); + } +} + +#if HWY_MAX_BYTES >= 32 + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes >= Lanes(d)) return LoadU(d, p); + + const Half dh; + const size_t half_N = Lanes(dh); + if (num_lanes <= half_N) { + return ZeroExtendVector(d, LoadN(dh, p, num_lanes)); + } else { + const VFromD v_lo = LoadU(dh, p); + const VFromD v_hi = LoadN(dh, p + half_N, num_lanes - half_N); + return Combine(d, v_hi, v_lo); + } +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes >= Lanes(d)) return LoadU(d, p); + + const Half dh; + const size_t half_N = Lanes(dh); + const VFromD no_h = LowerHalf(no); + if (num_lanes <= half_N) { + return ConcatUpperLower(d, no, + ResizeBitCast(d, LoadNOr(no_h, dh, p, num_lanes))); + } else { + const VFromD v_lo = LoadU(dh, p); + const VFromD v_hi = + LoadNOr(no_h, dh, p + half_N, num_lanes - half_N); + return Combine(d, v_hi, v_lo); + } +} + +#endif // HWY_MAX_BYTES >= 32 + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast(d, LoadN(du, detail::U16LanePointer(p), num_lanes)); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast( + d, LoadNOr(BitCast(du, no), du, detail::U16LanePointer(p), num_lanes)); +} + +#else // !HWY_MEM_OPS_MIGHT_FAULT || HWY_HAVE_SCALABLE + +// For SVE and non-sanitizer AVX-512; RVV has its own specialization. +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { +#if HWY_MEM_OPS_MIGHT_FAULT + if (num_lanes <= 0) return Zero(d); +#endif + + return MaskedLoad(FirstN(d, num_lanes), d, p); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { +#if HWY_MEM_OPS_MIGHT_FAULT + if (num_lanes <= 0) return no; +#endif + + return MaskedLoadOr(no, FirstN(d, num_lanes), d, p); +} + +#endif // HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE +#endif // HWY_NATIVE_LOAD_N + +// ------------------------------ StoreN +#if (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +#if HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE +namespace detail { + +template +HWY_INLINE VFromD StoreNGetUpperHalf(DH dh, VFromD> v) { + constexpr size_t kMinShrVectBytes = HWY_TARGET_IS_NEON ? 8 : 16; + const FixedTag d_shift; + return ResizeBitCast( + dh, ShiftRightBytes(d_shift, ResizeBitCast(d_shift, v))); +} + +template +HWY_INLINE VFromD StoreNGetUpperHalf(DH dh, VFromD> v) { + return UpperHalf(dh, v); +} + +} // namespace detail + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 0) { + StoreU(v, d, p); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 1) { + StoreU(v, d, p); + } else if (max_lanes_to_store == 1) { + const FixedTag, 1> d1; + StoreU(LowerHalf(d1, v), d1, p); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const FixedTag, 2> d2; + const Half d1; + + if (max_lanes_to_store > 1) { + if (max_lanes_to_store >= 4) { + StoreU(v, d, p); + } else { + StoreU(ResizeBitCast(d2, v), d2, p); + if (max_lanes_to_store == 3) { + StoreU(ResizeBitCast(d1, detail::StoreNGetUpperHalf(d2, v)), d1, p + 2); + } + } + } else if (max_lanes_to_store == 1) { + StoreU(ResizeBitCast(d1, v), d1, p); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const FixedTag, 4> d4; + const Half d2; + const Half d1; + + if (max_lanes_to_store <= 1) { + if (max_lanes_to_store == 1) { + StoreU(ResizeBitCast(d1, v), d1, p); + } + } else if (max_lanes_to_store >= 8) { + StoreU(v, d, p); + } else if (max_lanes_to_store >= 4) { + StoreU(LowerHalf(d4, v), d4, p); + StoreN(detail::StoreNGetUpperHalf(d4, v), d4, p + 4, + max_lanes_to_store - 4); + } else { + StoreN(LowerHalf(d4, v), d4, p, max_lanes_to_store); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const FixedTag, 8> d8; + const Half d4; + const Half d2; + const Half d1; + + if (max_lanes_to_store <= 1) { + if (max_lanes_to_store == 1) { + StoreU(ResizeBitCast(d1, v), d1, p); + } + } else if (max_lanes_to_store >= 16) { + StoreU(v, d, p); + } else if (max_lanes_to_store >= 8) { + StoreU(LowerHalf(d8, v), d8, p); + StoreN(detail::StoreNGetUpperHalf(d8, v), d8, p + 8, + max_lanes_to_store - 8); + } else { + StoreN(LowerHalf(d8, v), d8, p, max_lanes_to_store); + } +} + +#if HWY_MAX_BYTES >= 32 +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t N = Lanes(d); + if (max_lanes_to_store >= N) { + StoreU(v, d, p); + return; + } + + const Half dh; + const size_t half_N = Lanes(dh); + if (max_lanes_to_store <= half_N) { + StoreN(LowerHalf(dh, v), dh, p, max_lanes_to_store); + } else { + StoreU(LowerHalf(dh, v), dh, p); + StoreN(UpperHalf(dh, v), dh, p + half_N, max_lanes_to_store - half_N); + } +} +#endif // HWY_MAX_BYTES >= 32 + +#else // !HWY_MEM_OPS_MIGHT_FAULT || HWY_HAVE_SCALABLE +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t N = Lanes(d); + const size_t clamped_max_lanes_to_store = HWY_MIN(max_lanes_to_store, N); +#if HWY_MEM_OPS_MIGHT_FAULT + if (clamped_max_lanes_to_store == 0) return; +#endif + + BlendedStore(v, FirstN(d, clamped_max_lanes_to_store), d, p); + + detail::MaybeUnpoison(p, clamped_max_lanes_to_store); +} +#endif // HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE + +#endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ TruncateStore +#if (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_STORE_TRUNCATED +#undef HWY_NATIVE_STORE_TRUNCATED +#else +#define HWY_NATIVE_STORE_TRUNCATED +#endif + +template +HWY_API void TruncateStore(VFromD v, const D /*d*/, T* HWY_RESTRICT p) { + using DTo = Rebind; + DTo dsmall; + StoreU(TruncateTo(dsmall, v), dsmall, p); +} + +#endif // (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ Scatter + +#if (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +template > +HWY_API void ScatterOffset(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> offset) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + HWY_ALIGN TI offset_lanes[MaxLanes(d)]; + Store(offset, di, offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < MaxLanes(d); ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template > +HWY_API void ScatterIndex(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + HWY_ALIGN TI index_lanes[MaxLanes(d)]; + Store(index, di, index_lanes); + + for (size_t i = 0; i < MaxLanes(d); ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +template > +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D d, + T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + HWY_ALIGN TI index_lanes[MaxLanes(d)]; + Store(index, di, index_lanes); + + HWY_ALIGN TI mask_lanes[MaxLanes(di)]; + Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask_lanes[i]) base[index_lanes[i]] = lanes[i]; + } +} + +template > +HWY_API void ScatterIndexN(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_store) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (i < max_lanes_to_store) base[ExtractLane(index, i)] = ExtractLane(v, i); + } +} +#else +template > +HWY_API void ScatterIndexN(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_store) { + MaskedScatterIndex(v, FirstN(d, max_lanes_to_store), d, base, index); +} +#endif // (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ Gather + +#if (defined(HWY_NATIVE_GATHER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +template > +HWY_API VFromD GatherOffset(D d, const T* HWY_RESTRICT base, + VFromD> offset) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI offset_lanes[MaxLanes(d)]; + Store(offset, di, offset_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(offset_lanes[i] >= 0); + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template > +HWY_API VFromD GatherIndex(D d, const T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI index_lanes[MaxLanes(d)]; + Store(index, di, index_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(index_lanes[i] >= 0); + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +template > +HWY_API VFromD MaskedGatherIndex(MFromD m, D d, + const T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI index_lanes[MaxLanes(di)]; + Store(index, di, index_lanes); + + HWY_ALIGN TI mask_lanes[MaxLanes(di)]; + Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(index_lanes[i] >= 0); + lanes[i] = mask_lanes[i] ? base[index_lanes[i]] : T{0}; + } + return Load(d, lanes); +} + +template > +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D d, + const T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI index_lanes[MaxLanes(di)]; + Store(index, di, index_lanes); + + HWY_ALIGN TI mask_lanes[MaxLanes(di)]; + Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + + HWY_ALIGN T no_lanes[MaxLanes(d)]; + Store(no, d, no_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(index_lanes[i] >= 0); + lanes[i] = mask_lanes[i] ? base[index_lanes[i]] : no_lanes[i]; + } + return Load(d, lanes); +} + +template > +HWY_API VFromD GatherIndexN(D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + return GatherIndexNOr(Zero(d), d, base, index, max_lanes_to_load); +} + +template > +HWY_API VFromD GatherIndexNOr(VFromD no, D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + VFromD v = no; + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (i < max_lanes_to_load) + v = InsertLane(v, i, base[ExtractLane(index, i)]); + } + return v; +} +#else +template > +HWY_API VFromD GatherIndexN(D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + return MaskedGatherIndex(FirstN(d, max_lanes_to_load), d, base, index); +} +template > +HWY_API VFromD GatherIndexNOr(VFromD no, D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + return MaskedGatherIndexOr(no, FirstN(d, max_lanes_to_load), d, base, index); +} +#endif // (defined(HWY_NATIVE_GATHER) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ Integer AbsDiff and SumsOf8AbsDiff + +#if (defined(HWY_NATIVE_INTEGER_ABS_DIFF) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +template +HWY_API V AbsDiff(V a, V b) { + return Sub(Max(a, b), Min(a, b)); +} + +#endif // HWY_NATIVE_INTEGER_ABS_DIFF + +#if (defined(HWY_NATIVE_SUMS_OF_8_ABS_DIFF) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#endif + +template ), + HWY_IF_V_SIZE_GT_D(DFromV, (HWY_TARGET == HWY_SCALAR ? 0 : 4))> +HWY_API Vec>> SumsOf8AbsDiff(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWideX3 dw; + + return BitCast(dw, SumsOf8(BitCast(du, AbsDiff(a, b)))); +} + +#endif // HWY_NATIVE_SUMS_OF_8_ABS_DIFF + +// ------------------------------ SaturatedAdd/SaturatedSub for UI32/UI64 + +#if (defined(HWY_NATIVE_I32_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#endif // HWY_NATIVE_I32_SATURATED_ADDSUB + +#if (defined(HWY_NATIVE_I64_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#endif // HWY_NATIVE_I64_SATURATED_ADDSUB + +#if (defined(HWY_NATIVE_U32_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + return Add(a, Min(b, Not(a))); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + return Sub(a, Min(a, b)); +} + +#endif // HWY_NATIVE_U32_SATURATED_ADDSUB + +#if (defined(HWY_NATIVE_U64_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + return Add(a, Min(b, Not(a))); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + return Sub(a, Min(a, b)); +} + +#endif // HWY_NATIVE_U64_SATURATED_ADDSUB + +// ------------------------------ Unsigned to signed demotions + +template , DN>>, + hwy::EnableIf<(sizeof(TFromD) < sizeof(TFromV))>* = nullptr, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_D(DFromV))> +HWY_API VFromD DemoteTo(DN dn, V v) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned dn_u; + + // First, do a signed to signed demotion. This will convert any values + // that are greater than hwy::HighestValue>>() to a + // negative value. + const auto i2i_demote_result = DemoteTo(dn, BitCast(di, v)); + + // Second, convert any negative values to hwy::HighestValue>() + // using an unsigned Min operation. + const auto max_signed_val = Set(dn, hwy::HighestValue>()); + + return BitCast( + dn, Min(BitCast(dn_u, i2i_demote_result), BitCast(dn_u, max_signed_val))); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template , DN>>, + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_D(DFromV))> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned dn_u; + + // First, do a signed to signed demotion. This will convert any values + // that are greater than hwy::HighestValue>>() to a + // negative value. + const auto i2i_demote_result = + ReorderDemote2To(dn, BitCast(di, a), BitCast(di, b)); + + // Second, convert any negative values to hwy::HighestValue>() + // using an unsigned Min operation. + const auto max_signed_val = Set(dn, hwy::HighestValue>()); + + return BitCast( + dn, Min(BitCast(dn_u, i2i_demote_result), BitCast(dn_u, max_signed_val))); +} +#endif + +// ------------------------------ PromoteLowerTo + +// There is no codegen advantage for a native version of this. It is provided +// only for convenience. +template +HWY_API VFromD PromoteLowerTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteTo(d, LowerHalf(dh, v)); +} + +// ------------------------------ PromoteUpperTo + +#if (defined(HWY_NATIVE_PROMOTE_UPPER_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// This requires UpperHalf. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +template +HWY_API VFromD PromoteUpperTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +#endif // HWY_TARGET != HWY_SCALAR +#endif // HWY_NATIVE_PROMOTE_UPPER_TO + +// ------------------------------ float16_t <-> float + +#if (defined(HWY_NATIVE_F16C) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const RebindToSigned di32; + const RebindToUnsigned du32; + const Rebind du16; + using VU32 = VFromD; + + const VU32 bits16 = PromoteTo(du32, BitCast(du16, v)); + const VU32 sign = ShiftRight<15>(bits16); + const VU32 biased_exp = And(ShiftRight<10>(bits16), Set(du32, 0x1F)); + const VU32 mantissa = And(bits16, Set(du32, 0x3FF)); + const VU32 subnormal = + BitCast(du32, Mul(ConvertTo(df32, BitCast(di32, mantissa)), + Set(df32, 1.0f / 16384 / 1024))); + + const VU32 biased_exp32 = Add(biased_exp, Set(du32, 127 - 15)); + const VU32 mantissa32 = ShiftLeft<23 - 10>(mantissa); + const VU32 normal = Or(ShiftLeft<23>(biased_exp32), mantissa32); + const VU32 bits32 = IfThenElse(Eq(biased_exp, Zero(du32)), subnormal, normal); + return BitCast(df32, Or(ShiftLeft<31>(sign), bits32)); +} + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToSigned di16; + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + + // There are 23 fractional bits (plus the implied 1 bit) in the mantissa of + // a F32, and there are 10 fractional bits (plus the implied 1 bit) in the + // mantissa of a F16 + + // We want the unbiased exponent of round_incr[i] to be at least (-14) + 13 as + // 2^(-14) is the smallest positive normal F16 value and as we want 13 + // mantissa bits (including the implicit 1 bit) to the left of the + // F32 mantissa bits in rounded_val[i] since 23 - 10 is equal to 13 + + // The biased exponent of round_incr[i] needs to be at least 126 as + // (-14) + 13 + 127 is equal to 126 + + // We also want to biased exponent of round_incr[i] to be less than or equal + // to 255 (which is equal to MaxExponentField()) + + // The biased F32 exponent of round_incr is equal to + // HWY_MAX(HWY_MIN(((exp_bits[i] >> 23) & 255) + 13, 255), 126) + + // hi9_bits[i] is equal to the upper 9 bits of v[i] + const auto hi9_bits = ShiftRight<23>(BitCast(du32, v)); + + const auto k13 = Set(du32, uint32_t{13u}); + + // Minimum biased F32 exponent of round_incr + const auto k126 = Set(du32, uint32_t{126u}); + + // round_incr_hi9_bits[i] is equivalent to + // (hi9_bits[i] & 0x100) | + // HWY_MAX(HWY_MIN((hi9_bits[i] & 0xFF) + 13, 255), 126) + +#if HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128 + const auto k255 = Set(du32, uint32_t{255u}); + const auto round_incr_hi9_bits = BitwiseIfThenElse( + k255, Max(Min(Add(And(hi9_bits, k255), k13), k255), k126), hi9_bits); +#else + // On targets other than SCALAR and EMU128, the exponent bits of hi9_bits can + // be incremented by 13 and clamped to the [13, 255] range without overflowing + // into the sign bit of hi9_bits by using U8 SaturatedAdd as there are 8 + // exponent bits in an F32 + + // U8 Max can be used on targets other than SCALAR and EMU128 to clamp + // ((hi9_bits & 0xFF) + 13) to the [126, 255] range without affecting the sign + // bit + + const Repartition du32_as_u8; + const auto round_incr_hi9_bits = BitCast( + du32, + Max(SaturatedAdd(BitCast(du32_as_u8, hi9_bits), BitCast(du32_as_u8, k13)), + BitCast(du32_as_u8, k126))); +#endif + + // (round_incr_hi9_bits >> 8) is equal to (hi9_bits >> 8), and + // (round_incr_hi9_bits & 0xFF) is equal to + // HWY_MAX(HWY_MIN((round_incr_hi9_bits & 0xFF) + 13, 255), 126) + + const auto round_incr = BitCast(df32, ShiftLeft<23>(round_incr_hi9_bits)); + + // Add round_incr[i] to v[i] to round the mantissa to the nearest F16 mantissa + // and to move the fractional bits of the resulting non-NaN mantissa down to + // the lower 10 bits of rounded_val if (v[i] + round_incr[i]) is a non-NaN + // value + const auto rounded_val = Add(v, round_incr); + + // rounded_val_bits is the bits of rounded_val as a U32 + const auto rounded_val_bits = BitCast(du32, rounded_val); + + // rounded_val[i] is known to have the same biased exponent as round_incr[i] + // as |round_incr[i]| > 2^12*|v[i]| is true if round_incr[i] is a finite + // value, round_incr[i] and v[i] both have the same sign, and |round_incr[i]| + // is either a power of 2 that is greater than or equal to 2^-1 or infinity. + + // If rounded_val[i] is a finite F32 value, then + // (rounded_val_bits[i] & 0x00000FFF) is the bit representation of the + // rounded mantissa of rounded_val[i] as a UQ2.10 fixed point number that is + // in the range [0, 2]. + + // In other words, (rounded_val_bits[i] & 0x00000FFF) is between 0 and 0x0800, + // with (rounded_val_bits[i] & 0x000003FF) being the fractional bits of the + // resulting F16 mantissa, if rounded_v[i] is a finite F32 value. + + // (rounded_val_bits[i] & 0x007FF000) == 0 is guaranteed to be true if + // rounded_val[i] is a non-NaN value + + // The biased exponent of rounded_val[i] is guaranteed to be at least 126 as + // the biased exponent of round_incr[i] is at least 126 and as both v[i] and + // round_incr[i] have the same sign bit + + // The ULP of a F32 value with a biased exponent of 126 is equal to + // 2^(126 - 127 - 23), which is equal to 2^(-24) (which is also the ULP of a + // F16 value with a biased exponent of 0 or 1 as (1 - 15 - 10) is equal to + // -24) + + // The biased exponent (before subtracting by 126) needs to be clamped to the + // [126, 157] range as 126 + 31 is equal to 157 and as 31 is the largest + // biased exponent of a F16. + + // The biased exponent of the resulting F16 value is equal to + // HWY_MIN((round_incr_hi9_bits[i] & 0xFF) + + // ((rounded_val_bits[i] >> 10) & 0xFF), 157) - 126 + +#if HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128 + const auto k157Shl10 = Set(du32, static_cast(uint32_t{157u} << 10)); + auto f16_exp_bits = + Min(Add(ShiftLeft<10>(And(round_incr_hi9_bits, k255)), + And(rounded_val_bits, + Set(du32, static_cast(uint32_t{0xFFu} << 10)))), + k157Shl10); + const auto f16_result_is_inf_mask = + RebindMask(df32, Eq(f16_exp_bits, k157Shl10)); +#else + const auto k157 = Set(du32, uint32_t{157}); + auto f16_exp_bits = BitCast( + du32, + Min(SaturatedAdd(BitCast(du32_as_u8, round_incr_hi9_bits), + BitCast(du32_as_u8, ShiftRight<10>(rounded_val_bits))), + BitCast(du32_as_u8, k157))); + const auto f16_result_is_inf_mask = RebindMask(df32, Eq(f16_exp_bits, k157)); + f16_exp_bits = ShiftLeft<10>(f16_exp_bits); +#endif + + f16_exp_bits = + Sub(f16_exp_bits, Set(du32, static_cast(uint32_t{126u} << 10))); + + const auto f16_unmasked_mant_bits = + BitCast(di32, Or(IfThenZeroElse(f16_result_is_inf_mask, rounded_val), + VecFromMask(df32, IsNaN(rounded_val)))); + + const auto f16_exp_mant_bits = + OrAnd(BitCast(di32, f16_exp_bits), f16_unmasked_mant_bits, + Set(di32, int32_t{0x03FF})); + + // f16_bits_as_i32 is the F16 bits sign-extended to an I32 (with the upper 17 + // bits of f16_bits_as_i32[i] set to the sign bit of rounded_val[i]) to allow + // efficient truncation of the F16 bits to an I16 using an I32->I16 DemoteTo + // operation + const auto f16_bits_as_i32 = + OrAnd(f16_exp_mant_bits, ShiftRight<16>(BitCast(di32, rounded_val_bits)), + Set(di32, static_cast(0xFFFF8000u))); + return BitCast(df16, DemoteTo(di16, f16_bits_as_i32)); +} + +#endif // HWY_NATIVE_F16C + +// ------------------------------ F64->F16 DemoteTo +#if (defined(HWY_NATIVE_DEMOTE_F64_TO_F16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const Rebind df64; + const Rebind du64; + const Rebind df32; + + // The mantissa bits of v[i] are first rounded using round-to-odd rounding to + // the nearest F64 value that has the lower 29 bits zeroed out to ensure that + // the result is correctly rounded to a F16. + + const auto vf64_rounded = OrAnd( + And(v, + BitCast(df64, Set(du64, static_cast(0xFFFFFFFFE0000000u)))), + BitCast(df64, Add(BitCast(du64, v), + Set(du64, static_cast(0x000000001FFFFFFFu)))), + BitCast(df64, Set(du64, static_cast(0x0000000020000000ULL)))); + + return DemoteTo(df16, DemoteTo(df32, vf64_rounded)); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_NATIVE_DEMOTE_F64_TO_F16 + +// ------------------------------ F16->F64 PromoteTo +#if (defined(HWY_NATIVE_PROMOTE_F16_TO_F64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { + return PromoteTo(df64, PromoteTo(Rebind(), v)); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_NATIVE_PROMOTE_F16_TO_F64 + +// ------------------------------ F32 to BF16 DemoteTo +#if (defined(HWY_NATIVE_DEMOTE_F32_TO_BF16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +namespace detail { + +// Round a F32 value to the nearest BF16 value, with the result returned as the +// rounded F32 value bitcasted to an U32 + +// RoundF32ForDemoteToBF16 also converts NaN values to QNaN values to prevent +// NaN F32 values from being converted to an infinity +template )> +HWY_INLINE VFromD>> RoundF32ForDemoteToBF16(V v) { + const DFromV d; + const RebindToUnsigned du32; + + const auto is_non_nan = Not(IsNaN(v)); + const auto bits32 = BitCast(du32, v); + + const auto round_incr = + Add(And(ShiftRight<16>(bits32), Set(du32, uint32_t{1})), + Set(du32, uint32_t{0x7FFFu})); + return MaskedAddOr(Or(bits32, Set(du32, uint32_t{0x00400000u})), + RebindMask(du32, is_non_nan), bits32, round_incr); +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { + const RebindToUnsigned du16; + const Twice dt_u16; + + const auto rounded_bits = BitCast(dt_u16, detail::RoundF32ForDemoteToBF16(v)); +#if HWY_IS_LITTLE_ENDIAN + return BitCast( + dbf16, LowerHalf(du16, ConcatOdd(dt_u16, rounded_bits, rounded_bits))); +#else + return BitCast( + dbf16, LowerHalf(du16, ConcatEven(dt_u16, rounded_bits, rounded_bits))); +#endif +} + +template +HWY_API VFromD OrderedDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; + + const auto rounded_a_bits32 = + BitCast(du16, detail::RoundF32ForDemoteToBF16(a)); + const auto rounded_b_bits32 = + BitCast(du16, detail::RoundF32ForDemoteToBF16(b)); +#if HWY_IS_LITTLE_ENDIAN + return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, rounded_b_bits32), + BitCast(du16, rounded_a_bits32))); +#else + return BitCast(dbf16, ConcatEven(du16, BitCast(du16, rounded_b_bits32), + BitCast(du16, rounded_a_bits32))); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; + +#if HWY_IS_LITTLE_ENDIAN + const auto a_in_odd = detail::RoundF32ForDemoteToBF16(a); + const auto b_in_even = ShiftRight<16>(detail::RoundF32ForDemoteToBF16(b)); +#else + const auto a_in_odd = ShiftRight<16>(detail::RoundF32ForDemoteToBF16(a)); + const auto b_in_even = detail::RoundF32ForDemoteToBF16(b); +#endif + + return BitCast(dbf16, + OddEven(BitCast(du16, a_in_odd), BitCast(du16, b_in_even))); +} + +#endif // HWY_NATIVE_DEMOTE_F32_TO_BF16 + +// ------------------------------ PromoteInRangeTo +#if (defined(HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +#if HWY_HAVE_INTEGER64 +template +HWY_API VFromD PromoteInRangeTo(D64 d64, VFromD> v) { + return PromoteTo(d64, v); +} +#endif + +#endif // HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO + +// ------------------------------ ConvertInRangeTo +#if (defined(HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +template +HWY_API VFromD ConvertInRangeTo(DI di, VFromD> v) { + return ConvertTo(di, v); +} + +#endif // HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO + +// ------------------------------ DemoteInRangeTo +#if (defined(HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD DemoteInRangeTo(D32 d32, VFromD> v) { + return DemoteTo(d32, v); +} +#endif + +#endif // HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO + +// ------------------------------ PromoteInRangeLowerTo/PromoteInRangeUpperTo + +template )> +HWY_API VFromD PromoteInRangeLowerTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, LowerHalf(dh, v)); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API VFromD PromoteInRangeUpperTo(D d, V v) { +#if (HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_EMU128 || \ + (HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64)) + // On targets that provide target-specific implementations of F32->UI64 + // PromoteInRangeTo, promote the upper half of v using PromoteInRangeTo + + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, UpperHalf(dh, v)); +#else + // Otherwise, on targets where F32->UI64 PromoteInRangeTo is simply a wrapper + // around F32->UI64 PromoteTo, promote the upper half of v to TFromD using + // PromoteUpperTo + return PromoteUpperTo(d, v); +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ PromoteInRangeEvenTo/PromoteInRangeOddTo + +template )> +HWY_API VFromD PromoteInRangeEvenTo(D d, V v) { +#if HWY_TARGET == HWY_SCALAR + return PromoteInRangeTo(d, v); +#elif (HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_EMU128 || \ + (HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64)) + // On targets that provide target-specific implementations of F32->UI64 + // PromoteInRangeTo, promote the even lanes of v using PromoteInRangeTo + + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const DFromV d_from; + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, LowerHalf(dh, ConcatEven(d_from, v, v))); +#else + // Otherwise, on targets where F32->UI64 PromoteInRangeTo is simply a wrapper + // around F32->UI64 PromoteTo, promote the even lanes of v to TFromD using + // PromoteEvenTo + return PromoteEvenTo(d, v); +#endif // HWY_TARGET == HWY_SCALAR +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API VFromD PromoteInRangeOddTo(D d, V v) { +#if (HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_EMU128 || \ + (HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64)) + // On targets that provide target-specific implementations of F32->UI64 + // PromoteInRangeTo, promote the odd lanes of v using PromoteInRangeTo + + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const DFromV d_from; + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, LowerHalf(dh, ConcatOdd(d_from, v, v))); +#else + // Otherwise, on targets where F32->UI64 PromoteInRangeTo is simply a wrapper + // around F32->UI64 PromoteTo, promote the odd lanes of v to TFromD using + // PromoteOddTo + return PromoteOddTo(d, v); +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ SumsOf2 + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf2( + TypeTag /*type_tag*/, hwy::SizeTag /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + return Add(PromoteEvenTo(dw, v), PromoteOddTo(dw, v)); +} + +} // namespace detail + +template +HWY_API VFromD>> SumsOf2(V v) { + return detail::SumsOf2(hwy::TypeTag>(), + hwy::SizeTag)>(), v); +} +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ SumsOf4 + +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf4( + TypeTag /*type_tag*/, hwy::SizeTag /*lane_size_tag*/, V v) { + using hwy::HWY_NAMESPACE::SumsOf2; + return SumsOf2(SumsOf2(v)); +} + +} // namespace detail + +template +HWY_API VFromD>> SumsOf4(V v) { + return detail::SumsOf4(hwy::TypeTag>(), + hwy::SizeTag)>(), v); +} + +// ------------------------------ OrderedTruncate2To + +#if HWY_IDE || \ + (defined(HWY_NATIVE_ORDERED_TRUNCATE_2_TO) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#undef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#else +#define HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#endif + +// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template ) * 2), + HWY_IF_LANES_D(DFromV>, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedTruncate2To(DN dn, V a, V b) { + return ConcatEven(dn, BitCast(dn, b), BitCast(dn, a)); +} +#endif // HWY_TARGET != HWY_SCALAR +#endif // HWY_NATIVE_ORDERED_TRUNCATE_2_TO + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#if (defined(HWY_NATIVE_LEADING_ZERO_COUNT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +namespace detail { + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const RebindToFloat df; +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE2 + const RebindToSigned di; + const Repartition di16; + + // On SSE2/SSSE3/SSE4/AVX2, do an int32_t to float conversion, followed + // by a unsigned right shift of the uint32_t bit representation of the + // floating point values by 23, followed by an int16_t Min + // operation as we are only interested in the biased exponent that would + // result from a uint32_t to float conversion. + + // An int32_t to float vector conversion is also much more efficient on + // SSE2/SSSE3/SSE4/AVX2 than an uint32_t vector to float vector conversion + // as an uint32_t vector to float vector conversion on SSE2/SSSE3/SSE4/AVX2 + // requires multiple instructions whereas an int32_t to float vector + // conversion can be carried out using a single instruction on + // SSE2/SSSE3/SSE4/AVX2. + + const auto f32_bits = BitCast(d, ConvertTo(df, BitCast(di, v))); + return BitCast(d, Min(BitCast(di16, ShiftRight<23>(f32_bits)), + BitCast(di16, Set(d, 158)))); +#else + const auto f32_bits = BitCast(d, ConvertTo(df, v)); + return BitCast(d, ShiftRight<23>(f32_bits)); +#endif +} + +template )> +HWY_INLINE V I32RangeU32ToF32BiasedExp(V v) { + // I32RangeU32ToF32BiasedExp is similar to UIntToF32BiasedExp, but + // I32RangeU32ToF32BiasedExp assumes that v[i] is between 0 and 2147483647. + const DFromV d; + const RebindToFloat df; +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE2 + const RebindToSigned d_src; +#else + const RebindToUnsigned d_src; +#endif + const auto f32_bits = BitCast(d, ConvertTo(df, BitCast(d_src, v))); + return ShiftRight<23>(f32_bits); +} + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Rebind du32; + const auto f32_biased_exp_as_u32 = + I32RangeU32ToF32BiasedExp(PromoteTo(du32, v)); + return TruncateTo(d, f32_biased_exp_as_u32); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Half dh; + const Rebind du32; + + const auto lo_u32 = PromoteTo(du32, LowerHalf(dh, v)); + const auto hi_u32 = PromoteTo(du32, UpperHalf(dh, v)); + + const auto lo_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(lo_u32); + const auto hi_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(hi_u32); +#if HWY_TARGET <= HWY_SSE2 + const RebindToSigned di32; + const RebindToSigned di; + return BitCast(d, + OrderedDemote2To(di, BitCast(di32, lo_f32_biased_exp_as_u32), + BitCast(di32, hi_f32_biased_exp_as_u32))); +#else + return OrderedTruncate2To(d, lo_f32_biased_exp_as_u32, + hi_f32_biased_exp_as_u32); +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Rebind du32; + const auto f32_biased_exp_as_u32 = + I32RangeU32ToF32BiasedExp(PromoteTo(du32, v)); + return U8FromU32(f32_biased_exp_as_u32); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Half dh; + const Rebind du32; + const Repartition du16; + + const auto lo_u32 = PromoteTo(du32, LowerHalf(dh, v)); + const auto hi_u32 = PromoteTo(du32, UpperHalf(dh, v)); + + const auto lo_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(lo_u32); + const auto hi_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(hi_u32); + +#if HWY_TARGET <= HWY_SSE2 + const RebindToSigned di32; + const RebindToSigned di16; + const auto f32_biased_exp_as_i16 = + OrderedDemote2To(di16, BitCast(di32, lo_f32_biased_exp_as_u32), + BitCast(di32, hi_f32_biased_exp_as_u32)); + return DemoteTo(d, f32_biased_exp_as_i16); +#else + const auto f32_biased_exp_as_u16 = OrderedTruncate2To( + du16, lo_f32_biased_exp_as_u32, hi_f32_biased_exp_as_u32); + return TruncateTo(d, f32_biased_exp_as_u16); +#endif +} + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Half dh; + const Half dq; + const Rebind du32; + const Repartition du16; + + const auto lo_half = LowerHalf(dh, v); + const auto hi_half = UpperHalf(dh, v); + + const auto u32_q0 = PromoteTo(du32, LowerHalf(dq, lo_half)); + const auto u32_q1 = PromoteTo(du32, UpperHalf(dq, lo_half)); + const auto u32_q2 = PromoteTo(du32, LowerHalf(dq, hi_half)); + const auto u32_q3 = PromoteTo(du32, UpperHalf(dq, hi_half)); + + const auto f32_biased_exp_as_u32_q0 = I32RangeU32ToF32BiasedExp(u32_q0); + const auto f32_biased_exp_as_u32_q1 = I32RangeU32ToF32BiasedExp(u32_q1); + const auto f32_biased_exp_as_u32_q2 = I32RangeU32ToF32BiasedExp(u32_q2); + const auto f32_biased_exp_as_u32_q3 = I32RangeU32ToF32BiasedExp(u32_q3); + +#if HWY_TARGET <= HWY_SSE2 + const RebindToSigned di32; + const RebindToSigned di16; + + const auto lo_f32_biased_exp_as_i16 = + OrderedDemote2To(di16, BitCast(di32, f32_biased_exp_as_u32_q0), + BitCast(di32, f32_biased_exp_as_u32_q1)); + const auto hi_f32_biased_exp_as_i16 = + OrderedDemote2To(di16, BitCast(di32, f32_biased_exp_as_u32_q2), + BitCast(di32, f32_biased_exp_as_u32_q3)); + return OrderedDemote2To(d, lo_f32_biased_exp_as_i16, + hi_f32_biased_exp_as_i16); +#else + const auto lo_f32_biased_exp_as_u16 = OrderedTruncate2To( + du16, f32_biased_exp_as_u32_q0, f32_biased_exp_as_u32_q1); + const auto hi_f32_biased_exp_as_u16 = OrderedTruncate2To( + du16, f32_biased_exp_as_u32_q2, f32_biased_exp_as_u32_q3); + return OrderedTruncate2To(d, lo_f32_biased_exp_as_u16, + hi_f32_biased_exp_as_u16); +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +#if HWY_TARGET == HWY_SCALAR +template +using F32ExpLzcntMinMaxRepartition = RebindToUnsigned; +#elif HWY_TARGET >= HWY_SSSE3 && HWY_TARGET <= HWY_SSE2 +template +using F32ExpLzcntMinMaxRepartition = Repartition; +#else +template +using F32ExpLzcntMinMaxRepartition = + Repartition), 4)>, D>; +#endif + +template +using F32ExpLzcntMinMaxCmpV = VFromD>>; + +template +HWY_INLINE F32ExpLzcntMinMaxCmpV F32ExpLzcntMinMaxBitCast(V v) { + const DFromV d; + const F32ExpLzcntMinMaxRepartition d2; + return BitCast(d2, v); +} + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { +#if HWY_TARGET == HWY_SCALAR + const uint64_t u64_val = GetLane(v); + const float f32_val = static_cast(u64_val); + const uint32_t f32_bits = BitCastScalar(f32_val); + return Set(d, static_cast(f32_bits >> 23)); +#else + const Repartition du32; + const auto f32_biased_exp = UIntToF32BiasedExp(du32, BitCast(du32, v)); + const auto f32_biased_exp_adj = + IfThenZeroElse(Eq(f32_biased_exp, Zero(du32)), + BitCast(du32, Set(d, 0x0000002000000000u))); + const auto adj_f32_biased_exp = Add(f32_biased_exp, f32_biased_exp_adj); + + return ShiftRight<32>(BitCast( + d, Max(F32ExpLzcntMinMaxBitCast(adj_f32_biased_exp), + F32ExpLzcntMinMaxBitCast(Reverse2(du32, adj_f32_biased_exp))))); +#endif +} + +template +HWY_INLINE V UIntToF32BiasedExp(V v) { + const DFromV d; + return UIntToF32BiasedExp(d, v); +} + +template +HWY_INLINE V NormalizeForUIntTruncConvToF32(V v) { + return v; +} + +template +HWY_INLINE V NormalizeForUIntTruncConvToF32(V v) { + // If v[i] >= 16777216 is true, make sure that the bit at + // HighestSetBitIndex(v[i]) - 24 is zeroed out to ensure that any inexact + // conversion to single-precision floating point is rounded down. + + // This zeroing-out can be accomplished through the AndNot operation below. + return AndNot(ShiftRight<24>(v), v); +} + +} // namespace detail + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + + const auto f32_biased_exp = detail::UIntToF32BiasedExp( + detail::NormalizeForUIntTruncConvToF32(BitCast(du, v))); + return BitCast(d, Sub(f32_biased_exp, Set(du, TU{127}))); +} + +template +HWY_API V LeadingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + + constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; + const auto f32_biased_exp = detail::UIntToF32BiasedExp( + detail::NormalizeForUIntTruncConvToF32(BitCast(du, v))); + const auto lz_count = Sub(Set(du, TU{kNumOfBitsInT + 126}), f32_biased_exp); + + return BitCast(d, + Min(detail::F32ExpLzcntMinMaxBitCast(lz_count), + detail::F32ExpLzcntMinMaxBitCast(Set(du, kNumOfBitsInT)))); +} + +template +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + using TU = TFromD; + + const auto vi = BitCast(di, v); + const auto lowest_bit = BitCast(du, And(vi, Neg(vi))); + + constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; + const auto f32_biased_exp = detail::UIntToF32BiasedExp(lowest_bit); + const auto tz_count = Sub(f32_biased_exp, Set(du, TU{127})); + + return BitCast(d, + Min(detail::F32ExpLzcntMinMaxBitCast(tz_count), + detail::F32ExpLzcntMinMaxBitCast(Set(du, kNumOfBitsInT)))); +} +#endif // HWY_NATIVE_LEADING_ZERO_COUNT + +// ------------------------------ MaskedLeadingZeroCount +#if (defined(HWY_NATIVE_MASKED_LEADING_ZERO_COUNT) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#endif + +template +HWY_API V MaskedLeadingZeroCount(M m, V v) { + return IfThenElseZero(m, LeadingZeroCount(v)); +} +#endif // HWY_NATIVE_MASKED_LEADING_ZERO_COUNT + +// ------------------------------ AESRound + +// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +// Define for white-box testing, even if native instructions are available. +namespace detail { + +// Constant-time: computes inverse in GF(2^4) based on "Accelerating AES with +// Vector Permute Instructions" and the accompanying assembly language +// implementation: https://crypto.stanford.edu/vpaes/vpaes.tgz. See also Botan: +// https://botan.randombit.net/doxygen/aes__vperm_8cpp_source.html . +// +// A brute-force 256 byte table lookup can also be made constant-time, and +// possibly competitive on NEON, but this is more performance-portable +// especially for x86 and large vectors. + +template // u8 +HWY_INLINE V SubBytesMulInverseAndAffineLookup(V state, V affine_tblL, + V affine_tblU) { + const DFromV du; + const auto mask = Set(du, uint8_t{0xF}); + + // Change polynomial basis to GF(2^4) + { + const VFromD basisL = + Dup128VecFromValues(du, 0x00, 0x70, 0x2A, 0x5A, 0x98, 0xE8, 0xB2, 0xC2, + 0x08, 0x78, 0x22, 0x52, 0x90, 0xE0, 0xBA, 0xCA); + const VFromD basisU = + Dup128VecFromValues(du, 0x00, 0x4D, 0x7C, 0x31, 0x7D, 0x30, 0x01, 0x4C, + 0x81, 0xCC, 0xFD, 0xB0, 0xFC, 0xB1, 0x80, 0xCD); + const auto sL = And(state, mask); + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto gf4L = TableLookupBytes(basisL, sL); + const auto gf4U = TableLookupBytes(basisU, sU); + state = Xor(gf4L, gf4U); + } + + // Inversion in GF(2^4). Elements 0 represent "infinity" (division by 0) and + // cause TableLookupBytesOr0 to return 0. + const VFromD zetaInv = Dup128VecFromValues( + du, 0x80, 7, 11, 15, 6, 10, 4, 1, 9, 8, 5, 2, 12, 14, 13, 3); + const VFromD tbl = Dup128VecFromValues( + du, 0x80, 1, 8, 13, 15, 6, 5, 14, 2, 12, 11, 10, 9, 3, 7, 4); + const auto sL = And(state, mask); // L=low nibble, U=upper + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto sX = Xor(sU, sL); + const auto invL = TableLookupBytes(zetaInv, sL); + const auto invU = TableLookupBytes(tbl, sU); + const auto invX = TableLookupBytes(tbl, sX); + const auto outL = Xor(sX, TableLookupBytesOr0(tbl, Xor(invL, invU))); + const auto outU = Xor(sU, TableLookupBytesOr0(tbl, Xor(invL, invX))); + + const auto affL = TableLookupBytesOr0(affine_tblL, outL); + const auto affU = TableLookupBytesOr0(affine_tblU, outU); + return Xor(affL, affU); +} + +template // u8 +HWY_INLINE V SubBytes(V state) { + const DFromV du; + // Linear skew (cannot bake 0x63 bias into the table because out* indices + // may have the infinity flag set). + const VFromD affineL = + Dup128VecFromValues(du, 0x00, 0xC7, 0xBD, 0x6F, 0x17, 0x6D, 0xD2, 0xD0, + 0x78, 0xA8, 0x02, 0xC5, 0x7A, 0xBF, 0xAA, 0x15); + const VFromD affineU = + Dup128VecFromValues(du, 0x00, 0x6A, 0xBB, 0x5F, 0xA5, 0x74, 0xE4, 0xCF, + 0xFA, 0x35, 0x2B, 0x41, 0xD1, 0x90, 0x1E, 0x8E); + return Xor(SubBytesMulInverseAndAffineLookup(state, affineL, affineU), + Set(du, uint8_t{0x63})); +} + +template // u8 +HWY_INLINE V InvSubBytes(V state) { + const DFromV du; + const VFromD gF2P4InvToGF2P8InvL = + Dup128VecFromValues(du, 0x00, 0x40, 0xF9, 0x7E, 0x53, 0xEA, 0x87, 0x13, + 0x2D, 0x3E, 0x94, 0xD4, 0xB9, 0x6D, 0xAA, 0xC7); + const VFromD gF2P4InvToGF2P8InvU = + Dup128VecFromValues(du, 0x00, 0x1D, 0x44, 0x93, 0x0F, 0x56, 0xD7, 0x12, + 0x9C, 0x8E, 0xC5, 0xD8, 0x59, 0x81, 0x4B, 0xCA); + + // Apply the inverse affine transformation + const auto b = Xor(Xor3(Or(ShiftLeft<1>(state), ShiftRight<7>(state)), + Or(ShiftLeft<3>(state), ShiftRight<5>(state)), + Or(ShiftLeft<6>(state), ShiftRight<2>(state))), + Set(du, uint8_t{0x05})); + + // The GF(2^8) multiplicative inverse is computed as follows: + // - Changing the polynomial basis to GF(2^4) + // - Computing the GF(2^4) multiplicative inverse + // - Converting the GF(2^4) multiplicative inverse to the GF(2^8) + // multiplicative inverse through table lookups using the + // kGF2P4InvToGF2P8InvL and kGF2P4InvToGF2P8InvU tables + return SubBytesMulInverseAndAffineLookup(b, gF2P4InvToGF2P8InvL, + gF2P4InvToGF2P8InvU); +} + +} // namespace detail + +#endif // HWY_TARGET != HWY_SCALAR + +#if (defined(HWY_NATIVE_AES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +namespace detail { + +template // u8 +HWY_INLINE V ShiftRows(const V state) { + const DFromV du; + // transposed: state is column major + const VFromD shift_row = Dup128VecFromValues( + du, 0, 5, 10, 15, 4, 9, 14, 3, 8, 13, 2, 7, 12, 1, 6, 11); + return TableLookupBytes(state, shift_row); +} + +template // u8 +HWY_INLINE V InvShiftRows(const V state) { + const DFromV du; + // transposed: state is column major + const VFromD shift_row = Dup128VecFromValues( + du, 0, 13, 10, 7, 4, 1, 14, 11, 8, 5, 2, 15, 12, 9, 6, 3); + return TableLookupBytes(state, shift_row); +} + +template // u8 +HWY_INLINE V GF2P8Mod11BMulBy2(V v) { + const DFromV du; + const RebindToSigned di; // can only do signed comparisons + const auto msb = Lt(BitCast(di, v), Zero(di)); + const auto overflow = BitCast(du, IfThenElseZero(msb, Set(di, int8_t{0x1B}))); + return Xor(Add(v, v), overflow); // = v*2 in GF(2^8). +} + +template // u8 +HWY_INLINE V MixColumns(const V state) { + const DFromV du; + // For each column, the rows are the sum of GF(2^8) matrix multiplication by: + // 2 3 1 1 // Let s := state*1, d := state*2, t := state*3. + // 1 2 3 1 // d are on diagonal, no permutation needed. + // 1 1 2 3 // t1230 indicates column indices of threes for the 4 rows. + // 3 1 1 2 // We also need to compute s2301 and s3012 (=1230 o 2301). + const VFromD v2301 = Dup128VecFromValues( + du, 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13); + const VFromD v1230 = Dup128VecFromValues( + du, 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12); + const auto d = GF2P8Mod11BMulBy2(state); // = state*2 in GF(2^8). + const auto s2301 = TableLookupBytes(state, v2301); + const auto d_s2301 = Xor(d, s2301); + const auto t_s2301 = Xor(state, d_s2301); // t(s*3) = XOR-sum {s, d(s*2)} + const auto t1230_s3012 = TableLookupBytes(t_s2301, v1230); + return Xor(d_s2301, t1230_s3012); // XOR-sum of 4 terms +} + +template // u8 +HWY_INLINE V InvMixColumns(const V state) { + const DFromV du; + // For each column, the rows are the sum of GF(2^8) matrix multiplication by: + // 14 11 13 9 + // 9 14 11 13 + // 13 9 14 11 + // 11 13 9 14 + const VFromD v2301 = Dup128VecFromValues( + du, 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13); + const VFromD v1230 = Dup128VecFromValues( + du, 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12); + + const auto sx2 = GF2P8Mod11BMulBy2(state); /* = state*2 in GF(2^8) */ + const auto sx4 = GF2P8Mod11BMulBy2(sx2); /* = state*4 in GF(2^8) */ + const auto sx8 = GF2P8Mod11BMulBy2(sx4); /* = state*8 in GF(2^8) */ + const auto sx9 = Xor(sx8, state); /* = state*9 in GF(2^8) */ + const auto sx11 = Xor(sx9, sx2); /* = state*11 in GF(2^8) */ + const auto sx13 = Xor(sx9, sx4); /* = state*13 in GF(2^8) */ + const auto sx14 = Xor3(sx8, sx4, sx2); /* = state*14 in GF(2^8) */ + + const auto sx13_0123_sx9_1230 = Xor(sx13, TableLookupBytes(sx9, v1230)); + const auto sx14_0123_sx11_1230 = Xor(sx14, TableLookupBytes(sx11, v1230)); + const auto sx13_2301_sx9_3012 = TableLookupBytes(sx13_0123_sx9_1230, v2301); + return Xor(sx14_0123_sx11_1230, sx13_2301_sx9_3012); +} + +} // namespace detail + +template // u8 +HWY_API V AESRound(V state, const V round_key) { + // Intel docs swap the first two steps, but it does not matter because + // ShiftRows is a permutation and SubBytes is independent of lane index. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = detail::MixColumns(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template // u8 +HWY_API V AESLastRound(V state, const V round_key) { + // LIke AESRound, but without MixColumns. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template +HWY_API V AESInvMixColumns(V state) { + return detail::InvMixColumns(state); +} + +template // u8 +HWY_API V AESRoundInv(V state, const V round_key) { + state = detail::InvSubBytes(state); + state = detail::InvShiftRows(state); + state = detail::InvMixColumns(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template // u8 +HWY_API V AESLastRoundInv(V state, const V round_key) { + // Like AESRoundInv, but without InvMixColumns. + state = detail::InvSubBytes(state); + state = detail::InvShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template )> +HWY_API V AESKeyGenAssist(V v) { + const DFromV d; + const V rconXorMask = Dup128VecFromValues(d, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, + 0, 0, kRcon, 0, 0, 0); + const V rotWordShuffle = Dup128VecFromValues(d, 4, 5, 6, 7, 5, 6, 7, 4, 12, + 13, 14, 15, 13, 14, 15, 12); + const auto sub_word_result = detail::SubBytes(v); + const auto rot_word_result = + TableLookupBytes(sub_word_result, rotWordShuffle); + return Xor(rot_word_result, rconXorMask); +} + +// Constant-time implementation inspired by +// https://www.bearssl.org/constanttime.html, but about half the cost because we +// use 64x64 multiplies and 128-bit XORs. +template +HWY_API V CLMulLower(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulEven(a0, b0), MulEven(a1, b3)); + auto m1 = Xor(MulEven(a0, b1), MulEven(a1, b0)); + auto m2 = Xor(MulEven(a0, b2), MulEven(a1, b1)); + auto m3 = Xor(MulEven(a0, b3), MulEven(a1, b2)); + m0 = Xor(m0, Xor(MulEven(a2, b2), MulEven(a3, b1))); + m1 = Xor(m1, Xor(MulEven(a2, b3), MulEven(a3, b2))); + m2 = Xor(m2, Xor(MulEven(a2, b0), MulEven(a3, b3))); + m3 = Xor(m3, Xor(MulEven(a2, b1), MulEven(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +template +HWY_API V CLMulUpper(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulOdd(a0, b0), MulOdd(a1, b3)); + auto m1 = Xor(MulOdd(a0, b1), MulOdd(a1, b0)); + auto m2 = Xor(MulOdd(a0, b2), MulOdd(a1, b1)); + auto m3 = Xor(MulOdd(a0, b3), MulOdd(a1, b2)); + m0 = Xor(m0, Xor(MulOdd(a2, b2), MulOdd(a3, b1))); + m1 = Xor(m1, Xor(MulOdd(a2, b3), MulOdd(a3, b2))); + m2 = Xor(m2, Xor(MulOdd(a2, b0), MulOdd(a3, b3))); + m3 = Xor(m3, Xor(MulOdd(a2, b1), MulOdd(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +#endif // HWY_NATIVE_AES +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ PopulationCount + +#if (defined(HWY_NATIVE_POPCNT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +// This overload requires vectors to be at least 16 bytes, which is the case +// for LMUL >= 2. +#undef HWY_IF_POPCNT +#if HWY_TARGET == HWY_RVV +#define HWY_IF_POPCNT(D) \ + hwy::EnableIf= 1 && D().MaxLanes() >= 16>* = nullptr +#else +// Other targets only have these two overloads which are mutually exclusive, so +// no further conditions are required. +#define HWY_IF_POPCNT(D) void* = nullptr +#endif // HWY_TARGET == HWY_RVV + +template , HWY_IF_U8_D(D), + HWY_IF_V_SIZE_GT_D(D, 8), HWY_IF_POPCNT(D)> +HWY_API V PopulationCount(V v) { + const D d; + const V lookup = + Dup128VecFromValues(d, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4); + const auto lo = And(v, Set(d, uint8_t{0xF})); + const auto hi = ShiftRight<4>(v); + return Add(TableLookupBytes(lookup, hi), TableLookupBytes(lookup, lo)); +} + +// RVV has a specialization that avoids the Set(). +#if HWY_TARGET != HWY_RVV +// Slower fallback for capped vectors. +template , HWY_IF_U8_D(D), + HWY_IF_V_SIZE_LE_D(D, 8)> +HWY_API V PopulationCount(V v) { + const D d; + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + const V k33 = Set(d, uint8_t{0x33}); + v = Sub(v, And(ShiftRight<1>(v), Set(d, uint8_t{0x55}))); + v = Add(And(ShiftRight<2>(v), k33), And(v, k33)); + return And(Add(v, ShiftRight<4>(v)), Set(d, uint8_t{0x0F})); +} +#endif // HWY_TARGET != HWY_RVV + +template , HWY_IF_U16_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + const Repartition d8; + const auto vals = BitCast(d, PopulationCount(BitCast(d8, v))); + return Add(ShiftRight<8>(vals), And(vals, Set(d, uint16_t{0xFF}))); +} + +template , HWY_IF_U32_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + Repartition d16; + auto vals = BitCast(d, PopulationCount(BitCast(d16, v))); + return Add(ShiftRight<16>(vals), And(vals, Set(d, uint32_t{0xFF}))); +} + +#if HWY_HAVE_INTEGER64 +template , HWY_IF_U64_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + Repartition d32; + auto vals = BitCast(d, PopulationCount(BitCast(d32, v))); + return Add(ShiftRight<32>(vals), And(vals, Set(d, 0xFFULL))); +} +#endif + +#endif // HWY_NATIVE_POPCNT + +// ------------------------------ 8-bit multiplication + +#if (defined(HWY_NATIVE_MUL_8) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif + +// 8 bit and fits in wider reg: promote +template +HWY_API V operator*(const V a, const V b) { + const DFromV d; + const Rebind>, decltype(d)> dw; + const RebindToUnsigned du; // TruncateTo result + const RebindToUnsigned dwu; // TruncateTo input + const VFromD mul = PromoteTo(dw, a) * PromoteTo(dw, b); + // TruncateTo is cheaper than ConcatEven. + return BitCast(d, TruncateTo(du, BitCast(dwu, mul))); +} + +// 8 bit full reg: promote halves +template +HWY_API V operator*(const V a, const V b) { + const DFromV d; + const Half dh; + const Twice> dw; + const VFromD a0 = PromoteTo(dw, LowerHalf(dh, a)); + const VFromD a1 = PromoteTo(dw, UpperHalf(dh, a)); + const VFromD b0 = PromoteTo(dw, LowerHalf(dh, b)); + const VFromD b1 = PromoteTo(dw, UpperHalf(dh, b)); + const VFromD m0 = a0 * b0; + const VFromD m1 = a1 * b1; + return ConcatEven(d, BitCast(d, m1), BitCast(d, m0)); +} + +#endif // HWY_NATIVE_MUL_8 + +// ------------------------------ 64-bit multiplication + +#if (defined(HWY_NATIVE_MUL_64) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +// Single-lane i64 or u64 +template +HWY_API V operator*(V x, V y) { + const DFromV d; + using T = TFromD; + using TU = MakeUnsigned; + const TU xu = static_cast(GetLane(x)); + const TU yu = static_cast(GetLane(y)); + return Set(d, static_cast(xu * yu)); +} + +template , HWY_IF_U64_D(D64), + HWY_IF_V_SIZE_GT_D(D64, 8)> +HWY_API V operator*(V x, V y) { + RepartitionToNarrow d32; + auto x32 = BitCast(d32, x); + auto y32 = BitCast(d32, y); + auto lolo = BitCast(d32, MulEven(x32, y32)); + auto lohi = BitCast(d32, MulEven(x32, BitCast(d32, ShiftRight<32>(y)))); + auto hilo = BitCast(d32, MulEven(BitCast(d32, ShiftRight<32>(x)), y32)); + auto hi = BitCast(d32, ShiftLeft<32>(BitCast(D64{}, lohi + hilo))); + return BitCast(D64{}, lolo + hi); +} +template , HWY_IF_I64_D(DI64), + HWY_IF_V_SIZE_GT_D(DI64, 8)> +HWY_API V operator*(V x, V y) { + RebindToUnsigned du64; + return BitCast(DI64{}, BitCast(du64, x) * BitCast(du64, y)); +} + +#endif // HWY_NATIVE_MUL_64 + +// ------------------------------ MulRound +template +HWY_API V MulRound(V a, V b) { + return Round(Mul(a, b)); +} + +// ------------------------------ MulAdd / NegMulAdd + +#if (defined(HWY_NATIVE_INT_FMA) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA +#else +#define HWY_NATIVE_INT_FMA +#endif + +#ifdef HWY_NATIVE_INT_FMSUB +#undef HWY_NATIVE_INT_FMSUB +#else +#define HWY_NATIVE_INT_FMSUB +#endif + +template +HWY_API V MulAdd(V mul, V x, V add) { + return Add(Mul(mul, x), add); +} + +template +HWY_API V NegMulAdd(V mul, V x, V add) { + return Sub(add, Mul(mul, x)); +} + +template +HWY_API V MulSub(V mul, V x, V sub) { + return Sub(Mul(mul, x), sub); +} +#endif // HWY_NATIVE_INT_FMA +// ------------------------------ MulComplex* / MaskedMulComplex* + +#if (defined(HWY_NATIVE_CPLX) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_CPLX +#undef HWY_NATIVE_CPLX +#else +#define HWY_NATIVE_CPLX +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +template )> +HWY_API V ComplexConj(V a) { + return OddEven(Neg(a), a); +} + +template +HWY_API V MulComplex(V a, V b) { + // a = u + iv, b = x + iy + const auto u = DupEven(a); + const auto v = DupOdd(a); + const auto x = DupEven(b); + const auto y = DupOdd(b); + + return OddEven(MulAdd(u, y, Mul(v, x)), Sub(Mul(u, x), Mul(v, y))); +} + +template +HWY_API V MulComplexConj(V a, V b) { + // a = u + iv, b = x + iy + const auto u = DupEven(a); + const auto v = DupOdd(a); + const auto x = DupEven(b); + const auto y = DupOdd(b); + + return OddEven(Sub(Mul(v, x), Mul(u, y)), MulAdd(u, x, Mul(v, y))); +} + +template +HWY_API V MulComplexAdd(V a, V b, V c) { + return Add(MulComplex(a, b), c); +} + +template +HWY_API V MulComplexConjAdd(V a, V b, V c) { + return Add(MulComplexConj(a, b), c); +} + +template +HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { + return IfThenElseZero(mask, MulComplexConjAdd(a, b, c)); +} + +template +HWY_API V MaskedMulComplexConj(M mask, V a, V b) { + return IfThenElseZero(mask, MulComplexConj(a, b)); +} + +template +HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { + return IfThenElse(mask, MulComplex(a, b), no); +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_CPLX + +// ------------------------------ MaskedMulAddOr +#if (defined(HWY_NATIVE_MASKED_INT_FMA) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_INT_FMA +#undef HWY_NATIVE_MASKED_INT_FMA +#else +#define HWY_NATIVE_MASKED_INT_FMA +#endif + +template +HWY_API V MaskedMulAddOr(V no, M m, V mul, V x, V add) { + return IfThenElse(m, MulAdd(mul, x, add), no); +} + +#endif // HWY_NATIVE_MASKED_INT_FMA + +// ------------------------------ Integer MulSub / NegMulSub +#if (defined(HWY_NATIVE_INT_FMSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INT_FMSUB +#undef HWY_NATIVE_INT_FMSUB +#else +#define HWY_NATIVE_INT_FMSUB +#endif + +template +HWY_API V MulSub(V mul, V x, V sub) { + const DFromV d; + const RebindToSigned di; + return MulAdd(mul, x, BitCast(d, Neg(BitCast(di, sub)))); +} + +#endif // HWY_NATIVE_INT_FMSUB + +template +HWY_API V NegMulSub(V mul, V x, V sub) { + const DFromV d; + const RebindToSigned di; + + return BitCast(d, Neg(BitCast(di, MulAdd(mul, x, sub)))); +} + +// ------------------------------ MulAddSub + +// MulAddSub(mul, x, sub_or_add) for a 1-lane vector is equivalent to +// MulSub(mul, x, sub_or_add) +template , 1)> +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + return MulSub(mul, x, sub_or_add); +} + +// MulAddSub for F16/F32/F64 vectors with 2 or more lanes on +// SSSE3/SSE4/AVX2/AVX3 is implemented in x86_128-inl.h, x86_256-inl.h, and +// x86_512-inl.h + +// MulAddSub for F16/F32/F64 vectors on SVE is implemented in arm_sve-inl.h + +// MulAddSub for integer vectors on SVE2 is implemented in arm_sve-inl.h +template +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + using D = DFromV; + using T = TFromD; + using TNegate = If(), MakeSigned, T>; + + const D d; + const Rebind d_negate; + + const auto add = + OddEven(sub_or_add, BitCast(d, Neg(BitCast(d_negate, sub_or_add)))); + return MulAdd(mul, x, add); +} +// ------------------------------ MulSubAdd + +template +HWY_API V MulSubAdd(V mul, V x, V sub_or_add) { + using D = DFromV; + using T = TFromD; + using TNegate = If(), MakeSigned, T>; + + const D d; + const Rebind d_negate; + + return MulAddSub(mul, x, BitCast(d, Neg(BitCast(d_negate, sub_or_add)))); +} + +// ------------------------------ MaskedConvertTo +template +HWY_API VFromD MaskedConvertTo(M m, D d, V v) { + return IfThenElseZero(m, ConvertTo(d, v)); +} + +// ------------------------------ Integer division +#if (defined(HWY_NATIVE_INT_DIV) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +namespace detail { + +// DemoteInRangeTo, PromoteInRangeTo, and ConvertInRangeTo are okay to use in +// the implementation of detail::IntDiv in generic_ops-inl.h as the current +// implementations of DemoteInRangeTo, PromoteInRangeTo, and ConvertInRangeTo +// will convert values that are outside of the range of TFromD by either +// saturation, truncation, or converting values that are outside of the +// destination range to LimitsMin>() (which is equal to +// static_cast>(LimitsMax>() + 1)) + +template ))> +HWY_INLINE Vec IntDivConvFloatToInt(D di, V vf) { + return ConvertInRangeTo(di, vf); +} + +template ))> +HWY_INLINE Vec IntDivConvIntToFloat(D df, V vi) { + return ConvertTo(df, vi); +} + +#if !HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template )> +HWY_INLINE Vec IntDivConvFloatToInt(D df, V vi) { + return PromoteInRangeTo(df, vi); +} + +// If !HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 is true, then UI64->F32 +// IntDivConvIntToFloat(df, vi) returns an approximation of +// static_cast(v[i]) that is within 4 ULP of static_cast(v[i]) +template )> +HWY_INLINE Vec IntDivConvIntToFloat(D df32, V vi) { + const Twice dt_f32; + + auto vf32 = + ConvertTo(dt_f32, BitCast(RebindToSigned(), vi)); + +#if HWY_IS_LITTLE_ENDIAN + const auto lo_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); + auto hi_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); +#else + const auto lo_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); + auto hi_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); +#endif + + const RebindToSigned di32; + + hi_f32 = + Add(hi_f32, And(BitCast(df32, BroadcastSignBit(BitCast(di32, lo_f32))), + Set(df32, 1.0f))); + return hwy::HWY_NAMESPACE::MulAdd(hi_f32, Set(df32, 4294967296.0f), lo_f32); +} + +template )> +HWY_INLINE Vec IntDivConvIntToFloat(D df32, V vu) { + const Twice dt_f32; + + auto vf32 = + ConvertTo(dt_f32, BitCast(RebindToUnsigned(), vu)); + +#if HWY_IS_LITTLE_ENDIAN + const auto lo_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); + const auto hi_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); +#else + const auto lo_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); + const auto hi_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); +#endif + + return hwy::HWY_NAMESPACE::MulAdd(hi_f32, Set(df32, 4294967296.0f), lo_f32); +} +#endif // !HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template , kOrigLaneSize)> +HWY_INLINE V IntDivUsingFloatDiv(V a, V b) { + const DFromV d; + const RebindToFloat df; + + // If kOrigLaneSize < sizeof(T) is true, then a[i] and b[i] are both in the + // [LimitsMin>(), + // LimitsMax>()] range. + + // floor(|a[i] / b[i]|) <= |flt_q| < floor(|a[i] / b[i]|) + 1 is also + // guaranteed to be true if MakeFloat has at least kOrigLaneSize*8 + 1 + // mantissa bits (including the implied one bit), where flt_q is equal to + // static_cast>(a[i]) / static_cast>(b[i]), + // even in the case where the magnitude of an inexact floating point division + // result is rounded up. + + // In other words, floor(flt_q) < flt_q < ceil(flt_q) is guaranteed to be true + // if (a[i] % b[i]) != 0 is true and MakeFloat has at least + // kOrigLaneSize*8 + 1 mantissa bits (including the implied one bit), even in + // the case where the magnitude of an inexact floating point division result + // is rounded up. + + // It is okay to do conversions from MakeFloat> to TFromV using + // ConvertInRangeTo if sizeof(TFromV) > kOrigLaneSize as the result of the + // floating point division is always greater than LimitsMin>() and + // less than LimitsMax>() if sizeof(TFromV) > kOrigLaneSize and + // b[i] != 0. + +#if HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64 + // On Armv7, do division by multiplying by the ApproximateReciprocal + // to avoid unnecessary overhead as F32 Div refines the approximate + // reciprocal using 4 Newton-Raphson iterations + + const RebindToSigned di; + const RebindToUnsigned du; + + const auto flt_b = ConvertTo(df, b); + auto flt_recip_b = ApproximateReciprocal(flt_b); + if (kOrigLaneSize > 1) { + flt_recip_b = + Mul(flt_recip_b, ReciprocalNewtonRaphsonStep(flt_recip_b, flt_b)); + } + + auto q0 = ConvertInRangeTo(d, Mul(ConvertTo(df, a), flt_recip_b)); + const auto r0 = BitCast(di, hwy::HWY_NAMESPACE::NegMulAdd(q0, b, a)); + + auto r1 = r0; + + // Need to negate r1[i] if a[i] < 0 is true + if (IsSigned>()) { + r1 = IfNegativeThenNegOrUndefIfZero(BitCast(di, a), r1); + } + + // r1[i] is now equal to (a[i] < 0) ? (-r0[i]) : r0[i] + + auto abs_b = BitCast(du, b); + if (IsSigned>()) { + abs_b = BitCast(du, Abs(BitCast(di, abs_b))); + } + + // If (r1[i] < 0 || r1[i] >= abs_b[i]) is true, then set q1[i] to -1. + // Otherwise, set q1[i] to 0. + + // (r1[i] < 0 || r1[i] >= abs_b[i]) can be carried out using a single unsigned + // comparison as static_cast(r1[i]) >= TU(LimitsMax() + 1) >= abs_b[i] + // will be true if r1[i] < 0 is true. + auto q1 = BitCast(di, VecFromMask(du, Ge(BitCast(du, r1), abs_b))); + + // q1[i] is now equal to (r1[i] < 0 || r1[i] >= abs_b[i]) ? -1 : 0 + + // Need to negate q1[i] if r0[i] and b[i] do not have the same sign + auto q1_negate_mask = r0; + if (IsSigned>()) { + q1_negate_mask = Xor(q1_negate_mask, BitCast(di, b)); + } + q1 = IfNegativeThenElse(q1_negate_mask, Neg(q1), q1); + + // q1[i] is now equal to (r1[i] < 0 || r1[i] >= abs_b[i]) ? + // (((r0[i] ^ b[i]) < 0) ? 1 : -1) + + // Need to subtract q1[i] from q0[i] to get the final result + return Sub(q0, BitCast(d, q1)); +#else + // On targets other than Armv7 NEON, use F16 or F32 division as most targets + // other than Armv7 NEON have native F32 divide instructions + return ConvertInRangeTo(d, Div(ConvertTo(df, a), ConvertTo(df, b))); +#endif +} + +template , kOrigLaneSize), + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 4) | (1 << 8))> +HWY_INLINE V IntDivUsingFloatDiv(V a, V b) { + // If kOrigLaneSize == sizeof(T) is true, at least two reciprocal + // multiplication steps are needed as the mantissa of MakeFloat has fewer + // than kOrigLaneSize*8 + 1 bits + + using T = TFromV; + +#if HWY_HAVE_FLOAT64 + using TF = MakeFloat; +#else + using TF = float; +#endif + + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + const Rebind df; + + if (!IsSigned()) { + // If T is unsigned, set a[i] to (a[i] >= b[i] ? 1 : 0) and set b[i] to 1 if + // b[i] > LimitsMax>() is true + + const auto one = Set(di, MakeSigned{1}); + a = BitCast( + d, IfNegativeThenElse(BitCast(di, b), + IfThenElseZero(RebindMask(di, Ge(a, b)), one), + BitCast(di, a))); + b = BitCast(d, IfNegativeThenElse(BitCast(di, b), one, BitCast(di, b))); + } + + // LimitsMin() <= b[i] <= LimitsMax>() is now true + + const auto flt_b = IntDivConvIntToFloat(df, b); + +#if HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64 + auto flt_recip_b = ApproximateReciprocal(flt_b); + flt_recip_b = + Mul(flt_recip_b, ReciprocalNewtonRaphsonStep(flt_recip_b, flt_b)); +#else + const auto flt_recip_b = Div(Set(df, TF(1.0)), flt_b); +#endif + + // It is okay if the conversion of a[i] * flt_recip_b[i] to T using + // IntDivConvFloatToInt returns incorrect results in any lanes where b[i] == 0 + // as the result of IntDivUsingFloatDiv(a, b) is implementation-defined in any + // lanes where b[i] == 0. + + // If ScalarAbs(b[i]) == 1 is true, then it is possible for + // a[i] * flt_recip_b[i] to be rounded up to a value that is outside of the + // range of T. If a[i] * flt_recip_b[i] is outside of the range of T, + // IntDivConvFloatToInt will convert any values that are out of the range of T + // by either saturation, truncation, or wrapping around to LimitsMin(). + + // It is okay if the conversion of a[i] * flt_recip_b[i] to T using + // IntDivConvFloatToInt wraps around if ScalarAbs(b[i]) == 1 as r0 will have + // the correct sign if ScalarAbs(b[i]) == 1, even in the cases where the + // conversion of a[i] * flt_recip_b[i] to T using IntDivConvFloatToInt is + // truncated or wraps around. + + // If ScalarAbs(b[i]) >= 2 is true, a[i] * flt_recip_b[i] will be within the + // range of T, even in the cases where the conversion of a[i] to TF is + // rounded up or the result of multiplying a[i] by flt_recip_b[i] is rounded + // up. + + // ScalarAbs(r0[i]) will also always be less than (LimitsMax() / 2) if + // b[i] != 0, even in the cases where the conversion of a[i] * flt_recip_b[i] + // to T using IntDivConvFloatToInt is truncated or is wrapped around. + + auto q0 = + IntDivConvFloatToInt(d, Mul(IntDivConvIntToFloat(df, a), flt_recip_b)); + const auto r0 = BitCast(di, hwy::HWY_NAMESPACE::NegMulAdd(q0, b, a)); + + // If b[i] != 0 is true, r0[i] * flt_recip_b[i] is always within the range of + // T, even in the cases where the conversion of r0[i] to TF is rounded up or + // the multiplication of r0[i] by flt_recip_b[i] is rounded up. + + auto q1 = + IntDivConvFloatToInt(di, Mul(IntDivConvIntToFloat(df, r0), flt_recip_b)); + const auto r1 = hwy::HWY_NAMESPACE::NegMulAdd(q1, BitCast(di, b), r0); + + auto r3 = r1; + +#if !HWY_HAVE_FLOAT64 + // Need two additional reciprocal multiplication steps for I64/U64 vectors if + // HWY_HAVE_FLOAT64 is 0 + if (sizeof(T) == 8) { + const auto q2 = IntDivConvFloatToInt( + di, Mul(IntDivConvIntToFloat(df, r1), flt_recip_b)); + const auto r2 = hwy::HWY_NAMESPACE::NegMulAdd(q2, BitCast(di, b), r1); + + const auto q3 = IntDivConvFloatToInt( + di, Mul(IntDivConvIntToFloat(df, r2), flt_recip_b)); + r3 = hwy::HWY_NAMESPACE::NegMulAdd(q3, BitCast(di, b), r2); + + q0 = Add(q0, BitCast(d, q2)); + q1 = Add(q1, q3); + } +#endif // !HWY_HAVE_FLOAT64 + + auto r4 = r3; + + // Need to negate r4[i] if a[i] < 0 is true + if (IsSigned>()) { + r4 = IfNegativeThenNegOrUndefIfZero(BitCast(di, a), r4); + } + + // r4[i] is now equal to (a[i] < 0) ? (-r3[i]) : r3[i] + + auto abs_b = BitCast(du, b); + if (IsSigned>()) { + abs_b = BitCast(du, Abs(BitCast(di, abs_b))); + } + + // If (r4[i] < 0 || r4[i] >= abs_b[i]) is true, then set q4[i] to -1. + // Otherwise, set r4[i] to 0. + + // (r4[i] < 0 || r4[i] >= abs_b[i]) can be carried out using a single unsigned + // comparison as static_cast(r4[i]) >= TU(LimitsMax() + 1) >= abs_b[i] + // will be true if r4[i] < 0 is true. + auto q4 = BitCast(di, VecFromMask(du, Ge(BitCast(du, r4), abs_b))); + + // q4[i] is now equal to (r4[i] < 0 || r4[i] >= abs_b[i]) ? -1 : 0 + + // Need to negate q4[i] if r3[i] and b[i] do not have the same sign + auto q4_negate_mask = r3; + if (IsSigned>()) { + q4_negate_mask = Xor(q4_negate_mask, BitCast(di, b)); + } + q4 = IfNegativeThenElse(q4_negate_mask, Neg(q4), q4); + + // q4[i] is now equal to (r4[i] < 0 || r4[i] >= abs_b[i]) ? + // (((r3[i] ^ b[i]) < 0) ? 1 : -1) + + // The final result is equal to q0[i] + q1[i] - q4[i] + return Sub(Add(q0, BitCast(d, q1)), BitCast(d, q4)); +} + +template ) == 1) ? 4 : 2))> +HWY_INLINE V IntDiv(V a, V b) { + using T = TFromV; + + // If HWY_HAVE_FLOAT16 is 0, need to promote I8 to I32 and U8 to U32 + using TW = MakeWide< + If<(!HWY_HAVE_FLOAT16 && sizeof(TFromV) == 1), MakeWide, T>>; + + const DFromV d; + const Rebind dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3/SSE4/AVX2/AVX3, promote to and from MakeSigned to avoid + // unnecessary overhead + const RebindToSigned dw_i; + + // On SSE2/SSSE3/SSE4/AVX2/AVX3, demote to MakeSigned if + // kOrigLaneSize < sizeof(T) to avoid unnecessary overhead + const If<(kOrigLaneSize < sizeof(T)), RebindToSigned, + decltype(d)> + d_demote_to; +#else + // On other targets, promote to TW and demote to T + const decltype(dw) dw_i; + const decltype(d) d_demote_to; +#endif + + return BitCast( + d, DemoteTo(d_demote_to, IntDivUsingFloatDiv( + PromoteTo(dw_i, a), PromoteTo(dw_i, b)))); +} + +template +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3/SSE4/AVX2/AVX3, promote to and from MakeSigned to avoid + // unnecessary overhead + const RebindToSigned dw_i; + + // On SSE2/SSSE3/SSE4/AVX2/AVX3, demote to MakeSigned> if + // kOrigLaneSize < sizeof(TFromV) to avoid unnecessary overhead + const If<(kOrigLaneSize < sizeof(TFromV)), RebindToSigned, + decltype(d)> + d_demote_to; +#else + // On other targets, promote to MakeWide> and demote to TFromV + const decltype(dw) dw_i; + const decltype(d) d_demote_to; +#endif + + return BitCast(d, OrderedDemote2To( + d_demote_to, + IntDivUsingFloatDiv( + PromoteLowerTo(dw_i, a), PromoteLowerTo(dw_i, b)), + IntDivUsingFloatDiv( + PromoteUpperTo(dw_i, a), PromoteUpperTo(dw_i, b)))); +} + +#if !HWY_HAVE_FLOAT16 +template ), + HWY_IF_V_SIZE_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const Rebind>, decltype(d)> dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3, demote from int16_t to TFromV to avoid unnecessary + // overhead + const RebindToSigned dw_i; +#else + // On other targets, demote from MakeWide> to TFromV + const decltype(dw) dw_i; +#endif + + return DemoteTo(d, + BitCast(dw_i, IntDiv<1>(PromoteTo(dw, a), PromoteTo(dw, b)))); +} +template ), + HWY_IF_V_SIZE_GT_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3, demote from int16_t to TFromV to avoid unnecessary + // overhead + const RebindToSigned dw_i; +#else + // On other targets, demote from MakeWide> to TFromV + const decltype(dw) dw_i; +#endif + + return OrderedDemote2To( + d, BitCast(dw_i, IntDiv<1>(PromoteLowerTo(dw, a), PromoteLowerTo(dw, b))), + BitCast(dw_i, IntDiv<1>(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b)))); +} +#endif // !HWY_HAVE_FLOAT16 + +template +HWY_INLINE V IntDiv(V a, V b) { + return IntDivUsingFloatDiv(a, b); +} + +#if HWY_HAVE_FLOAT64 +template ), + HWY_IF_V_SIZE_LE_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const Rebind df64; + + // It is okay to demote the F64 Div result to int32_t or uint32_t using + // DemoteInRangeTo as static_cast(a[i]) / static_cast(b[i]) + // will always be within the range of TFromV if b[i] != 0 and + // sizeof(TFromV) <= 4. + + return DemoteInRangeTo(d, Div(PromoteTo(df64, a), PromoteTo(df64, b))); +} +template ), + HWY_IF_V_SIZE_GT_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const Half dh; + const Repartition df64; + + // It is okay to demote the F64 Div result to int32_t or uint32_t using + // DemoteInRangeTo as static_cast(a[i]) / static_cast(b[i]) + // will always be within the range of TFromV if b[i] != 0 and + // sizeof(TFromV) <= 4. + + const VFromD div1 = + Div(PromoteUpperTo(df64, a), PromoteUpperTo(df64, b)); + const VFromD div0 = + Div(PromoteLowerTo(df64, a), PromoteLowerTo(df64, b)); + return Combine(d, DemoteInRangeTo(dh, div1), DemoteInRangeTo(dh, div0)); +} +#endif // HWY_HAVE_FLOAT64 + +template +HWY_INLINE V IntMod(V a, V b) { + return hwy::HWY_NAMESPACE::NegMulAdd(IntDiv(a, b), b, a); +} + +#if HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \ + HWY_TARGET == HWY_WASM_EMU256 +template ), + HWY_IF_V_SIZE_LE_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntMod(V a, V b) { + const DFromV d; + const Rebind>, decltype(d)> dw; + return DemoteTo(d, IntMod(PromoteTo(dw, a), PromoteTo(dw, b))); +} + +template ), + HWY_IF_V_SIZE_GT_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntMod(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + return OrderedDemote2To( + d, IntMod(PromoteLowerTo(dw, a), PromoteLowerTo(dw, b)), + IntMod(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b))); +} +#endif // HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || HWY_TARGET == + // HWY_WASM_EMU256 + +} // namespace detail + +#if HWY_TARGET == HWY_SCALAR + +template +HWY_API Vec1 operator/(Vec1 a, Vec1 b) { + return detail::IntDiv(a, b); +} +template +HWY_API Vec1 operator%(Vec1 a, Vec1 b) { + return detail::IntMod(a, b); +} + +#else // HWY_TARGET != HWY_SCALAR + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + return detail::IntDiv(a, b); +} + +template +HWY_API Vec128 operator%(Vec128 a, Vec128 b) { + return detail::IntMod(a, b); +} + +#if HWY_CAP_GE256 +template +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return detail::IntDiv(a, b); +} +template +HWY_API Vec256 operator%(Vec256 a, Vec256 b) { + return detail::IntMod(a, b); +} +#endif + +#if HWY_CAP_GE512 +template +HWY_API Vec512 operator/(Vec512 a, Vec512 b) { + return detail::IntDiv(a, b); +} +template +HWY_API Vec512 operator%(Vec512 a, Vec512 b) { + return detail::IntMod(a, b); +} +#endif + +#endif // HWY_TARGET == HWY_SCALAR + +#endif // HWY_NATIVE_INT_DIV + +// ------------------------------ AverageRound + +#if (defined(HWY_NATIVE_AVERAGE_ROUND_UI32) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +template )> +HWY_API V AverageRound(V a, V b) { + return Sub(Or(a, b), ShiftRight<1>(Xor(a, b))); +} + +#endif // HWY_NATIVE_AVERAGE_ROUND_UI64 + +#if (defined(HWY_NATIVE_AVERAGE_ROUND_UI64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +#if HWY_HAVE_INTEGER64 +template )> +HWY_API V AverageRound(V a, V b) { + return Sub(Or(a, b), ShiftRight<1>(Xor(a, b))); +} +#endif + +#endif // HWY_NATIVE_AVERAGE_ROUND_UI64 + +// ------------------------------ RoundingShiftRight (AverageRound) + +#if (defined(HWY_NATIVE_ROUNDING_SHR) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +template +HWY_API V RoundingShiftRight(V v) { + const DFromV d; + using T = TFromD; + + static_assert( + 0 <= kShiftAmt && kShiftAmt <= static_cast(sizeof(T) * 8 - 1), + "kShiftAmt is out of range"); + + constexpr int kScaleDownShrAmt = HWY_MAX(kShiftAmt - 1, 0); + + auto scaled_down_v = v; + HWY_IF_CONSTEXPR(kScaleDownShrAmt > 0) { + scaled_down_v = ShiftRight(v); + } + + HWY_IF_CONSTEXPR(kShiftAmt == 0) { return scaled_down_v; } + + return AverageRound(scaled_down_v, Zero(d)); +} + +template +HWY_API V RoundingShiftRightSame(V v, int shift_amt) { + const DFromV d; + using T = TFromD; + + const int shift_amt_is_zero_mask = -static_cast(shift_amt == 0); + + const auto scaled_down_v = ShiftRightSame( + v, static_cast(static_cast(shift_amt) + + static_cast(~shift_amt_is_zero_mask))); + + return AverageRound( + scaled_down_v, + And(scaled_down_v, Set(d, static_cast(shift_amt_is_zero_mask)))); +} + +template +HWY_API V RoundingShr(V v, V amt) { + const DFromV d; + const RebindToUnsigned du; + using T = TFromD; + using TU = MakeUnsigned; + + const auto unsigned_amt = BitCast(du, amt); + const auto scale_down_shr_amt = + BitCast(d, SaturatedSub(unsigned_amt, Set(du, TU{1}))); + + const auto scaled_down_v = Shr(v, scale_down_shr_amt); + return AverageRound(scaled_down_v, + IfThenElseZero(Eq(amt, Zero(d)), scaled_down_v)); +} + +#endif // HWY_NATIVE_ROUNDING_SHR + +// ------------------------------ MulEvenAdd (PromoteEvenTo) + +// SVE with bf16 and NEON with bf16 override this. +#if (defined(HWY_NATIVE_MUL_EVEN_BF16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MUL_EVEN_BF16 +#undef HWY_NATIVE_MUL_EVEN_BF16 +#else +#define HWY_NATIVE_MUL_EVEN_BF16 +#endif + +template >> +HWY_API VFromD MulEvenAdd(DF df, VBF a, VBF b, VFromD c) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), c); +} + +template >> +HWY_API VFromD MulOddAdd(DF df, VBF a, VBF b, VFromD c) { + return MulAdd(PromoteOddTo(df, a), PromoteOddTo(df, b), c); +} + +#endif // HWY_NATIVE_MUL_EVEN_BF16 + +// ------------------------------ ReorderWidenMulAccumulate (MulEvenAdd) + +// AVX3_SPR/ZEN4, and NEON with bf16 but not(!) SVE override this. +#if (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF df, VBF a, VBF b, + VFromD sum0, + VFromD& sum1) { + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. + sum1 = MulOddAdd(df, a, b, sum1); + return MulEvenAdd(df, a, b, sum0); +} + +#endif // HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 + +// ------------------------------ WidenMulAccumulate + +#if (defined(HWY_NATIVE_WIDEN_MUL_ACCUMULATE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#endif + +template), + class DN = RepartitionToNarrow> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = MulAdd(PromoteUpperTo(d, mul), PromoteUpperTo(d, x), high); + return MulAdd(PromoteLowerTo(d, mul), PromoteLowerTo(d, x), low); +} + +#endif // HWY_NATIVE_WIDEN_MUL_ACCUMULATE + +#if 0 +#if (defined(HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#endif + +#if HWY_HAVE_FLOAT16 + +template> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = MulAdd(PromoteUpperTo(d, mul), PromoteUpperTo(d, x), high); + return MulAdd(PromoteLowerTo(d, mul), PromoteLowerTo(d, x), low); +} + +#endif // HWY_HAVE_FLOAT16 + +#endif // HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#endif // #if 0 + +// ------------------------------ SatWidenMulPairwiseAdd + +#if (defined(HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#undef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#else +#define HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#endif + +template >, HWY_IF_I16_D(DI16), + HWY_IF_U8_D(DFromV), HWY_IF_I8_D(DFromV), + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VI8)), + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VU8_2))> +HWY_API Vec SatWidenMulPairwiseAdd(DI16 di16, VU8 a, VI8 b) { + const RebindToUnsigned du16; + + const auto a0 = BitCast(di16, PromoteEvenTo(du16, a)); + const auto b0 = PromoteEvenTo(di16, b); + + const auto a1 = BitCast(di16, PromoteOddTo(du16, a)); + const auto b1 = PromoteOddTo(di16, b); + + return SaturatedAdd(Mul(a0, b0), Mul(a1, b1)); +} + +#endif + +// ------------------------------ SatWidenMulPairwiseAccumulate + +#if (defined(HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#else +#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#endif + +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 di32, VFromD> a, + VFromD> b, VFromD sum) { + // WidenMulPairwiseAdd(di32, a, b) is okay here as + // a[0]*b[0]+a[1]*b[1] is between -2147418112 and 2147483648 and as + // a[0]*b[0]+a[1]*b[1] can only overflow an int32_t if + // a[0], b[0], a[1], and b[1] are all equal to -32768. + + const auto product = WidenMulPairwiseAdd(di32, a, b); + + const auto mul_overflow = + VecFromMask(di32, Eq(product, Set(di32, LimitsMin()))); + + return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), + Add(product, mul_overflow)); +} + +#endif // HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM + +// ------------------------------ SatWidenMulAccumFixedPoint + +#if (defined(HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Repartition dt_i16; + + const auto vt_a = ResizeBitCast(dt_i16, a); + const auto vt_b = ResizeBitCast(dt_i16, b); + + const auto dup_a = InterleaveWholeLower(dt_i16, vt_a, vt_a); + const auto dup_b = InterleaveWholeLower(dt_i16, vt_b, vt_b); + + return SatWidenMulPairwiseAccumulate(di32, dup_a, dup_b, sum); +} + +#endif // HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT + +// ------------------------------ MaskedSqrt + +#if (defined(HWY_NATIVE_MASKED_SQRT) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_MASKED_SQRT +#undef HWY_NATIVE_MASKED_SQRT +#else +#define HWY_NATIVE_MASKED_SQRT +#endif +template +HWY_API V MaskedSqrt(M m, V v) { + return IfThenElseZero(m, Sqrt(v)); +} + +template +HWY_API V MaskedSqrtOr(V no, M m, V v) { + return IfThenElse(m, Sqrt(v), no); +} +#endif + +// ------------------------------ SumOfMulQuadAccumulate + +#if (defined(HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Repartition di16; + + const auto a0 = PromoteEvenTo(di16, a); + const auto b0 = PromoteEvenTo(di16, b); + + const auto a1 = PromoteOddTo(di16, a); + const auto b1 = PromoteOddTo(di16, b); + + return Add(sum, Add(WidenMulPairwiseAdd(di32, a0, b0), + WidenMulPairwiseAdd(di32, a1, b1))); +} + +#endif + +#if (defined(HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 du32, VFromD> a, + VFromD> b, VFromD sum) { + const Repartition du16; + const RebindToSigned di16; + const RebindToSigned di32; + + const auto lo8_mask = Set(di16, int16_t{0x00FF}); + const auto a0 = And(BitCast(di16, a), lo8_mask); + const auto b0 = And(BitCast(di16, b), lo8_mask); + + const auto a1 = BitCast(di16, ShiftRight<8>(BitCast(du16, a))); + const auto b1 = BitCast(di16, ShiftRight<8>(BitCast(du16, b))); + + return Add(sum, Add(BitCast(du32, WidenMulPairwiseAdd(di32, a0, b0)), + BitCast(du32, WidenMulPairwiseAdd(di32, a1, b1)))); +} + +#endif + +#if (defined(HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 di32, VFromD> a_u, + VFromD> b_i, VFromD sum) { + const Repartition di16; + const RebindToUnsigned du16; + + const auto a0 = And(BitCast(di16, a_u), Set(di16, int16_t{0x00FF})); + const auto b0 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, b_i))); + + const auto a1 = BitCast(di16, ShiftRight<8>(BitCast(du16, a_u))); + const auto b1 = ShiftRight<8>(BitCast(di16, b_i)); + + // NOTE: SatWidenMulPairwiseAdd(di16, a_u, b_i) cannot be used in + // SumOfMulQuadAccumulate as it is possible for + // a_u[0]*b_i[0]+a_u[1]*b_i[1] to overflow an int16_t if a_u[0], b_i[0], + // a_u[1], and b_i[1] are all non-zero and b_i[0] and b_i[1] have the same + // sign. + + return Add(sum, Add(WidenMulPairwiseAdd(di32, a0, b0), + WidenMulPairwiseAdd(di32, a1, b1))); +} + +#endif + +#if (defined(HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#endif + +#if HWY_HAVE_INTEGER64 +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI64 di64, VFromD> a, + VFromD> b, VFromD sum) { + const Repartition di32; + + // WidenMulPairwiseAdd(di32, a, b) is okay here as + // a[0]*b[0]+a[1]*b[1] is between -2147418112 and 2147483648 and as + // a[0]*b[0]+a[1]*b[1] can only overflow an int32_t if + // a[0], b[0], a[1], and b[1] are all equal to -32768. + + const auto i32_pairwise_sum = WidenMulPairwiseAdd(di32, a, b); + const auto i32_pairwise_sum_overflow = + VecFromMask(di32, Eq(i32_pairwise_sum, Set(di32, LimitsMin()))); + + // The upper 32 bits of sum0 and sum1 need to be zeroed out in the case of + // overflow. + const auto hi32_mask = Set(di64, static_cast(~int64_t{0xFFFFFFFF})); + const auto p0_zero_out_mask = + ShiftLeft<32>(BitCast(di64, i32_pairwise_sum_overflow)); + const auto p1_zero_out_mask = + And(BitCast(di64, i32_pairwise_sum_overflow), hi32_mask); + + const auto p0 = + AndNot(p0_zero_out_mask, + ShiftRight<32>(ShiftLeft<32>(BitCast(di64, i32_pairwise_sum)))); + const auto p1 = + AndNot(p1_zero_out_mask, ShiftRight<32>(BitCast(di64, i32_pairwise_sum))); + + return Add(sum, Add(p0, p1)); +} +#endif // HWY_HAVE_INTEGER64 +#endif // HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE + +#if (defined(HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#endif + +#if HWY_HAVE_INTEGER64 +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU64 du64, VFromD> a, + VFromD> b, VFromD sum) { + const auto u32_even_prod = MulEven(a, b); + const auto u32_odd_prod = MulOdd(a, b); + + const auto p0 = Add(PromoteEvenTo(du64, u32_even_prod), + PromoteEvenTo(du64, u32_odd_prod)); + const auto p1 = + Add(PromoteOddTo(du64, u32_even_prod), PromoteOddTo(du64, u32_odd_prod)); + + return Add(sum, Add(p0, p1)); +} +#endif // HWY_HAVE_INTEGER64 +#endif // HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE + +// ------------------------------ F64 ApproximateReciprocal + +#if (defined(HWY_NATIVE_F64_APPROX_RECIP) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +#if HWY_HAVE_FLOAT64 +template )> +HWY_API V ApproximateReciprocal(V v) { + const DFromV d; + return Div(Set(d, 1.0), v); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_NATIVE_F64_APPROX_RECIP + +// ------------------------------ MaskedApproximateReciprocal +template +HWY_API V MaskedApproximateReciprocal(M m, V v) { + return IfThenElseZero(m, ApproximateReciprocal(v)); +} + +// ------------------------------ F64 ApproximateReciprocalSqrt + +#if (defined(HWY_NATIVE_F64_APPROX_RSQRT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +#if HWY_HAVE_FLOAT64 +template )> +HWY_API V ApproximateReciprocalSqrt(V v) { + const DFromV d; + const RebindToUnsigned du; + const auto half = Mul(v, Set(d, 0.5)); + // Initial guess based on log2(f) + const auto guess = BitCast(d, Sub(Set(du, uint64_t{0x5FE6EB50C7B537A9u}), + ShiftRight<1>(BitCast(du, v)))); + // One Newton-Raphson iteration + return Mul(guess, NegMulAdd(Mul(half, guess), guess, Set(d, 1.5))); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_NATIVE_F64_APPROX_RSQRT + +// ------------------------------ MaskedApproximateReciprocalSqrt +template +HWY_API V MaskedApproximateReciprocalSqrt(M m, V v) { + return IfThenElseZero(m, ApproximateReciprocalSqrt(v)); +} + +// ------------------------------ Compress* + +#if (defined(HWY_NATIVE_COMPRESS8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +template +HWY_API size_t CompressBitsStore(V v, const uint8_t* HWY_RESTRICT bits, D d, + T* unaligned) { + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + const Simd d8; + T* HWY_RESTRICT pos = unaligned; + + HWY_ALIGN constexpr T table[2048] = { + 0, 1, 2, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 1, 0, 2, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 2, 0, 1, 3, 4, 5, 6, 7, /**/ 0, 2, 1, 3, 4, 5, 6, 7, // + 1, 2, 0, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 3, 0, 1, 2, 4, 5, 6, 7, /**/ 0, 3, 1, 2, 4, 5, 6, 7, // + 1, 3, 0, 2, 4, 5, 6, 7, /**/ 0, 1, 3, 2, 4, 5, 6, 7, // + 2, 3, 0, 1, 4, 5, 6, 7, /**/ 0, 2, 3, 1, 4, 5, 6, 7, // + 1, 2, 3, 0, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 4, 0, 1, 2, 3, 5, 6, 7, /**/ 0, 4, 1, 2, 3, 5, 6, 7, // + 1, 4, 0, 2, 3, 5, 6, 7, /**/ 0, 1, 4, 2, 3, 5, 6, 7, // + 2, 4, 0, 1, 3, 5, 6, 7, /**/ 0, 2, 4, 1, 3, 5, 6, 7, // + 1, 2, 4, 0, 3, 5, 6, 7, /**/ 0, 1, 2, 4, 3, 5, 6, 7, // + 3, 4, 0, 1, 2, 5, 6, 7, /**/ 0, 3, 4, 1, 2, 5, 6, 7, // + 1, 3, 4, 0, 2, 5, 6, 7, /**/ 0, 1, 3, 4, 2, 5, 6, 7, // + 2, 3, 4, 0, 1, 5, 6, 7, /**/ 0, 2, 3, 4, 1, 5, 6, 7, // + 1, 2, 3, 4, 0, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 5, 0, 1, 2, 3, 4, 6, 7, /**/ 0, 5, 1, 2, 3, 4, 6, 7, // + 1, 5, 0, 2, 3, 4, 6, 7, /**/ 0, 1, 5, 2, 3, 4, 6, 7, // + 2, 5, 0, 1, 3, 4, 6, 7, /**/ 0, 2, 5, 1, 3, 4, 6, 7, // + 1, 2, 5, 0, 3, 4, 6, 7, /**/ 0, 1, 2, 5, 3, 4, 6, 7, // + 3, 5, 0, 1, 2, 4, 6, 7, /**/ 0, 3, 5, 1, 2, 4, 6, 7, // + 1, 3, 5, 0, 2, 4, 6, 7, /**/ 0, 1, 3, 5, 2, 4, 6, 7, // + 2, 3, 5, 0, 1, 4, 6, 7, /**/ 0, 2, 3, 5, 1, 4, 6, 7, // + 1, 2, 3, 5, 0, 4, 6, 7, /**/ 0, 1, 2, 3, 5, 4, 6, 7, // + 4, 5, 0, 1, 2, 3, 6, 7, /**/ 0, 4, 5, 1, 2, 3, 6, 7, // + 1, 4, 5, 0, 2, 3, 6, 7, /**/ 0, 1, 4, 5, 2, 3, 6, 7, // + 2, 4, 5, 0, 1, 3, 6, 7, /**/ 0, 2, 4, 5, 1, 3, 6, 7, // + 1, 2, 4, 5, 0, 3, 6, 7, /**/ 0, 1, 2, 4, 5, 3, 6, 7, // + 3, 4, 5, 0, 1, 2, 6, 7, /**/ 0, 3, 4, 5, 1, 2, 6, 7, // + 1, 3, 4, 5, 0, 2, 6, 7, /**/ 0, 1, 3, 4, 5, 2, 6, 7, // + 2, 3, 4, 5, 0, 1, 6, 7, /**/ 0, 2, 3, 4, 5, 1, 6, 7, // + 1, 2, 3, 4, 5, 0, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 6, 0, 1, 2, 3, 4, 5, 7, /**/ 0, 6, 1, 2, 3, 4, 5, 7, // + 1, 6, 0, 2, 3, 4, 5, 7, /**/ 0, 1, 6, 2, 3, 4, 5, 7, // + 2, 6, 0, 1, 3, 4, 5, 7, /**/ 0, 2, 6, 1, 3, 4, 5, 7, // + 1, 2, 6, 0, 3, 4, 5, 7, /**/ 0, 1, 2, 6, 3, 4, 5, 7, // + 3, 6, 0, 1, 2, 4, 5, 7, /**/ 0, 3, 6, 1, 2, 4, 5, 7, // + 1, 3, 6, 0, 2, 4, 5, 7, /**/ 0, 1, 3, 6, 2, 4, 5, 7, // + 2, 3, 6, 0, 1, 4, 5, 7, /**/ 0, 2, 3, 6, 1, 4, 5, 7, // + 1, 2, 3, 6, 0, 4, 5, 7, /**/ 0, 1, 2, 3, 6, 4, 5, 7, // + 4, 6, 0, 1, 2, 3, 5, 7, /**/ 0, 4, 6, 1, 2, 3, 5, 7, // + 1, 4, 6, 0, 2, 3, 5, 7, /**/ 0, 1, 4, 6, 2, 3, 5, 7, // + 2, 4, 6, 0, 1, 3, 5, 7, /**/ 0, 2, 4, 6, 1, 3, 5, 7, // + 1, 2, 4, 6, 0, 3, 5, 7, /**/ 0, 1, 2, 4, 6, 3, 5, 7, // + 3, 4, 6, 0, 1, 2, 5, 7, /**/ 0, 3, 4, 6, 1, 2, 5, 7, // + 1, 3, 4, 6, 0, 2, 5, 7, /**/ 0, 1, 3, 4, 6, 2, 5, 7, // + 2, 3, 4, 6, 0, 1, 5, 7, /**/ 0, 2, 3, 4, 6, 1, 5, 7, // + 1, 2, 3, 4, 6, 0, 5, 7, /**/ 0, 1, 2, 3, 4, 6, 5, 7, // + 5, 6, 0, 1, 2, 3, 4, 7, /**/ 0, 5, 6, 1, 2, 3, 4, 7, // + 1, 5, 6, 0, 2, 3, 4, 7, /**/ 0, 1, 5, 6, 2, 3, 4, 7, // + 2, 5, 6, 0, 1, 3, 4, 7, /**/ 0, 2, 5, 6, 1, 3, 4, 7, // + 1, 2, 5, 6, 0, 3, 4, 7, /**/ 0, 1, 2, 5, 6, 3, 4, 7, // + 3, 5, 6, 0, 1, 2, 4, 7, /**/ 0, 3, 5, 6, 1, 2, 4, 7, // + 1, 3, 5, 6, 0, 2, 4, 7, /**/ 0, 1, 3, 5, 6, 2, 4, 7, // + 2, 3, 5, 6, 0, 1, 4, 7, /**/ 0, 2, 3, 5, 6, 1, 4, 7, // + 1, 2, 3, 5, 6, 0, 4, 7, /**/ 0, 1, 2, 3, 5, 6, 4, 7, // + 4, 5, 6, 0, 1, 2, 3, 7, /**/ 0, 4, 5, 6, 1, 2, 3, 7, // + 1, 4, 5, 6, 0, 2, 3, 7, /**/ 0, 1, 4, 5, 6, 2, 3, 7, // + 2, 4, 5, 6, 0, 1, 3, 7, /**/ 0, 2, 4, 5, 6, 1, 3, 7, // + 1, 2, 4, 5, 6, 0, 3, 7, /**/ 0, 1, 2, 4, 5, 6, 3, 7, // + 3, 4, 5, 6, 0, 1, 2, 7, /**/ 0, 3, 4, 5, 6, 1, 2, 7, // + 1, 3, 4, 5, 6, 0, 2, 7, /**/ 0, 1, 3, 4, 5, 6, 2, 7, // + 2, 3, 4, 5, 6, 0, 1, 7, /**/ 0, 2, 3, 4, 5, 6, 1, 7, // + 1, 2, 3, 4, 5, 6, 0, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 7, 0, 1, 2, 3, 4, 5, 6, /**/ 0, 7, 1, 2, 3, 4, 5, 6, // + 1, 7, 0, 2, 3, 4, 5, 6, /**/ 0, 1, 7, 2, 3, 4, 5, 6, // + 2, 7, 0, 1, 3, 4, 5, 6, /**/ 0, 2, 7, 1, 3, 4, 5, 6, // + 1, 2, 7, 0, 3, 4, 5, 6, /**/ 0, 1, 2, 7, 3, 4, 5, 6, // + 3, 7, 0, 1, 2, 4, 5, 6, /**/ 0, 3, 7, 1, 2, 4, 5, 6, // + 1, 3, 7, 0, 2, 4, 5, 6, /**/ 0, 1, 3, 7, 2, 4, 5, 6, // + 2, 3, 7, 0, 1, 4, 5, 6, /**/ 0, 2, 3, 7, 1, 4, 5, 6, // + 1, 2, 3, 7, 0, 4, 5, 6, /**/ 0, 1, 2, 3, 7, 4, 5, 6, // + 4, 7, 0, 1, 2, 3, 5, 6, /**/ 0, 4, 7, 1, 2, 3, 5, 6, // + 1, 4, 7, 0, 2, 3, 5, 6, /**/ 0, 1, 4, 7, 2, 3, 5, 6, // + 2, 4, 7, 0, 1, 3, 5, 6, /**/ 0, 2, 4, 7, 1, 3, 5, 6, // + 1, 2, 4, 7, 0, 3, 5, 6, /**/ 0, 1, 2, 4, 7, 3, 5, 6, // + 3, 4, 7, 0, 1, 2, 5, 6, /**/ 0, 3, 4, 7, 1, 2, 5, 6, // + 1, 3, 4, 7, 0, 2, 5, 6, /**/ 0, 1, 3, 4, 7, 2, 5, 6, // + 2, 3, 4, 7, 0, 1, 5, 6, /**/ 0, 2, 3, 4, 7, 1, 5, 6, // + 1, 2, 3, 4, 7, 0, 5, 6, /**/ 0, 1, 2, 3, 4, 7, 5, 6, // + 5, 7, 0, 1, 2, 3, 4, 6, /**/ 0, 5, 7, 1, 2, 3, 4, 6, // + 1, 5, 7, 0, 2, 3, 4, 6, /**/ 0, 1, 5, 7, 2, 3, 4, 6, // + 2, 5, 7, 0, 1, 3, 4, 6, /**/ 0, 2, 5, 7, 1, 3, 4, 6, // + 1, 2, 5, 7, 0, 3, 4, 6, /**/ 0, 1, 2, 5, 7, 3, 4, 6, // + 3, 5, 7, 0, 1, 2, 4, 6, /**/ 0, 3, 5, 7, 1, 2, 4, 6, // + 1, 3, 5, 7, 0, 2, 4, 6, /**/ 0, 1, 3, 5, 7, 2, 4, 6, // + 2, 3, 5, 7, 0, 1, 4, 6, /**/ 0, 2, 3, 5, 7, 1, 4, 6, // + 1, 2, 3, 5, 7, 0, 4, 6, /**/ 0, 1, 2, 3, 5, 7, 4, 6, // + 4, 5, 7, 0, 1, 2, 3, 6, /**/ 0, 4, 5, 7, 1, 2, 3, 6, // + 1, 4, 5, 7, 0, 2, 3, 6, /**/ 0, 1, 4, 5, 7, 2, 3, 6, // + 2, 4, 5, 7, 0, 1, 3, 6, /**/ 0, 2, 4, 5, 7, 1, 3, 6, // + 1, 2, 4, 5, 7, 0, 3, 6, /**/ 0, 1, 2, 4, 5, 7, 3, 6, // + 3, 4, 5, 7, 0, 1, 2, 6, /**/ 0, 3, 4, 5, 7, 1, 2, 6, // + 1, 3, 4, 5, 7, 0, 2, 6, /**/ 0, 1, 3, 4, 5, 7, 2, 6, // + 2, 3, 4, 5, 7, 0, 1, 6, /**/ 0, 2, 3, 4, 5, 7, 1, 6, // + 1, 2, 3, 4, 5, 7, 0, 6, /**/ 0, 1, 2, 3, 4, 5, 7, 6, // + 6, 7, 0, 1, 2, 3, 4, 5, /**/ 0, 6, 7, 1, 2, 3, 4, 5, // + 1, 6, 7, 0, 2, 3, 4, 5, /**/ 0, 1, 6, 7, 2, 3, 4, 5, // + 2, 6, 7, 0, 1, 3, 4, 5, /**/ 0, 2, 6, 7, 1, 3, 4, 5, // + 1, 2, 6, 7, 0, 3, 4, 5, /**/ 0, 1, 2, 6, 7, 3, 4, 5, // + 3, 6, 7, 0, 1, 2, 4, 5, /**/ 0, 3, 6, 7, 1, 2, 4, 5, // + 1, 3, 6, 7, 0, 2, 4, 5, /**/ 0, 1, 3, 6, 7, 2, 4, 5, // + 2, 3, 6, 7, 0, 1, 4, 5, /**/ 0, 2, 3, 6, 7, 1, 4, 5, // + 1, 2, 3, 6, 7, 0, 4, 5, /**/ 0, 1, 2, 3, 6, 7, 4, 5, // + 4, 6, 7, 0, 1, 2, 3, 5, /**/ 0, 4, 6, 7, 1, 2, 3, 5, // + 1, 4, 6, 7, 0, 2, 3, 5, /**/ 0, 1, 4, 6, 7, 2, 3, 5, // + 2, 4, 6, 7, 0, 1, 3, 5, /**/ 0, 2, 4, 6, 7, 1, 3, 5, // + 1, 2, 4, 6, 7, 0, 3, 5, /**/ 0, 1, 2, 4, 6, 7, 3, 5, // + 3, 4, 6, 7, 0, 1, 2, 5, /**/ 0, 3, 4, 6, 7, 1, 2, 5, // + 1, 3, 4, 6, 7, 0, 2, 5, /**/ 0, 1, 3, 4, 6, 7, 2, 5, // + 2, 3, 4, 6, 7, 0, 1, 5, /**/ 0, 2, 3, 4, 6, 7, 1, 5, // + 1, 2, 3, 4, 6, 7, 0, 5, /**/ 0, 1, 2, 3, 4, 6, 7, 5, // + 5, 6, 7, 0, 1, 2, 3, 4, /**/ 0, 5, 6, 7, 1, 2, 3, 4, // + 1, 5, 6, 7, 0, 2, 3, 4, /**/ 0, 1, 5, 6, 7, 2, 3, 4, // + 2, 5, 6, 7, 0, 1, 3, 4, /**/ 0, 2, 5, 6, 7, 1, 3, 4, // + 1, 2, 5, 6, 7, 0, 3, 4, /**/ 0, 1, 2, 5, 6, 7, 3, 4, // + 3, 5, 6, 7, 0, 1, 2, 4, /**/ 0, 3, 5, 6, 7, 1, 2, 4, // + 1, 3, 5, 6, 7, 0, 2, 4, /**/ 0, 1, 3, 5, 6, 7, 2, 4, // + 2, 3, 5, 6, 7, 0, 1, 4, /**/ 0, 2, 3, 5, 6, 7, 1, 4, // + 1, 2, 3, 5, 6, 7, 0, 4, /**/ 0, 1, 2, 3, 5, 6, 7, 4, // + 4, 5, 6, 7, 0, 1, 2, 3, /**/ 0, 4, 5, 6, 7, 1, 2, 3, // + 1, 4, 5, 6, 7, 0, 2, 3, /**/ 0, 1, 4, 5, 6, 7, 2, 3, // + 2, 4, 5, 6, 7, 0, 1, 3, /**/ 0, 2, 4, 5, 6, 7, 1, 3, // + 1, 2, 4, 5, 6, 7, 0, 3, /**/ 0, 1, 2, 4, 5, 6, 7, 3, // + 3, 4, 5, 6, 7, 0, 1, 2, /**/ 0, 3, 4, 5, 6, 7, 1, 2, // + 1, 3, 4, 5, 6, 7, 0, 2, /**/ 0, 1, 3, 4, 5, 6, 7, 2, // + 2, 3, 4, 5, 6, 7, 0, 1, /**/ 0, 2, 3, 4, 5, 6, 7, 1, // + 1, 2, 3, 4, 5, 6, 7, 0, /**/ 0, 1, 2, 3, 4, 5, 6, 7}; + + for (size_t i = 0; i < Lanes(d); i += 8) { + // Each byte worth of bits is the index of one of 256 8-byte ranges, and its + // population count determines how far to advance the write position. + const size_t bits8 = bits[i / 8]; + const auto indices = Load(d8, table + bits8 * 8); + const auto compressed = TableLookupBytes(LoadU(d8, lanes + i), indices); + StoreU(compressed, d8, pos); + pos += PopCount(bits8); + } + return static_cast(pos - unaligned); +} + +template +HWY_API size_t CompressStore(V v, M mask, D d, T* HWY_RESTRICT unaligned) { + uint8_t bits[HWY_MAX(size_t{8}, MaxLanes(d) / 8)]; + (void)StoreMaskBits(d, mask, bits); + return CompressBitsStore(v, bits, d, unaligned); +} + +template +HWY_API size_t CompressBlendedStore(V v, M mask, D d, + T* HWY_RESTRICT unaligned) { + HWY_ALIGN T buf[MaxLanes(d)]; + const size_t bytes = CompressStore(v, mask, d, buf); + BlendedStore(Load(d, buf), FirstN(d, bytes), d, unaligned); + return bytes; +} + +// For reasons unknown, HWY_IF_T_SIZE_V is a compile error in SVE. +template , HWY_IF_T_SIZE(T, 1)> +HWY_API V Compress(V v, const M mask) { + const DFromV d; + HWY_ALIGN T lanes[MaxLanes(d)]; + (void)CompressStore(v, mask, d, lanes); + return Load(d, lanes); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + const DFromV d; + HWY_ALIGN T lanes[MaxLanes(d)]; + (void)CompressBitsStore(v, bits, d, lanes); + return Load(d, lanes); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API V CompressNot(V v, M mask) { + return Compress(v, Not(mask)); +} + +#endif // HWY_NATIVE_COMPRESS8 + +// ------------------------------ Expand + +// Note that this generic implementation assumes <= 128 bit fixed vectors; +// the SVE and RVV targets provide their own native implementations. +#if (defined(HWY_NATIVE_EXPAND) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +namespace detail { + +template +HWY_INLINE Vec128 IndicesForExpandFromBits(uint64_t mask_bits) { + static_assert(N <= 8, "Should only be called for half-vectors"); + const Simd du8; + HWY_DASSERT(mask_bits < 0x100); + alignas(16) static constexpr uint8_t table[2048] = { + // PrintExpand8x8Tables + 128, 128, 128, 128, 128, 128, 128, 128, // + 0, 128, 128, 128, 128, 128, 128, 128, // + 128, 0, 128, 128, 128, 128, 128, 128, // + 0, 1, 128, 128, 128, 128, 128, 128, // + 128, 128, 0, 128, 128, 128, 128, 128, // + 0, 128, 1, 128, 128, 128, 128, 128, // + 128, 0, 1, 128, 128, 128, 128, 128, // + 0, 1, 2, 128, 128, 128, 128, 128, // + 128, 128, 128, 0, 128, 128, 128, 128, // + 0, 128, 128, 1, 128, 128, 128, 128, // + 128, 0, 128, 1, 128, 128, 128, 128, // + 0, 1, 128, 2, 128, 128, 128, 128, // + 128, 128, 0, 1, 128, 128, 128, 128, // + 0, 128, 1, 2, 128, 128, 128, 128, // + 128, 0, 1, 2, 128, 128, 128, 128, // + 0, 1, 2, 3, 128, 128, 128, 128, // + 128, 128, 128, 128, 0, 128, 128, 128, // + 0, 128, 128, 128, 1, 128, 128, 128, // + 128, 0, 128, 128, 1, 128, 128, 128, // + 0, 1, 128, 128, 2, 128, 128, 128, // + 128, 128, 0, 128, 1, 128, 128, 128, // + 0, 128, 1, 128, 2, 128, 128, 128, // + 128, 0, 1, 128, 2, 128, 128, 128, // + 0, 1, 2, 128, 3, 128, 128, 128, // + 128, 128, 128, 0, 1, 128, 128, 128, // + 0, 128, 128, 1, 2, 128, 128, 128, // + 128, 0, 128, 1, 2, 128, 128, 128, // + 0, 1, 128, 2, 3, 128, 128, 128, // + 128, 128, 0, 1, 2, 128, 128, 128, // + 0, 128, 1, 2, 3, 128, 128, 128, // + 128, 0, 1, 2, 3, 128, 128, 128, // + 0, 1, 2, 3, 4, 128, 128, 128, // + 128, 128, 128, 128, 128, 0, 128, 128, // + 0, 128, 128, 128, 128, 1, 128, 128, // + 128, 0, 128, 128, 128, 1, 128, 128, // + 0, 1, 128, 128, 128, 2, 128, 128, // + 128, 128, 0, 128, 128, 1, 128, 128, // + 0, 128, 1, 128, 128, 2, 128, 128, // + 128, 0, 1, 128, 128, 2, 128, 128, // + 0, 1, 2, 128, 128, 3, 128, 128, // + 128, 128, 128, 0, 128, 1, 128, 128, // + 0, 128, 128, 1, 128, 2, 128, 128, // + 128, 0, 128, 1, 128, 2, 128, 128, // + 0, 1, 128, 2, 128, 3, 128, 128, // + 128, 128, 0, 1, 128, 2, 128, 128, // + 0, 128, 1, 2, 128, 3, 128, 128, // + 128, 0, 1, 2, 128, 3, 128, 128, // + 0, 1, 2, 3, 128, 4, 128, 128, // + 128, 128, 128, 128, 0, 1, 128, 128, // + 0, 128, 128, 128, 1, 2, 128, 128, // + 128, 0, 128, 128, 1, 2, 128, 128, // + 0, 1, 128, 128, 2, 3, 128, 128, // + 128, 128, 0, 128, 1, 2, 128, 128, // + 0, 128, 1, 128, 2, 3, 128, 128, // + 128, 0, 1, 128, 2, 3, 128, 128, // + 0, 1, 2, 128, 3, 4, 128, 128, // + 128, 128, 128, 0, 1, 2, 128, 128, // + 0, 128, 128, 1, 2, 3, 128, 128, // + 128, 0, 128, 1, 2, 3, 128, 128, // + 0, 1, 128, 2, 3, 4, 128, 128, // + 128, 128, 0, 1, 2, 3, 128, 128, // + 0, 128, 1, 2, 3, 4, 128, 128, // + 128, 0, 1, 2, 3, 4, 128, 128, // + 0, 1, 2, 3, 4, 5, 128, 128, // + 128, 128, 128, 128, 128, 128, 0, 128, // + 0, 128, 128, 128, 128, 128, 1, 128, // + 128, 0, 128, 128, 128, 128, 1, 128, // + 0, 1, 128, 128, 128, 128, 2, 128, // + 128, 128, 0, 128, 128, 128, 1, 128, // + 0, 128, 1, 128, 128, 128, 2, 128, // + 128, 0, 1, 128, 128, 128, 2, 128, // + 0, 1, 2, 128, 128, 128, 3, 128, // + 128, 128, 128, 0, 128, 128, 1, 128, // + 0, 128, 128, 1, 128, 128, 2, 128, // + 128, 0, 128, 1, 128, 128, 2, 128, // + 0, 1, 128, 2, 128, 128, 3, 128, // + 128, 128, 0, 1, 128, 128, 2, 128, // + 0, 128, 1, 2, 128, 128, 3, 128, // + 128, 0, 1, 2, 128, 128, 3, 128, // + 0, 1, 2, 3, 128, 128, 4, 128, // + 128, 128, 128, 128, 0, 128, 1, 128, // + 0, 128, 128, 128, 1, 128, 2, 128, // + 128, 0, 128, 128, 1, 128, 2, 128, // + 0, 1, 128, 128, 2, 128, 3, 128, // + 128, 128, 0, 128, 1, 128, 2, 128, // + 0, 128, 1, 128, 2, 128, 3, 128, // + 128, 0, 1, 128, 2, 128, 3, 128, // + 0, 1, 2, 128, 3, 128, 4, 128, // + 128, 128, 128, 0, 1, 128, 2, 128, // + 0, 128, 128, 1, 2, 128, 3, 128, // + 128, 0, 128, 1, 2, 128, 3, 128, // + 0, 1, 128, 2, 3, 128, 4, 128, // + 128, 128, 0, 1, 2, 128, 3, 128, // + 0, 128, 1, 2, 3, 128, 4, 128, // + 128, 0, 1, 2, 3, 128, 4, 128, // + 0, 1, 2, 3, 4, 128, 5, 128, // + 128, 128, 128, 128, 128, 0, 1, 128, // + 0, 128, 128, 128, 128, 1, 2, 128, // + 128, 0, 128, 128, 128, 1, 2, 128, // + 0, 1, 128, 128, 128, 2, 3, 128, // + 128, 128, 0, 128, 128, 1, 2, 128, // + 0, 128, 1, 128, 128, 2, 3, 128, // + 128, 0, 1, 128, 128, 2, 3, 128, // + 0, 1, 2, 128, 128, 3, 4, 128, // + 128, 128, 128, 0, 128, 1, 2, 128, // + 0, 128, 128, 1, 128, 2, 3, 128, // + 128, 0, 128, 1, 128, 2, 3, 128, // + 0, 1, 128, 2, 128, 3, 4, 128, // + 128, 128, 0, 1, 128, 2, 3, 128, // + 0, 128, 1, 2, 128, 3, 4, 128, // + 128, 0, 1, 2, 128, 3, 4, 128, // + 0, 1, 2, 3, 128, 4, 5, 128, // + 128, 128, 128, 128, 0, 1, 2, 128, // + 0, 128, 128, 128, 1, 2, 3, 128, // + 128, 0, 128, 128, 1, 2, 3, 128, // + 0, 1, 128, 128, 2, 3, 4, 128, // + 128, 128, 0, 128, 1, 2, 3, 128, // + 0, 128, 1, 128, 2, 3, 4, 128, // + 128, 0, 1, 128, 2, 3, 4, 128, // + 0, 1, 2, 128, 3, 4, 5, 128, // + 128, 128, 128, 0, 1, 2, 3, 128, // + 0, 128, 128, 1, 2, 3, 4, 128, // + 128, 0, 128, 1, 2, 3, 4, 128, // + 0, 1, 128, 2, 3, 4, 5, 128, // + 128, 128, 0, 1, 2, 3, 4, 128, // + 0, 128, 1, 2, 3, 4, 5, 128, // + 128, 0, 1, 2, 3, 4, 5, 128, // + 0, 1, 2, 3, 4, 5, 6, 128, // + 128, 128, 128, 128, 128, 128, 128, 0, // + 0, 128, 128, 128, 128, 128, 128, 1, // + 128, 0, 128, 128, 128, 128, 128, 1, // + 0, 1, 128, 128, 128, 128, 128, 2, // + 128, 128, 0, 128, 128, 128, 128, 1, // + 0, 128, 1, 128, 128, 128, 128, 2, // + 128, 0, 1, 128, 128, 128, 128, 2, // + 0, 1, 2, 128, 128, 128, 128, 3, // + 128, 128, 128, 0, 128, 128, 128, 1, // + 0, 128, 128, 1, 128, 128, 128, 2, // + 128, 0, 128, 1, 128, 128, 128, 2, // + 0, 1, 128, 2, 128, 128, 128, 3, // + 128, 128, 0, 1, 128, 128, 128, 2, // + 0, 128, 1, 2, 128, 128, 128, 3, // + 128, 0, 1, 2, 128, 128, 128, 3, // + 0, 1, 2, 3, 128, 128, 128, 4, // + 128, 128, 128, 128, 0, 128, 128, 1, // + 0, 128, 128, 128, 1, 128, 128, 2, // + 128, 0, 128, 128, 1, 128, 128, 2, // + 0, 1, 128, 128, 2, 128, 128, 3, // + 128, 128, 0, 128, 1, 128, 128, 2, // + 0, 128, 1, 128, 2, 128, 128, 3, // + 128, 0, 1, 128, 2, 128, 128, 3, // + 0, 1, 2, 128, 3, 128, 128, 4, // + 128, 128, 128, 0, 1, 128, 128, 2, // + 0, 128, 128, 1, 2, 128, 128, 3, // + 128, 0, 128, 1, 2, 128, 128, 3, // + 0, 1, 128, 2, 3, 128, 128, 4, // + 128, 128, 0, 1, 2, 128, 128, 3, // + 0, 128, 1, 2, 3, 128, 128, 4, // + 128, 0, 1, 2, 3, 128, 128, 4, // + 0, 1, 2, 3, 4, 128, 128, 5, // + 128, 128, 128, 128, 128, 0, 128, 1, // + 0, 128, 128, 128, 128, 1, 128, 2, // + 128, 0, 128, 128, 128, 1, 128, 2, // + 0, 1, 128, 128, 128, 2, 128, 3, // + 128, 128, 0, 128, 128, 1, 128, 2, // + 0, 128, 1, 128, 128, 2, 128, 3, // + 128, 0, 1, 128, 128, 2, 128, 3, // + 0, 1, 2, 128, 128, 3, 128, 4, // + 128, 128, 128, 0, 128, 1, 128, 2, // + 0, 128, 128, 1, 128, 2, 128, 3, // + 128, 0, 128, 1, 128, 2, 128, 3, // + 0, 1, 128, 2, 128, 3, 128, 4, // + 128, 128, 0, 1, 128, 2, 128, 3, // + 0, 128, 1, 2, 128, 3, 128, 4, // + 128, 0, 1, 2, 128, 3, 128, 4, // + 0, 1, 2, 3, 128, 4, 128, 5, // + 128, 128, 128, 128, 0, 1, 128, 2, // + 0, 128, 128, 128, 1, 2, 128, 3, // + 128, 0, 128, 128, 1, 2, 128, 3, // + 0, 1, 128, 128, 2, 3, 128, 4, // + 128, 128, 0, 128, 1, 2, 128, 3, // + 0, 128, 1, 128, 2, 3, 128, 4, // + 128, 0, 1, 128, 2, 3, 128, 4, // + 0, 1, 2, 128, 3, 4, 128, 5, // + 128, 128, 128, 0, 1, 2, 128, 3, // + 0, 128, 128, 1, 2, 3, 128, 4, // + 128, 0, 128, 1, 2, 3, 128, 4, // + 0, 1, 128, 2, 3, 4, 128, 5, // + 128, 128, 0, 1, 2, 3, 128, 4, // + 0, 128, 1, 2, 3, 4, 128, 5, // + 128, 0, 1, 2, 3, 4, 128, 5, // + 0, 1, 2, 3, 4, 5, 128, 6, // + 128, 128, 128, 128, 128, 128, 0, 1, // + 0, 128, 128, 128, 128, 128, 1, 2, // + 128, 0, 128, 128, 128, 128, 1, 2, // + 0, 1, 128, 128, 128, 128, 2, 3, // + 128, 128, 0, 128, 128, 128, 1, 2, // + 0, 128, 1, 128, 128, 128, 2, 3, // + 128, 0, 1, 128, 128, 128, 2, 3, // + 0, 1, 2, 128, 128, 128, 3, 4, // + 128, 128, 128, 0, 128, 128, 1, 2, // + 0, 128, 128, 1, 128, 128, 2, 3, // + 128, 0, 128, 1, 128, 128, 2, 3, // + 0, 1, 128, 2, 128, 128, 3, 4, // + 128, 128, 0, 1, 128, 128, 2, 3, // + 0, 128, 1, 2, 128, 128, 3, 4, // + 128, 0, 1, 2, 128, 128, 3, 4, // + 0, 1, 2, 3, 128, 128, 4, 5, // + 128, 128, 128, 128, 0, 128, 1, 2, // + 0, 128, 128, 128, 1, 128, 2, 3, // + 128, 0, 128, 128, 1, 128, 2, 3, // + 0, 1, 128, 128, 2, 128, 3, 4, // + 128, 128, 0, 128, 1, 128, 2, 3, // + 0, 128, 1, 128, 2, 128, 3, 4, // + 128, 0, 1, 128, 2, 128, 3, 4, // + 0, 1, 2, 128, 3, 128, 4, 5, // + 128, 128, 128, 0, 1, 128, 2, 3, // + 0, 128, 128, 1, 2, 128, 3, 4, // + 128, 0, 128, 1, 2, 128, 3, 4, // + 0, 1, 128, 2, 3, 128, 4, 5, // + 128, 128, 0, 1, 2, 128, 3, 4, // + 0, 128, 1, 2, 3, 128, 4, 5, // + 128, 0, 1, 2, 3, 128, 4, 5, // + 0, 1, 2, 3, 4, 128, 5, 6, // + 128, 128, 128, 128, 128, 0, 1, 2, // + 0, 128, 128, 128, 128, 1, 2, 3, // + 128, 0, 128, 128, 128, 1, 2, 3, // + 0, 1, 128, 128, 128, 2, 3, 4, // + 128, 128, 0, 128, 128, 1, 2, 3, // + 0, 128, 1, 128, 128, 2, 3, 4, // + 128, 0, 1, 128, 128, 2, 3, 4, // + 0, 1, 2, 128, 128, 3, 4, 5, // + 128, 128, 128, 0, 128, 1, 2, 3, // + 0, 128, 128, 1, 128, 2, 3, 4, // + 128, 0, 128, 1, 128, 2, 3, 4, // + 0, 1, 128, 2, 128, 3, 4, 5, // + 128, 128, 0, 1, 128, 2, 3, 4, // + 0, 128, 1, 2, 128, 3, 4, 5, // + 128, 0, 1, 2, 128, 3, 4, 5, // + 0, 1, 2, 3, 128, 4, 5, 6, // + 128, 128, 128, 128, 0, 1, 2, 3, // + 0, 128, 128, 128, 1, 2, 3, 4, // + 128, 0, 128, 128, 1, 2, 3, 4, // + 0, 1, 128, 128, 2, 3, 4, 5, // + 128, 128, 0, 128, 1, 2, 3, 4, // + 0, 128, 1, 128, 2, 3, 4, 5, // + 128, 0, 1, 128, 2, 3, 4, 5, // + 0, 1, 2, 128, 3, 4, 5, 6, // + 128, 128, 128, 0, 1, 2, 3, 4, // + 0, 128, 128, 1, 2, 3, 4, 5, // + 128, 0, 128, 1, 2, 3, 4, 5, // + 0, 1, 128, 2, 3, 4, 5, 6, // + 128, 128, 0, 1, 2, 3, 4, 5, // + 0, 128, 1, 2, 3, 4, 5, 6, // + 128, 0, 1, 2, 3, 4, 5, 6, // + 0, 1, 2, 3, 4, 5, 6, 7}; + return LoadU(du8, table + mask_bits * 8); +} + +} // namespace detail + +// Half vector of bytes: one table lookup +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + + const uint64_t mask_bits = BitsFromMask(d, mask); + const Vec128 indices = + detail::IndicesForExpandFromBits(mask_bits); + return BitCast(d, TableLookupBytesOr0(v, indices)); +} + +// Full vector of bytes: two table lookups +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const Full128 d; + const RebindToUnsigned du; + const Half duh; + const Vec128 vu = BitCast(du, v); + + const uint64_t mask_bits = BitsFromMask(d, mask); + const uint64_t maskL = mask_bits & 0xFF; + const uint64_t maskH = mask_bits >> 8; + + // We want to skip past the v bytes already consumed by idxL. There is no + // instruction for shift-reg by variable bytes. Storing v itself would work + // but would involve a store-load forwarding stall. We instead shuffle using + // loaded indices. + // TODO: MultiRotateRight would also help, but if we have that, we probably + // also have native 8-bit Expand? + alignas(16) static constexpr uint8_t iota[32] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128}; + const VFromD shift = LoadU(du, iota + PopCount(maskL)); + const VFromD vL = LowerHalf(duh, vu); + const VFromD vH = + LowerHalf(duh, TableLookupBytesOr0(vu, shift)); + + const VFromD idxL = detail::IndicesForExpandFromBits<8>(maskL); + const VFromD idxH = detail::IndicesForExpandFromBits<8>(maskH); + + const VFromD expandL = TableLookupBytesOr0(vL, idxL); + const VFromD expandH = TableLookupBytesOr0(vH, idxH); + return BitCast(d, Combine(du, expandH, expandL)); +} + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const RebindToUnsigned du; + + const Rebind du8; + const uint64_t mask_bits = BitsFromMask(d, mask); + + // Storing as 8-bit reduces table size from 4 KiB to 2 KiB. We cannot apply + // the nibble trick used below because not all indices fit within one lane. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintExpand16x8ByteTables + 128, 128, 128, 128, 128, 128, 128, 128, // + 0, 128, 128, 128, 128, 128, 128, 128, // + 128, 0, 128, 128, 128, 128, 128, 128, // + 0, 2, 128, 128, 128, 128, 128, 128, // + 128, 128, 0, 128, 128, 128, 128, 128, // + 0, 128, 2, 128, 128, 128, 128, 128, // + 128, 0, 2, 128, 128, 128, 128, 128, // + 0, 2, 4, 128, 128, 128, 128, 128, // + 128, 128, 128, 0, 128, 128, 128, 128, // + 0, 128, 128, 2, 128, 128, 128, 128, // + 128, 0, 128, 2, 128, 128, 128, 128, // + 0, 2, 128, 4, 128, 128, 128, 128, // + 128, 128, 0, 2, 128, 128, 128, 128, // + 0, 128, 2, 4, 128, 128, 128, 128, // + 128, 0, 2, 4, 128, 128, 128, 128, // + 0, 2, 4, 6, 128, 128, 128, 128, // + 128, 128, 128, 128, 0, 128, 128, 128, // + 0, 128, 128, 128, 2, 128, 128, 128, // + 128, 0, 128, 128, 2, 128, 128, 128, // + 0, 2, 128, 128, 4, 128, 128, 128, // + 128, 128, 0, 128, 2, 128, 128, 128, // + 0, 128, 2, 128, 4, 128, 128, 128, // + 128, 0, 2, 128, 4, 128, 128, 128, // + 0, 2, 4, 128, 6, 128, 128, 128, // + 128, 128, 128, 0, 2, 128, 128, 128, // + 0, 128, 128, 2, 4, 128, 128, 128, // + 128, 0, 128, 2, 4, 128, 128, 128, // + 0, 2, 128, 4, 6, 128, 128, 128, // + 128, 128, 0, 2, 4, 128, 128, 128, // + 0, 128, 2, 4, 6, 128, 128, 128, // + 128, 0, 2, 4, 6, 128, 128, 128, // + 0, 2, 4, 6, 8, 128, 128, 128, // + 128, 128, 128, 128, 128, 0, 128, 128, // + 0, 128, 128, 128, 128, 2, 128, 128, // + 128, 0, 128, 128, 128, 2, 128, 128, // + 0, 2, 128, 128, 128, 4, 128, 128, // + 128, 128, 0, 128, 128, 2, 128, 128, // + 0, 128, 2, 128, 128, 4, 128, 128, // + 128, 0, 2, 128, 128, 4, 128, 128, // + 0, 2, 4, 128, 128, 6, 128, 128, // + 128, 128, 128, 0, 128, 2, 128, 128, // + 0, 128, 128, 2, 128, 4, 128, 128, // + 128, 0, 128, 2, 128, 4, 128, 128, // + 0, 2, 128, 4, 128, 6, 128, 128, // + 128, 128, 0, 2, 128, 4, 128, 128, // + 0, 128, 2, 4, 128, 6, 128, 128, // + 128, 0, 2, 4, 128, 6, 128, 128, // + 0, 2, 4, 6, 128, 8, 128, 128, // + 128, 128, 128, 128, 0, 2, 128, 128, // + 0, 128, 128, 128, 2, 4, 128, 128, // + 128, 0, 128, 128, 2, 4, 128, 128, // + 0, 2, 128, 128, 4, 6, 128, 128, // + 128, 128, 0, 128, 2, 4, 128, 128, // + 0, 128, 2, 128, 4, 6, 128, 128, // + 128, 0, 2, 128, 4, 6, 128, 128, // + 0, 2, 4, 128, 6, 8, 128, 128, // + 128, 128, 128, 0, 2, 4, 128, 128, // + 0, 128, 128, 2, 4, 6, 128, 128, // + 128, 0, 128, 2, 4, 6, 128, 128, // + 0, 2, 128, 4, 6, 8, 128, 128, // + 128, 128, 0, 2, 4, 6, 128, 128, // + 0, 128, 2, 4, 6, 8, 128, 128, // + 128, 0, 2, 4, 6, 8, 128, 128, // + 0, 2, 4, 6, 8, 10, 128, 128, // + 128, 128, 128, 128, 128, 128, 0, 128, // + 0, 128, 128, 128, 128, 128, 2, 128, // + 128, 0, 128, 128, 128, 128, 2, 128, // + 0, 2, 128, 128, 128, 128, 4, 128, // + 128, 128, 0, 128, 128, 128, 2, 128, // + 0, 128, 2, 128, 128, 128, 4, 128, // + 128, 0, 2, 128, 128, 128, 4, 128, // + 0, 2, 4, 128, 128, 128, 6, 128, // + 128, 128, 128, 0, 128, 128, 2, 128, // + 0, 128, 128, 2, 128, 128, 4, 128, // + 128, 0, 128, 2, 128, 128, 4, 128, // + 0, 2, 128, 4, 128, 128, 6, 128, // + 128, 128, 0, 2, 128, 128, 4, 128, // + 0, 128, 2, 4, 128, 128, 6, 128, // + 128, 0, 2, 4, 128, 128, 6, 128, // + 0, 2, 4, 6, 128, 128, 8, 128, // + 128, 128, 128, 128, 0, 128, 2, 128, // + 0, 128, 128, 128, 2, 128, 4, 128, // + 128, 0, 128, 128, 2, 128, 4, 128, // + 0, 2, 128, 128, 4, 128, 6, 128, // + 128, 128, 0, 128, 2, 128, 4, 128, // + 0, 128, 2, 128, 4, 128, 6, 128, // + 128, 0, 2, 128, 4, 128, 6, 128, // + 0, 2, 4, 128, 6, 128, 8, 128, // + 128, 128, 128, 0, 2, 128, 4, 128, // + 0, 128, 128, 2, 4, 128, 6, 128, // + 128, 0, 128, 2, 4, 128, 6, 128, // + 0, 2, 128, 4, 6, 128, 8, 128, // + 128, 128, 0, 2, 4, 128, 6, 128, // + 0, 128, 2, 4, 6, 128, 8, 128, // + 128, 0, 2, 4, 6, 128, 8, 128, // + 0, 2, 4, 6, 8, 128, 10, 128, // + 128, 128, 128, 128, 128, 0, 2, 128, // + 0, 128, 128, 128, 128, 2, 4, 128, // + 128, 0, 128, 128, 128, 2, 4, 128, // + 0, 2, 128, 128, 128, 4, 6, 128, // + 128, 128, 0, 128, 128, 2, 4, 128, // + 0, 128, 2, 128, 128, 4, 6, 128, // + 128, 0, 2, 128, 128, 4, 6, 128, // + 0, 2, 4, 128, 128, 6, 8, 128, // + 128, 128, 128, 0, 128, 2, 4, 128, // + 0, 128, 128, 2, 128, 4, 6, 128, // + 128, 0, 128, 2, 128, 4, 6, 128, // + 0, 2, 128, 4, 128, 6, 8, 128, // + 128, 128, 0, 2, 128, 4, 6, 128, // + 0, 128, 2, 4, 128, 6, 8, 128, // + 128, 0, 2, 4, 128, 6, 8, 128, // + 0, 2, 4, 6, 128, 8, 10, 128, // + 128, 128, 128, 128, 0, 2, 4, 128, // + 0, 128, 128, 128, 2, 4, 6, 128, // + 128, 0, 128, 128, 2, 4, 6, 128, // + 0, 2, 128, 128, 4, 6, 8, 128, // + 128, 128, 0, 128, 2, 4, 6, 128, // + 0, 128, 2, 128, 4, 6, 8, 128, // + 128, 0, 2, 128, 4, 6, 8, 128, // + 0, 2, 4, 128, 6, 8, 10, 128, // + 128, 128, 128, 0, 2, 4, 6, 128, // + 0, 128, 128, 2, 4, 6, 8, 128, // + 128, 0, 128, 2, 4, 6, 8, 128, // + 0, 2, 128, 4, 6, 8, 10, 128, // + 128, 128, 0, 2, 4, 6, 8, 128, // + 0, 128, 2, 4, 6, 8, 10, 128, // + 128, 0, 2, 4, 6, 8, 10, 128, // + 0, 2, 4, 6, 8, 10, 12, 128, // + 128, 128, 128, 128, 128, 128, 128, 0, // + 0, 128, 128, 128, 128, 128, 128, 2, // + 128, 0, 128, 128, 128, 128, 128, 2, // + 0, 2, 128, 128, 128, 128, 128, 4, // + 128, 128, 0, 128, 128, 128, 128, 2, // + 0, 128, 2, 128, 128, 128, 128, 4, // + 128, 0, 2, 128, 128, 128, 128, 4, // + 0, 2, 4, 128, 128, 128, 128, 6, // + 128, 128, 128, 0, 128, 128, 128, 2, // + 0, 128, 128, 2, 128, 128, 128, 4, // + 128, 0, 128, 2, 128, 128, 128, 4, // + 0, 2, 128, 4, 128, 128, 128, 6, // + 128, 128, 0, 2, 128, 128, 128, 4, // + 0, 128, 2, 4, 128, 128, 128, 6, // + 128, 0, 2, 4, 128, 128, 128, 6, // + 0, 2, 4, 6, 128, 128, 128, 8, // + 128, 128, 128, 128, 0, 128, 128, 2, // + 0, 128, 128, 128, 2, 128, 128, 4, // + 128, 0, 128, 128, 2, 128, 128, 4, // + 0, 2, 128, 128, 4, 128, 128, 6, // + 128, 128, 0, 128, 2, 128, 128, 4, // + 0, 128, 2, 128, 4, 128, 128, 6, // + 128, 0, 2, 128, 4, 128, 128, 6, // + 0, 2, 4, 128, 6, 128, 128, 8, // + 128, 128, 128, 0, 2, 128, 128, 4, // + 0, 128, 128, 2, 4, 128, 128, 6, // + 128, 0, 128, 2, 4, 128, 128, 6, // + 0, 2, 128, 4, 6, 128, 128, 8, // + 128, 128, 0, 2, 4, 128, 128, 6, // + 0, 128, 2, 4, 6, 128, 128, 8, // + 128, 0, 2, 4, 6, 128, 128, 8, // + 0, 2, 4, 6, 8, 128, 128, 10, // + 128, 128, 128, 128, 128, 0, 128, 2, // + 0, 128, 128, 128, 128, 2, 128, 4, // + 128, 0, 128, 128, 128, 2, 128, 4, // + 0, 2, 128, 128, 128, 4, 128, 6, // + 128, 128, 0, 128, 128, 2, 128, 4, // + 0, 128, 2, 128, 128, 4, 128, 6, // + 128, 0, 2, 128, 128, 4, 128, 6, // + 0, 2, 4, 128, 128, 6, 128, 8, // + 128, 128, 128, 0, 128, 2, 128, 4, // + 0, 128, 128, 2, 128, 4, 128, 6, // + 128, 0, 128, 2, 128, 4, 128, 6, // + 0, 2, 128, 4, 128, 6, 128, 8, // + 128, 128, 0, 2, 128, 4, 128, 6, // + 0, 128, 2, 4, 128, 6, 128, 8, // + 128, 0, 2, 4, 128, 6, 128, 8, // + 0, 2, 4, 6, 128, 8, 128, 10, // + 128, 128, 128, 128, 0, 2, 128, 4, // + 0, 128, 128, 128, 2, 4, 128, 6, // + 128, 0, 128, 128, 2, 4, 128, 6, // + 0, 2, 128, 128, 4, 6, 128, 8, // + 128, 128, 0, 128, 2, 4, 128, 6, // + 0, 128, 2, 128, 4, 6, 128, 8, // + 128, 0, 2, 128, 4, 6, 128, 8, // + 0, 2, 4, 128, 6, 8, 128, 10, // + 128, 128, 128, 0, 2, 4, 128, 6, // + 0, 128, 128, 2, 4, 6, 128, 8, // + 128, 0, 128, 2, 4, 6, 128, 8, // + 0, 2, 128, 4, 6, 8, 128, 10, // + 128, 128, 0, 2, 4, 6, 128, 8, // + 0, 128, 2, 4, 6, 8, 128, 10, // + 128, 0, 2, 4, 6, 8, 128, 10, // + 0, 2, 4, 6, 8, 10, 128, 12, // + 128, 128, 128, 128, 128, 128, 0, 2, // + 0, 128, 128, 128, 128, 128, 2, 4, // + 128, 0, 128, 128, 128, 128, 2, 4, // + 0, 2, 128, 128, 128, 128, 4, 6, // + 128, 128, 0, 128, 128, 128, 2, 4, // + 0, 128, 2, 128, 128, 128, 4, 6, // + 128, 0, 2, 128, 128, 128, 4, 6, // + 0, 2, 4, 128, 128, 128, 6, 8, // + 128, 128, 128, 0, 128, 128, 2, 4, // + 0, 128, 128, 2, 128, 128, 4, 6, // + 128, 0, 128, 2, 128, 128, 4, 6, // + 0, 2, 128, 4, 128, 128, 6, 8, // + 128, 128, 0, 2, 128, 128, 4, 6, // + 0, 128, 2, 4, 128, 128, 6, 8, // + 128, 0, 2, 4, 128, 128, 6, 8, // + 0, 2, 4, 6, 128, 128, 8, 10, // + 128, 128, 128, 128, 0, 128, 2, 4, // + 0, 128, 128, 128, 2, 128, 4, 6, // + 128, 0, 128, 128, 2, 128, 4, 6, // + 0, 2, 128, 128, 4, 128, 6, 8, // + 128, 128, 0, 128, 2, 128, 4, 6, // + 0, 128, 2, 128, 4, 128, 6, 8, // + 128, 0, 2, 128, 4, 128, 6, 8, // + 0, 2, 4, 128, 6, 128, 8, 10, // + 128, 128, 128, 0, 2, 128, 4, 6, // + 0, 128, 128, 2, 4, 128, 6, 8, // + 128, 0, 128, 2, 4, 128, 6, 8, // + 0, 2, 128, 4, 6, 128, 8, 10, // + 128, 128, 0, 2, 4, 128, 6, 8, // + 0, 128, 2, 4, 6, 128, 8, 10, // + 128, 0, 2, 4, 6, 128, 8, 10, // + 0, 2, 4, 6, 8, 128, 10, 12, // + 128, 128, 128, 128, 128, 0, 2, 4, // + 0, 128, 128, 128, 128, 2, 4, 6, // + 128, 0, 128, 128, 128, 2, 4, 6, // + 0, 2, 128, 128, 128, 4, 6, 8, // + 128, 128, 0, 128, 128, 2, 4, 6, // + 0, 128, 2, 128, 128, 4, 6, 8, // + 128, 0, 2, 128, 128, 4, 6, 8, // + 0, 2, 4, 128, 128, 6, 8, 10, // + 128, 128, 128, 0, 128, 2, 4, 6, // + 0, 128, 128, 2, 128, 4, 6, 8, // + 128, 0, 128, 2, 128, 4, 6, 8, // + 0, 2, 128, 4, 128, 6, 8, 10, // + 128, 128, 0, 2, 128, 4, 6, 8, // + 0, 128, 2, 4, 128, 6, 8, 10, // + 128, 0, 2, 4, 128, 6, 8, 10, // + 0, 2, 4, 6, 128, 8, 10, 12, // + 128, 128, 128, 128, 0, 2, 4, 6, // + 0, 128, 128, 128, 2, 4, 6, 8, // + 128, 0, 128, 128, 2, 4, 6, 8, // + 0, 2, 128, 128, 4, 6, 8, 10, // + 128, 128, 0, 128, 2, 4, 6, 8, // + 0, 128, 2, 128, 4, 6, 8, 10, // + 128, 0, 2, 128, 4, 6, 8, 10, // + 0, 2, 4, 128, 6, 8, 10, 12, // + 128, 128, 128, 0, 2, 4, 6, 8, // + 0, 128, 128, 2, 4, 6, 8, 10, // + 128, 0, 128, 2, 4, 6, 8, 10, // + 0, 2, 128, 4, 6, 8, 10, 12, // + 128, 128, 0, 2, 4, 6, 8, 10, // + 0, 128, 2, 4, 6, 8, 10, 12, // + 128, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14}; + // Extend to double length because InterleaveLower will only use the (valid) + // lower half, and we want N u16. + const Twice du8x2; + const Vec128 indices8 = + ZeroExtendVector(du8x2, Load(du8, table + mask_bits * 8)); + const Vec128 indices16 = + BitCast(du, InterleaveLower(du8x2, indices8, indices8)); + // TableLookupBytesOr0 operates on bytes. To convert u16 lane indices to byte + // indices, add 0 to even and 1 to odd byte lanes. + const Vec128 byte_indices = Add( + indices16, + Set(du, static_cast(HWY_IS_LITTLE_ENDIAN ? 0x0100 : 0x0001))); + return BitCast(d, TableLookupBytesOr0(v, byte_indices)); +} + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, mask); + + alignas(16) static constexpr uint32_t packed_array[16] = { + // PrintExpand64x4Nibble - same for 32x4. + 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, + 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, + 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2). + const Vec128 packed = Set(du, packed_array[mask_bits]); + alignas(16) static constexpr uint32_t shifts[4] = {0, 4, 8, 12}; + Vec128 indices = packed >> Load(du, shifts); + // AVX2 _mm256_permutexvar_epi32 will ignore upper bits, but IndicesFromVec + // checks bounds, so clear the upper bits. + indices = And(indices, Set(du, N - 1)); + const Vec128 expand = + TableLookupLanes(BitCast(du, v), IndicesFromVec(du, indices)); + // TableLookupLanes cannot also zero masked-off lanes, so do that now. + return IfThenElseZero(mask, BitCast(d, expand)); +} + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + // Same as Compress, just zero out the mask=false lanes. + return IfThenElseZero(mask, Compress(v, mask)); +} + +// For single-element vectors, this is at least as fast as native. +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + return IfThenElseZero(mask, v); +} + +// ------------------------------ LoadExpand +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +#endif // HWY_NATIVE_EXPAND + +// ------------------------------ TwoTablesLookupLanes + +template +using IndicesFromD = decltype(IndicesFromVec(D(), Zero(RebindToUnsigned()))); + +// RVV/SVE have their own implementations of +// TwoTablesLookupLanes(D d, VFromD a, VFromD b, IndicesFromD idx) +#if HWY_TARGET != HWY_RVV && !HWY_TARGET_IS_SVE +template +HWY_API VFromD TwoTablesLookupLanes(D /*d*/, VFromD a, VFromD b, + IndicesFromD idx) { + return TwoTablesLookupLanes(a, b, idx); +} +#endif + +// ------------------------------ Reverse2, Reverse4, Reverse8 (8-bit) + +#if (defined(HWY_NATIVE_REVERSE2_8) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +#undef HWY_PREFER_ROTATE +// Platforms on which RotateRight is likely faster than TableLookupBytes. +// RVV and SVE anyway have their own implementation of this. +#if HWY_TARGET == HWY_SSE2 || HWY_TARGET <= HWY_AVX3 || \ + HWY_TARGET == HWY_WASM || HWY_TARGET == HWY_PPC8 +#define HWY_PREFER_ROTATE 1 +#else +#define HWY_PREFER_ROTATE 0 +#endif + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + // Exclude AVX3 because its 16-bit RotateRight is actually 3 instructions. +#if HWY_PREFER_ROTATE && HWY_TARGET > HWY_AVX3 + const Repartition du16; + return BitCast(d, RotateRight<8>(BitCast(du16, v))); +#else + const VFromD shuffle = Dup128VecFromValues(d, 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, + 11, 10, 13, 12, 15, 14); + return TableLookupBytes(v, shuffle); +#endif +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { +#if HWY_PREFER_ROTATE + const Repartition du16; + return BitCast(d, Reverse2(du16, BitCast(du16, Reverse2(d, v)))); +#else + const Repartition du8; + const VFromD shuffle = Dup128VecFromValues( + du8, 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12); + return TableLookupBytes(v, BitCast(d, shuffle)); +#endif +} + +template +HWY_API VFromD Reverse8(D d, VFromD v) { +#if HWY_PREFER_ROTATE + const Repartition du32; + return BitCast(d, Reverse2(du32, BitCast(du32, Reverse4(d, v)))); +#else + const Repartition du8; + const VFromD shuffle = Dup128VecFromValues( + du8, 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8); + return TableLookupBytes(v, BitCast(d, shuffle)); +#endif +} + +#endif // HWY_NATIVE_REVERSE2_8 + +// ------------------------------ ReverseLaneBytes + +#if (defined(HWY_NATIVE_REVERSE_LANE_BYTES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REVERSE_LANE_BYTES +#undef HWY_NATIVE_REVERSE_LANE_BYTES +#else +#define HWY_NATIVE_REVERSE_LANE_BYTES +#endif + +template +HWY_API V ReverseLaneBytes(V v) { + const DFromV d; + const Repartition du8; + return BitCast(d, Reverse2(du8, BitCast(du8, v))); +} + +template +HWY_API V ReverseLaneBytes(V v) { + const DFromV d; + const Repartition du8; + return BitCast(d, Reverse4(du8, BitCast(du8, v))); +} + +template +HWY_API V ReverseLaneBytes(V v) { + const DFromV d; + const Repartition du8; + return BitCast(d, Reverse8(du8, BitCast(du8, v))); +} + +#endif // HWY_NATIVE_REVERSE_LANE_BYTES + +// ------------------------------ ReverseBits + +// On these targets, we emulate 8-bit shifts using 16-bit shifts and therefore +// require at least two lanes to BitCast to 16-bit. We avoid Highway's 8-bit +// shifts because those would add extra masking already taken care of by +// UI8ReverseBitsStep. Note that AVX3_DL/AVX3_ZEN4 support GFNI and use it to +// implement ReverseBits, so this code is not used there. +#undef HWY_REVERSE_BITS_MIN_BYTES +#if ((HWY_TARGET >= HWY_AVX3 && HWY_TARGET <= HWY_SSE2) || \ + HWY_TARGET == HWY_WASM || HWY_TARGET == HWY_WASM_EMU256) +#define HWY_REVERSE_BITS_MIN_BYTES 2 +#else +#define HWY_REVERSE_BITS_MIN_BYTES 1 +#endif + +#if (defined(HWY_NATIVE_REVERSE_BITS_UI8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +namespace detail { + +template , HWY_REVERSE_BITS_MIN_BYTES - 1)> +HWY_INLINE V UI8ReverseBitsStep(V v) { + const DFromV d; + const RebindToUnsigned du; +#if HWY_REVERSE_BITS_MIN_BYTES == 2 + const Repartition d_shift; +#else + const RebindToUnsigned d_shift; +#endif + + const auto v_to_shift = BitCast(d_shift, v); + const auto shl_result = BitCast(d, ShiftLeft(v_to_shift)); + const auto shr_result = BitCast(d, ShiftRight(v_to_shift)); + const auto shr_result_mask = + BitCast(d, Set(du, static_cast(kShrResultMask))); + return Or(And(shr_result, shr_result_mask), + AndNot(shr_result_mask, shl_result)); +} + +#if HWY_REVERSE_BITS_MIN_BYTES == 2 +template , 1)> +HWY_INLINE V UI8ReverseBitsStep(V v) { + return V{UI8ReverseBitsStep(Vec128{v.raw}) + .raw}; +} +#endif + +} // namespace detail + +template +HWY_API V ReverseBits(V v) { + auto result = detail::UI8ReverseBitsStep<1, 0x55>(v); + result = detail::UI8ReverseBitsStep<2, 0x33>(result); + result = detail::UI8ReverseBitsStep<4, 0x0F>(result); + return result; +} + +#endif // HWY_NATIVE_REVERSE_BITS_UI8 + +#if (defined(HWY_NATIVE_REVERSE_BITS_UI16_32_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#undef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#else +#define HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#endif + +template +HWY_API V ReverseBits(V v) { + const DFromV d; + const Repartition du8; + return ReverseLaneBytes(BitCast(d, ReverseBits(BitCast(du8, v)))); +} +#endif // HWY_NATIVE_REVERSE_BITS_UI16_32_64 + +// ------------------------------ Per4LaneBlockShuffle + +#if (defined(HWY_NATIVE_PER4LANEBLKSHUF_DUP32) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#else +#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +template +HWY_INLINE Vec Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { +#if HWY_TARGET == HWY_RVV + constexpr int kPow2 = d.Pow2(); + constexpr int kLoadPow2 = HWY_MAX(kPow2, -1); + const ScalableTag d_load; +#else + constexpr size_t kMaxBytes = d.MaxBytes(); +#if HWY_TARGET_IS_NEON + constexpr size_t kMinLanesToLoad = 2; +#else + constexpr size_t kMinLanesToLoad = 4; +#endif + constexpr size_t kNumToLoad = + HWY_MAX(kMaxBytes / sizeof(uint32_t), kMinLanesToLoad); + const CappedTag d_load; +#endif + return ResizeBitCast(d, Dup128VecFromValues(d_load, x0, x1, x2, x3)); +} + +} // namespace detail +#endif + +#endif // HWY_NATIVE_PER4LANEBLKSHUF_DUP32 + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +template +HWY_INLINE V Per2LaneBlockShuffle(hwy::SizeTag<0> /*idx_10_tag*/, V v) { + return DupEven(v); +} + +template +HWY_INLINE V Per2LaneBlockShuffle(hwy::SizeTag<1> /*idx_10_tag*/, V v) { + const DFromV d; + return Reverse2(d, v); +} + +template +HWY_INLINE V Per2LaneBlockShuffle(hwy::SizeTag<2> /*idx_10_tag*/, V v) { + return v; +} + +template +HWY_INLINE V Per2LaneBlockShuffle(hwy::SizeTag<3> /*idx_10_tag*/, V v) { + return DupOdd(v); +} + +HWY_INLINE uint32_t U8x4Per4LaneBlkIndices(const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { +#if HWY_IS_LITTLE_ENDIAN + return static_cast((idx3 << 24) | (idx2 << 16) | (idx1 << 8) | + idx0); +#else + return static_cast(idx3 | (idx2 << 8) | (idx1 << 16) | + (idx0 << 24)); +#endif +} + +template +HWY_INLINE Vec TblLookupPer4LaneBlkU8IdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { +#if HWY_TARGET == HWY_RVV + const AdjustSimdTagToMinVecPow2> du32; +#else + const Repartition du32; +#endif + + return ResizeBitCast( + d, Set(du32, U8x4Per4LaneBlkIndices(idx3, idx2, idx1, idx0))); +} + +#if HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_EMU128 +#define HWY_PER_4_BLK_TBL_LOOKUP_LANES_ENABLE(D) void* = nullptr +#else +#define HWY_PER_4_BLK_TBL_LOOKUP_LANES_ENABLE(D) HWY_IF_T_SIZE_D(D, 8) + +template +HWY_INLINE V Per4LaneBlkShufDoTblLookup(V v, V idx) { + const DFromV d; + const Repartition du8; + return BitCast(d, TableLookupBytes(BitCast(du8, v), BitCast(du8, idx))); +} + +template +HWY_INLINE Vec TblLookupPer4LaneBlkShufIdx(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const Repartition du32; + const uint32_t idx3210 = U8x4Per4LaneBlkIndices(idx3, idx2, idx1, idx0); + const auto v_byte_idx = Per4LaneBlkShufDupSet4xU32( + du32, static_cast(idx3210 + 0x0C0C0C0C), + static_cast(idx3210 + 0x08080808), + static_cast(idx3210 + 0x04040404), + static_cast(idx3210)); + return ResizeBitCast(d, v_byte_idx); +} + +template +HWY_INLINE Vec TblLookupPer4LaneBlkShufIdx(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const Repartition du32; +#if HWY_IS_LITTLE_ENDIAN + const uint32_t idx10 = static_cast((idx1 << 16) | idx0); + const uint32_t idx32 = static_cast((idx3 << 16) | idx2); + constexpr uint32_t kLaneByteOffsets{0x01000100}; +#else + const uint32_t idx10 = static_cast(idx1 | (idx0 << 16)); + const uint32_t idx32 = static_cast(idx3 | (idx2 << 16)); + constexpr uint32_t kLaneByteOffsets{0x00010001}; +#endif + constexpr uint32_t kHiLaneByteOffsets{kLaneByteOffsets + 0x08080808u}; + + const auto v_byte_idx = Per4LaneBlkShufDupSet4xU32( + du32, static_cast(idx32 * 0x0202u + kHiLaneByteOffsets), + static_cast(idx10 * 0x0202u + kHiLaneByteOffsets), + static_cast(idx32 * 0x0202u + kLaneByteOffsets), + static_cast(idx10 * 0x0202u + kLaneByteOffsets)); + return ResizeBitCast(d, v_byte_idx); +} + +template +HWY_INLINE Vec TblLookupPer4LaneBlkShufIdx(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const Repartition du32; +#if HWY_IS_LITTLE_ENDIAN + constexpr uint32_t kLaneByteOffsets{0x03020100}; +#else + constexpr uint32_t kLaneByteOffsets{0x00010203}; +#endif + + const auto v_byte_idx = Per4LaneBlkShufDupSet4xU32( + du32, static_cast(idx3 * 0x04040404u + kLaneByteOffsets), + static_cast(idx2 * 0x04040404u + kLaneByteOffsets), + static_cast(idx1 * 0x04040404u + kLaneByteOffsets), + static_cast(idx0 * 0x04040404u + kLaneByteOffsets)); + return ResizeBitCast(d, v_byte_idx); +} +#endif + +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + return TblLookupPer4LaneBlkU8IdxInBlk(d, idx3, idx2, idx1, idx0); +} + +#if HWY_TARGET == HWY_RVV +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const Rebind du8; + return PromoteTo(d, + TblLookupPer4LaneBlkU8IdxInBlk(du8, idx3, idx2, idx1, idx0)); +} +#else +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const uint16_t u16_idx0 = static_cast(idx0); + const uint16_t u16_idx1 = static_cast(idx1); + const uint16_t u16_idx2 = static_cast(idx2); + const uint16_t u16_idx3 = static_cast(idx3); +#if HWY_TARGET_IS_NEON + constexpr size_t kMinLanesToLoad = 4; +#else + constexpr size_t kMinLanesToLoad = 8; +#endif + constexpr size_t kNumToLoad = HWY_MAX(HWY_MAX_LANES_D(D), kMinLanesToLoad); + const CappedTag d_load; + return ResizeBitCast( + d, Dup128VecFromValues(d_load, u16_idx0, u16_idx1, u16_idx2, u16_idx3, + u16_idx0, u16_idx1, u16_idx2, u16_idx3)); +} + +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + return Per4LaneBlkShufDupSet4xU32(d, idx3, idx2, idx1, idx0); +} + +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const RebindToUnsigned du; + const Rebind du32; + return BitCast(d, PromoteTo(du, Per4LaneBlkShufDupSet4xU32(du32, idx3, idx2, + idx1, idx0))); +} +#endif + +template +HWY_INLINE IndicesFromD TblLookupPer4LaneBlkShufIdx(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const RebindToUnsigned du; + using TU = TFromD; + auto idx_in_blk = TblLookupPer4LaneBlkIdxInBlk(du, idx3, idx2, idx1, idx0); + + constexpr size_t kN = HWY_MAX_LANES_D(D); + if (kN < 4) { + idx_in_blk = And(idx_in_blk, Set(du, static_cast(kN - 1))); + } + +#if HWY_TARGET == HWY_RVV + const auto blk_offsets = AndS(Iota0(du), static_cast(~TU{3})); +#else + const auto blk_offsets = + And(Iota(du, TU{0}), Set(du, static_cast(~TU{3}))); +#endif + return IndicesFromVec(d, Add(idx_in_blk, blk_offsets)); +} + +template )> +HWY_INLINE V Per4LaneBlkShufDoTblLookup(V v, IndicesFromD> idx) { + return TableLookupLanes(v, idx); +} + +#undef HWY_PER_4_BLK_TBL_LOOKUP_LANES_ENABLE + +template +HWY_INLINE V TblLookupPer4LaneBlkShuf(V v, size_t idx3210) { + const DFromV d; + const uint32_t idx3 = static_cast((idx3210 >> 6) & 3); + const uint32_t idx2 = static_cast((idx3210 >> 4) & 3); + const uint32_t idx1 = static_cast((idx3210 >> 2) & 3); + const uint32_t idx0 = static_cast(idx3210 & 3); + const auto idx = TblLookupPer4LaneBlkShufIdx(d, idx3, idx2, idx1, idx0); + return Per4LaneBlkShufDoTblLookup(v, idx); +} + +// The detail::Per4LaneBlockShuffle overloads that have the extra lane_size_tag +// and vect_size_tag parameters are only called for vectors that have at +// least 4 lanes (or scalable vectors that might possibly have 4 or more lanes) +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + return TblLookupPer4LaneBlkShuf(v, kIdx3210); +} + +#if HWY_HAVE_FLOAT64 +template +HWY_INLINE VFromD>> Per4LaneBlockShufCastToWide( + hwy::FloatTag /* type_tag */, hwy::SizeTag<4> /* lane_size_tag */, V v) { + const DFromV d; + const RepartitionToWide dw; + return BitCast(dw, v); +} +#endif + +template +HWY_INLINE VFromD>>> +Per4LaneBlockShufCastToWide(hwy::FloatTag /* type_tag */, + hwy::SizeTag /* lane_size_tag */, V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(dw, v); +} + +template +HWY_INLINE VFromD>> Per4LaneBlockShufCastToWide( + hwy::NonFloatTag /* type_tag */, + hwy::SizeTag /* lane_size_tag */, V v) { + const DFromV d; + const RepartitionToWide dw; + return BitCast(dw, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x1B> /*idx_3210_tag*/, V v) { + const DFromV d; + return Reverse4(d, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x44> /*idx_3210_tag*/, V v) { + const DFromV d; + const auto vw = Per4LaneBlockShufCastToWide( + hwy::IsFloatTag>(), hwy::SizeTag)>(), v); + return BitCast(d, DupEven(vw)); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x4E> /*idx_3210_tag*/, V v) { + const DFromV d; + const auto vw = Per4LaneBlockShufCastToWide( + hwy::IsFloatTag>(), hwy::SizeTag)>(), v); + const DFromV dw; + return BitCast(d, Reverse2(dw, vw)); +} + +#if HWY_MAX_BYTES >= 32 +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x4E> /*idx_3210_tag*/, V v) { + return SwapAdjacentBlocks(v); +} +#endif + +template , 4), + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2))> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x50> /*idx_3210_tag*/, V v) { + const DFromV d; + return InterleaveLower(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x50> /*idx_3210_tag*/, V v) { + const DFromV d; + return InterleaveLower(d, v, v); +} + +template , 4)> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x88> /*idx_3210_tag*/, V v) { + const DFromV d; + return ConcatEven(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xA0> /*idx_3210_tag*/, V v) { + return DupEven(v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xB1> /*idx_3210_tag*/, V v) { + const DFromV d; + return Reverse2(d, v); +} + +template , 4)> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xDD> /*idx_3210_tag*/, V v) { + const DFromV d; + return ConcatOdd(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xE4> /*idx_3210_tag*/, V v) { + return v; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xEE> /*idx_3210_tag*/, V v) { + const DFromV d; + const auto vw = Per4LaneBlockShufCastToWide( + hwy::IsFloatTag>(), hwy::SizeTag)>(), v); + return BitCast(d, DupOdd(vw)); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xF5> /*idx_3210_tag*/, V v) { + return DupOdd(v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xFA> /*idx_3210_tag*/, V v) { + const DFromV d; + return InterleaveUpper(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag idx_3210_tag, V v) { + const DFromV d; + return Per4LaneBlockShuffle(idx_3210_tag, hwy::SizeTag)>(), + hwy::SizeTag(), v); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_SCALAR + +template , 1)> +HWY_API V Per4LaneBlockShuffle(V v) { + static_assert(kIdx0 <= 3, "kIdx0 <= 3 must be true"); + static_assert(kIdx1 <= 3, "kIdx1 <= 3 must be true"); + static_assert(kIdx2 <= 3, "kIdx2 <= 3 must be true"); + static_assert(kIdx3 <= 3, "kIdx3 <= 3 must be true"); + + return v; +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template , 2)> +HWY_API V Per4LaneBlockShuffle(V v) { + static_assert(kIdx0 <= 3, "kIdx0 <= 3 must be true"); + static_assert(kIdx1 <= 3, "kIdx1 <= 3 must be true"); + static_assert(kIdx2 <= 3, "kIdx2 <= 3 must be true"); + static_assert(kIdx3 <= 3, "kIdx3 <= 3 must be true"); + + constexpr bool isReverse2 = (kIdx0 == 1 || kIdx1 == 0) && (kIdx0 != kIdx1); + constexpr size_t kPer2BlkIdx0 = (kIdx0 <= 1) ? kIdx0 : (isReverse2 ? 1 : 0); + constexpr size_t kPer2BlkIdx1 = (kIdx1 <= 1) ? kIdx1 : (isReverse2 ? 0 : 1); + + constexpr size_t kIdx10 = (kPer2BlkIdx1 << 1) | kPer2BlkIdx0; + static_assert(kIdx10 <= 3, "kIdx10 <= 3 must be true"); + return detail::Per2LaneBlockShuffle(hwy::SizeTag(), v); +} + +template , 2)> +HWY_API V Per4LaneBlockShuffle(V v) { + static_assert(kIdx0 <= 3, "kIdx0 <= 3 must be true"); + static_assert(kIdx1 <= 3, "kIdx1 <= 3 must be true"); + static_assert(kIdx2 <= 3, "kIdx2 <= 3 must be true"); + static_assert(kIdx3 <= 3, "kIdx3 <= 3 must be true"); + + constexpr size_t kIdx3210 = + (kIdx3 << 6) | (kIdx2 << 4) | (kIdx1 << 2) | kIdx0; + return detail::Per4LaneBlockShuffle(hwy::SizeTag(), v); +} +#endif + +// ------------------------------ PairwiseAdd128/PairwiseSub128 +// (Per4LaneBlockShuffle) +#if (defined(HWY_NATIVE_PAIRWISE_ADD_128) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PAIRWISE_ADD_128 +#undef HWY_NATIVE_PAIRWISE_ADD_128 +#else +#define HWY_NATIVE_PAIRWISE_ADD_128 +#endif + +namespace detail { + +// detail::BlockwiseConcatOddEven(d, v) returns the even lanes of each block of +// v followed by the odd lanes of v +#if HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D d, + Vec v) { +#if HWY_TARGET == HWY_RVV + const ScalableTag du64; +#else + const Repartition> du64; +#endif + + const Repartition, decltype(du64)> d_concat; + const auto v_to_concat = ResizeBitCast(d_concat, v); + + const auto evens = ConcatEven(d, v_to_concat, v_to_concat); + const auto odds = ConcatOdd(d, v_to_concat, v_to_concat); + return ResizeBitCast( + d, InterleaveWholeLower(BitCast(du64, evens), BitCast(du64, odds))); +} + +#else // !(HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV) + +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D d, + Vec v) { +#if HWY_TARGET == HWY_SSE2 + const RebindToUnsigned du; + const RebindToSigned> dw; + + const auto vu = BitCast(du, v); + return BitCast( + d, OrderedDemote2To(du, PromoteEvenTo(dw, vu), PromoteOddTo(dw, vu))); +#else + const Repartition du8; + const auto idx = + BitCast(d, Dup128VecFromValues(du8, 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, + 9, 11, 13, 15)); + return TableLookupBytes(v, idx); +#endif +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D d, + Vec v) { +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di; + const RepartitionToWide dw; + const auto vi = BitCast(di, v); + return BitCast( + d, OrderedDemote2To(di, PromoteEvenTo(dw, vi), PromoteOddTo(dw, vi))); +#else + const Repartition du8; + const auto idx = BitCast(d, Dup128VecFromValues(du8, 0, 1, 4, 5, 8, 9, 12, 13, + 2, 3, 6, 7, 10, 11, 14, 15)); + return TableLookupBytes(v, idx); +#endif +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D /*d*/, + Vec v) { + return Per4LaneBlockShuffle<3, 1, 2, 0>(v); +} +#endif // HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV + +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D /*d*/, + Vec v) { + return v; +} + +} // namespace detail + +// Pairwise add with output in 128 bit blocks of a and b. +template +HWY_API Vec PairwiseAdd128(D d, Vec a, Vec b) { + return detail::BlockwiseConcatOddEven(d, PairwiseAdd(d, a, b)); +} + +// Pairwise sub with output in 128 bit blocks of a and b. +template +HWY_API Vec PairwiseSub128(D d, Vec a, Vec b) { + return detail::BlockwiseConcatOddEven(d, PairwiseSub(d, a, b)); +} + +#endif + +// ------------------------------ Blocks + +template +HWY_API size_t Blocks(D d) { + return (d.MaxBytes() <= 16) ? 1 : ((Lanes(d) * sizeof(TFromD) + 15) / 16); +} + +// ------------------------------ Block insert/extract/broadcast ops +#if (defined(HWY_NATIVE_BLK_INSERT_EXTRACT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BLK_INSERT_EXTRACT +#undef HWY_NATIVE_BLK_INSERT_EXTRACT +#else +#define HWY_NATIVE_BLK_INSERT_EXTRACT +#endif + +template +HWY_API V InsertBlock(V /*v*/, V blk_to_insert) { + static_assert(kBlockIdx == 0, "Invalid block index"); + return blk_to_insert; +} + +template +HWY_API V ExtractBlock(V v) { + static_assert(kBlockIdx == 0, "Invalid block index"); + return v; +} + +template +HWY_API V BroadcastBlock(V v) { + static_assert(kBlockIdx == 0, "Invalid block index"); + return v; +} + +#endif // HWY_NATIVE_BLK_INSERT_EXTRACT + +// ------------------------------ BroadcastLane +#if (defined(HWY_NATIVE_BROADCASTLANE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BROADCASTLANE +#undef HWY_NATIVE_BROADCASTLANE +#else +#define HWY_NATIVE_BROADCASTLANE +#endif + +template +HWY_API V BroadcastLane(V v) { + return Broadcast(v); +} + +#endif // HWY_NATIVE_BROADCASTLANE + +// ------------------------------ Slide1Up and Slide1Down +#if (defined(HWY_NATIVE_SLIDE1_UP_DOWN) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SLIDE1_UP_DOWN +#undef HWY_NATIVE_SLIDE1_UP_DOWN +#else +#define HWY_NATIVE_SLIDE1_UP_DOWN +#endif + +template +HWY_API VFromD Slide1Up(D d, VFromD /*v*/) { + return Zero(d); +} +template +HWY_API VFromD Slide1Down(D d, VFromD /*v*/) { + return Zero(d); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + return ShiftLeftLanes<1>(d, v); +} +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + return ShiftRightLanes<1>(d, v); +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_SLIDE1_UP_DOWN + +// ------------------------------ SlideUpBlocks + +template +HWY_API VFromD SlideUpBlocks(D /*d*/, VFromD v) { + static_assert(kBlocks == 0, "kBlocks == 0 must be true"); + return v; +} + +#if HWY_HAVE_SCALABLE || HWY_TARGET == HWY_SVE_256 +template +HWY_API VFromD SlideUpBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && static_cast(kBlocks) < d.MaxBlocks(), + "kBlocks must be between 0 and d.MaxBlocks() - 1"); + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + return SlideUpLanes(d, v, static_cast(kBlocks) * kLanesPerBlock); +} +#endif + +// ------------------------------ SlideDownBlocks + +template +HWY_API VFromD SlideDownBlocks(D /*d*/, VFromD v) { + static_assert(kBlocks == 0, "kBlocks == 0 must be true"); + return v; +} + +#if HWY_HAVE_SCALABLE || HWY_TARGET == HWY_SVE_256 +template +HWY_API VFromD SlideDownBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && static_cast(kBlocks) < d.MaxBlocks(), + "kBlocks must be between 0 and d.MaxBlocks() - 1"); + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + return SlideDownLanes(d, v, static_cast(kBlocks) * kLanesPerBlock); +} +#endif + +// ------------------------------ Slide mask up/down +#if (defined(HWY_NATIVE_SLIDE_MASK) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_SLIDE_MASK +#undef HWY_NATIVE_SLIDE_MASK +#else +#define HWY_NATIVE_SLIDE_MASK +#endif + +template +HWY_API Mask SlideMask1Up(D d, Mask m) { + return MaskFromVec(Slide1Up(d, VecFromMask(d, m))); +} + +template +HWY_API Mask SlideMask1Down(D d, Mask m) { + return MaskFromVec(Slide1Down(d, VecFromMask(d, m))); +} + +template +HWY_API Mask SlideMaskUpLanes(D d, Mask m, size_t amt) { + return MaskFromVec(SlideUpLanes(d, VecFromMask(d, m), amt)); +} + +template +HWY_API Mask SlideMaskDownLanes(D d, Mask m, size_t amt) { + return MaskFromVec(SlideDownLanes(d, VecFromMask(d, m), amt)); +} + +#endif // HWY_NATIVE_SLIDE_MASK + +// ------------------------------ SumsOfAdjQuadAbsDiff + +#if (defined(HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API Vec>> SumsOfAdjQuadAbsDiff(V8 a, V8 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + using D8 = DFromV; + const D8 d8; + const RebindToUnsigned du8; + const RepartitionToWide d16; + const RepartitionToWide du16; + + // Ensure that a is resized to a vector that has at least + // HWY_MAX(Lanes(d8), size_t{8} << kAOffset) lanes for the interleave and + // CombineShiftRightBytes operations below. +#if HWY_TARGET == HWY_RVV + // On RVV targets, need to ensure that d8_interleave.Pow2() >= 0 is true + // to ensure that Lanes(d8_interleave) >= 16 is true. + + // Lanes(d8_interleave) >= Lanes(d8) is guaranteed to be true on RVV + // targets as d8_interleave.Pow2() >= d8.Pow2() is true. + constexpr int kInterleavePow2 = HWY_MAX(d8.Pow2(), 0); + const ScalableTag, kInterleavePow2> d8_interleave; +#elif HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE + // On SVE targets, Lanes(d8_interleave) >= 16 and + // Lanes(d8_interleave) >= Lanes(d8) are both already true as d8 is a SIMD + // tag for a full u8/i8 vector on SVE. + const D8 d8_interleave; +#else + // On targets that use non-scalable vector types, Lanes(d8_interleave) is + // equal to HWY_MAX(Lanes(d8), size_t{8} << kAOffset). + constexpr size_t kInterleaveLanes = + HWY_MAX(HWY_MAX_LANES_D(D8), size_t{8} << kAOffset); + const FixedTag, kInterleaveLanes> d8_interleave; +#endif + + // The ResizeBitCast operation below will resize a to a vector that has + // at least HWY_MAX(Lanes(d8), size_t{8} << kAOffset) lanes for the + // InterleaveLower, InterleaveUpper, and CombineShiftRightBytes operations + // below. + const auto a_to_interleave = ResizeBitCast(d8_interleave, a); + + const auto a_interleaved_lo = + InterleaveLower(d8_interleave, a_to_interleave, a_to_interleave); + const auto a_interleaved_hi = + InterleaveUpper(d8_interleave, a_to_interleave, a_to_interleave); + + /* a01: { a[kAOffset*4+0], a[kAOffset*4+1], a[kAOffset*4+1], a[kAOffset*4+2], + a[kAOffset*4+2], a[kAOffset*4+3], a[kAOffset*4+3], a[kAOffset*4+4], + a[kAOffset*4+4], a[kAOffset*4+5], a[kAOffset*4+5], a[kAOffset*4+6], + a[kAOffset*4+6], a[kAOffset*4+7], a[kAOffset*4+7], a[kAOffset*4+8] } + */ + /* a23: { a[kAOffset*4+2], a[kAOffset*4+3], a[kAOffset*4+3], a[kAOffset*4+4], + a[kAOffset*4+4], a[kAOffset*4+5], a[kAOffset*4+5], a[kAOffset*4+6], + a[kAOffset*4+6], a[kAOffset*4+7], a[kAOffset*4+7], a[kAOffset*4+8], + a[kAOffset*4+8], a[kAOffset*4+9], a[kAOffset*4+9], a[kAOffset*4+10] + } */ + + // a01 and a23 are resized back to V8 as only the first Lanes(d8) lanes of + // the CombineShiftRightBytes are needed for the subsequent AbsDiff operations + // and as a01 and a23 need to be the same vector type as b01 and b23 for the + // AbsDiff operations below. + const V8 a01 = + ResizeBitCast(d8, CombineShiftRightBytes( + d8_interleave, a_interleaved_hi, a_interleaved_lo)); + const V8 a23 = + ResizeBitCast(d8, CombineShiftRightBytes( + d8_interleave, a_interleaved_hi, a_interleaved_lo)); + + /* b01: { b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1], + b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1], + b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1], + b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1] } + */ + /* b23: { b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3], + b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3], + b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3], + b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3] } + */ + const V8 b01 = BitCast(d8, Broadcast(BitCast(d16, b))); + const V8 b23 = BitCast(d8, Broadcast(BitCast(d16, b))); + + const VFromD absdiff_sum_01 = + SumsOf2(BitCast(du8, AbsDiff(a01, b01))); + const VFromD absdiff_sum_23 = + SumsOf2(BitCast(du8, AbsDiff(a23, b23))); + return BitCast(d16, Add(absdiff_sum_01, absdiff_sum_23)); +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if (defined(HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API Vec>> SumsOfShuffledQuadAbsDiff(V8 a, + V8 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + +#if HWY_TARGET == HWY_RVV + // On RVV, ensure that both vA and vB have a LMUL of at least 1/2 so that + // both vA and vB can be bitcasted to a u32 vector. + const detail::AdjustSimdTagToMinVecPow2< + RepartitionToWideX2>> + d32; + const RepartitionToNarrow d16; + const RepartitionToNarrow d8; + + const auto vA = ResizeBitCast(d8, a); + const auto vB = ResizeBitCast(d8, b); +#else + const DFromV d8; + const RepartitionToWide d16; + const RepartitionToWide d32; + + const auto vA = a; + const auto vB = b; +#endif + + const RebindToUnsigned du8; + + const auto a_shuf = + Per4LaneBlockShuffle(BitCast(d32, vA)); + /* a0123_2345: { a_shuf[0], a_shuf[1], a_shuf[2], a_shuf[3], + a_shuf[2], a_shuf[3], a_shuf[4], a_shuf[5], + a_shuf[8], a_shuf[9], a_shuf[10], a_shuf[11], + a_shuf[10], a_shuf[11], a_shuf[12], a_shuf[13] } */ + /* a1234_3456: { a_shuf[1], a_shuf[2], a_shuf[3], a_shuf[4], + a_shuf[3], a_shuf[4], a_shuf[5], a_shuf[6], + a_shuf[9], a_shuf[10], a_shuf[11], a_shuf[12], + a_shuf[11], a_shuf[12], a_shuf[13], a_shuf[14] } */ +#if HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE + // On RVV/SVE targets, use Slide1Up/Slide1Down instead of + // ShiftLeftBytes/ShiftRightBytes to avoid unnecessary zeroing out of any + // lanes that are shifted into an adjacent 16-byte block as any lanes that are + // shifted into an adjacent 16-byte block by Slide1Up/Slide1Down will be + // replaced by the OddEven operation. + const auto a_0123_2345 = BitCast( + d8, OddEven(BitCast(d32, Slide1Up(d16, BitCast(d16, a_shuf))), a_shuf)); + const auto a_1234_3456 = + BitCast(d8, OddEven(BitCast(d32, Slide1Up(d8, BitCast(d8, a_shuf))), + BitCast(d32, Slide1Down(d8, BitCast(d8, a_shuf))))); +#else + const auto a_0123_2345 = + BitCast(d8, OddEven(ShiftLeftBytes<2>(d32, a_shuf), a_shuf)); + const auto a_1234_3456 = BitCast( + d8, + OddEven(ShiftLeftBytes<1>(d32, a_shuf), ShiftRightBytes<1>(d32, a_shuf))); +#endif + + auto even_sums = SumsOf4(BitCast(du8, AbsDiff(a_0123_2345, vB))); + auto odd_sums = SumsOf4(BitCast(du8, AbsDiff(a_1234_3456, vB))); + +#if HWY_IS_LITTLE_ENDIAN + odd_sums = ShiftLeft<16>(odd_sums); +#else + even_sums = ShiftLeft<16>(even_sums); +#endif + + const auto sums = OddEven(BitCast(d16, odd_sums), BitCast(d16, even_sums)); + +#if HWY_TARGET == HWY_RVV + return ResizeBitCast(RepartitionToWide>(), sums); +#else + return sums; +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF + +// ------------------------------ BitShuffle (Rol) +#if (defined(HWY_NATIVE_BITSHUFFLE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +#if HWY_HAVE_INTEGER64 && HWY_TARGET != HWY_SCALAR +template ), HWY_IF_UI8(TFromV)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Repartition du8; + +#if HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \ + HWY_TARGET == HWY_WASM_EMU256 + const Repartition d_idx_shr; +#else + const Repartition d_idx_shr; +#endif + +#if HWY_IS_LITTLE_ENDIAN + constexpr uint64_t kExtractedBitsMask = + static_cast(0x8040201008040201u); +#else + constexpr uint64_t kExtractedBitsMask = + static_cast(0x0102040810204080u); +#endif + + const auto k7 = Set(du8, uint8_t{0x07}); + + auto unmasked_byte_idx = BitCast(du8, ShiftRight<3>(BitCast(d_idx_shr, idx))); +#if HWY_IS_BIG_ENDIAN + // Need to invert the lower 3 bits of unmasked_byte_idx[i] on big-endian + // targets + unmasked_byte_idx = Xor(unmasked_byte_idx, k7); +#endif // HWY_IS_BIG_ENDIAN + + const auto byte_idx = BitwiseIfThenElse( + k7, unmasked_byte_idx, + BitCast(du8, Dup128VecFromValues(du64, uint64_t{0}, + uint64_t{0x0808080808080808u}))); + // We want to shift right by idx & 7 to extract the desired bit in `bytes`, + // and left by iota & 7 to put it in the correct output bit. To correctly + // handle shift counts from -7 to 7, we rotate. + const auto rotate_left_bits = Sub(Iota(du8, uint8_t{0}), BitCast(du8, idx)); + + const auto extracted_bits = + And(Rol(TableLookupBytes(v, byte_idx), rotate_left_bits), + BitCast(du8, Set(du64, kExtractedBitsMask))); + // Combine bit-sliced (one bit per byte) into one 64-bit sum. + return BitCast(d64, SumsOf8(extracted_bits)); +} +#endif // HWY_HAVE_INTEGER64 && HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_BITSHUFFLE + +template +HWY_API V MaskedOr(M m, V a, V b) { + return IfThenElseZero(m, Or(a, b)); +} +// ------------------------------ AllBits1/AllBits0 +#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ALLONES +#undef HWY_NATIVE_ALLONES +#else +#define HWY_NATIVE_ALLONES +#endif + +template > +HWY_API bool AllBits1(D d, V v) { + const RebindToUnsigned du; + using TU = TFromD; + return AllTrue(du, Eq(BitCast(du, v), Set(du, hwy::HighestValue()))); +} +#endif // HWY_NATIVE_ALLONES + +#if (defined(HWY_NATIVE_ALLZEROS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ALLZEROS +#undef HWY_NATIVE_ALLZEROS +#else +#define HWY_NATIVE_ALLZEROS +#endif + +template > +HWY_API bool AllBits0(D d, V v) { + return AllTrue(d, Eq(v, Zero(d))); +} +#endif // HWY_NATIVE_ALLZEROS + +// ------------------------------ MultiRotateRight +#if (defined(HWY_NATIVE_MULTIROTATERIGHT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MULTIROTATERIGHT +#undef HWY_NATIVE_MULTIROTATERIGHT +#else +#define HWY_NATIVE_MULTIROTATERIGHT +#endif + +template ), HWY_IF_UI8(TFromV), + class VI_2 = VFromD, DFromV>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VI_2)), + HWY_IF_V_SIZE_V(V, 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + const DFromV d64; + const Twice dt64; + const Repartition du8; + const Repartition dt_u8; + const Repartition dt_u16; + const auto k7 = Set(du8, uint8_t{0x07}); + const auto k63 = Set(du8, uint8_t{0x3F}); + + const auto masked_idx = And(k63, BitCast(du8, idx)); + + auto byte_idx = ShiftRight<3>(masked_idx); +#if HWY_IS_LITTLE_ENDIAN + const auto hi_byte_idx = Add(byte_idx, Set(du8, uint8_t{1})); +#else + byte_idx = Xor(byte_idx, k7); + const auto hi_byte_idx = Add(byte_idx, k7); +#endif + + const auto idx_shift = And(k7, masked_idx); + + // Calculate even lanes + const auto even_src = DupEven(ResizeBitCast(dt64, v)); + // Expand indexes to pull out 16 bit segments of idx and idx + 1 +#if HWY_IS_LITTLE_ENDIAN + const auto even_idx = InterleaveLower(ResizeBitCast(dt_u8, byte_idx), + ResizeBitCast(dt_u8, hi_byte_idx)); +#else + const auto even_idx = InterleaveLower(ResizeBitCast(dt_u8, hi_byte_idx), + ResizeBitCast(dt_u8, byte_idx)); +#endif + // TableLookupBytes indexes select from within a 16 byte block + const auto even_segments = TableLookupBytes(even_src, even_idx); + // Extract unaligned bytes from 16 bit segments + const auto even_idx_shift = PromoteTo(dt_u16, idx_shift); + const auto extracted_even_bytes = + Shr(BitCast(dt_u16, even_segments), even_idx_shift); + + // Extract the even bytes of each 128 bit block and pack into lower 64 bits +#if HWY_IS_LITTLE_ENDIAN + const auto even_lanes = BitCast( + dt64, + ConcatEven(dt_u8, Zero(dt_u8), BitCast(dt_u8, extracted_even_bytes))); +#else + const auto even_lanes = BitCast( + dt64, + ConcatOdd(dt_u8, Zero(dt_u8), BitCast(dt_u8, extracted_even_bytes))); +#endif + + return LowerHalf(d64, even_lanes); +} + +template ), HWY_IF_UI8(TFromV), + class VI_2 = VFromD, DFromV>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VI_2)), + HWY_IF_V_SIZE_GT_V(V, 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + const DFromV d64; + const Repartition du8; + const Repartition du16; + const auto k7 = Set(du8, uint8_t{0x07}); + const auto k63 = Set(du8, uint8_t{0x3F}); + + const auto masked_idx = And(k63, BitCast(du8, idx)); + + auto byte_idx = ShiftRight<3>(masked_idx); +#if HWY_IS_LITTLE_ENDIAN + const auto hi_byte_idx = Add(byte_idx, Set(du8, uint8_t{1})); +#else + byte_idx = Xor(byte_idx, k7); + const auto hi_byte_idx = Add(byte_idx, k7); +#endif + + const auto idx_shift = And(k7, masked_idx); + + // Calculate even lanes + const auto even_src = DupEven(v); + // Expand indexes to pull out 16 bit segments of idx and idx + 1 +#if HWY_IS_LITTLE_ENDIAN + const auto even_idx = InterleaveLower(byte_idx, hi_byte_idx); +#else + const auto even_idx = InterleaveLower(hi_byte_idx, byte_idx); +#endif + // TableLookupBytes indexes select from within a 16 byte block + const auto even_segments = TableLookupBytes(even_src, even_idx); + // Extract unaligned bytes from 16 bit segments +#if HWY_IS_LITTLE_ENDIAN + const auto even_idx_shift = ZipLower(idx_shift, Zero(du8)); +#else + const auto even_idx_shift = ZipLower(Zero(du8), idx_shift); +#endif + const auto extracted_even_bytes = + Shr(BitCast(du16, even_segments), even_idx_shift); + + // Calculate odd lanes + const auto odd_src = DupOdd(v); + // Expand indexes to pull out 16 bit segments of idx and idx + 1 +#if HWY_IS_LITTLE_ENDIAN + const auto odd_idx = InterleaveUpper(du8, byte_idx, hi_byte_idx); +#else + const auto odd_idx = InterleaveUpper(du8, hi_byte_idx, byte_idx); +#endif + // TableLookupBytes indexes select from within a 16 byte block + const auto odd_segments = TableLookupBytes(odd_src, odd_idx); + // Extract unaligned bytes from 16 bit segments +#if HWY_IS_LITTLE_ENDIAN + const auto odd_idx_shift = ZipUpper(du16, idx_shift, Zero(du8)); +#else + const auto odd_idx_shift = ZipUpper(du16, Zero(du8), idx_shift); +#endif + const auto extracted_odd_bytes = + Shr(BitCast(du16, odd_segments), odd_idx_shift); + + // Extract the even bytes of each 128 bit block and pack into lower 64 bits +#if HWY_IS_LITTLE_ENDIAN + const auto even_lanes = BitCast( + d64, ConcatEven(du8, Zero(du8), BitCast(du8, extracted_even_bytes))); + const auto odd_lanes = BitCast( + d64, ConcatEven(du8, Zero(du8), BitCast(du8, extracted_odd_bytes))); +#else + const auto even_lanes = BitCast( + d64, ConcatOdd(du8, Zero(du8), BitCast(du8, extracted_even_bytes))); + const auto odd_lanes = BitCast( + d64, ConcatOdd(du8, Zero(du8), BitCast(du8, extracted_odd_bytes))); +#endif + // Interleave at 64 bit level + return InterleaveWholeLower(even_lanes, odd_lanes); +} + +#if HWY_TARGET == HWY_RVV + +// MultiRotateRight for LMUL=1/2 case on RVV +template ), HWY_IF_UI8(TFromV), + class VI_2 = VFromD, DFromV>>, + HWY_IF_POW2_LE_D(DFromV, 0), + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VI_2) / 2)> +HWY_API V MultiRotateRight(V v, VI idx) { + return MultiRotateRight(v, ResizeBitCast(Twice>(), idx)); +} + +#endif + +#endif + +// ================================================== Operator wrapper + +// SVE* and RVV currently cannot define operators and have already defined +// (only) the corresponding functions such as Add. +#if (defined(HWY_NATIVE_OPERATOR_REPLACEMENTS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS +#undef HWY_NATIVE_OPERATOR_REPLACEMENTS +#else +#define HWY_NATIVE_OPERATOR_REPLACEMENTS +#endif + +template +HWY_API V Add(V a, V b) { + return a + b; +} +template +HWY_API V Sub(V a, V b) { + return a - b; +} + +template +HWY_API V Mul(V a, V b) { + return a * b; +} +template +HWY_API V Div(V a, V b) { + return a / b; +} +template +HWY_API V Mod(V a, V b) { + return a % b; +} + +template +V Shl(V a, V b) { + return a << b; +} +template +V Shr(V a, V b) { + return a >> b; +} + +template +HWY_API auto Eq(V a, V b) -> decltype(a == b) { + return a == b; +} +template +HWY_API auto Ne(V a, V b) -> decltype(a == b) { + return a != b; +} +template +HWY_API auto Lt(V a, V b) -> decltype(a == b) { + return a < b; +} + +template +HWY_API auto Gt(V a, V b) -> decltype(a == b) { + return a > b; +} +template +HWY_API auto Ge(V a, V b) -> decltype(a == b) { + return a >= b; +} + +template +HWY_API auto Le(V a, V b) -> decltype(a == b) { + return a <= b; +} + +#endif // HWY_NATIVE_OPERATOR_REPLACEMENTS + +#undef HWY_GENERIC_IF_EMULATED_D + +// TODO: remove once callers are updated. +// SVE and RVV do not support DFromM because their masks are loosely typed. +#if HWY_MAX_BYTES <= 64 && !HWY_TARGET_IS_SVE && HWY_TARGET != HWY_RVV +namespace detail { +template +uint64_t BitsFromMask(M m) { + const DFromM d; + return ::hwy::HWY_NAMESPACE::BitsFromMask(d, m); +} +} // namespace detail +#endif // !HWY_HAVE_SCALABLE && HWY_MAX_BYTES <= 64 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/aom/third_party/highway/hwy/ops/inside-inl.h b/third_party/aom/third_party/highway/hwy/ops/inside-inl.h new file mode 100644 index 000000000000..be0ff46e3853 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/inside-inl.h @@ -0,0 +1,691 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Must be included inside an existing include guard, with the following ops +// already defined: BitCast, And, Set, ShiftLeft, ShiftRight, PromoteLowerTo, +// ConcatEven, ConcatOdd, plus the optional detail::PromoteEvenTo and +// detail::PromoteOddTo (if implemented in the target-specific header). + +// This is normally set by set_macros-inl.h before this header is included; +// if not, we are viewing this header standalone. Reduce IDE errors by: +#if !defined(HWY_NAMESPACE) +// 1) Defining HWY_IDE so we get syntax highlighting rather than all-gray text. +#include "third_party/highway/hwy/ops/shared-inl.h" +// 2) Entering the HWY_NAMESPACE to make definitions from shared-inl.h visible. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +#define HWY_INSIDE_END_NAMESPACE +// 3) Providing a dummy VFromD (usually done by the target-specific header). +template +using VFromD = int; +template +using TFromV = int; +template +struct DFromV {}; +#endif + +// ------------------------------ Vec/Create/Get/Set2..4 + +// On SVE and RVV, Vec2..4 are aliases to built-in types. Also exclude the +// fixed-size SVE targets. +#if HWY_IDE || (!HWY_HAVE_SCALABLE && !HWY_TARGET_IS_SVE) + +// NOTE: these are used inside arm_neon-inl.h, hence they cannot be defined in +// generic_ops-inl.h, which is included after that. +template +struct Vec2 { + VFromD v0; + VFromD v1; +}; + +template +struct Vec3 { + VFromD v0; + VFromD v1; + VFromD v2; +}; + +template +struct Vec4 { + VFromD v0; + VFromD v1; + VFromD v2; + VFromD v3; +}; + +// D arg is unused but allows deducing D. +template +HWY_API Vec2 Create2(D /* tag */, VFromD v0, VFromD v1) { + return Vec2{v0, v1}; +} + +template +HWY_API Vec3 Create3(D /* tag */, VFromD v0, VFromD v1, VFromD v2) { + return Vec3{v0, v1, v2}; +} + +template +HWY_API Vec4 Create4(D /* tag */, VFromD v0, VFromD v1, VFromD v2, + VFromD v3) { + return Vec4{v0, v1, v2, v3}; +} + +template +HWY_API VFromD Get2(Vec2 tuple) { + static_assert(kIndex < 2, "Tuple index out of bounds"); + return kIndex == 0 ? tuple.v0 : tuple.v1; +} + +template +HWY_API VFromD Get3(Vec3 tuple) { + static_assert(kIndex < 3, "Tuple index out of bounds"); + return kIndex == 0 ? tuple.v0 : kIndex == 1 ? tuple.v1 : tuple.v2; +} + +template +HWY_API VFromD Get4(Vec4 tuple) { + static_assert(kIndex < 4, "Tuple index out of bounds"); + return kIndex == 0 ? tuple.v0 + : kIndex == 1 ? tuple.v1 + : kIndex == 2 ? tuple.v2 + : tuple.v3; +} + +template +HWY_API Vec2 Set2(Vec2 tuple, VFromD val) { + static_assert(kIndex < 2, "Tuple index out of bounds"); + if (kIndex == 0) { + tuple.v0 = val; + } else { + tuple.v1 = val; + } + return tuple; +} + +template +HWY_API Vec3 Set3(Vec3 tuple, VFromD val) { + static_assert(kIndex < 3, "Tuple index out of bounds"); + if (kIndex == 0) { + tuple.v0 = val; + } else if (kIndex == 1) { + tuple.v1 = val; + } else { + tuple.v2 = val; + } + return tuple; +} + +template +HWY_API Vec4 Set4(Vec4 tuple, VFromD val) { + static_assert(kIndex < 4, "Tuple index out of bounds"); + if (kIndex == 0) { + tuple.v0 = val; + } else if (kIndex == 1) { + tuple.v1 = val; + } else if (kIndex == 2) { + tuple.v2 = val; + } else { + tuple.v3 = val; + } + return tuple; +} + +#endif // !HWY_HAVE_SCALABLE || HWY_IDE + +// ------------------------------ Rol/Ror (And, Or, Neg, Shl, Shr) +#if (defined(HWY_NATIVE_ROL_ROR_8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_8 +#undef HWY_NATIVE_ROL_ROR_8 +#else +#define HWY_NATIVE_ROL_ROR_8 +#endif + +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint8_t{7}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint8_t{7}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +#endif // HWY_NATIVE_ROL_ROR_8 + +#if (defined(HWY_NATIVE_ROL_ROR_16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint16_t{15}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint16_t{15}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +#endif // HWY_NATIVE_ROL_ROR_16 + +#if (defined(HWY_NATIVE_ROL_ROR_32_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint32_t{31}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint32_t{31}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +#if HWY_HAVE_INTEGER64 +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint64_t{63}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint64_t{63}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} +#endif // HWY_HAVE_INTEGER64 + +#endif // HWY_NATIVE_ROL_ROR_32_64 + +// ------------------------------ RotateLeftSame/RotateRightSame + +#if (defined(HWY_NATIVE_ROL_ROR_SAME_8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_SAME_8 +#undef HWY_NATIVE_ROL_ROR_SAME_8 +#else +#define HWY_NATIVE_ROL_ROR_SAME_8 +#endif + +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 7; + const int shr_amt = static_cast((0u - static_cast(bits)) & 7u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 7; + const int shl_amt = static_cast((0u - static_cast(bits)) & 7u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +#endif // HWY_NATIVE_ROL_ROR_SAME_8 + +#if (defined(HWY_NATIVE_ROL_ROR_SAME_16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 15; + const int shr_amt = + static_cast((0u - static_cast(bits)) & 15u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 15; + const int shl_amt = + static_cast((0u - static_cast(bits)) & 15u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} +#endif // HWY_NATIVE_ROL_ROR_SAME_16 + +#if (defined(HWY_NATIVE_ROL_ROR_SAME_32_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 31; + const int shr_amt = + static_cast((0u - static_cast(bits)) & 31u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 31; + const int shl_amt = + static_cast((0u - static_cast(bits)) & 31u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +#if HWY_HAVE_INTEGER64 +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 63; + const int shr_amt = + static_cast((0u - static_cast(bits)) & 63u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 63; + const int shl_amt = + static_cast((0u - static_cast(bits)) & 63u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} +#endif // HWY_HAVE_INTEGER64 + +#endif // HWY_NATIVE_ROL_ROR_SAME_32_64 + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +// These are used by target-specific headers for ReorderWidenMulAccumulate etc. + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +// Tag dispatch is used in detail::PromoteEvenTo and detail::PromoteOddTo as +// there are target-specific specializations for some of the +// detail::PromoteEvenTo and detail::PromoteOddTo cases on +// SVE/PPC/SSE2/SSSE3/SSE4/AVX2. + +// All targets except HWY_SCALAR use the implementations of +// detail::PromoteEvenTo and detail::PromoteOddTo in generic_ops-inl.h for at +// least some of the PromoteEvenTo and PromoteOddTo cases. + +// Signed to signed PromoteEvenTo/PromoteOddTo +template +HWY_INLINE VFromD PromoteEvenTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_TARGET_IS_SVE + // The intrinsic expects the wide lane type. + return NativePromoteEvenTo(BitCast(d_to, v)); +#else +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, need to shift each lane of the bitcasted + // vector left by kToLaneSize * 4 bits to get the bits of the even + // source lanes into the upper kToLaneSize * 4 bits of even_in_hi. + const auto even_in_hi = ShiftLeft(BitCast(d_to, v)); +#else + // On big-endian targets, the bits of the even source lanes are already + // in the upper kToLaneSize * 4 bits of the lanes of the bitcasted + // vector. + const auto even_in_hi = BitCast(d_to, v); +#endif + + // Right-shift even_in_hi by kToLaneSize * 4 bits + return ShiftRight(even_in_hi); +#endif // HWY_TARGET_IS_SVE +} + +// Unsigned to unsigned PromoteEvenTo/PromoteOddTo +template +HWY_INLINE VFromD PromoteEvenTo( + hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_TARGET_IS_SVE + // The intrinsic expects the wide lane type. + return NativePromoteEvenTo(BitCast(d_to, v)); +#else +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, the bits of the even source lanes are already + // in the lower kToLaneSize * 4 bits of the lanes of the bitcasted vector. + + // Simply need to zero out the upper bits of each lane of the bitcasted + // vector. + return And(BitCast(d_to, v), + Set(d_to, static_cast>(LimitsMax>()))); +#else + // On big-endian targets, need to shift each lane of the bitcasted vector + // right by kToLaneSize * 4 bits to get the bits of the even source lanes into + // the lower kToLaneSize * 4 bits of the result. + + // The right shift below will zero out the upper kToLaneSize * 4 bits of the + // result. + return ShiftRight(BitCast(d_to, v)); +#endif +#endif // HWY_TARGET_IS_SVE +} + +template +HWY_INLINE VFromD PromoteOddTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, the bits of the odd source lanes are already in + // the upper kToLaneSize * 4 bits of the lanes of the bitcasted vector. + const auto odd_in_hi = BitCast(d_to, v); +#else + // On big-endian targets, need to shift each lane of the bitcasted vector + // left by kToLaneSize * 4 bits to get the bits of the odd source lanes into + // the upper kToLaneSize * 4 bits of odd_in_hi. + const auto odd_in_hi = ShiftLeft(BitCast(d_to, v)); +#endif + + // Right-shift odd_in_hi by kToLaneSize * 4 bits + return ShiftRight(odd_in_hi); +} + +template +HWY_INLINE VFromD PromoteOddTo( + hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, need to shift each lane of the bitcasted vector + // right by kToLaneSize * 4 bits to get the bits of the odd source lanes into + // the lower kToLaneSize * 4 bits of the result. + + // The right shift below will zero out the upper kToLaneSize * 4 bits of the + // result. + return ShiftRight(BitCast(d_to, v)); +#else + // On big-endian targets, the bits of the even source lanes are already + // in the lower kToLaneSize * 4 bits of the lanes of the bitcasted vector. + + // Simply need to zero out the upper bits of each lane of the bitcasted + // vector. + return And(BitCast(d_to, v), + Set(d_to, static_cast>(LimitsMax>()))); +#endif +} + +// Unsigned to signed: Same as unsigned->unsigned PromoteEvenTo/PromoteOddTo +// followed by BitCast to signed +template +HWY_INLINE VFromD PromoteEvenTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { + const RebindToUnsigned du_to; + return BitCast(d_to, + PromoteEvenTo(hwy::UnsignedTag(), hwy::SizeTag(), + hwy::UnsignedTag(), du_to, v)); +} + +template +HWY_INLINE VFromD PromoteOddTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { + const RebindToUnsigned du_to; + return BitCast(d_to, + PromoteOddTo(hwy::UnsignedTag(), hwy::SizeTag(), + hwy::UnsignedTag(), du_to, v)); +} + +// BF16->F32 PromoteEvenTo + +// NOTE: It is possible for FromTypeTag to be hwy::SignedTag or hwy::UnsignedTag +// instead of hwy::FloatTag on targets that use scalable vectors. + +// VBF16 is considered to be a bfloat16_t vector if TFromV is the same +// type as TFromV>> + +// The BF16->F32 PromoteEvenTo overload is only enabled if VBF16 is considered +// to be a bfloat16_t vector. +template >, + hwy::EnableIf, TFromV>()>* = nullptr> +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, DF32 d_to, + VBF16 v) { + const RebindToUnsigned du_to; +#if HWY_IS_LITTLE_ENDIAN + // On little-endian platforms, need to shift left each lane of the bitcasted + // vector by 16 bits. + return BitCast(d_to, ShiftLeft<16>(BitCast(du_to, v))); +#else + // On big-endian platforms, the even lanes of the source vector are already + // in the upper 16 bits of the lanes of the bitcasted vector. + + // Need to simply zero out the lower 16 bits of each lane of the bitcasted + // vector. + return BitCast(d_to, + And(BitCast(du_to, v), Set(du_to, uint32_t{0xFFFF0000u}))); +#endif +} + +// BF16->F32 PromoteOddTo + +// NOTE: It is possible for FromTypeTag to be hwy::SignedTag or hwy::UnsignedTag +// instead of hwy::FloatTag on targets that use scalable vectors. + +// VBF16 is considered to be a bfloat16_t vector if TFromV is the same +// type as TFromV>> + +// The BF16->F32 PromoteEvenTo overload is only enabled if VBF16 is considered +// to be a bfloat16_t vector. +template >, + hwy::EnableIf, TFromV>()>* = nullptr> +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, DF32 d_to, + VBF16 v) { + const RebindToUnsigned du_to; +#if HWY_IS_LITTLE_ENDIAN + // On little-endian platforms, the odd lanes of the source vector are already + // in the upper 16 bits of the lanes of the bitcasted vector. + + // Need to simply zero out the lower 16 bits of each lane of the bitcasted + // vector. + return BitCast(d_to, + And(BitCast(du_to, v), Set(du_to, uint32_t{0xFFFF0000u}))); +#else + // On big-endian platforms, need to shift left each lane of the bitcasted + // vector by 16 bits. + return BitCast(d_to, ShiftLeft<16>(BitCast(du_to, v))); +#endif +} + +// Default PromoteEvenTo/PromoteOddTo implementations +template +HWY_INLINE VFromD PromoteEvenTo( + ToTypeTag /*to_type_tag*/, hwy::SizeTag /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + return PromoteLowerTo(d_to, v); +} + +template +HWY_INLINE VFromD PromoteEvenTo( + ToTypeTag /*to_type_tag*/, hwy::SizeTag /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const DFromV d; + return PromoteLowerTo(d_to, ConcatEven(d, v, v)); +} + +template +HWY_INLINE VFromD PromoteOddTo( + ToTypeTag /*to_type_tag*/, hwy::SizeTag /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const DFromV d; + return PromoteLowerTo(d_to, ConcatOdd(d, v, v)); +} + +} // namespace detail + +template )), + class V2 = VFromD, D>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(V2))> +HWY_API VFromD PromoteEvenTo(D d, V v) { + return detail::PromoteEvenTo(hwy::TypeTag>(), + hwy::SizeTag)>(), + hwy::TypeTag>(), d, v); +} + +template )), + class V2 = VFromD, D>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(V2))> +HWY_API VFromD PromoteOddTo(D d, V v) { + return detail::PromoteOddTo(hwy::TypeTag>(), + hwy::SizeTag)>(), + hwy::TypeTag>(), d, v); +} +#endif // HWY_TARGET != HWY_SCALAR + +#ifdef HWY_INSIDE_END_NAMESPACE +#undef HWY_INSIDE_END_NAMESPACE +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); +#endif diff --git a/third_party/aom/third_party/highway/hwy/ops/loongarch_lsx-inl.h b/third_party/aom/third_party/highway/hwy/ops/loongarch_lsx-inl.h new file mode 100644 index 000000000000..035e38b97847 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/loongarch_lsx-inl.h @@ -0,0 +1,16 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO: fill \ No newline at end of file diff --git a/third_party/aom/third_party/highway/hwy/ops/ppc_vsx-inl.h b/third_party/aom/third_party/highway/hwy/ops/ppc_vsx-inl.h new file mode 100644 index 000000000000..02de0175a087 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/ppc_vsx-inl.h @@ -0,0 +1,7409 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit vectors for VSX/Z14 +// External include guard in highway.h - see comment there. + +#if HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 +#define HWY_S390X_HAVE_Z14 1 +#else +#define HWY_S390X_HAVE_Z14 0 +#endif + +#pragma push_macro("vector") +#pragma push_macro("pixel") +#pragma push_macro("bool") + +#undef vector +#undef pixel +#undef bool + +#if HWY_S390X_HAVE_Z14 +#include +#else +#include +#endif + +#pragma pop_macro("vector") +#pragma pop_macro("pixel") +#pragma pop_macro("bool") + +#include "third_party/highway/hwy/ops/shared-inl.h" + +// clang's altivec.h gates some intrinsics behind #ifdef __POWER10_VECTOR__, and +// some GCC do the same for _ARCH_PWR10. +// This means we can only use POWER10-specific intrinsics in static dispatch +// mode (where the -mpower10-vector compiler flag is passed). Same for PPC9. +// On other compilers, the usual target check is sufficient. +#if !HWY_S390X_HAVE_Z14 && HWY_TARGET <= HWY_PPC9 && \ + (defined(_ARCH_PWR9) || defined(__POWER9_VECTOR__)) +#define HWY_PPC_HAVE_9 1 +#else +#define HWY_PPC_HAVE_9 0 +#endif + +#if !HWY_S390X_HAVE_Z14 && HWY_TARGET <= HWY_PPC10 && \ + (defined(_ARCH_PWR10) || defined(__POWER10_VECTOR__)) +#define HWY_PPC_HAVE_10 1 +#else +#define HWY_PPC_HAVE_10 0 +#endif + +#if HWY_S390X_HAVE_Z14 && HWY_TARGET <= HWY_Z15 && __ARCH__ >= 13 +#define HWY_S390X_HAVE_Z15 1 +#else +#define HWY_S390X_HAVE_Z15 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template +struct Raw128; + +// Each Raw128 specialization defines the following typedefs: +// - type: +// the backing Altivec/VSX raw vector type of the Vec128 type +// - RawBoolVec: +// the backing Altivec/VSX raw __bool vector type of the Mask128 type +// - RawT: +// the lane type for intrinsics, in particular vec_splat +// - AlignedRawVec: +// the 128-bit GCC/Clang vector type for aligned loads/stores +// - UnalignedRawVec: +// the 128-bit GCC/Clang vector type for unaligned loads/stores +#define HWY_VSX_RAW128(LANE_TYPE, RAW_VECT_LANE_TYPE, RAW_BOOL_VECT_LANE_TYPE) \ + template <> \ + struct Raw128 { \ + using type = __vector RAW_VECT_LANE_TYPE; \ + using RawBoolVec = __vector __bool RAW_BOOL_VECT_LANE_TYPE; \ + using RawT = RAW_VECT_LANE_TYPE; \ + typedef LANE_TYPE AlignedRawVec \ + __attribute__((__vector_size__(16), __aligned__(16), __may_alias__)); \ + typedef LANE_TYPE UnalignedRawVec __attribute__(( \ + __vector_size__(16), __aligned__(alignof(LANE_TYPE)), __may_alias__)); \ + }; + +HWY_VSX_RAW128(int8_t, signed char, char) +HWY_VSX_RAW128(uint8_t, unsigned char, char) +HWY_VSX_RAW128(int16_t, signed short, short) // NOLINT(runtime/int) +HWY_VSX_RAW128(uint16_t, unsigned short, short) // NOLINT(runtime/int) +HWY_VSX_RAW128(int32_t, signed int, int) +HWY_VSX_RAW128(uint32_t, unsigned int, int) +HWY_VSX_RAW128(int64_t, signed long long, long long) // NOLINT(runtime/int) +HWY_VSX_RAW128(uint64_t, unsigned long long, long long) // NOLINT(runtime/int) +HWY_VSX_RAW128(float, float, int) +HWY_VSX_RAW128(double, double, long long) // NOLINT(runtime/int) + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +#undef HWY_VSX_RAW128 + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +// FF..FF or 0. +template +struct Mask128 { + typename detail::Raw128::RawBoolVec raw; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template > +HWY_API Vec128 Zero(D /* tag */) { + // There is no vec_splats for 64-bit, so we cannot rely on casting the 0 + // argument in order to select the correct overload. We instead cast the + // return vector type; see also the comment in BitCast. + return Vec128{ + reinterpret_cast::type>(vec_splats(0))}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ BitCast + +template +HWY_API VFromD BitCast(D /*d*/, + Vec128().MaxLanes()> v) { + // C-style casts are not sufficient when compiling with + // -fno-lax-vector-conversions, which will be the future default in Clang, + // but reinterpret_cast is. + return VFromD{ + reinterpret_cast>::type>(v.raw)}; +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D /*d*/, FromV v) { + // C-style casts are not sufficient when compiling with + // -fno-lax-vector-conversions, which will be the future default in Clang, + // but reinterpret_cast is. + return VFromD{ + reinterpret_cast>::type>(v.raw)}; +} + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template )> +HWY_API VFromD Set(D /* tag */, TFromD t) { + using RawLane = typename detail::Raw128>::RawT; + return VFromD{vec_splats(static_cast(t))}; +} + +template )> +HWY_API VFromD Set(D d, TFromD t) { + const RebindToUnsigned du; + return BitCast(d, Set(du, BitCastScalar>(t))); +} + +// Returns a vector with uninitialized elements. +template +HWY_API VFromD Undefined(D d) { +#if HWY_COMPILER_GCC_ACTUAL + // Suppressing maybe-uninitialized both here and at the caller does not work, + // so initialize. + return Zero(d); +#elif HWY_HAS_BUILTIN(__builtin_nondeterministic_value) + return VFromD{__builtin_nondeterministic_value(Zero(d).raw)}; +#else + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + typename detail::Raw128>::type raw; + return VFromD{raw}; + HWY_DIAGNOSTICS(pop) +#endif +} + +// ------------------------------ GetLane + +// Gets the single value stored in a vector/part. + +template +HWY_API T GetLane(Vec128 v) { + return static_cast(v.raw[0]); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + const typename detail::Raw128>::type raw = { + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15}; + return VFromD{raw}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const typename detail::Raw128>::type raw = {t0, t1, t2, t3, + t4, t5, t6, t7}; + return VFromD{raw}; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToUnsigned du; + return BitCast( + d, Dup128VecFromValues( + du, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + const typename detail::Raw128>::type raw = {t0, t1, t2, t3}; + return VFromD{raw}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + const typename detail::Raw128>::type raw = {t0, t1}; + return VFromD{raw}; +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU{BitCast(du, a).raw & BitCast(du, b).raw}); +#else + return BitCast(d, VU{vec_and(BitCast(du, a).raw, BitCast(du, b).raw)}); +#endif +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast( + d, VU{vec_andc(BitCast(du, mask).raw, BitCast(du, not_mask).raw)}); +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU{BitCast(du, a).raw | BitCast(du, b).raw}); +#else + return BitCast(d, VU{vec_or(BitCast(du, a).raw, BitCast(du, b).raw)}); +#endif +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU{BitCast(du, a).raw ^ BitCast(du, b).raw}); +#else + return BitCast(d, VU{vec_xor(BitCast(du, a).raw, BitCast(du, b).raw)}); +#endif +} + +// ------------------------------ Not +template +HWY_API Vec128 Not(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{vec_nor(BitCast(du, v).raw, BitCast(du, v).raw)}); +} + +// ------------------------------ IsConstantRawAltivecVect +namespace detail { + +template +static HWY_INLINE bool IsConstantRawAltivecVect( + hwy::SizeTag<1> /* lane_size_tag */, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]) && + __builtin_constant_p(v[8]) && __builtin_constant_p(v[9]) && + __builtin_constant_p(v[10]) && __builtin_constant_p(v[11]) && + __builtin_constant_p(v[12]) && __builtin_constant_p(v[13]) && + __builtin_constant_p(v[14]) && __builtin_constant_p(v[15]); +} + +template +static HWY_INLINE bool IsConstantRawAltivecVect( + hwy::SizeTag<2> /* lane_size_tag */, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]); +} + +template +static HWY_INLINE bool IsConstantRawAltivecVect( + hwy::SizeTag<4> /* lane_size_tag */, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]); +} + +template +static HWY_INLINE bool IsConstantRawAltivecVect( + hwy::SizeTag<8> /* lane_size_tag */, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]); +} + +template +static HWY_INLINE bool IsConstantRawAltivecVect(RawV v) { + return IsConstantRawAltivecVect(hwy::SizeTag(), v); +} + +} // namespace detail + +// ------------------------------ TernaryLogic +#if HWY_PPC_HAVE_10 +namespace detail { + +// NOTE: the kTernLogOp bits of the PPC10 TernaryLogic operation are in reverse +// order of the kTernLogOp bits of AVX3 +// _mm_ternarylogic_epi64(a, b, c, kTernLogOp) +template +HWY_INLINE V TernaryLogic(V a, V b, V c) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const auto a_raw = BitCast(du, a).raw; + const auto b_raw = BitCast(du, b).raw; + const auto c_raw = BitCast(du, c).raw; + +#if HWY_COMPILER_GCC_ACTUAL + // Use inline assembly on GCC to work around GCC compiler bug + typename detail::Raw128>::type raw_ternlog_result; + __asm__("xxeval %x0,%x1,%x2,%x3,%4" + : "=wa"(raw_ternlog_result) + : "wa"(a_raw), "wa"(b_raw), "wa"(c_raw), + "n"(static_cast(kTernLogOp)) + :); +#else + const auto raw_ternlog_result = + vec_ternarylogic(a_raw, b_raw, c_raw, kTernLogOp); +#endif + + return BitCast(d, VU{raw_ternlog_result}); +} + +} // namespace detail +#endif // HWY_PPC_HAVE_10 + +// ------------------------------ Xor3 +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { +#if HWY_PPC_HAVE_10 +#if defined(__OPTIMIZE__) + if (static_cast(detail::IsConstantRawAltivecVect(x1.raw)) + + static_cast(detail::IsConstantRawAltivecVect(x2.raw)) + + static_cast(detail::IsConstantRawAltivecVect(x3.raw)) >= + 2) { + return Xor(x1, Xor(x2, x3)); + } else // NOLINT +#endif + { + return detail::TernaryLogic<0x69>(x1, x2, x3); + } +#else + return Xor(x1, Xor(x2, x3)); +#endif +} + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { +#if HWY_PPC_HAVE_10 +#if defined(__OPTIMIZE__) + if (static_cast(detail::IsConstantRawAltivecVect(o1.raw)) + + static_cast(detail::IsConstantRawAltivecVect(o2.raw)) + + static_cast(detail::IsConstantRawAltivecVect(o3.raw)) >= + 2) { + return Or(o1, Or(o2, o3)); + } else // NOLINT +#endif + { + return detail::TernaryLogic<0x7F>(o1, o2, o3); + } +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { +#if HWY_PPC_HAVE_10 +#if defined(__OPTIMIZE__) + if (detail::IsConstantRawAltivecVect(a1.raw) && + detail::IsConstantRawAltivecVect(a2.raw)) { + return Or(o, And(a1, a2)); + } else // NOLINT +#endif + { + return detail::TernaryLogic<0x1F>(o, a1, a2); + } +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{vec_sel(BitCast(du, no).raw, BitCast(du, yes).raw, + BitCast(du, mask).raw)}); +} + +// ------------------------------ BitwiseIfThenElse + +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return IfVecThenElse(mask, yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(Vec128 a, Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(Vec128 a, Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(Vec128 a, Vec128 b) { + return Xor(a, b); +} + +// ================================================== SIGN + +// ------------------------------ Neg + +template +HWY_API Vec128 Neg(Vec128 v) { + // If T is an signed integer type, use Zero(d) - v instead of vec_neg to + // avoid undefined behavior in the case where v[i] == LimitsMin() + const DFromV d; + return Zero(d) - v; +} + +template +HWY_API Vec128 Neg(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return Xor(v, SignBit(DFromV())); +#else + return Vec128{vec_neg(v.raw)}; +#endif +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(Vec128 v) { + // If T is a signed integer type, use Max(v, Neg(v)) instead of vec_abs to + // avoid undefined behavior in the case where v[i] == LimitsMin(). + return Max(v, Neg(v)); +} + +template +HWY_API Vec128 Abs(Vec128 v) { + return Vec128{vec_abs(v.raw)}; +} + +// ------------------------------ CopySign + +#if HWY_S390X_HAVE_Z14 +template +HWY_API V CopySign(const V magn, const V sign) { + static_assert(IsFloat>(), "Only makes sense for floating-point"); + + const DFromV d; + const auto msb = SignBit(d); + + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + return BitwiseIfThenElse(msb, sign, magn); +} +#else // VSX +template +HWY_API Vec128 CopySign(Vec128 magn, + Vec128 sign) { + // Work around compiler bugs that are there with vec_cpsgn on older versions + // of GCC/Clang +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200 + return Vec128{__builtin_vec_copysign(magn.raw, sign.raw)}; +#elif HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 && \ + HWY_HAS_BUILTIN(__builtin_vsx_xvcpsgnsp) + return Vec128{__builtin_vsx_xvcpsgnsp(magn.raw, sign.raw)}; +#else + return Vec128{vec_cpsgn(sign.raw, magn.raw)}; +#endif +} + +template +HWY_API Vec128 CopySign(Vec128 magn, + Vec128 sign) { + // Work around compiler bugs that are there with vec_cpsgn on older versions + // of GCC/Clang +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200 + return Vec128{__builtin_vec_copysign(magn.raw, sign.raw)}; +#elif HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 && \ + HWY_HAS_BUILTIN(__builtin_vsx_xvcpsgndp) + return Vec128{__builtin_vsx_xvcpsgndp(magn.raw, sign.raw)}; +#else + return Vec128{vec_cpsgn(sign.raw, magn.raw)}; +#endif +} +#endif // HWY_S390X_HAVE_Z14 + +template +HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { + // PPC8 can also handle abs < 0, so no extra action needed. + static_assert(IsFloat(), "Only makes sense for floating-point"); + return CopySign(abs, sign); +} + +// ================================================== MEMORY (1) + +// Note: type punning is safe because the types are tagged with may_alias. +// (https://godbolt.org/z/fqrWjfjsP) + +// ------------------------------ Load + +template > +HWY_API Vec128 Load(D /* tag */, const T* HWY_RESTRICT aligned) { +// Suppress the ignoring attributes warning that is generated by +// HWY_RCAST_ALIGNED(const LoadRaw*, aligned) with GCC +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4649, ignored "-Wignored-attributes") +#endif + + using LoadRaw = typename detail::Raw128::AlignedRawVec; + const LoadRaw* HWY_RESTRICT p = HWY_RCAST_ALIGNED(const LoadRaw*, aligned); + using ResultRaw = typename detail::Raw128::type; + return Vec128{reinterpret_cast(*p)}; + +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(pop) +#endif +} + +// Any <= 64 bit +template > +HWY_API VFromD Load(D d, const T* HWY_RESTRICT p) { + using BitsT = UnsignedFromSize; + + BitsT bits; + const Repartition d_bits; + CopyBytes(p, &bits); + return BitCast(d, Set(d_bits, bits)); +} + +// ================================================== MASK + +// ------------------------------ Mask + +// Mask and Vec are both backed by vector types (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(Vec128 v) { + using Raw = typename detail::Raw128::RawBoolVec; + return Mask128{reinterpret_cast(v.raw)}; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API Vec128 VecFromMask(Mask128 v) { + return Vec128{ + reinterpret_cast::type>(v.raw)}; +} + +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VFromD{ + reinterpret_cast>::type>(v.raw)}; +} + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{vec_sel( + BitCast(du, no).raw, BitCast(du, yes).raw, mask.raw)}); +} + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(Mask128 m) { + return Mask128{vec_nor(m.raw, m.raw)}; +} + +template +HWY_API Mask128 And(Mask128 a, Mask128 b) { +#if HWY_S390X_HAVE_Z14 + return Mask128{a.raw & b.raw}; +#else + return Mask128{vec_and(a.raw, b.raw)}; +#endif +} + +template +HWY_API Mask128 AndNot(Mask128 a, Mask128 b) { + return Mask128{vec_andc(b.raw, a.raw)}; +} + +template +HWY_API Mask128 Or(Mask128 a, Mask128 b) { +#if HWY_S390X_HAVE_Z14 + return Mask128{a.raw | b.raw}; +#else + return Mask128{vec_or(a.raw, b.raw)}; +#endif +} + +template +HWY_API Mask128 Xor(Mask128 a, Mask128 b) { +#if HWY_S390X_HAVE_Z14 + return Mask128{a.raw ^ b.raw}; +#else + return Mask128{vec_xor(a.raw, b.raw)}; +#endif +} + +template +HWY_API Mask128 ExclusiveNeither(Mask128 a, Mask128 b) { + return Mask128{vec_nor(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(Vec128 v, const int bits) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + +#if HWY_S390X_HAVE_Z14 + return BitCast(d, + VFromD{BitCast(du, v).raw + << Set(du, static_cast(bits)).raw}); +#else + // Do an unsigned vec_sl operation to avoid undefined behavior + return BitCast( + d, VFromD{ + vec_sl(BitCast(du, v).raw, Set(du, static_cast(bits)).raw)}); +#endif +} + +// ------------------------------ ShiftRightSame + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + using TU = typename detail::Raw128>::RawT; +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw >> vec_splats(static_cast(bits))}; +#else + return Vec128{vec_sr(v.raw, vec_splats(static_cast(bits)))}; +#endif +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { +#if HWY_S390X_HAVE_Z14 + using TI = typename detail::Raw128::RawT; + return Vec128{v.raw >> vec_splats(static_cast(bits))}; +#else + using TU = typename detail::Raw128>::RawT; + return Vec128{vec_sra(v.raw, vec_splats(static_cast(bits)))}; +#endif +} + +// ------------------------------ ShiftLeft + +template +HWY_API Vec128 ShiftLeft(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return ShiftLeftSame(v, kBits); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec128 ShiftRight(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return ShiftRightSame(v, kBits); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec128 BroadcastSignBit(Vec128 v) { + return ShiftRightSame(v, static_cast(sizeof(T) * 8 - 1)); +} + +// ================================================== SWIZZLE (1) + +// ------------------------------ TableLookupBytes +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, + Vec128 from) { + const Repartition> du8_from; + return Vec128{reinterpret_cast::type>( + vec_perm(bytes.raw, bytes.raw, BitCast(du8_from, from).raw))}; +} + +// ------------------------------ TableLookupBytesOr0 +// For all vector widths; Altivec/VSX needs zero out +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + const DFromV di; + Repartition di8; + const VI zeroOutMask = BitCast(di, BroadcastSignBit(BitCast(di8, from))); + return AndNot(zeroOutMask, TableLookupBytes(bytes, from)); +} + +// ------------------------------ Reverse +#if HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 900 +// Workaround for missing vec_reve on Z14 with GCC 8 or earlier +template , HWY_IF_LANES_GT_D(D, 1), + HWY_IF_T_SIZE_D(D, 1)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, + 5, 4, 3, 2, 1, 0))); +} + +template , HWY_IF_LANES_GT_D(D, 1), + HWY_IF_T_SIZE_D(D, 2)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, + 4, 5, 2, 3, 0, 1))); +} + +template , HWY_IF_LANES_GT_D(D, 1), + HWY_IF_T_SIZE_D(D, 4)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, + 6, 7, 0, 1, 2, 3))); +} + +template , HWY_IF_LANES_GT_D(D, 1), + HWY_IF_T_SIZE_D(D, 8)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return Vec128{vec_sld(v.raw, v.raw, 8)}; +} +#else +template , HWY_IF_LANES_GT_D(D, 1)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return Vec128{vec_reve(v.raw)}; +} +#endif + +// ------------------------------ Shuffles (Reverse) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + const __vector unsigned char kShuffle = {4, 5, 6, 7, 0, 1, 2, 3, + 12, 13, 14, 15, 8, 9, 10, 11}; + return Vec128{vec_perm(v.raw, v.raw, kShuffle)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. As with +// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output +// comes from the first argument. +namespace detail { + +template +HWY_API Vec32 ShuffleTwo2301(Vec32 a, Vec32 b) { + const __vector unsigned char kShuffle16 = {1, 0, 19, 18}; + return Vec32{vec_perm(a.raw, b.raw, kShuffle16)}; +} +template +HWY_API Vec64 ShuffleTwo2301(Vec64 a, Vec64 b) { + const __vector unsigned char kShuffle = {2, 3, 0, 1, 22, 23, 20, 21}; + return Vec64{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec128 ShuffleTwo2301(Vec128 a, Vec128 b) { + const __vector unsigned char kShuffle = {4, 5, 6, 7, 0, 1, 2, 3, + 28, 29, 30, 31, 24, 25, 26, 27}; + return Vec128{vec_perm(a.raw, b.raw, kShuffle)}; +} + +template +HWY_API Vec32 ShuffleTwo1230(Vec32 a, Vec32 b) { + const __vector unsigned char kShuffle = {0, 3, 18, 17}; + return Vec32{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec64 ShuffleTwo1230(Vec64 a, Vec64 b) { + const __vector unsigned char kShuffle = {0, 1, 6, 7, 20, 21, 18, 19}; + return Vec64{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec128 ShuffleTwo1230(Vec128 a, Vec128 b) { + const __vector unsigned char kShuffle = {0, 1, 2, 3, 12, 13, 14, 15, + 24, 25, 26, 27, 20, 21, 22, 23}; + return Vec128{vec_perm(a.raw, b.raw, kShuffle)}; +} + +template +HWY_API Vec32 ShuffleTwo3012(Vec32 a, Vec32 b) { + const __vector unsigned char kShuffle = {2, 1, 16, 19}; + return Vec32{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec64 ShuffleTwo3012(Vec64 a, Vec64 b) { + const __vector unsigned char kShuffle = {4, 5, 2, 3, 16, 17, 22, 23}; + return Vec64{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec128 ShuffleTwo3012(Vec128 a, Vec128 b) { + const __vector unsigned char kShuffle = {8, 9, 10, 11, 4, 5, 6, 7, + 16, 17, 18, 19, 28, 29, 30, 31}; + return Vec128{vec_perm(a.raw, b.raw, kShuffle)}; +} + +} // namespace detail + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(Vec128 v) { + const Full128 d; + const Full128 du64; + return BitCast(d, Reverse(du64, BitCast(du64, v))); +} +template +HWY_API Vec128 Shuffle01(Vec128 v) { + return Reverse(Full128(), v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(Vec128 v) { +#if HWY_IS_LITTLE_ENDIAN + return Vec128{vec_sld(v.raw, v.raw, 12)}; +#else + return Vec128{vec_sld(v.raw, v.raw, 4)}; +#endif +} +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(Vec128 v) { +#if HWY_IS_LITTLE_ENDIAN + return Vec128{vec_sld(v.raw, v.raw, 4)}; +#else + return Vec128{vec_sld(v.raw, v.raw, 12)}; +#endif +} + +template +HWY_API Vec128 Shuffle0123(Vec128 v) { + return Reverse(Full128(), v); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo /*dto*/, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{vec_cmpeq(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +// ------------------------------ Strict inequality + +template +HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{vec_cmpgt(a.raw, b.raw)}; +} + +// ------------------------------ Weak inequality + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Mask128{vec_cmpge(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Not(b > a); +} + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + return b > a; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + return b >= a; +} + +// ================================================== MEMORY (2) + +// ------------------------------ Load +template > +HWY_API Vec128 LoadU(D /* tag */, const T* HWY_RESTRICT p) { + using LoadRaw = typename detail::Raw128::UnalignedRawVec; + const LoadRaw* HWY_RESTRICT praw = reinterpret_cast(p); + using ResultRaw = typename detail::Raw128::type; + return Vec128{reinterpret_cast(*praw)}; +} + +// For < 128 bit, LoadU == Load. +template > +HWY_API VFromD LoadU(D d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template > +HWY_API VFromD LoadDup128(D d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +#if (HWY_PPC_HAVE_9 && HWY_ARCH_PPC_64) || HWY_S390X_HAVE_Z14 +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +template > +HWY_API VFromD LoadN(D d, const T* HWY_RESTRICT p, + size_t max_lanes_to_load) { +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(max_lanes_to_load) && max_lanes_to_load == 0) { + return Zero(d); + } + + if (__builtin_constant_p(max_lanes_to_load >= HWY_MAX_LANES_D(D)) && + max_lanes_to_load >= HWY_MAX_LANES_D(D)) { + return LoadU(d, p); + } +#endif + + const size_t num_of_bytes_to_load = + HWY_MIN(max_lanes_to_load, HWY_MAX_LANES_D(D)) * sizeof(TFromD); + const Repartition du8; +#if HWY_S390X_HAVE_Z14 + return (num_of_bytes_to_load > 0) + ? BitCast(d, VFromD{vec_load_len( + const_cast( + reinterpret_cast(p)), + static_cast(num_of_bytes_to_load - 1))}) + : Zero(d); +#else + return BitCast( + d, + VFromD{vec_xl_len( + const_cast(reinterpret_cast(p)), + num_of_bytes_to_load)}); +#endif +} + +template > +HWY_API VFromD LoadNOr(VFromD no, D d, const T* HWY_RESTRICT p, + size_t max_lanes_to_load) { +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(max_lanes_to_load) && max_lanes_to_load == 0) { + return no; + } + + if (__builtin_constant_p(max_lanes_to_load >= HWY_MAX_LANES_D(D)) && + max_lanes_to_load >= HWY_MAX_LANES_D(D)) { + return LoadU(d, p); + } +#endif + + return IfThenElse(FirstN(d, max_lanes_to_load), + LoadN(d, p, max_lanes_to_load), no); +} + +#endif // HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14 + +// Returns a vector with lane i=[0, N) set to "first" + i. +namespace detail { + +template +HWY_INLINE VFromD Iota0(D d) { + constexpr __vector unsigned char kU8Iota0 = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + return BitCast(d, VFromD>{kU8Iota0}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + constexpr __vector unsigned short kU16Iota0 = {0, 1, 2, 3, 4, 5, 6, 7}; + return BitCast(d, VFromD>{kU16Iota0}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + constexpr __vector unsigned int kU32Iota0 = {0, 1, 2, 3}; + return BitCast(d, VFromD>{kU32Iota0}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + constexpr __vector unsigned long long kU64Iota0 = {0, 1}; + return BitCast(d, VFromD>{kU64Iota0}); +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + constexpr __vector float kF32Iota0 = {0.0f, 1.0f, 2.0f, 3.0f}; + return VFromD{kF32Iota0}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + constexpr __vector double kF64Iota0 = {0.0, 1.0}; + return VFromD{kF64Iota0}; +} + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + return detail::Iota0(d) + Set(d, static_cast>(first)); +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API MFromD FirstN(D d, size_t num) { + const RebindToUnsigned du; + using TU = TFromD; + return RebindMask(d, Iota(du, 0) < Set(du, static_cast(num))); +} + +// ------------------------------ MaskedLoad +template > +HWY_API VFromD MaskedLoad(MFromD m, D d, const T* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +// ------------------------------ MaskedLoadOr +template > +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const T* HWY_RESTRICT p) { + return IfThenElse(m, LoadU(d, p), v); +} + +// ------------------------------ Store + +template > +HWY_API void Store(Vec128 v, D /* tag */, T* HWY_RESTRICT aligned) { +// Suppress the ignoring attributes warning that is generated by +// HWY_RCAST_ALIGNED(StoreRaw*, aligned) with GCC +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4649, ignored "-Wignored-attributes") +#endif + + using StoreRaw = typename detail::Raw128::AlignedRawVec; + *HWY_RCAST_ALIGNED(StoreRaw*, aligned) = reinterpret_cast(v.raw); + +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(pop) +#endif +} + +template > +HWY_API void StoreU(Vec128 v, D /* tag */, T* HWY_RESTRICT p) { + using StoreRaw = typename detail::Raw128::UnalignedRawVec; + *reinterpret_cast(p) = reinterpret_cast(v.raw); +} + +template > +HWY_API void Store(VFromD v, D d, T* HWY_RESTRICT p) { + using BitsT = UnsignedFromSize; + + const Repartition d_bits; + const BitsT bits = GetLane(BitCast(d_bits, v)); + CopyBytes(&bits, p); +} + +// For < 128 bit, StoreU == Store. +template > +HWY_API void StoreU(VFromD v, D d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +#if (HWY_PPC_HAVE_9 && HWY_ARCH_PPC_64) || HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(max_lanes_to_store) && max_lanes_to_store == 0) { + return; + } + + if (__builtin_constant_p(max_lanes_to_store >= HWY_MAX_LANES_D(D)) && + max_lanes_to_store >= HWY_MAX_LANES_D(D)) { + StoreU(v, d, p); + return; + } +#endif + + const size_t num_of_bytes_to_store = + HWY_MIN(max_lanes_to_store, HWY_MAX_LANES_D(D)) * sizeof(TFromD); + const Repartition du8; +#if HWY_S390X_HAVE_Z14 + if (num_of_bytes_to_store > 0) { + vec_store_len(BitCast(du8, v).raw, reinterpret_cast(p), + static_cast(num_of_bytes_to_store - 1)); + } +#else + vec_xst_len(BitCast(du8, v).raw, reinterpret_cast(p), + num_of_bytes_to_store); +#endif +} +#endif + +// ------------------------------ BlendedStore + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const VFromD old = LoadU(d, p); + StoreU(IfThenElse(RebindMask(d, m), v, old), d, p); +} + +// ================================================== ARITHMETIC + +namespace detail { +// If TFromD is an integer type, detail::RebindToUnsignedIfNotFloat +// rebinds D to MakeUnsigned>. + +// Otherwise, if TFromD is a floating-point type (including F16 and BF16), +// detail::RebindToUnsignedIfNotFloat is the same as D. +template +using RebindToUnsignedIfNotFloat = + hwy::If<(!hwy::IsFloat>() && !hwy::IsSpecialFloat>()), + RebindToUnsigned, D>; +} // namespace detail + +// ------------------------------ Addition + +template +HWY_API Vec128 operator+(Vec128 a, Vec128 b) { + const DFromV d; + const detail::RebindToUnsignedIfNotFloat d_arith; + + // If T is an integer type, do an unsigned vec_add to avoid undefined behavior +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VFromD{BitCast(d_arith, a).raw + + BitCast(d_arith, b).raw}); +#else + return BitCast(d, VFromD{vec_add( + BitCast(d_arith, a).raw, BitCast(d_arith, b).raw)}); +#endif +} + +// ------------------------------ Subtraction + +template +HWY_API Vec128 operator-(Vec128 a, Vec128 b) { + const DFromV d; + const detail::RebindToUnsignedIfNotFloat d_arith; + + // If T is an integer type, do an unsigned vec_sub to avoid undefined behavior +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VFromD{BitCast(d_arith, a).raw - + BitCast(d_arith, b).raw}); +#else + return BitCast(d, VFromD{vec_sub( + BitCast(d_arith, a).raw, BitCast(d_arith, b).raw)}); +#endif +} + +// ------------------------------ SumsOf8 +template )> +HWY_API VFromD>> SumsOf8(V v) { + return SumsOf2(SumsOf4(v)); +} + +template )> +HWY_API VFromD>> SumsOf8(V v) { +#if HWY_S390X_HAVE_Z14 + const DFromV di8; + const RebindToUnsigned du8; + const RepartitionToWideX3 di64; + + return BitCast(di64, SumsOf8(BitCast(du8, Xor(v, SignBit(di8))))) + + Set(di64, int64_t{-1024}); +#else + return SumsOf2(SumsOf4(v)); +#endif +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +#if HWY_S390X_HAVE_Z14 +// Z14/Z15/Z16 does not have I8/U8/I16/U16 SaturatedAdd instructions unlike most +// other integer SIMD instruction sets + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + return Add(a, Min(b, Not(a))); +} + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); + const auto overflow_result = Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +#else // VSX + +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + return Vec128{vec_adds(a.raw, b.raw)}; +} +#endif // HWY_S390X_HAVE_Z14 + +#if HWY_PPC_HAVE_10 + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = + BroadcastSignBit(detail::TernaryLogic<0x42>(a, b, sum)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +#endif // HWY_PPC_HAVE_10 + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +#if HWY_S390X_HAVE_Z14 +// Z14/Z15/Z16 does not have I8/U8/I16/U16 SaturatedSub instructions unlike most +// other integer SIMD instruction sets + +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + return Sub(a, Min(a, b)); +} + +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); + const auto overflow_result = Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#else // VSX + +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + return Vec128{vec_subs(a.raw, b.raw)}; +} +#endif // HWY_S390X_HAVE_Z14 + +#if HWY_PPC_HAVE_10 + +template )> +HWY_API V SaturatedSub(V a, V b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = + BroadcastSignBit(detail::TernaryLogic<0x18>(a, b, diff)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#endif // HWY_PPC_HAVE_10 + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#if HWY_S390X_HAVE_Z14 +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +#define HWY_PPC_IF_AVERAGE_ROUND_T(T) void* = nullptr +#else // !HWY_S390X_HAVE_Z14 +#define HWY_PPC_IF_AVERAGE_ROUND_T(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4)) +#endif // HWY_S390X_HAVE_Z14 + +template +HWY_API Vec128 AverageRound(Vec128 a, Vec128 b) { + return Vec128{vec_avg(a.raw, b.raw)}; +} + +#undef HWY_PPC_IF_AVERAGE_ROUND_T + +// ------------------------------ Multiplication + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + const DFromV d; + const detail::RebindToUnsignedIfNotFloat d_arith; + + // If T is an integer type, do an unsigned vec_mul to avoid undefined behavior +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VFromD{BitCast(d_arith, a).raw * + BitCast(d_arith, b).raw}); +#else + return BitCast(d, VFromD{vec_mul( + BitCast(d_arith, a).raw, BitCast(d_arith, b).raw)}); +#endif +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. + +#if HWY_S390X_HAVE_Z14 +#define HWY_PPC_IF_MULHIGH_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4)) +#define HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH(T) \ + hwy::EnableIf()>* = nullptr +#elif HWY_PPC_HAVE_10 +#define HWY_PPC_IF_MULHIGH_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8)) +#define HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2)) +#else +#define HWY_PPC_IF_MULHIGH_USING_VEC_MULH(T) \ + hwy::EnableIf()>* = nullptr +#define HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4)) +#endif + +#if HWY_S390X_HAVE_Z14 || HWY_PPC_HAVE_10 +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + return Vec128{vec_mulh(a.raw, b.raw)}; +} +#endif + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + const auto p_even = MulEven(a, b); + +#if HWY_IS_LITTLE_ENDIAN + const auto p_even_full = ResizeBitCast(Full128(), p_even); + return Vec128{ + vec_sld(p_even_full.raw, p_even_full.raw, 16 - sizeof(T))}; +#else + const DFromV d; + return ResizeBitCast(d, p_even); +#endif +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + const DFromV d; + + const auto p_even = BitCast(d, MulEven(a, b)); + const auto p_odd = BitCast(d, MulOdd(a, b)); + +#if HWY_IS_LITTLE_ENDIAN + return InterleaveOdd(d, p_even, p_odd); +#else + return InterleaveEven(d, p_even, p_odd); +#endif +} + +#if !HWY_PPC_HAVE_10 +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T p_hi; + Mul128(GetLane(a), GetLane(b), &p_hi); + return Set(Full64(), p_hi); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + const DFromV d; + const Half dh; + return Combine(d, MulHigh(UpperHalf(dh, a), UpperHalf(dh, b)), + MulHigh(LowerHalf(dh, a), LowerHalf(dh, b))); +} +#endif // !HWY_PPC_HAVE_10 + +#undef HWY_PPC_IF_MULHIGH_USING_VEC_MULH +#undef HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH + +// Multiplies even lanes (0, 2, ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template +HWY_API Vec128, (N + 1) / 2> MulEven(Vec128 a, + Vec128 b) { + return Vec128, (N + 1) / 2>{vec_mule(a.raw, b.raw)}; +} + +// Multiplies odd lanes (1, 3, ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template +HWY_API Vec128, (N + 1) / 2> MulOdd(Vec128 a, + Vec128 b) { + return Vec128, (N + 1) / 2>{vec_mulo(a.raw, b.raw)}; +} + +// ------------------------------ Rol/Ror + +#ifdef HWY_NATIVE_ROL_ROR_8 +#undef HWY_NATIVE_ROL_ROR_8 +#else +#define HWY_NATIVE_ROL_ROR_8 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template +HWY_API Vec128 Rol(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{vec_rl(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToSigned di; + return Rol(a, BitCast(d, Neg(BitCast(di, b)))); +} + +// ------------------------------ RotateRight +template +HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + + return (kBits == 0) + ? v + : Rol(v, Set(d, static_cast(static_cast(kSizeInBits) - + kBits))); +} + +// ------------------------------ RotateLeftSame/RotateRightSame +#ifdef HWY_NATIVE_ROL_ROR_SAME_8 +#undef HWY_NATIVE_ROL_ROR_SAME_8 +#else +#define HWY_NATIVE_ROL_ROR_SAME_8 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +template +HWY_API Vec128 RotateLeftSame(Vec128 v, int bits) { + const DFromV d; + return Rol(v, Set(d, static_cast(static_cast(bits)))); +} + +template +HWY_API Vec128 RotateRightSame(Vec128 v, int bits) { + const DFromV d; + return Rol(v, Set(d, static_cast(0u - static_cast(bits)))); +} + +// ------------------------------ IfNegativeThenElse + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + + const DFromV d; +#if HWY_PPC_HAVE_10 + const RebindToUnsigned du; + return BitCast( + d, VFromD{vec_blendv( + BitCast(du, no).raw, BitCast(du, yes).raw, BitCast(du, v).raw)}); +#else + const RebindToSigned di; + return IfVecThenElse(BitCast(d, BroadcastSignBit(BitCast(di, v))), yes, no); +#endif +} + +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#undef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#else +#define HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#endif + +#ifdef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#undef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#else +#define HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#endif + +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + const DFromV d; + return IfNegativeThenElse(v, yes, Zero(d)); +} + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + const DFromV d; + return IfNegativeThenElse(v, Zero(d), no); +} +#endif + +// generic_ops takes care of integer T. +template +HWY_API Vec128 AbsDiff(Vec128 a, Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Vec128{vec_madd(mul.raw, x.raw, add.raw)}; +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + // NOTE: the vec_nmsub operation below computes -(mul * x - add), + // which is equivalent to add - mul * x in the round-to-nearest + // and round-towards-zero rounding modes + return Vec128{vec_nmsub(mul.raw, x.raw, add.raw)}; +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Vec128{vec_msub(mul.raw, x.raw, sub.raw)}; +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + // NOTE: The vec_nmadd operation below computes -(mul * x + sub), + // which is equivalent to -mul * x - sub in the round-to-nearest + // and round-towards-zero rounding modes + return Vec128{vec_nmadd(mul.raw, x.raw, sub.raw)}; +} + +// ------------------------------ Floating-point div +// Approximate reciprocal + +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { +#if HWY_S390X_HAVE_Z14 + return Vec128{a.raw / b.raw}; +#else + return Vec128{vec_div(a.raw, b.raw)}; +#endif +} + +template +HWY_API Vec128 ApproximateReciprocal(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + return Set(d, T(1.0)) / v; +#else + return Vec128{vec_re(v.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +#if HWY_S390X_HAVE_Z14 +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + const auto half = v * Set(d, 0.5f); + // Initial guess based on log2(f) + const auto guess = BitCast( + d, Set(du, uint32_t{0x5F3759DFu}) - ShiftRight<1>(BitCast(du, v))); + // One Newton-Raphson iteration + return guess * NegMulAdd(half * guess, guess, Set(d, 1.5f)); +} +#else // VSX + +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{vec_rsqrte(v.raw)}; +} +#endif // HWY_S390X_HAVE_Z14 + +// Full precision square root +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{vec_sqrt(v.raw)}; +} + +// ------------------------------ GetBiasedExponent + +#if HWY_PPC_HAVE_9 + +#ifdef HWY_NATIVE_GET_BIASED_EXPONENT +#undef HWY_NATIVE_GET_BIASED_EXPONENT +#else +#define HWY_NATIVE_GET_BIASED_EXPONENT +#endif + +template +HWY_API VFromD>> GetBiasedExponent(V v) { + return VFromD>>{vec_extract_exp(v.raw)}; +} + +#endif // HWY_PPC_HAVE_9 + +// ------------------------------ Min (Gt, IfThenElse) + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{vec_min(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{vec_max(a.raw, b.raw)}; +} + +// ------------------------------- Integer AbsDiff for PPC9/PPC10 + +#if HWY_PPC_HAVE_9 +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +template +HWY_API V AbsDiff(const V a, const V b) { + return V{vec_absd(a.raw, b.raw)}; +} + +template )> +HWY_API V AbsDiff(const V a, const V b) { + return Sub(Max(a, b), Min(a, b)); +} + +template +HWY_API V AbsDiff(const V a, const V b) { + return Sub(Max(a, b), Min(a, b)); +} + +#endif // HWY_PPC_HAVE_9 + +// ------------------------------ Integer Div for PPC10 +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for I32 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I32 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed int raw_result; + __asm__("vdivsw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for U32 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U32 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned int raw_result; + __asm__("vdivuw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for I64 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I64 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed long long raw_result; + __asm__("vdivsd %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for U64 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U64 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned long long raw_result; + __asm__("vdivud %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + return OrderedDemote2To(d, PromoteLowerTo(dw, a) / PromoteLowerTo(dw, b), + PromoteUpperTo(dw, a) / PromoteUpperTo(dw, b)); +} + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + const DFromV d; + const Rebind, decltype(d)> dw; + return DemoteTo(d, PromoteTo(dw, a) / PromoteTo(dw, b)); +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for I32 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I32 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed int raw_result; + __asm__("vmodsw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for U32 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U32 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned int raw_result; + __asm__("vmoduw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for I64 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I64 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed long long raw_result; + __asm__("vmodsd %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for U64 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U64 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned long long raw_result; + __asm__("vmodud %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + return OrderedDemote2To(d, PromoteLowerTo(dw, a) % PromoteLowerTo(dw, b), + PromoteUpperTo(dw, a) % PromoteUpperTo(dw, b)); +} + +template +HWY_API Vec128 operator%(Vec128 a, Vec128 b) { + const DFromV d; + const Rebind, decltype(d)> dw; + return DemoteTo(d, PromoteTo(dw, a) % PromoteTo(dw, b)); +} +#endif + +// ================================================== MEMORY (3) + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + __builtin_prefetch(aligned, 1, 0); + Store(v, d, aligned); +} + +// ------------------------------ Scatter in generic_ops-inl.h +// ------------------------------ Gather in generic_ops-inl.h + +// ================================================== SWIZZLE (2) + +// ------------------------------ LowerHalf + +// Returns upper/lower half of a vector. +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{v.raw}; +} +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128{v.raw}; +} + +// ------------------------------ ShiftLeftBytes + +// NOTE: The ShiftLeftBytes operation moves the elements of v to the right +// by kBytes bytes and zeroes out the first kBytes bytes of v on both +// little-endian and big-endian PPC targets +// (same behavior as the HWY_EMU128 ShiftLeftBytes operation on both +// little-endian and big-endian targets) + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + if (kBytes == 0) return v; + const auto zeros = Zero(d); +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_sld(v.raw, zeros.raw, kBytes)}; +#else + return VFromD{vec_sld(zeros.raw, v.raw, (-kBytes) & 15)}; +#endif +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +// NOTE: The ShiftLeftLanes operation moves the elements of v to the right +// by kLanes lanes and zeroes out the first kLanes lanes of v on both +// little-endian and big-endian PPC targets +// (same behavior as the HWY_EMU128 ShiftLeftLanes operation on both +// little-endian and big-endian targets) + +template > +HWY_API VFromD ShiftLeftLanes(D d, VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes + +// NOTE: The ShiftRightBytes operation moves the elements of v to the left +// by kBytes bytes and zeroes out the last kBytes bytes of v on both +// little-endian and big-endian PPC targets +// (same behavior as the HWY_EMU128 ShiftRightBytes operation on both +// little-endian and big-endian targets) + +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + if (kBytes == 0) return v; + + // For partial vectors, clear upper lanes so we shift in zeros. + if (d.MaxBytes() != 16) { + const Full128> dfull; + VFromD vfull{v.raw}; + v = VFromD{IfThenElseZero(FirstN(dfull, MaxLanes(d)), vfull).raw}; + } + + const auto zeros = Zero(d); +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_sld(zeros.raw, v.raw, (-kBytes) & 15)}; +#else + return VFromD{vec_sld(v.raw, zeros.raw, kBytes)}; +#endif +} + +// ------------------------------ ShiftRightLanes + +// NOTE: The ShiftRightLanes operation moves the elements of v to the left +// by kLanes lanes and zeroes out the last kLanes lanes of v on both +// little-endian and big-endian PPC targets +// (same behavior as the HWY_EMU128 ShiftRightLanes operation on both +// little-endian and big-endian targets) + +template +HWY_API VFromD ShiftRightLanes(D d, VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + return LowerHalf(d, ShiftRightBytes(Twice(), v)); +} + +// ------------------------------ ExtractLane +template +HWY_API T ExtractLane(Vec128 v, size_t i) { + return static_cast(v.raw[i]); +} + +// ------------------------------ InsertLane +template +HWY_API Vec128 InsertLane(Vec128 v, size_t i, T t) { +#if HWY_IS_LITTLE_ENDIAN + typename detail::Raw128::type raw_result = v.raw; + raw_result[i] = BitCastScalar::RawT>(t); + return Vec128{raw_result}; +#else + // On ppc64be without this, mul_test fails, but swizzle_test passes. + DFromV d; + alignas(16) T lanes[16 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +#endif +} + +// ------------------------------ CombineShiftRightBytes + +// NOTE: The CombineShiftRightBytes operation below moves the elements of lo to +// the left by kBytes bytes and moves the elements of hi right by (d.MaxBytes() +// - kBytes) bytes on both little-endian and big-endian PPC targets. + +template > +HWY_API Vec128 CombineShiftRightBytes(D /*d*/, Vec128 hi, Vec128 lo) { + constexpr size_t kSize = 16; + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); +#if HWY_IS_LITTLE_ENDIAN + return Vec128{vec_sld(hi.raw, lo.raw, (-kBytes) & 15)}; +#else + return Vec128{vec_sld(lo.raw, hi.raw, kBytes)}; +#endif +} + +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + using V8 = Vec128; + const DFromV dfull8; + const Repartition, decltype(dfull8)> dfull; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(dfull8, hi8, lo8); + return VFromD{BitCast(dfull, r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{vec_splat(v.raw, kLane)}; +} + +// ------------------------------ TableLookupLanes (Shuffle01) + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices128 { + __vector unsigned char raw; +}; + +namespace detail { + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + return Iota(d8, 0); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + constexpr __vector unsigned char kBroadcastLaneBytes = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; +#else + constexpr __vector unsigned char kBroadcastLaneBytes = { + 1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}; +#endif + return VFromD{kBroadcastLaneBytes}; +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + constexpr __vector unsigned char kBroadcastLaneBytes = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; +#else + constexpr __vector unsigned char kBroadcastLaneBytes = { + 3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15}; +#endif + return VFromD{kBroadcastLaneBytes}; +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + constexpr __vector unsigned char kBroadcastLaneBytes = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; +#else + constexpr __vector unsigned char kBroadcastLaneBytes = { + 7, 7, 7, 7, 7, 7, 7, 7, 15, 15, 15, 15, 15, 15, 15, 15}; +#endif + return VFromD{kBroadcastLaneBytes}; +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + return Zero(d8); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + constexpr __vector unsigned char kByteOffsets = {0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1}; + return VFromD{kByteOffsets}; +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + constexpr __vector unsigned char kByteOffsets = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + return VFromD{kByteOffsets}; +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + constexpr __vector unsigned char kByteOffsets = {0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7}; + return VFromD{kByteOffsets}; +} + +} // namespace detail + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + const Repartition d8; + return Indices128, MaxLanes(D())>{BitCast(d8, vec).raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + const Repartition d8; + using V8 = VFromD; + + // Broadcast each lane index to all bytes of T and shift to bytes + const V8 lane_indices = TableLookupBytes( + BitCast(d8, vec), detail::IndicesFromVecBroadcastLaneBytes(d)); + constexpr int kIndexShiftAmt = static_cast(FloorLog2(sizeof(T))); + const V8 byte_indices = ShiftLeft(lane_indices); + const V8 sum = Add(byte_indices, detail::IndicesFromVecByteOffsets(d)); + return Indices128, MaxLanes(D())>{sum.raw}; +} + +template +HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( + D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + const Repartition d8; + return BitCast(d, TableLookupBytes(v, VFromD{idx.raw})); +} + +// Single lane: no change +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 /* idx */) { + return v; +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Twice dt; + const Repartition dt_u8; +// TableLookupLanes currently requires table and index vectors to be the same +// size, though a half-length index vector would be sufficient here. +#if HWY_IS_MSAN + const Vec128 idx_vec{idx.raw}; + const Indices128 idx2{Combine(dt, idx_vec, idx_vec).raw}; +#else + // We only keep LowerHalf of the result, which is valid in idx. + const Indices128 idx2{idx.raw}; +#endif + return LowerHalf( + d, TableLookupBytes(Combine(dt, b, a), + BitCast(dt, VFromD{idx2.raw}))); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + return Vec128{vec_perm(a.raw, b.raw, idx.raw)}; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template , HWY_IF_LANES_D(D, 1)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return v; +} + +// 32-bit x2: shuffle +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec64 Reverse(D /* tag */, Vec64 v) { + return Vec64{Shuffle2301(Vec128{v.raw}).raw}; +} + +// 16-bit x4: shuffle +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 Reverse(D /* tag */, Vec64 v) { + const __vector unsigned char kShuffle = {6, 7, 4, 5, 2, 3, 0, 1, + 14, 15, 12, 13, 10, 11, 8, 9}; + return Vec64{vec_perm(v.raw, v.raw, kShuffle)}; +} + +// 16-bit x2: rotate bytes +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec32 Reverse(D d, Vec32 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +// ------------------------------- ReverseLaneBytes + +#if (HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14) && \ + ((!HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL >= 710) || \ + (HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL >= 900) || \ + HWY_COMPILER_CLANG >= 400) + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit ReverseLaneBytes. +#ifdef HWY_NATIVE_REVERSE_LANE_BYTES +#undef HWY_NATIVE_REVERSE_LANE_BYTES +#else +#define HWY_NATIVE_REVERSE_LANE_BYTES +#endif + +template +HWY_API V ReverseLaneBytes(V v) { + return V{vec_revb(v.raw)}; +} + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API VFromD Reverse2(D d, VFromD v) { + const Repartition du16; + return BitCast(d, ReverseLaneBytes(BitCast(du16, v))); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API VFromD Reverse4(D d, VFromD v) { + const Repartition du32; + return BitCast(d, ReverseLaneBytes(BitCast(du32, v))); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API VFromD Reverse8(D d, VFromD v) { + const Repartition du64; + return BitCast(d, ReverseLaneBytes(BitCast(du64, v))); +} + +#endif // HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14 + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec16 Reverse(D d, Vec16 v) { + return Reverse2(d, v); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 Reverse(D d, Vec32 v) { + return Reverse4(d, v); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 Reverse(D d, Vec64 v) { + return Reverse8(d, v); +} + +// ------------------------------ Reverse2 + +// Single lane: no change +template , HWY_IF_LANES_D(D, 1)> +HWY_API Vec128 Reverse2(D /* tag */, Vec128 v) { + return v; +} + +template , HWY_IF_T_SIZE(T, 2)> +HWY_API VFromD Reverse2(D d, VFromD v) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template , HWY_IF_T_SIZE(T, 4)> +HWY_API VFromD Reverse2(D d, VFromD v) { + const Repartition du64; + return BitCast(d, RotateRight<32>(BitCast(du64, v))); +} + +template , HWY_IF_T_SIZE(T, 8)> +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D /*d*/, VFromD v) { + const __vector unsigned char kShuffle = {6, 7, 4, 5, 2, 3, 0, 1, + 14, 15, 12, 13, 10, 11, 8, 9}; + return VFromD{vec_perm(v.raw, v.raw, kShuffle)}; +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + return Reverse(d, v); +} + +template +HWY_API VFromD Reverse4(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, VFromD v) { + return Reverse(d, v); +} + +template +HWY_API VFromD Reverse8(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{vec_mergeh(a.raw, b.raw)}; +} + +// Additional overload for the optional tag +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// Full +template > +HWY_API Vec128 InterleaveUpper(D /* tag */, Vec128 a, Vec128 b) { + return Vec128{vec_mergel(a.raw, b.raw)}; +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + return InterleaveLower(d, VFromD{UpperHalf(d2, a).raw}, + VFromD{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ------------------------------ Per4LaneBlkShufDupSet4xU32 + +// Used by hwy/ops/generic_ops-inl.h to implement Per4LaneBlockShuffle +namespace detail { + +#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#else +#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#endif + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + const __vector unsigned int raw = {x0, x1, x2, x3}; + return ResizeBitCast(d, Vec128{raw}); +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + using VU8 = VFromD; + const auto v_shift_amt = + BitCast(Full128(), + Set(Full128(), + static_cast(amt * sizeof(TFromD) * 8))); + +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU8{vec_srb(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else // VSX +#if HWY_IS_LITTLE_ENDIAN + return BitCast(d, VU8{vec_slo(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else + return BitCast(d, VU8{vec_sro(BitCast(du8, v).raw, v_shift_amt.raw)}); +#endif // HWY_IS_LITTLE_ENDIAN +#endif // HWY_S390X_HAVE_Z14 +} + +// ------------------------------ SlideDownLanes + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + using TU = UnsignedFromSize; + const Repartition du; + const auto v_shift_amt = + Set(du, static_cast(amt * sizeof(TFromD) * 8)); + +#if HWY_IS_LITTLE_ENDIAN + return BitCast(d, BitCast(du, v) >> v_shift_amt); +#else + return BitCast(d, BitCast(du, v) << v_shift_amt); +#endif +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + using VU8 = VFromD; + const auto v_shift_amt = + BitCast(Full128(), + Set(Full128(), + static_cast(amt * sizeof(TFromD) * 8))); + +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU8{vec_slb(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else // VSX +#if HWY_IS_LITTLE_ENDIAN + return BitCast(d, VU8{vec_sro(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else + return BitCast(d, VU8{vec_slo(BitCast(du8, v).raw, v_shift_amt.raw)}); +#endif // HWY_IS_LITTLE_ENDIAN +#endif // HWY_S390X_HAVE_Z14 +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + using Raw = typename detail::Raw128>::type; + const VU lo{reinterpret_cast(lo_half.raw)}; + const VU hi{reinterpret_cast(hi_half.raw)}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const Half dh; + return IfThenElseZero(FirstN(d, MaxLanes(dh)), VFromD{lo.raw}); +} + +// ------------------------------ Concat full (InterleaveLower) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template > +HWY_API Vec128 ConcatLowerLower(D d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template > +HWY_API Vec128 ConcatUpperUpper(D d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template > +HWY_API Vec128 ConcatLowerUpper(D d, Vec128 hi, Vec128 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template > +HWY_API Vec128 ConcatUpperLower(D /*d*/, Vec128 hi, Vec128 lo) { + const __vector unsigned char kShuffle = {0, 1, 2, 3, 4, 5, 6, 7, + 24, 25, 26, 27, 28, 29, 30, 31}; + return Vec128{vec_perm(lo.raw, hi.raw, kShuffle)}; +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ TruncateTo + +template = sizeof(TFromD) * 2)>* = nullptr, + HWY_IF_LANES_D(D, 1)> +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + using Raw = typename detail::Raw128>::type; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{reinterpret_cast(v.raw)}; +#else + return VFromD{reinterpret_cast( + vec_sld(v.raw, v.raw, sizeof(FromT) - sizeof(TFromD)))}; +#endif +} + +namespace detail { + +template ) * 2), HWY_IF_LANES_GT_D(D, 1)> +HWY_API VFromD Truncate2To( + D /* tag */, Vec128().MaxLanes()> lo, + Vec128().MaxLanes()> hi) { + return VFromD{vec_pack(lo.raw, hi.raw)}; +} + +} // namespace detail + +template ) * 2), HWY_IF_LANES_GT_D(D, 1)> +HWY_API VFromD TruncateTo(D /* d */, + Vec128().MaxLanes()> v) { + return VFromD{vec_pack(v.raw, v.raw)}; +} + +template = sizeof(TFromD) * 4)>* = nullptr, + HWY_IF_LANES_GT_D(D, 1)> +HWY_API VFromD TruncateTo(D d, + Vec128().MaxLanes()> v) { + const Rebind, decltype(d)> d2; + return TruncateTo(d, TruncateTo(d2, v)); +} + +// ------------------------------ ConcatOdd (TruncateTo) + +// 8-bit full +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + const Repartition dw; + const RebindToUnsigned du; +#if HWY_IS_LITTLE_ENDIAN + // Right-shift 8 bits per u16 so we can pack. + const Vec128 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<8>(BitCast(dw, lo)); +#else + const Vec128 uH = BitCast(dw, hi); + const Vec128 uL = BitCast(dw, lo); +#endif + return BitCast(d, detail::Truncate2To(du, uL, uH)); +} + +// 8-bit x8 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 ConcatOdd(D /*d*/, Vec64 hi, Vec64 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactOddU8 = {1, 3, 5, 7, 17, 19, 21, 23}; + return Vec64{vec_perm(lo.raw, hi.raw, kCompactOddU8)}; +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatOdd(D /*d*/, Vec32 hi, Vec32 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactOddU8 = {1, 3, 17, 19}; + return Vec32{vec_perm(lo.raw, hi.raw, kCompactOddU8)}; +} + +// 16-bit full +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + const Repartition dw; + const RebindToUnsigned du; +#if HWY_IS_LITTLE_ENDIAN + const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); +#else + const Vec128 uH = BitCast(dw, hi); + const Vec128 uL = BitCast(dw, lo); +#endif + return BitCast(d, detail::Truncate2To(du, uL, uH)); +} + +// 16-bit x4 +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 ConcatOdd(D /*d*/, Vec64 hi, Vec64 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactOddU16 = {2, 3, 6, 7, 18, 19, 22, 23}; + return Vec64{vec_perm(lo.raw, hi.raw, kCompactOddU16)}; +} + +// 32-bit full +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { +#if HWY_IS_LITTLE_ENDIAN + (void)d; + const __vector unsigned char kShuffle = {4, 5, 6, 7, 12, 13, 14, 15, + 20, 21, 22, 23, 28, 29, 30, 31}; + return Vec128{vec_perm(lo.raw, hi.raw, kShuffle)}; +#else + const RebindToUnsigned du; + const Repartition dw; + return BitCast(d, detail::Truncate2To(du, BitCast(dw, lo), BitCast(dw, hi))); +#endif +} + +// Any type x2 +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (TruncateTo) + +// 8-bit full +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + const Repartition dw; + const RebindToUnsigned du; +#if HWY_IS_LITTLE_ENDIAN + const Vec128 uH = BitCast(dw, hi); + const Vec128 uL = BitCast(dw, lo); +#else + // Right-shift 8 bits per u16 so we can pack. + const Vec128 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<8>(BitCast(dw, lo)); +#endif + return BitCast(d, detail::Truncate2To(du, uL, uH)); +} + +// 8-bit x8 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 ConcatEven(D /*d*/, Vec64 hi, Vec64 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactEvenU8 = {0, 2, 4, 6, 16, 18, 20, 22}; + return Vec64{vec_perm(lo.raw, hi.raw, kCompactEvenU8)}; +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatEven(D /*d*/, Vec32 hi, Vec32 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactEvenU8 = {0, 2, 16, 18}; + return Vec32{vec_perm(lo.raw, hi.raw, kCompactEvenU8)}; +} + +// 16-bit full +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + // Isolate lower 16 bits per u32 so we can pack. + const Repartition dw; + const RebindToUnsigned du; +#if HWY_IS_LITTLE_ENDIAN + const Vec128 uH = BitCast(dw, hi); + const Vec128 uL = BitCast(dw, lo); +#else + const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); +#endif + return BitCast(d, detail::Truncate2To(du, uL, uH)); +} + +// 16-bit x4 +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 ConcatEven(D /*d*/, Vec64 hi, Vec64 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactEvenU16 = {0, 1, 4, 5, 16, 17, 20, 21}; + return Vec64{vec_perm(lo.raw, hi.raw, kCompactEvenU16)}; +} + +// 32-bit full +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { +#if HWY_IS_LITTLE_ENDIAN + const Repartition dw; + const RebindToUnsigned du; + return BitCast(d, detail::Truncate2To(du, BitCast(dw, lo), BitCast(dw, hi))); +#else + (void)d; + constexpr __vector unsigned char kShuffle = {0, 1, 2, 3, 8, 9, 10, 11, + 16, 17, 18, 19, 24, 25, 26, 27}; + return Vec128{vec_perm(lo.raw, hi.raw, kShuffle)}; +#endif +} + +// Any T x2 +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ OrderedTruncate2To (ConcatEven, ConcatOdd) +#ifdef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#undef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#else +#define HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#endif + +template ) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedTruncate2To(D d, V a, V b) { +#if HWY_IS_LITTLE_ENDIAN + return ConcatEven(d, BitCast(d, b), BitCast(d, a)); +#else + return ConcatOdd(d, BitCast(d, b), BitCast(d, a)); +#endif +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return v; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + const DFromV d; + const Repartition du8; + constexpr __vector unsigned char kShuffle = {0, 0, 2, 2, 4, 4, 6, 6, + 8, 8, 10, 10, 12, 12, 14, 14}; + return TableLookupBytes(v, BitCast(d, VFromD{kShuffle})); +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + const DFromV d; + const Repartition du8; + constexpr __vector unsigned char kShuffle = {0, 1, 0, 1, 4, 5, 4, 5, + 8, 9, 8, 9, 12, 13, 12, 13}; + return TableLookupBytes(v, BitCast(d, VFromD{kShuffle})); +} + +template +HWY_API Vec128 DupEven(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 0, 1, 2, 3, 0, 1, 2, 3, 8, 9, 10, + 11, 8, 9, 10, 11))); +#else + return Vec128{vec_mergee(v.raw, v.raw)}; +#endif +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + const DFromV d; + const Repartition du8; + constexpr __vector unsigned char kShuffle = {1, 1, 3, 3, 5, 5, 7, 7, + 9, 9, 11, 11, 13, 13, 15, 15}; + return TableLookupBytes(v, BitCast(d, VFromD{kShuffle})); +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + const DFromV d; + const Repartition du8; + constexpr __vector unsigned char kShuffle = {2, 3, 2, 3, 6, 7, 6, 7, + 10, 11, 10, 11, 14, 15, 14, 15}; + return TableLookupBytes(v, BitCast(d, VFromD{kShuffle})); +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, + 15, 12, 13, 14, 15))); +#else + return Vec128{vec_mergeo(v.raw, v.raw)}; +#endif +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_INLINE Vec128 OddEven(Vec128 a, Vec128 b) { + const DFromV d; + const __vector unsigned char mask = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfVecThenElse(BitCast(d, Vec128{mask}), b, a); +} + +template +HWY_INLINE Vec128 OddEven(Vec128 a, Vec128 b) { + const DFromV d; + const __vector unsigned char mask = {0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, + 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; + return IfVecThenElse(BitCast(d, Vec128{mask}), b, a); +} + +template +HWY_INLINE Vec128 OddEven(Vec128 a, Vec128 b) { + const DFromV d; + const __vector unsigned char mask = {0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, + 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0}; + return IfVecThenElse(BitCast(d, Vec128{mask}), b, a); +} + +template +HWY_INLINE Vec128 OddEven(Vec128 a, Vec128 b) { + // Same as ConcatUpperLower for full vectors; do not call that because this + // is more efficient for 64x1 vectors. + const DFromV d; + const __vector unsigned char mask = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, 0, 0, 0, 0}; + return IfVecThenElse(BitCast(d, Vec128{mask}), b, a); +} + +// ------------------------------ InterleaveEven + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 0, 16, 2, 18, 4, 20, 6, 22, 8, 24, + 10, 26, 12, 28, 14, 30) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{Dup128VecFromValues(Full128(), 0, 1, + 16, 17, 4, 5, 20, 21, 8, + 9, 24, 25, 12, 13, 28, 29) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { +#if HWY_S390X_HAVE_Z14 + const Full128> d_full; + const Indices128> idx{Dup128VecFromValues(Full128(), 0, 1, + 2, 3, 16, 17, 18, 19, 8, + 9, 10, 11, 24, 25, 26, 27) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +#else + (void)d; + return VFromD{vec_mergee(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 1, 17, 3, 19, 5, 21, 7, 23, 9, 25, + 11, 27, 13, 29, 15, 31) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 2, 3, 18, 19, 6, 7, 22, 23, 10, + 11, 26, 27, 14, 15, 30, 31) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { +#if HWY_S390X_HAVE_Z14 + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 4, 5, 6, 7, 20, 21, 22, 23, 12, + 13, 14, 15, 28, 29, 30, 31) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +#else + (void)d; + return VFromD{vec_mergeo(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template > +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template > +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ MulFixedPoint15 (OddEven) + +#if HWY_S390X_HAVE_Z14 +HWY_API Vec16 MulFixedPoint15(Vec16 a, Vec16 b) { + const DFromV di16; + const RepartitionToWide di32; + + const auto round_up_incr = Set(di32, 0x4000); + const auto i32_product = MulEven(a, b) + round_up_incr; + + return ResizeBitCast(di16, ShiftLeft<1>(i32_product)); +} +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + const DFromV di16; + const RepartitionToWide di32; + + const auto round_up_incr = Set(di32, 0x4000); + const auto even_product = MulEven(a, b) + round_up_incr; + const auto odd_product = MulOdd(a, b) + round_up_incr; + + return OddEven(BitCast(di16, ShiftRight<15>(odd_product)), + BitCast(di16, ShiftLeft<1>(even_product))); +} +#else +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + const Vec128 zero = Zero(Full128()); + return Vec128{vec_mradds(a.raw, b.raw, zero.raw)}; +} +#endif + +// ------------------------------ Shl + +namespace detail { +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw << bits.raw}; +#else + return Vec128{vec_sl(v.raw, bits.raw)}; +#endif +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec128 Shl(hwy::SignedTag /*tag*/, Vec128 v, + Vec128 bits) { + const DFromV di; + const RebindToUnsigned du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr + +namespace detail { +template +HWY_API Vec128 Shr(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw >> bits.raw}; +#else + return Vec128{vec_sr(v.raw, bits.raw)}; +#endif +} + +template +HWY_API Vec128 Shr(hwy::SignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw >> bits.raw}; +#else + const DFromV di; + const RebindToUnsigned du; + return Vec128{vec_sra(v.raw, BitCast(du, bits).raw)}; +#endif +} + +} // namespace detail + +template +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return detail::Shr(hwy::TypeTag(), v, bits); +} + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +template +HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + using V64 = typename detail::Raw128::type; + const V64 mul128_result = reinterpret_cast(vec_mule(a.raw, b.raw)); +#if HWY_IS_LITTLE_ENDIAN + return Vec128{mul128_result}; +#else + // Need to swap the two halves of mul128_result on big-endian targets as + // the upper 64 bits of the product are in lane 0 of mul128_result and + // the lower 64 bits of the product are in lane 1 of mul128_result + return Vec128{vec_sld(mul128_result, mul128_result, 8)}; +#endif +#else + alignas(16) T mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128(), mul); +#endif +} + +template +HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + using V64 = typename detail::Raw128::type; + const V64 mul128_result = reinterpret_cast(vec_mulo(a.raw, b.raw)); +#if HWY_IS_LITTLE_ENDIAN + return Vec128{mul128_result}; +#else + // Need to swap the two halves of mul128_result on big-endian targets as + // the upper 64 bits of the product are in lane 0 of mul128_result and + // the lower 64 bits of the product are in lane 1 of mul128_result + return Vec128{vec_sld(mul128_result, mul128_result, 8)}; +#endif +#else + alignas(16) T mul[2]; + const Full64 d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128(), mul); +#endif +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "third_party/highway/hwy/ops/inside-inl.h" + +// ------------------------------ WidenMulPairwiseAdd + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} + +// Even if N=1, the input is always at least 2 lanes, hence vec_msum is safe. +template >> +HWY_API VFromD WidenMulPairwiseAdd(D32 d32, V16 a, V16 b) { +#if HWY_S390X_HAVE_Z14 + (void)d32; + return MulEven(a, b) + MulOdd(a, b); +#else + return VFromD{vec_msum(a.raw, b.raw, Zero(d32).raw)}; +#endif +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +// Even if N=1, the input is always at least 2 lanes, hence vec_msum is safe. +template >> +HWY_API VFromD ReorderWidenMulAccumulate(D32 /*d32*/, V16 a, V16 b, + VFromD sum0, + VFromD& /*sum1*/) { +#if HWY_S390X_HAVE_Z14 + return MulEven(a, b) + MulOdd(a, b) + sum0; +#else + return VFromD{vec_msum(a.raw, b.raw, sum0.raw)}; +#endif +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec128 RearrangeToOddPlusEven(Vec128 sum0, + Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API Vec128 RearrangeToOddPlusEven( + Vec128 sum0, Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); +} + +// ------------------------------ SatWidenMulPairwiseAccumulate +#if !HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#else +#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#endif + +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{vec_msums(a.raw, b.raw, sum.raw)}; +} + +#endif // !HWY_S390X_HAVE_Z14 + +// ------------------------------ SumOfMulQuadAccumulate +#if !HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{vec_msum(a.raw, b.raw, sum.raw)}; +} + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD{vec_msum(b_i.raw, a_u.raw, sum.raw)}; +} + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Repartition du8; + + const auto result_sum_0 = + SumOfMulQuadAccumulate(di32, BitCast(du8, a), b, sum); + const auto result_sum_1 = ShiftLeft<8>(SumsOf4(And(b, BroadcastSignBit(a)))); + return result_sum_0 - result_sum_1; +} + +#endif // !HWY_S390X_HAVE_Z14 + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned to signed/unsigned: zero-extend. +template +HWY_API VFromD PromoteTo(D /* d */, + Vec128().MaxLanes()> v) { + // First pretend the input has twice the lanes - the upper half will be + // ignored by ZipLower. + const Rebind> d2; + const VFromD twice{v.raw}; + // Then cast to narrow as expected by ZipLower, in case the sign of FromT + // differs from that of D. + const RepartitionToNarrow dn; + +#if HWY_IS_LITTLE_ENDIAN + return ZipLower(BitCast(dn, twice), Zero(dn)); +#else + return ZipLower(Zero(dn), BitCast(dn, twice)); +#endif +} + +// Signed: replicate sign bit. +template +HWY_API VFromD PromoteTo(D /* d */, + Vec128().MaxLanes()> v) { + using Raw = typename detail::Raw128>::type; + return VFromD{reinterpret_cast(vec_unpackh(v.raw))}; +} + +// 8-bit to 32-bit: First, promote to 16-bit, and then convert to 32-bit. +template +HWY_API VFromD PromoteTo(D d32, + Vec128().MaxLanes()> v) { + const DFromV d8; + const Rebind, decltype(d8)> d16; + return PromoteTo(d32, PromoteTo(d16, v)); +} + +// 8-bit or 16-bit to 64-bit: First, promote to MakeWide, and then +// convert to 64-bit. +template +HWY_API VFromD PromoteTo(D d64, + Vec128().MaxLanes()> v) { + const Rebind, decltype(d64)> dw; + return PromoteTo(d64, PromoteTo(dw, v)); +} + +#if HWY_PPC_HAVE_9 + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, VFromD> v) { + return VFromD{vec_extract_fp32_from_shorth(v.raw)}; +} + +#endif // HWY_PPC_HAVE_9 + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + const __vector float raw_v = InterleaveLower(v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#elif HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler errors with GCC 9 or earlier on Z14 + return VFromD{__builtin_s390_vflls(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +} + +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { +#if HWY_S390X_HAVE_Z14 + const RebindToSigned di64; + return ConvertTo(df64, PromoteTo(di64, v)); +#else // VSX + (void)df64; + const __vector signed int raw_v = InterleaveLower(v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +#endif // HWY_S390X_HAVE_Z14 +} + +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du64; + return ConvertTo(df64, PromoteTo(du64, v)); +#else // VSX + (void)df64; + const __vector unsigned int raw_v = InterleaveLower(v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +#endif // HWY_S390X_HAVE_Z14 +} + +#if !HWY_S390X_HAVE_Z14 +namespace detail { + +template +static HWY_INLINE V VsxF2INormalizeSrcVals(V v) { +#if !defined(HWY_DISABLE_PPC_VSX_QEMU_F2I_WORKAROUND) + // Workaround for QEMU 7/8 VSX float to int conversion bug + return IfThenElseZero(v == v, v); +#else + return v; +#endif +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD>> +VsxXvcvspsxds(VF32 vf32) { + using VI64 = VFromD>>; +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500) || \ + HWY_HAS_BUILTIN(__builtin_vsx_xvcvspsxds) + // Use __builtin_vsx_xvcvspsxds if it is available (which is the case with + // GCC 4.8 through GCC 14 or Clang 13 or later on PPC8/PPC9/PPC10) + return VI64{__builtin_vsx_xvcvspsxds(vf32.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL >= 1500 && HWY_IS_LITTLE_ENDIAN + // On little-endian PPC8/PPC9/PPC10 with GCC 15 or later, use the F32->I64 + // vec_signedo intrinsic as the __builtin_vsx_xvcvspsxds intrinsic has been + // removed from GCC in GCC 15 + return VI64{vec_signedo(vf32.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL >= 1500 && HWY_IS_BIG_ENDIAN + // On big-endian PPC8/PPC9/PPC10 with GCC 15 or later, use the F32->I64 + // vec_signede intrinsic as the __builtin_vsx_xvcvspsxds intrinsic has been + // removed from GCC in GCC 15 + return VI64{vec_signede(vf32.raw)}; +#else + // Inline assembly fallback for older versions of Clang that do not have the + // __builtin_vsx_xvcvspsxds intrinsic + __vector signed long long raw_result; + __asm__("xvcvspsxds %x0, %x1" : "=wa"(raw_result) : "wa"(vf32.raw) :); + return VI64{raw_result}; +#endif +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD>> +VsxXvcvspuxds(VF32 vf32) { + using VU64 = VFromD>>; +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500) || \ + HWY_HAS_BUILTIN(__builtin_vsx_xvcvspuxds) + // Use __builtin_vsx_xvcvspuxds if it is available (which is the case with + // GCC 4.8 through GCC 14 or Clang 13 or later on PPC8/PPC9/PPC10) + return VU64{reinterpret_cast<__vector unsigned long long>( + __builtin_vsx_xvcvspuxds(vf32.raw))}; +#elif HWY_COMPILER_GCC_ACTUAL >= 1500 && HWY_IS_LITTLE_ENDIAN + // On little-endian PPC8/PPC9/PPC10 with GCC 15 or later, use the F32->U64 + // vec_unsignedo intrinsic as the __builtin_vsx_xvcvspuxds intrinsic has been + // removed from GCC in GCC 15 + return VU64{vec_unsignedo(vf32.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL >= 1500 && HWY_IS_BIG_ENDIAN + // On big-endian PPC8/PPC9/PPC10 with GCC 15 or later, use the F32->U64 + // vec_unsignedo intrinsic as the __builtin_vsx_xvcvspuxds intrinsic has been + // removed from GCC in GCC 15 + return VU64{vec_unsignede(vf32.raw)}; +#else + // Inline assembly fallback for older versions of Clang that do not have the + // __builtin_vsx_xvcvspuxds intrinsic + __vector unsigned long long raw_result; + __asm__("xvcvspuxds %x0, %x1" : "=wa"(raw_result) : "wa"(vf32.raw) :); + return VU64{raw_result}; +#endif +} + +} // namespace detail +#endif // !HWY_S390X_HAVE_Z14 + +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { +#if !HWY_S390X_HAVE_Z14 + const Repartition dt_f32; + const auto vt_f32 = ResizeBitCast(dt_f32, v); + return detail::VsxXvcvspsxds( + detail::VsxF2INormalizeSrcVals(InterleaveLower(vt_f32, vt_f32))); +#else + const RebindToFloat df64; + return ConvertTo(di64, PromoteTo(df64, v)); +#endif +} + +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { +#if !HWY_S390X_HAVE_Z14 + const Repartition dt_f32; + const auto vt_f32 = ResizeBitCast(dt_f32, v); + return detail::VsxXvcvspuxds( + detail::VsxF2INormalizeSrcVals(InterleaveLower(vt_f32, vt_f32))); +#else + const RebindToFloat df64; + return ConvertTo(du64, PromoteTo(df64, v)); +#endif +} + +// ------------------------------ PromoteUpperTo + +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// Unsigned to signed/unsigned: zero-extend. +template +HWY_API VFromD PromoteUpperTo(D d, Vec128 v) { + const RebindToUnsigned du; + const RepartitionToNarrow dn; + +#if HWY_IS_LITTLE_ENDIAN + return BitCast(d, ZipUpper(du, v, Zero(dn))); +#else + return BitCast(d, ZipUpper(du, Zero(dn), v)); +#endif +} + +// Signed: replicate sign bit. +template +HWY_API VFromD PromoteUpperTo(D /* d */, Vec128 v) { + using Raw = typename detail::Raw128>::type; + return VFromD{reinterpret_cast(vec_unpackl(v.raw))}; +} + +// F16 to F32 +template +HWY_API VFromD PromoteUpperTo(D df32, Vec128 v) { +#if HWY_PPC_HAVE_9 + (void)df32; + return VFromD{vec_extract_fp32_from_shortl(v.raw)}; +#else + const Rebind dh; + return PromoteTo(df32, UpperHalf(dh, v)); +#endif +} + +// BF16 to F32 +template +HWY_API VFromD PromoteUpperTo(D df32, Vec128 v) { + const Repartition du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteUpperTo(D /*tag*/, Vec128 v) { + const __vector float raw_v = InterleaveUpper(Full128(), v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#elif HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler error with GCC 9 or earlier on Z14 + return VFromD{__builtin_s390_vflls(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +} + +template +HWY_API VFromD PromoteUpperTo(D df64, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToSigned di64; + return ConvertTo(df64, PromoteUpperTo(di64, v)); +#else // VSX + (void)df64; + const __vector signed int raw_v = + InterleaveUpper(Full128(), v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +#endif // HWY_S390X_HAVE_Z14 +} + +template +HWY_API VFromD PromoteUpperTo(D df64, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du64; + return ConvertTo(df64, PromoteUpperTo(du64, v)); +#else // VSX + (void)df64; + const __vector unsigned int raw_v = + InterleaveUpper(Full128(), v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +#endif // HWY_S390X_HAVE_Z14 +} + +template +HWY_API VFromD PromoteUpperTo(D di64, Vec128 v) { +#if !HWY_S390X_HAVE_Z14 + (void)di64; + return detail::VsxXvcvspsxds( + detail::VsxF2INormalizeSrcVals(InterleaveUpper(Full128(), v, v))); +#else + const RebindToFloat df64; + return ConvertTo(di64, PromoteUpperTo(df64, v)); +#endif +} + +template +HWY_API VFromD PromoteUpperTo(D du64, Vec128 v) { +#if !HWY_S390X_HAVE_Z14 + (void)du64; + return detail::VsxXvcvspuxds( + detail::VsxF2INormalizeSrcVals(InterleaveUpper(Full128(), v, v))); +#else + const RebindToFloat df64; + return ConvertTo(du64, PromoteUpperTo(df64, v)); +#endif +} + +// Generic version for <=64 bit input/output +template +HWY_API VFromD PromoteUpperTo(D d, V v) { + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +namespace detail { + +// Signed to Signed PromoteEvenTo/PromoteOddTo for PPC9/PPC10 +#if HWY_PPC_HAVE_9 && \ + (HWY_COMPILER_GCC_ACTUAL >= 1200 || HWY_COMPILER_CLANG >= 1200) + +#if HWY_IS_LITTLE_ENDIAN +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signexti(v.raw)}; +} +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signextll(v.raw)}; +} +#else +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signexti(v.raw)}; +} +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signextll(v.raw)}; +} +#endif // HWY_IS_LITTLE_ENDIAN + +#endif // HWY_PPC_HAVE_9 + +// I32/U32/F32->F64 PromoteEvenTo +#if HWY_S390X_HAVE_Z14 +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_doublee(v.raw)}; +} +template )> +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const Rebind>, decltype(d_to)> dw; + return ConvertTo(d_to, PromoteEvenTo(dw, v)); +} +#else // VSX +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_doublee(v.raw)}; +} +#endif // HWY_S390X_HAVE_Z14 + +// F32->I64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // VsxXvcvspsxds expects the source values to be in the odd lanes on + // little-endian PPC, and the Shuffle2103 operation below will shift the even + // lanes of normalized_v into the odd lanes. + return VsxXvcvspsxds(Shuffle2103(normalized_v)); +#else + // VsxXvcvspsxds expects the source values to be in the even lanes on + // big-endian PPC. + return VsxXvcvspsxds(normalized_v); +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteEvenTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +// F32->U64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // VsxXvcvspuxds expects the source values to be in the odd lanes + // on little-endian PPC, and the Shuffle2103 operation below will shift the + // even lanes of normalized_v into the odd lanes. + return VsxXvcvspuxds(Shuffle2103(normalized_v)); +#else + // VsxXvcvspuxds expects the source values to be in the even lanes + // on big-endian PPC. + return VsxXvcvspuxds(normalized_v); +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteEvenTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +// I32/U32/F32->F64 PromoteOddTo +#if HWY_S390X_HAVE_Z14 +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { + return PromoteEvenTo(hwy::FloatTag(), hwy::SizeTag<8>(), hwy::FloatTag(), + d_to, V{vec_sld(v.raw, v.raw, 4)}); +} +template )> +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const Rebind>, decltype(d_to)> dw; + return ConvertTo(d_to, PromoteOddTo(dw, v)); +} +#else +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_doubleo(v.raw)}; +} +#endif + +// F32->I64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // VsxXvcvspsxds expects the source values to be in the odd lanes + // on little-endian PPC + return VsxXvcvspsxds(normalized_v); +#else + // VsxXvcvspsxds expects the source values to be in the even lanes + // on big-endian PPC, and the Shuffle0321 operation below will shift the odd + // lanes of normalized_v into the even lanes. + return VsxXvcvspsxds(Shuffle0321(normalized_v)); +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteOddTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +// F32->U64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // VsxXvcvspuxds expects the source values to be in the odd lanes + // on little-endian PPC + return VsxXvcvspuxds(normalized_v); +#else + // VsxXvcvspuxds expects the source values to be in the even lanes + // on big-endian PPC, and the Shuffle0321 operation below will shift the odd + // lanes of normalized_v into the even lanes. + return VsxXvcvspuxds(Shuffle0321(normalized_v)); +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteOddTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +} // namespace detail + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template ) * 2)> +HWY_API VFromD DemoteTo(D /* tag */, + Vec128().MaxLanes()> v) { + return VFromD{vec_packsu(v.raw, v.raw)}; +} + +template ) * 2)> +HWY_API VFromD DemoteTo(D /* tag */, + Vec128().MaxLanes()> v) { + return VFromD{vec_packs(v.raw, v.raw)}; +} + +template ) * 2)> +HWY_API VFromD DemoteTo(D /* tag */, + Vec128().MaxLanes()> v) { + return VFromD{vec_packs(v.raw, v.raw)}; +} + +template = sizeof(TFromD) * 4)>* = nullptr> +HWY_API VFromD DemoteTo(D d, + Vec128().MaxLanes()> v) { + const Rebind, D> d2; + return DemoteTo(d, DemoteTo(d2, v)); +} + +template = sizeof(TFromD) * 4)>* = nullptr> +HWY_API VFromD DemoteTo(D d, + Vec128().MaxLanes()> v) { + const Rebind, D> d2; + return DemoteTo(d, DemoteTo(d2, v)); +} + +template = sizeof(TFromD) * 4)>* = nullptr> +HWY_API VFromD DemoteTo(D d, + Vec128().MaxLanes()> v) { + const Rebind>, D> d2; + return DemoteTo(d, DemoteTo(d2, v)); +} + +#if HWY_PPC_HAVE_9 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvsphp)) + +// We already toggled HWY_NATIVE_F16C above. + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { +// Avoid vec_pack_to_short_fp32 on Clang because its implementation is buggy. +#if HWY_COMPILER_GCC_ACTUAL + (void)df16; + return VFromD{vec_pack_to_short_fp32(v.raw, v.raw)}; +#elif HWY_HAS_BUILTIN(__builtin_vsx_xvcvsphp) + // Work around bug in the clang implementation of vec_pack_to_short_fp32 + // by using the __builtin_vsx_xvcvsphp builtin on PPC9/PPC10 targets + // if the __builtin_vsx_xvcvsphp intrinsic is available + const RebindToUnsigned du16; + const Rebind du; + const VFromD bits16{ + reinterpret_cast<__vector unsigned int>(__builtin_vsx_xvcvsphp(v.raw))}; + return BitCast(df16, TruncateTo(du16, bits16)); +#else +#error "Only define the function if we have a native implementation" +#endif +} + +#endif // HWY_PPC_HAVE_9 + +#if HWY_PPC_HAVE_9 + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +namespace detail { + +// On big-endian PPC9, VsxXscvdphp converts vf64[0] to a F16, returned as an U64 +// vector with the resulting F16 bits in the lower 16 bits of U64 lane 0 + +// On little-endian PPC9, VsxXscvdphp converts vf64[1] to a F16, returned as +// an U64 vector with the resulting F16 bits in the lower 16 bits of U64 lane 1 +static HWY_INLINE Vec128 VsxXscvdphp(Vec128 vf64) { + // Inline assembly is needed for the PPC9 xscvdphp instruction as there is + // currently no intrinsic available for the PPC9 xscvdphp instruction + __vector unsigned long long raw_result; + __asm__("xscvdphp %x0, %x1" : "=wa"(raw_result) : "wa"(vf64.raw)); + return Vec128{raw_result}; +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToUnsigned du16; + const Rebind du64; + + const Full128 df64_full; +#if HWY_IS_LITTLE_ENDIAN + const auto bits16_as_u64 = + UpperHalf(du64, detail::VsxXscvdphp(Combine(df64_full, v, v))); +#else + const auto bits16_as_u64 = + LowerHalf(du64, detail::VsxXscvdphp(ResizeBitCast(df64_full, v))); +#endif + + return BitCast(df16, TruncateTo(du16, bits16_as_u64)); +} + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToUnsigned du16; + const Rebind du64; + const Rebind df64; + +#if HWY_IS_LITTLE_ENDIAN + const auto bits64_as_u64_0 = detail::VsxXscvdphp(InterleaveLower(df64, v, v)); + const auto bits64_as_u64_1 = detail::VsxXscvdphp(v); + const auto bits64_as_u64 = + InterleaveUpper(du64, bits64_as_u64_0, bits64_as_u64_1); +#else + const auto bits64_as_u64_0 = detail::VsxXscvdphp(v); + const auto bits64_as_u64_1 = detail::VsxXscvdphp(InterleaveUpper(df64, v, v)); + const auto bits64_as_u64 = + InterleaveLower(du64, bits64_as_u64_0, bits64_as_u64_1); +#endif + + return BitCast(df16, TruncateTo(du16, bits64_as_u64)); +} + +#elif HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +namespace detail { + +template +static HWY_INLINE VFromD DemoteToF32WithRoundToOdd( + DF32 df32, VFromD> v) { + const Twice dt_f32; + + __vector float raw_f32_in_even; + __asm__("vledb %0,%1,0,3" : "=v"(raw_f32_in_even) : "v"(v.raw)); + + const VFromD f32_in_even{raw_f32_in_even}; + return LowerHalf(df32, ConcatEven(dt_f32, f32_in_even, f32_in_even)); +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const Rebind df32; + return DemoteTo(df16, detail::DemoteToF32WithRoundToOdd(df32, v)); +} + +#endif // HWY_PPC_HAVE_9 + +#if HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +namespace detail { + +// VsxXvcvspbf16 converts a F32 vector to a BF16 vector, bitcasted to an U32 +// vector with the resulting BF16 bits in the lower 16 bits of each U32 lane +template +static HWY_INLINE VFromD> VsxXvcvspbf16( + D dbf16, VFromD> v) { + const Rebind du32; + const Repartition du32_as_du8; + + using VU32 = __vector unsigned int; + + // Even though the __builtin_vsx_xvcvspbf16 builtin performs a F32 to BF16 + // conversion, the __builtin_vsx_xvcvspbf16 intrinsic expects a + // __vector unsigned char argument (at least as of GCC 13 and Clang 17) + return VFromD>{reinterpret_cast( + __builtin_vsx_xvcvspbf16(BitCast(du32_as_du8, v).raw))}; +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { + const RebindToUnsigned du16; + return BitCast(dbf16, TruncateTo(du16, detail::VsxXvcvspbf16(dbf16, v))); +} + +#endif // HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) + +// Specializations for partial vectors because vec_packs sets lanes above 2*N. +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Twice dn_full; + const Repartition du32_full; + + const VFromD v_full{vec_packs(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN /*dn*/, V a, V b) { + return VFromD{vec_packs(a.raw, b.raw)}; +} + +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Twice dn_full; + const Repartition du32_full; + + const VFromD v_full{vec_packsu(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN /*dn*/, V a, V b) { + return VFromD{vec_packsu(a.raw, b.raw)}; +} + +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Twice dn_full; + const Repartition du32_full; + + const VFromD v_full{vec_packs(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN /*dn*/, V a, V b) { + return VFromD{vec_packs(a.raw, b.raw)}; +} + +#if HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) +template ), + HWY_IF_LANES_D(D, HWY_MAX_LANES_V(V) * 2)> +HWY_API VFromD ReorderDemote2To(D dbf16, V a, V b) { + const RebindToUnsigned du16; + const Half dh_bf16; + return BitCast(dbf16, + OrderedTruncate2To(du16, detail::VsxXvcvspbf16(dh_bf16, a), + detail::VsxXvcvspbf16(dh_bf16, b))); +} +#endif + +template ), class V, + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} + +#if HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) +template ), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} +#endif + +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { +#if HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler error with GCC 9 or earlier on Z14 + return Vec32{__builtin_s390_vflrd(v.raw, 0, 0)}; +#else + return Vec32{vec_floate(v.raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D d, Vec128 v) { +#if HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler error with GCC 9 or earlier on Z14 + const Vec128 f64_to_f32{__builtin_s390_vflrd(v.raw, 0, 0)}; +#elif HWY_S390X_HAVE_Z14 || HWY_IS_LITTLE_ENDIAN + const Vec128 f64_to_f32{vec_floate(v.raw)}; +#else + const Vec128 f64_to_f32{vec_floato(v.raw)}; +#endif + +#if HWY_S390X_HAVE_Z14 + const Twice dt; + return LowerHalf(d, ConcatEven(dt, f64_to_f32, f64_to_f32)); +#else + const RebindToUnsigned du; + const Rebind du64; + return Vec64{ + BitCast(d, TruncateTo(du, BitCast(du64, f64_to_f32))).raw}; +#endif +} + +template +HWY_API Vec32 DemoteTo(D di32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind di64; + return DemoteTo(di32, ConvertTo(di64, v)); +#else + (void)di32; + return Vec32{vec_signede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D di32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind di64; + return DemoteTo(di32, ConvertTo(di64, v)); +#else + (void)di32; + +#if HWY_IS_LITTLE_ENDIAN + const Vec128 f64_to_i32{ + vec_signede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#else + const Vec128 f64_to_i32{ + vec_signedo(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif + + const Rebind di64; + const Vec128 vi64 = BitCast(di64, f64_to_i32); + return Vec64{vec_pack(vi64.raw, vi64.raw)}; +#endif +} + +template +HWY_API Vec32 DemoteTo(D du32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind du64; + return DemoteTo(du32, ConvertTo(du64, v)); +#else + (void)du32; + return Vec32{vec_unsignede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D du32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind du64; + return DemoteTo(du32, ConvertTo(du64, v)); +#else + (void)du32; +#if HWY_IS_LITTLE_ENDIAN + const Vec128 f64_to_u32{ + vec_unsignede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#else + const Vec128 f64_to_u32{ + vec_unsignedo(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif + + const Rebind du64; + const Vec128 vu64 = BitCast(du64, f64_to_u32); + return Vec64{vec_pack(vu64.raw, vu64.raw)}; +#endif +} + +#if HWY_S390X_HAVE_Z14 +namespace detail { + +template )> +HWY_INLINE VFromD>> ConvToF64WithRoundToOdd(V v) { + __vector double raw_result; + // Use inline assembly to do a round-to-odd I64->F64 conversion on Z14 + __asm__("vcdgb %0,%1,0,3" : "=v"(raw_result) : "v"(v.raw)); + return VFromD>>{raw_result}; +} + +template )> +HWY_INLINE VFromD>> ConvToF64WithRoundToOdd(V v) { + __vector double raw_result; + // Use inline assembly to do a round-to-odd U64->F64 conversion on Z14 + __asm__("vcdlgb %0,%1,0,3" : "=v"(raw_result) : "v"(v.raw)); + return VFromD>>{raw_result}; +} + +} // namespace detail +#endif // HWY_S390X_HAVE_Z14 + +template +HWY_API Vec32 DemoteTo(D df32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX + (void)df32; + return Vec32{vec_floate(v.raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D df32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX +#if HWY_IS_LITTLE_ENDIAN + const Vec128 i64_to_f32{vec_floate(v.raw)}; +#else + const Vec128 i64_to_f32{vec_floato(v.raw)}; +#endif + + const RebindToUnsigned du32; + const Rebind du64; + return Vec64{ + BitCast(df32, TruncateTo(du32, BitCast(du64, i64_to_f32))).raw}; +#endif +} + +template +HWY_API Vec32 DemoteTo(D df32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX + (void)df32; + return Vec32{vec_floate(v.raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D df32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX +#if HWY_IS_LITTLE_ENDIAN + const Vec128 u64_to_f32{vec_floate(v.raw)}; +#else + const Vec128 u64_to_f32{vec_floato(v.raw)}; +#endif + + const RebindToUnsigned du; + const Rebind du64; + return Vec64{ + BitCast(df32, TruncateTo(du, BitCast(du64, u64_to_f32))).raw}; +#endif +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(Vec128 v) { + const Rebind> du16; + const Rebind du8; + return TruncateTo(du8, TruncateTo(du16, v)); +} +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +// Note: altivec.h vec_ct* currently contain C casts which triggers +// -Wdeprecate-lax-vec-conv-all warnings, so disable them. + +#if HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 +template +HWY_API VFromD ConvertTo(D df32, + Vec128().MaxLanes()> v) { + const Rebind df64; + return DemoteTo(df32, PromoteTo(df64, v)); +} +template +HWY_API VFromD ConvertTo(D df32, Vec128 v) { + const RepartitionToWide df64; + +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler error with GCC 9 or earlier on Z14 + const VFromD vf32_lo{ + __builtin_s390_vflrd(PromoteLowerTo(df64, v).raw, 0, 0)}; + const VFromD vf32_hi{ + __builtin_s390_vflrd(PromoteUpperTo(df64, v).raw, 0, 0)}; +#else + const VFromD vf32_lo{vec_floate(PromoteLowerTo(df64, v).raw)}; + const VFromD vf32_hi{vec_floate(PromoteUpperTo(df64, v).raw)}; +#endif + return ConcatEven(df32, vf32_hi, vf32_lo); +} +#else // Z15 or PPC +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { + HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") +#endif +#if HWY_S390X_HAVE_Z15 + return VFromD{vec_float(v.raw)}; +#else + return VFromD{vec_ctf(v.raw, 0)}; +#endif + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_TARGET == HWY_Z14 + +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { + return VFromD{vec_double(v.raw)}; +} + +// Truncates (rounds toward zero). +#if HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 +template +HWY_API VFromD ConvertTo(D di32, + Vec128().MaxLanes()> v) { + const Rebind di64; + return DemoteTo(di32, PromoteTo(di64, v)); +} +template +HWY_API VFromD ConvertTo(D di32, + Vec128().MaxLanes()> v) { + const RepartitionToWide di64; + return OrderedDemote2To(di32, PromoteLowerTo(di64, v), + PromoteUpperTo(di64, v)); +} +#else // Z15 or PPC +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { +#if defined(__OPTIMIZE__) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr int32_t kMinI32 = LimitsMin(); + constexpr int32_t kMaxI32 = LimitsMax(); + return Dup128VecFromValues( + D(), + (v.raw[0] >= -2147483648.0f) + ? ((v.raw[0] < 2147483648.0f) ? static_cast(v.raw[0]) + : kMaxI32) + : ((v.raw[0] < 0) ? kMinI32 : 0), + (v.raw[1] >= -2147483648.0f) + ? ((v.raw[1] < 2147483648.0f) ? static_cast(v.raw[1]) + : kMaxI32) + : ((v.raw[1] < 0) ? kMinI32 : 0), + (v.raw[2] >= -2147483648.0f) + ? ((v.raw[2] < 2147483648.0f) ? static_cast(v.raw[2]) + : kMaxI32) + : ((v.raw[2] < 0) ? kMinI32 : 0), + (v.raw[3] >= -2147483648.0f) + ? ((v.raw[3] < 2147483648.0f) ? static_cast(v.raw[3]) + : kMaxI32) + : ((v.raw[3] < 0) ? kMinI32 : 0)); + } +#endif + +#if HWY_S390X_HAVE_Z15 + // Use inline assembly on Z15 to avoid undefined behavior if v[i] is not in + // the range of an int32_t + __vector signed int raw_result; + __asm__("vcfeb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") +#endif + return VFromD{vec_cts(v.raw, 0)}; + HWY_DIAGNOSTICS(pop) +#endif // HWY_S390X_HAVE_Z15 +} +#endif // HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 + +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { +#if defined(__OPTIMIZE__) && (!HWY_COMPILER_CLANG || !HWY_S390X_HAVE_Z14) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr int64_t kMinI64 = LimitsMin(); + constexpr int64_t kMaxI64 = LimitsMax(); + return Dup128VecFromValues(D(), + (v.raw[0] >= -9223372036854775808.0) + ? ((v.raw[0] < 9223372036854775808.0) + ? static_cast(v.raw[0]) + : kMaxI64) + : ((v.raw[0] < 0) ? kMinI64 : 0LL), + (v.raw[1] >= -9223372036854775808.0) + ? ((v.raw[1] < 9223372036854775808.0) + ? static_cast(v.raw[1]) + : kMaxI64) + : ((v.raw[1] < 0) ? kMinI64 : 0LL)); + } +#endif + + // Use inline assembly to avoid undefined behavior if v[i] is not within the + // range of an int64_t + __vector signed long long raw_result; +#if HWY_S390X_HAVE_Z14 + __asm__("vcgdb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); +#else + __asm__("xvcvdpsxds %x0,%x1" + : "=wa"(raw_result) + : "wa"(detail::VsxF2INormalizeSrcVals(v).raw)); +#endif + return VFromD{raw_result}; +} + +#if HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 +template +HWY_API VFromD ConvertTo(D du32, + Vec128().MaxLanes()> v) { + const Rebind du64; + return DemoteTo(du32, PromoteTo(du64, v)); +} +template +HWY_API VFromD ConvertTo(D du32, + Vec128().MaxLanes()> v) { + const RepartitionToWide du64; + return OrderedDemote2To(du32, PromoteLowerTo(du64, v), + PromoteUpperTo(du64, v)); +} +#else // Z15 or VSX +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { +#if defined(__OPTIMIZE__) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr uint32_t kMaxU32 = LimitsMax(); + return Dup128VecFromValues( + D(), + (v.raw[0] >= 0.0f) + ? ((v.raw[0] < 4294967296.0f) ? static_cast(v.raw[0]) + : kMaxU32) + : 0, + (v.raw[1] >= 0.0f) + ? ((v.raw[1] < 4294967296.0f) ? static_cast(v.raw[1]) + : kMaxU32) + : 0, + (v.raw[2] >= 0.0f) + ? ((v.raw[2] < 4294967296.0f) ? static_cast(v.raw[2]) + : kMaxU32) + : 0, + (v.raw[3] >= 0.0f) + ? ((v.raw[3] < 4294967296.0f) ? static_cast(v.raw[3]) + : kMaxU32) + : 0); + } +#endif + +#if HWY_S390X_HAVE_Z15 + // Use inline assembly on Z15 to avoid undefined behavior if v[i] is not in + // the range of an uint32_t + __vector unsigned int raw_result; + __asm__("vclfeb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else // VSX + HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") +#endif + VFromD result{vec_ctu(v.raw, 0)}; + HWY_DIAGNOSTICS(pop) + return result; +#endif // HWY_S390X_HAVE_Z15 +} +#endif // HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 + +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { + HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") +#endif + +#if defined(__OPTIMIZE__) && (!HWY_COMPILER_CLANG || !HWY_S390X_HAVE_Z14) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr uint64_t kMaxU64 = LimitsMax(); + return Dup128VecFromValues( + D(), + (v.raw[0] >= 0.0) ? ((v.raw[0] < 18446744073709551616.0) + ? static_cast(v.raw[0]) + : kMaxU64) + : 0, + (v.raw[1] >= 0.0) ? ((v.raw[1] < 18446744073709551616.0) + ? static_cast(v.raw[1]) + : kMaxU64) + : 0); + } +#endif + + // Use inline assembly to avoid undefined behavior if v[i] is not within the + // range of an uint64_t + __vector unsigned long long raw_result; +#if HWY_S390X_HAVE_Z14 + __asm__("vclgdb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); +#else // VSX + __asm__("xvcvdpuxds %x0,%x1" + : "=wa"(raw_result) + : "wa"(detail::VsxF2INormalizeSrcVals(v).raw)); +#endif + return VFromD{raw_result}; +} + +// ------------------------------ Floating-point rounding (ConvertTo) + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(Vec128 v) { + return Vec128{vec_round(v.raw)}; +} + +template +HWY_API Vec128 Round(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return Vec128{vec_round(v.raw)}; +#else + return Vec128{vec_rint(v.raw)}; +#endif +} + +template +HWY_API Vec128, N> NearestInt(Vec128 v) { + const DFromV d; + const RebindToSigned di; + return ConvertTo(di, Round(v)); +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + return DemoteTo(di32, Round(v)); +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(Vec128 v) { + return Vec128{vec_trunc(v.raw)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(Vec128 v) { + return Vec128{vec_ceil(v.raw)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(Vec128 v) { + return Vec128{vec_floor(v.raw)}; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask128 IsNaN(Vec128 v) { + static_assert(IsFloat(), "Only for float"); + return v != v; +} + +template +HWY_API Mask128 IsInf(Vec128 v) { + static_assert(IsFloat(), "Only for float"); + using TU = MakeUnsigned; + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask( + d, + Eq(Add(vu, vu), Set(du, static_cast(hwy::MaxExponentTimes2())))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(Vec128 v) { + static_assert(IsFloat(), "Only for float"); + using TU = MakeUnsigned; + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent(hwy::MaxExponentTimes2())))); +} + +// ================================================== CRYPTO + +#if !HWY_S390X_HAVE_Z14 && !defined(HWY_DISABLE_PPC8_CRYPTO) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +namespace detail { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600 +using CipherTag = Full128; +#else +using CipherTag = Full128; +#endif // !HWY_COMPILER_CLANG +using CipherVec = VFromD; +} // namespace detail + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + const detail::CipherTag dc; + const Full128 du8; +#if HWY_IS_LITTLE_ENDIAN + return Reverse(du8, + BitCast(du8, detail::CipherVec{vec_cipher_be( + BitCast(dc, Reverse(du8, state)).raw, + BitCast(dc, Reverse(du8, round_key)).raw)})); +#else + return BitCast(du8, detail::CipherVec{vec_cipher_be( + BitCast(dc, state).raw, BitCast(dc, round_key).raw)}); +#endif +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + const detail::CipherTag dc; + const Full128 du8; +#if HWY_IS_LITTLE_ENDIAN + return Reverse(du8, + BitCast(du8, detail::CipherVec{vec_cipherlast_be( + BitCast(dc, Reverse(du8, state)).raw, + BitCast(dc, Reverse(du8, round_key)).raw)})); +#else + return BitCast(du8, detail::CipherVec{vec_cipherlast_be( + BitCast(dc, state).raw, BitCast(dc, round_key).raw)}); +#endif +} + +HWY_API Vec128 AESRoundInv(Vec128 state, + Vec128 round_key) { + const detail::CipherTag dc; + const Full128 du8; +#if HWY_IS_LITTLE_ENDIAN + return Xor(Reverse(du8, BitCast(du8, detail::CipherVec{vec_ncipher_be( + BitCast(dc, Reverse(du8, state)).raw, + Zero(dc).raw)})), + round_key); +#else + return Xor(BitCast(du8, detail::CipherVec{vec_ncipher_be( + BitCast(dc, state).raw, Zero(dc).raw)}), + round_key); +#endif +} + +HWY_API Vec128 AESLastRoundInv(Vec128 state, + Vec128 round_key) { + const detail::CipherTag dc; + const Full128 du8; +#if HWY_IS_LITTLE_ENDIAN + return Reverse(du8, + BitCast(du8, detail::CipherVec{vec_ncipherlast_be( + BitCast(dc, Reverse(du8, state)).raw, + BitCast(dc, Reverse(du8, round_key)).raw)})); +#else + return BitCast(du8, detail::CipherVec{vec_ncipherlast_be( + BitCast(dc, state).raw, BitCast(dc, round_key).raw)}); +#endif +} + +HWY_API Vec128 AESInvMixColumns(Vec128 state) { + const Full128 du8; + const auto zero = Zero(du8); + + // PPC8/PPC9/PPC10 does not have a single instruction for the AES + // InvMixColumns operation like ARM Crypto, SVE2 Crypto, or AES-NI do. + + // The AESInvMixColumns operation can be carried out on PPC8/PPC9/PPC10 + // by doing an AESLastRound operation with a zero round_key followed by an + // AESRoundInv operation with a zero round_key. + return AESRoundInv(AESLastRound(state, zero), zero); +} + +template +HWY_API Vec128 AESKeyGenAssist(Vec128 v) { + constexpr __vector unsigned char kRconXorMask = {0, 0, 0, 0, kRcon, 0, 0, 0, + 0, 0, 0, 0, kRcon, 0, 0, 0}; + constexpr __vector unsigned char kRotWordShuffle = { + 4, 5, 6, 7, 5, 6, 7, 4, 12, 13, 14, 15, 13, 14, 15, 12}; + const detail::CipherTag dc; + const Full128 du8; + const auto sub_word_result = + BitCast(du8, detail::CipherVec{vec_sbox_be(BitCast(dc, v).raw)}); + const auto rot_word_result = + TableLookupBytes(sub_word_result, Vec128{kRotWordShuffle}); + return Xor(rot_word_result, Vec128{kRconXorMask}); +} + +template +HWY_API Vec128 CLMulLower(Vec128 a, + Vec128 b) { + // NOTE: Lane 1 of both a and b need to be zeroed out for the + // vec_pmsum_be operation below as the vec_pmsum_be operation + // does a carryless multiplication of each 64-bit half and then + // adds the two halves using an bitwise XOR operation. + + const DFromV d; + const auto zero = Zero(d); + + using VU64 = __vector unsigned long long; + const VU64 pmsum_result = reinterpret_cast( + vec_pmsum_be(InterleaveLower(a, zero).raw, InterleaveLower(b, zero).raw)); + +#if HWY_IS_LITTLE_ENDIAN + return Vec128{pmsum_result}; +#else + // Need to swap the two halves of pmsum_result on big-endian targets as + // the upper 64 bits of the carryless multiplication result are in lane 0 of + // pmsum_result and the lower 64 bits of the carryless multiplication result + // are in lane 1 of mul128_result + return Vec128{vec_sld(pmsum_result, pmsum_result, 8)}; +#endif +} + +template +HWY_API Vec128 CLMulUpper(Vec128 a, + Vec128 b) { + // NOTE: Lane 0 of both a and b need to be zeroed out for the + // vec_pmsum_be operation below as the vec_pmsum_be operation + // does a carryless multiplication of each 64-bit half and then + // adds the two halves using an bitwise XOR operation. + + const DFromV d; + const auto zero = Zero(d); + + using VU64 = __vector unsigned long long; + const VU64 pmsum_result = reinterpret_cast( + vec_pmsum_be(vec_mergel(zero.raw, a.raw), vec_mergel(zero.raw, b.raw))); + +#if HWY_IS_LITTLE_ENDIAN + return Vec128{pmsum_result}; +#else + // Need to swap the two halves of pmsum_result on big-endian targets as + // the upper 64 bits of the carryless multiplication result are in lane 0 of + // pmsum_result and the lower 64 bits of the carryless multiplication result + // are in lane 1 of mul128_result + return Vec128{vec_sld(pmsum_result, pmsum_result, 8)}; +#endif +} + +#endif // !defined(HWY_DISABLE_PPC8_CRYPTO) + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE MFromD LoadMaskBits128(D /*d*/, uint64_t mask_bits) { +#if HWY_PPC_HAVE_10 + const Vec128 mask_vec{vec_genbm(mask_bits)}; + +#if HWY_IS_LITTLE_ENDIAN + return MFromD{MaskFromVec(mask_vec).raw}; +#else + return MFromD{MaskFromVec(Reverse(Full128(), mask_vec)).raw}; +#endif // HWY_IS_LITTLE_ENDIAN + +#else // PPC9 or earlier + const Full128 du8; + const Full128 du16; + const Vec128 vbits = + BitCast(du8, Set(du16, static_cast(mask_bits))); + + // Replicate bytes 8x such that each byte contains the bit that governs it. +#if HWY_IS_LITTLE_ENDIAN + const __vector unsigned char kRep8 = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; +#else + const __vector unsigned char kRep8 = {1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0}; +#endif // HWY_IS_LITTLE_ENDIAN + + const Vec128 rep8{vec_perm(vbits.raw, vbits.raw, kRep8)}; + const __vector unsigned char kBit = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return MFromD{TestBit(rep8, Vec128{kBit}).raw}; +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_INLINE MFromD LoadMaskBits128(D /*d*/, uint64_t mask_bits) { +#if HWY_PPC_HAVE_10 + const Vec128 mask_vec{vec_genhm(mask_bits)}; + +#if HWY_IS_LITTLE_ENDIAN + return MFromD{MaskFromVec(mask_vec).raw}; +#else + return MFromD{MaskFromVec(Reverse(Full128(), mask_vec)).raw}; +#endif // HWY_IS_LITTLE_ENDIAN + +#else // PPC9 or earlier + const __vector unsigned short kBit = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = + Set(Full128(), static_cast(mask_bits)); + return MFromD{TestBit(vmask_bits, Vec128{kBit}).raw}; +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_INLINE MFromD LoadMaskBits128(D /*d*/, uint64_t mask_bits) { +#if HWY_PPC_HAVE_10 + const Vec128 mask_vec{vec_genwm(mask_bits)}; + +#if HWY_IS_LITTLE_ENDIAN + return MFromD{MaskFromVec(mask_vec).raw}; +#else + return MFromD{MaskFromVec(Reverse(Full128(), mask_vec)).raw}; +#endif // HWY_IS_LITTLE_ENDIAN + +#else // PPC9 or earlier + const __vector unsigned int kBit = {1, 2, 4, 8}; + const auto vmask_bits = + Set(Full128(), static_cast(mask_bits)); + return MFromD{TestBit(vmask_bits, Vec128{kBit}).raw}; +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_INLINE MFromD LoadMaskBits128(D /*d*/, uint64_t mask_bits) { +#if HWY_PPC_HAVE_10 + const Vec128 mask_vec{vec_gendm(mask_bits)}; + +#if HWY_IS_LITTLE_ENDIAN + return MFromD{MaskFromVec(mask_vec).raw}; +#else + return MFromD{MaskFromVec(Reverse(Full128(), mask_vec)).raw}; +#endif // HWY_IS_LITTLE_ENDIAN + +#else // PPC9 or earlier + const __vector unsigned long long kBit = {1, 2}; + const auto vmask_bits = + Set(Full128(), static_cast(mask_bits)); + return MFromD{TestBit(vmask_bits, Vec128{kBit}).raw}; +#endif // HWY_PPC_HAVE_10 +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + // If there are 8 or fewer lanes, simply convert bits[0] to a uint64_t + uint64_t mask_bits = bits[0]; + + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + + return detail::LoadMaskBits128(d, mask_bits); +} + +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + // First, copy the mask bits to a uint16_t as there as there are at most + // 16 lanes in a vector. + + // Copying the mask bits to a uint16_t first will also ensure that the + // mask bits are loaded into the lower 16 bits on big-endian PPC targets. + uint16_t u16_mask_bits; + CopyBytes(bits, &u16_mask_bits); + +#if HWY_IS_LITTLE_ENDIAN + return detail::LoadMaskBits128(d, u16_mask_bits); +#else + // On big-endian targets, u16_mask_bits need to be byte swapped as bits + // contains the mask bits in little-endian byte order + + // GCC/Clang will optimize the load of u16_mask_bits and byte swap to a + // single lhbrx instruction on big-endian PPC targets when optimizations + // are enabled. +#if HWY_HAS_BUILTIN(__builtin_bswap16) + return detail::LoadMaskBits128(d, __builtin_bswap16(u16_mask_bits)); +#else + return detail::LoadMaskBits128( + d, static_cast((u16_mask_bits << 8) | (u16_mask_bits >> 8))); +#endif +#endif +} + +template +struct CompressIsPartition { + // generic_ops-inl does not guarantee IsPartition for 8-bit. + enum { value = (sizeof(T) != 1) }; +}; + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits128(d, mask_bits); +} + +// ------------------------------ StoreMaskBits + +namespace detail { + +// Returns the lowest N of the mask bits. +template +constexpr uint64_t OnlyActive(D d, uint64_t mask_bits) { + return (d.MaxBytes() == 16) ? mask_bits + : mask_bits & ((1ull << d.MaxLanes()) - 1); +} + +#if !HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN +// fallback for missing vec_extractm +template +HWY_INLINE uint64_t ExtractSignBits(Vec128 sign_bits, + __vector unsigned char bit_shuffle) { + // clang POWER8 and 9 targets appear to differ in their return type of + // vec_vbpermq: unsigned or signed, so cast to avoid a warning. + using VU64 = detail::Raw128::type; +#if HWY_S390X_HAVE_Z14 + const Vec128 extracted{ + reinterpret_cast(vec_bperm_u128(sign_bits.raw, bit_shuffle))}; +#else + const Vec128 extracted{ + reinterpret_cast(vec_vbpermq(sign_bits.raw, bit_shuffle))}; +#endif + return extracted.raw[HWY_IS_LITTLE_ENDIAN]; +} + +#endif // !HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN + +} // namespace detail + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const Repartition du8; + const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + +#if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN + return detail::OnlyActive(d, + static_cast(vec_extractm(sign_bits.raw))); +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + const __vector unsigned char kBitShuffle = {120, 112, 104, 96, 88, 80, 72, 64, + 56, 48, 40, 32, 24, 16, 8, 0}; + return detail::OnlyActive(d, detail::ExtractSignBits(sign_bits, kBitShuffle)); +#endif // HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned du; + + const Repartition du8; + const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + +#if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN + return detail::OnlyActive( + d, static_cast(vec_extractm(BitCast(du, sign_bits).raw))); +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + (void)du; +#if HWY_IS_LITTLE_ENDIAN + const __vector unsigned char kBitShuffle = { + 112, 96, 80, 64, 48, 32, 16, 0, 128, 128, 128, 128, 128, 128, 128, 128}; +#else + const __vector unsigned char kBitShuffle = { + 128, 128, 128, 128, 128, 128, 128, 128, 112, 96, 80, 64, 48, 32, 16, 0}; +#endif + return detail::OnlyActive(d, detail::ExtractSignBits(sign_bits, kBitShuffle)); +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned du; + + const Repartition du8; + const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + +#if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN + return detail::OnlyActive( + d, static_cast(vec_extractm(BitCast(du, sign_bits).raw))); +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + (void)du; +#if HWY_IS_LITTLE_ENDIAN + const __vector unsigned char kBitShuffle = {96, 64, 32, 0, 128, 128, + 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128}; +#else + const __vector unsigned char kBitShuffle = {128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, + 96, 64, 32, 0}; +#endif + return detail::OnlyActive(d, detail::ExtractSignBits(sign_bits, kBitShuffle)); +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned du; + + const Repartition du8; + const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + +#if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN + return detail::OnlyActive( + d, static_cast(vec_extractm(BitCast(du, sign_bits).raw))); +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + (void)du; +#if HWY_IS_LITTLE_ENDIAN + const __vector unsigned char kBitShuffle = {64, 0, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128}; +#else + const __vector unsigned char kBitShuffle = {128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, + 128, 128, 64, 0}; +#endif + return detail::OnlyActive(d, detail::ExtractSignBits(sign_bits, kBitShuffle)); +#endif // HWY_PPC_HAVE_10 +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + // For vectors with 8 or fewer lanes, simply cast the result of BitsFromMask + // to an uint8_t and store the result in bits[0]. + bits[0] = static_cast(BitsFromMask(d, mask)); + return sizeof(uint8_t); +} + +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + const auto mask_bits = BitsFromMask(d, mask); + + // First convert mask_bits to a uint16_t as we only want to store + // the lower 16 bits of mask_bits as there are 16 lanes in mask. + + // Converting mask_bits to a uint16_t first will also ensure that + // the lower 16 bits of mask_bits are stored instead of the upper 16 bits + // of mask_bits on big-endian PPC targets. +#if HWY_IS_LITTLE_ENDIAN + const uint16_t u16_mask_bits = static_cast(mask_bits); +#else + // On big-endian targets, the bytes of mask_bits need to be swapped + // as StoreMaskBits expects the mask bits to be stored in little-endian + // byte order. + + // GCC will also optimize the byte swap and CopyBytes operations below + // to a single sthbrx instruction when optimizations are enabled on + // big-endian PPC targets +#if HWY_HAS_BUILTIN(__builtin_bswap16) + const uint16_t u16_mask_bits = + __builtin_bswap16(static_cast(mask_bits)); +#else + const uint16_t u16_mask_bits = static_cast( + (mask_bits << 8) | (static_cast(mask_bits) >> 8)); +#endif +#endif + + CopyBytes(&u16_mask_bits, bits); + return sizeof(uint16_t); +} + +// ------------------------------ Mask testing + +template +HWY_API bool AllFalse(D d, MFromD mask) { + const RebindToUnsigned du; + return static_cast( + vec_all_eq(VecFromMask(du, RebindMask(du, mask)).raw, Zero(du).raw)); +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + const RebindToUnsigned du; + using TU = TFromD; + return static_cast(vec_all_eq(VecFromMask(du, RebindMask(du, mask)).raw, + Set(du, hwy::LimitsMax()).raw)); +} + +template +HWY_API bool AllFalse(D d, MFromD mask) { + const Full128> d_full; + constexpr size_t kN = MaxLanes(d); + return AllFalse(d_full, + And(MFromD{mask.raw}, FirstN(d_full, kN))); +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + const Full128> d_full; + constexpr size_t kN = MaxLanes(d); + return AllTrue( + d_full, Or(MFromD{mask.raw}, Not(FirstN(d_full, kN)))); +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + return PopCount(BitsFromMask(d, mask)); +} + +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) +namespace detail { + +template +static HWY_INLINE size_t VsxCntlzLsbb(V v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200 && \ + HWY_IS_LITTLE_ENDIAN + // Use inline assembly to work around bug in GCC 11 and earlier on + // little-endian PPC9 + int idx; + __asm__("vctzlsbb %0,%1" : "=r"(idx) : "v"(v.raw)); + return static_cast(idx); +#else + return static_cast(vec_cntlz_lsbb(v.raw)); +#endif +} + +template +static HWY_INLINE size_t VsxCnttzLsbb(V v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200 && \ + HWY_IS_LITTLE_ENDIAN + // Use inline assembly to work around bug in GCC 11 and earlier on + // little-endian PPC9 + int idx; + __asm__("vclzlsbb %0,%1" : "=r"(idx) : "v"(v.raw)); + return static_cast(idx); +#else + return static_cast(vec_cnttz_lsbb(v.raw)); +#endif +} + +} // namespace detail +#endif + +template > +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { +// For little-endian PPC10, BitsFromMask is already efficient. +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + if (detail::IsFull(d)) { + const Repartition d8; + const auto bytes = BitCast(d8, VecFromMask(d, mask)); + return detail::VsxCntlzLsbb(bytes) / sizeof(T); + } +#endif // HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + return Num0BitsBelowLS1Bit_Nonzero64(BitsFromMask(d, mask)); +} + +template > +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { +// For little-endian PPC10, BitsFromMask is already efficient. +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + constexpr size_t kN = 16 / sizeof(T); + if (detail::IsFull(d)) { + const Repartition d8; + const auto bytes = BitCast(d8, VecFromMask(d, mask)); + const size_t idx = detail::VsxCntlzLsbb(bytes) / sizeof(T); + return idx == kN ? -1 : static_cast(idx); + } +#endif // HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + const uint64_t mask_bits = BitsFromMask(d, mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +template > +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { +// For little-endian PPC10, BitsFromMask is already efficient. +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + if (detail::IsFull(d)) { + const Repartition d8; + const auto bytes = BitCast(d8, VecFromMask(d, mask)); + const size_t idx = detail::VsxCnttzLsbb(bytes) / sizeof(T); + return 16 / sizeof(T) - 1 - idx; + } +#endif // HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + return 63 - Num0BitsAboveMS1Bit_Nonzero64(BitsFromMask(d, mask)); +} + +template > +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { +// For little-endian PPC10, BitsFromMask is already efficient. +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + constexpr size_t kN = 16 / sizeof(T); + if (detail::IsFull(d)) { + const Repartition d8; + const auto bytes = BitCast(d8, VecFromMask(d, mask)); + const size_t idx = detail::VsxCnttzLsbb(bytes) / sizeof(T); + return idx == kN ? -1 : static_cast(kN - 1 - idx); + } +#endif // HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + const uint64_t mask_bits = BitsFromMask(d, mask); + return mask_bits ? intptr_t(63 - Num0BitsAboveMS1Bit_Nonzero64(mask_bits)) + : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +#if HWY_PPC_HAVE_10 +template +HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { + constexpr unsigned kGenPcvmMode = + (kIsCompress ? 1u : 0u) | (HWY_IS_LITTLE_ENDIAN ? 2u : 0u); + + // Inline assembly is used instead of the vec_genpcvm intrinsic to work around + // compiler bugs on little-endian PPC10 + typename detail::Raw128>::type idx; + __asm__("xxgenpcvbm %x0, %1, %2" + : "=wa"(idx) + : "v"(mask.raw), "i"(kGenPcvmMode)); + return VFromD{idx}; +} +template +HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { + constexpr unsigned kGenPcvmMode = + (kIsCompress ? 1u : 0u) | (HWY_IS_LITTLE_ENDIAN ? 2u : 0u); + + // Inline assembly is used instead of the vec_genpcvm intrinsic to work around + // compiler bugs on little-endian PPC10 + typename detail::Raw128>::type idx; + __asm__("xxgenpcvhm %x0, %1, %2" + : "=wa"(idx) + : "v"(mask.raw), "i"(kGenPcvmMode)); + return VFromD{idx}; +} +template +HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { + constexpr unsigned kGenPcvmMode = + (kIsCompress ? 1u : 0u) | (HWY_IS_LITTLE_ENDIAN ? 2u : 0u); + + // Inline assembly is used instead of the vec_genpcvm intrinsic to work around + // compiler bugs on little-endian PPC10 + typename detail::Raw128>::type idx; + __asm__("xxgenpcvwm %x0, %1, %2" + : "=wa"(idx) + : "v"(mask.raw), "i"(kGenPcvmMode)); + return VFromD{idx}; +} +#endif + +// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6. +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + // To reduce cache footprint, store lane indices and convert to byte indices + // (2*lane + 0..1), with the doubling baked into the table. It's not clear + // that the additional cost of unpacking nibbles is worthwhile. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + constexpr uint16_t kPairIndexIncrement = + HWY_IS_LITTLE_ENDIAN ? 0x0100 : 0x0001; + + return BitCast(d, pairs + Set(du, kPairIndexIncrement)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + // To reduce cache footprint, store lane indices and convert to byte indices + // (2*lane + 0..1), with the doubling baked into the table. It's not clear + // that the additional cost of unpacking nibbles is worthwhile. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + constexpr uint16_t kPairIndexIncrement = + HWY_IS_LITTLE_ENDIAN ? 0x0100 : 0x0001; + + return BitCast(d, pairs + Set(du, kPairIndexIncrement)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_API Vec128 CompressBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +template +HWY_API Vec128 CompressNotBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromNotBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +// General case, 1 byte +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return TableLookupBytes( + v, detail::CompressOrExpandIndicesFromMask(d, mask)); +} +#endif + +// General case, 2 or 4 bytes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return detail::CompressBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressNot + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +#if HWY_PPC_HAVE_10 +// General case, 1 byte +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + return TableLookupBytes( + v, detail::CompressOrExpandIndicesFromMask(d, Not(mask))); +} +#endif + +// General case, 2 or 4 bytes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::CompressBits(v, BitsFromMask(d, Not(mask))); + } + return detail::CompressNotBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +#if HWY_PPC_HAVE_10 +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + const DFromV d; + return Compress(v, LoadMaskBits(d, bits)); +} +#endif + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + // As there are at most 8 lanes in v if sizeof(TFromD) > 1, simply + // convert bits[0] to a uint64_t + uint64_t mask_bits = bits[0]; + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::CompressBits(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +#if HWY_PPC_HAVE_10 +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, m); + const auto indices = detail::CompressOrExpandIndicesFromMask(d, m); + const auto compressed = TableLookupBytes(v, indices); + StoreU(compressed, d, unaligned); + return count; +} +#endif + +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + return count; +} + +#if HWY_PPC_HAVE_10 +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, m); + const auto indices = detail::CompressOrExpandIndicesFromMask(d, m); + const auto compressed = TableLookupBytes(v, indices); + StoreN(compressed, d, unaligned, count); + return count; +} +#endif + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +#if (HWY_PPC_HAVE_9 && HWY_ARCH_PPC_64) || HWY_S390X_HAVE_Z14 + StoreN(compressed, d, unaligned, count); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + return count; +} + +#if HWY_PPC_HAVE_10 +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} +#endif + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + // As there are at most 8 lanes in v if sizeof(TFromD) > 1, simply + // convert bits[0] to a uint64_t + uint64_t mask_bits = bits[0]; + constexpr size_t kN = MaxLanes(d); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + const size_t count = PopCount(mask_bits); + + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + + return count; +} + +// ------------------------------ Expand +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const auto idx = detail::CompressOrExpandIndicesFromMask(d, mask); + return IfThenElseZero(mask, TableLookupBytes(v, idx)); +} + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + // Same as Compress, just zero out the mask=false lanes. + return IfThenElseZero(mask, Compress(v, mask)); +} + +// For single-element vectors, this is at least as fast as native. +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + return IfThenElseZero(mask, v); +} + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} +#endif // HWY_PPC_HAVE_10 + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Additional mask logical operations +namespace detail { + +#if HWY_IS_LITTLE_ENDIAN +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + return v; +} +template +HWY_INLINE V Per128BitBlkRevLanesOnBe(V v) { + return v; +} +#else +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + const DFromV d; + return Reverse8(d, v); +} +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + const DFromV d; + return Reverse4(d, v); +} +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + const DFromV d; + return Reverse2(d, v); +} +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + return v; +} +template +HWY_INLINE V Per128BitBlkRevLanesOnBe(V v) { + const DFromV d; + return Reverse(d, v); +} +#endif + +template +HWY_INLINE V I128Subtract(V a, V b) { +#if HWY_S390X_HAVE_Z14 +#if HWY_COMPILER_CLANG + // Workaround for bug in vec_sub_u128 in Clang vecintrin.h + typedef __uint128_t VU128 __attribute__((__vector_size__(16))); + const V diff_i128{reinterpret_cast>::type>( + reinterpret_cast(a.raw) - reinterpret_cast(b.raw))}; +#else // !HWY_COMPILER_CLANG + const V diff_i128{reinterpret_cast>::type>( + vec_sub_u128(reinterpret_cast<__vector unsigned char>(a.raw), + reinterpret_cast<__vector unsigned char>(b.raw)))}; +#endif // HWY_COMPILER_CLANG +#elif defined(__SIZEOF_INT128__) + using VU128 = __vector unsigned __int128; + const V diff_i128{reinterpret_cast>::type>( + vec_sub(reinterpret_cast(a.raw), reinterpret_cast(b.raw)))}; +#else + const DFromV d; + const Repartition du64; + + const auto u64_a = BitCast(du64, a); + const auto u64_b = BitCast(du64, b); + + const auto diff_u64 = u64_a - u64_b; + const auto borrow_u64 = VecFromMask(du64, u64_a < u64_b); + +#if HWY_IS_LITTLE_ENDIAN + const auto borrow_u64_shifted = ShiftLeftBytes<8>(du64, borrow_u64); +#else + const auto borrow_u64_shifted = ShiftRightBytes<8>(du64, borrow_u64); +#endif + + const auto diff_i128 = BitCast(d, diff_u64 + borrow_u64_shifted); +#endif + + return diff_i128; +} + +} // namespace detail + +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const FixedTag d; + const auto vmask = VecFromMask(d, mask); + return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask))); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Simd d; + const Full64 d_full64; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_le64 = + BitCast(Full64(), + detail::Per64BitBlkRevLanesOnBe(ResizeBitCast(d_full64, vmask))); + const auto neg_vmask_le64 = Neg(vmask_le64); + const auto neg_vmask = ResizeBitCast( + d, detail::Per64BitBlkRevLanesOnBe(BitCast(d_full64, neg_vmask_le64))); + + return MaskFromVec(Or(vmask, neg_vmask)); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Full128 d; + auto vmask = VecFromMask(d, mask); + + const auto vmask_le128 = detail::Per128BitBlkRevLanesOnBe(vmask); + const auto neg_vmask_le128 = detail::I128Subtract(Zero(d), vmask_le128); + const auto neg_vmask = detail::Per128BitBlkRevLanesOnBe(neg_vmask_le128); + + return MaskFromVec(BitCast(d, Or(vmask, neg_vmask))); +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const FixedTag d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto zero = Zero(di); + const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero); + return MaskFromVec(BitCast(d, And(vmask, vmask2))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Simd d; + const Full64 d_full64; + const RebindToSigned di; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_le64 = + BitCast(Full64(), + detail::Per64BitBlkRevLanesOnBe(ResizeBitCast(d_full64, vmask))); + const auto neg_vmask_le64 = Neg(vmask_le64); + const auto neg_vmask = ResizeBitCast( + d, detail::Per64BitBlkRevLanesOnBe(BitCast(d_full64, neg_vmask_le64))); + + const auto first_vmask = BitCast(di, And(vmask, neg_vmask)); + return MaskFromVec(BitCast(d, Or(first_vmask, Neg(first_vmask)))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Full128 d; + const RebindToSigned di; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_le128 = detail::Per128BitBlkRevLanesOnBe(vmask); + const auto neg_vmask_le128 = detail::I128Subtract(Zero(d), vmask_le128); + const auto neg_vmask = detail::Per128BitBlkRevLanesOnBe(neg_vmask_le128); + + return MaskFromVec(BitCast(d, Neg(BitCast(di, And(vmask, neg_vmask))))); +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 /*mask*/) { + const FixedTag d; + const RebindToSigned di; + using TI = MakeSigned; + + return RebindMask(d, MaskFromVec(Set(di, TI(-1)))); +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + const Simd d; + return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); +} + +// ------------------------------ SumsOf2 and SumsOf4 +namespace detail { + +#if !HWY_S390X_HAVE_Z14 +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsum4sbs(D d, __vector signed char a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { + const int64_t sum0 = + static_cast(a[0]) + static_cast(a[1]) + + static_cast(a[2]) + static_cast(a[3]) + + static_cast(b[0]); + const int64_t sum1 = + static_cast(a[4]) + static_cast(a[5]) + + static_cast(a[6]) + static_cast(a[7]) + + static_cast(b[1]); + const int64_t sum2 = + static_cast(a[8]) + static_cast(a[9]) + + static_cast(a[10]) + static_cast(a[11]) + + static_cast(b[2]); + const int64_t sum3 = + static_cast(a[12]) + static_cast(a[13]) + + static_cast(a[14]) + static_cast(a[15]) + + static_cast(b[3]); + const int32_t sign0 = static_cast(sum0 >> 63); + const int32_t sign1 = static_cast(sum1 >> 63); + const int32_t sign2 = static_cast(sum2 >> 63); + const int32_t sign3 = static_cast(sum3 >> 63); + using Raw = typename detail::Raw128::type; + return BitCast( + d, + VFromD{Raw{ + (sign0 == (sum0 >> 31)) ? static_cast(sum0) + : static_cast(sign0 ^ 0x7FFFFFFF), + (sign1 == (sum1 >> 31)) ? static_cast(sum1) + : static_cast(sign1 ^ 0x7FFFFFFF), + (sign2 == (sum2 >> 31)) ? static_cast(sum2) + : static_cast(sign2 ^ 0x7FFFFFFF), + (sign3 == (sum3 >> 31)) + ? static_cast(sum3) + : static_cast(sign3 ^ 0x7FFFFFFF)}}); + } else // NOLINT +#endif + { + return BitCast(d, VFromD{vec_vsum4sbs(a, b)}); + } +} + +// Casts nominally uint32_t result to D. +template +HWY_INLINE VFromD AltivecVsum4ubs(D d, __vector unsigned char a, + __vector unsigned int b) { + const Repartition du32; +#ifdef __OPTIMIZE__ + if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { + const uint64_t sum0 = + static_cast(a[0]) + static_cast(a[1]) + + static_cast(a[2]) + static_cast(a[3]) + + static_cast(b[0]); + const uint64_t sum1 = + static_cast(a[4]) + static_cast(a[5]) + + static_cast(a[6]) + static_cast(a[7]) + + static_cast(b[1]); + const uint64_t sum2 = + static_cast(a[8]) + static_cast(a[9]) + + static_cast(a[10]) + static_cast(a[11]) + + static_cast(b[2]); + const uint64_t sum3 = + static_cast(a[12]) + static_cast(a[13]) + + static_cast(a[14]) + static_cast(a[15]) + + static_cast(b[3]); + return BitCast( + d, + VFromD{(__vector unsigned int){ + static_cast(sum0 <= 0xFFFFFFFFu ? sum0 : 0xFFFFFFFFu), + static_cast(sum1 <= 0xFFFFFFFFu ? sum1 : 0xFFFFFFFFu), + static_cast(sum2 <= 0xFFFFFFFFu ? sum2 : 0xFFFFFFFFu), + static_cast(sum3 <= 0xFFFFFFFFu ? sum3 + : 0xFFFFFFFFu)}}); + } else // NOLINT +#endif + { + return BitCast(d, VFromD{vec_vsum4ubs(a, b)}); + } +} + +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsum2sws(D d, __vector signed int a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + const Repartition du64; + constexpr int kDestLaneOffset = HWY_IS_BIG_ENDIAN; + if (IsConstantRawAltivecVect(a) && __builtin_constant_p(b[kDestLaneOffset]) && + __builtin_constant_p(b[kDestLaneOffset + 2])) { + const int64_t sum0 = static_cast(a[0]) + + static_cast(a[1]) + + static_cast(b[kDestLaneOffset]); + const int64_t sum1 = static_cast(a[2]) + + static_cast(a[3]) + + static_cast(b[kDestLaneOffset + 2]); + const int32_t sign0 = static_cast(sum0 >> 63); + const int32_t sign1 = static_cast(sum1 >> 63); + return BitCast(d, VFromD{(__vector unsigned long long){ + (sign0 == (sum0 >> 31)) + ? static_cast(sum0) + : static_cast(sign0 ^ 0x7FFFFFFF), + (sign1 == (sum1 >> 31)) + ? static_cast(sum1) + : static_cast(sign1 ^ 0x7FFFFFFF)}}); + } else // NOLINT +#endif + { + __vector signed int sum; + + // Inline assembly is used for vsum2sws to avoid unnecessary shuffling + // on little-endian PowerPC targets as the result of the vsum2sws + // instruction will already be in the correct lanes on little-endian + // PowerPC targets. + __asm__("vsum2sws %0,%1,%2" : "=v"(sum) : "v"(a), "v"(b)); + + return BitCast(d, VFromD{sum}); + } +} + +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsum4shs(D d, __vector signed short a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { + const int64_t sum0 = static_cast(a[0]) + + static_cast(a[1]) + + static_cast(b[0]); + const int64_t sum1 = static_cast(a[2]) + + static_cast(a[3]) + + static_cast(b[1]); + const int64_t sum2 = static_cast(a[4]) + + static_cast(a[5]) + + static_cast(b[2]); + const int64_t sum3 = static_cast(a[6]) + + static_cast(a[7]) + + static_cast(b[3]); + const int32_t sign0 = static_cast(sum0 >> 63); + const int32_t sign1 = static_cast(sum1 >> 63); + const int32_t sign2 = static_cast(sum2 >> 63); + const int32_t sign3 = static_cast(sum3 >> 63); + using Raw = typename detail::Raw128::type; + return BitCast( + d, + VFromD{Raw{ + (sign0 == (sum0 >> 31)) ? static_cast(sum0) + : static_cast(sign0 ^ 0x7FFFFFFF), + (sign1 == (sum1 >> 31)) ? static_cast(sum1) + : static_cast(sign1 ^ 0x7FFFFFFF), + (sign2 == (sum2 >> 31)) ? static_cast(sum2) + : static_cast(sign2 ^ 0x7FFFFFFF), + (sign3 == (sum3 >> 31)) + ? static_cast(sum3) + : static_cast(sign3 ^ 0x7FFFFFFF)}}); + } else // NOLINT +#endif + { + return BitCast(d, VFromD{vec_vsum4shs(a, b)}); + } +} + +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsumsws(D d, __vector signed int a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + constexpr int kDestLaneOffset = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + if (IsConstantRawAltivecVect(a) && __builtin_constant_p(b[kDestLaneOffset])) { + const int64_t sum = + static_cast(a[0]) + static_cast(a[1]) + + static_cast(a[2]) + static_cast(a[3]) + + static_cast(b[kDestLaneOffset]); + const int32_t sign = static_cast(sum >> 63); +#if HWY_IS_LITTLE_ENDIAN + return BitCast( + d, VFromD{(__vector signed int){ + (sign == (sum >> 31)) ? static_cast(sum) + : static_cast(sign ^ 0x7FFFFFFF), + 0, 0, 0}}); +#else + return BitCast(d, VFromD{(__vector signed int){ + 0, 0, 0, + (sign == (sum >> 31)) + ? static_cast(sum) + : static_cast(sign ^ 0x7FFFFFFF)}}); +#endif + } else // NOLINT +#endif + { + __vector signed int sum; + + // Inline assembly is used for vsumsws to avoid unnecessary shuffling + // on little-endian PowerPC targets as the result of the vsumsws + // instruction will already be in the correct lanes on little-endian + // PowerPC targets. + __asm__("vsumsws %0,%1,%2" : "=v"(sum) : "v"(a), "v"(b)); + + return BitCast(d, VFromD{sum}); + } +} + +template +HWY_INLINE Vec128 AltivecU16SumsOf2(Vec128 v) { + const RebindToSigned> di16; + const RepartitionToWide di32; + return AltivecVsum4shs(di32, Xor(BitCast(di16, v), Set(di16, -32768)).raw, + Set(di32, 65536).raw); +} +#endif // !HWY_S390X_HAVE_Z14 + +// U16->U32 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_S390X_HAVE_Z14 + return VFromD{vec_sum4(v.raw, Zero(d).raw)}; +#else + return BitCast(dw, AltivecU16SumsOf2(v)); +#endif +} + +// I16->I32 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du; + return BitCast(dw, SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<2>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw, int32_t{-65536}); +#else + return AltivecVsum4shs(dw, v.raw, Zero(dw).raw); +#endif +} + +#if HWY_S390X_HAVE_Z14 +// U32->U64 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + return VFromD{vec_sum2(v.raw, Zero(d).raw)}; +} + +// I32->I64 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned du; + + return BitCast(dw, SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<4>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw, int64_t{-4294967296LL}); +} +#endif + +// U8->U32 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWideX2 dw2; + +#if HWY_S390X_HAVE_Z14 + return VFromD{vec_sum4(v.raw, Zero(d).raw)}; +#else + return AltivecVsum4ubs(dw2, v.raw, Zero(dw2).raw); +#endif +} + +// I8->I32 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWideX2 dw2; + +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du; + return BitCast(dw2, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw2, int32_t{-512}); +#else + return AltivecVsum4sbs(dw2, v.raw, Zero(dw2).raw); +#endif +} + +// U16->U64 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + const RepartitionToWide dw2; + +#if HWY_S390X_HAVE_Z14 + return VFromD{vec_sum2(v.raw, Zero(d).raw)}; +#else + const RebindToSigned dw_i; + return AltivecVsum2sws(dw2, BitCast(dw_i, SumsOf2(v)).raw, Zero(dw_i).raw); +#endif +} + +// I16->I64 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + const RepartitionToWide dw2; + +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du; + return BitCast(dw2, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<2>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw2, int64_t{-131072}); +#else // VSX + const auto sums_of_4_in_lo32 = + AltivecVsum2sws(dw, SumsOf2(v).raw, Zero(dw).raw); + +#if HWY_IS_LITTLE_ENDIAN + return PromoteEvenTo(dw2, sums_of_4_in_lo32); +#else + return PromoteOddTo(dw2, sums_of_4_in_lo32); +#endif // HWY_IS_LITTLE_ENDIAN +#endif // HWY_S390X_HAVE_Z14 +} + +} // namespace detail + +// ------------------------------ SumOfLanes + +// We define SumOfLanes for 8/16-bit types (and I32/U32/I64/U64 on Z14/Z15/Z16); +// enable generic for the rest. +#undef HWY_IF_SUM_OF_LANES_D +#if HWY_S390X_HAVE_Z14 +#define HWY_IF_SUM_OF_LANES_D(D) HWY_IF_LANES_GT_D(D, 1), HWY_IF_FLOAT3264_D(D) +#else +#define HWY_IF_SUM_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8)) +#endif + +#if HWY_S390X_HAVE_Z14 +namespace detail { + +#if HWY_COMPILER_CLANG && HWY_HAS_BUILTIN(__builtin_s390_vsumqf) && \ + HWY_HAS_BUILTIN(__builtin_s390_vsumqg) +// Workaround for bug in vec_sum_u128 in Clang vecintrin.h +template +HWY_INLINE Vec128 SumOfU32OrU64LanesAsU128(Vec128 v) { + typedef __uint128_t VU128 __attribute__((__vector_size__(16))); + const DFromV d; + const RebindToUnsigned du; + const VU128 sum = {__builtin_s390_vsumqf(BitCast(du, v).raw, Zero(du).raw)}; + return Vec128{reinterpret_cast::type>(sum)}; +} +template +HWY_INLINE Vec128 SumOfU32OrU64LanesAsU128(Vec128 v) { + typedef __uint128_t VU128 __attribute__((__vector_size__(16))); + const DFromV d; + const RebindToUnsigned du; + const VU128 sum = {__builtin_s390_vsumqg(BitCast(du, v).raw, Zero(du).raw)}; + return Vec128{reinterpret_cast::type>(sum)}; +} +#else +template +HWY_INLINE Vec128 SumOfU32OrU64LanesAsU128(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, Vec128{vec_sum_u128(BitCast(du, v).raw, Zero(du).raw)}); +} +#endif + +} // namespace detail + +template +HWY_API VFromD SumOfLanes(D /*d64*/, VFromD v) { + return Broadcast<1>(detail::SumOfU32OrU64LanesAsU128(v)); +} +#endif + +template +HWY_API Vec32 SumOfLanes(D du16, Vec32 v) { + constexpr int kSumLaneIdx = HWY_IS_BIG_ENDIAN; + return Broadcast( + BitCast(du16, detail::SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<2>(), v))); +} + +template +HWY_API Vec64 SumOfLanes(D du16, Vec64 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + return Broadcast( + BitCast(du16, detail::SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<2>(), v))); +} + +template +HWY_API Vec128 SumOfLanes(D du16, Vec128 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; +#if HWY_S390X_HAVE_Z14 + return Broadcast( + BitCast(du16, detail::SumOfU32OrU64LanesAsU128(detail::SumsOf4( + hwy::UnsignedTag(), hwy::SizeTag<2>(), v)))); +#else // VSX + const auto zero = Zero(Full128()); + return Broadcast( + detail::AltivecVsumsws(du16, detail::AltivecU16SumsOf2(v).raw, zero.raw)); +#endif +} + +template +HWY_API Vec32 SumOfLanes(D di16, Vec32 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du16; + return BitCast(di16, SumOfLanes(du16, BitCast(du16, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_BIG_ENDIAN; + return Broadcast( + BitCast(di16, detail::SumsOf2(hwy::SignedTag(), hwy::SizeTag<2>(), v))); +#endif +} + +template +HWY_API Vec64 SumOfLanes(D di16, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du16; + return BitCast(di16, SumOfLanes(du16, BitCast(du16, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + return Broadcast( + BitCast(di16, detail::SumsOf4(hwy::SignedTag(), hwy::SizeTag<2>(), v))); +#endif +} + +template +HWY_API Vec128 SumOfLanes(D di16, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du16; + return BitCast(di16, SumOfLanes(du16, BitCast(du16, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; + const Full128 di32; + const auto zero = Zero(di32); + return Broadcast(detail::AltivecVsumsws( + di16, detail::AltivecVsum4shs(di32, v.raw, zero.raw).raw, zero.raw)); +#endif +} + +template +HWY_API Vec32 SumOfLanes(D du8, Vec32 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + return Broadcast( + BitCast(du8, detail::SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v))); +} + +template +HWY_API Vec16 SumOfLanes(D du8, Vec16 v) { + const Twice dt_u8; + return LowerHalf(du8, SumOfLanes(dt_u8, Combine(dt_u8, Zero(du8), v))); +} + +template +HWY_API Vec64 SumOfLanes(D du8, Vec64 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; + return Broadcast(BitCast(du8, SumsOf8(v))); +} + +template +HWY_API Vec128 SumOfLanes(D du8, Vec128 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 15; + +#if HWY_S390X_HAVE_Z14 + return Broadcast( + BitCast(du8, detail::SumOfU32OrU64LanesAsU128(detail::SumsOf4( + hwy::UnsignedTag(), hwy::SizeTag<1>(), v)))); +#else + const Full128 du32; + const RebindToSigned di32; + const Vec128 zero = Zero(du32); + return Broadcast(detail::AltivecVsumsws( + du8, detail::AltivecVsum4ubs(di32, v.raw, zero.raw).raw, + BitCast(di32, zero).raw)); +#endif +} + +template +HWY_API Vec32 SumOfLanes(D di8, Vec32 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du8; + return BitCast(di8, SumOfLanes(du8, BitCast(du8, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + return Broadcast( + BitCast(di8, detail::SumsOf4(hwy::SignedTag(), hwy::SizeTag<1>(), v))); +#endif +} + +template +HWY_API Vec16 SumOfLanes(D di8, Vec16 v) { + const Twice dt_i8; + return LowerHalf(di8, SumOfLanes(dt_i8, Combine(dt_i8, Zero(di8), v))); +} + +template +HWY_API Vec64 SumOfLanes(D di8, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du8; + return BitCast(di8, SumOfLanes(du8, BitCast(du8, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; + return Broadcast(BitCast(di8, SumsOf8(v))); +#endif +} + +template +HWY_API Vec128 SumOfLanes(D di8, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du8; + return BitCast(di8, SumOfLanes(du8, BitCast(du8, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 15; + const Full128 di32; + const Vec128 zero = Zero(di32); + return Broadcast(detail::AltivecVsumsws( + di8, detail::AltivecVsum4sbs(di32, v.raw, zero.raw).raw, zero.raw)); +#endif +} + +#if HWY_S390X_HAVE_Z14 +template +HWY_API VFromD SumOfLanes(D d32, VFromD v) { + const RebindToUnsigned du32; + return Broadcast<1>( + BitCast(d32, detail::SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<4>(), + BitCast(du32, v)))); +} + +template +HWY_API VFromD SumOfLanes(D /*d32*/, VFromD v) { + return Broadcast<3>(detail::SumOfU32OrU64LanesAsU128(v)); +} +#endif + +// generic_ops defines MinOfLanes and MaxOfLanes. + +// ------------------------------ ReduceSum for N=4 I8/U8 + +// GetLane(SumsOf4(v)) is more efficient on PPC/Z14 than the default N=4 +// I8/U8 ReduceSum implementation in generic_ops-inl.h +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +template +HWY_API TFromD ReduceSum(D /*d*/, VFromD v) { + return static_cast>(GetLane(SumsOf4(v))); +} + +// ------------------------------ BitShuffle + +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Repartition du8; + + const Full128> d_full_u64; + const Full128> d_full_u8; + + using RawVU64 = __vector unsigned long long; + +#if HWY_PPC_HAVE_9 + +#if HWY_IS_LITTLE_ENDIAN + (void)d_full_u64; + auto bit_idx = ResizeBitCast(d_full_u8, idx); +#else + auto bit_idx = + BitCast(d_full_u8, ReverseLaneBytes(ResizeBitCast(d_full_u64, idx))); +#endif + + bit_idx = Xor(bit_idx, Set(d_full_u8, uint8_t{0x3F})); + + return BitCast(d64, VFromD{reinterpret_cast( + vec_bperm(BitCast(du64, v).raw, bit_idx.raw))}); +#else // !HWY_PPC_HAVE_9 + +#if HWY_IS_LITTLE_ENDIAN + const auto bit_idx_xor_mask = BitCast( + d_full_u8, Dup128VecFromValues(d_full_u64, uint64_t{0x7F7F7F7F7F7F7F7Fu}, + uint64_t{0x3F3F3F3F3F3F3F3Fu})); + const auto bit_idx = Xor(ResizeBitCast(d_full_u8, idx), bit_idx_xor_mask); + constexpr int kBitShufResultByteShrAmt = 8; +#else + const auto bit_idx_xor_mask = BitCast( + d_full_u8, Dup128VecFromValues(d_full_u64, uint64_t{0x3F3F3F3F3F3F3F3Fu}, + uint64_t{0x7F7F7F7F7F7F7F7Fu})); + const auto bit_idx = + Xor(BitCast(d_full_u8, ReverseLaneBytes(ResizeBitCast(d_full_u64, idx))), + bit_idx_xor_mask); + constexpr int kBitShufResultByteShrAmt = 6; +#endif + +#if HWY_S390X_HAVE_Z14 + const VFromD bit_shuf_result{reinterpret_cast( + vec_bperm_u128(BitCast(du8, v).raw, bit_idx.raw))}; +#elif defined(__SIZEOF_INT128__) + using RawVU128 = __vector unsigned __int128; + const VFromD bit_shuf_result{reinterpret_cast( + vec_vbpermq(reinterpret_cast(v.raw), bit_idx.raw))}; +#else + using RawVU128 = __vector unsigned char; + const VFromD bit_shuf_result{reinterpret_cast( + vec_vbpermq(reinterpret_cast(v.raw), bit_idx.raw))}; +#endif + + return ResizeBitCast( + d64, PromoteTo(d_full_u64, + ResizeBitCast( + Rebind(), + CombineShiftRightBytes( + d_full_u64, bit_shuf_result, bit_shuf_result)))); +#endif // HWY_PPC_HAVE_9 +} + +// ------------------------------ Lt128 + +namespace detail { + +// Returns vector-mask for Lt128. +template > +HWY_INLINE V Lt128Vec(D d, V a, V b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + (void)d; + using VU64 = __vector unsigned long long; + using VU128 = __vector unsigned __int128; +#if HWY_IS_LITTLE_ENDIAN + const VU128 a_u128 = reinterpret_cast(a.raw); + const VU128 b_u128 = reinterpret_cast(b.raw); +#else + // NOTE: Need to swap the halves of both a and b on big-endian targets + // as the upper 64 bits of a and b are in lane 1 and the lower 64 bits + // of a and b are in lane 0 whereas the vec_cmplt operation below expects + // the upper 64 bits in lane 0 and the lower 64 bits in lane 1 on + // big-endian PPC targets. + const VU128 a_u128 = reinterpret_cast(vec_sld(a.raw, a.raw, 8)); + const VU128 b_u128 = reinterpret_cast(vec_sld(b.raw, b.raw, 8)); +#endif + return V{reinterpret_cast(vec_cmplt(a_u128, b_u128))}; +#else // !HWY_PPC_HAVE_10 + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const auto eqHL = Eq(a, b); + const V ltHL = VecFromMask(d, Lt(a, b)); + const V ltLX = ShiftLeftLanes<1>(ltHL); + const V vecHx = IfThenElse(eqHL, ltLX, ltHL); + return InterleaveUpper(d, vecHx, vecHx); +#endif +} + +// Returns vector-mask for Eq128. +template > +HWY_INLINE V Eq128Vec(D d, V a, V b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + (void)d; + using VU64 = __vector unsigned long long; + using VU128 = __vector unsigned __int128; + return V{reinterpret_cast(vec_cmpeq(reinterpret_cast(a.raw), + reinterpret_cast(b.raw)))}; +#else + const auto eqHL = VecFromMask(d, Eq(a, b)); + const auto eqLH = Reverse2(d, eqHL); + return And(eqHL, eqLH); +#endif +} + +template > +HWY_INLINE V Ne128Vec(D d, V a, V b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + (void)d; + using VU64 = __vector unsigned long long; + using VU128 = __vector unsigned __int128; + return V{reinterpret_cast(vec_cmpne(reinterpret_cast(a.raw), + reinterpret_cast(b.raw)))}; +#else + const auto neHL = VecFromMask(d, Ne(a, b)); + const auto neLH = Reverse2(d, neHL); + return Or(neHL, neLH); +#endif +} + +template > +HWY_INLINE V Lt128UpperVec(D d, V a, V b) { + const V ltHL = VecFromMask(d, Lt(a, b)); + return InterleaveUpper(d, ltHL, ltHL); +} + +template > +HWY_INLINE V Eq128UpperVec(D d, V a, V b) { + const V eqHL = VecFromMask(d, Eq(a, b)); + return InterleaveUpper(d, eqHL, eqHL); +} + +template > +HWY_INLINE V Ne128UpperVec(D d, V a, V b) { + const V neHL = VecFromMask(d, Ne(a, b)); + return InterleaveUpper(d, neHL, neHL); +} + +} // namespace detail + +template > +HWY_API MFromD Lt128(D d, V a, V b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); +} + +template > +HWY_API MFromD Eq128(D d, V a, V b) { + return MaskFromVec(detail::Eq128Vec(d, a, b)); +} + +template > +HWY_API MFromD Ne128(D d, V a, V b) { + return MaskFromVec(detail::Ne128Vec(d, a, b)); +} + +template > +HWY_API MFromD Lt128Upper(D d, V a, V b) { + return MaskFromVec(detail::Lt128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Eq128Upper(D d, V a, V b) { + return MaskFromVec(detail::Eq128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Ne128Upper(D d, V a, V b) { + return MaskFromVec(detail::Ne128UpperVec(d, a, b)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template > +HWY_API V Min128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +} + +template > +HWY_API V Max128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +} + +template > +HWY_API V Min128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b); +} + +template > +HWY_API V Max128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b); +} + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +template +HWY_API V LeadingZeroCount(V v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const RebindToUnsigned du; + +#if HWY_COMPILER_GCC_ACTUAL && defined(__OPTIMIZE__) + // Work around for GCC compiler bug in vec_cnttz on Z14/Z15 if v[i] is a + // constant + __asm__("" : "+v"(v.raw)); +#endif + + return BitCast(d, VFromD{vec_cntlz(BitCast(du, v).raw)}); +#else + return V{vec_cntlz(v.raw)}; +#endif +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); +} + +#if HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14 +template +HWY_API V TrailingZeroCount(V v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return V{vec_vctz(v.raw)}; +#else +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const RebindToUnsigned du; + +#if HWY_COMPILER_GCC_ACTUAL && defined(__OPTIMIZE__) + // Work around for GCC compiler bug in vec_cnttz on Z14/Z15 if v[i] is a + // constant + __asm__("" : "+v"(v.raw)); +#endif + + return BitCast(d, VFromD{vec_cnttz(BitCast(du, v).raw)}); +#else + return V{vec_cnttz(v.raw)}; +#endif // HWY_S390X_HAVE_Z14 +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +} +#else +template +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + + const auto vi = BitCast(di, v); + const auto lowest_bit = And(vi, Neg(vi)); + constexpr TI kNumOfBitsInT{sizeof(TI) * 8}; + const auto bit_idx = HighestSetBitIndex(lowest_bit); + return BitCast(d, IfThenElse(MaskFromVec(BroadcastSignBit(bit_idx)), + Set(di, kNumOfBitsInT), bit_idx)); +} +#endif + +#undef HWY_PPC_HAVE_9 +#undef HWY_PPC_HAVE_10 +#undef HWY_S390X_HAVE_Z14 +#undef HWY_S390X_HAVE_Z15 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/aom/third_party/highway/hwy/ops/rvv-inl.h b/third_party/aom/third_party/highway/hwy/ops/rvv-inl.h new file mode 100644 index 000000000000..752c87de6ecd --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/rvv-inl.h @@ -0,0 +1,6568 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RISC-V V vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include + +#include "third_party/highway/hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Support for vfloat16m*_t and PromoteTo/DemoteTo. +#ifdef __riscv_zvfhmin +#define HWY_RVV_HAVE_F16C 1 +#else +#define HWY_RVV_HAVE_F16C 0 +#endif + +template +struct DFromV_t {}; // specialized in macros +template +using DFromV = typename DFromV_t>::type; + +template +using TFromV = TFromD>; + +template +constexpr size_t MLenFromD(Simd /* tag */) { + // Returns divisor = type bits / LMUL. Folding *8 into the ScaleByPower + // argument enables fractional LMUL < 1. Limit to 64 because that is the + // largest value for which vbool##_t are defined. + return HWY_MIN(64, sizeof(T) * 8 * 8 / detail::ScaleByPower(8, kPow2)); +} + +namespace detail { + +template +class AdjustSimdTagToMinVecPow2_t {}; + +template +class AdjustSimdTagToMinVecPow2_t> { + private: + using D = Simd; + static constexpr int kMinVecPow2 = + -3 + static_cast(FloorLog2(sizeof(T))); + static constexpr size_t kNumMaxLanes = HWY_MAX_LANES_D(D); + static constexpr int kNewPow2 = HWY_MAX(kPow2, kMinVecPow2); + static constexpr size_t kNewN = D::template NewN(); + + public: + using type = Simd; +}; + +template +using AdjustSimdTagToMinVecPow2 = + typename AdjustSimdTagToMinVecPow2_t>::type; + +} // namespace detail + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// For all mask sizes MLEN: (1/Nth of a register, one bit per lane) +// The first three arguments are arbitrary SEW, LMUL, SHIFT such that +// SEW >> SHIFT = MLEN. +#define HWY_RVV_FOREACH_B(X_MACRO, NAME, OP) \ + X_MACRO(64, 0, 64, NAME, OP) \ + X_MACRO(32, 0, 32, NAME, OP) \ + X_MACRO(16, 0, 16, NAME, OP) \ + X_MACRO(8, 0, 8, NAME, OP) \ + X_MACRO(8, 1, 4, NAME, OP) \ + X_MACRO(8, 2, 2, NAME, OP) \ + X_MACRO(8, 3, 1, NAME, OP) + +// For given SEW, iterate over one of LMULS: _TRUNC, _EXT, _ALL. This allows +// reusing type lists such as HWY_RVV_FOREACH_U for _ALL (the usual case) or +// _EXT (for Combine). To achieve this, we HWY_CONCAT with the LMULS suffix. +// +// Precompute SEW/LMUL => MLEN to allow token-pasting the result. For the same +// reason, also pass the double-width and half SEW and LMUL (suffixed D and H, +// respectively). "__" means there is no corresponding LMUL (e.g. LMULD for m8). +// Args: BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP + +// LMULS = _TRUNC: truncatable (not the smallest LMUL) +#define HWY_RVV_FOREACH_08_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_08_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _DEMOTE: can demote from SEW*LMUL to SEWH*LMULH. +#define HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _LE2: <= 2 +#define HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) + +// LMULS = _EXT: not the largest LMUL +#define HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) + +// LMULS = _ALL (2^MinPow2() <= LMUL <= 8) +#define HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// 'Virtual' LMUL. This upholds the Highway guarantee that vectors are at least +// 128 bit and LowerHalf is defined whenever there are at least 2 lanes, even +// though RISC-V LMUL must be at least SEW/64 (notice that this rules out +// LMUL=1/2 for SEW=64). To bridge the gap, we add overloads for kPow2 equal to +// one less than should be supported, with all other parameters (vector type +// etc.) unchanged. For D with the lowest kPow2 ('virtual LMUL'), Lanes() +// returns half of what it usually would. +// +// Notice that we can only add overloads whenever there is a D argument: those +// are unique with respect to non-virtual-LMUL overloads because their kPow2 +// template argument differs. Otherwise, there is no actual vuint64mf2_t, and +// defining another overload with the same LMUL would be an error. Thus we have +// a separate _VIRT category for HWY_RVV_FOREACH*, and the common case is +// _ALL_VIRT (meaning the regular LMUL plus the VIRT overloads), used in most +// functions that take a D. + +#define HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -3, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -2, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, -1, /*MLEN=*/64, NAME, OP) + +// ALL + VIRT +#define HWY_RVV_FOREACH_08_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// LE2 + VIRT +#define HWY_RVV_FOREACH_08_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// GET/SET + VIRT +#define HWY_RVV_FOREACH_08_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// For the smallest LMUL for each SEW, similar to the LowerHalf operator, we +// provide the Get and Set operator that returns the same vector type. +#define HWY_RVV_FOREACH_08_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) + +// EXT + VIRT +#define HWY_RVV_FOREACH_08_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// DEMOTE + VIRT +#define HWY_RVV_FOREACH_08_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// SEW for unsigned: +#define HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, uint, u, NAME, OP) + +// SEW for signed: +#define HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, int, i, NAME, OP) + +// SEW for float: + +// Used for conversion instructions if HWY_RVV_HAVE_F16C. +#define HWY_RVV_FOREACH_F16_UNCONDITIONAL(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, float, f, NAME, OP) + +#if HWY_HAVE_FLOAT16 +// Full support for f16 in all ops +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F16_UNCONDITIONAL(X_MACRO, NAME, OP, LMULS) +// Only BF16 is emulated. +#define HWY_RVV_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_RVV_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) +#else +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) +#define HWY_RVV_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_RVV_IF_NOT_EMULATED_D(D) HWY_IF_NOT_SPECIAL_FLOAT_D(D) +#endif +#define HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, float, f, NAME, OP) +#define HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, float, f, NAME, OP) + +// Commonly used type/SEW groups: +#define HWY_RVV_FOREACH_UI08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) + +// For all combinations of SEW: +#define HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) + +// Commonly used type categories: +#define HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) + +// Assemble types for use in x-macros +#define HWY_RVV_T(BASE, SEW) BASE##SEW##_t +#define HWY_RVV_D(BASE, SEW, N, SHIFT) Simd +#define HWY_RVV_V(BASE, SEW, LMUL) v##BASE##SEW##LMUL##_t +#define HWY_RVV_TUP(BASE, SEW, LMUL, TUP) v##BASE##SEW##LMUL##x##TUP##_t +#define HWY_RVV_M(MLEN) vbool##MLEN##_t + +} // namespace detail + +// Until we have full intrinsic support for fractional LMUL, mixed-precision +// code can use LMUL 1..8 (adequate unless they need many registers). +#define HWY_SPECIALIZE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <> \ + struct DFromV_t { \ + using Lane = HWY_RVV_T(BASE, SEW); \ + using type = ScalableTag; \ + }; + +HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _, _ALL) +#undef HWY_SPECIALIZE + +// ------------------------------ Lanes + +// WARNING: we want to query VLMAX/sizeof(T), but this may actually change VL! + +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD +// HWY_RVV_CAPPED_LANES_SPECIAL_CASES provides some additional optimizations +// to CappedLanes in non-debug builds +#define HWY_RVV_CAPPED_LANES_SPECIAL_CASES(BASE, SEW, LMUL) \ + if (__builtin_constant_p(cap >= kMaxLanes) && (cap >= kMaxLanes)) { \ + /* If cap is known to be greater than or equal to MaxLanes(d), */ \ + /* HWY_MIN(cap, Lanes(d)) will be equal to Lanes(d) */ \ + return Lanes(d); \ + } \ + \ + if ((__builtin_constant_p((cap & (cap - 1)) == 0) && \ + ((cap & (cap - 1)) == 0)) || \ + (__builtin_constant_p(cap <= HWY_MAX(kMinLanesPerFullVec, 4)) && \ + (cap <= HWY_MAX(kMinLanesPerFullVec, 4)))) { \ + /* If cap is known to be a power of 2, then */ \ + /* vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the same */ \ + /* result as HWY_MIN(cap, Lanes(d)) as kMaxLanes is a power of 2 and */ \ + /* as (cap > VLMAX && cap < 2 * VLMAX) can only be true if cap is not a */ \ + /* power of 2 since VLMAX is always a power of 2 */ \ + \ + /* If cap is known to be less than or equal to 4, then */ \ + /* vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the same */ \ + /* result as HWY_MIN(cap, Lanes(d)) as HWY_MIN(cap, kMaxLanes) <= 4 is */ \ + /* true if cap <= 4 and as vsetvl(HWY_MIN(cap, kMaxLanes)) is */ \ + /* guaranteed to return the same result as HWY_MIN(cap, Lanes(d)) */ \ + /* if HWY_MIN(cap, kMaxLanes) <= 4 is true */ \ + \ + /* If cap is known to be less than or equal to kMinLanesPerFullVec, */ \ + /* then vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the */ \ + /* same result as HWY_MIN(cap, Lanes(d)) as */ \ + /* HWY_MIN(cap, kMaxLanes) <= kMinLanesPerFullVec is true if */ \ + /* cap <= kMinLanesPerFullVec is true */ \ + \ + /* If cap <= HWY_MAX(kMinLanesPerFullVec, 4) is true, then either */ \ + /* cap <= 4 or cap <= kMinLanesPerFullVec must be true */ \ + \ + /* If cap <= HWY_MAX(kMinLanesPerFullVec, 4) is known to be true, */ \ + /* then vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the */ \ + /* same result as HWY_MIN(cap, Lanes(d)) */ \ + \ + /* If no cap, avoid the HWY_MIN. */ \ + return detail::IsFull(d) \ + ? __riscv_vsetvl_e##SEW##LMUL(cap) \ + : __riscv_vsetvl_e##SEW##LMUL(HWY_MIN(cap, kMaxLanes)); \ + } +#else +#define HWY_RVV_CAPPED_LANES_SPECIAL_CASES(BASE, SEW, LMUL) +#endif + +#define HWY_RVV_LANES(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + constexpr size_t kFull = HWY_LANES(HWY_RVV_T(BASE, SEW)); \ + constexpr size_t kCap = MaxLanes(d); \ + /* If no cap, avoid generating a constant by using VLMAX. */ \ + return N == kFull ? __riscv_vsetvlmax_e##SEW##LMUL() \ + : __riscv_vsetvl_e##SEW##LMUL(kCap); \ + } \ + template \ + HWY_API size_t Capped##NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, size_t cap) { \ + /* NOTE: Section 6.3 of the RVV specification, which can be found at */ \ + /* https://github.com/riscv/riscv-v-spec/blob/master/v-spec.adoc, */ \ + /* allows vsetvl to return a result less than Lanes(d) but greater than */ \ + /* or equal to ((cap + 1) / 2) if */ \ + /* (Lanes(d) > 2 && cap > HWY_MAX(Lanes(d), 4) && cap < (2 * Lanes(d))) */ \ + /* is true */ \ + \ + /* VLMAX is the number of lanes in a vector of type */ \ + /* VFromD, which is returned by */ \ + /* Lanes(DFromV>()) */ \ + \ + /* VLMAX is guaranteed to be a power of 2 under Section 2 of the RVV */ \ + /* specification */ \ + \ + /* The VLMAX of a vector of type VFromD is at least 2 as */ \ + /* the HWY_RVV target requires support for the RVV Zvl128b extension, */ \ + /* which guarantees that vectors with LMUL=1 are at least 16 bytes */ \ + \ + /* If VLMAX == 2 is true, then vsetvl(cap) is equal to HWY_MIN(cap, 2) */ \ + /* as cap == 3 is the only value such that */ \ + /* (cap > VLMAX && cap < 2 * VLMAX) if VLMAX == 2 and as */ \ + /* ((3 + 1) / 2) is equal to 2 */ \ + \ + /* If cap <= 4 is true, then vsetvl(cap) must be equal to */ \ + /* HWY_MIN(cap, VLMAX) as cap <= VLMAX is true if VLMAX >= 4 is true */ \ + /* and as vsetvl(cap) is guaranteed to be equal to HWY_MIN(cap, VLMAX) */ \ + /* if VLMAX == 2 */ \ + \ + /* We want CappedLanes(d, cap) to return Lanes(d) if cap > Lanes(d) as */ \ + /* LoadN(d, p, cap) expects to load exactly HWY_MIN(cap, Lanes(d)) */ \ + /* lanes and StoreN(v, d, p, cap) expects to store exactly */ \ + /* HWY_MIN(cap, Lanes(d)) lanes, even in the case where vsetvl returns */ \ + /* a result that is less than HWY_MIN(cap, Lanes(d)) */ \ + \ + /* kMinLanesPerFullVec is the minimum value of VLMAX for a vector of */ \ + /* type VFromD */ \ + constexpr size_t kMinLanesPerFullVec = \ + detail::ScaleByPower(16 / (SEW / 8), SHIFT); \ + /* kMaxLanes is the maximum number of lanes returned by Lanes(d) */ \ + constexpr size_t kMaxLanes = MaxLanes(d); \ + \ + HWY_RVV_CAPPED_LANES_SPECIAL_CASES(BASE, SEW, LMUL) \ + \ + if (kMaxLanes <= HWY_MAX(kMinLanesPerFullVec, 4)) { \ + /* If kMaxLanes <= kMinLanesPerFullVec is true, then */ \ + /* vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return */ \ + /* HWY_MIN(cap, Lanes(d)) as */ \ + /* HWY_MIN(cap, kMaxLanes) <= kMaxLanes <= VLMAX is true if */ \ + /* kMaxLanes <= kMinLanesPerFullVec is true */ \ + \ + /* If kMaxLanes <= 4 is true, then vsetvl(HWY_MIN(cap, kMaxLanes)) is */ \ + /* guaranteed to return the same result as HWY_MIN(cap, Lanes(d)) as */ \ + /* HWY_MIN(cap, kMaxLanes) <= 4 is true if kMaxLanes <= 4 is true */ \ + \ + /* If kMaxLanes <= HWY_MAX(kMinLanesPerFullVec, 4) is true, then */ \ + /* either kMaxLanes <= 4 or kMaxLanes <= kMinLanesPerFullVec must be */ \ + /* true */ \ + \ + return __riscv_vsetvl_e##SEW##LMUL(HWY_MIN(cap, kMaxLanes)); \ + } else { \ + /* If kMaxLanes > HWY_MAX(kMinLanesPerFullVec, 4) is true, need to */ \ + /* obtain the actual number of lanes using Lanes(d) and clamp cap to */ \ + /* the result of Lanes(d) */ \ + const size_t actual = Lanes(d); \ + return HWY_MIN(actual, cap); \ + } \ + } + +#define HWY_RVV_LANES_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + constexpr size_t kCap = MaxLanes(d); \ + /* In case of virtual LMUL (intrinsics do not provide "uint16mf8_t") */ \ + /* vsetvl may or may not be correct, so do it ourselves. */ \ + const size_t actual = \ + detail::ScaleByPower(__riscv_vlenb() / (SEW / 8), SHIFT); \ + return HWY_MIN(actual, kCap); \ + } \ + template \ + HWY_API size_t Capped##NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, size_t cap) { \ + /* In case of virtual LMUL (intrinsics do not provide "uint16mf8_t") */ \ + /* vsetvl may or may not be correct, so do it ourselves. */ \ + const size_t actual = \ + detail::ScaleByPower(__riscv_vlenb() / (SEW / 8), SHIFT); \ + /* If no cap, avoid an extra HWY_MIN. */ \ + return detail::IsFull(d) ? HWY_MIN(actual, cap) \ + : HWY_MIN(HWY_MIN(actual, cap), MaxLanes(d)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_LANES, Lanes, setvlmax_e, _ALL) +HWY_RVV_FOREACH(HWY_RVV_LANES_VIRT, Lanes, lenb, _VIRT) +#undef HWY_RVV_LANES +#undef HWY_RVV_LANES_VIRT +#undef HWY_RVV_CAPPED_LANES_SPECIAL_CASES + +template +HWY_API size_t Lanes(D /* tag*/) { + return Lanes(RebindToUnsigned()); +} + +template +HWY_API size_t CappedLanes(D /* tag*/, size_t cap) { + return CappedLanes(RebindToUnsigned(), cap); +} + +// ------------------------------ Common x-macros + +// Last argument to most intrinsics. Use when the op has no d arg of its own, +// which means there is no user-specified cap. +#define HWY_RVV_AVL(SEW, SHIFT) \ + Lanes(ScalableTag()) + +// vector = f(vector), e.g. Not +#define HWY_RVV_RETV_ARGV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, scalar), e.g. detail::AddS +#define HWY_RVV_RETV_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_RVV_RETV_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(a, b, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, mask, vector, vector), e.g. MaskedAddOr +#define HWY_RVV_RETV_ARGMVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) no, HWY_RVV_M(MLEN) m, \ + HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_mu(m, no, a, b, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(mask) +#define HWY_RVV_RETM_ARGM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) m) { \ + return __riscv_vm##OP##_m_b##MLEN(m, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// ================================================== INIT + +// ------------------------------ Set + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_T(BASE, SEW) arg) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(arg, Lanes(d)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SET, Set, mv_v_x, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_SET, Set, fmv_v_f, _ALL_VIRT) +#undef HWY_RVV_SET + +// Treat bfloat16_t as int16_t (using the previously defined Set overloads); +// required for Zero and VFromD. +template +decltype(Set(RebindToSigned(), 0)) Set(D d, hwy::bfloat16_t arg) { + return Set(RebindToSigned(), BitCastScalar(arg)); +} +#if !HWY_HAVE_FLOAT16 // Otherwise already defined above. +// WARNING: returns a different type than emulated bfloat16_t so that we can +// implement PromoteTo overloads for both bfloat16_t and float16_t, and also +// provide a Neg(hwy::float16_t) overload that coexists with Neg(int16_t). +template +decltype(Set(RebindToUnsigned(), 0)) Set(D d, hwy::float16_t arg) { + return Set(RebindToUnsigned(), BitCastScalar(arg)); +} +#endif + +template +using VFromD = decltype(Set(D(), TFromD())); + +// ------------------------------ Zero + +template +HWY_API VFromD Zero(D d) { + // Cast to support bfloat16_t. + const RebindToUnsigned du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ Undefined + +// RVV vundefined is 'poisoned' such that even XORing a _variable_ initialized +// by it gives unpredictable results. It should only be used for maskoff, so +// keep it internal. For the Highway op, just use Zero (single instruction). +namespace detail { +#define HWY_RVV_UNDEFINED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) /* tag */) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(); /* no AVL */ \ + } + +HWY_RVV_FOREACH(HWY_RVV_UNDEFINED, Undefined, undefined, _ALL) +#undef HWY_RVV_UNDEFINED +} // namespace detail + +template +HWY_API VFromD Undefined(D d) { + return Zero(d); +} + +// ------------------------------ BitCast + +namespace detail { + +// Halves LMUL. (Use LMUL arg for the source so we can use _TRUNC.) +#define HWY_RVV_TRUNC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH( \ + v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) +#undef HWY_RVV_TRUNC + +// Doubles LMUL to `d2` (the arg is only necessary for _VIRT). +#define HWY_RVV_EXT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULD) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULD( \ + v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT, Ext, lmul_ext, _EXT) +#undef HWY_RVV_EXT + +// For virtual LMUL e.g. 'uint32mf4_t', the return type should be mf2, which is +// the same as the actual input type. +#define HWY_RVV_EXT_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v; \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT_VIRT, Ext, lmul_ext, _VIRT) +#undef HWY_RVV_EXT_VIRT + +template +VFromD Ext(D d, VFromD> v) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, Ext(du, BitCast(duh, v))); +} + +// For BitCastToByte, the D arg is only to prevent duplicate definitions caused +// by _ALL_VIRT. + +// There is no reinterpret from u8 <-> u8, so just return. +#define HWY_RVV_CAST_U8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + vuint8##LMUL##_t v) { \ + return v; \ + } \ + template \ + HWY_API vuint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v; \ + } + +// For i8, need a single reinterpret (HWY_RVV_CAST_IF does two). +#define HWY_RVV_CAST_I8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + vint8##LMUL##_t v) { \ + return __riscv_vreinterpret_v_i8##LMUL##_u8##LMUL(v); \ + } \ + template \ + HWY_API vint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return __riscv_vreinterpret_v_u8##LMUL##_i8##LMUL(v); \ + } + +// Separate u/i because clang only provides signed <-> unsigned reinterpret for +// the same SEW. +#define HWY_RVV_CAST_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return __riscv_v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + __riscv_v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return __riscv_v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + __riscv_v##OP##_v_u8##LMUL##_u##SEW##LMUL(v)); \ + } + +// Additional versions for virtual LMUL using LMULH for byte vectors. +#define HWY_RVV_CAST_VIRT_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(__riscv_v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return __riscv_v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v2); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_VIRT_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(__riscv_v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + __riscv_v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v))); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return __riscv_v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + __riscv_v##OP##_v_u8##LMUL##_u##SEW##LMUL(v2)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_CAST_U8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I08(HWY_RVV_CAST_I8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_U, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_VIRT_U, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) +#if HWY_HAVE_FLOAT16 // HWY_RVV_FOREACH_F already covered float16_ +#elif HWY_RVV_HAVE_F16C // zvfhmin provides reinterpret* intrinsics: +HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) +#else +template +HWY_INLINE VFromD> BitCastFromByte( + D /* d */, VFromD> v) { + return BitCastFromByte(RebindToUnsigned(), v); +} +#endif + +#undef HWY_RVV_CAST_U8 +#undef HWY_RVV_CAST_I8 +#undef HWY_RVV_CAST_U +#undef HWY_RVV_CAST_IF +#undef HWY_RVV_CAST_VIRT_U +#undef HWY_RVV_CAST_VIRT_IF + +template +HWY_INLINE VFromD> BitCastFromByte( + D d, VFromD> v) { + return BitCastFromByte(RebindToSigned(), v); +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(d, v)); +} + +// ------------------------------ Iota + +namespace detail { + +#define HWY_RVV_IOTA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(Lanes(d)); \ + } + +// For i8 lanes, this may well wrap around. Unsigned only is less error-prone. +HWY_RVV_FOREACH_U(HWY_RVV_IOTA, Iota0, id_v, _ALL_VIRT) +#undef HWY_RVV_IOTA + +// Used by Expand. +#define HWY_RVV_MASKED_IOTA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_M(MLEN) mask) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(mask, Lanes(d)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_MASKED_IOTA, MaskedIota, iota_m, _ALL_VIRT) +#undef HWY_RVV_MASKED_IOTA + +} // namespace detail + +// ================================================== LOGICAL + +// ------------------------------ Not + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGV, Not, not, _ALL) + +template +HWY_API V Not(const V v) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Not(BitCast(DU(), v))); +} + +// ------------------------------ And + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AndS, and_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, And, and, _ALL) + +template +HWY_API V And(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), And(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Or + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Or, or, _ALL) + +template +HWY_API V Or(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Or(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Xor + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, XorS, xor_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Xor, xor, _ALL) + +template +HWY_API V Xor(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Xor(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ AndNot +template +HWY_API V AndNot(const V not_a, const V b) { + return And(Not(not_a), b); +} + +// ------------------------------ Xor3 +template +HWY_API V Xor3(V x1, V x2, V x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 +template +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ CopySign + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, CopySign, fsgnj, _ALL) + +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + // RVV can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ================================================== ARITHMETIC + +// Per-target flags to prevent generic_ops-inl.h defining Add etc. +#ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS +#undef HWY_NATIVE_OPERATOR_REPLACEMENTS +#else +#define HWY_NATIVE_OPERATOR_REPLACEMENTS +#endif + +// ------------------------------ Add + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AddS, add_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, AddS, fadd_vf, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, ReverseSubS, rsub_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, ReverseSubS, frsub_vf, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Add, add, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Add, fadd, _ALL) + +// ------------------------------ Sub +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, SubS, sub_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Sub, sub, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Sub, fsub, _ALL) + +// ------------------------------ Neg (ReverseSubS, Xor) + +template +HWY_API V Neg(const V v) { + return detail::ReverseSubS(v, 0); +} + +// vector = f(vector), but argument is repeated +#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, v, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn, _ALL) + +#if !HWY_HAVE_FLOAT16 + +template )> // hwy::float16_t +HWY_API V Neg(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask()))); +} + +#endif // !HWY_HAVE_FLOAT16 + +// ------------------------------ SaturatedAdd + +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB +#endif + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) + +// ------------------------------ SaturatedSub + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) + +// ------------------------------ AverageRound + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +// Define this to opt-out of the default behavior, which is AVOID on certain +// compiler versions. You can define only this to use VXRM, or define both this +// and HWY_RVV_AVOID_VXRM to always avoid VXRM. +#ifndef HWY_RVV_CHOOSE_VXRM + +// Assume that GCC-13 defaults to 'avoid VXRM'. Tested with GCC 13.1.0. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400 +#define HWY_RVV_AVOID_VXRM +// Clang 16 with __riscv_v_intrinsic == 11000 may either require VXRM or avoid. +// Assume earlier versions avoid. +#elif HWY_COMPILER_CLANG && \ + (HWY_COMPILER_CLANG < 1600 || __riscv_v_intrinsic < 11000) +#define HWY_RVV_AVOID_VXRM +#endif + +#endif // HWY_RVV_CHOOSE_VXRM + +// Adding __RISCV_VXRM_* was a backwards-incompatible change and it is not clear +// how to detect whether it is supported or required. #ifdef __RISCV_VXRM_RDN +// does not work because it seems to be a compiler built-in, but neither does +// __has_builtin(__RISCV_VXRM_RDN). The intrinsics version was also not updated, +// so we require a macro to opt out of the new intrinsics. +#ifdef HWY_RVV_AVOID_VXRM +#define HWY_RVV_INSERT_VXRM(vxrm, avl) avl +#define __RISCV_VXRM_RNU +#define __RISCV_VXRM_RDN +#else // default: use new vxrm arguments +#define HWY_RVV_INSERT_VXRM(vxrm, avl) vxrm, avl +#endif + +// Extra rounding mode = up argument. +#define HWY_RVV_RETV_AVERAGE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + a, b, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_RETV_AVERAGE, AverageRound, aadd, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETV_AVERAGE, AverageRound, aaddu, _ALL) + +#undef HWY_RVV_RETV_AVERAGE + +// ------------------------------ ShiftLeft[Same] + +// Intrinsics do not define .vi forms, so use .vx instead. +#define HWY_RVV_SHIFT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(v, kBits, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Same(HWY_RVV_V(BASE, SEW, LMUL) v, int bits) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(v, static_cast(bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SHIFT, ShiftLeft, sll, _ALL) + +// ------------------------------ ShiftRight[Same] + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT, ShiftRight, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT, ShiftRight, sra, _ALL) + +#undef HWY_RVV_SHIFT + +// ------------------------------ RoundingShiftRight[Same] + +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +// Intrinsics do not define .vi forms, so use .vx instead. +#define HWY_RVV_ROUNDING_SHR(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL( \ + v, kBits, \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Same(HWY_RVV_V(BASE, SEW, LMUL) v, int bits) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL( \ + v, static_cast(bits), \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_ROUNDING_SHR, RoundingShiftRight, ssrl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_ROUNDING_SHR, RoundingShiftRight, ssra, _ALL) + +#undef HWY_RVV_ROUNDING_SHR + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template )> +HWY_API VFromD>> SumsOf8(const VU8 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = detail::AndS(BitCast(du16, v), 0xFF); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return detail::AndS(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), 0xFFFFull); +} + +template )> +HWY_API VFromD>> SumsOf8(const VI8 v) { + const DFromV di8; + const RepartitionToWide di16; + const RepartitionToWide di32; + const RepartitionToWide di64; + const RebindToUnsigned du32; + const RebindToUnsigned du64; + using VI16 = VFromD; + + const VI16 vFDB97531 = ShiftRight<8>(BitCast(di16, v)); + const VI16 vECA86420 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, v))); + const VI16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VI16 sDC_zz_98_zz_54_zz_10_zz = + BitCast(di16, ShiftLeft<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VI16 sFC_xx_B8_xx_74_xx_30_xx = + Add(sFE_DC_BA_98_76_54_32_10, sDC_zz_98_zz_54_zz_10_zz); + const VI16 sB8_xx_zz_zz_30_xx_zz_zz = + BitCast(di16, ShiftLeft<32>(BitCast(du64, sFC_xx_B8_xx_74_xx_30_xx))); + const VI16 sF8_xx_xx_xx_70_xx_xx_xx = + Add(sFC_xx_B8_xx_74_xx_30_xx, sB8_xx_zz_zz_30_xx_zz_zz); + return ShiftRight<48>(BitCast(di64, sF8_xx_xx_xx_70_xx_xx_xx)); +} + +// ------------------------------ RotateRight +template +HWY_API V RotateRight(const V v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ Shl +#define HWY_RVV_SHIFT_VV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, bits, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shl, sll, _ALL) + +#define HWY_RVV_SHIFT_II(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + const HWY_RVV_D(uint, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT) du; \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, BitCast(du, bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shl, sll, _ALL) + +// ------------------------------ Shr + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shr, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shr, sra, _ALL) + +#undef HWY_RVV_SHIFT_II +#undef HWY_RVV_SHIFT_VV + +// ------------------------------ RoundingShr +#define HWY_RVV_ROUNDING_SHR_VV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + v, bits, \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_ROUNDING_SHR_VV, RoundingShr, ssrl, _ALL) + +#define HWY_RVV_ROUNDING_SHR_II(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + const HWY_RVV_D(uint, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT) du; \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + v, BitCast(du, bits), \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_ROUNDING_SHR_II, RoundingShr, ssra, _ALL) + +#undef HWY_RVV_ROUNDING_SHR_VV +#undef HWY_RVV_ROUNDING_SHR_II + +// ------------------------------ Min + +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, MinS, minu_vx, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, MinS, min_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, MinS, fmin_vf, _ALL) + +} // namespace detail + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Min, minu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Min, min, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Min, fmin, _ALL) + +// ------------------------------ Max + +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, MaxS, maxu_vx, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, MaxS, max_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, MaxS, fmax_vf, _ALL) + +} // namespace detail + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Max, maxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Max, max, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Max, fmax, _ALL) + +// ------------------------------ Mul + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Mul, mul, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Mul, fmul, _ALL) + +// ------------------------------ MulHigh + +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) + +// ------------------------------ MulFixedPoint15 + +// Extra rounding mode = up argument. +#define HWY_RVV_MUL15(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + a, b, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_I16(HWY_RVV_MUL15, MulFixedPoint15, smul, _ALL) + +#undef HWY_RVV_MUL15 + +// ------------------------------ Div +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Div, divu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Div, div, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Div, fdiv, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Mod, remu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Mod, rem, _ALL) + +// ------------------------------ MaskedAddOr etc. + +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedMinOr, minu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedMinOr, min, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedMinOr, fmin, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedMaxOr, maxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedMaxOr, max, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedMaxOr, fmax, _ALL) + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGMVV, MaskedAddOr, add, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedAddOr, fadd, _ALL) + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGMVV, MaskedSubOr, sub, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedSubOr, fsub, _ALL) + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGMVV, MaskedMulOr, mul, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedMulOr, fmul, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedDivOr, divu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedDivOr, div, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedDivOr, fdiv, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedModOr, remu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedModOr, rem, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedSatAddOr, saddu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedSatAddOr, sadd, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedSatSubOr, ssubu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedSatSubOr, ssub, _ALL) + +// ------------------------------ ApproximateReciprocal +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, ApproximateReciprocal, frec7, _ALL) + +// ------------------------------ Sqrt +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, Sqrt, fsqrt, _ALL) + +// ------------------------------ ApproximateReciprocalSqrt +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, ApproximateReciprocalSqrt, frsqrt7, _ALL) + +// ------------------------------ MulAdd + +// Per-target flag to prevent generic_ops-inl.h from defining int MulAdd. +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA +#else +#define HWY_NATIVE_INT_FMA +#endif + +// Note: op is still named vv, not vvv. +#define HWY_RVV_FMA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) mul, HWY_RVV_V(BASE, SEW, LMUL) x, \ + HWY_RVV_V(BASE, SEW, LMUL) add) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(add, mul, x, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_FMA, MulAdd, macc, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulAdd, fmacc, _ALL) + +// ------------------------------ NegMulAdd +HWY_RVV_FOREACH_UI(HWY_RVV_FMA, NegMulAdd, nmsac, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulAdd, fnmsac, _ALL) + +// ------------------------------ MulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulSub, fmsac, _ALL) + +// ------------------------------ NegMulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL) + +#undef HWY_RVV_FMA + +// ================================================== COMPARE + +// ------------------------------ MClear + +// mask = f() +#define HWY_RVV_RETM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME##MLEN() { \ + return __riscv_vm##OP##_m_b##MLEN(HWY_RVV_AVL(SEW, SHIFT)); \ + } + +namespace detail { +HWY_RVV_FOREACH_B(HWY_RVV_RETM, MClear, clr) // with ##MLEN suffix +} // namespace detail + +#undef HWY_RVV_RETM + +// Comparisons set a mask bit to 1 if the condition is true, else 0. The XX in +// vboolXX_t is a power of two divisor for vector bits. SEW=8 / LMUL=1 = 1/8th +// of all bits; SEW=8 / LMUL=4 = half of all bits. + +// mask = f(vector, vector) +#define HWY_RVV_RETM_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN( \ + a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(mask, vector, vector) +#define HWY_RVV_RETM_ARGMVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) a, \ + HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN##_mu( \ + m, detail::MClear##MLEN(), a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(vector, scalar) +#define HWY_RVV_RETM_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL##_b##MLEN( \ + a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +// ------------------------------ Eq +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Eq, mseq, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Eq, mfeq, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGMVV, MaskedEq, mseq, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedEq, mfeq, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, EqS, mseq_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, EqS, mfeq_vf, _ALL) +} // namespace detail + +// ------------------------------ Ne +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Ne, msne, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Ne, mfne, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGMVV, MaskedNe, msne, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedNe, mfne, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, NeS, msne_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, NeS, mfne_vf, _ALL) +} // namespace detail + +// ------------------------------ Lt +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Lt, msltu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Lt, mslt, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Lt, mflt, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGMVV, MaskedLt, msltu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGMVV, MaskedLt, mslt, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedLt, mflt, _ALL) + +namespace detail { +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, LtS, mslt_vx, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVS, LtS, msltu_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, LtS, mflt_vf, _ALL) +} // namespace detail + +// ------------------------------ Le +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Le, msleu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Le, msle, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Le, mfle, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGMVV, MaskedLe, msleu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGMVV, MaskedLe, msle, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedLe, mfle, _ALL) + +template +using MFromD = decltype(Eq(Zero(D()), Zero(D()))); + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return MaskedNe(m, v, v); +} + +#undef HWY_RVV_RETM_ARGMVV +#undef HWY_RVV_RETM_ARGVV +#undef HWY_RVV_RETM_ARGVS + +// ------------------------------ Gt/Ge (Lt, Le) + +// Swap args to reverse comparisons: +template +HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) { + return Lt(b, a); +} + +template +HWY_API auto Ge(const V a, const V b) -> decltype(Le(a, b)) { + return Le(b, a); +} + +template > +HWY_API MFromD MaskedGt(M m, V a, V b) { + return MaskedLt(m, b, a); +} + +template > +HWY_API MFromD MaskedGe(M m, V a, V b) { + return MaskedLe(m, b, a); +} + +// ------------------------------ TestBit +template +HWY_API auto TestBit(const V a, const V bit) -> decltype(Eq(a, bit)) { + return detail::NeS(And(a, bit), 0); +} + +// ------------------------------ Not +// NOLINTNEXTLINE +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, Not, not ) + +// ------------------------------ And + +// mask = f(mask_a, mask_b) (note arg2,arg1 order!) +#define HWY_RVV_RETM_ARGMM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) a, HWY_RVV_M(MLEN) b) { \ + return __riscv_vm##OP##_mm_b##MLEN(b, a, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, And, and) + +// ------------------------------ AndNot +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, AndNot, andn) + +// ------------------------------ Or +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Or, or) + +// ------------------------------ Xor +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Xor, xor) + +// ------------------------------ ExclusiveNeither +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, ExclusiveNeither, xnor) + +#undef HWY_RVV_RETM_ARGMM + +// ------------------------------ IfThenElse + +#define HWY_RVV_IF_THEN_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) yes, \ + HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return __riscv_v##OP##_vvm_##CHAR##SEW##LMUL(no, yes, m, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_IF_THEN_ELSE, IfThenElse, merge, _ALL) + +#undef HWY_RVV_IF_THEN_ELSE + +// ------------------------------ IfThenElseZero +template +HWY_API V IfThenElseZero(const M mask, const V yes) { + return IfThenElse(mask, yes, Zero(DFromV())); +} + +// ------------------------------ IfThenZeroElse + +#define HWY_RVV_IF_THEN_ZERO_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(no, 0, m, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, merge_vxm, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, fmerge_vfm, _ALL) + +#undef HWY_RVV_IF_THEN_ZERO_ELSE + +// ------------------------------ MaskFromVec +template +HWY_API MFromD> MaskFromVec(const V v) { + return detail::NeS(v, 0); +} + +// ------------------------------ IsNegative (MFromD) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +// Generic for all vector lengths +template +HWY_API MFromD> IsNegative(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + + return detail::LtS(BitCast(di, v), static_cast(0)); +} + +// ------------------------------ MaskFalse + +// For mask ops including vmclr, elements past VL are tail-agnostic and cannot +// be relied upon, so define a variant of the generic_ops-inl implementation of +// MaskFalse that ensures all bits are zero as required by mask_test. +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +template +HWY_API MFromD MaskFalse(D d) { + const DFromV> d_full; + return MaskFromVec(Zero(d_full)); +} + +// ------------------------------ RebindMask +template +HWY_API MFromD RebindMask(const D /*d*/, const MFrom mask) { + // No need to check lane size/LMUL are the same: if not, casting MFrom to + // MFromD would fail. + return mask; +} + +// ------------------------------ VecFromMask + +// Returns mask ? ~0 : 0. No longer use sub.vx(Zero(), 1, mask) because per the +// default mask-agnostic policy, the result of inactive lanes may also be ~0. +#define HWY_RVV_VEC_FROM_MASK(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_M(MLEN) m) { \ + /* MaskFalse requires we set all lanes for capped d and virtual LMUL. */ \ + const DFromV> d_full; \ + const RebindToSigned di; \ + using TI = TFromD; \ + return BitCast(d_full, __riscv_v##OP##_i##SEW##LMUL(Zero(di), TI{-1}, m, \ + Lanes(d_full))); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_VEC_FROM_MASK, VecFromMask, merge_vxm, _ALL_VIRT) + +#undef HWY_RVV_VEC_FROM_MASK + +template +HWY_API VFromD VecFromMask(const D d, MFromD mask) { + return BitCast(d, VecFromMask(RebindToUnsigned(), mask)); +} + +// ------------------------------ IfVecThenElse (MaskFromVec) +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight) * 8 - 1>(v); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + return IfThenElse(IsNegative(v), yes, no); +} + +// ------------------------------ FindFirstTrue + +#define HWY_RVV_FIND_FIRST_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API intptr_t FindFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return __riscv_vfirst_m_b##MLEN(m, Lanes(d)); \ + } \ + template \ + HWY_API size_t FindKnownFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return static_cast(__riscv_vfirst_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_FIND_FIRST_TRUE, , _) +#undef HWY_RVV_FIND_FIRST_TRUE + +// ------------------------------ AllFalse +template +HWY_API bool AllFalse(D d, MFromD m) { + return FindFirstTrue(d, m) < 0; +} + +// ------------------------------ AllTrue + +#define HWY_RVV_ALL_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API bool AllTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return AllFalse(d, __riscv_vmnot_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_ALL_TRUE, _, _) +#undef HWY_RVV_ALL_TRUE + +// ------------------------------ CountTrue + +#define HWY_RVV_COUNT_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t CountTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return __riscv_vcpop_m_b##MLEN(m, Lanes(d)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_COUNT_TRUE, _, _) +#undef HWY_RVV_COUNT_TRUE + +// ------------------------------ PromoteMaskTo + +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +template )), + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return m; +} + +// ------------------------------ DemoteMaskTo + +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +template ) - 1), + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return m; +} + +// ================================================== MEMORY + +// ------------------------------ Load + +#define HWY_RVV_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(p), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_LOAD, Load, le, _ALL_VIRT) +#undef HWY_RVV_LOAD + +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, Load(du, detail::U16LanePointer(p))); +} + +// ------------------------------ LoadU +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + // RVV only requires element alignment, not vector alignment. + return Load(d, p); +} + +// ------------------------------ MaskedLoad + +#define HWY_RVV_MASKED_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_mu( \ + m, Zero(d), detail::NativeLanePointer(p), Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Or(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_mu( \ + m, v, detail::NativeLanePointer(p), Lanes(d)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_MASKED_LOAD, MaskedLoad, le, _ALL_VIRT) +#undef HWY_RVV_MASKED_LOAD + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, + MaskedLoad(RebindMask(du, m), du, detail::U16LanePointer(p))); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD no, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, MaskedLoadOr(BitCast(du, no), RebindMask(du, m), du, + detail::U16LanePointer(p))); +} + +// ------------------------------ LoadN + +// Native with avl is faster than the generic_ops using FirstN. +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +#define HWY_RVV_LOADN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t num_lanes) { \ + /* Use a tail-undisturbed load in LoadN as the tail-undisturbed load */ \ + /* operation below will leave any lanes past the first */ \ + /* (lowest-indexed) HWY_MIN(num_lanes, Lanes(d)) lanes unchanged */ \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_tu( \ + Zero(d), detail::NativeLanePointer(p), CappedLanes(d, num_lanes)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME##Or( \ + HWY_RVV_V(BASE, SEW, LMUL) no, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t num_lanes) { \ + /* Use a tail-undisturbed load in LoadNOr as the tail-undisturbed load */ \ + /* operation below will set any lanes past the first */ \ + /* (lowest-indexed) HWY_MIN(num_lanes, Lanes(d)) lanes to the */ \ + /* corresponding lanes in no */ \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_tu( \ + no, detail::NativeLanePointer(p), CappedLanes(d, num_lanes)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_LOADN, LoadN, le, _ALL_VIRT) +#undef HWY_RVV_LOADN + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast(d, LoadN(du, detail::U16LanePointer(p), num_lanes)); +} +template +HWY_API VFromD LoadNOr(VFromD v, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast( + d, LoadNOr(BitCast(du, v), du, detail::U16LanePointer(p), num_lanes)); +} + +// ------------------------------ Store + +#define HWY_RVV_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(p), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_STORE, Store, se, _ALL_VIRT) +#undef HWY_RVV_STORE + +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + Store(BitCast(du, v), du, detail::U16LanePointer(p)); +} + +// ------------------------------ BlendedStore + +#define HWY_RVV_BLENDED_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_m( \ + m, detail::NativeLanePointer(p), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_BLENDED_STORE, BlendedStore, se, _ALL_VIRT) +#undef HWY_RVV_BLENDED_STORE + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + BlendedStore(BitCast(du, v), RebindMask(du, m), du, + detail::U16LanePointer(p)); +} + +// ------------------------------ StoreN + +namespace detail { + +#define HWY_RVV_STOREN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(size_t count, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(p), v, count); \ + } +HWY_RVV_FOREACH(HWY_RVV_STOREN, StoreN, se, _ALL_VIRT) +#undef HWY_RVV_STOREN + +template +HWY_API void StoreN(size_t count, VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + StoreN(count, BitCast(du, v), du, detail::U16LanePointer(p)); +} + +} // namespace detail + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template +HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + // NOTE: Need to clamp max_lanes_to_store to Lanes(d), even if + // MaxLanes(d) >= MaxLanes(DFromV>()) is true, as it is possible for + // detail::StoreN(max_lanes_to_store, v, d, p) to store fewer than + // Lanes(DFromV>()) lanes to p if + // max_lanes_to_store > Lanes(DFromV>()) and + // max_lanes_to_store < 2 * Lanes(DFromV>()) are both true. + + // Also need to make sure that no more than Lanes(d) lanes are stored to p + // if Lanes(d) < Lanes(DFromV>()) is true, which is possible if + // MaxLanes(d) < MaxLanes(DFromV>()) or + // d.Pow2() < DFromV>().Pow2() is true. + detail::StoreN(CappedLanes(d, max_lanes_to_store), v, d, p); +} + +// ------------------------------ StoreU +template +HWY_API void StoreU(const V v, D d, TFromD* HWY_RESTRICT p) { + // RVV only requires element alignment, not vector alignment. + Store(v, d, p); +} + +// ------------------------------ Stream +template +HWY_API void Stream(const V v, D d, T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ ScatterOffset + +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +#define HWY_RVV_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + const RebindToUnsigned du; \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(base), BitCast(du, offset), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_SCATTER, ScatterOffset, sux, _ALL_VIRT) +#undef HWY_RVV_SCATTER + +// ------------------------------ ScatterIndex +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + VFromD> indices) { + constexpr size_t kBits = CeilLog2(sizeof(TFromD)); + return ScatterOffset(v, d, base, ShiftLeft(indices)); +} + +// ------------------------------ MaskedScatterIndex + +#define HWY_RVV_MASKED_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) indices) { \ + const RebindToUnsigned du; \ + constexpr size_t kBits = CeilLog2(sizeof(TFromD)); \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL##_m( \ + m, detail::NativeLanePointer(base), \ + ShiftLeft(BitCast(du, indices)), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_MASKED_SCATTER, MaskedScatterIndex, sux, _ALL_VIRT) +#undef HWY_RVV_MASKED_SCATTER + +// ------------------------------ GatherOffset + +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +#define HWY_RVV_GATHER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + const RebindToUnsigned du; \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(base), BitCast(du, offset), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_GATHER, GatherOffset, lux, _ALL_VIRT) +#undef HWY_RVV_GATHER + +// ------------------------------ GatherIndex + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, + const VFromD> index) { + constexpr size_t kBits = CeilLog2(sizeof(TFromD)); + return GatherOffset(d, base, ShiftLeft(index)); +} + +// ------------------------------ MaskedGatherIndexOr + +#define HWY_RVV_MASKED_GATHER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) no, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) indices) { \ + const RebindToUnsigned du; \ + const RebindToSigned di; \ + (void)di; /* for HWY_DASSERT */ \ + constexpr size_t kBits = CeilLog2(SEW / 8); \ + HWY_DASSERT(AllFalse(di, Lt(indices, Zero(di)))); \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL##_mu( \ + m, no, detail::NativeLanePointer(base), \ + ShiftLeft(BitCast(du, indices)), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_MASKED_GATHER, MaskedGatherIndexOr, lux, _ALL_VIRT) +#undef HWY_RVV_MASKED_GATHER + +template +HWY_API VFromD MaskedGatherIndex(MFromD m, D d, const TFromD* base, + VFromD> indices) { + return MaskedGatherIndexOr(Zero(d), m, d, base, indices); +} + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// SEW is for the input. +#define HWY_RVV_PROMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##CHAR##SEWD##LMULD(v, Lanes(d)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_PROMOTE, PromoteTo, zext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U16(HWY_RVV_PROMOTE, PromoteTo, zext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_PROMOTE, PromoteTo, zext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I08(HWY_RVV_PROMOTE, PromoteTo, sext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_PROMOTE, PromoteTo, sext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_PROMOTE, PromoteTo, sext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_F32(HWY_RVV_PROMOTE, PromoteTo, fwcvt_f_f_v_, _EXT_VIRT) + +#if HWY_HAVE_FLOAT16 || HWY_RVV_HAVE_F16C + +HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_PROMOTE, PromoteTo, fwcvt_f_f_v_, + _EXT_VIRT) + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif +#endif // HWY_HAVE_FLOAT16 || HWY_RVV_HAVE_F16C + +#undef HWY_RVV_PROMOTE + +// The above X-macro cannot handle 4x promotion nor type switching. +// TODO(janwas): use BASE2 arg to allow the latter. +#define HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, LMUL, LMUL_IN, \ + SHIFT, ADD) \ + template \ + HWY_API HWY_RVV_V(BASE, BITS, LMUL) \ + PromoteTo(HWY_RVV_D(BASE, BITS, N, SHIFT + ADD) d, \ + HWY_RVV_V(BASE_IN, BITS_IN, LMUL_IN) v) { \ + return __riscv_v##OP##CHAR##BITS##LMUL(v, Lanes(d)); \ + } + +#define HWY_RVV_PROMOTE_X2(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -2, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, m1, 0, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m2, 1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m4, 2, 1) + +#define HWY_RVV_PROMOTE_X4(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf4, -2, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, mf2, -1, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m1, 0, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m2, 1, 2) + +#define HWY_RVV_PROMOTE_X4_FROM_U8(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, mf2, mf8, -3, 2) \ + HWY_RVV_PROMOTE_X4(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) + +#define HWY_RVV_PROMOTE_X8(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf8, -3, 3) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, mf4, -2, 3) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, mf2, -1, 3) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m1, 0, 3) + +HWY_RVV_PROMOTE_X8(zext_vf8_, uint, u, 64, uint, 8) +HWY_RVV_PROMOTE_X8(sext_vf8_, int, i, 64, int, 8) + +HWY_RVV_PROMOTE_X4_FROM_U8(zext_vf4_, uint, u, 32, uint, 8) +HWY_RVV_PROMOTE_X4_FROM_U8(sext_vf4_, int, i, 32, int, 8) +HWY_RVV_PROMOTE_X4(zext_vf4_, uint, u, 64, uint, 16) +HWY_RVV_PROMOTE_X4(sext_vf4_, int, i, 64, int, 16) + +// i32 to f64 +HWY_RVV_PROMOTE_X2(fwcvt_f_x_v_, float, f, 64, int, 32) + +// u32 to f64 +HWY_RVV_PROMOTE_X2(fwcvt_f_xu_v_, float, f, 64, uint, 32) + +// f32 to i64 +HWY_RVV_PROMOTE_X2(fwcvt_rtz_x_f_v_, int, i, 64, float, 32) + +// f32 to u64 +HWY_RVV_PROMOTE_X2(fwcvt_rtz_xu_f_v_, uint, u, 64, float, 32) + +#undef HWY_RVV_PROMOTE_X8 +#undef HWY_RVV_PROMOTE_X4_FROM_U8 +#undef HWY_RVV_PROMOTE_X4 +#undef HWY_RVV_PROMOTE_X2 +#undef HWY_RVV_PROMOTE + +// I16->I64 or U16->U64 PromoteTo with virtual LMUL +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return PromoteTo(ScalableTag(), v); +} + +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return PromoteTo(ScalableTag(), v); +} + +// Unsigned to signed: cast for unsigned promote. +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + const RebindToSigned di32; + const Rebind du16; + return BitCast(d, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ DemoteTo U + +// SEW is for the source so we can use _DEMOTE_VIRT. +#define HWY_RVV_DEMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##CHAR##SEWH##LMULH( \ + v, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); \ + } + +// Unsigned -> unsigned +HWY_RVV_FOREACH_U16(HWY_RVV_DEMOTE, DemoteTo, nclipu_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE, DemoteTo, nclipu_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_U64(HWY_RVV_DEMOTE, DemoteTo, nclipu_wx_, _DEMOTE_VIRT) + +// SEW is for the source so we can use _DEMOTE_VIRT. +#define HWY_RVV_DEMOTE_I_TO_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(uint, SEWH, LMULH) NAME( \ + HWY_RVV_D(uint, SEWH, N, SHIFT - 1) dn, HWY_RVV_V(int, SEW, LMUL) v) { \ + const HWY_RVV_D(uint, SEW, N, SHIFT) du; \ + /* First clamp negative numbers to zero to match x86 packus. */ \ + return DemoteTo(dn, BitCast(du, detail::MaxS(v, 0))); \ + } +HWY_RVV_FOREACH_I64(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +#undef HWY_RVV_DEMOTE_I_TO_U + +template +HWY_API vuint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return __riscv_vnclipu_wx_u8mf8( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return __riscv_vnclipu_wx_u8mf4( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return __riscv_vnclipu_wx_u8mf2( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return __riscv_vnclipu_wx_u8m1( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return __riscv_vnclipu_wx_u8m2( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} + +template +HWY_API vuint8mf8_t DemoteTo(Simd d, const vuint32mf2_t v) { + return __riscv_vnclipu_wx_u8mf8( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8mf4_t DemoteTo(Simd d, const vuint32m1_t v) { + return __riscv_vnclipu_wx_u8mf4( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8mf2_t DemoteTo(Simd d, const vuint32m2_t v) { + return __riscv_vnclipu_wx_u8mf2( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8m1_t DemoteTo(Simd d, const vuint32m4_t v) { + return __riscv_vnclipu_wx_u8m1( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8m2_t DemoteTo(Simd d, const vuint32m8_t v) { + return __riscv_vnclipu_wx_u8m2( + DemoteTo(Simd(), v), 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +HWY_API vuint8mf8_t U8FromU32(const vuint32mf2_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8mf8( + __riscv_vnclipu_wx_u16mf4(v, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} +HWY_API vuint8mf4_t U8FromU32(const vuint32m1_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8mf4( + __riscv_vnclipu_wx_u16mf2(v, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} +HWY_API vuint8mf2_t U8FromU32(const vuint32m2_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8mf2( + __riscv_vnclipu_wx_u16m1(v, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} +HWY_API vuint8m1_t U8FromU32(const vuint32m4_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8m1( + __riscv_vnclipu_wx_u16m2(v, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} +HWY_API vuint8m2_t U8FromU32(const vuint32m8_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8m2( + __riscv_vnclipu_wx_u16m4(v, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +// ------------------------------ Truncations + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint32mf2_t v2 = __riscv_vnclipu_wx_u32mf2( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + const vuint16mf4_t v3 = __riscv_vnclipu_wx_u16mf4( + v2, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf8(v3, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint32m1_t v2 = __riscv_vnclipu_wx_u32m1( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + const vuint16mf2_t v3 = __riscv_vnclipu_wx_u16mf2( + v2, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf4(v3, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint32m2_t v2 = __riscv_vnclipu_wx_u32m2( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + const vuint16m1_t v3 = __riscv_vnclipu_wx_u16m1( + v2, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf2(v3, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint32m4_t v2 = __riscv_vnclipu_wx_u32m4( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + const vuint16m2_t v3 = __riscv_vnclipu_wx_u16m2( + v2, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8m1(v3, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32mf2_t v2 = __riscv_vnclipu_wx_u32mf2( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16mf4(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32mf2_t v2 = __riscv_vnclipu_wx_u32mf2( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16mf4(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32m1_t v2 = __riscv_vnclipu_wx_u32m1( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16mf2(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32m2_t v2 = __riscv_vnclipu_wx_u32m2( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16m1(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32m4_t v2 = __riscv_vnclipu_wx_u32m4( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16m2(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32mf2(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32mf2(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32m1(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32m2(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32m4(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16mf4_t v2 = __riscv_vnclipu_wx_u16mf4( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf8(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16mf2_t v2 = __riscv_vnclipu_wx_u16mf2( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf4(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16m1_t v2 = __riscv_vnclipu_wx_u16m1( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf2(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16m2_t v2 = __riscv_vnclipu_wx_u16m2( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8m1(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16m4_t v2 = __riscv_vnclipu_wx_u16m4( + v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8m2(v2, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16mf4(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16mf4(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16mf2(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16m1(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16m2(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16m4(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16mf4_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8mf8(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16mf2_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8mf4(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m1_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8mf2(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m2_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8m1(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m4_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8m2(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m8_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8m4(v1, 0, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +// ------------------------------ DemoteTo I + +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE, DemoteTo, nclip_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE, DemoteTo, nclip_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I64(HWY_RVV_DEMOTE, DemoteTo, nclip_wx_, _DEMOTE_VIRT) + +template +HWY_API vint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +#undef HWY_RVV_DEMOTE + +// ------------------------------ DemoteTo F + +// SEW is for the source so we can use _DEMOTE_VIRT. +#define HWY_RVV_DEMOTE_F(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##SEWH##LMULH(v, Lanes(d)); \ + } + +#if HWY_HAVE_FLOAT16 || HWY_RVV_HAVE_F16C +HWY_RVV_FOREACH_F32(HWY_RVV_DEMOTE_F, DemoteTo, fncvt_f_f_w_f, _DEMOTE_VIRT) +#endif +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteTo, fncvt_f_f_w_f, _DEMOTE_VIRT) + +namespace detail { +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteToF32WithRoundToOdd, + fncvt_rod_f_f_w_f, _DEMOTE_VIRT) +} // namespace detail + +#undef HWY_RVV_DEMOTE_F + +// TODO(janwas): add BASE2 arg to allow generating this via DEMOTE_F. +template +HWY_API vint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32m1_t DemoteTo(Simd d, const vfloat64m2_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32m1(v, Lanes(d)); +} +template +HWY_API vint32m2_t DemoteTo(Simd d, const vfloat64m4_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32m2(v, Lanes(d)); +} +template +HWY_API vint32m4_t DemoteTo(Simd d, const vfloat64m8_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32m4(v, Lanes(d)); +} + +template +HWY_API vuint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32mf2(v, Lanes(d)); +} +template +HWY_API vuint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32mf2(v, Lanes(d)); +} +template +HWY_API vuint32m1_t DemoteTo(Simd d, const vfloat64m2_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32m1(v, Lanes(d)); +} +template +HWY_API vuint32m2_t DemoteTo(Simd d, const vfloat64m4_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32m2(v, Lanes(d)); +} +template +HWY_API vuint32m4_t DemoteTo(Simd d, const vfloat64m8_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32m4(v, Lanes(d)); +} + +template +HWY_API vfloat32mf2_t DemoteTo(Simd d, const vint64m1_t v) { + return __riscv_vfncvt_f_x_w_f32mf2(v, Lanes(d)); +} +template +HWY_API vfloat32mf2_t DemoteTo(Simd d, const vint64m1_t v) { + return __riscv_vfncvt_f_x_w_f32mf2(v, Lanes(d)); +} +template +HWY_API vfloat32m1_t DemoteTo(Simd d, const vint64m2_t v) { + return __riscv_vfncvt_f_x_w_f32m1(v, Lanes(d)); +} +template +HWY_API vfloat32m2_t DemoteTo(Simd d, const vint64m4_t v) { + return __riscv_vfncvt_f_x_w_f32m2(v, Lanes(d)); +} +template +HWY_API vfloat32m4_t DemoteTo(Simd d, const vint64m8_t v) { + return __riscv_vfncvt_f_x_w_f32m4(v, Lanes(d)); +} + +template +HWY_API vfloat32mf2_t DemoteTo(Simd d, const vuint64m1_t v) { + return __riscv_vfncvt_f_xu_w_f32mf2(v, Lanes(d)); +} +template +HWY_API vfloat32mf2_t DemoteTo(Simd d, const vuint64m1_t v) { + return __riscv_vfncvt_f_xu_w_f32mf2(v, Lanes(d)); +} +template +HWY_API vfloat32m1_t DemoteTo(Simd d, const vuint64m2_t v) { + return __riscv_vfncvt_f_xu_w_f32m1(v, Lanes(d)); +} +template +HWY_API vfloat32m2_t DemoteTo(Simd d, const vuint64m4_t v) { + return __riscv_vfncvt_f_xu_w_f32m2(v, Lanes(d)); +} +template +HWY_API vfloat32m4_t DemoteTo(Simd d, const vuint64m8_t v) { + return __riscv_vfncvt_f_xu_w_f32m4(v, Lanes(d)); +} + +// Narrows f32 bits to bf16 using round to even. +// SEW is for the source so we can use _DEMOTE_VIRT. +#ifdef HWY_RVV_AVOID_VXRM +#define HWY_RVV_DEMOTE_16_NEAREST_EVEN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, \ + LMULD, LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + const auto round = \ + detail::AddS(detail::AndS(ShiftRight<16>(v), 1u), 0x7FFFu); \ + v = Add(v, round); \ + /* The default rounding mode appears to be RNU=0, which adds the LSB. */ \ + /* Prevent further rounding by clearing the bits we want to truncate. */ \ + v = detail::AndS(v, 0xFFFF0000u); \ + return __riscv_v##OP##CHAR##SEWH##LMULH(v, 16, Lanes(d)); \ + } + +#else +#define HWY_RVV_DEMOTE_16_NEAREST_EVEN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, \ + LMULD, LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##CHAR##SEWH##LMULH( \ + v, 16, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNE, Lanes(d))); \ + } +#endif // HWY_RVV_AVOID_VXRM +namespace detail { +HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE_16_NEAREST_EVEN, DemoteTo16NearestEven, + nclipu_wx_, _DEMOTE_VIRT) +} +#undef HWY_RVV_DEMOTE_16_NEAREST_EVEN + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +template +HWY_API VFromD DemoteTo(DBF16 d, VFromD> v) { + const DFromV df; + const RebindToUnsigned du32; + const RebindToUnsigned du16; + // Consider an f32 mantissa with the upper 7 bits set, followed by a 1-bit + // and at least one other bit set. This will round to 0 and increment the + // exponent. If the exponent was already 0xFF (NaN), then the result is -inf; + // there no wraparound because nclipu saturates. Note that in this case, the + // input cannot have been inf because its mantissa bits are zero. To avoid + // converting NaN to inf, we canonicalize the NaN to prevent the rounding. + const decltype(v) canonicalized = + IfThenElse(Eq(v, v), v, BitCast(df, Set(du32, 0x7F800000))); + return BitCast( + d, detail::DemoteTo16NearestEven(du16, BitCast(du32, canonicalized))); +} + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const Rebind df32; + return DemoteTo(df16, detail::DemoteToF32WithRoundToOdd(df32, v)); +} + +// ------------------------------ ConvertTo F + +#define HWY_RVV_CONVERT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(int, SEW, LMUL) v) { \ + return __riscv_vfcvt_f_x_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(uint, SEW, LMUL) v) { \ + return __riscv_vfcvt_f_xu_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + /* Truncates (rounds toward zero). */ \ + template \ + HWY_API HWY_RVV_V(int, SEW, LMUL) ConvertTo(HWY_RVV_D(int, SEW, N, SHIFT) d, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_rtz_x_f_v_i##SEW##LMUL(v, Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(uint, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(uint, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_rtz_xu_f_v_u##SEW##LMUL(v, Lanes(d)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _, _ALL_VIRT) +#undef HWY_RVV_CONVERT + +// Uses default rounding mode. Must be separate because there is no D arg. +#define HWY_RVV_NEAREST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) NearestInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_x_f_v_i##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_NEAREST, _, _, _ALL) +#undef HWY_RVV_NEAREST + +template +HWY_API vint32mf2_t DemoteToNearestInt(Simd d, + const vfloat64m1_t v) { + return __riscv_vfncvt_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32mf2_t DemoteToNearestInt(Simd d, + const vfloat64m1_t v) { + return __riscv_vfncvt_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32m1_t DemoteToNearestInt(Simd d, + const vfloat64m2_t v) { + return __riscv_vfncvt_x_f_w_i32m1(v, Lanes(d)); +} +template +HWY_API vint32m2_t DemoteToNearestInt(Simd d, + const vfloat64m4_t v) { + return __riscv_vfncvt_x_f_w_i32m2(v, Lanes(d)); +} +template +HWY_API vint32m4_t DemoteToNearestInt(Simd d, + const vfloat64m8_t v) { + return __riscv_vfncvt_x_f_w_i32m4(v, Lanes(d)); +} + +// ================================================== COMBINE + +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template +HWY_INLINE size_t LanesPerBlock(Simd d) { + // kMinVecBytes is the minimum size of VFromD in bytes + constexpr size_t kMinVecBytes = + ScaleByPower(16, HWY_MAX(HWY_MIN(kPow2, 3), -3)); + // kMinVecLanes is the minimum number of lanes in VFromD + constexpr size_t kMinVecLanes = (kMinVecBytes + sizeof(T) - 1) / sizeof(T); + // kMaxLpb is the maximum number of lanes per block + constexpr size_t kMaxLpb = HWY_MIN(16 / sizeof(T), MaxLanes(d)); + + // If kMaxLpb <= kMinVecLanes is true, then kMaxLpb <= Lanes(d) is true + if (kMaxLpb <= kMinVecLanes) return kMaxLpb; + + // Fractional LMUL: Lanes(d) may be smaller than kMaxLpb, so honor that. + const size_t lanes_per_vec = Lanes(d); + return HWY_MIN(lanes_per_vec, kMaxLpb); +} + +template +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned>; + return AndS(iota0, static_cast(~(LanesPerBlock(d) - 1))); +} + +template +HWY_INLINE MFromD FirstNPerBlock(D /* tag */) { + const RebindToUnsigned du; + const RebindToSigned di; + using TU = TFromD; + const auto idx_mod = AndS(Iota0(du), static_cast(LanesPerBlock(du) - 1)); + return LtS(BitCast(di, idx_mod), static_cast>(kLanes)); +} + +#define HWY_RVV_SLIDE_UP(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dst, HWY_RVV_V(BASE, SEW, LMUL) src, \ + size_t lanes) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(dst, src, lanes, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +#define HWY_RVV_SLIDE_DOWN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) src, size_t lanes) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(src, lanes, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_SLIDE_UP, SlideUp, slideup, _ALL) +HWY_RVV_FOREACH(HWY_RVV_SLIDE_DOWN, SlideDown, slidedown, _ALL) + +#undef HWY_RVV_SLIDE_UP +#undef HWY_RVV_SLIDE_DOWN + +#define HWY_RVV_GET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH( \ + v, kIndex); /* no AVL */ \ + } +#define HWY_RVV_GET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { return Trunc(v); } \ + HWY_IF_CONSTEXPR(kIndex != 0) { \ + return Trunc(SlideDown( \ + v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ + SHIFT - 1){}))); \ + } \ + } +#define HWY_RVV_GET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { return v; } \ + HWY_IF_CONSTEXPR(kIndex != 0) { \ + return SlideDown( \ + v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ + SHIFT){}) / \ + 2); \ + } \ + } +HWY_RVV_FOREACH(HWY_RVV_GET, Get, get, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_GET_VIRT, Get, get, _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_GET_SMALLEST, Get, get, _GET_SET_SMALLEST) +#undef HWY_RVV_GET +#undef HWY_RVV_GET_VIRT +#undef HWY_RVV_GET_SMALLEST + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD>> +Get(D d, VFromD v) { + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); + HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) { return Get(v); } + HWY_IF_CONSTEXPR(kIndex != 0 && !detail::IsFull(d)) { + const AdjustSimdTagToMinVecPow2> dh; + const size_t slide_down_amt = + (dh.Pow2() < DFromV().Pow2()) ? Lanes(dh) : (Lanes(d) / 2); + return ResizeBitCast(dh, SlideDown(v, slide_down_amt)); + } +} + +#define HWY_RVV_PARTIAL_VEC_SET_HALF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v, \ + size_t half_N) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + const DFromV d; \ + HWY_IF_CONSTEXPR(kIndex == 0) { \ + return __riscv_v##OP##_v_v_##CHAR##SEW##LMUL##_tu(dest, Ext(d, v), \ + half_N); \ + } \ + HWY_IF_CONSTEXPR(kIndex != 0) { return SlideUp(dest, Ext(d, v), half_N); } \ + } +#define HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST( \ + BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v, \ + size_t half_N) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { \ + return __riscv_v##OP##_v_v_##CHAR##SEW##LMUL##_tu(dest, v, half_N); \ + } \ + HWY_IF_CONSTEXPR(kIndex != 0) { return SlideUp(dest, v, half_N); } \ + } +HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF, PartialVecSetHalf, mv, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF, PartialVecSetHalf, mv, + _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST, PartialVecSetHalf, mv, + _GET_SET_SMALLEST) +#undef HWY_RVV_PARTIAL_VEC_SET_HALF +#undef HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) dest, \ + HWY_RVV_V(BASE, SEW, LMULH) v) { \ + HWY_IF_CONSTEXPR(detail::IsFull(d)) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMULH##_##CHAR##SEW##LMUL( \ + dest, kIndex, v); /* no AVL */ \ + } \ + HWY_IF_CONSTEXPR(!detail::IsFull(d)) { \ + const Half dh; \ + return PartialVecSetHalf(dest, v, Lanes(dh)); \ + } \ + } +#define HWY_RVV_SET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) dest, \ + HWY_RVV_V(BASE, SEW, LMULH) v) { \ + const Half dh; \ + return PartialVecSetHalf(dest, v, Lanes(dh)); \ + } +#define HWY_RVV_SET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) dest, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return PartialVecSetHalf(dest, v, Lanes(d) / 2); \ + } +#define HWY_RVV_SET_SMALLEST_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT - 1) d, \ + HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return PartialVecSetHalf(dest, v, Lanes(d) / 2); \ + } +HWY_RVV_FOREACH(HWY_RVV_SET, Set, set, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_SET_VIRT, Set, set, _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_SET_SMALLEST, Set, set, _GET_SET_SMALLEST) +HWY_RVV_FOREACH_UI163264(HWY_RVV_SET_SMALLEST_VIRT, Set, set, _GET_SET_SMALLEST) +HWY_RVV_FOREACH_F(HWY_RVV_SET_SMALLEST_VIRT, Set, set, _GET_SET_SMALLEST) +#undef HWY_RVV_SET +#undef HWY_RVV_SET_VIRT +#undef HWY_RVV_SET_SMALLEST +#undef HWY_RVV_SET_SMALLEST_VIRT + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD Set( + D d, VFromD dest, VFromD>> v) { + const RebindToUnsigned du; + return BitCast( + d, Set(du, BitCast(du, dest), + BitCast(RebindToUnsigned>(), v))); +} + +} // namespace detail + +// ------------------------------ SlideUpLanes +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + return detail::SlideUp(Zero(d), v, amt); +} + +// ------------------------------ SlideDownLanes +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + v = detail::SlideDown(v, amt); + // Zero out upper lanes if v is a partial vector + if (MaxLanes(d) < MaxLanes(DFromV())) { + v = detail::SlideUp(v, Zero(d), Lanes(d) - amt); + } + return v; +} + +// ------------------------------ ConcatUpperLower +template +HWY_API V ConcatUpperLower(D d, const V hi, const V lo) { + const auto lo_lower = detail::Get<0>(d, lo); + return detail::Set<0>(d, hi, lo_lower); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API V ConcatLowerLower(D d, const V hi, const V lo) { + const auto hi_lower = detail::Get<0>(d, hi); + return detail::Set<1>(d, lo, hi_lower); +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API V ConcatUpperUpper(D d, const V hi, const V lo) { + const auto lo_upper = detail::Get<1>(d, lo); + return detail::Set<0>(d, hi, lo_upper); +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { + const auto lo_upper = detail::Get<1>(d, lo); + const auto hi_lower = detail::Get<0>(d, hi); + return detail::Set<1>(d, ResizeBitCast(d, lo_upper), hi_lower); +} + +// ------------------------------ Combine +template +HWY_API VFromD Combine(D2 d2, const V hi, const V lo) { + return detail::Set<1>(d2, ResizeBitCast(d2, lo), hi); +} + +// ------------------------------ ZeroExtendVector +template +HWY_API VFromD ZeroExtendVector(D2 d2, const V lo) { + return Combine(d2, Xor(lo, lo), lo); +} + +// ------------------------------ Lower/UpperHalf + +namespace detail { + +// RVV may only support LMUL >= SEW/64; returns whether that holds for D. Note +// that SEW = sizeof(T)*8 and LMUL = 1 << d.Pow2(). Add 3 to Pow2 to avoid +// negative shift counts. +template +constexpr bool IsSupportedLMUL(D d) { + return (size_t{1} << (d.Pow2() + 3)) >= sizeof(TFromD); +} + +} // namespace detail + +// If IsSupportedLMUL, just 'truncate' i.e. halve LMUL. +template * = nullptr> +HWY_API VFromD LowerHalf(const DH /* tag */, const VFromD> v) { + return detail::Trunc(v); +} + +// Otherwise, there is no corresponding intrinsic type (e.g. vuint64mf2_t), and +// the hardware may set "vill" if we attempt such an LMUL. However, the V +// extension on application processors requires Zvl128b, i.e. VLEN >= 128, so it +// still makes sense to have half of an SEW=64 vector. We instead just return +// the vector, and rely on the kPow2 in DH to halve the return value of Lanes(). +template * = nullptr> +HWY_API V LowerHalf(const DH /* tag */, const V v) { + return v; +} + +// Same, but without D arg +template +HWY_API VFromD>> LowerHalf(const V v) { + return LowerHalf(Half>(), v); +} + +template +HWY_API VFromD UpperHalf(const DH /*d2*/, const VFromD> v) { + const Twice d; + return detail::Get<1>(d, v); +} + +// ================================================== SWIZZLE + +namespace detail { +// Special instruction for 1 lane is presumably faster? +#define HWY_RVV_SLIDE1(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(v, 0, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SLIDE1, Slide1Up, slide1up_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_SLIDE1, Slide1Up, fslide1up_vf, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_SLIDE1, Slide1Down, slide1down_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_SLIDE1, Slide1Down, fslide1down_vf, _ALL) +#undef HWY_RVV_SLIDE1 +} // namespace detail + +// ------------------------------ Slide1Up and Slide1Down +#ifdef HWY_NATIVE_SLIDE1_UP_DOWN +#undef HWY_NATIVE_SLIDE1_UP_DOWN +#else +#define HWY_NATIVE_SLIDE1_UP_DOWN +#endif + +template +HWY_API VFromD Slide1Up(D /*d*/, VFromD v) { + return detail::Slide1Up(v); +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + v = detail::Slide1Down(v); + // Zero out upper lanes if v is a partial vector + if (MaxLanes(d) < MaxLanes(DFromV())) { + v = detail::SlideUp(v, Zero(d), Lanes(d) - 1); + } + return v; +} + +// ------------------------------ GetLane + +#define HWY_RVV_GET_LANE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_T(BASE, SEW) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_s_##CHAR##SEW##LMUL##_##CHAR##SEW(v); /* no AVL */ \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_GET_LANE, GetLane, mv_x, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_GET_LANE, GetLane, fmv_f, _ALL) +#undef HWY_RVV_GET_LANE + +// ------------------------------ ExtractLane +template +HWY_API TFromV ExtractLane(const V v, size_t i) { + return GetLane(detail::SlideDown(v, i)); +} + +// ------------------------------ Additional mask logical operations + +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, SetOnlyFirst, sof) +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, SetBeforeFirst, sbf) +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, SetAtOrBeforeFirst, sif) + +#define HWY_RVV_SET_AT_OR_AFTER_FIRST(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) SetAtOrAfterFirst(HWY_RVV_M(MLEN) m) { \ + return Not(SetBeforeFirst(m)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_SET_AT_OR_AFTER_FIRST, _, _) +#undef HWY_RVV_SET_AT_OR_AFTER_FIRST + +// ------------------------------ InsertLane + +// T template arg because TFromV might not match the hwy::float16_t argument. +template +HWY_API V InsertLane(const V v, size_t i, T t) { + const Rebind> d; + const RebindToUnsigned du; // Iota0 is unsigned only + using TU = TFromD; + const auto is_i = detail::EqS(detail::Iota0(du), static_cast(i)); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// For 8-bit lanes, Iota0 might overflow. +template +HWY_API V InsertLane(const V v, size_t i, T t) { + const Rebind> d; + const auto zero = Zero(d); + const auto one = Set(d, 1); + const auto ge_i = Eq(detail::SlideUp(zero, one, i), one); + const auto is_i = SetOnlyFirst(ge_i); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// ------------------------------ OddEven + +namespace detail { + +// Faster version using a wide constant instead of Iota0 + AndS. +template +HWY_INLINE MFromD IsEven(D d) { + const RebindToUnsigned du; + const RepartitionToWide duw; + return RebindMask(d, detail::NeS(BitCast(du, Set(duw, 1)), 0u)); +} + +template +HWY_INLINE MFromD IsEven(D d) { + const RebindToUnsigned du; // Iota0 is unsigned only + return detail::EqS(detail::AndS(detail::Iota0(du), 1), 0); +} + +// Also provide the negated form because there is no native CompressNot. +template +HWY_INLINE MFromD IsOdd(D d) { + const RebindToUnsigned du; + const RepartitionToWide duw; + return RebindMask(d, detail::EqS(BitCast(du, Set(duw, 1)), 0u)); +} + +template +HWY_INLINE MFromD IsOdd(D d) { + const RebindToUnsigned du; // Iota0 is unsigned only + return detail::NeS(detail::AndS(detail::Iota0(du), 1), 0); +} + +} // namespace detail + +template +HWY_API V OddEven(const V a, const V b) { + return IfThenElse(detail::IsEven(DFromV()), b, a); +} + +// ------------------------------ DupEven (OddEven) +template +HWY_API V DupEven(const V v) { + const V up = detail::Slide1Up(v); + return OddEven(up, v); +} + +// ------------------------------ DupOdd (OddEven) +template +HWY_API V DupOdd(const V v) { + const V down = detail::Slide1Down(v); + return OddEven(v, down); +} + +// ------------------------------ InterleaveEven (OddEven) +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return OddEven(detail::Slide1Up(b), a); +} + +// ------------------------------ InterleaveOdd (OddEven) +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return OddEven(b, detail::Slide1Down(a)); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API V OddEvenBlocks(const V a, const V b) { + const RebindToUnsigned> du; // Iota0 is unsigned only + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromV)); + const auto idx_block = ShiftRight(detail::Iota0(du)); + const auto is_even = detail::EqS(detail::AndS(idx_block, 1), 0); + return IfThenElse(is_even, b, a); +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV d; + const size_t lpb = detail::LanesPerBlock(d); + const V down = detail::SlideDown(v, lpb); + const V up = detail::SlideUp(v, v, lpb); + return OddEvenBlocks(up, down); +} + +// ------------------------------ InterleaveEvenBlocks +// (SlideUpLanes, OddEvenBlocks) + +template > +HWY_API V InterleaveEvenBlocks(D d, V a, V b) { + const size_t lpb = detail::LanesPerBlock(d); + return OddEvenBlocks(SlideUpLanes(d, b, lpb), a); +} + +// ------------------------------ InterleaveOddBlocks +// (SlideDownLanes, OddEvenBlocks) + +template > +HWY_API V InterleaveOddBlocks(D d, V a, V b) { + const size_t lpb = detail::LanesPerBlock(d); + return OddEvenBlocks(b, SlideDownLanes(d, a, lpb)); +} + +// ------------------------------ TableLookupLanes + +template +HWY_API VFromD> IndicesFromVec(D d, VI vec) { + static_assert(sizeof(TFromD) == sizeof(TFromV), "Index != lane"); + const RebindToUnsigned du; // instead of : avoids unused d. + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + using TU = TFromD; + const size_t twice_num_of_lanes = Lanes(d) * 2; + HWY_DASSERT(AllTrue( + du, Eq(indices, + detail::AndS(indices, static_cast(twice_num_of_lanes - 1))))); +#endif + return indices; +} + +template +HWY_API VFromD> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +#define HWY_RVV_TABLE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEW, LMUL) idx) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// TableLookupLanes is supported for all types, but beware that indices are +// likely to wrap around for 8-bit lanes. When using TableLookupLanes inside +// this file, ensure that it is safe or use TableLookupLanes16 instead. +HWY_RVV_FOREACH(HWY_RVV_TABLE, TableLookupLanes, rgather, _ALL) +#undef HWY_RVV_TABLE + +namespace detail { + +#define HWY_RVV_TABLE16(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEWD, LMULD) idx) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI08(HWY_RVV_TABLE16, TableLookupLanes16, rgatherei16, _EXT) +#undef HWY_RVV_TABLE16 + +// Used by Expand. +#define HWY_RVV_MASKED_TABLE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) mask, HWY_RVV_V(BASE, SEW, LMUL) maskedoff, \ + HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEW, LMUL) idx) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_mu(mask, maskedoff, v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_MASKED_TABLE, MaskedTableLookupLanes, rgather, _ALL) +#undef HWY_RVV_MASKED_TABLE + +#define HWY_RVV_MASKED_TABLE16(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) mask, HWY_RVV_V(BASE, SEW, LMUL) maskedoff, \ + HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEWD, LMULD) idx) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_mu(mask, maskedoff, v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI08(HWY_RVV_MASKED_TABLE16, MaskedTableLookupLanes16, + rgatherei16, _EXT) +#undef HWY_RVV_MASKED_TABLE16 + +} // namespace detail + +// ------------------------------ Reverse (TableLookupLanes) +template +HWY_API VFromD Reverse(D d, VFromD v) { + const Rebind du16; + const size_t N = Lanes(d); + const auto idx = + detail::ReverseSubS(detail::Iota0(du16), static_cast(N - 1)); + return detail::TableLookupLanes16(v, idx); +} + +template +HWY_API VFromD Reverse(D d, VFromD v) { + const Half dh; + const Rebind du16; + const size_t half_n = Lanes(dh); + const auto idx = detail::ReverseSubS(detail::Iota0(du16), + static_cast(half_n - 1)); + const auto reversed_lo = detail::TableLookupLanes16(LowerHalf(dh, v), idx); + const auto reversed_hi = detail::TableLookupLanes16(UpperHalf(dh, v), idx); + return Combine(d, reversed_lo, reversed_hi); +} + +template +HWY_API VFromD Reverse(D /* tag */, VFromD v) { + const RebindToUnsigned du; + using TU = TFromD; + const size_t N = Lanes(du); + const auto idx = + detail::ReverseSubS(detail::Iota0(du), static_cast(N - 1)); + return TableLookupLanes(v, idx); +} + +// ------------------------------ ResizeBitCast + +// Extends or truncates a vector to match the given d. +namespace detail { + +template +HWY_INLINE VFromD ChangeLMUL(D /* d */, VFromD v) { + return v; +} + +// Sanity check: when calling ChangeLMUL, the caller (ResizeBitCast) already +// BitCast to the same lane type. Note that V may use the native lane type for +// f16, so convert D to that before checking. +#define HWY_RVV_IF_SAME_T_DV(D, V) \ + hwy::EnableIf>, TFromV>()>* = nullptr + +// LMUL of VFromD < LMUL of V: need to truncate v +template >, DFromV().Pow2() - 1)> +HWY_INLINE VFromD ChangeLMUL(D d, V v) { + const DFromV d_from; + const Half dh_from; + static_assert( + DFromV>().Pow2() < DFromV().Pow2(), + "The LMUL of VFromD must be less than the LMUL of V"); + static_assert( + DFromV>().Pow2() <= DFromV>().Pow2(), + "The LMUL of VFromD must be less than or equal to the LMUL of " + "VFromD"); + return ChangeLMUL(d, Trunc(v)); +} + +// LMUL of VFromD > LMUL of V: need to extend v +template >, DFromV().Pow2())> +HWY_INLINE VFromD ChangeLMUL(D d, V v) { + const DFromV d_from; + const Twice dt_from; + static_assert(DFromV>().Pow2() > DFromV().Pow2(), + "The LMUL of VFromD must be greater than " + "the LMUL of V"); + static_assert( + DFromV>().Pow2() >= DFromV>().Pow2(), + "The LMUL of VFromD must be greater than or equal to the LMUL of " + "VFromD"); + return ChangeLMUL(d, Ext(dt_from, v)); +} + +#undef HWY_RVV_IF_SAME_T_DV + +} // namespace detail + +template +HWY_API VFromD ResizeBitCast(DTo /*dto*/, VFrom v) { + const DFromV d_from; + const Repartition du8_from; + const DFromV> d_to; + const Repartition du8_to; + return BitCast(d_to, detail::ChangeLMUL(du8_to, BitCast(du8_from, v))); +} + +// ------------------------------ Reverse2 (RotateRight, OddEven) + +// Per-target flags to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +// Shifting and adding requires fewer instructions than blending, but casting to +// u32 only works for LMUL in [1/2, 8]. + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du16; + return ResizeBitCast(d, RotateRight<8>(ResizeBitCast(du16, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du32; + return ResizeBitCast(d, RotateRight<16>(ResizeBitCast(du32, v))); +} + +// Shifting and adding requires fewer instructions than blending, but casting to +// u64 does not work for LMUL < 1. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast(d, RotateRight<32>(ResizeBitCast(du64, v))); +} + +template , HWY_IF_T_SIZE_D(D, 8)> +HWY_API V Reverse2(D /* tag */, const V v) { + const V up = detail::Slide1Up(v); + const V down = detail::Slide1Down(v); + return OddEven(up, down); +} + +// ------------------------------ Reverse4 (TableLookupLanes) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du16; + return ResizeBitCast(d, Reverse2(du16, ResizeBitCast(du16, Reverse2(d, v)))); +} + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 3); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ Reverse8 (TableLookupLanes) + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du32; + return ResizeBitCast(d, Reverse2(du32, ResizeBitCast(du32, Reverse4(d, v)))); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 7); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { + const detail::AdjustSimdTagToMinVecPow2> du64; + const size_t N = Lanes(du64); + const auto rev = + detail::ReverseSubS(detail::Iota0(du64), static_cast(N - 1)); + // Swap lo/hi u64 within each block + const auto idx = detail::XorS(rev, 1); + return ResizeBitCast(d, TableLookupLanes(ResizeBitCast(du64, v), idx)); +} + +// ------------------------------ Compress + +// RVV supports all lane types natively. +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +template +struct CompressIsPartition { + enum { value = 0 }; +}; + +#define HWY_RVV_COMPRESS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) mask) { \ + return __riscv_v##OP##_vm_##CHAR##SEW##LMUL(v, mask, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_COMPRESS, Compress, compress, _ALL) +#undef HWY_RVV_COMPRESS + +// ------------------------------ Expand + +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +// >= 2-byte lanes: idx lanes will not overflow. +template +HWY_API V Expand(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto idx = detail::MaskedIota(du, RebindMask(du, mask)); + const V zero = Zero(d); + return detail::MaskedTableLookupLanes(mask, zero, v, idx); +} + +// 1-byte lanes, LMUL < 8: promote idx to u16. +template , + HWY_IF_POW2_LE_D(D, 2)> +HWY_API V Expand(V v, const M mask) { + const D d; + const Rebind du16; + const auto idx = detail::MaskedIota(du16, RebindMask(du16, mask)); + const V zero = Zero(d); + return detail::MaskedTableLookupLanes16(mask, zero, v, idx); +} + +// 1-byte lanes, max LMUL: unroll 2x. +template , + HWY_IF_POW2_GT_D(DFromV, 2)> +HWY_API V Expand(V v, const M mask) { + const D d; + const Half dh; + const auto v0 = LowerHalf(dh, v); + // TODO(janwas): skip vec<->mask if we can cast masks. + const V vmask = VecFromMask(d, mask); + const auto m0 = MaskFromVec(LowerHalf(dh, vmask)); + + // Cannot just use UpperHalf, must shift by the number of inputs consumed. + const size_t count = CountTrue(dh, m0); + const auto v1 = detail::Trunc(detail::SlideDown(v, count)); + const auto m1 = MaskFromVec(UpperHalf(dh, vmask)); + return Combine(d, Expand(v1, m1), Expand(v0, m0)); +} + +// ------------------------------ LoadExpand +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +// ------------------------------ CompressNot +template +HWY_API V CompressNot(V v, const M mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +template +HWY_API V CompressBlocksNot(V v, const M mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const V v, const M mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const V v, const M mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + StoreN(Compress(v, mask), d, unaligned, count); + return count; +} + +// ================================================== COMPARE (2) + +// ------------------------------ FindLastTrue + +template +HWY_API intptr_t FindLastTrue(D d, MFromD m) { + const RebindToSigned di; + const intptr_t fft_rev_idx = + FindFirstTrue(d, MaskFromVec(Reverse(di, VecFromMask(di, m)))); + return (fft_rev_idx >= 0) + ? (static_cast(Lanes(d) - 1) - fft_rev_idx) + : intptr_t{-1}; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD m) { + const RebindToSigned di; + const size_t fft_rev_idx = + FindKnownFirstTrue(d, MaskFromVec(Reverse(di, VecFromMask(di, m)))); + return Lanes(d) - 1 - fft_rev_idx; +} + +// ------------------------------ ConcatOdd (Compress) + +namespace detail { + +#define HWY_RVV_NARROW(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEWD, LMULD) v) { \ + return __riscv_v##OP##_wx_##CHAR##SEW##LMUL(v, kShift, \ + HWY_RVV_AVL(SEWD, SHIFT + 1)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_NARROW, Narrow, nsrl, _EXT) +HWY_RVV_FOREACH_U16(HWY_RVV_NARROW, Narrow, nsrl, _EXT) +HWY_RVV_FOREACH_U32(HWY_RVV_NARROW, Narrow, nsrl, _EXT) +#undef HWY_RVV_NARROW + +} // namespace detail + +// Casting to wider and narrowing is the fastest for < 64-bit lanes. +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + constexpr size_t kBits = sizeof(TFromD) * 8; + const Twice dt; + const RepartitionToWide> dtuw; + const VFromD hl = BitCast(dtuw, Combine(dt, hi, lo)); + return BitCast(d, detail::Narrow(hl)); +} + +// 64-bit: Combine+Compress. +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const Twice dt; + const VFromD hl = Combine(dt, hi, lo); + return LowerHalf(d, Compress(hl, detail::IsOdd(dt))); +} + +// Any type, max LMUL: Compress both, then Combine. +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const Half dh; + const MFromD is_odd = detail::IsOdd(d); + const VFromD hi_odd = Compress(hi, is_odd); + const VFromD lo_odd = Compress(lo, is_odd); + return Combine(d, LowerHalf(dh, hi_odd), LowerHalf(dh, lo_odd)); +} + +// ------------------------------ ConcatEven (Compress) + +// Casting to wider and narrowing is the fastest for < 64-bit lanes. +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Twice dt; + const RepartitionToWide> dtuw; + const VFromD hl = BitCast(dtuw, Combine(dt, hi, lo)); + return BitCast(d, detail::Narrow<0>(hl)); +} + +// 64-bit: Combine+Compress. +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Twice dt; + const VFromD hl = Combine(dt, hi, lo); + return LowerHalf(d, Compress(hl, detail::IsEven(dt))); +} + +// Any type, max LMUL: Compress both, then Combine. +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Half dh; + const MFromD is_even = detail::IsEven(d); + const VFromD hi_even = Compress(hi, is_even); + const VFromD lo_even = Compress(lo, is_even); + return Combine(d, LowerHalf(dh, hi_even), LowerHalf(dh, lo_even)); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "third_party/highway/hwy/ops/inside-inl.h" + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes +template > +HWY_API V CombineShiftRightBytes(const D d, const V hi, V lo) { + const Repartition d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); + const auto hi_up = detail::SlideUp(hi8, hi8, 16 - kBytes); + const auto lo_down = detail::SlideDown(lo8, kBytes); + const auto is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +} + +// ------------------------------ CombineShiftRightLanes +template > +HWY_API V CombineShiftRightLanes(const D d, const V hi, V lo) { + constexpr size_t kLanesUp = 16 / sizeof(TFromV) - kLanes; + const auto hi_up = detail::SlideUp(hi, hi, kLanesUp); + const auto lo_down = detail::SlideDown(lo, kLanes); + const auto is_lo = detail::FirstNPerBlock(d); + return IfThenElse(is_lo, lo_down, hi_up); +} + +// ------------------------------ Shuffle2301 (ShiftLeft) +template +HWY_API V Shuffle2301(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const Repartition du64; + const auto v64 = BitCast(du64, v); + return BitCast(d, Or(ShiftRight<32>(v64), ShiftLeft<32>(v64))); +} + +// ------------------------------ Shuffle2103 +template +HWY_API V Shuffle2103(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<3>(d, v, v); +} + +// ------------------------------ Shuffle0321 +template +HWY_API V Shuffle0321(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle1032 +template +HWY_API V Shuffle1032(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<2>(d, v, v); +} + +// ------------------------------ Shuffle01 +template +HWY_API V Shuffle01(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 8, "Defined for 64-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle0123 +template +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ TableLookupBytes + +template +HWY_API VI TableLookupBytes(const VT vt, const VI vi) { + const DFromV dt; // T=table, I=index. + const DFromV di; + const Repartition dt8; + const Repartition di8; + // Required for producing half-vectors with table lookups from a full vector. + // If we instead run at the LMUL of the index vector, lookups into the table + // would be truncated. Thus we run at the larger of the two LMULs and truncate + // the result vector to the original index LMUL. + constexpr int kPow2T = dt8.Pow2(); + constexpr int kPow2I = di8.Pow2(); + const Simd dm8; // m=max + const auto vmt = detail::ChangeLMUL(dm8, BitCast(dt8, vt)); + const auto vmi = detail::ChangeLMUL(dm8, BitCast(di8, vi)); + auto offsets = detail::OffsetsOf128BitBlocks(dm8, detail::Iota0(dm8)); + // If the table is shorter, wrap around offsets so they do not reference + // undefined lanes in the newly extended vmt. + if (kPow2T < kPow2I) { + offsets = detail::AndS(offsets, static_cast(Lanes(dt8) - 1)); + } + const auto out = TableLookupLanes(vmt, Add(vmi, offsets)); + return BitCast(di, detail::ChangeLMUL(di8, out)); +} + +template +HWY_API VI TableLookupBytesOr0(const VT vt, const VI idx) { + const DFromV di; + const Repartition di8; + const auto idx8 = BitCast(di8, idx); + const auto lookup = TableLookupBytes(vt, idx8); + return BitCast(di, IfThenZeroElse(detail::LtS(idx8, 0), lookup)); +} + +// ------------------------------ TwoTablesLookupLanes + +// WARNING: 8-bit lanes may lead to unexpected results because idx is the same +// size and may overflow. +template +HWY_API VFromD TwoTablesLookupLanes(D d, VFromD a, VFromD b, + VFromD> idx) { + const Twice dt; + const RebindToUnsigned dt_u; + const auto combined_tbl = Combine(dt, b, a); + const auto combined_idx = Combine(dt_u, idx, idx); + return LowerHalf(d, TableLookupLanes(combined_tbl, combined_idx)); +} + +template +HWY_API VFromD TwoTablesLookupLanes(D d, VFromD a, VFromD b, + VFromD> idx) { + const RebindToUnsigned du; + using TU = TFromD; + + const size_t num_of_lanes = Lanes(d); + const auto idx_mod = detail::AndS(idx, static_cast(num_of_lanes - 1)); + const auto sel_a_mask = Ne(idx, idx_mod); // FALSE if a + + const auto a_lookup_result = TableLookupLanes(a, idx_mod); + return detail::MaskedTableLookupLanes(sel_a_mask, a_lookup_result, b, + idx_mod); +} + +template +HWY_API V TwoTablesLookupLanes(V a, V b, + VFromD>> idx) { + const DFromV d; + return TwoTablesLookupLanes(d, a, b, idx); +} + +// ------------------------------ Broadcast + +// 8-bit requires 16-bit tables. +template , HWY_IF_T_SIZE_D(D, 1), + HWY_IF_POW2_LE_D(D, 2)> +HWY_API V Broadcast(const V v) { + const D d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + + const Rebind du16; + VFromD idx = + detail::OffsetsOf128BitBlocks(d, detail::Iota0(du16)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + return detail::TableLookupLanes16(v, idx); +} + +// 8-bit and max LMUL: split into halves. +template , HWY_IF_T_SIZE_D(D, 1), + HWY_IF_POW2_GT_D(D, 2)> +HWY_API V Broadcast(const V v) { + const D d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + + const Half dh; + using VH = VFromD; + const Rebind du16; + VFromD idx = + detail::OffsetsOf128BitBlocks(d, detail::Iota0(du16)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + const VH lo = detail::TableLookupLanes16(LowerHalf(dh, v), idx); + const VH hi = detail::TableLookupLanes16(UpperHalf(dh, v), idx); + return Combine(d, hi, lo); +} + +template , + HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 2) | (1 << 4) | (1 << 8))> +HWY_API V Broadcast(const V v) { + const D d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + + const RebindToUnsigned du; + auto idx = detail::OffsetsOf128BitBlocks(d, detail::Iota0(du)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + return TableLookupLanes(v, idx); +} + +// ------------------------------ BroadcastLane +#ifdef HWY_NATIVE_BROADCASTLANE +#undef HWY_NATIVE_BROADCASTLANE +#else +#define HWY_NATIVE_BROADCASTLANE +#endif + +namespace detail { + +#define HWY_RVV_BROADCAST_LANE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, size_t idx) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_BROADCAST_LANE, BroadcastLane, rgather, _ALL) +#undef HWY_RVV_BROADCAST_LANE + +} // namespace detail + +template +HWY_API V BroadcastLane(V v) { + static_assert(0 <= kLane && kLane < HWY_MAX_LANES_V(V), "Invalid lane"); + return detail::BroadcastLane(v, static_cast(kLane)); +} + +// ------------------------------ InsertBlock +#ifdef HWY_NATIVE_BLK_INSERT_EXTRACT +#undef HWY_NATIVE_BLK_INSERT_EXTRACT +#else +#define HWY_NATIVE_BLK_INSERT_EXTRACT +#endif + +template +HWY_API V InsertBlock(V v, VFromD>> blk_to_insert) { + const DFromV d; + using TU = If<(sizeof(TFromV) == 1 && DFromV().Pow2() >= -2), uint16_t, + MakeUnsigned>>; + using TIdx = If; + + const Repartition du; + const Rebind d_idx; + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + constexpr size_t kMaxLanesPerBlock = 16 / sizeof(TU); + + constexpr size_t kBlkByteOffset = + static_cast(kBlockIdx) * kMaxLanesPerBlock; + const auto vu = BitCast(du, v); + const auto vblk = ResizeBitCast(du, blk_to_insert); + const auto vblk_shifted = detail::SlideUp(vblk, vblk, kBlkByteOffset); + const auto insert_mask = RebindMask( + du, detail::LtS(detail::SubS(detail::Iota0(d_idx), + static_cast(kBlkByteOffset)), + static_cast(kMaxLanesPerBlock))); + + return BitCast(d, IfThenElse(insert_mask, vblk_shifted, vu)); +} + +// ------------------------------ BroadcastBlock +template , -3)> +HWY_API V BroadcastBlock(V v) { + const DFromV d; + const Repartition du8; + const Rebind du16; + + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + + const auto idx = detail::AddS(detail::AndS(detail::Iota0(du16), uint16_t{15}), + static_cast(kBlockIdx * 16)); + return BitCast(d, detail::TableLookupLanes16(BitCast(du8, v), idx)); +} + +template , -3)> +HWY_API V BroadcastBlock(V v) { + const DFromV d; + using TU = If) == 1, uint16_t, MakeUnsigned>>; + const Repartition du; + + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + constexpr size_t kMaxLanesPerBlock = 16 / sizeof(TU); + + const auto idx = detail::AddS( + detail::AndS(detail::Iota0(du), static_cast(kMaxLanesPerBlock - 1)), + static_cast(static_cast(kBlockIdx) * kMaxLanesPerBlock)); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ ExtractBlock +template +HWY_API VFromD>> ExtractBlock(V v) { + const DFromV d; + const BlockDFromD d_block; + + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + constexpr size_t kMaxLanesPerBlock = 16 / sizeof(TFromD); + constexpr size_t kBlkByteOffset = + static_cast(kBlockIdx) * kMaxLanesPerBlock; + + return ResizeBitCast(d_block, detail::SlideDown(v, kBlkByteOffset)); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API V ShiftLeftLanes(const D d, const V v) { + const RebindToSigned di; + const RebindToUnsigned du; + using TI = TFromD; + const auto shifted = detail::SlideUp(v, v, kLanes); + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(du)), + static_cast(detail::LanesPerBlock(di) - 1)); + const auto clear = detail::LtS(idx_mod, static_cast(kLanes)); + return IfThenZeroElse(clear, shifted); +} + +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, const VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftLanes(BitCast(d8, v))); +} + +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftRightLanes +template >> +HWY_API V ShiftRightLanes(const Simd d, V v) { + const RebindToSigned di; + const RebindToUnsigned du; + using TI = TFromD; + // For partial vectors, clear upper lanes so we shift in zeros. + if (N <= 16 / sizeof(T)) { + v = detail::SlideUp(v, Zero(d), N); + } + + const auto shifted = detail::SlideDown(v, kLanes); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + const size_t lpb = detail::LanesPerBlock(di); + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(du)), static_cast(lpb - 1)); + const auto keep = detail::LtS(idx_mod, static_cast(lpb - kLanes)); + return IfThenElseZero(keep, shifted); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftRightLanes(d8, BitCast(d8, v))); +} + +// ------------------------------ InterleaveWholeLower +#ifdef HWY_NATIVE_INTERLEAVE_WHOLE +#undef HWY_NATIVE_INTERLEAVE_WHOLE +#else +#define HWY_NATIVE_INTERLEAVE_WHOLE +#endif + +namespace detail { +// Returns double-length vector with interleaved lanes. +template +HWY_API VFromD InterleaveWhole(D d, VFromD> a, VFromD> b) { + const RebindToUnsigned du; + using TW = MakeWide>; + const Rebind> dw; + const Half duh; // cast inputs to unsigned so we zero-extend + + const VFromD aw = PromoteTo(dw, BitCast(duh, a)); + const VFromD bw = PromoteTo(dw, BitCast(duh, b)); + return BitCast(d, Or(aw, BitCast(dw, detail::Slide1Up(BitCast(du, bw))))); +} +// 64-bit: cannot PromoteTo, but can Ext. +template +HWY_API VFromD InterleaveWhole(D d, VFromD> a, VFromD> b) { + const RebindToUnsigned du; + const auto idx = ShiftRight<1>(detail::Iota0(du)); + return OddEven(TableLookupLanes(detail::Ext(d, b), idx), + TableLookupLanes(detail::Ext(d, a), idx)); +} +template +HWY_API VFromD InterleaveWhole(D d, VFromD> a, VFromD> b) { + const Half dh; + const Half dq; + const VFromD i0 = + InterleaveWhole(dh, LowerHalf(dq, a), LowerHalf(dq, b)); + const VFromD i1 = + InterleaveWhole(dh, UpperHalf(dq, a), UpperHalf(dq, b)); + return Combine(d, i1, i0); +} + +} // namespace detail + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + const detail::AdjustSimdTagToMinVecPow2> dw; + const RepartitionToNarrow du_src; + + const VFromD aw = + ResizeBitCast(d, PromoteLowerTo(dw, ResizeBitCast(du_src, a))); + const VFromD bw = + ResizeBitCast(d, PromoteLowerTo(dw, ResizeBitCast(du_src, b))); + return Or(aw, detail::Slide1Up(bw)); +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + const auto idx = ShiftRight<1>(detail::Iota0(du)); + return OddEven(TableLookupLanes(b, idx), TableLookupLanes(a, idx)); +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + // Use Lanes(d) / 2 instead of Lanes(Half()) as Lanes(Half()) can only + // be called if (d.Pow2() >= -2 && d.Pow2() == DFromV>().Pow2()) is + // true and and as the results of InterleaveWholeUpper are + // implementation-defined if Lanes(d) is less than 2. + const size_t half_N = Lanes(d) / 2; + return InterleaveWholeLower(d, detail::SlideDown(a, half_N), + detail::SlideDown(b, half_N)); +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + // Use Lanes(d) / 2 instead of Lanes(Half()) as Lanes(Half()) can only + // be called if (d.Pow2() >= -2 && d.Pow2() == DFromV>().Pow2()) is + // true and as the results of InterleaveWholeUpper are implementation-defined + // if Lanes(d) is less than 2. + const size_t half_N = Lanes(d) / 2; + const RebindToUnsigned du; + const auto idx = detail::AddS(ShiftRight<1>(detail::Iota0(du)), + static_cast(half_N)); + return OddEven(TableLookupLanes(b, idx), TableLookupLanes(a, idx)); +} + +// ------------------------------ InterleaveLower (InterleaveWholeLower) + +namespace detail { + +// Definitely at least 128 bit: match x86 semantics (independent blocks). Using +// InterleaveWhole and 64-bit Compress avoids 8-bit overflow. +template +HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const Twice dt; + const RebindToUnsigned dt_u; + const VFromD interleaved = detail::InterleaveWhole(dt, a, b); + // Keep only even 128-bit blocks. This is faster than u64 ConcatEven + // because we only have a single vector. + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromD)); + const VFromD idx_block = + ShiftRight(detail::Iota0(dt_u)); + const MFromD is_even = + detail::EqS(detail::AndS(idx_block, 1), 0); + return BitCast(d, LowerHalf(Compress(BitCast(dt_u, interleaved), is_even))); +} +template +HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) { + const Half dh; + const VFromD i0 = + InterleaveLowerBlocks(dh, LowerHalf(dh, a), LowerHalf(dh, b)); + const VFromD i1 = + InterleaveLowerBlocks(dh, UpperHalf(dh, a), UpperHalf(dh, b)); + return Combine(d, i1, i0); +} + +// As above, for the upper half of blocks. +template +HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const Twice dt; + const RebindToUnsigned dt_u; + const VFromD interleaved = detail::InterleaveWhole(dt, a, b); + // Keep only odd 128-bit blocks. This is faster than u64 ConcatEven + // because we only have a single vector. + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromD)); + const VFromD idx_block = + ShiftRight(detail::Iota0(dt_u)); + const MFromD is_odd = + detail::EqS(detail::AndS(idx_block, 1), 1); + return BitCast(d, LowerHalf(Compress(BitCast(dt_u, interleaved), is_odd))); +} +template +HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) { + const Half dh; + const VFromD i0 = + InterleaveUpperBlocks(dh, LowerHalf(dh, a), LowerHalf(dh, b)); + const VFromD i1 = + InterleaveUpperBlocks(dh, UpperHalf(dh, a), UpperHalf(dh, b)); + return Combine(d, i1, i0); +} + +// RVV vectors are at least 128 bit when there is no fractional LMUL nor cap. +// Used by functions with per-block behavior such as InterleaveLower. +template +constexpr bool IsGE128(Simd /* d */) { + return N * sizeof(T) >= 16 && kPow2 >= 0; +} + +// Definitely less than 128-bit only if there is a small cap; fractional LMUL +// might not be enough if vectors are large. +template +constexpr bool IsLT128(Simd /* d */) { + return N * sizeof(T) < 16; +} + +} // namespace detail + +#define HWY_RVV_IF_GE128_D(D) hwy::EnableIf* = nullptr +#define HWY_RVV_IF_LT128_D(D) hwy::EnableIf* = nullptr +#define HWY_RVV_IF_CAN128_D(D) \ + hwy::EnableIf* = nullptr + +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + return detail::InterleaveLowerBlocks(d, a, b); +} + +// Single block: interleave without extra Compress. +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return InterleaveWholeLower(d, a, b); +} + +// Could be either; branch at runtime. +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + if (Lanes(d) * sizeof(TFromD) <= 16) { + return InterleaveWholeLower(d, a, b); + } + // Fractional LMUL: use LMUL=1 to ensure we can cast to u64. + const ScalableTag, HWY_MAX(d.Pow2(), 0)> d1; + return ResizeBitCast(d, detail::InterleaveLowerBlocks( + d1, ResizeBitCast(d1, a), ResizeBitCast(d1, b))); +} + +template +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV(), a, b); +} + +// ------------------------------ InterleaveUpper (Compress) + +template +HWY_API V InterleaveUpper(D d, const V a, const V b) { + return detail::InterleaveUpperBlocks(d, a, b); +} + +// Single block: interleave without extra Compress. +template +HWY_API V InterleaveUpper(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return InterleaveWholeUpper(d, a, b); +} + +// Could be either; branch at runtime. +template +HWY_API V InterleaveUpper(D d, const V a, const V b) { + if (Lanes(d) * sizeof(TFromD) <= 16) { + return InterleaveWholeUpper(d, a, b); + } + // Fractional LMUL: use LMUL=1 to ensure we can cast to u64. + const ScalableTag, HWY_MAX(d.Pow2(), 0)> d1; + return ResizeBitCast(d, detail::InterleaveUpperBlocks( + d1, ResizeBitCast(d1, a), ResizeBitCast(d1, b))); +} + +// ------------------------------ ZipLower + +template >> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} + +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} + +// ------------------------------ ZipUpper +template +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== REDUCE + +// We have ReduceSum, generic_ops-inl.h defines SumOfLanes via Set. +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +// scalar = f(vector, zero_m1) +#define HWY_RVV_REDUCE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_T(BASE, SEW) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_V(BASE, SEW, m1) v0) { \ + return GetLane(__riscv_v##OP##_vs_##CHAR##SEW##LMUL##_##CHAR##SEW##m1( \ + v, v0, Lanes(d))); \ + } + +// detail::RedSum, detail::RedMin, and detail::RedMax is more efficient +// for N=4 I8/U8 reductions on RVV than the default implementations of the +// the N=4 I8/U8 ReduceSum/ReduceMin/ReduceMax operations in generic_ops-inl.h +#undef HWY_IF_REDUCE_D +#define HWY_IF_REDUCE_D(D) hwy::EnableIf* = nullptr + +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +#ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#else +#define HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#endif + +// ------------------------------ ReduceSum + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredusum, _ALL_VIRT) +} // namespace detail + +template +HWY_API TFromD ReduceSum(D d, const VFromD v) { + const auto v0 = Zero(ScalableTag>()); // always m1 + return detail::RedSum(d, v, v0); +} + +// ------------------------------ ReduceMin +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu, _ALL_VIRT) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin, _ALL_VIRT) +} // namespace detail + +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMin(D d, const VFromD v) { + const ScalableTag d1; // always m1 + return detail::RedMin(d, v, Set(d1, HighestValue())); +} + +// ------------------------------ ReduceMax +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu, _ALL_VIRT) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax, _ALL_VIRT) +} // namespace detail + +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMax(D d, const VFromD v) { + const ScalableTag d1; // always m1 + return detail::RedMax(d, v, Set(d1, LowestValue())); +} + +#undef HWY_RVV_REDUCE + +// TODO: add MaskedReduceSum/Min/Max + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); +} + +// ================================================== Ops with dependencies + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +// Requires Clang 16+, GCC 14+; otherwise emulated in generic_ops-inl.h. +#if HWY_HAVE_TUPLE + +#define HWY_RVV_GET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##2(HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##x2_##CHAR##SEW##LMUL(tup, \ + kIndex); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##3(HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##x3_##CHAR##SEW##LMUL(tup, \ + kIndex); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##4(HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##x4_##CHAR##SEW##LMUL(tup, \ + kIndex); \ + } + +HWY_RVV_FOREACH(HWY_RVV_GET, Get, get, _LE2) +#undef HWY_RVV_GET + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 2) NAME##2( \ + HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMUL##x2( \ + tup, kIndex, v); \ + } \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 3) NAME##3( \ + HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMUL##x3( \ + tup, kIndex, v); \ + } \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 4) NAME##4( \ + HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMUL##x4( \ + tup, kIndex, v); \ + } + +HWY_RVV_FOREACH(HWY_RVV_SET, Set, set, _LE2) +#undef HWY_RVV_SET + +// RVV does not provide vcreate, so implement using Set. +#define HWY_RVV_CREATE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 2) \ + NAME##2(HWY_RVV_D(BASE, SEW, N, SHIFT) /*d*/, \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1) { \ + HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup{}; \ + tup = Set2<0>(tup, v0); \ + tup = Set2<1>(tup, v1); \ + return tup; \ + } \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 3) NAME##3( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /*d*/, HWY_RVV_V(BASE, SEW, LMUL) v0, \ + HWY_RVV_V(BASE, SEW, LMUL) v1, HWY_RVV_V(BASE, SEW, LMUL) v2) { \ + HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup{}; \ + tup = Set3<0>(tup, v0); \ + tup = Set3<1>(tup, v1); \ + tup = Set3<2>(tup, v2); \ + return tup; \ + } \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 4) \ + NAME##4(HWY_RVV_D(BASE, SEW, N, SHIFT) /*d*/, \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_V(BASE, SEW, LMUL) v3) { \ + HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup{}; \ + tup = Set4<0>(tup, v0); \ + tup = Set4<1>(tup, v1); \ + tup = Set4<2>(tup, v2); \ + tup = Set4<3>(tup, v3); \ + return tup; \ + } + +HWY_RVV_FOREACH(HWY_RVV_CREATE, Create, xx, _LE2_VIRT) +#undef HWY_RVV_CREATE + +template +using Vec2 = decltype(Create2(D(), Zero(D()), Zero(D()))); +template +using Vec3 = decltype(Create3(D(), Zero(D()), Zero(D()), Zero(D()))); +template +using Vec4 = decltype(Create4(D(), Zero(D()), Zero(D()), Zero(D()), Zero(D()))); + +#define HWY_RVV_LOAD2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup = \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x2(unaligned, Lanes(d)); \ + v0 = Get2<0>(tup); \ + v1 = Get2<1>(tup); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD2, LoadInterleaved2, lseg2, _LE2_VIRT) +#undef HWY_RVV_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_RVV_LOAD3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup = \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x3(unaligned, Lanes(d)); \ + v0 = Get3<0>(tup); \ + v1 = Get3<1>(tup); \ + v2 = Get3<2>(tup); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD3, LoadInterleaved3, lseg3, _LE2_VIRT) +#undef HWY_RVV_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_RVV_LOAD4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2, HWY_RVV_V(BASE, SEW, LMUL) & v3) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup = \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x4(unaligned, Lanes(d)); \ + v0 = Get4<0>(tup); \ + v1 = Get4<1>(tup); \ + v2 = Get4<2>(tup); \ + v3 = Get4<3>(tup); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD4, LoadInterleaved4, lseg4, _LE2_VIRT) +#undef HWY_RVV_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_RVV_STORE2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, \ + HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup = Create2(d, v0, v1); \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x2(unaligned, tup, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE2, StoreInterleaved2, sseg2, _LE2_VIRT) +#undef HWY_RVV_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_RVV_STORE3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup = Create3(d, v0, v1, v2); \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x3(unaligned, tup, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE3, StoreInterleaved3, sseg3, _LE2_VIRT) +#undef HWY_RVV_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_RVV_STORE4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_V(BASE, SEW, LMUL) v3, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup = Create4(d, v0, v1, v2, v3); \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x4(unaligned, tup, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE4, StoreInterleaved4, sseg4, _LE2_VIRT) +#undef HWY_RVV_STORE4 + +#else // !HWY_HAVE_TUPLE + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + const VFromD A = LoadU(d, unaligned); // v1[1] v0[1] v1[0] v0[0] + const VFromD B = LoadU(d, unaligned + Lanes(d)); + v0 = ConcatEven(d, B, A); + v1 = ConcatOdd(d, B, A); +} + +namespace detail { +#define HWY_RVV_LOAD_STRIDED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t stride) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + p, static_cast(stride), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_LOAD_STRIDED, LoadStrided, lse, _ALL_VIRT) +#undef HWY_RVV_LOAD_STRIDED +} // namespace detail + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + // Offsets are bytes, and this is not documented. + v0 = detail::LoadStrided(d, unaligned + 0, 3 * sizeof(T)); + v1 = detail::LoadStrided(d, unaligned + 1, 3 * sizeof(T)); + v2 = detail::LoadStrided(d, unaligned + 2, 3 * sizeof(T)); +} + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + // Offsets are bytes, and this is not documented. + v0 = detail::LoadStrided(d, unaligned + 0, 4 * sizeof(T)); + v1 = detail::LoadStrided(d, unaligned + 1, 4 * sizeof(T)); + v2 = detail::LoadStrided(d, unaligned + 2, 4 * sizeof(T)); + v3 = detail::LoadStrided(d, unaligned + 3, 4 * sizeof(T)); +} + +// Not 64-bit / max LMUL: interleave via promote, slide, OddEven. +template , HWY_IF_NOT_T_SIZE_D(D, 8), + HWY_IF_POW2_LE_D(D, 2), HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + const Twice> duw; + const Twice dt; + // Interleave with zero by promoting to wider (unsigned) type. + const VFromD w0 = BitCast(dt, PromoteTo(duw, BitCast(du, v0))); + const VFromD w1 = BitCast(dt, PromoteTo(duw, BitCast(du, v1))); + // OR second vector into the zero-valued lanes (faster than OddEven). + StoreU(Or(w0, detail::Slide1Up(w1)), dt, unaligned); +} + +// Can promote, max LMUL: two half-length +template , HWY_IF_NOT_T_SIZE_D(D, 8), + HWY_IF_POW2_GT_D(D, 2), HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved2(LowerHalf(dh, v0), LowerHalf(dh, v1), d, unaligned); + StoreInterleaved2(UpperHalf(dh, v0), UpperHalf(dh, v1), d, + unaligned + Lanes(d)); +} + +namespace detail { +#define HWY_RVV_STORE_STRIDED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t stride) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + p, static_cast(stride), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_STORE_STRIDED, StoreStrided, sse, _ALL_VIRT) +#undef HWY_RVV_STORE_STRIDED +} // namespace detail + +// 64-bit: strided +template , HWY_IF_T_SIZE_D(D, 8), + HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + // Offsets are bytes, and this is not documented. + detail::StoreStrided(v0, d, unaligned + 0, 2 * sizeof(T)); + detail::StoreStrided(v1, d, unaligned + 1, 2 * sizeof(T)); +} + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + T* HWY_RESTRICT unaligned) { + // Offsets are bytes, and this is not documented. + detail::StoreStrided(v0, d, unaligned + 0, 3 * sizeof(T)); + detail::StoreStrided(v1, d, unaligned + 1, 3 * sizeof(T)); + detail::StoreStrided(v2, d, unaligned + 2, 3 * sizeof(T)); +} + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, T* HWY_RESTRICT unaligned) { + // Offsets are bytes, and this is not documented. + detail::StoreStrided(v0, d, unaligned + 0, 4 * sizeof(T)); + detail::StoreStrided(v1, d, unaligned + 1, 4 * sizeof(T)); + detail::StoreStrided(v2, d, unaligned + 2, 4 * sizeof(T)); + detail::StoreStrided(v3, d, unaligned + 3, 4 * sizeof(T)); +} + +#endif // HWY_HAVE_TUPLE + +// Rely on generic Load/StoreInterleaved[234] for any emulated types. +// Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_RVV_IF_EMULATED_D. + +// ------------------------------ Dup128VecFromValues (ResizeBitCast) + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD /*t1*/) { + return Set(d, t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { + const auto even_lanes = Set(d, t0); +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(BitCastScalar(t0) == + BitCastScalar(t1)) && + (BitCastScalar(t0) == BitCastScalar(t1))) { + return even_lanes; + } +#endif + + const auto odd_lanes = Set(d, t1); + return OddEven(odd_lanes, even_lanes); +} + +namespace detail { + +#pragma pack(push, 1) + +template +struct alignas(8) Vec64ValsWrapper { + static_assert(sizeof(T) >= 1, "sizeof(T) >= 1 must be true"); + static_assert(sizeof(T) <= 8, "sizeof(T) <= 8 must be true"); + T vals[8 / sizeof(T)]; +}; + +#pragma pack(pop) + +} // namespace detail + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast( + d, Dup128VecFromValues( + du64, + BitCastScalar(detail::Vec64ValsWrapper>{ + {t0, t1, t2, t3, t4, t5, t6, t7}}), + BitCastScalar(detail::Vec64ValsWrapper>{ + {t8, t9, t10, t11, t12, t13, t14, t15}}))); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast( + d, Dup128VecFromValues( + du64, + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1, t2, t3}}), + BitCastScalar( + detail::Vec64ValsWrapper>{{t4, t5, t6, t7}}))); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast( + d, + Dup128VecFromValues(du64, + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1}}), + BitCastScalar( + detail::Vec64ValsWrapper>{{t2, t3}}))); +} + +// ------------------------------ PopulationCount (ShiftRight) + +// Handles LMUL < 2 or capped vectors, which generic_ops-inl cannot. +template , HWY_IF_U8_D(D), + hwy::EnableIf* = nullptr> +HWY_API V PopulationCount(V v) { + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + v = Sub(v, detail::AndS(ShiftRight<1>(v), 0x55)); + v = Add(detail::AndS(ShiftRight<2>(v), 0x33), detail::AndS(v, 0x33)); + return detail::AndS(Add(v, ShiftRight<4>(v)), 0x0F); +} + +// ------------------------------ LoadDup128 + +template +HWY_API VFromD LoadDup128(D d, const TFromD* const HWY_RESTRICT p) { + const RebindToUnsigned du; + + // Make sure that no more than 16 bytes are loaded from p + constexpr int kLoadPow2 = d.Pow2(); + constexpr size_t kMaxLanesToLoad = + HWY_MIN(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD)); + constexpr size_t kLoadN = D::template NewN(); + const Simd, kLoadN, kLoadPow2> d_load; + static_assert(d_load.MaxBytes() <= 16, + "d_load.MaxBytes() <= 16 must be true"); + static_assert((d.MaxBytes() < 16) || (d_load.MaxBytes() == 16), + "d_load.MaxBytes() == 16 must be true if d.MaxBytes() >= 16 is " + "true"); + static_assert((d.MaxBytes() >= 16) || (d_load.MaxBytes() == d.MaxBytes()), + "d_load.MaxBytes() == d.MaxBytes() must be true if " + "d.MaxBytes() < 16 is true"); + + const VFromD loaded = Load(d_load, p); + if (d.MaxBytes() <= 16) return loaded; + + // idx must be unsigned for TableLookupLanes. + using TU = TFromD; + const TU mask = static_cast(detail::LanesPerBlock(d) - 1); + // Broadcast the first block. + const VFromD> idx = detail::AndS(detail::Iota0(du), mask); + // Safe even for 8-bit lanes because indices never exceed 15. + return TableLookupLanes(loaded, idx); +} + +// ------------------------------ LoadMaskBits + +// Support all combinations of T and SHIFT(LMUL) without explicit overloads for +// each. First overload for MLEN=1..64. +namespace detail { + +// Maps D to MLEN (wrapped in SizeTag), such that #mask_bits = VLEN/MLEN. MLEN +// increases with lane size and decreases for increasing LMUL. Cap at 64, the +// largest supported by HWY_RVV_FOREACH_B (and intrinsics), for virtual LMUL +// e.g. vuint16mf8_t: (8*2 << 3) == 128. +template +using MaskTag = hwy::SizeTag), -D().Pow2()))>; + +#define HWY_RVV_LOAD_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_INLINE HWY_RVV_M(MLEN) \ + NAME(hwy::SizeTag /* tag */, const uint8_t* bits, size_t N) { \ + return __riscv_v##OP##_v_b##MLEN(bits, N); \ + } +HWY_RVV_FOREACH_B(HWY_RVV_LOAD_MASK_BITS, LoadMaskBits, lm) +#undef HWY_RVV_LOAD_MASK_BITS +} // namespace detail + +template > +HWY_API auto LoadMaskBits(D d, const uint8_t* bits) + -> decltype(detail::LoadMaskBits(MT(), bits, Lanes(d))) { + return detail::LoadMaskBits(MT(), bits, Lanes(d)); +} + +// ------------------------------ StoreMaskBits +#define HWY_RVV_STORE_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(D d, HWY_RVV_M(MLEN) m, uint8_t* bits) { \ + const size_t N = Lanes(d); \ + __riscv_v##OP##_v_b##MLEN(bits, m, N); \ + /* Non-full byte, need to clear the undefined upper bits. */ \ + /* Use MaxLanes and sizeof(T) to move some checks to compile-time. */ \ + constexpr bool kLessThan8 = \ + detail::ScaleByPower(16 / sizeof(TFromD), d.Pow2()) < 8; \ + if (MaxLanes(d) < 8 || (kLessThan8 && N < 8)) { \ + const int mask = (1 << N) - 1; \ + bits[0] = static_cast(bits[0] & mask); \ + } \ + return (N + 7) / 8; \ + } +HWY_RVV_FOREACH_B(HWY_RVV_STORE_MASK_BITS, StoreMaskBits, sm) +#undef HWY_RVV_STORE_MASK_BITS + +// ------------------------------ CompressBits, CompressBitsStore (LoadMaskBits) + +template +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ FirstN (Iota0, Lt, RebindMask, SlideUp) + +// NOTE: do not use this as a building block within rvv-inl - it is likely more +// efficient to use avl or detail::SlideUp. + +// Disallow for 8-bit because Iota is likely to overflow. +template +HWY_API MFromD FirstN(const D d, const size_t n) { + const RebindToUnsigned du; + using TU = TFromD; + return RebindMask(d, detail::LtS(detail::Iota0(du), static_cast(n))); +} + +template +HWY_API MFromD FirstN(const D d, const size_t n) { + const auto zero = Zero(d); + const auto one = Set(d, 1); + return Eq(detail::SlideUp(one, zero, n), one); +} + +// ------------------------------ LowerHalfOfMask/UpperHalfOfMask + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// Target-specific implementations of LowerHalfOfMask, UpperHalfOfMask, +// CombineMasks, OrderedDemote2MasksTo, and Dup128MaskFromMaskBits are possible +// on RVV if the __riscv_vreinterpret_v_b*_u8m1 and +// __riscv_vreinterpret_v_u8m1_b* intrinsics are available. + +// The __riscv_vreinterpret_v_b*_u8m1 and __riscv_vreinterpret_v_u8m1_b* +// intrinsics available with Clang 17 and later and GCC 14 and later. + +namespace detail { + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool1_t m) { + return __riscv_vreinterpret_v_b1_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool2_t m) { + return __riscv_vreinterpret_v_b2_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool4_t m) { + return __riscv_vreinterpret_v_b4_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool8_t m) { + return __riscv_vreinterpret_v_b8_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool16_t m) { + return __riscv_vreinterpret_v_b16_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool32_t m) { + return __riscv_vreinterpret_v_b32_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool64_t m) { + return __riscv_vreinterpret_v_b64_u8m1(m); +} + +template , vbool1_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b1(v); +} + +template , vbool2_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b2(v); +} + +template , vbool4_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b4(v); +} + +template , vbool8_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b8(v); +} + +template , vbool16_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b16(v); +} + +template , vbool32_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b32(v); +} + +template , vbool64_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b64(v); +} + +} // namespace detail + +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API MFromD LowerHalfOfMask(D d, MFromD> m) { + return detail::U8MaskBitsVecToMask(d, detail::MaskToU8MaskBitsVec(m)); +} + +#ifdef HWY_NATIVE_UPPER_HALF_OF_MASK +#undef HWY_NATIVE_UPPER_HALF_OF_MASK +#else +#define HWY_NATIVE_UPPER_HALF_OF_MASK +#endif + +template +HWY_API MFromD UpperHalfOfMask(D d, MFromD> m) { + const size_t N = Lanes(d); + + vuint8m1_t mask_bits = detail::MaskToU8MaskBitsVec(m); + mask_bits = ShiftRightSame(mask_bits, static_cast(N & 7)); + if (HWY_MAX_LANES_D(D) >= 8) { + mask_bits = SlideDownLanes(ScalableTag(), mask_bits, N / 8); + } + + return detail::U8MaskBitsVecToMask(d, mask_bits); +} + +// ------------------------------ CombineMasks + +#ifdef HWY_NATIVE_COMBINE_MASKS +#undef HWY_NATIVE_COMBINE_MASKS +#else +#define HWY_NATIVE_COMBINE_MASKS +#endif + +template +HWY_API MFromD CombineMasks(D d, MFromD> hi, MFromD> lo) { + const Half dh; + const size_t half_N = Lanes(dh); + + const auto ext_lo_mask = + And(detail::U8MaskBitsVecToMask(d, detail::MaskToU8MaskBitsVec(lo)), + FirstN(d, half_N)); + vuint8m1_t hi_mask_bits = detail::MaskToU8MaskBitsVec(hi); + hi_mask_bits = ShiftLeftSame(hi_mask_bits, static_cast(half_N & 7)); + if (HWY_MAX_LANES_D(D) >= 8) { + hi_mask_bits = + SlideUpLanes(ScalableTag(), hi_mask_bits, half_N / 8); + } + + return Or(ext_lo_mask, detail::U8MaskBitsVecToMask(d, hi_mask_bits)); +} + +// ------------------------------ OrderedDemote2MasksTo + +#ifdef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#undef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#else +#define HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#endif + +template ) / 2), + class DTo_2 = Repartition, DFrom>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD OrderedDemote2MasksTo(DTo d_to, DFrom /*d_from*/, + MFromD a, MFromD b) { + return CombineMasks(d_to, b, a); +} + +#endif // HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// ------------------------------ Dup128MaskFromMaskBits + +namespace detail { +// Even though this is only used after checking if (kN < X), this helper +// function prevents "shift count exceeded" errors. +template +constexpr unsigned MaxMaskBits() { + return (1u << kN) - 1; +} +template +constexpr unsigned MaxMaskBits() { + return ~0u; +} + +template +constexpr int SufficientPow2ForMask() { + return HWY_MAX( + D().Pow2() - 3 - static_cast(FloorLog2(sizeof(TFromD))), -3); +} +} // namespace detail + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + return detail::U8MaskBitsVecToMask( + d, Set(ScalableTag(), static_cast(mask_bits))); +#else + const RebindToUnsigned du8; + const detail::AdjustSimdTagToMinVecPow2> + du64; + + const auto bytes = ResizeBitCast( + du8, detail::AndS( + ResizeBitCast(du64, Set(du8, static_cast(mask_bits))), + uint64_t{0x8040201008040201u})); + return detail::NeS(bytes, uint8_t{0}); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + const ScalableTag()> du16; + // There are exactly 16 mask bits for 128 vector bits of 8-bit lanes. + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL( + ScalableTag(), + BitCast(du8, Set(du16, static_cast(mask_bits))))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du8; + const Repartition du16; + const detail::AdjustSimdTagToMinVecPow2> + du64; + + // Replicate the lower 16 bits of mask_bits to each u16 lane of a u16 vector, + // and then bitcast the replicated mask_bits to a u8 vector + const auto bytes = BitCast(du8, Set(du16, static_cast(mask_bits))); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const auto rep8 = TableLookupLanes(bytes, ShiftRight<3>(detail::Iota0(du8))); + + const auto masked_out_rep8 = ResizeBitCast( + du8, + detail::AndS(ResizeBitCast(du64, rep8), uint64_t{0x8040201008040201u})); + return detail::NeS(masked_out_rep8, uint8_t{0}); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + // There are exactly 8 mask bits for 128 vector bits of 16-bit lanes. + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits)))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du; + const VFromD bits = + Shl(Set(du, uint16_t{1}), detail::AndS(detail::Iota0(du), 7)); + return TestBit(Set(du, static_cast(mask_bits)), bits); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 4) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits * 0x11)))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du; + const VFromD bits = Dup128VecFromValues(du, 1, 2, 4, 8); + return TestBit(Set(du, static_cast(mask_bits)), bits); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 2) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits * 0x55)))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du; + const VFromD bits = Dup128VecFromValues(du, 1, 2); + return TestBit(Set(du, static_cast(mask_bits)), bits); +#endif +} + +// ------------------------------ Abs (Max, Neg) + +template +HWY_API V Abs(const V v) { + return Max(v, Neg(v)); +} + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Abs, fsgnjx, _ALL) + +#undef HWY_RVV_RETV_ARGV2 + +// ------------------------------ AbsDiff (Abs, Sub) +template +HWY_API V AbsDiff(const V a, const V b) { + return Abs(Sub(a, b)); +} + +// ------------------------------ Round (NearestInt, ConvertTo, CopySign) + +// IEEE-754 roundToIntegralTiesToEven returns floating-point, but we do not have +// a dedicated instruction for that. Rounding to integer and converting back to +// float is correct except when the input magnitude is large, in which case the +// input was already an integer (because mantissa >> exponent is zero). + +namespace detail { +enum RoundingModes { kNear, kTrunc, kDown, kUp }; + +template +HWY_INLINE auto UseInt(const V v) -> decltype(MaskFromVec(v)) { + return detail::LtS(Abs(v), MantissaEnd>()); +} + +} // namespace detail + +template +HWY_API V Round(const V v) { + const DFromV df; + + const auto integer = NearestInt(v); // round using current mode + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Trunc (ConvertTo) +template +HWY_API V Trunc(const V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Ceil +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) +namespace detail { +#define HWY_RVV_CEIL_INT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) CeilInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_x_f_v_i##SEW##LMUL##_rm(v, __RISCV_FRM_RUP, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_CEIL_INT, _, _, _ALL) +#undef HWY_RVV_CEIL_INT + +} // namespace detail + +template +HWY_API V Ceil(const V v) { + const DFromV df; + + const auto integer = detail::CeilInt(v); + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +#else // GCC 13 or earlier or Clang 16 or earlier + +template +HWY_API V Ceil(const V v) { + const DFromV df; + const RebindToSigned di; + + using T = TFromD; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto pos1 = + IfThenElseZero(Lt(int_f, v), Set(df, ConvertScalarTo(1.0))); + + return IfThenElse(detail::UseInt(v), Add(int_f, pos1), v); +} + +#endif // (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || + // (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) + +// ------------------------------ Floor +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) +namespace detail { +#define HWY_RVV_FLOOR_INT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) FloorInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_x_f_v_i##SEW##LMUL##_rm(v, __RISCV_FRM_RDN, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_FLOOR_INT, _, _, _ALL) +#undef HWY_RVV_FLOOR_INT + +} // namespace detail + +template +HWY_API V Floor(const V v) { + const DFromV df; + + const auto integer = detail::FloorInt(v); + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +#else // GCC 13 or earlier or Clang 16 or earlier + +template +HWY_API V Floor(const V v) { + const DFromV df; + const RebindToSigned di; + + using T = TFromD; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = + IfThenElseZero(Gt(int_f, v), Set(df, ConvertScalarTo(-1.0))); + + return IfThenElse(detail::UseInt(v), Add(int_f, neg1), v); +} + +#endif // (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || + // (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) + +// ------------------------------ Floating-point classification (Ne) + +// vfclass does not help because it would require 3 instructions (to AND and +// then compare the bits), whereas these are just 1-3 integer instructions. + +template +HWY_API MFromD> IsNaN(const V v) { + return Ne(v, v); +} + +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +// We use a fused Set/comparison for IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +template > +HWY_API MFromD IsInf(const V v) { + const D d; + const RebindToSigned di; + using T = TFromD; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, detail::EqS(Add(vi, vi), hwy::MaxExponentTimes2())); +} + +// Returns whether normal/subnormal/zero. +template > +HWY_API MFromD IsFinite(const V v) { + const D d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + using T = TFromD; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtS(exp, hwy::MaxExponentField())); +} + +// ------------------------------ Iota (ConvertTo) + +template +HWY_API VFromD Iota(const D d, T2 first) { + return detail::AddS(detail::Iota0(d), static_cast>(first)); +} + +template +HWY_API VFromD Iota(const D d, T2 first) { + const RebindToUnsigned du; + return detail::AddS(BitCast(d, detail::Iota0(du)), + static_cast>(first)); +} + +template +HWY_API VFromD Iota(const D d, T2 first) { + const RebindToUnsigned du; + const RebindToSigned di; + return detail::AddS(ConvertTo(d, BitCast(di, detail::Iota0(du))), + ConvertScalarTo>(first)); +} + +// ------------------------------ BitShuffle (PromoteTo, Rol, SumsOf8) + +// Native implementation required to avoid 8-bit wraparound on long vectors. +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +// Cannot handle LMUL=8 because we promote indices. +template ), class D64 = DFromV, + HWY_IF_UI64_D(D64), HWY_IF_POW2_LE_D(D64, 2)> +HWY_API V64 BitShuffle(V64 values, VI idx) { + const RebindToUnsigned du64; + const Repartition du8; + const Rebind du16; + using VU8 = VFromD; + using VU16 = VFromD; + // For each 16-bit (to avoid wraparound for long vectors) index of an output + // byte: offset of the u64 lane to which it belongs. + const VU16 byte_offsets = + detail::AndS(detail::Iota0(du16), static_cast(~7u)); + // idx is for a bit; shifting makes that bytes. Promote so we can add + // byte_offsets, then we have the u8 lane index within the whole vector. + const VU16 idx16 = + Add(byte_offsets, PromoteTo(du16, ShiftRight<3>(BitCast(du8, idx)))); + const VU8 bytes = detail::TableLookupLanes16(BitCast(du8, values), idx16); + + // We want to shift right by idx & 7 to extract the desired bit in `bytes`, + // and left by iota & 7 to put it in the correct output bit. To correctly + // handle shift counts from -7 to 7, we rotate (unfortunately not natively + // supported on RVV). + const VU8 rotate_left_bits = Sub(detail::Iota0(du8), BitCast(du8, idx)); + const VU8 extracted_bits_mask = + BitCast(du8, Set(du64, static_cast(0x8040201008040201u))); + const VU8 extracted_bits = + And(Rol(bytes, rotate_left_bits), extracted_bits_mask); + // Combine bit-sliced (one bit per byte) into one 64-bit sum. + return BitCast(D64(), SumsOf8(extracted_bits)); +} + +template ), class D64 = DFromV, + HWY_IF_UI64_D(D64), HWY_IF_POW2_GT_D(D64, 2)> +HWY_API V64 BitShuffle(V64 values, VI idx) { + const Half dh; + const Half> dih; + using V64H = VFromD; + const V64H r0 = BitShuffle(LowerHalf(dh, values), LowerHalf(dih, idx)); + const V64H r1 = BitShuffle(UpperHalf(dh, values), UpperHalf(dih, idx)); + return Combine(D64(), r1, r0); +} + +// ------------------------------ MulEven/Odd (Mul, OddEven) + +template , class DW = RepartitionToWide> +HWY_API VFromD MulEven(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), OddEven(detail::Slide1Up(hi), lo)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD MulOdd(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), OddEven(hi, detail::Slide1Down(lo))); +} + +// There is no 64x64 vwmul. +template +HWY_INLINE V MulEven(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return OddEven(detail::Slide1Up(hi), lo); +} + +template +HWY_INLINE V MulOdd(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return OddEven(hi, detail::Slide1Down(lo)); +} + +// ------------------------------ ReorderDemote2To (OddEven, Combine) + +template +HWY_API VFromD ReorderDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; + const Half du16_half; + const RebindToUnsigned> du32; + const VFromD a_in_even = PromoteTo( + du32, detail::DemoteTo16NearestEven(du16_half, BitCast(du32, a))); + const VFromD b_in_even = PromoteTo( + du32, detail::DemoteTo16NearestEven(du16_half, BitCast(du32, b))); + // Equivalent to InterleaveEven, but because the upper 16 bits are zero, we + // can OR instead of OddEven. + const VFromD a_in_odd = + detail::Slide1Up(BitCast(du16, a_in_even)); + return BitCast(dbf16, Or(a_in_odd, BitCast(du16, b_in_even))); +} + +// If LMUL is not the max, Combine first to avoid another DemoteTo. +template ), + HWY_IF_POW2_LE_D(DN, 2), class V, HWY_IF_SIGNED_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Rebind, DN> dt; + const VFromD ab = Combine(dt, b, a); + return DemoteTo(dn, ab); +} + +template ) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Rebind, DN> dt; + const VFromD ab = Combine(dt, b, a); + return DemoteTo(dn, ab); +} + +// Max LMUL: must DemoteTo first, then Combine. +template ), + HWY_IF_POW2_GT_D(DN, 2), class V, HWY_IF_SIGNED_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Half dnh; + const VFromD demoted_a = DemoteTo(dnh, a); + const VFromD demoted_b = DemoteTo(dnh, b); + return Combine(dn, demoted_b, demoted_a); +} + +template ) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Half dnh; + const VFromD demoted_a = DemoteTo(dnh, a); + const VFromD demoted_b = DemoteTo(dnh, b); + return Combine(dn, demoted_b, demoted_a); +} + +// If LMUL is not the max, Combine first to avoid another DemoteTo. +template ), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + const Rebind, DN> dt; + const VFromD ab = Combine(dt, b, a); + return DemoteTo(dn, ab); +} + +// Max LMUL: must DemoteTo first, then Combine. +template ), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + const Half dnh; + const RebindToUnsigned dn_u; + const RebindToUnsigned dnh_u; + const auto demoted_a = BitCast(dnh_u, DemoteTo(dnh, a)); + const auto demoted_b = BitCast(dnh_u, DemoteTo(dnh, b)); + return BitCast(dn, Combine(dn_u, demoted_b, demoted_a)); +} + +template ), class V, + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + return ReorderDemote2To(dn, a, b); +} + +// ------------------------------ WidenMulPairwiseAdd + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + const VFromD ae = PromoteEvenTo(df, a); + const VFromD be = PromoteEvenTo(df, b); + const VFromD ao = PromoteOddTo(df, a); + const VFromD bo = PromoteOddTo(df, b); + return MulAdd(ae, be, Mul(ao, bo)); +} + +template >> +HWY_API VFromD WidenMulPairwiseAdd(D d32, V16 a, V16 b) { + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +namespace detail { + +#define HWY_RVV_WIDEN_MACC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEWD, LMULD) sum, \ + HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##CHAR##SEWD##LMULD(sum, a, b, Lanes(d)); \ + } + +HWY_RVV_FOREACH_I16(HWY_RVV_WIDEN_MACC, WidenMulAcc, wmacc_vv_, _EXT_VIRT) +HWY_RVV_FOREACH_U16(HWY_RVV_WIDEN_MACC, WidenMulAcc, wmaccu_vv_, _EXT_VIRT) +#undef HWY_RVV_WIDEN_MACC + +// If LMUL is not the max, we can WidenMul first (3 instructions). +template , + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateI16(D32 d32, VFromD a, + VFromD b, const V32 sum0, + V32& sum1) { + const Twice d32t; + using V32T = VFromD; + V32T sum = Combine(d32t, sum1, sum0); + sum = detail::WidenMulAcc(d32t, sum, a, b); + sum1 = UpperHalf(d32, sum); + return LowerHalf(d32, sum); +} + +// Max LMUL: must LowerHalf first (4 instructions). +template , + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateI16(D32 d32, VFromD a, + VFromD b, const V32 sum0, + V32& sum1) { + const Half d16h; + using V16H = VFromD; + const V16H a0 = LowerHalf(d16h, a); + const V16H a1 = UpperHalf(d16h, a); + const V16H b0 = LowerHalf(d16h, b); + const V16H b1 = UpperHalf(d16h, b); + sum1 = detail::WidenMulAcc(d32, sum1, a1, b1); + return detail::WidenMulAcc(d32, sum0, a0, b0); +} + +// If LMUL is not the max, we can WidenMul first (3 instructions). +template , + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateU16(D32 d32, VFromD a, + VFromD b, const V32 sum0, + V32& sum1) { + const Twice d32t; + using V32T = VFromD; + V32T sum = Combine(d32t, sum1, sum0); + sum = detail::WidenMulAcc(d32t, sum, a, b); + sum1 = UpperHalf(d32, sum); + return LowerHalf(d32, sum); +} + +// Max LMUL: must LowerHalf first (4 instructions). +template , + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateU16(D32 d32, VFromD a, + VFromD b, const V32 sum0, + V32& sum1) { + const Half d16h; + using V16H = VFromD; + const V16H a0 = LowerHalf(d16h, a); + const V16H a1 = UpperHalf(d16h, a); + const V16H b0 = LowerHalf(d16h, b); + const V16H b1 = UpperHalf(d16h, b); + sum1 = detail::WidenMulAcc(d32, sum1, a1, b1); + return detail::WidenMulAcc(d32, sum0, a0, b0); +} + +} // namespace detail + +template +HWY_API VW ReorderWidenMulAccumulate(D d32, VN a, VN b, const VW sum0, + VW& sum1) { + return detail::ReorderWidenMulAccumulateI16(d32, a, b, sum0, sum1); +} + +template +HWY_API VW ReorderWidenMulAccumulate(D d32, VN a, VN b, const VW sum0, + VW& sum1) { + return detail::ReorderWidenMulAccumulateU16(d32, a, b, sum0, sum1); +} + +// ------------------------------ RearrangeToOddPlusEven + +template // vint32_t* +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // vwmacc doubles LMUL, so we require a pairwise sum here. This op is + // expected to be less frequent than ReorderWidenMulAccumulate, hence it's + // preferable to do the extra work here rather than do manual odd/even + // extraction there. + const DFromV di32; + const RebindToUnsigned du32; + const Twice di32x2; + const RepartitionToWide di64x2; + const RebindToUnsigned du64x2; + const auto combined = BitCast(di64x2, Combine(di32x2, sum1, sum0)); + // Isolate odd/even int32 in int64 lanes. + const auto even = ShiftRight<32>(ShiftLeft<32>(combined)); // sign extend + const auto odd = ShiftRight<32>(combined); + return BitCast(di32, TruncateTo(du32, BitCast(du64x2, Add(even, odd)))); +} + +// For max LMUL, we cannot Combine again and instead manually unroll. +HWY_API vint32m8_t RearrangeToOddPlusEven(vint32m8_t sum0, vint32m8_t sum1) { + const DFromV d; + const Half dh; + const vint32m4_t lo = + RearrangeToOddPlusEven(LowerHalf(sum0), UpperHalf(dh, sum0)); + const vint32m4_t hi = + RearrangeToOddPlusEven(LowerHalf(sum1), UpperHalf(dh, sum1)); + return Combine(d, hi, lo); +} + +template // vuint32_t* +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // vwmacc doubles LMUL, so we require a pairwise sum here. This op is + // expected to be less frequent than ReorderWidenMulAccumulate, hence it's + // preferable to do the extra work here rather than do manual odd/even + // extraction there. + const DFromV du32; + const Twice du32x2; + const RepartitionToWide du64x2; + const auto combined = BitCast(du64x2, Combine(du32x2, sum1, sum0)); + // Isolate odd/even int32 in int64 lanes. + const auto even = detail::AndS(combined, uint64_t{0xFFFFFFFFu}); + const auto odd = ShiftRight<32>(combined); + return TruncateTo(du32, Add(even, odd)); +} + +// For max LMUL, we cannot Combine again and instead manually unroll. +HWY_API vuint32m8_t RearrangeToOddPlusEven(vuint32m8_t sum0, vuint32m8_t sum1) { + const DFromV d; + const Half dh; + const vuint32m4_t lo = + RearrangeToOddPlusEven(LowerHalf(sum0), UpperHalf(dh, sum0)); + const vuint32m4_t hi = + RearrangeToOddPlusEven(LowerHalf(sum1), UpperHalf(dh, sum1)); + return Combine(d, hi, lo); +} + +template // vfloat* +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); // invariant already holds +} + +// ------------------------------ Lt128 +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + // The subsequent computations are performed using e8mf8 (8-bit elements with + // a fractional LMUL of 1/8) for the following reasons: + // 1. It is correct for the possible input vector types e64m<1,2,4,8>. This is + // because the resulting mask can occupy at most 1/8 of a full vector when + // using e64m8. + // 2. It can be more efficient than using a full vector or a vector group. + // + // The algorithm computes the result as follows: + // 1. Compute cH | (=H & cL) in the high bits, where cH and cL represent the + // comparison results for the high and low 64-bit elements, respectively. + // 2. Shift the result right by 1 to duplicate the comparison results for the + // low bits. + // 3. Obtain the final result by performing a bitwise OR on the high and low + // bits. + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t ltHL0 = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Lt(a, b))); + const vuint8mf8_t eqHL0 = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Eq(a, b))); + const vuint8mf8_t ltLx0 = Add(ltHL0, ltHL0); + const vuint8mf8_t resultHx = detail::AndS(OrAnd(ltHL0, ltLx0, eqHL0), 0xaa); + const vuint8mf8_t resultxL = ShiftRight<1>(resultHx); + const vuint8mf8_t result = Or(resultHx, resultxL); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, detail::ChangeLMUL(du8m1, result)); +} + +#else + +template +HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + // Truth table of Eq and Compare for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // Shift leftward so L can influence H. + const VFromD ltLx = detail::Slide1Up(ltHL); + const VFromD vecHx = OrAnd(ltHL, eqHL, ltLx); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(vecHx, detail::Slide1Down(vecHx))); +} + +#endif // HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// ------------------------------ Lt128Upper +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Lt128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t ltHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Lt(a, b))); + const vuint8mf8_t ltHx = detail::AndS(ltHL, 0xaa); + const vuint8mf8_t ltxL = ShiftRight<1>(ltHx); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, + detail::ChangeLMUL(du8m1, Or(ltHx, ltxL))); +} + +#else + +template +HWY_INLINE MFromD Lt128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + const VFromD down = detail::Slide1Down(ltHL); + // b(267743505): Clang compiler bug, workaround is DoNotOptimize + asm volatile("" : : "r,m"(GetLane(down)) : "memory"); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(ltHL, down)); +} + +#endif // HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// ------------------------------ Eq128 +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Eq128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t eqHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Eq(a, b))); + const vuint8mf8_t eqxH = ShiftRight<1>(eqHL); + const vuint8mf8_t result0L = detail::AndS(And(eqHL, eqxH), 0x55); + const vuint8mf8_t resultH0 = Add(result0L, result0L); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(du8m1, Or(result0L, resultH0))); +} + +#else + +template +HWY_INLINE MFromD Eq128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + const VFromD eqLH = Reverse2(d, eqHL); + const VFromD eq = And(eqHL, eqLH); + // b(267743505): Clang compiler bug, workaround is DoNotOptimize + asm volatile("" : : "r,m"(GetLane(eq)) : "memory"); + return MaskFromVec(eq); +} + +#endif + +// ------------------------------ Eq128Upper +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Eq128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t eqHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Eq(a, b))); + const vuint8mf8_t eqHx = detail::AndS(eqHL, 0xaa); + const vuint8mf8_t eqxL = ShiftRight<1>(eqHx); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, + detail::ChangeLMUL(du8m1, Or(eqHx, eqxL))); +} + +#else + +template +HWY_INLINE MFromD Eq128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(eqHL, detail::Slide1Down(eqHL))); +} + +#endif + +// ------------------------------ Ne128 +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Ne128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t neHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Ne(a, b))); + const vuint8mf8_t nexH = ShiftRight<1>(neHL); + const vuint8mf8_t result0L = detail::AndS(Or(neHL, nexH), 0x55); + const vuint8mf8_t resultH0 = Add(result0L, result0L); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(du8m1, Or(result0L, resultH0))); +} + +#else + +template +HWY_INLINE MFromD Ne128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + const VFromD neLH = Reverse2(d, neHL); + // b(267743505): Clang compiler bug, workaround is DoNotOptimize + asm volatile("" : : "r,m"(GetLane(neLH)) : "memory"); + return MaskFromVec(Or(neHL, neLH)); +} + +#endif + +// ------------------------------ Ne128Upper +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Ne128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t neHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Ne(a, b))); + const vuint8mf8_t neHx = detail::AndS(neHL, 0xaa); + const vuint8mf8_t nexL = ShiftRight<1>(neHx); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, + detail::ChangeLMUL(du8m1, Or(neHx, nexL))); +} + +#else + +template +HWY_INLINE MFromD Ne128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + const VFromD down = detail::Slide1Down(neHL); + // b(267743505): Clang compiler bug, workaround is DoNotOptimize + asm volatile("" : : "r,m"(GetLane(down)) : "memory"); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(neHL, down)); +} + +#endif + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE VFromD Min128(D /* tag */, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD minHL = Min(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, a, b); + // The upper lane is just minHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(minHL, IfThenElse(eqXH, minHL, lo)); +} + +template +HWY_INLINE VFromD Max128(D /* tag */, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD maxHL = Max(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, b, a); + // The upper lane is just maxHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(maxHL, IfThenElse(eqXH, maxHL, lo)); +} + +template +HWY_INLINE VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== END MACROS +#undef HWY_RVV_AVL +#undef HWY_RVV_D +#undef HWY_RVV_FOREACH +#undef HWY_RVV_FOREACH_08_ALL +#undef HWY_RVV_FOREACH_08_ALL_VIRT +#undef HWY_RVV_FOREACH_08_DEMOTE +#undef HWY_RVV_FOREACH_08_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_08_EXT +#undef HWY_RVV_FOREACH_08_EXT_VIRT +#undef HWY_RVV_FOREACH_08_TRUNC +#undef HWY_RVV_FOREACH_08_VIRT +#undef HWY_RVV_FOREACH_16_ALL +#undef HWY_RVV_FOREACH_16_ALL_VIRT +#undef HWY_RVV_FOREACH_16_DEMOTE +#undef HWY_RVV_FOREACH_16_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_16_EXT +#undef HWY_RVV_FOREACH_16_EXT_VIRT +#undef HWY_RVV_FOREACH_16_TRUNC +#undef HWY_RVV_FOREACH_16_VIRT +#undef HWY_RVV_FOREACH_32_ALL +#undef HWY_RVV_FOREACH_32_ALL_VIRT +#undef HWY_RVV_FOREACH_32_DEMOTE +#undef HWY_RVV_FOREACH_32_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_32_EXT +#undef HWY_RVV_FOREACH_32_EXT_VIRT +#undef HWY_RVV_FOREACH_32_TRUNC +#undef HWY_RVV_FOREACH_32_VIRT +#undef HWY_RVV_FOREACH_64_ALL +#undef HWY_RVV_FOREACH_64_ALL_VIRT +#undef HWY_RVV_FOREACH_64_DEMOTE +#undef HWY_RVV_FOREACH_64_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_64_EXT +#undef HWY_RVV_FOREACH_64_EXT_VIRT +#undef HWY_RVV_FOREACH_64_TRUNC +#undef HWY_RVV_FOREACH_64_VIRT +#undef HWY_RVV_FOREACH_B +#undef HWY_RVV_FOREACH_F +#undef HWY_RVV_FOREACH_F16 +#undef HWY_RVV_FOREACH_F32 +#undef HWY_RVV_FOREACH_F3264 +#undef HWY_RVV_FOREACH_F64 +#undef HWY_RVV_FOREACH_I +#undef HWY_RVV_FOREACH_I08 +#undef HWY_RVV_FOREACH_I16 +#undef HWY_RVV_FOREACH_I163264 +#undef HWY_RVV_FOREACH_I32 +#undef HWY_RVV_FOREACH_I64 +#undef HWY_RVV_FOREACH_U +#undef HWY_RVV_FOREACH_U08 +#undef HWY_RVV_FOREACH_U16 +#undef HWY_RVV_FOREACH_U163264 +#undef HWY_RVV_FOREACH_U32 +#undef HWY_RVV_FOREACH_U64 +#undef HWY_RVV_FOREACH_UI +#undef HWY_RVV_FOREACH_UI08 +#undef HWY_RVV_FOREACH_UI16 +#undef HWY_RVV_FOREACH_UI163264 +#undef HWY_RVV_FOREACH_UI32 +#undef HWY_RVV_FOREACH_UI3264 +#undef HWY_RVV_FOREACH_UI64 +#undef HWY_RVV_IF_EMULATED_D +#undef HWY_RVV_IF_CAN128_D +#undef HWY_RVV_IF_GE128_D +#undef HWY_RVV_IF_LT128_D +#undef HWY_RVV_INSERT_VXRM +#undef HWY_RVV_M +#undef HWY_RVV_RETM_ARGM +#undef HWY_RVV_RETV_ARGMVV +#undef HWY_RVV_RETV_ARGV +#undef HWY_RVV_RETV_ARGVS +#undef HWY_RVV_RETV_ARGVV +#undef HWY_RVV_T +#undef HWY_RVV_V +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/aom/third_party/highway/hwy/ops/scalar-inl.h b/third_party/aom/third_party/highway/hwy/ops/scalar-inl.h new file mode 100644 index 000000000000..ca59e8016986 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/scalar-inl.h @@ -0,0 +1,2170 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#ifndef HWY_NO_LIBCXX +#include // sqrtf +#endif + +#include "third_party/highway/hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Single instruction, single data. +template +using Sisd = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec1 { + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 1; // only for DFromV + + HWY_INLINE Vec1() = default; + Vec1(const Vec1&) = default; + Vec1& operator=(const Vec1&) = default; + HWY_INLINE explicit Vec1(const T t) : raw(t) {} + + HWY_INLINE Vec1& operator*=(const Vec1 other) { + return *this = (*this * other); + } + HWY_INLINE Vec1& operator/=(const Vec1 other) { + return *this = (*this / other); + } + HWY_INLINE Vec1& operator+=(const Vec1 other) { + return *this = (*this + other); + } + HWY_INLINE Vec1& operator-=(const Vec1 other) { + return *this = (*this - other); + } + HWY_INLINE Vec1& operator%=(const Vec1 other) { + return *this = (*this % other); + } + HWY_INLINE Vec1& operator&=(const Vec1 other) { + return *this = (*this & other); + } + HWY_INLINE Vec1& operator|=(const Vec1 other) { + return *this = (*this | other); + } + HWY_INLINE Vec1& operator^=(const Vec1 other) { + return *this = (*this ^ other); + } + + T raw; +}; + +// 0 or FF..FF, same size as Vec1. +template +struct Mask1 { + using Raw = hwy::MakeUnsigned; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = 1; // only for DFromM + + static HWY_INLINE Mask1 FromBool(bool b) { + Mask1 mask; + mask.bits = b ? static_cast(~Raw{0}) : 0; + return mask; + } + + Raw bits; +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +template , typename TFrom> +HWY_API Vec1 BitCast(DTo /* tag */, Vec1 v) { + static_assert(sizeof(TTo) <= sizeof(TFrom), "Promoting is undefined"); + TTo to; + CopyBytes(&v.raw, &to); // not same size - ok to shrink + return Vec1(to); +} + +// ------------------------------ Zero + +template > +HWY_API Vec1 Zero(D /* tag */) { + return Vec1(ConvertScalarTo(0)); +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set +template , typename T2> +HWY_API Vec1 Set(D /* tag */, const T2 t) { + return Vec1(static_cast(t)); +} + +// ------------------------------ Undefined +template > +HWY_API Vec1 Undefined(D d) { + return Zero(d); +} + +// ------------------------------ Iota +template , typename T2> +HWY_API Vec1 Iota(const D /* tag */, const T2 first) { + return Vec1(static_cast(first)); +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D /* tag */, FromV v) { + using TFrom = TFromV; + using TTo = TFromD; + constexpr size_t kCopyLen = HWY_MIN(sizeof(TFrom), sizeof(TTo)); + TTo to{}; + CopyBytes(&v.raw, &to); + return VFromD(to); +} + +namespace detail { + +// ResizeBitCast on the HWY_SCALAR target has zero-extending semantics if +// sizeof(TFromD) is greater than sizeof(TFromV) +template +HWY_INLINE VFromD ZeroExtendResizeBitCast(FromSizeTag /* from_size_tag */, + ToSizeTag /* to_size_tag */, + DTo d_to, DFrom /*d_from*/, + VFromD v) { + return ResizeBitCast(d_to, v); +} + +} // namespace detail + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/, + TFromD /*t2*/, TFromD /*t3*/, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/, + TFromD /*t8*/, TFromD /*t9*/, + TFromD /*t10*/, TFromD /*t11*/, + TFromD /*t12*/, TFromD /*t13*/, + TFromD /*t14*/, TFromD /*t15*/) { + return VFromD(t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/, + TFromD /*t2*/, TFromD /*t3*/, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/) { + return VFromD(t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/, + TFromD /*t2*/, TFromD /*t3*/) { + return VFromD(t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/) { + return VFromD(t0); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec1 Not(const Vec1 v) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, v).raw))); +} + +// ------------------------------ And + +template +HWY_API Vec1 And(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw & BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator&(const Vec1 a, const Vec1 b) { + return And(a, b); +} + +// ------------------------------ AndNot + +template +HWY_API Vec1 AndNot(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, a).raw & + BitCast(du, b).raw))); +} + +// ------------------------------ Or + +template +HWY_API Vec1 Or(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw | BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator|(const Vec1 a, const Vec1 b) { + return Or(a, b); +} + +// ------------------------------ Xor + +template +HWY_API Vec1 Xor(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw ^ BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator^(const Vec1 a, const Vec1 b) { + return Xor(a, b); +} + +// ------------------------------ Xor3 + +template +HWY_API Vec1 Xor3(Vec1 x1, Vec1 x2, Vec1 x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 + +template +HWY_API Vec1 Or3(Vec1 o1, Vec1 o2, Vec1 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec1 OrAnd(const Vec1 o, const Vec1 a1, const Vec1 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ Mask + +template , typename TFrom> +HWY_API Mask1 RebindMask(DTo /*tag*/, Mask1 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask1{m.bits}; +} + +// v must be 0 or FF..FF. +template +HWY_API Mask1 MaskFromVec(const Vec1 v) { + Mask1 mask; + CopySameSize(&v, &mask); + return mask; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template > +Vec1 VecFromMask(D /* tag */, const Mask1 mask) { + Vec1 v; + CopySameSize(&mask, &v); + return v; +} + +template +uint64_t BitsFromMask(D, MFromD mask) { + return mask.bits ? 1 : 0; +} + +template > +HWY_API Mask1 FirstN(D /*tag*/, size_t n) { + return Mask1::FromBool(n != 0); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec1 IfVecThenElse(Vec1 mask, Vec1 yes, Vec1 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ CopySign +template +HWY_API Vec1 CopySign(const Vec1 magn, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API Vec1 CopySignToAbs(const Vec1 abs, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const Sisd d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API Vec1 BroadcastSignBit(const Vec1 v) { + return Vec1(ScalarShr(v.raw, sizeof(T) * 8 - 1)); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +template +HWY_API Vec1 PopulationCount(Vec1 v) { + return Vec1(static_cast(PopCount(v.raw))); +} + +// ------------------------------ IfThenElse + +// Returns mask ? yes : no. +template +HWY_API Vec1 IfThenElse(const Mask1 mask, const Vec1 yes, + const Vec1 no) { + return mask.bits ? yes : no; +} + +template +HWY_API Vec1 IfThenElseZero(const Mask1 mask, const Vec1 yes) { + return mask.bits ? yes : Vec1(ConvertScalarTo(0)); +} + +template +HWY_API Vec1 IfThenZeroElse(const Mask1 mask, const Vec1 no) { + return mask.bits ? Vec1(ConvertScalarTo(0)) : no; +} + +template +HWY_API Vec1 IfNegativeThenElse(Vec1 v, Vec1 yes, Vec1 no) { + const DFromV d; + const RebindToSigned di; + const auto vi = BitCast(di, v); + + return vi.raw < 0 ? yes : no; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask1 Not(const Mask1 m) { + return MaskFromVec(Not(VecFromMask(Sisd(), m))); +} + +template +HWY_API Mask1 And(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 AndNot(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Or(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Xor(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 ExclusiveNeither(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +template +HWY_API Mask1 SetAtOrAfterFirst(Mask1 mask) { + return mask; +} + +template +HWY_API Mask1 SetBeforeFirst(Mask1 mask) { + return Not(mask); +} + +template +HWY_API Mask1 SetOnlyFirst(Mask1 mask) { + return mask; +} + +template +HWY_API Mask1 SetAtOrBeforeFirst(Mask1 /*mask*/) { + return Mask1::FromBool(true); +} + +// ------------------------------ LowerHalfOfMask + +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API MFromD LowerHalfOfMask(D /*d*/, MFromD m) { + return m; +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeft(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return Vec1( + static_cast(static_cast>(v.raw) << kBits)); +} + +template +HWY_API Vec1 ShiftRight(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return Vec1(ScalarShr(v.raw, kBits)); +} + +// ------------------------------ RotateRight (ShiftRight) +template +HWY_API Vec1 RotateRight(const Vec1 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ ShiftLeftSame (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeftSame(const Vec1 v, int bits) { + return Vec1( + static_cast(static_cast>(v.raw) << bits)); +} + +template +HWY_API Vec1 ShiftRightSame(const Vec1 v, int bits) { + return Vec1(ScalarShr(v.raw, bits)); +} + +// ------------------------------ Shl + +// Single-lane => same as ShiftLeftSame except for the argument type. +template +HWY_API Vec1 operator<<(const Vec1 v, const Vec1 bits) { + return ShiftLeftSame(v, static_cast(bits.raw)); +} + +template +HWY_API Vec1 operator>>(const Vec1 v, const Vec1 bits) { + return ShiftRightSame(v, static_cast(bits.raw)); +} + +// ================================================== ARITHMETIC + +template +HWY_API Vec1 operator+(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 + b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} + +template +HWY_API Vec1 operator-(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 - b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} + +// ------------------------------ SumsOf8 + +HWY_API Vec1 SumsOf8(const Vec1 v) { + return Vec1(v.raw); +} +HWY_API Vec1 SumsOf8(const Vec1 v) { + return Vec1(v.raw); +} + +// ------------------------------ SumsOf2 + +template +HWY_API Vec1> SumsOf2(const Vec1 v) { + const DFromV d; + const Rebind, decltype(d)> dw; + return PromoteTo(dw, v); +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 255))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast( + HWY_MIN(HWY_MAX(0, static_cast(a.raw) + b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedAdd(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw + b.raw), 127))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast( + HWY_MIN(HWY_MAX(-32768, static_cast(a.raw) + b.raw), 32767))); +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 255))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast( + HWY_MIN(HWY_MAX(0, static_cast(a.raw) - b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedSub(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw - b.raw), 127))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast( + HWY_MIN(HWY_MAX(-32768, static_cast(a.raw) - b.raw), 32767))); +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +template +HWY_API Vec1 AverageRound(const Vec1 a, const Vec1 b) { + const T a_val = a.raw; + const T b_val = b.raw; + return Vec1(static_cast((a_val | b_val) - ScalarShr(a_val ^ b_val, 1))); +} + +// ------------------------------ Absolute value + +template +HWY_API Vec1 Abs(const Vec1 a) { + return Vec1(ScalarAbs(a.raw)); +} + +// ------------------------------ Min/Max + +// may be unavailable, so implement our own. + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + if (ScalarIsNaN(a.raw)) return b; + if (ScalarIsNaN(b.raw)) return a; + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + if (ScalarIsNaN(a.raw)) return b; + if (ScalarIsNaN(b.raw)) return a; + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +// ------------------------------ Floating-point negate + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Xor(v, SignBit(Sisd())); +} + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Zero(Sisd()) - v; +} + +// ------------------------------ mul/div + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(double{a.raw} * b.raw)); +} + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(static_cast(a.raw) * + static_cast(b.raw))); +} + +template +HWY_API Vec1 operator/(const Vec1 a, const Vec1 b) { + return Vec1(a.raw / b.raw); +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + using TW = MakeWide; + return Vec1(static_cast( + (static_cast(a.raw) * static_cast(b.raw)) >> (sizeof(T) * 8))); +} +template +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + T hi; + Mul128(a.raw, b.raw, &hi); + return Vec1(hi); +} + +HWY_API Vec1 MulFixedPoint15(Vec1 a, Vec1 b) { + return Vec1(static_cast((a.raw * b.raw + 16384) >> 15)); +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +template +HWY_API Vec1> MulEven(const Vec1 a, const Vec1 b) { + using TW = MakeWide; + const TW a_wide = a.raw; + return Vec1(static_cast(a_wide * b.raw)); +} + +template +HWY_API Vec1> MulOdd(const Vec1, const Vec1) { + static_assert(sizeof(T) == 0, "There are no odd lanes"); +} + +// Approximate reciprocal +HWY_API Vec1 ApproximateReciprocal(const Vec1 v) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The return value is arbitrary. + if (v.raw == 0.0f) return Vec1(0.0f); + return Vec1(1.0f / v.raw); +} + +// generic_ops takes care of integer T. +template +HWY_API Vec1 AbsDiff(const Vec1 a, const Vec1 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec1 MulAdd(const Vec1 mul, const Vec1 x, const Vec1 add) { + return mul * x + add; +} + +template +HWY_API Vec1 NegMulAdd(const Vec1 mul, const Vec1 x, + const Vec1 add) { + return add - mul * x; +} + +template +HWY_API Vec1 MulSub(const Vec1 mul, const Vec1 x, const Vec1 sub) { + return mul * x - sub; +} + +template +HWY_API Vec1 NegMulSub(const Vec1 mul, const Vec1 x, + const Vec1 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Approximate reciprocal square root +HWY_API Vec1 ApproximateReciprocalSqrt(const Vec1 v) { + float f = v.raw; + const float half = f * 0.5f; + uint32_t bits; + CopySameSize(&f, &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopySameSize(&bits, &f); + // One Newton-Raphson iteration + return Vec1(f * (1.5f - (half * f * f))); +} + +// Square root +HWY_API Vec1 Sqrt(Vec1 v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return Vec1(__builtin_sqrt(v.raw)); +#else + uint32_t bits; + CopyBytes(&v, &bits); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1 << 29) + (bits >> 1) - (1 << 22); + CopyBytes(&bits, &v); + return v; +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return Vec1(sqrtf(v.raw)); +#endif // !HWY_NO_LIBCXX +} +HWY_API Vec1 Sqrt(Vec1 v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return Vec1(__builtin_sqrt(v.raw)); +#else + uint64_t bits; + CopyBytes(&v, &bits); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1ULL << 61) + (bits >> 1) - (1ULL << 51); + CopyBytes(&bits, &v); + return v; +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return Vec1(sqrt(v.raw)); +#endif // HWY_NO_LIBCXX +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec1 Round(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw < MantissaEnd())) { // Huge or NaN + return v; + } + const T k0 = ConvertScalarTo(0); + const T bias = ConvertScalarTo(v.raw < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw + bias); + if (rounded == 0) return CopySignToAbs(Vec1(k0), v); + TI offset = 0; + // Round to even + if ((rounded & 1) && ScalarAbs(ConvertScalarTo(rounded) - v.raw) == + ConvertScalarTo(0.5)) { + offset = v.raw < k0 ? -1 : 1; + } + return Vec1(ConvertScalarTo(rounded - offset)); +} + +// Round-to-nearest even. +template +HWY_API Vec1> NearestInt(const Vec1 v) { + using TI = MakeSigned; + + const T abs = Abs(v).raw; + const bool is_sign = ScalarSignBit(v.raw); + + if (!(abs < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs <= ConvertScalarTo(LimitsMax()))) { + return Vec1(is_sign ? LimitsMin() : LimitsMax()); + } + return Vec1(ConvertScalarTo(v.raw)); + } + const T bias = + ConvertScalarTo(v.raw < ConvertScalarTo(0.0) ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw + bias); + if (rounded == 0) return Vec1(0); + TI offset = 0; + // Round to even + if ((rounded & 1) && ScalarAbs(ConvertScalarTo(rounded) - v.raw) == + ConvertScalarTo(0.5)) { + offset = is_sign ? -1 : 1; + } + return Vec1(rounded - offset); +} + +// Round-to-nearest even. +template +HWY_API VFromD DemoteToNearestInt(DI32 /*di32*/, const Vec1 v) { + using T = double; + using TI = int32_t; + + const T abs = Abs(v).raw; + const bool is_sign = ScalarSignBit(v.raw); + + // Check if too large to cast or NaN + if (!(abs <= ConvertScalarTo(LimitsMax()))) { + return Vec1(is_sign ? LimitsMin() : LimitsMax()); + } + + const T bias = + ConvertScalarTo(v.raw < ConvertScalarTo(0.0) ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw + bias); + if (rounded == 0) return Vec1(0); + TI offset = 0; + // Round to even + if ((rounded & 1) && ScalarAbs(ConvertScalarTo(rounded) - v.raw) == + ConvertScalarTo(0.5)) { + offset = is_sign ? -1 : 1; + } + return Vec1(rounded - offset); +} + +template +HWY_API Vec1 Trunc(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw <= MantissaEnd())) { // Huge or NaN + return v; + } + const TI truncated = ConvertScalarTo(v.raw); + if (truncated == 0) return CopySignToAbs(Vec1(0), v); + return Vec1(ConvertScalarTo(truncated)); +} + +template +V Ceiling(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool positive = f > Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => 0 or 1. + if (exponent < 0) return positive ? V(1) : V(-0.0); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +template +V Floor(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool negative = f < Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => -1 or 0. + if (exponent < 0) return V(negative ? Float(-1.0) : Float(0.0)); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +// Toward +infinity, aka ceiling +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} + +// Toward -infinity, aka floor +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} + +// ================================================== COMPARE + +template +HWY_API Mask1 operator==(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw == b.raw); +} + +template +HWY_API Mask1 operator!=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw != b.raw); +} + +template +HWY_API Mask1 TestBit(const Vec1 v, const Vec1 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask1 operator<(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw < b.raw); +} +template +HWY_API Mask1 operator>(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw > b.raw); +} + +template +HWY_API Mask1 operator<=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw <= b.raw); +} +template +HWY_API Mask1 operator>=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw >= b.raw); +} + +// ------------------------------ Floating-point classification (==) + +template +HWY_API Mask1 IsNaN(const Vec1 v) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + return Mask1::FromBool(ScalarIsNaN(v.raw)); +} + +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFF000000u)); +} +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFFE0000000000000ull)); +} + +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFF000000u); +} +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFFE0000000000000ull); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template > +HWY_API Vec1 Load(D /* tag */, const T* HWY_RESTRICT aligned) { + T t; + CopySameSize(aligned, &t); + return Vec1(t); +} + +template > +HWY_API Vec1 MaskedLoad(Mask1 m, D d, const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template > +HWY_API Vec1 MaskedLoadOr(Vec1 v, Mask1 m, D d, + const T* HWY_RESTRICT aligned) { + return IfThenElse(m, Load(d, aligned), v); +} + +template > +HWY_API Vec1 LoadU(D d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template > +HWY_API Vec1 LoadDup128(D d, const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +template > +HWY_API VFromD LoadN(D d, const T* HWY_RESTRICT p, + size_t max_lanes_to_load) { + return (max_lanes_to_load > 0) ? Load(d, p) : Zero(d); +} + +template > +HWY_API VFromD LoadNOr(VFromD no, D d, const T* HWY_RESTRICT p, + size_t max_lanes_to_load) { + return (max_lanes_to_load > 0) ? Load(d, p) : no; +} + +// ------------------------------ Store + +template > +HWY_API void Store(const Vec1 v, D /* tag */, T* HWY_RESTRICT aligned) { + CopySameSize(&v.raw, aligned); +} + +template > +HWY_API void StoreU(const Vec1 v, D d, T* HWY_RESTRICT p) { + return Store(v, d, p); +} + +template > +HWY_API void BlendedStore(const Vec1 v, Mask1 m, D d, T* HWY_RESTRICT p) { + if (!m.bits) return; + StoreU(v, d, p); +} + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 0) { + Store(v, d, p); + } +} + +// ------------------------------ Tuples +#include "third_party/highway/hwy/ops/inside-inl.h" + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining StoreInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, Vec1& v0, + Vec1& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, Vec1& v0, + Vec1& v1, Vec1& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, Vec1& v0, + Vec1& v1, Vec1& v2, Vec1& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template > +HWY_API void StoreInterleaved2(const Vec1 v0, const Vec1 v1, D d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); +} + +template > +HWY_API void StoreInterleaved3(const Vec1 v0, const Vec1 v1, + const Vec1 v2, D d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +template > +HWY_API void StoreInterleaved4(const Vec1 v0, const Vec1 v1, + const Vec1 v2, const Vec1 v3, D d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); + StoreU(v3, d, unaligned + 3); +} + +// ------------------------------ Stream + +template > +HWY_API void Stream(const Vec1 v, D d, T* HWY_RESTRICT aligned) { + return Store(v, d, aligned); +} + +// ------------------------------ Scatter + +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +template , typename TI> +HWY_API void ScatterOffset(Vec1 v, D d, T* base, Vec1 offset) { + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + const intptr_t addr = + reinterpret_cast(base) + static_cast(offset.raw); + Store(v, d, reinterpret_cast(addr)); +} + +template , typename TI> +HWY_API void ScatterIndex(Vec1 v, D d, T* HWY_RESTRICT base, + Vec1 index) { + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + Store(v, d, base + index.raw); +} + +template , typename TI> +HWY_API void MaskedScatterIndex(Vec1 v, Mask1 m, D d, + T* HWY_RESTRICT base, Vec1 index) { + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + if (m.bits) Store(v, d, base + index.raw); +} + +// ------------------------------ Gather + +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +template > +HWY_API Vec1 GatherOffset(D d, const T* base, Vec1> offset) { + HWY_DASSERT(offset.raw >= 0); + const intptr_t addr = + reinterpret_cast(base) + static_cast(offset.raw); + return Load(d, reinterpret_cast(addr)); +} + +template > +HWY_API Vec1 GatherIndex(D d, const T* HWY_RESTRICT base, + Vec1> index) { + HWY_DASSERT(index.raw >= 0); + return Load(d, base + index.raw); +} + +template > +HWY_API Vec1 MaskedGatherIndex(Mask1 m, D d, const T* HWY_RESTRICT base, + Vec1> index) { + HWY_DASSERT(index.raw >= 0); + return MaskedLoad(m, d, base + index.raw); +} + +template > +HWY_API Vec1 MaskedGatherIndexOr(Vec1 no, Mask1 m, D d, + const T* HWY_RESTRICT base, + Vec1> index) { + HWY_DASSERT(index.raw >= 0); + return MaskedLoadOr(no, m, d, base + index.raw); +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +namespace detail { + +template +HWY_INLINE ToT CastValueForF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + using ToTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(static_cast(LimitsMax()) + + static_cast(ScalarSignBit(val))); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(ToTypeTag /* to_type_tag */, FromT val) { + return ConvertScalarTo(val); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(hwy::SignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(hwy::UnsignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} + +// If val is within the range of ToT, CastValueForInRangeF2IConv(val) +// returns static_cast(val) +// +// Otherwise, CastValueForInRangeF2IConv(val) returns an +// implementation-defined result if val is not within the range of ToT. +template +HWY_INLINE ToT CastValueForInRangeF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(LimitsMin()); +} + +} // namespace detail + +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +template , typename TFrom> +HWY_API Vec1 PromoteTo(DTo /* tag */, Vec1 from) { + static_assert(sizeof(TTo) > sizeof(TFrom), "Not promoting"); + // For bits Y > X, floatX->floatY and intX->intY are always representable. + return Vec1( + detail::CastValueForPromoteTo(hwy::TypeTag(), from.raw)); +} + +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD PromoteInRangeTo(DTo /* tag */, Vec1 from) { + using TTo = TFromD; + return Vec1(detail::CastValueForInRangeF2IConv(from.raw)); +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(TFrom) is here, +// so we overload for TFrom=double and TTo={float,int32_t}. +template +HWY_API Vec1 DemoteTo(D /* tag */, Vec1 from) { + // Prevent ubsan errors when converting float to narrower integer/float + if (IsInf(from).bits || + Abs(from).raw > static_cast(HighestValue())) { + return Vec1(ScalarSignBit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec1 from) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + return Vec1>(detail::CastValueForF2IConv>(from.raw)); +} + +template , typename TFrom, + HWY_IF_SIGNED(TFrom), HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD)> +HWY_API Vec1 DemoteTo(DTo /* tag */, Vec1 from) { + static_assert(!IsFloat(), "TFrom=double are handled above"); + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + // Int to int: choose closest value in TTo to `from` (avoids UB) + from.raw = HWY_MIN(HWY_MAX(LimitsMin(), from.raw), LimitsMax()); + return Vec1(static_cast(from.raw)); +} + +// Disable the default unsigned to signed DemoteTo implementation in +// generic_ops-inl.h on SCALAR as the SCALAR target has a target-specific +// implementation of the unsigned to signed DemoteTo op and as ReorderDemote2To +// is not supported on the SCALAR target + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + +template , typename TFrom, + HWY_IF_UNSIGNED(TFrom), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(DTo)> +HWY_API Vec1 DemoteTo(DTo /* tag */, Vec1 from) { + static_assert(!IsFloat(), "TFrom=double are handled above"); + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + const auto max = static_cast>(LimitsMax()); + + // Int to int: choose closest value in TTo to `from` (avoids UB) + return Vec1(static_cast(HWY_MIN(from.raw, max))); +} + +template , typename TFrom, + HWY_IF_UI64(TFrom), HWY_IF_F32_D(DTo)> +HWY_API Vec1 DemoteTo(DTo /* tag */, Vec1 from) { + // int64_t/uint64_t to float: simply cast to TTo + return Vec1(static_cast(from.raw)); +} + +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD DemoteInRangeTo(D32 /*d32*/, + VFromD> v) { + using TTo = TFromD; + return Vec1(detail::CastValueForInRangeF2IConv(v.raw)); +} + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions; +// use this scalar version to verify the vector implementation. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_API Vec1 PromoteTo(D /* tag */, const Vec1 v) { + return Vec1(F32FromF16(v.raw)); +} + +template +HWY_API Vec1 PromoteTo(D d, const Vec1 v) { + return Set(d, F32FromBF16(v.raw)); +} + +template +HWY_API VFromD PromoteEvenTo(DTo d_to, Vec1 v) { + return PromoteTo(d_to, v); +} + +template +HWY_API Vec1 DemoteTo(D /* tag */, const Vec1 v) { + return Vec1(F16FromF32(v.raw)); +} + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +template +HWY_API Vec1 DemoteTo(D d, const Vec1 v) { + return Set(d, BF16FromF32(v.raw)); +} + +template , typename TFrom, + HWY_IF_FLOAT(TFrom)> +HWY_API Vec1 ConvertTo(DTo /* tag */, Vec1 from) { + static_assert(sizeof(TTo) == sizeof(TFrom), "Should have same size"); + // float## -> int##: return closest representable value. + return Vec1(detail::CastValueForF2IConv(from.raw)); +} + +template , typename TFrom, + HWY_IF_NOT_FLOAT(TFrom)> +HWY_API Vec1 ConvertTo(DTo /* tag */, Vec1 from) { + static_assert(sizeof(TTo) == sizeof(TFrom), "Should have same size"); + // int## -> float##: no check needed + return Vec1(static_cast(from.raw)); +} + +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, VFromD> v) { + using TTo = TFromD; + return VFromD(detail::CastValueForInRangeF2IConv(v.raw)); +} + +HWY_API Vec1 U8FromU32(const Vec1 v) { + return DemoteTo(Sisd(), v); +} + +// ------------------------------ TruncateTo + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFFFFFFu)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +// ================================================== COMBINE +// UpperHalf, ZeroExtendVector, Combine, Concat* are unsupported. + +template +HWY_API Vec1 LowerHalf(Vec1 v) { + return v; +} + +template > +HWY_API Vec1 LowerHalf(D /* tag */, Vec1 v) { + return v; +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(const Vec1 v) { + return v.raw; +} + +template +HWY_API T ExtractLane(const Vec1 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return v.raw; +} + +template +HWY_API Vec1 InsertLane(Vec1 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + v.raw = t; + return v; +} + +template +HWY_API Vec1 DupEven(Vec1 v) { + return v; +} +// DupOdd is unsupported. + +template +HWY_API Vec1 OddEven(Vec1 /* odd */, Vec1 even) { + return even; +} + +template +HWY_API Vec1 OddEvenBlocks(Vec1 /* odd */, Vec1 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec1 SwapAdjacentBlocks(Vec1 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template > +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template > +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices1 { + MakeSigned raw; +}; + +template , typename TI> +HWY_API Indices1 IndicesFromVec(D, Vec1 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); + HWY_DASSERT(vec.raw <= 1); + return Indices1{static_cast>(vec.raw)}; +} + +template , typename TI> +HWY_API Indices1 SetTableIndices(D d, const TI* idx) { + return IndicesFromVec(d, LoadU(Sisd(), idx)); +} + +template +HWY_API Vec1 TableLookupLanes(const Vec1 v, const Indices1 /* idx */) { + return v; +} + +template +HWY_API Vec1 TwoTablesLookupLanes(const Vec1 a, const Vec1 b, + const Indices1 idx) { + return (idx.raw == 0) ? a : b; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template > +HWY_API Vec1 ReverseBlocks(D /* tag */, const Vec1 v) { + return v; +} + +// ------------------------------ Reverse + +template > +HWY_API Vec1 Reverse(D /* tag */, const Vec1 v) { + return v; +} + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +// Must not be called: +template > +HWY_API Vec1 Reverse2(D /* tag */, const Vec1 v) { + return v; +} + +template > +HWY_API Vec1 Reverse4(D /* tag */, const Vec1 v) { + return v; +} + +template > +HWY_API Vec1 Reverse8(D /* tag */, const Vec1 v) { + return v; +} + +// ------------------------------ ReverseLaneBytes + +#ifdef HWY_NATIVE_REVERSE_LANE_BYTES +#undef HWY_NATIVE_REVERSE_LANE_BYTES +#else +#define HWY_NATIVE_REVERSE_LANE_BYTES +#endif + +HWY_API Vec1 ReverseLaneBytes(Vec1 v) { + const uint32_t val{v.raw}; + return Vec1( + static_cast(((val << 8) & 0xFF00u) | ((val >> 8) & 0x00FFu))); +} + +HWY_API Vec1 ReverseLaneBytes(Vec1 v) { + const uint32_t val = v.raw; + return Vec1(static_cast( + ((val << 24) & 0xFF000000u) | ((val << 8) & 0x00FF0000u) | + ((val >> 8) & 0x0000FF00u) | ((val >> 24) & 0x000000FFu))); +} + +HWY_API Vec1 ReverseLaneBytes(Vec1 v) { + const uint64_t val = v.raw; + return Vec1(static_cast( + ((val << 56) & 0xFF00000000000000u) | + ((val << 40) & 0x00FF000000000000u) | + ((val << 24) & 0x0000FF0000000000u) | ((val << 8) & 0x000000FF00000000u) | + ((val >> 8) & 0x00000000FF000000u) | ((val >> 24) & 0x0000000000FF0000u) | + ((val >> 40) & 0x000000000000FF00u) | + ((val >> 56) & 0x00000000000000FFu))); +} + +template +HWY_API V ReverseLaneBytes(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ReverseLaneBytes(BitCast(du, v))); +} + +// ------------------------------ ReverseBits +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +#ifdef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#undef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#else +#define HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#endif + +namespace detail { + +template +HWY_INLINE T ReverseBitsOfEachByte(T val) { + using TU = MakeUnsigned; + constexpr TU kMaxUnsignedVal{LimitsMax()}; + constexpr TU kShrMask1 = + static_cast(0x5555555555555555u & kMaxUnsignedVal); + constexpr TU kShrMask2 = + static_cast(0x3333333333333333u & kMaxUnsignedVal); + constexpr TU kShrMask3 = + static_cast(0x0F0F0F0F0F0F0F0Fu & kMaxUnsignedVal); + + constexpr TU kShlMask1 = static_cast(~kShrMask1); + constexpr TU kShlMask2 = static_cast(~kShrMask2); + constexpr TU kShlMask3 = static_cast(~kShrMask3); + + TU result = static_cast(val); + result = static_cast(((result << 1) & kShlMask1) | + ((result >> 1) & kShrMask1)); + result = static_cast(((result << 2) & kShlMask2) | + ((result >> 2) & kShrMask2)); + result = static_cast(((result << 4) & kShlMask3) | + ((result >> 4) & kShrMask3)); + return static_cast(result); +} + +} // namespace detail + +template +HWY_API V ReverseBits(V v) { + return V(detail::ReverseBitsOfEachByte(v.raw)); +} + +template +HWY_API V ReverseBits(V v) { + return ReverseLaneBytes(V(detail::ReverseBitsOfEachByte(v.raw))); +} + +template +HWY_API V ReverseBits(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ReverseBits(BitCast(du, v))); +} + +// ------------------------------ SlideUpLanes + +template +HWY_API VFromD SlideUpLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +// ------------------------------ SlideDownLanes + +template +HWY_API VFromD SlideDownLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +// ================================================== BLOCKWISE +// Shift*Bytes, CombineShiftRightBytes, Interleave*, Shuffle* are unsupported. + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec1 Broadcast(const Vec1 v) { + static_assert(kLane == 0, "Scalar only has one lane"); + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template +HWY_API Vec1 TableLookupBytes(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +template +HWY_API Vec1 TableLookupBytesOr0(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = idx_bytes[i] & 0x80 ? 0 : in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +// ------------------------------ ZipLower + +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1(static_cast((uint32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1((uint32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1((uint64_t{b.raw} << 32) + a.raw); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1(static_cast((int32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1((int32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1((int64_t{b.raw} << 32) + a.raw); +} + +template , typename TN = MakeNarrow> +HWY_API Vec1 ZipLower(DW /* tag */, Vec1 a, Vec1 b) { + return Vec1(static_cast((TW{b.raw} << (sizeof(TN) * 8)) + a.raw)); +} + +// ================================================== MASK + +template > +HWY_API bool AllFalse(D /* tag */, const Mask1 mask) { + return mask.bits == 0; +} + +template > +HWY_API bool AllTrue(D /* tag */, const Mask1 mask) { + return mask.bits != 0; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template > +HWY_API Mask1 LoadMaskBits(D /* tag */, const uint8_t* HWY_RESTRICT bits) { + return Mask1::FromBool((bits[0] & 1) != 0); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D /*d*/, unsigned mask_bits) { + return MFromD::FromBool((mask_bits & 1) != 0); +} + +// `p` points to at least 8 writable bytes. +template > +HWY_API size_t StoreMaskBits(D d, const Mask1 mask, uint8_t* bits) { + *bits = AllTrue(d, mask); + return 1; +} + +template > +HWY_API size_t CountTrue(D /* tag */, const Mask1 mask) { + return mask.bits == 0 ? 0 : 1; +} + +template > +HWY_API intptr_t FindFirstTrue(D /* tag */, const Mask1 mask) { + return mask.bits == 0 ? -1 : 0; +} + +template > +HWY_API size_t FindKnownFirstTrue(D /* tag */, const Mask1 /* m */) { + return 0; // There is only one lane and we know it is true. +} + +template > +HWY_API intptr_t FindLastTrue(D /* tag */, const Mask1 mask) { + return mask.bits == 0 ? -1 : 0; +} + +template > +HWY_API size_t FindKnownLastTrue(D /* tag */, const Mask1 /* m */) { + return 0; // There is only one lane and we know it is true. +} + +// ------------------------------ Compress, CompressBits + +template +struct CompressIsPartition { + enum { value = 1 }; +}; + +template +HWY_API Vec1 Compress(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +template +HWY_API Vec1 CompressNot(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +// ------------------------------ CompressStore +template > +HWY_API size_t CompressStore(Vec1 v, const Mask1 mask, D d, + T* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template > +HWY_API size_t CompressBlendedStore(Vec1 v, const Mask1 mask, D d, + T* HWY_RESTRICT unaligned) { + if (!mask.bits) return 0; + StoreU(v, d, unaligned); + return 1; +} + +// ------------------------------ CompressBits +template +HWY_API Vec1 CompressBits(Vec1 v, const uint8_t* HWY_RESTRICT /*bits*/) { + return v; +} + +// ------------------------------ CompressBitsStore +template > +HWY_API size_t CompressBitsStore(Vec1 v, const uint8_t* HWY_RESTRICT bits, + D d, T* HWY_RESTRICT unaligned) { + const Mask1 mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ Expand + +// generic_ops-inl.h requires Vec64/128, so implement [Load]Expand here. +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +template +HWY_API Vec1 Expand(Vec1 v, const Mask1 mask) { + return IfThenElseZero(mask, v); +} + +// ------------------------------ LoadExpand +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return MaskedLoad(mask, d, unaligned); +} + +// ------------------------------ WidenMulPairwiseAdd + +template +HWY_API Vec1 WidenMulPairwiseAdd(D32 /* tag */, Vec1 a, + Vec1 b) { + return Vec1(F32FromBF16(a.raw)) * Vec1(F32FromBF16(b.raw)); +} + +template +HWY_API Vec1 WidenMulPairwiseAdd(D32 /* tag */, Vec1 a, + Vec1 b) { + return Vec1(a.raw * b.raw); +} + +// ------------------------------ SatWidenMulAccumFixedPoint +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + // Multiplying static_cast(a.raw) by static_cast(b.raw) + // followed by an addition of the product is okay as + // (a.raw * b.raw * 2) is between -2147418112 and 2147483648 and as + // a.raw * b.raw * 2 can only overflow an int32_t if both a.raw and b.raw are + // equal to -32768. + + const VFromD product(static_cast(a.raw) * + static_cast(b.raw)); + const VFromD product2 = Add(product, product); + + const auto mul_overflow = + VecFromMask(di32, Eq(product2, Set(di32, LimitsMin()))); + + return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), + Add(product2, mul_overflow)); +} + +// ------------------------------ SatWidenMulPairwiseAdd + +#ifdef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#undef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#else +#define HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#endif + +template +HWY_API Vec1 SatWidenMulPairwiseAdd(DI16 /* tag */, Vec1 a, + Vec1 b) { + // Saturation of a.raw * b.raw is not needed on the HWY_SCALAR target as the + // input vectors only have 1 lane on the HWY_SCALAR target and as + // a.raw * b.raw is between -32640 and 32385, which is already within the + // range of an int16_t. + + // On other targets, a saturated addition of a[0]*b[0] + a[1]*b[1] is needed + // as it is possible for the addition of a[0]*b[0] + a[1]*b[1] to overflow if + // a[0], a[1], b[0], and b[1] are all non-zero and b[0] and b[1] both have the + // same sign. + + return Vec1(static_cast(a.raw) * + static_cast(b.raw)); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +template +HWY_API Vec1 ReorderWidenMulAccumulate(D32 /* tag */, Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return MulAdd(Vec1(F32FromBF16(a.raw)), + Vec1(F32FromBF16(b.raw)), sum0); +} + +template +HWY_API Vec1 ReorderWidenMulAccumulate(D32 /* tag */, Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return Vec1(a.raw * b.raw + sum0.raw); +} + +template +HWY_API Vec1 ReorderWidenMulAccumulate(DU32 /* tag */, + Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return Vec1(static_cast(a.raw) * b.raw + sum0.raw); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec1 RearrangeToOddPlusEven(Vec1 sum0, Vec1 /* sum1 */) { + return sum0; // invariant already holds +} + +// ================================================== REDUCTIONS + +// Nothing native, generic_ops-inl defines SumOfLanes and ReduceSum. + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/aom/third_party/highway/hwy/ops/set_macros-inl.h b/third_party/aom/third_party/highway/hwy/ops/set_macros-inl.h new file mode 100644 index 000000000000..2cadeb8ad766 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/set_macros-inl.h @@ -0,0 +1,821 @@ +// Copyright 2020 Google LLC +// Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Sets macros based on HWY_TARGET. + +// This include guard is toggled by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. +#if defined(HWY_SET_MACROS_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif + +#endif // HWY_SET_MACROS_PER_TARGET + +#include "third_party/highway/hwy/detect_compiler_arch.h" // IWYU: export +#include "third_party/highway/hwy/detect_targets.h" // IWYU: export + +#undef HWY_NAMESPACE +#undef HWY_ALIGN +#undef HWY_MAX_BYTES +#undef HWY_LANES + +#undef HWY_HAVE_SCALABLE +#undef HWY_HAVE_TUPLE +#undef HWY_HAVE_INTEGER64 +#undef HWY_HAVE_FLOAT16 +#undef HWY_HAVE_FLOAT64 +#undef HWY_MEM_OPS_MIGHT_FAULT +#undef HWY_NATIVE_FMA +#undef HWY_NATIVE_DOT_BF16 +#undef HWY_CAP_GE256 +#undef HWY_CAP_GE512 + +#undef HWY_TARGET_IS_SVE +#if HWY_TARGET & HWY_ALL_SVE +#define HWY_TARGET_IS_SVE 1 +#else +#define HWY_TARGET_IS_SVE 0 +#endif + +#undef HWY_TARGET_IS_NEON +#if HWY_TARGET & HWY_ALL_NEON +#define HWY_TARGET_IS_NEON 1 +#else +#define HWY_TARGET_IS_NEON 0 +#endif + +#undef HWY_TARGET_IS_PPC +#if HWY_TARGET & HWY_ALL_PPC +#define HWY_TARGET_IS_PPC 1 +#else +#define HWY_TARGET_IS_PPC 0 +#endif + +#undef HWY_TARGET_IS_AVX10_2 +#if HWY_TARGET == HWY_AVX10_2 || HWY_TARGET == HWY_AVX10_2_512 +#define HWY_TARGET_IS_AVX10_2 1 +#else +#define HWY_TARGET_IS_AVX10_2 0 +#endif + +// Supported on all targets except RVV (requires GCC 14 or upcoming Clang) +#if HWY_TARGET == HWY_RVV && \ + ((HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1700)) +#define HWY_HAVE_TUPLE 0 +#else +#define HWY_HAVE_TUPLE 1 +#endif + +// For internal use (clamping/validating N for Simd<>) +#undef HWY_MAX_N +#if HWY_TARGET == HWY_SCALAR +#define HWY_MAX_N 1 +#else +#define HWY_MAX_N 65536 +#endif + +// For internal use (clamping kPow2 for Simd<>) +#undef HWY_MAX_POW2 +// For HWY_TARGET == HWY_RVV, LMUL <= 8. Even on other targets, we want to +// support say Rebind> d; whose kPow2 is also 3. +// However, those other targets do not actually support multiple vectors, and +// thus Lanes(d) must not exceed Lanes(ScalableTag()). +#define HWY_MAX_POW2 3 + +// User-visible. Loose lower bound that guarantees HWY_MAX_BYTES >> +// (-HWY_MIN_POW2) <= 1. Useful for terminating compile-time recursions. +#undef HWY_MIN_POW2 +#if HWY_TARGET == HWY_RVV +#define HWY_MIN_POW2 -16 +#else +// Tighter bound for other targets, whose vectors are smaller, to potentially +// save compile time. +#define HWY_MIN_POW2 -8 +#endif // HWY_TARGET == HWY_RVV + +#undef HWY_TARGET_STR + +#if defined(HWY_DISABLE_PCLMUL_AES) +#define HWY_TARGET_STR_PCLMUL_AES "" +#else +#define HWY_TARGET_STR_PCLMUL_AES ",pclmul,aes" +#endif + +#if defined(HWY_DISABLE_BMI2_FMA) +#define HWY_TARGET_STR_BMI2_FMA "" +#else +#define HWY_TARGET_STR_BMI2_FMA ",bmi,bmi2,fma" +#endif + +#if defined(HWY_DISABLE_F16C) +#define HWY_TARGET_STR_F16C "" +#else +#define HWY_TARGET_STR_F16C ",f16c" +#endif + +#define HWY_TARGET_STR_SSE2 "sse2" + +#define HWY_TARGET_STR_SSSE3 "sse2,ssse3" + +#define HWY_TARGET_STR_SSE4 \ + HWY_TARGET_STR_SSSE3 ",sse4.1,sse4.2" HWY_TARGET_STR_PCLMUL_AES +// Include previous targets, which are the half-vectors of the next target. +#define HWY_TARGET_STR_AVX2 \ + HWY_TARGET_STR_SSE4 ",avx,avx2" HWY_TARGET_STR_BMI2_FMA HWY_TARGET_STR_F16C + +#if HWY_COMPILER_GCC_ACTUAL >= 1400 || HWY_COMPILER_CLANG >= 1800 +#define HWY_TARGET_STR_AVX3_VL512 ",evex512" +#else +#define HWY_TARGET_STR_AVX3_VL512 +#endif + +#define HWY_TARGET_STR_AVX3_256 \ + HWY_TARGET_STR_AVX2 \ + ",avx512f,avx512cd,avx512vl,avx512dq,avx512bw" HWY_TARGET_STR_AVX3_VL512 + +#define HWY_TARGET_STR_AVX3 HWY_TARGET_STR_AVX3_256 HWY_TARGET_STR_AVX3_VL512 + +#define HWY_TARGET_STR_AVX3_DL_256 \ + HWY_TARGET_STR_AVX3_256 \ + ",vpclmulqdq,avx512vbmi,avx512vbmi2,vaes,avx512vnni,avx512bitalg," \ + "avx512vpopcntdq,gfni" + +#define HWY_TARGET_STR_AVX3_DL \ + HWY_TARGET_STR_AVX3_DL_256 HWY_TARGET_STR_AVX3_VL512 + +// Force-disable for compilers that do not properly support avx512bf16. +#if !defined(HWY_AVX3_DISABLE_AVX512BF16) && \ + (HWY_COMPILER_CLANGCL || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 900)) +#define HWY_AVX3_DISABLE_AVX512BF16 +#endif + +#if !defined(HWY_AVX3_DISABLE_AVX512BF16) +#define HWY_TARGET_STR_AVX3_ZEN4_256 HWY_TARGET_STR_AVX3_DL ",avx512bf16" +#else +#define HWY_TARGET_STR_AVX3_ZEN4_256 HWY_TARGET_STR_AVX3_DL +#endif + +#define HWY_TARGET_STR_AVX3_ZEN4 \ + HWY_TARGET_STR_AVX3_ZEN4_256 HWY_TARGET_STR_AVX3_VL512 + +#define HWY_TARGET_STR_AVX3_SPR_256 HWY_TARGET_STR_AVX3_ZEN4 ",avx512fp16" + +#define HWY_TARGET_STR_AVX3_SPR \ + HWY_TARGET_STR_AVX3_SPR_256 HWY_TARGET_STR_AVX3_VL512 + +#if HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2000 +#define HWY_TARGET_STR_AVX10_2 \ + HWY_TARGET_STR_AVX3_SPR_256 ",no-evex512,avx10.2-256" +#define HWY_TARGET_STR_AVX10_2_512 \ + HWY_TARGET_STR_AVX3_SPR ",avx10.2-256,avx10.2-512" +#else +#define HWY_TARGET_STR_AVX10_2 HWY_TARGET_STR_AVX3_SPR_256 ",no-evex512" +#define HWY_TARGET_STR_AVX10_2_512 HWY_TARGET_STR_AVX3_SPR +#endif + +#if defined(HWY_DISABLE_PPC8_CRYPTO) +#define HWY_TARGET_STR_PPC8_CRYPTO "" +#else +#define HWY_TARGET_STR_PPC8_CRYPTO ",crypto" +#endif + +#define HWY_TARGET_STR_PPC8 \ + "altivec,vsx,power8-vector" HWY_TARGET_STR_PPC8_CRYPTO +#define HWY_TARGET_STR_PPC9 HWY_TARGET_STR_PPC8 ",power9-vector" + +#if HWY_COMPILER_CLANG +#define HWY_TARGET_STR_PPC10 HWY_TARGET_STR_PPC9 ",power10-vector" +#else +// See #1707 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=102059#c35. +// When the baseline is PPC 8 or 9, inlining functions such as PreventElision +// into PPC10 code fails because PPC10 defaults to no-htm and is thus worse than +// the baseline, which has htm. We cannot have pragma target on functions +// outside HWY_NAMESPACE such as those in base.h. It would be possible for users +// to set -mno-htm globally, but we can also work around this at the library +// level by claiming that PPC10 still has HTM, thus avoiding the mismatch. This +// seems to be safe because HTM uses builtins rather than modifying codegen, see +// https://gcc.gnu.org/legacy-ml/gcc-patches/2013-07/msg00167.html. +#define HWY_TARGET_STR_PPC10 HWY_TARGET_STR_PPC9 ",cpu=power10,htm" +#endif + +#define HWY_TARGET_STR_Z14 "arch=z14" +#define HWY_TARGET_STR_Z15 "arch=z15" + +// Before include guard so we redefine HWY_TARGET_STR on each include, +// governed by the current HWY_TARGET. + +//----------------------------------------------------------------------------- +// SSE2 +#if HWY_TARGET == HWY_SSE2 + +#define HWY_NAMESPACE N_SSE2 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_SSE2 +//----------------------------------------------------------------------------- +// SSSE3 +#elif HWY_TARGET == HWY_SSSE3 + +#define HWY_NAMESPACE N_SSSE3 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_SSSE3 + +//----------------------------------------------------------------------------- +// SSE4 +#elif HWY_TARGET == HWY_SSE4 + +#define HWY_NAMESPACE N_SSE4 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_SSE4 + +//----------------------------------------------------------------------------- +// AVX2 +#elif HWY_TARGET == HWY_AVX2 + +#define HWY_NAMESPACE N_AVX2 +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_LANES(T) (32 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#ifdef HWY_DISABLE_BMI2_FMA +#define HWY_NATIVE_FMA 0 +#else +#define HWY_NATIVE_FMA 1 +#endif +#define HWY_NATIVE_DOT_BF16 0 + +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_AVX2 + +//----------------------------------------------------------------------------- +// AVX3[_DL]/AVX10 +#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL || \ + HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX3_SPR || \ + HWY_TARGET == HWY_AVX10_2 || HWY_TARGET == HWY_AVX10_2_512 + +#if HWY_TARGET == HWY_AVX10_2 +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_LANES(T) (32 / sizeof(T)) +#else +#define HWY_ALIGN alignas(64) +#define HWY_MAX_BYTES 64 +#define HWY_LANES(T) (64 / sizeof(T)) +#endif + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#if HWY_TARGET <= HWY_AVX10_2 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1901) && \ + HWY_HAVE_SCALAR_F16_TYPE +#define HWY_HAVE_FLOAT16 1 +#else +#define HWY_HAVE_FLOAT16 0 +#endif +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#if (HWY_TARGET <= HWY_AVX3_ZEN4) && !defined(HWY_AVX3_DISABLE_AVX512BF16) +#define HWY_NATIVE_DOT_BF16 1 +#else +#define HWY_NATIVE_DOT_BF16 0 +#endif +#define HWY_CAP_GE256 1 + +#if HWY_MAX_BYTES >= 64 +#define HWY_CAP_GE512 1 +#else +#define HWY_CAP_GE512 0 +#endif + +#if HWY_TARGET == HWY_AVX3 + +#define HWY_NAMESPACE N_AVX3 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3 + +#elif HWY_TARGET == HWY_AVX3_DL + +#define HWY_NAMESPACE N_AVX3_DL +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3_DL + +#elif HWY_TARGET == HWY_AVX3_ZEN4 + +#define HWY_NAMESPACE N_AVX3_ZEN4 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3_ZEN4 + +#elif HWY_TARGET == HWY_AVX3_SPR + +#define HWY_NAMESPACE N_AVX3_SPR +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3_SPR + +#elif HWY_TARGET == HWY_AVX10_2 + +#define HWY_NAMESPACE N_AVX10_2 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX10_2 + +#elif HWY_TARGET == HWY_AVX10_2_512 + +#define HWY_NAMESPACE N_AVX10_2_512 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX10_2_512 + +#else +#error "Logic error" +#endif // HWY_TARGET + +//----------------------------------------------------------------------------- +// PPC8, PPC9, PPC10 +#elif HWY_TARGET_IS_PPC + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_TARGET == HWY_PPC8 + +#define HWY_NAMESPACE N_PPC8 +#define HWY_TARGET_STR HWY_TARGET_STR_PPC8 + +#elif HWY_TARGET == HWY_PPC9 + +#define HWY_NAMESPACE N_PPC9 +#define HWY_TARGET_STR HWY_TARGET_STR_PPC9 + +#elif HWY_TARGET == HWY_PPC10 + +#define HWY_NAMESPACE N_PPC10 +#define HWY_TARGET_STR HWY_TARGET_STR_PPC10 + +#else +#error "Logic error" +#endif // HWY_TARGET + +//----------------------------------------------------------------------------- +// Z14, Z15 +#elif HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_TARGET == HWY_Z14 + +#define HWY_NAMESPACE N_Z14 +#define HWY_TARGET_STR HWY_TARGET_STR_Z14 + +#elif HWY_TARGET == HWY_Z15 + +#define HWY_NAMESPACE N_Z15 +#define HWY_TARGET_STR HWY_TARGET_STR_Z15 + +#else +#error "Logic error" +#endif // HWY_TARGET == HWY_Z15 + +//----------------------------------------------------------------------------- +// NEON +#elif HWY_TARGET_IS_NEON + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || HWY_TARGET == HWY_NEON_BF16 +#define HWY_HAVE_FLOAT16 1 +#else +#define HWY_HAVE_FLOAT16 0 +#endif + +#if HWY_ARCH_ARM_A64 +#define HWY_HAVE_FLOAT64 1 +#else +#define HWY_HAVE_FLOAT64 0 +#endif + +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#if defined(__ARM_FEATURE_FMA) || defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +#define HWY_NATIVE_FMA 1 +#else +#define HWY_NATIVE_FMA 0 +#endif +#if HWY_NEON_HAVE_F32_TO_BF16C || HWY_TARGET == HWY_NEON_BF16 +#define HWY_NATIVE_DOT_BF16 1 +#else +#define HWY_NATIVE_DOT_BF16 0 +#endif + +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_TARGET == HWY_NEON_WITHOUT_AES +#define HWY_NAMESPACE N_NEON_WITHOUT_AES +#elif HWY_TARGET == HWY_NEON +#define HWY_NAMESPACE N_NEON +#elif HWY_TARGET == HWY_NEON_BF16 +#define HWY_NAMESPACE N_NEON_BF16 +#else +#error "Logic error, missing case" +#endif // HWY_TARGET + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_ARCH_ARM_V7 + +// The __attribute__((target(+neon-vfpv4)) was introduced in gcc >= 8. +#if HWY_COMPILER_GCC_ACTUAL >= 800 +#define HWY_TARGET_STR "+neon-vfpv4" +#else // GCC < 7 +// Do not define HWY_TARGET_STR (no pragma). +#endif // HWY_COMPILER_GCC_ACTUAL + +#else // !HWY_ARCH_ARM_V7 + +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1300) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1300) +// GCC 12 or earlier and Clang 12 or earlier require +crypto be added to the +// target string to enable AArch64 AES intrinsics +#define HWY_TARGET_STR_NEON "+crypto" +#else +#define HWY_TARGET_STR_NEON "+aes" +#endif + +// Clang >= 16 requires +fullfp16 instead of fp16, but Apple Clang 15 = 1600 +// fails to parse unless the string starts with armv8, whereas 1700 refuses it. +#if HWY_COMPILER_CLANG >= 1700 +#define HWY_TARGET_STR_FP16 "+fullfp16" +#elif HWY_COMPILER_CLANG >= 1600 && defined(__apple_build_version__) +#define HWY_TARGET_STR_FP16 "armv8.4-a+fullfp16" +#else +#define HWY_TARGET_STR_FP16 "+fp16" +#endif + +#if HWY_TARGET == HWY_NEON_WITHOUT_AES +// Do not define HWY_TARGET_STR (no pragma). +#elif HWY_TARGET == HWY_NEON +#define HWY_TARGET_STR HWY_TARGET_STR_NEON +#elif HWY_TARGET == HWY_NEON_BF16 +#define HWY_TARGET_STR HWY_TARGET_STR_FP16 "+bf16+dotprod" HWY_TARGET_STR_NEON +#else +#error "Logic error, missing case" +#endif // HWY_TARGET + +#endif // !HWY_ARCH_ARM_V7 +#else // !HWY_HAVE_RUNTIME_DISPATCH +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// SVE[2] +#elif HWY_TARGET_IS_SVE + +// SVE only requires lane alignment, not natural alignment of the entire vector. +#define HWY_ALIGN alignas(8) + +// Value ensures MaxLanes() is the tightest possible upper bound to reduce +// overallocation. +#define HWY_LANES(T) ((HWY_MAX_BYTES) / sizeof(T)) + +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#if HWY_SVE_HAVE_BF16_FEATURE +#define HWY_NATIVE_DOT_BF16 1 +#else +#define HWY_NATIVE_DOT_BF16 0 +#endif +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_TARGET == HWY_SVE2 +#define HWY_NAMESPACE N_SVE2 +#define HWY_MAX_BYTES 256 +#define HWY_HAVE_SCALABLE 1 +#elif HWY_TARGET == HWY_SVE_256 +#define HWY_NAMESPACE N_SVE_256 +#define HWY_MAX_BYTES 32 +#define HWY_HAVE_SCALABLE 0 +#elif HWY_TARGET == HWY_SVE2_128 +#define HWY_NAMESPACE N_SVE2_128 +#define HWY_MAX_BYTES 16 +#define HWY_HAVE_SCALABLE 0 +#else +#define HWY_NAMESPACE N_SVE +#define HWY_MAX_BYTES 256 +#define HWY_HAVE_SCALABLE 1 +#endif + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +// Static dispatch with -march=armv8-a+sve2+aes, or no baseline, hence dynamic +// dispatch, which checks for AES support at runtime. +#if defined(__ARM_FEATURE_SVE2_AES) || (HWY_BASELINE_SVE2 == 0) +#define HWY_TARGET_STR "+sve2+sve2-aes,+sve" +#else // SVE2 without AES +#define HWY_TARGET_STR "+sve2,+sve" +#endif +#else // not SVE2 target +#define HWY_TARGET_STR "+sve" +#endif +#else // !HWY_HAVE_RUNTIME_DISPATCH +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// WASM +#elif HWY_TARGET == HWY_WASM + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_WASM + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// WASM_EMU256 +#elif HWY_TARGET == HWY_WASM_EMU256 + +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_LANES(T) (32 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_WASM_EMU256 + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// RVV +#elif HWY_TARGET == HWY_RVV + +// RVV only requires lane alignment, not natural alignment of the entire vector, +// and the compiler already aligns builtin types, so nothing to do here. +#define HWY_ALIGN + +// The spec requires VLEN <= 2^16 bits, so the limit is 2^16 bytes (LMUL=8). +#define HWY_MAX_BYTES 65536 + +// = HWY_MAX_BYTES divided by max LMUL=8 because MaxLanes includes the actual +// LMUL. This is the tightest possible upper bound. +#define HWY_LANES(T) (8192 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_RVV_HAVE_F16_VEC +#define HWY_HAVE_FLOAT16 1 +#else +#define HWY_HAVE_FLOAT16 0 +#endif + +#define HWY_NAMESPACE N_RVV + +#if HWY_COMPILER_CLANG >= 1900 +// https://github.com/riscv/riscv-v-spec/blob/master/v-spec.adoc#181-zvl-minimum-vector-length-standard-extensions +#define HWY_TARGET_STR "Zvl128b,Zve64d" +#else +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. +#endif + +//----------------------------------------------------------------------------- +// LSX/LASX +#elif HWY_TARGET == HWY_LSX || HWY_TARGET == HWY_LASX + +#if HWY_TARGET == HWY_LSX +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#else +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#endif + +#define HWY_LANES(T) (HWY_MAX_BYTES / sizeof(T)) + +// TODO: check flag values +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_TARGET == HWY_LSX +#define HWY_NAMESPACE N_LSX +#else +#define HWY_NAMESPACE N_LASX +#endif + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +//----------------------------------------------------------------------------- +// EMU128 +#elif HWY_TARGET == HWY_EMU128 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_EMU128 + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +//----------------------------------------------------------------------------- +// SCALAR +#elif HWY_TARGET == HWY_SCALAR + +#define HWY_ALIGN +#define HWY_MAX_BYTES 8 +#define HWY_LANES(T) 1 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_SCALAR + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +//----------------------------------------------------------------------------- + +// Sanity check: if we have f16 vector support, then base.h should also be +// using a built-in type for f16 scalars. +#if HWY_HAVE_FLOAT16 && !HWY_HAVE_SCALAR_F16_TYPE +#error "Logic error: f16 vectors but no scalars" +#endif + +// Override this to 1 in asan/msan builds, which will still fault. +#if HWY_IS_ASAN || HWY_IS_MSAN +#undef HWY_MEM_OPS_MIGHT_FAULT +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#endif + +// Clang <9 requires this be invoked at file scope, before any namespace. +#undef HWY_BEFORE_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_BEFORE_NAMESPACE() \ + HWY_PUSH_ATTRIBUTES(HWY_TARGET_STR) \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_BEFORE_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +// Clang <9 requires any namespaces be closed before this macro. +#undef HWY_AFTER_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_AFTER_NAMESPACE() \ + HWY_POP_ATTRIBUTES \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_AFTER_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +#undef HWY_ATTR +#if defined(HWY_TARGET_STR) && HWY_HAS_ATTRIBUTE(target) +#define HWY_ATTR __attribute__((target(HWY_TARGET_STR))) +#else +#define HWY_ATTR +#endif diff --git a/third_party/aom/third_party/highway/hwy/ops/shared-inl.h b/third_party/aom/third_party/highway/hwy/ops/shared-inl.h new file mode 100644 index 000000000000..95e339995455 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/shared-inl.h @@ -0,0 +1,720 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target definitions shared by ops/*.h and user code. + +// IWYU pragma: begin_exports +// Export does not seem to be recursive, so re-export these (also in base.h) +#include + +#include "third_party/highway/hwy/base.h" +// "IWYU pragma: keep" does not work for this include, so hide it from the IDE. +#if !HWY_IDE +#include +#endif + +#include "third_party/highway/hwy/detect_compiler_arch.h" +#include "third_party/highway/hwy/detect_targets.h" + +// Separate header because foreach_target.h re-enables its include guard. +#include "third_party/highway/hwy/ops/set_macros-inl.h" + +// IWYU pragma: end_exports + +#if HWY_IS_MSAN +#include +#endif + +// We are covered by the highway.h include guard, but generic_ops-inl.h +// includes this again #if HWY_IDE. +// clang-format off +#if defined(HIGHWAY_HWY_OPS_SHARED_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT +// clang-format on +#ifdef HIGHWAY_HWY_OPS_SHARED_TOGGLE +#undef HIGHWAY_HWY_OPS_SHARED_TOGGLE +#else +#define HIGHWAY_HWY_OPS_SHARED_TOGGLE +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// NOTE: GCC generates incorrect code for vector arguments to non-inlined +// functions in two situations: +// - on Windows and GCC 10.3, passing by value crashes due to unaligned loads: +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412. +// - on aarch64 and GCC 9.3.0 or 11.2.1, passing by value causes many (but not +// all) tests to fail. +// +// We therefore pass by const& only on GCC and (Windows or aarch64). This alias +// must be used for all vector/mask parameters of functions marked HWY_NOINLINE, +// and possibly also other functions that are not inlined. +// +// Even better is to avoid passing vector arguments to non-inlined functions, +// because the SVE and RISC-V ABIs are still works in progress and may lead to +// incorrect codegen. +#if HWY_COMPILER_GCC_ACTUAL && (HWY_OS_WIN || HWY_ARCH_ARM_A64) +template +using VecArg = const V&; +#else +template +using VecArg = V; +#endif + +namespace detail { + +template +struct NativeLaneTypeT { + using type = T; +}; +template <> +struct NativeLaneTypeT { +#if HWY_HAVE_SCALAR_F16_TYPE + using type = hwy::float16_t::Native; +#else + using type = uint16_t; +#endif +}; +template <> +struct NativeLaneTypeT { +#if HWY_HAVE_SCALAR_BF16_TYPE + using type = hwy::bfloat16_t::Native; +#else + using type = uint16_t; +#endif +}; + +// The type expected by intrinsics for the given Highway lane type T. This +// usually matches T, but differs for our wrapper types [b]float16_t. Use this +// only when defining intrinsic wrappers, and NOT for casting, which is UB. +template +using NativeLaneType = typename NativeLaneTypeT::type; + +// Returns the same pointer after changing type to NativeLaneType. Use this only +// for wrapper functions that call intrinsics (e.g. load/store) where some of +// the overloads expect _Float16* or __bf16* arguments. For non-special floats, +// this returns the same pointer and type. +// +// This makes use of the fact that a wrapper struct is pointer-interconvertible +// with its first member (a union), thus also with the union members. Do NOT +// call both this and U16LanePointer on the same object - they access different +// union members, and this is not guaranteed to be safe. +template +HWY_INLINE T* NativeLanePointer(T* p) { + return p; +} +template >, + HWY_IF_F16(T)> +HWY_INLINE constexpr If(), const NT*, NT*> NativeLanePointer(T* p) { +#if HWY_HAVE_SCALAR_F16_TYPE + return &p->native; +#else + return &p->bits; +#endif +} +template >, + HWY_IF_BF16(T)> +HWY_INLINE constexpr If(), const NT*, NT*> NativeLanePointer(T* p) { +#if HWY_HAVE_SCALAR_BF16_TYPE + return &p->native; +#else + return &p->bits; +#endif +} + +// Returns a pointer to the u16 member of our [b]float16_t wrapper structs. +// Use this in Highway targets that lack __bf16 intrinsics; for storing to +// memory, we BitCast vectors to u16 and write to the pointer returned here. +// Do NOT call both this and U16LanePointer on the same object - they access +// different union members, and this is not guaranteed to be safe. +template +HWY_INLINE If(), const uint16_t*, uint16_t*> U16LanePointer(T* p) { + return &p->bits; +} + +// Returns N * 2^pow2. N is the number of lanes in a full vector and pow2 the +// desired fraction or multiple of it, see Simd<>. `pow2` is most often in +// [-3, 3] but can also be lower for user-specified fractions. +constexpr size_t ScaleByPower(size_t N, int pow2) { + return pow2 >= 0 ? (N << pow2) : (N >> (-pow2)); +} + +template +HWY_INLINE void MaybePoison(T* HWY_RESTRICT unaligned, size_t count) { +#if HWY_IS_MSAN + __msan_poison(unaligned, count * sizeof(T)); +#else + (void)unaligned; + (void)count; +#endif +} + +template +HWY_INLINE void MaybeUnpoison(T* HWY_RESTRICT unaligned, size_t count) { + // Workaround for MSAN not marking compressstore as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#else + (void)unaligned; + (void)count; +#endif +} + +} // namespace detail + +// Highway operations are implemented as overloaded functions selected using a +// zero-sized tag type D := Simd. T denotes the lane type. +// +// N defines how many lanes are in a 'full' vector, typically equal to +// HWY_LANES(T) (which is the actual count on targets with vectors of known +// size, and an upper bound in case of scalable vectors), otherwise a +// user-specified limit at most that large. +// +// 2^kPow2 is a _subsequently_ applied scaling factor that indicates the +// desired fraction of a 'full' vector: 0 means full, -1 means half; 1,2,3 +// means two/four/eight full vectors ganged together. The largest supported +// kPow2 is `HWY_MAX_POW2` and the aliases below take care of clamping +// user-specified values to that. Note that `Simd` and `Simd` +// have the same `MaxLanes` and `Lanes`. +// +// We can theoretically keep halving Lanes(), but recursive instantiations of +// kPow2 - 1 will eventually fail e.g. because -64 is not a valid shift count. +// Users must terminate such compile-time recursions at or above HWY_MIN_POW2. +// +// WARNING: do not use N directly because it may be a special representation of +// a fractional MaxLanes. This arises when we Rebind Simd to +// Simd. RVV requires that the last argument (kPow2) be two, +// but we want MaxLanes to be the same in both cases. Hence ?? is a +// fixed-point encoding of 1/4. +// +// Instead of referring to Simd<> directly, users create D via aliases: +// - ScalableTag for a full vector; +// - ScalableTag() for a fraction/group, where `kPow2` is +// interpreted as `HWY_MIN(kPow2, HWY_MAX_POW2)`; +// - CappedTag for a vector with up to kLimit lanes; or +// - FixedTag for a vector with exactly kNumLanes lanes. +// +// Instead of N, use Lanes(D()) for the actual number of lanes at runtime and +// D().MaxLanes() for a constexpr upper bound. Both are powers of two. +template +struct Simd { + constexpr Simd() = default; + using T = Lane; + + private: + static_assert(sizeof(Lane) <= 8, "Lanes are up to 64-bit"); + static_assert(IsSame>(), + "Lane must not be a reference type, const-qualified type, or " + "volatile-qualified type"); + static_assert(IsIntegerLaneType() || IsFloat() || + IsSpecialFloat(), + "IsIntegerLaneType(), IsFloat(), or IsSpecialFloat() " + "must be true"); + // 20 bits are sufficient for any HWY_MAX_BYTES. This is the 'normal' value of + // N when kFrac == 0, otherwise it is one (see FracN). + static constexpr size_t kWhole = N & 0xFFFFF; + // Fractional part is in the bits above kWhole. + static constexpr int kFrac = static_cast(N >> 20); + // Can be 8x larger because kPow2 may be as low as -3 (Rebind of a larger + // type to u8 results in fractions). + static_assert(kWhole <= 8 * HWY_MAX_N && kFrac <= 3, "Out of range"); + static_assert(kFrac == 0 || kWhole == 1, "If frac, whole must be 1"); + static_assert((kWhole & (kWhole - 1)) == 0 && kWhole != 0, "Not 2^x"); + // Important to check this here because kPow2 <= -64 causes confusing + // compile errors (invalid shift count). + static_assert(kPow2 >= HWY_MIN_POW2, "Forgot kPow2 recursion terminator?"); + // However, do NOT verify kPow2 <= HWY_MAX_POW2 - users should be able to + // Rebind> in order to discover that its + // kPow2 is out of bounds. + + public: + // Upper bound on the number of lanes (tight if !HWY_HAVE_SCALABLE). In the + // common case, N == kWhole, but if kFrac is nonzero, we deduct it from kPow2. + // E.g. Rebind> is Simd. + // The resulting number of lanes is still 1 because this N represents 1/4 + // (the ratio of the sizes). Note that RVV requires kPow2 to be the ratio of + // the sizes so that the correct LMUL overloads are chosen, even if N is + // small enough that it would fit in an LMUL=1 vector. + // + // Cannot be an enum because GCC warns when using enums and non-enums in the + // same expression. Cannot be a static constexpr function (MSVC limitation). + // Rounded up to one so this is a valid array length. + // + // Do not use this directly - only 'public' so it is visible from the accessor + // macro required by MSVC. + static constexpr size_t kPrivateLanes = + HWY_MAX(size_t{1}, detail::ScaleByPower(kWhole, kPow2 - kFrac)); + // Do not use this directly - only 'public' so it is visible from the accessor + // macro required by MSVC. + static constexpr int kPrivatePow2 = kPow2; + + constexpr size_t MaxLanes() const { return kPrivateLanes; } + constexpr size_t MaxBytes() const { return kPrivateLanes * sizeof(Lane); } + constexpr size_t MaxBlocks() const { return (MaxBytes() + 15) / 16; } + // For SFINAE (HWY_IF_POW2_GT_D). + constexpr int Pow2() const { return kPow2; } + + // ------------------------------ Changing lane type or count + // Do not use any of these directly. Anything used from member typedefs cannot + // be made private, but functions only used within other functions can. + + // Returns number of NewT lanes that fit within MaxBytes(). + template + static constexpr size_t RepartitionLanes() { + // Round up to correctly handle larger NewT. + return (kPrivateLanes * sizeof(T) + sizeof(NewT) - 1) / sizeof(NewT); + } + + // Returns the new kPow2 required for lanes of type NewT. + template + static constexpr int RebindPow2() { + return kPow2 + + ((sizeof(NewT) >= sizeof(T)) + ? static_cast(CeilLog2(sizeof(NewT) / sizeof(T))) + : -static_cast(CeilLog2(sizeof(T) / sizeof(NewT)))); + } + + private: + // Returns 0 or whole NewN such that kNewMaxLanes = NewN * 2^kNewPow2. + template + static constexpr size_t WholeN() { + return detail::ScaleByPower(kNewMaxLanes, -kNewPow2); + } + + // Returns fractional NewN such that kNewMaxLanes = NewN * 2^kNewPow2. + template + static constexpr size_t FracN() { + // Only reached if kNewPow2 > CeilLog2(kNewMaxLanes) >= 0 (else WholeN + // would not have been zero), but clamp to zero to avoid warnings. kFrac is + // the difference, stored in the upper bits of N, and we also set kWhole = + // 1 so that the new kPrivateLanes = kNewMaxLanes. + static_assert(HWY_MAX_N <= (size_t{1} << 20), "Change bit shift"); + return static_cast( + 1 + (HWY_MAX(0, kNewPow2 - static_cast(CeilLog2(kNewMaxLanes))) + << 20)); + } + + public: + // Returns (whole or fractional) NewN, see above. + template + static constexpr size_t NewN() { + // We require a fraction if inverting kNewPow2 results in 0. + return WholeN() == 0 + ? FracN() + : WholeN(); + } + + // PromoteTo/DemoteTo() with another lane type, but same number of lanes. + template + using Rebind = + Simd(), kPrivateLanes>(), RebindPow2()>; + + // Change lane type while keeping the same vector size, e.g. for MulEven. + template + using Repartition = + Simd()>(), kPow2>; + + // Half the lanes while keeping the same lane type, e.g. for LowerHalf. + using Half = Simd; + + // Twice the lanes while keeping the same lane type, e.g. for Combine. + using Twice = Simd; +}; + +namespace detail { + +template +constexpr bool IsFull(Simd /* d */) { + return N == HWY_LANES(T) && kPow2 == 0; +} + +// Struct wrappers enable validation of arguments via static_assert. +template +struct ClampNAndPow2 { + using type = Simd; +}; + +template +struct ScalableTagChecker { + using type = typename ClampNAndPow2::type; +}; + +template +struct CappedTagChecker { + static_assert(kLimit != 0, "Does not make sense to have zero lanes"); + // Safely handle non-power-of-two inputs by rounding down, which is allowed by + // CappedTag. Otherwise, Simd would static_assert. + static constexpr size_t kLimitPow2 = size_t{1} << hwy::FloorLog2(kLimit); + static constexpr size_t N = HWY_MIN(kLimitPow2, HWY_LANES(T)); + using type = typename ClampNAndPow2::type; +}; + +template +struct FixedTagChecker { + static_assert(kNumLanes != 0, "Does not make sense to have zero lanes"); + static_assert(kNumLanes <= HWY_LANES(T), "Too many lanes"); + using type = Simd; +}; + +} // namespace detail + +// ------------------------------ Aliases for Simd<> + +// Tag describing a full vector (kPow2 == 0: the most common usage, e.g. 1D +// loops where the application does not care about the vector size) or a +// fraction/multiple of one. Fractions (kPow2 < 0) are useful for arguments or +// return values of type promotion and demotion. User-specified kPow2 is +// interpreted as `HWY_MIN(kPow2, HWY_MAX_POW2)`. +template +using ScalableTag = typename detail::ScalableTagChecker::type; + +// Tag describing a vector with *up to* kLimit active lanes, even on targets +// with scalable vectors and HWY_SCALAR. The runtime lane count `Lanes(tag)` may +// be less than kLimit, and is 1 on HWY_SCALAR. This alias is typically used for +// 1D loops with a relatively low application-defined upper bound, e.g. for 8x8 +// DCTs. However, it is better if data structures are designed to be +// vector-length-agnostic (e.g. a hybrid SoA where there are chunks of `M >= +// MaxLanes(d)` DC components followed by M AC1, .., and M AC63; this would +// enable vector-length-agnostic loops using ScalableTag). User-specified kPow2 +// is interpreted as `HWY_MIN(kPow2, HWY_MAX_POW2)`. +template +using CappedTag = typename detail::CappedTagChecker::type; + +#if !HWY_HAVE_SCALABLE +// If the vector size is known, and the app knows it does not want more than +// kLimit lanes, then capping can be beneficial. For example, AVX-512 has lower +// IPC and potentially higher costs for unaligned load/store vs. 256-bit AVX2. +template +using CappedTagIfFixed = CappedTag; +#else // HWY_HAVE_SCALABLE +// .. whereas on RVV/SVE, the cost of clamping Lanes() may exceed the benefit. +template +using CappedTagIfFixed = ScalableTag; +#endif + +// Alias for a tag describing a vector with *exactly* kNumLanes active lanes, +// even on targets with scalable vectors. Requires `kNumLanes` to be a power of +// two not exceeding `HWY_LANES(T)`. +// +// NOTE: if the application does not need to support HWY_SCALAR (+), use this +// instead of CappedTag to emphasize that there will be exactly kNumLanes lanes. +// This is useful for data structures that rely on exactly 128-bit SIMD, but +// these are discouraged because they cannot benefit from wider vectors. +// Instead, applications would ideally define a larger problem size and loop +// over it with the (unknown size) vectors from ScalableTag. +// +// + e.g. if the baseline is known to support SIMD, or the application requires +// ops such as TableLookupBytes not supported by HWY_SCALAR. +template +using FixedTag = typename detail::FixedTagChecker::type; + +// Convenience form for fixed sizes. +template +using Full16 = Simd; + +template +using Full32 = Simd; + +template +using Full64 = Simd; + +template +using Full128 = Simd; + +// ------------------------------ Accessors for Simd<> + +// Lane type. +template +using TFromD = typename D::T; + +// Upper bound on the number of lanes, typically used for SFINAE conditions and +// to allocate storage for targets with known vector sizes. Note: this may be a +// loose bound, instead use Lanes() as the actual size for AllocateAligned. +// MSVC workaround: use static constant directly instead of a function. +#define HWY_MAX_LANES_D(D) D::kPrivateLanes + +// Same as D().Pow2(), but this is too complex for SFINAE with MSVC, so we use a +// static constant directly. +#define HWY_POW2_D(D) D::kPrivatePow2 + +// Non-macro form of HWY_MAX_LANES_D in case that is preferable. WARNING: the +// macro form may be required for MSVC, which has limitations on deducing +// arguments. +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return HWY_MAX_LANES_D(D); +} + +#undef HWY_HAVE_CONSTEXPR_LANES +#undef HWY_LANES_CONSTEXPR + +#if HWY_HAVE_SCALABLE +#define HWY_HAVE_CONSTEXPR_LANES 0 +#define HWY_LANES_CONSTEXPR +#else + +// We want Lanes() to be constexpr where possible, so that compilers are able to +// precompute offsets. However, user code must not depend on the constexpr, +// because that will fail for RISC-V V and Arm SVE. To achieve both, we mark it +// as non-constexpr in debug builds, but not sanitizers, because we typically +// want them to see the same code. +#if HWY_IS_DEBUG_BUILD && !HWY_IS_SANITIZER +#define HWY_HAVE_CONSTEXPR_LANES 0 +#define HWY_LANES_CONSTEXPR +#else +#define HWY_HAVE_CONSTEXPR_LANES 1 +#define HWY_LANES_CONSTEXPR constexpr +#endif + +// Returns actual vector length, used when advancing loop counters. The +// non-constexpr implementations are defined in their target's header. For a +// guaranteed-constexpr upper bound, use `MaxLanes(d)`. +template +HWY_INLINE HWY_MAYBE_UNUSED HWY_LANES_CONSTEXPR size_t Lanes(D) { + return HWY_MAX_LANES_D(D); +} + +#endif // !HWY_HAVE_SCALABLE + +// Tag for the same number of lanes as D, but with the LaneType T. +template +using Rebind = typename D::template Rebind; + +template +using RebindToSigned = Rebind>, D>; +template +using RebindToUnsigned = Rebind>, D>; +template +using RebindToFloat = Rebind>, D>; + +// Tag for the same total size as D, but with the LaneType T. +template +using Repartition = typename D::template Repartition; + +template +using RepartitionToWide = Repartition>, D>; +template +using RepartitionToNarrow = Repartition>, D>; + +// Shorthand for applying RepartitionToWide twice (for 8/16-bit types). +template +using RepartitionToWideX2 = RepartitionToWide>; +// Shorthand for applying RepartitionToWide three times (for 8-bit types). +template +using RepartitionToWideX3 = RepartitionToWide>; + +// Tag for the same lane type as D, but half the lanes. +template +using Half = typename D::Half; + +// Tag for the same lane type as D, but twice the lanes. +template +using Twice = typename D::Twice; + +// Tag for a 16-byte block with the same lane type as D +#if HWY_HAVE_SCALABLE +namespace detail { + +template +class BlockDFromD_t {}; + +template +class BlockDFromD_t> { + using D = Simd; + static constexpr int kNewPow2 = HWY_MIN(kPow2, 0); + static constexpr size_t kMaxLpb = HWY_MIN(16 / sizeof(T), HWY_MAX_LANES_D(D)); + static constexpr size_t kNewN = D::template NewN(); + + public: + using type = Simd; +}; + +} // namespace detail + +template +using BlockDFromD = typename detail::BlockDFromD_t>::type; +#else +template +using BlockDFromD = + Simd, HWY_MIN(16 / sizeof(TFromD), HWY_MAX_LANES_D(D)), 0>; +#endif + +// Returns whether `ptr` is a multiple of `Lanes(d)` elements. +template +HWY_API bool IsAligned(D d, T* ptr) { + const size_t N = Lanes(d); + return reinterpret_cast(ptr) % (N * sizeof(T)) == 0; +} + +// ------------------------------ Choosing overloads (SFINAE) + +// Same as base.h macros but with a Simd argument instead of T. +#define HWY_IF_UNSIGNED_D(D) HWY_IF_UNSIGNED(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_UNSIGNED_D(D) \ + HWY_IF_NOT_UNSIGNED(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_SIGNED_D(D) HWY_IF_SIGNED(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_FLOAT_D(D) HWY_IF_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_FLOAT_D(D) HWY_IF_NOT_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_FLOAT3264_D(D) HWY_IF_FLOAT3264(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_FLOAT3264_D(D) \ + HWY_IF_NOT_FLOAT3264(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_SPECIAL_FLOAT_D(D) \ + HWY_IF_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_SPECIAL_FLOAT_D(D) \ + HWY_IF_NOT_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_FLOAT_OR_SPECIAL_D(D) \ + HWY_IF_FLOAT_OR_SPECIAL(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D) \ + HWY_IF_NOT_FLOAT_NOR_SPECIAL(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_T_SIZE_D(D, bytes) \ + HWY_IF_T_SIZE(hwy::HWY_NAMESPACE::TFromD, bytes) +#define HWY_IF_NOT_T_SIZE_D(D, bytes) \ + HWY_IF_NOT_T_SIZE(hwy::HWY_NAMESPACE::TFromD, bytes) +#define HWY_IF_T_SIZE_ONE_OF_D(D, bit_array) \ + HWY_IF_T_SIZE_ONE_OF(hwy::HWY_NAMESPACE::TFromD, bit_array) +#define HWY_IF_T_SIZE_LE_D(D, bytes) \ + HWY_IF_T_SIZE_LE(hwy::HWY_NAMESPACE::TFromD, bytes) +#define HWY_IF_T_SIZE_GT_D(D, bytes) \ + HWY_IF_T_SIZE_GT(hwy::HWY_NAMESPACE::TFromD, bytes) + +#define HWY_IF_LANES_D(D, lanes) HWY_IF_LANES(HWY_MAX_LANES_D(D), lanes) +#define HWY_IF_LANES_LE_D(D, lanes) HWY_IF_LANES_LE(HWY_MAX_LANES_D(D), lanes) +#define HWY_IF_LANES_GT_D(D, lanes) HWY_IF_LANES_GT(HWY_MAX_LANES_D(D), lanes) +#define HWY_IF_LANES_PER_BLOCK_D(D, lanes) \ + HWY_IF_LANES_PER_BLOCK(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), \ + lanes) + +#if HWY_COMPILER_MSVC +#define HWY_IF_POW2_LE_D(D, pow2) \ + hwy::EnableIf* = nullptr +#define HWY_IF_POW2_GT_D(D, pow2) \ + hwy::EnableIf<(HWY_POW2_D(D) > pow2)>* = nullptr +#else +#define HWY_IF_POW2_LE_D(D, pow2) hwy::EnableIf* = nullptr +#define HWY_IF_POW2_GT_D(D, pow2) hwy::EnableIf<(D().Pow2() > pow2)>* = nullptr +#endif // HWY_COMPILER_MSVC + +#define HWY_IF_U8_D(D) HWY_IF_U8(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_U16_D(D) HWY_IF_U16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_U32_D(D) HWY_IF_U32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_U64_D(D) HWY_IF_U64(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_I8_D(D) HWY_IF_I8(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_I16_D(D) HWY_IF_I16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_I32_D(D) HWY_IF_I32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_I64_D(D) HWY_IF_I64(hwy::HWY_NAMESPACE::TFromD) + +// Use instead of HWY_IF_T_SIZE_D to avoid ambiguity with float16_t/float/double +// overloads. +#define HWY_IF_UI8_D(D) HWY_IF_UI8(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_UI16_D(D) HWY_IF_UI16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_UI32_D(D) HWY_IF_UI32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_UI64_D(D) HWY_IF_UI64(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_BF16_D(D) HWY_IF_BF16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_BF16_D(D) HWY_IF_NOT_BF16(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_F16_D(D) HWY_IF_F16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_F16_D(D) HWY_IF_NOT_F16(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_F32_D(D) HWY_IF_F32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_F64_D(D) HWY_IF_F64(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_V_SIZE_D(D) \ + (HWY_MAX_LANES_D(D) * sizeof(hwy::HWY_NAMESPACE::TFromD)) +#define HWY_IF_V_SIZE_D(D, bytes) \ + HWY_IF_V_SIZE(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), bytes) +#define HWY_IF_V_SIZE_LE_D(D, bytes) \ + HWY_IF_V_SIZE_LE(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), bytes) +#define HWY_IF_V_SIZE_GT_D(D, bytes) \ + HWY_IF_V_SIZE_GT(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), bytes) + +// Same, but with a vector argument. ops/*-inl.h define their own TFromV. +#define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_NOT_UNSIGNED_V(V) \ + HWY_IF_NOT_UNSIGNED(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_NOT_FLOAT_V(V) HWY_IF_NOT_FLOAT(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_FLOAT3264_V(V) HWY_IF_FLOAT3264(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_SPECIAL_FLOAT_V(V) \ + HWY_IF_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_FLOAT_OR_SPECIAL_V(V) \ + HWY_IF_FLOAT_OR_SPECIAL(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) \ + HWY_IF_NOT_FLOAT_NOR_SPECIAL(hwy::HWY_NAMESPACE::TFromV) + +#define HWY_IF_T_SIZE_V(V, bytes) \ + HWY_IF_T_SIZE(hwy::HWY_NAMESPACE::TFromV, bytes) +#define HWY_IF_NOT_T_SIZE_V(V, bytes) \ + HWY_IF_NOT_T_SIZE(hwy::HWY_NAMESPACE::TFromV, bytes) +#define HWY_IF_T_SIZE_ONE_OF_V(V, bit_array) \ + HWY_IF_T_SIZE_ONE_OF(hwy::HWY_NAMESPACE::TFromV, bit_array) + +#define HWY_MAX_LANES_V(V) HWY_MAX_LANES_D(hwy::HWY_NAMESPACE::DFromV) +#define HWY_IF_V_SIZE_V(V, bytes) \ + HWY_IF_V_SIZE(hwy::HWY_NAMESPACE::TFromV, HWY_MAX_LANES_V(V), bytes) +#define HWY_IF_V_SIZE_LE_V(V, bytes) \ + HWY_IF_V_SIZE_LE(hwy::HWY_NAMESPACE::TFromV, HWY_MAX_LANES_V(V), bytes) +#define HWY_IF_V_SIZE_GT_V(V, bytes) \ + HWY_IF_V_SIZE_GT(hwy::HWY_NAMESPACE::TFromV, HWY_MAX_LANES_V(V), bytes) + +// Use in implementations of ReduceSum etc. to avoid conflicts with the N=1 and +// N=4 8-bit specializations in generic_ops-inl. +#undef HWY_IF_REDUCE_D +#define HWY_IF_REDUCE_D(D) \ + hwy::EnableIf) != 1)>* = nullptr + +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) HWY_IF_LANES_GT_D(D, 1) + +#undef HWY_IF_MINMAX_OF_LANES_D +#define HWY_IF_MINMAX_OF_LANES_D(D) HWY_IF_LANES_GT_D(D, 1) + +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) HWY_IF_LANES_GT_D(hwy::HWY_NAMESPACE::DFromV, 1) + +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(hwy::HWY_NAMESPACE::DFromV, 1) + +#undef HWY_IF_PAIRWISE_ADD_128_D +#define HWY_IF_PAIRWISE_ADD_128_D(D) HWY_IF_V_SIZE_GT_D(D, 8) + +#undef HWY_IF_PAIRWISE_SUB_128_D +#define HWY_IF_PAIRWISE_SUB_128_D(D) HWY_IF_V_SIZE_GT_D(D, 8) + +// HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V is used to disable the default +// implementation of unsigned to signed DemoteTo/ReorderDemote2To in +// generic_ops-inl.h for at least some of the unsigned to signed demotions on +// SCALAR/EMU128/SSE2/SSSE3/SSE4/AVX2/SVE/SVE2 + +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) void* = nullptr + +// Old names (deprecated) +#define HWY_IF_LANE_SIZE_D(D, bytes) HWY_IF_T_SIZE_D(D, bytes) +#define HWY_IF_NOT_LANE_SIZE_D(D, bytes) HWY_IF_NOT_T_SIZE_D(D, bytes) + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_OPS_SHARED_TOGGLE diff --git a/third_party/aom/third_party/highway/hwy/ops/wasm_128-inl.h b/third_party/aom/third_party/highway/hwy/ops/wasm_128-inl.h new file mode 100644 index 000000000000..207b57c99449 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/wasm_128-inl.h @@ -0,0 +1,5983 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit WASM vectors and operations. +// External include guard in highway.h - see comment there. + +#include + +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/ops/shared-inl.h" + +#ifdef HWY_WASM_OLD_NAMES +#define wasm_i8x16_shuffle wasm_v8x16_shuffle +#define wasm_i16x8_shuffle wasm_v16x8_shuffle +#define wasm_i32x4_shuffle wasm_v32x4_shuffle +#define wasm_i64x2_shuffle wasm_v64x2_shuffle +#define wasm_u16x8_extend_low_u8x16 wasm_i16x8_widen_low_u8x16 +#define wasm_u32x4_extend_low_u16x8 wasm_i32x4_widen_low_u16x8 +#define wasm_i32x4_extend_low_i16x8 wasm_i32x4_widen_low_i16x8 +#define wasm_i16x8_extend_low_i8x16 wasm_i16x8_widen_low_i8x16 +#define wasm_u32x4_extend_high_u16x8 wasm_i32x4_widen_high_u16x8 +#define wasm_i32x4_extend_high_i16x8 wasm_i32x4_widen_high_i16x8 +#define wasm_i32x4_trunc_sat_f32x4 wasm_i32x4_trunc_saturate_f32x4 +#define wasm_i62x2_trunc_sat_f64x2 wasm_i64x2_trunc_saturate_f64x2 +#define wasm_u8x16_add_sat wasm_u8x16_add_saturate +#define wasm_u8x16_sub_sat wasm_u8x16_sub_saturate +#define wasm_u16x8_add_sat wasm_u16x8_add_saturate +#define wasm_u16x8_sub_sat wasm_u16x8_sub_saturate +#define wasm_i8x16_add_sat wasm_i8x16_add_saturate +#define wasm_i8x16_sub_sat wasm_i8x16_sub_saturate +#define wasm_i16x8_add_sat wasm_i16x8_add_saturate +#define wasm_i16x8_sub_sat wasm_i16x8_sub_saturate +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +#if HWY_TARGET == HWY_WASM_EMU256 +template +using Full256 = Simd; +#endif + +namespace detail { + +template +struct Raw128 { + using type = __v128_u; +}; +template <> +struct Raw128 { + using type = __f32x4; +}; +template <> +struct Raw128 { + using type = __f64x2; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +// FF..FF or 0. +template +struct Mask128 { + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM + + typename detail::Raw128::type raw; +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ Zero + +// Use HWY_MAX_LANES_D here because VFromD is defined in terms of Zero. +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{wasm_i32x4_splat(0)}; +} +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{wasm_f32x4_splat(0.0f)}; +} +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{wasm_f64x2_splat(0.0)}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __v128_u BitCastToInteger(__v128_u v) { return v; } +HWY_INLINE __v128_u BitCastToInteger(__f32x4 v) { + return static_cast<__v128_u>(v); +} +HWY_INLINE __v128_u BitCastToInteger(__f64x2 v) { + return static_cast<__v128_u>(v); +} + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __v128_u operator()(__v128_u v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __f32x4 operator()(__v128_u v) { return static_cast<__f32x4>(v); } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __f64x2 operator()(__v128_u v) { return static_cast<__f64x2>(v); } +}; + +template +HWY_INLINE VFromD BitCastFromByte(D d, Vec128 v) { + return VFromD{BitCastFromInteger128>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, + Vec128().MaxLanes()> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Repartition du8_to; + return BitCast(d, VFromD{detail::BitCastToInteger(v.raw)}); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i8x16_splat(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i16x8_splat(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i32x4_splat(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i64x2_splat(static_cast(t))}; +} + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i16x8_splat(BitCastScalar(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_f32x4_splat(t)}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_f64x2_splat(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// For all vector sizes. +template +HWY_API VFromD Undefined(D d) { + return Zero(d); +} + +HWY_DIAGNOSTICS(pop) + +// For all vector sizes. +template , typename T2> +HWY_API VFromD Iota(D d, const T2 first) { + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + lanes[i] = AddWithWraparound(static_cast(first), i); + } + return Load(d, lanes); +} + +// ------------------------------ Dup128VecFromValues +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{wasm_i8x16_make(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, + t11, t12, t13, t14, t15)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{wasm_u8x16_make(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, + t11, t12, t13, t14, t15)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{wasm_i16x8_make(t0, t1, t2, t3, t4, t5, t6, t7)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{wasm_u16x8_make(t0, t1, t2, t3, t4, t5, t6, t7)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{wasm_i32x4_make(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{wasm_u32x4_make(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{wasm_f32x4_make(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{wasm_i64x2_make(t0, t1)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{wasm_u64x2_make(t0, t1)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{wasm_f64x2_make(t0, t1)}; +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f64x2_add(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f64x2_sub(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_add_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_add_sat(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add_sat(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_sub_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_sub_sat(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub_sat(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_avgr(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_avgr(a.raw, b.raw)}; +} + +template +HWY_API V AverageRound(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const V sign_bit = SignBit(d); + return Xor(BitCast(d, AverageRound(BitCast(du, Xor(a, sign_bit)), + BitCast(du, Xor(b, sign_bit)))), + sign_bit); +} + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i8x16_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i16x8_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i32x4_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i64x2_abs(v.raw)}; +} + +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_f32x4_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_f64x2_abs(v.raw)}; +} + +// ------------------------------ Shift lanes by constant #bits + +// Unsigned +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u32x4_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u64x2_shr(v.raw, kBits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i32x4_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i64x2_shr(v.raw, kBits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + + if (kBits == 0) return v; + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +// After https://reviews.llvm.org/D108415 shift argument became unsigned. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unsigned +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u32x4_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u64x2_shr(v.raw, bits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shr(v.raw, bits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, 0xFF >> bits); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ignore Wsign-conversion +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Minimum + +// Unsigned +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t min[2] = {HWY_MIN(a0, b0), HWY_MIN(a1, b1)}; + return Vec128{wasm_v128_load(min)}; +} + +// Signed +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + alignas(16) int64_t min[4]; + min[0] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + min[1] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128{wasm_v128_load(min)}; +} + +// Float +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + // Equivalent to a < b ? a : b (taking into account our swapped arg order, + // so that Min(NaN, x) is x to match x86). + return Vec128{wasm_f32x4_pmin(b.raw, a.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + // Equivalent to a < b ? a : b (taking into account our swapped arg order, + // so that Min(NaN, x) is x to match x86). + return Vec128{wasm_f64x2_pmin(b.raw, a.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t max[2] = {HWY_MAX(a0, b0), HWY_MAX(a1, b1)}; + return Vec128{wasm_v128_load(max)}; +} + +// Signed +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + alignas(16) int64_t max[2]; + max[0] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + max[1] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128{wasm_v128_load(max)}; +} + +// Float +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + // Equivalent to b < a ? a : b (taking into account our swapped arg order, + // so that Max(NaN, x) is x to match x86). + return Vec128{wasm_f32x4_pmax(b.raw, a.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + // Equivalent to b < a ? a : b (taking into account our swapped arg order, + // so that Max(NaN, x) is x to match x86). + return Vec128{wasm_f64x2_pmax(b.raw, a.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_u16x8_extmul_low_u8x16(a.raw, b.raw); + const auto h = wasm_u16x8_extmul_high_u8x16(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i8x16_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_i16x8_extmul_low_i8x16(a.raw, b.raw); + const auto h = wasm_i16x8_extmul_high_i8x16(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i8x16_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_u32x4_extmul_low_u16x8(a.raw, b.raw); + const auto h = wasm_u32x4_extmul_high_u16x8(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_i32x4_extmul_low_i16x8(a.raw, b.raw); + const auto h = wasm_i32x4_extmul_high_i16x8(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_u64x2_extmul_low_u32x4(a.raw, b.raw); + const auto h = wasm_u64x2_extmul_high_u32x4(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i32x4_shuffle(l, h, 1, 3, 5, 7)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_i64x2_extmul_low_i32x4(a.raw, b.raw); + const auto h = wasm_i64x2_extmul_high_i32x4(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i32x4_shuffle(l, h, 1, 3, 5, 7)}; +} + +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_q15mulr_sat(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and returns the double-width result. +template +HWY_API Vec128, (N + 1) / 2> MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + constexpr int kSrcBits = sizeof(T) * 8; + + const auto ae = + ShiftRight(ShiftLeft(ResizeBitCast(dw, a))); + const auto be = + ShiftRight(ShiftLeft(ResizeBitCast(dw, b))); + return ae * be; +} +template +HWY_API Vec128, (N + 1) / 2> MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + const auto kEvenMask = Set(dw, LimitsMax()); + + const auto ae = And(ResizeBitCast(dw, a), kEvenMask); + const auto be = And(ResizeBitCast(dw, b), kEvenMask); + return ae * be; +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + const auto ae = ShiftRight<32>(ShiftLeft<32>(ResizeBitCast(dw, a))).raw; + const auto be = ShiftRight<32>(ShiftLeft<32>(ResizeBitCast(dw, b))).raw; + return Vec128{wasm_i64x2_mul(ae, be)}; +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128{wasm_i64x2_mul(ae, be)}; +} + +// Multiplies odd lanes (1, 3 ..) and returns the double-width result. +template +HWY_API Vec128, (N + 1) / 2> MulOdd(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + constexpr int kSrcBits = sizeof(T) * 8; + + const auto ao = ShiftRight(BitCast(dw, a)); + const auto bo = ShiftRight(BitCast(dw, b)); + return ao * bo; +} +template +HWY_API Vec128, (N + 1) / 2> MulOdd(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + + const auto ao = ShiftRight<32>(BitCast(dw, a)); + const auto bo = ShiftRight<32>(BitCast(dw, b)); + return Vec128, (N + 1) / 2>{wasm_i64x2_mul(ao.raw, bo.raw)}; +} + +// ------------------------------ Negate + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i8x16_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i16x8_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i32x4_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i64x2_neg(v.raw)}; +} + +// ------------------------------ Floating-point mul / div + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{wasm_f32x4_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{wasm_f64x2_mul(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_div(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f64x2_div(a.raw, b.raw)}; +} + +template )> +HWY_API V ApproximateReciprocal(const V v) { + return Set(DFromV(), 1.0f) / v; +} + +// Integer overload defined in generic_ops-inl.h. +template +HWY_API Vec128 AbsDiff(const Vec128 a, const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return mul * x + add; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return add - mul * x; +} + +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return mul * x - sub; +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{wasm_f32x4_sqrt(v.raw)}; +} +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{wasm_f64x2_sqrt(v.raw)}; +} + +// Approximate reciprocal square root +template )> +HWY_API V ApproximateReciprocalSqrt(V v) { + // TODO(eustas): find cheaper a way to calculate this. + return Set(DFromV(), static_cast>(1.0)) / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{wasm_f32x4_nearest(v.raw)}; +} +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{wasm_f64x2_nearest(v.raw)}; +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{wasm_f32x4_trunc(v.raw)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{wasm_f64x2_trunc(v.raw)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{wasm_f32x4_ceil(v.raw)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{wasm_f64x2_ceil(v.raw)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{wasm_f32x4_floor(v.raw)}; +} +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{wasm_f64x2_floor(v.raw)}; +} + +// ------------------------------ Floating-point classification +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return v != v; +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vu, vu), Set(du, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API MFromD RebindMask(DTo /* tag */, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f64x2_eq(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// Unsigned +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f64x2_ne(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition d32; + const auto a32 = BitCast(d32, a); + const auto b32 = BitCast(d32, b); + // If the upper halves are not equal, this is the answer. + const auto m_gt = a32 > b32; + + // Otherwise, the lower half decides. + const auto m_eq = a32 == b32; + const auto lo_in_hi = wasm_i32x4_shuffle(m_gt.raw, m_gt.raw, 0, 0, 2, 2); + const auto lo_gt = And(m_eq, MaskFromVec(VFromD{lo_in_hi})); + + const auto gt = Or(lo_gt, m_gt); + // Copy result in upper 32 bits to lower 32 bits. + return Mask128{wasm_i32x4_shuffle(gt.raw, gt.raw, 1, 1, 3, 3)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f64x2_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { + return operator>(b, a); +} + +// ------------------------------ Weak inequality + +// Float >= +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f64x2_ge(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ge(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u8x16_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u16x8_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u32x4_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Not(b > a); +} + +template +HWY_API Mask128 operator<=(const Vec128 a, const Vec128 b) { + return operator>=(b, a); +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API MFromD FirstN(D d, size_t num) { + const RebindToSigned di; // Signed comparisons may be cheaper. + using TI = TFromD; + return RebindMask(d, Iota(di, 0) < Set(di, static_cast(num))); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec128 Not(Vec128 v) { + return Vec128{wasm_v128_not(v.raw)}; +} + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_and(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + return Vec128{wasm_v128_andnot(mask.raw, not_mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_or(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_xor(a.raw, b.raw)}; +} + +// ------------------------------ Xor3 + +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 + +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ BroadcastSignBit (compare) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; + return VecFromMask(d, v < Zero(d)); +} + +// ------------------------------ Mask + +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VFromD{v.raw}; +} + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{wasm_v128_bitselect(yes.raw, no.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + const DFromM d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) + +// The x86 multiply-by-Pow2() trick will not work because WASM saturates +// float->int correctly to 2^31-1 (not 2^31). Because WASM's shifts take a +// scalar count operand, per-lane shift instructions would require extract_lane +// for each lane, and hoping that shuffle is correctly mapped to a native +// instruction. Using non-vector shifts would incur a store-load forwarding +// stall when loading the result vector. We instead test bits of the shift +// count to "predicate" a shift of the entire vector by a constant. + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<5>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + const RebindToUnsigned du; + using TU = MakeUnsigned; + alignas(16) TU lanes[2] = {}; + alignas(16) TU bits_lanes[2] = {}; + Store(BitCast(du, v), du, lanes); + Store(BitCast(du, bits), du, bits_lanes); + lanes[0] <<= (bits_lanes[0] & 63); + lanes[1] <<= (bits_lanes[1] & 63); + return BitCast(d, Load(du, lanes)); +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<5>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + alignas(16) T lanes[2] = {}; + alignas(16) T bits_lanes[2] = {}; + Store(v, d, lanes); + Store(bits, d, bits_lanes); + lanes[0] >>= (bits_lanes[0] & 63); + lanes[1] >>= (bits_lanes[1] & 63); + return Load(d, lanes); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template > +HWY_API Vec128 Load(D /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128{wasm_v128_load(aligned)}; +} + +// Partial +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + VFromD v; + CopyBytes(p, &v); + return v; +} + +// LoadU == Load. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +template > +HWY_API VFromD MaskedLoad(MFromD m, D d, const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template > +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const T* HWY_RESTRICT aligned) { + return IfThenElse(m, Load(d, aligned), v); +} + +// ------------------------------ Store + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i8x16_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + const int16_t lane = wasm_i16x8_extract_lane(v.raw, kLane); + return static_cast(lane); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + const uint16_t bits = ExtractLane(BitCast(du, v)); + return BitCastScalar(bits); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i32x4_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i64x2_extract_lane(v.raw, kLane)); +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + return wasm_f32x4_extract_lane(v.raw, kLane); +} +template +HWY_INLINE double ExtractLane(const Vec128 v) { + return wasm_f64x2_extract_lane(v.raw, kLane); +} + +} // namespace detail + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// Partial +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { + CopyBytes(&v, p); +} + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + *p = detail::ExtractLane<0>(v); +} + +// StoreU == Store. +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// ------------------------------ Scatter in generic_ops-inl.h +// ------------------------------ Gather in generic_ops-inl.h + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane + +// One overload per vector length just in case *_extract_lane raise compile +// errors if their argument is out of bounds (even if that would never be +// reached at runtime). +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return detail::ExtractLane<0>(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ GetLane +template +HWY_API T GetLane(const Vec128 v) { + return detail::ExtractLane<0>(v); +} + +// ------------------------------ InsertLane + +namespace detail { + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i8x16_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i16x8_replace_lane(v.raw, kLane, BitCastScalar(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i32x4_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i64x2_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{wasm_f32x4_replace_lane(v.raw, kLane, t)}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + return Vec128{wasm_f64x2_replace_lane(v.raw, kLane, t)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls wasm_f64x2_replace_lane. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{v.raw}; +} +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128{v.raw}; +} + +// ------------------------------ ShiftLeftBytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_API VFromD ShiftLeftBytes(D /* tag */, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + switch (kBytes) { + case 0: + return v; + + case 1: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14)}; + + case 2: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13)}; + + case 3: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 0, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12)}; + + case 4: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11)}; + + case 5: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10)}; + + case 6: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)}; + + case 7: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 0, 1, 2, 3, 4, 5, 6, 7, 8)}; + + case 8: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 0, 1, 2, 3, 4, 5, 6, 7)}; + + case 9: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 0, 1, 2, 3, 4, 5, 6)}; + + case 10: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 0, 1, 2, 3, 4, 5)}; + + case 11: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 0, 1, 2, 3, 4)}; + + case 12: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 0, 1, 2, 3)}; + + case 13: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 0, 1, 2)}; + + case 14: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 0, + 1)}; + + case 15: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 0)}; + } + return VFromD{zero}; +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API VFromD ShiftLeftLanes(D d, const VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +namespace detail { + +// Helper function allows zeroing invalid lanes in caller. +template +HWY_API __i8x16 ShrBytes(const Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + + switch (kBytes) { + case 0: + return v.raw; + + case 1: + return wasm_i8x16_shuffle(v.raw, zero, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16); + + case 2: + return wasm_i8x16_shuffle(v.raw, zero, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16); + + case 3: + return wasm_i8x16_shuffle(v.raw, zero, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16, 16); + + case 4: + return wasm_i8x16_shuffle(v.raw, zero, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 16, 16, 16); + + case 5: + return wasm_i8x16_shuffle(v.raw, zero, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 16, 16, 16, 16); + + case 6: + return wasm_i8x16_shuffle(v.raw, zero, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16); + + case 7: + return wasm_i8x16_shuffle(v.raw, zero, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16, 16); + + case 8: + return wasm_i8x16_shuffle(v.raw, zero, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 9: + return wasm_i8x16_shuffle(v.raw, zero, 9, 10, 11, 12, 13, 14, 15, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 10: + return wasm_i8x16_shuffle(v.raw, zero, 10, 11, 12, 13, 14, 15, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 11: + return wasm_i8x16_shuffle(v.raw, zero, 11, 12, 13, 14, 15, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 12: + return wasm_i8x16_shuffle(v.raw, zero, 12, 13, 14, 15, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 13: + return wasm_i8x16_shuffle(v.raw, zero, 13, 14, 15, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 14: + return wasm_i8x16_shuffle(v.raw, zero, 14, 15, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 15: + return wasm_i8x16_shuffle(v.raw, zero, 15, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + case 16: + return zero; + } +} + +} // namespace detail + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + // For partial vectors, clear upper lanes so we shift in zeros. + if (d.MaxBytes() != 16) { + const Full128> dfull; + const VFromD vfull{v.raw}; + v = VFromD{IfThenElseZero(FirstN(dfull, MaxLanes(d)), vfull).raw}; + } + return VFromD{detail::ShrBytes(v)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API VFromD ShiftRightLanes(D d, const VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +template > +HWY_API Vec64 UpperHalf(D /* tag */, const Vec128 v) { + return Vec64{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} + +// Partial +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + return LowerHalf(d, ShiftRightBytes(Twice(), v)); +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API Vec128 CombineShiftRightBytes(D /* tag */, Vec128 hi, + Vec128 lo) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + switch (kBytes) { + case 0: + return lo; + + case 1: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16)}; + + case 2: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17)}; + + case 3: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18)}; + + case 4: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19)}; + + case 5: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20)}; + + case 6: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21)}; + + case 7: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22)}; + + case 8: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23)}; + + case 9: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24)}; + + case 10: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25)}; + + case 11: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26)}; + + case 12: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27)}; + + case 13: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28)}; + + case 14: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29)}; + + case 15: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30)}; + } + return hi; +} + +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + using V8 = Vec128; + const DFromV dfull8; + const Repartition, decltype(dfull8)> dfull; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(dfull8, hi8, lo8); + return VFromD{BitCast(dfull, r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i8x16_shuffle( + v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i16x8_shuffle(v.raw, v.raw, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{ + wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, kLane, kLane)}; +} + +// ------------------------------ TableLookupBytes + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + return Vec128{wasm_i8x16_swizzle(bytes.raw, from.raw)}; +} + +template +HWY_API Vec128 TableLookupBytesOr0(const Vec128 bytes, + const Vec128 from) { + const DFromV d; + // Mask size must match vector type, so cast everything to this type. + Repartition di8; + Repartition> d_bytes8; + const auto msb = BitCast(di8, from) < Zero(di8); + const auto lookup = + TableLookupBytes(BitCast(d_bytes8, bytes), BitCast(di8, from)); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. +namespace detail { + +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 1, 0, 3 + 16, 2 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 1, 0, 3 + 8, 2 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 1, 0, 3 + 4, 2 + 4)}; +} + +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 0, 3, 2 + 16, 1 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 0, 3, 2 + 8, 1 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 3, 2 + 4, 1 + 4)}; +} + +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 2, 1, 0 + 16, 3 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 2, 1, 0 + 8, 3 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 1, 0 + 4, 3 + 4)}; +} + +} // namespace detail + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle01(const Vec128 v) { + static_assert(sizeof(T) == 8, "Only for 64-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} + +// Reverse +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + __v128_u raw; +}; + +namespace detail { + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + return Iota(d8, 0); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + return Zero(d8); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}; + return Load(d8, kByteOffsets); +} + +} // namespace detail + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + (void)d; + return Indices128, MaxLanes(D())>{vec.raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + const Repartition d8; + using V8 = VFromD; + + // Broadcast each lane index to all bytes of T and shift to bytes + const V8 lane_indices = TableLookupBytes( + BitCast(d8, vec), detail::IndicesFromVecBroadcastLaneBytes(d)); + constexpr int kIndexShiftAmt = static_cast(FloorLog2(sizeof(T))); + const V8 byte_indices = ShiftLeft(lane_indices); + const V8 sum = Add(byte_indices, detail::IndicesFromVecByteOffsets(d)); + return Indices128, MaxLanes(D())>{BitCast(d, sum).raw}; +} + +template +HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( + D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + using TI = MakeSigned; + const DFromV d; + const Rebind di; + return BitCast(d, TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Twice dt; +// TableLookupLanes currently requires table and index vectors to be the same +// size, though a half-length index vector would be sufficient here. +#if HWY_IS_MSAN + const Vec128 idx_vec{idx.raw}; + const Indices128 idx2{Combine(dt, idx_vec, idx_vec).raw}; +#else + // We only keep LowerHalf of the result, which is valid in idx. + const Indices128 idx2{idx.raw}; +#endif + return LowerHalf(d, TableLookupLanes(Combine(dt, b, a), idx2)); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Repartition du8; + + const VFromD byte_idx{idx.raw}; + const auto byte_idx_mod = byte_idx & Set(du8, uint8_t{0x0F}); + // If ANDing did not change the index, it is for the lower half. + const auto is_lo = (byte_idx == byte_idx_mod); + + return BitCast(d, IfThenElse(is_lo, TableLookupBytes(a, byte_idx_mod), + TableLookupBytes(b, byte_idx_mod))); +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) + +// Single lane: no change +template , HWY_IF_LANES_D(D, 1)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return v; +} + +// 32-bit x2: shuffle +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec64 Reverse(D /* tag */, const Vec64 v) { + return Vec64{Shuffle2301(Vec128{v.raw}).raw}; +} + +// 64-bit x2: shuffle +template , HWY_IF_T_SIZE(T, 8)> +HWY_API Vec128 Reverse(D /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// 32-bit x2: shuffle +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 Reverse(D /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API VFromD Reverse(D d, const VFromD v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + static constexpr int kN = 16 + Lanes(d); + return VFromD{wasm_i8x16_shuffle( + v.raw, v.raw, + // kN is adjusted to ensure we have valid indices for all lengths. + kN - 1, kN - 2, kN - 3, kN - 4, kN - 5, kN - 6, kN - 7, kN - 8, kN - 9, + kN - 10, kN - 11, kN - 12, kN - 13, kN - 14, kN - 15, kN - 16)}; +} + +// ------------------------------ Reverse2 + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RepartitionToWide> dw; + return BitCast(d, RotateRight<16>(BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D /* tag */, const VFromD v) { + return Shuffle2301(v); +} + +template +HWY_API VFromD Reverse2(D /* tag */, const VFromD v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return VFromD{wasm_i16x8_shuffle(v.raw, v.raw, 3, 2, 1, 0, 7, 6, 5, 4)}; +} + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return Shuffle0123(v); +} + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + return Reverse(d, v); +} + +template +HWY_API VFromD Reverse8(D /* tag */, const VFromD) { + HWY_ASSERT(0); // don't have 8 lanes for > 16-bit lanes +} + +// ------------------------------ InterleaveLower + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, InterleaveLower(BitCast(du, a), BitCast(du, b))); +} + +// Additional overload for the optional tag (all vector lengths). +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +} // namespace detail + +// Full +template > +HWY_API Vec128 InterleaveUpper(D /* tag */, Vec128 a, Vec128 b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + return InterleaveLower(d, VFromD{UpperHalf(d2, a).raw}, + VFromD{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ------------------------------ Per4LaneBlockShuffle +namespace detail { + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + return V{wasm_i8x16_shuffle(v.raw, v.raw, kIdx0, kIdx1, kIdx2, kIdx3, + kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4, + kIdx0 + 8, kIdx1 + 8, kIdx2 + 8, kIdx3 + 8, + kIdx0 + 12, kIdx1 + 12, kIdx2 + 12, kIdx3 + 12)}; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + return V{wasm_i16x8_shuffle(v.raw, v.raw, kIdx0, kIdx1, kIdx2, kIdx3, + kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4)}; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + return V{wasm_i32x4_shuffle(v.raw, v.raw, kIdx0, kIdx1, kIdx2, kIdx3)}; +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Full64 du64; + const auto vu64 = ResizeBitCast(du64, v); + return ResizeBitCast( + d, ShiftLeftSame(vu64, static_cast(amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Repartition du8; + const auto idx = + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromV))); + return BitCast(d, TableLookupBytesOr0(BitCast(du8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideUpLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + case 8: + return ShiftLeftLanes<8>(d, v); + case 9: + return ShiftLeftLanes<9>(d, v); + case 10: + return ShiftLeftLanes<10>(d, v); + case 11: + return ShiftLeftLanes<11>(d, v); + case 12: + return ShiftLeftLanes<12>(d, v); + case 13: + return ShiftLeftLanes<13>(d, v); + case 14: + return ShiftLeftLanes<14>(d, v); + case 15: + return ShiftLeftLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition, decltype(d)> dv; + return BitCast(d, + ShiftRightSame(BitCast(dv, v), + static_cast(amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition di8; + auto idx = Iota(di8, static_cast(amt * sizeof(TFromV))); + idx = Or(idx, VecFromMask(di8, idx > Set(di8, int8_t{15}))); + return BitCast(d, TableLookupBytesOr0(BitCast(di8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideDownLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + case 8: + return ShiftRightLanes<8>(d, v); + case 9: + return ShiftRightLanes<9>(d, v); + case 10: + return ShiftRightLanes<10>(d, v); + case 11: + return ShiftRightLanes<11>(d, v); + case 12: + return ShiftRightLanes<12>(d, v); + case 13: + return ShiftRightLanes<13>(d, v); + case 14: + return ShiftRightLanes<14>(d, v); + case 15: + return ShiftRightLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + const RebindToUnsigned duh; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(duh, lo_half).raw}; + const VU hi{BitCast(duh, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (IfThenElseZero) +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const Half dh; + return IfThenElseZero(FirstN(d, MaxLanes(dh)), VFromD{lo.raw}); +} + +// ------------------------------ ConcatLowerLower +template > +HWY_API Vec128 ConcatLowerLower(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; +} + +// ------------------------------ ConcatUpperUpper +template > +HWY_API Vec128 ConcatUpperUpper(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; +} + +// ------------------------------ ConcatLowerUpper +template > +HWY_API Vec128 ConcatLowerUpper(D d, Vec128 hi, Vec128 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// ------------------------------ ConcatUpperLower +template > +HWY_API Vec128 ConcatUpperLower(D d, Vec128 hi, Vec128 lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatLowerUpper(D d, const VFromD hi, + const VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 ConcatOdd(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} + +// 8-bit x8 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 ConcatOdd(D /* tag */, Vec64 hi, Vec64 lo) { + // Don't care about upper half. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 17, 19, 21, + 23, 1, 3, 5, 7, 17, 19, 21, 23)}; +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatOdd(D /* tag */, Vec32 hi, Vec32 lo) { + // Don't care about upper 3/4. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 17, 19, 1, 3, 17, + 19, 1, 3, 17, 19, 1, 3, 17, 19)}; +} + +// 16-bit full +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 ConcatOdd(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +// 16-bit x4 +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 ConcatOdd(D /* tag */, Vec64 hi, Vec64 lo) { + // Don't care about upper half. + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 9, 11, 1, 3, 9, 11)}; +} + +// 32-bit full +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 ConcatOdd(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 1, 3, 5, 7)}; +} + +// Any T x2 +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 ConcatEven(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30)}; +} + +// 8-bit x8 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 ConcatEven(D /* tag */, Vec64 hi, Vec64 lo) { + // Don't care about upper half. + return Vec64{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 16, 18, 20, 22, + 0, 2, 4, 6, 16, 18, 20, 22)}; +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatEven(D /* tag */, Vec32 hi, Vec32 lo) { + // Don't care about upper 3/4. + return Vec32{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 16, 18, 0, 2, 16, 18, + 0, 2, 16, 18, 0, 2, 16, 18)}; +} + +// 16-bit full +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 ConcatEven(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14)}; +} + +// 16-bit x4 +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 ConcatEven(D /* tag */, Vec64 hi, Vec64 lo) { + // Don't care about upper half. + return Vec64{wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 8, 10, 0, 2, 8, 10)}; +} + +// 32-bit full +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 ConcatEven(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 0, 2, 4, 6)}; +} + +// Any T x2 +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{wasm_i8x16_shuffle(v.raw, v.raw, 0, 0, 2, 2, 4, 4, 6, 6, + 8, 8, 10, 10, 12, 12, 14, 14)}; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{wasm_i16x8_shuffle(v.raw, v.raw, 0, 0, 2, 2, 4, 4, 6, 6)}; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 0, 0, 2, 2)}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{wasm_i8x16_shuffle(v.raw, v.raw, 1, 1, 3, 3, 5, 5, 7, 7, + 9, 9, 11, 11, 13, 13, 15, 15)}; +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{wasm_i16x8_shuffle(v.raw, v.raw, 1, 1, 3, 3, 5, 5, 7, 7)}; +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 1, 3, 3)}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven + +namespace detail { + +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<1> /* tag */, const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) static constexpr uint8_t mask[16] = { + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<2> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 8, 1, 10, 3, 12, 5, 14, 7)}; +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<4> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<8> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 2, 1)}; +} + +} // namespace detail + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + return detail::OddEven(hwy::SizeTag(), a, b); +} +template +HWY_API Vec128 OddEven(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} + +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i8x16_shuffle(a.raw, b.raw, 0, 16, 2, 18, 4, 20, 6, 22, + 8, 24, 10, 26, 12, 28, 14, 30)}; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 2, 10, 4, 12, 6, 14)}; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 2, 6)}; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i8x16_shuffle(a.raw, b.raw, 1, 17, 3, 19, 5, 21, 7, 23, + 9, 25, 11, 27, 13, 29, 15, 31)}; +} + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i16x8_shuffle(a.raw, b.raw, 1, 9, 3, 11, 5, 13, 7, 15)}; +} + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i32x4_shuffle(a.raw, b.raw, 1, 5, 3, 7)}; +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ ReverseBlocks +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; // Single block: no change +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u32x4_extend_low_u16x8(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u64x2_extend_low_u32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u32x4_extend_low_u16x8(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u64x2_extend_low_u32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} + +// U8/U16 to U64/I64: First, zero-extend to U32, and then zero-extend to +// TFromD +template +HWY_API VFromD PromoteTo(D d, V v) { + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +} + +// Signed: replicate sign bit. +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i16x8_extend_low_i8x16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i32x4_extend_low_i16x8(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i64x2_extend_low_i32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{ + wasm_i32x4_extend_low_i16x8(wasm_i16x8_extend_low_i8x16(v.raw))}; +} + +// I8/I16 to I64: First, promote to I32, and then promote to I64 +template +HWY_API VFromD PromoteTo(D d, V v) { + const Rebind di32; + return PromoteTo(d, PromoteTo(di32, v)); +} + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f64x2_convert_low_i32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f64x2_convert_low_u32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f64x2_promote_low_f32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{157}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + const auto f32_to_i32_result = ConvertTo(di32, adj_v); + const auto lo64_or_mask = PromoteTo( + di64, + BitCast(du32, VecFromMask(di32, Eq(f32_to_i32_result, + Set(di32, LimitsMax()))))); + + return Or(PromoteTo(di64, BitCast(di32, f32_to_i32_result)) + << PromoteTo(di64, exponent_adj), + lo64_or_mask); +} + +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { + const Rebind du32; + const RebindToFloat df32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{158}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + const auto f32_to_u32_result = ConvertTo(du32, adj_v); + const auto lo32_or_mask = PromoteTo( + du64, + VecFromMask(du32, f32_to_u32_result == Set(du32, LimitsMax()))); + + return Or(PromoteTo(du64, f32_to_u32_result) << PromoteTo(du64, exponent_adj), + lo32_or_mask); +} + +// ------------------------------ PromoteUpperTo + +// Per-target flag to prevent generic_ops-inl.h from defining PromoteUpperTo. +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// Unsigned: zero-extend. +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u32x4_extend_high_u16x8(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u64x2_extend_high_u32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u32x4_extend_high_u16x8(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u64x2_extend_high_u32x4(v.raw)}; +} + +// Signed: replicate sign bit. +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_i16x8_extend_high_i8x16(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_i32x4_extend_high_i16x8(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_i64x2_extend_high_i32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteUpperTo(D df32, VFromD> v) { + const Rebind dh; + return PromoteTo(df32, UpperHalf(dh, v)); +} + +template +HWY_API VFromD PromoteUpperTo(D df32, VFromD> v) { + const Repartition du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteUpperTo(D dd, VFromD> v) { + // There is no wasm_f64x2_convert_high_i32x4. + return PromoteTo(dd, UpperHalf(Rebind(), v)); +} + +template +HWY_API VFromD PromoteUpperTo(D dd, VFromD> v) { + // There is no wasm_f64x2_convert_high_u32x4. + return PromoteTo(dd, UpperHalf(Rebind(), v)); +} + +template +HWY_API VFromD PromoteUpperTo(D dd, VFromD> v) { + // There is no wasm_f64x2_promote_high_f32x4. + return PromoteTo(dd, UpperHalf(Rebind(), v)); +} + +template +HWY_API VFromD PromoteUpperTo(D d64, VFromD> v) { + return PromoteTo(d64, UpperHalf(Rebind(), v)); +} + +// Generic version for <=64 bit input/output (_high is only for full vectors). +template +HWY_API VFromD PromoteUpperTo(D d, V v) { + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "third_party/highway/hwy/ops/inside-inl.h" + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return VFromD{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return VFromD{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV du32; + const RebindToSigned di32; + return DemoteTo(dn, BitCast(di32, Min(v, Set(du32, 0x7FFFFFFF)))); +} + +template +HWY_API VFromD DemoteTo(D du8, VFromD> v) { + const DFromV du16; + const RebindToSigned di16; + return DemoteTo(du8, BitCast(di16, Min(v, Set(du16, 0x7FFF)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u32x4_trunc_sat_f64x2_zero(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f32x4_demote_f64x2_zero(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64_63 = Set(df64, 27670116110564327424.0); + const auto f64_hi52 = + Xor(BitCast(df64, ShiftRight<12>(BitCast(du64, v))), k2p64_63) - k2p64_63; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + const auto f64_bits_decrement = + And(ShiftRight<63>(BitCast(du64, Xor(f64_sum, f64_carry))), + f64_sum_is_inexact); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - f64_bits_decrement, f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64 = Set(df64, 18446744073709551616.0); + const auto f64_hi52 = Or(BitCast(df64, ShiftRight<12>(v)), k2p64) - k2p64; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - ShiftRight<63>(BitCast(du64, f64_carry)), + f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} + +// Specializations for partial vectors because i16x8_narrow_i32x4 sets lanes +// above 2*N. +template +HWY_API Vec32 ReorderDemote2To(D dn, Vec32 a, + Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API Vec64 ReorderDemote2To(D dn, Vec64 a, + Vec64 b) { + const Twice dn_full; + const Repartition du32_full; + + const Vec128 v_full{wasm_i16x8_narrow_i32x4(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template +HWY_API Vec128 ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_narrow_i32x4(a.raw, b.raw)}; +} + +template +HWY_API Vec32 ReorderDemote2To(D dn, Vec32 a, + Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API Vec64 ReorderDemote2To(D dn, Vec64 a, + Vec64 b) { + const Twice dn_full; + const Repartition du32_full; + + const Vec128 v_full{wasm_u16x8_narrow_i32x4(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template +HWY_API Vec128 ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return Vec128{wasm_u16x8_narrow_i32x4(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV du32; + const RebindToSigned di32; + const auto max_i32 = Set(du32, 0x7FFFFFFFu); + + const auto clamped_a = BitCast(di32, Min(a, max_i32)); + const auto clamped_b = BitCast(di32, Min(b, max_i32)); + return ReorderDemote2To(dn, clamped_a, clamped_b); +} +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +// Specializations for partial vectors because i8x16_narrow_i16x8 sets lanes +// above 2*N. +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API Vec64 ReorderDemote2To(D dn, Vec64 a, + Vec64 b) { + const Twice dn_full; + const Repartition du32_full; + + const Vec128 v_full{wasm_i8x16_narrow_i16x8(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template +HWY_API Vec128 ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_narrow_i16x8(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API Vec64 ReorderDemote2To(D dn, Vec64 a, + Vec64 b) { + const Twice dn_full; + const Repartition du32_full; + + const Vec128 v_full{wasm_u8x16_narrow_i16x8(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template +HWY_API Vec128 ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return Vec128{wasm_u8x16_narrow_i16x8(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV du16; + const RebindToSigned di16; + const auto max_i16 = Set(du16, 0x7FFFu); + + const auto clamped_a = BitCast(di16, Min(a, max_i16)); + const auto clamped_b = BitCast(di16, Min(b, max_i16)); + return ReorderDemote2To(dn, clamped_a, clamped_b); +} +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(DTo /* tag */, Vec128 v) { + // BitCast requires the same size; DTo might be u8x1 and v u16x1. + const Repartition, DFromV> dto; + return VFromD{BitCast(dto, v).raw}; +} + +template +HWY_API Vec16 TruncateTo(D /* tag */, Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + const auto v4 = ConcatEven(d, v2, v2); + return LowerHalf(LowerHalf(LowerHalf(ConcatEven(d, v4, v4)))); +} + +template +HWY_API Vec32 TruncateTo(D /* tag */, Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + return LowerHalf(LowerHalf(ConcatEven(d, v2, v2))); +} + +template +HWY_API Vec64 TruncateTo(D /* tag */, Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + const auto v3 = ConcatEven(d, v2, v2); + return VFromD{v3.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return VFromD{v2.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return VFromD{v2.raw}; +} + +// ------------------------------ Demotions to/from i64 + +namespace detail { +template +HWY_INLINE VFromD> DemoteFromU64MaskOutResult( + D /*dn*/, VFromD> v) { + return v; +} + +template +HWY_INLINE VFromD> DemoteFromU64MaskOutResult( + D /*dn*/, VFromD> v) { + const DFromV du64; + return And(v, + Set(du64, static_cast(hwy::HighestValue>()))); +} + +template +HWY_INLINE VFromD> DemoteFromU64Saturate( + D dn, VFromD> v) { + const Rebind du64; + const RebindToSigned di64; + constexpr int kShiftAmt = static_cast(sizeof(TFromD) * 8) - + static_cast(hwy::IsSigned>()); + + const auto too_big = BitCast( + du64, VecFromMask( + di64, Gt(BitCast(di64, ShiftRight(v)), Zero(di64)))); + return DemoteFromU64MaskOutResult(dn, Or(v, too_big)); +} + +template +HWY_INLINE VFromD ReorderDemote2From64To32Combine(D dn, V a, V b) { + return ConcatEven(dn, BitCast(dn, b), BitCast(dn, a)); +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV di64; + const RebindToUnsigned du64; + const RebindToUnsigned dn_u; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask = BitCast(du64, BroadcastSignBit(v)); + const auto saturated_vals = Xor( + invert_mask, + detail::DemoteFromU64Saturate(dn, Xor(invert_mask, BitCast(du64, v)))); + return BitCast(dn, TruncateTo(dn_u, saturated_vals)); +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV di64; + const RebindToUnsigned du64; + + const auto non_neg_vals = BitCast(du64, AndNot(BroadcastSignBit(v), v)); + return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, non_neg_vals)); +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, v)); +} + +template )> +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask_a = BitCast(du64, BroadcastSignBit(a)); + const auto invert_mask_b = BitCast(du64, BroadcastSignBit(b)); + const auto saturated_a = Xor( + invert_mask_a, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_a, BitCast(du64, a)))); + const auto saturated_b = Xor( + invert_mask_b, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_b, BitCast(du64, b)))); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + + const auto saturated_a = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(a), a))); + const auto saturated_b = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(b), b))); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const Half dnh; + + const auto saturated_a = detail::DemoteFromU64Saturate(dnh, a); + const auto saturated_b = detail::DemoteFromU64Saturate(dnh, b); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template ), class V, + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} + +// ------------------------------ ConvertTo + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f32x4_convert_i32x4(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f32x4_convert_u32x4(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D dd, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition d32; + const Repartition d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +} + +namespace detail { +template +HWY_INLINE VFromD>> U64ToF64VecFast(VW w) { + const DFromV d64; + const RebindToFloat dd; + const auto cnst2_52_dbl = Set(dd, 0x0010000000000000); // 2^52 + return BitCast(dd, Or(w, BitCast(d64, cnst2_52_dbl))) - cnst2_52_dbl; +} +} // namespace detail + +template +HWY_API VFromD ConvertTo(D dd, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned d64; + using VU = VFromD; + + const VU msk_lo = Set(d64, 0xFFFFFFFF); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest/highest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + const auto v_lo_dbl = detail::U64ToF64VecFast(v_lo); + return MulAdd(cnst2_32_dbl, detail::U64ToF64VecFast(v_hi), v_lo_dbl); +} + +// Truncates (rounds toward zero). +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i32x4_trunc_sat_f32x4(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u32x4_trunc_sat_f32x4(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { + using VI = VFromD; + using MI = MFromD; + const RebindToUnsigned du; + using VU = VFromD; + const Repartition du16; + const VI k1075 = Set(di, 1075); // biased exponent of 2^52 + + // Exponent indicates whether the number can be represented as int64_t. + const VU biased_exp = ShiftRight<52>(BitCast(du, v)) & Set(du, 0x7FF); + const MI in_range = BitCast(di, biased_exp) < Set(di, 1086); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate). + // Use 16-bit saturated unsigned subtraction to compute shift_mnt and + // shift_int since biased_exp[i] is a non-negative integer that is less than + // or equal to 2047. + // The upper 48 bits of both shift_mnt and shift_int are guaranteed to be + // zero as the upper 48 bits of both k1075 and biased_exp are zero. + + const VU shift_mnt = BitCast( + du, SaturatedSub(BitCast(du16, k1075), BitCast(du16, biased_exp))); + const VU shift_int = BitCast( + du, SaturatedSub(BitCast(du16, biased_exp), BitCast(du16, k1075))); + const VU mantissa = BitCast(du, v) & Set(du, (1ULL << 52) - 1); + // Include implicit 1-bit + VU int53 = (mantissa | Set(du, 1ULL << 52)) >> shift_mnt; + // WASM clamps shift count; zero if greater. + const MI tiny = BitCast(di, shift_mnt) > Set(di, 63); + int53 = IfThenZeroElse(RebindMask(du, tiny), int53); + + // For inputs larger than 2^53 - 1, insert zeros at the bottom. + // For inputs less than 2^63, the implicit 1-bit is guaranteed not to be + // shifted out of the left shift result below as shift_int[i] <= 10 is true + // for any inputs that are less than 2^63. + const VU shifted = int53 << shift_int; + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, BitCast(di, shifted), limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +} + +template +HWY_API VFromD ConvertTo(DU du, VFromD> v) { + const RebindToSigned di; + using MI = MFromD; + using VU = VFromD; + const Repartition du16; + const VU k1075 = Set(du, 1075); /* biased exponent of 2^52 */ + + const auto non_neg_v = ZeroIfNegative(v); + + // Exponent indicates whether the number can be represented as int64_t. + const VU biased_exp = ShiftRight<52>(BitCast(du, non_neg_v)); + const VU out_of_range = + BitCast(du, VecFromMask(di, BitCast(di, biased_exp) > Set(di, 1086))); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + + // Use 16-bit saturated unsigned subtraction to compute shift_mnt and + // shift_int since biased_exp[i] is a non-negative integer that is less than + // or equal to 2047. + + // 16-bit saturated unsigned subtraction is also more efficient than a + // 64-bit subtraction followed by a 64-bit signed Max operation on + // WASM. + + // The upper 48 bits of both shift_mnt and shift_int are guaranteed to be + // zero as the upper 48 bits of both k1075 and biased_exp are zero. + + const VU shift_mnt = BitCast( + du, SaturatedSub(BitCast(du16, k1075), BitCast(du16, biased_exp))); + const VU shift_int = BitCast( + du, SaturatedSub(BitCast(du16, biased_exp), BitCast(du16, k1075))); + const VU mantissa = BitCast(du, non_neg_v) & Set(du, (1ULL << 52) - 1); + // Include implicit 1-bit. + VU int53 = (mantissa | Set(du, 1ULL << 52)) >> shift_mnt; + // WASM clamps shift count; zero if greater. + const MI tiny = BitCast(di, shift_mnt) > Set(di, 63); + int53 = IfThenZeroElse(RebindMask(du, tiny), int53); + + // For inputs larger than 2^53 - 1, insert zeros at the bottom. + + // For inputs less than 2^64, the implicit 1-bit is guaranteed not to be + // shifted out of the left shift result below as shift_int[i] <= 11 is true + // for any inputs that are less than 2^64. + + const VU shifted = int53 << shift_int; + return (shifted | out_of_range); +} + +// ------------------------------ NearestInt (Round) +template +HWY_API Vec128, N> NearestInt(const Vec128 v) { + return ConvertTo(RebindToSigned>(), Round(v)); +} + +// ------------------------------ DemoteToNearestInt (Round) +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + // No single instruction, round then demote. + return DemoteTo(di32, Round(v)); +} + +// ================================================== MISC + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = And(BitCast(du16, v), Set(du16, 0xFF)); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return And(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), Set(du64, 0xFFFF)); +} + +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + const DFromV di8; + const RepartitionToWide di16; + const RepartitionToWide di32; + const RepartitionToWide di64; + const RebindToUnsigned du32; + const RebindToUnsigned du64; + using VI16 = VFromD; + + const VI16 vFDB97531 = ShiftRight<8>(BitCast(di16, v)); + const VI16 vECA86420 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, v))); + const VI16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VI16 sDC_zz_98_zz_54_zz_10_zz = + BitCast(di16, ShiftLeft<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VI16 sFC_xx_B8_xx_74_xx_30_xx = + Add(sFE_DC_BA_98_76_54_32_10, sDC_zz_98_zz_54_zz_10_zz); + const VI16 sB8_xx_zz_zz_30_xx_zz_zz = + BitCast(di16, ShiftLeft<32>(BitCast(du64, sFC_xx_B8_xx_74_xx_30_xx))); + const VI16 sF8_xx_xx_xx_70_xx_xx_xx = + Add(sFC_xx_B8_xx_74_xx_30_xx, sB8_xx_zz_zz_30_xx_zz_zz); + return ShiftRight<48>(BitCast(di64, sF8_xx_xx_xx_70_xx_xx_xx)); +} + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const VFromD vbits{wasm_i32x4_splat(static_cast(bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) static constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) static constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(MaxLanes(d) + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Returns the lowest N bits for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(D d, uint64_t bits) { + return (d.MaxBytes() == 16) ? bits : bits & ((1ull << d.MaxLanes()) - 1); +} + +} // namespace detail + +template +HWY_API uint64_t BitsFromMask(D /*d*/, const MFromD mask) { + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, mask.raw); + + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + const uint64_t lo = ((lanes[0] * kMagic) >> 56); + const uint64_t hi = ((lanes[1] * kMagic) >> 48) & 0xFF00; + return hi + lo; // exactly 16 bits, no OnlyActive required +} + +template +HWY_API uint64_t BitsFromMask(D /*d*/, const MFromD mask) { + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + const uint64_t bytes = + static_cast(wasm_i64x2_extract_lane(mask.raw, 0)); + return (bytes * kMagic) >> 56; // exactly 8 bits, no OnlyActive required +} + +// 32-bit or less: need masking +template +HWY_API uint64_t BitsFromMask(D d, const MFromD mask) { + uint64_t bytes = static_cast(wasm_i64x2_extract_lane(mask.raw, 0)); + // Clear potentially undefined bytes. + bytes &= (1ULL << (Lanes(d) * 8)) - 1; + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + return detail::OnlyActive(d, (bytes * kMagic) >> 56); +} + +template +HWY_API uint64_t BitsFromMask(D /*d*/, const MFromD mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const Rebind d8; + using M8 = MFromD; + const __i16x8 zero = wasm_i16x8_splat(0); + const M8 mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; + return detail::OnlyActive(d8, BitsFromMask(d8, mask8)); +} + +template +HWY_API uint64_t BitsFromMask(D d, const MFromD mask) { + const __i32x4 mask_i = static_cast<__i32x4>(mask.raw); + const __i32x4 slice = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint32_t lanes[4]; + wasm_v128_store(lanes, sliced_mask); + return detail::OnlyActive(d, lanes[0] | lanes[1] | lanes[2] | lanes[3]); +} + +template +HWY_API uint64_t BitsFromMask(D d, const MFromD mask) { + const __i64x2 mask_i = static_cast<__i64x2>(mask.raw); + const __i64x2 slice = wasm_i64x2_make(1, 2); + const __i64x2 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, sliced_mask); + return detail::OnlyActive(d, lanes[0] | lanes[1]); +} + +namespace detail { + +// Returns 0xFF for bytes with index >= N, otherwise 0. +template +constexpr __i8x16 BytesAbove() { + return /**/ + (N == 0) ? wasm_i32x4_make(-1, -1, -1, -1) + : (N == 4) ? wasm_i32x4_make(0, -1, -1, -1) + : (N == 8) ? wasm_i32x4_make(0, 0, -1, -1) + : (N == 12) ? wasm_i32x4_make(0, 0, 0, -1) + : (N == 16) ? wasm_i32x4_make(0, 0, 0, 0) + : (N == 2) ? wasm_i16x8_make(0, -1, -1, -1, -1, -1, -1, -1) + : (N == 6) ? wasm_i16x8_make(0, 0, 0, -1, -1, -1, -1, -1) + : (N == 10) ? wasm_i16x8_make(0, 0, 0, 0, 0, -1, -1, -1) + : (N == 14) ? wasm_i16x8_make(0, 0, 0, 0, 0, 0, 0, -1) + : (N == 1) ? wasm_i8x16_make(0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1) + : (N == 3) ? wasm_i8x16_make(0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 5) ? wasm_i8x16_make(0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 7) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, + -1, -1, -1) + : (N == 9) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, + -1, -1, -1) + : (N == 11) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1) + : (N == 13) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1) + : wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, const MFromD mask, uint8_t* bits) { + const uint64_t mask_bits = BitsFromMask(d, mask); + const size_t kNumBytes = (d.MaxLanes() + 7) / 8; + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API size_t CountTrue(D d, const MFromD m) { + return PopCount(BitsFromMask(d, m)); +} +template +HWY_API size_t CountTrue(D d, const MFromD m) { + return PopCount(BitsFromMask(d, m)); +} +template +HWY_API size_t CountTrue(D /*d*/, const MFromD m) { + const __i32x4 var_shift = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 shifted_bits = wasm_v128_and(m.raw, var_shift); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, shifted_bits); + return PopCount(lanes[0] | lanes[1]); +} +template +HWY_API size_t CountTrue(D /*d*/, const MFromD m) { + alignas(16) int64_t lanes[2]; + wasm_v128_store(lanes, m.raw); + return static_cast(-(lanes[0] + lanes[1])); +} + +// Partial +template , HWY_IF_V_SIZE_LE_D(D, 8)> +HWY_API size_t CountTrue(D d, MFromD m) { + // Ensure all undefined bytes are 0. + const MFromD mask{detail::BytesAbove()}; + const Full128 dfull; + return CountTrue(dfull, Mask128{AndNot(mask, m).raw}); +} + +// Full vector +template +HWY_API bool AllFalse(D d, const MFromD m) { + const auto v8 = BitCast(Full128(), VecFromMask(d, m)); + return !wasm_v128_any_true(v8.raw); +} + +// Full vector +namespace detail { +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + return wasm_i8x16_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + return wasm_i16x8_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + return wasm_i32x4_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + return wasm_i64x2_all_true(m.raw); +} + +} // namespace detail + +template > +HWY_API bool AllTrue(D /* tag */, const Mask128 m) { + return detail::AllTrue(hwy::SizeTag(), m); +} + +// Partial vectors + +template , HWY_IF_V_SIZE_LE_D(D, 8)> +HWY_API bool AllFalse(D d, const MFromD m) { + // Ensure all undefined bytes are 0. + const MFromD mask{detail::BytesAbove()}; + return AllFalse(Full128(), Mask128{AndNot(mask, m).raw}); +} + +template , HWY_IF_V_SIZE_LE_D(D, 8)> +HWY_API bool AllTrue(D d, const MFromD m) { + // Ensure all undefined bytes are FF. + const MFromD mask{detail::BytesAbove()}; + return AllTrue(Full128(), Mask128{Or(mask, m).raw}); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, const MFromD mask) { + const uint32_t bits = static_cast(BitsFromMask(d, mask)); + return Num0BitsBelowLS1Bit_Nonzero32(bits); +} + +template +HWY_API intptr_t FindFirstTrue(D d, const MFromD mask) { + const uint32_t bits = static_cast(BitsFromMask(d, mask)); + return bits ? static_cast(Num0BitsBelowLS1Bit_Nonzero32(bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, const MFromD mask) { + const uint32_t bits = static_cast(BitsFromMask(d, mask)); + return 31 - Num0BitsAboveMS1Bit_Nonzero32(bits); +} + +template +HWY_API intptr_t FindLastTrue(D d, const MFromD mask) { + const uint32_t bits = static_cast(BitsFromMask(d, mask)); + return bits + ? (31 - static_cast(Num0BitsAboveMS1Bit_Nonzero32(bits))) + : -1; +} + +// ------------------------------ Compress + +namespace detail { + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +// Helper functions called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. + +template +HWY_INLINE Vec128 Compress(Vec128 v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromBits(mask_bits); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_INLINE Vec128 CompressNot(Vec128 v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromNotBits(mask_bits); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +template +struct CompressIsPartition { +#if HWY_TARGET == HWY_WASM_EMU256 + enum { value = 0 }; +#else + enum { value = (sizeof(T) != 1) }; +#endif +}; + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return detail::Compress(v, BitsFromMask(d, mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, BitsFromMask(d, Not(mask))); + } + return detail::CompressNot(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, mask); + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; // so we can support fp16/bf16 + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + const VFromD compressed = + detail::Compress(BitCast(du, v), mask_bits); + const MFromD store_mask = RebindMask(d, FirstN(du, count)); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kN = MaxLanes(d); + CopyBytes<(kN + 7) / 8>(bits, &mask_bits); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Additional mask logical operations +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const FixedTag d; + const auto vmask = VecFromMask(d, mask); + return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask))); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Simd d; + const auto vmask = VecFromMask(d, mask); + const auto neg_vmask = + ResizeBitCast(d, Neg(ResizeBitCast(Full64(), vmask))); + return MaskFromVec(Or(vmask, neg_vmask)); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Full128 d; + const Repartition di64; + + auto vmask = BitCast(di64, VecFromMask(d, mask)); + vmask = Or(vmask, Neg(vmask)); + + // Copy the sign bit of the first int64_t lane to the second int64_t lane + const auto vmask2 = BroadcastSignBit(InterleaveLower(Zero(di64), vmask)); + return MaskFromVec(BitCast(d, Or(vmask, vmask2))); +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const FixedTag d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto zero = Zero(di); + const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero); + return MaskFromVec(BitCast(d, And(vmask, vmask2))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Simd d; + const RebindToSigned di; + + const auto vmask = ResizeBitCast(Full64(), VecFromMask(d, mask)); + const auto only_first_vmask = + BitCast(d, Neg(ResizeBitCast(di, And(vmask, Neg(vmask))))); + return MaskFromVec(only_first_vmask); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Full128 d; + const RebindToSigned di; + const Repartition di64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + const auto vmask2 = VecFromMask(di64, InterleaveLower(zero, vmask) == zero); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 /*mask*/) { + const FixedTag d; + const RebindToSigned di; + using TI = MakeSigned; + + return RebindMask(d, MaskFromVec(Set(di, TI(-1)))); +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + const Simd d; + return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); +} + +// ------------------------------ MulEven/Odd (Load) + +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 0)), + static_cast(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); + return Load(Full128(), mul); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 1)), + static_cast(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); + return Load(Full128(), mul); +} + +// ------------------------------ I64/U64 MulHigh (GetLane) +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Set(Full64(), hi); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi_0; + T hi_1; + Mul128(GetLane(a), GetLane(b), &hi_0); + Mul128(detail::ExtractLane<1>(a), detail::ExtractLane<1>(b), &hi_1); + return Dup128VecFromValues(Full128(), hi_0, hi_1); +} + +// ------------------------------ WidenMulPairwiseAdd (MulAdd, PromoteEvenTo) + +// Generic for all vector lengths. +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} + +// Even if N=1, the input is always at least 2 lanes, hence i32x4_dot_i16x8 is +// safe. +template >> +HWY_API VFromD WidenMulPairwiseAdd(D32 /* tag */, V16 a, V16 b) { + return VFromD{wasm_i32x4_dot_i16x8(a.raw, b.raw)}; +} + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DU32 du32, VU16 a, VU16 b) { + return MulAdd(PromoteEvenTo(du32, a), PromoteEvenTo(du32, b), + Mul(PromoteOddTo(du32, a), PromoteOddTo(du32, b))); +} + +// ------------------------------ ReorderWidenMulAccumulate + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(D32 d32, V16 a, V16 b, + const VFromD sum0, + VFromD& /*sum1*/) { + return sum0 + WidenMulPairwiseAdd(d32, a, b); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec128 RearrangeToOddPlusEven( + const Vec128 sum0, const Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API Vec128 RearrangeToOddPlusEven( + const Vec128 sum0, const Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, + const Vec128 sum1) { + return Add(sum0, sum1); +} + +// ------------------------------ Reductions + +// Nothing native, generic_ops-inl defines SumOfLanes and ReduceSum. + +// ------------------------------ Lt128 + +template +HWY_INLINE MFromD Lt128(D d, VFromD a, VFromD b) { + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const MFromD eqHL = Eq(a, b); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const VFromD ltLx = DupEven(ltHL); + const VFromD outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template +HWY_INLINE MFromD Lt128Upper(D d, VFromD a, VFromD b) { + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template +HWY_INLINE MFromD Eq128(D d, VFromD a, VFromD b) { + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template +HWY_INLINE MFromD Eq128Upper(D d, VFromD a, VFromD b) { + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template +HWY_INLINE MFromD Ne128(D d, VFromD a, VFromD b) { + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template +HWY_INLINE MFromD Ne128Upper(D d, VFromD a, VFromD b) { + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_INLINE VFromD Min128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/aom/third_party/highway/hwy/ops/wasm_256-inl.h b/third_party/aom/third_party/highway/hwy/ops/wasm_256-inl.h new file mode 100644 index 000000000000..e81f33f3ab76 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/wasm_256-inl.h @@ -0,0 +1,2519 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit WASM vectors and operations. Experimental. +// External include guard in highway.h - see comment there. + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "third_party/highway/hwy/ops/wasm_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +class Vec256 { + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator%=(const Vec256 other) { + return *this = (*this % other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Vec128 v0; + Vec128 v1; +}; + +template +struct Mask256 { + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromM + + Mask128 m0; + Mask128 m1; +}; + +// ------------------------------ Zero + +// Avoid VFromD here because it is defined in terms of Zero. +template +HWY_API Vec256> Zero(D d) { + const Half dh; + Vec256> ret; + ret.v0 = ret.v1 = Zero(dh); + return ret; +} + +// ------------------------------ BitCast +template +HWY_API VFromD BitCast(D d, Vec256 v) { + const Half dh; + VFromD ret; + ret.v0 = BitCast(dh, v.v0); + ret.v1 = BitCast(dh, v.v1); + return ret; +} + +// ------------------------------ ResizeBitCast + +// 32-byte vector to 32-byte vector: Same as BitCast +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, v); +} + +// <= 16-byte vector to 32-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Half dh; + VFromD ret; + ret.v0 = ResizeBitCast(dh, v); + ret.v1 = Zero(dh); + return ret; +} + +// 32-byte vector to <= 16-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return ResizeBitCast(d, v.v0); +} + +// ------------------------------ Set +template +HWY_API VFromD Set(D d, const T2 t) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Set(dh, static_cast>(t)); + return ret; +} + +// Undefined, Iota defined in wasm_128. + +// ------------------------------ Dup128VecFromValues +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7, t8, + t9, t10, t11, t12, t13, t14, t15); + return ret; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7); + return ret; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3); + return ret; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1); + return ret; +} + +// ================================================== ARITHMETIC + +template +HWY_API Vec256 operator+(Vec256 a, const Vec256 b) { + a.v0 += b.v0; + a.v1 += b.v1; + return a; +} + +template +HWY_API Vec256 operator-(Vec256 a, const Vec256 b) { + a.v0 -= b.v0; + a.v1 -= b.v1; + return a; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(const Vec256 v) { + Vec256 ret; + ret.v0 = SumsOf8(v.v0); + ret.v1 = SumsOf8(v.v1); + return ret; +} + +HWY_API Vec256 SumsOf8(const Vec256 v) { + Vec256 ret; + ret.v0 = SumsOf8(v.v0); + ret.v1 = SumsOf8(v.v1); + return ret; +} + +template +HWY_API Vec256 SaturatedAdd(Vec256 a, const Vec256 b) { + a.v0 = SaturatedAdd(a.v0, b.v0); + a.v1 = SaturatedAdd(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 SaturatedSub(Vec256 a, const Vec256 b) { + a.v0 = SaturatedSub(a.v0, b.v0); + a.v1 = SaturatedSub(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 AverageRound(Vec256 a, const Vec256 b) { + a.v0 = AverageRound(a.v0, b.v0); + a.v1 = AverageRound(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Abs(Vec256 v) { + v.v0 = Abs(v.v0); + v.v1 = Abs(v.v1); + return v; +} + +// ------------------------------ Shift lanes by constant #bits + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + v.v0 = ShiftLeft(v.v0); + v.v1 = ShiftLeft(v.v1); + return v; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + v.v0 = ShiftRight(v.v0); + v.v1 = ShiftRight(v.v1); + return v; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec256 RotateRight(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +template +HWY_API Vec256 ShiftLeftSame(Vec256 v, const int bits) { + v.v0 = ShiftLeftSame(v.v0, bits); + v.v1 = ShiftLeftSame(v.v1, bits); + return v; +} + +template +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + v.v0 = ShiftRightSame(v.v0, bits); + v.v1 = ShiftRightSame(v.v1, bits); + return v; +} + +// ------------------------------ Min, Max +template +HWY_API Vec256 Min(Vec256 a, const Vec256 b) { + a.v0 = Min(a.v0, b.v0); + a.v1 = Min(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Max(Vec256 a, const Vec256 b) { + a.v0 = Max(a.v0, b.v0); + a.v1 = Max(a.v1, b.v1); + return a; +} +// ------------------------------ Integer multiplication + +template +HWY_API Vec256 operator*(Vec256 a, const Vec256 b) { + a.v0 *= b.v0; + a.v1 *= b.v1; + return a; +} + +template +HWY_API Vec256 MulHigh(Vec256 a, const Vec256 b) { + a.v0 = MulHigh(a.v0, b.v0); + a.v1 = MulHigh(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 MulFixedPoint15(Vec256 a, const Vec256 b) { + a.v0 = MulFixedPoint15(a.v0, b.v0); + a.v1 = MulFixedPoint15(a.v1, b.v1); + return a; +} + +// Cannot use MakeWide because that returns uint128_t for uint64_t, but we want +// uint64_t. +template +HWY_API Vec256> MulEven(Vec256 a, const Vec256 b) { + Vec256> ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} +template +HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} + +template +HWY_API Vec256> MulOdd(Vec256 a, const Vec256 b) { + Vec256> ret; + ret.v0 = MulOdd(a.v0, b.v0); + ret.v1 = MulOdd(a.v1, b.v1); + return ret; +} +template +HWY_API Vec256 MulOdd(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulOdd(a.v0, b.v0); + ret.v1 = MulOdd(a.v1, b.v1); + return ret; +} + +// ------------------------------ Negate +template +HWY_API Vec256 Neg(Vec256 v) { + v.v0 = Neg(v.v0); + v.v1 = Neg(v.v1); + return v; +} + +// ------------------------------ AbsDiff +// generic_ops takes care of integer T. +template +HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point division +// generic_ops takes care of integer T. +template +HWY_API Vec256 operator/(Vec256 a, const Vec256 b) { + a.v0 /= b.v0; + a.v1 /= b.v1; + return a; +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, Vec256 add) { + mul.v0 = MulAdd(mul.v0, x.v0, add.v0); + mul.v1 = MulAdd(mul.v1, x.v1, add.v1); + return mul; +} + +template +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, Vec256 add) { + mul.v0 = NegMulAdd(mul.v0, x.v0, add.v0); + mul.v1 = NegMulAdd(mul.v1, x.v1, add.v1); + return mul; +} + +template +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, Vec256 sub) { + mul.v0 = MulSub(mul.v0, x.v0, sub.v0); + mul.v1 = MulSub(mul.v1, x.v1, sub.v1); + return mul; +} + +template +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, Vec256 sub) { + mul.v0 = NegMulSub(mul.v0, x.v0, sub.v0); + mul.v1 = NegMulSub(mul.v1, x.v1, sub.v1); + return mul; +} + +// ------------------------------ Floating-point square root + +template +HWY_API Vec256 Sqrt(Vec256 v) { + v.v0 = Sqrt(v.v0); + v.v1 = Sqrt(v.v1); + return v; +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +template +HWY_API Vec256 Round(Vec256 v) { + v.v0 = Round(v.v0); + v.v1 = Round(v.v1); + return v; +} + +// Toward zero, aka truncate +template +HWY_API Vec256 Trunc(Vec256 v) { + v.v0 = Trunc(v.v0); + v.v1 = Trunc(v.v1); + return v; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec256 Ceil(Vec256 v) { + v.v0 = Ceil(v.v0); + v.v1 = Ceil(v.v1); + return v; +} + +// Toward -infinity, aka floor +template +HWY_API Vec256 Floor(Vec256 v) { + v.v0 = Floor(v.v0); + v.v1 = Floor(v.v1); + return v; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask256 IsNaN(const Vec256 v) { + return v != v; +} + +template +HWY_API Mask256 IsInf(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vu, vu), Set(du, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask256 IsFinite(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template > +HWY_API MFromD RebindMask(DTo /*tag*/, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return MFromD{Mask128{m.m0.raw}, Mask128{m.m1.raw}}; +} + +template +HWY_API Mask256 TestBit(Vec256 v, Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask256 operator==(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator==(a.v0, b.v0); + m.m1 = operator==(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator!=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator!=(a.v0, b.v0); + m.m1 = operator!=(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator<(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator<(a.v0, b.v0); + m.m1 = operator<(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator>(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator>(a.v0, b.v0); + m.m1 = operator>(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator<=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator<=(a.v0, b.v0); + m.m1 = operator<=(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator>=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator>=(a.v0, b.v0); + m.m1 = operator>=(a.v1, b.v1); + return m; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API MFromD FirstN(const D d, size_t num) { + const RebindToSigned di; // Signed comparisons may be cheaper. + using TI = TFromD; + return RebindMask(d, Iota(di, 0) < Set(di, static_cast(num))); +} + +// ================================================== LOGICAL + +template +HWY_API Vec256 Not(Vec256 v) { + v.v0 = Not(v.v0); + v.v1 = Not(v.v1); + return v; +} + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + a.v0 = And(a.v0, b.v0); + a.v1 = And(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + not_mask.v0 = AndNot(not_mask.v0, mask.v0); + not_mask.v1 = AndNot(not_mask.v1, mask.v1); + return not_mask; +} + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + a.v0 = Or(a.v0, b.v0); + a.v1 = Or(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + a.v0 = Xor(a.v0, b.v0); + a.v1 = Xor(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { + return Xor(x1, Xor(x2, x3)); +} + +template +HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { + return Or(o1, Or(o2, o3)); +} + +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { + return Or(o, And(a1, a2)); +} + +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign +template +HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + Mask256 m; + m.m0 = MaskFromVec(v.v0); + m.m1 = MaskFromVec(v.v1); + return m; +} + +template > +HWY_API Vec256 VecFromMask(D d, Mask256 m) { + const Half dh; + Vec256 v; + v.v0 = VecFromMask(dh, m.m0); + v.v1 = VecFromMask(dh, m.m1); + return v; +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD m) { + const Half dh; + const uint64_t lo = BitsFromMask(dh, m.m0); + const uint64_t hi = BitsFromMask(dh, m.m1); + return (hi << Lanes(dh)) | lo; +} + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + yes.v0 = IfThenElse(mask.m0, yes.v0, no.v0); + yes.v1 = IfThenElse(mask.m1, yes.v1, no.v1); + return yes; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + v.v0 = IfNegativeThenElse(v.v0, yes.v0, no.v0); + v.v1 = IfNegativeThenElse(v.v1, yes.v1, no.v1); + return v; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + return MaskFromVec(Not(VecFromMask(Full256(), m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) +template +HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { + v.v0 = operator<<(v.v0, bits.v0); + v.v1 = operator<<(v.v1, bits.v1); + return v; +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) +template +HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { + v.v0 = operator>>(v.v0, bits.v0); + v.v1 = operator>>(v.v1, bits.v1); + return v; +} + +// ------------------------------ BroadcastSignBit (compare, VecFromMask) + +template +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight(v); +} +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + const DFromV d; + return VecFromMask(d, v < Zero(d)); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { + const Half dh; + VFromD ret; + ret.v0 = Load(dh, aligned); + ret.v1 = Load(dh, aligned + Lanes(dh)); + return ret; +} + +template > +HWY_API Vec256 MaskedLoad(Mask256 m, D d, const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template > +HWY_API Vec256 MaskedLoadOr(Vec256 v, Mask256 m, D d, + const T* HWY_RESTRICT aligned) { + return IfThenElse(m, Load(d, aligned), v); +} + +// LoadU == Load. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Load(dh, p); + return ret; +} + +// ------------------------------ Store + +template > +HWY_API void Store(Vec256 v, D d, T* HWY_RESTRICT aligned) { + const Half dh; + Store(v.v0, dh, aligned); + Store(v.v1, dh, aligned + Lanes(dh)); +} + +// StoreU == Store. +template > +HWY_API void StoreU(Vec256 v, D d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template > +HWY_API void BlendedStore(Vec256 v, Mask256 m, D d, T* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Stream +template > +HWY_API void Stream(Vec256 v, D d, T* HWY_RESTRICT aligned) { + // Same as aligned stores. + Store(v, d, aligned); +} + +// ------------------------------ Scatter, Gather defined in wasm_128 + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane +template +HWY_API T ExtractLane(const Vec256 v, size_t i) { + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane +template +HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { + DFromV d; + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ ExtractBlock +template +HWY_API Vec128 ExtractBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + return (kBlockIdx == 0) ? v.v0 : v.v1; +} + +// ------------------------------ InsertBlock +template +HWY_API Vec256 InsertBlock(Vec256 v, Vec128 blk_to_insert) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + Vec256 result; + if (kBlockIdx == 0) { + result.v0 = blk_to_insert; + result.v1 = v.v1; + } else { + result.v0 = v.v0; + result.v1 = blk_to_insert; + } + return result; +} + +// ------------------------------ BroadcastBlock +template +HWY_API Vec256 BroadcastBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + Vec256 result; + result.v0 = result.v1 = (kBlockIdx == 0 ? v.v0 : v.v1); + return result; +} + +// ------------------------------ LowerHalf + +template > +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { + return v.v0; +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + return v.v0; +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ShiftLeftBytes + +template > +HWY_API Vec256 ShiftLeftBytes(D d, Vec256 v) { + const Half dh; + v.v0 = ShiftLeftBytes(dh, v.v0); + v.v1 = ShiftLeftBytes(dh, v.v1); + return v; +} + +template +HWY_API Vec256 ShiftLeftBytes(Vec256 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API Vec256 ShiftLeftLanes(D d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API Vec256 ShiftRightBytes(D d, Vec256 v) { + const Half dh; + v.v0 = ShiftRightBytes(dh, v.v0); + v.v1 = ShiftRightBytes(dh, v.v1); + return v; +} + +// ------------------------------ ShiftRightLanes +template > +HWY_API Vec256 ShiftRightLanes(D d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) +template > +HWY_API Vec128 UpperHalf(D /* tag */, const Vec256 v) { + return v.v1; +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API Vec256 CombineShiftRightBytes(D d, Vec256 hi, Vec256 lo) { + const Half dh; + hi.v0 = CombineShiftRightBytes(dh, hi.v0, lo.v0); + hi.v1 = CombineShiftRightBytes(dh, hi.v1, lo.v1); + return hi; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + Vec256 ret; + ret.v0 = Broadcast(v.v0); + ret.v1 = Broadcast(v.v1); + return ret; +} + +template +HWY_API Vec256 BroadcastLane(const Vec256 v) { + constexpr int kLanesPerBlock = static_cast(16 / sizeof(T)); + static_assert(0 <= kLane && kLane < kLanesPerBlock * 2, "Invalid lane"); + constexpr int kLaneInBlkIdx = kLane & (kLanesPerBlock - 1); + Vec256 ret; + ret.v0 = ret.v1 = + Broadcast(kLane >= kLanesPerBlock ? v.v1 : v.v0); + return ret; +} + +// ------------------------------ TableLookupBytes + +// Both full +template +HWY_API Vec256 TableLookupBytes(const Vec256 bytes, Vec256 from) { + from.v0 = TableLookupBytes(bytes.v0, from.v0); + from.v1 = TableLookupBytes(bytes.v1, from.v1); + return from; +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec256 bytes, + const Vec128 from) { + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128{LowerHalf(Full128(), tbl_full).raw}; +} + +// Partial table vector +template +HWY_API Vec256 TableLookupBytes(Vec128 bytes, const Vec256 from) { + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by wasm_128. + +template +HWY_API VI TableLookupBytesOr0(V bytes, VI from) { + // wasm out-of-bounds policy already zeros, so TableLookupBytes is fine. + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Hard-coded shuffles + +template +HWY_API Vec256 Shuffle01(Vec256 v) { + v.v0 = Shuffle01(v.v0); + v.v1 = Shuffle01(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle2301(Vec256 v) { + v.v0 = Shuffle2301(v.v0); + v.v1 = Shuffle2301(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle1032(Vec256 v) { + v.v0 = Shuffle1032(v.v0); + v.v1 = Shuffle1032(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle0321(Vec256 v) { + v.v0 = Shuffle0321(v.v0); + v.v1 = Shuffle0321(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle2103(Vec256 v) { + v.v0 = Shuffle2103(v.v0); + v.v1 = Shuffle2103(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle0123(Vec256 v) { + v.v0 = Shuffle0123(v.v0); + v.v1 = Shuffle0123(v.v1); + return v; +} + +// Used by generic_ops-inl.h +namespace detail { + +template +HWY_API Vec256 ShuffleTwo2301(Vec256 a, const Vec256 b) { + a.v0 = ShuffleTwo2301(a.v0, b.v0); + a.v1 = ShuffleTwo2301(a.v1, b.v1); + return a; +} +template +HWY_API Vec256 ShuffleTwo1230(Vec256 a, const Vec256 b) { + a.v0 = ShuffleTwo1230(a.v0, b.v0); + a.v1 = ShuffleTwo1230(a.v1, b.v1); + return a; +} +template +HWY_API Vec256 ShuffleTwo3012(Vec256 a, const Vec256 b) { + a.v0 = ShuffleTwo3012(a.v0, b.v0); + a.v1 = ShuffleTwo3012(a.v1, b.v1); + return a; +} + +} // namespace detail + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices256 { + __v128_u i0; + __v128_u i1; +}; + +template , typename TI> +HWY_API Indices256 IndicesFromVec(D /* tag */, Vec256 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); + Indices256 ret; + ret.i0 = vec.v0.raw; + ret.i1 = vec.v1.raw; + return ret; +} + +template +HWY_API Indices256> SetTableIndices(D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(const Vec256 v, Indices256 idx) { + const DFromV d; + const Half dh; + const auto idx_i0 = IndicesFromVec(dh, Vec128{idx.i0}); + const auto idx_i1 = IndicesFromVec(dh, Vec128{idx.i1}); + + Vec256 result; + result.v0 = TwoTablesLookupLanes(v.v0, v.v1, idx_i0); + result.v1 = TwoTablesLookupLanes(v.v0, v.v1, idx_i1); + return result; +} + +template +HWY_API Vec256 TableLookupLanesOr0(Vec256 v, Indices256 idx) { + // The out of bounds behavior will already zero lanes. + return TableLookupLanesOr0(v, idx); +} + +template +HWY_API Vec256 TwoTablesLookupLanes(const Vec256 a, const Vec256 b, + Indices256 idx) { + const DFromV d; + const Half dh; + const RebindToUnsigned du; + using TU = MakeUnsigned; + constexpr size_t kLanesPerVect = 32 / sizeof(TU); + + Vec256 vi; + vi.v0 = Vec128{idx.i0}; + vi.v1 = Vec128{idx.i1}; + const auto vmod = vi & Set(du, TU{kLanesPerVect - 1}); + const auto is_lo = RebindMask(d, vi == vmod); + + const auto idx_i0 = IndicesFromVec(dh, vmod.v0); + const auto idx_i1 = IndicesFromVec(dh, vmod.v1); + + Vec256 result_lo; + Vec256 result_hi; + result_lo.v0 = TwoTablesLookupLanes(a.v0, a.v1, idx_i0); + result_lo.v1 = TwoTablesLookupLanes(a.v0, a.v1, idx_i1); + result_hi.v0 = TwoTablesLookupLanes(b.v0, b.v1, idx_i0); + result_hi.v1 = TwoTablesLookupLanes(b.v0, b.v1, idx_i1); + return IfThenElse(is_lo, result_lo, result_hi); +} + +// ------------------------------ Reverse +template > +HWY_API Vec256 Reverse(D d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v1 = Reverse(dh, v.v0); // note reversed v1 member order + ret.v0 = Reverse(dh, v.v1); + return ret; +} + +// ------------------------------ Reverse2 +template > +HWY_API Vec256 Reverse2(D d, Vec256 v) { + const Half dh; + v.v0 = Reverse2(dh, v.v0); + v.v1 = Reverse2(dh, v.v1); + return v; +} + +// ------------------------------ Reverse4 + +// Each block has only 2 lanes, so swap blocks and their lanes. +template , HWY_IF_T_SIZE(T, 8)> +HWY_API Vec256 Reverse4(D d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = Reverse2(dh, v.v1); // swapped + ret.v1 = Reverse2(dh, v.v0); + return ret; +} + +template , HWY_IF_NOT_T_SIZE(T, 8)> +HWY_API Vec256 Reverse4(D d, Vec256 v) { + const Half dh; + v.v0 = Reverse4(dh, v.v0); + v.v1 = Reverse4(dh, v.v1); + return v; +} + +// ------------------------------ Reverse8 + +template , HWY_IF_T_SIZE(T, 8)> +HWY_API Vec256 Reverse8(D /* tag */, Vec256 /* v */) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// Each block has only 4 lanes, so swap blocks and their lanes. +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec256 Reverse8(D d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = Reverse4(dh, v.v1); // swapped + ret.v1 = Reverse4(dh, v.v0); + return ret; +} + +template , + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))> +HWY_API Vec256 Reverse8(D d, Vec256 v) { + const Half dh; + v.v0 = Reverse8(dh, v.v0); + v.v1 = Reverse8(dh, v.v1); + return v; +} + +// ------------------------------ InterleaveLower + +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + a.v0 = InterleaveLower(a.v0, b.v0); + a.v1 = InterleaveLower(a.v1, b.v1); + return a; +} + +// wasm_128 already defines a template with D, V, V args. + +// ------------------------------ InterleaveUpper (UpperHalf) + +template > +HWY_API Vec256 InterleaveUpper(D d, Vec256 a, Vec256 b) { + const Half dh; + a.v0 = InterleaveUpper(dh, a.v0, b.v0); + a.v1 = InterleaveUpper(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ InterleaveWholeLower +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const Half dh; + VFromD ret; + ret.v0 = InterleaveLower(a.v0, b.v0); + ret.v1 = InterleaveUpper(dh, a.v0, b.v0); + return ret; +} + +// ------------------------------ InterleaveWholeUpper +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const Half dh; + VFromD ret; + ret.v0 = InterleaveLower(a.v1, b.v1); + ret.v1 = InterleaveUpper(dh, a.v1, b.v1); + return ret; +} + +// ------------------------------ ZipLower/ZipUpper defined in wasm_128 + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) +template > +HWY_API Vec256 Combine(D /* d */, Vec128 hi, Vec128 lo) { + Vec256 ret; + ret.v1 = hi; + ret.v0 = lo; + return ret; +} + +// ------------------------------ ZeroExtendVector (Combine) +template > +HWY_API Vec256 ZeroExtendVector(D d, Vec128 lo) { + const Half dh; + return Combine(d, Zero(dh), lo); +} + +// ------------------------------ ZeroExtendResizeBitCast + +namespace detail { + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag<32> /* to_size_tag */, DTo d_to, DFrom d_from, + VFromD v) { + const Half dh_to; + return ZeroExtendVector(d_to, ZeroExtendResizeBitCast(dh_to, d_from, v)); +} + +} // namespace detail + +// ------------------------------ ConcatLowerLower +template > +HWY_API Vec256 ConcatLowerLower(D /* tag */, Vec256 hi, Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v0; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatUpperUpper +template > +HWY_API Vec256 ConcatUpperUpper(D /* tag */, Vec256 hi, Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v1; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatLowerUpper +template > +HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v0; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatUpperLower +template > +HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v1; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatOdd +template > +HWY_API Vec256 ConcatOdd(D d, Vec256 hi, Vec256 lo) { + const Half dh; + Vec256 ret; + ret.v0 = ConcatOdd(dh, lo.v1, lo.v0); + ret.v1 = ConcatOdd(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ ConcatEven +template > +HWY_API Vec256 ConcatEven(D d, Vec256 hi, Vec256 lo) { + const Half dh; + Vec256 ret; + ret.v0 = ConcatEven(dh, lo.v1, lo.v0); + ret.v1 = ConcatEven(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ DupEven +template +HWY_API Vec256 DupEven(Vec256 v) { + v.v0 = DupEven(v.v0); + v.v1 = DupEven(v.v1); + return v; +} + +// ------------------------------ DupOdd +template +HWY_API Vec256 DupOdd(Vec256 v) { + v.v0 = DupOdd(v.v0); + v.v1 = DupOdd(v.v1); + return v; +} + +// ------------------------------ OddEven +template +HWY_API Vec256 OddEven(Vec256 a, const Vec256 b) { + a.v0 = OddEven(a.v0, b.v0); + a.v1 = OddEven(a.v1, b.v1); + return a; +} + +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Half dh; + a.v0 = InterleaveEven(dh, a.v0, b.v0); + a.v1 = InterleaveEven(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Half dh; + a.v0 = InterleaveOdd(dh, a.v0, b.v0); + a.v1 = InterleaveOdd(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + odd.v0 = even.v0; + return odd; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + Vec256 ret; + ret.v0 = v.v1; // swapped order + ret.v1 = v.v0; + return ret; +} + +// ------------------------------ InterleaveEvenBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveEvenBlocks(D, V a, V b) { + V ret; + ret.v0 = a.v0; + ret.v1 = b.v0; + return ret; +} +// ------------------------------ InterleaveOddBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveOddBlocks(D, V a, V b) { + V ret; + ret.v0 = a.v1; + ret.v1 = b.v1; + return ret; +} + +// ------------------------------ ReverseBlocks +template > +HWY_API Vec256 ReverseBlocks(D /* tag */, const Vec256 v) { + return SwapAdjacentBlocks(v); // 2 blocks, so Swap = Reverse +} + +// ------------------------------ Per4LaneBlockShuffle +namespace detail { + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const Half dh; + using VH = VFromD; + + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + + V ret; + ret.v0 = VH{wasm_i8x16_shuffle( + v.v0.raw, v.v0.raw, kIdx0, kIdx1, kIdx2, kIdx3, kIdx0 + 4, kIdx1 + 4, + kIdx2 + 4, kIdx3 + 4, kIdx0 + 8, kIdx1 + 8, kIdx2 + 8, kIdx3 + 8, + kIdx0 + 12, kIdx1 + 12, kIdx2 + 12, kIdx3 + 12)}; + ret.v1 = VH{wasm_i8x16_shuffle( + v.v1.raw, v.v1.raw, kIdx0, kIdx1, kIdx2, kIdx3, kIdx0 + 4, kIdx1 + 4, + kIdx2 + 4, kIdx3 + 4, kIdx0 + 8, kIdx1 + 8, kIdx2 + 8, kIdx3 + 8, + kIdx0 + 12, kIdx1 + 12, kIdx2 + 12, kIdx3 + 12)}; + return ret; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const Half dh; + using VH = VFromD; + + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + + V ret; + ret.v0 = VH{wasm_i16x8_shuffle(v.v0.raw, v.v0.raw, kIdx0, kIdx1, kIdx2, kIdx3, + kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4)}; + ret.v1 = VH{wasm_i16x8_shuffle(v.v1.raw, v.v1.raw, kIdx0, kIdx1, kIdx2, kIdx3, + kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4)}; + return ret; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const Half dh; + using VH = VFromD; + + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + + V ret; + ret.v0 = + VH{wasm_i32x4_shuffle(v.v0.raw, v.v0.raw, kIdx0, kIdx1, kIdx2, kIdx3)}; + ret.v1 = + VH{wasm_i32x4_shuffle(v.v1.raw, v.v1.raw, kIdx0, kIdx1, kIdx2, kIdx3)}; + return ret; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const Half dh; + using VH = VFromD; + + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + + V ret; + ret.v0 = VH{wasm_i64x2_shuffle(v.v0.raw, v.v1.raw, kIdx0, kIdx1)}; + ret.v1 = VH{wasm_i64x2_shuffle(v.v0.raw, v.v1.raw, kIdx2, kIdx3)}; + return ret; +} + +} // namespace detail + +// ------------------------------ SlideUpBlocks +template +HWY_API VFromD SlideUpBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + return (kBlocks == 1) ? ConcatLowerLower(d, v, Zero(d)) : v; +} + +// ------------------------------ SlideDownBlocks +template +HWY_API VFromD SlideDownBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + const Half dh; + return (kBlocks == 1) ? ZeroExtendVector(d, UpperHalf(dh, v)) : v; +} + +// ------------------------------ SlideUpLanes + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + const Half dh; + const RebindToUnsigned du; + const RebindToUnsigned dh_u; + const auto vu = BitCast(du, v); + VFromD ret; + +#if !HWY_IS_DEBUG_BUILD + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + if (__builtin_constant_p(amt) && amt < kLanesPerBlock) { + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + ret.v0 = BitCast(dh, ShiftLeftBytes<1>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<15>(dh_u, vu.v1, vu.v0)); + return ret; + case 2: + ret.v0 = BitCast(dh, ShiftLeftBytes<2>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<14>(dh_u, vu.v1, vu.v0)); + return ret; + case 3: + ret.v0 = BitCast(dh, ShiftLeftBytes<3>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<13>(dh_u, vu.v1, vu.v0)); + return ret; + case 4: + ret.v0 = BitCast(dh, ShiftLeftBytes<4>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<12>(dh_u, vu.v1, vu.v0)); + return ret; + case 5: + ret.v0 = BitCast(dh, ShiftLeftBytes<5>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<11>(dh_u, vu.v1, vu.v0)); + return ret; + case 6: + ret.v0 = BitCast(dh, ShiftLeftBytes<6>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<10>(dh_u, vu.v1, vu.v0)); + return ret; + case 7: + ret.v0 = BitCast(dh, ShiftLeftBytes<7>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<9>(dh_u, vu.v1, vu.v0)); + return ret; + case 8: + ret.v0 = BitCast(dh, ShiftLeftBytes<8>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<8>(dh_u, vu.v1, vu.v0)); + return ret; + case 9: + ret.v0 = BitCast(dh, ShiftLeftBytes<9>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<7>(dh_u, vu.v1, vu.v0)); + return ret; + case 10: + ret.v0 = BitCast(dh, ShiftLeftBytes<10>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<6>(dh_u, vu.v1, vu.v0)); + return ret; + case 11: + ret.v0 = BitCast(dh, ShiftLeftBytes<11>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<5>(dh_u, vu.v1, vu.v0)); + return ret; + case 12: + ret.v0 = BitCast(dh, ShiftLeftBytes<12>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<4>(dh_u, vu.v1, vu.v0)); + return ret; + case 13: + ret.v0 = BitCast(dh, ShiftLeftBytes<13>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<3>(dh_u, vu.v1, vu.v0)); + return ret; + case 14: + ret.v0 = BitCast(dh, ShiftLeftBytes<14>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<2>(dh_u, vu.v1, vu.v0)); + return ret; + case 15: + ret.v0 = BitCast(dh, ShiftLeftBytes<15>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<1>(dh_u, vu.v1, vu.v0)); + return ret; + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + ret.v0 = Zero(dh); + ret.v1 = SlideUpLanes(dh, LowerHalf(dh, v), amt - kLanesPerBlock); + return ret; + } +#endif + + const Repartition du8; + const RebindToSigned di8; + const Half dh_i8; + + const auto lo_byte_idx = BitCast( + di8, + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromD)))); + + const auto hi_byte_idx = + UpperHalf(dh_i8, lo_byte_idx) - Set(dh_i8, int8_t{16}); + const auto hi_sel_mask = + UpperHalf(dh_i8, lo_byte_idx) > Set(dh_i8, int8_t{15}); + + ret = BitCast(d, + TableLookupBytesOr0(ConcatLowerLower(du, vu, vu), lo_byte_idx)); + ret.v1 = + BitCast(dh, IfThenElse(hi_sel_mask, + TableLookupBytes(UpperHalf(dh_u, vu), hi_byte_idx), + BitCast(dh_i8, ret.v1))); + return ret; +} + +// ------------------------------ Slide1Up +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + VFromD ret; + const Half dh; + constexpr int kShrByteAmt = static_cast(16 - sizeof(TFromD)); + ret.v0 = ShiftLeftLanes<1>(dh, v.v0); + ret.v1 = CombineShiftRightBytes(dh, v.v1, v.v0); + return ret; +} + +// ------------------------------ SlideDownLanes + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + const Half dh; + const RebindToUnsigned du; + const RebindToUnsigned dh_u; + VFromD ret; + + const auto vu = BitCast(du, v); + +#if !HWY_IS_DEBUG_BUILD + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + if (__builtin_constant_p(amt) && amt < kLanesPerBlock) { + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + ret.v0 = BitCast(dh, CombineShiftRightBytes<1>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<1>(dh_u, vu.v1)); + return ret; + case 2: + ret.v0 = BitCast(dh, CombineShiftRightBytes<2>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<2>(dh_u, vu.v1)); + return ret; + case 3: + ret.v0 = BitCast(dh, CombineShiftRightBytes<3>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<3>(dh_u, vu.v1)); + return ret; + case 4: + ret.v0 = BitCast(dh, CombineShiftRightBytes<4>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<4>(dh_u, vu.v1)); + return ret; + case 5: + ret.v0 = BitCast(dh, CombineShiftRightBytes<5>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<5>(dh_u, vu.v1)); + return ret; + case 6: + ret.v0 = BitCast(dh, CombineShiftRightBytes<6>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<6>(dh_u, vu.v1)); + return ret; + case 7: + ret.v0 = BitCast(dh, CombineShiftRightBytes<7>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<7>(dh_u, vu.v1)); + return ret; + case 8: + ret.v0 = BitCast(dh, CombineShiftRightBytes<8>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<8>(dh_u, vu.v1)); + return ret; + case 9: + ret.v0 = BitCast(dh, CombineShiftRightBytes<9>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<9>(dh_u, vu.v1)); + return ret; + case 10: + ret.v0 = BitCast(dh, CombineShiftRightBytes<10>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<10>(dh_u, vu.v1)); + return ret; + case 11: + ret.v0 = BitCast(dh, CombineShiftRightBytes<11>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<11>(dh_u, vu.v1)); + return ret; + case 12: + ret.v0 = BitCast(dh, CombineShiftRightBytes<12>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<12>(dh_u, vu.v1)); + return ret; + case 13: + ret.v0 = BitCast(dh, CombineShiftRightBytes<13>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<13>(dh_u, vu.v1)); + return ret; + case 14: + ret.v0 = BitCast(dh, CombineShiftRightBytes<14>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<14>(dh_u, vu.v1)); + return ret; + case 15: + ret.v0 = BitCast(dh, CombineShiftRightBytes<15>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<15>(dh_u, vu.v1)); + return ret; + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + ret.v0 = SlideDownLanes(dh, UpperHalf(dh, v), amt - kLanesPerBlock); + ret.v1 = Zero(dh); + return ret; + } +#endif + + const Repartition du8; + const Half dh_u8; + + const auto lo_byte_idx = + Iota(du8, static_cast(amt * sizeof(TFromD))); + const auto u8_16 = Set(du8, uint8_t{16}); + const auto hi_byte_idx = lo_byte_idx - u8_16; + + const auto lo_sel_mask = + LowerHalf(dh_u8, lo_byte_idx) < LowerHalf(dh_u8, u8_16); + ret = BitCast(d, IfThenElseZero(hi_byte_idx < u8_16, + TableLookupBytes(ConcatUpperUpper(du, vu, vu), + hi_byte_idx))); + ret.v0 = + BitCast(dh, IfThenElse(lo_sel_mask, + TableLookupBytes(LowerHalf(dh_u, vu), + LowerHalf(dh_u8, lo_byte_idx)), + BitCast(dh_u8, LowerHalf(dh, ret)))); + return ret; +} + +// ------------------------------ Slide1Down +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + VFromD ret; + const Half dh; + constexpr int kShrByteAmt = static_cast(sizeof(TFromD)); + ret.v0 = CombineShiftRightBytes(dh, v.v1, v.v0); + ret.v1 = ShiftRightBytes(dh, v.v1); + return ret; +} + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +template +HWY_API VFromD PromoteTo(D d, Vec128 v) { + const Half dh; + VFromD ret; + // PromoteLowerTo is defined later in generic_ops-inl.h. + ret.v0 = PromoteTo(dh, LowerHalf(v)); + ret.v1 = PromoteUpperTo(dh, v); + return ret; +} + +// 4x promotion: 8-bit to 32-bit or 16-bit to 64-bit +template +HWY_API Vec256> PromoteTo(DW d, Vec64 v) { + const Half dh; + // 16-bit lanes for UI8->UI32, 32-bit lanes for UI16->UI64 + const Rebind, decltype(d)> d2; + const auto v_2x = PromoteTo(d2, v); + Vec256> ret; + // PromoteLowerTo is defined later in generic_ops-inl.h. + ret.v0 = PromoteTo(dh, LowerHalf(v_2x)); + ret.v1 = PromoteUpperTo(dh, v_2x); + return ret; +} + +// 8x promotion: 8-bit to 64-bit +template +HWY_API Vec256> PromoteTo(DW d, Vec32 v) { + const Half dh; + const Repartition>, decltype(dh)> d4; // 32-bit lanes + const auto v32 = PromoteTo(d4, v); + Vec256> ret; + // PromoteLowerTo is defined later in generic_ops-inl.h. + ret.v0 = PromoteTo(dh, LowerHalf(v32)); + ret.v1 = PromoteUpperTo(dh, v32); + return ret; +} + +// ------------------------------ PromoteUpperTo + +// Not native, but still define this here because wasm_128 toggles +// HWY_NATIVE_PROMOTE_UPPER_TO. +template +HWY_API VFromD PromoteUpperTo(D d, Vec256 v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type + // from v because it cannot be deduced from D (could be either bf16 or f16). + const Rebind dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// ------------------------------ DemoteTo + +template +HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { + return Vec128{wasm_u16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +template +HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { + return Vec128{wasm_u8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +template +HWY_API Vec128 DemoteTo(D di, Vec256 v) { + const Vec64 lo{wasm_i32x4_trunc_sat_f64x2_zero(v.v0.raw)}; + const Vec64 hi{wasm_i32x4_trunc_sat_f64x2_zero(v.v1.raw)}; + return Combine(di, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D di, Vec256 v) { + const Vec64 lo{wasm_u32x4_trunc_sat_f64x2_zero(v.v0.raw)}; + const Vec64 hi{wasm_u32x4_trunc_sat_f64x2_zero(v.v1.raw)}; + return Combine(di, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D df, Vec256 v) { + const Vec64 lo = DemoteTo(Full64(), v.v0); + const Vec64 hi = DemoteTo(Full64(), v.v1); + return Combine(df, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D df, Vec256 v) { + const Vec64 lo = DemoteTo(Full64(), v.v0); + const Vec64 hi = DemoteTo(Full64(), v.v1); + return Combine(df, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D d16, Vec256 v) { + const Half d16h; + const Vec64 lo = DemoteTo(d16h, v.v0); + const Vec64 hi = DemoteTo(d16h, v.v1); + return Combine(d16, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D df32, Vec256 v) { + const Half df32h; + const Vec64 lo = DemoteTo(df32h, v.v0); + const Vec64 hi = DemoteTo(df32h, v.v1); + return Combine(df32, hi, lo); +} + +// For already range-limited input [0, 255]. +HWY_API Vec64 U8FromU32(Vec256 v) { + const Full64 du8; + const Full256 di32; // no unsigned DemoteTo + return DemoteTo(du8, BitCast(di32, v)); +} + +// ------------------------------ Truncations + +template +HWY_API Vec32 TruncateTo(D /* tag */, Vec256 v) { + return Vec32{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 8, 16, 24, 0, + 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, + 24)}; +} + +template +HWY_API Vec64 TruncateTo(D /* tag */, Vec256 v) { + return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 8, 9, 16, + 17, 24, 25, 0, 1, 8, 9, 16, 17, 24, + 25)}; +} + +template +HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 2, 3, 8, + 9, 10, 11, 16, 17, 18, 19, 24, 25, + 26, 27)}; +} + +template +HWY_API Vec64 TruncateTo(D /* tag */, Vec256 v) { + return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 4, 8, 12, 16, + 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, + 28)}; +} + +template +HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 4, 5, 8, + 9, 12, 13, 16, 17, 20, 21, 24, 25, + 28, 29)}; +} + +template +HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 2, 4, 6, 8, + 10, 12, 14, 16, 18, 20, 22, 24, 26, + 28, 30)}; +} + +// ------------------------------ ReorderDemote2To +template ), HWY_IF_SIGNED_V(V), + HWY_IF_T_SIZE_ONE_OF_D(DN, (1 << 1) | (1 << 2) | (1 << 4)), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Half dnh; + VFromD demoted; + demoted.v0 = DemoteTo(dnh, a); + demoted.v1 = DemoteTo(dnh, b); + return demoted; +} + +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Half dnh; + VFromD demoted; + demoted.v0 = DemoteTo(dnh, a); + demoted.v1 = DemoteTo(dnh, b); + return demoted; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +template > +HWY_API Vec256 ConvertTo(DTo d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = ConvertTo(dh, v.v0); + ret.v1 = ConvertTo(dh, v.v1); + return ret; +} + +template +HWY_API Vec256> NearestInt(const Vec256 v) { + return ConvertTo(Full256>(), Round(v)); +} + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + const Half dh; + MFromD ret; + ret.m0 = LoadMaskBits(dh, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(TFromD); + const uint8_t bits_upper[8] = {static_cast(bits[0] >> kBitsPerHalf)}; + ret.m1 = LoadMaskBits(dh, bits_upper); + return ret; +} + +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + const Half dh; + MFromD ret; + ret.m0 = LoadMaskBits(dh, bits); + constexpr size_t kLanesPerHalf = 16 / sizeof(TFromD); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + ret.m1 = LoadMaskBits(dh, bits + kBytesPerHalf); + return ret; +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const Half dh; + MFromD ret; + ret.m0 = ret.m1 = Dup128MaskFromMaskBits(dh, mask_bits); + return ret; +} + +// ------------------------------ Mask + +// `p` points to at least 8 writable bytes. +template , + HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8))> +HWY_API size_t StoreMaskBits(D d, const Mask256 mask, uint8_t* bits) { + const Half dh; + StoreMaskBits(dh, mask.m0, bits); + const uint8_t lo = bits[0]; + StoreMaskBits(dh, mask.m1, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(T); + bits[0] = static_cast(lo | (bits[0] << kBitsPerHalf)); + return (kBitsPerHalf * 2 + 7) / 8; +} + +template , + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))> +HWY_API size_t StoreMaskBits(D d, const Mask256 mask, uint8_t* bits) { + const Half dh; + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + StoreMaskBits(dh, mask.m0, bits); + StoreMaskBits(dh, mask.m1, bits + kBytesPerHalf); + return kBytesPerHalf * 2; +} + +template > +HWY_API size_t CountTrue(D d, const Mask256 m) { + const Half dh; + return CountTrue(dh, m.m0) + CountTrue(dh, m.m1); +} + +template > +HWY_API bool AllFalse(D d, const Mask256 m) { + const Half dh; + return AllFalse(dh, m.m0) && AllFalse(dh, m.m1); +} + +template > +HWY_API bool AllTrue(D d, const Mask256 m) { + const Half dh; + return AllTrue(dh, m.m0) && AllTrue(dh, m.m1); +} + +template > +HWY_API size_t FindKnownFirstTrue(D d, const Mask256 mask) { + const Half dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); // not known + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + return lo >= 0 ? static_cast(lo) + : kLanesPerHalf + FindKnownFirstTrue(dh, mask.m1); +} + +template > +HWY_API intptr_t FindFirstTrue(D d, const Mask256 mask) { + const Half dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); + constexpr int kLanesPerHalf = 16 / sizeof(T); + if (lo >= 0) return lo; + + const intptr_t hi = FindFirstTrue(dh, mask.m1); + return hi + (hi >= 0 ? kLanesPerHalf : 0); +} + +template > +HWY_API size_t FindKnownLastTrue(D d, const Mask256 mask) { + const Half dh; + const intptr_t hi = FindLastTrue(dh, mask.m1); // not known + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + return hi >= 0 ? kLanesPerHalf + static_cast(hi) + : FindKnownLastTrue(dh, mask.m0); +} + +template > +HWY_API intptr_t FindLastTrue(D d, const Mask256 mask) { + const Half dh; + constexpr int kLanesPerHalf = 16 / sizeof(T); + const intptr_t hi = FindLastTrue(dh, mask.m1); + return hi >= 0 ? kLanesPerHalf + hi : FindLastTrue(dh, mask.m0); +} + +// ------------------------------ CompressStore +template > +HWY_API size_t CompressStore(Vec256 v, const Mask256 mask, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + const size_t count = CompressStore(v.v0, mask.m0, dh, unaligned); + const size_t count2 = CompressStore(v.v1, mask.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBlendedStore +template > +HWY_API size_t CompressBlendedStore(Vec256 v, const Mask256 m, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + const size_t count = CompressBlendedStore(v.v0, m.m0, dh, unaligned); + const size_t count2 = CompressBlendedStore(v.v1, m.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBitsStore + +template > +HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, + D d, T* HWY_RESTRICT unaligned) { + const Mask256 m = LoadMaskBits(d, bits); + return CompressStore(v, m, d, unaligned); +} + +// ------------------------------ Compress +template +HWY_API Vec256 Compress(const Vec256 v, const Mask256 mask) { + const DFromV d; + alignas(32) T lanes[32 / sizeof(T)] = {}; + (void)CompressStore(v, mask, d, lanes); + return Load(d, lanes); +} + +// ------------------------------ CompressNot +template +HWY_API Vec256 CompressNot(Vec256 v, const Mask256 mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + const Full128 dh; + // Because the non-selected (mask=1) blocks are undefined, we can return the + // input unless mask = 01, in which case we must bring down the upper block. + return AllTrue(dh, AndNot(mask.m1, mask.m0)) ? SwapAdjacentBlocks(v) : v; +} + +// ------------------------------ CompressBits +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + const Mask256 m = LoadMaskBits(DFromV(), bits); + return Compress(v, m); +} + +// ------------------------------ Expand +template +HWY_API Vec256 Expand(const Vec256 v, const Mask256 mask) { + Vec256 ret; + const Full256 d; + const Half dh; + alignas(32) T lanes[32 / sizeof(T)] = {}; + Store(v, d, lanes); + ret.v0 = Expand(v.v0, mask.m0); + ret.v1 = Expand(LoadU(dh, lanes + CountTrue(dh, mask.m0)), mask.m1); + return ret; +} + +// ------------------------------ LoadExpand +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { + +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template > +HWY_API void LoadTransposedBlocks3(D d, const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C) { + const Vec256 v10 = LoadU(d, unaligned + 0 * MaxLanes(d)); + const Vec256 v32 = LoadU(d, unaligned + 1 * MaxLanes(d)); + const Vec256 v54 = LoadU(d, unaligned + 2 * MaxLanes(d)); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of A) +// 5 1 +// 6 2 +// 7 3 +template > +HWY_API void LoadTransposedBlocks4(D d, const T* HWY_RESTRICT unaligned, + Vec256& vA, Vec256& vB, Vec256& vC, + Vec256& vD) { + const Vec256 v10 = LoadU(d, unaligned + 0 * MaxLanes(d)); + const Vec256 v32 = LoadU(d, unaligned + 1 * MaxLanes(d)); + const Vec256 v54 = LoadU(d, unaligned + 2 * MaxLanes(d)); + const Vec256 v76 = LoadU(d, unaligned + 3 * MaxLanes(d)); + + vA = ConcatLowerLower(d, v54, v10); + vB = ConcatUpperUpper(d, v54, v10); + vC = ConcatLowerLower(d, v76, v32); + vD = ConcatUpperUpper(d, v76, v32); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template > +HWY_API void StoreTransposedBlocks2(Vec256 i, Vec256 j, D d, + T* HWY_RESTRICT unaligned) { + const Vec256 out0 = ConcatLowerLower(d, j, i); + const Vec256 out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * MaxLanes(d)); + StoreU(out1, d, unaligned + 1 * MaxLanes(d)); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template > +HWY_API void StoreTransposedBlocks3(Vec256 i, Vec256 j, Vec256 k, D d, + T* HWY_RESTRICT unaligned) { + const Vec256 out0 = ConcatLowerLower(d, j, i); + const Vec256 out1 = ConcatUpperLower(d, i, k); + const Vec256 out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * MaxLanes(d)); + StoreU(out1, d, unaligned + 1 * MaxLanes(d)); + StoreU(out2, d, unaligned + 2 * MaxLanes(d)); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template > +HWY_API void StoreTransposedBlocks4(Vec256 i, Vec256 j, Vec256 k, + Vec256 l, D d, + T* HWY_RESTRICT unaligned) { + // Write lower halves, then upper. + const Vec256 out0 = ConcatLowerLower(d, j, i); + const Vec256 out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * MaxLanes(d)); + StoreU(out1, d, unaligned + 1 * MaxLanes(d)); + const Vec256 out2 = ConcatUpperUpper(d, j, i); + const Vec256 out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * MaxLanes(d)); + StoreU(out3, d, unaligned + 3 * MaxLanes(d)); +} + +} // namespace detail + +// ------------------------------ Additional mask logical operations + +template +HWY_API Mask256 SetAtOrAfterFirst(Mask256 mask) { + const Full256 d; + const Half dh; + const Repartition dh_i64; + + Mask256 result; + result.m0 = SetAtOrAfterFirst(mask.m0); + result.m1 = SetAtOrAfterFirst(mask.m1); + + // Copy the sign bit of the lower 128-bit half to the upper 128-bit half + const auto vmask_lo = BitCast(dh_i64, VecFromMask(dh, result.m0)); + result.m1 = + Or(result.m1, MaskFromVec(BitCast(dh, BroadcastSignBit(InterleaveUpper( + dh_i64, vmask_lo, vmask_lo))))); + + return result; +} + +template +HWY_API Mask256 SetBeforeFirst(Mask256 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask256 SetOnlyFirst(Mask256 mask) { + const Full256 d; + const RebindToSigned di; + const Repartition di64; + const Half dh_i64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + + const auto vmask_eq_0 = VecFromMask(di64, vmask == zero); + auto vmask2_lo = LowerHalf(dh_i64, vmask_eq_0); + auto vmask2_hi = UpperHalf(dh_i64, vmask_eq_0); + + vmask2_lo = And(vmask2_lo, InterleaveLower(vmask2_lo, vmask2_lo)); + vmask2_hi = And(ConcatLowerUpper(dh_i64, vmask2_hi, vmask2_lo), + InterleaveUpper(dh_i64, vmask2_lo, vmask2_lo)); + vmask2_lo = InterleaveLower(Set(dh_i64, int64_t{-1}), vmask2_lo); + + const auto vmask2 = Combine(di64, vmask2_hi, vmask2_lo); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask256 SetAtOrBeforeFirst(Mask256 mask) { + const Full256 d; + constexpr size_t kLanesPerBlock = MaxLanes(d) / 2; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_lo = ConcatLowerLower(d, vmask, Zero(d)); + return SetBeforeFirst( + MaskFromVec(CombineShiftRightBytes<(kLanesPerBlock - 1) * sizeof(T)>( + d, vmask, vmask_lo))); +} + +// ------------------------------ WidenMulPairwiseAdd +template > +HWY_API Vec256 WidenMulPairwiseAdd(D32 d32, Vec256 a, Vec256 b) { + const Half d32h; + Vec256 result; + result.v0 = WidenMulPairwiseAdd(d32h, a.v0, b.v0); + result.v1 = WidenMulPairwiseAdd(d32h, a.v1, b.v1); + return result; +} + +// ------------------------------ ReorderWidenMulAccumulate +template > +HWY_API Vec256 ReorderWidenMulAccumulate(D32 d32, Vec256 a, + Vec256 b, Vec256 sum0, + Vec256& sum1) { + const Half d32h; + sum0.v0 = ReorderWidenMulAccumulate(d32h, a.v0, b.v0, sum0.v0, sum1.v0); + sum0.v1 = ReorderWidenMulAccumulate(d32h, a.v1, b.v1, sum0.v1, sum1.v1); + return sum0; +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec256 RearrangeToOddPlusEven(Vec256 sum0, Vec256 sum1) { + sum0.v0 = RearrangeToOddPlusEven(sum0.v0, sum1.v0); + sum0.v1 = RearrangeToOddPlusEven(sum0.v1, sum1.v1); + return sum0; +} + +// ------------------------------ Reductions in generic_ops + +// ------------------------------ Lt128 + +template > +HWY_INLINE Mask256 Lt128(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Lt128(dh, a.v0, b.v0); + ret.m1 = Lt128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Lt128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Lt128Upper(dh, a.v0, b.v0); + ret.m1 = Lt128Upper(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Eq128(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Eq128(dh, a.v0, b.v0); + ret.m1 = Eq128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Eq128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Eq128Upper(dh, a.v0, b.v0); + ret.m1 = Eq128Upper(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Ne128(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Ne128(dh, a.v0, b.v0); + ret.m1 = Ne128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Ne128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Ne128Upper(dh, a.v0, b.v0); + ret.m1 = Ne128Upper(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Vec256 Min128(D d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Min128(dh, a.v0, b.v0); + ret.v1 = Min128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Vec256 Max128(D d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Max128(dh, a.v0, b.v0); + ret.v1 = Max128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Vec256 Min128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Min128Upper(dh, a.v0, b.v0); + ret.v1 = Min128Upper(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Vec256 Max128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Max128Upper(dh, a.v0, b.v0); + ret.v1 = Max128Upper(dh, a.v1, b.v1); + return ret; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/aom/third_party/highway/hwy/ops/x86_128-inl.h b/third_party/aom/third_party/highway/hwy/ops/x86_128-inl.h new file mode 100644 index 000000000000..db38e79cec24 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/x86_128-inl.h @@ -0,0 +1,13907 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit vectors and SSE4 instructions, plus some AVX2 and AVX512-VL +// operations when compiling for those targets. +// External include guard in highway.h - see comment there. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_GCC_ACTUAL +#include "third_party/highway/hwy/base.h" + +// Avoid uninitialized warnings in GCC's emmintrin.h - see +// https://github.com/google/highway/issues/710 and pull/902 +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, + ignored "-Wmaybe-uninitialized") +#endif + +#include +#include +#if HWY_TARGET == HWY_SSSE3 +#include // SSSE3 +#elif HWY_TARGET <= HWY_SSE4 +#include // SSE4 +#ifndef HWY_DISABLE_PCLMUL_AES +#include // CLMUL +#endif +#endif + +#include "third_party/highway/hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +// Enable generic functions for whichever of (f16, bf16) are not supported. +#if !HWY_HAVE_FLOAT16 +#define HWY_X86_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#else +#define HWY_X86_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#endif + +#undef HWY_AVX3_HAVE_F32_TO_BF16C +#if HWY_TARGET <= HWY_AVX3_ZEN4 && !HWY_COMPILER_CLANGCL && \ + (HWY_COMPILER_GCC_ACTUAL >= 1000 || HWY_COMPILER_CLANG >= 900) && \ + !defined(HWY_AVX3_DISABLE_AVX512BF16) +#define HWY_AVX3_HAVE_F32_TO_BF16C 1 +#else +#define HWY_AVX3_HAVE_F32_TO_BF16C 0 +#endif + +#undef HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT +#if HWY_TARGET <= HWY_AVX3 && HWY_ARCH_X86_64 +#define HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT "v" +#else +#define HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT "x" +#endif + +template +struct Raw128 { + using type = __m128i; +}; +#if HWY_HAVE_FLOAT16 +template <> +struct Raw128 { + using type = __m128h; +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct Raw128 { + using type = __m128; +}; +template <> +struct Raw128 { + using type = __m128d; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 + +// Template arg: sizeof(lane type) +template +struct RawMask128T {}; +template <> +struct RawMask128T<1> { + using type = __mmask16; +}; +template <> +struct RawMask128T<2> { + using type = __mmask8; +}; +template <> +struct RawMask128T<4> { + using type = __mmask8; +}; +template <> +struct RawMask128T<8> { + using type = __mmask8; +}; + +template +using RawMask128 = typename RawMask128T::type; + +#else // AVX2 or earlier + +template +using RawMask128 = typename Raw128::type; + +#endif // HWY_TARGET <= HWY_AVX3 + +} // namespace detail + +template +struct Mask128 { + using Raw = typename detail::RawMask128; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM + +#if HWY_TARGET <= HWY_AVX3 + static Mask128 FromBits(uint64_t mask_bits) { + return Mask128{static_cast(mask_bits)}; + } +#else +// Lanes are either FF..FF or 0. +#endif + + Raw raw; +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ Zero + +// Use HWY_MAX_LANES_D here because VFromD is defined in terms of Zero. +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{_mm_setzero_si128()}; +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Zero(D /* tag */) { + return Vec128{_mm_setzero_ph()}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Zero(D /* tag */) { + return Vec128{_mm_setzero_ps()}; +} +template +HWY_API Vec128 Zero(D /* tag */) { + return Vec128{_mm_setzero_pd()}; +} +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{_mm_setzero_si128()}; +} + +// Using the existing Zero function instead of a dedicated function for +// deduction avoids having to forward-declare Vec256 here. +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m128i BitCastToInteger(__m128i v) { return v; } +#if HWY_HAVE_FLOAT16 +HWY_INLINE __m128i BitCastToInteger(__m128h v) { return _mm_castph_si128(v); } +#endif // HWY_HAVE_FLOAT16 +HWY_INLINE __m128i BitCastToInteger(__m128 v) { return _mm_castps_si128(v); } +HWY_INLINE __m128i BitCastToInteger(__m128d v) { return _mm_castpd_si128(v); } + +#if HWY_AVX3_HAVE_F32_TO_BF16C +HWY_INLINE __m128i BitCastToInteger(__m128bh v) { + // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to + // bit cast a __m128bh to a __m128i as there is currently no intrinsic + // available (as of GCC 13 and Clang 17) that can bit cast a __m128bh vector + // to a __m128i vector + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // On GCC or Clang, use reinterpret_cast to bit cast a __m128bh to a __m128i + return reinterpret_cast<__m128i>(v); +#else + // On MSVC, use BitCastScalar to bit cast a __m128bh to a __m128i as MSVC does + // not allow reinterpret_cast, static_cast, or a C-style cast to be used to + // bit cast from one SSE/AVX vector type to a different SSE/AVX vector type + return BitCastScalar<__m128i>(v); +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __m128i operator()(__m128i v) { return v; } +}; +#if HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128h operator()(__m128i v) { return _mm_castsi128_ph(v); } +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128 operator()(__m128i v) { return _mm_castsi128_ps(v); } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128d operator()(__m128i v) { return _mm_castsi128_pd(v); } +}; + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + Vec128 v) { + return VFromD{BitCastFromInteger128>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, + Vec128().MaxLanes()> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm_set1_epi32(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm_set1_epi64x(static_cast(t))}; // NOLINT +} +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Set(D /* tag */, float16_t t) { + return VFromD{_mm_set1_ph(t)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API VFromD Set(D /* tag */, float t) { + return VFromD{_mm_set1_ps(t)}; +} +template +HWY_API VFromD Set(D /* tag */, double t) { + return VFromD{_mm_set1_pd(t)}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD Set(D df, TFromD t) { + const RebindToUnsigned du; + static_assert(sizeof(TFromD) == 2, "Expecting [b]f16"); + uint16_t bits; + CopyBytes<2>(&t, &bits); + return BitCast(df, Set(du, bits)); +} + +// ------------------------------ Undefined + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API VFromD Undefined(D /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return VFromD{_mm_undefined_si128()}; +} +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Undefined(D /* tag */) { + return VFromD{_mm_undefined_ph()}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API VFromD Undefined(D /* tag */) { + return VFromD{_mm_undefined_ps()}; +} +template +HWY_API VFromD Undefined(D /* tag */) { + return VFromD{_mm_undefined_pd()}; +} +template +HWY_API VFromD Undefined(D /* tag */) { + return VFromD{_mm_undefined_si128()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ GetLane + +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw) & 0xFF); +} +template +HWY_API T GetLane(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const uint16_t bits = + static_cast(_mm_cvtsi128_si32(BitCast(du, v).raw) & 0xFFFF); + return BitCastScalar(bits); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw)); +} +template +HWY_API float GetLane(const Vec128 v) { + return _mm_cvtss_f32(v.raw); +} +template +HWY_API T GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + return lanes[0]; +#else + return static_cast(_mm_cvtsi128_si64(v.raw)); +#endif +} +template +HWY_API double GetLane(const Vec128 v) { + return _mm_cvtsd_f64(v.raw); +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Repartition du8; + return BitCast(d, VFromD{detail::BitCastToInteger(v.raw)}); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{_mm_setr_epi8( + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), static_cast(t8), + static_cast(t9), static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), static_cast(t14), + static_cast(t15))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{ + _mm_setr_epi16(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7))}; +} + +// Generic for all vector lengths +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{_mm_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7)}; +} +#else +// Generic for all vector lengths if HWY_HAVE_FLOAT16 is not true +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{ + _mm_setr_epi32(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{_mm_setr_ps(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + // Need to use _mm_set_epi64x as there is no _mm_setr_epi64x intrinsic + // available + return VFromD{ + _mm_set_epi64x(static_cast(t1), static_cast(t0))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{_mm_setr_pd(t0, t1)}; +} + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<1> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<2> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<4> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<8> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<16> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]) && + __builtin_constant_p(v[8]) && __builtin_constant_p(v[9]) && + __builtin_constant_p(v[10]) && __builtin_constant_p(v[11]) && + __builtin_constant_p(v[12]) && __builtin_constant_p(v[13]) && + __builtin_constant_p(v[14]) && __builtin_constant_p(v[15]); +} + +#if HWY_TARGET <= HWY_AVX2 +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<32> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]) && + __builtin_constant_p(v[8]) && __builtin_constant_p(v[9]) && + __builtin_constant_p(v[10]) && __builtin_constant_p(v[11]) && + __builtin_constant_p(v[12]) && __builtin_constant_p(v[13]) && + __builtin_constant_p(v[14]) && __builtin_constant_p(v[15]) && + __builtin_constant_p(v[16]) && __builtin_constant_p(v[17]) && + __builtin_constant_p(v[18]) && __builtin_constant_p(v[19]) && + __builtin_constant_p(v[20]) && __builtin_constant_p(v[21]) && + __builtin_constant_p(v[22]) && __builtin_constant_p(v[23]) && + __builtin_constant_p(v[24]) && __builtin_constant_p(v[25]) && + __builtin_constant_p(v[26]) && __builtin_constant_p(v[27]) && + __builtin_constant_p(v[28]) && __builtin_constant_p(v[29]) && + __builtin_constant_p(v[30]) && __builtin_constant_p(v[31]); +} +#endif + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantX86Vec( + hwy::SizeTag num_of_lanes_tag, V v) { + using T = TFromV; +#if HWY_HAVE_FLOAT16 && HWY_HAVE_SCALAR_F16_TYPE + using F16VecLaneT = hwy::float16_t::Native; +#else + using F16VecLaneT = uint16_t; +#endif + using RawVecLaneT = If(), F16VecLaneT, + If(), uint16_t, T>>; + + // Suppress the -Wignored-attributes warning that is emitted by + // RemoveCvRef with GCC + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4649, ignored "-Wignored-attributes") + typedef RawVecLaneT GccRawVec + __attribute__((__vector_size__(sizeof(RemoveCvRef)))); + HWY_DIAGNOSTICS(pop) + + return IsConstantRawX86Vec(num_of_lanes_tag, + reinterpret_cast(v.raw)); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantX86VecForF2IConv(V v) { + constexpr size_t kNumOfLanesInRawSrcVec = + HWY_MAX(HWY_MAX_LANES_V(V), 16 / sizeof(TFromV)); + constexpr size_t kNumOfLanesInRawResultVec = + HWY_MAX(HWY_MAX_LANES_V(V), 16 / sizeof(TTo)); + constexpr size_t kNumOfLanesToCheck = + HWY_MIN(kNumOfLanesInRawSrcVec, kNumOfLanesInRawResultVec); + + return IsConstantX86Vec(hwy::SizeTag(), v); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{ + _mm_and_si128(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{_mm_and_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{_mm_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm_andnot_si128( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); +} +template +HWY_API Vec128 AndNot(Vec128 not_mask, + Vec128 mask) { + return Vec128{_mm_andnot_ps(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(Vec128 not_mask, + Vec128 mask) { + return Vec128{_mm_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{ + _mm_or_si128(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{_mm_or_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{_mm_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{ + _mm_xor_si128(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{_mm_xor_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{_mm_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not +template +HWY_API Vec128 Not(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const __m128i vu = BitCast(du, v).raw; + return BitCast(d, VU{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(d, VU{_mm_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Xor3 +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +#else + return Xor(x1, Xor(x2, x3)); +#endif +} + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast( + d, VU{_mm_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ BitwiseIfThenElse +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return IfVecThenElse(mask, yes, no); +} + +#endif + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Neg(hwy::FloatTag /*tag*/, const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_INLINE Vec128 Neg(hwy::SpecialTag /*tag*/, const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_INLINE Vec128 Neg(hwy::SignedTag /*tag*/, const Vec128 v) { + return Zero(DFromV()) - v; +} + +} // namespace detail + +template +HWY_INLINE Vec128 Neg(const Vec128 v) { + return detail::Neg(hwy::TypeTag(), v); +} + +// ------------------------------ Floating-point Abs +// Generic for all vector lengths +template )> +HWY_API V Abs(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + return v & BitCast(d, Set(di, static_cast(~SignMask()))); +} + +// ------------------------------ CopySign +// Generic for all vector lengths. +template +HWY_API V CopySign(const V magn, const V sign) { + static_assert(IsFloat>(), "Only makes sense for floating-point"); + + const DFromV d; + const auto msb = SignBit(d); + + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + return BitwiseIfThenElse(msb, sign, magn); +} + +// ------------------------------ CopySignToAbs +// Generic for all vector lengths. +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ================================================== MASK + +#if HWY_TARGET <= HWY_AVX3 +// ------------------------------ MaskFromVec + +namespace detail { + +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<1> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<2> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<4> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<8> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} +#endif +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +// ------------------------------ MaskFalse (MFromD) + +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +// Generic for all vector lengths +template +HWY_API MFromD MaskFalse(D /*d*/) { + return MFromD{static_cast().raw)>(0)}; +} + +// ------------------------------ IsNegative (MFromD) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +// Generic for all vector lengths +template +HWY_API MFromD> IsNegative(V v) { + return MaskFromVec(v); +} + +// ------------------------------ PromoteMaskTo (MFromD) + +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +// AVX3 PromoteMaskTo is generic for all vector lengths +template )), + class DFrom_2 = Rebind, DTo>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return MFromD{static_cast().raw)>(m.raw)}; +} + +// ------------------------------ DemoteMaskTo (MFromD) + +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +// AVX3 DemoteMaskTo is generic for all vector lengths +template ) - 1), + class DFrom_2 = Rebind, DTo>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return MFromD{static_cast().raw)>(m.raw)}; +} + +// ------------------------------ CombineMasks (MFromD) + +#ifdef HWY_NATIVE_COMBINE_MASKS +#undef HWY_NATIVE_COMBINE_MASKS +#else +#define HWY_NATIVE_COMBINE_MASKS +#endif + +// For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. +#if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) +#if HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC_ACTUAL >= 700 || \ + HWY_COMPILER_CLANG >= 800 +#define HWY_COMPILER_HAS_MASK_INTRINSICS 1 +#else +#define HWY_COMPILER_HAS_MASK_INTRINSICS 0 +#endif +#endif // HWY_COMPILER_HAS_MASK_INTRINSICS + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask8 combined_mask = _kor_mask8( + _kshiftli_mask8(static_cast<__mmask8>(hi.raw), 1), + _kand_mask8(static_cast<__mmask8>(lo.raw), static_cast<__mmask8>(1))); +#else + const auto combined_mask = + (static_cast(hi.raw) << 1) | (lo.raw & 1); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask8 combined_mask = _kor_mask8( + _kshiftli_mask8(static_cast<__mmask8>(hi.raw), 2), + _kand_mask8(static_cast<__mmask8>(lo.raw), static_cast<__mmask8>(3))); +#else + const auto combined_mask = + (static_cast(hi.raw) << 2) | (lo.raw & 3); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask8 combined_mask = _kor_mask8( + _kshiftli_mask8(static_cast<__mmask8>(hi.raw), 4), + _kand_mask8(static_cast<__mmask8>(lo.raw), static_cast<__mmask8>(15))); +#else + const auto combined_mask = + (static_cast(hi.raw) << 4) | (lo.raw & 15u); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask16 combined_mask = _mm512_kunpackb( + static_cast<__mmask16>(hi.raw), static_cast<__mmask16>(lo.raw)); +#else + const auto combined_mask = + ((static_cast(hi.raw) << 8) | (lo.raw & 0xFFu)); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +// ------------------------------ LowerHalfOfMask (MFromD) + +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +// Generic for all vector lengths +template +HWY_API MFromD LowerHalfOfMask(D d, MFromD> m) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumOfBitsInRawMask = sizeof(RawM) * 8; + + MFromD result_mask{static_cast(m.raw)}; + + if (kN < kNumOfBitsInRawMask) { + result_mask = + And(result_mask, MFromD{static_cast((1ULL << kN) - 1)}); + } + + return result_mask; +} + +// ------------------------------ UpperHalfOfMask (MFromD) + +#ifdef HWY_NATIVE_UPPER_HALF_OF_MASK +#undef HWY_NATIVE_UPPER_HALF_OF_MASK +#else +#define HWY_NATIVE_UPPER_HALF_OF_MASK +#endif + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask8(static_cast<__mmask8>(m.raw), 1); +#else + const auto shifted_mask = static_cast(m.raw) >> 1; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask8(static_cast<__mmask8>(m.raw), 2); +#else + const auto shifted_mask = static_cast(m.raw) >> 2; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask8(static_cast<__mmask8>(m.raw), 4); +#else + const auto shifted_mask = static_cast(m.raw) >> 4; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask16(static_cast<__mmask16>(m.raw), 8); +#else + const auto shifted_mask = static_cast(m.raw) >> 8; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +// ------------------------------ OrderedDemote2MasksTo (MFromD, CombineMasks) + +#ifdef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#undef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#else +#define HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#endif + +// Generic for all vector lengths +template ) / 2), + class DTo_2 = Repartition, DFrom>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD OrderedDemote2MasksTo(DTo d_to, DFrom /*d_from*/, + MFromD a, MFromD b) { + using MH = MFromD>; + using RawMH = decltype(MH().raw); + + return CombineMasks(d_to, MH{static_cast(b.raw)}, + MH{static_cast(a.raw)}); +} + +// ------------------------------ Slide mask up/down +#ifdef HWY_NATIVE_SLIDE_MASK +#undef HWY_NATIVE_SLIDE_MASK +#else +#define HWY_NATIVE_SLIDE_MASK +#endif + +template +HWY_API MFromD SlideMask1Up(D d, MFromD m) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr unsigned kValidLanesMask = (1u << kN) - 1u; + +#if HWY_COMPILER_HAS_MASK_INTRINSICS + MFromD result_mask{ + static_cast(_kshiftli_mask8(static_cast<__mmask8>(m.raw), 1))}; + + if (kN < 8) { + result_mask = + And(result_mask, MFromD{static_cast(kValidLanesMask)}); + } +#else + MFromD result_mask{ + static_cast((static_cast(m.raw) << 1) & kValidLanesMask)}; +#endif + + return result_mask; +} + +template +HWY_API MFromD SlideMask1Up(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftli_mask16(static_cast<__mmask16>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) << 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D d, MFromD m) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr unsigned kValidLanesMask = (1u << kN) - 1u; + +#if HWY_COMPILER_HAS_MASK_INTRINSICS + if (kN < 8) { + m = And(m, MFromD{static_cast(kValidLanesMask)}); + } + + return MFromD{ + static_cast(_kshiftri_mask8(static_cast<__mmask8>(m.raw), 1))}; +#else + return MFromD{ + static_cast((static_cast(m.raw) & kValidLanesMask) >> 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftri_mask16(static_cast<__mmask16>(m.raw), 1))}; +#else + return MFromD{ + static_cast((static_cast(m.raw) & 0xFFFFu) >> 1)}; +#endif +} + +// Generic for all vector lengths +template +HWY_API MFromD SlideMaskUpLanes(D d, MFromD m, size_t amt) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr uint64_t kValidLanesMask = + static_cast(((kN < 64) ? (1ULL << kN) : 0ULL) - 1ULL); + + return MFromD{static_cast( + (static_cast(m.raw) << (amt & 63)) & kValidLanesMask)}; +} + +// Generic for all vector lengths +template +HWY_API MFromD SlideMaskDownLanes(D d, MFromD m, size_t amt) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr uint64_t kValidLanesMask = + static_cast(((kN < 64) ? (1ULL << kN) : 0ULL) - 1ULL); + + return MFromD{static_cast( + (static_cast(m.raw) & kValidLanesMask) >> (amt & 63))}; +} + +// ------------------------------ VecFromMask + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi8(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi16(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi32(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi64(v.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_ph(_mm_movm_epi16(v.raw))}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_ps(_mm_movm_epi32(v.raw))}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_pd(_mm_movm_epi64(v.raw))}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VecFromMask(v); +} + +// ------------------------------ RebindMask (MaskFromVec) + +template +HWY_API MFromD RebindMask(DTo /* tag */, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +// ------------------------------ IfThenElse + +namespace detail { + +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_ph(mask.raw, no.raw, yes.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// Generic for all vector lengths. +template , HWY_X86_IF_EMULATED_D(D)> +HWY_API V IfThenElse(MFromD mask, V yes, V no) { + const RebindToUnsigned du; + return BitCast( + D(), IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, Vec128 no) { + return Vec128{_mm_mask_blend_ps(mask.raw, no.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_pd(mask.raw, no.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_ps(mask.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_pd(mask.raw, yes.raw)}; +} + +// Generic for all vector lengths. +template , HWY_IF_SPECIAL_FLOAT_D(D)> +HWY_API V IfThenElseZero(MFromD mask, V yes) { + const RebindToUnsigned du; + return BitCast(D(), IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec128{_mm_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +// Generic for all vector lengths. +template , HWY_IF_SPECIAL_FLOAT_D(D)> +HWY_API V IfThenZeroElse(MFromD mask, V no) { + const RebindToUnsigned du; + return BitCast(D(), IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask128 And(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Or(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0x3)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0x3)}; +#endif +} + +// UnmaskedNot returns ~m.raw without zeroing out any invalid bits +template +HWY_INLINE Mask128 UnmaskedNot(const Mask128 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask16>(_knot_mask16(m.raw))}; +#else + return Mask128{static_cast<__mmask16>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask128 UnmaskedNot(const Mask128 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_knot_mask8(m.raw))}; +#else + return Mask128{static_cast<__mmask8>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Not(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + // sizeof(T) == 1 and N == 16: simply return ~m as all 16 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + // sizeof(T) == 1 and N <= 8: need to zero out the upper bits of ~m as there + // are fewer than 16 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<1>(), m, Mask128::FromBits((1ull << N) - 1)); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + // sizeof(T) == 2 and N == 8: simply return ~m as all 8 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + // sizeof(T) == 2 and N <= 4: need to zero out the upper bits of ~m as there + // are fewer than 8 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<2>(), m, Mask128::FromBits((1ull << N) - 1)); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + // sizeof(T) == 4: need to zero out the upper bits of ~m as there are at most + // 4 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<4>(), m, Mask128::FromBits((1ull << N) - 1)); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + // sizeof(T) == 8: need to zero out the upper bits of ~m as there are at most + // 2 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<8>(), m, Mask128::FromBits((1ull << N) - 1)); +} + +} // namespace detail + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Not(const Mask128 m) { + // Flip only the valid bits + return detail::Not(hwy::SizeTag(), m); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +#else // AVX2 or below + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{v.raw}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VecFromMask(v); +} + +#if HWY_TARGET >= HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const auto vmask = VecFromMask(DFromV(), mask); + return Or(And(vmask, yes), AndNot(vmask, no)); +} + +#else // HWY_TARGET < HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, Vec128 no) { + return Vec128{_mm_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +#endif // HWY_TARGET >= HWY_SSSE3 + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ ShiftLeft + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +#if HWY_TARGET <= HWY_AVX3_DL + +namespace detail { +template +HWY_API Vec128 GaloisAffine( + Vec128 v, VFromD>> matrix) { + return Vec128{_mm_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; +} +} // namespace detail + +#else // HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ ShiftRight + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi32(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// i64 is implemented after BroadcastSignBit. + +// ================================================== MEMORY (1) + +// Clang static analysis claims the memory immediately after a partial vector +// store is uninitialized, and also flags the input to partial loads (at least +// for loadl_pd) as "garbage". This is a false alarm because msan does not +// raise errors. We work around this by using CopyBytes instead of intrinsics, +// but only for the analyzer to avoid potentially bad code generation. +// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7. +#ifndef HWY_SAFE_PARTIAL_LOAD_STORE +#if defined(__clang_analyzer__) || \ + (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_SAFE_PARTIAL_LOAD_STORE 1 +#else +#define HWY_SAFE_PARTIAL_LOAD_STORE 0 +#endif +#endif // HWY_SAFE_PARTIAL_LOAD_STORE + +// ------------------------------ Load + +template +HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { + return VFromD{_mm_load_si128(reinterpret_cast(aligned))}; +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Load(D, const float16_t* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ph(aligned)}; +} +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; + return BitCast(d, Load(du, detail::U16LanePointer(aligned))); +} +template +HWY_API Vec128 Load(D /* tag */, const float* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ps(aligned)}; +} +template +HWY_API Vec128 Load(D /* tag */, const double* HWY_RESTRICT aligned) { + return Vec128{_mm_load_pd(aligned)}; +} + +template +HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_loadu_si128(reinterpret_cast(p))}; +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 LoadU(D, const float16_t* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ph(p)}; +} +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadU(du, detail::U16LanePointer(p))); +} +template +HWY_API Vec128 LoadU(D /* tag */, const float* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ps(p)}; +} +template +HWY_API Vec128 LoadU(D /* tag */, const double* HWY_RESTRICT p) { + return Vec128{_mm_loadu_pd(p)}; +} + +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = _mm_setzero_si128(); + CopyBytes<8>(p, &v); // not same size +#else + const __m128i v = _mm_loadl_epi64(reinterpret_cast(p)); +#endif + return BitCast(d, VFromD{v}); +} + +template +HWY_API Vec64 Load(D /* tag */, const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + const __m128 hi = _mm_setzero_ps(); + return Vec64{_mm_loadl_pi(hi, reinterpret_cast(p))}; +#endif +} + +template +HWY_API Vec64 Load(D /* tag */, const double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128d v = _mm_setzero_pd(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + return Vec64{_mm_load_sd(p)}; +#endif +} + +template +HWY_API Vec32 Load(D /* tag */, const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<4>(p, &v); // not same size + return Vec32{v}; +#else + return Vec32{_mm_load_ss(p)}; +#endif +} + +// Any <= 32 bit except +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + // Clang ArgumentPromotionPass seems to break this code. We can unpoison + // before SetTableIndices -> LoadU -> Load and the memory is poisoned again. + detail::MaybeUnpoison(p, Lanes(d)); + +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = Zero(Full128>()).raw; + CopyBytes(p, &v); // not same size as VFromD +#else + int32_t bits = 0; + CopyBytes(p, &bits); // not same size as VFromD + const __m128i v = _mm_cvtsi32_si128(bits); +#endif + return BitCast(d, VFromD{v}); +} + +// For < 128 bit, LoadU == Load. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// ------------------------------ Store + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + _mm_store_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void Store(Vec128 v, D, float16_t* HWY_RESTRICT aligned) { + _mm_store_ph(aligned, v.raw); +} +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; + Store(BitCast(du, v), du, reinterpret_cast(aligned)); +} +template +HWY_API void Store(Vec128 v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm_store_ps(aligned, v.raw); +} +template +HWY_API void Store(Vec128 v, D /* tag */, + double* HWY_RESTRICT aligned) { + _mm_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(p), v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec128 v, D, float16_t* HWY_RESTRICT p) { + _mm_storeu_ph(p, v.raw); +} +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + StoreU(BitCast(du, v), du, reinterpret_cast(p)); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, float* HWY_RESTRICT p) { + _mm_storeu_ps(p, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, double* HWY_RESTRICT p) { + _mm_storeu_pd(p, v.raw); +} + +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + (void)d; + CopyBytes<8>(&v, p); // not same size +#else + const RebindToUnsigned du; // for float16_t + _mm_storel_epi64(reinterpret_cast<__m128i*>(p), BitCast(du, v).raw); +#endif +} +template +HWY_API void Store(Vec64 v, D /* tag */, float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); +#endif +} +template +HWY_API void Store(Vec64 v, D /* tag */, double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pd(p, v.raw); +#endif +} + +// Any <= 32 bit except +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { + CopyBytes(&v, p); // not same size +} +template +HWY_API void Store(Vec32 v, D /* tag */, float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<4>(&v, p); // not same size +#else + _mm_store_ss(p, v.raw); +#endif +} + +// For < 128 bit, StoreU == Store. +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ================================================== SWIZZLE (1) + +// ------------------------------ TableLookupBytes +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const DFromV d; + const Repartition du8; + + const DFromV d_bytes; + const Repartition du8_bytes; +#if HWY_TARGET == HWY_SSE2 +#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) + typedef uint8_t GccU8RawVectType __attribute__((__vector_size__(16))); + (void)d; + (void)du8; + (void)d_bytes; + (void)du8_bytes; + return Vec128{reinterpret_cast::type>( + __builtin_shuffle(reinterpret_cast(bytes.raw), + reinterpret_cast(from.raw)))}; +#else + const Full128 du8_full; + + alignas(16) uint8_t result_bytes[16]; + alignas(16) uint8_t u8_bytes[16]; + alignas(16) uint8_t from_bytes[16]; + + Store(Vec128{BitCast(du8_bytes, bytes).raw}, du8_full, u8_bytes); + Store(Vec128{BitCast(du8, from).raw}, du8_full, from_bytes); + + for (int i = 0; i < 16; i++) { + result_bytes[i] = u8_bytes[from_bytes[i] & 15]; + } + + return BitCast(d, VFromD{Load(du8_full, result_bytes).raw}); +#endif +#else // SSSE3 or newer + return BitCast( + d, VFromD{_mm_shuffle_epi8(BitCast(du8_bytes, bytes).raw, + BitCast(du8, from).raw)}); +#endif +} + +// ------------------------------ TableLookupBytesOr0 +// For all vector widths; x86 anyway zeroes if >= 0x80 on SSSE3/SSE4/AVX2/AVX3 +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { +#if HWY_TARGET == HWY_SSE2 + const DFromV d; + const Repartition di8; + + const auto di8_from = BitCast(di8, from); + return BitCast(d, IfThenZeroElse(di8_from < Zero(di8), + TableLookupBytes(bytes, di8_from))); +#else + return TableLookupBytes(bytes, from); +#endif +} + +// ------------------------------ Shuffles (ShiftRight, TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_epi32(v.raw, 0xB1)}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. As with +// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output +// comes from the first argument. +namespace detail { + +template +HWY_API Vec32 ShuffleTwo2301(const Vec32 a, const Vec32 b) { + const DFromV d; + const Twice d2; + const auto ba = Combine(d2, b, a); +#if HWY_TARGET == HWY_SSE2 + Vec32 ba_shuffled{ + _mm_shufflelo_epi16(ba.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + return BitCast(d, Or(ShiftLeft<8>(ba_shuffled), ShiftRight<8>(ba_shuffled))); +#else + const RebindToUnsigned d2_u; + const auto shuffle_idx = + BitCast(d2, Dup128VecFromValues(d2_u, 1, 0, 7, 6, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0)); + return Vec32{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec64 ShuffleTwo2301(const Vec64 a, const Vec64 b) { + const DFromV d; + const Twice d2; + const auto ba = Combine(d2, b, a); +#if HWY_TARGET == HWY_SSE2 + Vec64 ba_shuffled{ + _mm_shuffle_epi32(ba.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + return Vec64{ + _mm_shufflelo_epi16(ba_shuffled.raw, _MM_SHUFFLE(2, 3, 0, 1))}; +#else + const RebindToUnsigned d2_u; + const auto shuffle_idx = BitCast( + d2, + Dup128VecFromValues(d2_u, 0x0302, 0x0100, 0x0f0e, 0x0d0c, 0, 0, 0, 0)); + return Vec64{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec32 ShuffleTwo1230(const Vec32 a, const Vec32 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const auto zero = Zero(d); + const Rebind di16; + const Vec32 a_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(a.raw, zero.raw), _MM_SHUFFLE(3, 0, 3, 0))}; + const Vec32 b_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(b.raw, zero.raw), _MM_SHUFFLE(1, 2, 1, 2))}; + const auto ba_shuffled = Combine(di16, b_shuffled, a_shuffled); + return Vec32{_mm_packus_epi16(ba_shuffled.raw, ba_shuffled.raw)}; +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = + BitCast(d2, Dup128VecFromValues(d2_u, 0, 3, 6, 5, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0)); + return Vec32{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec64 ShuffleTwo1230(const Vec64 a, const Vec64 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const Vec32 a_shuffled{ + _mm_shufflelo_epi16(a.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + const Vec32 b_shuffled{ + _mm_shufflelo_epi16(b.raw, _MM_SHUFFLE(1, 2, 1, 2))}; + return Combine(d, b_shuffled, a_shuffled); +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = BitCast( + d2, + Dup128VecFromValues(d2_u, 0x0100, 0x0706, 0x0d0c, 0x0b0a, 0, 0, 0, 0)); + return Vec64{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec32 ShuffleTwo3012(const Vec32 a, const Vec32 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const auto zero = Zero(d); + const Rebind di16; + const Vec32 a_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(a.raw, zero.raw), _MM_SHUFFLE(1, 2, 1, 2))}; + const Vec32 b_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(b.raw, zero.raw), _MM_SHUFFLE(3, 0, 3, 0))}; + const auto ba_shuffled = Combine(di16, b_shuffled, a_shuffled); + return Vec32{_mm_packus_epi16(ba_shuffled.raw, ba_shuffled.raw)}; +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = + BitCast(d2, Dup128VecFromValues(d2_u, 2, 1, 4, 7, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0)); + return Vec32{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec64 ShuffleTwo3012(const Vec64 a, const Vec64 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const Vec32 a_shuffled{ + _mm_shufflelo_epi16(a.raw, _MM_SHUFFLE(1, 2, 1, 2))}; + const Vec32 b_shuffled{ + _mm_shufflelo_epi16(b.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + return Combine(d, b_shuffled, a_shuffled); +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = BitCast( + d2, + Dup128VecFromValues(d2_u, 0x0504, 0x0302, 0x0908, 0x0f0e, 0, 0, 0, 0)); + return Vec64{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 1)}; +} + +// Rotate right 32 bits +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +// ------------------------------ TestBit + +namespace detail { + +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<1> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<2> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<4> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<8> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Signed/float < +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpgt_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu64_mask(a.raw, b.raw)}; +} + +#else // AVX2 or below + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo dto, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + const Simd d; + return MaskFromVec(BitCast(dto, VecFromMask(d, m))); +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + const DFromV d64; + const RepartitionToNarrow d32; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +#else + return Mask128{_mm_cmpeq_epi64(a.raw, b.raw)}; +#endif +} + +// Signed +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + // Same as signed ==; avoid duplicating the SSSE3 version. + const DFromV d; + RebindToUnsigned du; + return RebindMask(d, BitCast(du, a) == BitCast(du, b)); +} + +// Float +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpeq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpneq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpneq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +namespace detail { + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi8(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + // See https://stackoverflow.com/questions/65166174/: + const DFromV d; + const RepartitionToNarrow d32; + const Vec128 m_eq32{Eq(BitCast(d32, a), BitCast(d32, b)).raw}; + const Vec128 m_gt32{Gt(BitCast(d32, a), BitCast(d32, b)).raw}; + // If a.upper is greater, upper := true. Otherwise, if a.upper == b.upper: + // upper := b-a (unsigned comparison result of lower). Otherwise: upper := 0. + const __m128i upper = OrAnd(m_gt32, m_eq32, Sub(b, a)).raw; + // Duplicate upper to lower half. + return Mask128{_mm_shuffle_epi32(upper, _MM_SHUFFLE(3, 3, 1, 1))}; +#else + return Mask128{_mm_cmpgt_epi64(a.raw, b.raw)}; // SSE4.2 +#endif +} + +template +HWY_INLINE Mask128 Gt(hwy::UnsignedTag /*tag*/, Vec128 a, + Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); + const auto sa = BitCast(di, Xor(a, msb)); + const auto sb = BitCast(di, Xor(b, msb)); + return RebindMask(du, Gt(hwy::SignedTag(), sa, sb)); +} + +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_ps(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template +HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +namespace detail { +template +HWY_INLINE Mask128 Ge(hwy::SignedTag tag, Vec128 a, + Vec128 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask128 Ge(hwy::UnsignedTag tag, Vec128 a, + Vec128 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_ps(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return detail::Ge(hwy::TypeTag(), a, b); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + return b > a; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + return b >= a; +} + +// ------------------------------ Iota (Load) + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_epi8( + static_cast(15), static_cast(14), static_cast(13), + static_cast(12), static_cast(11), static_cast(10), + static_cast(9), static_cast(8), static_cast(7), + static_cast(6), static_cast(5), static_cast(4), + static_cast(3), static_cast(2), static_cast(1), + static_cast(0))}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_epi16(int16_t{7}, int16_t{6}, int16_t{5}, int16_t{4}, + int16_t{3}, int16_t{2}, int16_t{1}, + int16_t{0})}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_ph(float16_t{7}, float16_t{6}, float16_t{5}, + float16_t{4}, float16_t{3}, float16_t{2}, + float16_t{1}, float16_t{0})}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{ + _mm_set_epi32(int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_epi64x(int64_t{1}, int64_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_ps(3.0f, 2.0f, 1.0f, 0.0f)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_pd(1.0, 0.0)}; +} + +#if HWY_COMPILER_MSVC +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + const V mask_out_mask{_mm_set_epi32(0, 0, 0, 0xFF)}; + return v & mask_out_mask; +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { +#if HWY_TARGET <= HWY_SSE4 + return V{_mm_blend_epi16(v.raw, _mm_setzero_si128(), 0xFE)}; +#else + const V mask_out_mask{_mm_set_epi32(0, 0, 0, 0xFFFF)}; + return v & mask_out_mask; +#endif +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + const DFromV d; + const Repartition df; + using VF = VFromD; + return BitCast(d, VF{_mm_move_ss(_mm_setzero_ps(), BitCast(df, v).raw)}); +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm_move_epi64(BitCast(du, v).raw)}); +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + return v; +} +#endif + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + const auto result_iota = + detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); +#if HWY_COMPILER_MSVC + return detail::MaskOutVec128Iota(result_iota); +#else + return result_iota; +#endif +} + +// ------------------------------ FirstN (Iota, Lt) + +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API M FirstN(D d, size_t num) { + constexpr size_t kN = MaxLanes(d); + // For AVX3, this ensures `num` <= 255 as required by bzhi, which only looks + // at the lower 8 bits; for AVX2 and below, this ensures `num` fits in TI. + num = HWY_MIN(num, kN); +#if HWY_TARGET <= HWY_AVX3 +#if HWY_ARCH_X86_64 + const uint64_t all = (1ull << kN) - 1; + return M::FromBits(_bzhi_u64(all, num)); +#else + const uint32_t all = static_cast((1ull << kN) - 1); + return M::FromBits(_bzhi_u32(all, static_cast(num))); +#endif // HWY_ARCH_X86_64 +#else // HWY_TARGET > HWY_AVX3 + const RebindToSigned di; // Signed comparisons are cheaper. + using TI = TFromD; + return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(num))); +#endif // HWY_TARGET <= HWY_AVX3 +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_unpacklo_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_unpacklo_pd(a.raw, b.raw)}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ================================================== MEMORY (2) + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_maskz_loadu_epi16(m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const float* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const double* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_pd(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_epi8(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{ + _mm_mask_loadu_epi16(BitCast(du, v).raw, m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_epi32(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_epi64(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const float* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_ps(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const double* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_pd(v.raw, m.raw, p)}; +} + +#elif HWY_TARGET == HWY_AVX2 + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return VFromD{_mm_maskload_epi32(p_p, m.raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return VFromD{_mm_maskload_epi64(p_p, m.raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, const float* HWY_RESTRICT p) { + const RebindToSigned di; + return VFromD{_mm_maskload_ps(p, BitCast(di, VecFromMask(d, m)).raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, const double* HWY_RESTRICT p) { + const RebindToSigned di; + return VFromD{_mm_maskload_pd(p, BitCast(di, VecFromMask(d, m)).raw)}; +} + +// There is no maskload_epi8/16, so blend instead. +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +#else // <= SSE4 + +// Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +#endif + +// ------------------------------ MaskedLoadOr + +#if HWY_TARGET > HWY_AVX3 // else: native + +// Generic for all vector lengths. +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElse(m, LoadU(d, p), v); +} + +#endif // HWY_TARGET > HWY_AVX3 + +// ------------------------------ LoadN (InterleaveLower) + +#if HWY_TARGET <= HWY_AVX2 && !HWY_MEM_OPS_MIGHT_FAULT + +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +// Generic for all vector lengths. +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + return ResizeBitCast(d, MaskedLoad(FirstN(d_full, num_lanes), d_full, p)); +} + +// Generic for all vector lengths. +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + return ResizeBitCast(d, MaskedLoadOr(ResizeBitCast(d_full, no), + FirstN(d_full, num_lanes), d_full, p)); +} + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +// 'Leading' means the part that fits in 32-bit lanes. With 2-byte vectors, +// there are none, so return the remainder (v_trailing). +template +HWY_INLINE VFromD AVX2UIF8Or16LoadLeadingN( + VFromD /*load_mask*/, D /*d*/, const TFromD* HWY_RESTRICT /*p*/, + VFromD v_trailing) { + return v_trailing; +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadLeadingNOr( + VFromD /*no*/, VFromD /*load_mask*/, D /*d*/, + const TFromD* HWY_RESTRICT /*p*/, VFromD v_trailing) { + return v_trailing; +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadLeadingN(VFromD load_mask, D d, + const TFromD* HWY_RESTRICT p, + VFromD v_trailing) { + using DI32 = Repartition; + const FixedTag di32_full; + + // ResizeBitCast of load_mask to di32 is okay below if + // d.MaxBytes() < di32.MaxBytes() is true as any lanes of load_mask.raw past + // the first (lowest-index) lanes of load_mask.raw will have already been + // zeroed out by FirstN. + return ResizeBitCast( + d, IfNegativeThenElse( + ResizeBitCast(di32_full, load_mask), + MaskedLoad(MaskFromVec(ResizeBitCast(di32_full, load_mask)), + di32_full, reinterpret_cast(p)), + ResizeBitCast(di32_full, v_trailing))); +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadLeadingNOr(VFromD no, + VFromD load_mask, D d, + const TFromD* HWY_RESTRICT p, + VFromD v_trailing) { + using DI32 = Repartition; + const FixedTag di32_full; + + // ResizeBitCast of load_mask to di32 is okay below if + // d.MaxBytes() < di32.MaxBytes() is true as any lanes of load_mask.raw past + // the first (lowest-index) lanes of load_mask.raw will have already been + // zeroed out by FirstN. + return ResizeBitCast( + d, IfNegativeThenElse( + ResizeBitCast(di32_full, load_mask), + MaskedLoadOr(ResizeBitCast(di32_full, no), + MaskFromVec(ResizeBitCast(di32_full, load_mask)), + di32_full, reinterpret_cast(p)), + ResizeBitCast(di32_full, v_trailing))); +} + +// Single lane: load or default value. +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingN(VFromD /*load_mask*/, D d, + const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : Zero(d); +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingNOr( + VFromD no, VFromD /*load_mask*/, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : no; +} + +// Two lanes: load 1, 2, or default. +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingN(VFromD /*load_mask*/, D d, + const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes > 1) { + return LoadU(d, p); + } else { + const FixedTag, 1> d1; + return (num_lanes == 1) ? ResizeBitCast(d, LoadU(d1, p)) : Zero(d); + } +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingNOr( + VFromD no, VFromD /*load_mask*/, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes > 1) { + return LoadU(d, p); + } else { + if (num_lanes == 0) return no; + // Load one, upper lane is default. + const FixedTag, 1> d1; + return InterleaveLower(ResizeBitCast(d, LoadU(d1, p)), no); + } +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingN(VFromD load_mask, D d, + const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const size_t trailing_n = num_lanes & 3; + if (trailing_n == 0) return Zero(d); + + VFromD v_trailing = And(load_mask, Set(d, p[num_lanes - 1])); + + if ((trailing_n & 2) != 0) { + const Repartition di16; + int16_t i16_bits; + CopyBytes(p + num_lanes - trailing_n, &i16_bits); + v_trailing = BitCast( + d, IfNegativeThenElse(BitCast(di16, load_mask), Set(di16, i16_bits), + BitCast(di16, v_trailing))); + } + + return v_trailing; +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingNOr( + VFromD no, VFromD load_mask, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const size_t trailing_n = num_lanes & 3; + if (trailing_n == 0) return no; + + VFromD v_trailing = IfVecThenElse(load_mask, Set(d, p[num_lanes - 1]), no); + + if ((trailing_n & 2) != 0) { + const Repartition di16; + int16_t i16_bits; + CopyBytes(p + num_lanes - trailing_n, &i16_bits); + v_trailing = BitCast( + d, IfNegativeThenElse(BitCast(di16, load_mask), Set(di16, i16_bits), + BitCast(di16, v_trailing))); + } + + return v_trailing; +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingN(VFromD load_mask, D d, + const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if ((num_lanes & 1) != 0) { + return And(load_mask, Set(d, p[num_lanes - 1])); + } else { + return Zero(d); + } +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingNOr( + VFromD no, VFromD load_mask, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if ((num_lanes & 1) != 0) { + return IfVecThenElse(load_mask, Set(d, p[num_lanes - 1]), no); + } else { + return no; + } +} + +} // namespace detail + +// Generic for all vector lengths. +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, size_t N) { + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + + const VFromD load_mask = + ResizeBitCast(d, VecFromMask(d_full, FirstN(d_full, N))); + const size_t num_lanes = HWY_MIN(N, HWY_MAX_LANES_D(D)); + const VFromD v_trailing = + detail::AVX2UIF8Or16LoadTrailingN(load_mask, d, p, num_lanes); + +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(num_lanes < (4 / sizeof(TFromD))) && + num_lanes < (4 / sizeof(TFromD))) { + return v_trailing; + } +#endif + + return detail::AVX2UIF8Or16LoadLeadingN(load_mask, d, p, v_trailing); +} + +// Generic for all vector lengths. +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t N) { + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + + const VFromD load_mask = + ResizeBitCast(d, VecFromMask(d_full, FirstN(d_full, N))); + const size_t num_lanes = HWY_MIN(N, HWY_MAX_LANES_D(D)); + const VFromD v_trailing = + detail::AVX2UIF8Or16LoadTrailingNOr(no, load_mask, d, p, num_lanes); + +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(num_lanes < (4 / sizeof(TFromD))) && + num_lanes < (4 / sizeof(TFromD))) { + return v_trailing; + } +#endif + + return detail::AVX2UIF8Or16LoadLeadingNOr(no, load_mask, d, p, v_trailing); +} + +#endif // HWY_TARGET > HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX2 && !HWY_MEM_OPS_MIGHT_FAULT + +// ------------------------------ BlendedStore + +namespace detail { + +// There is no maskload_epi8/16 with which we could safely implement +// BlendedStore. Manual blending is also unsafe because loading a full vector +// that crosses the array end causes asan faults. Resort to scalar code; the +// caller should instead use memcpy, assuming m is FirstN(d, n). +template +HWY_API void ScalarMaskedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToSigned di; // for testing mask if T=bfloat16_t. + using TI = TFromD; + alignas(16) TI buf[MaxLanes(d)]; + alignas(16) TI mask[MaxLanes(d)]; + Store(BitCast(di, v), di, buf); + Store(BitCast(di, VecFromMask(d, m)), di, mask); + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm_mask_storeu_epi8(p, m.raw, v.raw); +} +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + _mm_mask_storeu_epi16(reinterpret_cast(p), RebindMask(du, m).raw, + BitCast(du, v).raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D, float* HWY_RESTRICT p) { + _mm_mask_storeu_ps(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D, double* HWY_RESTRICT p) { + _mm_mask_storeu_pd(p, m.raw, v.raw); +} + +#elif HWY_TARGET == HWY_AVX2 + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + detail::ScalarMaskedStore(v, m, d, p); +} + +namespace detail { + +template +HWY_INLINE void NativeBlendedStore(V v, M m, TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi32(pi, m.raw, v.raw); +} + +template +HWY_INLINE void NativeBlendedStore(V v, M m, TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi64(pi, m.raw, v.raw); +} + +template +HWY_INLINE void NativeBlendedStore(V v, M m, float* HWY_RESTRICT p) { + _mm_maskstore_ps(p, m.raw, v.raw); +} + +template +HWY_INLINE void NativeBlendedStore(V v, M m, double* HWY_RESTRICT p) { + _mm_maskstore_pd(p, m.raw, v.raw); +} + +} // namespace detail + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToSigned di; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (d.MaxBytes() < 16) { + const Full128> dfull; + const Mask128> mfull{m.raw}; + m = MFromD{And(mfull, FirstN(dfull, MaxLanes(d))).raw}; + } + + // Float/double require, and unsigned ints tolerate, signed int masks. + detail::NativeBlendedStore(v, RebindMask(di, m), p); +} + +#else // <= SSE4 + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + // Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). + detail::ScalarMaskedStore(v, m, d, p); +} + +#endif // SSE4 + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ AddSub + +#if HWY_TARGET <= HWY_SSSE3 + +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) \ + HWY_IF_V_SIZE_GT_V( \ + V, ((hwy::IsFloat3264>()) ? 32 : sizeof(TFromV))) + +template +HWY_API Vec128 AddSub(Vec128 a, Vec128 b) { + return Vec128{_mm_addsub_ps(a.raw, b.raw)}; +} +HWY_API Vec128 AddSub(Vec128 a, Vec128 b) { + return Vec128{_mm_addsub_pd(a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_SSSE3 + +// ------------------------------ PairwiseAdd128/PairwiseSub128 + +// Need to use the default implementation of PairwiseAdd128/PairwiseSub128 in +// generic_ops-inl.h for U8/I8/F16/I64/U64 vectors and 64-byte vectors + +#if HWY_TARGET <= HWY_SSSE3 + +#undef HWY_IF_PAIRWISE_ADD_128_D +#undef HWY_IF_PAIRWISE_SUB_128_D +#define HWY_IF_PAIRWISE_ADD_128_D(D) \ + hwy::EnableIf<( \ + HWY_MAX_LANES_D(D) > (32 / sizeof(hwy::HWY_NAMESPACE::TFromD)) || \ + (HWY_MAX_LANES_D(D) > (8 / sizeof(hwy::HWY_NAMESPACE::TFromD)) && \ + !(hwy::IsSameEither, int16_t, \ + uint16_t>() || \ + sizeof(hwy::HWY_NAMESPACE::TFromD) == 4 || \ + hwy::IsSame, double>())))>* = nullptr +#define HWY_IF_PAIRWISE_SUB_128_D(D) HWY_IF_PAIRWISE_ADD_128_D(D) + +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_hadd_epi16(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, Neg(BitCast(di, VFromD{_mm_hsub_epi16(a.raw, b.raw)}))); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_hadd_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, Neg(BitCast(di, VFromD{_mm_hsub_epi32(a.raw, b.raw)}))); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_hadd_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + return Neg(VFromD{_mm_hsub_ps(a.raw, b.raw)}); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_hadd_pd(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + return Neg(VFromD{_mm_hsub_pd(a.raw, b.raw)}); +} + +#endif // HWY_TARGET <= HWY_SSSE3 + +// ------------------------------ SumsOf8 +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128{_mm_sad_epu8(v.raw, _mm_setzero_si128())}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOf8(V v) { + const DFromV d; + const RebindToUnsigned du; + const Repartition di64; + + // Adjust the values of v to be in the 0..255 range by adding 128 to each lane + // of v (which is the same as an bitwise XOR of each i8 lane by 128) and then + // bitcasting the Xor result to an u8 vector. + const auto v_adj = BitCast(du, Xor(v, SignBit(d))); + + // Need to add -1024 to each i64 lane of the result of the SumsOf8(v_adj) + // operation to account for the adjustment made above. + return BitCast(di64, SumsOf8(v_adj)) + Set(di64, int64_t{-1024}); +} + +#ifdef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#endif + +template +HWY_API Vec128 SumsOf8AbsDiff(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sad_epu8(a.raw, b.raw)}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOf8AbsDiff(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWideX3 di64; + + // Adjust the values of a and b to be in the 0..255 range by adding 128 to + // each lane of a and b (which is the same as an bitwise XOR of each i8 lane + // by 128) and then bitcasting the results of the Xor operations to u8 + // vectors. + const auto i8_msb = SignBit(d); + const auto a_adj = BitCast(du, Xor(a, i8_msb)); + const auto b_adj = BitCast(du, Xor(b, i8_msb)); + + // The result of SumsOf8AbsDiff(a_adj, b_adj) can simply be bitcasted to an + // i64 vector as |(a[i] + 128) - (b[i] + 128)| == |a[i] - b[i]| is true + return BitCast(di64, SumsOf8AbsDiff(a_adj, b_adj)); +} + +// ------------------------------ SumsOf4 +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +HWY_INLINE Vec128 SumsOf4( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, + Vec128 v) { + const DFromV d; + + // _mm_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be + // zeroed out and the sums of the 4 consecutive lanes are already in the + // even uint16_t lanes of the _mm_maskz_dbsad_epu8 result. + return Vec128{ + _mm_maskz_dbsad_epu8(static_cast<__mmask8>(0x55), v.raw, Zero(d).raw, 0)}; +} + +// detail::SumsOf4 for Vec128 on AVX3 is implemented in x86_512-inl.h + +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ SumsOfAdjQuadAbsDiff + +#if HWY_TARGET <= HWY_SSE4 +#ifdef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#endif + +template +HWY_API Vec128 SumsOfAdjQuadAbsDiff( + Vec128 a, Vec128 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + return Vec128{ + _mm_mpsadbw_epu8(a.raw, b.raw, (kAOffset << 2) | kBOffset)}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOfAdjQuadAbsDiff(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + // Adjust the values of a and b to be in the 0..255 range by adding 128 to + // each lane of a and b (which is the same as an bitwise XOR of each i8 lane + // by 128) and then bitcasting the results of the Xor operations to u8 + // vectors. + const auto i8_msb = SignBit(d); + const auto a_adj = BitCast(du, Xor(a, i8_msb)); + const auto b_adj = BitCast(du, Xor(b, i8_msb)); + + // The result of SumsOfAdjQuadAbsDiff(a_adj, b_adj) can + // simply be bitcasted to an i16 vector as + // |(a[i] + 128) - (b[i] + 128)| == |a[i] - b[i]| is true. + return BitCast(dw, SumsOfAdjQuadAbsDiff(a_adj, b_adj)); +} +#endif + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if HWY_TARGET <= HWY_AVX3 +#ifdef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#endif + +template +HWY_API Vec128 SumsOfShuffledQuadAbsDiff( + Vec128 a, Vec128 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + return Vec128{ + _mm_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOfShuffledQuadAbsDiff(V a, + V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + // Adjust the values of a and b to be in the 0..255 range by adding 128 to + // each lane of a and b (which is the same as an bitwise XOR of each i8 lane + // by 128) and then bitcasting the results of the Xor operations to u8 + // vectors. + const auto i8_msb = SignBit(d); + const auto a_adj = BitCast(du, Xor(a, i8_msb)); + const auto b_adj = BitCast(du, Xor(b, i8_msb)); + + // The result of + // SumsOfShuffledQuadAbsDiff(a_adj, b_adj) can + // simply be bitcasted to an i16 vector as + // |(a[i] + 128) - (b[i] + 128)| == |a[i] - b[i]| is true. + return BitCast( + dw, SumsOfShuffledQuadAbsDiff(a_adj, b_adj)); +} +#endif + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi16(a.raw, b.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, + Vec128 b) { + const DFromV d; + const auto sum = a + b; + const auto overflow_mask = MaskFromVec( + Vec128{_mm_ternarylogic_epi32(a.raw, b.raw, sum.raw, 0x42)}); + const auto i32_max = Set(d, LimitsMax()); + const Vec128 overflow_result{_mm_mask_ternarylogic_epi32( + i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, sum); +} + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, + Vec128 b) { + const DFromV d; + const auto sum = a + b; + const auto overflow_mask = MaskFromVec( + Vec128{_mm_ternarylogic_epi64(a.raw, b.raw, sum.raw, 0x42)}); + const auto i64_max = Set(d, LimitsMax()); + const Vec128 overflow_result{_mm_mask_ternarylogic_epi64( + i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, sum); +} +#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi16(a.raw, b.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN +template +HWY_API Vec128 SaturatedSub(Vec128 a, + Vec128 b) { + const DFromV d; + const auto diff = a - b; + const auto overflow_mask = MaskFromVec( + Vec128{_mm_ternarylogic_epi32(a.raw, b.raw, diff.raw, 0x18)}); + const auto i32_max = Set(d, LimitsMax()); + const Vec128 overflow_result{_mm_mask_ternarylogic_epi32( + i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, diff); +} + +template +HWY_API Vec128 SaturatedSub(Vec128 a, + Vec128 b) { + const DFromV d; + const auto diff = a - b; + const auto overflow_mask = MaskFromVec( + Vec128{_mm_ternarylogic_epi64(a.raw, b.raw, diff.raw, 0x18)}); + const auto i64_max = Set(d, LimitsMax()); + const Vec128 overflow_result{_mm_mask_ternarylogic_epi64( + i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, diff); +} +#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu16(a.raw, b.raw)}; +} + +// I8/I16 AverageRound is generic for all vector lengths +template +HWY_API V AverageRound(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const V sign_bit = SignBit(d); + return Xor(BitCast(d, AverageRound(BitCast(du, Xor(a, sign_bit)), + BitCast(du, Xor(b, sign_bit)))), + sign_bit); +} + +// ------------------------------ Integer multiplication + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epu16(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epi16(a.raw, b.raw)}; +} + +template , 1)> +HWY_API V MulHigh(V a, V b) { + const DFromV d; + const Full128> d_full; + return ResizeBitCast( + d, Slide1Down(d_full, ResizeBitCast(d_full, MulEven(a, b)))); +} + +// I8/U8/I32/U32 MulHigh is generic for all vector lengths >= 2 lanes +template , 1)> +HWY_API V MulHigh(V a, V b) { + const DFromV d; + + const auto p_even = BitCast(d, MulEven(a, b)); + const auto p_odd = BitCast(d, MulOdd(a, b)); + return InterleaveOdd(d, p_even, p_odd); +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template )> +HWY_API VFromD>> MulEven(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + const auto lo8_mask = Set(dw, uint16_t{0x00FF}); + return And(ResizeBitCast(dw, a), lo8_mask) * + And(ResizeBitCast(dw, b), lo8_mask); +} + +template )> +HWY_API VFromD>> MulEven(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + return ShiftRight<8>(ShiftLeft<8>(ResizeBitCast(dw, a))) * + ShiftRight<8>(ShiftLeft<8>(ResizeBitCast(dw, b))); +} + +template )> +HWY_API VFromD>> MulEven(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + const RepartitionToNarrow dw_as_d16; + + const auto lo = ResizeBitCast(dw, a * b); + const auto hi = ShiftLeft<16>(ResizeBitCast(dw, MulHigh(a, b))); + return BitCast(dw, OddEven(BitCast(dw_as_d16, hi), BitCast(dw_as_d16, lo))); +} + +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epu32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned du; + + // p[i] = (((a[i] >> 31) * (a[i] >> 31)) << 64) + + // (((a[i] >> 31) * b[i]) << 32) + + // (((b[i] >> 31) * a[i]) << 32) + + // ((a[i] & int64_t{0xFFFFFFFF}) * (b[i] & int64_t{0xFFFFFFFF})) + + // ((a[i] >> 31) * (a[i] >> 31)) << 64 does not need to be computed as the + // lower 64 bits of ((a[i] >> 31) * (a[i] >> 31)) << 64 is zero. + + // (((a[i] >> 31) * b[i]) << 32) + (((b[i] >> 31) * a[i]) << 32) == + // -((((a[i] >> 31) & b[i]) + ((b[i] >> 31) & a[i])) << 32) + + // ((a[i] & int64_t{0xFFFFFFFF}) * (b[i] & int64_t{0xFFFFFFFF})) can be + // computed using MulEven(BitCast(du, a), BitCast(du, b)) + + const auto neg_p_hi = ShiftLeft<32>( + ResizeBitCast(dw, And(ShiftRight<31>(a), b) + And(ShiftRight<31>(b), a))); + const auto p_lo = BitCast(dw, MulEven(BitCast(du, a), BitCast(du, b))); + return p_lo - neg_p_hi; +#else + return Vec128{_mm_mul_epi32(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD>> MulOdd(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + return ShiftRight<8>(ResizeBitCast(dw, a)) * + ShiftRight<8>(ResizeBitCast(dw, b)); +} + +template )> +HWY_API VFromD>> MulOdd(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned dw_u; + const RepartitionToNarrow dw_as_d16; + + const auto lo = ShiftRight<16>(BitCast(dw_u, ResizeBitCast(dw, a * b))); + const auto hi = ResizeBitCast(dw, MulHigh(a, b)); + return BitCast(dw, OddEven(BitCast(dw_as_d16, hi), BitCast(dw_as_d16, lo))); +} + +template )> +HWY_API VFromD>> MulOdd(V a, V b) { + return MulEven(DupOdd(a), DupOdd(b)); +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + // Not as inefficient as it looks: _mm_mullo_epi32 has 10 cycle latency. + // 64-bit right shift would also work but also needs port 5, so no benefit. + // Notation: x=don't care, z=0. + const __m128i a_x3x1 = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x2x0 = MulEven(a, b); + const __m128i b_x3x1 = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x3x1 = + MulEven(Vec128{a_x3x1}, Vec128{b_x3x1}); + // We could _mm_slli_epi64 by 32 to get 3z1z and OR with z2z0, but generating + // the latter requires one more instruction or a constant. + const __m128i mul_20 = + _mm_shuffle_epi32(mullo_x2x0.raw, _MM_SHUFFLE(2, 0, 2, 0)); + const __m128i mul_31 = + _mm_shuffle_epi32(mullo_x3x1.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(mul_20, mul_31)}; +#else + return Vec128{_mm_mullo_epi32(a.raw, b.raw)}; +#endif +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + // Same as unsigned; avoid duplicating the SSSE3 code. + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); +} + +#if HWY_TARGET <= HWY_AVX3 +// Per-target flag to prevent generic_ops-inl.h from defining 64-bit operator*. +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +template +HWY_API Vec128 operator*(Vec128 a, + Vec128 b) { + return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(Vec128 a, + Vec128 b) { + return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; +} +#endif + +// ------------------------------ RotateRight (ShiftRight, Or) + +// U8 RotateRight implementation on AVX3_DL is now in x86_512-inl.h as U8 +// RotateRight uses detail::GaloisAffine on AVX3_DL + +#if HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + if (kBits == 0) return v; + // AVX3 does not support 8-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +} +#endif + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); + if (kBits == 0) return v; +#if HWY_TARGET <= HWY_AVX3_DL + return Vec128{_mm_shrdi_epi16(v.raw, v.raw, kBits)}; +#else + // AVX3 does not support 16-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +// I8/I16/I32/I64 RotateRight is generic for all vector lengths +template +HWY_API V RotateRight(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, RotateRight(BitCast(du, v))); +} + +// ------------------------------ Rol/Ror +#if HWY_TARGET <= HWY_AVX3_DL +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{_mm_shrdv_epi16(a.raw, a.raw, b.raw)}; +} + +// U16/I16 Rol is generic for all vector lengths on AVX3_DL +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + return Ror(a, BitCast(d, Neg(BitCast(di, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template +HWY_API Vec128 Rol(Vec128 a, Vec128 b) { + return Vec128{_mm_rolv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{_mm_rorv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Rol(Vec128 a, Vec128 b) { + return Vec128{_mm_rolv_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{_mm_rorv_epi64(a.raw, b.raw)}; +} + +#endif + +// ------------------------------ RotateLeftSame/RotateRightSame + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +// Generic for all vector lengths +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + return Ror(v, + Set(d, static_cast>(0u - static_cast(bits)))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + return Ror(v, Set(d, static_cast>(bits))); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +// Generic for all vector lengths +template +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + return Rol(v, Set(d, static_cast>(static_cast(bits)))); +} + +template +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + return Ror(v, Set(d, static_cast>(static_cast(bits)))); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; + return VecFromMask(v < Zero(d)); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<15>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<31>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Vec128{_mm_srai_epi64(v.raw, 63)}; +#elif HWY_TARGET == HWY_AVX2 || HWY_TARGET == HWY_SSE4 + return VecFromMask(v < Zero(d)); +#else + // Efficient Lt() requires SSE4.2 and BLENDVPD requires SSE4.1. 32-bit shift + // avoids generating a zero. + const RepartitionToNarrow d32; + const auto sign = ShiftRight<31>(BitCast(d32, v)); + return Vec128{ + _mm_shuffle_epi32(sign.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +#endif +} + +// ------------------------------ Integer Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_COMPILER_MSVC || HWY_TARGET == HWY_SSE2 + const DFromV d; + const RebindToUnsigned du; + const auto zero = Zero(du); + const auto v_as_u8 = BitCast(du, v); + return BitCast(d, Min(v_as_u8, zero - v_as_u8)); +#else + return Vec128{_mm_abs_epi8(v.raw)}; +#endif +} + +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_TARGET == HWY_SSE2 + const auto zero = Zero(DFromV()); + return Max(v, zero - v); +#else + return Vec128{_mm_abs_epi16(v.raw)}; +#endif +} + +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_TARGET <= HWY_SSSE3 + return Vec128{_mm_abs_epi32(v.raw)}; +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi64(v.raw)}; +} +#else +// I64 Abs is generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template )> +HWY_API V Abs(V v) { + const auto zero = Zero(DFromV()); + return IfNegativeThenElse(v, zero - v, v); +} +#endif + +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Min(BitCast(du, v), BitCast(du, SaturatedSub(Zero(d), v)))); +} + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + return Max(v, SaturatedSub(Zero(DFromV()), v)); +} + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + const auto abs_v = Abs(v); + +#if HWY_TARGET <= HWY_SSE4 + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Min(BitCast(du, abs_v), + Set(du, static_cast(LimitsMax())))); +#else + return Add(abs_v, BroadcastSignBit(abs_v)); +#endif +} + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + const auto abs_v = Abs(v); + return Add(abs_v, BroadcastSignBit(abs_v)); +} + +// GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512VL +// srli_epi64: the count should be unsigned int. Note that this is not the same +// as the Shift3264Count in x86_512-inl.h (GCC also requires int). +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400) +using Shift64Count = int; +#else +// Assume documented behavior. Clang 12, GCC 14 and MSVC 14.28.29910 match this. +using Shift64Count = unsigned int; +#endif + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{ + _mm_srai_epi64(v.raw, static_cast(kBits))}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +// ------------------------------ IfNegativeThenElse +template +HWY_API Vec128 IfNegativeThenElse(const Vec128 v, + const Vec128 yes, + const Vec128 no) { +// int8: IfThenElse only looks at the MSB on SSE4 or newer +#if HWY_TARGET <= HWY_SSE4 + const auto mask = MaskFromVec(v); +#else + const DFromV d; + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#endif + + return IfThenElse(mask, yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + +// 16-bit: no native blendv on AVX2 or earlier, so copy sign to lower byte's +// MSB. +#if HWY_TARGET <= HWY_AVX3 + const auto mask = MaskFromVec(v); +#else + const DFromV d; + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#endif + + return IfThenElse(mask, yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + // 32/64-bit: use float IfThenElse on SSE4/AVX2, which only looks at the MSB + // on SSE4 or later. + const RebindToFloat df; + const auto mask = MaskFromVec(BitCast(df, v)); + return BitCast(d, IfThenElse(mask, BitCast(df, yes), BitCast(df, no))); +#else // SSE2, SSSE3, or AVX3 + +#if HWY_TARGET <= HWY_AVX3 + // No need to cast to float or broadcast sign bit on AVX3 as IfThenElse only + // looks at the MSB on AVX3 + (void)d; + const auto mask = MaskFromVec(v); +#else + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#endif + + return IfThenElse(mask, yes, no); +#endif +} + +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + +#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#undef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#else +#define HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#endif + +#ifdef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#undef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#else +#define HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#endif + +// SSE4/AVX2 IfNegativeThenElseZero/IfNegativeThenZeroElse is generic for all +// vector lengths +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + const DFromV d; + return IfNegativeThenElse(v, yes, Zero(d)); +} + +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + return IfThenElseZero(IsNegative(v), yes); +} + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + const DFromV d; + return IfNegativeThenElse(v, Zero(d), no); +} + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + return IfThenZeroElse(IsNegative(v), no); +} + +#endif // HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +#if HWY_TARGET <= HWY_SSSE3 + +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero(Vec128 mask, + Vec128 v) { + return Vec128{_mm_sign_epi8(v.raw, mask.raw)}; +} + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero( + Vec128 mask, Vec128 v) { + return Vec128{_mm_sign_epi16(v.raw, mask.raw)}; +} + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero( + Vec128 mask, Vec128 v) { + return Vec128{_mm_sign_epi32(v.raw, mask.raw)}; +} + +// Generic for all vector lengths +template )> +HWY_API V IfNegativeThenNegOrUndefIfZero(V mask, V v) { +#if HWY_TARGET <= HWY_AVX3 + // MaskedSubOr is more efficient than IfNegativeThenElse on AVX3 + const DFromV d; + return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v); +#else + // IfNegativeThenElse is more efficient than MaskedSubOr on SSE4/AVX2 + return IfNegativeThenElse(mask, Neg(v), v); +#endif +} + +#endif // HWY_TARGET <= HWY_SSSE3 + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi16(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi32(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi64(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi16(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi32(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi64(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srli_epi16(v.raw, bits)}; + } +#endif + return Vec128{_mm_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srli_epi32(v.raw, bits)}; + } +#endif + return Vec128{_mm_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srli_epi64(v.raw, bits)}; + } +#endif + return Vec128{_mm_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srai_epi16(v.raw, bits)}; + } +#endif + return Vec128{_mm_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srai_epi32(v.raw, bits)}; + } +#endif + return Vec128{_mm_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{ + _mm_srai_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec128{_mm_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Floating-point mul / div + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator*(Vec128 a, + Vec128 b) { + return Vec128{_mm_mul_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator*(const Vec64 a, const Vec64 b) { + return Vec64{_mm_mul_sd(a.raw, b.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_MUL_BY_POW2 +#undef HWY_NATIVE_MUL_BY_POW2 +#else +#define HWY_NATIVE_MUL_BY_POW2 +#endif + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulByFloorPow2(Vec128 a, + Vec128 b) { + return Vec128{_mm_scalef_ph(a.raw, b.raw)}; +} +#endif + +template +HWY_API Vec128 MulByFloorPow2(Vec128 a, + Vec128 b) { + return Vec128{_mm_scalef_ps(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulByFloorPow2(Vec128 a, + Vec128 b) { + return Vec128{_mm_scalef_pd(a.raw, b.raw)}; +} + +// MulByPow2 is generic for all vector lengths on AVX3 +template +HWY_API V MulByPow2(V v, VFromD>> exp) { + const DFromV d; + return MulByFloorPow2(v, ConvertTo(d, exp)); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator/(const Vec64 a, const Vec64 b) { + return Vec64{_mm_div_sd(a.raw, b.raw)}; +} + +// Approximate reciprocal +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 ApproximateReciprocal( + const Vec128 v) { + return Vec128{_mm_rcp_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ss(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +HWY_API Vec128 ApproximateReciprocal(Vec128 v) { + return Vec128{_mm_rcp14_pd(v.raw)}; +} +HWY_API Vec64 ApproximateReciprocal(Vec64 v) { + return Vec64{_mm_rcp14_sd(v.raw, v.raw)}; +} +#endif + +// Generic for all vector lengths. +template +HWY_API V AbsDiff(V a, V b) { + return Abs(a - b); +} + +// ------------------------------ GetExponent + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_GET_EXPONENT +#undef HWY_NATIVE_GET_EXPONENT +#else +#define HWY_NATIVE_GET_EXPONENT +#endif + +#if HWY_HAVE_FLOAT16 +template ), HWY_IF_V_SIZE_LE_V(V, 16)> +HWY_API V GetExponent(V v) { + return V{_mm_getexp_ph(v.raw)}; +} +#endif +template ), HWY_IF_V_SIZE_LE_V(V, 16)> +HWY_API V GetExponent(V v) { + return V{_mm_getexp_ps(v.raw)}; +} +template ), HWY_IF_V_SIZE_LE_V(V, 16)> +HWY_API V GetExponent(V v) { + return V{_mm_getexp_pd(v.raw)}; +} + +#endif + +// ------------------------------ MaskedMinOr + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMaxOr + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedAddOr + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSubOr + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMulOr + +// There are no elementwise integer mask_mul. Generic for all vector lengths. +template +HWY_API V MaskedMulOr(V no, M m, V a, V b) { + return IfThenElse(m, a * b, no); +} + +template +HWY_API Vec128 MaskedMulOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMulOr(Vec128 no, + Mask128 m, Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedMulOr(Vec128 no, + Mask128 m, + Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedDivOr + +template +HWY_API Vec128 MaskedDivOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedDivOr(Vec128 no, + Mask128 m, Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedDivOr(Vec128 no, + Mask128 m, + Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// Generic for all vector lengths +template +HWY_API V MaskedDivOr(V no, MFromD> m, V a, V b) { + return IfThenElse(m, Div(a, b), no); +} + +// ------------------------------ MaskedModOr +// Generic for all vector lengths +template +HWY_API V MaskedModOr(V no, MFromD> m, V a, V b) { + return IfThenElse(m, Mod(a, b), no); +} + +// ------------------------------ MaskedSatAddOr + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ MaskedSatSubOr + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Floating-point multiply-add variants + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulAdd(Vec128 mul, + Vec128 x, + Vec128 add) { + return Vec128{_mm_fmadd_ph(mul.raw, x.raw, add.raw)}; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, + Vec128 x, + Vec128 add) { + return Vec128{_mm_fnmadd_ph(mul.raw, x.raw, add.raw)}; +} + +template +HWY_API Vec128 MulSub(Vec128 mul, + Vec128 x, + Vec128 sub) { + return Vec128{_mm_fmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, + Vec128 x, + Vec128 sub) { + return Vec128{_mm_fnmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return mul * x + add; +#else + return Vec128{_mm_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return mul * x + add; +#else + return Vec128{_mm_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return add - mul * x; +#else + return Vec128{_mm_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return add - mul * x; +#else + return Vec128{_mm_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return mul * x - sub; +#else + return Vec128{_mm_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return mul * x - sub; +#else + return Vec128{_mm_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_SSSE3 + +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), \ + HWY_IF_T_SIZE_ONE_OF_V( \ + V, (1 << 1) | ((hwy::IsFloat>()) \ + ? 0 \ + : ((1 << 2) | (1 << 4) | (1 << 8)))) + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulAddSub(Vec128 mul, + Vec128 x, + Vec128 sub_or_add) { + return Vec128{_mm_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 MulAddSub(Vec128 mul, Vec128 x, + Vec128 sub_or_add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return AddSub(mul * x, sub_or_add); +#else + return Vec128{_mm_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +HWY_API Vec128 MulAddSub(Vec128 mul, Vec128 x, + Vec128 sub_or_add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return AddSub(mul * x, sub_or_add); +#else + return Vec128{_mm_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +#endif // HWY_TARGET <= HWY_SSSE3 + +// ------------------------------ Floating-point square root + +// Full precision square root +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{_mm_sqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{_mm_sqrt_ps(v.raw)}; +} +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{_mm_sqrt_ss(v.raw)}; +} +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{_mm_sqrt_pd(v.raw)}; +} +HWY_API Vec64 Sqrt(Vec64 v) { + return Vec64{_mm_sqrt_sd(_mm_setzero_pd(), v.raw)}; +} + +// Approximate reciprocal square root +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{_mm_rsqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{_mm_rsqrt_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{_mm_rsqrt_ss(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +HWY_API Vec64 ApproximateReciprocalSqrt(Vec64 v) { + return Vec64{_mm_rsqrt14_sd(v.raw, v.raw)}; +} +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { +#if HWY_COMPILER_MSVC + const DFromV d; + return Vec128{_mm_mask_rsqrt14_pd( + Undefined(d).raw, static_cast<__mmask8>(0xFF), v.raw)}; +#else + return Vec128{_mm_rsqrt14_pd(v.raw)}; +#endif +} +#endif + +// ------------------------------ Min (Gt, IfThenElse) + +namespace detail { + +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MinU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{_mm_min_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return Vec128{ + _mm_sub_epi16(a.raw, _mm_subs_epu16(a.raw, b.raw))}; +#else + return Vec128{_mm_min_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128{_mm_min_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epu64(a.raw, b.raw)}; +#else + return detail::MinU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{_mm_min_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Min(Vec128 a, + Vec128 b) { + return Vec128{_mm_min_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{_mm_min_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{_mm_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +namespace detail { +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MaxU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{_mm_max_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return Vec128{ + _mm_add_epi16(a.raw, _mm_subs_epu16(b.raw, a.raw))}; +#else + return Vec128{_mm_max_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128{_mm_max_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epu64(a.raw, b.raw)}; +#else + return detail::MaxU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{_mm_max_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Max(Vec128 a, + Vec128 b) { + return Vec128{_mm_max_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{_mm_max_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{_mm_max_pd(a.raw, b.raw)}; +} + +// ================================================== MEMORY (3) + +// ------------------------------ Non-temporal stores + +// On clang6, we see incorrect code generated for _mm_stream_pi, so +// round even partial vectors up to 16 bytes. +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; // for float16_t + _mm_stream_si128(reinterpret_cast<__m128i*>(aligned), BitCast(du, v).raw); +} +template +HWY_API void Stream(VFromD v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(VFromD v, D /* tag */, double* HWY_RESTRICT aligned) { + _mm_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unfortunately the GCC/Clang intrinsics do not accept int64_t*. +using GatherIndex64 = long long int; // NOLINT(runtime/int) +static_assert(sizeof(GatherIndex64) == 8, "Must be 64-bit type"); + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +namespace detail { + +template +HWY_INLINE void NativeScatter128(VFromD v, D d, TFromD* HWY_RESTRICT base, + VI index) { + if (d.MaxBytes() == 16) { + _mm_i32scatter_epi32(base, index.raw, v.raw, kScale); + } else { + const __mmask8 mask = (1u << MaxLanes(d)) - 1; + _mm_mask_i32scatter_epi32(base, mask, index.raw, v.raw, kScale); + } +} + +template +HWY_INLINE void NativeScatter128(VFromD v, D d, TFromD* HWY_RESTRICT base, + VI index) { + if (d.MaxBytes() == 16) { + _mm_i64scatter_epi64(base, index.raw, v.raw, kScale); + } else { + const __mmask8 mask = (1u << MaxLanes(d)) - 1; + _mm_mask_i64scatter_epi64(base, mask, index.raw, v.raw, kScale); + } +} + +template +HWY_INLINE void NativeScatter128(VFromD v, D d, float* HWY_RESTRICT base, + VI index) { + if (d.MaxBytes() == 16) { + _mm_i32scatter_ps(base, index.raw, v.raw, kScale); + } else { + const __mmask8 mask = (1u << MaxLanes(d)) - 1; + _mm_mask_i32scatter_ps(base, mask, index.raw, v.raw, kScale); + } +} + +template +HWY_INLINE void NativeScatter128(VFromD v, D d, double* HWY_RESTRICT base, + VI index) { + if (d.MaxBytes() == 16) { + _mm_i64scatter_pd(base, index.raw, v.raw, kScale); + } else { + const __mmask8 mask = (1u << MaxLanes(d)) - 1; + _mm_mask_i64scatter_pd(base, mask, index.raw, v.raw, kScale); + } +} + +template +HWY_INLINE void NativeMaskedScatter128(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT base, VI index) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + _mm_mask_i32scatter_epi32(base, m.raw, index.raw, v.raw, kScale); +} + +template +HWY_INLINE void NativeMaskedScatter128(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT base, VI index) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + _mm_mask_i64scatter_epi64(base, m.raw, index.raw, v.raw, kScale); +} + +template +HWY_INLINE void NativeMaskedScatter128(VFromD v, MFromD m, D d, + float* HWY_RESTRICT base, VI index) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + _mm_mask_i32scatter_ps(base, m.raw, index.raw, v.raw, kScale); +} + +template +HWY_INLINE void NativeMaskedScatter128(VFromD v, MFromD m, D d, + double* HWY_RESTRICT base, VI index) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + _mm_mask_i64scatter_pd(base, m.raw, index.raw, v.raw, kScale); +} + +} // namespace detail + +template +HWY_API void ScatterOffset(VFromD v, D d, TFromD* HWY_RESTRICT base, + VFromD> offset) { + return detail::NativeScatter128<1>(v, d, base, offset); +} +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + VFromD> index) { + return detail::NativeScatter128)>(v, d, base, index); +} +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT base, + VFromD> index) { + return detail::NativeMaskedScatter128)>(v, m, d, base, + index); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Gather (Load/Store) + +#if HWY_TARGET <= HWY_AVX2 + +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +namespace detail { + +template +HWY_INLINE Vec128 NativeGather128(const T* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), indices.raw, kScale)}; +} + +template +HWY_INLINE Vec128 NativeGather128(const T* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), indices.raw, kScale)}; +} + +template +HWY_INLINE Vec128 NativeGather128(const float* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i32gather_ps(base, indices.raw, kScale)}; +} + +template +HWY_INLINE Vec128 NativeGather128(const double* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i64gather_pd(base, indices.raw, kScale)}; +} + +template +HWY_INLINE Vec128 NativeMaskedGatherOr128(Vec128 no, + Mask128 m, + const T* HWY_RESTRICT base, + Vec128 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_mmask_i32gather_epi32( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; +#else + return Vec128{ + _mm_mask_i32gather_epi32(no.raw, reinterpret_cast(base), + indices.raw, m.raw, kScale)}; +#endif +} + +template +HWY_INLINE Vec128 NativeMaskedGatherOr128(Vec128 no, + Mask128 m, + const T* HWY_RESTRICT base, + Vec128 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_mmask_i64gather_epi64( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; +#else + return Vec128{_mm_mask_i64gather_epi64( + no.raw, reinterpret_cast(base), indices.raw, m.raw, + kScale)}; +#endif +} + +template +HWY_INLINE Vec128 NativeMaskedGatherOr128( + Vec128 no, Mask128 m, const float* HWY_RESTRICT base, + Vec128 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{ + _mm_mmask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; +#else + return Vec128{ + _mm_mask_i32gather_ps(no.raw, base, indices.raw, m.raw, kScale)}; +#endif +} + +template +HWY_INLINE Vec128 NativeMaskedGatherOr128( + Vec128 no, Mask128 m, const double* HWY_RESTRICT base, + Vec128 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{ + _mm_mmask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; +#else + return Vec128{ + _mm_mask_i64gather_pd(no.raw, base, indices.raw, m.raw, kScale)}; +#endif +} + +} // namespace detail + +template +HWY_API VFromD GatherOffset(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> offsets) { + return detail::NativeGather128<1>(base, offsets); +} + +template > +HWY_API VFromD GatherIndex(D /*d*/, const T* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeGather128(base, indices); +} + +template > +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D d, + const T* HWY_RESTRICT base, + VFromD> indices) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + + return detail::NativeMaskedGatherOr128(no, m, base, indices); +} + +// Generic for all vector lengths. +template +HWY_API VFromD MaskedGatherIndex(MFromD m, D d, + const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return MaskedGatherIndexOr(Zero(d), m, d, base, indices); +} + +#endif // HWY_TARGET <= HWY_AVX2 + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE (2) + +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{v.raw}; +} +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128{v.raw}; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const RebindToUnsigned du; + return BitCast( + d, VFromD{_mm_slli_si128(BitCast(du, v).raw, kBytes)}); +} + +// Generic for all vector lengths. +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +// Generic for all vector lengths. +template +HWY_API VFromD ShiftLeftLanes(D d, const VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes)>(BitCast(d8, v))); +} + +// Generic for all vector lengths. +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const RebindToUnsigned du; + // For partial vectors, clear upper lanes so we shift in zeros. + if (d.MaxBytes() != 16) { + const Full128> dfull; + const VFromD vfull{v.raw}; + v = VFromD{IfThenElseZero(FirstN(dfull, MaxLanes(d)), vfull).raw}; + } + return BitCast( + d, VFromD{_mm_srli_si128(BitCast(du, v).raw, kBytes)}); +} + +// ------------------------------ ShiftRightLanes +// Generic for all vector lengths. +template +HWY_API VFromD ShiftRightLanes(D d, const VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + const Twice> dut; + using VUT = VFromD; // for float16_t + const VUT vut = BitCast(dut, v); + return BitCast(d, LowerHalf(VUT{_mm_unpackhi_epi64(vut.raw, vut.raw)})); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64{_mm_movehl_ps(v.raw, v.raw)}; +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_pd(v.raw, v.raw)}; +} + +// Partial +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + return LowerHalf(d, ShiftRightBytes(Twice(), v)); +} + +// ------------------------------ ExtractLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + const int pair = _mm_extract_epi16(v.raw, kLane / 2); + constexpr int kShift = kLane & 1 ? 8 : 0; + return static_cast((pair >> kShift) & 0xFF); +#else + return static_cast(_mm_extract_epi8(v.raw, kLane) & 0xFF); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); + const DFromV d; + const RebindToUnsigned du; + const uint16_t lane = static_cast( + _mm_extract_epi16(BitCast(du, v).raw, kLane) & 0xFFFF); + return BitCastScalar(lane); +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return static_cast(_mm_cvtsi128_si32( + (kLane == 0) ? v.raw : _mm_shuffle_epi32(v.raw, kLane))); +#else + return static_cast(_mm_extract_epi32(v.raw, kLane)); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_ARCH_X86_32 + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#elif HWY_TARGET >= HWY_SSSE3 + return static_cast( + _mm_cvtsi128_si64((kLane == 0) ? v.raw : _mm_shuffle_epi32(v.raw, 0xEE))); +#else + return static_cast(_mm_extract_epi64(v.raw, kLane)); +#endif +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return _mm_cvtss_f32((kLane == 0) ? v.raw + : _mm_shuffle_ps(v.raw, v.raw, kLane)); +#else + // Bug in the intrinsic, returns int but should be float. + const int32_t bits = _mm_extract_ps(v.raw, kLane); + return BitCastScalar(bits); +#endif +} + +// There is no extract_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE double ExtractLane(const Vec64 v) { + static_assert(kLane == 0, "Lane index out of bounds"); + return GetLane(v); +} + +template +HWY_INLINE double ExtractLane(const Vec128 v) { + static_assert(kLane < 2, "Lane index out of bounds"); + const Half> dh; + return kLane == 0 ? GetLane(v) : GetLane(UpperHalf(dh, v)); +} + +} // namespace detail + +// Requires one overload per vector length because ExtractLane<3> may be a +// compile error if it calls _mm_extract_epi64. +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE V InsertLaneUsingBroadcastAndBlend(V v, size_t i, TFromV t) { + const DFromV d; + +#if HWY_TARGET <= HWY_AVX3 + using RawMask = decltype(MaskFromVec(VFromD()).raw); + const auto mask = MFromD{static_cast(uint64_t{1} << i)}; +#else + const RebindToUnsigned du; + using TU = TFromD; + const auto mask = RebindMask(d, Iota(du, 0) == Set(du, static_cast(i))); +#endif + + return IfThenElse(mask, Set(d, t), v); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return InsertLaneUsingBroadcastAndBlend(v, kLane, t); +#else + return Vec128{_mm_insert_epi8(v.raw, t, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + const DFromV d; + const RebindToUnsigned du; + const uint16_t bits = BitCastScalar(t); + return BitCast(d, VFromD{ + _mm_insert_epi16(BitCast(du, v).raw, bits, kLane)}); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return InsertLaneUsingBroadcastAndBlend(v, kLane, t); +#else + const MakeSigned ti = BitCastScalar>(t); + return Vec128{_mm_insert_epi32(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 || HWY_ARCH_X86_32 + const DFromV d; + const RebindToFloat df; + const auto vt = BitCast(df, Set(d, t)); + if (kLane == 0) { + return BitCast( + d, Vec128{_mm_shuffle_pd(vt.raw, BitCast(df, v).raw, 2)}); + } + return BitCast( + d, Vec128{_mm_shuffle_pd(BitCast(df, v).raw, vt.raw, 0)}); +#else + const MakeSigned ti = BitCastScalar>(t); + return Vec128{_mm_insert_epi64(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return InsertLaneUsingBroadcastAndBlend(v, kLane, t); +#else + return Vec128{_mm_insert_ps(v.raw, _mm_set_ss(t), kLane << 4)}; +#endif +} + +// There is no insert_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane == 0, "Lane index out of bounds"); + return Set(DFromV(), t); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + const DFromV d; + const Vec128 vt = Set(d, t); + if (kLane == 0) { + return Vec128{_mm_shuffle_pd(vt.raw, v.raw, 2)}; + } + return Vec128{_mm_shuffle_pd(v.raw, vt.raw, 0)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls _mm_insert_epi64. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +// ------------------------------ CombineShiftRightBytes + +#if HWY_TARGET == HWY_SSE2 +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + static_assert(0 < kBytes && kBytes < 16, "kBytes invalid"); + return Or(ShiftRightBytes(d, lo), ShiftLeftBytes<16 - kBytes>(d, hi)); +} +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + + const Twice dt; + return VFromD{ShiftRightBytes(dt, Combine(dt, hi, lo)).raw}; +} +#else +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + const Repartition d8; + return BitCast(d, Vec128{_mm_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + using V8 = Vec128; + const DFromV dfull8; + const Repartition, decltype(dfull8)> dfull; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(dfull8, hi8, lo8); + return VFromD{BitCast(dfull, r).raw}; +} +#endif + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(vu.raw, (0x55 * kLane) & 0xFF); + return BitCast(d, VU{_mm_unpacklo_epi64(lo, lo)}); + } else { + const __m128i hi = _mm_shufflehi_epi16(vu.raw, (0x55 * (kLane - 4)) & 0xFF); + return BitCast(d, VU{_mm_unpackhi_epi64(hi, hi)}); + } +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 3 * kLane)}; +} + +// ------------------------------ TableLookupLanes (Shuffle01) + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices128 { + __m128i raw; +}; + +template , typename TI, size_t kN, + HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE(T, 1)> +HWY_API Indices128 IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, kN * 2)))); +#endif + + // No change as byte indices are always used for 8-bit lane types + (void)d; + return Indices128{vec.raw}; +} + +template , typename TI, size_t kN, + HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE(T, 2)> +HWY_API Indices128 IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, kN * 2)))); +#endif + +#if HWY_TARGET <= HWY_AVX3 || HWY_TARGET == HWY_SSE2 + (void)d; + return Indices128{vec.raw}; +#else // SSSE3, SSE4, or AVX2 + const Repartition d8; + using V8 = VFromD; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + + // Broadcast each lane index to all 4 bytes of T + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<1>(BitCast(d16, lane_indices))); + + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif // HWY_TARGET <= HWY_AVX3 || HWY_TARGET == HWY_SSE2 +} + +template , typename TI, size_t kN, + HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE(T, 4)> +HWY_API Indices128 IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, kN * 2)))); +#endif + +#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SSE2 + (void)d; + return Indices128{vec.raw}; +#else + const Repartition d8; + using V8 = VFromD; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + + // Broadcast each lane index to all 4 bytes of T + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif +} + +template , typename TI, size_t kN, + HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE(T, 8)> +HWY_API Indices128 IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(kN * 2))))); +#else + (void)d; +#endif + + // No change - even without AVX3, we can shuffle+blend. + return Indices128{vec.raw}; +} + +template +HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( + D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + return TableLookupBytes(v, Vec128{idx.raw}); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return {_mm_permutexvar_epi16(idx.raw, v.raw)}; +#elif HWY_TARGET == HWY_SSE2 +#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) + typedef uint16_t GccU16RawVectType __attribute__((__vector_size__(16))); + return Vec128{reinterpret_cast::type>( + __builtin_shuffle(reinterpret_cast(v.raw), + reinterpret_cast(idx.raw)))}; +#else + const Full128 d_full; + alignas(16) T src_lanes[8]; + alignas(16) uint16_t indices[8]; + alignas(16) T result_lanes[8]; + + Store(Vec128{v.raw}, d_full, src_lanes); + _mm_store_si128(reinterpret_cast<__m128i*>(indices), idx.raw); + + for (int i = 0; i < 8; i++) { + result_lanes[i] = src_lanes[indices[i] & 7u]; + } + + return Vec128{Load(d_full, result_lanes).raw}; +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) +#else + return TableLookupBytes(v, Vec128{idx.raw}); +#endif +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { + return {_mm_permutexvar_ph(idx.raw, v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + const DFromV d; + const RebindToFloat df; + const Vec128 perm{_mm_permutevar_ps(BitCast(df, v).raw, idx.raw)}; + return BitCast(d, perm); +#elif HWY_TARGET == HWY_SSE2 +#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) + typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16))); + return Vec128{reinterpret_cast::type>( + __builtin_shuffle(reinterpret_cast(v.raw), + reinterpret_cast(idx.raw)))}; +#else + const Full128 d_full; + alignas(16) T src_lanes[4]; + alignas(16) uint32_t indices[4]; + alignas(16) T result_lanes[4]; + + Store(Vec128{v.raw}, d_full, src_lanes); + _mm_store_si128(reinterpret_cast<__m128i*>(indices), idx.raw); + + for (int i = 0; i < 4; i++) { + result_lanes[i] = src_lanes[indices[i] & 3u]; + } + + return Vec128{Load(d_full, result_lanes).raw}; +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) +#else // SSSE3 or SSE4 + return TableLookupBytes(v, Vec128{idx.raw}); +#endif +} + +#if HWY_TARGET <= HWY_SSSE3 +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX2 + return Vec128{_mm_permutevar_ps(v.raw, idx.raw)}; +#else // SSSE3 or SSE4 + const DFromV df; + const RebindToSigned di; + return BitCast(df, + TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +#endif // HWY_TARGET <= HWY_AVX2 +} +#endif // HWY_TARGET <= HWY_SSSE3 + +// Single lane: no change +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 /* idx */) { + return v; +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + // There is no _mm_permute[x]var_epi64. + vidx += vidx; // bit1 is the decider (unusual) + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_permutevar_pd(BitCast(df, v).raw, vidx.raw)}); +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const RebindToSigned di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { + Vec128 vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + vidx += vidx; // bit1 is the decider (unusual) + return Vec128{_mm_permutevar_pd(v.raw, vidx.raw)}; +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const DFromV d; + const RebindToSigned di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128 mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template +HWY_API VFromD Reverse(D /* tag */, VFromD v) { + return v; +} + +// 32-bit x2: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return VFromD{Shuffle2301(Vec128>{v.raw}).raw}; +} + +// 64-bit x2: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return Shuffle01(v); +} + +// 32-bit x4: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API VFromD Reverse(D d, const VFromD v) { + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + constexpr size_t kN = MaxLanes(d); + if (kN == 1) return v; + if (kN == 2) { + return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 0, 1))}); + } + if (kN == 4) { + return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET == HWY_SSE2 + const VU rev4{ + _mm_shufflehi_epi16(_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3)), + _MM_SHUFFLE(0, 1, 2, 3))}; + return BitCast(d, VU{_mm_shuffle_epi32(rev4.raw, _MM_SHUFFLE(1, 0, 3, 2))}); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + constexpr int kN = static_cast(MaxLanes(d)); + if (kN == 1) return v; +#if HWY_TARGET <= HWY_SSSE3 + // NOTE: Lanes with negative shuffle control mask values are set to zero. + alignas(16) static constexpr int8_t kReverse[16] = { + kN - 1, kN - 2, kN - 3, kN - 4, kN - 5, kN - 6, kN - 7, kN - 8, + kN - 9, kN - 10, kN - 11, kN - 12, kN - 13, kN - 14, kN - 15, kN - 16}; + const RebindToSigned di; + const VFromD idx = Load(di, kReverse); + return VFromD{_mm_shuffle_epi8(BitCast(di, v).raw, idx.raw)}; +#else + const RepartitionToWide d16; + return BitCast(d, Reverse(d16, RotateRight<8>(BitCast(d16, v)))); +#endif +} + +// ------------------------------ Reverse2 + +// Single lane: no change +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return v; +} + +// Generic for all vector lengths (128-bit sufficient if SSE2). +template +HWY_API VFromD Reverse2(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3 + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +#elif HWY_TARGET == HWY_SSE2 + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + constexpr size_t kN = MaxLanes(d); + __m128i shuf_result = _mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(2, 3, 0, 1)); + if (kN > 4) { + shuf_result = _mm_shufflehi_epi16(shuf_result, _MM_SHUFFLE(2, 3, 0, 1)); + } + return BitCast(d, VU{shuf_result}); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0302, 0x0100, 0x0706, 0x0504, 0x0B0A, 0x0908, 0x0F0E, 0x0D0C); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} + +// Generic for all vector lengths. +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle2301(v); +} + +// Generic for all vector lengths. +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + // 4x 16-bit: a single shufflelo suffices. + constexpr size_t kN = MaxLanes(d); + if (kN <= 4) { + return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET == HWY_SSE2 + return BitCast(d, VU{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3)), + _MM_SHUFFLE(0, 1, 2, 3))}); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0706, 0x0504, 0x0302, 0x0100, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} + +// Generic for all vector lengths. +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return Shuffle0123(v); +} + +template +HWY_API VFromD Reverse4(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { +#if HWY_TARGET == HWY_SSE2 + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} + +template +HWY_API VFromD Reverse8(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit +} + +// ------------------------------ ReverseBits in x86_512 + +// ------------------------------ InterleaveUpper (UpperHalf) + +// Full +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_epi64(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_pd(a.raw, b.raw)}; +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + return InterleaveLower(d, VFromD{UpperHalf(d2, a).raw}, + VFromD{UpperHalf(d2, b).raw}); +} + +// -------------------------- I8/U8 Broadcast (InterleaveLower, InterleaveUpper) + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + const DFromV d; + +#if HWY_TARGET == HWY_SSE2 + const Full128 d_full; + const Vec128 v_full{v.raw}; + const auto v_interleaved = (kLane < 8) + ? InterleaveLower(d_full, v_full, v_full) + : InterleaveUpper(d_full, v_full, v_full); + return ResizeBitCast( + d, Broadcast(BitCast(Full128(), v_interleaved))); +#else + return TableLookupBytes(v, Set(d, static_cast(kLane))); +#endif +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +// Generic for all vector lengths. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== CONVERT (1) + +// ------------------------------ PromoteTo unsigned (TableLookupBytesOr0) +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + return VFromD{_mm_unpacklo_epi8(v.raw, zero)}; +#else + return VFromD{_mm_cvtepu8_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return VFromD{_mm_unpacklo_epi16(v.raw, _mm_setzero_si128())}; +#else + return VFromD{_mm_cvtepu16_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return VFromD{_mm_unpacklo_epi32(v.raw, _mm_setzero_si128())}; +#else + return VFromD{_mm_cvtepu32_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + const __m128i u16 = _mm_unpacklo_epi8(v.raw, zero); + return VFromD{_mm_unpacklo_epi16(u16, zero)}; +#else + return VFromD{_mm_cvtepu8_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET > HWY_SSSE3 + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +#elif HWY_TARGET == HWY_SSSE3 + alignas(16) static constexpr int8_t kShuffle[16] = { + 0, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1}; + const Repartition di8; + return TableLookupBytesOr0(v, BitCast(d, Load(di8, kShuffle))); +#else + (void)d; + return VFromD{_mm_cvtepu8_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET > HWY_SSSE3 + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +#elif HWY_TARGET == HWY_SSSE3 + alignas(16) static constexpr int8_t kShuffle[16] = { + 0, 1, -1, -1, -1, -1, -1, -1, 2, 3, -1, -1, -1, -1, -1, -1}; + const Repartition di8; + return TableLookupBytesOr0(v, BitCast(d, Load(di8, kShuffle))); +#else + (void)d; + return VFromD{_mm_cvtepu16_epi64(v.raw)}; +#endif +} + +// Unsigned to signed: same plus cast. +template ), sizeof(TFromV)), + HWY_IF_LANES_D(D, HWY_MAX_LANES_V(V))> +HWY_API VFromD PromoteTo(D di, V v) { + const RebindToUnsigned du; + return BitCast(di, PromoteTo(du, v)); +} + +// ------------------------------ PromoteTo signed (ShiftRight, ZipLower) + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return ShiftRight<8>(VFromD{_mm_unpacklo_epi8(v.raw, v.raw)}); +#else + return VFromD{_mm_cvtepi8_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return ShiftRight<16>(VFromD{_mm_unpacklo_epi16(v.raw, v.raw)}); +#else + return VFromD{_mm_cvtepi16_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return ShiftRight<32>(VFromD{_mm_unpacklo_epi32(v.raw, v.raw)}); +#else + return VFromD{_mm_cvtepi32_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i x2 = _mm_unpacklo_epi8(v.raw, v.raw); + const __m128i x4 = _mm_unpacklo_epi16(x2, x2); + return ShiftRight<24>(VFromD{x4}); +#else + return VFromD{_mm_cvtepi8_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const Repartition di32; + const Half dh_i32; + const VFromD x4{PromoteTo(dh_i32, v).raw}; + const VFromD s4{ + _mm_shufflelo_epi16(x4.raw, _MM_SHUFFLE(3, 3, 1, 1))}; + return ZipLower(d, x4, s4); +#else + (void)d; + return VFromD{_mm_cvtepi8_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const Repartition di32; + const Half dh_i32; + const VFromD x2{PromoteTo(dh_i32, v).raw}; + const VFromD s2{ + _mm_shufflelo_epi16(x2.raw, _MM_SHUFFLE(3, 3, 1, 1))}; + return ZipLower(d, x2, s2); +#else + (void)d; + return VFromD{_mm_cvtepi16_epi64(v.raw)}; +#endif +} + +// -------------------- PromoteTo float (ShiftLeft, IfNegativeThenElse) +#if HWY_TARGET < HWY_SSE4 && !defined(HWY_DISABLE_F16C) + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +// Workaround for origin tracking bug in Clang msan prior to 11.0 +// (spurious "uninitialized memory" for TestF16 with "ORIGIN: invalid") +#if HWY_IS_MSAN && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1100) +#define HWY_INLINE_F16 HWY_NOINLINE +#else +#define HWY_INLINE_F16 HWY_INLINE +#endif +template +HWY_INLINE_F16 VFromD PromoteTo(D /*tag*/, VFromD> v) { +#if HWY_HAVE_FLOAT16 + const RebindToUnsigned> du16; + return VFromD{_mm_cvtph_ps(BitCast(du16, v).raw)}; +#else + return VFromD{_mm_cvtph_ps(v.raw)}; +#endif +} + +#endif // HWY_NATIVE_F16C + +#if HWY_HAVE_FLOAT16 + +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, VFromD> v) { + return VFromD{_mm_cvtph_pd(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtps_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepi32_pd(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD PromoteTo(D /*df64*/, VFromD> v) { + return VFromD{_mm_cvtepu32_pd(v.raw)}; +} +#else +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { + const Rebind di32; + const auto i32_to_f64_result = PromoteTo(df64, BitCast(di32, v)); + return i32_to_f64_result + IfNegativeThenElse(i32_to_f64_result, + Set(df64, 4294967296.0), + Zero(df64)); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Per4LaneBlockShuffle +namespace detail { + +#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#else +#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#endif + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + return ResizeBitCast( + d, Vec128{_mm_set_epi32( + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0))}); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag<8> /*vect_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm_shufflelo_epi16( + BitCast(du, v).raw, static_cast(kIdx3210 & 0xFF))}); +} + +#if HWY_TARGET == HWY_SSE2 +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag<16> /*vect_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + constexpr int kShuffle = static_cast(kIdx3210 & 0xFF); + return BitCast( + d, VFromD{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(BitCast(du, v).raw, kShuffle), kShuffle)}); +} + +template * = nullptr> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag idx_3210_tag, + hwy::SizeTag<1> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const Rebind du16; + const RebindToSigned di16; + + const auto vu16 = PromoteTo(du16, BitCast(du, v)); + const auto shuf16_result = Per4LaneBlockShuffle( + idx_3210_tag, hwy::SizeTag<2>(), hwy::SizeTag(), vu16); + return BitCast(d, DemoteTo(du, BitCast(di16, shuf16_result))); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag idx_3210_tag, + hwy::SizeTag<1> /*lane_size_tag*/, + hwy::SizeTag<16> /*vect_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; + const Repartition du16; + const RebindToSigned di16; + + const auto zero = Zero(d); + const auto v_lo16 = BitCast(du16, InterleaveLower(d, v, zero)); + const auto v_hi16 = BitCast(du16, InterleaveUpper(d, v, zero)); + + const auto lo_shuf_result = Per4LaneBlockShuffle( + idx_3210_tag, hwy::SizeTag<2>(), hwy::SizeTag<16>(), v_lo16); + const auto hi_shuf_result = Per4LaneBlockShuffle( + idx_3210_tag, hwy::SizeTag<2>(), hwy::SizeTag<16>(), v_hi16); + + return BitCast(d, OrderedDemote2To(du, BitCast(di16, lo_shuf_result), + BitCast(di16, hi_shuf_result))); +} +#endif + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<16> /*vect_size_tag*/, V v) { + return V{_mm_shuffle_epi32(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<16> /*vect_size_tag*/, V v) { + return V{_mm_shuffle_ps(v.raw, v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Full64 du64; + const auto vu64 = ResizeBitCast(du64, v); + return ResizeBitCast( + d, ShiftLeftSame(vu64, static_cast(amt * sizeof(TFromV) * 8))); +} + +#if HWY_TARGET <= HWY_SSSE3 +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Repartition du8; + const auto idx = + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromV))); + return BitCast(d, TableLookupBytesOr0(BitCast(du8, v), idx)); +} +#else +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Repartition di32; + const Repartition du64; + constexpr size_t kNumOfLanesPerU64 = 8 / sizeof(TFromV); + + const auto vu64 = BitCast(du64, v); + const auto v_hi = IfVecThenElse( + BitCast(du64, Set(di32, -static_cast(amt >= kNumOfLanesPerU64))), + BitCast(du64, ShiftLeftBytes<8>(du64, vu64)), vu64); + const auto v_lo = ShiftLeftBytes<8>(du64, v_hi); + + const int shl_amt = static_cast((amt * sizeof(TFromV) * 8) & 63); + return BitCast( + d, Or(ShiftLeftSame(v_hi, shl_amt), ShiftRightSame(v_lo, 64 - shl_amt))); +} +#endif + +} // namespace detail + +template +HWY_API VFromD SlideUpLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + case 8: + return ShiftLeftLanes<8>(d, v); + case 9: + return ShiftLeftLanes<9>(d, v); + case 10: + return ShiftLeftLanes<10>(d, v); + case 11: + return ShiftLeftLanes<11>(d, v); + case 12: + return ShiftLeftLanes<12>(d, v); + case 13: + return ShiftLeftLanes<13>(d, v); + case 14: + return ShiftLeftLanes<14>(d, v); + case 15: + return ShiftLeftLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition, decltype(d)> dv; + return BitCast(d, + ShiftRightSame(BitCast(dv, v), + static_cast(amt * sizeof(TFromV) * 8))); +} + +#if HWY_TARGET <= HWY_SSSE3 +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition di8; + auto idx = Iota(di8, static_cast(amt * sizeof(TFromV))); + idx = Or(idx, VecFromMask(di8, idx > Set(di8, int8_t{15}))); + return BitCast(d, TableLookupBytesOr0(BitCast(di8, v), idx)); +} +#else +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition di32; + const Repartition du64; + constexpr size_t kNumOfLanesPerU64 = 8 / sizeof(TFromV); + + const auto vu64 = BitCast(du64, v); + const auto v_lo = IfVecThenElse( + BitCast(du64, Set(di32, -static_cast(amt >= kNumOfLanesPerU64))), + BitCast(du64, ShiftRightBytes<8>(du64, vu64)), vu64); + const auto v_hi = ShiftRightBytes<8>(du64, v_lo); + + const int shr_amt = static_cast((amt * sizeof(TFromV) * 8) & 63); + return BitCast( + d, Or(ShiftRightSame(v_lo, shr_amt), ShiftLeftSame(v_hi, 64 - shr_amt))); +} +#endif + +} // namespace detail + +template +HWY_API VFromD SlideDownLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + case 8: + return ShiftRightLanes<8>(d, v); + case 9: + return ShiftRightLanes<9>(d, v); + case 10: + return ShiftRightLanes<10>(d, v); + case 11: + return ShiftRightLanes<11>(d, v); + case 12: + return ShiftRightLanes<12>(d, v); + case 13: + return ShiftRightLanes<13>(d, v); + case 14: + return ShiftRightLanes<14>(d, v); + case 15: + return ShiftRightLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +// ================================================== MEMORY (4) + +// ------------------------------ StoreN (ExtractLane) + +#if HWY_TARGET <= HWY_AVX2 + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template +HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t num_lanes_to_store = + HWY_MIN(max_lanes_to_store, HWY_MAX_LANES_D(D)); + +#if HWY_COMPILER_MSVC + // Work around MSVC compiler bug by using a HWY_FENCE before the BlendedStore + HWY_FENCE; +#endif + + BlendedStore(v, FirstN(d, num_lanes_to_store), d, p); + +#if HWY_COMPILER_MSVC + // Work around MSVC compiler bug by using a HWY_FENCE after the BlendedStore + HWY_FENCE; +#endif + + detail::MaybeUnpoison(p, num_lanes_to_store); +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 0) { + StoreU(v, d, p); + } +} + +template +HWY_API void StoreN(VFromD v, D /*d*/, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store >= 1) { + p[static_cast(max_lanes_to_store > 1)] = detail::ExtractLane<1>(v); + p[0] = GetLane(v); + } +} + +namespace detail { + +template +HWY_API void AVX2UIF8Or16StoreTrailingN(VFromD v_trailing, D /*d*/, + TFromD* HWY_RESTRICT p, + size_t num_lanes_to_store) { + // AVX2UIF8Or16StoreTrailingN should only be called for an I8/U8 vector if + // (num_lanes_to_store & 3) != 0 is true + const auto v_full128 = ResizeBitCast(Full128>(), v_trailing); + if ((num_lanes_to_store & 2) != 0) { + const uint16_t u16_bits = GetLane(BitCast(Full128(), v_full128)); + p[num_lanes_to_store - 1] = detail::ExtractLane<2>(v_full128); + CopyBytes(&u16_bits, + p + (num_lanes_to_store & ~size_t{3})); + } else { + p[num_lanes_to_store - 1] = GetLane(v_full128); + } +} + +template +HWY_API void AVX2UIF8Or16StoreTrailingN(VFromD v_trailing, D /*d*/, + TFromD* p, + size_t num_lanes_to_store) { + // AVX2UIF8Or16StoreTrailingN should only be called for an I16/U16/F16/BF16 + // vector if (num_lanes_to_store & 1) == 1 is true + p[num_lanes_to_store - 1] = GetLane(v_trailing); +} + +} // namespace detail + +template +HWY_API void StoreN(VFromD v, D d, TFromD* p, size_t max_lanes_to_store) { + const size_t num_lanes_to_store = + HWY_MIN(max_lanes_to_store, HWY_MAX_LANES_D(D)); + + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + const RebindToUnsigned du_full; + const Repartition di32_full; + + const auto i32_store_mask = BitCast( + di32_full, VecFromMask(du_full, FirstN(du_full, num_lanes_to_store))); + const auto vi32 = ResizeBitCast(di32_full, v); + +#if HWY_COMPILER_MSVC + // Work around MSVC compiler bug by using a HWY_FENCE before the BlendedStore + HWY_FENCE; +#endif + + BlendedStore(vi32, MaskFromVec(i32_store_mask), di32_full, + reinterpret_cast(p)); + + constexpr size_t kNumOfLanesPerI32 = 4 / sizeof(TFromD); + constexpr size_t kTrailingLenMask = kNumOfLanesPerI32 - 1; + const size_t trailing_n = (num_lanes_to_store & kTrailingLenMask); + + if (trailing_n != 0) { + const VFromD v_trailing = ResizeBitCast( + d, SlideDownLanes(di32_full, vi32, + num_lanes_to_store / kNumOfLanesPerI32)); + detail::AVX2UIF8Or16StoreTrailingN(v_trailing, d, p, num_lanes_to_store); + } + +#if HWY_COMPILER_MSVC + // Work around MSVC compiler bug by using a HWY_FENCE after the BlendedStore + HWY_FENCE; +#endif + + detail::MaybeUnpoison(p, num_lanes_to_store); +} +#endif // HWY_TARGET > HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX2 + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + const RebindToUnsigned duh; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(duh, lo_half).raw}; + const VU hi{BitCast(duh, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, VFromD{_mm_move_epi64(BitCast(duh, lo).raw)}); +} + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const Half dh; + return IfThenElseZero(FirstN(d, MaxLanes(dh)), VFromD{lo.raw}); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, ZeroExtendVector(du, BitCast(duh, lo))); +} +#endif + +// Generic for all vector lengths. +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, ZeroExtendVector(du, BitCast(duh, lo))); +} + +// ------------------------------ Concat full (InterleaveLower) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Repartition d64; + return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Repartition d64; + return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Repartition dd; +#if HWY_TARGET >= HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd(BitCast(dd, lo).raw, BitCast(dd, hi).raw, + _MM_SHUFFLE2(1, 0))}); +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _pd can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +template +HWY_API Vec128 ConcatUpperLower(D d, Vec128 hi, + Vec128 lo) { +#if HWY_TARGET >= HWY_SSSE3 + (void)d; + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 2, 1, 0))}; +#else + // _mm_shuffle_ps has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + const RepartitionToWide dd; + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +template +HWY_API Vec128 ConcatUpperLower(D /* tag */, Vec128 hi, + Vec128 lo) { +#if HWY_TARGET >= HWY_SSSE3 + return Vec128{_mm_shuffle_pd(lo.raw, hi.raw, _MM_SHUFFLE2(1, 0))}; +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return Vec128{_mm_blend_pd(hi.raw, lo.raw, 1)}; +#endif +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatLowerUpper(D d, const VFromD hi, + const VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const Repartition dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec128 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<8>(BitCast(dw, lo)); + return VFromD{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec64 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec64 uL = ShiftRight<8>(BitCast(dw, lo)); + return VFromD{_mm_shuffle_epi32(_mm_packus_epi16(uL.raw, uH.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +#else + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[8] = {1, 3, 5, 7}; + const VFromD shuf = BitCast(d, Load(Full64(), kCompactOddU8)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +#endif +} + +// 8-bit x4 +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + const Twice dw_2; + // Right-shift 8 bits per u16 so we can pack. + const Vec32 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec32 uL = ShiftRight<8>(BitCast(dw, lo)); + const Vec64 uHL = Combine(dw_2, uH, uL); + return VFromD{_mm_packus_epi16(uHL.raw, uHL.raw)}; +#else + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[4] = {1, 3}; + const VFromD shuf = BitCast(d, Load(Full32(), kCompactOddU8)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + // Right-shift 16 bits per i32 - a *signed* shift of 0x8000xxxx returns + // 0xFFFF8000, which correctly saturates to 0x8000. + const RebindToUnsigned du; + const Repartition dw; + const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); + return BitCast(d, VFromD{_mm_packs_epi32(uL.raw, uH.raw)}); +} + +// 16-bit x4 +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + // Right-shift 16 bits per i32 - a *signed* shift of 0x8000xxxx returns + // 0xFFFF8000, which correctly saturates to 0x8000. + const Repartition dw; + const Vec64 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec64 uL = ShiftRight<16>(BitCast(dw, lo)); + return VFromD{_mm_shuffle_epi32(_mm_packs_epi32(uL.raw, uH.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +#else + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU16[8] = {2, 3, 6, 7}; + const VFromD shuf = BitCast(d, Load(Full64(), kCompactOddU16)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +#endif +} + +// 32-bit full +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(3, 1, 3, 1))}); +} + +// Any type x2 +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Repartition dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec128 mask = Set(dw, 0x00FF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return VFromD{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec64 mask = Set(dw, 0x00FF); + const Vec64 uH = And(BitCast(dw, hi), mask); + const Vec64 uL = And(BitCast(dw, lo), mask); + return VFromD{_mm_shuffle_epi32(_mm_packus_epi16(uL.raw, uH.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +#else + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[8] = {0, 2, 4, 6}; + const VFromD shuf = BitCast(d, Load(Full64(), kCompactEvenU8)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +#endif +} + +// 8-bit x4 +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + const Twice dw_2; + // Isolate lower 8 bits per u16 so we can pack. + const Vec32 mask = Set(dw, 0x00FF); + const Vec32 uH = And(BitCast(dw, hi), mask); + const Vec32 uL = And(BitCast(dw, lo), mask); + const Vec64 uHL = Combine(dw_2, uH, uL); + return VFromD{_mm_packus_epi16(uHL.raw, uHL.raw)}; +#else + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[4] = {0, 2}; + const VFromD shuf = BitCast(d, Load(Full32(), kCompactEvenU8)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +#endif +} + +// 16-bit full +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET <= HWY_SSE4 + // Isolate lower 16 bits per u32 so we can pack. + const RebindToUnsigned du; // for float16_t + const Repartition dw; + const Vec128 mask = Set(dw, 0x0000FFFF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return BitCast(d, VFromD{_mm_packus_epi32(uL.raw, uH.raw)}); +#elif HWY_TARGET == HWY_SSE2 + const Repartition dw; + return ConcatOdd(d, BitCast(d, ShiftLeft<16>(BitCast(dw, hi))), + BitCast(d, ShiftLeft<16>(BitCast(dw, lo)))); +#else + const RebindToUnsigned du; + // packs_epi32 saturates 0x8000 to 0x7FFF. Instead ConcatEven within the two + // inputs, then concatenate them. + alignas(16) + const uint16_t kCompactEvenU16[8] = {0x0100, 0x0504, 0x0908, 0x0D0C}; + const VFromD shuf = BitCast(d, Load(du, kCompactEvenU16)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return ConcatLowerLower(d, H, L); +#endif +} + +// 16-bit x4 +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + return ConcatOdd(d, BitCast(d, ShiftLeft<16>(BitCast(dw, hi))), + BitCast(d, ShiftLeft<16>(BitCast(dw, lo)))); +#else + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU16[8] = {0, 1, 4, 5}; + const VFromD shuf = BitCast(d, Load(Full64(), kCompactEvenU16)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +#endif +} + +// 32-bit full +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} +template +HWY_API VFromD ConcatEven(D /* d */, VFromD hi, VFromD lo) { + return VFromD{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; +} + +// Any T x2 +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return v; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +template +HWY_API V DupEven(V v) { + const DFromV d; + +#if HWY_TARGET <= HWY_SSSE3 + const RebindToUnsigned du; + const VFromD shuffle = Dup128VecFromValues( + du, 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14); + return TableLookupBytes(v, BitCast(d, shuffle)); +#else + const Repartition du16; + return IfVecThenElse(BitCast(d, Set(du16, uint16_t{0xFF00})), + BitCast(d, ShiftLeft<8>(BitCast(du16, v))), v); +#endif +} + +template +HWY_API Vec64 DupEven(const Vec64 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_shufflelo_epi16( + BitCast(du, v).raw, _MM_SHUFFLE(2, 2, 0, 0))}); +} + +// Generic for all vector lengths. +template +HWY_API V DupEven(const V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t +#if HWY_TARGET <= HWY_SSSE3 + const VFromD shuffle = Dup128VecFromValues( + du, 0x0100, 0x0100, 0x0504, 0x0504, 0x0908, 0x0908, 0x0d0c, 0x0d0c); + return TableLookupBytes(v, BitCast(d, shuffle)); +#else + return BitCast( + d, VFromD{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(BitCast(du, v).raw, _MM_SHUFFLE(2, 2, 0, 0)), + _MM_SHUFFLE(2, 2, 0, 0))}); +#endif +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return v; +} + +template +HWY_API V DupOdd(V v) { + const DFromV d; + +#if HWY_TARGET <= HWY_SSSE3 + const RebindToUnsigned du; + const VFromD shuffle = Dup128VecFromValues( + du, 1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15); + return TableLookupBytes(v, BitCast(d, shuffle)); +#else + const Repartition du16; + return IfVecThenElse(BitCast(d, Set(du16, uint16_t{0x00FF})), + BitCast(d, ShiftRight<8>(BitCast(du16, v))), v); +#endif +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_shufflelo_epi16( + BitCast(du, v).raw, _MM_SHUFFLE(3, 3, 1, 1))}); +} + +// Generic for all vector lengths. +template +HWY_API V DupOdd(V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t +#if HWY_TARGET <= HWY_SSSE3 + const VFromD shuffle = Dup128VecFromValues( + du, 0x0302, 0x0302, 0x0706, 0x0706, 0x0b0a, 0x0b0a, 0x0f0e, 0x0f0e); + return TableLookupBytes(v, BitCast(d, shuffle)); +#else + return BitCast( + d, VFromD{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(BitCast(du, v).raw, _MM_SHUFFLE(3, 3, 1, 1)), + _MM_SHUFFLE(3, 3, 1, 1))}); +#endif +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ TwoTablesLookupLanes (DupEven) + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Twice dt; +// TableLookupLanes currently requires table and index vectors to be the same +// size, though a half-length index vector would be sufficient here. +#if HWY_IS_MSAN + const Vec128 idx_vec{idx.raw}; + const Indices128 idx2{Combine(dt, idx_vec, idx_vec).raw}; +#else + // We only keep LowerHalf of the result, which is valid in idx. + const Indices128 idx2{idx.raw}; +#endif + return LowerHalf(d, TableLookupLanes(Combine(dt, b, a), idx2)); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec128{_mm_permutex2var_epi8(a.raw, idx.raw, b.raw)}; +#else // AVX3 or below + const DFromV d; + const Vec128 idx_vec{idx.raw}; + +#if HWY_TARGET <= HWY_SSE4 + const Repartition du16; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, idx_vec)))); +#else + const RebindToSigned di; + const auto sel_hi_mask = + RebindMask(d, BitCast(di, idx_vec) > Set(di, int8_t{15})); +#endif + + const auto lo_lookup_result = TableLookupBytes(a, idx_vec); +#if HWY_TARGET <= HWY_AVX3 + const Vec128 lookup_result{_mm_mask_shuffle_epi8( + lo_lookup_result.raw, sel_hi_mask.raw, b.raw, idx_vec.raw)}; + return lookup_result; +#else + const auto hi_lookup_result = TableLookupBytes(b, idx_vec); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif // HWY_TARGET <= HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX3_DL +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_epi16(a.raw, idx.raw, b.raw)}; +#elif HWY_TARGET == HWY_SSE2 + const DFromV d; + const RebindToSigned di; + const Vec128 idx_vec{idx.raw}; + const auto sel_hi_mask = + RebindMask(d, BitCast(di, idx_vec) > Set(di, int16_t{7})); + const auto lo_lookup_result = TableLookupLanes(a, idx); + const auto hi_lookup_result = TableLookupLanes(b, idx); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#else + const DFromV d; + const Repartition du8; + return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), + Indices128{idx.raw})); +#endif +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_epi32(a.raw, idx.raw, b.raw)}; +#else // AVX2 or below + const DFromV d; + +#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SSE2 + const Vec128 idx_vec{idx.raw}; + +#if HWY_TARGET <= HWY_AVX2 + const RebindToFloat d_sel; + const auto sel_hi_mask = MaskFromVec(BitCast(d_sel, ShiftLeft<29>(idx_vec))); +#else + const RebindToSigned d_sel; + const auto sel_hi_mask = BitCast(d_sel, idx_vec) > Set(d_sel, int32_t{3}); +#endif + + const auto lo_lookup_result = BitCast(d_sel, TableLookupLanes(a, idx)); + const auto hi_lookup_result = BitCast(d_sel, TableLookupLanes(b, idx)); + return BitCast(d, + IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result)); +#else // SSSE3 or SSE4 + const Repartition du8; + return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), + Indices128{idx.raw})); +#endif // HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SSE2 +#endif // HWY_TARGET <= HWY_AVX3 +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, + Vec128 b, + Indices128 idx) { + return Vec128{_mm_permutex2var_ph(a.raw, idx.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_ps(a.raw, idx.raw, b.raw)}; +#elif HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SSE2 + const DFromV d; + +#if HWY_TARGET <= HWY_AVX2 + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<29>(Vec128{idx.raw}))); +#else + const RebindToSigned di; + const auto sel_hi_mask = + RebindMask(d, Vec128{idx.raw} > Set(di, int32_t{3})); +#endif + + const auto lo_lookup_result = TableLookupLanes(a, idx); + const auto hi_lookup_result = TableLookupLanes(b, idx); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#else // SSSE3 or SSE4 + const DFromV d; + const Repartition du8; + return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), + Indices128{idx.raw})); +#endif +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_epi64(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const Vec128 idx_vec{idx.raw}; + const Indices128 idx_mod{And(idx_vec, Set(d, T{1})).raw}; + +#if HWY_TARGET <= HWY_SSE4 + const RebindToFloat d_sel; + const auto sel_hi_mask = MaskFromVec(BitCast(d_sel, ShiftLeft<62>(idx_vec))); +#else // SSE2 or SSSE3 + const Repartition di32; + const RebindToSigned d_sel; + const auto sel_hi_mask = MaskFromVec( + BitCast(d_sel, VecFromMask(di32, DupEven(BitCast(di32, idx_vec)) > + Set(di32, int32_t{1})))); +#endif // HWY_TARGET <= HWY_SSE4 + + const auto lo_lookup_result = BitCast(d_sel, TableLookupLanes(a, idx_mod)); + const auto hi_lookup_result = BitCast(d_sel, TableLookupLanes(b, idx_mod)); + return BitCast(d, + IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result)); +#endif // HWY_TARGET <= HWY_AVX3 +} + +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_pd(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const RebindToSigned di; + const Vec128 idx_vec{idx.raw}; + const Indices128 idx_mod{And(idx_vec, Set(di, int64_t{1})).raw}; + +#if HWY_TARGET <= HWY_SSE4 + const auto sel_hi_mask = MaskFromVec(BitCast(d, ShiftLeft<62>(idx_vec))); +#else // SSE2 or SSSE3 + const Repartition di32; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, VecFromMask(di32, DupEven(BitCast(di32, idx_vec)) > + Set(di32, int32_t{1})))); +#endif // HWY_TARGET <= HWY_SSE4 + + const auto lo_lookup_result = TableLookupLanes(a, idx_mod); + const auto hi_lookup_result = TableLookupLanes(b, idx_mod); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif // HWY_TARGET <= HWY_AVX3 +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) static constexpr uint8_t mask[16] = { + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; +#if HWY_TARGET >= HWY_SSSE3 + const Repartition d8; + alignas(16) static constexpr uint8_t mask[16] = { + 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +#else + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_blend_epi16( + BitCast(du, a).raw, BitCast(du, b).raw, 0x55)}); +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i odd = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128i even = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(even, odd)}; +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _ps can do 3/cycle. + const DFromV d; + const RebindToFloat df; + return BitCast(d, Vec128{_mm_blend_ps(BitCast(df, a).raw, + BitCast(df, b).raw, 5)}); +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + // Same as ConcatUpperLower for full vectors; do not call that because this + // is more efficient for 64x1 vectors. + const DFromV d; + const RebindToFloat dd; +#if HWY_TARGET >= HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd( + BitCast(dd, b).raw, BitCast(dd, a).raw, _MM_SHUFFLE2(1, 0))}); +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, a).raw, + BitCast(dd, b).raw, 1)}); +#endif +} + +template +HWY_API Vec128 OddEven(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + // SHUFPS must fill the lower half of the output from one input, so we + // need another shuffle. Unpack avoids another immediate byte. + const __m128 odd = _mm_shuffle_ps(a.raw, a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128 even = _mm_shuffle_ps(b.raw, b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_ps(even, odd)}; +#else + return Vec128{_mm_blend_ps(a.raw, b.raw, 5)}; +#endif +} + +// -------------------------- InterleaveEven + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + return ConcatEven(d, b, a); +} + +// I8/U8 InterleaveEven is generic for all vector lengths that are >= 4 bytes +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Repartition du16; + return OddEven(BitCast(d, ShiftLeft<8>(BitCast(du16, b))), a); +} + +// I16/U16 InterleaveEven is generic for all vector lengths that are >= 8 bytes +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Repartition du32; + return OddEven(BitCast(d, ShiftLeft<16>(BitCast(du32, b))), a); +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_epi32( + a.raw, static_cast<__mmask8>(0x0A), b.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_ps(a.raw, static_cast<__mmask8>(0x0A), + b.raw, b.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +#else +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const auto b2_b0_a2_a0 = ConcatEven(df, BitCast(df, b), BitCast(df, a)); + return BitCast( + d, VFromD{_mm_shuffle_ps(b2_b0_a2_a0.raw, b2_b0_a2_a0.raw, + _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// -------------------------- InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return ConcatOdd(d, b, a); +} + +// I8/U8 InterleaveOdd is generic for all vector lengths that are >= 4 bytes +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Repartition du16; + return OddEven(b, BitCast(d, ShiftRight<8>(BitCast(du16, a)))); +} + +// I16/U16 InterleaveOdd is generic for all vector lengths that are >= 8 bytes +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Repartition du32; + return OddEven(b, BitCast(d, ShiftRight<16>(BitCast(du32, a)))); +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_epi32( + b.raw, static_cast<__mmask8>(0x05), a.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_ps(b.raw, static_cast<__mmask8>(0x05), + a.raw, a.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +#else +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const auto b3_b1_a3_a1 = ConcatOdd(df, BitCast(df, b), BitCast(df, a)); + return BitCast( + d, VFromD{_mm_shuffle_ps(b3_b1_a3_a1.raw, b3_b1_a3_a1.raw, + _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ Shl (ZipLower, Mul) + +// Use AVX2/3 variable shifts where available, otherwise multiply by powers of +// two from loading float exponents, which is considerably faster (according +// to LLVM-MCA) than scalar or testing bits: https://gcc.godbolt.org/z/9G7Y9v. + +namespace detail { + +#if HWY_TARGET == HWY_AVX2 // Unused for AVX3 - we use sllv directly +template +HWY_API V AVX2ShlU16Vec128(V v, V bits) { + const DFromV d; + const Rebind du32; + return TruncateTo(d, PromoteTo(du32, v) << PromoteTo(du32, bits)); +} +#elif HWY_TARGET > HWY_AVX2 + +template +static HWY_INLINE VFromD Pow2ConvF32ToI32( + D32 d32, VFromD> vf32) { + const RebindToSigned di32; +#if HWY_COMPILER_GCC_ACTUAL + // ConvertInRangeTo is safe with GCC due the inline assembly workaround used + // for F32->I32 ConvertInRangeTo with GCC + return BitCast(d32, ConvertInRangeTo(di32, vf32)); +#else + // Otherwise, use NearestIntInRange because we rely on the native 0x80..00 + // overflow behavior + return BitCast(d32, NearestIntInRange(di32, vf32)); +#endif +} + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template +HWY_INLINE Vec128> Pow2(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + const Rebind df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dw, zero, upper); + const auto f1 = ZipUpper(dw, zero, upper); + // See cvtps comment below. + const VFromD bits0 = Pow2ConvF32ToI32(dw, BitCast(df, f0)); + const VFromD bits1 = Pow2ConvF32ToI32(dw, BitCast(df, f1)); +#if HWY_TARGET <= HWY_SSE4 + return VFromD{_mm_packus_epi32(bits0.raw, bits1.raw)}; +#else + return ConcatEven(du, BitCast(du, bits1), BitCast(du, bits0)); +#endif +} + +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const Twice dt_u; + const RepartitionToWide dt_w; + const RebindToFloat dt_f; + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dt_w, Zero(dt_u), ResizeBitCast(dt_u, upper)); + // See cvtps comment below. + const VFromD bits0 = + Pow2ConvF32ToI32(dt_w, BitCast(dt_f, f0)); +#if HWY_TARGET <= HWY_SSE4 + return VFromD{_mm_packus_epi32(bits0.raw, bits0.raw)}; +#elif HWY_TARGET == HWY_SSSE3 + alignas(16) + const uint16_t kCompactEvenU16[8] = {0x0100, 0x0504, 0x0908, 0x0D0C}; + return TableLookupBytes(bits0, Load(du, kCompactEvenU16)); +#else + const RebindToSigned dt_i32; + const auto bits0_i32 = ShiftRight<16>(BitCast(dt_i32, ShiftLeft<16>(bits0))); + return VFromD{_mm_packs_epi32(bits0_i32.raw, bits0_i32.raw)}; +#endif +} + +// Same, for 32-bit shifts. +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const RebindToFloat df; + const auto exp = ShiftLeft<23>(v); + const auto f = exp + Set(d, 0x3F800000); // 1.0f + // Do not use ConvertTo because we rely on the native 0x80..00 overflow + // behavior. + return Pow2ConvF32ToI32(d, BitCast(df, f)); +} + +#endif // HWY_TARGET > HWY_AVX2 + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_sllv_epi16(v.raw, bits.raw)}; +#elif HWY_TARGET == HWY_AVX2 + return AVX2ShlU16Vec128(v, bits); +#else + return v * Pow2(bits); +#endif +} + +#if HWY_TARGET > HWY_AVX3 +HWY_API Vec16 Shl(hwy::UnsignedTag /*tag*/, Vec16 v, + Vec16 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 bits16{_mm_cvtepu16_epi64(bits.raw)}; +#else + const auto bits16 = And(bits, Vec16{_mm_set_epi64x(0, 0xFFFF)}); +#endif + return Vec16{_mm_sll_epi16(v.raw, bits16.raw)}; +} +#endif + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V AVX2ShlU8Vec128(V v, V bits) { + const DFromV d; + const Rebind du16; + return TruncateTo(d, PromoteTo(du16, v) << PromoteTo(du16, bits)); +} +#elif HWY_TARGET <= HWY_AVX2 +template +HWY_INLINE V AVX2ShlU8Vec128(V v, V bits) { + const DFromV d; + const Rebind du32; + return TruncateTo(d, PromoteTo(du32, v) << PromoteTo(du32, bits)); +} +template +HWY_INLINE V AVX2ShlU8Vec128(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind du16; + const Rebind dh_u32; + + const VFromD lo_shl_result = + PromoteTo(dh_u32, LowerHalf(dh, v)) + << PromoteTo(dh_u32, LowerHalf(dh, bits)); + const VFromD hi_shl_result = + PromoteTo(dh_u32, UpperHalf(dh, v)) + << PromoteTo(dh_u32, UpperHalf(dh, bits)); + const VFromD u16_shl_result = ConcatEven( + du16, BitCast(du16, hi_shl_result), BitCast(du16, lo_shl_result)); + return TruncateTo(d, u16_shl_result); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// 8-bit: may use the Shl overload for uint16_t. +template +HWY_API Vec128 Shl(hwy::UnsignedTag tag, Vec128 v, + Vec128 bits) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL + (void)tag; + // kMask[i] = 0xFF >> i + alignas(16) static constexpr uint8_t kMasks[16] = { + 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0x00}; + // kShl[i] = 1 << i + alignas(16) static constexpr uint8_t kShl[16] = {1, 2, 4, 8, 0x10, + 0x20, 0x40, 0x80, 0x00}; + v = And(v, TableLookupBytes(Load(Full64(), kMasks), bits)); + const VFromD mul = + TableLookupBytes(Load(Full64(), kShl), bits); + return VFromD{_mm_gf2p8mul_epi8(v.raw, mul.raw)}; +#elif HWY_TARGET <= HWY_AVX2 + (void)tag; + (void)d; + return AVX2ShlU8Vec128(v, bits); +#else + const Repartition dw; + using VW = VFromD; + const VW even_mask = Set(dw, 0x00FF); + const VW odd_mask = Set(dw, 0xFF00); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + // Shift even lanes in-place + const VW evens = Shl(tag, vw, And(bits16, even_mask)); + const VW odds = Shl(tag, And(vw, odd_mask), ShiftRight<8>(bits16)); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +#endif +} +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 bits8{_mm_cvtepu8_epi64(bits.raw)}; +#else + const Vec16 bits8 = + And(Vec16{bits.raw}, Vec16{_mm_set_epi64x(0, 0xFF)}); +#endif + return Vec128{_mm_sll_epi16(v.raw, bits8.raw)}; +} + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET >= HWY_SSE4 + return v * Pow2(bits); +#else + return Vec128{_mm_sllv_epi32(v.raw, bits.raw)}; +#endif +} + +#if HWY_TARGET >= HWY_SSE4 +HWY_API Vec32 Shl(hwy::UnsignedTag /*tag*/, Vec32 v, + const Vec32 bits) { +#if HWY_TARGET == HWY_SSE4 + const Vec32 bits32{_mm_cvtepu32_epi64(bits.raw)}; +#else + const auto bits32 = + Combine(Full64(), Zero(Full32()), bits); +#endif + return Vec32{_mm_sll_epi32(v.raw, bits32.raw)}; +} +#endif + +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET >= HWY_SSE4 + const DFromV d; + // Individual shifts and combine + const Vec128 out0{_mm_sll_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_sll_epi64(v.raw, bits1)}; + return ConcatUpperLower(d, out1, out0); +#else + return Vec128{_mm_sllv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 Shl(hwy::UnsignedTag /*tag*/, Vec64 v, + Vec64 bits) { + return Vec64{_mm_sll_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec128 Shl(hwy::SignedTag /*tag*/, Vec128 v, + Vec128 bits) { + const DFromV di; + const RebindToUnsigned du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (mul, mask, BroadcastSignBit) + +// Use AVX2+ variable shifts except for SSSE3/SSE4. There, we use +// widening multiplication by powers of two obtained by loading float exponents, +// followed by a constant right-shift. This is still faster than a scalar or +// bit-test approach: https://gcc.godbolt.org/z/9G7Y9v. + +#if HWY_TARGET <= HWY_AVX2 +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V AVX2ShrU8Vec128(V v, V bits) { + const DFromV d; + const Rebind du16; + const RebindToSigned di16; + return DemoteTo(d, + BitCast(di16, PromoteTo(du16, v) >> PromoteTo(du16, bits))); +} +#else // AVX2 +template +HWY_INLINE V AVX2ShrU16Vec128(V v, V bits) { + const DFromV d; + const Rebind du32; + const RebindToSigned di32; + return DemoteTo(d, + BitCast(di32, PromoteTo(du32, v) >> PromoteTo(du32, bits))); +} +template +HWY_INLINE V AVX2ShrU8Vec128(V v, V bits) { + const DFromV d; + const Rebind du32; + const RebindToSigned di32; + return DemoteTo(d, + BitCast(di32, PromoteTo(du32, v) >> PromoteTo(du32, bits))); +} +template +HWY_INLINE V AVX2ShrU8Vec128(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind di16; + const Rebind du16; + const Rebind dh_i32; + const Rebind dh_u32; + + const auto lo_shr_result = + BitCast(dh_i32, PromoteTo(dh_u32, LowerHalf(dh, v)) >> + PromoteTo(dh_u32, LowerHalf(dh, bits))); + const auto hi_shr_result = + BitCast(dh_i32, PromoteTo(dh_u32, UpperHalf(dh, v)) >> + PromoteTo(dh_u32, UpperHalf(dh, bits))); + const auto i16_shr_result = + BitCast(di16, OrderedDemote2To(du16, lo_shr_result, hi_shr_result)); + return DemoteTo(d, i16_shr_result); +} +#endif // HWY_TARGET <= HWY_AVX3 + +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX2 + +template +HWY_API Vec128 operator>>(Vec128 in, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srlv_epi16(in.raw, bits.raw)}; +#elif HWY_TARGET <= HWY_AVX2 + return detail::AVX2ShrU16Vec128(in, bits); +#else + const DFromV d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + const auto out = MulHigh(in, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), in, out); +#endif +} + +#if HWY_TARGET > HWY_AVX3 +HWY_API Vec16 operator>>(const Vec16 in, + const Vec16 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 bits16{_mm_cvtepu16_epi64(bits.raw)}; +#else + const auto bits16 = And(bits, Vec16{_mm_set_epi64x(0, 0xFFFF)}); +#endif + return Vec16{_mm_srl_epi16(in.raw, bits16.raw)}; +} +#endif + +// 8-bit uses 16-bit shifts. +template +HWY_API Vec128 operator>>(Vec128 in, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX2 + return detail::AVX2ShrU8Vec128(in, bits); +#else + const DFromV d; + const Repartition dw; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, in); + const VW bits16 = BitCast(dw, bits); + const VW evens = And(vw, mask) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> ShiftRight<8>(bits16); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +#endif +} +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 in8{_mm_cvtepu8_epi16(in.raw)}; + const Vec16 bits8{_mm_cvtepu8_epi64(bits.raw)}; +#else + const Vec16 mask{_mm_set_epi64x(0, 0xFF)}; + const Vec16 in8 = And(Vec16{in.raw}, mask); + const Vec16 bits8 = And(Vec16{bits.raw}, mask); +#endif + return Vec128{_mm_srl_epi16(in8.raw, bits8.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET >= HWY_SSE4 + // 32x32 -> 64 bit mul, then shift right by 32. + const DFromV d32; + // Move odd lanes into position for the second mul. Shuffle more gracefully + // handles N=1 than repartitioning to u64 and shifting 32 bits right. + const Vec128 in31{_mm_shuffle_epi32(in.raw, 0x31)}; + // For bits=0, we cannot mul by 2^32, so fix the result later. + const auto mul = detail::Pow2(Set(d32, 32) - bits); + const auto out20 = ShiftRight<32>(MulEven(in, mul)); // z 2 z 0 + const Vec128 mul31{_mm_shuffle_epi32(mul.raw, 0x31)}; + // No need to shift right, already in the correct position. + const auto out31 = BitCast(d32, MulEven(in31, mul31)); // 3 ? 1 ? + const Vec128 out = OddEven(out31, BitCast(d32, out20)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d32), in, out); +#else + return Vec128{_mm_srlv_epi32(in.raw, bits.raw)}; +#endif +} + +#if HWY_TARGET >= HWY_SSE4 +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSE4 + const Vec32 bits32{_mm_cvtepu32_epi64(bits.raw)}; +#else + const auto bits32 = + Combine(Full64(), Zero(Full32()), bits); +#endif + return Vec128{_mm_srl_epi32(in.raw, bits32.raw)}; +} +#endif + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET >= HWY_SSE4 + const DFromV d; + // Individual shifts and combine + const Vec128 out0{_mm_srl_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_srl_epi64(v.raw, bits1)}; + return ConcatUpperLower(d, out1, out0); +#else + return Vec128{_mm_srlv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + return Vec64{_mm_srl_epi64(v.raw, bits.raw)}; +} + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V AVX2ShrI8Vec128(V v, V bits) { + const DFromV d; + const Rebind di16; + return DemoteTo(d, PromoteTo(di16, v) >> PromoteTo(di16, bits)); +} +#elif HWY_TARGET <= HWY_AVX2 // AVX2 +template +HWY_INLINE V AVX2ShrI16Vec128(V v, V bits) { + const DFromV d; + const Rebind di32; + return DemoteTo(d, PromoteTo(di32, v) >> PromoteTo(di32, bits)); +} +template +HWY_INLINE V AVX2ShrI8Vec128(V v, V bits) { + const DFromV d; + const Rebind di32; + return DemoteTo(d, PromoteTo(di32, v) >> PromoteTo(di32, bits)); +} +template +HWY_INLINE V AVX2ShrI8Vec128(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind di16; + const Rebind dh_i32; + + const auto lo_shr_result = PromoteTo(dh_i32, LowerHalf(dh, v)) >> + PromoteTo(dh_i32, LowerHalf(dh, bits)); + const auto hi_shr_result = PromoteTo(dh_i32, UpperHalf(dh, v)) >> + PromoteTo(dh_i32, UpperHalf(dh, bits)); + const auto i16_shr_result = + OrderedDemote2To(di16, lo_shr_result, hi_shr_result); + return DemoteTo(d, i16_shr_result); +} +#endif + +#if HWY_TARGET > HWY_AVX3 +// Also used in x86_256-inl.h. +template +HWY_INLINE V SignedShr(const DI di, const V v, const V count_i) { + const RebindToUnsigned du; + const auto count = BitCast(du, count_i); // same type as value to shift + // Clear sign and restore afterwards. This is preferable to shifting the MSB + // downwards because Shr is somewhat more expensive than Shl. + const auto sign = BroadcastSignBit(v); + const auto abs = BitCast(du, v ^ sign); // off by one, but fixed below + return BitCast(di, abs >> count) ^ sign; +} +#endif + +} // namespace detail + +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi16(v.raw, bits.raw)}; +#elif HWY_TARGET <= HWY_AVX2 + return detail::AVX2ShrI16Vec128(v, bits); +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} + +#if HWY_TARGET > HWY_AVX3 +HWY_API Vec16 operator>>(Vec16 v, Vec16 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 bits16{_mm_cvtepu16_epi64(bits.raw)}; +#else + const auto bits16 = And(bits, Vec16{_mm_set_epi64x(0, 0xFFFF)}); +#endif + return Vec16{_mm_sra_epi16(v.raw, bits16.raw)}; +} +#endif + +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX2 + return detail::AVX2ShrI8Vec128(v, bits); +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 vi16{_mm_cvtepi8_epi16(v.raw)}; + const Vec16 bits8{_mm_cvtepu8_epi64(bits.raw)}; +#else + const DFromV d; + const Rebind di16; + const Twice dt; + + const auto vi16 = ShiftRight<8>(BitCast(di16, Combine(dt, v, v))); + const Vec16 bits8 = + And(Vec16{bits.raw}, Vec16{_mm_set_epi64x(0, 0xFF)}); +#endif + return Vec128{_mm_sra_epi16(vi16.raw, bits8.raw)}; +} + +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX2 + return Vec128{_mm_srav_epi32(v.raw, bits.raw)}; +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} + +#if HWY_TARGET > HWY_AVX2 +HWY_API Vec32 operator>>(Vec32 v, Vec32 bits) { +#if HWY_TARGET == HWY_SSE4 + const Vec32 bits32{_mm_cvtepu32_epi64(bits.raw)}; +#else + const auto bits32 = Combine(Full64(), Zero(Full32()), bits); +#endif + return Vec32{_mm_sra_epi32(v.raw, bits32.raw)}; +} +#endif + +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi64(v.raw, bits.raw)}; +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +namespace detail { + +template )> +static HWY_INLINE V SSE2Mul128(V a, V b, V& mulH) { + const DFromV du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need the lower 32 bits + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need + // the even (lower 64 bits of every 128-bit block) results. See + // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.txt + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + mulH = MulEven(aH, bH) + w1 + k; + return ShiftLeft<32>(t) + w3; +} + +template )> +static HWY_INLINE V SSE2Mul128(V a, V b, V& mulH) { + const DFromV di64; + const RebindToUnsigned du64; + using VU64 = VFromD; + + VU64 unsigned_mulH; + const auto mulL = BitCast( + di64, SSE2Mul128(BitCast(du64, a), BitCast(du64, b), unsigned_mulH)); + mulH = BitCast(di64, unsigned_mulH) - And(BroadcastSignBit(a), b) - + And(a, BroadcastSignBit(b)); + return mulL; +} + +} // namespace detail + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +template ), + HWY_IF_V_SIZE_GT_V(V, (HWY_ARCH_X86_64 ? 16 : 8))> +HWY_API V MulEven(V a, V b) { + V mulH; + const V mulL = detail::SSE2Mul128(a, b, mulH); + return InterleaveLower(mulL, mulH); +} + +template ), + HWY_IF_V_SIZE_GT_V(V, (HWY_ARCH_X86_64 ? 16 : 8))> +HWY_API V MulOdd(V a, V b) { + const DFromV du64; + V mulH; + const V mulL = detail::SSE2Mul128(a, b, mulH); + return InterleaveUpper(du64, mulL, mulH); +} + +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +template ), + HWY_IF_V_SIZE_GT_V(V, (HWY_ARCH_X86_64 ? 8 : 0))> +HWY_API V MulHigh(V a, V b) { + V mulH; + detail::SSE2Mul128(a, b, mulH); + return mulH; +} + +#if HWY_ARCH_X86_64 + +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + alignas(16) T mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(d, mul); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + const Half d2; + alignas(16) T mul[2]; + const T a1 = GetLane(UpperHalf(d2, a)); + const T b1 = GetLane(UpperHalf(d2, b)); + mul[0] = Mul128(a1, b1, &mul[1]); + return Load(d, mul); +} + +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Vec64{_mm_cvtsi64_si128(static_cast(hi))}; +} + +#endif // HWY_ARCH_X86_64 + +// ================================================== CONVERT (2) + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +// I32->I64 PromoteEvenTo/PromoteOddTo + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec64 v) { + return PromoteLowerTo(d_to, v); +} + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec128 v) { + const Repartition d_from; + return PromoteLowerTo(d_to, ConcatEven(d_from, v, v)); +} + +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + V v) { + const Repartition d_from; + return PromoteLowerTo(d_to, ConcatOdd(d_from, v, v)); +} + +} // namespace detail +#endif + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "third_party/highway/hwy/ops/inside-inl.h" + +// ------------------------------ WidenMulPairwiseAdd (PromoteEvenTo) + +#if HWY_NATIVE_DOT_BF16 + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return VFromD{_mm_dpbf16_ps(Zero(df).raw, + reinterpret_cast<__m128bh>(a.raw), + reinterpret_cast<__m128bh>(b.raw))}; +} + +#else + +// Generic for all vector lengths. +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} + +#endif // HWY_NATIVE_DOT_BF16 + +// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. +template >> +HWY_API VFromD WidenMulPairwiseAdd(D32 /* tag */, V16 a, V16 b) { + return VFromD{_mm_madd_epi16(a.raw, b.raw)}; +} + +// Generic for all vector lengths. +template >> +HWY_API VFromD WidenMulPairwiseAdd(DU32 du32, VU16 a, VU16 b) { + const auto p_lo = a * b; + const auto p_hi = MulHigh(a, b); + + const auto p_hi1_lo0 = BitCast(du32, OddEven(p_hi, p_lo)); + const auto p_hi0_lo1 = Or(ShiftLeft<16>(BitCast(du32, p_hi)), + ShiftRight<16>(BitCast(du32, p_lo))); + return Add(BitCast(du32, p_hi1_lo0), BitCast(du32, p_hi0_lo1)); +} + +// ------------------------------ SatWidenMulPairwiseAdd + +#if HWY_TARGET <= HWY_SSSE3 + +#ifdef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#undef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#else +#define HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#endif + +// Even if N=1, the input is always at least 2 lanes, hence _mm_maddubs_epi16 +// is safe. +template +HWY_API VFromD SatWidenMulPairwiseAdd( + DI16 /* tag */, VFromD> a, + VFromD> b) { + return VFromD{_mm_maddubs_epi16(a.raw, b.raw)}; +} + +#endif + +// ------------------------------ SatWidenMulPairwiseAccumulate + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#else +#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#endif + +// Even if N=1, the I16 vectors have at least 2 lanes, hence _mm_dpwssds_epi32 +// is safe. +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm_dpwssds_epi32(sum.raw, a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ ReorderWidenMulAccumulate (PromoteEvenTo) + +#if HWY_NATIVE_DOT_BF16 + +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{_mm_dpbf16_ps(sum0.raw, reinterpret_cast<__m128bh>(a.raw), + reinterpret_cast<__m128bh>(b.raw))}; +} + +#endif // HWY_NATIVE_DOT_BF16 + +// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. +template >> +HWY_API VFromD ReorderWidenMulAccumulate(D32 d, V16 a, V16 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; +#if HWY_TARGET <= HWY_AVX3_DL + return VFromD{_mm_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; +#else + return sum0 + WidenMulPairwiseAdd(d, a, b); +#endif +} + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DU32 d, VU16 a, VU16 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; + return sum0 + WidenMulPairwiseAdd(d, a, b); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, + Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API Vec128 RearrangeToOddPlusEven( + const Vec128 sum0, Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); +} + +// ------------------------------ SumOfMulQuadAccumulate +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD{_mm_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; +} + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + // TODO(janwas): AVX-VNNI-INT8 has dpbssd. + const Repartition du8; + + const auto a_u = BitCast(du8, a); + const auto result_sum_0 = SumOfMulQuadAccumulate(di32, a_u, b, sum); + const auto result_sum_1 = ShiftLeft<8>( + SumOfMulQuadAccumulate(di32, ShiftRight<7>(a_u), b, Zero(di32))); + return result_sum_0 - result_sum_1; +} + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 du32, VFromD> a, + VFromD> b, VFromD sum) { + // TODO(janwas): AVX-VNNI-INT8 has dpbuud. + const Repartition du8; + const RebindToSigned di8; + const RebindToSigned di32; + + const auto b_i = BitCast(di8, b); + const auto result_sum_0 = + SumOfMulQuadAccumulate(di32, a, b_i, BitCast(di32, sum)); + const auto result_sum_1 = ShiftLeft<8>( + SumOfMulQuadAccumulate(di32, a, BroadcastSignBit(b_i), Zero(di32))); + + return BitCast(du32, result_sum_0 - result_sum_1); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_packs_epi32(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const Rebind di32; + const auto zero_if_neg = AndNot(ShiftRight<31>(v), v); + const auto too_big = VecFromMask(di32, Gt(v, Set(di32, 0xFFFF))); + const auto clamped = Or(zero_if_neg, too_big); +#if HWY_TARGET == HWY_SSE2 + const Rebind du16; + const RebindToSigned di16; + return BitCast(du16, DemoteTo(di16, ShiftRight<16>(ShiftLeft<16>(clamped)))); +#else + const Repartition du16; + // Lower 2 bytes from each 32-bit lane; same as return type for fewer casts. + alignas(16) static constexpr uint16_t kLower2Bytes[16] = { + 0x0100, 0x0504, 0x0908, 0x0D0C, 0x8080, 0x8080, 0x8080, 0x8080}; + const auto lo2 = Load(du16, kLower2Bytes); + return VFromD{TableLookupBytes(BitCast(du16, clamped), lo2).raw}; +#endif +#else + return VFromD{_mm_packus_epi32(v.raw, v.raw)}; +#endif +} + +template +HWY_API VFromD DemoteTo(D du16, VFromD> v) { + const DFromV du32; + const RebindToSigned di32; +#if HWY_TARGET >= HWY_SSSE3 + const auto too_big = + VecFromMask(di32, Gt(BitCast(di32, ShiftRight<16>(v)), Zero(di32))); + const auto clamped = Or(BitCast(di32, v), too_big); +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di16; + return BitCast(du16, DemoteTo(di16, ShiftRight<16>(ShiftLeft<16>(clamped)))); +#else + (void)du16; + const Repartition du16_full; + // Lower 2 bytes from each 32-bit lane; same as return type for fewer casts. + alignas(16) static constexpr uint16_t kLower2Bytes[16] = { + 0x0100, 0x0504, 0x0908, 0x0D0C, 0x8080, 0x8080, 0x8080, 0x8080}; + const auto lo2 = Load(du16_full, kLower2Bytes); + return VFromD{TableLookupBytes(BitCast(du16_full, clamped), lo2).raw}; +#endif +#else + return DemoteTo(du16, BitCast(di32, Min(v, Set(du32, 0x7FFFFFFF)))); +#endif +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return VFromD{_mm_packus_epi16(i16, i16)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_packus_epi16(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return VFromD{_mm_packs_epi16(i16, i16)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_packs_epi16(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D du8, VFromD> v) { +#if HWY_TARGET <= HWY_AVX3 + // NOTE: _mm_cvtusepi32_epi8 is a saturated conversion of 32-bit unsigned + // integers to 8-bit unsigned integers + (void)du8; + return VFromD{_mm_cvtusepi32_epi8(v.raw)}; +#else + const DFromV du32; + const RebindToSigned di32; + const auto max_i32 = Set(du32, 0x7FFFFFFFu); + +#if HWY_TARGET >= HWY_SSSE3 + // On SSE2/SSSE3, clamp u32 values to an i32 using the u8 Min operation + // as SSE2/SSSE3 can do an u8 Min operation in a single instruction. + + // The u8 Min operation below leaves the lower 24 bits of each 32-bit + // lane unchanged. + + // The u8 Min operation below will leave any values that are less than or + // equal to 0x7FFFFFFF unchanged. + + // For values that are greater than or equal to 0x80000000, the u8 Min + // operation below will force the upper 8 bits to 0x7F and leave the lower + // 24 bits unchanged. + + // An u8 Min operation is okay here as any clamped value that is greater than + // or equal to 0x80000000 will be clamped to a value between 0x7F000000 and + // 0x7FFFFFFF through the u8 Min operation below, which will then be converted + // to 0xFF through the i32->u8 demotion. + const Repartition du32_as_du8; + const auto clamped = BitCast( + di32, Min(BitCast(du32_as_du8, v), BitCast(du32_as_du8, max_i32))); +#else + const auto clamped = BitCast(di32, Min(v, max_i32)); +#endif + + return DemoteTo(du8, clamped); +#endif +} + +template +HWY_API VFromD DemoteTo(D du8, VFromD> v) { + const DFromV du16; + const RebindToSigned di16; + const auto max_i16 = Set(du16, 0x7FFF); + +#if HWY_TARGET >= HWY_SSSE3 + // On SSE2/SSSE3, clamp u16 values to an i16 using the u8 Min operation + // as SSE2/SSSE3 can do an u8 Min operation in a single instruction. + + // The u8 Min operation below leaves the lower 8 bits of each 16-bit + // lane unchanged. + + // The u8 Min operation below will leave any values that are less than or + // equal to 0x7FFF unchanged. + + // For values that are greater than or equal to 0x8000, the u8 Min + // operation below will force the upper 8 bits to 0x7F and leave the lower + // 8 bits unchanged. + + // An u8 Min operation is okay here as any clamped value that is greater than + // or equal to 0x8000 will be clamped to a value between 0x7F00 and + // 0x7FFF through the u8 Min operation below, which will then be converted + // to 0xFF through the i16->u8 demotion. + const Repartition du16_as_du8; + const auto clamped = BitCast( + di16, Min(BitCast(du16_as_du8, v), BitCast(du16_as_du8, max_i16))); +#else + const auto clamped = BitCast(di16, Min(v, max_i16)); +#endif + + return DemoteTo(du8, clamped); +} + +#if HWY_TARGET < HWY_SSE4 && !defined(HWY_DISABLE_F16C) + +// HWY_NATIVE_F16C was already toggled above. + +// Work around MSVC warning for _mm_cvtps_ph (8 is actually a valid immediate). +// clang-cl requires a non-empty string, so we 'ignore' the irrelevant -Wmain. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wmain") + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToUnsigned du16; + return BitCast( + df16, VFromD{_mm_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); +} + +HWY_DIAGNOSTICS(pop) + +#endif // F16C + +#if HWY_HAVE_FLOAT16 + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +template +HWY_API VFromD DemoteTo(D /*df16*/, VFromD> v) { + return VFromD{_mm_cvtpd_ph(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +// The _mm*_cvtneps_pbh and _mm*_cvtne2ps_pbh intrinsics require GCC 9 or later +// or Clang 10 or later + +// Also need GCC or Clang to bit cast the __m128bh, __m256bh, or __m512bh vector +// returned by the _mm*_cvtneps_pbh and _mm*_cvtne2ps_pbh intrinsics to a +// __m128i, __m256i, or __m512i as there are currently no intrinsics available +// (as of GCC 13 and Clang 17) to bit cast a __m128bh, __m256bh, or __m512bh +// vector to a __m128i, __m256i, or __m512i vector + +#if HWY_AVX3_HAVE_F32_TO_BF16C +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +template +HWY_API VFromD DemoteTo(D /*dbf16*/, VFromD> v) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m128i raw_result; + __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + // The _mm_cvtneps_pbh intrinsic returns a __m128bh vector that needs to be + // bit casted to a __m128i vector + return VFromD{detail::BitCastToInteger(_mm_cvtneps_pbh(v.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /*dbf16*/, Vec128 a, + Vec128 b) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m128i raw_result; + __asm__("vcvtne2ps2bf16 %2, %1, %0" + : "=v"(raw_result) + : "v"(b.raw), "v"(a.raw)); + return VFromD{raw_result}; +#else + // The _mm_cvtne2ps_pbh intrinsic returns a __m128bh vector that needs to be + // bit casted to a __m128i vector + return VFromD{detail::BitCastToInteger(_mm_cvtne2ps_pbh(b.raw, a.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec64 a, + Vec64 b) { + return VFromD{_mm_shuffle_epi32( + detail::BitCastToInteger(_mm_cvtne2ps_pbh(b.raw, a.raw)), + _MM_SHUFFLE(2, 0, 2, 0))}; +} + +template +HWY_API VFromD ReorderDemote2To(D dbf16, Vec32 a, Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dbf16, Combine(dt, b, a)); +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +// Specializations for partial vectors because packs_epi32 sets lanes above 2*N. +template +HWY_API VFromD ReorderDemote2To(D dn, Vec32 a, Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec64 a, + Vec64 b) { + return VFromD{_mm_shuffle_epi32(_mm_packs_epi32(a.raw, b.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{_mm_packs_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec32 a, Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API VFromD ReorderDemote2To(D dn, Vec64 a, Vec64 b) { +#if HWY_TARGET >= HWY_SSSE3 + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +#else + (void)dn; + return VFromD{_mm_shuffle_epi32(_mm_packus_epi32(a.raw, b.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +#endif +} +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + const Half dnh; + const auto u16_a = DemoteTo(dnh, a); + const auto u16_b = DemoteTo(dnh, b); + return Combine(dn, u16_b, u16_a); +#else + (void)dn; + return VFromD{_mm_packus_epi32(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV du32; + const RebindToSigned di32; + const auto max_i32 = Set(du32, 0x7FFFFFFFu); + +#if HWY_TARGET >= HWY_SSSE3 + const Repartition du32_as_du8; + // On SSE2/SSSE3, clamp a and b using u8 Min operation + const auto clamped_a = BitCast( + di32, Min(BitCast(du32_as_du8, a), BitCast(du32_as_du8, max_i32))); + const auto clamped_b = BitCast( + di32, Min(BitCast(du32_as_du8, b), BitCast(du32_as_du8, max_i32))); +#else + const auto clamped_a = BitCast(di32, Min(a, max_i32)); + const auto clamped_b = BitCast(di32, Min(b, max_i32)); +#endif + + return ReorderDemote2To(dn, clamped_a, clamped_b); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +// Specializations for partial vectors because packs_epi32 sets lanes above 2*N. +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec64 a, + Vec64 b) { + return VFromD{_mm_shuffle_epi32(_mm_packs_epi16(a.raw, b.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{_mm_packs_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec64 a, + Vec64 b) { + return VFromD{_mm_shuffle_epi32(_mm_packus_epi16(a.raw, b.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{_mm_packus_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV du16; + const RebindToSigned di16; + const auto max_i16 = Set(du16, 0x7FFFu); + +#if HWY_TARGET >= HWY_SSSE3 + const Repartition du16_as_du8; + // On SSE2/SSSE3, clamp a and b using u8 Min operation + const auto clamped_a = BitCast( + di16, Min(BitCast(du16_as_du8, a), BitCast(du16_as_du8, max_i16))); + const auto clamped_b = BitCast( + di16, Min(BitCast(du16_as_du8, b), BitCast(du16_as_du8, max_i16))); +#else + const auto clamped_a = BitCast(di16, Min(a, max_i16)); + const auto clamped_b = BitCast(di16, Min(b, max_i16)); +#endif + + return ReorderDemote2To(dn, clamped_a, clamped_b); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +template ), + HWY_IF_V_SIZE_LE_D(D, 16), class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} + +#if HWY_AVX3_HAVE_F32_TO_BF16C +// F32 to BF16 OrderedDemote2To is generic for all vector lengths on targets +// that support AVX512BF16 +template +HWY_API VFromD OrderedDemote2To(D dbf16, VFromD> a, + VFromD> b) { + return ReorderDemote2To(dbf16, a, b); +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtpd_ps(v.raw)}; +} + +namespace detail { + +// Generic for all vector lengths. +template +HWY_INLINE VFromD ClampF64ToI32Max(D d, VFromD v) { + // The max can be exactly represented in binary64, so clamping beforehand + // prevents x86 conversion from raising an exception and returning 80..00. + return Min(v, Set(d, 2147483647.0)); +} + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD +template +static constexpr HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::FloatTag /* to_type_tag */, TF from_val) { + return ConvertScalarTo(from_val); +} + +template +static HWY_BITCASTSCALAR_CONSTEXPR HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::SpecialTag /* to_type_tag */, TF from_val) { + return ConvertScalarTo(from_val); +} + +template +static HWY_BITCASTSCALAR_CXX14_CONSTEXPR HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::SignedTag /* to_type_tag */, TF from_val) { +#if HWY_HAVE_SCALAR_F16_TYPE && HWY_HAVE_SCALAR_F16_OPERATORS + using TFArith = If, hwy::bfloat16_t>(), float, + RemoveCvRef>; +#else + using TFArith = If>; +#endif + + const TFArith from_val_in_arith_type = ConvertScalarTo(from_val); + constexpr TTo kMinResultVal = LimitsMin(); + HWY_BITCASTSCALAR_CONSTEXPR const TFArith kMinOutOfRangePosVal = + ScalarAbs(ConvertScalarTo(kMinResultVal)); + + return (ScalarAbs(from_val_in_arith_type) < kMinOutOfRangePosVal) + ? ConvertScalarTo(from_val_in_arith_type) + : kMinResultVal; +} + +template +static HWY_CXX14_CONSTEXPR HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::UnsignedTag /* to_type_tag */, TF from_val) { +#if HWY_HAVE_SCALAR_F16_TYPE && HWY_HAVE_SCALAR_F16_OPERATORS + using TFArith = If, hwy::bfloat16_t>(), float, + RemoveCvRef>; +#else + using TFArith = If>; +#endif + + const TFArith from_val_in_arith_type = ConvertScalarTo(from_val); + constexpr TTo kTToMsb = static_cast(TTo{1} << (sizeof(TTo) * 8 - 1)); + constexpr const TFArith kNegOne = ConvertScalarTo(-1.0); + constexpr const TFArith kMinOutOfRangePosVal = + ConvertScalarTo(static_cast(kTToMsb) * 2.0); + + return (from_val_in_arith_type > kNegOne && + from_val_in_arith_type < kMinOutOfRangePosVal) + ? ConvertScalarTo(from_val_in_arith_type) + : LimitsMax(); +} + +template +static constexpr HWY_INLINE HWY_MAYBE_UNUSED TTo +X86ConvertScalarFromFloat(TF from_val) { + return X86ConvertScalarFromFloat(hwy::TypeTag>(), + from_val); +} +#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + +} // namespace detail + +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), int32_t{0}, + int32_t{0}); + } +#endif + + __m128i raw_result; + __asm__("%vcvttpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttpd_epi32(v.raw)}; +#endif +} + +// F64 to I32 DemoteTo is generic for all vector lengths +template +HWY_API VFromD DemoteTo(D di32, VFromD> v) { + const Rebind df64; + const VFromD clamped = detail::ClampF64ToI32Max(df64, v); + return DemoteInRangeTo(di32, clamped); +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), uint32_t{0}, + uint32_t{0}); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm_cvttpd_epu32(v.raw)}; +#endif +} + +// F64->U32 DemoteTo is generic for all vector lengths +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return DemoteInRangeTo(D(), ZeroIfNegative(v)); +} +#else // HWY_TARGET > HWY_AVX3 + +// F64 to U32 DemoteInRangeTo is generic for all vector lengths on +// SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteInRangeTo(D du32, VFromD> v) { + const RebindToSigned di32; + const Rebind df64; + const RebindToUnsigned du64; + + const auto k2_31 = Set(df64, 2147483648.0); + const auto v_is_ge_k2_31 = (v >= k2_31); + const auto clamped_lo31_f64 = v - IfThenElseZero(v_is_ge_k2_31, k2_31); + const auto clamped_lo31_u32 = + BitCast(du32, DemoteInRangeTo(di32, clamped_lo31_f64)); + const auto clamped_u32_msb = ShiftLeft<31>( + TruncateTo(du32, BitCast(du64, VecFromMask(df64, v_is_ge_k2_31)))); + return Or(clamped_lo31_u32, clamped_u32_msb); +} + +// F64 to U32 DemoteTo is generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteTo(D du32, VFromD> v) { + const Rebind df64; + const auto clamped = Min(ZeroIfNegative(v), Set(df64, 4294967295.0)); + return DemoteInRangeTo(du32, clamped); +} +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepi64_ps(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepu64_ps(v.raw)}; +} +#else +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64_63 = Set(df64, 27670116110564327424.0); + const auto f64_hi52 = + Xor(BitCast(df64, ShiftRight<12>(BitCast(du64, v))), k2p64_63) - k2p64_63; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + const auto f64_bits_decrement = + And(ShiftRight<63>(BitCast(du64, Xor(f64_sum, f64_carry))), + f64_sum_is_inexact); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - f64_bits_decrement, f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64 = Set(df64, 18446744073709551616.0); + const auto f64_hi52 = Or(BitCast(df64, ShiftRight<12>(v)), k2p64) - k2p64; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - ShiftRight<63>(BitCast(du64, f64_carry)), + f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} +#endif + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned> di32; + const Rebind du8; + return DemoteTo(du8, BitCast(di32, v)); +#else + const DFromV d32; + const Repartition d8; + alignas(16) static constexpr uint32_t k8From32[4] = { + 0x0C080400u, 0x0C080400u, 0x0C080400u, 0x0C080400u}; + // Also replicate bytes into all 32 bit lanes for safety. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + return LowerHalf(LowerHalf(BitCast(d8, quad))); +#endif +} + +// ------------------------------ F32->UI64 PromoteTo +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD PromoteInRangeTo(D /*di64*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttps2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm_cvttps_epi64(v.raw)}; +#endif +} + +// Generic for all vector lengths. +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { + const Rebind df32; + const RebindToFloat df64; + // We now avoid GCC UB in PromoteInRangeTo via assembly, see #2189 and + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115115. Previously we fixed up + // the result afterwards using three instructions. Now we instead check if + // v >= 2^63, and if so replace the output with 2^63-1, which is likely more + // efficient. Note that the previous representable f32 is less than 2^63 and + // thus fits in i64. + const MFromD overflow = RebindMask( + di64, PromoteMaskTo(df64, df32, Ge(v, Set(df32, 9.223372e18f)))); + return IfThenElse(overflow, Set(di64, LimitsMax()), + PromoteInRangeTo(di64, v)); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return PromoteInRangeTo(D(), ZeroIfNegative(v)); +} +template +HWY_API VFromD PromoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttps2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm_cvttps_epu64(v.raw)}; +#endif +} +#else // AVX2 or below + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{157}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + const auto f32_to_i32_result = ConvertTo(di32, adj_v); + const auto lo64_or_mask = PromoteTo( + di64, + BitCast(du32, VecFromMask(di32, Eq(f32_to_i32_result, + Set(di32, LimitsMax()))))); + + return Or(PromoteTo(di64, BitCast(di32, f32_to_i32_result)) + << PromoteTo(di64, exponent_adj), + lo64_or_mask); +} + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteInRangeTo(D d64, VFromD> v) { + const Rebind>, decltype(d64)> d32; + const RebindToSigned di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{0xFFFFFF9Du})))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + const auto f32_to_i32_result = ConvertInRangeTo(di32, adj_v); + return PromoteTo(d64, BitCast(d32, f32_to_i32_result)) + << PromoteTo(d64, exponent_adj); +} + +namespace detail { + +template +HWY_INLINE VFromD PromoteF32ToU64OverflowMaskToU64( + DU64 du64, VFromD> i32_overflow_mask) { + const Rebind di32; + const Twice dt_i32; + + const auto vt_i32_overflow_mask = ResizeBitCast(dt_i32, i32_overflow_mask); + return BitCast(du64, + InterleaveLower(vt_i32_overflow_mask, vt_i32_overflow_mask)); +} + +template +HWY_INLINE VFromD PromoteF32ToU64OverflowMaskToU64( + DU64 du64, VFromD> i32_overflow_mask) { + const RebindToSigned di64; + return BitCast(du64, PromoteTo(di64, i32_overflow_mask)); +} + +} // namespace detail + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto non_neg_v = ZeroIfNegative(v); + + const auto exponent_adj = BitCast( + du32, Min(SaturatedSub(BitCast(du32_as_du8, + ShiftRight<23>(BitCast(du32, non_neg_v))), + BitCast(du32_as_du8, Set(du32, uint32_t{157}))), + BitCast(du32_as_du8, Set(du32, uint32_t{33})))); + + const auto adj_v = + BitCast(df32, BitCast(du32, non_neg_v) - ShiftLeft<23>(exponent_adj)); + const auto f32_to_i32_result = ConvertInRangeTo(di32, adj_v); + + const auto i32_overflow_mask = BroadcastSignBit(f32_to_i32_result); + const auto overflow_result = + detail::PromoteF32ToU64OverflowMaskToU64(du64, i32_overflow_mask); + + return Or(PromoteTo(du64, BitCast(du32, f32_to_i32_result)) + << PromoteTo(du64, exponent_adj), + overflow_result); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ MulFixedPoint15 + +#if HWY_TARGET == HWY_SSE2 +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition di32; + + auto lo_product = a * b; + auto hi_product = MulHigh(a, b); + + const VFromD i32_product_lo{ + _mm_unpacklo_epi16(lo_product.raw, hi_product.raw)}; + const VFromD i32_product_hi{ + _mm_unpackhi_epi16(lo_product.raw, hi_product.raw)}; + + const auto round_up_incr = Set(di32, 0x4000); + return ReorderDemote2To(d, ShiftRight<15>(i32_product_lo + round_up_incr), + ShiftRight<15>(i32_product_hi + round_up_incr)); +} + +template +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + const DFromV d; + const Rebind di32; + + const auto lo_product = a * b; + const auto hi_product = MulHigh(a, b); + const VFromD i32_product{ + _mm_unpacklo_epi16(lo_product.raw, hi_product.raw)}; + + return DemoteTo(d, ShiftRight<15>(i32_product + Set(di32, 0x4000))); +} +#else +template +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhrs_epi16(a.raw, b.raw)}; +} +#endif + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(DTo /* tag */, Vec128 v) { + // BitCast requires the same size; DTo might be u8x1 and v u16x1. + const Repartition, DFromV> dto; + return VFromD{BitCast(dto, v).raw}; +} + +template +HWY_API VFromD TruncateTo(D d, Vec128 v) { +#if HWY_TARGET == HWY_SSE2 + const Vec128 lo{v.raw}; + const Vec128 hi{_mm_unpackhi_epi64(v.raw, v.raw)}; + return Combine(d, hi, lo); +#else + const Repartition> d8; + (void)d; + alignas(16) static constexpr uint8_t kIdx[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + const Vec128 v8 = TableLookupBytes(v, Load(d8, kIdx)); + return LowerHalf(LowerHalf(LowerHalf(v8))); +#endif +} + +template +HWY_API VFromD TruncateTo(D d, Vec128 v) { +#if HWY_TARGET == HWY_SSE2 + const Vec128 lo{v.raw}; + const Vec128 hi{_mm_unpackhi_epi64(v.raw, v.raw)}; + return Combine(d, hi, lo); +#else + (void)d; + const Repartition> d16; + alignas(16) static constexpr uint16_t kIdx[8] = { + 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u}; + const Vec128 v16 = TableLookupBytes(v, Load(d16, kIdx)); + return LowerHalf(LowerHalf(v16)); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + return VFromD{_mm_shuffle_epi32(v.raw, 0x88)}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const DFromV du32; +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di32; + const Rebind du8; + return DemoteTo(du8, BitCast(di32, ShiftRight<24>(ShiftLeft<24>(v)))); +#else + const Repartition d; + alignas(16) static constexpr uint8_t kIdx[16] = { + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu, + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d, kIdx)))); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const DFromV du32; +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di32; + const Rebind du16; + const RebindToSigned di16; + return BitCast( + du16, DemoteTo(di16, ShiftRight<16>(BitCast(di32, ShiftLeft<16>(v))))); +#else + const Repartition d; + return LowerHalf(ConcatEven(d, BitCast(d, v), BitCast(d, v))); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const DFromV du16; +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di16; + const Rebind du8; + const RebindToSigned di8; + return BitCast(du8, + DemoteTo(di8, ShiftRight<8>(BitCast(di16, ShiftLeft<8>(v))))); +#else + const Repartition d; + return LowerHalf(ConcatEven(d, BitCast(d, v), BitCast(d, v))); +#endif +} + +// ------------------------------ Demotions to/from i64 + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtsepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtsepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtsepi64_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtusepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtusepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtusepi64_epi8(v.raw)}; +} +#else // AVX2 or below + +// Disable the default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h for U64->I8/I16/I32 demotions on +// SSE2/SSSE3/SSE4/AVX2 as U64->I8/I16/I32 DemoteTo/ReorderDemote2To for +// SSE2/SSSE3/SSE4/AVX2 is implemented in x86_128-inl.h + +// The default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h are still used for U32->I8/I16 and +// U16->I8 demotions on SSE2/SSSE3/SSE4/AVX2 + +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) HWY_IF_NOT_T_SIZE_V(V, 8) + +namespace detail { +template +HWY_INLINE VFromD> DemoteFromU64MaskOutResult( + D /*dn*/, VFromD> v) { + return v; +} + +template +HWY_INLINE VFromD> DemoteFromU64MaskOutResult( + D /*dn*/, VFromD> v) { + const DFromV du64; + return And(v, + Set(du64, static_cast(hwy::HighestValue>()))); +} + +template +HWY_INLINE VFromD> DemoteFromU64Saturate( + D dn, VFromD> v) { + const Rebind du64; + const RebindToSigned di64; + constexpr int kShiftAmt = static_cast(sizeof(TFromD) * 8) - + static_cast(hwy::IsSigned>()); + + const auto too_big = BitCast( + du64, VecFromMask( + di64, Gt(BitCast(di64, ShiftRight(v)), Zero(di64)))); + return DemoteFromU64MaskOutResult(dn, Or(v, too_big)); +} + +template +HWY_INLINE VFromD ReorderDemote2From64To32Combine(D dn, V a, V b) { + return ConcatEven(dn, BitCast(dn, b), BitCast(dn, a)); +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV di64; + const RebindToUnsigned du64; + const RebindToUnsigned dn_u; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask = BitCast(du64, BroadcastSignBit(v)); + const auto saturated_vals = Xor( + invert_mask, + detail::DemoteFromU64Saturate(dn, Xor(invert_mask, BitCast(du64, v)))); + return BitCast(dn, TruncateTo(dn_u, saturated_vals)); +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV di64; + const RebindToUnsigned du64; + + const auto non_neg_vals = BitCast(du64, AndNot(BroadcastSignBit(v), v)); + return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, non_neg_vals)); +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const RebindToUnsigned dn_u; + return BitCast(dn, TruncateTo(dn_u, detail::DemoteFromU64Saturate(dn, v))); +} + +#if HWY_TARGET == HWY_SSE2 +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const Rebind di32; + return DemoteTo(dn, DemoteTo(di32, v)); +} +#endif // HWY_TARGET == HWY_SSE2 + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, v)); +} +#endif // HWY_TARGET <= HWY_AVX3 + +template )> +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +#endif + +#if HWY_TARGET > HWY_AVX2 +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask_a = BitCast(du64, BroadcastSignBit(a)); + const auto invert_mask_b = BitCast(du64, BroadcastSignBit(b)); + const auto saturated_a = Xor( + invert_mask_a, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_a, BitCast(du64, a)))); + const auto saturated_b = Xor( + invert_mask_b, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_b, BitCast(du64, b)))); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + + const auto saturated_a = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(a), a))); + const auto saturated_b = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(b), b))); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const Half dnh; + + const auto saturated_a = detail::DemoteFromU64Saturate(dnh, a); + const auto saturated_b = detail::DemoteFromU64Saturate(dnh, b); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} +#endif // HWY_TARGET > HWY_AVX2 + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepu16_ph(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepi16_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepi32_ps(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ConvertTo(D /*df*/, VFromD> v) { + return VFromD{_mm_cvtepu32_ps(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, VFromD> v) { + return VFromD{_mm_cvtepi64_pd(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, VFromD> v) { + return VFromD{_mm_cvtepu64_pd(v.raw)}; +} +#else // AVX2 or below +// Generic for all vector lengths. +template +HWY_API VFromD ConvertTo(D df, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/34066228/) + const RebindToUnsigned du32; + const RebindToSigned d32; + + const auto msk_lo = Set(du32, 0xFFFF); + const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 + + // Extract the 16 lowest/highest significant bits of v and cast to signed int + const auto v_lo = BitCast(d32, And(v, msk_lo)); + const auto v_hi = BitCast(d32, ShiftRight<16>(v)); + return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); +} + +// Generic for all vector lengths. +template +HWY_API VFromD ConvertTo(D dd, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition d32; + const Repartition d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +} + +namespace detail { +template +HWY_INLINE VFromD>> U64ToF64VecFast(VW w) { + const DFromV d64; + const RebindToFloat dd; + const auto cnst2_52_dbl = Set(dd, 0x0010000000000000); // 2^52 + return BitCast(dd, Or(w, BitCast(d64, cnst2_52_dbl))) - cnst2_52_dbl; +} +} // namespace detail + +// Generic for all vector lengths. +template +HWY_API VFromD ConvertTo(D dd, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned d64; + using VU = VFromD; + + const VU msk_lo = Set(d64, 0xFFFFFFFF); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest/highest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + const auto v_lo_dbl = detail::U64ToF64VecFast(v_lo); + return MulAdd(cnst2_32_dbl, detail::U64ToF64VecFast(v_hi), v_lo_dbl); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// Truncates (rounds toward zero). + +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertInRangeTo(D /*di*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttph_epi16 if any values of v[i] + // are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7])); + } +#endif + + __m128i raw_result; + __asm__("vcvttph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttph_epi16(v.raw)}; +#endif +} + +// F16 to I16 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(D di, VFromD> v) { + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = + RebindMask(di, Ge(v, Set(df, ConvertScalarTo(32768.0f)))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +} + +template +HWY_API VFromD ConvertInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttph_epu16 if any values of v[i] + // are not within the range of an uint16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7])); + } +#endif + + __m128i raw_result; + __asm__("vcvttph2uw {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttph_epu16(v.raw)}; +#endif +} + +// F16->U16 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return ConvertInRangeTo(D(), ZeroIfNegative(v)); +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertInRangeTo(D /*di*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("%vcvttps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttps_epi32(v.raw)}; +#endif +} + +// F32 to I32 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(D di, VFromD> v) { + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = RebindMask(di, Ge(v, Set(df, 2147483648.0f))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DI(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttpd_epi64(v.raw)}; +#endif +} + +// F64 to I64 ConvertTo is generic for all vector lengths on AVX3 +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = + RebindMask(di, Ge(v, Set(df, 9.223372036854776e18))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +} + +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttps_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DU(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvttps2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttps_epu32(v.raw)}; +#endif +} + +// F32->U32 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { + return ConvertInRangeTo(DU(), ZeroIfNegative(v)); +} + +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epu64 with GCC if any + // values of v[i] are not within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DU(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttpd_epu64(v.raw)}; +#endif +} + +// F64->U64 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { + return ConvertInRangeTo(DU(), ZeroIfNegative(v)); +} + +#else // AVX2 or below + +namespace detail { + +template +static HWY_INLINE VFromD ConvInRangeF32ToU32( + DU32 du32, VFromD> v, VFromD& exp_diff) { + const RebindToSigned di32; + const RebindToFloat df32; + + exp_diff = Set(du32, uint32_t{158}) - ShiftRight<23>(BitCast(du32, v)); + const auto scale_down_f32_val_mask = + VecFromMask(du32, Eq(exp_diff, Zero(du32))); + + const auto v_scaled = + BitCast(df32, BitCast(du32, v) + ShiftLeft<23>(scale_down_f32_val_mask)); + const auto f32_to_u32_result = + BitCast(du32, ConvertInRangeTo(di32, v_scaled)); + + return f32_to_u32_result + And(f32_to_u32_result, scale_down_f32_val_mask); +} + +} // namespace detail + +// F32 to U32 ConvertInRangeTo is generic for all vector lengths on +// SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertInRangeTo(DU32 du32, + VFromD> v) { + VFromD exp_diff; + const auto f32_to_u32_result = detail::ConvInRangeF32ToU32(du32, v, exp_diff); + return f32_to_u32_result; +} + +// F32 to U32 ConvertTo is generic for all vector lengths on +// SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertTo(DU32 du32, VFromD> v) { + const RebindToSigned di32; + + const auto non_neg_v = ZeroIfNegative(v); + VFromD exp_diff; + const auto f32_to_u32_result = + detail::ConvInRangeF32ToU32(du32, non_neg_v, exp_diff); + + return Or(f32_to_u32_result, + BitCast(du32, BroadcastSignBit(BitCast(di32, exp_diff)))); +} + +namespace detail { + +template +HWY_API VFromD ConvAbsInRangeF64ToUI64(D64 d64, + VFromD> v, + VFromD& biased_exp) { + const RebindToSigned di64; + const RebindToUnsigned du64; + using VU64 = VFromD; + const Repartition du16; + const VU64 k1075 = Set(du64, 1075); /* biased exponent of 2^52 */ + + // Exponent indicates whether the number can be represented as int64_t. + biased_exp = BitCast(d64, ShiftRight<52>(BitCast(du64, v))); + HWY_IF_CONSTEXPR(IsSigned>()) { + biased_exp = And(biased_exp, Set(d64, TFromD{0x7FF})); + } + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + + // Use 16-bit saturated unsigned subtraction to compute shift_mnt and + // shift_int since biased_exp[i] is a non-negative integer that is less than + // or equal to 2047. + + // 16-bit saturated unsigned subtraction is also more efficient than a + // 64-bit subtraction followed by a 64-bit signed Max operation on + // SSE2/SSSE3/SSE4/AVX2. + + // The upper 48 bits of both shift_mnt and shift_int are guaranteed to be + // zero as the upper 48 bits of both k1075 and biased_exp are zero. + + const VU64 shift_mnt = BitCast( + du64, SaturatedSub(BitCast(du16, k1075), BitCast(du16, biased_exp))); + const VU64 shift_int = BitCast( + du64, SaturatedSub(BitCast(du16, biased_exp), BitCast(du16, k1075))); + const VU64 mantissa = BitCast(du64, v) & Set(du64, (1ULL << 52) - 1); + // Include implicit 1-bit. NOTE: the shift count may exceed 63; we rely on x86 + // returning zero in that case. + const VU64 int53 = (mantissa | Set(du64, 1ULL << 52)) >> shift_mnt; + + // For inputs larger than 2^53 - 1, insert zeros at the bottom. + + // For inputs less than 2^64, the implicit 1-bit is guaranteed not to be + // shifted out of the left shift result below as shift_int[i] <= 11 is true + // for any inputs that are less than 2^64. + + return BitCast(d64, int53 << shift_int); +} + +} // namespace detail + +#if HWY_ARCH_X86_64 + +namespace detail { + +template +static HWY_INLINE int64_t SSE2ConvFirstF64LaneToI64(Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttsd_si64 with GCC if v[0] is + // not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (IsConstantX86Vec(hwy::SizeTag<1>(), v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return X86ConvertScalarFromFloat(raw_v[0]); + } +#endif + + int64_t result; + __asm__("%vcvttsd2si {%1, %0|%0, %1}" + : "=r"(result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return result; +#else + return _mm_cvttsd_si64(v.raw); +#endif +} + +} // namespace detail + +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, Vec64 v) { + return VFromD{_mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToI64(v))}; +} +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, Vec128 v) { + const __m128i i0 = _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToI64(v)); + const Full64 dd2; + const __m128i i1 = + _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToI64(UpperHalf(dd2, v))); + return VFromD{_mm_unpacklo_epi64(i0, i1)}; +} + +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = + RebindMask(di, Ge(v, Set(df, 9.223372036854776e18))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +} +#endif // HWY_ARCH_X86_64 + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 +template +HWY_API VFromD ConvertInRangeTo(DI di, VFromD> v) { + using VI = VFromD; + + VI biased_exp; + const VI shifted = detail::ConvAbsInRangeF64ToUI64(di, v, biased_exp); + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + + // If the input was negative, negate the integer (two's complement). + return (shifted ^ sign_mask) - sign_mask; +} + +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { + using VI = VFromD; + + VI biased_exp; + const VI shifted = detail::ConvAbsInRangeF64ToUI64(di, v, biased_exp); + +#if HWY_TARGET <= HWY_SSE4 + const auto in_range = biased_exp < Set(di, 1086); +#else + const Repartition di32; + const auto in_range = MaskFromVec(BitCast( + di, + VecFromMask(di32, DupEven(BitCast(di32, biased_exp)) < Set(di32, 1086)))); +#endif + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, shifted, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +} +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertInRangeTo(DU du, VFromD> v) { + VFromD biased_exp; + const auto shifted = detail::ConvAbsInRangeF64ToUI64(du, v, biased_exp); + return shifted; +} + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertTo(DU du, VFromD> v) { + const RebindToSigned di; + using VU = VFromD; + + VU biased_exp; + const VU shifted = + detail::ConvAbsInRangeF64ToUI64(du, ZeroIfNegative(v), biased_exp); + + // Exponent indicates whether the number can be represented as uint64_t. +#if HWY_TARGET <= HWY_SSE4 + const VU out_of_range = + BitCast(du, VecFromMask(di, BitCast(di, biased_exp) > Set(di, 1086))); +#else + const Repartition di32; + const VU out_of_range = BitCast( + du, + VecFromMask(di32, DupEven(BitCast(di32, biased_exp)) > Set(di32, 1086))); +#endif + + return (shifted | out_of_range); +} +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CXX14_CONSTEXPR TTo +X86ScalarNearestInt(TF flt_val) { +#if HWY_HAVE_SCALAR_F16_TYPE && HWY_HAVE_SCALAR_F16_OPERATORS + using TFArith = If, hwy::bfloat16_t>(), float, + RemoveCvRef>; +#else + using TFArith = If>; +#endif + + const TTo trunc_int_val = X86ConvertScalarFromFloat(flt_val); + const TFArith abs_val_diff = ScalarAbs( + ConvertScalarTo(ConvertScalarTo(flt_val) - + ConvertScalarTo(trunc_int_val))); + constexpr TFArith kHalf = ConvertScalarTo(0.5); + + const bool round_result_up = + ((trunc_int_val ^ ScalarShr(trunc_int_val, sizeof(TTo) * 8 - 1)) != + LimitsMax()) && + (abs_val_diff > kHalf || + (abs_val_diff == kHalf && (trunc_int_val & 1) != 0)); + return static_cast( + trunc_int_val + + (round_result_up ? (ScalarSignBit(flt_val) ? (-1) : 1) : 0)); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + +// If these are in namespace detail, the x86_256/512 templates are not found. +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtps_epi32 with GCC if any values + // of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("%vcvtps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtps_epi32(v.raw)}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtph_epi16 if any values of v[i] + // are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7])); + } +#endif + + __m128i raw_result; + __asm__("vcvtph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtph_epi16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvtpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtpd_epi64(v.raw)}; +#endif +} + +#else // HWY_TARGET > HWY_AVX3 + +namespace detail { + +#if HWY_ARCH_X86_64 +template +static HWY_INLINE int64_t +SSE2ConvFirstF64LaneToNearestI64(Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtsd_si64 with GCC if v[0] is + // not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (IsConstantX86Vec(hwy::SizeTag<1>(), v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return X86ScalarNearestInt(raw_v[0]); + } +#endif + + int64_t result; + __asm__("%vcvtsd2si {%1, %0|%0, %1}" + : "=r"(result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return result; +#else + return _mm_cvtsd_si64(v.raw); +#endif +} +#endif // HWY_ARCH_X86_64 + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 +template +static HWY_INLINE VFromD SSE2NearestI64InRange( + DI64 di64, VFromD> v) { + const RebindToFloat df64; + const RebindToUnsigned du64; + using VI64 = VFromD; + + const auto mant_end = Set(df64, MantissaEnd()); + const auto is_small = Lt(Abs(v), mant_end); + + const auto adj_v = Max(v, Set(df64, -9223372036854775808.0)) + + IfThenElseZero(is_small, CopySignToAbs(mant_end, v)); + const auto adj_v_biased_exp = + And(BitCast(di64, ShiftRight<52>(BitCast(du64, adj_v))), + Set(di64, int64_t{0x7FF})); + + // We can simply subtract 1075 from adj_v_biased_exp[i] to get shift_int since + // adj_v_biased_exp[i] is at least 1075 + const VI64 shift_int = adj_v_biased_exp + Set(di64, int64_t{-1075}); + + const VI64 mantissa = BitCast(di64, adj_v) & Set(di64, (1LL << 52) - 1); + // Include implicit 1-bit if is_small[i] is 0. NOTE: the shift count may + // exceed 63; we rely on x86 returning zero in that case. + const VI64 int53 = mantissa | IfThenZeroElse(RebindMask(di64, is_small), + Set(di64, 1LL << 52)); + + const VI64 sign_mask = BroadcastSignBit(BitCast(di64, v)); + // If the input was negative, negate the integer (two's complement). + return ((int53 << shift_int) ^ sign_mask) - sign_mask; +} +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +} // namespace detail + +#if HWY_ARCH_X86_64 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, Vec64 v) { + return VFromD{ + _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToNearestI64(v))}; +} +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, Vec128 v) { + const __m128i i0 = + _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToNearestI64(v)); + const Full64 dd2; + const __m128i i1 = _mm_cvtsi64_si128( + detail::SSE2ConvFirstF64LaneToNearestI64(UpperHalf(dd2, v))); + return VFromD{_mm_unpacklo_epi64(i0, i1)}; +} +#endif // HWY_ARCH_X86_64 + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 +template +static HWY_INLINE VFromD NearestIntInRange(DI di, + VFromD> v) { + return detail::SSE2NearestI64InRange(di, v); +} +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +#endif // HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD DemoteToNearestIntInRange( + DI, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtpd_epi32 with GCC if any values + // of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DI(), detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), int32_t{0}, int32_t{0}); + } +#endif + + __m128i raw_result; + __asm__("%vcvtpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtpd_epi32(v.raw)}; +#endif +} + +// F16/F32/F64 NearestInt is generic for all vector lengths +template , class DI = RebindToSigned, + HWY_IF_FLOAT_D(DF), + HWY_IF_T_SIZE_ONE_OF_D(DF, (1 << 4) | (1 << 8) | + (HWY_HAVE_FLOAT16 ? (1 << 2) : 0))> +HWY_API VFromD NearestInt(const VF v) { + const DI di; + using TI = TFromD; + using TF = TFromD; + using TFArith = If>; + + constexpr TFArith kMinOutOfRangePosVal = + static_cast(-static_cast(LimitsMin())); + static_assert(kMinOutOfRangePosVal > static_cast(0.0), + "kMinOutOfRangePosVal > 0.0 must be true"); + + // See comment at the first occurrence of "IfThenElse(overflow,". + // Here we are rounding, whereas previous occurrences truncate, but there is + // no difference because the previous float value is well below the max i32. + const auto overflow = RebindMask( + di, Ge(v, Set(DF(), ConvertScalarTo(kMinOutOfRangePosVal)))); + auto result = + IfThenElse(overflow, Set(di, LimitsMax()), NearestIntInRange(di, v)); + + return result; +} + +template +HWY_API VFromD DemoteToNearestInt(DI, VFromD> v) { + const DI di; + const Rebind df64; + return DemoteToNearestIntInRange(di, Min(v, Set(df64, 2147483647.0))); +} + +// ------------------------------ Floating-point rounding (ConvertTo) + +#if HWY_TARGET >= HWY_SSSE3 + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + // Rely on rounding after addition with a large value such that no mantissa + // bits remain (assuming the current mode is nearest-even). We may need a + // compiler flag for precise floating-point to prevent "optimizing" this out. + const DFromV df; + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +namespace detail { + +// Truncating to integer and converting back to float is correct except when the +// input magnitude is large, in which case the input was already an integer +// (because mantissa >> exponent is zero). +template +HWY_INLINE Mask128 UseInt(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const DFromV d; + return Abs(v) < Set(d, MantissaEnd()); +} + +} // namespace detail + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertInRangeTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertInRangeTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +#ifdef HWY_NATIVE_CEIL_FLOOR_INT +#undef HWY_NATIVE_CEIL_FLOOR_INT +#else +#define HWY_NATIVE_CEIL_FLOOR_INT +#endif + +template +HWY_API VFromD>> CeilInt(V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + return integer - + VecFromMask(di, RebindMask(di, And(detail::UseInt(v), int_f < v))); +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertInRangeTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +template +HWY_API VFromD>> FloorInt(V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + return integer + + VecFromMask(di, RebindMask(di, And(detail::UseInt(v), int_f > v))); +} + +#else + +// Toward nearest integer, ties to even +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_roundscale_ph(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_roundscale_ph(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_roundscale_ph(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_roundscale_ph(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +#endif // !HWY_SSSE3 + +// ------------------------------ Floating-point classification + +#define HWY_X86_FPCLASS_QNAN 0x01 +#define HWY_X86_FPCLASS_POS0 0x02 +#define HWY_X86_FPCLASS_NEG0 0x04 +#define HWY_X86_FPCLASS_POS_INF 0x08 +#define HWY_X86_FPCLASS_NEG_INF 0x10 +#define HWY_X86_FPCLASS_SUBNORMAL 0x20 +#define HWY_X86_FPCLASS_NEG 0x40 +#define HWY_X86_FPCLASS_SNAN 0x80 + +#if HWY_HAVE_FLOAT16 || HWY_IDE + +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return Mask128{ + _mm_fpclass_ph_mask(v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +} + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; + HWY_DIAGNOSTICS(pop) +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +template +HWY_API Mask128 IsFinite(const Vec128 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask128{_mm_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{ + _mm_fpclass_ps_mask(v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +#else + return Mask128{_mm_cmpunord_ps(v.raw, v.raw)}; +#endif +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{ + _mm_fpclass_pd_mask(v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +#else + return Mask128{_mm_cmpunord_pd(v.raw, v.raw)}; +#endif +} + +#ifdef HWY_NATIVE_IS_EITHER_NAN +#undef HWY_NATIVE_IS_EITHER_NAN +#else +#define HWY_NATIVE_IS_EITHER_NAN +#endif + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask128{_mm_cmpunord_ps(a.raw, b.raw)}; +#endif +} + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, + Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask128{_mm_cmpunord_pd(a.raw, b.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask128{_mm_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} +template +HWY_API Mask128 IsFinite(const Vec128 v) { + return Not(Mask128{_mm_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET <= HWY_SSE4 + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenc_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenclast_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128 AESInvMixColumns(Vec128 state) { + return Vec128{_mm_aesimc_si128(state.raw)}; +} + +HWY_API Vec128 AESRoundInv(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesdec_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128 AESLastRoundInv(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesdeclast_si128(state.raw, round_key.raw)}; +} + +template +HWY_API Vec128 AESKeyGenAssist(Vec128 v) { + return Vec128{_mm_aeskeygenassist_si128(v.raw, kRcon)}; +} + +template +HWY_API Vec128 CLMulLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x00)}; +} + +template +HWY_API Vec128 CLMulUpper(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x11)}; +} + +#endif // !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET <= HWY_SSE4 + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +template +HWY_INLINE MFromD LoadMaskBits128(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, kN=1. + const VFromD vbits{_mm_cvtsi32_si128(static_cast(mask_bits))}; + +#if HWY_TARGET == HWY_SSE2 + // {b0, b1, ...} ===> {b0, b0, b1, b1, ...} + __m128i unpacked_vbits = _mm_unpacklo_epi8(vbits.raw, vbits.raw); + // {b0, b0, b1, b1, ...} ==> {b0, b0, b0, b0, b1, b1, b1, b1, ...} + unpacked_vbits = _mm_unpacklo_epi16(unpacked_vbits, unpacked_vbits); + // {b0, b0, b0, b0, b1, b1, b1, b1, ...} ==> + // {b0, b0, b0, b0, b0, b0, b0, b0, b1, b1, b1, b1, b1, b1, b1, b1} + const VFromD rep8{ + _mm_unpacklo_epi32(unpacked_vbits, unpacked_vbits)}; +#else + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) static constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); +#endif + const VFromD bit = Dup128VecFromValues( + du, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return RebindMask(d, TestBit(rep8, bit)); +} + +template +HWY_INLINE MFromD LoadMaskBits128(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits128(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits128(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail +#endif // HWY_TARGET > HWY_AVX3 + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t kN = MaxLanes(d); +#if HWY_TARGET <= HWY_AVX3 + (void)d; + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (kN + 7) / 8; + CopyBytes(bits, &mask_bits); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + return MFromD::FromBits(mask_bits); +#else + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (kN + 7) / 8; + CopyBytes(bits, &mask_bits); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + return detail::LoadMaskBits128(d, mask_bits); +#endif +} + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + +#if HWY_TARGET <= HWY_AVX3 + return MFromD::FromBits(mask_bits); +#else + return detail::LoadMaskBits128(d, mask_bits); +#endif +} + +template +struct CompressIsPartition { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 supports native compress, but a table-based approach allows + // 'partitioning' (also moving mask=false lanes to the top), which helps + // vqsort. This is only feasible for eight or less lanes, i.e. sizeof(T) == 8 + // on AVX3. For simplicity, we only use tables for 64-bit lanes (not AVX3 + // u32x8 etc.). + enum { value = (sizeof(T) == 8) }; +#else + // generic_ops-inl does not guarantee IsPartition for 8-bit. + enum { value = (sizeof(T) != 1) }; +#endif +}; + +namespace detail { + +// Returns `mask_bits` (from movemask) with the upper bits cleared, if there +// are 8 or fewer valid bits. +template +constexpr uint64_t OnlyActive(D d, uint64_t mask_bits) { + return (d.MaxBytes() >= 16) ? mask_bits + : mask_bits & ((1ull << d.MaxLanes()) - 1); +} + +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ BitsFromMask (MFromD, OnlyActive) +// Generic for all vector lengths. +template +HWY_INLINE uint64_t BitsFromMask(D d, MFromD mask) { + return detail::OnlyActive(d, mask.raw); +} + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + CopyBytes(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (kN < 8) { + const int mask_bits = (1 << kN) - 1; + bits[0] = static_cast(bits[0] & mask_bits); + } + + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint64_t mask_bits = uint64_t{mask.raw} & ((1ull << kN) - 1); + return PopCount(mask_bits); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint32_t mask_bits = uint32_t{mask.raw} & ((1u << kN) - 1); + return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint32_t mask_bits = uint32_t{mask.raw} & ((1u << kN) - 1); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint32_t mask_bits = uint32_t{mask.raw} & ((1u << kN) - 1); + return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint32_t mask_bits = uint32_t{mask.raw} & ((1u << kN) - 1); + return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits)) + : -1; +} + +template +HWY_API bool AllFalse(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint64_t mask_bits = uint64_t{mask.raw} & ((1ull << kN) - 1); + return mask_bits == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint64_t mask_bits = uint64_t{mask.raw} & ((1ull << kN) - 1); + // Cannot use _kortestc because we may have less than 8 mask bits. + return mask_bits == (1ull << kN) - 1; +} + +// ------------------------------ Compress + +// 8-16 bit Compress, CompressStore defined in x86_512 because they use Vec512. + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return Vec128{_mm_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + HWY_DASSERT(mask.raw < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const DFromV d; + const Repartition d8; + const auto index = Load(d8, u8_indices + 16 * mask.raw); + return BitCast(d, TableLookupBytes(BitCast(d8, v), index)); +} + +// ------------------------------ CompressNot (Compress) + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // See CompressIsPartition, PrintCompressNot64x2NibbleTables + alignas(16) static constexpr uint64_t packed_array[16] = { + 0x00000010, 0x00000001, 0x00000010, 0x00000010}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm_permutexvar_epi64 will ignore the upper bits. + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(16) static constexpr uint64_t shifts[2] = {0, 4}; + const auto indices = Indices128{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressStore (defined in x86_512) + +// ------------------------------ CompressBlendedStore (defined in x86_avx3) + +// ------------------------------ CompressBitsStore (defined in x86_512) + +#else // AVX2 or below + +// ------------------------------ BitsFromMask + +namespace detail { + +constexpr HWY_INLINE uint64_t U64FromInt(int mask_bits) { + return static_cast(static_cast(mask_bits)); +} + +} // namespace detail + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const auto sign_bits = BitCast(d, VecFromMask(d, mask)).raw; + return detail::OnlyActive(d, + detail::U64FromInt(_mm_movemask_epi8(sign_bits))); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const auto sign_bits = _mm_packs_epi16(mask.raw, _mm_setzero_si128()); + return detail::OnlyActive(d, + detail::U64FromInt(_mm_movemask_epi8(sign_bits))); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToFloat df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return detail::OnlyActive(d, + detail::U64FromInt(_mm_movemask_ps(sign_bits.raw))); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToFloat df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return detail::OnlyActive(d, + detail::U64FromInt(_mm_movemask_pd(sign_bits.raw))); +} + +// ------------------------------ StoreMaskBits +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + constexpr size_t kNumBytes = (MaxLanes(d) + 7) / 8; + const uint64_t mask_bits = BitsFromMask(d, mask); + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API bool AllFalse(D d, MFromD mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return BitsFromMask(d, mask) == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr uint64_t kAllBits = (1ull << MaxLanes(d)) - 1; + return BitsFromMask(d, mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + return PopCount(BitsFromMask(d, mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + return Num0BitsBelowLS1Bit_Nonzero32( + static_cast(BitsFromMask(d, mask))); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + return 31 - Num0BitsAboveMS1Bit_Nonzero32( + static_cast(BitsFromMask(d, mask))); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits)) + : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6. +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_API Vec128 CompressBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +template +HWY_API Vec128 CompressNotBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromNotBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 bytes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return detail::CompressBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressNot + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::CompressBits(v, BitsFromMask(d, Not(mask))); + } + return detail::CompressNotBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::CompressBits(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + BlendedStore(compressed, FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + uint64_t mask_bits = 0; + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + CopyBytes(bits, &mask_bits); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + + detail::MaybeUnpoison(unaligned, count); + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Expand + +// Otherwise, use the generic_ops-inl.h fallback. +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + +// The native instructions for 8/16-bit actually require VBMI2 (HWY_AVX3_DL), +// but we still want to override generic_ops-inl's table-based implementation +// whenever we have the 32-bit expand provided by AVX3. +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE // VBMI2 + +template +HWY_INLINE Vec128 NativeExpand(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_expand_epi8(mask.raw, v.raw)}; +} + +template +HWY_INLINE Vec128 NativeExpand(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_expand_epi16(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint8_t* HWY_RESTRICT unaligned) { + return VFromD{_mm_maskz_expandloadu_epi8(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint16_t* HWY_RESTRICT unaligned) { + return VFromD{_mm_maskz_expandloadu_epi16(mask.raw, unaligned)}; +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_INLINE Vec128 NativeExpand(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_expand_epi32(mask.raw, v.raw)}; +} + +template +HWY_INLINE Vec128 NativeExpand(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_expand_epi64(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint32_t* HWY_RESTRICT unaligned) { + return VFromD{_mm_maskz_expandloadu_epi32(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint64_t* HWY_RESTRICT unaligned) { + return VFromD{_mm_maskz_expandloadu_epi64(mask.raw, unaligned)}; +} + +} // namespace detail + +// Otherwise, 8/16-bit are implemented in x86_512 using PromoteTo. +#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE // VBMI2 + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +} + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Additional mask logical operations + +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +static HWY_INLINE uint32_t AVX3Blsi(T x) { + using TU = MakeUnsigned; + const auto u32_val = static_cast(static_cast(x)); +#if HWY_COMPILER_CLANGCL + return static_cast(u32_val & (0u - u32_val)); +#else + return static_cast(_blsi_u32(u32_val)); +#endif +} +template +static HWY_INLINE uint64_t AVX3Blsi(T x) { + const auto u64_val = static_cast(x); +#if HWY_COMPILER_CLANGCL || HWY_ARCH_X86_32 + return static_cast(u64_val & (0ULL - u64_val)); +#else + return static_cast(_blsi_u64(u64_val)); +#endif +} + +template +static HWY_INLINE uint32_t AVX3Blsmsk(T x) { + using TU = MakeUnsigned; + const auto u32_val = static_cast(static_cast(x)); +#if HWY_COMPILER_CLANGCL + return static_cast(u32_val ^ (u32_val - 1u)); +#else + return static_cast(_blsmsk_u32(u32_val)); +#endif +} +template +static HWY_INLINE uint64_t AVX3Blsmsk(T x) { + const auto u64_val = static_cast(x); +#if HWY_COMPILER_CLANGCL || HWY_ARCH_X86_32 + return static_cast(u64_val ^ (u64_val - 1ULL)); +#else + return static_cast(_blsmsk_u64(u64_val)); +#endif +} + +} // namespace detail + +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + constexpr uint32_t kActiveElemMask = (uint32_t{1} << N) - 1; + return Mask128{static_cast::Raw>( + (0u - detail::AVX3Blsi(mask.raw)) & kActiveElemMask)}; +} +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + constexpr uint32_t kActiveElemMask = (uint32_t{1} << N) - 1; + return Mask128{static_cast::Raw>( + (detail::AVX3Blsi(mask.raw) - 1u) & kActiveElemMask)}; +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + constexpr uint32_t kActiveElemMask = (uint32_t{1} << N) - 1; + return Mask128{static_cast::Raw>( + detail::AVX3Blsmsk(mask.raw) & kActiveElemMask)}; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return Mask128{ + static_cast::Raw>(detail::AVX3Blsi(mask.raw))}; +} +#else // AVX2 or below +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const FixedTag d; + const auto vmask = VecFromMask(d, mask); + return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask))); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Simd d; + const auto vmask = VecFromMask(d, mask); + const auto neg_vmask = + ResizeBitCast(d, Neg(ResizeBitCast(Full64(), vmask))); + return MaskFromVec(Or(vmask, neg_vmask)); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Full128 d; + const Repartition di64; + const Repartition df32; + const Repartition di32; + using VF = VFromD; + + auto vmask = BitCast(di64, VecFromMask(d, mask)); + vmask = Or(vmask, Neg(vmask)); + + // Copy the sign bit of the first int64_t lane to the second int64_t lane + const auto vmask2 = BroadcastSignBit( + BitCast(di32, VF{_mm_shuffle_ps(Zero(df32).raw, BitCast(df32, vmask).raw, + _MM_SHUFFLE(1, 1, 0, 0))})); + return MaskFromVec(BitCast(d, Or(vmask, BitCast(di64, vmask2)))); +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const FixedTag d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto zero = Zero(di); + const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero); + return MaskFromVec(BitCast(d, And(vmask, vmask2))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Simd d; + const RebindToSigned di; + + const auto vmask = ResizeBitCast(Full64(), VecFromMask(d, mask)); + const auto only_first_vmask = + BitCast(d, Neg(ResizeBitCast(di, And(vmask, Neg(vmask))))); + return MaskFromVec(only_first_vmask); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Full128 d; + const RebindToSigned di; + const Repartition di64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + const auto vmask2 = VecFromMask(di64, InterleaveLower(zero, vmask) == zero); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 /*mask*/) { + const FixedTag d; + const RebindToSigned di; + using TI = MakeSigned; + + return RebindMask(d, MaskFromVec(Set(di, TI(-1)))); +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + const Simd d; + return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reductions + +// Nothing fully native, generic_ops-inl defines SumOfLanes and ReduceSum. + +// We provide specializations of u8x8 and u8x16, so exclude those. +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), \ + hwy::EnableIf, uint8_t>() || \ + (HWY_V_SIZE_D(D) != 8 && HWY_V_SIZE_D(D) != 16)>* = \ + nullptr + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, static_cast(GetLane(SumsOf8(v)) & 0xFF)); +} +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + const Repartition d64; + VFromD sums = SumsOf8(v); + sums = SumOfLanes(d64, sums); + return Broadcast<0>(BitCast(d, sums)); +} + +#if HWY_TARGET <= HWY_SSE4 +// We provide specializations of u8x8, u8x16, and u16x8, so exclude those. +#undef HWY_IF_MINMAX_OF_LANES_D +#define HWY_IF_MINMAX_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), \ + hwy::EnableIf<(!hwy::IsSame, uint8_t>() || \ + ((HWY_V_SIZE_D(D) < 8) || (HWY_V_SIZE_D(D) > 16))) && \ + (!hwy::IsSame, uint16_t>() || \ + (HWY_V_SIZE_D(D) != 16))>* = nullptr + +template +HWY_API Vec128 MinOfLanes(D /* tag */, Vec128 v) { + return Broadcast<0>(Vec128{_mm_minpos_epu16(v.raw)}); +} + +template +HWY_API Vec128 MaxOfLanes(D d, Vec128 v) { + const Vec128 max = Set(d, LimitsMax()); + return max - MinOfLanes(d, max - v); +} + +template +HWY_API Vec64 MinOfLanes(D d, Vec64 v) { + const Rebind d16; + return TruncateTo(d, MinOfLanes(d16, PromoteTo(d16, v))); +} +template +HWY_API Vec128 MinOfLanes(D d, Vec128 v) { + const Half dh; + Vec64 result = + Min(MinOfLanes(dh, UpperHalf(dh, v)), MinOfLanes(dh, LowerHalf(dh, v))); + return Combine(d, result, result); +} + +template +HWY_API Vec64 MaxOfLanes(D d, Vec64 v) { + const Vec64 m(Set(d, LimitsMax())); + return m - MinOfLanes(d, m - v); +} +template +HWY_API Vec128 MaxOfLanes(D d, Vec128 v) { + const Vec128 m(Set(d, LimitsMax())); + return m - MinOfLanes(d, m - v); +} + +#endif // HWY_TARGET <= HWY_SSE4 + +// ------------------------------ BitShuffle +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_LE_V(V, 16), + HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Rebind du8; + + int32_t i32_bit_shuf_result = static_cast( + static_cast(_mm_bitshuffle_epi64_mask(v.raw, idx.raw))); + + return BitCast(d64, PromoteTo(du64, VFromD{_mm_cvtsi32_si128( + i32_bit_shuf_result)})); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ MultiRotateRight + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_MULTIROTATERIGHT +#undef HWY_NATIVE_MULTIROTATERIGHT +#else +#define HWY_NATIVE_MULTIROTATERIGHT +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_LE_V(V, 16), + HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + return V{_mm_multishift_epi64_epi8(idx.raw, v.raw)}; +} + +#endif + +// ------------------------------ Lt128 + +namespace detail { + +// Returns vector-mask for Lt128. Generic for all vector lengths. +template +HWY_INLINE VFromD Lt128Vec(const D d, VFromD a, VFromD b) { + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const auto eqHL = Eq(a, b); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + const VFromD ltLX = ShiftLeftLanes<1>(ltHL); + const VFromD vecHx = IfThenElse(eqHL, ltLX, ltHL); + return InterleaveUpper(d, vecHx, vecHx); +} + +// Returns vector-mask for Eq128. Generic for all vector lengths. +template +HWY_INLINE VFromD Eq128Vec(D d, VFromD a, VFromD b) { + const auto eqHL = VecFromMask(d, Eq(a, b)); + const auto eqLH = Reverse2(d, eqHL); + return And(eqHL, eqLH); +} + +template +HWY_INLINE VFromD Ne128Vec(D d, VFromD a, VFromD b) { + const auto neHL = VecFromMask(d, Ne(a, b)); + const auto neLH = Reverse2(d, neHL); + return Or(neHL, neLH); +} + +template +HWY_INLINE VFromD Lt128UpperVec(D d, VFromD a, VFromD b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + return InterleaveUpper(d, ltHL, ltHL); +} + +template +HWY_INLINE VFromD Eq128UpperVec(D d, VFromD a, VFromD b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return InterleaveUpper(d, eqHL, eqHL); +} + +template +HWY_INLINE VFromD Ne128UpperVec(D d, VFromD a, VFromD b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return InterleaveUpper(d, neHL, neHL); +} + +} // namespace detail + +template +HWY_API MFromD Lt128(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); +} + +template +HWY_API MFromD Eq128(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Eq128Vec(d, a, b)); +} + +template +HWY_API MFromD Ne128(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Ne128Vec(d, a, b)); +} + +template +HWY_API MFromD Lt128Upper(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Lt128UpperVec(d, a, b)); +} + +template +HWY_API MFromD Eq128Upper(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Eq128UpperVec(d, a, b)); +} + +template +HWY_API MFromD Ne128Upper(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Ne128UpperVec(d, a, b)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template +HWY_API VFromD Min128(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +} + +template +HWY_API VFromD Max128(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +} + +template +HWY_API VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b); +} + +template +HWY_API VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b); +} + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +template ), HWY_IF_V_SIZE_LE_D(DFromV, 16)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm_lzcnt_epi32(v.raw)}; +} + +template ), HWY_IF_V_SIZE_LE_D(DFromV, 16)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm_lzcnt_epi64(v.raw)}; +} + +// HighestSetBitIndex and TrailingZeroCount is implemented in x86_512-inl.h +// for AVX3 targets + +#endif // HWY_TARGET <= HWY_AVX3 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#undef HWY_X86_IF_EMULATED_D + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/third_party/aom/third_party/highway/hwy/ops/x86_256-inl.h b/third_party/aom/third_party/highway/hwy/ops/x86_256-inl.h new file mode 100644 index 000000000000..32df08497e30 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/x86_256-inl.h @@ -0,0 +1,8983 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when +// compiling for that target. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "third_party/highway/hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, + ignored "-Wmaybe-uninitialized") +#endif + +// Must come before HWY_COMPILER_CLANGCL +#include // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +#include +// avxintrin defines __m256i and must come before avx2intrin. +#include +#include // _pext_u64 +#include +#include +#include + +#if HWY_TARGET <= HWY_AVX10_2 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// Must come after avx512fintrin, else will not define 512-bit intrinsics. +#include +#include +#include +#include +#include + +#endif // HWY_TARGET <= HWY_AVX10_2 + +// clang-format on +#endif // HWY_COMPILER_CLANGCL + +// For half-width vectors. Already includes base.h. +#include "third_party/highway/hwy/ops/shared-inl.h" +// Already included by shared-inl, but do it again to avoid IDE warnings. +#include "third_party/highway/hwy/ops/x86_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template +struct Raw256 { + using type = __m256i; +}; +#if HWY_HAVE_FLOAT16 +template <> +struct Raw256 { + using type = __m256h; +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct Raw256 { + using type = __m256; +}; +template <> +struct Raw256 { + using type = __m256d; +}; + +} // namespace detail + +template +class Vec256 { + using Raw = typename detail::Raw256::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator%=(const Vec256 other) { + return *this = (*this % other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 + +// Template arg: sizeof(lane type) +template +struct RawMask256T {}; +template <> +struct RawMask256T<1> { + using type = __mmask32; +}; +template <> +struct RawMask256T<2> { + using type = __mmask16; +}; +template <> +struct RawMask256T<4> { + using type = __mmask8; +}; +template <> +struct RawMask256T<8> { + using type = __mmask8; +}; + +template +using RawMask256 = typename RawMask256T::type; + +#else // AVX2 or earlier + +template +using RawMask256 = typename Raw256::type; + +#endif // HWY_TARGET <= HWY_AVX3 + +} // namespace detail + +template +struct Mask256 { + using Raw = typename detail::RawMask256; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromM + +#if HWY_TARGET <= HWY_AVX3 + static Mask256 FromBits(uint64_t mask_bits) { + return Mask256{static_cast(mask_bits)}; + } +#else +// Lanes are either FF..FF or 0. +#endif // HWY_TARGET <= HWY_AVX3 + + Raw raw; +}; + +template +using Full256 = Simd; + + +// ------------------------------ Zero + +// Cannot use VFromD here because it is defined in terms of Zero. +template +HWY_API Vec256> Zero(D /* tag */) { + return Vec256>{_mm256_setzero_si256()}; +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{_mm256_setzero_si256()}; +} +template +HWY_API Vec256 Zero(D /* tag */) { +#if HWY_HAVE_FLOAT16 + return Vec256{_mm256_setzero_ph()}; +#else + return Vec256{_mm256_setzero_si256()}; +#endif // HWY_HAVE_FLOAT16 +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{_mm256_setzero_ps()}; +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{_mm256_setzero_pd()}; +} + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; } +#if HWY_HAVE_FLOAT16 +HWY_INLINE __m256i BitCastToInteger(__m256h v) { + return _mm256_castph_si256(v); +} +#endif // HWY_HAVE_FLOAT16 +HWY_INLINE __m256i BitCastToInteger(__m256 v) { return _mm256_castps_si256(v); } +HWY_INLINE __m256i BitCastToInteger(__m256d v) { + return _mm256_castpd_si256(v); +} + +#if HWY_AVX3_HAVE_F32_TO_BF16C +HWY_INLINE __m256i BitCastToInteger(__m256bh v) { + // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to + // bit cast a __m256bh to a __m256i as there is currently no intrinsic + // available (as of GCC 13 and Clang 17) that can bit cast a __m256bh vector + // to a __m256i vector + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // On GCC or Clang, use reinterpret_cast to bit cast a __m256bh to a __m256i + return reinterpret_cast<__m256i>(v); +#else + // On MSVC, use BitCastScalar to bit cast a __m256bh to a __m256i as MSVC does + // not allow reinterpret_cast, static_cast, or a C-style cast to be used to + // bit cast from one AVX vector type to a different AVX vector type + return BitCastScalar<__m256i>(v); +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_INLINE Vec256 BitCastToByte(Vec256 v) { + return Vec256{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger256 { + HWY_INLINE __m256i operator()(__m256i v) { return v; } +}; +#if HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256h operator()(__m256i v) { return _mm256_castsi256_ph(v); } +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256 operator()(__m256i v) { return _mm256_castsi256_ps(v); } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256d operator()(__m256i v) { return _mm256_castsi256_pd(v); } +}; + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, Vec256 v) { + return VFromD{BitCastFromInteger256>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, Vec256 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm256_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm256_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm256_set1_epi32(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm256_set1_epi64x(static_cast(t))}; // NOLINT +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 Set(D /* tag */, float16_t t) { + return Vec256{_mm256_set1_ph(t)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec256 Set(D /* tag */, float t) { + return Vec256{_mm256_set1_ps(t)}; +} +template +HWY_API Vec256 Set(D /* tag */, double t) { + return Vec256{_mm256_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API VFromD Undefined(D /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return VFromD{_mm256_undefined_si256()}; +} +template +HWY_API Vec256 Undefined(D /* tag */) { + return Vec256{_mm256_undefined_si256()}; +} +template +HWY_API Vec256 Undefined(D /* tag */) { +#if HWY_HAVE_FLOAT16 + return Vec256{_mm256_undefined_ph()}; +#else + return Vec256{_mm256_undefined_si256()}; +#endif +} +template +HWY_API Vec256 Undefined(D /* tag */) { + return Vec256{_mm256_undefined_ps()}; +} +template +HWY_API Vec256 Undefined(D /* tag */) { + return Vec256{_mm256_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ ResizeBitCast + +// 32-byte vector to 32-byte vector (or 64-byte vector to 64-byte vector on +// AVX3) +template ))> +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, v); +} + +// 32-byte vector to 16-byte vector (or 64-byte vector to 32-byte vector on +// AVX3) +template )) / 2)> +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const DFromV d_from; + const Half dh_from; + return BitCast(d, LowerHalf(dh_from, v)); +} + +// 32-byte vector (or 64-byte vector on AVX3) to <= 8-byte vector +template +HWY_API VFromD ResizeBitCast(D /*d*/, FromV v) { + return VFromD{ResizeBitCast(Full128>(), v).raw}; +} + +// <= 16-byte vector to 32-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, Vec256{_mm256_castsi128_si256( + ResizeBitCast(Full128(), v).raw)}); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{_mm256_setr_epi8( + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), static_cast(t8), + static_cast(t9), static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), static_cast(t14), + static_cast(t15), static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), static_cast(t4), + static_cast(t5), static_cast(t6), static_cast(t7), + static_cast(t8), static_cast(t9), static_cast(t10), + static_cast(t11), static_cast(t12), static_cast(t13), + static_cast(t14), static_cast(t15))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{ + _mm256_setr_epi16(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7))}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{_mm256_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2, + t3, t4, t5, t6, t7)}; +} +#endif + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{ + _mm256_setr_epi32(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{_mm256_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{ + _mm256_setr_epi64x(static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{_mm256_setr_pd(t0, t1, t0, t1)}; +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_and_si256(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec256 And(Vec256 a, Vec256 b) { + return Vec256{_mm256_and_ps(a.raw, b.raw)}; +} +HWY_API Vec256 And(Vec256 a, Vec256 b) { + return Vec256{_mm256_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_andnot_si256( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); +} +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + return Vec256{_mm256_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + return Vec256{_mm256_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_or_si256(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + return Vec256{_mm256_or_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + return Vec256{_mm256_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_xor_si256(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + return Vec256{_mm256_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + return Vec256{_mm256_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not +template +HWY_API Vec256 Not(const Vec256 v) { + const DFromV d; + using TU = MakeUnsigned; +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const __m256i vu = BitCast(RebindToUnsigned(), v).raw; + return BitCast(d, Vec256{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(d, Vec256{_mm256_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Xor3 +template +HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +#else + return Xor(x1, Xor(x2, x3)); +#endif +} + +// ------------------------------ Or3 +template +HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<1> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<2> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<4> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<8> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 PopulationCount(Vec256 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ================================================== MASK + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 IfThenElse(Mask256 mask, + Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_blend_ph(mask.raw, no.raw, yes.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_blend_ps(mask.raw, no.raw, yes.raw)}; +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_blend_pd(mask.raw, no.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return Vec256{_mm256_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec256 IfThenElseZero(Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec256{_mm256_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return Vec256{_mm256_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return Vec256{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask256 And(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Or(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} + +// UnmaskedNot returns ~m.raw without zeroing out any invalid bits +template +HWY_INLINE Mask256 UnmaskedNot(const Mask256 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask32>(_knot_mask32(m.raw))}; +#else + return Mask256{static_cast<__mmask32>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask256 UnmaskedNot(const Mask256 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask16>(_knot_mask16(m.raw))}; +#else + return Mask256{static_cast<__mmask16>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask256 UnmaskedNot(const Mask256 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask8>(_knot_mask8(m.raw))}; +#else + return Mask256{static_cast<__mmask8>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Not(hwy::SizeTag<1> /*tag*/, const Mask256 m) { + // sizeof(T) == 1: simply return ~m as all 32 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask256 Not(hwy::SizeTag<2> /*tag*/, const Mask256 m) { + // sizeof(T) == 2: simply return ~m as all 16 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask256 Not(hwy::SizeTag<4> /*tag*/, const Mask256 m) { + // sizeof(T) == 4: simply return ~m as all 8 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask256 Not(hwy::SizeTag<8> /*tag*/, const Mask256 m) { + // sizeof(T) == 8: need to zero out the upper 4 bits of ~m as only the lower + // 4 bits of m are valid + + // Return (~m) & 0x0F + return AndNot(hwy::SizeTag<8>(), m, Mask256::FromBits(uint64_t{0x0F})); +} + +} // namespace detail + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Not(const Mask256 m) { + // Flip only the valid bits. + return detail::Not(hwy::SizeTag(), m); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask32 combined_mask = _mm512_kunpackw( + static_cast<__mmask32>(hi.raw), static_cast<__mmask32>(lo.raw)); +#else + const auto combined_mask = + ((static_cast(hi.raw) << 16) | (lo.raw & 0xFFFFu)); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask32(static_cast<__mmask32>(m.raw), 16); +#else + const auto shifted_mask = static_cast(m.raw) >> 16; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD SlideMask1Up(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftli_mask32(static_cast<__mmask32>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) << 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftri_mask32(static_cast<__mmask32>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) >> 1)}; +#endif +} + +#else // AVX2 + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{v.raw}; +} + +// ------------------------------ IfThenElse + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + return Vec256{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + const DFromV d; + return yes & VecFromMask(d, mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + const DFromV d; + return AndNot(VecFromMask(d, mask), no); +} + +template +HWY_API Vec256 ZeroIfNegative(Vec256 v) { + static_assert(IsSigned(), "Only for float"); + const DFromV d; + const auto zero = Zero(d); + // AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes + return IfThenElse(MaskFromVec(v), zero, v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + const Full256 d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo /*tag*/, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<1> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<2> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<4> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<8> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask256 operator==(Vec256 a, + Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask256 operator!=(Vec256 a, + Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +#if HWY_HAVE_FLOAT16 +HWY_API Mask256 operator>=(Vec256 a, + Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpge_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpge_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpge_epu64_mask(a.raw, b.raw)}; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + const RebindToSigned> di; + return Mask256{MaskFromVec(BitCast(di, v)).raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi8(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi16(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi32(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi64(v.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_ph(_mm256_movm_epi16(v.raw))}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_ps(_mm256_movm_epi32(v.raw))}; +} + +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_pd(_mm256_movm_epi64(v.raw))}; +} + +#else // AVX2 + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo d_to, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + const Full256 dfrom; + return MaskFromVec(BitCast(d_to, VecFromMask(dfrom, m))); +} + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpeq_epi8(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpeq_epi16(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpeq_epi32(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpeq_epi64(a.raw, b.raw)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Not(a == b); +} +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_NEQ_OQ)}; +} +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +// Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8 +// to perform an unsigned comparison instead of the intended signed. Workaround +// is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy +#if HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 903 +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1 +#else +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0 +#endif + +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { +#if HWY_AVX2_GCC_CMPGT8_WORKAROUND + using i8x32 = signed char __attribute__((__vector_size__(32))); + return Mask256{static_cast<__m256i>(reinterpret_cast(a.raw) > + reinterpret_cast(b.raw))}; +#else + return Mask256{_mm256_cmpgt_epi8(a.raw, b.raw)}; +#endif +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi16(a.raw, b.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi32(a.raw, b.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi64(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask256 Gt(hwy::UnsignedTag /*tag*/, Vec256 a, Vec256 b) { + const Full256 du; + const RebindToSigned di; + const Vec256 msb = Set(du, (LimitsMax() >> 1) + 1); + return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); +} + +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)}; +} + +} // namespace detail + +template +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +namespace detail { + +template +HWY_INLINE Mask256 Ge(hwy::SignedTag tag, Vec256 a, Vec256 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask256 Ge(hwy::UnsignedTag tag, Vec256 a, Vec256 b) { + return Not(Gt(tag, b, a)); +} + +HWY_INLINE Mask256 Ge(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_INLINE Mask256 Ge(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)}; +} + +} // namespace detail + +template +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return detail::Ge(hwy::TypeTag(), a, b); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { + return b > a; +} + +template +HWY_API Mask256 operator<=(const Vec256 a, const Vec256 b) { + return b >= a; +} + +// ------------------------------ Min (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_min_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +#endif +} + +// Signed +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Min(Vec256 a, Vec256 b) { + return Vec256{_mm256_min_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_max_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +#endif +} + +// Signed +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Max(Vec256 a, Vec256 b) { + return Vec256{_mm256_max_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ Iota + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm256_set_epi8( + static_cast(31), static_cast(30), static_cast(29), + static_cast(28), static_cast(27), static_cast(26), + static_cast(25), static_cast(24), static_cast(23), + static_cast(22), static_cast(21), static_cast(20), + static_cast(19), static_cast(18), static_cast(17), + static_cast(16), static_cast(15), static_cast(14), + static_cast(13), static_cast(12), static_cast(11), + static_cast(10), static_cast(9), static_cast(8), + static_cast(7), static_cast(6), static_cast(5), + static_cast(4), static_cast(3), static_cast(2), + static_cast(1), static_cast(0))}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm256_set_epi16( + int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12}, int16_t{11}, + int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6}, int16_t{5}, + int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{ + _mm256_set_ph(float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12}, + float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8}, + float16_t{7}, float16_t{6}, float16_t{5}, float16_t{4}, + float16_t{3}, float16_t{2}, float16_t{1}, float16_t{0})}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm256_set_epi32(int32_t{7}, int32_t{6}, int32_t{5}, + int32_t{4}, int32_t{3}, int32_t{2}, + int32_t{1}, int32_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{ + _mm256_set_epi64x(int64_t{3}, int64_t{2}, int64_t{1}, int64_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{ + _mm256_set_ps(7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm256_set_pd(3.0, 2.0, 1.0, 0.0)}; +} + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + return detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); +} + +// ------------------------------ FirstN (Iota, Lt) + +template > +HWY_API M FirstN(const D d, size_t n) { + constexpr size_t kN = MaxLanes(d); + // For AVX3, this ensures `num` <= 255 as required by bzhi, which only looks + // at the lower 8 bits; for AVX2 and below, this ensures `num` fits in TI. + n = HWY_MIN(n, kN); + +#if HWY_TARGET <= HWY_AVX3 +#if HWY_ARCH_X86_64 + const uint64_t all = (1ull << kN) - 1; + return M::FromBits(_bzhi_u64(all, n)); +#else + const uint32_t all = static_cast((1ull << kN) - 1); + return M::FromBits(_bzhi_u32(all, static_cast(n))); +#endif // HWY_ARCH_X86_64 +#else + const RebindToSigned di; // Signed comparisons are cheaper. + using TI = TFromD; + return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(n))); +#endif +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ AddSub + +HWY_API Vec256 AddSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_addsub_ps(a.raw, b.raw)}; +} +HWY_API Vec256 AddSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_addsub_pd(a.raw, b.raw)}; +} + +// ------------------------------ PairwiseAdd128/PairwiseSub128 + +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_hadd_epi16(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, + Neg(BitCast(di, VFromD{_mm256_hsub_epi16(a.raw, b.raw)}))); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_hadd_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, + Neg(BitCast(di, VFromD{_mm256_hsub_epi32(a.raw, b.raw)}))); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_hadd_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + return Neg(VFromD{_mm256_hsub_ps(a.raw, b.raw)}); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_hadd_pd(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + return Neg(VFromD{_mm256_hsub_pd(a.raw, b.raw)}); +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(Vec256 v) { + return Vec256{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; +} + +HWY_API Vec256 SumsOf8AbsDiff(Vec256 a, Vec256 b) { + return Vec256{_mm256_sad_epu8(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf4 +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +HWY_INLINE Vec256 SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + Vec256 v) { + const DFromV d; + + // _mm256_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be + // zeroed out and the sums of the 4 consecutive lanes are already in the + // even uint16_t lanes of the _mm256_maskz_dbsad_epu8 result. + return Vec256{_mm256_maskz_dbsad_epu8( + static_cast<__mmask16>(0x5555), v.raw, Zero(d).raw, 0)}; +} + +// detail::SumsOf4 for Vec256 on AVX3 is implemented in x86_512-inl.h + +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ SumsOfAdjQuadAbsDiff + +template +static Vec256 SumsOfAdjQuadAbsDiff(Vec256 a, + Vec256 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + return Vec256{_mm256_mpsadbw_epu8( + a.raw, b.raw, + (kAOffset << 5) | (kBOffset << 3) | (kAOffset << 2) | kBOffset)}; +} + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if HWY_TARGET <= HWY_AVX3 +template +static Vec256 SumsOfShuffledQuadAbsDiff(Vec256 a, + Vec256 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + return Vec256{ + _mm256_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; +} +#endif + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{_mm256_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{_mm256_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{_mm256_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{_mm256_adds_epi16(a.raw, b.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + const DFromV d; + const auto sum = a + b; + const auto overflow_mask = MaskFromVec( + Vec256{_mm256_ternarylogic_epi32(a.raw, b.raw, sum.raw, 0x42)}); + const auto i32_max = Set(d, LimitsMax()); + const Vec256 overflow_result{_mm256_mask_ternarylogic_epi32( + i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, sum); +} + +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + const DFromV d; + const auto sum = a + b; + const auto overflow_mask = MaskFromVec( + Vec256{_mm256_ternarylogic_epi64(a.raw, b.raw, sum.raw, 0x42)}); + const auto i64_max = Set(d, LimitsMax()); + const Vec256 overflow_result{_mm256_mask_ternarylogic_epi64( + i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, sum); +} +#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_subs_epi16(a.raw, b.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + const DFromV d; + const auto diff = a - b; + const auto overflow_mask = MaskFromVec( + Vec256{_mm256_ternarylogic_epi32(a.raw, b.raw, diff.raw, 0x18)}); + const auto i32_max = Set(d, LimitsMax()); + const Vec256 overflow_result{_mm256_mask_ternarylogic_epi32( + i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, diff); +} + +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + const DFromV d; + const auto diff = a - b; + const auto overflow_mask = MaskFromVec( + Vec256{_mm256_ternarylogic_epi64(a.raw, b.raw, diff.raw, 0x18)}); + const auto i64_max = Set(d, LimitsMax()); + const Vec256 overflow_result{_mm256_mask_ternarylogic_epi64( + i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, diff); +} +#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{_mm256_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{_mm256_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec256 Abs(Vec256 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (wrong result) + const DFromV d; + const auto zero = Zero(d); + return Vec256{_mm256_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec256{_mm256_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi16(v.raw)}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi32(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi64(v.raw)}; +} +#endif + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; +} +#endif + +// Signed +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; +} +#endif + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec256 MulFixedPoint15(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeft + +#if HWY_TARGET <= HWY_AVX3_DL +namespace detail { +template +HWY_API Vec256 GaloisAffine(Vec256 v, Vec256 matrix) { + return Vec256{_mm256_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; +} +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ ShiftRight + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srai_epi32(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + const Full256 d8; + // Use raw instead of BitCast to support N=1. + const Vec256 shifted{ShiftRight(Vec256{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// i64 is implemented after BroadcastSignBit. + +// ------------------------------ RotateRight + +// U8 RotateRight implementation on AVX3_DL is now in x86_512-inl.h as U8 +// RotateRight uses detail::GaloisAffine on AVX3_DL + +#if HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + if (kBits == 0) return v; + // AVX3 does not support 8-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +} +#endif + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); + if (kBits == 0) return v; +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_shrdi_epi16(v.raw, v.raw, kBits)}; +#else + // AVX3 does not support 16-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +// ------------------------------ Rol/Ror +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{_mm256_shrdv_epi16(a.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec256 Rol(Vec256 a, Vec256 b) { + return Vec256{_mm256_rolv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{_mm256_rorv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Rol(Vec256 a, Vec256 b) { + return Vec256{_mm256_rolv_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{_mm256_rorv_epi64(a.raw, b.raw)}; +} + +#endif + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + const DFromV d; + return VecFromMask(v < Zero(d)); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<31>(v); +} + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{ + _mm256_srai_epi64(v.raw, static_cast(kBits))}; +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<63>(v); +} + +#else // AVX2 + +// Unlike above, this will be used to implement int64_t ShiftRight. +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + const DFromV d; + return VecFromMask(v < Zero(d)); +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +} + +#endif // #if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, + Vec256 no) { + // int8: AVX2 IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + +#if HWY_TARGET <= HWY_AVX3 + const auto mask = MaskFromVec(v); +#else + // 16-bit: no native blendv on AVX2, so copy sign to lower byte's MSB. + const DFromV d; + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#endif + + return IfThenElse(mask, yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + +#if HWY_TARGET <= HWY_AVX3 + // No need to cast to float on AVX3 as IfThenElse only looks at the MSB on + // AVX3 + return IfThenElse(MaskFromVec(v), yes, no); +#else + const DFromV d; + const RebindToFloat df; + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + const MFromD msb = MaskFromVec(BitCast(df, v)); + return BitCast(d, IfThenElse(msb, BitCast(df, yes), BitCast(df, no))); +#endif +} + +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{_mm256_sign_epi8(v.raw, mask.raw)}; +} + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{_mm256_sign_epi16(v.raw, mask.raw)}; +} + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{_mm256_sign_epi32(v.raw, mask.raw)}; +} + +// ------------------------------ ShiftLeftSame + +// Disable sign conversion warnings for GCC debug intrinsics. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi16(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi32(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi64(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi16(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi32(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi64(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srli_epi16(v.raw, bits)}; + } +#endif + return Vec256{_mm256_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srli_epi32(v.raw, bits)}; + } +#endif + return Vec256{_mm256_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srli_epi64(v.raw, bits)}; + } +#endif + return Vec256{_mm256_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srai_epi16(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srai_epi32(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{ + _mm256_srai_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec256{_mm256_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Neg (Xor, Sub) + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec256 Neg(hwy::FloatTag /*tag*/, const Vec256 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +template +HWY_INLINE Vec256 Neg(hwy::SpecialTag /*tag*/, const Vec256 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +// Not floating-point +template +HWY_INLINE Vec256 Neg(hwy::SignedTag /*tag*/, const Vec256 v) { + const DFromV d; + return Zero(d) - v; +} + +} // namespace detail + +template +HWY_API Vec256 Neg(const Vec256 v) { + return detail::Neg(hwy::TypeTag(), v); +} + +// ------------------------------ Floating-point mul / div + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_pd(a.raw, b.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MulByFloorPow2(Vec256 a, + Vec256 b) { + return Vec256{_mm256_scalef_ph(a.raw, b.raw)}; +} +#endif + +HWY_API Vec256 MulByFloorPow2(Vec256 a, Vec256 b) { + return Vec256{_mm256_scalef_ps(a.raw, b.raw)}; +} + +HWY_API Vec256 MulByFloorPow2(Vec256 a, Vec256 b) { + return Vec256{_mm256_scalef_pd(a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return Vec256{_mm256_div_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return Vec256{_mm256_div_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return Vec256{_mm256_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 ApproximateReciprocal(Vec256 v) { + return Vec256{_mm256_rcp_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 ApproximateReciprocal(Vec256 v) { + return Vec256{_mm256_rcp_ps(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 ApproximateReciprocal(Vec256 v) { + return Vec256{_mm256_rcp14_pd(v.raw)}; +} +#endif + +// ------------------------------ GetExponent + +#if HWY_TARGET <= HWY_AVX3 + +#if HWY_HAVE_FLOAT16 +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V GetExponent(V v) { + return V{_mm256_getexp_ph(v.raw)}; +} +#endif +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V GetExponent(V v) { + return V{_mm256_getexp_ps(v.raw)}; +} +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V GetExponent(V v) { + return V{_mm256_getexp_pd(v.raw)}; +} + +#endif + +// ------------------------------ MaskedMinOr + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMaxOr + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedAddOr + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSubOr + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMulOr + +HWY_API Vec256 MaskedMulOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec256 MaskedMulOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MaskedMulOr(Vec256 no, + Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedDivOr + +HWY_API Vec256 MaskedDivOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec256 MaskedDivOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MaskedDivOr(Vec256 no, + Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSatAddOr + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ MaskedSatSubOr + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Floating-point multiply-add variants + +#if HWY_HAVE_FLOAT16 + +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{_mm256_fmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{_mm256_fnmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{_mm256_fmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{_mm256_fnmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, + Vec256 sub_or_add) { + return Vec256{_mm256_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, + Vec256 sub_or_add) { +#ifdef HWY_DISABLE_BMI2_FMA + return AddSub(mul * x, sub_or_add); +#else + return Vec256{_mm256_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, + Vec256 sub_or_add) { +#ifdef HWY_DISABLE_BMI2_FMA + return AddSub(mul * x, sub_or_add); +#else + return Vec256{_mm256_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Sqrt(Vec256 v) { + return Vec256{_mm256_sqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Sqrt(Vec256 v) { + return Vec256{_mm256_sqrt_ps(v.raw)}; +} +HWY_API Vec256 Sqrt(Vec256 v) { + return Vec256{_mm256_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 ApproximateReciprocalSqrt(Vec256 v) { + return Vec256{_mm256_rsqrt_ph(v.raw)}; +} +#endif +HWY_API Vec256 ApproximateReciprocalSqrt(Vec256 v) { + return Vec256{_mm256_rsqrt_ps(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 ApproximateReciprocalSqrt(Vec256 v) { +#if HWY_COMPILER_MSVC + const DFromV d; + return Vec256{_mm256_mask_rsqrt14_pd( + Undefined(d).raw, static_cast<__mmask8>(0xFF), v.raw)}; +#else + return Vec256{_mm256_rsqrt14_pd(v.raw)}; +#endif +} +#endif + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, tie to even +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Round(Vec256 v) { + return Vec256{_mm256_roundscale_ph( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Round(Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Round(Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Trunc(Vec256 v) { + return Vec256{ + _mm256_roundscale_ph(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Trunc(Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Trunc(Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Ceil(Vec256 v) { + return Vec256{ + _mm256_roundscale_ph(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Ceil(Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Ceil(Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Floor(Vec256 v) { + return Vec256{ + _mm256_roundscale_ph(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Floor(Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Floor(Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +// ------------------------------ Floating-point classification + +#if HWY_HAVE_FLOAT16 || HWY_IDE + +HWY_API Mask256 IsNaN(Vec256 v) { + return Mask256{_mm256_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +} + +HWY_API Mask256 IsEitherNaN(Vec256 a, + Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; + HWY_DIAGNOSTICS(pop) +} + +HWY_API Mask256 IsInf(Vec256 v) { + return Mask256{_mm256_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +HWY_API Mask256 IsFinite(Vec256 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask256{_mm256_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask256 IsNaN(Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +#else + return Mask256{_mm256_cmp_ps(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} +HWY_API Mask256 IsNaN(Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +#else + return Mask256{_mm256_cmp_pd(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} + +HWY_API Mask256 IsEitherNaN(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_UNORD_Q)}; +#endif +} + +HWY_API Mask256 IsEitherNaN(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_UNORD_Q)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +HWY_API Mask256 IsInf(Vec256 v) { + return Mask256{_mm256_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} +HWY_API Mask256 IsInf(Vec256 v) { + return Mask256{_mm256_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +HWY_API Mask256 IsFinite(Vec256 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask256{_mm256_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} +HWY_API Mask256 IsFinite(Vec256 v) { + return Not(Mask256{_mm256_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { + return VFromD{ + _mm256_load_si256(reinterpret_cast(aligned))}; +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 Load(D /* tag */, + const float16_t* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_ph(aligned)}; +} +#endif +template +HWY_API Vec256 Load(D /* tag */, const float* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_ps(aligned)}; +} +template +HWY_API Vec256 Load(D /* tag */, const double* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_pd(aligned)}; +} + +template +HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_loadu_si256(reinterpret_cast(p))}; +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_ph(p)}; +} +#endif +template +HWY_API Vec256 LoadU(D /* tag */, const float* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_ps(p)}; +} +template +HWY_API Vec256 LoadU(D /* tag */, const double* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_maskz_loadu_epi16(m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, D /* tag */, + const float* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, D /* tag */, + const double* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_pd(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_mask_loadu_epi8(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{ + _mm256_mask_loadu_epi16(BitCast(du, v).raw, m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_mask_loadu_epi32(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_mask_loadu_epi64(v.raw, m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoadOr(VFromD v, Mask256 m, D /* tag */, + const float* HWY_RESTRICT p) { + return Vec256{_mm256_mask_loadu_ps(v.raw, m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoadOr(VFromD v, Mask256 m, D /* tag */, + const double* HWY_RESTRICT p) { + return Vec256{_mm256_mask_loadu_pd(v.raw, m.raw, p)}; +} + +#else // AVX2 + +// There is no maskload_epi8/16, so blend instead. +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + return VFromD{_mm256_maskload_epi32(pi, m.raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + return VFromD{_mm256_maskload_epi64(pi, m.raw)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, D d, + const float* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec256{_mm256_maskload_ps(p, mi.raw)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, D d, + const double* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec256{_mm256_maskload_pd(p, mi.raw)}; +} + +#endif + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + const Full128> d128; + const RebindToUnsigned du128; + const __m128i v128 = BitCast(du128, LoadU(d128, p)).raw; +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + // Workaround for incorrect results with _mm256_broadcastsi128_si256. Note + // that MSVC also lacks _mm256_zextsi128_si256, but cast (which leaves the + // upper half undefined) is fine because we're overwriting that anyway. + // This workaround seems in turn to generate incorrect code in MSVC 2022 + // (19.31), so use broadcastsi128 there. + return BitCast(d, VFromD{_mm256_inserti128_si256( + _mm256_castsi128_si256(v128), v128, 1)}); +#else + // The preferred path. This is perhaps surprising, because vbroadcasti128 + // with xmm input has 7 cycle latency on Intel, but Clang >= 7 is able to + // pattern-match this to vbroadcastf128 with a memory operand as desired. + return BitCast(d, VFromD{_mm256_broadcastsi128_si256(v128)}); +#endif +} +template +HWY_API Vec256 LoadDup128(D /* tag */, const float* HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const Full128 d128; + const __m128 v128 = LoadU(d128, p).raw; + return Vec256{ + _mm256_insertf128_ps(_mm256_castps128_ps256(v128), v128, 1)}; +#else + return Vec256{_mm256_broadcast_ps(reinterpret_cast(p))}; +#endif +} +template +HWY_API Vec256 LoadDup128(D /* tag */, const double* HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const Full128 d128; + const __m128d v128 = LoadU(d128, p).raw; + return Vec256{ + _mm256_insertf128_pd(_mm256_castpd128_pd256(v128), v128, 1)}; +#else + return Vec256{ + _mm256_broadcast_pd(reinterpret_cast(p))}; +#endif +} + +// ------------------------------ Store + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void Store(Vec256 v, D /* tag */, + float16_t* HWY_RESTRICT aligned) { + _mm256_store_ph(aligned, v.raw); +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API void Store(Vec256 v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm256_store_ps(aligned, v.raw); +} +template +HWY_API void Store(Vec256 v, D /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec256 v, D /* tag */, + float16_t* HWY_RESTRICT p) { + _mm256_storeu_ph(p, v.raw); +} +#endif +template +HWY_API void StoreU(Vec256 v, D /* tag */, float* HWY_RESTRICT p) { + _mm256_storeu_ps(p, v.raw); +} +template +HWY_API void StoreU(Vec256 v, D /* tag */, double* HWY_RESTRICT p) { + _mm256_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm256_mask_storeu_epi8(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + _mm256_mask_storeu_epi16(reinterpret_cast(p), + RebindMask(du, m).raw, BitCast(du, v).raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm256_mask_storeu_epi32(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm256_mask_storeu_epi64(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, D /* tag */, + float* HWY_RESTRICT p) { + _mm256_mask_storeu_ps(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, D /* tag */, + double* HWY_RESTRICT p) { + _mm256_mask_storeu_pd(p, m.raw, v.raw); +} + +#else // AVX2 + +// Intel SDM says "No AC# reported for any mask bit combinations". However, AMD +// allows AC# if "Alignment checking enabled and: 256-bit memory operand not +// 32-byte aligned". Fortunately AC# is not enabled by default and requires both +// OS support (CR0) and the application to set rflags.AC. We assume these remain +// disabled because x86/x64 code and compiler output often contain misaligned +// scalar accesses, which would also fault. +// +// Caveat: these are slow on AMD Jaguar/Bulldozer. + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + // There is no maskload_epi8/16. Blending is also unsafe because loading a + // full vector that crosses the array end causes asan faults. Resort to scalar + // code; the caller should instead use memcpy, assuming m is FirstN(d, n). + const RebindToUnsigned du; + using TU = TFromD; + alignas(32) TU buf[MaxLanes(d)]; + alignas(32) TU mask[MaxLanes(d)]; + Store(BitCast(du, v), du, buf); + Store(BitCast(du, VecFromMask(d, m)), du, mask); + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm256_maskstore_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm256_maskstore_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, D d, + float* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm256_maskstore_ps(p, mi.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, D d, + double* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm256_maskstore_pd(p, mi.raw, v.raw); +} + +#endif + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; // for float16_t + _mm256_stream_si256(reinterpret_cast<__m256i*>(aligned), BitCast(du, v).raw); +} +template +HWY_API void Stream(Vec256 v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm256_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(Vec256 v, D /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_stream_pd(aligned, v.raw); +} + +// ------------------------------ ScatterOffset + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + Vec256 offset) { + _mm256_i32scatter_epi32(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + Vec256 offset) { + _mm256_i64scatter_epi64(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, float* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i32scatter_ps(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, double* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i64scatter_pd(base, offset.raw, v.raw, 1); +} + +// ------------------------------ ScatterIndex + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm256_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm256_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, float* HWY_RESTRICT base, + VFromD> index) { + _mm256_i32scatter_ps(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, double* HWY_RESTRICT base, + VFromD> index) { + _mm256_i64scatter_pd(base, index.raw, v.raw, 8); +} + +// ------------------------------ MaskedScatterIndex + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm256_mask_i32scatter_epi32(base, m.raw, index.raw, v.raw, 4); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm256_mask_i64scatter_epi64(base, m.raw, index.raw, v.raw, 8); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + float* HWY_RESTRICT base, + VFromD> index) { + _mm256_mask_i32scatter_ps(base, m.raw, index.raw, v.raw, 4); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + double* HWY_RESTRICT base, + VFromD> index) { + _mm256_mask_i64scatter_pd(base, m.raw, index.raw, v.raw, 8); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Gather + +namespace detail { + +template +HWY_INLINE Vec256 NativeGather256(const T* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i32gather_epi32( + reinterpret_cast(base), indices.raw, kScale)}; +} + +template +HWY_INLINE Vec256 NativeGather256(const T* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i64gather_epi64( + reinterpret_cast(base), indices.raw, kScale)}; +} + +template +HWY_API Vec256 NativeGather256(const float* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i32gather_ps(base, indices.raw, kScale)}; +} + +template +HWY_API Vec256 NativeGather256(const double* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i64gather_pd(base, indices.raw, kScale)}; +} + +} // namespace detail + +template +HWY_API VFromD GatherOffset(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> offsets) { + return detail::NativeGather256<1>(base, offsets); +} + +template +HWY_API VFromD GatherIndex(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeGather256)>(base, indices); +} + +// ------------------------------ MaskedGatherIndexOr + +namespace detail { + +template +HWY_INLINE Vec256 NativeMaskedGatherOr256(Vec256 no, Mask256 m, + const T* HWY_RESTRICT base, + Vec256 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_mmask_i32gather_epi32( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; +#else + return Vec256{_mm256_mask_i32gather_epi32( + no.raw, reinterpret_cast(base), indices.raw, m.raw, + kScale)}; +#endif +} + +template +HWY_INLINE Vec256 NativeMaskedGatherOr256(Vec256 no, Mask256 m, + const T* HWY_RESTRICT base, + Vec256 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_mmask_i64gather_epi64( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; +#else + // For reasons unknown, _mm256_mask_i64gather_epi64 returns all-zeros. + const Full256 d; + const Full256 dd; + return BitCast(d, + Vec256{_mm256_mask_i64gather_pd( + BitCast(dd, no).raw, reinterpret_cast(base), + indices.raw, RebindMask(dd, m).raw, kScale)}); +#endif +} + +template +HWY_API Vec256 NativeMaskedGatherOr256(Vec256 no, + Mask256 m, + const float* HWY_RESTRICT base, + Vec256 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{ + _mm256_mmask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; +#else + return Vec256{ + _mm256_mask_i32gather_ps(no.raw, base, indices.raw, m.raw, kScale)}; +#endif +} + +template +HWY_API Vec256 NativeMaskedGatherOr256(Vec256 no, + Mask256 m, + const double* HWY_RESTRICT base, + Vec256 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{ + _mm256_mmask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; +#else + return Vec256{ + _mm256_mask_i64gather_pd(no.raw, base, indices.raw, m.raw, kScale)}; +#endif +} + +} // namespace detail + +template +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D /*d*/, + const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeMaskedGatherOr256)>(no, m, base, + indices); +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{_mm256_castsi256_si128(v.raw)}; +} +template +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { + return Vec128{_mm256_castsi256_si128(v.raw)}; +} +template +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { +#if HWY_HAVE_FLOAT16 + return Vec128{_mm256_castph256_ph128(v.raw)}; +#else + return Vec128{_mm256_castsi256_si128(v.raw)}; +#endif // HWY_HAVE_FLOAT16 +} +template +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { + return Vec128{_mm256_castps256_ps128(v.raw)}; +} +template +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { + return Vec128{_mm256_castpd256_pd128(v.raw)}; +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + const Full128 dh; + return LowerHalf(dh, v); +} + +// ------------------------------ UpperHalf + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + const RebindToUnsigned du; // for float16_t + const Twice dut; + return BitCast(d, VFromD{ + _mm256_extracti128_si256(BitCast(dut, v).raw, 1)}); +} +template +HWY_API VFromD UpperHalf(D /* tag */, Vec256 v) { + return VFromD{_mm256_extractf128_ps(v.raw, 1)}; +} +template +HWY_API VFromD UpperHalf(D /* tag */, Vec256 v) { + return VFromD{_mm256_extractf128_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template +HWY_API T ExtractLane(const Vec256 v, size_t i) { + const DFromV d; + HWY_DASSERT(i < Lanes(d)); + +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + if (__builtin_constant_p(i < kLanesPerBlock) && (i < kLanesPerBlock)) { + return ExtractLane(LowerHalf(Half(), v), i); + } +#endif + + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (Store) +template +HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ExtractBlock (LowerHalf, UpperHalf) + +template +HWY_API Vec128 ExtractBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + const Half> dh; + return (kBlockIdx == 0) ? LowerHalf(dh, v) : UpperHalf(dh, v); +} + +// ------------------------------ ZeroExtendVector + +// Unfortunately the initial _mm256_castsi128_si256 intrinsic leaves the upper +// bits undefined. Although it makes sense for them to be zero (VEX encoded +// 128-bit instructions zero the upper lanes to avoid large penalties), a +// compiler could decide to optimize out code that relies on this. +// +// The newer _mm256_zextsi128_si256 intrinsic fixes this by specifying the +// zeroing, but it is not available on MSVC until 1920 nor GCC until 10.1. +// Unfortunately as of 2023-08 it still seems to cause internal compiler errors +// on MSVC, so we consider it unavailable there. +// +// Without zext we can still possibly obtain the desired code thanks to pattern +// recognition; note that the expensive insert instruction might not actually be +// generated, see https://gcc.godbolt.org/z/1MKGaP. + +#if !defined(HWY_HAVE_ZEXT) +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 500) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1000) +#define HWY_HAVE_ZEXT 1 +#else +#define HWY_HAVE_ZEXT 0 +#endif +#endif // defined(HWY_HAVE_ZEXT) + +template +HWY_API VFromD ZeroExtendVector(D /* tag */, VFromD> lo) { +#if HWY_HAVE_ZEXT + return VFromD{_mm256_zextsi128_si256(lo.raw)}; +#elif HWY_COMPILER_MSVC + // Workaround: _mm256_inserti128_si256 does not actually zero the hi part. + return VFromD{_mm256_set_m128i(_mm_setzero_si128(), lo.raw)}; +#else + return VFromD{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; +#endif +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 ZeroExtendVector(D d, Vec128 lo) { +#if HWY_HAVE_ZEXT + (void)d; + return Vec256{_mm256_zextph128_ph256(lo.raw)}; +#else + const RebindToUnsigned du; + return BitCast(d, ZeroExtendVector(du, BitCast(du, lo))); +#endif // HWY_HAVE_ZEXT +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec256 ZeroExtendVector(D /* tag */, Vec128 lo) { +#if HWY_HAVE_ZEXT + return Vec256{_mm256_zextps128_ps256(lo.raw)}; +#else + return Vec256{_mm256_insertf128_ps(_mm256_setzero_ps(), lo.raw, 0)}; +#endif +} +template +HWY_API Vec256 ZeroExtendVector(D /* tag */, Vec128 lo) { +#if HWY_HAVE_ZEXT + return Vec256{_mm256_zextpd128_pd256(lo.raw)}; +#else + return Vec256{_mm256_insertf128_pd(_mm256_setzero_pd(), lo.raw, 0)}; +#endif +} + +// ------------------------------ ZeroExtendResizeBitCast + +namespace detail { + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag<8> /* from_size_tag */, hwy::SizeTag<32> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Twice dt_from; + const Twice dq_from; + return BitCast(d_to, ZeroExtendVector(dq_from, ZeroExtendVector(dt_from, v))); +} + +} // namespace detail + +// ------------------------------ Combine + +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + const RebindToUnsigned du; // for float16_t + const Half dh_u; + const auto lo256 = ZeroExtendVector(du, BitCast(dh_u, lo)); + return BitCast(d, VFromD{_mm256_inserti128_si256( + lo256.raw, BitCast(dh_u, hi).raw, 1)}); +} +template +HWY_API Vec256 Combine(D d, Vec128 hi, Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_insertf128_ps(lo256.raw, hi.raw, 1)}; +} +template +HWY_API Vec256 Combine(D d, Vec128 hi, Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_insertf128_pd(lo256.raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes +template +HWY_API VFromD ShiftLeftBytes(D /* tag */, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bslli_epi128. + return VFromD{_mm256_slli_si256(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D /* tag */, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bsrli_epi128. + return VFromD{_mm256_srli_si256(v.raw, kBytes)}; +} + +// ------------------------------ CombineShiftRightBytes +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + const Repartition d8; + return BitCast(d, Vec256{_mm256_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(vu.raw, (0x55 * kLane) & 0xFF); + return BitCast(d, VU{_mm256_unpacklo_epi64(lo, lo)}); + } else { + const __m256i hi = + _mm256_shufflehi_epi16(vu.raw, (0x55 * (kLane - 4)) & 0xFF); + return BitCast(d, VU{_mm256_unpackhi_epi64(hi, hi)}); + } +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +template +HWY_API Vec256 Broadcast(Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 15 * kLane)}; +} + +// ------------------------------ Concat blocks (LowerHalf, ZeroExtendVector) + +// _mm256_broadcastsi128_si256 has 7 cycle latency on ICL. +// _mm256_permute2x128_si256 is slow on Zen1 (8 uops), so we avoid it (at no +// extra cost) for LowerLower and UpperLower. + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + const Half d2; + const RebindToUnsigned du2; // for float16_t + return BitCast( + d, VFromD{_mm256_inserti128_si256( + BitCast(du, lo).raw, BitCast(du2, LowerHalf(d2, hi)).raw, 1)}); +} +template +HWY_API Vec256 ConcatLowerLower(D d, Vec256 hi, + Vec256 lo) { + const Half d2; + return Vec256{_mm256_insertf128_ps(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} +template +HWY_API Vec256 ConcatLowerLower(D d, Vec256 hi, + Vec256 lo) { + const Half d2; + return Vec256{_mm256_insertf128_pd(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_permute2x128_si256( + BitCast(du, lo).raw, BitCast(du, hi).raw, 0x21)}); +} +template +HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x21)}; +} +template +HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x21)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_blend_epi32( + BitCast(du, hi).raw, BitCast(du, lo).raw, 0x0F)}); +} +template +HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_blend_ps(hi.raw, lo.raw, 0x0F)}; +} +template +HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_blend_pd(hi.raw, lo.raw, 3)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_permute2x128_si256( + BitCast(du, lo).raw, BitCast(du, hi).raw, 0x31)}); +} +template +HWY_API Vec256 ConcatUpperUpper(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x31)}; +} +template +HWY_API Vec256 ConcatUpperUpper(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x31)}; +} + +// ------------------------------ BroadcastBlock +template +HWY_API Vec256 BroadcastBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + const DFromV d; + return (kBlockIdx == 0) ? ConcatLowerLower(d, v, v) + : ConcatUpperUpper(d, v, v); +} + +// ------------------------------ BroadcastLane + +namespace detail { + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastb_epi8(LowerHalf(dh, v).raw)}; +} + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + const Half dh; + const RebindToUnsigned dh_u; + return BitCast(d, VFromD{_mm256_broadcastw_epi16( + BitCast(dh_u, LowerHalf(dh, v)).raw)}); +} + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastd_epi32(LowerHalf(dh, v).raw)}; +} + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastq_epi64(LowerHalf(dh, v).raw)}; +} + +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastss_ps(LowerHalf(dh, v).raw)}; +} + +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastsd_pd(LowerHalf(dh, v).raw)}; +} + +template * = nullptr, + HWY_IF_NOT_T_SIZE(T, 8)> +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag /* lane_idx_tag */, + Vec256 v) { + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + constexpr int kBlockIdx = static_cast(kLaneIdx / kLanesPerBlock); + constexpr int kLaneInBlkIdx = + static_cast(kLaneIdx) & (kLanesPerBlock - 1); + return Broadcast(BroadcastBlock(v)); +} + +template * = nullptr, + HWY_IF_UI64(T)> +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag /* lane_idx_tag */, + Vec256 v) { + static_assert(kLaneIdx <= 3, "Invalid lane"); + return Vec256{ + _mm256_permute4x64_epi64(v.raw, static_cast(0x55 * kLaneIdx))}; +} + +template * = nullptr> +HWY_INLINE Vec256 BroadcastLane( + hwy::SizeTag /* lane_idx_tag */, Vec256 v) { + static_assert(kLaneIdx <= 3, "Invalid lane"); + return Vec256{ + _mm256_permute4x64_pd(v.raw, static_cast(0x55 * kLaneIdx))}; +} + +} // namespace detail + +template +HWY_API Vec256 BroadcastLane(Vec256 v) { + static_assert(kLaneIdx >= 0, "Invalid lane"); + return detail::BroadcastLane(hwy::SizeTag(kLaneIdx)>(), + v); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec256 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0xB1)}; +} +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// Used by generic_ops-inl.h +namespace detail { + +template +HWY_API Vec256 ShuffleTwo2301(const Vec256 a, const Vec256 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template +HWY_API Vec256 ShuffleTwo1230(const Vec256 a, const Vec256 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template +HWY_API Vec256 ShuffleTwo3012(const Vec256 a, const Vec256 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + // Shorter encoding than _mm256_permute_ps. + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + // Shorter encoding than _mm256_permute_pd. + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 5)}; +} + +// Rotate right 32 bits +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices256 { + __m256i raw; +}; + +// 8-bit lanes: indices remain unchanged +template +HWY_API Indices256> IndicesFromVec(D /* tag */, Vec256 vec) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full256 di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(2 * Lanes(di)))))); +#endif + return Indices256>{vec.raw}; +} + +// 16-bit lanes: convert indices to 32x8 unless AVX3 is available +template +HWY_API Indices256> IndicesFromVec(D /* tag */, Vec256 vec) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + const Full256 di; +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(2 * Lanes(di)))))); +#endif + +#if HWY_TARGET <= HWY_AVX3 + (void)di; + return Indices256>{vec.raw}; +#else + const Repartition d8; + using V8 = VFromD; + alignas(32) static constexpr uint8_t kByteOffsets[32] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + + // Broadcast each lane index to all 2 bytes of T + alignas(32) static constexpr uint8_t kBroadcastLaneBytes[32] = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14, + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<1>(BitCast(d16, lane_indices))); + + return Indices256>{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif // HWY_TARGET <= HWY_AVX3 +} + +// Native 8x32 instruction: indices remain unchanged +template +HWY_API Indices256> IndicesFromVec(D /* tag */, Vec256 vec) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full256 di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(2 * Lanes(di)))))); +#endif + return Indices256>{vec.raw}; +} + +// 64-bit lanes: convert indices to 8x32 unless AVX3 is available +template +HWY_API Indices256> IndicesFromVec(D d, Vec256 idx64) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + const Rebind di; + (void)di; // potentially unused +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllFalse(di, Lt(idx64, Zero(di))) && + AllTrue(di, Lt(idx64, Set(di, static_cast(2 * Lanes(di)))))); +#endif + +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Indices256>{idx64.raw}; +#else + const Repartition df; // 32-bit! + // Replicate 64-bit index into upper 32 bits + const Vec256 dup = + BitCast(di, Vec256{_mm256_moveldup_ps(BitCast(df, idx64).raw)}); + // For each idx64 i, idx32 are 2*i and 2*i+1. + const Vec256 idx32 = dup + dup + Set(di, TI(1) << 32); + return Indices256>{idx32.raw}; +#endif +} + +template +HWY_API Indices256> SetTableIndices(D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_permutexvar_epi8(idx.raw, v.raw)}; +#else + const Vec256 idx_vec{idx.raw}; + const DFromV d; + const Repartition du16; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, idx_vec)))); + + const auto a = ConcatLowerLower(d, v, v); + const auto b = ConcatUpperUpper(d, v, v); + const auto lo_lookup_result = TableLookupBytes(a, idx_vec); + +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_mask_shuffle_epi8( + lo_lookup_result.raw, sel_hi_mask.raw, b.raw, idx_vec.raw)}; +#else + const auto hi_lookup_result = TableLookupBytes(b, idx_vec); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif // HWY_TARGET <= HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX3_DL +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_epi16(idx.raw, v.raw)}; +#else + const DFromV d; + const Repartition du8; + return BitCast( + d, TableLookupLanes(BitCast(du8, v), Indices256{idx.raw})); +#endif +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 TableLookupLanes(Vec256 v, + Indices256 idx) { + return Vec256{_mm256_permutexvar_ph(idx.raw, v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_epi64(idx.raw, v.raw)}; +#else + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +#endif +} + +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { + return Vec256{_mm256_permutevar8x32_ps(v.raw, idx.raw)}; +} + +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_pd(idx.raw, v.raw)}; +#else + const Full256 df; + const Full256 du; + return BitCast(df, Vec256{_mm256_permutevar8x32_epi32( + BitCast(du, v).raw, idx.raw)}); +#endif +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_permutex2var_epi8(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<2>(Vec256{idx.raw}))); + const auto lo_lookup_result = TableLookupLanes(a, idx); + const auto hi_lookup_result = TableLookupLanes(b, idx); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_epi16(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const Repartition du8; + return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), + Indices256{idx.raw})); +#endif +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_epi32(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const RebindToFloat df; + const Vec256 idx_vec{idx.raw}; + + const auto sel_hi_mask = MaskFromVec(BitCast(df, ShiftLeft<28>(idx_vec))); + const auto lo_lookup_result = BitCast(df, TableLookupLanes(a, idx)); + const auto hi_lookup_result = BitCast(df, TableLookupLanes(b, idx)); + return BitCast(d, + IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result)); +#endif +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, + Vec256 b, + Indices256 idx) { + return Vec256{_mm256_permutex2var_ph(a.raw, idx.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_ps(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<28>(Vec256{idx.raw}))); + const auto lo_lookup_result = TableLookupLanes(a, idx); + const auto hi_lookup_result = TableLookupLanes(b, idx); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_epi64(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const Repartition du32; + return BitCast(d, TwoTablesLookupLanes(BitCast(du32, a), BitCast(du32, b), + Indices256{idx.raw})); +#endif +} + +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_pd(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const Repartition du32; + return BitCast(d, TwoTablesLookupLanes(BitCast(du32, a), BitCast(du32, b), + Indices256{idx.raw})); +#endif +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_permute4x64_epi64( + BitCast(du, v).raw, _MM_SHUFFLE(1, 0, 3, 2))}); +} + +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + return Vec256{_mm256_permute4x64_pd(v.raw, _MM_SHUFFLE(1, 0, 3, 2))}; +} + +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + // Assume no domain-crossing penalty between float/double (true on SKX). + const DFromV d; + const RepartitionToWide dw; + return BitCast(d, SwapAdjacentBlocks(BitCast(dw, v))); +} + +// ------------------------------ InterleaveEvenBlocks (ConcatLowerLower) +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveEvenBlocks(D d, V a, V b) { + return ConcatLowerLower(d, b, a); +} + +// ------------------------------ InterleaveOddBlocks (ConcatUpperUpper) +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveOddBlocks(D d, V a, V b) { + return ConcatUpperUpper(d, b, a); +} + +// ------------------------------ Reverse (RotateRight) + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(32) static constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(32) static constexpr int64_t kReverse[4] = {3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) static constexpr int16_t kReverse[16] = { + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec256 idx = Load(di, kReverse); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + const auto rev128 = TableLookupBytes(v, shuffle); + return VFromD{ + _mm256_permute4x64_epi64(rev128.raw, _MM_SHUFFLE(1, 0, 3, 2))}; +#endif +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr TFromD kReverse[32] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +#else + // First reverse bytes within blocks via PSHUFB, then swap blocks. + alignas(32) static constexpr TFromD kReverse[32] = { + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + return SwapAdjacentBlocks(TableLookupBytes(v, Load(d, kReverse))); +#endif +} + +// ------------------------------ Reverse2 (in x86_128) + +// ------------------------------ Reverse4 (SwapAdjacentBlocks) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0706, 0x0504, 0x0302, 0x0100, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908); + return BitCast(d, TableLookupBytes(v, shuffle)); +} + +// 32 bit Reverse4 defined in x86_128. + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + // Could also use _mm256_permute4x64_epi64. + return SwapAdjacentBlocks(Shuffle01(v)); +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + return Reverse(d, v); +} + +template +HWY_API VFromD Reverse8(D /* tag */, const VFromD /* v */) { + HWY_ASSERT(0); // AVX2 does not have 8 64-bit lanes +} + +// ------------------------------ ReverseBits in x86_512 + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm256_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm256_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_pd(a.raw, b.raw)}; +} + +// ---------------------------- InsertBlock (ConcatLowerLower, ConcatUpperLower) +template +HWY_API Vec256 InsertBlock(Vec256 v, Vec128 blk_to_insert) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + + const DFromV d; + const auto vec_to_insert = ResizeBitCast(d, blk_to_insert); + return (kBlockIdx == 0) ? ConcatUpperLower(d, v, vec_to_insert) + : ConcatLowerLower(d, vec_to_insert, v); +} + +// ------------------------------ ConcatOdd + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr uint8_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Unsigned 8-bit shift so we can pack. + const Vec256 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec256 uL = ShiftRight<8>(BitCast(dw, lo)); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return VFromD{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) static constexpr uint16_t kIdx[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Unsigned 16-bit shift so we can pack. + const Vec256 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec256 uL = ShiftRight<16>(BitCast(dw, lo)); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return BitCast(d, VFromD{_mm256_permute4x64_epi64( + u16, _MM_SHUFFLE(3, 1, 2, 0))}); +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) static constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v3131{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return VFromD{_mm256_permute4x64_epi64(BitCast(du, v3131).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) static constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return VFromD{_mm256_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + const VFromD v3131{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return BitCast(d, Vec256{_mm256_permute4x64_epi64( + BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}); +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v31{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 15)}; + return VFromD{ + _mm256_permute4x64_epi64(BitCast(du, v31).raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatOdd(D d, Vec256 hi, Vec256 lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return Vec256{ + _mm256_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + (void)d; + const Vec256 v31{_mm256_shuffle_pd(lo.raw, hi.raw, 15)}; + return Vec256{ + _mm256_permute4x64_pd(v31.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ ConcatEven + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(64) static constexpr uint8_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec256 mask = Set(dw, 0x00FF); + const Vec256 uH = And(BitCast(dw, hi), mask); + const Vec256 uL = And(BitCast(dw, lo), mask); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return VFromD{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint16_t kIdx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 16 bits per u32 so we can pack. + const Vec256 mask = Set(dw, 0x0000FFFF); + const Vec256 uH = And(BitCast(dw, hi), mask); + const Vec256 uL = And(BitCast(dw, lo), mask); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return BitCast(d, VFromD{_mm256_permute4x64_epi64( + u16, _MM_SHUFFLE(3, 1, 2, 0))}); +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v2020{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return VFromD{_mm256_permute4x64_epi64(BitCast(du, v2020).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return VFromD{_mm256_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + const VFromD v2020{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return BitCast(d, Vec256{_mm256_permute4x64_epi64( + BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); + +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v20{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; + return VFromD{ + _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +template +HWY_API Vec256 ConcatEven(D d, Vec256 hi, Vec256 lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return Vec256{ + _mm256_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + (void)d; + const Vec256 v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; + return Vec256{ + _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ InterleaveWholeLower + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(32) static constexpr uint8_t kIdx[32] = { + 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, + 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47}; + return VFromD{_mm256_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +#endif +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint16_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, + 4, 20, 5, 21, 6, 22, 7, 23}; + return BitCast( + d, VFromD{_mm256_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm256_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm256_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {0, 4, 1, 5}; + return VFromD{_mm256_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {0, 4, 1, 5}; + return VFromD{_mm256_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} +#else // AVX2 +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +} +#endif + +// ------------------------------ InterleaveWholeUpper + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(32) static constexpr uint8_t kIdx[32] = { + 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, + 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63}; + return VFromD{_mm256_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +#endif +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint16_t kIdx[16] = { + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + return BitCast( + d, VFromD{_mm256_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm256_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm256_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {2, 6, 3, 7}; + return VFromD{_mm256_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {2, 6, 3, 7}; + return VFromD{_mm256_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} +#else // AVX2 +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +} +#endif + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template +HWY_API Vec256 DupEven(const Vec256 v) { + const DFromV d; + return InterleaveLower(d, v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec256 DupOdd(const Vec256 v) { + const DFromV d; + return InterleaveUpper(d, v, v); +} + +// ------------------------------ OddEven + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + const DFromV d; + const Full256 d8; + const VFromD mask = + Dup128VecFromValues(d8, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, + 0, 0xFF, 0, 0xFF, 0); + return IfThenElse(MaskFromVec(BitCast(d, mask)), b, a); +} + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_blend_epi16( + BitCast(du, a).raw, BitCast(du, b).raw, 0x55)}); +} + +#if HWY_HAVE_FLOAT16 +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{ + _mm256_mask_blend_ph(static_cast<__mmask16>(0x5555), a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x55)}; +} + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x33)}; +} + +HWY_API Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_blend_ps(a.raw, b.raw, 0x55)}; +} + +HWY_API Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_blend_pd(a.raw, b.raw, 5)}; +} + +// -------------------------- InterleaveEven + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_epi32( + a.raw, static_cast<__mmask8>(0xAA), b.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_ps(a.raw, static_cast<__mmask8>(0xAA), + b.raw, b.raw, + _MM_SHUFFLE(2, 2, 0, 0))}; +} +#else +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const VFromD b2_b0_a2_a0{_mm256_shuffle_ps( + BitCast(df, a).raw, BitCast(df, b).raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return BitCast( + d, VFromD{_mm256_shuffle_ps( + b2_b0_a2_a0.raw, b2_b0_a2_a0.raw, _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// I64/U64/F64 InterleaveEven is generic for vector lengths >= 32 bytes +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// -------------------------- InterleaveOdd + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_epi32( + b.raw, static_cast<__mmask8>(0x55), a.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_ps(b.raw, static_cast<__mmask8>(0x55), + a.raw, a.raw, + _MM_SHUFFLE(3, 3, 1, 1))}; +} +#else +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const VFromD b3_b1_a3_a3{_mm256_shuffle_ps( + BitCast(df, a).raw, BitCast(df, b).raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return BitCast( + d, VFromD{_mm256_shuffle_ps( + b3_b1_a3_a3.raw, b3_b1_a3_a3.raw, _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// I64/U64/F64 InterleaveOdd is generic for vector lengths >= 32 bytes +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + +// ------------------------------ OddEvenBlocks + +template +Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_blend_epi32( + BitCast(du, odd).raw, BitCast(du, even).raw, 0xFu)}); +} + +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_ps(odd.raw, even.raw, 0xFu)}; +} + +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_pd(odd.raw, even.raw, 0x3u)}; +} + +// ------------------------------ ReverseBlocks (SwapAdjacentBlocks) + +template +HWY_API VFromD ReverseBlocks(D /*d*/, VFromD v) { + return SwapAdjacentBlocks(v); +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template +HWY_API Vec256 TableLookupBytes(Vec256 bytes, Vec256 from) { + const DFromV d; + return BitCast(d, Vec256{_mm256_shuffle_epi8( + BitCast(Full256(), bytes).raw, + BitCast(Full256(), from).raw)}); +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec256 bytes, Vec128 from) { + const Full256 di; + const Half dih; + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(di, Vec128{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128{LowerHalf(dih, tbl_full).raw}; +} + +// Partial table vector +template +HWY_API Vec256 TableLookupBytes(Vec128 bytes, Vec256 from) { + const Full256 d; + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(d, Vec128{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by x86_128. + +// ------------------------------ I8/U8 Broadcast (TableLookupBytes) + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return TableLookupBytes(v, Set(Full256(), static_cast(kLane))); +} + +// ------------------------------ Per4LaneBlockShuffle + +namespace detail { + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + return BitCast(d, Vec256{_mm256_set_epi32( + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0), + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0))}); +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + return V{_mm256_shuffle_epi32(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + return V{_mm256_shuffle_ps(v.raw, v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x44> /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + return ConcatLowerLower(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xEE> /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + return ConcatUpperUpper(d, v, v); +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + return V{_mm256_permute4x64_epi64(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + return V{_mm256_permute4x64_pd(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V CombineShiftRightI32Lanes(V hi, V lo) { + const DFromV d; + const Repartition du32; + return BitCast(d, + Vec256{_mm256_alignr_epi32( + BitCast(du32, hi).raw, BitCast(du32, lo).raw, kI32Lanes)}); +} + +template +HWY_INLINE V CombineShiftRightI64Lanes(V hi, V lo) { + const DFromV d; + const Repartition du64; + return BitCast(d, + Vec256{_mm256_alignr_epi64( + BitCast(du64, hi).raw, BitCast(du64, lo).raw, kI64Lanes)}); +} + +template +HWY_INLINE V SlideUpI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + const DFromV d; + return CombineShiftRightI64Lanes<4 - kI64Lanes>(v, Zero(d)); +} +#else // AVX2 +template )> +HWY_INLINE V SlideUpI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + constexpr int kIdx0 = (-kI64Lanes) & 3; + constexpr int kIdx1 = (-kI64Lanes + 1) & 3; + constexpr int kIdx2 = (-kI64Lanes + 2) & 3; + constexpr int kIdx3 = (-kI64Lanes + 3) & 3; + constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0); + constexpr int kBlendMask = (1 << (kI64Lanes * 2)) - 1; + + const DFromV d; + return V{_mm256_blend_epi32(_mm256_permute4x64_epi64(v.raw, kIdx3210), + Zero(d).raw, kBlendMask)}; +} + +template )> +HWY_INLINE V SlideUpI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + constexpr int kIdx0 = (-kI64Lanes) & 3; + constexpr int kIdx1 = (-kI64Lanes + 1) & 3; + constexpr int kIdx2 = (-kI64Lanes + 2) & 3; + constexpr int kIdx3 = (-kI64Lanes + 3) & 3; + constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0); + constexpr int kBlendMask = (1 << kI64Lanes) - 1; + + const DFromV d; + const Repartition dd; + return BitCast(d, Vec256{_mm256_blend_pd( + _mm256_permute4x64_pd(BitCast(dd, v).raw, kIdx3210), + Zero(dd).raw, kBlendMask)}); +} +#endif // HWY_TARGET <= HWY_AVX3 + +template HWY_AVX3) ? (1 << 2) : 0))> +HWY_INLINE VFromD TableLookupSlideUpLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + + const auto idx_vec = + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromD))); + const Indices256> idx{idx_vec.raw}; + +#if HWY_TARGET <= HWY_AVX3_DL + return TwoTablesLookupLanes(v, Zero(d), idx); +#else + return TableLookupLanes(v, idx); +#endif +} + +template +HWY_INLINE VFromD TableLookupSlideUpLanes(D d, VFromD v, size_t amt) { + const RebindToUnsigned du; + using TU = TFromD; + + const auto idx = Iota(du, static_cast(size_t{0} - amt)); +#if HWY_TARGET <= HWY_AVX3 + const auto masked_idx = + And(idx, Set(du, static_cast(MaxLanes(d) * 2 - 1))); + return TwoTablesLookupLanes(v, Zero(d), IndicesFromVec(d, masked_idx)); +#else + const auto masked_idx = And(idx, Set(du, static_cast(MaxLanes(d) - 1))); + return IfThenElseZero(RebindMask(d, idx == masked_idx), + TableLookupLanes(v, IndicesFromVec(d, masked_idx))); +#endif +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_INLINE VFromD TableLookupSlideUpLanes(D d, VFromD v, size_t amt) { + const RepartitionToNarrow dn; + return BitCast(d, TableLookupSlideUpLanes(dn, BitCast(dn, v), amt * 2)); +} +#endif // HWY_TARGET > HWY_AVX3 + +} // namespace detail + +template +HWY_API VFromD SlideUpBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + return (kBlocks == 1) ? ConcatLowerLower(d, v, Zero(d)) : v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + if (__builtin_constant_p(amt)) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + return CombineShiftRightBytes<15>(d, v, v_lo); + case 2: + return CombineShiftRightBytes<14>(d, v, v_lo); + case 3: + return CombineShiftRightBytes<13>(d, v, v_lo); + case 4: +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<7>(v, Zero(d)); +#else + return CombineShiftRightBytes<12>(d, v, v_lo); +#endif + case 5: + return CombineShiftRightBytes<11>(d, v, v_lo); + case 6: + return CombineShiftRightBytes<10>(d, v, v_lo); + case 7: + return CombineShiftRightBytes<9>(d, v, v_lo); + case 8: + return detail::SlideUpI64Lanes<1>(v); + case 9: + return CombineShiftRightBytes<7>(d, v, v_lo); + case 10: + return CombineShiftRightBytes<6>(d, v, v_lo); + case 11: + return CombineShiftRightBytes<5>(d, v, v_lo); + case 12: +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<5>(v, Zero(d)); +#else + return CombineShiftRightBytes<4>(d, v, v_lo); +#endif + case 13: + return CombineShiftRightBytes<3>(d, v, v_lo); + case 14: + return CombineShiftRightBytes<2>(d, v, v_lo); + case 15: + return CombineShiftRightBytes<1>(d, v, v_lo); + case 16: + return ConcatLowerLower(d, v, Zero(d)); +#if HWY_TARGET <= HWY_AVX3 + case 20: + return detail::CombineShiftRightI32Lanes<3>(v, Zero(d)); +#endif + case 24: + return detail::SlideUpI64Lanes<3>(v); +#if HWY_TARGET <= HWY_AVX3 + case 28: + return detail::CombineShiftRightI32Lanes<1>(v, Zero(d)); +#endif + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + const Half dh; + return Combine(d, SlideUpLanes(dh, LowerHalf(dh, v), amt - kLanesPerBlock), + Zero(dh)); + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +// ------------------------------ Slide1Up + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<15>(d, v, v_lo); +} + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<14>(d, v, v_lo); +} + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<7>(v, Zero(d)); +#else + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<12>(d, v, v_lo); +#endif +} + +template +HWY_API VFromD Slide1Up(D /*d*/, VFromD v) { + return detail::SlideUpI64Lanes<1>(v); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V SlideDownI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + const DFromV d; + return CombineShiftRightI64Lanes(Zero(d), v); +} +#else // AVX2 +template )> +HWY_INLINE V SlideDownI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + constexpr int kIdx1 = (kI64Lanes + 1) & 3; + constexpr int kIdx2 = (kI64Lanes + 2) & 3; + constexpr int kIdx3 = (kI64Lanes + 3) & 3; + constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kI64Lanes); + constexpr int kBlendMask = + static_cast((0xFFu << ((4 - kI64Lanes) * 2)) & 0xFFu); + + const DFromV d; + return V{_mm256_blend_epi32(_mm256_permute4x64_epi64(v.raw, kIdx3210), + Zero(d).raw, kBlendMask)}; +} + +template )> +HWY_INLINE V SlideDownI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + constexpr int kIdx1 = (kI64Lanes + 1) & 3; + constexpr int kIdx2 = (kI64Lanes + 2) & 3; + constexpr int kIdx3 = (kI64Lanes + 3) & 3; + constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kI64Lanes); + constexpr int kBlendMask = (0x0F << (4 - kI64Lanes)) & 0x0F; + + const DFromV d; + const Repartition dd; + return BitCast(d, Vec256{_mm256_blend_pd( + _mm256_permute4x64_pd(BitCast(dd, v).raw, kIdx3210), + Zero(dd).raw, kBlendMask)}); +} +#endif // HWY_TARGET <= HWY_AVX3 + +template HWY_AVX3) ? (1 << 2) : 0))> +HWY_INLINE VFromD TableLookupSlideDownLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + + auto idx_vec = Iota(du8, static_cast(amt * sizeof(TFromD))); + +#if HWY_TARGET <= HWY_AVX3_DL + const auto result_mask = idx_vec < Set(du8, uint8_t{32}); + return VFromD{ + _mm256_maskz_permutexvar_epi8(result_mask.raw, idx_vec.raw, v.raw)}; +#else + const RebindToSigned di8; + idx_vec = + Or(idx_vec, BitCast(du8, VecFromMask(di8, BitCast(di8, idx_vec) > + Set(di8, int8_t{31})))); + return TableLookupLanes(v, Indices256>{idx_vec.raw}); +#endif +} + +template +HWY_INLINE VFromD TableLookupSlideDownLanes(D d, VFromD v, size_t amt) { + const RebindToUnsigned du; + using TU = TFromD; + + const auto idx = Iota(du, static_cast(amt)); + const auto masked_idx = And(idx, Set(du, static_cast(MaxLanes(d) - 1))); + + return IfThenElseZero(RebindMask(d, idx == masked_idx), + TableLookupLanes(v, IndicesFromVec(d, masked_idx))); +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_INLINE VFromD TableLookupSlideDownLanes(D d, VFromD v, size_t amt) { + const RepartitionToNarrow dn; + return BitCast(d, TableLookupSlideDownLanes(dn, BitCast(dn, v), amt * 2)); +} +#endif // HWY_TARGET > HWY_AVX3 + +} // namespace detail + +template +HWY_API VFromD SlideDownBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + const Half dh; + return (kBlocks == 1) ? ZeroExtendVector(d, UpperHalf(dh, v)) : v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + const Half dh; + if (__builtin_constant_p(amt)) { + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + return CombineShiftRightBytes<1>(d, v_hi, v); + case 2: + return CombineShiftRightBytes<2>(d, v_hi, v); + case 3: + return CombineShiftRightBytes<3>(d, v_hi, v); + case 4: +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<1>(Zero(d), v); +#else + return CombineShiftRightBytes<4>(d, v_hi, v); +#endif + case 5: + return CombineShiftRightBytes<5>(d, v_hi, v); + case 6: + return CombineShiftRightBytes<6>(d, v_hi, v); + case 7: + return CombineShiftRightBytes<7>(d, v_hi, v); + case 8: + return detail::SlideDownI64Lanes<1>(v); + case 9: + return CombineShiftRightBytes<9>(d, v_hi, v); + case 10: + return CombineShiftRightBytes<10>(d, v_hi, v); + case 11: + return CombineShiftRightBytes<11>(d, v_hi, v); + case 12: +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<3>(Zero(d), v); +#else + return CombineShiftRightBytes<12>(d, v_hi, v); +#endif + case 13: + return CombineShiftRightBytes<13>(d, v_hi, v); + case 14: + return CombineShiftRightBytes<14>(d, v_hi, v); + case 15: + return CombineShiftRightBytes<15>(d, v_hi, v); + case 16: + return v_hi; +#if HWY_TARGET <= HWY_AVX3 + case 20: + return detail::CombineShiftRightI32Lanes<5>(Zero(d), v); +#endif + case 24: + return detail::SlideDownI64Lanes<3>(v); +#if HWY_TARGET <= HWY_AVX3 + case 28: + return detail::CombineShiftRightI32Lanes<7>(Zero(d), v); +#endif + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + return ZeroExtendVector( + d, SlideDownLanes(dh, UpperHalf(dh, v), amt - kLanesPerBlock)); + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +// ------------------------------ Slide1Down + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<1>(d, v_hi, v); +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<2>(d, v_hi, v); +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<1>(Zero(d), v); +#else + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<4>(d, v_hi, v); +#endif +} + +template +HWY_API VFromD Slide1Down(D /*d*/, VFromD v) { + return detail::SlideDownI64Lanes<1>(v); +} + +// ------------------------------ Shl (Mul, ZipLower) + +namespace detail { + +#if HWY_TARGET > HWY_AVX3 && !HWY_IDE // AVX2 or older +template +HWY_INLINE V AVX2ShlU16Vec256(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind du32; + + const auto lo_shl_result = PromoteTo(du32, LowerHalf(dh, v)) + << PromoteTo(du32, LowerHalf(dh, bits)); + const auto hi_shl_result = PromoteTo(du32, UpperHalf(dh, v)) + << PromoteTo(du32, UpperHalf(dh, bits)); + return ConcatEven(d, BitCast(d, hi_shl_result), BitCast(d, lo_shl_result)); +} +#endif + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + return Vec256{_mm256_sllv_epi16(v.raw, bits.raw)}; +#else + return AVX2ShlU16Vec256(v, bits); +#endif +} + +// 8-bit: may use the Shl overload for uint16_t. +HWY_API Vec256 Shl(hwy::UnsignedTag tag, Vec256 v, + Vec256 bits) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL + (void)tag; + // masks[i] = 0xFF >> i + const VFromD masks = + Dup128VecFromValues(d, 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0, + 0, 0, 0, 0, 0, 0, 0); + // kShl[i] = 1 << i + const VFromD shl = Dup128VecFromValues( + d, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 0, 0, 0, 0, 0, 0, 0, 0); + v = And(v, TableLookupBytes(masks, bits)); + const VFromD mul = TableLookupBytes(shl, bits); + return VFromD{_mm256_gf2p8mul_epi8(v.raw, mul.raw)}; +#else + const Repartition dw; + using VW = VFromD; + const VW even_mask = Set(dw, 0x00FF); + const VW odd_mask = Set(dw, 0xFF00); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + // Shift even lanes in-place + const VW evens = Shl(tag, vw, And(bits16, even_mask)); + const VW odds = Shl(tag, And(vw, odd_mask), ShiftRight<8>(bits16)); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +#endif +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{_mm256_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{_mm256_sllv_epi64(v.raw, bits.raw)}; +} + +template +HWY_INLINE Vec256 Shl(hwy::SignedTag /*tag*/, Vec256 v, Vec256 bits) { + // Signed left shifts are the same as unsigned. + const Full256 di; + const Full256> du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec256 operator<<(Vec256 v, Vec256 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (MulHigh, IfThenElse, Not) + +#if HWY_TARGET > HWY_AVX3 // AVX2 +namespace detail { + +template +HWY_INLINE V AVX2ShrU16Vec256(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind di32; + const Rebind du32; + + const auto lo_shr_result = + PromoteTo(du32, LowerHalf(dh, v)) >> PromoteTo(du32, LowerHalf(dh, bits)); + const auto hi_shr_result = + PromoteTo(du32, UpperHalf(dh, v)) >> PromoteTo(du32, UpperHalf(dh, bits)); + return OrderedDemote2To(d, BitCast(di32, lo_shr_result), + BitCast(di32, hi_shr_result)); +} + +} // namespace detail +#endif + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srlv_epi16(v.raw, bits.raw)}; +#else + return detail::AVX2ShrU16Vec256(v, bits); +#endif +} + +// 8-bit uses 16-bit shifts. +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + const DFromV d; + const RepartitionToWide dw; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + const VW evens = And(vw, mask) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> ShiftRight<8>(bits16); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srlv_epi64(v.raw, bits.raw)}; +} + +#if HWY_TARGET > HWY_AVX3 // AVX2 +namespace detail { + +template +HWY_INLINE V AVX2ShrI16Vec256(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind di32; + + const auto lo_shr_result = + PromoteTo(di32, LowerHalf(dh, v)) >> PromoteTo(di32, LowerHalf(dh, bits)); + const auto hi_shr_result = + PromoteTo(di32, UpperHalf(dh, v)) >> PromoteTo(di32, UpperHalf(dh, bits)); + return OrderedDemote2To(d, lo_shr_result, hi_shr_result); +} + +} // namespace detail +#endif + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srav_epi16(v.raw, bits.raw)}; +#else + return detail::AVX2ShrI16Vec256(v, bits); +#endif +} + +// 8-bit uses 16-bit shifts. +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned dw_u; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + const VW evens = ShiftRight<8>(ShiftLeft<8>(vw)) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> BitCast(dw, ShiftRight<8>(BitCast(dw_u, bits16))); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srav_epi64(v.raw, bits.raw)}; +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} + +// ------------------------------ WidenMulPairwiseAdd + +#if HWY_NATIVE_DOT_BF16 + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return VFromD{_mm256_dpbf16_ps(Zero(df).raw, + reinterpret_cast<__m256bh>(a.raw), + reinterpret_cast<__m256bh>(b.raw))}; +} + +#endif // HWY_NATIVE_DOT_BF16 + +template +HWY_API VFromD WidenMulPairwiseAdd(D /*d32*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SatWidenMulPairwiseAdd + +template +HWY_API VFromD SatWidenMulPairwiseAdd( + DI16 /* tag */, VFromD> a, + VFromD> b) { + return VFromD{_mm256_maddubs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SatWidenMulPairwiseAccumulate + +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm256_dpwssds_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ ReorderWidenMulAccumulate + +#if HWY_NATIVE_DOT_BF16 +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{_mm256_dpbf16_ps(sum0.raw, + reinterpret_cast<__m256bh>(a.raw), + reinterpret_cast<__m256bh>(b.raw))}; +} +#endif // HWY_NATIVE_DOT_BF16 + +template +HWY_API VFromD ReorderWidenMulAccumulate(D d, Vec256 a, + Vec256 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; +#if HWY_TARGET <= HWY_AVX3_DL + return VFromD{_mm256_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; +#else + return sum0 + WidenMulPairwiseAdd(d, a, b); +#endif +} + +// ------------------------------ RearrangeToOddPlusEven +HWY_API Vec256 RearrangeToOddPlusEven(const Vec256 sum0, + Vec256 /*sum1*/) { + return sum0; // invariant already holds +} + +HWY_API Vec256 RearrangeToOddPlusEven(const Vec256 sum0, + Vec256 /*sum1*/) { + return sum0; // invariant already holds +} + +// ------------------------------ SumOfMulQuadAccumulate + +#if HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD{_mm256_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; +} + +#endif + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtps_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi32_pd(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API Vec256 PromoteTo(D /* tag */, Vec128 v) { + return Vec256{_mm256_cvtepu32_pd(v.raw)}; +} +#endif + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepu8_epi16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepu8_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepu16_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepu32_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + return VFromD{_mm256_cvtepu16_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec32 v) { + return VFromD{_mm256_cvtepu8_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi8_epi16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi8_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi16_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi32_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + return VFromD{_mm256_cvtepi16_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec32 v) { + return VFromD{_mm256_cvtepi8_epi64(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD PromoteInRangeTo(D /*di64*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttps_epi64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +template +HWY_API VFromD PromoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an uint64_t +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttps_epu64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +// I32->I64 PromoteEvenTo/PromoteOddTo + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec256 v) { + return BitCast(d_to, OddEven(DupEven(BroadcastSignBit(v)), v)); +} + +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec256 v) { + return BitCast(d_to, OddEven(BroadcastSignBit(v), DupOdd(v))); +} + +} // namespace detail +#endif + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i u16 = _mm256_packus_epi32(v.raw, v.raw); + // Concatenating lower halves of both 128-bit blocks afterward is more + // efficient than an extra input with low block = high block of v. + return VFromD{_mm256_castsi256_si128(_mm256_permute4x64_epi64(u16, 0x88))}; +} + +template +HWY_API VFromD DemoteTo(D dn, Vec256 v) { + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i i16 = _mm256_packs_epi32(v.raw, v.raw); + return VFromD{_mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); + const __m128i i16 = _mm256_castsi256_si128(i16_concat); + return VFromD{_mm_packus_epi16(i16, i16)}; +} + +template +HWY_API VFromD DemoteTo(D dn, Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + (void)dn; + return VFromD{_mm256_cvtusepi32_epi8(v.raw)}; +#else + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); +#endif +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i u8 = _mm256_packus_epi16(v.raw, v.raw); + return VFromD{_mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; +} + +template +HWY_API VFromD DemoteTo(D dn, Vec256 v) { + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFu)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); + const __m128i i16 = _mm256_castsi256_si128(i16_concat); + return VFromD{_mm_packs_epi16(i16, i16)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i i8 = _mm256_packs_epi16(v.raw, v.raw); + return VFromD{_mm256_castsi256_si128(_mm256_permute4x64_epi64(i8, 0x88))}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtsepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtsepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtsepi64_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm256_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm256_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm256_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtusepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtusepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtusepi64_epi8(v.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3 + +#ifndef HWY_DISABLE_F16C + +// Avoid "value of intrinsic immediate argument '8' is out of range '0 - 7'". +// 8 is the correct value of _MM_FROUND_NO_EXC, which is allowed here. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wsign-conversion") + +template +HWY_API VFromD DemoteTo(D df16, Vec256 v) { + const RebindToUnsigned du16; + return BitCast( + df16, VFromD{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); +} + +HWY_DIAGNOSTICS(pop) + +#endif // HWY_DISABLE_F16C + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD DemoteTo(D /*df16*/, Vec256 v) { + return VFromD{_mm256_cvtpd_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_AVX3_HAVE_F32_TO_BF16C +template +HWY_API VFromD DemoteTo(D /*dbf16*/, Vec256 v) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m128i raw_result; + __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + // The _mm256_cvtneps_pbh intrinsic returns a __m128bh vector that needs to be + // bit casted to a __m128i vector + return VFromD{detail::BitCastToInteger(_mm256_cvtneps_pbh(v.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /*dbf16*/, Vec256 a, + Vec256 b) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m256i raw_result; + __asm__("vcvtne2ps2bf16 %2, %1, %0" + : "=v"(raw_result) + : "v"(b.raw), "v"(a.raw)); + return VFromD{raw_result}; +#else + // The _mm256_cvtne2ps_pbh intrinsic returns a __m256bh vector that needs to + // be bit casted to a __m256i vector + return VFromD{detail::BitCastToInteger(_mm256_cvtne2ps_pbh(b.raw, a.raw))}; +#endif +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_API VFromD ReorderDemote2To(D /*d16*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_packs_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /*d16*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_packus_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const DFromV d; + const RebindToSigned di; + const auto max_i32 = Set(d, 0x7FFFFFFFu); + return ReorderDemote2To(dn, BitCast(di, Min(a, max_i32)), + BitCast(di, Min(b, max_i32))); +} + +template +HWY_API VFromD ReorderDemote2To(D /*d16*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_packs_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /*d16*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_packus_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const DFromV d; + const RebindToSigned di; + const auto max_i16 = Set(d, 0x7FFFu); + return ReorderDemote2To(dn, BitCast(di, Min(a, max_i16)), + BitCast(di, Min(b, max_i16))); +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_API Vec256 ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + const Repartition dn_f; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask_a = BitCast(du64, BroadcastSignBit(a)); + const auto invert_mask_b = BitCast(du64, BroadcastSignBit(b)); + const auto saturated_a = Xor( + invert_mask_a, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_a, BitCast(du64, a)))); + const auto saturated_b = Xor( + invert_mask_b, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_b, BitCast(du64, b)))); + + return BitCast(dn, + Vec256{_mm256_shuffle_ps(BitCast(dn_f, saturated_a).raw, + BitCast(dn_f, saturated_b).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} + +template +HWY_API Vec256 ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + const Repartition dn_f; + + const auto saturated_a = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(a), a))); + const auto saturated_b = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(b), b))); + + return BitCast(dn, + Vec256{_mm256_shuffle_ps(BitCast(dn_f, saturated_a).raw, + BitCast(dn_f, saturated_b).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const Half dnh; + const Repartition dn_f; + + const auto saturated_a = detail::DemoteFromU64Saturate(dnh, a); + const auto saturated_b = detail::DemoteFromU64Saturate(dnh, b); + + return BitCast(dn, + Vec256{_mm256_shuffle_ps(BitCast(dn_f, saturated_a).raw, + BitCast(dn_f, saturated_b).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} +#endif // HWY_TARGET > HWY_AVX3 + +template ), + HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2), + HWY_IF_T_SIZE_ONE_OF_V(V, + (1 << 1) | (1 << 2) | (1 << 4) | + ((HWY_TARGET > HWY_AVX3) ? (1 << 8) : 0))> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return VFromD{_mm256_permute4x64_epi64(ReorderDemote2To(d, a, b).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const Half dnh; + return Combine(dn, DemoteTo(dnh, b), DemoteTo(dnh, a)); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const Half dnh; + return Combine(dn, DemoteTo(dnh, b), DemoteTo(dnh, a)); +} + +template ), + HWY_IF_V_SIZE_GT_D(D, 16), class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2), + HWY_IF_T_SIZE_V(V, 8)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} +#endif + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtpd_ps(v.raw)}; +} + +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm256_cvttpd_epi32(v.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm256_cvttpd_epu32(v.raw)}; +#endif +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm256_cvtepi64_ps(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm256_cvtepu64_ps(v.raw)}; +} +#endif + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec256 v) { + const Full256 d32; + const Full64 d8; + alignas(32) static constexpr uint32_t k8From32[8] = { + 0x0C080400u, ~0u, ~0u, ~0u, ~0u, 0x0C080400u, ~0u, ~0u}; + // Place first four bytes in lo[0], remaining 4 in hi[1]. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Half(), quad); + return BitCast(d8, LowerHalf(lo | hi)); +} + +// ------------------------------ Truncations + +namespace detail { + +// LO and HI each hold four indices of bytes within a 128-bit block. +template +HWY_INLINE Vec128 LookupAndConcatHalves(Vec256 v) { + const Full256 d32; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr uint32_t kMap[8] = { + LO, HI, 0x10101010 + LO, 0x10101010 + HI, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(Load(d32, kMap).raw, v.raw); +#else + alignas(32) static constexpr uint32_t kMap[8] = {LO, HI, ~0u, ~0u, + ~0u, ~0u, LO, HI}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto result = _mm256_permute4x64_epi64(quad.raw, 0xCC); + // Possible alternative: + // const auto lo = LowerHalf(quad); + // const auto hi = UpperHalf(Half(), quad); + // const auto result = lo | hi; +#endif + + return Vec128{_mm256_castsi256_si128(result)}; +} + +// LO and HI each hold two indices of bytes within a 128-bit block. +template +HWY_INLINE Vec128 LookupAndConcatQuarters(Vec256 v) { + const Full256 d16; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr uint16_t kMap[16] = { + LO, HI, 0x1010 + LO, 0x1010 + HI, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(Load(d16, kMap).raw, v.raw); + return LowerHalf(Vec128{_mm256_castsi256_si128(result)}); +#else + constexpr uint16_t ff = static_cast(~0u); + alignas(32) static constexpr uint16_t kMap[16] = { + LO, ff, HI, ff, ff, ff, ff, ff, ff, ff, ff, ff, LO, ff, HI, ff}; + const auto quad = TableLookupBytes(v, Load(d16, kMap)); + const auto mixed = _mm256_permute4x64_epi64(quad.raw, 0xCC); + const auto half = _mm256_castsi256_si128(mixed); + return LowerHalf(Vec128{_mm_packus_epi32(half, half)}); +#endif +} + +} // namespace detail + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const Full256 d32; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr uint32_t kMap[8] = {0x18100800u, 0, 0, 0, + 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(Load(d32, kMap).raw, v.raw); + return LowerHalf(LowerHalf(LowerHalf(Vec256{result}))); +#else + alignas(32) static constexpr uint32_t kMap[8] = {0xFFFF0800u, ~0u, ~0u, ~0u, + 0x0800FFFFu, ~0u, ~0u, ~0u}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Half(), quad); + const auto result = lo | hi; + return LowerHalf(LowerHalf(Vec128{result.raw})); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const auto result = detail::LookupAndConcatQuarters<0x100, 0x908>(v); + return VFromD{result.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const Full256 d32; + alignas(32) static constexpr uint32_t kEven[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto v32 = + TableLookupLanes(BitCast(d32, v), SetTableIndices(d32, kEven)); + return LowerHalf(Vec256{v32.raw}); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const auto full = detail::LookupAndConcatQuarters<0x400, 0xC08>(v); + return VFromD{full.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const auto full = detail::LookupAndConcatHalves<0x05040100, 0x0D0C0908>(v); + return VFromD{full.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const auto full = detail::LookupAndConcatHalves<0x06040200, 0x0E0C0A08>(v); + return VFromD{full.raw}; +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtepu16_ph(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtepi16_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtepi32_ps(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ConvertTo(D /*df*/, Vec256 v) { + return VFromD{_mm256_cvtepu32_ps(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, Vec256 v) { + return VFromD{_mm256_cvtepi64_pd(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, Vec256 v) { + return VFromD{_mm256_cvtepu64_pd(v.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3 + +// Truncates (rounds toward zero). + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi16( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]), + detail::X86ConvertScalarFromFloat(raw_v[8]), + detail::X86ConvertScalarFromFloat(raw_v[9]), + detail::X86ConvertScalarFromFloat(raw_v[10]), + detail::X86ConvertScalarFromFloat(raw_v[11]), + detail::X86ConvertScalarFromFloat(raw_v[12]), + detail::X86ConvertScalarFromFloat(raw_v[13]), + detail::X86ConvertScalarFromFloat(raw_v[14]), + detail::X86ConvertScalarFromFloat(raw_v[15]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // HWY_COMPILER_GCC_ACTUAL < 1200 + return VFromD{_mm256_cvttph_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttph_epu16 with GCC if any + // values of v[i] are not within the range of an uint16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi16( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[8])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[9])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[10])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[11])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[12])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[13])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[14])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[15])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttph2uw {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // HWY_COMPILER_GCC_ACTUAL < 1200 + return VFromD{_mm256_cvttph_epu16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm256_cvttps_epi32(v.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ConvertInRangeTo(D /*di*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttpd_epi64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttps_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttps_epu32(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epu64 with GCC if any + // values of v[i] are not within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttpd_epu64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +#endif // HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtps_epi32 if any values of + // v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi32(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtps_epi32(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} + +#if HWY_HAVE_FLOAT16 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*d*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi16(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]), + detail::X86ScalarNearestInt(raw_v[8]), + detail::X86ScalarNearestInt(raw_v[9]), + detail::X86ScalarNearestInt(raw_v[10]), + detail::X86ScalarNearestInt(raw_v[11]), + detail::X86ScalarNearestInt(raw_v[12]), + detail::X86ScalarNearestInt(raw_v[13]), + detail::X86ScalarNearestInt(raw_v[14]), + detail::X86ScalarNearestInt(raw_v[15]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtph_epi16(v.raw)}; +#endif +} +#endif + +#if HWY_TARGET <= HWY_AVX3 +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi64x(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtpd_epi64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +#endif // HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD DemoteToNearestIntInRange( + DI, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvtpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtpd_epi32(v.raw)}; +#endif +} + +#ifndef HWY_DISABLE_F16C + +template +HWY_API VFromD PromoteTo(D df32, Vec128 v) { + (void)df32; +#if HWY_HAVE_FLOAT16 + const RebindToUnsigned> du16; + return VFromD{_mm256_cvtph_ps(BitCast(du16, v).raw)}; +#else + return VFromD{_mm256_cvtph_ps(v.raw)}; +#endif // HWY_HAVE_FLOAT16 +} + +#endif // HWY_DISABLE_F16C + +#if HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, Vec64 v) { + return VFromD{_mm256_cvtph_pd(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD PromoteTo(D df32, Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +HWY_API Vec256 AESRound(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_aesenc_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 AESLastRound(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 AESRoundInv(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_aesdec_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, AESRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRoundInv(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 AESLastRoundInv(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_aesdeclast_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine( + d, AESLastRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRoundInv(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +template )> +HWY_API V AESInvMixColumns(V state) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL + // On AVX3_DL, it is more efficient to do an InvMixColumns operation for a + // 256-bit or 512-bit vector by doing a AESLastRound operation + // (_mm256_aesenclast_epi128/_mm512_aesenclast_epi128) followed by a + // AESRoundInv operation (_mm256_aesdec_epi128/_mm512_aesdec_epi128) than to + // split the vector into 128-bit vectors, carrying out multiple + // _mm_aesimc_si128 operations, and then combining the _mm_aesimc_si128 + // results back into a 256-bit or 512-bit vector. + const auto zero = Zero(d); + return AESRoundInv(AESLastRound(state, zero), zero); +#else + const Half dh; + return Combine(d, AESInvMixColumns(UpperHalf(dh, state)), + AESInvMixColumns(LowerHalf(dh, state))); +#endif +} + +template +HWY_API Vec256 AESKeyGenAssist(Vec256 v) { + const Full256 d; +#if HWY_TARGET <= HWY_AVX3_DL + const VFromD rconXorMask = Dup128VecFromValues( + d, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0); + const VFromD rotWordShuffle = Dup128VecFromValues( + d, 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12); + const Repartition du32; + const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); + const auto sub_word_result = AESLastRound(w13, rconXorMask); + return TableLookupBytes(sub_word_result, rotWordShuffle); +#else + const Half d2; + return Combine(d, AESKeyGenAssist(UpperHalf(d2, v)), + AESKeyGenAssist(LowerHalf(v))); +#endif +} + +HWY_API Vec256 CLMulLower(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x00)}; +#else + const Full256 d; + const Half d2; + return Combine(d, CLMulLower(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulLower(LowerHalf(a), LowerHalf(b))); +#endif +} + +HWY_API Vec256 CLMulUpper(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x11)}; +#else + const Full256 d; + const Half d2; + return Combine(d, CLMulUpper(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulUpper(LowerHalf(a), LowerHalf(b))); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ LoadMaskBits + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + return MFromD::FromBits(mask_bits); +} + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + + CopyBytes(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (kN < 8) { + const int mask_bits = static_cast((1ull << kN) - 1); + bits[0] = static_cast(bits[0] & mask_bits); + } + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API size_t CountTrue(D /* tag */, MFromD mask) { + return PopCount(static_cast(mask.raw)); +} + +template +HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +template +HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD mask) { + return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + return mask.raw ? static_cast(FindKnownLastTrue(d, mask)) + : intptr_t{-1}; +} + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { + return (uint64_t{mask.raw} & 0xF) == 0; +} + +} // namespace detail + +template +HWY_API bool AllFalse(D /* tag */, MFromD mask) { + return detail::AllFalse(hwy::SizeTag)>(), mask); +} + +namespace detail { + +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { + // Cannot use _kortestc because we have less than 8 mask bits. + return mask.raw == 0xFu; +} + +} // namespace detail + +template +HWY_API bool AllTrue(D /* tag */, const MFromD mask) { + return detail::AllTrue(hwy::SizeTag)>(), mask); +} + +// ------------------------------ Compress + +// 16-bit is defined in x86_512 so we can use 512-bit vectors. + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; +} + +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + return Vec256{_mm256_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + // See CompressIsPartition. + alignas(16) static constexpr uint64_t packed_array[16] = { + // PrintCompress64x4NibbleTables + 0x00003210, 0x00003210, 0x00003201, 0x00003210, 0x00003102, 0x00003120, + 0x00003021, 0x00003210, 0x00002103, 0x00002130, 0x00002031, 0x00002310, + 0x00001032, 0x00001320, 0x00000321, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) static constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressNot (Compress) + +// Implemented in x86_512 for lane size != 8. + +template +HWY_API Vec256 CompressNot(Vec256 v, Mask256 mask) { + // See CompressIsPartition. + alignas(16) static constexpr uint64_t packed_array[16] = { + // PrintCompressNot64x4NibbleTables + 0x00003210, 0x00000321, 0x00001320, 0x00001032, 0x00002310, 0x00002031, + 0x00002130, 0x00002103, 0x00003210, 0x00003021, 0x00003120, 0x00003102, + 0x00003210, 0x00003201, 0x00003210, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(32) static constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressStore (defined in x86_512) +// ------------------------------ CompressBlendedStore (defined in x86_512) +// ------------------------------ CompressBitsStore (defined in x86_512) + +#else // AVX2 + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// 256 suffix avoids ambiguity with x86_128 without needing HWY_IF_V_SIZE. +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + const Repartition du32; + const auto vbits = BitCast(du, Set(du32, static_cast(mask_bits))); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + const Repartition du64; + alignas(32) static constexpr uint64_t kRep8[4] = { + 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull, + 0x0303030303030303ull}; + const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); + + const VFromD bit = Dup128VecFromValues( + du, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return RebindMask(d, TestBit(rep8, bit)); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + alignas(32) static constexpr uint16_t kBit[16] = { + 1, 2, 4, 8, 16, 32, 64, 128, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kBit[8] = {1, 2, 4, 8}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + return detail::LoadMaskBits256>(mask_bits); +} + +// ------------------------------ BitsFromMask + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned d8; + const auto sign_bits = BitCast(d8, VecFromMask(d, mask)).raw; + // Prevent sign-extension of 32-bit masks because the intrinsic returns int. + return static_cast(_mm256_movemask_epi8(sign_bits)); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { +#if !defined(HWY_DISABLE_BMI2_FMA) && !defined(HWY_DISABLE_PEXT_ON_AVX2) + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + const uint64_t sign_bits8 = BitsFromMask(d8, mask8); + // Skip the bits from the lower byte of each u16 (better not to use the + // same packs_epi16 as SSE4, because that requires an extra swizzle here). + return _pext_u32(static_cast(sign_bits8), 0xAAAAAAAAu); +#else + // Slow workaround for when BMI2 is disabled + // Remove useless lower half of each u16 while preserving the sign bit. + // Bytes [0, 8) and [16, 24) have the same sign bits as the input lanes. + const auto sign_bits = _mm256_packs_epi16(mask.raw, _mm256_setzero_si256()); + // Move odd qwords (value zero) to top so they don't affect the mask value. + const auto compressed = _mm256_castsi256_si128( + _mm256_permute4x64_epi64(sign_bits, _MM_SHUFFLE(3, 1, 2, 0))); + return static_cast(_mm_movemask_epi8(compressed)); +#endif // HWY_ARCH_X86_64 +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToFloat df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_ps(sign_bits)); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToFloat df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_pd(sign_bits)); +} + +// ------------------------------ StoreMaskBits +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + HWY_LANES_CONSTEXPR size_t kNumBytes = (N + 7) / 8; + + const uint64_t mask_bits = BitsFromMask(d, mask); + CopyBytes(&mask_bits, bits, kNumBytes); + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Specialize for 16-bit lanes to avoid unnecessary pext. This assumes each mask +// lane is 0 or ~0. +template +HWY_API bool AllFalse(D d, MFromD mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return BitsFromMask(d8, mask8) == 0; +} + +template +HWY_API bool AllFalse(D d, MFromD mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return BitsFromMask(d, mask) == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return BitsFromMask(d8, mask8) == (1ull << 32) - 1; +} +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr uint64_t kAllBits = (1ull << MaxLanes(d)) - 1; + return BitsFromMask(d, mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return PopCount(BitsFromMask(d8, mask8)) >> 1; +} +template +HWY_API size_t CountTrue(D d, MFromD mask) { + return PopCount(BitsFromMask(d, mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits)) + : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +template +HWY_INLINE Vec256 IndicesFromBits256(uint64_t mask_bits) { + const Full256 d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) static constexpr uint32_t packed_array[256] = { + // PrintCompress32x8Tables + 0x76543210, 0x76543218, 0x76543209, 0x76543298, 0x7654310a, 0x765431a8, + 0x765430a9, 0x76543a98, 0x7654210b, 0x765421b8, 0x765420b9, 0x76542b98, + 0x765410ba, 0x76541ba8, 0x76540ba9, 0x7654ba98, 0x7653210c, 0x765321c8, + 0x765320c9, 0x76532c98, 0x765310ca, 0x76531ca8, 0x76530ca9, 0x7653ca98, + 0x765210cb, 0x76521cb8, 0x76520cb9, 0x7652cb98, 0x76510cba, 0x7651cba8, + 0x7650cba9, 0x765cba98, 0x7643210d, 0x764321d8, 0x764320d9, 0x76432d98, + 0x764310da, 0x76431da8, 0x76430da9, 0x7643da98, 0x764210db, 0x76421db8, + 0x76420db9, 0x7642db98, 0x76410dba, 0x7641dba8, 0x7640dba9, 0x764dba98, + 0x763210dc, 0x76321dc8, 0x76320dc9, 0x7632dc98, 0x76310dca, 0x7631dca8, + 0x7630dca9, 0x763dca98, 0x76210dcb, 0x7621dcb8, 0x7620dcb9, 0x762dcb98, + 0x7610dcba, 0x761dcba8, 0x760dcba9, 0x76dcba98, 0x7543210e, 0x754321e8, + 0x754320e9, 0x75432e98, 0x754310ea, 0x75431ea8, 0x75430ea9, 0x7543ea98, + 0x754210eb, 0x75421eb8, 0x75420eb9, 0x7542eb98, 0x75410eba, 0x7541eba8, + 0x7540eba9, 0x754eba98, 0x753210ec, 0x75321ec8, 0x75320ec9, 0x7532ec98, + 0x75310eca, 0x7531eca8, 0x7530eca9, 0x753eca98, 0x75210ecb, 0x7521ecb8, + 0x7520ecb9, 0x752ecb98, 0x7510ecba, 0x751ecba8, 0x750ecba9, 0x75ecba98, + 0x743210ed, 0x74321ed8, 0x74320ed9, 0x7432ed98, 0x74310eda, 0x7431eda8, + 0x7430eda9, 0x743eda98, 0x74210edb, 0x7421edb8, 0x7420edb9, 0x742edb98, + 0x7410edba, 0x741edba8, 0x740edba9, 0x74edba98, 0x73210edc, 0x7321edc8, + 0x7320edc9, 0x732edc98, 0x7310edca, 0x731edca8, 0x730edca9, 0x73edca98, + 0x7210edcb, 0x721edcb8, 0x720edcb9, 0x72edcb98, 0x710edcba, 0x71edcba8, + 0x70edcba9, 0x7edcba98, 0x6543210f, 0x654321f8, 0x654320f9, 0x65432f98, + 0x654310fa, 0x65431fa8, 0x65430fa9, 0x6543fa98, 0x654210fb, 0x65421fb8, + 0x65420fb9, 0x6542fb98, 0x65410fba, 0x6541fba8, 0x6540fba9, 0x654fba98, + 0x653210fc, 0x65321fc8, 0x65320fc9, 0x6532fc98, 0x65310fca, 0x6531fca8, + 0x6530fca9, 0x653fca98, 0x65210fcb, 0x6521fcb8, 0x6520fcb9, 0x652fcb98, + 0x6510fcba, 0x651fcba8, 0x650fcba9, 0x65fcba98, 0x643210fd, 0x64321fd8, + 0x64320fd9, 0x6432fd98, 0x64310fda, 0x6431fda8, 0x6430fda9, 0x643fda98, + 0x64210fdb, 0x6421fdb8, 0x6420fdb9, 0x642fdb98, 0x6410fdba, 0x641fdba8, + 0x640fdba9, 0x64fdba98, 0x63210fdc, 0x6321fdc8, 0x6320fdc9, 0x632fdc98, + 0x6310fdca, 0x631fdca8, 0x630fdca9, 0x63fdca98, 0x6210fdcb, 0x621fdcb8, + 0x620fdcb9, 0x62fdcb98, 0x610fdcba, 0x61fdcba8, 0x60fdcba9, 0x6fdcba98, + 0x543210fe, 0x54321fe8, 0x54320fe9, 0x5432fe98, 0x54310fea, 0x5431fea8, + 0x5430fea9, 0x543fea98, 0x54210feb, 0x5421feb8, 0x5420feb9, 0x542feb98, + 0x5410feba, 0x541feba8, 0x540feba9, 0x54feba98, 0x53210fec, 0x5321fec8, + 0x5320fec9, 0x532fec98, 0x5310feca, 0x531feca8, 0x530feca9, 0x53feca98, + 0x5210fecb, 0x521fecb8, 0x520fecb9, 0x52fecb98, 0x510fecba, 0x51fecba8, + 0x50fecba9, 0x5fecba98, 0x43210fed, 0x4321fed8, 0x4320fed9, 0x432fed98, + 0x4310feda, 0x431feda8, 0x430feda9, 0x43feda98, 0x4210fedb, 0x421fedb8, + 0x420fedb9, 0x42fedb98, 0x410fedba, 0x41fedba8, 0x40fedba9, 0x4fedba98, + 0x3210fedc, 0x321fedc8, 0x320fedc9, 0x32fedc98, 0x310fedca, 0x31fedca8, + 0x30fedca9, 0x3fedca98, 0x210fedcb, 0x21fedcb8, 0x20fedcb9, 0x2fedcb98, + 0x10fedcba, 0x1fedcba8, 0x0fedcba9, 0xfedcba98}; + + // No need to mask because _mm256_permutevar8x32_epi32 ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromBits256(uint64_t mask_bits) { + const Full256 d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) static constexpr uint32_t u32_indices[128] = { + // PrintCompress64x4PairTables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, + 10, 11, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, + 12, 13, 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 2, 3, 6, 7, + 10, 11, 12, 13, 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 6, 7, + 14, 15, 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, 2, 3, 4, 5, + 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, 4, 5, + 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, 2, 3, + 10, 11, 12, 13, 14, 15, 0, 1, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits256(uint64_t mask_bits) { + const Full256 d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) static constexpr uint32_t packed_array[256] = { + // PrintCompressNot32x8Tables + 0xfedcba98, 0x8fedcba9, 0x9fedcba8, 0x98fedcba, 0xafedcb98, 0xa8fedcb9, + 0xa9fedcb8, 0xa98fedcb, 0xbfedca98, 0xb8fedca9, 0xb9fedca8, 0xb98fedca, + 0xbafedc98, 0xba8fedc9, 0xba9fedc8, 0xba98fedc, 0xcfedba98, 0xc8fedba9, + 0xc9fedba8, 0xc98fedba, 0xcafedb98, 0xca8fedb9, 0xca9fedb8, 0xca98fedb, + 0xcbfeda98, 0xcb8feda9, 0xcb9feda8, 0xcb98feda, 0xcbafed98, 0xcba8fed9, + 0xcba9fed8, 0xcba98fed, 0xdfecba98, 0xd8fecba9, 0xd9fecba8, 0xd98fecba, + 0xdafecb98, 0xda8fecb9, 0xda9fecb8, 0xda98fecb, 0xdbfeca98, 0xdb8feca9, + 0xdb9feca8, 0xdb98feca, 0xdbafec98, 0xdba8fec9, 0xdba9fec8, 0xdba98fec, + 0xdcfeba98, 0xdc8feba9, 0xdc9feba8, 0xdc98feba, 0xdcafeb98, 0xdca8feb9, + 0xdca9feb8, 0xdca98feb, 0xdcbfea98, 0xdcb8fea9, 0xdcb9fea8, 0xdcb98fea, + 0xdcbafe98, 0xdcba8fe9, 0xdcba9fe8, 0xdcba98fe, 0xefdcba98, 0xe8fdcba9, + 0xe9fdcba8, 0xe98fdcba, 0xeafdcb98, 0xea8fdcb9, 0xea9fdcb8, 0xea98fdcb, + 0xebfdca98, 0xeb8fdca9, 0xeb9fdca8, 0xeb98fdca, 0xebafdc98, 0xeba8fdc9, + 0xeba9fdc8, 0xeba98fdc, 0xecfdba98, 0xec8fdba9, 0xec9fdba8, 0xec98fdba, + 0xecafdb98, 0xeca8fdb9, 0xeca9fdb8, 0xeca98fdb, 0xecbfda98, 0xecb8fda9, + 0xecb9fda8, 0xecb98fda, 0xecbafd98, 0xecba8fd9, 0xecba9fd8, 0xecba98fd, + 0xedfcba98, 0xed8fcba9, 0xed9fcba8, 0xed98fcba, 0xedafcb98, 0xeda8fcb9, + 0xeda9fcb8, 0xeda98fcb, 0xedbfca98, 0xedb8fca9, 0xedb9fca8, 0xedb98fca, + 0xedbafc98, 0xedba8fc9, 0xedba9fc8, 0xedba98fc, 0xedcfba98, 0xedc8fba9, + 0xedc9fba8, 0xedc98fba, 0xedcafb98, 0xedca8fb9, 0xedca9fb8, 0xedca98fb, + 0xedcbfa98, 0xedcb8fa9, 0xedcb9fa8, 0xedcb98fa, 0xedcbaf98, 0xedcba8f9, + 0xedcba9f8, 0xedcba98f, 0xfedcba98, 0xf8edcba9, 0xf9edcba8, 0xf98edcba, + 0xfaedcb98, 0xfa8edcb9, 0xfa9edcb8, 0xfa98edcb, 0xfbedca98, 0xfb8edca9, + 0xfb9edca8, 0xfb98edca, 0xfbaedc98, 0xfba8edc9, 0xfba9edc8, 0xfba98edc, + 0xfcedba98, 0xfc8edba9, 0xfc9edba8, 0xfc98edba, 0xfcaedb98, 0xfca8edb9, + 0xfca9edb8, 0xfca98edb, 0xfcbeda98, 0xfcb8eda9, 0xfcb9eda8, 0xfcb98eda, + 0xfcbaed98, 0xfcba8ed9, 0xfcba9ed8, 0xfcba98ed, 0xfdecba98, 0xfd8ecba9, + 0xfd9ecba8, 0xfd98ecba, 0xfdaecb98, 0xfda8ecb9, 0xfda9ecb8, 0xfda98ecb, + 0xfdbeca98, 0xfdb8eca9, 0xfdb9eca8, 0xfdb98eca, 0xfdbaec98, 0xfdba8ec9, + 0xfdba9ec8, 0xfdba98ec, 0xfdceba98, 0xfdc8eba9, 0xfdc9eba8, 0xfdc98eba, + 0xfdcaeb98, 0xfdca8eb9, 0xfdca9eb8, 0xfdca98eb, 0xfdcbea98, 0xfdcb8ea9, + 0xfdcb9ea8, 0xfdcb98ea, 0xfdcbae98, 0xfdcba8e9, 0xfdcba9e8, 0xfdcba98e, + 0xfedcba98, 0xfe8dcba9, 0xfe9dcba8, 0xfe98dcba, 0xfeadcb98, 0xfea8dcb9, + 0xfea9dcb8, 0xfea98dcb, 0xfebdca98, 0xfeb8dca9, 0xfeb9dca8, 0xfeb98dca, + 0xfebadc98, 0xfeba8dc9, 0xfeba9dc8, 0xfeba98dc, 0xfecdba98, 0xfec8dba9, + 0xfec9dba8, 0xfec98dba, 0xfecadb98, 0xfeca8db9, 0xfeca9db8, 0xfeca98db, + 0xfecbda98, 0xfecb8da9, 0xfecb9da8, 0xfecb98da, 0xfecbad98, 0xfecba8d9, + 0xfecba9d8, 0xfecba98d, 0xfedcba98, 0xfed8cba9, 0xfed9cba8, 0xfed98cba, + 0xfedacb98, 0xfeda8cb9, 0xfeda9cb8, 0xfeda98cb, 0xfedbca98, 0xfedb8ca9, + 0xfedb9ca8, 0xfedb98ca, 0xfedbac98, 0xfedba8c9, 0xfedba9c8, 0xfedba98c, + 0xfedcba98, 0xfedc8ba9, 0xfedc9ba8, 0xfedc98ba, 0xfedcab98, 0xfedca8b9, + 0xfedca9b8, 0xfedca98b, 0xfedcba98, 0xfedcb8a9, 0xfedcb9a8, 0xfedcb98a, + 0xfedcba98, 0xfedcba89, 0xfedcba98, 0xfedcba98}; + + // No need to mask because <_mm256_permutevar8x32_epi32> ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const Vec256 packed = Set(d32, packed_array[mask_bits]); + alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits256(uint64_t mask_bits) { + const Full256 d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) static constexpr uint32_t u32_indices[128] = { + // PrintCompressNot64x4PairTables + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, + 8, 9, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, + 8, 9, 10, 11, 14, 15, 12, 13, 10, 11, 14, 15, 8, 9, 12, 13, + 8, 9, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 8, 9, 14, 15, + 8, 9, 12, 13, 10, 11, 14, 15, 12, 13, 8, 9, 10, 11, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 8, 9, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} + +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const DFromV d; + const Repartition du32; + + HWY_DASSERT(mask_bits < (1ull << Lanes(d))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256 indices{IndicesFromBits256(mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + const auto vu16 = BitCast(du, v); // (required for float16_t inputs) + const Half duh; + const auto half0 = LowerHalf(duh, vu16); + const auto half1 = UpperHalf(duh, vu16); + + const uint64_t mask_bits0 = mask_bits & 0xFF; + const uint64_t mask_bits1 = mask_bits >> 8; + const auto compressed0 = detail::CompressBits(half0, mask_bits0); + const auto compressed1 = detail::CompressBits(half1, mask_bits1); + + alignas(32) uint16_t all_true[16] = {}; + // Store mask=true lanes, left to right. + const size_t num_true0 = PopCount(mask_bits0); + Store(compressed0, duh, all_true); + StoreU(compressed1, duh, all_true + num_true0); + + if (hwy::HWY_NAMESPACE::CompressIsPartition::value) { + // Store mask=false lanes, right to left. The second vector fills the upper + // half with right-aligned false lanes. The first vector is shifted + // rightwards to overwrite the true lanes of the second. + alignas(32) uint16_t all_false[16] = {}; + const size_t num_true1 = PopCount(mask_bits1); + Store(compressed1, duh, all_false + 8); + StoreU(compressed0, duh, all_false + num_true1); + + const auto mask = FirstN(du, num_true0 + num_true1); + return BitCast(d, + IfThenElse(mask, Load(du, all_true), Load(du, all_false))); + } else { + // Only care about the mask=true lanes. + return BitCast(d, Load(du, all_true)); + } +} + +template +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + const DFromV d; + const Repartition du32; + + HWY_DASSERT(mask_bits < (1ull << Lanes(d))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256 indices{IndicesFromNotBits256(mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + // Compress ensures only the lower 16 bits are set, so flip those. + return Compress(v, mask_bits ^ 0xFFFF); +} + +} // namespace detail + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 m) { + const DFromV d; + return detail::Compress(v, BitsFromMask(d, m)); +} + +template +HWY_API Vec256 CompressNot(Vec256 v, Mask256 m) { + const DFromV d; + return detail::CompressNot(v, BitsFromMask(d, m)); +} + +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + return CompressNot(v, mask); +} + +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + + const RebindToUnsigned du; + const Repartition du32; + HWY_DASSERT(mask_bits < (1ull << Lanes(d))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). Nibble MSB encodes FirstN. + const Vec256 idx_mask = + detail::IndicesFromBits256>(mask_bits); + // Shift nibble MSB into MSB + const Mask256 mask32 = MaskFromVec(ShiftLeft<28>(idx_mask)); + // First cast to unsigned (RebindMask cannot change lane size) + const MFromD mask_u{mask32.raw}; + const MFromD mask = RebindMask(d, mask_u); + const VFromD compressed = BitCast( + d, + TableLookupLanes(BitCast(du32, v), Indices256{idx_mask.raw})); + + BlendedStore(compressed, mask, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + const VFromD compressed = detail::Compress(v, mask_bits); + +#if HWY_MEM_OPS_MIGHT_FAULT // true if HWY_IS_MSAN + // BlendedStore tests mask for each lane, but we know that the mask is + // FirstN, so we can just copy. + alignas(32) TFromD buf[16]; + Store(compressed, d, buf); + CopyBytes(buf, unaligned, count * sizeof(TFromD)); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + return count; +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + HWY_LANES_CONSTEXPR size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits, kNumBytes); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Dup128MaskFromMaskBits + +// Generic for all vector lengths >= 32 bytes +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const Half dh; + const auto mh = Dup128MaskFromMaskBits(dh, mask_bits); + return CombineMasks(d, mh, mh); +} + +// ------------------------------ Expand + +// Always define Expand/LoadExpand because generic_ops only does so for Vec128. + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE // VBMI2 + +HWY_INLINE Vec256 NativeExpand(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_expand_epi8(mask.raw, v.raw)}; +} + +HWY_INLINE Vec256 NativeExpand(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_expand_epi16(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint8_t* HWY_RESTRICT unaligned) { + return VFromD{_mm256_maskz_expandloadu_epi8(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint16_t* HWY_RESTRICT unaligned) { + return VFromD{_mm256_maskz_expandloadu_epi16(mask.raw, unaligned)}; +} + +#endif // HWY_TARGET <= HWY_AVX3_DL +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + +HWY_INLINE Vec256 NativeExpand(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_expand_epi32(mask.raw, v.raw)}; +} + +HWY_INLINE Vec256 NativeExpand(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_expand_epi64(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint32_t* HWY_RESTRICT unaligned) { + return VFromD{_mm256_maskz_expandloadu_epi32(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint64_t* HWY_RESTRICT unaligned) { + return VFromD{_mm256_maskz_expandloadu_epi64(mask.raw, unaligned)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +} // namespace detail + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +#else + // LUTs are infeasible for so many mask combinations, so Combine two + // half-vector Expand. + const Half dh; + const uint64_t mask_bits = BitsFromMask(d, mask); + constexpr size_t N = 32 / sizeof(T); + const size_t countL = PopCount(mask_bits & ((1 << (N / 2)) - 1)); + const Mask128 maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask))); + const Vec128 expandL = Expand(LowerHalf(v), maskL); + // We have to shift the input by a variable number of bytes, but there isn't + // a table-driven option for that until VBMI, and CPUs with that likely also + // have VBMI2 and thus native Expand. + alignas(32) T lanes[N]; + Store(v, d, lanes); + const Mask128 maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask))); + const Vec128 expandH = Expand(LoadU(dh, lanes + countL), maskH); + return Combine(d, expandH, expandL); +#endif +} + +// If AVX3, this is already implemented by x86_512. +#if HWY_TARGET != HWY_AVX3 + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const Full256 d; +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + return BitCast(d, detail::NativeExpand(BitCast(du, v), RebindMask(du, mask))); +#else // AVX2 + // LUTs are infeasible for 2^16 possible masks, so splice together two + // half-vector Expand. + const Half dh; + const Mask128 maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask))); + const Vec128 expandL = Expand(LowerHalf(v), maskL); + // We have to shift the input by a variable number of u16. permutevar_epi16 + // requires AVX3 and if we had that, we'd use native u32 Expand. The only + // alternative is re-loading, which incurs a store to load forwarding stall. + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + const Vec128 vH = LoadU(dh, lanes + CountTrue(dh, maskL)); + const Mask128 maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask))); + const Vec128 expandH = Expand(vH, maskH); + return Combine(d, expandH, expandL); +#endif // AVX2 +} + +#endif // HWY_TARGET != HWY_AVX3 + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const Full256 d; +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +#else + const RebindToUnsigned du; + const uint64_t mask_bits = BitsFromMask(d, mask); + + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintExpand32x8Nibble. + 0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0, + 0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10, + 0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0, + 0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210, + 0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0, + 0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10, + 0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0, + 0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210, + 0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0, + 0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10, + 0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0, + 0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210, + 0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0, + 0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10, + 0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0, + 0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210, + 0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0, + 0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10, + 0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0, + 0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210, + 0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0, + 0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10, + 0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0, + 0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210, + 0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0, + 0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10, + 0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0, + 0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210, + 0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0, + 0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10, + 0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0, + 0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210, + 0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0, + 0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10, + 0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0, + 0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210, + 0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0, + 0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10, + 0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0, + 0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210, + 0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0, + 0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10, + 0x543210ff, 0x654321f0, 0x6543210f, 0x76543210, + }; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3). + const Vec256 packed = Set(du, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + // TableLookupLanes ignores upper bits; avoid bounds-check in IndicesFromVec. + const Indices256 indices{(packed >> Load(du, shifts)).raw}; + const Vec256 expand = TableLookupLanes(BitCast(du, v), indices); + // TableLookupLanes cannot also zero masked-off lanes, so do that now. + return IfThenElseZero(mask, BitCast(d, expand)); +#endif +} + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const Full256 d; +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +#else + const RebindToUnsigned du; + const uint64_t mask_bits = BitsFromMask(d, mask); + + alignas(16) constexpr uint64_t packed_array[16] = { + // PrintExpand64x4Nibble. + 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, + 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, + 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2). + const Vec256 packed = Set(du, packed_array[mask_bits]); + alignas(32) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; +#if HWY_TARGET <= HWY_AVX3 // native 64-bit TableLookupLanes + // TableLookupLanes ignores upper bits; avoid bounds-check in IndicesFromVec. + const Indices256 indices{(packed >> Load(du, shifts)).raw}; +#else + // 64-bit TableLookupLanes on AVX2 requires IndicesFromVec, which checks + // bounds, so clear the upper bits. + const Vec256 masked = And(packed >> Load(du, shifts), Set(du, 3)); + const Indices256 indices = IndicesFromVec(du, masked); +#endif + const Vec256 expand = TableLookupLanes(BitCast(du, v), indices); + // TableLookupLanes cannot also zero masked-off lanes, so do that now. + return IfThenElseZero(mask, BitCast(d, expand)); +#endif +} + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template +HWY_API void LoadTransposedBlocks3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& A, VFromD& B, VFromD& C) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD v10 = LoadU(d, unaligned + 0 * N); // 1 0 + const VFromD v32 = LoadU(d, unaligned + 1 * N); + const VFromD v54 = LoadU(d, unaligned + 2 * N); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of vA) +// 5 1 +// 6 2 +// 7 3 +template +HWY_API void LoadTransposedBlocks4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& vA, VFromD& vB, VFromD& vC, + VFromD& vD) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD v10 = LoadU(d, unaligned + 0 * N); + const VFromD v32 = LoadU(d, unaligned + 1 * N); + const VFromD v54 = LoadU(d, unaligned + 2 * N); + const VFromD v76 = LoadU(d, unaligned + 3 * N); + + vA = ConcatLowerLower(d, v54, v10); + vB = ConcatUpperUpper(d, v54, v10); + vC = ConcatLowerLower(d, v76, v32); + vD = ConcatUpperUpper(d, v76, v32); +} +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template +HWY_API void StoreTransposedBlocks2(VFromD i, VFromD j, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template +HWY_API void StoreTransposedBlocks3(VFromD i, VFromD j, VFromD k, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperLower(d, i, k); + const auto out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template +HWY_API void StoreTransposedBlocks4(VFromD i, VFromD j, VFromD k, + VFromD l, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + // Write lower halves, then upper. + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + const auto out2 = ConcatUpperUpper(d, j, i); + const auto out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} +} // namespace detail + +// ------------------------------ Additional mask logical operations + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API Mask256 SetAtOrAfterFirst(Mask256 mask) { + constexpr size_t N = MaxLanes(Full256()); + constexpr uint32_t kActiveElemMask = + static_cast((uint64_t{1} << N) - 1); + return Mask256{static_cast::Raw>( + (0u - detail::AVX3Blsi(mask.raw)) & kActiveElemMask)}; +} +template +HWY_API Mask256 SetBeforeFirst(Mask256 mask) { + constexpr size_t N = MaxLanes(Full256()); + constexpr uint32_t kActiveElemMask = + static_cast((uint64_t{1} << N) - 1); + return Mask256{static_cast::Raw>( + (detail::AVX3Blsi(mask.raw) - 1u) & kActiveElemMask)}; +} +template +HWY_API Mask256 SetAtOrBeforeFirst(Mask256 mask) { + constexpr size_t N = MaxLanes(Full256()); + constexpr uint32_t kActiveElemMask = + static_cast((uint64_t{1} << N) - 1); + return Mask256{static_cast::Raw>( + detail::AVX3Blsmsk(mask.raw) & kActiveElemMask)}; +} +template +HWY_API Mask256 SetOnlyFirst(Mask256 mask) { + return Mask256{ + static_cast::Raw>(detail::AVX3Blsi(mask.raw))}; +} +#else // AVX2 +template +HWY_API Mask256 SetAtOrAfterFirst(Mask256 mask) { + const Full256 d; + const Repartition di64; + const Repartition df32; + const Repartition di32; + const Half dh_i64; + const Half dh_i32; + using VF32 = VFromD; + + auto vmask = BitCast(di64, VecFromMask(d, mask)); + vmask = Or(vmask, Neg(vmask)); + + // Copy the sign bit of the even int64_t lanes to the odd int64_t lanes + const auto vmask2 = BitCast( + di32, VF32{_mm256_shuffle_ps(Zero(df32).raw, BitCast(df32, vmask).raw, + _MM_SHUFFLE(1, 1, 0, 0))}); + vmask = Or(vmask, BitCast(di64, BroadcastSignBit(vmask2))); + + // Copy the sign bit of the lower 128-bit half to the upper 128-bit half + const auto vmask3 = + BroadcastSignBit(Broadcast<3>(BitCast(dh_i32, LowerHalf(dh_i64, vmask)))); + vmask = Or(vmask, BitCast(di64, Combine(di32, vmask3, Zero(dh_i32)))); + return MaskFromVec(BitCast(d, vmask)); +} + +template +HWY_API Mask256 SetBeforeFirst(Mask256 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask256 SetOnlyFirst(Mask256 mask) { + const Full256 d; + const RebindToSigned di; + const Repartition di64; + const Half dh_i64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + + const auto vmask_eq_0 = VecFromMask(di64, vmask == zero); + auto vmask2_lo = LowerHalf(dh_i64, vmask_eq_0); + auto vmask2_hi = UpperHalf(dh_i64, vmask_eq_0); + + vmask2_lo = And(vmask2_lo, InterleaveLower(vmask2_lo, vmask2_lo)); + vmask2_hi = And(ConcatLowerUpper(dh_i64, vmask2_hi, vmask2_lo), + InterleaveUpper(dh_i64, vmask2_lo, vmask2_lo)); + vmask2_lo = InterleaveLower(Set(dh_i64, int64_t{-1}), vmask2_lo); + + const auto vmask2 = Combine(di64, vmask2_hi, vmask2_lo); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask256 SetAtOrBeforeFirst(Mask256 mask) { + const Full256 d; + constexpr size_t kLanesPerBlock = MaxLanes(d) / 2; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_lo = ConcatLowerLower(d, vmask, Zero(d)); + return SetBeforeFirst( + MaskFromVec(CombineShiftRightBytes<(kLanesPerBlock - 1) * sizeof(T)>( + d, vmask, vmask_lo))); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reductions in generic_ops + +// ------------------------------ BitShuffle +#if HWY_TARGET <= HWY_AVX3_DL +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 32), HWY_IF_V_SIZE_V(VI, 32)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Rebind du8; + + int32_t i32_bit_shuf_result = + static_cast(_mm256_bitshuffle_epi64_mask(v.raw, idx.raw)); + + return BitCast(d64, PromoteTo(du64, VFromD{_mm_cvtsi32_si128( + i32_bit_shuf_result)})); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ MultiRotateRight + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_MULTIROTATERIGHT +#undef HWY_NATIVE_MULTIROTATERIGHT +#else +#define HWY_NATIVE_MULTIROTATERIGHT +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 32), HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + return V{_mm256_multishift_epi64_epi8(idx.raw, v.raw)}; +} + +#endif + +// ------------------------------ LeadingZeroCount + +#if HWY_TARGET <= HWY_AVX3 +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm256_lzcnt_epi32(v.raw)}; +} + +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm256_lzcnt_epi64(v.raw)}; +} + +namespace detail { + +template , HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V Lzcnt32ForU8OrU16OrU32(V v) { + const DFromV d; + const Rebind di32; + const Rebind du32; + + const auto v_lz_count = LeadingZeroCount(PromoteTo(du32, v)); + return DemoteTo(d, BitCast(di32, v_lz_count)); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED V Lzcnt32ForU8OrU16OrU32(V v) { + return LeadingZeroCount(v); +} + +template , HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V Lzcnt32ForU8OrU16OrU32(V v) { + const DFromV d; + const RepartitionToWide dw; + const RebindToSigned dw_i; + + const auto lo_v_lz_count = Lzcnt32ForU8OrU16OrU32(PromoteLowerTo(dw, v)); + const auto hi_v_lz_count = Lzcnt32ForU8OrU16OrU32(PromoteUpperTo(dw, v)); + return OrderedDemote2To(d, BitCast(dw_i, lo_v_lz_count), + BitCast(dw_i, hi_v_lz_count)); +} + +} // namespace detail + +template +HWY_API V LeadingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + + constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; + const auto v_lzcnt32 = detail::Lzcnt32ForU8OrU16OrU32(BitCast(du, v)); + return BitCast(d, Min(v_lzcnt32 - Set(du, TU{32 - kNumOfBitsInT}), + Set(du, TU{kNumOfBitsInT}))); +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + return BitCast( + d, Set(du, TU{31}) - detail::Lzcnt32ForU8OrU16OrU32(BitCast(du, v))); +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); +} + +template +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const RebindToSigned di; + using T = TFromD; + + const auto vi = BitCast(di, v); + const auto lowest_bit = BitCast(d, And(vi, Neg(vi))); + constexpr T kNumOfBitsInT{sizeof(T) * 8}; + const auto bit_idx = HighestSetBitIndex(lowest_bit); + return IfThenElse(MaskFromVec(bit_idx), Set(d, kNumOfBitsInT), bit_idx); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/third_party/aom/third_party/highway/hwy/ops/x86_512-inl.h b/third_party/aom/third_party/highway/hwy/ops/x86_512-inl.h new file mode 100644 index 000000000000..9fc52d23f095 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/x86_512-inl.h @@ -0,0 +1,7634 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 512-bit AVX512 vectors and operations. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "third_party/highway/hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, + ignored "-Wmaybe-uninitialized") +#endif + +#include // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +// clang-format off +#include + +#include +// avxintrin defines __m256i and must come before avx2intrin. +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#if HWY_TARGET <= HWY_AVX3_DL +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// Must come after avx512fintrin, else will not define 512-bit intrinsics. +#include +#include +#include +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3_SPR +#include +#include +#endif // HWY_TARGET <= HWY_AVX3_SPR + +// clang-format on +#endif // HWY_COMPILER_CLANGCL + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "third_party/highway/hwy/ops/x86_256-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { + +template +struct Raw512 { + using type = __m512i; +}; +#if HWY_HAVE_FLOAT16 +template <> +struct Raw512 { + using type = __m512h; +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct Raw512 { + using type = __m512; +}; +template <> +struct Raw512 { + using type = __m512d; +}; + +// Template arg: sizeof(lane type) +template +struct RawMask512 {}; +template <> +struct RawMask512<1> { + using type = __mmask64; +}; +template <> +struct RawMask512<2> { + using type = __mmask32; +}; +template <> +struct RawMask512<4> { + using type = __mmask16; +}; +template <> +struct RawMask512<8> { + using type = __mmask8; +}; + +} // namespace detail + +template +class Vec512 { + using Raw = typename detail::Raw512::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec512& operator*=(const Vec512 other) { + return *this = (*this * other); + } + HWY_INLINE Vec512& operator/=(const Vec512 other) { + return *this = (*this / other); + } + HWY_INLINE Vec512& operator+=(const Vec512 other) { + return *this = (*this + other); + } + HWY_INLINE Vec512& operator-=(const Vec512 other) { + return *this = (*this - other); + } + HWY_INLINE Vec512& operator%=(const Vec512 other) { + return *this = (*this % other); + } + HWY_INLINE Vec512& operator&=(const Vec512 other) { + return *this = (*this & other); + } + HWY_INLINE Vec512& operator|=(const Vec512 other) { + return *this = (*this | other); + } + HWY_INLINE Vec512& operator^=(const Vec512 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// Mask register: one bit per lane. +template +struct Mask512 { + using Raw = typename detail::RawMask512::type; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromM + + Raw raw; +}; + +template +using Full512 = Simd; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; } +#if HWY_HAVE_FLOAT16 +HWY_INLINE __m512i BitCastToInteger(__m512h v) { + return _mm512_castph_si512(v); +} +#endif // HWY_HAVE_FLOAT16 +HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } +HWY_INLINE __m512i BitCastToInteger(__m512d v) { + return _mm512_castpd_si512(v); +} + +#if HWY_AVX3_HAVE_F32_TO_BF16C +HWY_INLINE __m512i BitCastToInteger(__m512bh v) { + // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to + // bit cast a __m512bh to a __m512i as there is currently no intrinsic + // available (as of GCC 13 and Clang 17) that can bit cast a __m512bh vector + // to a __m512i vector + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // On GCC or Clang, use reinterpret_cast to bit cast a __m512bh to a __m512i + return reinterpret_cast<__m512i>(v); +#else + // On MSVC, use BitCastScalar to bit cast a __m512bh to a __m512i as MSVC does + // not allow reinterpret_cast, static_cast, or a C-style cast to be used to + // bit cast from one AVX vector type to a different AVX vector type + return BitCastScalar<__m512i>(v); +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_INLINE Vec512 BitCastToByte(Vec512 v) { + return Vec512{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger512 { + HWY_INLINE __m512i operator()(__m512i v) { return v; } +}; +#if HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512h operator()(__m512i v) { return _mm512_castsi512_ph(v); } +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } +}; +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } +}; + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, Vec512 v) { + return VFromD{BitCastFromInteger512>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, Vec512 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm512_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm512_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm512_set1_epi32(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm512_set1_epi64(static_cast(t))}; // NOLINT +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 Set(D /* tag */, float16_t t) { + return Vec512{_mm512_set1_ph(t)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 Set(D /* tag */, float t) { + return Vec512{_mm512_set1_ps(t)}; +} +template +HWY_API Vec512 Set(D /* tag */, double t) { + return Vec512{_mm512_set1_pd(t)}; +} + +// ------------------------------ Zero (Set) + +// GCC pre-9.1 lacked setzero, so use Set instead. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + +// Cannot use VFromD here because it is defined in terms of Zero. +template +HWY_API Vec512> Zero(D d) { + return Set(d, TFromD{0}); +} +// BitCast is defined below, but the Raw type is the same, so use that. +template +HWY_API Vec512 Zero(D /* tag */) { + const RebindToUnsigned du; + return Vec512{Set(du, 0).raw}; +} +template +HWY_API Vec512 Zero(D /* tag */) { + const RebindToUnsigned du; + return Vec512{Set(du, 0).raw}; +} + +#else + +template +HWY_API Vec512> Zero(D /* tag */) { + return Vec512>{_mm512_setzero_si512()}; +} +template +HWY_API Vec512 Zero(D /* tag */) { + return Vec512{_mm512_setzero_si512()}; +} +template +HWY_API Vec512 Zero(D /* tag */) { +#if HWY_HAVE_FLOAT16 + return Vec512{_mm512_setzero_ph()}; +#else + return Vec512{_mm512_setzero_si512()}; +#endif +} +template +HWY_API Vec512 Zero(D /* tag */) { + return Vec512{_mm512_setzero_ps()}; +} +template +HWY_API Vec512 Zero(D /* tag */) { + return Vec512{_mm512_setzero_pd()}; +} + +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + +// ------------------------------ Undefined + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec512> Undefined(D /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec512>{_mm512_undefined_epi32()}; +} +template +HWY_API Vec512 Undefined(D /* tag */) { + return Vec512{_mm512_undefined_epi32()}; +} +template +HWY_API Vec512 Undefined(D /* tag */) { +#if HWY_HAVE_FLOAT16 + return Vec512{_mm512_undefined_ph()}; +#else + return Vec512{_mm512_undefined_epi32()}; +#endif +} +template +HWY_API Vec512 Undefined(D /* tag */) { + return Vec512{_mm512_undefined_ps()}; +} +template +HWY_API Vec512 Undefined(D /* tag */) { + return Vec512{_mm512_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ ResizeBitCast + +// 64-byte vector to 16-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, Vec128{_mm512_castsi512_si128( + BitCast(Full512(), v).raw)}); +} + +// <= 16-byte vector to 64-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, Vec512{_mm512_castsi128_si512( + ResizeBitCast(Full128(), v).raw)}); +} + +// 32-byte vector to 64-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, Vec512{_mm512_castsi256_si512( + BitCast(Full256(), v).raw)}); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + return BroadcastBlock<0>(ResizeBitCast( + d, Dup128VecFromValues(Full128>(), t0, t1, t2, t3, t4, t5, t6, + t7, t8, t9, t10, t11, t12, t13, t14, t15))); +#else + (void)d; + // Need to use _mm512_set_epi8 as there is no _mm512_setr_epi8 intrinsic + // available + return VFromD{_mm512_set_epi8( + static_cast(t15), static_cast(t14), static_cast(t13), + static_cast(t12), static_cast(t11), static_cast(t10), + static_cast(t9), static_cast(t8), static_cast(t7), + static_cast(t6), static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), static_cast(t1), + static_cast(t0), static_cast(t15), static_cast(t14), + static_cast(t13), static_cast(t12), static_cast(t11), + static_cast(t10), static_cast(t9), static_cast(t8), + static_cast(t7), static_cast(t6), static_cast(t5), + static_cast(t4), static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), static_cast(t15), + static_cast(t14), static_cast(t13), static_cast(t12), + static_cast(t11), static_cast(t10), static_cast(t9), + static_cast(t8), static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), static_cast(t3), + static_cast(t2), static_cast(t1), static_cast(t0), + static_cast(t15), static_cast(t14), static_cast(t13), + static_cast(t12), static_cast(t11), static_cast(t10), + static_cast(t9), static_cast(t8), static_cast(t7), + static_cast(t6), static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), static_cast(t1), + static_cast(t0))}; +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + return BroadcastBlock<0>( + ResizeBitCast(d, Dup128VecFromValues(Full128>(), t0, t1, t2, t3, + t4, t5, t6, t7))); +#else + (void)d; + // Need to use _mm512_set_epi16 as there is no _mm512_setr_epi16 intrinsic + // available + return VFromD{ + _mm512_set_epi16(static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), + static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), + static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), + static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0))}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{_mm512_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2, + t3, t4, t5, t6, t7, t0, t1, t2, t3, t4, t5, + t6, t7, t0, t1, t2, t3, t4, t5, t6, t7)}; +} +#endif + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{ + _mm512_setr_epi32(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{_mm512_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3, t0, t1, t2, + t3, t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{ + _mm512_setr_epi64(static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{_mm512_setr_pd(t0, t1, t0, t1, t0, t1, t0, t1)}; +} + +// ----------------------------- Iota + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D d) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + alignas(64) static constexpr TFromD kIota[64] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}; + return Load(d, kIota); +#else + (void)d; + return VFromD{_mm512_set_epi8( + static_cast(63), static_cast(62), static_cast(61), + static_cast(60), static_cast(59), static_cast(58), + static_cast(57), static_cast(56), static_cast(55), + static_cast(54), static_cast(53), static_cast(52), + static_cast(51), static_cast(50), static_cast(49), + static_cast(48), static_cast(47), static_cast(46), + static_cast(45), static_cast(44), static_cast(43), + static_cast(42), static_cast(41), static_cast(40), + static_cast(39), static_cast(38), static_cast(37), + static_cast(36), static_cast(35), static_cast(34), + static_cast(33), static_cast(32), static_cast(31), + static_cast(30), static_cast(29), static_cast(28), + static_cast(27), static_cast(26), static_cast(25), + static_cast(24), static_cast(23), static_cast(22), + static_cast(21), static_cast(20), static_cast(19), + static_cast(18), static_cast(17), static_cast(16), + static_cast(15), static_cast(14), static_cast(13), + static_cast(12), static_cast(11), static_cast(10), + static_cast(9), static_cast(8), static_cast(7), + static_cast(6), static_cast(5), static_cast(4), + static_cast(3), static_cast(2), static_cast(1), + static_cast(0))}; +#endif +} + +template +HWY_INLINE VFromD Iota0(D d) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + alignas(64) static constexpr TFromD kIota[32] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + return Load(d, kIota); +#else + (void)d; + return VFromD{_mm512_set_epi16( + int16_t{31}, int16_t{30}, int16_t{29}, int16_t{28}, int16_t{27}, + int16_t{26}, int16_t{25}, int16_t{24}, int16_t{23}, int16_t{22}, + int16_t{21}, int16_t{20}, int16_t{19}, int16_t{18}, int16_t{17}, + int16_t{16}, int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12}, + int16_t{11}, int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6}, + int16_t{5}, int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_ph( + float16_t{31}, float16_t{30}, float16_t{29}, float16_t{28}, float16_t{27}, + float16_t{26}, float16_t{25}, float16_t{24}, float16_t{23}, float16_t{22}, + float16_t{21}, float16_t{20}, float16_t{19}, float16_t{18}, float16_t{17}, + float16_t{16}, float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12}, + float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8}, float16_t{7}, + float16_t{6}, float16_t{5}, float16_t{4}, float16_t{3}, float16_t{2}, + float16_t{1}, float16_t{0})}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_epi32( + int32_t{15}, int32_t{14}, int32_t{13}, int32_t{12}, int32_t{11}, + int32_t{10}, int32_t{9}, int32_t{8}, int32_t{7}, int32_t{6}, int32_t{5}, + int32_t{4}, int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_epi64(int64_t{7}, int64_t{6}, int64_t{5}, + int64_t{4}, int64_t{3}, int64_t{2}, + int64_t{1}, int64_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, + 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, + 0.0f)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0)}; +} + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + return detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec512 Not(const Vec512 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i vu = BitCast(du, v).raw; + return BitCast(d, VU{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); +} + +// ------------------------------ And + +template +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_and_si512(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_ps(a.raw, b.raw)}; +} +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_andnot_si512( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_or_si512(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_xor_si512(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor3 +template +HWY_API Vec512 Xor3(Vec512 x1, Vec512 x2, Vec512 x3) { +#if !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +#else + return Xor(x1, Xor(x2, x3)); +#endif +} + +// ------------------------------ Or3 +template +HWY_API Vec512 Or3(Vec512 o1, Vec512 o2, Vec512 o3) { +#if !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template +HWY_API Vec512 OrAnd(Vec512 o, Vec512 a1, Vec512 a2) { +#if !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec512 IfVecThenElse(Vec512 mask, Vec512 yes, Vec512 no) { +#if !HWY_IS_MSAN + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec512 operator&(const Vec512 a, const Vec512 b) { + return And(a, b); +} + +template +HWY_API Vec512 operator|(const Vec512 a, const Vec512 b) { + return Or(a, b); +} + +template +HWY_API Vec512 operator^(const Vec512 a, const Vec512 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<1> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<2> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<4> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<8> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 PopulationCount(Vec512 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ================================================== MASK + +// ------------------------------ FirstN + +// Possibilities for constructing a bitmask of N ones: +// - kshift* only consider the lowest byte of the shift count, so they would +// not correctly handle large n. +// - Scalar shifts >= 64 are UB. +// - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, +// we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. + +#if HWY_ARCH_X86_32 +namespace detail { + +// 32 bit mask is sufficient for lane size >= 2. +template +HWY_INLINE Mask512 FirstN(size_t n) { + Mask512 m; + const uint32_t all = ~uint32_t{0}; + // BZHI only looks at the lower 8 bits of n, but it has been clamped to + // MaxLanes, which is at most 32. + m.raw = static_cast(_bzhi_u32(all, n)); + return m; +} + +#if HWY_COMPILER_MSVC >= 1920 || HWY_COMPILER_GCC_ACTUAL >= 900 || \ + HWY_COMPILER_CLANG || HWY_COMPILER_ICC +template +HWY_INLINE Mask512 FirstN(size_t n) { + uint32_t lo_mask; + uint32_t hi_mask; + uint32_t hi_mask_len; +#if HWY_COMPILER_GCC + if (__builtin_constant_p(n >= 32) && n >= 32) { + if (__builtin_constant_p(n >= 64) && n >= 64) { + hi_mask_len = 32u; + } else { + hi_mask_len = static_cast(n) - 32u; + } + lo_mask = hi_mask = 0xFFFFFFFFu; + } else // NOLINT(readability/braces) +#endif + { + const uint32_t lo_mask_len = static_cast(n); + lo_mask = _bzhi_u32(0xFFFFFFFFu, lo_mask_len); + +#if HWY_COMPILER_GCC + if (__builtin_constant_p(lo_mask_len <= 32) && lo_mask_len <= 32) { + return Mask512{static_cast<__mmask64>(lo_mask)}; + } +#endif + + _addcarry_u32(_subborrow_u32(0, lo_mask_len, 32u, &hi_mask_len), + 0xFFFFFFFFu, 0u, &hi_mask); + } + hi_mask = _bzhi_u32(hi_mask, hi_mask_len); +#if HWY_COMPILER_GCC && !HWY_COMPILER_ICC + if (__builtin_constant_p((static_cast(hi_mask) << 32) | lo_mask)) +#endif + return Mask512{static_cast<__mmask64>( + (static_cast(hi_mask) << 32) | lo_mask)}; +#if HWY_COMPILER_GCC && !HWY_COMPILER_ICC + else + return Mask512{_mm512_kunpackd(static_cast<__mmask64>(hi_mask), + static_cast<__mmask64>(lo_mask))}; +#endif +} +#else // HWY_COMPILER.. +template +HWY_INLINE Mask512 FirstN(size_t n) { + const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0}; + return Mask512{static_cast<__mmask64>(bits)}; +} +#endif // HWY_COMPILER.. +} // namespace detail +#endif // HWY_ARCH_X86_32 + +template +HWY_API MFromD FirstN(D d, size_t n) { + // This ensures `num` <= 255 as required by bzhi, which only looks + // at the lower 8 bits. + n = HWY_MIN(n, MaxLanes(d)); + +#if HWY_ARCH_X86_64 + MFromD m; + const uint64_t all = ~uint64_t{0}; + m.raw = static_cast(_bzhi_u64(all, n)); + return m; +#else + return detail::FirstN>(n); +#endif // HWY_ARCH_X86_64 +} + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 IfThenElse(Mask512 mask, + Vec512 yes, + Vec512 no) { + return Vec512{_mm512_mask_blend_ph(mask.raw, no.raw, yes.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 IfThenElse(Mask512 mask, Vec512 yes, + Vec512 no) { + return Vec512{_mm512_mask_blend_ps(mask.raw, no.raw, yes.raw)}; +} +HWY_API Vec512 IfThenElse(Mask512 mask, Vec512 yes, + Vec512 no) { + return Vec512{_mm512_mask_blend_pd(mask.raw, no.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<1> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<2> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<4> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<8> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} +HWY_API Vec512 IfThenElseZero(Mask512 mask, Vec512 yes) { + return Vec512{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec512 IfThenElseZero(Mask512 mask, + Vec512 yes) { + return Vec512{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec512{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} +HWY_API Vec512 IfThenZeroElse(Mask512 mask, Vec512 no) { + return Vec512{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec512 IfThenZeroElse(Mask512 mask, Vec512 no) { + return Vec512{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec512 IfNegativeThenElse(Vec512 v, Vec512 yes, Vec512 no) { + static_assert(IsSigned(), "Only works for signed/float"); + // AVX3 MaskFromVec only looks at the MSB + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec512 IfNegativeThenNegOrUndefIfZero(Vec512 mask, Vec512 v) { + // AVX3 MaskFromVec only looks at the MSB + const DFromV d; + return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec512 SumsOf8(const Vec512 v) { + const Full512 d; + return Vec512{_mm512_sad_epu8(v.raw, Zero(d).raw)}; +} + +HWY_API Vec512 SumsOf8AbsDiff(Vec512 a, Vec512 b) { + return Vec512{_mm512_sad_epu8(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf4 +namespace detail { + +HWY_INLINE Vec512 SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + Vec512 v) { + const DFromV d; + + // _mm512_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be + // zeroed out and the sums of the 4 consecutive lanes are already in the + // even uint16_t lanes of the _mm512_maskz_dbsad_epu8 result. + return Vec512{_mm512_maskz_dbsad_epu8( + static_cast<__mmask32>(0x55555555), v.raw, Zero(d).raw, 0)}; +} + +// I8->I32 SumsOf4 +// Generic for all vector lengths +template +HWY_INLINE VFromD>> SumsOf4( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWideX2 di32; + + // Adjust the values of v to be in the 0..255 range by adding 128 to each lane + // of v (which is the same as an bitwise XOR of each i8 lane by 128) and then + // bitcasting the Xor result to an u8 vector. + const auto v_adj = BitCast(du, Xor(v, SignBit(d))); + + // Need to add -512 to each i32 lane of the result of the + // SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj) operation to account + // for the adjustment made above. + return BitCast(di32, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj)) + + Set(di32, int32_t{-512}); +} + +} // namespace detail + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if HWY_TARGET <= HWY_AVX3 +template +static Vec512 SumsOfShuffledQuadAbsDiff(Vec512 a, + Vec512 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + return Vec512{ + _mm512_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; +} +#endif + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + return Vec512{_mm512_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + return Vec512{_mm512_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + return Vec512{_mm512_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + return Vec512{_mm512_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + return Vec512{_mm512_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + return Vec512{_mm512_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + return Vec512{_mm512_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + return Vec512{_mm512_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec512 AverageRound(Vec512 a, Vec512 b) { + return Vec512{_mm512_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 AverageRound(Vec512 a, Vec512 b) { + return Vec512{_mm512_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec512 Abs(const Vec512 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (untested due to internal compiler error) + const DFromV d; + const auto zero = Zero(d); + return Vec512{_mm512_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec512{_mm512_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi16(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi32(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi64(v.raw)}; +} + +// ------------------------------ ShiftLeft + +#if HWY_TARGET <= HWY_AVX3_DL +namespace detail { +template +HWY_API Vec512 GaloisAffine(Vec512 v, Vec512 matrix) { + return Vec512{_mm512_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; +} +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + const DFromV d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ ShiftRight + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi64(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec512 shifted{ShiftRight(Vec512{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ RotateRight + +#if HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + if (kBits == 0) return v; + // AVX3 does not support 8-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +} +#endif // HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); + if (kBits == 0) return v; +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_shrdi_epi16(v.raw, v.raw, kBits)}; +#else + // AVX3 does not support 16-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); + if (kBits == 0) return v; + return Vec512{_mm512_ror_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); + if (kBits == 0) return v; + return Vec512{_mm512_ror_epi64(v.raw, kBits)}; +} + +// ------------------------------ Rol/Ror +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API Vec512 Ror(Vec512 a, Vec512 b) { + return Vec512{_mm512_shrdv_epi16(a.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec512 Rol(Vec512 a, Vec512 b) { + return Vec512{_mm512_rolv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec512 Ror(Vec512 a, Vec512 b) { + return Vec512{_mm512_rorv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec512 Rol(Vec512 a, Vec512 b) { + return Vec512{_mm512_rolv_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec512 Ror(Vec512 a, Vec512 b) { + return Vec512{_mm512_rorv_epi64(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeftSame + +// GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512 +// shift-with-immediate: the counts should all be unsigned int. Despite casting, +// we still see warnings in GCC debug builds, hence disable. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100 +using Shift16Count = int; +using Shift3264Count = int; +#elif HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400 +// GCC 11.0 requires these, prior versions used a macro+cast and don't care. +using Shift16Count = int; +using Shift3264Count = unsigned int; +#else +// Assume documented behavior. Clang 11, GCC 14 and MSVC 14.28.29910 match this. +using Shift16Count = unsigned int; +using Shift3264Count = unsigned int; +#endif + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi16(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi32(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi16(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi32(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + const DFromV d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srli_epi16(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srli_epi32(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srli_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const DFromV d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srai_epi16(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srai_epi32(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srai_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Minimum + +// Unsigned +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec512 MulFixedPoint15(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ Neg (Sub) + +template +HWY_API Vec512 Neg(const Vec512 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +template +HWY_API Vec512 Neg(const Vec512 v) { + const DFromV d; + return Zero(d) - v; +} + +// ------------------------------ Floating-point mul / div + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_pd(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MulByFloorPow2(Vec512 a, + Vec512 b) { + return Vec512{_mm512_scalef_ph(a.raw, b.raw)}; +} +#endif + +HWY_API Vec512 MulByFloorPow2(Vec512 a, Vec512 b) { + return Vec512{_mm512_scalef_ps(a.raw, b.raw)}; +} + +HWY_API Vec512 MulByFloorPow2(Vec512 a, Vec512 b) { + return Vec512{_mm512_scalef_pd(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 operator/(Vec512 a, Vec512 b) { + return Vec512{_mm512_div_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 operator/(Vec512 a, Vec512 b) { + return Vec512{_mm512_div_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator/(Vec512 a, Vec512 b) { + return Vec512{_mm512_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 ApproximateReciprocal(const Vec512 v) { + return Vec512{_mm512_rcp_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 ApproximateReciprocal(const Vec512 v) { + return Vec512{_mm512_rcp14_ps(v.raw)}; +} + +HWY_API Vec512 ApproximateReciprocal(Vec512 v) { + return Vec512{_mm512_rcp14_pd(v.raw)}; +} + +// ------------------------------ GetExponent + +#if HWY_HAVE_FLOAT16 +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V GetExponent(V v) { + return V{_mm512_getexp_ph(v.raw)}; +} +#endif +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V GetExponent(V v) { + return V{_mm512_getexp_ps(v.raw)}; +} +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V GetExponent(V v) { + return V{_mm512_getexp_pd(v.raw)}; +} + +// ------------------------------ MaskedMinOr + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMaxOr + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedAddOr + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSubOr + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMulOr + +HWY_API Vec512 MaskedMulOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec512 MaskedMulOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MaskedMulOr(Vec512 no, + Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedDivOr + +HWY_API Vec512 MaskedDivOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec512 MaskedDivOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MaskedDivOr(Vec512 no, + Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSatAddOr + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ MaskedSatSubOr + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ Floating-point multiply-add variants + +#if HWY_HAVE_FLOAT16 + +HWY_API Vec512 MulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec512 NegMulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fnmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec512 MulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +HWY_API Vec512 NegMulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fnmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +// Returns mul * x + add +HWY_API Vec512 MulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 MulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns add - mul * x +HWY_API Vec512 NegMulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 NegMulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns mul * x - sub +HWY_API Vec512 MulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 MulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// Returns -mul * x - sub +HWY_API Vec512 NegMulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 NegMulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MulAddSub(Vec512 mul, Vec512 x, + Vec512 sub_or_add) { + return Vec512{_mm512_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec512 MulAddSub(Vec512 mul, Vec512 x, + Vec512 sub_or_add) { + return Vec512{_mm512_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; +} + +HWY_API Vec512 MulAddSub(Vec512 mul, Vec512 x, + Vec512 sub_or_add) { + return Vec512{_mm512_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_ps(v.raw)}; +} +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 ApproximateReciprocalSqrt(Vec512 v) { + return Vec512{_mm512_rsqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 ApproximateReciprocalSqrt(Vec512 v) { + return Vec512{_mm512_rsqrt14_ps(v.raw)}; +} + +HWY_API Vec512 ApproximateReciprocalSqrt(Vec512 v) { + return Vec512{_mm512_rsqrt14_pd(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Toward nearest integer, tie to even +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Round(Vec512 v) { + return Vec512{_mm512_roundscale_ph( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Round(Vec512 v) { + return Vec512{_mm512_roundscale_ps( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Round(Vec512 v) { + return Vec512{_mm512_roundscale_pd( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Trunc(Vec512 v) { + return Vec512{ + _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Trunc(Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Trunc(Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Ceil(Vec512 v) { + return Vec512{ + _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Ceil(Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Ceil(Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Floor(Vec512 v) { + return Vec512{ + _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Floor(Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Floor(Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo /*tag*/, Mask512 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<1> /*tag*/, Vec512 v, + Vec512 bit) { + return Mask512{_mm512_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<2> /*tag*/, Vec512 v, + Vec512 bit) { + return Mask512{_mm512_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<4> /*tag*/, Vec512 v, + Vec512 bit) { + return Mask512{_mm512_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<8> /*tag*/, Vec512 v, + Vec512 bit) { + return Mask512{_mm512_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask512 TestBit(const Vec512 v, const Vec512 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask512 operator==(Vec512 a, + Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask512 operator!=(Vec512 a, + Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +#if HWY_HAVE_FLOAT16 +HWY_API Mask512 operator>=(Vec512 a, + Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epi64_mask(a.raw, b.raw)}; +} + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask512 operator<(Vec512 a, Vec512 b) { + return b > a; +} + +template +HWY_API Mask512 operator<=(Vec512 a, Vec512 b) { + return b >= a; +} + +// ------------------------------ Mask + +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + return Mask512{_mm512_movepi8_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + return Mask512{_mm512_movepi16_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + return Mask512{_mm512_movepi32_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + return Mask512{_mm512_movepi64_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + const RebindToSigned> di; + return Mask512{MaskFromVec(BitCast(di, v)).raw}; +} + +template +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_movm_epi8(m.raw)}; +} +template +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_movm_epi16(m.raw)}; +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_castsi512_ph(_mm512_movm_epi16(m.raw))}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_movm_epi32(m.raw)}; +} +template +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_movm_epi64(m.raw)}; +} +template +HWY_API Vec512 VecFromMask(Mask512 m) { + const Full512 d; + const Full512> di; + return BitCast(d, VecFromMask(RebindMask(di, m))); +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask512 Not(hwy::SizeTag<1> /*tag*/, Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask64(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<2> /*tag*/, Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask32(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<4> /*tag*/, Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask16(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<8> /*tag*/, Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask8(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFF)}; +#endif +} + +template +HWY_INLINE Mask512 And(hwy::SizeTag<1> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<2> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<4> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<8> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<1> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask64(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<2> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<4> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<8> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 Or(hwy::SizeTag<1> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<2> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<4> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<8> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<1> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<2> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<4> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<8> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask64(a.raw, b.raw)}; +#else + return Mask512{~(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask512 Not(Mask512 m) { + return detail::Not(hwy::SizeTag(), m); +} + +template +HWY_API Mask512 And(Mask512 a, Mask512 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 AndNot(Mask512 a, Mask512 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Or(Mask512 a, Mask512 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Xor(Mask512 a, Mask512 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 ExclusiveNeither(Mask512 a, Mask512 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask64 combined_mask = _mm512_kunpackd( + static_cast<__mmask64>(hi.raw), static_cast<__mmask64>(lo.raw)); +#else + const __mmask64 combined_mask = static_cast<__mmask64>( + ((static_cast(hi.raw) << 32) | (lo.raw & 0xFFFFFFFFULL))); +#endif + + return MFromD{combined_mask}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask64(static_cast<__mmask64>(m.raw), 32); +#else + const auto shifted_mask = static_cast(m.raw) >> 32; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD SlideMask1Up(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftli_mask64(static_cast<__mmask64>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) << 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftri_mask64(static_cast<__mmask64>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) >> 1)}; +#endif +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec512 BroadcastSignBit(Vec512 v) { +#if HWY_TARGET <= HWY_AVX3_DL + const Repartition> du64; + return detail::GaloisAffine(v, Set(du64, 0x8080808080808080ull)); +#else + const DFromV d; + return VecFromMask(v < Zero(d)); +#endif +} + +HWY_API Vec512 BroadcastSignBit(Vec512 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec512 BroadcastSignBit(Vec512 v) { + return ShiftRight<31>(v); +} + +HWY_API Vec512 BroadcastSignBit(Vec512 v) { + return ShiftRight<63>(v); +} + +// ------------------------------ Floating-point classification (Not) + +#if HWY_HAVE_FLOAT16 || HWY_IDE + +namespace detail { + +template +__mmask32 Fix_mm512_fpclass_ph_mask(__m512h v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500 + // GCC's _mm512_cmp_ph_mask uses `__mmask8` instead of `__mmask32`, hence only + // the first 8 lanes are set. + return static_cast<__mmask32>(__builtin_ia32_fpclassph512_mask( + static_cast<__v32hf>(v), kCategories, static_cast<__mmask32>(-1))); +#else + return _mm512_fpclass_ph_mask(v, kCategories); +#endif +} + +} // namespace detail + +HWY_API Mask512 IsNaN(Vec512 v) { + constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN; + return Mask512{ + detail::Fix_mm512_fpclass_ph_mask(v.raw)}; +} + +HWY_API Mask512 IsEitherNaN(Vec512 a, + Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; + HWY_DIAGNOSTICS(pop) +} + +HWY_API Mask512 IsInf(Vec512 v) { + constexpr int kCategories = HWY_X86_FPCLASS_POS_INF | HWY_X86_FPCLASS_NEG_INF; + return Mask512{ + detail::Fix_mm512_fpclass_ph_mask(v.raw)}; +} + +// Returns whether normal/subnormal/zero. fpclass doesn't have a flag for +// positive, so we have to check for inf/NaN and negate. +HWY_API Mask512 IsFinite(Vec512 v) { + constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF; + return Not(Mask512{ + detail::Fix_mm512_fpclass_ph_mask(v.raw)}); +} + +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 IsNaN(Vec512 v) { + return Mask512{_mm512_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +} +HWY_API Mask512 IsNaN(Vec512 v) { + return Mask512{_mm512_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +} + +HWY_API Mask512 IsEitherNaN(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +} + +HWY_API Mask512 IsEitherNaN(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +} + +HWY_API Mask512 IsInf(Vec512 v) { + return Mask512{_mm512_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} +HWY_API Mask512 IsInf(Vec512 v) { + return Mask512{_mm512_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +// Returns whether normal/subnormal/zero. fpclass doesn't have a flag for +// positive, so we have to check for inf/NaN and negate. +HWY_API Mask512 IsFinite(Vec512 v) { + return Not(Mask512{_mm512_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} +HWY_API Mask512 IsFinite(Vec512 v) { + return Not(Mask512{_mm512_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { + return VFromD{_mm512_load_si512(aligned)}; +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 Load(D /* tag */, + const float16_t* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_ph(aligned)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 Load(D /* tag */, const float* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_ps(aligned)}; +} +template +HWY_API VFromD Load(D /* tag */, const double* HWY_RESTRICT aligned) { + return VFromD{_mm512_load_pd(aligned)}; +} + +template +HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_loadu_si512(p)}; +} + +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_ph(p)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 LoadU(D /* tag */, const float* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_ps(p)}; +} +template +HWY_API VFromD LoadU(D /* tag */, const double* HWY_RESTRICT p) { + return VFromD{_mm512_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm512_maskz_loadu_epi16( + m.raw, reinterpret_cast(p))}); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, D /* tag */, + const float* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, D /* tag */, + const double* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_pd(m.raw, p)}; +} + +// ------------------------------ MaskedLoadOr + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_epi8(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast( + d, VFromD{_mm512_mask_loadu_epi16( + BitCast(du, v).raw, m.raw, reinterpret_cast(p))}); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_epi32(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_epi64(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, Mask512 m, D /* tag */, + const float* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_ps(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, Mask512 m, D /* tag */, + const double* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_pd(v.raw, m.raw, p)}; +} + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API VFromD LoadDup128(D d, const TFromD* const HWY_RESTRICT p) { + const RebindToUnsigned du; + const Full128> d128; + const RebindToUnsigned du128; + return BitCast(d, VFromD{_mm512_broadcast_i32x4( + BitCast(du128, LoadU(d128, p)).raw)}); +} +template +HWY_API VFromD LoadDup128(D /* tag */, const float* HWY_RESTRICT p) { + const __m128 x4 = _mm_loadu_ps(p); + return VFromD{_mm512_broadcast_f32x4(x4)}; +} + +template +HWY_API VFromD LoadDup128(D /* tag */, const double* HWY_RESTRICT p) { + const __m128d x2 = _mm_loadu_pd(p); + return VFromD{_mm512_broadcast_f64x2(x2)}; +} + +// ------------------------------ Store + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API void Store(Vec512 v, D /* tag */, + float16_t* HWY_RESTRICT aligned) { + _mm512_store_ph(aligned, v.raw); +} +#endif +template +HWY_API void Store(Vec512 v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm512_store_ps(aligned, v.raw); +} +template +HWY_API void Store(VFromD v, D /* tag */, double* HWY_RESTRICT aligned) { + _mm512_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec512 v, D /* tag */, + float16_t* HWY_RESTRICT p) { + _mm512_storeu_ph(p, v.raw); +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API void StoreU(Vec512 v, D /* tag */, float* HWY_RESTRICT p) { + _mm512_storeu_ps(p, v.raw); +} +template +HWY_API void StoreU(Vec512 v, D /* tag */, double* HWY_RESTRICT p) { + _mm512_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm512_mask_storeu_epi8(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + _mm512_mask_storeu_epi16(reinterpret_cast(p), m.raw, + BitCast(du, v).raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm512_mask_storeu_epi32(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm512_mask_storeu_epi64(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, D /* tag */, + float* HWY_RESTRICT p) { + _mm512_mask_storeu_ps(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, D /* tag */, + double* HWY_RESTRICT p) { + _mm512_mask_storeu_pd(p, m.raw, v.raw); +} + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; // for float16_t + _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), BitCast(du, v).raw); +} +template +HWY_API void Stream(VFromD v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm512_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(VFromD v, D /* tag */, double* HWY_RESTRICT aligned) { + _mm512_stream_pd(aligned, v.raw); +} + +// ------------------------------ ScatterOffset + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> offset) { + _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> offset) { + _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, float* HWY_RESTRICT base, + Vec512 offset) { + _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, double* HWY_RESTRICT base, + Vec512 offset) { + _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); +} + +// ------------------------------ ScatterIndex + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, float* HWY_RESTRICT base, + Vec512 index) { + _mm512_i32scatter_ps(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, double* HWY_RESTRICT base, + Vec512 index) { + _mm512_i64scatter_pd(base, index.raw, v.raw, 8); +} + +// ------------------------------ MaskedScatterIndex + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm512_mask_i32scatter_epi32(base, m.raw, index.raw, v.raw, 4); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm512_mask_i64scatter_epi64(base, m.raw, index.raw, v.raw, 8); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + float* HWY_RESTRICT base, + Vec512 index) { + _mm512_mask_i32scatter_ps(base, m.raw, index.raw, v.raw, 4); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + double* HWY_RESTRICT base, + Vec512 index) { + _mm512_mask_i64scatter_pd(base, m.raw, index.raw, v.raw, 8); +} + +// ------------------------------ Gather + +namespace detail { + +template +HWY_INLINE Vec512 NativeGather512(const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i32gather_epi32(indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeGather512(const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i64gather_epi64(indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeGather512(const float* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i32gather_ps(indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeGather512(const double* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i64gather_pd(indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeMaskedGatherOr512(Vec512 no, Mask512 m, + const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{ + _mm512_mask_i32gather_epi32(no.raw, m.raw, indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeMaskedGatherOr512(Vec512 no, Mask512 m, + const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{ + _mm512_mask_i64gather_epi64(no.raw, m.raw, indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeMaskedGatherOr512(Vec512 no, + Mask512 m, + const float* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{ + _mm512_mask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeMaskedGatherOr512( + Vec512 no, Mask512 m, const double* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{ + _mm512_mask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; +} +} // namespace detail + +template +HWY_API VFromD GatherOffset(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> offsets) { + return detail::NativeGather512<1>(base, offsets); +} + +template +HWY_API VFromD GatherIndex(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeGather512)>(base, indices); +} + +template +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D /*d*/, + const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeMaskedGatherOr512)>(no, m, base, + indices); +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{_mm512_castsi512_si256(v.raw)}; +} +template +HWY_API VFromD LowerHalf(D /* tag */, Vec512 v) { + return VFromD{_mm512_castsi512_si256(v.raw)}; +} +template +HWY_API VFromD LowerHalf(D /* tag */, Vec512 v) { +#if HWY_HAVE_FLOAT16 + return VFromD{_mm512_castph512_ph256(v.raw)}; +#else + return VFromD{_mm512_castsi512_si256(v.raw)}; +#endif // HWY_HAVE_FLOAT16 +} +template +HWY_API VFromD LowerHalf(D /* tag */, Vec512 v) { + return VFromD{_mm512_castps512_ps256(v.raw)}; +} +template +HWY_API VFromD LowerHalf(D /* tag */, Vec512 v) { + return VFromD{_mm512_castpd512_pd256(v.raw)}; +} + +template +HWY_API Vec256 LowerHalf(Vec512 v) { + const Half> dh; + return LowerHalf(dh, v); +} + +// ------------------------------ UpperHalf + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + const RebindToUnsigned du; // for float16_t + const Twice dut; + return BitCast(d, VFromD{ + _mm512_extracti32x8_epi32(BitCast(dut, v).raw, 1)}); +} +template +HWY_API VFromD UpperHalf(D /* tag */, VFromD> v) { + return VFromD{_mm512_extractf32x8_ps(v.raw, 1)}; +} +template +HWY_API VFromD UpperHalf(D /* tag */, VFromD> v) { + return VFromD{_mm512_extractf64x4_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template +HWY_API T ExtractLane(const Vec512 v, size_t i) { + const DFromV d; + HWY_DASSERT(i < Lanes(d)); + +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + if (__builtin_constant_p(i < kLanesPerBlock) && (i < kLanesPerBlock)) { + return ExtractLane(ResizeBitCast(Full128(), v), i); + } +#endif + + alignas(64) T lanes[MaxLanes(d)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ ExtractBlock +template * = nullptr> +HWY_API Vec128 ExtractBlock(Vec512 v) { + const DFromV d; + const Half dh; + return ExtractBlock(LowerHalf(dh, v)); +} + +template 1)>* = nullptr> +HWY_API Vec128 ExtractBlock(Vec512 v) { + static_assert(kBlockIdx <= 3, "Invalid block index"); + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(Full128(), + Vec128>{ + _mm512_extracti32x4_epi32(BitCast(du, v).raw, kBlockIdx)}); +} + +template 1)>* = nullptr> +HWY_API Vec128 ExtractBlock(Vec512 v) { + static_assert(kBlockIdx <= 3, "Invalid block index"); + return Vec128{_mm512_extractf32x4_ps(v.raw, kBlockIdx)}; +} + +template 1)>* = nullptr> +HWY_API Vec128 ExtractBlock(Vec512 v) { + static_assert(kBlockIdx <= 3, "Invalid block index"); + return Vec128{_mm512_extractf64x2_pd(v.raw, kBlockIdx)}; +} + +// ------------------------------ InsertLane (Store) +template +HWY_API Vec512 InsertLane(const Vec512 v, size_t i, T t) { + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +// ------------------------------ InsertBlock +namespace detail { + +template +HWY_INLINE Vec512 InsertBlock(hwy::SizeTag<0> /* blk_idx_tag */, Vec512 v, + Vec128 blk_to_insert) { + const DFromV d; + const auto insert_mask = FirstN(d, 16 / sizeof(T)); + return IfThenElse(insert_mask, ResizeBitCast(d, blk_to_insert), v); +} + +template +HWY_INLINE Vec512 InsertBlock(hwy::SizeTag /* blk_idx_tag */, + Vec512 v, Vec128 blk_to_insert) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + const Full128> du_blk_to_insert; + return BitCast( + d, VFromD{_mm512_inserti32x4( + BitCast(du, v).raw, BitCast(du_blk_to_insert, blk_to_insert).raw, + static_cast(kBlockIdx & 3))}); +} + +template * = nullptr> +HWY_INLINE Vec512 InsertBlock(hwy::SizeTag /* blk_idx_tag */, + Vec512 v, + Vec128 blk_to_insert) { + return Vec512{_mm512_insertf32x4(v.raw, blk_to_insert.raw, + static_cast(kBlockIdx & 3))}; +} + +template * = nullptr> +HWY_INLINE Vec512 InsertBlock(hwy::SizeTag /* blk_idx_tag */, + Vec512 v, + Vec128 blk_to_insert) { + return Vec512{_mm512_insertf64x2(v.raw, blk_to_insert.raw, + static_cast(kBlockIdx & 3))}; +} + +} // namespace detail + +template +HWY_API Vec512 InsertBlock(Vec512 v, Vec128 blk_to_insert) { + static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); + return detail::InsertBlock(hwy::SizeTag(kBlockIdx)>(), v, + blk_to_insert); +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec512 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ZeroExtendVector + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { +#if HWY_HAVE_ZEXT // See definition/comment in x86_256-inl.h. + (void)d; + return VFromD{_mm512_zextsi256_si512(lo.raw)}; +#else + return VFromD{_mm512_inserti32x8(Zero(d).raw, lo.raw, 0)}; +#endif +} +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { +#if HWY_HAVE_ZEXT + (void)d; + return VFromD{_mm512_zextph256_ph512(lo.raw)}; +#else + const RebindToUnsigned du; + return BitCast(d, ZeroExtendVector(du, BitCast(du, lo))); +#endif +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { +#if HWY_HAVE_ZEXT + (void)d; + return VFromD{_mm512_zextps256_ps512(lo.raw)}; +#else + return VFromD{_mm512_insertf32x8(Zero(d).raw, lo.raw, 0)}; +#endif +} +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { +#if HWY_HAVE_ZEXT + (void)d; + return VFromD{_mm512_zextpd256_pd512(lo.raw)}; +#else + return VFromD{_mm512_insertf64x4(Zero(d).raw, lo.raw, 0)}; +#endif +} + +// ------------------------------ ZeroExtendResizeBitCast + +namespace detail { + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Repartition du8_from; + const auto vu8 = BitCast(du8_from, v); + const RebindToUnsigned du_to; +#if HWY_HAVE_ZEXT + return BitCast(d_to, + VFromD{_mm512_zextsi128_si512(vu8.raw)}); +#else + return BitCast(d_to, VFromD{ + _mm512_inserti32x4(Zero(du_to).raw, vu8.raw, 0)}); +#endif +} + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Repartition df32_from; + const auto vf32 = BitCast(df32_from, v); +#if HWY_HAVE_ZEXT + (void)d_to; + return Vec512{_mm512_zextps128_ps512(vf32.raw)}; +#else + return Vec512{_mm512_insertf32x4(Zero(d_to).raw, vf32.raw, 0)}; +#endif +} + +template +HWY_INLINE Vec512 ZeroExtendResizeBitCast( + hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Repartition df64_from; + const auto vf64 = BitCast(df64_from, v); +#if HWY_HAVE_ZEXT + (void)d_to; + return Vec512{_mm512_zextpd128_pd512(vf64.raw)}; +#else + return Vec512{_mm512_insertf64x2(Zero(d_to).raw, vf64.raw, 0)}; +#endif +} + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag<8> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Twice dt_from; + return ZeroExtendResizeBitCast(hwy::SizeTag<16>(), hwy::SizeTag<64>(), d_to, + dt_from, ZeroExtendVector(dt_from, v)); +} + +} // namespace detail + +// ------------------------------ Combine + +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + const RebindToUnsigned du; // for float16_t + const Half duh; + const __m512i lo512 = ZeroExtendVector(du, BitCast(duh, lo)).raw; + return BitCast(d, VFromD{ + _mm512_inserti32x8(lo512, BitCast(duh, hi).raw, 1)}); +} +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + return VFromD{_mm512_insertf32x8(ZeroExtendVector(d, lo).raw, hi.raw, 1)}; +} +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + return VFromD{_mm512_insertf64x4(ZeroExtendVector(d, lo).raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes +template +HWY_API VFromD ShiftLeftBytes(D /* tag */, const VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return VFromD{_mm512_bslli_epi128(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D /* tag */, const VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return VFromD{_mm512_bsrli_epi128(v.raw, kBytes)}; +} + +// ------------------------------ CombineShiftRightBytes + +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + const Repartition d8; + return BitCast(d, Vec512{_mm512_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(vu.raw, (0x55 * kLane) & 0xFF); + return BitCast(d, VU{_mm512_unpacklo_epi64(lo, lo)}); + } else { + const __m512i hi = + _mm512_shufflehi_epi16(vu.raw, (0x55 * (kLane - 4)) & 0xFF); + return BitCast(d, VU{_mm512_unpackhi_epi64(hi, hi)}); + } +} + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, perm)}; +} + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, perm)}; +} + +// ------------------------------ BroadcastBlock +template +HWY_API Vec512 BroadcastBlock(Vec512 v) { + static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast( + d, VFromD{_mm512_shuffle_i32x4( + BitCast(du, v).raw, BitCast(du, v).raw, 0x55 * kBlockIdx)}); +} + +template +HWY_API Vec512 BroadcastBlock(Vec512 v) { + static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, 0x55 * kBlockIdx)}; +} + +template +HWY_API Vec512 BroadcastBlock(Vec512 v) { + static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, 0x55 * kBlockIdx)}; +} + +// ------------------------------ BroadcastLane + +namespace detail { + +template +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{_mm512_broadcastb_epi8(ResizeBitCast(Full128(), v).raw)}; +} + +template +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm512_broadcastw_epi16( + ResizeBitCast(Full128(), v).raw)}); +} + +template +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{_mm512_broadcastd_epi32(ResizeBitCast(Full128(), v).raw)}; +} + +template +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{_mm512_broadcastq_epi64(ResizeBitCast(Full128(), v).raw)}; +} + +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{ + _mm512_broadcastss_ps(ResizeBitCast(Full128(), v).raw)}; +} + +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{ + _mm512_broadcastsd_pd(ResizeBitCast(Full128(), v).raw)}; +} + +template * = nullptr> +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag /* lane_idx_tag */, + Vec512 v) { + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + constexpr int kBlockIdx = static_cast(kLaneIdx / kLanesPerBlock); + constexpr int kLaneInBlkIdx = + static_cast(kLaneIdx) & (kLanesPerBlock - 1); + return Broadcast(BroadcastBlock(v)); +} + +} // namespace detail + +template +HWY_API Vec512 BroadcastLane(Vec512 v) { + static_assert(0 <= kLaneIdx, "Invalid lane"); + return detail::BroadcastLane(hwy::SizeTag(kLaneIdx)>(), + v); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec512 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; +} +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +namespace detail { + +template +HWY_API Vec512 ShuffleTwo2301(const Vec512 a, const Vec512 b) { + const DFromV d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_CDAB)}); +} +template +HWY_API Vec512 ShuffleTwo1230(const Vec512 a, const Vec512 b) { + const DFromV d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_BCDA)}); +} +template +HWY_API Vec512 ShuffleTwo3012(const Vec512 a, const Vec512 b) { + const DFromV d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_DABC)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + // Shorter encoding than _mm512_permute_ps. + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + // Shorter encoding than _mm512_permute_pd. + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; +} + +// Rotate right 32 bits +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; +} +// Rotate left 32 bits +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; +} + +// Reverse +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices512 { + __m512i raw; +}; + +template , typename TI> +HWY_API Indices512 IndicesFromVec(D /* tag */, Vec512 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const DFromV di; + const RebindToUnsigned du; + using TU = MakeUnsigned; + const auto vec_u = BitCast(du, vec); + HWY_DASSERT( + AllTrue(du, Lt(vec_u, Set(du, static_cast(128 / sizeof(T)))))); +#endif + return Indices512{vec.raw}; +} + +template +HWY_API Indices512> SetTableIndices(D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_permutexvar_epi8(idx.raw, v.raw)}; +#else + const DFromV d; + const Repartition du16; + const Vec512 idx_vec{idx.raw}; + + const auto bd_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, idx_vec)))); + const auto cd_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<2>(BitCast(du16, idx_vec)))); + + const Vec512 v_a{_mm512_shuffle_i32x4(v.raw, v.raw, 0x00)}; + const Vec512 v_b{_mm512_shuffle_i32x4(v.raw, v.raw, 0x55)}; + const Vec512 v_c{_mm512_shuffle_i32x4(v.raw, v.raw, 0xAA)}; + const Vec512 v_d{_mm512_shuffle_i32x4(v.raw, v.raw, 0xFF)}; + + const auto shuf_a = TableLookupBytes(v_a, idx_vec); + const auto shuf_c = TableLookupBytes(v_c, idx_vec); + const Vec512 shuf_ab{_mm512_mask_shuffle_epi8(shuf_a.raw, bd_sel_mask.raw, + v_b.raw, idx_vec.raw)}; + const Vec512 shuf_cd{_mm512_mask_shuffle_epi8(shuf_c.raw, bd_sel_mask.raw, + v_d.raw, idx_vec.raw)}; + return IfThenElse(cd_sel_mask, shuf_cd, shuf_ab); +#endif +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi16(idx.raw, v.raw)}; +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 TableLookupLanes(Vec512 v, + Indices512 idx) { + return Vec512{_mm512_permutexvar_ph(idx.raw, v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi32(idx.raw, v.raw)}; +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi64(idx.raw, v.raw)}; +} + +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_ps(idx.raw, v.raw)}; +} + +HWY_API Vec512 TableLookupLanes(Vec512 v, + Indices512 idx) { + return Vec512{_mm512_permutexvar_pd(idx.raw, v.raw)}; +} + +template +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_permutex2var_epi8(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const auto b_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<1>(Vec512{idx.raw}))); + return IfThenElse(b_sel_mask, TableLookupLanes(b, idx), + TableLookupLanes(a, idx)); +#endif +} + +template +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_epi16(a.raw, idx.raw, b.raw)}; +} + +template +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_epi32(a.raw, idx.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, + Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_ph(a.raw, idx.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_ps(a.raw, idx.raw, b.raw)}; +} + +template +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_epi64(a.raw, idx.raw, b.raw)}; +} + +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_pd(a.raw, idx.raw, b.raw)}; +} + +// ------------------------------ Reverse + +template +HWY_API VFromD Reverse(D d, const VFromD v) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToSigned di; + alignas(64) static constexpr int8_t kReverse[64] = { + 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, + 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec512 idx = Load(di, kReverse); + return BitCast( + d, Vec512{_mm512_permutexvar_epi8(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide d16; + return BitCast(d, Reverse(d16, RotateRight<8>(BitCast(d16, v)))); +#endif +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + const RebindToSigned di; + alignas(64) static constexpr int16_t kReverse[32] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec512 idx = Load(di, kReverse); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(64) static constexpr int32_t kReverse[16] = { + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(64) static constexpr int64_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +// ------------------------------ Reverse2 (in x86_128) + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToSigned di; + alignas(64) static constexpr int16_t kReverse4[32] = { + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, + 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; + const Vec512 idx = Load(di, kReverse4); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +// 32 bit Reverse4 defined in x86_128. + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return VFromD{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} +template +HWY_API VFromD Reverse4(D /* tag */, VFromD v) { + return VFromD{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToSigned di; + alignas(64) static constexpr int16_t kReverse8[32] = { + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToSigned di; + alignas(64) static constexpr int32_t kReverse8[16] = { + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + return Reverse(d, v); +} + +// ------------------------------ ReverseBits (GaloisAffine) + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +// Generic for all vector lengths. Must be defined after all GaloisAffine. +template +HWY_API V ReverseBits(V v) { + const Repartition> du64; + return detail::GaloisAffine(v, Set(du64, 0x8040201008040201u)); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ InterleaveLower + +template +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm512_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm512_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_pd(a.raw, b.raw)}; +} + +// ------------------------------ Concat* halves + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BABA)}); +} +template +HWY_API VFromD ConcatLowerLower(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +template +HWY_API Vec512 ConcatLowerLower(D /* tag */, Vec512 hi, + Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_DCDC)}); +} +template +HWY_API VFromD ConcatUpperUpper(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +template +HWY_API Vec512 ConcatUpperUpper(D /* tag */, Vec512 hi, + Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BADC)}); +} +template +HWY_API VFromD ConcatLowerUpper(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; +} +template +HWY_API Vec512 ConcatLowerUpper(D /* tag */, Vec512 hi, + Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks + // are efficiently loaded from 32-bit regs. + const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm512_mask_blend_epi16( + mask, BitCast(du, hi).raw, BitCast(du, lo).raw)}); +} +template +HWY_API VFromD ConcatUpperLower(D /* tag */, VFromD hi, VFromD lo) { + const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); + return VFromD{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; +} +template +HWY_API Vec512 ConcatUpperLower(D /* tag */, Vec512 hi, + Vec512 lo) { + const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); + return Vec512{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; +} + +// ------------------------------ ConcatOdd + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(64) static constexpr uint8_t kIdx[64] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, + 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, + 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, + 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, + 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec512 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec512 uL = ShiftRight<8>(BitCast(dw, lo)); + const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512 du64; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + return VFromD{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return VFromD{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +// ------------------------------ ConcatEven + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(64) static constexpr uint8_t kIdx[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, + 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, + 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, + 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, + 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec512 mask = Set(dw, 0x00FF); + const Vec512 uH = And(BitCast(dw, hi), mask); + const Vec512 uL = And(BitCast(dw, lo), mask); + const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512 du64; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + return VFromD{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return VFromD{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +// ------------------------------ InterleaveWholeLower + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(64) static constexpr uint8_t kIdx[64] = { + 0, 64, 1, 65, 2, 66, 3, 67, 4, 68, 5, 69, 6, 70, 7, 71, + 8, 72, 9, 73, 10, 74, 11, 75, 12, 76, 13, 77, 14, 78, 15, 79, + 16, 80, 17, 81, 18, 82, 19, 83, 20, 84, 21, 85, 22, 86, 23, 87, + 24, 88, 25, 89, 26, 90, 27, 91, 28, 92, 29, 93, 30, 94, 31, 95}; + return VFromD{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + alignas(64) static constexpr uint64_t kIdx2[8] = {0, 1, 8, 9, 2, 3, 10, 11}; + const Repartition du64; + return VFromD{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, + Load(du64, kIdx2).raw, + InterleaveUpper(d, a, b).raw)}; +#endif +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, + 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47}; + return BitCast( + d, VFromD{_mm512_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, + 4, 20, 5, 21, 6, 22, 7, 23}; + return VFromD{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, + 4, 20, 5, 21, 6, 22, 7, 23}; + return VFromD{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(64) static constexpr uint8_t kIdx[64] = { + 32, 96, 33, 97, 34, 98, 35, 99, 36, 100, 37, 101, 38, 102, 39, 103, + 40, 104, 41, 105, 42, 106, 43, 107, 44, 108, 45, 109, 46, 110, 47, 111, + 48, 112, 49, 113, 50, 114, 51, 115, 52, 116, 53, 117, 54, 118, 55, 119, + 56, 120, 57, 121, 58, 122, 59, 123, 60, 124, 61, 125, 62, 126, 63, 127}; + return VFromD{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + alignas(64) static constexpr uint64_t kIdx2[8] = {4, 5, 12, 13, 6, 7, 14, 15}; + const Repartition du64; + return VFromD{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, + Load(du64, kIdx2).raw, + InterleaveUpper(d, a, b).raw)}; +#endif +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, + 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63}; + return BitCast( + d, VFromD{_mm512_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + return VFromD{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + return VFromD{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; +} +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; +} + +template +HWY_API Vec512 DupEven(const Vec512 v) { + const DFromV d; + return InterleaveLower(d, v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; +} +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; +} + +template +HWY_API Vec512 DupOdd(const Vec512 v) { + const DFromV d; + return InterleaveUpper(d, v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_API Vec512 OddEven(const Vec512 a, const Vec512 b) { + constexpr size_t s = sizeof(T); + constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; + return IfThenElse(Mask512{0x5555555555555555ull >> shift}, b, a); +} + +// -------------------------- InterleaveEven + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_epi32( + a.raw, static_cast<__mmask16>(0xAAAA), b.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_ps(a.raw, static_cast<__mmask16>(0xAAAA), + b.raw, b.raw, + _MM_SHUFFLE(2, 2, 0, 0))}; +} +// -------------------------- InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_epi32( + b.raw, static_cast<__mmask16>(0x5555), a.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_ps(b.raw, static_cast<__mmask16>(0x5555), + a.raw, a.raw, + _MM_SHUFFLE(3, 3, 1, 1))}; +} + +// ------------------------------ OddEvenBlocks + +template +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast( + d, VFromD{_mm512_mask_blend_epi64( + __mmask8{0x33u}, BitCast(du, odd).raw, BitCast(du, even).raw)}); +} + +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{ + _mm512_mask_blend_ps(__mmask16{0x0F0Fu}, odd.raw, even.raw)}; +} + +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{ + _mm512_mask_blend_pd(__mmask8{0x33u}, odd.raw, even.raw)}; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_CDAB)}); +} + +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +// ------------------------------ InterleaveEvenBlocks +template +HWY_API Vec512 InterleaveEvenBlocks(Full512 d, Vec512 a, Vec512 b) { + return OddEvenBlocks(SlideUpBlocks<1>(d, b), a); +} + +// ------------------------------ InterleaveOddBlocks (ConcatUpperUpper) +template +HWY_API Vec512 InterleaveOddBlocks(Full512 d, Vec512 a, Vec512 b) { + return OddEvenBlocks(b, SlideDownBlocks<1>(d, a)); +} + +// ------------------------------ ReverseBlocks + +template +HWY_API VFromD ReverseBlocks(D d, VFromD v) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_ABCD)}); +} +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return VFromD{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return VFromD{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template +HWY_API Vec512 TableLookupBytes(Vec512 bytes, Vec512 indices) { + const DFromV d; + return BitCast(d, Vec512{_mm512_shuffle_epi8( + BitCast(Full512(), bytes).raw, + BitCast(Full512(), indices).raw)}); +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec512 bytes, Vec128 from) { + const Full512 d512; + const Half d256; + const Half d128; + // First expand to full 128, then 256, then 512. + const Vec128 from_full{from.raw}; + const auto from_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, from_full)); + const auto tbl_full = TableLookupBytes(bytes, from_512); + // Shrink to 256, then 128, then partial. + return Vec128{LowerHalf(d128, LowerHalf(d256, tbl_full)).raw}; +} +template +HWY_API Vec256 TableLookupBytes(Vec512 bytes, Vec256 from) { + const DFromV dih; + const Twice di; + const auto from_512 = ZeroExtendVector(di, from); + return LowerHalf(dih, TableLookupBytes(bytes, from_512)); +} + +// Partial table vector +template +HWY_API Vec512 TableLookupBytes(Vec128 bytes, Vec512 from) { + const DFromV d512; + const Half d256; + const Half d128; + // First expand to full 128, then 256, then 512. + const Vec128 bytes_full{bytes.raw}; + const auto bytes_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, bytes_full)); + return TableLookupBytes(bytes_512, from); +} +template +HWY_API Vec512 TableLookupBytes(Vec256 bytes, Vec512 from) { + const Full512 d; + return TableLookupBytes(ZeroExtendVector(d, bytes), from); +} + +// Partial both are handled by x86_128/256. + +// ------------------------------ I8/U8 Broadcast (TableLookupBytes) + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return TableLookupBytes(v, Set(Full512(), static_cast(kLane))); +} + +// ------------------------------ Per4LaneBlockShuffle + +namespace detail { + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + return BitCast(d, Vec512{_mm512_set_epi32( + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0), + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0), + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0), + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0))}); +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<64> /*vect_size_tag*/, V v) { + return V{ + _mm512_shuffle_epi32(v.raw, static_cast<_MM_PERM_ENUM>(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<64> /*vect_size_tag*/, V v) { + return V{_mm512_shuffle_ps(v.raw, v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<64> /*vect_size_tag*/, V v) { + return V{_mm512_permutex_epi64(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<64> /*vect_size_tag*/, V v) { + return V{_mm512_permutex_pd(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE V CombineShiftRightI32Lanes(V hi, V lo) { + const DFromV d; + const Repartition du32; + return BitCast(d, + Vec512{_mm512_alignr_epi32( + BitCast(du32, hi).raw, BitCast(du32, lo).raw, kI32Lanes)}); +} + +template +HWY_INLINE V CombineShiftRightI64Lanes(V hi, V lo) { + const DFromV d; + const Repartition du64; + return BitCast(d, + Vec512{_mm512_alignr_epi64( + BitCast(du64, hi).raw, BitCast(du64, lo).raw, kI64Lanes)}); +} + +template +HWY_INLINE V SlideUpI32Lanes(V v) { + static_assert(0 <= kI32Lanes && kI32Lanes <= 15, + "kI32Lanes must be between 0 and 15"); + const DFromV d; + return CombineShiftRightI32Lanes<16 - kI32Lanes>(v, Zero(d)); +} + +template +HWY_INLINE V SlideUpI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 7, + "kI64Lanes must be between 0 and 7"); + const DFromV d; + return CombineShiftRightI64Lanes<8 - kI64Lanes>(v, Zero(d)); +} + +template +HWY_INLINE VFromD TableLookupSlideUpLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + +#if HWY_TARGET <= HWY_AVX3_DL + const auto byte_idx = Iota(du8, static_cast(size_t{0} - amt)); + return TwoTablesLookupLanes(v, Zero(d), Indices512>{byte_idx.raw}); +#else + const Repartition du16; + const Repartition du64; + const auto byte_idx = Iota(du8, static_cast(size_t{0} - (amt & 15))); + const auto blk_u64_idx = + Iota(du64, static_cast(uint64_t{0} - ((amt >> 4) << 1))); + + const VFromD even_blocks{ + _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; + const VFromD odd_blocks{ + _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(3, 1, 1, 3))}; + const auto odd_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, byte_idx)))); + const auto even_blk_lookup_result = + BitCast(d, TableLookupBytes(even_blocks, byte_idx)); + const VFromD blockwise_slide_up_result{ + _mm512_mask_shuffle_epi8(even_blk_lookup_result.raw, odd_sel_mask.raw, + odd_blocks.raw, byte_idx.raw)}; + return BitCast(d, TwoTablesLookupLanes( + BitCast(du64, blockwise_slide_up_result), Zero(du64), + Indices512{blk_u64_idx.raw})); +#endif +} + +} // namespace detail + +template +HWY_API VFromD SlideUpBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 3, + "kBlocks must be between 0 and 3"); + switch (kBlocks) { + case 0: + return v; + case 1: + return detail::SlideUpI64Lanes<2>(v); + case 2: + return ConcatLowerLower(d, v, Zero(d)); + case 3: + return detail::SlideUpI64Lanes<6>(v); + } + + return v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return detail::SlideUpI32Lanes<1>(v); + case 2: + return detail::SlideUpI64Lanes<1>(v); + case 3: + return detail::SlideUpI32Lanes<3>(v); + case 4: + return detail::SlideUpI64Lanes<2>(v); + case 5: + return detail::SlideUpI32Lanes<5>(v); + case 6: + return detail::SlideUpI64Lanes<3>(v); + case 7: + return detail::SlideUpI32Lanes<7>(v); + case 8: + return ConcatLowerLower(d, v, Zero(d)); + case 9: + return detail::SlideUpI32Lanes<9>(v); + case 10: + return detail::SlideUpI64Lanes<5>(v); + case 11: + return detail::SlideUpI32Lanes<11>(v); + case 12: + return detail::SlideUpI64Lanes<6>(v); + case 13: + return detail::SlideUpI32Lanes<13>(v); + case 14: + return detail::SlideUpI64Lanes<7>(v); + case 15: + return detail::SlideUpI32Lanes<15>(v); + } + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return detail::SlideUpI64Lanes<1>(v); + case 2: + return detail::SlideUpI64Lanes<2>(v); + case 3: + return detail::SlideUpI64Lanes<3>(v); + case 4: + return ConcatLowerLower(d, v, Zero(d)); + case 5: + return detail::SlideUpI64Lanes<5>(v); + case 6: + return detail::SlideUpI64Lanes<6>(v); + case 7: + return detail::SlideUpI64Lanes<7>(v); + } + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + if ((amt & 3) == 0) { + const Repartition du32; + return BitCast(d, SlideUpLanes(du32, BitCast(du32, v), amt >> 2)); + } else if ((amt & 1) == 0) { + const Repartition du16; + return BitCast( + d, detail::TableLookupSlideUpLanes(du16, BitCast(du16, v), amt >> 1)); + } +#if HWY_TARGET > HWY_AVX3_DL + else if (amt <= 63) { // NOLINT(readability/braces) + const Repartition du64; + const size_t blk_u64_slideup_amt = (amt >> 4) << 1; + const auto vu64 = BitCast(du64, v); + const auto v_hi = + BitCast(d, SlideUpLanes(du64, vu64, blk_u64_slideup_amt)); + const auto v_lo = + (blk_u64_slideup_amt <= 4) + ? BitCast(d, SlideUpLanes(du64, vu64, blk_u64_slideup_amt + 2)) + : Zero(d); + switch (amt & 15) { + case 1: + return CombineShiftRightBytes<15>(d, v_hi, v_lo); + case 3: + return CombineShiftRightBytes<13>(d, v_hi, v_lo); + case 5: + return CombineShiftRightBytes<11>(d, v_hi, v_lo); + case 7: + return CombineShiftRightBytes<9>(d, v_hi, v_lo); + case 9: + return CombineShiftRightBytes<7>(d, v_hi, v_lo); + case 11: + return CombineShiftRightBytes<5>(d, v_hi, v_lo); + case 13: + return CombineShiftRightBytes<3>(d, v_hi, v_lo); + case 15: + return CombineShiftRightBytes<1>(d, v_hi, v_lo); + } + } +#endif // HWY_TARGET > HWY_AVX3_DL + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt) && (amt & 1) == 0) { + const Repartition du32; + return BitCast(d, SlideUpLanes(du32, BitCast(du32, v), amt >> 1)); + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +// ------------------------------ Slide1Up + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3_DL + return detail::TableLookupSlideUpLanes(d, v, 1); +#else + const auto v_lo = detail::SlideUpI64Lanes<2>(v); + return CombineShiftRightBytes<15>(d, v, v_lo); +#endif +} + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + return detail::TableLookupSlideUpLanes(d, v, 1); +} + +template +HWY_API VFromD Slide1Up(D /*d*/, VFromD v) { + return detail::SlideUpI32Lanes<1>(v); +} + +template +HWY_API VFromD Slide1Up(D /*d*/, VFromD v) { + return detail::SlideUpI64Lanes<1>(v); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE V SlideDownI32Lanes(V v) { + static_assert(0 <= kI32Lanes && kI32Lanes <= 15, + "kI32Lanes must be between 0 and 15"); + const DFromV d; + return CombineShiftRightI32Lanes(Zero(d), v); +} + +template +HWY_INLINE V SlideDownI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 7, + "kI64Lanes must be between 0 and 7"); + const DFromV d; + return CombineShiftRightI64Lanes(Zero(d), v); +} + +template +HWY_INLINE VFromD TableLookupSlideDownLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + +#if HWY_TARGET <= HWY_AVX3_DL + auto byte_idx = Iota(du8, static_cast(amt)); + return TwoTablesLookupLanes(v, Zero(d), Indices512>{byte_idx.raw}); +#else + const Repartition du16; + const Repartition du64; + const auto byte_idx = Iota(du8, static_cast(amt & 15)); + const auto blk_u64_idx = Iota(du64, static_cast(((amt >> 4) << 1))); + + const VFromD even_blocks{ + _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(0, 2, 2, 0))}; + const VFromD odd_blocks{ + _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; + const auto odd_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, byte_idx)))); + const VFromD even_blk_lookup_result{ + _mm512_maskz_shuffle_epi8(static_cast<__mmask64>(0x0000FFFFFFFFFFFFULL), + even_blocks.raw, byte_idx.raw)}; + const VFromD blockwise_slide_up_result{ + _mm512_mask_shuffle_epi8(even_blk_lookup_result.raw, odd_sel_mask.raw, + odd_blocks.raw, byte_idx.raw)}; + return BitCast(d, TwoTablesLookupLanes( + BitCast(du64, blockwise_slide_up_result), Zero(du64), + Indices512{blk_u64_idx.raw})); +#endif +} + +} // namespace detail + +template +HWY_API VFromD SlideDownBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 3, + "kBlocks must be between 0 and 3"); + const Half dh; + switch (kBlocks) { + case 0: + return v; + case 1: + return detail::SlideDownI64Lanes<2>(v); + case 2: + return ZeroExtendVector(d, UpperHalf(dh, v)); + case 3: + return detail::SlideDownI64Lanes<6>(v); + } + + return v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + const Half dh; + switch (amt) { + case 1: + return detail::SlideDownI32Lanes<1>(v); + case 2: + return detail::SlideDownI64Lanes<1>(v); + case 3: + return detail::SlideDownI32Lanes<3>(v); + case 4: + return detail::SlideDownI64Lanes<2>(v); + case 5: + return detail::SlideDownI32Lanes<5>(v); + case 6: + return detail::SlideDownI64Lanes<3>(v); + case 7: + return detail::SlideDownI32Lanes<7>(v); + case 8: + return ZeroExtendVector(d, UpperHalf(dh, v)); + case 9: + return detail::SlideDownI32Lanes<9>(v); + case 10: + return detail::SlideDownI64Lanes<5>(v); + case 11: + return detail::SlideDownI32Lanes<11>(v); + case 12: + return detail::SlideDownI64Lanes<6>(v); + case 13: + return detail::SlideDownI32Lanes<13>(v); + case 14: + return detail::SlideDownI64Lanes<7>(v); + case 15: + return detail::SlideDownI32Lanes<15>(v); + } + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + const Half dh; + switch (amt) { + case 0: + return v; + case 1: + return detail::SlideDownI64Lanes<1>(v); + case 2: + return detail::SlideDownI64Lanes<2>(v); + case 3: + return detail::SlideDownI64Lanes<3>(v); + case 4: + return ZeroExtendVector(d, UpperHalf(dh, v)); + case 5: + return detail::SlideDownI64Lanes<5>(v); + case 6: + return detail::SlideDownI64Lanes<6>(v); + case 7: + return detail::SlideDownI64Lanes<7>(v); + } + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + if ((amt & 3) == 0) { + const Repartition du32; + return BitCast(d, SlideDownLanes(du32, BitCast(du32, v), amt >> 2)); + } else if ((amt & 1) == 0) { + const Repartition du16; + return BitCast(d, detail::TableLookupSlideDownLanes( + du16, BitCast(du16, v), amt >> 1)); + } +#if HWY_TARGET > HWY_AVX3_DL + else if (amt <= 63) { // NOLINT(readability/braces) + const Repartition du64; + const size_t blk_u64_slidedown_amt = (amt >> 4) << 1; + const auto vu64 = BitCast(du64, v); + const auto v_lo = + BitCast(d, SlideDownLanes(du64, vu64, blk_u64_slidedown_amt)); + const auto v_hi = + (blk_u64_slidedown_amt <= 4) + ? BitCast(d, + SlideDownLanes(du64, vu64, blk_u64_slidedown_amt + 2)) + : Zero(d); + switch (amt & 15) { + case 1: + return CombineShiftRightBytes<1>(d, v_hi, v_lo); + case 3: + return CombineShiftRightBytes<3>(d, v_hi, v_lo); + case 5: + return CombineShiftRightBytes<5>(d, v_hi, v_lo); + case 7: + return CombineShiftRightBytes<7>(d, v_hi, v_lo); + case 9: + return CombineShiftRightBytes<9>(d, v_hi, v_lo); + case 11: + return CombineShiftRightBytes<11>(d, v_hi, v_lo); + case 13: + return CombineShiftRightBytes<13>(d, v_hi, v_lo); + case 15: + return CombineShiftRightBytes<15>(d, v_hi, v_lo); + } + } +#endif + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt) && (amt & 1) == 0) { + const Repartition du32; + return BitCast(d, SlideDownLanes(du32, BitCast(du32, v), amt >> 1)); + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +// ------------------------------ Slide1Down + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3_DL + return detail::TableLookupSlideDownLanes(d, v, 1); +#else + const auto v_hi = detail::SlideDownI64Lanes<2>(v); + return CombineShiftRightBytes<1>(d, v_hi, v); +#endif +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + return detail::TableLookupSlideDownLanes(d, v, 1); +} + +template +HWY_API VFromD Slide1Down(D /*d*/, VFromD v) { + return detail::SlideDownI32Lanes<1>(v); +} + +template +HWY_API VFromD Slide1Down(D /*d*/, VFromD v) { + return detail::SlideDownI64Lanes<1>(v); +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepu8_epi16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm512_cvtepu8_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepu16_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepu32_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm512_cvtepu16_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + return VFromD{_mm512_cvtepu8_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepi8_epi16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm512_cvtepi8_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepi16_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepi32_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm512_cvtepi16_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + return VFromD{_mm512_cvtepi8_epi64(v.raw)}; +} + +// Float +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { +#if HWY_HAVE_FLOAT16 + const RebindToUnsigned> du16; + return VFromD{_mm512_cvtph_ps(BitCast(du16, v).raw)}; +#else + return VFromD{_mm512_cvtph_ps(v.raw)}; +#endif // HWY_HAVE_FLOAT16 +} + +#if HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, Vec128 v) { + return VFromD{_mm512_cvtph_pd(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD PromoteTo(D df32, Vec256 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtps_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepi32_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepu32_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteInRangeTo(D /*di64*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epu64(v.raw)}; +#endif +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du64; + const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(du64, kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD DemoteTo(D dn, Vec512 v) { + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du64; + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(du64, kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du32; + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + const Vec512 u8{_mm512_packus_epi16(i16.raw, i16.raw)}; + + const VFromD idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtusepi32_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du64; + const Vec512 u8{_mm512_packus_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(du64, kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD DemoteTo(D dn, Vec512 v) { + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFu)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du32; + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + const Vec512 i8{_mm512_packs_epi16(i16.raw, i16.raw)}; + + const VFromD idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du64; + const Vec512 u8{_mm512_packs_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(du64, kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtsepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtsepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtsepi64_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; + return VFromD{_mm512_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; + return VFromD{_mm512_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; + return VFromD{_mm512_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtusepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtusepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtusepi64_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D df16, Vec512 v) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + const RebindToUnsigned du16; + return BitCast( + df16, VFromD{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); + HWY_DIAGNOSTICS(pop) +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD DemoteTo(D /*df16*/, Vec512 v) { + return VFromD{_mm512_cvtpd_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_AVX3_HAVE_F32_TO_BF16C +template +HWY_API VFromD DemoteTo(D /*dbf16*/, Vec512 v) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m256i raw_result; + __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + // The _mm512_cvtneps_pbh intrinsic returns a __m256bh vector that needs to be + // bit casted to a __m256i vector + return VFromD{detail::BitCastToInteger(_mm512_cvtneps_pbh(v.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /*dbf16*/, Vec512 a, + Vec512 b) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m512i raw_result; + __asm__("vcvtne2ps2bf16 %2, %1, %0" + : "=v"(raw_result) + : "v"(b.raw), "v"(a.raw)); + return VFromD{raw_result}; +#else + // The _mm512_cvtne2ps_pbh intrinsic returns a __m512bh vector that needs to + // be bit casted to a __m512i vector + return VFromD{detail::BitCastToInteger(_mm512_cvtne2ps_pbh(b.raw, a.raw))}; +#endif +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec512 a, + Vec512 b) { + return VFromD{_mm512_packs_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec512 a, + Vec512 b) { + return VFromD{_mm512_packus_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec512 a, + Vec512 b) { + const DFromV du32; + const RebindToSigned di32; + const auto max_i32 = Set(du32, 0x7FFFFFFFu); + + return ReorderDemote2To(dn, BitCast(di32, Min(a, max_i32)), + BitCast(di32, Min(b, max_i32))); +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec512 a, + Vec512 b) { + return VFromD{_mm512_packs_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec512 a, + Vec512 b) { + return VFromD{_mm512_packus_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec512 a, + Vec512 b) { + const DFromV du16; + const RebindToSigned di16; + const auto max_i16 = Set(du16, 0x7FFFu); + + return ReorderDemote2To(dn, BitCast(di16, Min(a, max_i16)), + BitCast(di16, Min(b, max_i16))); +} + +template ), + HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2), + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + const Full512 du64; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(BitCast(du64, ReorderDemote2To(d, a, b)), + SetTableIndices(du64, kIdx))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtpd_ps(v.raw)}; +} + +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epi32(v.raw)}; +#endif +} + +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epu32(v.raw)}; +#endif +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm512_cvtepi64_ps(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm512_cvtepu64_ps(v.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec512 v) { + const DFromV d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + const VFromD v8From32 = + Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); + const auto quads = TableLookupBytes(v, v8From32); + // Gather the lowest 4 bytes of 4 128-bit blocks. + const VFromD index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); + const Vec512 bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(D d, const Vec512 v) { +#if HWY_TARGET <= HWY_AVX3_DL + (void)d; + const Full512 d8; + const VFromD v8From64 = Dup128VecFromValues( + d8, 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56); + const Vec512 bytes{_mm512_permutexvar_epi8(v8From64.raw, v.raw)}; + return LowerHalf(LowerHalf(LowerHalf(bytes))); +#else + const Full512 d32; + alignas(64) static constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512 even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return TruncateTo(d, LowerHalf(even)); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { + const Full512 d16; + alignas(16) static constexpr uint16_t k16From64[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + const Vec512 bytes{ + _mm512_permutexvar_epi16(LoadDup128(d16, k16From64).raw, v.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { + const Full512 d32; + alignas(64) static constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512 even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { +#if HWY_TARGET <= HWY_AVX3_DL + const Full512 d8; + const VFromD v8From32 = Dup128VecFromValues( + d8, 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60); + const Vec512 bytes{_mm512_permutexvar_epi8(v8From32.raw, v.raw)}; +#else + const Full512 d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + const VFromD v8From32 = + Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); + const auto quads = TableLookupBytes(v, v8From32); + // Gather the lowest 4 bytes of 4 128-bit blocks. + const VFromD index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); + const Vec512 bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; +#endif + return LowerHalf(LowerHalf(bytes)); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { + const Full512 d16; + alignas(64) static constexpr uint16_t k16From32[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + const Vec512 bytes{ + _mm512_permutexvar_epi16(Load(d16, k16From32).raw, v.raw)}; + return LowerHalf(bytes); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { +#if HWY_TARGET <= HWY_AVX3_DL + const Full512 d8; + alignas(64) static constexpr uint8_t k8From16[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + const Vec512 bytes{ + _mm512_permutexvar_epi8(Load(d8, k8From16).raw, v.raw)}; +#else + const Full512 d32; + const VFromD v16From32 = Dup128VecFromValues( + d32, 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u); + const auto quads = TableLookupBytes(v, v16From32); + alignas(64) static constexpr uint32_t kIndex32[16] = { + 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(Load(d32, kIndex32).raw, quads.raw)}; +#endif + return LowerHalf(bytes); +} + +// ------------------------------ Convert integer <=> floating point + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtepu16_ph(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtepi16_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtepi32_ps(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtepi64_pd(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag*/, Vec512 v) { + return VFromD{_mm512_cvtepu32_ps(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag*/, Vec512 v) { + return VFromD{_mm512_cvtepu64_pd(v.raw)}; +} + +// Truncates (rounds toward zero). +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_set_epi16(detail::X86ConvertScalarFromFloat(raw_v[31]), + detail::X86ConvertScalarFromFloat(raw_v[30]), + detail::X86ConvertScalarFromFloat(raw_v[29]), + detail::X86ConvertScalarFromFloat(raw_v[28]), + detail::X86ConvertScalarFromFloat(raw_v[27]), + detail::X86ConvertScalarFromFloat(raw_v[26]), + detail::X86ConvertScalarFromFloat(raw_v[25]), + detail::X86ConvertScalarFromFloat(raw_v[24]), + detail::X86ConvertScalarFromFloat(raw_v[23]), + detail::X86ConvertScalarFromFloat(raw_v[22]), + detail::X86ConvertScalarFromFloat(raw_v[21]), + detail::X86ConvertScalarFromFloat(raw_v[20]), + detail::X86ConvertScalarFromFloat(raw_v[19]), + detail::X86ConvertScalarFromFloat(raw_v[18]), + detail::X86ConvertScalarFromFloat(raw_v[17]), + detail::X86ConvertScalarFromFloat(raw_v[16]), + detail::X86ConvertScalarFromFloat(raw_v[15]), + detail::X86ConvertScalarFromFloat(raw_v[14]), + detail::X86ConvertScalarFromFloat(raw_v[13]), + detail::X86ConvertScalarFromFloat(raw_v[12]), + detail::X86ConvertScalarFromFloat(raw_v[11]), + detail::X86ConvertScalarFromFloat(raw_v[10]), + detail::X86ConvertScalarFromFloat(raw_v[9]), + detail::X86ConvertScalarFromFloat(raw_v[8]), + detail::X86ConvertScalarFromFloat(raw_v[7]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[0]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttph_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttph_epu16 with GCC if any + // values of v[i] are not within the range of an uint16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_set_epi16( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[31])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[30])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[29])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[28])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[27])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[26])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[25])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[24])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[23])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[22])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[21])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[20])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[19])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[18])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[17])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[16])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[15])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[14])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[13])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[12])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[11])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[10])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[9])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[8])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttph2uw {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttph_epu16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi32( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]), + detail::X86ConvertScalarFromFloat(raw_v[8]), + detail::X86ConvertScalarFromFloat(raw_v[9]), + detail::X86ConvertScalarFromFloat(raw_v[10]), + detail::X86ConvertScalarFromFloat(raw_v[11]), + detail::X86ConvertScalarFromFloat(raw_v[12]), + detail::X86ConvertScalarFromFloat(raw_v[13]), + detail::X86ConvertScalarFromFloat(raw_v[14]), + detail::X86ConvertScalarFromFloat(raw_v[15]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(D /*di*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttps_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi32( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[8])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[9])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[10])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[11])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[12])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[13])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[14])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[15])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epu32(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epu64 with GCC if any + // values of v[i] are not within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttpd2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epu64(v.raw)}; +#endif +} + +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_setr_epi32(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]), + detail::X86ScalarNearestInt(raw_v[8]), + detail::X86ScalarNearestInt(raw_v[9]), + detail::X86ScalarNearestInt(raw_v[10]), + detail::X86ScalarNearestInt(raw_v[11]), + detail::X86ScalarNearestInt(raw_v[12]), + detail::X86ScalarNearestInt(raw_v[13]), + detail::X86ScalarNearestInt(raw_v[14]), + detail::X86ScalarNearestInt(raw_v[15]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvtps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtps_epi32(v.raw)}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*d*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_set_epi16(detail::X86ScalarNearestInt(raw_v[31]), + detail::X86ScalarNearestInt(raw_v[30]), + detail::X86ScalarNearestInt(raw_v[29]), + detail::X86ScalarNearestInt(raw_v[28]), + detail::X86ScalarNearestInt(raw_v[27]), + detail::X86ScalarNearestInt(raw_v[26]), + detail::X86ScalarNearestInt(raw_v[25]), + detail::X86ScalarNearestInt(raw_v[24]), + detail::X86ScalarNearestInt(raw_v[23]), + detail::X86ScalarNearestInt(raw_v[22]), + detail::X86ScalarNearestInt(raw_v[21]), + detail::X86ScalarNearestInt(raw_v[20]), + detail::X86ScalarNearestInt(raw_v[19]), + detail::X86ScalarNearestInt(raw_v[18]), + detail::X86ScalarNearestInt(raw_v[17]), + detail::X86ScalarNearestInt(raw_v[16]), + detail::X86ScalarNearestInt(raw_v[15]), + detail::X86ScalarNearestInt(raw_v[14]), + detail::X86ScalarNearestInt(raw_v[13]), + detail::X86ScalarNearestInt(raw_v[12]), + detail::X86ScalarNearestInt(raw_v[11]), + detail::X86ScalarNearestInt(raw_v[10]), + detail::X86ScalarNearestInt(raw_v[9]), + detail::X86ScalarNearestInt(raw_v[8]), + detail::X86ScalarNearestInt(raw_v[7]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[0]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvtph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtph_epi16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 + +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_setr_epi64(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvtpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtpd_epi64(v.raw)}; +#endif +} + +template +static HWY_INLINE VFromD DemoteToNearestIntInRange(DI /* tag */, + Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi32(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtpd_epi32(v.raw)}; +#endif +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +HWY_API Vec512 AESRound(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_aesenc_epi128(state.raw, round_key.raw)}; +#else + const DFromV d; + const Half d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 AESLastRound(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const DFromV d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 AESRoundInv(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_aesdec_epi128(state.raw, round_key.raw)}; +#else + const Full512 d; + const Half d2; + return Combine(d, AESRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRoundInv(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 AESLastRoundInv(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_aesdeclast_epi128(state.raw, round_key.raw)}; +#else + const Full512 d; + const Half d2; + return Combine( + d, AESLastRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRoundInv(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +template +HWY_API Vec512 AESKeyGenAssist(Vec512 v) { + const Full512 d; +#if HWY_TARGET <= HWY_AVX3_DL + const VFromD rconXorMask = Dup128VecFromValues( + d, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0); + const VFromD rotWordShuffle = Dup128VecFromValues( + d, 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12); + const Repartition du32; + const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); + const auto sub_word_result = AESLastRound(w13, rconXorMask); + return TableLookupBytes(sub_word_result, rotWordShuffle); +#else + const Half d2; + return Combine(d, AESKeyGenAssist(UpperHalf(d2, v)), + AESKeyGenAssist(LowerHalf(v))); +#endif +} + +HWY_API Vec512 CLMulLower(Vec512 va, Vec512 vb) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x00)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const DFromV d; + const Half> d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulLower(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +HWY_API Vec512 CLMulUpper(Vec512 va, Vec512 vb) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x11)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const DFromV d; + const Half> d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulUpper(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +// ------------------------------ SumsOfAdjQuadAbsDiff (Broadcast, +// SumsOfAdjShufQuadAbsDiff) + +template +static Vec512 SumsOfAdjQuadAbsDiff(Vec512 a, + Vec512 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + + const DFromV d; + const RepartitionToWideX2 du32; + + // While AVX3 does not have a _mm512_mpsadbw_epu8 intrinsic, the + // SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on + // AVX3 using SumsOfShuffledQuadAbsDiff and U32 Broadcast. + return SumsOfShuffledQuadAbsDiff( + a, BitCast(d, Broadcast(BitCast(du32, b)))); +} + +#if !HWY_IS_MSAN +// ------------------------------ I32/I64 SaturatedAdd (MaskFromVec) + +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + const DFromV d; + const auto sum = a + b; + const auto overflow_mask = MaskFromVec( + Vec512{_mm512_ternarylogic_epi32(a.raw, b.raw, sum.raw, 0x42)}); + const auto i32_max = Set(d, LimitsMax()); + const Vec512 overflow_result{_mm512_mask_ternarylogic_epi32( + i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, sum); +} + +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + const DFromV d; + const auto sum = a + b; + const auto overflow_mask = MaskFromVec( + Vec512{_mm512_ternarylogic_epi64(a.raw, b.raw, sum.raw, 0x42)}); + const auto i64_max = Set(d, LimitsMax()); + const Vec512 overflow_result{_mm512_mask_ternarylogic_epi64( + i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, sum); +} + +// ------------------------------ I32/I64 SaturatedSub (MaskFromVec) + +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + const DFromV d; + const auto diff = a - b; + const auto overflow_mask = MaskFromVec( + Vec512{_mm512_ternarylogic_epi32(a.raw, b.raw, diff.raw, 0x18)}); + const auto i32_max = Set(d, LimitsMax()); + const Vec512 overflow_result{_mm512_mask_ternarylogic_epi32( + i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, diff); +} + +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + const DFromV d; + const auto diff = a - b; + const auto overflow_mask = MaskFromVec( + Vec512{_mm512_ternarylogic_epi64(a.raw, b.raw, diff.raw, 0x18)}); + const auto i64_max = Set(d, LimitsMax()); + const Vec512 overflow_result{_mm512_mask_ternarylogic_epi64( + i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)}; + return IfThenElse(overflow_mask, overflow_result, diff); +} +#endif // !HWY_IS_MSAN + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} + +} // namespace detail + +template +HWY_API bool AllFalse(D /* tag */, const MFromD mask) { + return detail::AllFalse(hwy::SizeTag)>(), mask); +} + +namespace detail { + +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFFFFFFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFull; +#endif +} + +} // namespace detail + +template +HWY_API bool AllTrue(D /* tag */, const MFromD mask) { + return detail::AllTrue(hwy::SizeTag)>(), mask); +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D /* tag */, const uint8_t* HWY_RESTRICT bits) { + MFromD mask; + CopyBytes<8 / sizeof(TFromD)>(bits, &mask.raw); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return mask; +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D /* tag */, MFromD mask, uint8_t* bits) { + const size_t kNumBytes = 8 / sizeof(TFromD); + CopyBytes(&mask.raw, bits); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return kNumBytes; +} + +template +HWY_API size_t CountTrue(D /* tag */, const MFromD mask) { + return PopCount(static_cast(mask.raw)); +} + +template +HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD mask) { + return Num0BitsBelowLS1Bit_Nonzero64(mask.raw); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +template +HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD mask) { + return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD mask) { + return 63 - Num0BitsAboveMS1Bit_Nonzero64(mask.raw); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + return mask.raw ? static_cast(FindKnownLastTrue(d, mask)) + : intptr_t{-1}; +} + +// ------------------------------ Compress + +template +HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) static constexpr uint64_t packed_array[256] = { + // From PrintCompress32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x76543210, 0x76543201, 0x76543210, 0x76543102, 0x76543120, + 0x76543021, 0x76543210, 0x76542103, 0x76542130, 0x76542031, 0x76542310, + 0x76541032, 0x76541320, 0x76540321, 0x76543210, 0x76532104, 0x76532140, + 0x76532041, 0x76532410, 0x76531042, 0x76531420, 0x76530421, 0x76534210, + 0x76521043, 0x76521430, 0x76520431, 0x76524310, 0x76510432, 0x76514320, + 0x76504321, 0x76543210, 0x76432105, 0x76432150, 0x76432051, 0x76432510, + 0x76431052, 0x76431520, 0x76430521, 0x76435210, 0x76421053, 0x76421530, + 0x76420531, 0x76425310, 0x76410532, 0x76415320, 0x76405321, 0x76453210, + 0x76321054, 0x76321540, 0x76320541, 0x76325410, 0x76310542, 0x76315420, + 0x76305421, 0x76354210, 0x76210543, 0x76215430, 0x76205431, 0x76254310, + 0x76105432, 0x76154320, 0x76054321, 0x76543210, 0x75432106, 0x75432160, + 0x75432061, 0x75432610, 0x75431062, 0x75431620, 0x75430621, 0x75436210, + 0x75421063, 0x75421630, 0x75420631, 0x75426310, 0x75410632, 0x75416320, + 0x75406321, 0x75463210, 0x75321064, 0x75321640, 0x75320641, 0x75326410, + 0x75310642, 0x75316420, 0x75306421, 0x75364210, 0x75210643, 0x75216430, + 0x75206431, 0x75264310, 0x75106432, 0x75164320, 0x75064321, 0x75643210, + 0x74321065, 0x74321650, 0x74320651, 0x74326510, 0x74310652, 0x74316520, + 0x74306521, 0x74365210, 0x74210653, 0x74216530, 0x74206531, 0x74265310, + 0x74106532, 0x74165320, 0x74065321, 0x74653210, 0x73210654, 0x73216540, + 0x73206541, 0x73265410, 0x73106542, 0x73165420, 0x73065421, 0x73654210, + 0x72106543, 0x72165430, 0x72065431, 0x72654310, 0x71065432, 0x71654320, + 0x70654321, 0x76543210, 0x65432107, 0x65432170, 0x65432071, 0x65432710, + 0x65431072, 0x65431720, 0x65430721, 0x65437210, 0x65421073, 0x65421730, + 0x65420731, 0x65427310, 0x65410732, 0x65417320, 0x65407321, 0x65473210, + 0x65321074, 0x65321740, 0x65320741, 0x65327410, 0x65310742, 0x65317420, + 0x65307421, 0x65374210, 0x65210743, 0x65217430, 0x65207431, 0x65274310, + 0x65107432, 0x65174320, 0x65074321, 0x65743210, 0x64321075, 0x64321750, + 0x64320751, 0x64327510, 0x64310752, 0x64317520, 0x64307521, 0x64375210, + 0x64210753, 0x64217530, 0x64207531, 0x64275310, 0x64107532, 0x64175320, + 0x64075321, 0x64753210, 0x63210754, 0x63217540, 0x63207541, 0x63275410, + 0x63107542, 0x63175420, 0x63075421, 0x63754210, 0x62107543, 0x62175430, + 0x62075431, 0x62754310, 0x61075432, 0x61754320, 0x60754321, 0x67543210, + 0x54321076, 0x54321760, 0x54320761, 0x54327610, 0x54310762, 0x54317620, + 0x54307621, 0x54376210, 0x54210763, 0x54217630, 0x54207631, 0x54276310, + 0x54107632, 0x54176320, 0x54076321, 0x54763210, 0x53210764, 0x53217640, + 0x53207641, 0x53276410, 0x53107642, 0x53176420, 0x53076421, 0x53764210, + 0x52107643, 0x52176430, 0x52076431, 0x52764310, 0x51076432, 0x51764320, + 0x50764321, 0x57643210, 0x43210765, 0x43217650, 0x43207651, 0x43276510, + 0x43107652, 0x43176520, 0x43076521, 0x43765210, 0x42107653, 0x42176530, + 0x42076531, 0x42765310, 0x41076532, 0x41765320, 0x40765321, 0x47653210, + 0x32107654, 0x32176540, 0x32076541, 0x32765410, 0x31076542, 0x31765420, + 0x30765421, 0x37654210, 0x21076543, 0x21765430, 0x20765431, 0x27654310, + 0x10765432, 0x17654320, 0x07654321, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) static constexpr uint64_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ Expand + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 +HWY_INLINE Vec512 NativeExpand(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_expand_epi8(mask.raw, v.raw)}; +} + +HWY_INLINE Vec512 NativeExpand(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_expand_epi16(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(Mask512 mask, D /* d */, + const uint8_t* HWY_RESTRICT unaligned) { + return VFromD{_mm512_maskz_expandloadu_epi8(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(Mask512 mask, D /* d */, + const uint16_t* HWY_RESTRICT unaligned) { + return VFromD{_mm512_maskz_expandloadu_epi16(mask.raw, unaligned)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +HWY_INLINE Vec512 NativeExpand(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_expand_epi32(mask.raw, v.raw)}; +} + +HWY_INLINE Vec512 NativeExpand(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_expand_epi64(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(Mask512 mask, D /* d */, + const uint32_t* HWY_RESTRICT unaligned) { + return VFromD{_mm512_maskz_expandloadu_epi32(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(Mask512 mask, D /* d */, + const uint64_t* HWY_RESTRICT unaligned) { + return VFromD{_mm512_maskz_expandloadu_epi64(mask.raw, unaligned)}; +} + +} // namespace detail + +template +HWY_API Vec512 Expand(Vec512 v, const Mask512 mask) { + const Full512 d; +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +#else + // LUTs are infeasible for 2^64 possible masks, so splice together two + // half-vector Expand. + const Full256 dh; + constexpr size_t N = MaxLanes(d); + // We have to shift the input by a variable number of u8. Shuffling requires + // VBMI2, in which case we would already have NativeExpand. We instead + // load at an offset, which may incur a store to load forwarding stall. + alignas(64) T lanes[N]; + Store(v, d, lanes); + using Bits = typename Mask256::Raw; + const Mask256 maskL{ + static_cast(mask.raw & Bits{(1ULL << (N / 2)) - 1})}; + const Mask256 maskH{static_cast(mask.raw >> (N / 2))}; + const size_t countL = CountTrue(dh, maskL); + const Vec256 expandL = Expand(LowerHalf(v), maskL); + const Vec256 expandH = Expand(LoadU(dh, lanes + countL), maskH); + return Combine(d, expandH, expandL); +#endif +} + +template +HWY_API Vec512 Expand(Vec512 v, const Mask512 mask) { + const Full512 d; + const RebindToUnsigned du; + const Vec512 vu = BitCast(du, v); +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + return BitCast(d, detail::NativeExpand(vu, RebindMask(du, mask))); +#else // AVX3 + // LUTs are infeasible for 2^32 possible masks, so splice together two + // half-vector Expand. + const Full256 dh; + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + using Bits = typename Mask256::Raw; + const Mask256 maskL{ + static_cast(mask.raw & static_cast((1ULL << (N / 2)) - 1))}; + const Mask256 maskH{static_cast(mask.raw >> (N / 2))}; + // In AVX3 we can permutevar, which avoids a potential store to load + // forwarding stall vs. reloading the input. + alignas(64) uint16_t iota[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + const Vec512 indices = LoadU(du, iota + CountTrue(dh, maskL)); + const Vec512 shifted{_mm512_permutexvar_epi16(indices.raw, vu.raw)}; + const Vec256 expandL = Expand(LowerHalf(v), maskL); + const Vec256 expandH = Expand(LowerHalf(BitCast(d, shifted)), maskH); + return Combine(d, expandH, expandL); +#endif // AVX3 +} + +template +HWY_API V Expand(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +} + +// For smaller vectors, it is likely more efficient to promote to 32-bit. +// This works for u8x16, u16x8, u16x16 (can be promoted to u32x16), but is +// unnecessary if HWY_AVX3_DL, which provides native instructions. +#if HWY_TARGET > HWY_AVX3_DL // no VBMI2 + +template , 16)> +HWY_API V Expand(V v, M mask) { + const DFromV d; + const RebindToUnsigned du; + const Rebind du32; + const VFromD vu = BitCast(du, v); + using M32 = MFromD; + const M32 m32{static_cast(mask.raw)}; + return BitCast(d, TruncateTo(du, Expand(PromoteTo(du32, vu), m32))); +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +} + +// ------------------------------ CompressNot + +template +HWY_API Vec512 CompressNot(Vec512 v, Mask512 mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) static constexpr uint64_t packed_array[256] = { + // From PrintCompressNot32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x07654321, 0x17654320, 0x10765432, 0x27654310, 0x20765431, + 0x21765430, 0x21076543, 0x37654210, 0x30765421, 0x31765420, 0x31076542, + 0x32765410, 0x32076541, 0x32176540, 0x32107654, 0x47653210, 0x40765321, + 0x41765320, 0x41076532, 0x42765310, 0x42076531, 0x42176530, 0x42107653, + 0x43765210, 0x43076521, 0x43176520, 0x43107652, 0x43276510, 0x43207651, + 0x43217650, 0x43210765, 0x57643210, 0x50764321, 0x51764320, 0x51076432, + 0x52764310, 0x52076431, 0x52176430, 0x52107643, 0x53764210, 0x53076421, + 0x53176420, 0x53107642, 0x53276410, 0x53207641, 0x53217640, 0x53210764, + 0x54763210, 0x54076321, 0x54176320, 0x54107632, 0x54276310, 0x54207631, + 0x54217630, 0x54210763, 0x54376210, 0x54307621, 0x54317620, 0x54310762, + 0x54327610, 0x54320761, 0x54321760, 0x54321076, 0x67543210, 0x60754321, + 0x61754320, 0x61075432, 0x62754310, 0x62075431, 0x62175430, 0x62107543, + 0x63754210, 0x63075421, 0x63175420, 0x63107542, 0x63275410, 0x63207541, + 0x63217540, 0x63210754, 0x64753210, 0x64075321, 0x64175320, 0x64107532, + 0x64275310, 0x64207531, 0x64217530, 0x64210753, 0x64375210, 0x64307521, + 0x64317520, 0x64310752, 0x64327510, 0x64320751, 0x64321750, 0x64321075, + 0x65743210, 0x65074321, 0x65174320, 0x65107432, 0x65274310, 0x65207431, + 0x65217430, 0x65210743, 0x65374210, 0x65307421, 0x65317420, 0x65310742, + 0x65327410, 0x65320741, 0x65321740, 0x65321074, 0x65473210, 0x65407321, + 0x65417320, 0x65410732, 0x65427310, 0x65420731, 0x65421730, 0x65421073, + 0x65437210, 0x65430721, 0x65431720, 0x65431072, 0x65432710, 0x65432071, + 0x65432170, 0x65432107, 0x76543210, 0x70654321, 0x71654320, 0x71065432, + 0x72654310, 0x72065431, 0x72165430, 0x72106543, 0x73654210, 0x73065421, + 0x73165420, 0x73106542, 0x73265410, 0x73206541, 0x73216540, 0x73210654, + 0x74653210, 0x74065321, 0x74165320, 0x74106532, 0x74265310, 0x74206531, + 0x74216530, 0x74210653, 0x74365210, 0x74306521, 0x74316520, 0x74310652, + 0x74326510, 0x74320651, 0x74321650, 0x74321065, 0x75643210, 0x75064321, + 0x75164320, 0x75106432, 0x75264310, 0x75206431, 0x75216430, 0x75210643, + 0x75364210, 0x75306421, 0x75316420, 0x75310642, 0x75326410, 0x75320641, + 0x75321640, 0x75321064, 0x75463210, 0x75406321, 0x75416320, 0x75410632, + 0x75426310, 0x75420631, 0x75421630, 0x75421063, 0x75436210, 0x75430621, + 0x75431620, 0x75431062, 0x75432610, 0x75432061, 0x75432160, 0x75432106, + 0x76543210, 0x76054321, 0x76154320, 0x76105432, 0x76254310, 0x76205431, + 0x76215430, 0x76210543, 0x76354210, 0x76305421, 0x76315420, 0x76310542, + 0x76325410, 0x76320541, 0x76321540, 0x76321054, 0x76453210, 0x76405321, + 0x76415320, 0x76410532, 0x76425310, 0x76420531, 0x76421530, 0x76421053, + 0x76435210, 0x76430521, 0x76431520, 0x76431052, 0x76432510, 0x76432051, + 0x76432150, 0x76432105, 0x76543210, 0x76504321, 0x76514320, 0x76510432, + 0x76524310, 0x76520431, 0x76521430, 0x76521043, 0x76534210, 0x76530421, + 0x76531420, 0x76531042, 0x76532410, 0x76532041, 0x76532140, 0x76532104, + 0x76543210, 0x76540321, 0x76541320, 0x76541032, 0x76542310, 0x76542031, + 0x76542130, 0x76542103, 0x76543210, 0x76543021, 0x76543120, 0x76543102, + 0x76543210, 0x76543201, 0x76543210, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) static constexpr uint64_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ LoadInterleaved4 + +// Actually implemented in generic_ops, we just overload LoadTransposedBlocks4. +namespace detail { + +// Type-safe wrapper. +template <_MM_PERM_ENUM kPerm, typename T> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_shuffle_i64x2( + BitCast(du, lo).raw, BitCast(du, hi).raw, kPerm)}); +} +template <_MM_PERM_ENUM kPerm> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, kPerm)}; +} +template <_MM_PERM_ENUM kPerm> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, kPerm)}; +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// Output: +// 9 6 3 0 (LSB of A) +// a 7 4 1 +// b 8 5 2 +template +HWY_API void LoadTransposedBlocks3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& A, VFromD& B, VFromD& C) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD v3210 = LoadU(d, unaligned + 0 * N); + const VFromD v7654 = LoadU(d, unaligned + 1 * N); + const VFromD vba98 = LoadU(d, unaligned + 2 * N); + + const VFromD v5421 = detail::Shuffle128<_MM_PERM_BACB>(v3210, v7654); + const VFromD va976 = detail::Shuffle128<_MM_PERM_CBDC>(v7654, vba98); + + A = detail::Shuffle128<_MM_PERM_CADA>(v3210, va976); + B = detail::Shuffle128<_MM_PERM_DBCA>(v5421, va976); + C = detail::Shuffle128<_MM_PERM_DADB>(v5421, vba98); +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// f e d c +// Output: +// c 8 4 0 (LSB of A) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +template +HWY_API void LoadTransposedBlocks4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& vA, VFromD& vB, VFromD& vC, + VFromD& vD) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD v3210 = LoadU(d, unaligned + 0 * N); + const VFromD v7654 = LoadU(d, unaligned + 1 * N); + const VFromD vba98 = LoadU(d, unaligned + 2 * N); + const VFromD vfedc = LoadU(d, unaligned + 3 * N); + + const VFromD v5410 = detail::Shuffle128<_MM_PERM_BABA>(v3210, v7654); + const VFromD vdc98 = detail::Shuffle128<_MM_PERM_BABA>(vba98, vfedc); + const VFromD v7632 = detail::Shuffle128<_MM_PERM_DCDC>(v3210, v7654); + const VFromD vfeba = detail::Shuffle128<_MM_PERM_DCDC>(vba98, vfedc); + vA = detail::Shuffle128<_MM_PERM_CACA>(v5410, vdc98); + vB = detail::Shuffle128<_MM_PERM_DBDB>(v5410, vdc98); + vC = detail::Shuffle128<_MM_PERM_CACA>(v7632, vfeba); + vD = detail::Shuffle128<_MM_PERM_DBDB>(v7632, vfeba); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2 + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 6 4 2 0 (LSB of i) +// 7 5 3 1 +// Output: +// 3 2 1 0 +// 7 6 5 4 +template +HWY_API void StoreTransposedBlocks2(const VFromD i, const VFromD j, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const auto j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const auto j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const auto j1_i1_j0_i0 = + detail::Shuffle128<_MM_PERM_DBCA>(j1_j0_i1_i0, j1_j0_i1_i0); + const auto j3_i3_j2_i2 = + detail::Shuffle128<_MM_PERM_DBCA>(j3_j2_i3_i2, j3_j2_i3_i2); + StoreU(j1_i1_j0_i0, d, unaligned + 0 * N); + StoreU(j3_i3_j2_i2, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 9 6 3 0 (LSB of i) +// a 7 4 1 +// b 8 5 2 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +template +HWY_API void StoreTransposedBlocks3(const VFromD i, const VFromD j, + const VFromD k, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD j2_j0_i2_i0 = detail::Shuffle128<_MM_PERM_CACA>(i, j); + const VFromD i3_i1_k2_k0 = detail::Shuffle128<_MM_PERM_DBCA>(k, i); + const VFromD j3_j1_k3_k1 = detail::Shuffle128<_MM_PERM_DBDB>(k, j); + + const VFromD out0 = // i1 k0 j0 i0 + detail::Shuffle128<_MM_PERM_CACA>(j2_j0_i2_i0, i3_i1_k2_k0); + const VFromD out1 = // j2 i2 k1 j1 + detail::Shuffle128<_MM_PERM_DBAC>(j3_j1_k3_k1, j2_j0_i2_i0); + const VFromD out2 = // k3 j3 i3 k2 + detail::Shuffle128<_MM_PERM_BDDB>(i3_i1_k2_k0, j3_j1_k3_k1); + + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// c 8 4 0 (LSB of i) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +// f e d c +template +HWY_API void StoreTransposedBlocks4(const VFromD i, const VFromD j, + const VFromD k, const VFromD l, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const VFromD l1_l0_k1_k0 = detail::Shuffle128<_MM_PERM_BABA>(k, l); + const VFromD j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const VFromD l3_l2_k3_k2 = detail::Shuffle128<_MM_PERM_DCDC>(k, l); + const VFromD out0 = + detail::Shuffle128<_MM_PERM_CACA>(j1_j0_i1_i0, l1_l0_k1_k0); + const VFromD out1 = + detail::Shuffle128<_MM_PERM_DBDB>(j1_j0_i1_i0, l1_l0_k1_k0); + const VFromD out2 = + detail::Shuffle128<_MM_PERM_CACA>(j3_j2_i3_i2, l3_l2_k3_k2); + const VFromD out3 = + detail::Shuffle128<_MM_PERM_DBDB>(j3_j2_i3_i2, l3_l2_k3_k2); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ Additional mask logical operations + +template +HWY_API Mask512 SetAtOrAfterFirst(Mask512 mask) { + return Mask512{ + static_cast::Raw>(0u - detail::AVX3Blsi(mask.raw))}; +} +template +HWY_API Mask512 SetBeforeFirst(Mask512 mask) { + return Mask512{ + static_cast::Raw>(detail::AVX3Blsi(mask.raw) - 1u)}; +} +template +HWY_API Mask512 SetAtOrBeforeFirst(Mask512 mask) { + return Mask512{ + static_cast::Raw>(detail::AVX3Blsmsk(mask.raw))}; +} +template +HWY_API Mask512 SetOnlyFirst(Mask512 mask) { + return Mask512{ + static_cast::Raw>(detail::AVX3Blsi(mask.raw))}; +} + +// ------------------------------ Shl (Dup128VecFromValues) + +HWY_API Vec512 operator<<(Vec512 v, Vec512 bits) { + return Vec512{_mm512_sllv_epi16(v.raw, bits.raw)}; +} + +// 8-bit: may use the << overload for uint16_t. +HWY_API Vec512 operator<<(Vec512 v, Vec512 bits) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL + // kMask[i] = 0xFF >> i + const VFromD masks = + Dup128VecFromValues(d, 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0, + 0, 0, 0, 0, 0, 0, 0); + // kShl[i] = 1 << i + const VFromD shl = + Dup128VecFromValues(d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0, + 0, 0, 0, 0, 0, 0, 0); + v = And(v, TableLookupBytes(masks, bits)); + const VFromD mul = TableLookupBytes(shl, bits); + return VFromD{_mm512_gf2p8mul_epi8(v.raw, mul.raw)}; +#else + const Repartition dw; + using VW = VFromD; + const VW even_mask = Set(dw, 0x00FF); + const VW odd_mask = Set(dw, 0xFF00); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + // Shift even lanes in-place + const VW evens = vw << And(bits16, even_mask); + const VW odds = And(vw, odd_mask) << ShiftRight<8>(bits16); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +#endif +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { + const DFromV di; + const RebindToUnsigned du; + return BitCast(di, BitCast(du, v) << BitCast(du, bits)); +} + +// ------------------------------ Shr (IfVecThenElse) + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi16(v.raw, bits.raw)}; +} + +// 8-bit uses 16-bit shifts. +HWY_API Vec512 operator>>(Vec512 v, Vec512 bits) { + const DFromV d; + const RepartitionToWide dw; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + const VW evens = And(vw, mask) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> ShiftRight<8>(bits16); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi16(v.raw, bits.raw)}; +} + +// 8-bit uses 16-bit shifts. +HWY_API Vec512 operator>>(Vec512 v, Vec512 bits) { + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned dw_u; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + const VW evens = ShiftRight<8>(ShiftLeft<8>(vw)) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> BitCast(dw, ShiftRight<8>(BitCast(dw_u, bits16))); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi64(v.raw, bits.raw)}; +} + +// ------------------------------ WidenMulPairwiseAdd + +#if HWY_NATIVE_DOT_BF16 +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return VFromD{_mm512_dpbf16_ps(Zero(df).raw, + reinterpret_cast<__m512bh>(a.raw), + reinterpret_cast<__m512bh>(b.raw))}; +} +#endif // HWY_NATIVE_DOT_BF16 + +template +HWY_API VFromD WidenMulPairwiseAdd(D /*d32*/, Vec512 a, + Vec512 b) { + return VFromD{_mm512_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SatWidenMulPairwiseAdd +template +HWY_API VFromD SatWidenMulPairwiseAdd( + DI16 /* tag */, VFromD> a, + VFromD> b) { + return VFromD{_mm512_maddubs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SatWidenMulPairwiseAccumulate +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm512_dpwssds_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ ReorderWidenMulAccumulate + +#if HWY_NATIVE_DOT_BF16 +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{_mm512_dpbf16_ps(sum0.raw, + reinterpret_cast<__m512bh>(a.raw), + reinterpret_cast<__m512bh>(b.raw))}; +} +#endif // HWY_NATIVE_DOT_BF16 + +template +HWY_API VFromD ReorderWidenMulAccumulate(D d, Vec512 a, + Vec512 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; +#if HWY_TARGET <= HWY_AVX3_DL + return VFromD{_mm512_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; +#else + return sum0 + WidenMulPairwiseAdd(d, a, b); +#endif +} + +HWY_API Vec512 RearrangeToOddPlusEven(const Vec512 sum0, + Vec512 /*sum1*/) { + return sum0; // invariant already holds +} + +HWY_API Vec512 RearrangeToOddPlusEven(const Vec512 sum0, + Vec512 /*sum1*/) { + return sum0; // invariant already holds +} + +// ------------------------------ SumOfMulQuadAccumulate + +#if HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD{_mm512_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; +} + +#endif + +// ------------------------------ Reductions + +namespace detail { + +// Used by generic_ops-inl +template +HWY_INLINE VFromD ReduceAcrossBlocks(D d, Func f, VFromD v) { + v = f(v, SwapAdjacentBlocks(v)); + return f(v, ReverseBlocks(d, v)); +} + +} // namespace detail + +// ------------------------------ BitShuffle +#if HWY_TARGET <= HWY_AVX3_DL +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, 64)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Rebind du8; + + const __mmask64 mmask64_bit_shuf_result = + _mm512_bitshuffle_epi64_mask(v.raw, idx.raw); + +#if HWY_ARCH_X86_64 + const VFromD vu8_bit_shuf_result{ + _mm_cvtsi64_si128(static_cast(mmask64_bit_shuf_result))}; +#else + const int32_t i32_lo_bit_shuf_result = + static_cast(mmask64_bit_shuf_result); + const int32_t i32_hi_bit_shuf_result = + static_cast(_kshiftri_mask64(mmask64_bit_shuf_result, 32)); + + const VFromD vu8_bit_shuf_result = ResizeBitCast( + du8, InterleaveLower( + Vec128{_mm_cvtsi32_si128(i32_lo_bit_shuf_result)}, + Vec128{_mm_cvtsi32_si128(i32_hi_bit_shuf_result)})); +#endif + + return BitCast(d64, PromoteTo(du64, vu8_bit_shuf_result)); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ MultiRotateRight + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_MULTIROTATERIGHT +#undef HWY_NATIVE_MULTIROTATERIGHT +#else +#define HWY_NATIVE_MULTIROTATERIGHT +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + return V{_mm512_multishift_epi64_epi8(idx.raw, v.raw)}; +} + +#endif + +// -------------------- LeadingZeroCount + +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm512_lzcnt_epi32(v.raw)}; +} + +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm512_lzcnt_epi64(v.raw)}; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/third_party/aom/third_party/highway/hwy/ops/x86_avx3-inl.h b/third_party/aom/third_party/highway/hwy/ops/x86_avx3-inl.h new file mode 100644 index 000000000000..80f9488c6e3d --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/ops/x86_avx3-inl.h @@ -0,0 +1,507 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// External include guard in highway.h - see comment there. + +#if HWY_TARGET == HWY_AVX10_2 +// For AVX10 targets that only support 256-bit or smaller vectors. Already +// includes base.h and shared-inl.h. +#include "third_party/highway/hwy/ops/x86_256-inl.h" +#else +// For AVX3/AVX10 targets that support 512-byte vectors. Already includes base.h +// and shared-inl.h. +#include "third_party/highway/hwy/ops/x86_512-inl.h" +#endif + +// AVX3/AVX10 ops that have dependencies on ops defined in x86_512-inl.h if +// HWY_MAX_BYTES >= 64 is true are defined below + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, + ignored "-Wmaybe-uninitialized") +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +#if HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ ShiftLeft + +// Generic for all vector lengths. Must be defined after all GaloisAffine. +template +HWY_API V ShiftLeft(const V v) { + const Repartition> du64; + if (kBits == 0) return v; + if (kBits == 1) return v + v; + constexpr uint64_t kMatrix = (0x0102040810204080ULL >> kBits) & + (0x0101010101010101ULL * (0xFF >> kBits)); + return detail::GaloisAffine(v, Set(du64, kMatrix)); +} + +// ------------------------------ ShiftRight + +// Generic for all vector lengths. Must be defined after all GaloisAffine. +template )> +HWY_API V ShiftRight(const V v) { + const Repartition> du64; + if (kBits == 0) return v; + constexpr uint64_t kMatrix = + (0x0102040810204080ULL << kBits) & + (0x0101010101010101ULL * ((0xFF << kBits) & 0xFF)); + return detail::GaloisAffine(v, Set(du64, kMatrix)); +} + +// Generic for all vector lengths. Must be defined after all GaloisAffine. +template )> +HWY_API V ShiftRight(const V v) { + const Repartition> du64; + if (kBits == 0) return v; + constexpr uint64_t kShift = + (0x0102040810204080ULL << kBits) & + (0x0101010101010101ULL * ((0xFF << kBits) & 0xFF)); + constexpr uint64_t kSign = + kBits == 0 ? 0 : (0x8080808080808080ULL >> (64 - (8 * kBits))); + return detail::GaloisAffine(v, Set(du64, kShift | kSign)); +} + +// ------------------------------ RotateRight + +// U8 RotateRight is generic for all vector lengths on AVX3_DL +template )> +HWY_API V RotateRight(V v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + + const Repartition> du64; + if (kBits == 0) return v; + + constexpr uint64_t kShrMatrix = + (0x0102040810204080ULL << kBits) & + (0x0101010101010101ULL * ((0xFF << kBits) & 0xFF)); + constexpr int kShlBits = (-kBits) & 7; + constexpr uint64_t kShlMatrix = (0x0102040810204080ULL >> kShlBits) & + (0x0101010101010101ULL * (0xFF >> kShlBits)); + constexpr uint64_t kMatrix = kShrMatrix | kShlMatrix; + + return detail::GaloisAffine(v, Set(du64, kMatrix)); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ Compress + +#pragma push_macro("HWY_X86_SLOW_COMPRESS_STORE") + +#ifndef HWY_X86_SLOW_COMPRESS_STORE // allow override +// Slow on Zen4 and SPR, faster if we emulate via Compress(). +#if HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX3_SPR +#define HWY_X86_SLOW_COMPRESS_STORE 1 +#else +#define HWY_X86_SLOW_COMPRESS_STORE 0 +#endif +#endif // HWY_X86_SLOW_COMPRESS_STORE + +// Always implement 8-bit here even if we lack VBMI2 because we can do better +// than generic_ops (8 at a time) via the native 32-bit compress (16 at a time). +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 +template +HWY_INLINE Vec128 NativeCompress(const Vec128 v, + const Mask128 mask) { + return Vec128{_mm_maskz_compress_epi8(mask.raw, v.raw)}; +} +HWY_INLINE Vec256 NativeCompress(const Vec256 v, + const Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi8(mask.raw, v.raw)}; +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE Vec512 NativeCompress(const Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi8(mask.raw, v.raw)}; +} +#endif + +template +HWY_INLINE Vec128 NativeCompress(const Vec128 v, + const Mask128 mask) { + return Vec128{_mm_maskz_compress_epi16(mask.raw, v.raw)}; +} +HWY_INLINE Vec256 NativeCompress(const Vec256 v, + const Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi16(mask.raw, v.raw)}; +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE Vec512 NativeCompress(const Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi16(mask.raw, v.raw)}; +} +#endif + +// Do not even define these to prevent accidental usage. +#if !HWY_X86_SLOW_COMPRESS_STORE + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + uint8_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + uint8_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + uint8_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +#endif + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + uint16_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + uint16_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + uint16_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +#endif // HWY_MAX_BYTES >= 64 + +#endif // HWY_X86_SLOW_COMPRESS_STORE + +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_INLINE Vec128 NativeCompress(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_compress_epi32(mask.raw, v.raw)}; +} +HWY_INLINE Vec256 NativeCompress(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; +} + +#if HWY_MAX_BYTES >= 64 +HWY_INLINE Vec512 NativeCompress(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; +} +#endif +// We use table-based compress for 64-bit lanes, see CompressIsPartition. + +// Do not even define these to prevent accidental usage. +#if !HWY_X86_SLOW_COMPRESS_STORE + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + uint32_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + uint32_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + uint32_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +#endif + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + uint64_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + uint64_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + uint64_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +#endif + +template +HWY_INLINE void NativeCompressStore(Vec128 v, Mask128 mask, + float* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + float* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + float* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); +} +#endif + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + double* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + double* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + double* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); +} +#endif + +#endif // HWY_X86_SLOW_COMPRESS_STORE + +// For u8x16 and <= u16x16 we can avoid store+load for Compress because there is +// only a single compressed vector (u32x16). Other EmuCompress are implemented +// after the EmuCompressStore they build upon. +template ), + HWY_IF_LANES_LE_D(DFromV, HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V EmuCompress(V v, MFromD> mask) { + const DFromV d; + const Rebind d32; + const VFromD v0 = PromoteTo(d32, v); + + using M32 = MFromD; + const M32 m0 = PromoteMaskTo(d32, d, mask); + return TruncateTo(d, Compress(v0, m0)); +} + +template ), + HWY_IF_LANES_LE_D(DFromV, HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V EmuCompress(V v, MFromD> mask) { + const DFromV d; + const Rebind di32; + const RebindToUnsigned du32; + + const MFromD mask32 = PromoteMaskTo(du32, d, mask); + // DemoteTo is 2 ops, but likely lower latency than TruncateTo on SKX. + // Only i32 -> u16 is supported, whereas NativeCompress expects u32. + const VFromD v32 = PromoteTo(du32, v); + return DemoteTo(d, BitCast(di32, NativeCompress(v32, mask32))); +} + +// See above - small-vector EmuCompressStore are implemented via EmuCompress. +template +static HWY_INLINE HWY_MAYBE_UNUSED void EmuCompressStore( + VFromD v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { + StoreU(EmuCompress(v, mask), d, unaligned); +} + +// Main emulation logic for wider vector, starting with EmuCompressStore because +// it is most convenient to merge pieces using memory (concatenating vectors at +// byte offsets is difficult). +template +static HWY_INLINE HWY_MAYBE_UNUSED void EmuCompressStore( + VFromD v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { + const Half dh; + + const MFromD m0 = LowerHalfOfMask(dh, mask); + const MFromD m1 = UpperHalfOfMask(dh, mask); + + const VFromD v0 = LowerHalf(dh, v); + const VFromD v1 = UpperHalf(dh, v); + + EmuCompressStore(v0, m0, dh, unaligned); + EmuCompressStore(v1, m1, dh, unaligned + CountTrue(dh, m0)); +} + +// Finally, the remaining EmuCompress for wide vectors, using EmuCompressStore. +template , HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V EmuCompress(V v, MFromD> mask) { + using D = DFromV; + using T = TFromD; + const D d; + + alignas(HWY_MAX_LANES_D(D) * sizeof(T)) T buf[2 * HWY_MAX_LANES_D(D)]; + EmuCompressStore(v, mask, d, buf); + return Load(d, buf); +} + +} // namespace detail + +template +HWY_API V Compress(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); +#else + return BitCast(d, detail::EmuCompress(BitCast(du, v), mu)); +#endif +} + +template +HWY_API V Compress(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); +} + +// ------------------------------ CompressNot + +template +HWY_API V CompressNot(V v, const M mask) { + return Compress(v, Not(mask)); +} + +// uint64_t lanes. Only implement for 256 and 512-bit vectors because this is a +// no-op for 128-bit. +template , 16)> +HWY_API V CompressBlocksNot(V v, M mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressBits +template +HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +// ------------------------------ CompressStore + +// Generic for all vector lengths. + +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { +#if HWY_X86_SLOW_COMPRESS_STORE + StoreU(Compress(v, mask), d, unaligned); +#else + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + auto pu = reinterpret_cast * HWY_RESTRICT>(unaligned); + +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + detail::NativeCompressStore(BitCast(du, v), mu, pu); +#else + detail::EmuCompressStore(BitCast(du, v), mu, du, pu); +#endif +#endif // HWY_X86_SLOW_COMPRESS_STORE + const size_t count = CountTrue(d, mask); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { +#if HWY_X86_SLOW_COMPRESS_STORE + StoreU(Compress(v, mask), d, unaligned); +#else + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + using TU = TFromD; + TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + detail::NativeCompressStore(BitCast(du, v), mu, pu); +#endif // HWY_X86_SLOW_COMPRESS_STORE + const size_t count = CountTrue(d, mask); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// Additional overloads to avoid casting to uint32_t (delay?). +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { +#if HWY_X86_SLOW_COMPRESS_STORE + StoreU(Compress(v, mask), d, unaligned); +#else + (void)d; + detail::NativeCompressStore(v, mask, unaligned); +#endif // HWY_X86_SLOW_COMPRESS_STORE + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + // Native CompressStore already does the blending at no extra cost (latency + // 11, rthroughput 2 - same as compress plus store). + + HWY_IF_CONSTEXPR(HWY_MAX_LANES_D(D) < (16 / sizeof(TFromD))) { + m = And(m, FirstN(d, HWY_MAX_LANES_D(D))); + } + + HWY_IF_CONSTEXPR(!HWY_X86_SLOW_COMPRESS_STORE && + (HWY_TARGET <= HWY_AVX3_DL || sizeof(TFromD) > 2)) { + return CompressStore(v, m, d, unaligned); + } + else { + const size_t count = CountTrue(d, m); + StoreN(Compress(v, m), d, unaligned, count); + detail::MaybeUnpoison(unaligned, count); + return count; + } +} + +// ------------------------------ CompressBitsStore +// Generic for all vector lengths. +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +#pragma pop_macro("HWY_X86_SLOW_COMPRESS_STORE") + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/third_party/aom/third_party/highway/hwy/per_target.h b/third_party/aom/third_party/highway/hwy/per_target.h new file mode 100644 index 000000000000..196cd42aa063 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/per_target.h @@ -0,0 +1,49 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_PER_TARGET_H_ +#define HIGHWAY_HWY_PER_TARGET_H_ + +#include +#include + +#include "third_party/highway/hwy/highway_export.h" + +// Functions to query the capabilities of the target that will be called by +// HWY_DYNAMIC_DISPATCH, which is not necessarily the current target. + +namespace hwy { + +// Returns the HWY_TARGET which HWY_DYNAMIC_DISPATCH selected. +HWY_DLLEXPORT int64_t DispatchedTarget(); + +// Returns size in bytes of a vector, i.e. `Lanes(ScalableTag())`. +// +// Do not cache the result, which may change after calling DisableTargets, or +// if software requests a different vector size (e.g. when entering/exiting SME +// streaming mode). Instead call this right before the code that depends on the +// result, without any DisableTargets or SME transition in-between. Note that +// this involves an indirect call, so prefer not to call this frequently nor +// unnecessarily. +HWY_DLLEXPORT size_t VectorBytes(); + +// Returns whether 64-bit integers, 16/64-bit floats are a supported lane type. +HWY_DLLEXPORT bool HaveInteger64(); +HWY_DLLEXPORT bool HaveFloat16(); +HWY_DLLEXPORT bool HaveFloat64(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_PER_TARGET_H_ diff --git a/third_party/aom/third_party/highway/hwy/perf_counters.h b/third_party/aom/third_party/highway/hwy/perf_counters.h new file mode 100644 index 000000000000..2764e985bdce --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/perf_counters.h @@ -0,0 +1,156 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_PERF_COUNTERS_H_ +#define HIGHWAY_HWY_PERF_COUNTERS_H_ + +// Reads OS/CPU performance counters. + +#include + +#include "third_party/highway/hwy/base.h" // HWY_ABORT +#include "third_party/highway/hwy/bit_set.h" + +namespace hwy { +namespace platform { + +// Avoid padding in case callers such as profiler.h store many instances. +#pragma pack(push, 1) +// Provides access to CPU/OS performance counters. Each instance has space for +// multiple counter values; which counters these are may change in future. +// Although counters are per-CPU, Linux accesses them via a syscall, hence we +// use the monostate pattern to avoid callers having to pass around a pointer. +// Note that this is not thread-safe, so the static member functions should only +// be called from the main thread. +class PerfCounters { + public: + // Chosen such that this class occupies one or two cache lines. + static constexpr size_t kCapacity = 14; + + // Bit indices used to identify counters. The ordering is arbitrary. Some of + // these counters may be 'removed' in the sense of not being visited by + // `Foreach`, but their enumerators will remain. New counters may be appended. + enum Counter { + kRefCycles = 0, + kInstructions, + kBranches, + kBranchMispredicts, + kBusCycles, + kCacheRefs, + kCacheMisses, + kL3Loads, + kL3Stores, + kPageFaults, // SW + kMigrations // SW + }; // BitSet64 requires these values to be less than 64. + + // Strings for user-facing messages, not used in the implementation. + static inline const char* Name(Counter c) { + switch (c) { + case kRefCycles: + return "ref_cycles"; + case kInstructions: + return "instructions"; + case kBranches: + return "branches"; + case kBranchMispredicts: + return "branch_mispredicts"; + case kBusCycles: + return "bus_cycles"; + case kCacheRefs: + return "cache_refs"; + case kCacheMisses: + return "cache_misses"; + case kL3Loads: + return "l3_load"; + case kL3Stores: + return "l3_store"; + case kPageFaults: + return "page_fault"; + case kMigrations: + return "migration"; + default: + HWY_ABORT("Bug: unknown counter %d", c); + } + } + + // Returns false if counters are unavailable. Must be called at least once + // before `StartAll`; it is separate to reduce the overhead of repeatedly + // stopping/starting counters. + HWY_DLLEXPORT static bool Init(); + + // Returns false if counters are unavailable, otherwise starts them. Note that + // they default to stopped. Unless this is called, the values read may be 0. + HWY_DLLEXPORT static bool StartAll(); + + // Stops and zeros all counters. This is not necessary if users subtract the + // previous counter values, but can increase precision because floating-point + // has more precision near zero. + HWY_DLLEXPORT static void StopAllAndReset(); + + // Reads the current (extrapolated, in case of multiplexing) counter values. + HWY_DLLEXPORT PerfCounters(); + + // Returns whether any counters were successfully read. + bool AnyValid() const { return valid_.Any(); } + + // Returns whether the given counter was successfully read. + bool IsValid(Counter c) const { + const size_t bit_idx = static_cast(c); + return valid_.Get(bit_idx); + } + + // Returns the maximum extrapolation factor for any counter, which is the + // total time between `StartAll` and now or the last `StopAllAndReset`, + // divided by the time that the counter was actually running. This + // approximates the number of counter groups that the CPU multiplexes onto the + // actual counter hardware. It is only meaningful if AnyValid(). + double MaxExtrapolate() const { return max_extrapolate_; } + + // Returns the value of the given counter, or zero if it is not valid. + double Get(Counter c) const { + return IsValid(c) ? values_[IndexForCounter(c)] : 0.0; + } + + // For each valid counter in increasing numerical order, calls `visitor` with + // the value and `Counter`. + template + void Foreach(const Visitor& visitor) { + valid_.Foreach([&](size_t bit_idx) { + const Counter c = static_cast(bit_idx); + visitor(values_[IndexForCounter(c)], c); + }); + } + + private: + // Index within `values_` for a given counter. + HWY_DLLEXPORT static size_t IndexForCounter(Counter c); + + BitSet64 valid_; + double max_extrapolate_; + // Floating-point because these are extrapolated (multiplexing). It would be + // nice for this to fit in one cache line to reduce the cost of reading + // counters in profiler.h, but some of the values are too large for float and + // we want more than 8 counters. Ensure all values are sums, not ratios, so + // that profiler.h can add/subtract them. These are contiguous in memory, in + // the order that counters were initialized. + double values_[kCapacity]; +}; +#pragma pack(pop) + +} // namespace platform +} // namespace hwy + +#endif // HIGHWAY_HWY_PERF_COUNTERS_H_ diff --git a/third_party/aom/third_party/highway/hwy/print-inl.h b/third_party/aom/third_party/highway/hwy/print-inl.h new file mode 100644 index 000000000000..16cfa141ebbb --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/print-inl.h @@ -0,0 +1,62 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Print() function + +#include "third_party/highway/hwy/highway.h" +#include "third_party/highway/hwy/print.h" + +// Per-target include guard +#if defined(HIGHWAY_HWY_PRINT_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_PRINT_INL_H_ +#undef HIGHWAY_HWY_PRINT_INL_H_ +#else +#define HIGHWAY_HWY_PRINT_INL_H_ +#endif + +#if HWY_TARGET == HWY_RVV +#include "third_party/highway/hwy/aligned_allocator.h" +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Prints lanes around `lane`, in memory order. +template > +HWY_API void Print(const D d, const char* caption, V v, size_t lane_u = 0, + size_t max_lanes = 7) { + const size_t N = Lanes(d); + using T = TFromD; +#if HWY_TARGET == HWY_RVV + auto storage = AllocateAligned(N); + T* HWY_RESTRICT lanes = storage.get(); +#else + // This works around an SVE compile error on GCC 11 and 12. Calling + // AllocateAligned here would seem to require it be marked with HWY_ATTR. + HWY_ALIGN T lanes[MaxLanes(d)]; +#endif + Store(v, d, lanes); + + const auto info = hwy::detail::MakeTypeInfo(); + hwy::detail::PrintArray(info, caption, lanes, N, lane_u, max_lanes); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // per-target include guard diff --git a/third_party/aom/third_party/highway/hwy/print.h b/third_party/aom/third_party/highway/hwy/print.h new file mode 100644 index 000000000000..3de44726fac0 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/print.h @@ -0,0 +1,75 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_PRINT_H_ +#define HWY_PRINT_H_ + +// Helpers for printing vector lanes. + +#include +#include + +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/highway_export.h" + +namespace hwy { + +namespace detail { + +// For implementing value comparisons etc. as type-erased functions to reduce +// template bloat. +struct TypeInfo { + size_t sizeof_t; + bool is_float; + bool is_signed; + bool is_bf16; +}; + +template +HWY_INLINE TypeInfo MakeTypeInfo() { + TypeInfo info; + info.sizeof_t = sizeof(T); + info.is_float = IsFloat(); + info.is_signed = IsSigned(); + info.is_bf16 = IsSame(); + return info; +} + +HWY_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100); +HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, + char* string100); + +HWY_DLLEXPORT void PrintArray(const TypeInfo& info, const char* caption, + const void* array_void, size_t N, + size_t lane_u = 0, size_t max_lanes = 7); + +} // namespace detail + +template +HWY_NOINLINE void PrintValue(T value) { + char str[100]; + detail::ToString(hwy::detail::MakeTypeInfo(), &value, str); + fprintf(stderr, "%s,", str); +} + +template +HWY_NOINLINE void PrintArray(const T* value, size_t count) { + detail::PrintArray(hwy::detail::MakeTypeInfo(), "", value, count, 0, + count); +} + +} // namespace hwy + +#endif // HWY_PRINT_H_ diff --git a/third_party/aom/third_party/highway/hwy/profiler.h b/third_party/aom/third_party/highway/hwy/profiler.h new file mode 100644 index 000000000000..a9c28136154c --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/profiler.h @@ -0,0 +1,672 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_PROFILER_H_ +#define HIGHWAY_HWY_PROFILER_H_ + +// High precision, low overhead time measurements. Returns exact call counts and +// total elapsed time for user-defined 'zones' (code regions, i.e. C++ scopes). +// +// Uses RAII to capture begin/end timestamps, with user-specified zone names: +// { PROFILER_ZONE("name"); /*code*/ } or +// the name of the current function: +// void FuncToMeasure() { PROFILER_FUNC; /*code*/ }. +// +// After all threads have exited any zones, invoke PROFILER_PRINT_RESULTS() to +// print call counts and average durations [CPU cycles] to stdout, sorted in +// descending order of total duration. +// +// The binary MUST be built with --dynamic_mode=off because we rely on the data +// segments being nearby; if not, an assertion will likely fail. + +#include "third_party/highway/hwy/base.h" + +// Configuration settings: + +// If zero, this file has no effect and no measurements will be recorded. +#ifndef PROFILER_ENABLED +#define PROFILER_ENABLED 0 +#endif + +// How many mebibytes to allocate (if PROFILER_ENABLED) per thread that +// enters at least one zone. Once this buffer is full, the thread will analyze +// and discard packets, thus temporarily adding some observer overhead. +// Each zone occupies 16 bytes. +#ifndef PROFILER_THREAD_STORAGE +#define PROFILER_THREAD_STORAGE 200ULL +#endif + +#if PROFILER_ENABLED || HWY_IDE + +#include +#include +#include +#include // strcmp + +#include + +#include "third_party/highway/hwy/aligned_allocator.h" +#include "third_party/highway/hwy/cache_control.h" // FlushStream +#include "third_party/highway/hwy/contrib/sort/vqsort.h" +#include "third_party/highway/hwy/robust_statistics.h" +#include "third_party/highway/hwy/timer.h" + +#define PROFILER_PRINT_OVERHEAD 0 + +namespace hwy { + +// Upper bounds for fixed-size data structures (guarded via HWY_DASSERT): + +// How many threads can actually enter a zone (those that don't do not count). +// Memory use is about kMaxThreads * PROFILER_THREAD_STORAGE MiB. +// WARNING: a fiber library can spawn hundreds of threads. +static constexpr size_t kMaxThreads = 256; + +static constexpr size_t kMaxDepth = 64; // Maximum nesting of zones. + +static constexpr size_t kMaxZones = 256; // Total number of zones. + +#pragma pack(push, 1) + +// Represents zone entry/exit events. Stores a full-resolution timestamp plus +// an offset (representing zone name or identifying exit packets). POD. +class Packet { + public: + // If offsets do not fit, UpdateOrAdd will overrun our heap allocation + // (governed by kMaxZones). We have seen multi-megabyte offsets. + static constexpr size_t kOffsetBits = 25; + static constexpr uint64_t kOffsetBias = 1ULL << (kOffsetBits - 1); + + // We need full-resolution timestamps; at an effective rate of 4 GHz, + // this permits 1 minute zone durations (for longer durations, split into + // multiple zones). Wraparound is handled by masking. + static constexpr size_t kTimestampBits = 64 - kOffsetBits; + static constexpr uint64_t kTimestampMask = (1ULL << kTimestampBits) - 1; + + static Packet Make(const size_t biased_offset, const uint64_t timestamp) { + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << kOffsetBits)); + + Packet packet; + packet.bits_ = + (biased_offset << kTimestampBits) + (timestamp & kTimestampMask); + + HWY_DASSERT(packet.BiasedOffset() == biased_offset); + HWY_DASSERT(packet.Timestamp() == (timestamp & kTimestampMask)); + return packet; + } + + uint64_t Timestamp() const { return bits_ & kTimestampMask; } + + size_t BiasedOffset() const { + const size_t biased_offset = (bits_ >> kTimestampBits); + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << kOffsetBits)); + return biased_offset; + } + + private: + uint64_t bits_; +}; +static_assert(sizeof(Packet) == 8, "Wrong Packet size"); + +// All translation units must use the same string origin. A static member +// function ensures this without requiring a separate .cc file. +struct StringOrigin { + // Returns the address of a string literal. Assuming zone names are also + // literals and stored nearby, we can represent them as offsets from this, + // which is faster to compute than hashes or even a static index. + static const char* Get() { + // Chosen such that no zone name is a prefix nor suffix of this string + // to ensure they aren't merged. Note zone exit packets use + // `biased_offset == kOffsetBias`. + static const char* string_origin = "__#__"; + return string_origin - Packet::kOffsetBias; + } +}; + +// Representation of an active zone, stored in a stack. Used to deduct +// child duration from the parent's self time. POD. +struct Node { + Packet packet; + uint64_t child_total; +}; +static_assert(sizeof(Node) == 16, "Wrong Node size"); + +// Holds statistics for all zones with the same name. POD. +struct Accumulator { + static constexpr size_t kNumCallBits = 64 - Packet::kOffsetBits; + + uint64_t BiasedOffset() const { + const size_t biased_offset = u128.lo >> kNumCallBits; + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << Packet::kOffsetBits)); + return biased_offset; + } + uint64_t NumCalls() const { return u128.lo & ((1ULL << kNumCallBits) - 1); } + uint64_t Duration() const { return u128.hi; } + + void Set(uint64_t biased_offset, uint64_t num_calls, uint64_t duration) { + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << Packet::kOffsetBits)); + HWY_DASSERT(num_calls < (1ULL << kNumCallBits)); + + u128.hi = duration; + u128.lo = (biased_offset << kNumCallBits) + num_calls; + + HWY_DASSERT(BiasedOffset() == biased_offset); + HWY_DASSERT(NumCalls() == num_calls); + HWY_DASSERT(Duration() == duration); + } + + void Add(uint64_t num_calls, uint64_t duration) { + const uint64_t biased_offset = BiasedOffset(); + (void)biased_offset; + + u128.lo += num_calls; + u128.hi += duration; + + HWY_DASSERT(biased_offset == BiasedOffset()); + } + + // For fast sorting by duration, which must therefore be the hi element. + // lo holds BiasedOffset and NumCalls. + uint128_t u128; +}; +static_assert(sizeof(Accumulator) == 16, "Wrong Accumulator size"); + +template +inline T ClampedSubtract(const T minuend, const T subtrahend) { + if (subtrahend > minuend) { + return 0; + } + return minuend - subtrahend; +} + +// Per-thread call graph (stack) and Accumulator for each zone. +class Results { + public: + Results() { + ZeroBytes(nodes_, sizeof(nodes_)); + ZeroBytes(zones_, sizeof(zones_)); + } + + // Used for computing overhead when this thread encounters its first Zone. + // This has no observable effect apart from increasing "analyze_elapsed_". + uint64_t ZoneDuration(const Packet* packets) { + HWY_DASSERT(depth_ == 0); + HWY_DASSERT(num_zones_ == 0); + AnalyzePackets(packets, 2); + const uint64_t duration = zones_[0].Duration(); + zones_[0].Set(1, 0, 0); // avoids triggering biased_offset = 0 checks + HWY_DASSERT(depth_ == 0); + num_zones_ = 0; + return duration; + } + + void SetSelfOverhead(const uint64_t self_overhead) { + self_overhead_ = self_overhead; + } + + void SetChildOverhead(const uint64_t child_overhead) { + child_overhead_ = child_overhead; + } + + // Draw all required information from the packets, which can be discarded + // afterwards. Called whenever this thread's storage is full. + void AnalyzePackets(const Packet* packets, const size_t num_packets) { + const uint64_t t0 = timer::Start(); + + for (size_t i = 0; i < num_packets; ++i) { + const Packet p = packets[i]; + // Entering a zone + if (p.BiasedOffset() != Packet::kOffsetBias) { + HWY_DASSERT(depth_ < kMaxDepth); + nodes_[depth_].packet = p; + HWY_DASSERT(p.BiasedOffset() != 0); + nodes_[depth_].child_total = 0; + ++depth_; + continue; + } + + HWY_DASSERT(depth_ != 0); + const Node& node = nodes_[depth_ - 1]; + // Masking correctly handles unsigned wraparound. + const uint64_t duration = + (p.Timestamp() - node.packet.Timestamp()) & Packet::kTimestampMask; + const uint64_t self_duration = ClampedSubtract( + duration, self_overhead_ + child_overhead_ + node.child_total); + + UpdateOrAdd(node.packet.BiasedOffset(), 1, self_duration); + --depth_; + + // Deduct this nested node's time from its parent's self_duration. + if (depth_ != 0) { + nodes_[depth_ - 1].child_total += duration + child_overhead_; + } + } + + const uint64_t t1 = timer::Stop(); + analyze_elapsed_ += t1 - t0; + } + + // Incorporates results from another thread. Call after all threads have + // exited any zones. + void Assimilate(Results& other) { + const uint64_t t0 = timer::Start(); + HWY_DASSERT(depth_ == 0); + HWY_DASSERT(other.depth_ == 0); + + for (size_t i = 0; i < other.num_zones_; ++i) { + const Accumulator& zone = other.zones_[i]; + UpdateOrAdd(zone.BiasedOffset(), zone.NumCalls(), zone.Duration()); + } + other.num_zones_ = 0; + const uint64_t t1 = timer::Stop(); + analyze_elapsed_ += t1 - t0 + other.analyze_elapsed_; + } + + // Single-threaded. + void Print() { + const uint64_t t0 = timer::Start(); + MergeDuplicates(); + + // Sort by decreasing total (self) cost. + VQSort(&zones_[0].u128, num_zones_, SortDescending()); + + const double inv_freq = 1.0 / platform::InvariantTicksPerSecond(); + + const char* string_origin = StringOrigin::Get(); + for (size_t i = 0; i < num_zones_; ++i) { + const Accumulator& z = zones_[i]; + const size_t num_calls = z.NumCalls(); + const double duration = static_cast(z.Duration()); + printf("%-40s: %10zu x %15.0f = %9.6f\n", + string_origin + z.BiasedOffset(), num_calls, duration / num_calls, + duration * inv_freq); + } + num_zones_ = 0; + + const uint64_t t1 = timer::Stop(); + analyze_elapsed_ += t1 - t0; + printf("Total analysis [s]: %f\n", + static_cast(analyze_elapsed_) * inv_freq); + } + + private: + // Updates an existing Accumulator (uniquely identified by biased_offset) or + // adds one if this is the first time this thread analyzed that zone. + // Uses a self-organizing list data structure, which avoids dynamic memory + // allocations and is far faster than unordered_map. + void UpdateOrAdd(const size_t biased_offset, const uint64_t num_calls, + const uint64_t duration) { + HWY_DASSERT(biased_offset != 0); + HWY_DASSERT(biased_offset < (1ULL << Packet::kOffsetBits)); + + // Special case for first zone: (maybe) update, without swapping. + if (num_zones_ != 0 && zones_[0].BiasedOffset() == biased_offset) { + zones_[0].Add(num_calls, duration); + return; + } + + // Look for a zone with the same offset. + for (size_t i = 1; i < num_zones_; ++i) { + if (zones_[i].BiasedOffset() == biased_offset) { + zones_[i].Add(num_calls, duration); + // Swap with predecessor (more conservative than move to front, + // but at least as successful). + const Accumulator prev = zones_[i - 1]; + zones_[i - 1] = zones_[i]; + zones_[i] = prev; + return; + } + } + + // Not found; create a new Accumulator. + HWY_DASSERT(num_zones_ < kMaxZones); + zones_[num_zones_].Set(biased_offset, num_calls, duration); + ++num_zones_; + } + + // Each instantiation of a function template seems to get its own copy of + // __func__ and GCC doesn't merge them. An N^2 search for duplicates is + // acceptable because we only expect a few dozen zones. + void MergeDuplicates() { + const char* string_origin = StringOrigin::Get(); + for (size_t i = 0; i < num_zones_; ++i) { + const size_t biased_offset = zones_[i].BiasedOffset(); + const char* name = string_origin + biased_offset; + // Separate num_calls from biased_offset so we can add them together. + uint64_t num_calls = zones_[i].NumCalls(); + + // Add any subsequent duplicates to num_calls and total_duration. + for (size_t j = i + 1; j < num_zones_;) { + if (!strcmp(name, string_origin + zones_[j].BiasedOffset())) { + num_calls += zones_[j].NumCalls(); + zones_[i].Add(0, zones_[j].Duration()); + // j was the last zone, so we are done. + if (j == num_zones_ - 1) break; + // Replace current zone with the last one, and check it next. + zones_[j] = zones_[--num_zones_]; + } else { // Name differed, try next Accumulator. + ++j; + } + } + + // Re-pack regardless of whether any duplicates were found. + zones_[i].Set(biased_offset, num_calls, zones_[i].Duration()); + } + } + + uint64_t analyze_elapsed_ = 0; + uint64_t self_overhead_ = 0; + uint64_t child_overhead_ = 0; + + size_t depth_ = 0; // Number of active zones. + size_t num_zones_ = 0; // Number of retired zones. + + alignas(HWY_ALIGNMENT) Node nodes_[kMaxDepth]; // Stack + alignas(HWY_ALIGNMENT) Accumulator zones_[kMaxZones]; // Self-organizing list +}; + +// Per-thread packet storage, dynamically allocated. +class ThreadSpecific { + static constexpr size_t kBufferCapacity = HWY_ALIGNMENT / sizeof(Packet); + + public: + // "name" is used to sanity-check offsets fit in kOffsetBits. + explicit ThreadSpecific(const char* name) + : max_packets_((PROFILER_THREAD_STORAGE << 20) / sizeof(Packet)), + packets_(AllocateAligned(max_packets_)), + num_packets_(0), + string_origin_(StringOrigin::Get()) { + // Even in optimized builds, verify that this zone's name offset fits + // within the allotted space. If not, UpdateOrAdd is likely to overrun + // zones_[]. Checking here on the cold path (only reached once per thread) + // is cheap, but it only covers one zone. + const size_t biased_offset = name - string_origin_; + HWY_ASSERT(biased_offset < (1ULL << Packet::kOffsetBits)); + } + + // Depends on Zone => defined below. + void ComputeOverhead(); + + void WriteEntry(const char* name, const uint64_t timestamp) { + HWY_DASSERT(name >= string_origin_); + const size_t biased_offset = static_cast(name - string_origin_); + Write(Packet::Make(biased_offset, timestamp)); + } + + void WriteExit(const uint64_t timestamp) { + const size_t biased_offset = Packet::kOffsetBias; + Write(Packet::Make(biased_offset, timestamp)); + } + + void AnalyzeRemainingPackets() { + // Ensures prior weakly-ordered streaming stores are globally visible. + FlushStream(); + + // Storage full => empty it. + if (num_packets_ + buffer_size_ > max_packets_) { + results_.AnalyzePackets(packets_.get(), num_packets_); + num_packets_ = 0; + } + CopyBytes(buffer_, packets_.get() + num_packets_, + buffer_size_ * sizeof(Packet)); + num_packets_ += buffer_size_; + + results_.AnalyzePackets(packets_.get(), num_packets_); + num_packets_ = 0; + } + + Results& GetResults() { return results_; } + + private: + // Overwrites "to" while attempting to bypass the cache (read-for-ownership). + // Both pointers must be aligned. + static void StreamCacheLine(const uint64_t* HWY_RESTRICT from, + uint64_t* HWY_RESTRICT to) { +#if HWY_COMPILER_CLANG + for (size_t i = 0; i < HWY_ALIGNMENT / sizeof(uint64_t); ++i) { + __builtin_nontemporal_store(from[i], to + i); + } +#else + hwy::CopyBytes(from, to, HWY_ALIGNMENT); +#endif + } + + // Write packet to buffer/storage, emptying them as needed. + void Write(const Packet packet) { + // Buffer full => copy to storage. + if (buffer_size_ == kBufferCapacity) { + // Storage full => empty it. + if (num_packets_ + kBufferCapacity > max_packets_) { + results_.AnalyzePackets(packets_.get(), num_packets_); + num_packets_ = 0; + } + // This buffering halves observer overhead and decreases the overall + // runtime by about 3%. Casting is safe because the first member is u64. + StreamCacheLine( + reinterpret_cast(buffer_), + reinterpret_cast(packets_.get() + num_packets_)); + num_packets_ += kBufferCapacity; + buffer_size_ = 0; + } + buffer_[buffer_size_] = packet; + ++buffer_size_; + } + + // Write-combining buffer to avoid cache pollution. Must be the first + // non-static member to ensure cache-line alignment. + Packet buffer_[kBufferCapacity]; + size_t buffer_size_ = 0; + + const size_t max_packets_; + // Contiguous storage for zone enter/exit packets. + AlignedFreeUniquePtr packets_; + size_t num_packets_; + // Cached here because we already read this cache line on zone entry/exit. + const char* string_origin_; + Results results_; +}; + +class ThreadList { + public: + // Called from any thread. + ThreadSpecific* Add(const char* name) { + const size_t index = num_threads_.fetch_add(1, std::memory_order_relaxed); + HWY_DASSERT(index < kMaxThreads); + + ThreadSpecific* ts = MakeUniqueAligned(name).release(); + threads_[index].store(ts, std::memory_order_release); + return ts; + } + + // Single-threaded. + void PrintResults() { + const auto acq = std::memory_order_acquire; + const size_t num_threads = num_threads_.load(acq); + + ThreadSpecific* main = threads_[0].load(acq); + main->AnalyzeRemainingPackets(); + + for (size_t i = 1; i < num_threads; ++i) { + ThreadSpecific* ts = threads_[i].load(acq); + ts->AnalyzeRemainingPackets(); + main->GetResults().Assimilate(ts->GetResults()); + } + + if (num_threads != 0) { + main->GetResults().Print(); + } + } + + private: + // Owning pointers. + alignas(64) std::atomic threads_[kMaxThreads]; + std::atomic num_threads_{0}; +}; + +// RAII zone enter/exit recorder constructed by the ZONE macro; also +// responsible for initializing ThreadSpecific. +class Zone { + public: + // "name" must be a string literal (see StringOrigin::Get). + HWY_NOINLINE explicit Zone(const char* name) { + HWY_FENCE; + ThreadSpecific* HWY_RESTRICT thread_specific = StaticThreadSpecific(); + if (HWY_UNLIKELY(thread_specific == nullptr)) { + // Ensure the CPU supports our timer. + char cpu[100]; + if (!platform::HaveTimerStop(cpu)) { + HWY_ABORT("CPU %s is too old for PROFILER_ENABLED=1, exiting", cpu); + } + + thread_specific = StaticThreadSpecific() = Threads().Add(name); + // Must happen after setting StaticThreadSpecific, because ComputeOverhead + // also calls Zone(). + thread_specific->ComputeOverhead(); + } + + // (Capture timestamp ASAP, not inside WriteEntry.) + HWY_FENCE; + const uint64_t timestamp = timer::Start(); + thread_specific->WriteEntry(name, timestamp); + } + + HWY_NOINLINE ~Zone() { + HWY_FENCE; + const uint64_t timestamp = timer::Stop(); + StaticThreadSpecific()->WriteExit(timestamp); + HWY_FENCE; + } + + // Call exactly once after all threads have exited all zones. + static void PrintResults() { Threads().PrintResults(); } + + private: + // Returns reference to the thread's ThreadSpecific pointer (initially null). + // Function-local static avoids needing a separate definition. + static ThreadSpecific*& StaticThreadSpecific() { + static thread_local ThreadSpecific* thread_specific; + return thread_specific; + } + + // Returns the singleton ThreadList. Non time-critical. + static ThreadList& Threads() { + static ThreadList threads_; + return threads_; + } +}; + +// Creates a zone starting from here until the end of the current scope. +// Timestamps will be recorded when entering and exiting the zone. +// "name" must be a string literal, which is ensured by merging with "". +#define PROFILER_ZONE(name) \ + HWY_FENCE; \ + const hwy::Zone zone("" name); \ + HWY_FENCE + +// Creates a zone for an entire function (when placed at its beginning). +// Shorter/more convenient than ZONE. +#define PROFILER_FUNC \ + HWY_FENCE; \ + const hwy::Zone zone(__func__); \ + HWY_FENCE + +#define PROFILER_PRINT_RESULTS hwy::Zone::PrintResults + +inline void ThreadSpecific::ComputeOverhead() { + // Delay after capturing timestamps before/after the actual zone runs. Even + // with frequency throttling disabled, this has a multimodal distribution, + // including 32, 34, 48, 52, 59, 62. + uint64_t self_overhead; + { + const size_t kNumSamples = 32; + uint32_t samples[kNumSamples]; + for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) { + const size_t kNumDurations = 1024; + uint32_t durations[kNumDurations]; + + for (size_t idx_duration = 0; idx_duration < kNumDurations; + ++idx_duration) { + { + PROFILER_ZONE("Dummy Zone (never shown)"); + } + const uint64_t duration = results_.ZoneDuration(buffer_); + buffer_size_ = 0; + durations[idx_duration] = static_cast(duration); + HWY_DASSERT(num_packets_ == 0); + } + robust_statistics::CountingSort(durations, kNumDurations); + samples[idx_sample] = robust_statistics::Mode(durations, kNumDurations); + } + // Median. + robust_statistics::CountingSort(samples, kNumSamples); + self_overhead = samples[kNumSamples / 2]; + if (PROFILER_PRINT_OVERHEAD) { + printf("Overhead: %.0f\n", static_cast(self_overhead)); + } + results_.SetSelfOverhead(self_overhead); + } + + // Delay before capturing start timestamp / after end timestamp. + const size_t kNumSamples = 32; + uint32_t samples[kNumSamples]; + for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) { + const size_t kNumDurations = 16; + uint32_t durations[kNumDurations]; + for (size_t idx_duration = 0; idx_duration < kNumDurations; + ++idx_duration) { + const size_t kReps = 10000; + // Analysis time should not be included => must fit within buffer. + HWY_DASSERT(kReps * 2 < max_packets_); + std::atomic_thread_fence(std::memory_order_seq_cst); + const uint64_t t0 = timer::Start(); + for (size_t i = 0; i < kReps; ++i) { + PROFILER_ZONE("Dummy"); + } + FlushStream(); + const uint64_t t1 = timer::Stop(); + HWY_DASSERT(num_packets_ + buffer_size_ == kReps * 2); + buffer_size_ = 0; + num_packets_ = 0; + const uint64_t avg_duration = (t1 - t0 + kReps / 2) / kReps; + durations[idx_duration] = + static_cast(ClampedSubtract(avg_duration, self_overhead)); + } + robust_statistics::CountingSort(durations, kNumDurations); + samples[idx_sample] = robust_statistics::Mode(durations, kNumDurations); + } + robust_statistics::CountingSort(samples, kNumSamples); + const uint64_t child_overhead = samples[9 * kNumSamples / 10]; + if (PROFILER_PRINT_OVERHEAD) { + printf("Child overhead: %.0f\n", static_cast(child_overhead)); + } + results_.SetChildOverhead(child_overhead); +} + +#pragma pack(pop) + +} // namespace hwy + +#endif // PROFILER_ENABLED || HWY_IDE + +#if !PROFILER_ENABLED && !HWY_IDE +#define PROFILER_ZONE(name) +#define PROFILER_FUNC +#define PROFILER_PRINT_RESULTS() +#endif + +#endif // HIGHWAY_HWY_PROFILER_H_ diff --git a/third_party/aom/third_party/highway/hwy/robust_statistics.h b/third_party/aom/third_party/highway/hwy/robust_statistics.h new file mode 100644 index 000000000000..5391cf595154 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/robust_statistics.h @@ -0,0 +1,148 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_ROBUST_STATISTICS_H_ +#define HIGHWAY_HWY_ROBUST_STATISTICS_H_ + +#include // std::sort, std::find_if +#include +#include // std::pair +#include + +#include "third_party/highway/hwy/base.h" + +namespace hwy { +namespace robust_statistics { + +// Sorts integral values in ascending order (e.g. for Mode). About 3x faster +// than std::sort for input distributions with very few unique values. +template +void CountingSort(T* values, size_t num_values) { + // Unique values and their frequency (similar to flat_map). + using Unique = std::pair; + std::vector unique; + for (size_t i = 0; i < num_values; ++i) { + const T value = values[i]; + const auto pos = + std::find_if(unique.begin(), unique.end(), + [value](const Unique u) { return u.first == value; }); + if (pos == unique.end()) { + unique.push_back(std::make_pair(value, 1)); + } else { + ++pos->second; + } + } + + // Sort in ascending order of value (pair.first). + std::sort(unique.begin(), unique.end()); + + // Write that many copies of each unique value to the array. + T* HWY_RESTRICT p = values; + for (const auto& value_count : unique) { + std::fill(p, p + value_count.second, value_count.first); + p += value_count.second; + } + HWY_ASSERT(p == values + num_values); +} + +// @return i in [idx_begin, idx_begin + half_count) that minimizes +// sorted[i + half_count] - sorted[i]. +template +size_t MinRange(const T* const HWY_RESTRICT sorted, const size_t idx_begin, + const size_t half_count) { + T min_range = std::numeric_limits::max(); + size_t min_idx = 0; + + for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) { + HWY_ASSERT(sorted[idx] <= sorted[idx + half_count]); + const T range = sorted[idx + half_count] - sorted[idx]; + if (range < min_range) { + min_range = range; + min_idx = idx; + } + } + + return min_idx; +} + +// Returns an estimate of the mode by calling MinRange on successively +// halved intervals. "sorted" must be in ascending order. This is the +// Half Sample Mode estimator proposed by Bickel in "On a fast, robust +// estimator of the mode", with complexity O(N log N). The mode is less +// affected by outliers in highly-skewed distributions than the median. +// The averaging operation below assumes "T" is an unsigned integer type. +template +T ModeOfSorted(const T* const HWY_RESTRICT sorted, const size_t num_values) { + size_t idx_begin = 0; + size_t half_count = num_values / 2; + while (half_count > 1) { + idx_begin = MinRange(sorted, idx_begin, half_count); + half_count >>= 1; + } + + const T x = sorted[idx_begin + 0]; + if (half_count == 0) { + return x; + } + HWY_ASSERT(half_count == 1); + const T average = (x + sorted[idx_begin + 1] + 1) / 2; + return average; +} + +// Returns the mode. Side effect: sorts "values". +template +T Mode(T* values, const size_t num_values) { + CountingSort(values, num_values); + return ModeOfSorted(values, num_values); +} + +template +T Mode(T (&values)[N]) { + return Mode(&values[0], N); +} + +// Returns the median value. Side effect: sorts "values". +template +T Median(T* values, const size_t num_values) { + HWY_ASSERT(num_values != 0); + std::sort(values, values + num_values); + const size_t half = num_values / 2; + // Odd count: return middle + if (num_values % 2) { + return values[half]; + } + // Even count: return average of middle two. + return (values[half] + values[half - 1] + 1) / 2; +} + +// Returns a robust measure of variability. +template +T MedianAbsoluteDeviation(const T* values, const size_t num_values, + const T median) { + HWY_ASSERT(num_values != 0); + std::vector abs_deviations; + abs_deviations.reserve(num_values); + for (size_t i = 0; i < num_values; ++i) { + const int64_t abs = ScalarAbs(static_cast(values[i]) - + static_cast(median)); + abs_deviations.push_back(static_cast(abs)); + } + return Median(abs_deviations.data(), num_values); +} + +} // namespace robust_statistics +} // namespace hwy + +#endif // HIGHWAY_HWY_ROBUST_STATISTICS_H_ diff --git a/third_party/aom/third_party/highway/hwy/stats.h b/third_party/aom/third_party/highway/hwy/stats.h new file mode 100644 index 000000000000..b4b95719fbfc --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/stats.h @@ -0,0 +1,194 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_STATS_H_ +#define HIGHWAY_HWY_STATS_H_ + +#include +#include + +#include +#include + +#include "third_party/highway/hwy/base.h" // HWY_ASSERT + +namespace hwy { + +// Thread-compatible. +template +class Bins { + public: + Bins() { Reset(); } + + template + void Notify(T bin) { + HWY_ASSERT(T{0} <= bin && bin < static_cast(N)); + counts_[static_cast(bin)]++; + } + + void Assimilate(const Bins& other) { + for (size_t i = 0; i < N; ++i) { + counts_[i] += other.counts_[i]; + } + } + + void Print(const char* caption) const { + fprintf(stderr, "\n%s [%zu]\n", caption, N); + size_t last_nonzero = 0; + for (size_t i = N - 1; i < N; --i) { + if (counts_[i] != 0) { + last_nonzero = i; + break; + } + } + for (size_t i = 0; i <= last_nonzero; ++i) { + fprintf(stderr, " %zu\n", counts_[i]); + } + } + + void Reset() { + for (size_t i = 0; i < N; ++i) { + counts_[i] = 0; + } + } + + private: + size_t counts_[N]; +}; + +// Descriptive statistics of a variable (4 moments). Thread-compatible. +class Stats { + public: + Stats() { Reset(); } + + void Notify(const float x) { + ++n_; + + min_ = HWY_MIN(min_, x); + max_ = HWY_MAX(max_, x); + + // Logarithmic transform avoids/delays underflow and overflow. + sum_log_ += std::log(static_cast(x)); + + // Online moments. Reference: https://goo.gl/9ha694 + const double d = x - m1_; + const double d_div_n = d / static_cast(n_); + const double d2n1_div_n = d * (static_cast(n_) - 1) * d_div_n; + const int64_t n_poly = n_ * n_ - 3 * n_ + 3; + m1_ += d_div_n; + m4_ += d_div_n * (d_div_n * (d2n1_div_n * static_cast(n_poly) + 6.0 * m2_) - 4.0 * m3_); + m3_ += d_div_n * (d2n1_div_n * (static_cast(n_) - 2) - 3.0 * m2_); + m2_ += d2n1_div_n; + } + + void Assimilate(const Stats& other); + + int64_t Count() const { return n_; } + + float Min() const { return min_; } + float Max() const { return max_; } + + double GeometricMean() const { + return n_ == 0 ? 0.0 : std::exp(sum_log_ / static_cast(n_)); + } + + double Mean() const { return m1_; } + // Same as Mu2. Assumes n_ is large. + double SampleVariance() const { + return n_ == 0 ? 0.0 : m2_ / static_cast(n_); + } + // Unbiased estimator for population variance even for smaller n_. + double Variance() const { + if (n_ == 0) return 0.0; + if (n_ == 1) return m2_; + return m2_ / static_cast(n_ - 1); + } + double StandardDeviation() const { return std::sqrt(Variance()); } + // Near zero for normal distributions; if positive on a unimodal distribution, + // the right tail is fatter. Assumes n_ is large. + double SampleSkewness() const { + if (ScalarAbs(m2_) < 1E-7) return 0.0; + return m3_ * std::sqrt(static_cast(n_)) / std::pow(m2_, 1.5); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Skewness() const { + if (n_ == 0) return 0.0; + const double biased = SampleSkewness(); + const double r = (static_cast(n_) - 1.0) / static_cast(n_); + return biased * std::pow(r, 1.5); + } + // Near zero for normal distributions; smaller values indicate fewer/smaller + // outliers and larger indicates more/larger outliers. Assumes n_ is large. + double SampleKurtosis() const { + if (ScalarAbs(m2_) < 1E-7) return 0.0; + return m4_ * static_cast(n_) / (m2_ * m2_); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Kurtosis() const { + if (n_ == 0) return 0.0; + const double biased = SampleKurtosis(); + const double r = (static_cast(n_) - 1.0) / static_cast(n_); + return biased * r * r; + } + + // Central moments, useful for "method of moments"-based parameter estimation + // of a mixture of two Gaussians. Assumes Count() != 0. + double Mu1() const { return m1_; } + double Mu2() const { return m2_ / static_cast(n_); } + double Mu3() const { return m3_ / static_cast(n_); } + double Mu4() const { return m4_ / static_cast(n_); } + + // Which statistics to EXCLUDE in ToString + enum { + kNoCount = 1, + kNoMeanSD = 2, + kNoMinMax = 4, + kNoSkewKurt = 8, + kNoGeomean = 16 + }; + std::string ToString(int exclude = 0) const; + + void Reset() { + n_ = 0; + + min_ = hwy::HighestValue(); + max_ = hwy::LowestValue(); + + sum_log_ = 0.0; + + m1_ = 0.0; + m2_ = 0.0; + m3_ = 0.0; + m4_ = 0.0; + } + + private: + int64_t n_; // signed for faster conversion + safe subtraction + + float min_; + float max_; + + double sum_log_; // for geomean + + // Moments + double m1_; + double m2_; + double m3_; + double m4_; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_STATS_H_ diff --git a/third_party/aom/third_party/highway/hwy/targets.h b/third_party/aom/third_party/highway/hwy/targets.h new file mode 100644 index 000000000000..6f34c890fe60 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/targets.h @@ -0,0 +1,365 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_TARGETS_H_ +#define HIGHWAY_HWY_TARGETS_H_ + +// Allows opting out of C++ standard library usage, which is not available in +// some Compiler Explorer environments. +#ifndef HWY_NO_LIBCXX +#include +#endif + +// For SIMD module implementations and their callers. Defines which targets to +// generate and call. + +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/detect_targets.h" +#include "third_party/highway/hwy/highway_export.h" + +#if !defined(HWY_NO_LIBCXX) +#include +#endif + +namespace hwy { + +// Returns bitfield of enabled targets that are supported on this CPU; there is +// always at least one such target, hence the return value is never 0. The +// targets returned may change after calling DisableTargets. This function is +// always defined, but the HWY_SUPPORTED_TARGETS wrapper may allow eliding +// calls to it if there is only a single target enabled. +HWY_DLLEXPORT int64_t SupportedTargets(); + +// Evaluates to a function call, or literal if there is a single target. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) == 0 +#define HWY_SUPPORTED_TARGETS HWY_TARGETS +#else +#define HWY_SUPPORTED_TARGETS hwy::SupportedTargets() +#endif + +// Subsequent SupportedTargets will not return targets whose bit(s) are set in +// `disabled_targets`. Exception: if SupportedTargets would return 0, it will +// instead return HWY_STATIC_TARGET (there must always be one target to call). +// +// This function is useful for disabling targets known to be buggy, or if the +// best available target is undesirable (perhaps due to throttling or memory +// bandwidth limitations). Use SetSupportedTargetsForTest instead of this +// function for iteratively enabling specific targets for testing. +HWY_DLLEXPORT void DisableTargets(int64_t disabled_targets); + +// Subsequent SupportedTargets will return the given set of targets, except +// those disabled via DisableTargets. Call with a mask of 0 to disable the mock +// and return to the normal SupportedTargets behavior. Used to run tests for +// all targets. +HWY_DLLEXPORT void SetSupportedTargetsForTest(int64_t targets); + +#ifndef HWY_NO_LIBCXX + +// Return the list of targets in HWY_TARGETS supported by the CPU as a list of +// individual HWY_* target macros such as HWY_SCALAR or HWY_NEON. This list +// is affected by the current SetSupportedTargetsForTest() mock if any. +HWY_INLINE std::vector SupportedAndGeneratedTargets() { + std::vector ret; + for (int64_t targets = SupportedTargets() & HWY_TARGETS; targets != 0; + targets = targets & (targets - 1)) { + int64_t current_target = targets & ~(targets - 1); + ret.push_back(current_target); + } + return ret; +} + +#endif // HWY_NO_LIBCXX + +static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) { + switch (target) { +#if HWY_ARCH_X86 + case HWY_SSE2: + return "SSE2"; + case HWY_SSSE3: + return "SSSE3"; + case HWY_SSE4: + return "SSE4"; + case HWY_AVX2: + return "AVX2"; + case HWY_AVX3: + return "AVX3"; + case HWY_AVX3_DL: + return "AVX3_DL"; + case HWY_AVX3_ZEN4: + return "AVX3_ZEN4"; + case HWY_AVX10_2: + return "AVX10_2"; + case HWY_AVX3_SPR: + return "AVX3_SPR"; + case HWY_AVX10_2_512: + return "AVX10_2_512"; +#endif + +#if HWY_ARCH_ARM + case HWY_SVE2_128: + return "SVE2_128"; + case HWY_SVE_256: + return "SVE_256"; + case HWY_SVE2: + return "SVE2"; + case HWY_SVE: + return "SVE"; + case HWY_NEON_BF16: + return "NEON_BF16"; + case HWY_NEON: + return "NEON"; + case HWY_NEON_WITHOUT_AES: + return "NEON_WITHOUT_AES"; +#endif + +#if HWY_ARCH_PPC + case HWY_PPC8: + return "PPC8"; + case HWY_PPC9: + return "PPC9"; + case HWY_PPC10: + return "PPC10"; +#endif + +#if HWY_ARCH_S390X + case HWY_Z14: + return "Z14"; + case HWY_Z15: + return "Z15"; +#endif + +#if HWY_ARCH_WASM + case HWY_WASM: + return "WASM"; + case HWY_WASM_EMU256: + return "WASM_EMU256"; +#endif + +#if HWY_ARCH_RISCV + case HWY_RVV: + return "RVV"; +#endif + +#if HWY_ARCH_LOONGARCH + case HWY_LSX: + return "LSX"; + case HWY_LASX: + return "LASX"; +#endif + + case HWY_EMU128: + return "EMU128"; + case HWY_SCALAR: + return "SCALAR"; + + default: + return "Unknown"; // must satisfy gtest IsValidParamName() + } +} + +// The maximum number of dynamic targets on any architecture is defined by +// HWY_MAX_DYNAMIC_TARGETS and depends on the arch. + +// For the ChosenTarget mask and index we use a different bit arrangement than +// in the HWY_TARGETS mask. Only the targets involved in the current +// architecture are used in this mask, and therefore only the least significant +// (HWY_MAX_DYNAMIC_TARGETS + 2) bits of the int64_t mask are used. The least +// significant bit is set when the mask is not initialized, the next +// HWY_MAX_DYNAMIC_TARGETS more significant bits are a range of bits from the +// HWY_TARGETS or SupportedTargets() mask for the given architecture shifted to +// that position and the next more significant bit is used for HWY_SCALAR (if +// HWY_COMPILE_ONLY_SCALAR is defined) or HWY_EMU128. Because of this we need to +// define equivalent values for HWY_TARGETS in this representation. +// This mask representation allows to use ctz() on this mask and obtain a small +// number that's used as an index of the table for dynamic dispatch. In this +// way the first entry is used when the mask is uninitialized, the following +// HWY_MAX_DYNAMIC_TARGETS are for dynamic dispatch and the last one is for +// scalar. + +// The HWY_SCALAR/HWY_EMU128 bit in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_SCALAR (1LL << (HWY_MAX_DYNAMIC_TARGETS + 1)) + +// Converts from a HWY_TARGETS mask to a ChosenTarget mask format for the +// current architecture. +#define HWY_CHOSEN_TARGET_SHIFT(X) \ + ((((X) >> (HWY_HIGHEST_TARGET_BIT + 1 - HWY_MAX_DYNAMIC_TARGETS)) & \ + ((1LL << HWY_MAX_DYNAMIC_TARGETS) - 1)) \ + << 1) + +// The HWY_TARGETS mask in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_TARGETS \ + (HWY_CHOSEN_TARGET_SHIFT(HWY_TARGETS) | HWY_CHOSEN_TARGET_MASK_SCALAR | 1LL) + +#if HWY_ARCH_X86 +// Maximum number of dynamic targets, changing this value is an ABI incompatible +// change +#define HWY_MAX_DYNAMIC_TARGETS 15 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_X86 +// These must match the order in which the HWY_TARGETS are defined +// starting by the least significant (HWY_HIGHEST_TARGET_BIT + 1 - +// HWY_MAX_DYNAMIC_TARGETS) bit. This list must contain exactly +// HWY_MAX_DYNAMIC_TARGETS elements and does not include SCALAR. The first entry +// corresponds to the best target. Don't include a "," at the end of the list. +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_AVX10_2_512(func_name), /* AVX10_2_512 */ \ + HWY_CHOOSE_AVX3_SPR(func_name), /* AVX3_SPR */ \ + HWY_CHOOSE_AVX10_2(func_name), /* reserved */ \ + HWY_CHOOSE_AVX3_ZEN4(func_name), /* AVX3_ZEN4 */ \ + HWY_CHOOSE_AVX3_DL(func_name), /* AVX3_DL */ \ + HWY_CHOOSE_AVX3(func_name), /* AVX3 */ \ + HWY_CHOOSE_AVX2(func_name), /* AVX2 */ \ + nullptr, /* AVX */ \ + HWY_CHOOSE_SSE4(func_name), /* SSE4 */ \ + HWY_CHOOSE_SSSE3(func_name), /* SSSE3 */ \ + nullptr, /* reserved - SSE3? */ \ + HWY_CHOOSE_SSE2(func_name) /* SSE2 */ + +#elif HWY_ARCH_ARM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 15 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_ARM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_SVE2_128(func_name), /* SVE2 128-bit */ \ + HWY_CHOOSE_SVE_256(func_name), /* SVE 256-bit */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_SVE2(func_name), /* SVE2 */ \ + HWY_CHOOSE_SVE(func_name), /* SVE */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_NEON_BF16(func_name), /* NEON + f16/dot/bf16 */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_NEON(func_name), /* NEON */ \ + HWY_CHOOSE_NEON_WITHOUT_AES(func_name) /* NEON without AES */ + +#elif HWY_ARCH_RISCV +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_RVV +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_RVV(func_name), /* RVV */ \ + nullptr /* reserved */ + +#elif HWY_ARCH_PPC || HWY_ARCH_S390X +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_PPC +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_PPC10(func_name), /* PPC10 */ \ + HWY_CHOOSE_PPC9(func_name), /* PPC9 */ \ + HWY_CHOOSE_PPC8(func_name), /* PPC8 */ \ + HWY_CHOOSE_Z15(func_name), /* Z15 */ \ + HWY_CHOOSE_Z14(func_name) /* Z14 */ + +#elif HWY_ARCH_WASM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_WASM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_WASM_EMU256(func_name), /* WASM_EMU256 */ \ + HWY_CHOOSE_WASM(func_name), /* WASM */ \ + nullptr /* reserved */ + +#elif HWY_ARCH_LOONGARCH +#define HWY_MAX_DYNAMIC_TARGETS 3 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_LOONGARCH +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + HWY_CHOOSE_LASX(func_name), /* LASX */ \ + HWY_CHOOSE_LSX(func_name) /* LSX */ + +#else +// Unknown architecture, will use HWY_SCALAR without dynamic dispatch, though +// still creating single-entry tables in HWY_EXPORT to ensure portability. +#define HWY_MAX_DYNAMIC_TARGETS 1 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_SCALAR +#endif + +// Bitfield of supported and enabled targets. The format differs from that of +// HWY_TARGETS; the lowest bit governs the first function pointer (which is +// special in that it calls FunctionCache, then Update, then dispatches to the +// actual implementation) in the tables created by HWY_EXPORT. Monostate (see +// GetChosenTarget), thread-safe except on RVV. +struct ChosenTarget { + public: + // Reset bits according to `targets` (typically the return value of + // SupportedTargets()). Postcondition: IsInitialized() == true. + void Update(int64_t targets) { + // These are `targets` shifted downwards, see above. Also include SCALAR + // (corresponds to the last entry in the function table) as fallback. + StoreMask(HWY_CHOSEN_TARGET_SHIFT(targets) | HWY_CHOSEN_TARGET_MASK_SCALAR); + } + + // Reset to the uninitialized state, so that FunctionCache will call Update + // during the next HWY_DYNAMIC_DISPATCH, and IsInitialized returns false. + void DeInit() { StoreMask(1); } + + // Whether Update was called. This indicates whether any HWY_DYNAMIC_DISPATCH + // function was called, which we check in tests. + bool IsInitialized() const { return LoadMask() != 1; } + + // Return the index in the dynamic dispatch table to be used by the current + // CPU. Note that this method must be in the header file so it uses the value + // of HWY_CHOSEN_TARGET_MASK_TARGETS defined in the translation unit that + // calls it, which may be different from others. This means we only enable + // those targets that were actually compiled in this module. + size_t HWY_INLINE GetIndex() const { + return hwy::Num0BitsBelowLS1Bit_Nonzero64( + static_cast(LoadMask() & HWY_CHOSEN_TARGET_MASK_TARGETS)); + } + + private: +#if defined(HWY_NO_LIBCXX) + int64_t LoadMask() const { return mask_; } + void StoreMask(int64_t mask) { mask_ = mask; } + + int64_t mask_{1}; // Initialized to 1 so GetIndex() returns 0. +#else + int64_t LoadMask() const { return mask_.load(); } + void StoreMask(int64_t mask) { mask_.store(mask); } + + std::atomic mask_{1}; // Initialized to 1 so GetIndex() returns 0. +#endif // HWY_ARCH_RISCV +}; + +// For internal use (e.g. by FunctionCache and DisableTargets). +HWY_DLLEXPORT ChosenTarget& GetChosenTarget(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_TARGETS_H_ diff --git a/third_party/aom/third_party/highway/hwy/timer-inl.h b/third_party/aom/third_party/highway/hwy/timer-inl.h new file mode 100644 index 000000000000..acc5c65e4fa6 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/timer-inl.h @@ -0,0 +1,48 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// DEPRECATED, use timer.h instead. + +#include "third_party/highway/hwy/timer.h" + +#if defined(HIGHWAY_HWY_TIMER_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_TIMER_INL_H_ +#undef HIGHWAY_HWY_TIMER_INL_H_ +#else +#define HIGHWAY_HWY_TIMER_INL_H_ +#endif + +#include "third_party/highway/hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace timer { + +// Deprecated aliases so that old code still compiles. Prefer to use +// `hwy::timer::*` from timer.h because that does not require highway.h. +using Ticks = hwy::timer::Ticks; + +inline Ticks Start() { return hwy::timer::Start(); } +inline Ticks Stop() { return hwy::timer::Stop(); } + +} // namespace timer + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // per-target include guard diff --git a/third_party/aom/third_party/highway/hwy/timer.h b/third_party/aom/third_party/highway/hwy/timer.h new file mode 100644 index 000000000000..6d819c55bbff --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/timer.h @@ -0,0 +1,237 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_TIMER_H_ +#define HIGHWAY_HWY_TIMER_H_ + +// Platform-specific timer functions. Provides Now() and functions for +// interpreting and converting Ticks. + +#include +#include // clock_gettime + +#include "third_party/highway/hwy/base.h" + +#if defined(_WIN32) || defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#include +#endif + +#if defined(__APPLE__) +#include +#include +#endif + +#if defined(__HAIKU__) +#include +#endif + +#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__) +#include // NOLINT __ppc_get_timebase_freq +#endif + +#if HWY_ARCH_X86 && HWY_COMPILER_MSVC +#include +#endif + +namespace hwy { +namespace platform { + +// Returns current timestamp [in seconds] relative to an unspecified origin. +// Features: monotonic (no negative elapsed time), steady (unaffected by system +// time changes), high-resolution (on the order of microseconds). +// Uses InvariantTicksPerSecond and the baseline version of timer::Start(). +HWY_DLLEXPORT double Now(); + +// Functions related to `Ticks` below. + +// Returns whether it is safe to call timer::Stop without executing an illegal +// instruction; if false, fills cpu100 (a pointer to a 100 character buffer) +// via GetCpuString(). +HWY_DLLEXPORT bool HaveTimerStop(char* cpu100); + +// Returns tick rate, useful for converting timer::Ticks to seconds. Invariant +// means the tick counter frequency is independent of CPU throttling or sleep. +// This call may be expensive, callers should cache the result. +HWY_DLLEXPORT double InvariantTicksPerSecond(); + +// Returns ticks elapsed in back to back timer calls, i.e. a function of the +// timer resolution (minimum measurable difference) and overhead. +// This call is expensive, callers should cache the result. +HWY_DLLEXPORT uint64_t TimerResolution(); + +// Returns false if no detailed description is available, otherwise fills +// `cpu100` with up to 100 characters (including \0) identifying the CPU model. +HWY_DLLEXPORT bool GetCpuString(char* cpu100); + +} // namespace platform + +struct Timestamp { + Timestamp() { t = platform::Now(); } + double t; +}; + +static inline double SecondsSince(const Timestamp& t0) { + const Timestamp t1; + return t1.t - t0.t; +} + +// Low-level Start/Stop functions, previously in timer-inl.h. + +namespace timer { + +// Ticks := platform-specific timer values (CPU cycles on x86). Must be +// unsigned to guarantee wraparound on overflow. +using Ticks = uint64_t; + +// Start/Stop return absolute timestamps and must be placed immediately before +// and after the region to measure. We provide separate Start/Stop functions +// because they use different fences. +// +// Background: RDTSC is not 'serializing'; earlier instructions may complete +// after it, and/or later instructions may complete before it. 'Fences' ensure +// regions' elapsed times are independent of such reordering. The only +// documented unprivileged serializing instruction is CPUID, which acts as a +// full fence (no reordering across it in either direction). Unfortunately +// the latency of CPUID varies wildly (perhaps made worse by not initializing +// its EAX input). Because it cannot reliably be deducted from the region's +// elapsed time, it must not be included in the region to measure (i.e. +// between the two RDTSC). +// +// The newer RDTSCP is sometimes described as serializing, but it actually +// only serves as a half-fence with release semantics. Although all +// instructions in the region will complete before the final timestamp is +// captured, subsequent instructions may leak into the region and increase the +// elapsed time. Inserting another fence after the final `RDTSCP` would prevent +// such reordering without affecting the measured region. +// +// Fortunately, such a fence exists. The LFENCE instruction is only documented +// to delay later loads until earlier loads are visible. However, Intel's +// reference manual says it acts as a full fence (waiting until all earlier +// instructions have completed, and delaying later instructions until it +// completes). AMD assigns the same behavior to MFENCE. +// +// We need a fence before the initial RDTSC to prevent earlier instructions +// from leaking into the region, and arguably another after RDTSC to avoid +// region instructions from completing before the timestamp is recorded. +// When surrounded by fences, the additional `RDTSCP` half-fence provides no +// benefit, so the initial timestamp can be recorded via RDTSC, which has +// lower overhead than `RDTSCP` because it does not read TSC_AUX. In summary, +// we define Start = LFENCE/RDTSC/LFENCE; Stop = RDTSCP/LFENCE. +// +// Using Start+Start leads to higher variance and overhead than Stop+Stop. +// However, Stop+Stop includes an LFENCE in the region measurements, which +// adds a delay dependent on earlier loads. The combination of Start+Stop +// is faster than Start+Start and more consistent than Stop+Stop because +// the first LFENCE already delayed subsequent loads before the measured +// region. This combination seems not to have been considered in prior work: +// http://akaros.cs.berkeley.edu/lxr/akaros/kern/arch/x86/rdtsc_test.c +// +// Note: performance counters can measure 'exact' instructions-retired or +// (unhalted) cycle counts. The RDPMC instruction is not serializing and also +// requires fences. Unfortunately, it is not accessible on all OSes and we +// prefer to avoid kernel-mode drivers. Performance counters are also affected +// by several under/over-count errata, so we use the TSC instead. + +// Returns a 64-bit timestamp in unit of 'ticks'; to convert to seconds, +// divide by InvariantTicksPerSecond. +static HWY_INLINE Ticks Start() { + Ticks t; +#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC + // pmccntr_el0 is privileged but cntvct_el0 is accessible in Linux and QEMU. + asm volatile("mrs %0, cntvct_el0" : "=r"(t)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); + t = __rdtsc(); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rdx", "memory", "cc"); +#elif HWY_ARCH_RISCV + asm volatile("fence; rdtime %0" : "=r"(t)); +#elif defined(_WIN32) || defined(_WIN64) + LARGE_INTEGER counter; + (void)QueryPerformanceCounter(&counter); + t = counter.QuadPart; +#elif defined(__APPLE__) + t = mach_absolute_time(); +#elif defined(__HAIKU__) + t = system_time_nsecs(); // since boot +#else // POSIX + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + t = static_cast(ts.tv_sec * 1000000000LL + ts.tv_nsec); +#endif + return t; +} + +// WARNING: on x86, caller must check `HaveTimerStop()` before using this! +static HWY_INLINE Ticks Stop() { + uint64_t t; +#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC + // pmccntr_el0 is privileged but cntvct_el0 is accessible in Linux and QEMU. + asm volatile("mrs %0, cntvct_el0" : "=r"(t)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + unsigned aux; + t = __rdtscp(&aux); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rcx", "rdx", "memory", "cc"); +#else + t = Start(); +#endif + return t; +} + +} // namespace timer + +} // namespace hwy + +#endif // HIGHWAY_HWY_TIMER_H_ diff --git a/third_party/aom/third_party/highway/hwy/x86_cpuid.h b/third_party/aom/third_party/highway/hwy/x86_cpuid.h new file mode 100644 index 000000000000..2fcdb3c65451 --- /dev/null +++ b/third_party/aom/third_party/highway/hwy/x86_cpuid.h @@ -0,0 +1,81 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_X86_CPUID_H_ +#define HIGHWAY_HWY_X86_CPUID_H_ + +// Wrapper for x86 CPUID intrinsics. Empty on other platforms. + +#include + +#include "third_party/highway/hwy/base.h" + +#if HWY_ARCH_X86 + +#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL +#include +#else +#include +#endif + +namespace hwy { +namespace x86 { + +// Calls CPUID instruction with eax=level and ecx=count and returns the result +// in abcd array where abcd = {eax, ebx, ecx, edx} (hence the name abcd). +static inline void Cpuid(const uint32_t level, const uint32_t count, + uint32_t* HWY_RESTRICT abcd) { +#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL + int regs[4]; + __cpuidex(regs, level, count); + for (int i = 0; i < 4; ++i) { + abcd[i] = regs[i]; + } +#else // HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL + uint32_t a; + uint32_t b; + uint32_t c; + uint32_t d; + __cpuid_count(level, count, a, b, c, d); + abcd[0] = a; + abcd[1] = b; + abcd[2] = c; + abcd[3] = d; +#endif // HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL +} + +static inline bool IsBitSet(const uint32_t reg, const int index) { + return (reg & (1U << index)) != 0; +} + +static inline uint32_t MaxLevel() { + uint32_t abcd[4]; + Cpuid(0, 0, abcd); + return abcd[0]; +} + +static inline bool IsAMD() { + uint32_t abcd[4]; + Cpuid(0, 0, abcd); + const uint32_t max_level = abcd[0]; + return max_level >= 1 && abcd[1] == 0x68747541 && abcd[2] == 0x444d4163 && + abcd[3] == 0x69746e65; +} + +} // namespace x86 +} // namespace hwy + +#endif // HWY_ARCH_X86 +#endif // HIGHWAY_HWY_X86_CPUID_H_ diff --git a/third_party/aom/tools/txfm_analyzer/txfm_graph.cc b/third_party/aom/tools/txfm_analyzer/txfm_graph.cc index f46cc8faa8d1..35506e798b99 100644 --- a/third_party/aom/tools/txfm_analyzer/txfm_graph.cc +++ b/third_party/aom/tools/txfm_analyzer/txfm_graph.cc @@ -400,7 +400,6 @@ void gen_type3_graph(Node *node, int stage_num, int node_num, int stage_idx, for (int nj = 0; nj < N / 2; nj += N_over_i) { int j = nj / (N_over_i); int kj = bitwise_reverse(i / 4 + j, max_bit); - // printf("kj = %d\n", kj); // I_N/2i --- 0 int offset = nj; @@ -555,14 +554,12 @@ void gen_adst_B_graph(Node *node, int stage_num, int node_num, int stage_idx, int nIn = nOut + size / 2; connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, 1, nIn, 1); - // printf("nOut: %d nIn: %d\n", nOut, nIn); } for (int ni = size / 2; ni < size; ni++) { int nOut = node_idx + ni; int nIn = nOut - size / 2; connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, -1, nIn, 1); - // printf("ndctOut: %d nIn: %d\n", nOut, nIn); } } @@ -774,8 +771,6 @@ void connect_layer_2d(Node *node, int stage_num, int node_num, int stage_idx, int nIn = node_idx + first * dct_node_num + second; int nOut = node_idx + second * dct_node_num + first; - // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut); - connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0); } } @@ -791,8 +786,6 @@ void connect_layer_2d_new(Node *node, int stage_num, int node_num, int nIn = node_idx + i * dct_node_num0 + j; int nOut = node_idx + j * dct_node_num1 + i; - // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut); - connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0); } } diff --git a/third_party/aom/tools/update_highway.sh b/third_party/aom/tools/update_highway.sh new file mode 100755 index 000000000000..9fb0da8118d8 --- /dev/null +++ b/third_party/aom/tools/update_highway.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# Update third_party/highway to the latest version. + +# Usage: (under libaom root directory) +# ./tools/update_highway.sh + +set -e + +highway_dir="$(pwd)/third_party/highway" +repo_url="https://github.com/google/highway" + +git clone --depth 1 "$repo_url" "$highway_dir" + +cd "${highway_dir}" + +commit_hash=$(git rev-parse HEAD) + +# Remove everything except ./hwy +find . -mindepth 1 \ + -not -path "./hwy" \ + -not -path "./hwy/*" \ + -not -name "LICENSE-BSD3" \ + -delete + +# Remove tests/ directory +rm -rf hwy/tests/ + +# Remove markdown files +find . -name "*.md" -delete + +# Remove cc files since we build highway header-only +find . -name "*.cc" -delete + +# Update the include path +find ./hwy \( -name "*.c" -o -name "*.cc" -o -name "*.h" \) -print0 | \ + xargs -0 sed -i 's/#include "hwy\//#include "third_party\/highway\/hwy\//g' + +find ./hwy \( -name "*.c" -o -name "*.cc" -o -name "*.h" \) -print0 | \ + xargs -0 sed -i \ + 's/HWY_TARGET_INCLUDE "hwy\//HWY_TARGET_INCLUDE "third_party\/highway\/hwy\//g' + +cat > "${highway_dir}/README.libaom" <