Bug 1963116 - Update aom to 719f60edc51b6141a2434bf1b5110c2fb075b246 r=padenot

Differential Revision: https://phabricator.services.mozilla.com/D246962
This commit is contained in:
Updatebot
2025-05-12 09:56:22 +00:00
committed by padenot@mozilla.com
parent 4813ad4a4e
commit d74a40de37
127 changed files with 111981 additions and 194 deletions

View File

@@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0
CONFIG_GCC equ 1 CONFIG_GCC equ 1
CONFIG_GCOV equ 0 CONFIG_GCOV equ 0
CONFIG_GPROF equ 0 CONFIG_GPROF equ 0
CONFIG_HIGHWAY equ 0
CONFIG_INSPECTION equ 0 CONFIG_INSPECTION equ 0
CONFIG_INTERNAL_STATS equ 0 CONFIG_INTERNAL_STATS equ 0
CONFIG_INTER_STATS_ONLY equ 0 CONFIG_INTER_STATS_ONLY equ 0

View File

@@ -42,6 +42,7 @@
#define CONFIG_GCC 1 #define CONFIG_GCC 1
#define CONFIG_GCOV 0 #define CONFIG_GCOV 0
#define CONFIG_GPROF 0 #define CONFIG_GPROF 0
#define CONFIG_HIGHWAY 0
#define CONFIG_INSPECTION 0 #define CONFIG_INSPECTION 0
#define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTERNAL_STATS 0
#define CONFIG_INTER_STATS_ONLY 0 #define CONFIG_INTER_STATS_ONLY 0

View File

@@ -40,6 +40,7 @@
.equ CONFIG_GCC, 1 .equ CONFIG_GCC, 1
.equ CONFIG_GCOV, 0 .equ CONFIG_GCOV, 0
.equ CONFIG_GPROF, 0 .equ CONFIG_GPROF, 0
.equ CONFIG_HIGHWAY, 0
.equ CONFIG_INSPECTION, 0 .equ CONFIG_INSPECTION, 0
.equ CONFIG_INTERNAL_STATS, 0 .equ CONFIG_INTERNAL_STATS, 0
.equ CONFIG_INTER_STATS_ONLY, 0 .equ CONFIG_INTER_STATS_ONLY, 0

View File

@@ -42,6 +42,7 @@
#define CONFIG_GCC 1 #define CONFIG_GCC 1
#define CONFIG_GCOV 0 #define CONFIG_GCOV 0
#define CONFIG_GPROF 0 #define CONFIG_GPROF 0
#define CONFIG_HIGHWAY 0
#define CONFIG_INSPECTION 0 #define CONFIG_INSPECTION 0
#define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTERNAL_STATS 0
#define CONFIG_INTER_STATS_ONLY 0 #define CONFIG_INTER_STATS_ONLY 0

View File

@@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0
CONFIG_GCC equ 1 CONFIG_GCC equ 1
CONFIG_GCOV equ 0 CONFIG_GCOV equ 0
CONFIG_GPROF equ 0 CONFIG_GPROF equ 0
CONFIG_HIGHWAY equ 0
CONFIG_INSPECTION equ 0 CONFIG_INSPECTION equ 0
CONFIG_INTERNAL_STATS equ 0 CONFIG_INTERNAL_STATS equ 0
CONFIG_INTER_STATS_ONLY equ 0 CONFIG_INTER_STATS_ONLY equ 0

View File

@@ -42,6 +42,7 @@
#define CONFIG_GCC 1 #define CONFIG_GCC 1
#define CONFIG_GCOV 0 #define CONFIG_GCOV 0
#define CONFIG_GPROF 0 #define CONFIG_GPROF 0
#define CONFIG_HIGHWAY 0
#define CONFIG_INSPECTION 0 #define CONFIG_INSPECTION 0
#define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTERNAL_STATS 0
#define CONFIG_INTER_STATS_ONLY 0 #define CONFIG_INTER_STATS_ONLY 0

View File

@@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0
CONFIG_GCC equ 1 CONFIG_GCC equ 1
CONFIG_GCOV equ 0 CONFIG_GCOV equ 0
CONFIG_GPROF equ 0 CONFIG_GPROF equ 0
CONFIG_HIGHWAY equ 0
CONFIG_INSPECTION equ 0 CONFIG_INSPECTION equ 0
CONFIG_INTERNAL_STATS equ 0 CONFIG_INTERNAL_STATS equ 0
CONFIG_INTER_STATS_ONLY equ 0 CONFIG_INTER_STATS_ONLY equ 0

View File

@@ -42,6 +42,7 @@
#define CONFIG_GCC 1 #define CONFIG_GCC 1
#define CONFIG_GCOV 0 #define CONFIG_GCOV 0
#define CONFIG_GPROF 0 #define CONFIG_GPROF 0
#define CONFIG_HIGHWAY 0
#define CONFIG_INSPECTION 0 #define CONFIG_INSPECTION 0
#define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTERNAL_STATS 0
#define CONFIG_INTER_STATS_ONLY 0 #define CONFIG_INTER_STATS_ONLY 0

View File

@@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0
CONFIG_GCC equ 1 CONFIG_GCC equ 1
CONFIG_GCOV equ 0 CONFIG_GCOV equ 0
CONFIG_GPROF equ 0 CONFIG_GPROF equ 0
CONFIG_HIGHWAY equ 0
CONFIG_INSPECTION equ 0 CONFIG_INSPECTION equ 0
CONFIG_INTERNAL_STATS equ 0 CONFIG_INTERNAL_STATS equ 0
CONFIG_INTER_STATS_ONLY equ 0 CONFIG_INTER_STATS_ONLY equ 0

View File

@@ -42,6 +42,7 @@
#define CONFIG_GCC 1 #define CONFIG_GCC 1
#define CONFIG_GCOV 0 #define CONFIG_GCOV 0
#define CONFIG_GPROF 0 #define CONFIG_GPROF 0
#define CONFIG_HIGHWAY 0
#define CONFIG_INSPECTION 0 #define CONFIG_INSPECTION 0
#define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTERNAL_STATS 0
#define CONFIG_INTER_STATS_ONLY 0 #define CONFIG_INTER_STATS_ONLY 0

View File

@@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0
CONFIG_GCC equ 1 CONFIG_GCC equ 1
CONFIG_GCOV equ 0 CONFIG_GCOV equ 0
CONFIG_GPROF equ 0 CONFIG_GPROF equ 0
CONFIG_HIGHWAY equ 0
CONFIG_INSPECTION equ 0 CONFIG_INSPECTION equ 0
CONFIG_INTERNAL_STATS equ 0 CONFIG_INTERNAL_STATS equ 0
CONFIG_INTER_STATS_ONLY equ 0 CONFIG_INTER_STATS_ONLY equ 0

View File

@@ -42,6 +42,7 @@
#define CONFIG_GCC 1 #define CONFIG_GCC 1
#define CONFIG_GCOV 0 #define CONFIG_GCOV 0
#define CONFIG_GPROF 0 #define CONFIG_GPROF 0
#define CONFIG_HIGHWAY 0
#define CONFIG_INSPECTION 0 #define CONFIG_INSPECTION 0
#define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTERNAL_STATS 0
#define CONFIG_INTER_STATS_ONLY 0 #define CONFIG_INTER_STATS_ONLY 0

View File

@@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0
CONFIG_GCC equ 1 CONFIG_GCC equ 1
CONFIG_GCOV equ 0 CONFIG_GCOV equ 0
CONFIG_GPROF equ 0 CONFIG_GPROF equ 0
CONFIG_HIGHWAY equ 0
CONFIG_INSPECTION equ 0 CONFIG_INSPECTION equ 0
CONFIG_INTERNAL_STATS equ 0 CONFIG_INTERNAL_STATS equ 0
CONFIG_INTER_STATS_ONLY equ 0 CONFIG_INTER_STATS_ONLY equ 0

View File

@@ -42,6 +42,7 @@
#define CONFIG_GCC 1 #define CONFIG_GCC 1
#define CONFIG_GCOV 0 #define CONFIG_GCOV 0
#define CONFIG_GPROF 0 #define CONFIG_GPROF 0
#define CONFIG_HIGHWAY 0
#define CONFIG_INSPECTION 0 #define CONFIG_INSPECTION 0
#define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTERNAL_STATS 0
#define CONFIG_INTER_STATS_ONLY 0 #define CONFIG_INTER_STATS_ONLY 0

View File

@@ -40,6 +40,7 @@ CONFIG_FPMT_TEST equ 0
CONFIG_GCC equ 1 CONFIG_GCC equ 1
CONFIG_GCOV equ 0 CONFIG_GCOV equ 0
CONFIG_GPROF equ 0 CONFIG_GPROF equ 0
CONFIG_HIGHWAY equ 0
CONFIG_INSPECTION equ 0 CONFIG_INSPECTION equ 0
CONFIG_INTERNAL_STATS equ 0 CONFIG_INTERNAL_STATS equ 0
CONFIG_INTER_STATS_ONLY equ 0 CONFIG_INTER_STATS_ONLY equ 0

View File

@@ -42,6 +42,7 @@
#define CONFIG_GCC 1 #define CONFIG_GCC 1
#define CONFIG_GCOV 0 #define CONFIG_GCOV 0
#define CONFIG_GPROF 0 #define CONFIG_GPROF 0
#define CONFIG_HIGHWAY 0
#define CONFIG_INSPECTION 0 #define CONFIG_INSPECTION 0
#define CONFIG_INTERNAL_STATS 0 #define CONFIG_INTERNAL_STATS 0
#define CONFIG_INTER_STATS_ONLY 0 #define CONFIG_INTER_STATS_ONLY 0

View File

@@ -20,11 +20,11 @@ origin:
# Human-readable identifier for this version/release # Human-readable identifier for this version/release
# Generally "version NNN", "tag SSS", "bookmark SSS" # Generally "version NNN", "tag SSS", "bookmark SSS"
release: 4e3595a426bacb022e8152540a32753c43822f54 (Thu Mar 27 13:41:21 2025 -0700). release: 719f60edc51b6141a2434bf1b5110c2fb075b246 (Fri Apr 25 19:13:37 2025 -0700).
# Revision to pull in # Revision to pull in
# Must be a long or short commit SHA (long preferred) # Must be a long or short commit SHA (long preferred)
revision: 4e3595a426bacb022e8152540a32753c43822f54 revision: 719f60edc51b6141a2434bf1b5110c2fb075b246
# The package's license, where possible using the mnemonic from # The package's license, where possible using the mnemonic from
# https://spdx.org/licenses/ # https://spdx.org/licenses/

View File

@@ -32,6 +32,7 @@ Arild Fuldseth <arilfuld@cisco.com>
Aron Rosenberg <arosenberg@logitech.com> Aron Rosenberg <arosenberg@logitech.com>
Arpad Panyik <Arpad.Panyik@arm.com> Arpad Panyik <Arpad.Panyik@arm.com>
Arun Singh Negi <arun.negi@ittiam.com> Arun Singh Negi <arun.negi@ittiam.com>
Athulya Raj Raji Mohini <AthulyaRaj.RajiMohini@arm.com>
Attila Nagy <attilanagy@google.com> Attila Nagy <attilanagy@google.com>
Balaji Anandapadmanaban <balaji.anandapadmanaban@arm.com> Balaji Anandapadmanaban <balaji.anandapadmanaban@arm.com>
Bohan Li <bohanli@google.com> Bohan Li <bohanli@google.com>

View File

@@ -1,3 +1,19 @@
2025-04-11 v3.12.1
This release includes several bug fixes. This release is ABI
compatible with the last release. See
https://aomedia.googlesource.com/aom/+log/v3.12.0..v3.12.1 for all the
commits in this release.
- Bug Fixes
* b:396169342: Assertion
`av1_is_subpelmv_in_range(&ms_params.mv_limits, start_mv)' failed.
* b:401671154: typo in void init_src_params(...)
* Coverity defect 323670: Uninitialized scalar variable in
encode_with_and_without_superres()
* cmake: bump minimum version to 3.16
* cfl_ppc: fix subtract_average_vsx
* Fix an incorrect index in av1_highbd_pixel_proj_error_neon
2025-02-10 v3.12.0 2025-02-10 v3.12.0
This release includes new codec interfaces, compression efficiency and This release includes new codec interfaces, compression efficiency and
perceptual improvements, speedup and memory optimizations, and bug perceptual improvements, speedup and memory optimizations, and bug

View File

@@ -59,7 +59,7 @@ endif()
# #
# The VERSION number in project() should be updated when these variables are. # The VERSION number in project() should be updated when these variables are.
set(LT_CURRENT 15) set(LT_CURRENT 15)
set(LT_REVISION 0) set(LT_REVISION 1)
set(LT_AGE 12) set(LT_AGE 12)
math(EXPR SO_VERSION "${LT_CURRENT} - ${LT_AGE}") math(EXPR SO_VERSION "${LT_CURRENT} - ${LT_AGE}")
set(SO_FILE_VERSION "${SO_VERSION}.${LT_AGE}.${LT_REVISION}") set(SO_FILE_VERSION "${SO_VERSION}.${LT_AGE}.${LT_REVISION}")
@@ -270,6 +270,43 @@ add_rtcd_build_step("${AOM_ROOT}/av1/common/av1_rtcd_defs.pl"
add_library(aom_rtcd OBJECT ${AOM_RTCD_SOURCES}) add_library(aom_rtcd OBJECT ${AOM_RTCD_SOURCES})
add_dependencies(aom_rtcd aom_version) add_dependencies(aom_rtcd aom_version)
if(CONFIG_HIGHWAY)
list(APPEND AOM_HIGHWAY_SOURCES
"${AOM_ROOT}/third_party/highway/hwy/abort.h"
"${AOM_ROOT}/third_party/highway/hwy/aligned_allocator.h"
"${AOM_ROOT}/third_party/highway/hwy/base.h"
"${AOM_ROOT}/third_party/highway/hwy/cache_control.h"
"${AOM_ROOT}/third_party/highway/hwy/per_target.h"
"${AOM_ROOT}/third_party/highway/hwy/print.h"
"${AOM_ROOT}/third_party/highway/hwy/foreach_target.h"
"${AOM_ROOT}/third_party/highway/hwy/highway_export.h"
"${AOM_ROOT}/third_party/highway/hwy/highway.h"
"${AOM_ROOT}/third_party/highway/hwy/print-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/timer-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/detect_compiler_arch.h"
"${AOM_ROOT}/third_party/highway/hwy/detect_targets.h"
"${AOM_ROOT}/third_party/highway/hwy/targets.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/arm_neon-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/arm_sve-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/emu128-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/generic_ops-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/scalar-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/set_macros-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/shared-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/x86_128-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/x86_256-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/x86_512-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/ops/x86_avx3-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/contrib/algo/copy-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/contrib/algo/find-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/contrib/algo/transform-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/contrib/dot/dot-inl.h"
"${AOM_ROOT}/third_party/highway/hwy/contrib/image/image.h"
"${AOM_ROOT}/third_party/highway/hwy/contrib/math/math-inl.h")
add_library(aom_hwy OBJECT ${AOM_HIGHWAY_SOURCES})
set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_hwy)
endif()
if(ENABLE_EXAMPLES) if(ENABLE_EXAMPLES)
add_library(aom_encoder_stats OBJECT ${AOM_ENCODER_STATS_SOURCES}) add_library(aom_encoder_stats OBJECT ${AOM_ENCODER_STATS_SOURCES})
set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_encoder_stats) set(AOM_LIB_TARGETS ${AOM_LIB_TARGETS} aom_encoder_stats)

View File

@@ -181,6 +181,11 @@ if(CONFIG_AV1_ENCODER)
"${AOM_ROOT}/aom_dsp/variance.c" "${AOM_ROOT}/aom_dsp/variance.c"
"${AOM_ROOT}/aom_dsp/variance.h") "${AOM_ROOT}/aom_dsp/variance.h")
if(CONFIG_HIGHWAY)
list(APPEND AOM_DSP_ENCODER_SOURCES "${AOM_ROOT}/aom_dsp/reduce_sum_hwy.h"
"${AOM_ROOT}/aom_dsp/sad_hwy.h")
endif()
# Flow estimation library and grain/noise table/model. # Flow estimation library and grain/noise table/model.
if(NOT CONFIG_REALTIME_ONLY) if(NOT CONFIG_REALTIME_ONLY)
list(APPEND AOM_DSP_ENCODER_SOURCES list(APPEND AOM_DSP_ENCODER_SOURCES
@@ -259,6 +264,11 @@ if(CONFIG_AV1_ENCODER)
"${AOM_ROOT}/aom_dsp/x86/blk_sse_sum_avx2.c" "${AOM_ROOT}/aom_dsp/x86/blk_sse_sum_avx2.c"
"${AOM_ROOT}/aom_dsp/x86/sum_squares_avx2.c") "${AOM_ROOT}/aom_dsp/x86/sum_squares_avx2.c")
if(CONFIG_HIGHWAY)
list(APPEND AOM_DSP_ENCODER_INTRIN_AVX2
"${AOM_ROOT}/aom_dsp/x86/sad_hwy_avx2.cc")
endif()
list(APPEND AOM_DSP_ENCODER_INTRIN_AVX list(APPEND AOM_DSP_ENCODER_INTRIN_AVX
"${AOM_ROOT}/aom_dsp/x86/aom_quantize_avx.c") "${AOM_ROOT}/aom_dsp/x86/aom_quantize_avx.c")

View File

@@ -0,0 +1,69 @@
/*
* Copyright (c) 2025, Alliance for Open Media. All rights reserved.
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#ifndef AOM_AOM_DSP_REDUCE_SUM_HWY_H_
#define AOM_AOM_DSP_REDUCE_SUM_HWY_H_
#include <type_traits>
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <size_t NumBlocks>
struct BlockReduceTraits;
template <>
struct BlockReduceTraits<1> {
template <typename D>
HWY_ATTR HWY_INLINE static hn::VFromD<D> ReduceSum(D d, hn::VFromD<D> v) {
(void)d;
return v;
}
};
template <size_t NumBlocks>
struct BlockReduceTraits {
static_assert(NumBlocks > 1,
"Primary template BlockReduceTraits assumes NumBlocks > 1");
static_assert((NumBlocks & (NumBlocks - 1)) == 0,
"BlockReduceTraits requires NumBlocks to be a power of 2.");
template <typename D>
HWY_ATTR HWY_INLINE static hn::VFromD<hn::BlockDFromD<D>> ReduceSum(
D d, hn::VFromD<D> v) {
(void)d;
constexpr hn::Half<D> half_d;
auto v_half = hn::Add(hn::LowerHalf(half_d, v), hn::UpperHalf(half_d, v));
return BlockReduceTraits<NumBlocks / 2>::ReduceSum(half_d, v_half);
}
};
// ReduceSum across blocks.
// For example, with a 4-block vector with 16 lanes of uint32_t:
// [a3 b3 c3 d3 a2 b2 c2 d2 a1 b1 c1 d1 a0 b0 c0 d0]
// returns a vector with 4 lanes:
// [a3+a2+a1+a0 b3+b2+b1+b0 c3+c2+c1+c0 d3+d2+d1+d0]
template <typename D>
HWY_ATTR HWY_INLINE hn::Vec<hn::BlockDFromD<D>> BlockReduceSum(
D int_tag, hn::VFromD<D> v) {
return BlockReduceTraits<int_tag.MaxBlocks()>::ReduceSum(int_tag, v);
}
} // namespace HWY_NAMESPACE
} // namespace
HWY_AFTER_NAMESPACE();
#endif // AOM_AOM_DSP_REDUCE_SUM_HWY_H_

77
third_party/aom/aom_dsp/sad_hwy.h vendored Normal file
View File

@@ -0,0 +1,77 @@
/*
* Copyright (c) 2025, Alliance for Open Media. All rights reserved.
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#ifndef AOM_AOM_DSP_SAD_HWY_H_
#define AOM_AOM_DSP_SAD_HWY_H_
#include "aom_dsp/reduce_sum_hwy.h"
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <int BlockWidth>
HWY_MAYBE_UNUSED unsigned int SumOfAbsoluteDiff(
const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
int ref_stride, int h, const uint8_t *second_pred = nullptr) {
constexpr hn::CappedTag<uint8_t, BlockWidth> pixel_tag;
constexpr hn::Repartition<uint64_t, decltype(pixel_tag)> intermediate_sum_tag;
const int vw = hn::Lanes(pixel_tag);
auto sum_sad = hn::Zero(intermediate_sum_tag);
const bool is_sad_avg = second_pred != nullptr;
for (int i = 0; i < h; ++i) {
for (int j = 0; j < BlockWidth; j += vw) {
auto src_vec = hn::LoadU(pixel_tag, &src_ptr[j]);
auto ref_vec = hn::LoadU(pixel_tag, &ref_ptr[j]);
if (is_sad_avg) {
auto sec_pred_vec = hn::LoadU(pixel_tag, &second_pred[j]);
ref_vec = hn::AverageRound(ref_vec, sec_pred_vec);
}
auto sad = hn::SumsOf8AbsDiff(src_vec, ref_vec);
sum_sad = hn::Add(sum_sad, sad);
}
src_ptr += src_stride;
ref_ptr += ref_stride;
if (is_sad_avg) {
second_pred += BlockWidth;
}
}
return static_cast<unsigned int>(
hn::ReduceSum(intermediate_sum_tag, sum_sad));
}
} // namespace HWY_NAMESPACE
} // namespace
#define FSAD(w, h, suffix) \
extern "C" unsigned int aom_sad##w##x##h##_##suffix( \
const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
int ref_stride); \
HWY_ATTR unsigned int aom_sad##w##x##h##_##suffix( \
const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
int ref_stride) { \
return HWY_NAMESPACE::SumOfAbsoluteDiff<w>(src_ptr, src_stride, ref_ptr, \
ref_stride, h); \
}
#define FOR_EACH_SAD_BLOCK_SIZE(X, suffix) \
X(128, 128, suffix) \
X(128, 64, suffix) \
X(64, 128, suffix) \
X(64, 64, suffix) \
X(64, 32, suffix)
HWY_AFTER_NAMESPACE();
#endif // AOM_AOM_DSP_SAD_HWY_H_

View File

@@ -101,11 +101,17 @@ static inline unsigned int sad32xh_avx2(const uint8_t *src_ptr, int src_stride,
h / 2); \ h / 2); \
} }
#if CONFIG_HIGHWAY
#define FSAD64 \
FSADS64_H(64) \
FSADS64_H(32)
#else
#define FSAD64 \ #define FSAD64 \
FSAD64_H(64) \ FSAD64_H(64) \
FSAD64_H(32) \ FSAD64_H(32) \
FSADS64_H(64) \ FSADS64_H(64) \
FSADS64_H(32) FSADS64_H(32)
#endif
#define FSAD32 \ #define FSAD32 \
FSAD32_H(64) \ FSAD32_H(64) \

View File

@@ -0,0 +1,17 @@
/*
* Copyright (c) 2016, Alliance for Open Media. All rights reserved.
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#define HWY_BASELINE_TARGETS HWY_AVX2
#define HWY_BROKEN_32BIT 0
#include "aom_dsp/sad_hwy.h"
FOR_EACH_SAD_BLOCK_SIZE(FSAD, avx2)

View File

@@ -56,6 +56,7 @@ static unsigned int sad64x64(const uint8_t *src_ptr, int src_stride,
return sum; return sum;
} }
#if !CONFIG_HIGHWAY
unsigned int aom_sad128x64_avx2(const uint8_t *src_ptr, int src_stride, unsigned int aom_sad128x64_avx2(const uint8_t *src_ptr, int src_stride,
const uint8_t *ref_ptr, int ref_stride) { const uint8_t *ref_ptr, int ref_stride) {
unsigned int half_width = 64; unsigned int half_width = 64;
@@ -83,6 +84,7 @@ unsigned int aom_sad128x128_avx2(const uint8_t *src_ptr, int src_stride,
sum += aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride); sum += aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride);
return sum; return sum;
} }
#endif
unsigned int aom_sad_skip_128x64_avx2(const uint8_t *src_ptr, int src_stride, unsigned int aom_sad_skip_128x64_avx2(const uint8_t *src_ptr, int src_stride,
const uint8_t *ref_ptr, int ref_stride) { const uint8_t *ref_ptr, int ref_stride) {

View File

@@ -46,16 +46,6 @@ static inline __m128i xx_loadu_128(const void *a) {
return _mm_loadu_si128((const __m128i *)a); return _mm_loadu_si128((const __m128i *)a);
} }
// _mm_loadu_si64 has been introduced in GCC 9, reimplement the function
// manually on older compilers.
#if !defined(__clang__) && __GNUC_MAJOR__ < 9
static inline __m128i xx_loadu_2x64(const void *hi, const void *lo) {
__m64 hi_, lo_;
memcpy(&hi_, hi, sizeof(hi_));
memcpy(&lo_, lo, sizeof(lo_));
return _mm_set_epi64(hi_, lo_);
}
#else
// Load 64 bits from each of hi and low, and pack into an SSE register // Load 64 bits from each of hi and low, and pack into an SSE register
// Since directly loading as `int64_t`s and using _mm_set_epi64 may violate // Since directly loading as `int64_t`s and using _mm_set_epi64 may violate
// the strict aliasing rule, this takes a different approach // the strict aliasing rule, this takes a different approach
@@ -63,7 +53,6 @@ static inline __m128i xx_loadu_2x64(const void *hi, const void *lo) {
return _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)lo), return _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)lo),
_mm_loadl_epi64((const __m128i *)hi)); _mm_loadl_epi64((const __m128i *)hi));
} }
#endif
static inline void xx_storel_32(void *const a, const __m128i v) { static inline void xx_storel_32(void *const a, const __m128i v) {
const int val = _mm_cvtsi128_si32(v); const int val = _mm_cvtsi128_si32(v);

View File

@@ -76,26 +76,11 @@ static inline __m256i yy_loadu_4x64(const void *e3, const void *e2,
return yy_set_m128i(_mm_castpd_si128(v23), _mm_castpd_si128(v01)); return yy_set_m128i(_mm_castpd_si128(v23), _mm_castpd_si128(v01));
} }
#define GCC_VERSION (__GNUC__ * 10000 \
+ __GNUC_MINOR__ * 100 \
+ __GNUC_PATCHLEVEL__)
// _mm256_loadu2_m128i has been introduced in GCC 10.1
#if !defined(__clang__) && GCC_VERSION < 101000
static inline __m256i yy_loadu2_128(const void *hi, const void *lo) {
__m128i mhi = _mm_loadu_si128((const __m128i *)(hi));
__m128i mlo = _mm_loadu_si128((const __m128i *)(lo));
return _mm256_set_m128i(mhi, mlo);
}
#else
static inline __m256i yy_loadu2_128(const void *hi, const void *lo) { static inline __m256i yy_loadu2_128(const void *hi, const void *lo) {
__m128i mhi = _mm_loadu_si128((const __m128i *)(hi)); __m128i mhi = _mm_loadu_si128((const __m128i *)(hi));
__m128i mlo = _mm_loadu_si128((const __m128i *)(lo)); __m128i mlo = _mm_loadu_si128((const __m128i *)(lo));
return yy_set_m128i(mhi, mlo); return yy_set_m128i(mhi, mlo);
} }
#endif
#undef GCC_VERSION
static inline void yy_storeu2_128(void *hi, void *lo, const __m256i a) { static inline void yy_storeu2_128(void *hi, void *lo, const __m256i a) {
_mm_storeu_si128((__m128i *)hi, _mm256_extracti128_si256(a, 1)); _mm_storeu_si128((__m128i *)hi, _mm256_extracti128_si256(a, 1));

View File

@@ -171,6 +171,19 @@ static inline uint64_t xgetbv(void) {
#define BIT(n) (1u << (n)) #define BIT(n) (1u << (n))
#endif #endif
#define MMX_BITS BIT(23)
#define SSE_BITS BIT(25)
#define SSE2_BITS BIT(26)
#define SSE3_BITS BIT(0)
#define SSSE3_BITS BIT(9)
#define SSE4_1_BITS BIT(19)
// Bits 27 (OSXSAVE) & 28 (256-bit AVX)
#define AVX_BITS (BIT(27) | BIT(28))
#define AVX2_BITS BIT(5)
#define FEATURE_SET(reg, feature) \
(((reg) & (feature##_BITS)) == (feature##_BITS))
static inline int x86_simd_caps(void) { static inline int x86_simd_caps(void) {
unsigned int flags = 0; unsigned int flags = 0;
unsigned int mask = ~0u; unsigned int mask = ~0u;
@@ -179,11 +192,9 @@ static inline int x86_simd_caps(void) {
/* See if the CPU capabilities are being overridden by the environment */ /* See if the CPU capabilities are being overridden by the environment */
env = getenv("AOM_SIMD_CAPS"); env = getenv("AOM_SIMD_CAPS");
if (env && *env) return (int)strtol(env, NULL, 0); if (env && *env) return (int)strtol(env, NULL, 0);
env = getenv("AOM_SIMD_CAPS_MASK"); env = getenv("AOM_SIMD_CAPS_MASK");
if (env && *env) mask = (unsigned int)strtoul(env, NULL, 0); if (env && *env) mask = (unsigned int)strtoul(env, NULL, 0);
/* Ensure that the CPUID instruction supports extended features */ /* Ensure that the CPUID instruction supports extended features */
@@ -194,37 +205,26 @@ static inline int x86_simd_caps(void) {
/* Get the standard feature flags */ /* Get the standard feature flags */
cpuid(1, 0, reg_eax, reg_ebx, reg_ecx, reg_edx); cpuid(1, 0, reg_eax, reg_ebx, reg_ecx, reg_edx);
if (reg_edx & BIT(23)) flags |= HAS_MMX; flags |= FEATURE_SET(reg_edx, MMX) ? HAS_MMX : 0;
flags |= FEATURE_SET(reg_edx, SSE) ? HAS_SSE : 0;
if (reg_edx & BIT(25)) flags |= HAS_SSE; /* aka xmm */ flags |= FEATURE_SET(reg_edx, SSE2) ? HAS_SSE2 : 0;
flags |= FEATURE_SET(reg_ecx, SSE3) ? HAS_SSE3 : 0;
if (reg_edx & BIT(26)) flags |= HAS_SSE2; /* aka wmt */ flags |= FEATURE_SET(reg_ecx, SSSE3) ? HAS_SSSE3 : 0;
flags |= FEATURE_SET(reg_ecx, SSE4_1) ? HAS_SSE4_1 : 0;
if (reg_ecx & BIT(0)) flags |= HAS_SSE3;
if (reg_ecx & BIT(9)) flags |= HAS_SSSE3;
if (reg_ecx & BIT(19)) flags |= HAS_SSE4_1;
if (reg_ecx & BIT(20)) flags |= HAS_SSE4_2;
// bits 27 (OSXSAVE) & 28 (256-bit AVX) // bits 27 (OSXSAVE) & 28 (256-bit AVX)
if ((reg_ecx & (BIT(27) | BIT(28))) == (BIT(27) | BIT(28))) { if (FEATURE_SET(reg_ecx, AVX)) {
// Check for OS-support of YMM state. Necessary for AVX and AVX2. // Check for OS-support of YMM state. Necessary for AVX and AVX2.
if ((xgetbv() & 0x6) == 0x6) { if ((xgetbv() & 0x6) == 0x6) {
flags |= HAS_AVX; flags |= HAS_AVX;
if (max_cpuid_val >= 7) { if (max_cpuid_val >= 7) {
/* Get the leaf 7 feature flags. Needed to check for AVX2 support */ /* Get the leaf 7 feature flags. Needed to check for AVX2 support */
cpuid(7, 0, reg_eax, reg_ebx, reg_ecx, reg_edx); cpuid(7, 0, reg_eax, reg_ebx, reg_ecx, reg_edx);
flags |= FEATURE_SET(reg_ebx, AVX2) ? HAS_AVX2 : 0;
if (reg_ebx & BIT(5)) flags |= HAS_AVX2;
} }
} }
} }
(void)reg_eax; // Avoid compiler warning on unused-but-set variable. (void)reg_eax; // Avoid compiler warning on unused-but-set variable.
return flags & mask; return flags & mask;
} }

View File

@@ -3883,6 +3883,38 @@ static aom_codec_err_t ctrl_set_svc_params(aom_codec_alg_priv_t *ctx,
ppi->number_temporal_layers = params->number_temporal_layers; ppi->number_temporal_layers = params->number_temporal_layers;
cpi->svc.number_spatial_layers = params->number_spatial_layers; cpi->svc.number_spatial_layers = params->number_spatial_layers;
cpi->svc.number_temporal_layers = params->number_temporal_layers; cpi->svc.number_temporal_layers = params->number_temporal_layers;
// Sequence parameters (operating_points_cnt_minus_1, operating_point_idc[])
// need to be updated if the number of layers have changed.
// Force a keyframe here and update the two relevant sequence parameters.
if (cpi->svc.prev_number_temporal_layers &&
cpi->svc.prev_number_spatial_layers &&
(cpi->svc.number_temporal_layers !=
cpi->svc.prev_number_temporal_layers ||
cpi->svc.number_spatial_layers != cpi->svc.prev_number_spatial_layers)) {
SequenceHeader *const seq_params = &ppi->seq_params;
seq_params->operating_points_cnt_minus_1 =
ppi->number_spatial_layers * ppi->number_temporal_layers - 1;
ctx->next_frame_flags |= AOM_EFLAG_FORCE_KF;
av1_set_svc_seq_params(ppi);
// Check for valid values for the spatial/temporal_layer_id here, since
// there has been a dynamic change in the number_spatial/temporal_layers,
// and if the ctrl_set_layer_id is not used after this call, the
// previous (last_encoded) values of spatial/temporal_layer_id will be used,
// which may be invalid.
cpi->svc.spatial_layer_id = AOMMAX(
0,
AOMMIN(cpi->svc.spatial_layer_id, cpi->svc.number_spatial_layers - 1));
cpi->svc.temporal_layer_id =
AOMMAX(0, AOMMIN(cpi->svc.temporal_layer_id,
cpi->svc.number_temporal_layers - 1));
cpi->common.spatial_layer_id =
AOMMAX(0, AOMMIN(cpi->common.spatial_layer_id,
cpi->svc.number_spatial_layers - 1));
cpi->common.temporal_layer_id =
AOMMAX(0, AOMMIN(cpi->common.temporal_layer_id,
cpi->svc.number_temporal_layers - 1));
}
if (ppi->number_spatial_layers > 1 || ppi->number_temporal_layers > 1) { if (ppi->number_spatial_layers > 1 || ppi->number_temporal_layers > 1) {
unsigned int sl, tl; unsigned int sl, tl;
ctx->ppi->use_svc = 1; ctx->ppi->use_svc = 1;

View File

@@ -19,7 +19,6 @@
#define OFF_1 16 #define OFF_1 16
#define OFF_2 32 #define OFF_2 32
#define OFF_3 48 #define OFF_3 48
#define CFL_BUF_LINE_BYTES 64
#define CFL_LINE_1 64 #define CFL_LINE_1 64
#define CFL_LINE_2 128 #define CFL_LINE_2 128
#define CFL_LINE_3 192 #define CFL_LINE_3 192
@@ -35,8 +34,6 @@ typedef vector unsigned long long uint64x2_t; // NOLINT(runtime/int)
static inline void subtract_average_vsx(const uint16_t *src_ptr, int16_t *dst, static inline void subtract_average_vsx(const uint16_t *src_ptr, int16_t *dst,
int width, int height, int round_offset, int width, int height, int round_offset,
int num_pel_log2) { int num_pel_log2) {
// int16_t *dst = dst_ptr;
const int16_t *dst_end = dst + height * CFL_BUF_LINE;
const int16_t *sum_buf = (const int16_t *)src_ptr; const int16_t *sum_buf = (const int16_t *)src_ptr;
const int16_t *end = sum_buf + height * CFL_BUF_LINE; const int16_t *end = sum_buf + height * CFL_BUF_LINE;
const uint32x4_t div_shift = vec_splats((uint32_t)num_pel_log2); const uint32x4_t div_shift = vec_splats((uint32_t)num_pel_log2);
@@ -63,7 +60,8 @@ static inline void subtract_average_vsx(const uint16_t *src_ptr, int16_t *dst,
sum_32x4_1 = sum_32x4_1 =
vec_sum4s(vec_vsx_ld(OFF_3 + CFL_LINE_1, sum_buf), sum_32x4_1); vec_sum4s(vec_vsx_ld(OFF_3 + CFL_LINE_1, sum_buf), sum_32x4_1);
} }
} while ((sum_buf += (CFL_BUF_LINE * 2)) < end); sum_buf += CFL_BUF_LINE * 2;
} while (sum_buf < end);
int32x4_t sum_32x4 = vec_add(sum_32x4_0, sum_32x4_1); int32x4_t sum_32x4 = vec_add(sum_32x4_0, sum_32x4_1);
const int32x4_t perm_64 = vec_perm(sum_32x4, sum_32x4, mask_64); const int32x4_t perm_64 = vec_perm(sum_32x4, sum_32x4, mask_64);
@@ -72,41 +70,44 @@ static inline void subtract_average_vsx(const uint16_t *src_ptr, int16_t *dst,
sum_32x4 = vec_add(sum_32x4, perm_32); sum_32x4 = vec_add(sum_32x4, perm_32);
const int32x4_t avg = vec_sr(sum_32x4, div_shift); const int32x4_t avg = vec_sr(sum_32x4, div_shift);
const int16x8_t vec_avg = vec_pack(avg, avg); const int16x8_t vec_avg = vec_pack(avg, avg);
const int16_t *src = (const int16_t *)src_ptr;
do { do {
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0, dst), vec_avg), OFF_0, dst); vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0, src), vec_avg), OFF_0, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_1, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_1, src), vec_avg),
OFF_0 + CFL_BUF_LINE_BYTES, dst); OFF_0 + CFL_LINE_1, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_2, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_2, src), vec_avg),
OFF_0 + CFL_LINE_2, dst); OFF_0 + CFL_LINE_2, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_3, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_0 + CFL_LINE_3, src), vec_avg),
OFF_0 + CFL_LINE_3, dst); OFF_0 + CFL_LINE_3, dst);
if (width >= 16) { if (width >= 16) {
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1, dst), vec_avg), OFF_1, dst); vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1, src), vec_avg), OFF_1, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_1, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_1, src), vec_avg),
OFF_1 + CFL_LINE_1, dst); OFF_1 + CFL_LINE_1, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_2, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_2, src), vec_avg),
OFF_1 + CFL_LINE_2, dst); OFF_1 + CFL_LINE_2, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_3, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_1 + CFL_LINE_3, src), vec_avg),
OFF_1 + CFL_LINE_3, dst); OFF_1 + CFL_LINE_3, dst);
} }
if (width == 32) { if (width == 32) {
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2, dst), vec_avg), OFF_2, dst); vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2, src), vec_avg), OFF_2, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_1, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_1, src), vec_avg),
OFF_2 + CFL_LINE_1, dst); OFF_2 + CFL_LINE_1, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_2, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_2, src), vec_avg),
OFF_2 + CFL_LINE_2, dst); OFF_2 + CFL_LINE_2, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_3, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_2 + CFL_LINE_3, src), vec_avg),
OFF_2 + CFL_LINE_3, dst); OFF_2 + CFL_LINE_3, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3, dst), vec_avg), OFF_3, dst); vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3, src), vec_avg), OFF_3, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_1, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_1, src), vec_avg),
OFF_3 + CFL_LINE_1, dst); OFF_3 + CFL_LINE_1, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_2, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_2, src), vec_avg),
OFF_3 + CFL_LINE_2, dst); OFF_3 + CFL_LINE_2, dst);
vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_3, dst), vec_avg), vec_vsx_st(vec_sub(vec_vsx_ld(OFF_3 + CFL_LINE_3, src), vec_avg),
OFF_3 + CFL_LINE_3, dst); OFF_3 + CFL_LINE_3, dst);
} }
} while ((dst += CFL_BUF_LINE * 4) < dst_end); src += CFL_BUF_LINE * 4;
dst += CFL_BUF_LINE * 4;
} while (src < end);
} }
// Declare wrappers for VSX sizes // Declare wrappers for VSX sizes

View File

@@ -235,12 +235,6 @@ static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N, model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
&rate, &dist); &rate, &dist);
// int rate2;
// int64_t dist2;
// model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2);
// printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n",
// sse, rate, dist, rate2, dist2); dist = dist2;
// rate = rate2;
rate += x->mode_costs.wedge_idx_cost[bsize][wedge_index]; rate += x->mode_costs.wedge_idx_cost[bsize][wedge_index];
rd = RDCOST(x->rdmult, rate, dist); rd = RDCOST(x->rdmult, rate, dist);

View File

@@ -2060,6 +2060,8 @@ static inline void encode_frame_internal(AV1_COMP *cpi) {
start_timing(cpi, av1_setup_motion_field_time); start_timing(cpi, av1_setup_motion_field_time);
#endif #endif
av1_calculate_ref_frame_side(cm); av1_calculate_ref_frame_side(cm);
features->allow_ref_frame_mvs &= !cpi->sf.hl_sf.disable_ref_frame_mvs;
if (features->allow_ref_frame_mvs) av1_setup_motion_field(cm); if (features->allow_ref_frame_mvs) av1_setup_motion_field(cm);
#if CONFIG_COLLECT_COMPONENT_TIMING #if CONFIG_COLLECT_COMPONENT_TIMING
end_timing(cpi, av1_setup_motion_field_time); end_timing(cpi, av1_setup_motion_field_time);

View File

@@ -487,6 +487,32 @@ static void set_bitstream_level_tier(AV1_PRIMARY *const ppi, int width,
} }
} }
void av1_set_svc_seq_params(AV1_PRIMARY *const ppi) {
SequenceHeader *const seq = &ppi->seq_params;
if (seq->operating_points_cnt_minus_1 == 0) {
seq->operating_point_idc[0] = 0;
seq->has_nonzero_operating_point_idc = false;
} else {
// Set operating_point_idc[] such that the i=0 point corresponds to the
// highest quality operating point (all layers), and subsequent
// operarting points (i > 0) are lower quality corresponding to
// skip decoding enhancement layers (temporal first).
int i = 0;
assert(seq->operating_points_cnt_minus_1 ==
(int)(ppi->number_spatial_layers * ppi->number_temporal_layers - 1));
for (unsigned int sl = 0; sl < ppi->number_spatial_layers; sl++) {
for (unsigned int tl = 0; tl < ppi->number_temporal_layers; tl++) {
seq->operating_point_idc[i] =
(~(~0u << (ppi->number_spatial_layers - sl)) << 8) |
~(~0u << (ppi->number_temporal_layers - tl));
assert(seq->operating_point_idc[i] != 0);
i++;
}
}
seq->has_nonzero_operating_point_idc = true;
}
}
static void init_seq_coding_tools(AV1_PRIMARY *const ppi, static void init_seq_coding_tools(AV1_PRIMARY *const ppi,
const AV1EncoderConfig *oxcf, const AV1EncoderConfig *oxcf,
int disable_frame_id_numbers) { int disable_frame_id_numbers) {
@@ -551,29 +577,7 @@ static void init_seq_coding_tools(AV1_PRIMARY *const ppi,
set_bitstream_level_tier(ppi, frm_dim_cfg->width, frm_dim_cfg->height, set_bitstream_level_tier(ppi, frm_dim_cfg->width, frm_dim_cfg->height,
oxcf->input_cfg.init_framerate); oxcf->input_cfg.init_framerate);
av1_set_svc_seq_params(ppi);
if (seq->operating_points_cnt_minus_1 == 0) {
seq->operating_point_idc[0] = 0;
seq->has_nonzero_operating_point_idc = false;
} else {
// Set operating_point_idc[] such that the i=0 point corresponds to the
// highest quality operating point (all layers), and subsequent
// operarting points (i > 0) are lower quality corresponding to
// skip decoding enhancement layers (temporal first).
int i = 0;
assert(seq->operating_points_cnt_minus_1 ==
(int)(ppi->number_spatial_layers * ppi->number_temporal_layers - 1));
for (unsigned int sl = 0; sl < ppi->number_spatial_layers; sl++) {
for (unsigned int tl = 0; tl < ppi->number_temporal_layers; tl++) {
seq->operating_point_idc[i] =
(~(~0u << (ppi->number_spatial_layers - sl)) << 8) |
~(~0u << (ppi->number_temporal_layers - tl));
assert(seq->operating_point_idc[i] != 0);
i++;
}
}
seq->has_nonzero_operating_point_idc = true;
}
} }
static void init_config_sequence(struct AV1_PRIMARY *ppi, static void init_config_sequence(struct AV1_PRIMARY *ppi,
@@ -770,6 +774,11 @@ void av1_change_config_seq(struct AV1_PRIMARY *ppi,
// Init sequence level coding tools // Init sequence level coding tools
// This should not be called after the first key frame. // This should not be called after the first key frame.
// Note that for SVC encoding the sequence parameters
// (operating_points_cnt_minus_1, operating_point_idc[],
// has_nonzero_operating_point_idc) should be updated whenever the
// number of layers is changed. This is done in the
// ctrl_set_svc_params().
if (!ppi->seq_params_locked) { if (!ppi->seq_params_locked) {
seq_params->operating_points_cnt_minus_1 = seq_params->operating_points_cnt_minus_1 =
(ppi->number_spatial_layers > 1 || ppi->number_temporal_layers > 1) (ppi->number_spatial_layers > 1 || ppi->number_temporal_layers > 1)
@@ -2276,7 +2285,12 @@ void av1_set_frame_size(AV1_COMP *cpi, int width, int height) {
if (av1_is_scaled(sf)) aom_extend_frame_borders(&buf->buf, num_planes); if (av1_is_scaled(sf)) aom_extend_frame_borders(&buf->buf, num_planes);
} }
} }
if (!frame_is_intra_only(cm) && !has_valid_ref_frame) { // For 1 pass CBR mode: we can skip this check for spatial enhancement
// layer if the target_bandwidth is zero, since it will be dropped.
const bool dropped_frame =
has_no_stats_stage(cpi) && cpi->oxcf.rc_cfg.mode == AOM_CBR &&
cpi->svc.spatial_layer_id > 0 && cpi->oxcf.rc_cfg.target_bandwidth == 0;
if (!frame_is_intra_only(cm) && !has_valid_ref_frame && !dropped_frame) {
aom_internal_error( aom_internal_error(
cm->error, AOM_CODEC_CORRUPT_FRAME, cm->error, AOM_CODEC_CORRUPT_FRAME,
"Can't find at least one reference frame with valid size"); "Can't find at least one reference frame with valid size");
@@ -2990,10 +3004,6 @@ static int encode_with_recode_loop(AV1_COMP *cpi, size_t *size, uint8_t *dest,
av1_set_variance_partition_thresholds(cpi, q, 0); av1_set_variance_partition_thresholds(cpi, q, 0);
// printf("Frame %d/%d: q = %d, frame_type = %d superres_denom = %d\n",
// cm->current_frame.frame_number, cm->show_frame, q,
// cm->current_frame.frame_type, cm->superres_scale_denominator);
if (loop_count == 0) { if (loop_count == 0) {
av1_setup_frame(cpi); av1_setup_frame(cpi);
} else if (get_primary_ref_frame_buf(cm) == NULL) { } else if (get_primary_ref_frame_buf(cm) == NULL) {
@@ -4010,8 +4020,10 @@ static int encode_frame_to_data_rate(AV1_COMP *cpi, size_t *size, uint8_t *dest,
cpi->frames_since_last_update = 1; cpi->frames_since_last_update = 1;
} }
if (cpi->svc.spatial_layer_id == cpi->svc.number_spatial_layers - 1) if (cpi->svc.spatial_layer_id == cpi->svc.number_spatial_layers - 1) {
cpi->svc.prev_number_spatial_layers = cpi->svc.number_spatial_layers; cpi->svc.prev_number_spatial_layers = cpi->svc.number_spatial_layers;
}
cpi->svc.prev_number_temporal_layers = cpi->svc.number_temporal_layers;
// Clear the one shot update flags for segmentation map and mode/ref loop // Clear the one shot update flags for segmentation map and mode/ref loop
// filter deltas. // filter deltas.

View File

@@ -2675,7 +2675,11 @@ typedef struct AV1_PRIMARY {
/*! /*!
* Sequence parameters have been transmitted already and locked * Sequence parameters have been transmitted already and locked
* or not. Once locked av1_change_config cannot change the seq * or not. Once locked av1_change_config cannot change the seq
* parameters. * parameters. Note that for SVC encoding the sequence parameters
* (operating_points_cnt_minus_1, operating_point_idc[],
* has_nonzero_operating_point_idc) should be updated whenever the
* number of layers is changed. This is done in the
* ctrl_set_svc_params().
*/ */
int seq_params_locked; int seq_params_locked;
@@ -3905,6 +3909,8 @@ void av1_set_screen_content_options(struct AV1_COMP *cpi,
void av1_update_frame_size(AV1_COMP *cpi); void av1_update_frame_size(AV1_COMP *cpi);
void av1_set_svc_seq_params(AV1_PRIMARY *const ppi);
typedef struct { typedef struct {
int pyr_level; int pyr_level;
int disp_order; int disp_order;

View File

@@ -518,8 +518,6 @@ static void process_tpl_stats_frame(AV1_COMP *cpi) {
const int gfu_boost = get_gfu_boost_from_r0_lap( const int gfu_boost = get_gfu_boost_from_r0_lap(
min_boost_factor, MAX_GFUBOOST_FACTOR, cpi->rd.r0, min_boost_factor, MAX_GFUBOOST_FACTOR, cpi->rd.r0,
cpi->ppi->p_rc.num_stats_required_for_gfu_boost); cpi->ppi->p_rc.num_stats_required_for_gfu_boost);
// printf("old boost %d new boost %d\n", cpi->rc.gfu_boost,
// gfu_boost);
cpi->ppi->p_rc.gfu_boost = combine_prior_with_tpl_boost( cpi->ppi->p_rc.gfu_boost = combine_prior_with_tpl_boost(
min_boost_factor, MAX_BOOST_COMBINE_FACTOR, min_boost_factor, MAX_BOOST_COMBINE_FACTOR,
cpi->ppi->p_rc.gfu_boost, gfu_boost, cpi->ppi->p_rc.gfu_boost, gfu_boost,
@@ -840,10 +838,12 @@ BLOCK_SIZE av1_select_sb_size(const AV1EncoderConfig *const oxcf, int width,
} }
assert(oxcf->tool_cfg.superblock_size == AOM_SUPERBLOCK_SIZE_DYNAMIC); assert(oxcf->tool_cfg.superblock_size == AOM_SUPERBLOCK_SIZE_DYNAMIC);
if (number_spatial_layers > 1 || if (number_spatial_layers > 1) {
oxcf->resize_cfg.resize_mode != RESIZE_NONE) { // For spatial layers better selection may be done given the resolutions
// Use the configured size (top resolution) for spatial layers or // used across the layers, but for now use 64x64 for spatial layers.
// on resize. return BLOCK_64X64;
} else if (oxcf->resize_cfg.resize_mode != RESIZE_NONE) {
// Use the configured size (top resolution) for resize.
return AOMMIN(oxcf->frm_dim_cfg.width, oxcf->frm_dim_cfg.height) > 720 return AOMMIN(oxcf->frm_dim_cfg.width, oxcf->frm_dim_cfg.height) > 720
? BLOCK_128X128 ? BLOCK_128X128
: BLOCK_64X64; : BLOCK_64X64;

View File

@@ -30,8 +30,9 @@
// Border over which to compute the global motion // Border over which to compute the global motion
#define ERRORADV_BORDER 0 #define ERRORADV_BORDER 0
int av1_is_enough_erroradvantage(double best_erroradvantage, int params_cost) { int av1_is_enough_erroradvantage(double best_erroradvantage, int params_cost,
return best_erroradvantage < erroradv_tr && double gm_erroradv_tr) {
return best_erroradvantage < gm_erroradv_tr &&
best_erroradvantage * params_cost < erroradv_prod_tr; best_erroradvantage * params_cost < erroradv_prod_tr;
} }
@@ -364,7 +365,8 @@ int64_t av1_refine_integerized_param(
WarpedMotionParams *wm, TransformationType wmtype, int use_hbd, int bd, WarpedMotionParams *wm, TransformationType wmtype, int use_hbd, int bd,
uint8_t *ref, int r_width, int r_height, int r_stride, uint8_t *dst, uint8_t *ref, int r_width, int r_height, int r_stride, uint8_t *dst,
int d_width, int d_height, int d_stride, int n_refinements, int d_width, int d_height, int d_stride, int n_refinements,
int64_t ref_frame_error, uint8_t *segment_map, int segment_map_stride) { int64_t ref_frame_error, uint8_t *segment_map, int segment_map_stride,
double gm_erroradv_tr) {
static const int max_trans_model_params[TRANS_TYPES] = { 0, 2, 4, 6 }; static const int max_trans_model_params[TRANS_TYPES] = { 0, 2, 4, 6 };
const int border = ERRORADV_BORDER; const int border = ERRORADV_BORDER;
int i = 0, p; int i = 0, p;
@@ -383,7 +385,8 @@ int64_t av1_refine_integerized_param(
// Compute the maximum error value that will be accepted, so that // Compute the maximum error value that will be accepted, so that
// get_warp_error can terminate early if it proves the model will not // get_warp_error can terminate early if it proves the model will not
// be accepted. // be accepted.
int64_t selection_threshold = (int64_t)lrint(ref_frame_error * erroradv_tr); int64_t selection_threshold =
(int64_t)lrint(ref_frame_error * gm_erroradv_tr);
return get_warp_error(wm, use_hbd, bd, ref, r_width, r_height, r_stride, return get_warp_error(wm, use_hbd, bd, ref, r_width, r_height, r_stride,
dst + border * d_stride + border, d_stride, border, dst + border * d_stride + border, d_stride, border,
border, d_width - 2 * border, d_height - 2 * border, border, d_width - 2 * border, d_height - 2 * border,

View File

@@ -77,7 +77,7 @@ void av1_convert_model_to_params(const double *params,
WarpedMotionParams *model); WarpedMotionParams *model);
// Criteria for accepting a global motion model // Criteria for accepting a global motion model
static const double erroradv_tr = 0.65; static const double erroradv_tr[2] = { 0.65, 0.2 };
static const double erroradv_prod_tr = 20000; static const double erroradv_prod_tr = 20000;
// Early exit threshold for global motion refinement // Early exit threshold for global motion refinement
@@ -91,7 +91,8 @@ static const double erroradv_prod_tr = 20000;
// threshold even if the model is initially above the threshold // threshold even if the model is initially above the threshold
static const double erroradv_early_tr = 0.70; static const double erroradv_early_tr = 0.70;
int av1_is_enough_erroradvantage(double best_erroradvantage, int params_cost); int av1_is_enough_erroradvantage(double best_erroradvantage, int params_cost,
double gm_erroradv_tr);
void av1_compute_feature_segmentation_map(uint8_t *segment_map, int width, void av1_compute_feature_segmentation_map(uint8_t *segment_map, int width,
int height, int *inliers, int height, int *inliers,
@@ -109,7 +110,8 @@ int64_t av1_refine_integerized_param(
WarpedMotionParams *wm, TransformationType wmtype, int use_hbd, int bd, WarpedMotionParams *wm, TransformationType wmtype, int use_hbd, int bd,
uint8_t *ref, int r_width, int r_height, int r_stride, uint8_t *dst, uint8_t *ref, int r_width, int r_height, int r_stride, uint8_t *dst,
int d_width, int d_height, int d_stride, int n_refinements, int d_width, int d_height, int d_stride, int n_refinements,
int64_t ref_frame_error, uint8_t *segment_map, int segment_map_stride); int64_t ref_frame_error, uint8_t *segment_map, int segment_map_stride,
double gm_erroradv_tr);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"

View File

@@ -91,13 +91,15 @@ static inline void compute_global_motion_for_ref_frame(
GlobalMotionMethod global_motion_method = default_global_motion_method; GlobalMotionMethod global_motion_method = default_global_motion_method;
int downsample_level = cpi->sf.gm_sf.downsample_level; int downsample_level = cpi->sf.gm_sf.downsample_level;
int num_refinements = cpi->sf.gm_sf.num_refinement_steps; int num_refinements = cpi->sf.gm_sf.num_refinement_steps;
int gm_erroradv_tr_level = cpi->sf.gm_sf.gm_erroradv_tr_level;
bool mem_alloc_failed = false; bool mem_alloc_failed = false;
assert(gm_erroradv_tr_level < 2);
// Select the best model based on fractional error reduction. // Select the best model based on fractional error reduction.
// By initializing this to erroradv_tr, the same logic which is used to // By initializing this to erroradv_tr, the same logic which is used to
// select the best model will automatically filter out any model which // select the best model will automatically filter out any model which
// doesn't meet the required quality threshold // doesn't meet the required quality threshold
double best_erroradv = erroradv_tr; double best_erroradv = erroradv_tr[gm_erroradv_tr_level];
for (TransformationType model = FIRST_GLOBAL_TRANS_TYPE; for (TransformationType model = FIRST_GLOBAL_TRANS_TYPE;
model <= LAST_GLOBAL_TRANS_TYPE; ++model) { model <= LAST_GLOBAL_TRANS_TYPE; ++model) {
if (!aom_compute_global_motion(model, cpi->source, ref_buf[frame], if (!aom_compute_global_motion(model, cpi->source, ref_buf[frame],
@@ -148,7 +150,8 @@ static inline void compute_global_motion_for_ref_frame(
ref_buf[frame]->y_buffer, ref_buf[frame]->y_crop_width, ref_buf[frame]->y_buffer, ref_buf[frame]->y_crop_width,
ref_buf[frame]->y_crop_height, ref_buf[frame]->y_stride, ref_buf[frame]->y_crop_height, ref_buf[frame]->y_stride,
cpi->source->y_buffer, src_width, src_height, src_stride, cpi->source->y_buffer, src_width, src_height, src_stride,
num_refinements, ref_frame_error, segment_map, segment_map_w); num_refinements, ref_frame_error, segment_map, segment_map_w,
erroradv_tr[gm_erroradv_tr_level]);
// av1_refine_integerized_param() can return a simpler model type than // av1_refine_integerized_param() can return a simpler model type than
// its input, so re-check model type here // its input, so re-check model type here
@@ -160,7 +163,8 @@ static inline void compute_global_motion_for_ref_frame(
if (!av1_is_enough_erroradvantage( if (!av1_is_enough_erroradvantage(
erroradvantage, erroradvantage,
gm_get_params_cost(&tmp_wm_params, ref_params, gm_get_params_cost(&tmp_wm_params, ref_params,
cm->features.allow_high_precision_mv))) { cm->features.allow_high_precision_mv),
erroradv_tr[gm_erroradv_tr_level])) {
continue; continue;
} }

View File

@@ -642,7 +642,7 @@ static int construct_multi_layer_gf_structure(
: gf_group->is_sframe_due ? S_FRAME : gf_group->is_sframe_due ? S_FRAME
: INTER_FRAME; : INTER_FRAME;
gf_group->is_sframe_due = gf_group->is_sframe_due =
sframe_enabled && !(gf_group->frame_type[frame_index] == S_FRAME); sframe_enabled && gf_group->frame_type[frame_index] != S_FRAME;
gf_group->refbuf_state[frame_index] = REFBUF_UPDATE; gf_group->refbuf_state[frame_index] = REFBUF_UPDATE;
gf_group->max_layer_depth = 1; gf_group->max_layer_depth = 1;
gf_group->arf_index = frame_index; gf_group->arf_index = frame_index;

View File

@@ -102,21 +102,18 @@ void av1_make_default_fullpel_ms_params(
ms_params->mv_limits = x->mv_limits; ms_params->mv_limits = x->mv_limits;
av1_set_mv_search_range(&ms_params->mv_limits, ref_mv); av1_set_mv_search_range(&ms_params->mv_limits, ref_mv);
if (cpi->oxcf.algo_cfg.sharpness) { if (cpi->oxcf.algo_cfg.sharpness == 3) {
int top_margin = x->e_mbd.mi_row * MI_SIZE + 8; int top_margin = x->e_mbd.mi_row * MI_SIZE + 8;
int left_margin = x->e_mbd.mi_col * MI_SIZE + 8; int left_margin = x->e_mbd.mi_col * MI_SIZE + 8;
int bottom_margin = cpi->common.cur_frame->height - int bottom_margin =
mi_size_high[bsize] * MI_SIZE - top_margin + 16; cpi->common.height - mi_size_high[bsize] * MI_SIZE - top_margin + 16;
int right_margin = cpi->common.cur_frame->width - int right_margin =
mi_size_wide[bsize] * MI_SIZE - left_margin + 16; cpi->common.width - mi_size_wide[bsize] * MI_SIZE - left_margin + 16;
if (ms_params->mv_limits.row_min < -top_margin) FullMvLimits *mv_limits = &ms_params->mv_limits;
ms_params->mv_limits.row_min = -top_margin; mv_limits->row_min = AOMMAX(mv_limits->row_min, -top_margin);
if (ms_params->mv_limits.row_max > bottom_margin) mv_limits->row_max = AOMMIN(mv_limits->row_max, bottom_margin);
ms_params->mv_limits.row_max = bottom_margin; mv_limits->col_min = AOMMAX(mv_limits->col_min, -left_margin);
if (ms_params->mv_limits.col_min < -left_margin) mv_limits->col_max = AOMMIN(mv_limits->col_max, right_margin);
ms_params->mv_limits.col_min = -left_margin;
if (ms_params->mv_limits.col_max > right_margin)
ms_params->mv_limits.col_max = right_margin;
} }
// Mvcost params // Mvcost params
@@ -193,6 +190,22 @@ void av1_make_default_subpel_ms_params(SUBPEL_MOTION_SEARCH_PARAMS *ms_params,
av1_set_subpel_mv_search_range(&ms_params->mv_limits, &x->mv_limits, ref_mv); av1_set_subpel_mv_search_range(&ms_params->mv_limits, &x->mv_limits, ref_mv);
if (cpi->oxcf.algo_cfg.sharpness == 3) {
int top_margin = GET_MV_SUBPEL(x->e_mbd.mi_row * MI_SIZE + 8);
int left_margin = GET_MV_SUBPEL(x->e_mbd.mi_col * MI_SIZE + 8);
int bottom_margin =
GET_MV_SUBPEL(cpi->common.height - mi_size_high[bsize] * MI_SIZE -
x->e_mbd.mi_row * MI_SIZE + 8);
int right_margin =
GET_MV_SUBPEL(cpi->common.width - mi_size_wide[bsize] * MI_SIZE -
x->e_mbd.mi_col * MI_SIZE + 8);
SubpelMvLimits *mv_limits = &ms_params->mv_limits;
mv_limits->row_min = AOMMAX(mv_limits->row_min, -top_margin);
mv_limits->row_max = AOMMIN(mv_limits->row_max, bottom_margin);
mv_limits->col_min = AOMMAX(mv_limits->col_min, -left_margin);
mv_limits->col_max = AOMMIN(mv_limits->col_max, right_margin);
}
// Mvcost params // Mvcost params
init_mv_cost_params(&ms_params->mv_cost_params, x->mv_costs, ref_mv, init_mv_cost_params(&ms_params->mv_cost_params, x->mv_costs, ref_mv,
x->errorperbit, x->sadperbit); x->errorperbit, x->sadperbit);
@@ -230,10 +243,10 @@ void av1_set_mv_search_range(FullMvLimits *mv_limits, const MV *mv) {
// Get intersection of UMV window and valid MV window to reduce # of checks // Get intersection of UMV window and valid MV window to reduce # of checks
// in diamond search. // in diamond search.
if (mv_limits->col_min < col_min) mv_limits->col_min = col_min; mv_limits->col_min = AOMMAX(mv_limits->col_min, col_min);
if (mv_limits->col_max > col_max) mv_limits->col_max = col_max; mv_limits->col_max = AOMMIN(mv_limits->col_max, col_max);
if (mv_limits->row_min < row_min) mv_limits->row_min = row_min; mv_limits->row_min = AOMMAX(mv_limits->row_min, row_min);
if (mv_limits->row_max > row_max) mv_limits->row_max = row_max; mv_limits->row_max = AOMMIN(mv_limits->row_max, row_max);
mv_limits->col_max = AOMMAX(mv_limits->col_min, mv_limits->col_max); mv_limits->col_max = AOMMAX(mv_limits->col_min, mv_limits->col_max);
mv_limits->row_max = AOMMAX(mv_limits->row_min, mv_limits->row_max); mv_limits->row_max = AOMMAX(mv_limits->row_min, mv_limits->row_max);

View File

@@ -3408,9 +3408,6 @@ static void find_next_key_frame(AV1_COMP *cpi, FIRSTPASS_STATS *this_frame) {
kf_bits = calculate_boost_bits( kf_bits = calculate_boost_bits(
AOMMIN(rc->frames_to_key, frames_to_key_clipped) - 1, p_rc->kf_boost, AOMMIN(rc->frames_to_key, frames_to_key_clipped) - 1, p_rc->kf_boost,
AOMMIN(twopass->kf_group_bits, kf_group_bits_clipped)); AOMMIN(twopass->kf_group_bits, kf_group_bits_clipped));
// printf("kf boost = %d kf_bits = %d kf_zeromotion_pct = %d\n",
// p_rc->kf_boost,
// kf_bits, twopass->kf_zeromotion_pct);
kf_bits = adjust_boost_bits_for_target_level(cpi, rc, kf_bits, kf_bits = adjust_boost_bits_for_target_level(cpi, rc, kf_bits,
twopass->kf_group_bits, 0); twopass->kf_group_bits, 0);

View File

@@ -1507,7 +1507,6 @@ static int64_t finer_search_wiener(const RestSearchCtxt *rsc,
WienerInfo *plane_wiener = &rui->wiener_info; WienerInfo *plane_wiener = &rui->wiener_info;
// printf("err pre = %"PRId64"\n", err);
const int start_step = 4; const int start_step = 4;
for (int s = start_step; s >= 1; s >>= 1) { for (int s = start_step; s >= 1; s >>= 1) {
for (int p = plane_off; p < WIENER_HALFWIN; ++p) { for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
@@ -1593,7 +1592,6 @@ static int64_t finer_search_wiener(const RestSearchCtxt *rsc,
} while (1); } while (1);
} }
} }
// printf("err post = %"PRId64"\n", err);
return err; return err;
} }
@@ -2052,6 +2050,8 @@ void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi) {
min_lr_unit_size = min_lr_unit_size =
AOMMAX(min_lr_unit_size, block_size_wide[cm->seq_params->sb_size]); AOMMAX(min_lr_unit_size, block_size_wide[cm->seq_params->sb_size]);
max_lr_unit_size = AOMMAX(min_lr_unit_size, max_lr_unit_size);
for (int plane = 0; plane < num_planes; ++plane) { for (int plane = 0; plane < num_planes; ++plane) {
cpi->pick_lr_ctxt.rusi[plane] = allocate_search_structs( cpi->pick_lr_ctxt.rusi[plane] = allocate_search_structs(
cm, &cm->rst_info[plane], plane > 0, min_lr_unit_size); cm, &cm->rst_info[plane], plane > 0, min_lr_unit_size);

View File

@@ -1448,8 +1448,6 @@ static int get_active_cq_level(const RATE_CONTROL *rc,
static const double cq_adjust_threshold = 0.1; static const double cq_adjust_threshold = 0.1;
int active_cq_level = rc_cfg->cq_level; int active_cq_level = rc_cfg->cq_level;
if (rc_cfg->mode == AOM_CQ || rc_cfg->mode == AOM_Q) { if (rc_cfg->mode == AOM_CQ || rc_cfg->mode == AOM_Q) {
// printf("Superres %d %d %d = %d\n", superres_denom, intra_only,
// rc->frames_to_key, !(intra_only && rc->frames_to_key <= 1));
if ((superres_mode == AOM_SUPERRES_QTHRESH || if ((superres_mode == AOM_SUPERRES_QTHRESH ||
superres_mode == AOM_SUPERRES_AUTO) && superres_mode == AOM_SUPERRES_AUTO) &&
superres_denom != SCALE_NUMERATOR) { superres_denom != SCALE_NUMERATOR) {
@@ -2497,6 +2495,7 @@ void av1_rc_postencode_update_drop_frame(AV1_COMP *cpi) {
if (cpi->svc.spatial_layer_id == cpi->svc.number_spatial_layers - 1) { if (cpi->svc.spatial_layer_id == cpi->svc.number_spatial_layers - 1) {
cpi->svc.prev_number_spatial_layers = cpi->svc.number_spatial_layers; cpi->svc.prev_number_spatial_layers = cpi->svc.number_spatial_layers;
} }
cpi->svc.prev_number_temporal_layers = cpi->svc.number_temporal_layers;
} }
int av1_find_qindex(double desired_q, aom_bit_depth_t bit_depth, int av1_find_qindex(double desired_q, aom_bit_depth_t bit_depth,

View File

@@ -649,10 +649,6 @@ void av1_fill_coeff_costs(CoeffCosts *coeff_costs, FRAME_CONTEXT *fc,
av1_cost_tokens_from_cdf( av1_cost_tokens_from_cdf(
br_rate, fc->coeff_br_cdf[AOMMIN(tx_size, TX_32X32)][plane][ctx], br_rate, fc->coeff_br_cdf[AOMMIN(tx_size, TX_32X32)][plane][ctx],
NULL); NULL);
// printf("br_rate: ");
// for(j = 0; j < BR_CDF_SIZE; j++)
// printf("%4d ", br_rate[j]);
// printf("\n");
for (i = 0; i < COEFF_BASE_RANGE; i += BR_CDF_SIZE - 1) { for (i = 0; i < COEFF_BASE_RANGE; i += BR_CDF_SIZE - 1) {
for (j = 0; j < BR_CDF_SIZE - 1; j++) { for (j = 0; j < BR_CDF_SIZE - 1; j++) {
pcost->lps_cost[ctx][i + j] = prev_cost + br_rate[j]; pcost->lps_cost[ctx][i + j] = prev_cost + br_rate[j];
@@ -660,10 +656,6 @@ void av1_fill_coeff_costs(CoeffCosts *coeff_costs, FRAME_CONTEXT *fc,
prev_cost += br_rate[j]; prev_cost += br_rate[j];
} }
pcost->lps_cost[ctx][i] = prev_cost; pcost->lps_cost[ctx][i] = prev_cost;
// printf("lps_cost: %d %d %2d : ", tx_size, plane, ctx);
// for (i = 0; i <= COEFF_BASE_RANGE; i++)
// printf("%5d ", pcost->lps_cost[ctx][i]);
// printf("\n");
} }
for (int ctx = 0; ctx < LEVEL_CONTEXTS; ++ctx) { for (int ctx = 0; ctx < LEVEL_CONTEXTS; ++ctx) {
pcost->lps_cost[ctx][0 + COEFF_BASE_RANGE + 1] = pcost->lps_cost[ctx][0 + COEFF_BASE_RANGE + 1] =

View File

@@ -607,6 +607,77 @@ void av1_get_horver_correlation_full_c(const int16_t *diff, int stride,
} }
} }
static void get_variance_stats(const AV1_COMP *cpi, const MACROBLOCK *x,
int num_planes, int64_t *src_var,
int64_t *rec_var) {
const MACROBLOCKD *xd = &x->e_mbd;
const MB_MODE_INFO *mbmi = xd->mi[0];
DECLARE_ALIGNED(16, uint8_t, dclevel[MAX_SB_SQUARE]);
memset(dclevel, 128, sizeof(dclevel));
int dclevel_stride = block_size_wide[mbmi->bsize];
*src_var = 0;
*rec_var = 0;
for (int plane = 0; plane < num_planes; ++plane) {
if (plane && !xd->is_chroma_ref) break;
const struct macroblock_plane *const p = &x->plane[plane];
const struct macroblockd_plane *const pd = &xd->plane[plane];
const BLOCK_SIZE bs =
get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
unsigned int sse;
int64_t var = cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, dclevel,
dclevel_stride, &sse);
*src_var += var;
var = cpi->ppi->fn_ptr[bs].vf(pd->dst.buf, pd->dst.stride, dclevel,
dclevel_stride, &sse);
*rec_var += var;
}
*src_var <<= 4;
*rec_var <<= 4;
}
static void adjust_rdcost(const AV1_COMP *cpi, const MACROBLOCK *x,
RD_STATS *rd_cost) {
if (cpi->oxcf.algo_cfg.sharpness != 3) return;
if (frame_is_kf_gf_arf(cpi)) return;
int64_t src_var, rec_var;
get_variance_stats(cpi, x, 1, &src_var, &rec_var);
if (src_var <= rec_var) return;
int64_t var_offset = src_var - rec_var;
rd_cost->dist += var_offset;
rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
}
static void adjust_cost(const AV1_COMP *cpi, const MACROBLOCK *x,
int64_t *rd_cost) {
if (cpi->oxcf.algo_cfg.sharpness != 3) return;
if (frame_is_kf_gf_arf(cpi)) return;
int64_t src_var, rec_var;
get_variance_stats(cpi, x, 1, &src_var, &rec_var);
if (src_var <= rec_var) return;
int64_t var_offset = src_var - rec_var;
*rd_cost += RDCOST(x->rdmult, 0, var_offset);
}
static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x, static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x,
int64_t *sse_y) { int64_t *sse_y) {
const AV1_COMMON *cm = &cpi->common; const AV1_COMMON *cm = &cpi->common;
@@ -5501,6 +5572,11 @@ static inline void search_intra_modes_in_interframe(
intra_search_state, cpi, x, bsize, intra_ref_frame_cost, ctx, intra_search_state, cpi, x, bsize, intra_ref_frame_cost, ctx,
&intra_rd_stats_y, search_state->best_rd, &mode_cost_y, &intra_rd_y, &intra_rd_stats_y, search_state->best_rd, &mode_cost_y, &intra_rd_y,
&best_model_rd, top_intra_model_rd); &best_model_rd, top_intra_model_rd);
if (intra_rd_y < INT64_MAX) {
adjust_cost(cpi, x, &intra_rd_y);
}
if (is_luma_result_valid && intra_rd_y < yrd_threshold) { if (is_luma_result_valid && intra_rd_y < yrd_threshold) {
is_best_y_mode_intra = 1; is_best_y_mode_intra = 1;
if (intra_rd_y < best_rd_y) { if (intra_rd_y < best_rd_y) {
@@ -5586,6 +5662,8 @@ static inline void search_intra_modes_in_interframe(
intra_rd_stats.rdcost = this_rd; intra_rd_stats.rdcost = this_rd;
adjust_rdcost(cpi, x, &intra_rd_stats);
// Collect mode stats for multiwinner mode processing // Collect mode stats for multiwinner mode processing
const int txfm_search_done = 1; const int txfm_search_done = 1;
store_winner_mode_stats( store_winner_mode_stats(
@@ -6019,6 +6097,12 @@ void av1_rd_pick_inter_mode(struct AV1_COMP *cpi, struct TileDataEnc *tile_data,
args.best_pred_sse = search_state.best_pred_sse; args.best_pred_sse = search_state.best_pred_sse;
args.skip_ifs = skip_interp_filter_search(cpi, is_single_pred); args.skip_ifs = skip_interp_filter_search(cpi, is_single_pred);
if (!frame_is_kf_gf_arf(cpi) && cpi->oxcf.algo_cfg.sharpness == 3) {
if (ref_frame != ALTREF_FRAME && ref_frame != GOLDEN_FRAME &&
ref_frame != INTRA_FRAME)
continue;
}
int64_t skip_rd[2] = { search_state.best_skip_rd[0], int64_t skip_rd[2] = { search_state.best_skip_rd[0],
search_state.best_skip_rd[1] }; search_state.best_skip_rd[1] };
int64_t this_yrd = INT64_MAX; int64_t this_yrd = INT64_MAX;
@@ -6057,6 +6141,9 @@ void av1_rd_pick_inter_mode(struct AV1_COMP *cpi, struct TileDataEnc *tile_data,
ref_frame_rd[ref_frame] = this_rd; ref_frame_rd[ref_frame] = this_rd;
} }
adjust_cost(cpi, x, &this_rd);
adjust_rdcost(cpi, x, &rd_stats);
// Did this mode help, i.e., is it the new best mode // Did this mode help, i.e., is it the new best mode
if (this_rd < search_state.best_rd) { if (this_rd < search_state.best_rd) {
assert(IMPLIES(comp_pred, assert(IMPLIES(comp_pred,

View File

@@ -614,6 +614,10 @@ static void set_good_speed_features_lc_dec_framesize_dependent(
(update_type == LF_UPDATE || update_type == OVERLAY_UPDATE || (update_type == LF_UPDATE || update_type == OVERLAY_UPDATE ||
update_type == INTNL_OVERLAY_UPDATE); update_type == INTNL_OVERLAY_UPDATE);
if (leaf_and_overlay_frames) sf->gm_sf.gm_search_type = GM_DISABLE_SEARCH; if (leaf_and_overlay_frames) sf->gm_sf.gm_search_type = GM_DISABLE_SEARCH;
sf->hl_sf.disable_ref_frame_mvs = 1;
} else if (is_608p_or_larger) {
sf->gm_sf.gm_erroradv_tr_level = 1;
} }
} }
@@ -2028,6 +2032,7 @@ static inline void init_hl_sf(HIGH_LEVEL_SPEED_FEATURES *hl_sf) {
hl_sf->accurate_bit_estimate = 0; hl_sf->accurate_bit_estimate = 0;
hl_sf->weight_calc_level_in_tf = 0; hl_sf->weight_calc_level_in_tf = 0;
hl_sf->allow_sub_blk_me_in_tf = 0; hl_sf->allow_sub_blk_me_in_tf = 0;
hl_sf->disable_ref_frame_mvs = 0;
} }
static inline void init_fp_sf(FIRST_PASS_SPEED_FEATURES *fp_sf) { static inline void init_fp_sf(FIRST_PASS_SPEED_FEATURES *fp_sf) {
@@ -2059,6 +2064,7 @@ static inline void init_gm_sf(GLOBAL_MOTION_SPEED_FEATURES *gm_sf) {
gm_sf->disable_gm_search_based_on_stats = 0; gm_sf->disable_gm_search_based_on_stats = 0;
gm_sf->downsample_level = 0; gm_sf->downsample_level = 0;
gm_sf->num_refinement_steps = GM_MAX_REFINEMENT_STEPS; gm_sf->num_refinement_steps = GM_MAX_REFINEMENT_STEPS;
gm_sf->gm_erroradv_tr_level = 0;
} }
static inline void init_part_sf(PARTITION_SPEED_FEATURES *part_sf) { static inline void init_part_sf(PARTITION_SPEED_FEATURES *part_sf) {
@@ -2613,6 +2619,29 @@ void av1_set_speed_features_framesize_independent(AV1_COMP *cpi, int speed) {
sf->rt_sf.gf_refresh_based_on_qp = 0; sf->rt_sf.gf_refresh_based_on_qp = 0;
} }
// Override some speed features for low complexity decode based on qindex.
static void set_speed_features_lc_dec_qindex_dependent(
const AV1_COMP *const cpi, SPEED_FEATURES *const sf, int speed) {
if (speed < 1 || speed > 3) return;
const AV1_COMMON *const cm = &cpi->common;
const int short_dimension = AOMMIN(cm->width, cm->height);
const int is_720p_or_larger = AOMMIN(cm->width, cm->height) >= 720;
const FRAME_UPDATE_TYPE update_type =
get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index);
const int leaf_and_overlay_frames =
(update_type == LF_UPDATE || update_type == OVERLAY_UPDATE ||
update_type == INTNL_OVERLAY_UPDATE);
if (short_dimension > 480 && short_dimension < 720) {
sf->lpf_sf.min_lr_unit_size = RESTORATION_UNITSIZE_MAX >> 1;
sf->lpf_sf.max_lr_unit_size = RESTORATION_UNITSIZE_MAX >> 1;
} else if (is_720p_or_larger && speed <= 2 && leaf_and_overlay_frames) {
sf->lpf_sf.min_lr_unit_size = RESTORATION_UNITSIZE_MAX >> 1;
sf->lpf_sf.max_lr_unit_size = RESTORATION_UNITSIZE_MAX >> 1;
}
}
// Override some speed features based on qindex // Override some speed features based on qindex
void av1_set_speed_features_qindex_dependent(AV1_COMP *cpi, int speed) { void av1_set_speed_features_qindex_dependent(AV1_COMP *cpi, int speed) {
AV1_COMMON *const cm = &cpi->common; AV1_COMMON *const cm = &cpi->common;
@@ -2815,4 +2844,7 @@ void av1_set_speed_features_qindex_dependent(AV1_COMP *cpi, int speed) {
set_subpel_search_method(&cpi->mv_search_params, set_subpel_search_method(&cpi->mv_search_params,
cpi->oxcf.unit_test_cfg.motion_vector_unit_test, cpi->oxcf.unit_test_cfg.motion_vector_unit_test,
sf->mv_sf.subpel_search_method); sf->mv_sf.subpel_search_method);
if (cpi->oxcf.enable_low_complexity_decode)
set_speed_features_lc_dec_qindex_dependent(cpi, sf, speed);
} }

View File

@@ -482,6 +482,11 @@ typedef struct HIGH_LEVEL_SPEED_FEATURES {
* 1: Conditionally allow motion estimation based on 4x4 sub-blocks variance. * 1: Conditionally allow motion estimation based on 4x4 sub-blocks variance.
*/ */
int allow_sub_blk_me_in_tf; int allow_sub_blk_me_in_tf;
/*!
* Enable/disable temporal mv prediction.
*/
int disable_ref_frame_mvs;
} HIGH_LEVEL_SPEED_FEATURES; } HIGH_LEVEL_SPEED_FEATURES;
/*! /*!
@@ -592,6 +597,10 @@ typedef struct GLOBAL_MOTION_SPEED_FEATURES {
// Number of refinement steps to apply after initial model generation // Number of refinement steps to apply after initial model generation
int num_refinement_steps; int num_refinement_steps;
// Error advantage threshold level used to determine whether global motion
// compensation should be enabled
int gm_erroradv_tr_level;
} GLOBAL_MOTION_SPEED_FEATURES; } GLOBAL_MOTION_SPEED_FEATURES;
typedef struct PARTITION_SPEED_FEATURES { typedef struct PARTITION_SPEED_FEATURES {

View File

@@ -93,6 +93,7 @@ typedef struct SVC {
int number_spatial_layers; int number_spatial_layers;
int number_temporal_layers; int number_temporal_layers;
int prev_number_spatial_layers; int prev_number_spatial_layers;
int prev_number_temporal_layers;
int use_flexible_mode; int use_flexible_mode;
int ksvc_fixed_mode; int ksvc_fixed_mode;
/*!\endcond */ /*!\endcond */

View File

@@ -184,6 +184,10 @@ static void tf_motion_search(AV1_COMP *cpi, MACROBLOCK *mb,
mbd->plane[0].pre[0].buf = ref_frame->y_buffer + y_offset; mbd->plane[0].pre[0].buf = ref_frame->y_buffer + y_offset;
mbd->plane[0].pre[0].stride = y_stride; mbd->plane[0].pre[0].stride = y_stride;
mbd->plane[0].pre[0].width = ref_width; mbd->plane[0].pre[0].width = ref_width;
mbd->mi_row =
mb_row * (block_size_high[block_size] / block_size_high[BLOCK_4X4]);
mbd->mi_col =
mb_col * (block_size_wide[block_size] / block_size_wide[BLOCK_4X4]);
*is_dc_diff_large = 0; *is_dc_diff_large = 0;
const SEARCH_METHODS search_method = NSTEP; const SEARCH_METHODS search_method = NSTEP;

View File

@@ -182,6 +182,8 @@ set_aom_config_var(CONFIG_CWG_E050 0
set_aom_config_var(CONFIG_LIBVMAF_PSNR_PEAK 1 set_aom_config_var(CONFIG_LIBVMAF_PSNR_PEAK 1
"Use libvmaf PSNR peak for 10- and 12-bit") "Use libvmaf PSNR peak for 10- and 12-bit")
set_aom_config_var(CONFIG_HIGHWAY 0 "Use Highway for SIMD.")
# #
# Variables in this section control optional features of the build system. # Variables in this section control optional features of the build system.
# #

View File

@@ -40,7 +40,7 @@ class ACMRandom {
int16_t Rand16Signed() { return static_cast<int16_t>(Rand16()); } int16_t Rand16Signed() { return static_cast<int16_t>(Rand16()); }
int16_t Rand15() { uint16_t Rand15() {
const uint32_t value = const uint32_t value =
random_.Generate(testing::internal::Random::kMaxRange); random_.Generate(testing::internal::Random::kMaxRange);
// There's a bit more entropy in the upper bits of this implementation. // There's a bit more entropy in the upper bits of this implementation.

View File

@@ -246,7 +246,6 @@ void AV1FwdTxfm2dMatchTest(TX_SIZE tx_size, lowbd_fwd_txfm_func target_func) {
memset(&param, 0, sizeof(param)); memset(&param, 0, sizeof(param));
const int rows = tx_size_high[tx_size]; const int rows = tx_size_high[tx_size];
const int cols = tx_size_wide[tx_size]; const int cols = tx_size_wide[tx_size];
// printf("%d x %d\n", cols, rows);
for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) { for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
if (libaom_test::IsTxSizeTypeValid( if (libaom_test::IsTxSizeTypeValid(
tx_size, static_cast<TX_TYPE>(tx_type)) == false) { tx_size, static_cast<TX_TYPE>(tx_type)) == false) {

View File

@@ -175,7 +175,7 @@ class CFLTestWithAlignedData : public CFLTest {
typedef cfl_subtract_average_fn (*sub_avg_fn)(TX_SIZE tx_size); typedef cfl_subtract_average_fn (*sub_avg_fn)(TX_SIZE tx_size);
typedef std::tuple<TX_SIZE, sub_avg_fn> sub_avg_param; typedef std::tuple<TX_SIZE, sub_avg_fn> sub_avg_param;
class CFLSubAvgTest : public ::testing::TestWithParam<sub_avg_param>, class CFLSubAvgTest : public ::testing::TestWithParam<sub_avg_param>,
public CFLTestWithData<int16_t> { public CFLTestWithData<uint16_t> {
public: public:
void SetUp() override { void SetUp() override {
CFLTest::init(std::get<0>(this->GetParam())); CFLTest::init(std::get<0>(this->GetParam()));
@@ -191,27 +191,31 @@ class CFLSubAvgTest : public ::testing::TestWithParam<sub_avg_param>,
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CFLSubAvgTest); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CFLSubAvgTest);
TEST_P(CFLSubAvgTest, SubAvgTest) { TEST_P(CFLSubAvgTest, SubAvgTest) {
int16_t dst[CFL_BUF_SQUARE];
int16_t dst_ref[CFL_BUF_SQUARE];
for (int it = 0; it < NUM_ITERATIONS; it++) { for (int it = 0; it < NUM_ITERATIONS; it++) {
randData(&ACMRandom::Rand15); randData(&ACMRandom::Rand15);
sub_avg((uint16_t *)data, data); sub_avg(data, dst);
sub_avg_ref((uint16_t *)data_ref, data_ref); sub_avg_ref(data_ref, dst_ref);
assert_eq<int16_t>(data, data_ref, width, height); assert_eq<int16_t>(dst, dst_ref, width, height);
} }
} }
TEST_P(CFLSubAvgTest, DISABLED_SubAvgSpeedTest) { TEST_P(CFLSubAvgTest, DISABLED_SubAvgSpeedTest) {
int16_t dst[CFL_BUF_SQUARE];
int16_t dst_ref[CFL_BUF_SQUARE];
aom_usec_timer ref_timer; aom_usec_timer ref_timer;
aom_usec_timer timer; aom_usec_timer timer;
randData(&ACMRandom::Rand15); randData(&ACMRandom::Rand15);
aom_usec_timer_start(&ref_timer); aom_usec_timer_start(&ref_timer);
for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) {
sub_avg_ref((uint16_t *)data_ref, data_ref); sub_avg_ref(data_ref, dst_ref);
} }
aom_usec_timer_mark(&ref_timer); aom_usec_timer_mark(&ref_timer);
int ref_elapsed_time = (int)aom_usec_timer_elapsed(&ref_timer); int ref_elapsed_time = (int)aom_usec_timer_elapsed(&ref_timer);
aom_usec_timer_start(&timer); aom_usec_timer_start(&timer);
for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) {
sub_avg((uint16_t *)data, data); sub_avg(data, dst);
} }
aom_usec_timer_mark(&timer); aom_usec_timer_mark(&timer);
int elapsed_time = (int)aom_usec_timer_elapsed(&timer); int elapsed_time = (int)aom_usec_timer_elapsed(&timer);
@@ -261,13 +265,13 @@ class CFLSubsampleTest : public ::testing::TestWithParam<S>,
CFLTestWithData<I>::randData(random); CFLTestWithData<I>::randData(random);
aom_usec_timer_start(&ref_timer); aom_usec_timer_start(&ref_timer);
for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) {
fun_ref(this->data_ref, CFL_BUF_LINE, sub_luma_pels); fun_ref(this->data_ref, CFL_BUF_LINE, sub_luma_pels_ref);
} }
aom_usec_timer_mark(&ref_timer); aom_usec_timer_mark(&ref_timer);
int ref_elapsed_time = (int)aom_usec_timer_elapsed(&ref_timer); int ref_elapsed_time = (int)aom_usec_timer_elapsed(&ref_timer);
aom_usec_timer_start(&timer); aom_usec_timer_start(&timer);
for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) {
fun(this->data, CFL_BUF_LINE, sub_luma_pels_ref); fun(this->data, CFL_BUF_LINE, sub_luma_pels);
} }
aom_usec_timer_mark(&timer); aom_usec_timer_mark(&timer);
int elapsed_time = (int)aom_usec_timer_elapsed(&timer); int elapsed_time = (int)aom_usec_timer_elapsed(&timer);

View File

@@ -98,15 +98,21 @@ class RcInterfaceTest : public ::libaom_test::EncoderTest,
// Go down to 2 temporal layers. // Go down to 2 temporal layers.
SetConfigSvc(3, 2); SetConfigSvc(3, 2);
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
frame_flags_ = AOM_EFLAG_FORCE_KF;
frame_params_.frame_type = aom::kKeyFrame;
ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_));
} else if (superframe_cnt_ == 200 && layer_id_.spatial_layer_id == 0) { } else if (superframe_cnt_ == 200 && layer_id_.spatial_layer_id == 0) {
// Go down to 1 temporal layer. // Go down to 1 temporal layer.
SetConfigSvc(3, 1); SetConfigSvc(3, 1);
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
frame_flags_ = AOM_EFLAG_FORCE_KF;
frame_params_.frame_type = aom::kKeyFrame;
ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_));
} else if (superframe_cnt_ == 300 && layer_id_.spatial_layer_id == 0) { } else if (superframe_cnt_ == 300 && layer_id_.spatial_layer_id == 0) {
// Go back up to 3 temporal layers. // Go back up to 3 temporal layers.
SetConfigSvc(3, 3); SetConfigSvc(3, 3);
frame_flags_ = AOM_EFLAG_FORCE_KF;
frame_params_.frame_type = aom::kKeyFrame;
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_));
} }
@@ -117,11 +123,15 @@ class RcInterfaceTest : public ::libaom_test::EncoderTest,
// Change to 2 spatial layers (240p, 480p). // Change to 2 spatial layers (240p, 480p).
SetConfigSvc(2, 3); SetConfigSvc(2, 3);
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
frame_flags_ = AOM_EFLAG_FORCE_KF;
frame_params_.frame_type = aom::kKeyFrame;
ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_));
} else if (superframe_cnt_ == 200 && layer_id_.spatial_layer_id == 0) { } else if (superframe_cnt_ == 200 && layer_id_.spatial_layer_id == 0) {
// Change to 1 spatial layer (480p). // Change to 1 spatial layer (480p).
SetConfigSvc(1, 3); SetConfigSvc(1, 3);
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
frame_flags_ = AOM_EFLAG_FORCE_KF;
frame_params_.frame_type = aom::kKeyFrame;
ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_)); ASSERT_TRUE(rc_api_->UpdateRateControl(rc_cfg_));
} else if (superframe_cnt_ == 300 && layer_id_.spatial_layer_id == 0) { } else if (superframe_cnt_ == 300 && layer_id_.spatial_layer_id == 0) {
// Go back to 3 spatial layers (120p, 240p, 480p). // Go back to 3 spatial layers (120p, 240p, 480p).
@@ -148,6 +158,10 @@ class RcInterfaceTest : public ::libaom_test::EncoderTest,
if (encoder_exit_) { if (encoder_exit_) {
return; return;
} }
int num_operating_points;
encoder->Control(AV1E_GET_NUM_OPERATING_POINTS, &num_operating_points);
ASSERT_EQ(num_operating_points,
rc_cfg_.ss_number_layers * rc_cfg_.ts_number_layers);
layer_frame_cnt_++; layer_frame_cnt_++;
frame_cnt_++; frame_cnt_++;
if (layer_id_.spatial_layer_id == rc_cfg_.ss_number_layers - 1) if (layer_id_.spatial_layer_id == rc_cfg_.ss_number_layers - 1)

View File

@@ -79,9 +79,12 @@ void ScaleForFrameNumber(unsigned int frame, unsigned int initial_w,
class ResizingVideoSource : public ::libaom_test::DummyVideoSource { class ResizingVideoSource : public ::libaom_test::DummyVideoSource {
public: public:
explicit ResizingVideoSource(int external_resize_pattern) { explicit ResizingVideoSource(int external_resize_pattern, int width,
int height) {
external_resize_pattern_ = external_resize_pattern; external_resize_pattern_ = external_resize_pattern;
SetSize(1280, 720); top_width_ = width;
top_height_ = height;
SetSize(top_width_, top_height_);
limit_ = 300; limit_ = 300;
} }
~ResizingVideoSource() override = default; ~ResizingVideoSource() override = default;
@@ -92,7 +95,7 @@ class ResizingVideoSource : public ::libaom_test::DummyVideoSource {
unsigned int width = 0; unsigned int width = 0;
unsigned int height = 0; unsigned int height = 0;
libaom_test::ACMRandom rnd(libaom_test::ACMRandom::DeterministicSeed()); libaom_test::ACMRandom rnd(libaom_test::ACMRandom::DeterministicSeed());
ScaleForFrameNumber(frame_, 1280, 720, &width, &height, ScaleForFrameNumber(frame_, top_width_, top_height_, &width, &height,
external_resize_pattern_); external_resize_pattern_);
SetSize(width, height); SetSize(width, height);
FillFrame(); FillFrame();
@@ -104,6 +107,9 @@ class ResizingVideoSource : public ::libaom_test::DummyVideoSource {
private: private:
int external_resize_pattern_; int external_resize_pattern_;
// top_width_/height_ is the configured resolution when codec is created.
int top_width_;
int top_height_;
}; };
class DatarateTestSVC class DatarateTestSVC
@@ -172,6 +178,7 @@ class DatarateTestSVC
use_last_as_scaled_single_ref_ = false; use_last_as_scaled_single_ref_ = false;
external_resize_dynamic_drop_layer_ = false; external_resize_dynamic_drop_layer_ = false;
external_resize_pattern_ = 0; external_resize_pattern_ = 0;
dynamic_tl_ = false;
} }
void PreEncodeFrameHook(::libaom_test::VideoSource *video, void PreEncodeFrameHook(::libaom_test::VideoSource *video,
@@ -309,9 +316,6 @@ class DatarateTestSVC
} }
if (layer_id_.spatial_layer_id == 0 && if (layer_id_.spatial_layer_id == 0 &&
(video->frame() == 1 || video->frame() == 150)) { (video->frame() == 1 || video->frame() == 150)) {
// Set the new top width/height for external resize.
top_sl_width_ = video->img()->d_w;
top_sl_height_ = video->img()->d_h;
for (int i = 0; i < 9; ++i) { for (int i = 0; i < 9; ++i) {
bitrate_layer_[i] = svc_params_.layer_target_bitrate[i]; bitrate_layer_[i] = svc_params_.layer_target_bitrate[i];
} }
@@ -345,8 +349,6 @@ class DatarateTestSVC
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
} else if (layer_id_.spatial_layer_id == 0 && } else if (layer_id_.spatial_layer_id == 0 &&
(video->frame() == 50 || video->frame() == 200)) { (video->frame() == 50 || video->frame() == 200)) {
top_sl_width_ = video->img()->d_w;
top_sl_height_ = video->img()->d_h;
if (external_resize_pattern_ == 1) { if (external_resize_pattern_ == 1) {
// Input size is 1/2. Change layer bitrates to set top layer to 0. // Input size is 1/2. Change layer bitrates to set top layer to 0.
// This will trigger skip encoding/dropping of top spatial layer. // This will trigger skip encoding/dropping of top spatial layer.
@@ -377,8 +379,6 @@ class DatarateTestSVC
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
} else if (layer_id_.spatial_layer_id == 0 && } else if (layer_id_.spatial_layer_id == 0 &&
(video->frame() == 100 || video->frame() == 250)) { (video->frame() == 100 || video->frame() == 250)) {
top_sl_width_ = video->img()->d_w;
top_sl_height_ = video->img()->d_h;
// Input is original size. Change layer bitrates to nonzero for all // Input is original size. Change layer bitrates to nonzero for all
// layers. // layers.
cfg_.rc_target_bitrate = cfg_.rc_target_bitrate =
@@ -395,6 +395,26 @@ class DatarateTestSVC
encoder->Config(&cfg_); encoder->Config(&cfg_);
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_); encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
} }
} else if (dynamic_tl_) {
if (video->frame() == 100) {
// Enable 3 temporal layers.
svc_params_.number_temporal_layers = 3;
number_temporal_layers_ = 3;
svc_params_.layer_target_bitrate[0] = 60 * cfg_.rc_target_bitrate / 100;
svc_params_.layer_target_bitrate[1] = 80 * cfg_.rc_target_bitrate / 100;
svc_params_.layer_target_bitrate[2] = cfg_.rc_target_bitrate;
svc_params_.framerate_factor[0] = 4;
svc_params_.framerate_factor[1] = 2;
svc_params_.framerate_factor[2] = 1;
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
} else if (video->frame() == 200) {
// Go back to 1 temporal layer.
svc_params_.number_temporal_layers = 1;
number_temporal_layers_ = 1;
svc_params_.layer_target_bitrate[0] = cfg_.rc_target_bitrate;
svc_params_.framerate_factor[0] = 1;
encoder->Control(AV1E_SET_SVC_PARAMS, &svc_params_);
}
} }
layer_frame_cnt_++; layer_frame_cnt_++;
DatarateTest::PreEncodeFrameHook(video, encoder); DatarateTest::PreEncodeFrameHook(video, encoder);
@@ -2853,9 +2873,47 @@ class DatarateTestSVC
cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)]; cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)];
cfg_.g_w = 1280; cfg_.g_w = 1280;
cfg_.g_h = 720; cfg_.g_h = 720;
top_sl_width_ = 1280; ResizingVideoSource video(1, 1280, 720);
top_sl_height_ = 720; ResetModel();
ResizingVideoSource video(1); external_resize_dynamic_drop_layer_ = true;
external_resize_pattern_ = 1;
number_temporal_layers_ = 3;
number_spatial_layers_ = 3;
// SL0
const int bitrate_sl0 = 1 * cfg_.rc_target_bitrate / 8;
target_layer_bitrate_[0] = 50 * bitrate_sl0 / 100;
target_layer_bitrate_[1] = 70 * bitrate_sl0 / 100;
target_layer_bitrate_[2] = bitrate_sl0;
// SL1
const int bitrate_sl1 = 3 * cfg_.rc_target_bitrate / 8;
target_layer_bitrate_[3] = 50 * bitrate_sl1 / 100;
target_layer_bitrate_[4] = 70 * bitrate_sl1 / 100;
target_layer_bitrate_[5] = bitrate_sl1;
// SL2
const int bitrate_sl2 = 4 * cfg_.rc_target_bitrate / 8;
target_layer_bitrate_[6] = 50 * bitrate_sl2 / 100;
target_layer_bitrate_[7] = 70 * bitrate_sl2 / 100;
target_layer_bitrate_[8] = bitrate_sl2;
ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
}
virtual void BasicRateTargetingSVC3TL3SLExternalResizePattern1HighResTest() {
cfg_.rc_buf_initial_sz = 500;
cfg_.rc_buf_optimal_sz = 500;
cfg_.rc_buf_sz = 1000;
cfg_.rc_dropframe_thresh = 0;
cfg_.rc_min_quantizer = 0;
cfg_.rc_max_quantizer = 63;
cfg_.rc_end_usage = AOM_CBR;
cfg_.g_lag_in_frames = 0;
cfg_.g_error_resilient = 0;
const int bitrate_array[2] = { 600, 1200 };
cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)];
cfg_.g_w = 1850;
cfg_.g_h = 1110;
cfg_.g_forced_max_frame_width = 1850;
cfg_.g_forced_max_frame_height = 1110;
ResizingVideoSource video(1, 1850, 1110);
ResetModel(); ResetModel();
external_resize_dynamic_drop_layer_ = true; external_resize_dynamic_drop_layer_ = true;
external_resize_pattern_ = 1; external_resize_pattern_ = 1;
@@ -2893,9 +2951,7 @@ class DatarateTestSVC
cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)]; cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)];
cfg_.g_w = 1280; cfg_.g_w = 1280;
cfg_.g_h = 720; cfg_.g_h = 720;
top_sl_width_ = 1280; ResizingVideoSource video(2, 1280, 720);
top_sl_height_ = 720;
ResizingVideoSource video(2);
ResetModel(); ResetModel();
external_resize_dynamic_drop_layer_ = true; external_resize_dynamic_drop_layer_ = true;
external_resize_pattern_ = 2; external_resize_pattern_ = 2;
@@ -2919,6 +2975,70 @@ class DatarateTestSVC
ASSERT_NO_FATAL_FAILURE(RunLoop(&video)); ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
} }
virtual void BasicRateTargetingSVC3TL3SLExternalResizePattern2HighResTest() {
cfg_.rc_buf_initial_sz = 500;
cfg_.rc_buf_optimal_sz = 500;
cfg_.rc_buf_sz = 1000;
cfg_.rc_dropframe_thresh = 0;
cfg_.rc_min_quantizer = 0;
cfg_.rc_max_quantizer = 63;
cfg_.rc_end_usage = AOM_CBR;
cfg_.g_lag_in_frames = 0;
cfg_.g_error_resilient = 0;
const int bitrate_array[2] = { 600, 1200 };
cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)];
cfg_.g_w = 1850;
cfg_.g_h = 1110;
cfg_.g_forced_max_frame_width = 1850;
cfg_.g_forced_max_frame_height = 1110;
ResizingVideoSource video(2, 1850, 1110);
ResetModel();
external_resize_dynamic_drop_layer_ = true;
external_resize_pattern_ = 2;
number_temporal_layers_ = 3;
number_spatial_layers_ = 3;
// SL0
const int bitrate_sl0 = 1 * cfg_.rc_target_bitrate / 8;
target_layer_bitrate_[0] = 50 * bitrate_sl0 / 100;
target_layer_bitrate_[1] = 70 * bitrate_sl0 / 100;
target_layer_bitrate_[2] = bitrate_sl0;
// SL1
const int bitrate_sl1 = 3 * cfg_.rc_target_bitrate / 8;
target_layer_bitrate_[3] = 50 * bitrate_sl1 / 100;
target_layer_bitrate_[4] = 70 * bitrate_sl1 / 100;
target_layer_bitrate_[5] = bitrate_sl1;
// SL2
const int bitrate_sl2 = 4 * cfg_.rc_target_bitrate / 8;
target_layer_bitrate_[6] = 50 * bitrate_sl2 / 100;
target_layer_bitrate_[7] = 70 * bitrate_sl2 / 100;
target_layer_bitrate_[8] = bitrate_sl2;
ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
}
virtual void BasicRateTargetingSVC3TL1SLDynamicTLTest() {
cfg_.rc_buf_initial_sz = 500;
cfg_.rc_buf_optimal_sz = 500;
cfg_.rc_buf_sz = 1000;
cfg_.rc_dropframe_thresh = 0;
cfg_.rc_min_quantizer = 0;
cfg_.rc_max_quantizer = 63;
cfg_.rc_end_usage = AOM_CBR;
cfg_.g_lag_in_frames = 0;
cfg_.g_error_resilient = 0;
::libaom_test::I420VideoSource video("niklas_640_480_30.yuv", 640, 480, 30,
1, 0, 400);
const int bitrate_array[2] = { 600, 1200 };
cfg_.rc_target_bitrate = bitrate_array[GET_PARAM(4)];
target_layer_bitrate_[0] = cfg_.rc_target_bitrate;
cfg_.g_w = 640;
cfg_.g_h = 480;
ResetModel();
number_temporal_layers_ = 1;
number_spatial_layers_ = 1;
dynamic_tl_ = true;
ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
}
int layer_frame_cnt_; int layer_frame_cnt_;
int superframe_cnt_; int superframe_cnt_;
int number_temporal_layers_; int number_temporal_layers_;
@@ -2961,8 +3081,7 @@ class DatarateTestSVC
bool external_resize_dynamic_drop_layer_; bool external_resize_dynamic_drop_layer_;
int bitrate_layer_[9]; int bitrate_layer_[9];
int external_resize_pattern_; int external_resize_pattern_;
int top_sl_width_; bool dynamic_tl_;
int top_sl_height_;
}; };
// Check basic rate targeting for CBR, for 3 temporal layers, 1 spatial. // Check basic rate targeting for CBR, for 3 temporal layers, 1 spatial.
@@ -3259,7 +3378,7 @@ TEST_P(DatarateTestSVC, BasicRateTargetingRPS1TL1SLDropFrames) {
// and denoiser enabled. The external resizer will resize down and back up, // and denoiser enabled. The external resizer will resize down and back up,
// setting 0/nonzero bitrate on spatial enhancement layers to disable/enable // setting 0/nonzero bitrate on spatial enhancement layers to disable/enable
// layers. Resizing starts on first frame and the pattern is: // layers. Resizing starts on first frame and the pattern is:
// 1/4 -> 1/2 -> 1 -> 1/4 -> 1/2. // 1/4 -> 1/2 -> 1 -> 1/4 -> 1/2. Configured resolution is 1280x720.
TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL3SLExternalResizePattern1) { TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL3SLExternalResizePattern1) {
BasicRateTargetingSVC3TL3SLExternalResizePattern1Test(); BasicRateTargetingSVC3TL3SLExternalResizePattern1Test();
} }
@@ -3268,11 +3387,38 @@ TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL3SLExternalResizePattern1) {
// and denoiser enabled. The external resizer will resize down and back up, // and denoiser enabled. The external resizer will resize down and back up,
// setting 0/nonzero bitrate on spatial enhancement layers to disable/enable // setting 0/nonzero bitrate on spatial enhancement layers to disable/enable
// layers. Resizing starts on first frame and the pattern is: // layers. Resizing starts on first frame and the pattern is:
// 1/2 -> 1/4 -> 1 -> 1/2 -> 1/4. // 1/4 -> 1/2 -> 1 -> 1/4 -> 1/2. Configured resolution is 1850x1110.
TEST_P(DatarateTestSVC,
BasicRateTargetingSVC3TL3SLExternalResizePattern1HighRes) {
BasicRateTargetingSVC3TL3SLExternalResizePattern1HighResTest();
}
// For 1 pass CBR SVC with 3 spatial and 3 temporal layers with external resize
// and denoiser enabled. The external resizer will resize down and back up,
// setting 0/nonzero bitrate on spatial enhancement layers to disable/enable
// layers. Resizing starts on first frame and the pattern is:
// 1/2 -> 1/4 -> 1 -> 1/2 -> 1/4. Configured resolution is 1280x720.
TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL3SLExternalResizePattern2) { TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL3SLExternalResizePattern2) {
BasicRateTargetingSVC3TL3SLExternalResizePattern2Test(); BasicRateTargetingSVC3TL3SLExternalResizePattern2Test();
} }
// For 1 pass CBR SVC with 3 spatial and 3 temporal layers with external resize
// and denoiser enabled. The external resizer will resize down and back up,
// setting 0/nonzero bitrate on spatial enhancement layers to disable/enable
// layers. Resizing starts on first frame and the pattern is:
// 1/2 -> 1/4 -> 1 -> 1/2 -> 1/4. Configured resolution is 1850x1110.
TEST_P(DatarateTestSVC,
BasicRateTargetingSVC3TL3SLExternalResizePattern2HighRes) {
BasicRateTargetingSVC3TL3SLExternalResizePattern2HighResTest();
}
// For 1 pass CBR SVC with 1 spatial and dynamic temporal layers.
// Start/initialize with 1 temporal layer and then enable 3 temporal layers
// during the sequence, and then back to 1.
TEST_P(DatarateTestSVC, BasicRateTargetingSVC3TL1SLDynamicTL) {
BasicRateTargetingSVC3TL1SLDynamicTLTest();
}
TEST(SvcParams, BitrateOverflow) { TEST(SvcParams, BitrateOverflow) {
uint8_t buf[6] = { 0 }; uint8_t buf[6] = { 0 };
aom_image_t img; aom_image_t img;

View File

@@ -0,0 +1,26 @@
Copyright (c) The Highway Project Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,11 @@
URL: https://github.com/google/highway
Version: e92c12750d18c372867809b882dd3ec6874ecc73
License: BSD-3-clause clear
License File: LICENSE-BSD3
Description:
Highway is a C++ library that provides portable SIMD/vector intrinsics.
Local Changes:
Remove everything except hwy/ and LICENSE-BSD3

View File

@@ -0,0 +1,11 @@
// Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: BSD-3-Clause
#ifndef HIGHWAY_HWY_ABORT_H_
#define HIGHWAY_HWY_ABORT_H_
// Empty header for compatibility.
// All Abort/Warn functionalities are in base.h.
#endif // HIGHWAY_HWY_ABORT_H_

View File

@@ -0,0 +1,426 @@
// Copyright 2020 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_
#define HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_
// Memory allocator with support for alignment and offsets.
#include <algorithm>
#include <array>
#include <cassert>
#include <cstdint>
#include <cstring>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include "third_party/highway/hwy/base.h"
#include "third_party/highway/hwy/per_target.h"
namespace hwy {
// Minimum alignment of allocated memory for use in HWY_ASSUME_ALIGNED, which
// requires a literal. To prevent false sharing, this should be at least the
// L1 cache line size, usually 64 bytes. However, Intel's L2 prefetchers may
// access pairs of lines, and M1 L2 and POWER8 lines are also 128 bytes.
#define HWY_ALIGNMENT 128
template <typename T>
HWY_API constexpr bool IsAligned(T* ptr, size_t align = HWY_ALIGNMENT) {
return reinterpret_cast<uintptr_t>(ptr) % align == 0;
}
// Pointers to functions equivalent to malloc/free with an opaque void* passed
// to them.
using AllocPtr = void* (*)(void* opaque, size_t bytes);
using FreePtr = void (*)(void* opaque, void* memory);
// Returns null or a pointer to at least `payload_size` (which can be zero)
// bytes of newly allocated memory, aligned to the larger of HWY_ALIGNMENT and
// the vector size. Calls `alloc` with the passed `opaque` pointer to obtain
// memory or malloc() if it is null.
HWY_DLLEXPORT void* AllocateAlignedBytes(size_t payload_size,
AllocPtr alloc_ptr = nullptr,
void* opaque_ptr = nullptr);
// Frees all memory. No effect if `aligned_pointer` == nullptr, otherwise it
// must have been returned from a previous call to `AllocateAlignedBytes`.
// Calls `free_ptr` with the passed `opaque_ptr` pointer to free the memory; if
// `free_ptr` function is null, uses the default free().
HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer,
FreePtr free_ptr, void* opaque_ptr);
// Class that deletes the aligned pointer passed to operator() calling the
// destructor before freeing the pointer. This is equivalent to the
// std::default_delete but for aligned objects. For a similar deleter equivalent
// to free() for aligned memory see AlignedFreer().
class AlignedDeleter {
public:
AlignedDeleter() : free_(nullptr), opaque_ptr_(nullptr) {}
AlignedDeleter(FreePtr free_ptr, void* opaque_ptr)
: free_(free_ptr), opaque_ptr_(opaque_ptr) {}
template <typename T>
void operator()(T* aligned_pointer) const {
return DeleteAlignedArray(aligned_pointer, free_, opaque_ptr_,
TypedArrayDeleter<T>);
}
private:
template <typename T>
static void TypedArrayDeleter(void* ptr, size_t size_in_bytes) {
size_t elems = size_in_bytes / sizeof(T);
for (size_t i = 0; i < elems; i++) {
// Explicitly call the destructor on each element.
(static_cast<T*>(ptr) + i)->~T();
}
}
// Function prototype that calls the destructor for each element in a typed
// array. TypeArrayDeleter<T> would match this prototype.
using ArrayDeleter = void (*)(void* t_ptr, size_t t_size);
HWY_DLLEXPORT static void DeleteAlignedArray(void* aligned_pointer,
FreePtr free_ptr,
void* opaque_ptr,
ArrayDeleter deleter);
FreePtr free_;
void* opaque_ptr_;
};
// Unique pointer to T with custom aligned deleter. This can be a single
// element U or an array of element if T is a U[]. The custom aligned deleter
// will call the destructor on U or each element of a U[] in the array case.
template <typename T>
using AlignedUniquePtr = std::unique_ptr<T, AlignedDeleter>;
// Aligned memory equivalent of make_unique<T> using the custom allocators
// alloc/free with the passed `opaque` pointer. This function calls the
// constructor with the passed Args... and calls the destructor of the object
// when the AlignedUniquePtr is destroyed.
template <typename T, typename... Args>
AlignedUniquePtr<T> MakeUniqueAlignedWithAlloc(AllocPtr alloc, FreePtr free,
void* opaque, Args&&... args) {
T* ptr = static_cast<T*>(AllocateAlignedBytes(sizeof(T), alloc, opaque));
return AlignedUniquePtr<T>(new (ptr) T(std::forward<Args>(args)...),
AlignedDeleter(free, opaque));
}
// Similar to MakeUniqueAlignedWithAlloc but using the default alloc/free
// functions.
template <typename T, typename... Args>
AlignedUniquePtr<T> MakeUniqueAligned(Args&&... args) {
T* ptr = static_cast<T*>(AllocateAlignedBytes(sizeof(T)));
return AlignedUniquePtr<T>(new (ptr) T(std::forward<Args>(args)...),
AlignedDeleter());
}
template <class T>
struct AlignedAllocator {
using value_type = T;
AlignedAllocator() = default;
template <class V>
explicit AlignedAllocator(const AlignedAllocator<V>&) noexcept {}
template <class V>
value_type* allocate(V n) {
static_assert(std::is_integral<V>::value,
"AlignedAllocator only supports integer types");
static_assert(sizeof(V) <= sizeof(std::size_t),
"V n must be smaller or equal size_t to avoid overflow");
return static_cast<value_type*>(
AllocateAlignedBytes(static_cast<std::size_t>(n) * sizeof(value_type)));
}
template <class V>
void deallocate(value_type* p, HWY_MAYBE_UNUSED V n) {
return FreeAlignedBytes(p, nullptr, nullptr);
}
};
template <class T, class V>
constexpr bool operator==(const AlignedAllocator<T>&,
const AlignedAllocator<V>&) noexcept {
return true;
}
template <class T, class V>
constexpr bool operator!=(const AlignedAllocator<T>&,
const AlignedAllocator<V>&) noexcept {
return false;
}
template <class T>
using AlignedVector = std::vector<T, AlignedAllocator<T>>;
// Helpers for array allocators (avoids overflow)
namespace detail {
// Returns x such that 1u << x == n (if n is a power of two).
static inline constexpr size_t ShiftCount(size_t n) {
return (n <= 1) ? 0 : 1 + ShiftCount(n / 2);
}
template <typename T>
T* AllocateAlignedItems(size_t items, AllocPtr alloc_ptr, void* opaque_ptr) {
constexpr size_t kSize = sizeof(T);
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
constexpr size_t kBits = ShiftCount(kSize);
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
const size_t bytes = kIsPow2 ? items << kBits : items * kSize;
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
if (check != items) {
return nullptr; // overflowed
}
return static_cast<T*>(AllocateAlignedBytes(bytes, alloc_ptr, opaque_ptr));
}
} // namespace detail
// Aligned memory equivalent of make_unique<T[]> for array types using the
// custom allocators alloc/free. This function calls the constructor with the
// passed Args... on every created item. The destructor of each element will be
// called when the AlignedUniquePtr is destroyed.
template <typename T, typename... Args>
AlignedUniquePtr<T[]> MakeUniqueAlignedArrayWithAlloc(
size_t items, AllocPtr alloc, FreePtr free, void* opaque, Args&&... args) {
T* ptr = detail::AllocateAlignedItems<T>(items, alloc, opaque);
if (ptr != nullptr) {
for (size_t i = 0; i < items; i++) {
new (ptr + i) T(std::forward<Args>(args)...);
}
}
return AlignedUniquePtr<T[]>(ptr, AlignedDeleter(free, opaque));
}
template <typename T, typename... Args>
AlignedUniquePtr<T[]> MakeUniqueAlignedArray(size_t items, Args&&... args) {
return MakeUniqueAlignedArrayWithAlloc<T, Args...>(
items, nullptr, nullptr, nullptr, std::forward<Args>(args)...);
}
// Custom deleter for std::unique_ptr equivalent to using free() as a deleter
// but for aligned memory.
class AlignedFreer {
public:
// Pass address of this to ctor to skip deleting externally-owned memory.
static void DoNothing(void* /*opaque*/, void* /*aligned_pointer*/) {}
AlignedFreer() : free_(nullptr), opaque_ptr_(nullptr) {}
AlignedFreer(FreePtr free_ptr, void* opaque_ptr)
: free_(free_ptr), opaque_ptr_(opaque_ptr) {}
template <typename T>
void operator()(T* aligned_pointer) const {
FreeAlignedBytes(aligned_pointer, free_, opaque_ptr_);
}
private:
FreePtr free_;
void* opaque_ptr_;
};
// Unique pointer to single POD, or (if T is U[]) an array of POD. For non POD
// data use AlignedUniquePtr.
template <typename T>
using AlignedFreeUniquePtr = std::unique_ptr<T, AlignedFreer>;
// Allocate an aligned and uninitialized array of POD values as a unique_ptr.
// Upon destruction of the unique_ptr the aligned array will be freed.
template <typename T>
AlignedFreeUniquePtr<T[]> AllocateAligned(const size_t items, AllocPtr alloc,
FreePtr free, void* opaque) {
static_assert(std::is_trivially_copyable<T>::value,
"AllocateAligned: requires trivially copyable T");
static_assert(std::is_trivially_destructible<T>::value,
"AllocateAligned: requires trivially destructible T");
return AlignedFreeUniquePtr<T[]>(
detail::AllocateAlignedItems<T>(items, alloc, opaque),
AlignedFreer(free, opaque));
}
// Same as previous AllocateAligned(), using default allocate/free functions.
template <typename T>
AlignedFreeUniquePtr<T[]> AllocateAligned(const size_t items) {
return AllocateAligned<T>(items, nullptr, nullptr, nullptr);
}
// A simple span containing data and size of data.
template <typename T>
class Span {
public:
Span() = default;
Span(T* data, size_t size) : size_(size), data_(data) {}
template <typename U>
Span(U u) : Span(u.data(), u.size()) {}
Span(std::initializer_list<const T> v) : Span(v.begin(), v.size()) {}
// Copies the contents of the initializer list to the span.
Span<T>& operator=(std::initializer_list<const T> v) {
HWY_DASSERT(size_ == v.size());
CopyBytes(v.begin(), data_, sizeof(T) * std::min(size_, v.size()));
return *this;
}
// Returns the size of the contained data.
size_t size() const { return size_; }
// Returns a pointer to the contained data.
T* data() { return data_; }
T* data() const { return data_; }
// Returns the element at index.
T& operator[](size_t index) const { return data_[index]; }
// Returns an iterator pointing to the first element of this span.
T* begin() { return data_; }
// Returns a const iterator pointing to the first element of this span.
constexpr const T* cbegin() const { return data_; }
// Returns an iterator pointing just beyond the last element at the
// end of this span.
T* end() { return data_ + size_; }
// Returns a const iterator pointing just beyond the last element at the
// end of this span.
constexpr const T* cend() const { return data_ + size_; }
private:
size_t size_ = 0;
T* data_ = nullptr;
};
// A multi dimensional array containing an aligned buffer.
//
// To maintain alignment, the innermost dimension will be padded to ensure all
// innermost arrays are aligned.
template <typename T, size_t axes>
class AlignedNDArray {
static_assert(std::is_trivial<T>::value,
"AlignedNDArray can only contain trivial types");
public:
AlignedNDArray(AlignedNDArray&& other) = default;
AlignedNDArray& operator=(AlignedNDArray&& other) = default;
// Constructs an array of the provided shape and fills it with zeros.
explicit AlignedNDArray(std::array<size_t, axes> shape) : shape_(shape) {
sizes_ = ComputeSizes(shape_);
memory_shape_ = shape_;
// Round the innermost dimension up to the number of bytes available for
// SIMD operations on this architecture to make sure that each innermost
// array is aligned from the first element.
memory_shape_[axes - 1] = RoundUpTo(memory_shape_[axes - 1], VectorBytes());
memory_sizes_ = ComputeSizes(memory_shape_);
buffer_ = hwy::AllocateAligned<T>(memory_size());
hwy::ZeroBytes(buffer_.get(), memory_size() * sizeof(T));
}
// Returns a span containing the innermost array at the provided indices.
Span<T> operator[](std::array<const size_t, axes - 1> indices) {
return Span<T>(buffer_.get() + Offset(indices), sizes_[indices.size()]);
}
// Returns a const span containing the innermost array at the provided
// indices.
Span<const T> operator[](std::array<const size_t, axes - 1> indices) const {
return Span<const T>(buffer_.get() + Offset(indices),
sizes_[indices.size()]);
}
// Returns the shape of the array, which might be smaller than the allocated
// buffer after padding the last axis to alignment.
const std::array<size_t, axes>& shape() const { return shape_; }
// Returns the shape of the allocated buffer, which might be larger than the
// used size of the array after padding to alignment.
const std::array<size_t, axes>& memory_shape() const { return memory_shape_; }
// Returns the size of the array, which might be smaller than the allocated
// buffer after padding the last axis to alignment.
size_t size() const { return sizes_[0]; }
// Returns the size of the allocated buffer, which might be larger than the
// used size of the array after padding to alignment.
size_t memory_size() const { return memory_sizes_[0]; }
// Returns a pointer to the allocated buffer.
T* data() { return buffer_.get(); }
// Returns a const pointer to the buffer.
const T* data() const { return buffer_.get(); }
// Truncates the array by updating its shape.
//
// The new shape must be equal to or less than the old shape in all axes.
//
// Doesn't modify underlying memory.
void truncate(const std::array<size_t, axes>& new_shape) {
#if HWY_IS_DEBUG_BUILD
for (size_t axis_index = 0; axis_index < axes; ++axis_index) {
HWY_ASSERT(new_shape[axis_index] <= shape_[axis_index]);
}
#endif
shape_ = new_shape;
sizes_ = ComputeSizes(shape_);
}
private:
std::array<size_t, axes> shape_;
std::array<size_t, axes> memory_shape_;
std::array<size_t, axes + 1> sizes_;
std::array<size_t, axes + 1> memory_sizes_;
hwy::AlignedFreeUniquePtr<T[]> buffer_;
// Computes offset in the buffer based on the provided indices.
size_t Offset(std::array<const size_t, axes - 1> indices) const {
size_t offset = 0;
size_t shape_index = 0;
for (const size_t axis_index : indices) {
offset += memory_sizes_[shape_index + 1] * axis_index;
shape_index++;
}
return offset;
}
// Computes the sizes of all sub arrays based on the sizes of each axis.
//
// Does this by multiplying the size of each axis with the previous one in
// reverse order, starting with the conceptual axis of size 1 containing the
// actual elements in the array.
static std::array<size_t, axes + 1> ComputeSizes(
std::array<size_t, axes> shape) {
std::array<size_t, axes + 1> sizes;
size_t axis = shape.size();
sizes[axis] = 1;
while (axis > 0) {
--axis;
sizes[axis] = sizes[axis + 1] * shape[axis];
}
return sizes;
}
};
} // namespace hwy
#endif // HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_

View File

@@ -0,0 +1,504 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_AUTO_TUNE_H_
#define HIGHWAY_HWY_AUTO_TUNE_H_
#include <stddef.h>
#include <stdint.h>
#include <string.h> // memmove
#include <cmath>
#include <vector>
#include "third_party/highway/hwy/aligned_allocator.h" // Span
#include "third_party/highway/hwy/base.h" // HWY_MIN
#include "third_party/highway/hwy/contrib/sort/vqsort.h"
// Infrastructure for auto-tuning (choosing optimal parameters at runtime).
namespace hwy {
// O(1) storage to estimate the central tendency of hundreds of independent
// distributions (one per configuration). The number of samples per distribution
// (`kMinSamples`) varies from few to dozens. We support both by first storing
// values in a buffer, and when full, switching to online variance estimation.
// Modified from `hwy/stats.h`.
class CostDistribution {
public:
static constexpr size_t kMaxValues = 14; // for total size of 128 bytes
void Notify(const double x) {
if (HWY_UNLIKELY(x < 0.0)) {
HWY_WARN("Ignoring negative cost %f.", x);
return;
}
// Online phase after filling and warm-up.
if (HWY_LIKELY(IsOnline())) return OnlineNotify(x);
// Fill phase: store up to `kMaxValues` values.
values_[num_values_++] = x;
HWY_DASSERT(num_values_ <= kMaxValues);
if (HWY_UNLIKELY(num_values_ == kMaxValues)) {
WarmUpOnline();
HWY_DASSERT(IsOnline());
}
}
// Returns an estimate of the true cost, mitigating the impact of noise.
//
// Background and observations from time measurements in `thread_pool.h`:
// - We aim for O(1) storage because there may be hundreds of instances.
// - The mean is biased upwards by mostly additive noise: particularly
// interruptions such as context switches, but also contention.
// - The minimum is not a robust estimator because there are also "lucky
// shots" (1.2-1.6x lower values) where interruptions or contention happen
// to be low.
// - We want to preserve information about contention and a configuration's
// sensitivity to it. Otherwise, we are optimizing for the best-case, not
// the common case.
// - It is still important to minimize the influence of outliers, such as page
// faults, which can cause multiple times larger measurements.
// - Detecting outliers based only on the initial variance is too brittle. If
// the sample is narrow, measurements will fluctuate across runs because
// too many measurements are considered outliers. This would cause the
// 'best' configuration to vary.
//
// Approach:
// - Use Winsorization to reduce the impact of outliers, while preserving
// information on the central tendency.
// - Continually update the thresholds based on the online variance, with
// exponential smoothing for stability.
// - Trim the initial sample via MAD or skewness for a robust estimate of the
// variance.
double EstimateCost() {
if (!IsOnline()) {
WarmUpOnline();
HWY_DASSERT(IsOnline());
}
return Mean();
}
// Multiplex online state into values_ to allow higher `kMaxValues`.
// Public for inspection in tests. Do not use directly.
double& M1() { return values_[0]; } // Moments for variance.
double& M2() { return values_[1]; }
double& Mean() { return values_[2]; } // Exponential smoothing.
double& Stddev() { return values_[3]; }
double& Lower() { return values_[4]; }
double& Upper() { return values_[5]; }
private:
static double Median(double* to_sort, size_t n) {
HWY_DASSERT(n >= 2);
// F64 is supported everywhere except Armv7.
#if !HWY_ARCH_ARM_V7
VQSort(to_sort, n, SortAscending());
#else
// Values are known to be finite and non-negative, hence sorting as U64 is
// equivalent.
VQSort(reinterpret_cast<uint64_t*>(to_sort), n, SortAscending());
#endif
if (n & 1) return to_sort[n / 2];
// Even length: average of two middle elements.
return (to_sort[n / 2] + to_sort[n / 2 - 1]) * 0.5;
}
static double MAD(const double* values, size_t n, const double median) {
double abs_dev[kMaxValues];
for (size_t i = 0; i < n; ++i) {
abs_dev[i] = ScalarAbs(values[i] - median);
}
return Median(abs_dev, n);
}
// If `num_values_` is large enough, sorts and discards outliers: either via
// MAD, or if too many values are equal, by trimming according to skewness.
void RemoveOutliers() {
if (num_values_ < 3) return; // Not enough to discard two.
HWY_DASSERT(num_values_ <= kMaxValues);
// Given the noise level in `auto_tune_test`, it can happen that 1/4 of the
// sample is an outlier *in either direction*. Use median absolute
// deviation, which is robust to almost half of the sample being outliers.
const double median = Median(values_, num_values_); // sorts in-place.
const double mad = MAD(values_, num_values_, median);
// At least half the sample is equal.
if (mad == 0.0) {
// Estimate skewness to decide which side to trim more.
const double skewness =
(values_[num_values_ - 1] - median) - (median - values_[0]);
const size_t trim = HWY_MAX(num_values_ / 2, size_t{2});
const size_t left =
HWY_MAX(skewness < 0.0 ? trim * 3 / 4 : trim / 4, size_t{1});
num_values_ -= trim;
HWY_DASSERT(num_values_ >= 1);
memmove(values_, values_ + left, num_values_ * sizeof(values_[0]));
return;
}
const double upper = median + 5.0 * mad;
const double lower = median - 5.0 * mad;
size_t right = num_values_ - 1;
while (values_[right] > upper) --right;
// Nonzero MAD implies no more than half are equal, so we did not advance
// beyond the median.
HWY_DASSERT(right >= num_values_ / 2);
size_t left = 0;
while (left < right && values_[left] < lower) ++left;
HWY_DASSERT(left <= num_values_ / 2);
num_values_ = right - left + 1;
memmove(values_, values_ + left, num_values_ * sizeof(values_[0]));
}
double SampleMean() const {
// Only called in non-online phase, but buffer might not be full.
HWY_DASSERT(!IsOnline() && 0 != num_values_ && num_values_ <= kMaxValues);
double sum = 0.0;
for (size_t i = 0; i < num_values_; ++i) {
sum += values_[i];
}
return sum / static_cast<double>(num_values_);
}
// Unbiased estimator for population variance even for small `num_values_`.
double SampleVariance(double sample_mean) const {
HWY_DASSERT(sample_mean >= 0.0); // we checked costs are non-negative.
// Only called in non-online phase, but buffer might not be full.
HWY_DASSERT(!IsOnline() && 0 != num_values_ && num_values_ <= kMaxValues);
if (HWY_UNLIKELY(num_values_ == 1)) return 0.0; // prevent divide-by-zero.
double sum2 = 0.0;
for (size_t i = 0; i < num_values_; ++i) {
const double d = values_[i] - sample_mean;
sum2 += d * d;
}
return sum2 / static_cast<double>(num_values_ - 1);
}
bool IsOnline() const { return online_n_ > 0.0; }
void OnlineNotify(double x) {
// Winsorize.
x = HWY_MIN(HWY_MAX(Lower(), x), Upper());
// Welford's online variance estimator.
// https://media.thinkbrg.com/wp-content/uploads/2020/06/19094655/720_720_McCrary_ImplementingAlgorithms_Whitepaper_20151119_WEB.pdf#page=7.09
const double n_minus_1 = online_n_;
online_n_ += 1.0;
const double d = x - M1();
const double d_div_n = d / online_n_;
M1() += d_div_n;
HWY_DASSERT(M1() >= Lower());
M2() += d * n_minus_1 * d_div_n; // d^2 * (N-1)/N
// HWY_MAX avoids divide-by-zero.
const double stddev = std::sqrt(M2() / HWY_MAX(1.0, n_minus_1));
// Exponential smoothing.
constexpr double kNew = 0.2; // relatively fast update
constexpr double kOld = 1.0 - kNew;
Mean() = M1() * kNew + Mean() * kOld;
Stddev() = stddev * kNew + Stddev() * kOld;
// Update thresholds from smoothed mean and stddev to enable recovering from
// a too narrow initial range due to excessive trimming.
Lower() = Mean() - 3.5 * Stddev();
Upper() = Mean() + 3.5 * Stddev();
}
void WarmUpOnline() {
RemoveOutliers();
// Compute and copy before writing to `M1`, which overwrites `values_`!
const double sample_mean = SampleMean();
const double sample_variance = SampleVariance(sample_mean);
double copy[kMaxValues];
hwy::CopyBytes(values_, copy, num_values_ * sizeof(values_[0]));
M1() = M2() = 0.0;
Mean() = sample_mean;
Stddev() = std::sqrt(sample_variance);
// For single-value or all-equal sample, widen the range, else we will only
// accept the same value.
if (Stddev() == 0.0) Stddev() = Mean() / 2;
// High tolerance because the distribution is not actually Gaussian, and
// we trimmed up to *half*, and do not want to reject too many values in
// the online phase.
Lower() = Mean() - 4.0 * Stddev();
Upper() = Mean() + 4.0 * Stddev();
// Feed copied values into online estimator.
for (size_t i = 0; i < num_values_; ++i) {
OnlineNotify(copy[i]);
}
HWY_DASSERT(IsOnline());
#if SIZE_MAX == 0xFFFFFFFFu
(void)padding_;
#endif
}
size_t num_values_ = 0; // size of `values_` <= `kMaxValues`
#if SIZE_MAX == 0xFFFFFFFFu
uint32_t padding_ = 0;
#endif
double online_n_ = 0.0; // number of calls to `OnlineNotify`.
double values_[kMaxValues];
};
static_assert(sizeof(CostDistribution) == 128, "");
// Implements a counter with wrap-around, plus the ability to skip values.
// O(1) time, O(N) space via doubly-linked list of indices.
class NextWithSkip {
public:
NextWithSkip() {}
explicit NextWithSkip(size_t num) {
links_.reserve(num);
for (size_t i = 0; i < num; ++i) {
links_.emplace_back(i, num);
}
}
size_t Next(size_t pos) {
HWY_DASSERT(pos < links_.size());
HWY_DASSERT(!links_[pos].IsRemoved());
return links_[pos].Next();
}
// Must not be called for an already skipped position. Ignores an attempt to
// skip the last remaining position.
void Skip(size_t pos) {
HWY_DASSERT(!links_[pos].IsRemoved()); // not already skipped.
const size_t prev = links_[pos].Prev();
const size_t next = links_[pos].Next();
if (prev == pos || next == pos) return; // last remaining position.
links_[next].SetPrev(prev);
links_[prev].SetNext(next);
links_[pos].Remove();
}
private:
// Combine prev/next into one array to improve locality/reduce allocations.
class Link {
// Bit-shifts avoid potentially expensive 16-bit loads. Store `next` at the
// top and `prev` at the bottom for extraction with a single shift/AND.
// There may be hundreds of configurations, so 8 bits are not enough.
static constexpr size_t kBits = 14;
static constexpr size_t kShift = 32 - kBits;
static constexpr uint32_t kMaxNum = 1u << kBits;
public:
Link(size_t pos, size_t num) {
HWY_DASSERT(num < kMaxNum);
const size_t prev = pos == 0 ? num - 1 : pos - 1;
const size_t next = pos == num - 1 ? 0 : pos + 1;
bits_ =
(static_cast<uint32_t>(next) << kShift) | static_cast<uint32_t>(prev);
HWY_DASSERT(Next() == next && Prev() == prev);
HWY_DASSERT(!IsRemoved());
}
bool IsRemoved() const { return (bits_ & kMaxNum) != 0; }
void Remove() { bits_ |= kMaxNum; }
size_t Next() const { return bits_ >> kShift; }
size_t Prev() const { return bits_ & (kMaxNum - 1); }
void SetNext(size_t next) {
HWY_DASSERT(next < kMaxNum);
bits_ &= (~0u >> kBits); // clear old next
bits_ |= static_cast<uint32_t>(next) << kShift;
HWY_DASSERT(Next() == next);
HWY_DASSERT(!IsRemoved());
}
void SetPrev(size_t prev) {
HWY_DASSERT(prev < kMaxNum);
bits_ &= ~(kMaxNum - 1); // clear old prev
bits_ |= static_cast<uint32_t>(prev);
HWY_DASSERT(Prev() == prev);
HWY_DASSERT(!IsRemoved());
}
private:
uint32_t bits_;
};
std::vector<Link> links_;
};
// State machine for choosing at runtime the lowest-cost `Config`, which is
// typically a struct containing multiple parameters. For an introduction, see
// "Auto-Tuning and Performance Portability on Heterogeneous Hardware".
//
// **Which parameters**
// Note that simple parameters such as the L2 cache size can be directly queried
// via `hwy/contrib/thread_pool/topology.h`. Difficult to predict parameters
// such as task granularity are more appropriate for auto-tuning. We also
// suggest that at least some parameters should also be 'algorithm variants'
// such as parallel vs. serial, or 2D tiling vs. 1D striping.
//
// **Search strategy**
// To guarantee the optimal result, we use exhaustive search, which is suitable
// for around 10 parameters and a few hundred combinations of 'candidate'
// configurations.
//
// **How to generate candidates**
// To keep this framework simple and generic, applications enumerate the search
// space and pass the list of all feasible candidates to `SetCandidates` before
// the first call to `NextConfig`. Applications should prune the space as much
// as possible, e.g. by upper-bounding parameters based on the known cache
// sizes, and applying constraints such as one being a multiple of another.
//
// **Usage**
// Applications typically conditionally branch to the code implementing the
// configuration returned by `NextConfig`. They measure the cost of running it
// and pass that to `NotifyCost`. Branching avoids the complexity and
// opaqueness of a JIT. The number of branches can be reduced (at the cost of
// code size) by inlining low-level decisions into larger code regions, e.g. by
// hoisting them outside hot loops.
//
// **What is cost**
// Cost is an arbitrary `uint64_t`, with lower values being better. Most
// applications will use the elapsed time. If the tasks being tuned are short,
// it is important to use a high-resolution timer such as `hwy/timer.h`. Energy
// may also be useful [https://www.osti.gov/servlets/purl/1361296].
//
// **Online vs. offline**
// Although applications can auto-tune once, offline, it may be difficult to
// ensure the stored configuration still applies to the current circumstances.
// Thus we recommend online auto-tuning, re-discovering the configuration on
// each run. We assume the overhead of bookkeeping and measuring cost is
// negligible relative to the actual work. The cost of auto-tuning is then that
// of running sub-optimal configurations. Assuming the best configuration is
// better than baseline, and the work is performed many thousands of times, the
// cost is outweighed by the benefits.
//
// **kMinSamples**
// To further reduce overhead, after `kMinSamples` rounds (= measurements of
// each configuration) we start excluding configurations from further
// measurements if they are sufficiently worse than the current best.
// `kMinSamples` can be several dozen when the tasks being tuned take a few
// microseconds. Even for longer tasks, it should be at least 2 for some noise
// tolerance. After this, there are another `kMinSamples / 2 + 1` rounds before
// declaring the winner.
template <typename Config, size_t kMinSamples = 2>
class AutoTune {
public:
// Returns non-null best configuration if auto-tuning has already finished.
// Otherwise, callers continue calling `NextConfig` and `NotifyCost`.
// Points into `Candidates()`.
const Config* Best() const { return best_; }
// If false, caller must call `SetCandidates` before `NextConfig`.
bool HasCandidates() const {
HWY_DASSERT(!Best());
return !candidates_.empty();
}
// WARNING: invalidates `Best()`, do not call if that is non-null.
void SetCandidates(std::vector<Config> candidates) {
HWY_DASSERT(!Best() && !HasCandidates());
candidates_.swap(candidates);
HWY_DASSERT(HasCandidates());
costs_.resize(candidates_.size());
list_ = NextWithSkip(candidates_.size());
}
// Typically called after Best() is non-null to compare all candidates' costs.
Span<const Config> Candidates() const {
HWY_DASSERT(HasCandidates());
return Span<const Config>(candidates_.data(), candidates_.size());
}
Span<CostDistribution> Costs() {
return Span<CostDistribution>(costs_.data(), costs_.size());
}
// Returns the current `Config` to measure.
const Config& NextConfig() const {
HWY_DASSERT(!Best() && HasCandidates());
return candidates_[config_idx_];
}
// O(1) except at the end of each round, which is O(N).
void NotifyCost(uint64_t cost) {
HWY_DASSERT(!Best() && HasCandidates());
costs_[config_idx_].Notify(static_cast<double>(cost));
// Save now before we update `config_idx_`.
const size_t my_idx = config_idx_;
// Only retrieve once we have enough samples, otherwise, we switch to
// online variance before the buffer is populated.
const double my_cost = rounds_complete_ >= kMinSamples
? costs_[config_idx_].EstimateCost()
: 0.0;
// Advance to next non-skipped config with wrap-around. This decorrelates
// measurements by not immediately re-measuring the same config.
config_idx_ = list_.Next(config_idx_);
// Might still equal `my_idx` if this is the only non-skipped config.
// Disqualify from future `NextConfig` if cost was too far beyond the
// current best. This reduces the number of measurements, while tolerating
// noise in the first few measurements. Must happen after advancing.
if (my_cost > skip_if_above_) {
list_.Skip(my_idx);
}
// Wrap-around indicates the round is complete.
if (HWY_UNLIKELY(config_idx_ <= my_idx)) {
++rounds_complete_;
// Enough samples for stable estimates: update the thresholds.
if (rounds_complete_ >= kMinSamples) {
double best_cost = HighestValue<double>();
size_t idx_min = 0;
for (size_t i = 0; i < candidates_.size(); ++i) {
const double estimate = costs_[i].EstimateCost();
if (estimate < best_cost) {
best_cost = estimate;
idx_min = i;
}
}
skip_if_above_ = best_cost * 1.25;
// After sufficient rounds, declare the winner.
if (HWY_UNLIKELY(rounds_complete_ == 3 * kMinSamples / 2 + 1)) {
best_ = &candidates_[idx_min];
HWY_DASSERT(Best());
}
}
}
}
// Avoid printing during the first few rounds, because those might be noisy
// and not yet skipped.
bool ShouldPrint() { return rounds_complete_ > kMinSamples; }
private:
const Config* best_ = nullptr;
std::vector<Config> candidates_;
std::vector<CostDistribution> costs_; // one per candidate
size_t config_idx_ = 0; // [0, candidates_.size())
NextWithSkip list_;
size_t rounds_complete_ = 0;
double skip_if_above_ = 0.0;
};
} // namespace hwy
#endif // HIGHWAY_HWY_AUTO_TUNE_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,158 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_BIT_SET_H_
#define HIGHWAY_HWY_BIT_SET_H_
// BitSet with fast Foreach for up to 64 and 4096 members.
#include <stddef.h>
#include "third_party/highway/hwy/base.h"
namespace hwy {
// 64-bit specialization of std::bitset, which lacks Foreach.
class BitSet64 {
public:
// No harm if `i` is already set.
void Set(size_t i) {
HWY_DASSERT(i < 64);
bits_ |= (1ULL << i);
HWY_DASSERT(Get(i));
}
// Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does
// not clear any existing bits.
void SetNonzeroBitsFrom64(uint64_t bits) { bits_ |= bits; }
void Clear(size_t i) {
HWY_DASSERT(i < 64);
bits_ &= ~(1ULL << i);
}
bool Get(size_t i) const {
HWY_DASSERT(i < 64);
return (bits_ & (1ULL << i)) != 0;
}
// Returns true if any Get(i) would return true for i in [0, 64).
bool Any() const { return bits_ != 0; }
// Returns lowest i such that Get(i). Caller must ensure Any() beforehand!
size_t First() const {
HWY_DASSERT(Any());
return Num0BitsBelowLS1Bit_Nonzero64(bits_);
}
// Returns uint64_t(Get(i)) << i for i in [0, 64).
uint64_t Get64() const { return bits_; }
// Calls `func(i)` for each `i` in the set. It is safe for `func` to modify
// the set, but the current Foreach call is unaffected.
template <class Func>
void Foreach(const Func& func) const {
uint64_t remaining_bits = bits_;
while (remaining_bits != 0) {
const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits);
remaining_bits &= remaining_bits - 1; // clear LSB
func(i);
}
}
size_t Count() const { return PopCount(bits_); }
private:
uint64_t bits_ = 0;
};
// Two-level bitset for up to kMaxSize <= 4096 values.
template <size_t kMaxSize = 4096>
class BitSet4096 {
public:
// No harm if `i` is already set.
void Set(size_t i) {
HWY_DASSERT(i < kMaxSize);
const size_t idx = i / 64;
const size_t mod = i % 64;
bits_[idx].Set(mod);
nonzero_.Set(idx);
HWY_DASSERT(Get(i));
}
// Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does
// not clear any existing bits.
void SetNonzeroBitsFrom64(uint64_t bits) {
bits_[0].SetNonzeroBitsFrom64(bits);
if (bits) nonzero_.Set(0);
}
void Clear(size_t i) {
HWY_DASSERT(i < kMaxSize);
const size_t idx = i / 64;
const size_t mod = i % 64;
bits_[idx].Clear(mod);
if (!bits_[idx].Any()) {
nonzero_.Clear(idx);
}
HWY_DASSERT(!Get(i));
}
bool Get(size_t i) const {
HWY_DASSERT(i < kMaxSize);
const size_t idx = i / 64;
const size_t mod = i % 64;
return bits_[idx].Get(mod);
}
// Returns true if any Get(i) would return true for i in [0, 64).
bool Any() const { return nonzero_.Any(); }
// Returns lowest i such that Get(i). Caller must ensure Any() beforehand!
size_t First() const {
HWY_DASSERT(Any());
const size_t idx = nonzero_.First();
return idx * 64 + bits_[idx].First();
}
// Returns uint64_t(Get(i)) << i for i in [0, 64).
uint64_t Get64() const { return bits_[0].Get64(); }
// Calls `func(i)` for each `i` in the set. It is safe for `func` to modify
// the set, but the current Foreach call is only affected if changing one of
// the not yet visited BitSet64 for which Any() is true.
template <class Func>
void Foreach(const Func& func) const {
nonzero_.Foreach([&func, this](size_t idx) {
bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); });
});
}
size_t Count() const {
size_t total = 0;
nonzero_.Foreach(
[&total, this](size_t idx) { total += bits_[idx].Count(); });
return total;
}
private:
static_assert(kMaxSize <= 64 * 64, "One BitSet64 insufficient");
BitSet64 nonzero_;
BitSet64 bits_[kMaxSize / 64];
};
} // namespace hwy
#endif // HIGHWAY_HWY_BIT_SET_H_

View File

@@ -0,0 +1,126 @@
// Copyright 2020 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_CACHE_CONTROL_H_
#define HIGHWAY_HWY_CACHE_CONTROL_H_
#include "third_party/highway/hwy/base.h"
// Requires SSE2; fails to compile on 32-bit Clang 7 (see
// https://github.com/gperftools/gperftools/issues/946).
#if !defined(__SSE2__) || (HWY_COMPILER_CLANG && HWY_ARCH_X86_32)
#undef HWY_DISABLE_CACHE_CONTROL
#define HWY_DISABLE_CACHE_CONTROL
#endif
#ifndef HWY_DISABLE_CACHE_CONTROL
// intrin.h is sufficient on MSVC and already included by base.h.
#if HWY_ARCH_X86 && !HWY_COMPILER_MSVC
#include <emmintrin.h> // SSE2
#include <xmmintrin.h> // _mm_prefetch
#elif HWY_ARCH_ARM_A64
#include <arm_acle.h>
#endif
#endif // HWY_DISABLE_CACHE_CONTROL
namespace hwy {
// Even if N*sizeof(T) is smaller, Stream may write a multiple of this size.
#define HWY_STREAM_MULTIPLE 16
// The following functions may also require an attribute.
#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC
#define HWY_ATTR_CACHE __attribute__((target("sse2")))
#else
#define HWY_ATTR_CACHE
#endif
// Windows.h #defines this, which causes infinite recursion. Temporarily
// undefine to avoid conflict with our function.
// TODO(janwas): remove when this function is removed.
#pragma push_macro("LoadFence")
#undef LoadFence
// Delays subsequent loads until prior loads are visible. Beware of potentially
// differing behavior across architectures and vendors: on Intel but not
// AMD CPUs, also serves as a full fence (waits for all prior instructions to
// complete).
HWY_INLINE HWY_ATTR_CACHE void LoadFence() {
#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL)
_mm_lfence();
#endif
}
// TODO(janwas): remove when this function is removed. (See above.)
#pragma pop_macro("LoadFence")
// Ensures values written by previous `Stream` calls are visible on the current
// core. This is NOT sufficient for synchronizing across cores; when `Stream`
// outputs are to be consumed by other core(s), the producer must publish
// availability (e.g. via mutex or atomic_flag) after `FlushStream`.
HWY_INLINE HWY_ATTR_CACHE void FlushStream() {
#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL)
_mm_sfence();
#endif
}
// Optionally begins loading the cache line containing "p" to reduce latency of
// subsequent actual loads.
template <typename T>
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T* p) {
(void)p;
#ifndef HWY_DISABLE_CACHE_CONTROL
#if HWY_ARCH_X86
_mm_prefetch(reinterpret_cast<const char*>(p), _MM_HINT_T0);
#elif HWY_COMPILER_GCC // includes clang
// Hint=0 (NTA) behavior differs, but skipping outer caches is probably not
// desirable, so use the default 3 (keep in caches).
__builtin_prefetch(p, /*write=*/0, /*hint=*/3);
#endif
#endif // HWY_DISABLE_CACHE_CONTROL
}
// Invalidates and flushes the cache line containing "p", if possible.
HWY_INLINE HWY_ATTR_CACHE void FlushCacheline(const void* p) {
#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL)
_mm_clflush(p);
#else
(void)p;
#endif
}
// Hints that we are inside a spin loop and potentially reduces power
// consumption and coherency traffic. For example, x86 avoids multiple
// outstanding load requests, which reduces the memory order violation penalty
// when exiting the loop.
HWY_INLINE HWY_ATTR_CACHE void Pause() {
#ifndef HWY_DISABLE_CACHE_CONTROL
#if HWY_ARCH_X86
_mm_pause();
#elif HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG
// This is documented in ACLE and the YIELD instruction is also available in
// Armv7, but the intrinsic is broken for Armv7 clang, hence A64 only.
__yield();
#elif HWY_ARCH_ARM && HWY_COMPILER_GCC // includes clang
__asm__ volatile("yield" ::: "memory");
#elif HWY_ARCH_PPC && HWY_COMPILER_GCC // includes clang
__asm__ volatile("or 27,27,27" ::: "memory");
#endif
#endif // HWY_DISABLE_CACHE_CONTROL
}
} // namespace hwy
#endif // HIGHWAY_HWY_CACHE_CONTROL_H_

View File

@@ -0,0 +1,145 @@
// Copyright 2022 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Per-target include guard
#if defined(HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_) == \
defined(HWY_TARGET_TOGGLE) // NOLINT
#ifdef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_
#undef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_
#else
#define HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_
#endif
#include <stddef.h>
#include <stdint.h>
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
// These functions avoid having to write a loop plus remainder handling in the
// (unfortunately still common) case where arrays are not aligned/padded. If the
// inputs are known to be aligned/padded, it is more efficient to write a single
// loop using Load(). We do not provide a CopyAlignedPadded because it
// would be more verbose than such a loop.
// Fills `to`[0, `count`) with `value`.
template <class D, typename T = TFromD<D>>
void Fill(D d, T value, size_t count, T* HWY_RESTRICT to) {
const size_t N = Lanes(d);
const Vec<D> v = Set(d, value);
size_t idx = 0;
if (count >= N) {
for (; idx <= count - N; idx += N) {
StoreU(v, d, to + idx);
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return;
const size_t remaining = count - idx;
HWY_DASSERT(0 != remaining && remaining < N);
SafeFillN(remaining, value, d, to + idx);
}
// Copies `from`[0, `count`) to `to`, which must not overlap `from`.
template <class D, typename T = TFromD<D>>
void Copy(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to) {
const size_t N = Lanes(d);
size_t idx = 0;
if (count >= N) {
for (; idx <= count - N; idx += N) {
const Vec<D> v = LoadU(d, from + idx);
StoreU(v, d, to + idx);
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return;
const size_t remaining = count - idx;
HWY_DASSERT(0 != remaining && remaining < N);
SafeCopyN(remaining, d, from + idx, to + idx);
}
// For idx in [0, count) in ascending order, appends `from[idx]` to `to` if the
// corresponding mask element of `func(d, v)` is true. Returns the STL-style end
// of the newly written elements in `to`.
//
// `func` is either a functor with a templated operator()(d, v) returning a
// mask, or a generic lambda if using C++14. Due to apparent limitations of
// Clang on Windows, it is currently necessary to add HWY_ATTR before the
// opening { of the lambda to avoid errors about "function .. requires target".
//
// NOTE: this is only supported for 16-, 32- or 64-bit types.
// NOTE: Func may be called a second time for elements it has already seen, but
// these elements will not be written to `to` again.
template <class D, class Func, typename T = TFromD<D>>
T* CopyIf(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to,
const Func& func) {
const size_t N = Lanes(d);
size_t idx = 0;
if (count >= N) {
for (; idx <= count - N; idx += N) {
const Vec<D> v = LoadU(d, from + idx);
to += CompressBlendedStore(v, func(d, v), d, to);
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return to;
#if HWY_MEM_OPS_MIGHT_FAULT
// Proceed one by one.
const CappedTag<T, 1> d1;
for (; idx < count; ++idx) {
using V1 = Vec<decltype(d1)>;
// Workaround for -Waggressive-loop-optimizations on GCC 8
// (iteration 2305843009213693951 invokes undefined behavior for T=i64)
const uintptr_t addr = reinterpret_cast<uintptr_t>(from);
const T* HWY_RESTRICT from_idx =
reinterpret_cast<const T * HWY_RESTRICT>(addr + (idx * sizeof(T)));
const V1 v = LoadU(d1, from_idx);
// Avoid storing to `to` unless we know it should be kept - otherwise, we
// might overrun the end if it was allocated for the exact count.
if (CountTrue(d1, func(d1, v)) == 0) continue;
StoreU(v, d1, to);
to += 1;
}
#else
// Start index of the last unaligned whole vector, ending at the array end.
const size_t last = count - N;
// Number of elements before `from` or already written.
const size_t invalid = idx - last;
HWY_DASSERT(0 != invalid && invalid < N);
const Mask<D> mask = Not(FirstN(d, invalid));
const Vec<D> v = MaskedLoad(mask, d, from + last);
to += CompressBlendedStore(v, And(mask, func(d, v)), d, to);
#endif
return to;
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_

View File

@@ -0,0 +1,113 @@
// Copyright 2022 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Per-target include guard
#if defined(HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_) == \
defined(HWY_TARGET_TOGGLE) // NOLINT
#ifdef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_
#undef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_
#else
#define HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_
#endif
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
// Returns index of the first element equal to `value` in `in[0, count)`, or
// `count` if not found.
template <class D, typename T = TFromD<D>>
size_t Find(D d, T value, const T* HWY_RESTRICT in, size_t count) {
const size_t N = Lanes(d);
const Vec<D> broadcasted = Set(d, value);
size_t i = 0;
if (count >= N) {
for (; i <= count - N; i += N) {
const intptr_t pos = FindFirstTrue(d, Eq(broadcasted, LoadU(d, in + i)));
if (pos >= 0) return i + static_cast<size_t>(pos);
}
}
if (i != count) {
#if HWY_MEM_OPS_MIGHT_FAULT
// Scan single elements.
const CappedTag<T, 1> d1;
using V1 = Vec<decltype(d1)>;
const V1 broadcasted1 = Set(d1, GetLane(broadcasted));
for (; i < count; ++i) {
if (AllTrue(d1, Eq(broadcasted1, LoadU(d1, in + i)))) {
return i;
}
}
#else
const size_t remaining = count - i;
HWY_DASSERT(0 != remaining && remaining < N);
const Mask<D> mask = FirstN(d, remaining);
const Vec<D> v = MaskedLoad(mask, d, in + i);
// Apply mask so that we don't 'find' the zero-padding from MaskedLoad.
const intptr_t pos = FindFirstTrue(d, And(Eq(broadcasted, v), mask));
if (pos >= 0) return i + static_cast<size_t>(pos);
#endif // HWY_MEM_OPS_MIGHT_FAULT
}
return count; // not found
}
// Returns index of the first element in `in[0, count)` for which `func(d, vec)`
// returns true, otherwise `count`.
template <class D, class Func, typename T = TFromD<D>>
size_t FindIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) {
const size_t N = Lanes(d);
size_t i = 0;
if (count >= N) {
for (; i <= count - N; i += N) {
const intptr_t pos = FindFirstTrue(d, func(d, LoadU(d, in + i)));
if (pos >= 0) return i + static_cast<size_t>(pos);
}
}
if (i != count) {
#if HWY_MEM_OPS_MIGHT_FAULT
// Scan single elements.
const CappedTag<T, 1> d1;
for (; i < count; ++i) {
if (AllTrue(d1, func(d1, LoadU(d1, in + i)))) {
return i;
}
}
#else
const size_t remaining = count - i;
HWY_DASSERT(0 != remaining && remaining < N);
const Mask<D> mask = FirstN(d, remaining);
const Vec<D> v = MaskedLoad(mask, d, in + i);
// Apply mask so that we don't 'find' the zero-padding from MaskedLoad.
const intptr_t pos = FindFirstTrue(d, And(func(d, v), mask));
if (pos >= 0) return i + static_cast<size_t>(pos);
#endif // HWY_MEM_OPS_MIGHT_FAULT
}
return count; // not found
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_

View File

@@ -0,0 +1,228 @@
// Copyright 2022 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Per-target include guard
#if defined(HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_) == \
defined(HWY_TARGET_TOGGLE)
#ifdef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_
#undef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_
#else
#define HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_
#endif
#include <stddef.h>
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
// These functions avoid having to write a loop plus remainder handling in the
// (unfortunately still common) case where arrays are not aligned/padded. If the
// inputs are known to be aligned/padded, it is more efficient to write a single
// loop using Load(). We do not provide a TransformAlignedPadded because it
// would be more verbose than such a loop.
//
// Func is either a functor with a templated operator()(d, v[, v1[, v2]]), or a
// generic lambda if using C++14. The d argument is the same as was passed to
// the Generate etc. functions. Due to apparent limitations of Clang, it is
// currently necessary to add HWY_ATTR before the opening { of the lambda to
// avoid errors about "always_inline function .. requires target".
//
// We do not check HWY_MEM_OPS_MIGHT_FAULT because LoadN/StoreN do not fault.
// Fills `out[0, count)` with the vectors returned by `func(d, index_vec)`,
// where `index_vec` is `Vec<RebindToUnsigned<D>>`. On the first call to `func`,
// the value of its lane i is i, and increases by `Lanes(d)` after every call.
// Note that some of these indices may be `>= count`, but the elements that
// `func` returns in those lanes will not be written to `out`.
template <class D, class Func, typename T = TFromD<D>>
void Generate(D d, T* HWY_RESTRICT out, size_t count, const Func& func) {
const RebindToUnsigned<D> du;
using TU = TFromD<decltype(du)>;
const size_t N = Lanes(d);
size_t idx = 0;
Vec<decltype(du)> vidx = Iota(du, 0);
if (count >= N) {
for (; idx <= count - N; idx += N) {
StoreU(func(d, vidx), d, out + idx);
vidx = Add(vidx, Set(du, static_cast<TU>(N)));
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return;
const size_t remaining = count - idx;
HWY_DASSERT(0 != remaining && remaining < N);
StoreN(func(d, vidx), d, out + idx, remaining);
}
// Calls `func(d, v)` for each input vector; out of bound lanes with index i >=
// `count` are instead taken from `no[i % Lanes(d)]`.
template <class D, class Func, typename T = TFromD<D>>
void Foreach(D d, const T* HWY_RESTRICT in, const size_t count, const Vec<D> no,
const Func& func) {
const size_t N = Lanes(d);
size_t idx = 0;
if (count >= N) {
for (; idx <= count - N; idx += N) {
const Vec<D> v = LoadU(d, in + idx);
func(d, v);
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return;
const size_t remaining = count - idx;
HWY_DASSERT(0 != remaining && remaining < N);
const Vec<D> v = LoadNOr(no, d, in + idx, remaining);
func(d, v);
}
// Replaces `inout[idx]` with `func(d, inout[idx])`. Example usage: multiplying
// array elements by a constant.
template <class D, class Func, typename T = TFromD<D>>
void Transform(D d, T* HWY_RESTRICT inout, size_t count, const Func& func) {
const size_t N = Lanes(d);
size_t idx = 0;
if (count >= N) {
for (; idx <= count - N; idx += N) {
const Vec<D> v = LoadU(d, inout + idx);
StoreU(func(d, v), d, inout + idx);
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return;
const size_t remaining = count - idx;
HWY_DASSERT(0 != remaining && remaining < N);
const Vec<D> v = LoadN(d, inout + idx, remaining);
StoreN(func(d, v), d, inout + idx, remaining);
}
// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx])`. Example usage:
// multiplying array elements by those of another array.
template <class D, class Func, typename T = TFromD<D>>
void Transform1(D d, T* HWY_RESTRICT inout, size_t count,
const T* HWY_RESTRICT in1, const Func& func) {
const size_t N = Lanes(d);
size_t idx = 0;
if (count >= N) {
for (; idx <= count - N; idx += N) {
const Vec<D> v = LoadU(d, inout + idx);
const Vec<D> v1 = LoadU(d, in1 + idx);
StoreU(func(d, v, v1), d, inout + idx);
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return;
const size_t remaining = count - idx;
HWY_DASSERT(0 != remaining && remaining < N);
const Vec<D> v = LoadN(d, inout + idx, remaining);
const Vec<D> v1 = LoadN(d, in1 + idx, remaining);
StoreN(func(d, v, v1), d, inout + idx, remaining);
}
// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx], in2[idx])`. Example
// usage: FMA of elements from three arrays, stored into the first array.
template <class D, class Func, typename T = TFromD<D>>
void Transform2(D d, T* HWY_RESTRICT inout, size_t count,
const T* HWY_RESTRICT in1, const T* HWY_RESTRICT in2,
const Func& func) {
const size_t N = Lanes(d);
size_t idx = 0;
if (count >= N) {
for (; idx <= count - N; idx += N) {
const Vec<D> v = LoadU(d, inout + idx);
const Vec<D> v1 = LoadU(d, in1 + idx);
const Vec<D> v2 = LoadU(d, in2 + idx);
StoreU(func(d, v, v1, v2), d, inout + idx);
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return;
const size_t remaining = count - idx;
HWY_DASSERT(0 != remaining && remaining < N);
const Vec<D> v = LoadN(d, inout + idx, remaining);
const Vec<D> v1 = LoadN(d, in1 + idx, remaining);
const Vec<D> v2 = LoadN(d, in2 + idx, remaining);
StoreN(func(d, v, v1, v2), d, inout + idx, remaining);
}
template <class D, typename T = TFromD<D>>
void Replace(D d, T* HWY_RESTRICT inout, size_t count, T new_t, T old_t) {
const size_t N = Lanes(d);
const Vec<D> old_v = Set(d, old_t);
const Vec<D> new_v = Set(d, new_t);
size_t idx = 0;
if (count >= N) {
for (; idx <= count - N; idx += N) {
Vec<D> v = LoadU(d, inout + idx);
StoreU(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx);
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return;
const size_t remaining = count - idx;
HWY_DASSERT(0 != remaining && remaining < N);
const Vec<D> v = LoadN(d, inout + idx, remaining);
StoreN(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx, remaining);
}
template <class D, class Func, typename T = TFromD<D>>
void ReplaceIf(D d, T* HWY_RESTRICT inout, size_t count, T new_t,
const Func& func) {
const size_t N = Lanes(d);
const Vec<D> new_v = Set(d, new_t);
size_t idx = 0;
if (count >= N) {
for (; idx <= count - N; idx += N) {
Vec<D> v = LoadU(d, inout + idx);
StoreU(IfThenElse(func(d, v), new_v, v), d, inout + idx);
}
}
// `count` was a multiple of the vector length `N`: already done.
if (HWY_UNLIKELY(idx == count)) return;
const size_t remaining = count - idx;
HWY_DASSERT(0 != remaining && remaining < N);
const Vec<D> v = LoadN(d, inout + idx, remaining);
StoreN(IfThenElse(func(d, v), new_v, v), d, inout + idx, remaining);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,460 @@
// Copyright 2021 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// clang-format off
#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT
// clang-format on
#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
#else
#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
#endif
#include <stddef.h>
#include <stdint.h>
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
// NOTE: the D argument describes the inputs, not the output, because both
// f32/f32, bf16/bf16, and f32/bf16 inputs accumulate to f32.
struct Dot {
// Specify zero or more of these, ORed together, as the kAssumptions template
// argument to Compute. Each one may improve performance or reduce code size,
// at the cost of additional requirements on the arguments.
enum Assumptions {
// num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T).
kAtLeastOneVector = 1,
// num_elements is divisible by N (a power of two, so this can be used if
// the problem size is known to be a power of two >= HWY_MAX_BYTES /
// sizeof(T)).
kMultipleOfVector = 2,
// RoundUpTo(num_elements, N) elements are accessible; their value does not
// matter (will be treated as if they were zero).
kPaddedToVector = 4,
};
// Returns sum{pa[i] * pb[i]} for floating-point inputs, including float16_t
// and double if HWY_HAVE_FLOAT16/64. Aligning the
// pointers to a multiple of N elements is helpful but not required.
template <int kAssumptions, class D, typename T = TFromD<D>>
static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa,
const T* const HWY_RESTRICT pb,
const size_t num_elements) {
static_assert(IsFloat<T>(), "MulAdd requires float type");
using V = decltype(Zero(d));
HWY_LANES_CONSTEXPR size_t N = Lanes(d);
size_t i = 0;
constexpr bool kIsAtLeastOneVector =
(kAssumptions & kAtLeastOneVector) != 0;
constexpr bool kIsMultipleOfVector =
(kAssumptions & kMultipleOfVector) != 0;
constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
// Won't be able to do a full vector load without padding => scalar loop.
if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
HWY_UNLIKELY(num_elements < N)) {
// Only 2x unroll to avoid excessive code size.
T sum0 = ConvertScalarTo<T>(0);
T sum1 = ConvertScalarTo<T>(0);
for (; i + 2 <= num_elements; i += 2) {
// For reasons unknown, fp16 += does not compile on clang (Arm).
sum0 = ConvertScalarTo<T>(sum0 + pa[i + 0] * pb[i + 0]);
sum1 = ConvertScalarTo<T>(sum1 + pa[i + 1] * pb[i + 1]);
}
if (i < num_elements) {
sum1 = ConvertScalarTo<T>(sum1 + pa[i] * pb[i]);
}
return ConvertScalarTo<T>(sum0 + sum1);
}
// Compiler doesn't make independent sum* accumulators, so unroll manually.
// 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
// for unaligned inputs (each unaligned pointer halves the throughput
// because it occupies both L1 load ports for a cycle). We cannot have
// arrays of vectors on RVV/SVE, so always unroll 4x.
V sum0 = Zero(d);
V sum1 = Zero(d);
V sum2 = Zero(d);
V sum3 = Zero(d);
// Main loop: unrolled
for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop
const auto a0 = LoadU(d, pa + i);
const auto b0 = LoadU(d, pb + i);
i += N;
sum0 = MulAdd(a0, b0, sum0);
const auto a1 = LoadU(d, pa + i);
const auto b1 = LoadU(d, pb + i);
i += N;
sum1 = MulAdd(a1, b1, sum1);
const auto a2 = LoadU(d, pa + i);
const auto b2 = LoadU(d, pb + i);
i += N;
sum2 = MulAdd(a2, b2, sum2);
const auto a3 = LoadU(d, pa + i);
const auto b3 = LoadU(d, pb + i);
i += N;
sum3 = MulAdd(a3, b3, sum3);
}
// Up to 3 iterations of whole vectors
for (; i + N <= num_elements; i += N) {
const auto a = LoadU(d, pa + i);
const auto b = LoadU(d, pb + i);
sum0 = MulAdd(a, b, sum0);
}
if (!kIsMultipleOfVector) {
const size_t remaining = num_elements - i;
if (remaining != 0) {
if (kIsPaddedToVector) {
const auto mask = FirstN(d, remaining);
const auto a = LoadU(d, pa + i);
const auto b = LoadU(d, pb + i);
sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
} else {
// Unaligned load such that the last element is in the highest lane -
// ensures we do not touch any elements outside the valid range.
// If we get here, then num_elements >= N.
HWY_DASSERT(i >= N);
i += remaining - N;
const auto skip = FirstN(d, N - remaining);
const auto a = LoadU(d, pa + i); // always unaligned
const auto b = LoadU(d, pb + i);
sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
}
}
} // kMultipleOfVector
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = Add(sum0, sum1);
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum2);
return ReduceSum(d, sum0);
}
// f32 * bf16
template <int kAssumptions, class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE float Compute(const DF df,
const float* const HWY_RESTRICT pa,
const hwy::bfloat16_t* const HWY_RESTRICT pb,
const size_t num_elements) {
#if HWY_TARGET == HWY_SCALAR
const Rebind<hwy::bfloat16_t, DF> dbf;
#else
const Repartition<hwy::bfloat16_t, DF> dbf;
using VBF = decltype(Zero(dbf));
#endif
const Half<decltype(dbf)> dbfh;
using VF = decltype(Zero(df));
HWY_LANES_CONSTEXPR size_t NF = Lanes(df);
constexpr bool kIsAtLeastOneVector =
(kAssumptions & kAtLeastOneVector) != 0;
constexpr bool kIsMultipleOfVector =
(kAssumptions & kMultipleOfVector) != 0;
constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
// Won't be able to do a full vector load without padding => scalar loop.
if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
HWY_UNLIKELY(num_elements < NF)) {
// Only 2x unroll to avoid excessive code size.
float sum0 = 0.0f;
float sum1 = 0.0f;
size_t i = 0;
for (; i + 2 <= num_elements; i += 2) {
sum0 += pa[i + 0] * ConvertScalarTo<float>(pb[i + 0]);
sum1 += pa[i + 1] * ConvertScalarTo<float>(pb[i + 1]);
}
for (; i < num_elements; ++i) {
sum1 += pa[i] * ConvertScalarTo<float>(pb[i]);
}
return sum0 + sum1;
}
// Compiler doesn't make independent sum* accumulators, so unroll manually.
// 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
// for unaligned inputs (each unaligned pointer halves the throughput
// because it occupies both L1 load ports for a cycle). We cannot have
// arrays of vectors on RVV/SVE, so always unroll 4x.
VF sum0 = Zero(df);
VF sum1 = Zero(df);
VF sum2 = Zero(df);
VF sum3 = Zero(df);
size_t i = 0;
#if HWY_TARGET != HWY_SCALAR // PromoteUpperTo supported
// Main loop: unrolled
for (; i + 4 * NF <= num_elements; /* i += 4 * N */) { // incr in loop
const VF a0 = LoadU(df, pa + i);
const VBF b0 = LoadU(dbf, pb + i);
i += NF;
sum0 = MulAdd(a0, PromoteLowerTo(df, b0), sum0);
const VF a1 = LoadU(df, pa + i);
i += NF;
sum1 = MulAdd(a1, PromoteUpperTo(df, b0), sum1);
const VF a2 = LoadU(df, pa + i);
const VBF b2 = LoadU(dbf, pb + i);
i += NF;
sum2 = MulAdd(a2, PromoteLowerTo(df, b2), sum2);
const VF a3 = LoadU(df, pa + i);
i += NF;
sum3 = MulAdd(a3, PromoteUpperTo(df, b2), sum3);
}
#endif // HWY_TARGET == HWY_SCALAR
// Up to 3 iterations of whole vectors
for (; i + NF <= num_elements; i += NF) {
const VF a = LoadU(df, pa + i);
const VF b = PromoteTo(df, LoadU(dbfh, pb + i));
sum0 = MulAdd(a, b, sum0);
}
if (!kIsMultipleOfVector) {
const size_t remaining = num_elements - i;
if (remaining != 0) {
if (kIsPaddedToVector) {
const auto mask = FirstN(df, remaining);
const VF a = LoadU(df, pa + i);
const VF b = PromoteTo(df, LoadU(dbfh, pb + i));
sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
} else {
// Unaligned load such that the last element is in the highest lane -
// ensures we do not touch any elements outside the valid range.
// If we get here, then num_elements >= N.
HWY_DASSERT(i >= NF);
i += remaining - NF;
const auto skip = FirstN(df, NF - remaining);
const VF a = LoadU(df, pa + i); // always unaligned
const VF b = PromoteTo(df, LoadU(dbfh, pb + i));
sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
}
}
} // kMultipleOfVector
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = Add(sum0, sum1);
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum2);
return ReduceSum(df, sum0);
}
// Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a
// multiple of N elements is helpful but not required.
template <int kAssumptions, class D, HWY_IF_BF16_D(D)>
static HWY_INLINE float Compute(const D d,
const bfloat16_t* const HWY_RESTRICT pa,
const bfloat16_t* const HWY_RESTRICT pb,
const size_t num_elements) {
const RebindToUnsigned<D> du16;
const Repartition<float, D> df32;
using V = decltype(Zero(df32));
HWY_LANES_CONSTEXPR size_t N = Lanes(d);
size_t i = 0;
constexpr bool kIsAtLeastOneVector =
(kAssumptions & kAtLeastOneVector) != 0;
constexpr bool kIsMultipleOfVector =
(kAssumptions & kMultipleOfVector) != 0;
constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
// Won't be able to do a full vector load without padding => scalar loop.
if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
HWY_UNLIKELY(num_elements < N)) {
float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for..
float sum1 = 0.0f; // this unlikely(?) case.
for (; i + 2 <= num_elements; i += 2) {
sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]);
sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]);
}
if (i < num_elements) {
sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
}
return sum0 + sum1;
}
// See comment in the other Compute() overload. Unroll 2x, but we need
// twice as many sums for ReorderWidenMulAccumulate.
V sum0 = Zero(df32);
V sum1 = Zero(df32);
V sum2 = Zero(df32);
V sum3 = Zero(df32);
// Main loop: unrolled
for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop
const auto a0 = LoadU(d, pa + i);
const auto b0 = LoadU(d, pb + i);
i += N;
sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
const auto a1 = LoadU(d, pa + i);
const auto b1 = LoadU(d, pb + i);
i += N;
sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3);
}
// Possibly one more iteration of whole vectors
if (i + N <= num_elements) {
const auto a0 = LoadU(d, pa + i);
const auto b0 = LoadU(d, pb + i);
i += N;
sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
}
if (!kIsMultipleOfVector) {
const size_t remaining = num_elements - i;
if (remaining != 0) {
if (kIsPaddedToVector) {
const auto mask = FirstN(du16, remaining);
const auto va = LoadU(d, pa + i);
const auto vb = LoadU(d, pb + i);
const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
} else {
// Unaligned load such that the last element is in the highest lane -
// ensures we do not touch any elements outside the valid range.
// If we get here, then num_elements >= N.
HWY_DASSERT(i >= N);
i += remaining - N;
const auto skip = FirstN(du16, N - remaining);
const auto va = LoadU(d, pa + i); // always unaligned
const auto vb = LoadU(d, pb + i);
const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
}
}
} // kMultipleOfVector
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = Add(sum0, sum1);
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum2);
return ReduceSum(df32, sum0);
}
// Returns sum{i32(pa[i]) * i32(pb[i])} for i16 inputs. Aligning the pointers
// to a multiple of N elements is helpful but not required.
template <int kAssumptions, class D, HWY_IF_I16_D(D)>
static HWY_INLINE int32_t Compute(const D d,
const int16_t* const HWY_RESTRICT pa,
const int16_t* const HWY_RESTRICT pb,
const size_t num_elements) {
const RebindToUnsigned<D> du16;
const RepartitionToWide<D> di32;
using VI32 = Vec<decltype(di32)>;
HWY_LANES_CONSTEXPR size_t N = Lanes(d);
size_t i = 0;
constexpr bool kIsAtLeastOneVector =
(kAssumptions & kAtLeastOneVector) != 0;
constexpr bool kIsMultipleOfVector =
(kAssumptions & kMultipleOfVector) != 0;
constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
// Won't be able to do a full vector load without padding => scalar loop.
if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
HWY_UNLIKELY(num_elements < N)) {
int32_t sum0 = 0; // Only 2x unroll to avoid excessive code size for..
int32_t sum1 = 0; // this unlikely(?) case.
for (; i + 2 <= num_elements; i += 2) {
sum0 += int32_t{pa[i + 0]} * int32_t{pb[i + 0]};
sum1 += int32_t{pa[i + 1]} * int32_t{pb[i + 1]};
}
if (i < num_elements) {
sum1 += int32_t{pa[i]} * int32_t{pb[i]};
}
return sum0 + sum1;
}
// See comment in the other Compute() overload. Unroll 2x, but we need
// twice as many sums for ReorderWidenMulAccumulate.
VI32 sum0 = Zero(di32);
VI32 sum1 = Zero(di32);
VI32 sum2 = Zero(di32);
VI32 sum3 = Zero(di32);
// Main loop: unrolled
for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop
const auto a0 = LoadU(d, pa + i);
const auto b0 = LoadU(d, pb + i);
i += N;
sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1);
const auto a1 = LoadU(d, pa + i);
const auto b1 = LoadU(d, pb + i);
i += N;
sum2 = ReorderWidenMulAccumulate(di32, a1, b1, sum2, sum3);
}
// Possibly one more iteration of whole vectors
if (i + N <= num_elements) {
const auto a0 = LoadU(d, pa + i);
const auto b0 = LoadU(d, pb + i);
i += N;
sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1);
}
if (!kIsMultipleOfVector) {
const size_t remaining = num_elements - i;
if (remaining != 0) {
if (kIsPaddedToVector) {
const auto mask = FirstN(du16, remaining);
const auto va = LoadU(d, pa + i);
const auto vb = LoadU(d, pb + i);
const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3);
} else {
// Unaligned load such that the last element is in the highest lane -
// ensures we do not touch any elements outside the valid range.
// If we get here, then num_elements >= N.
HWY_DASSERT(i >= N);
i += remaining - N;
const auto skip = FirstN(du16, N - remaining);
const auto va = LoadU(d, pa + i); // always unaligned
const auto vb = LoadU(d, pb + i);
const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3);
}
}
} // kMultipleOfVector
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = Add(sum0, sum1);
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum2);
return ReduceSum(di32, sum0);
}
};
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_

View File

@@ -0,0 +1,467 @@
// Copyright 2020 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_
#define HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_
// SIMD/multicore-friendly planar image representation with row accessors.
#include <string.h>
#include <utility> // std::move
#include "third_party/highway/hwy/aligned_allocator.h"
#include "third_party/highway/hwy/base.h"
namespace hwy {
// Type-independent parts of Image<> - reduces code duplication and facilitates
// moving member function implementations to cc file.
struct HWY_CONTRIB_DLLEXPORT ImageBase {
// Returns required alignment in bytes for externally allocated memory.
static size_t VectorSize();
// Returns distance [bytes] between the start of two consecutive rows, a
// multiple of VectorSize but NOT kAlias (see implementation).
static size_t BytesPerRow(size_t xsize, size_t sizeof_t);
// No allocation (for output params or unused images)
ImageBase()
: xsize_(0),
ysize_(0),
bytes_per_row_(0),
bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) {}
// Allocates memory (this is the common case)
ImageBase(size_t xsize, size_t ysize, size_t sizeof_t);
// References but does not take ownership of external memory. Useful for
// interoperability with other libraries. `aligned` must be aligned to a
// multiple of VectorSize() and `bytes_per_row` must also be a multiple of
// VectorSize() or preferably equal to BytesPerRow().
ImageBase(size_t xsize, size_t ysize, size_t bytes_per_row, void* aligned);
// Copy construction/assignment is forbidden to avoid inadvertent copies,
// which can be very expensive. Use CopyImageTo() instead.
ImageBase(const ImageBase& other) = delete;
ImageBase& operator=(const ImageBase& other) = delete;
// Move constructor (required for returning Image from function)
ImageBase(ImageBase&& other) noexcept = default;
// Move assignment (required for std::vector)
ImageBase& operator=(ImageBase&& other) noexcept = default;
void Swap(ImageBase& other);
// Useful for pre-allocating image with some padding for alignment purposes
// and later reporting the actual valid dimensions. Caller is responsible
// for ensuring xsize/ysize are <= the original dimensions.
void ShrinkTo(const size_t xsize, const size_t ysize) {
xsize_ = static_cast<uint32_t>(xsize);
ysize_ = static_cast<uint32_t>(ysize);
// NOTE: we can't recompute bytes_per_row for more compact storage and
// better locality because that would invalidate the image contents.
}
// How many pixels.
HWY_INLINE size_t xsize() const { return xsize_; }
HWY_INLINE size_t ysize() const { return ysize_; }
// NOTE: do not use this for copying rows - the valid xsize may be much less.
HWY_INLINE size_t bytes_per_row() const { return bytes_per_row_; }
// Raw access to byte contents, for interfacing with other libraries.
// Unsigned char instead of char to avoid surprises (sign extension).
HWY_INLINE uint8_t* bytes() {
void* p = bytes_.get();
return static_cast<uint8_t * HWY_RESTRICT>(HWY_ASSUME_ALIGNED(p, 64));
}
HWY_INLINE const uint8_t* bytes() const {
const void* p = bytes_.get();
return static_cast<const uint8_t * HWY_RESTRICT>(HWY_ASSUME_ALIGNED(p, 64));
}
protected:
// Returns pointer to the start of a row.
HWY_INLINE void* VoidRow(const size_t y) const {
#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN
if (y >= ysize_) {
HWY_ABORT("Row(%d) >= %u\n", static_cast<int>(y), ysize_);
}
#endif
void* row = bytes_.get() + y * bytes_per_row_;
return HWY_ASSUME_ALIGNED(row, 64);
}
enum class Padding {
// Allow Load(d, row + x) for x = 0; x < xsize(); x += Lanes(d). Default.
kRoundUp,
// Allow LoadU(d, row + x) for x <= xsize() - 1. This requires an extra
// vector to be initialized. If done by default, this would suppress
// legitimate msan warnings. We therefore require users to explicitly call
// InitializePadding before using unaligned loads (e.g. convolution).
kUnaligned
};
// Initializes the minimum bytes required to suppress msan warnings from
// legitimate (according to Padding mode) vector loads/stores on the right
// border, where some lanes are uninitialized and assumed to be unused.
void InitializePadding(size_t sizeof_t, Padding padding);
// (Members are non-const to enable assignment during move-assignment.)
uint32_t xsize_; // In valid pixels, not including any padding.
uint32_t ysize_;
size_t bytes_per_row_; // Includes padding.
AlignedFreeUniquePtr<uint8_t[]> bytes_;
};
// Single channel, aligned rows separated by padding. T must be POD.
//
// 'Single channel' (one 2D array per channel) simplifies vectorization
// (repeating the same operation on multiple adjacent components) without the
// complexity of a hybrid layout (8 R, 8 G, 8 B, ...). In particular, clients
// can easily iterate over all components in a row and Image requires no
// knowledge of the pixel format beyond the component type "T".
//
// 'Aligned' means each row is aligned to the L1 cache line size. This prevents
// false sharing between two threads operating on adjacent rows.
//
// 'Padding' is still relevant because vectors could potentially be larger than
// a cache line. By rounding up row sizes to the vector size, we allow
// reading/writing ALIGNED vectors whose first lane is a valid sample. This
// avoids needing a separate loop to handle remaining unaligned lanes.
//
// This image layout could also be achieved with a vector and a row accessor
// function, but a class wrapper with support for "deleter" allows wrapping
// existing memory allocated by clients without copying the pixels. It also
// provides convenient accessors for xsize/ysize, which shortens function
// argument lists. Supports move-construction so it can be stored in containers.
template <typename ComponentType>
class Image : public ImageBase {
public:
using T = ComponentType;
Image() = default;
Image(const size_t xsize, const size_t ysize)
: ImageBase(xsize, ysize, sizeof(T)) {}
Image(const size_t xsize, const size_t ysize, size_t bytes_per_row,
void* aligned)
: ImageBase(xsize, ysize, bytes_per_row, aligned) {}
void InitializePaddingForUnalignedAccesses() {
InitializePadding(sizeof(T), Padding::kUnaligned);
}
HWY_INLINE const T* ConstRow(const size_t y) const {
return static_cast<const T*>(VoidRow(y));
}
HWY_INLINE const T* ConstRow(const size_t y) {
return static_cast<const T*>(VoidRow(y));
}
// Returns pointer to non-const. This allows passing const Image* parameters
// when the callee is only supposed to fill the pixels, as opposed to
// allocating or resizing the image.
HWY_INLINE T* MutableRow(const size_t y) const {
return static_cast<T*>(VoidRow(y));
}
HWY_INLINE T* MutableRow(const size_t y) {
return static_cast<T*>(VoidRow(y));
}
// Returns number of pixels (some of which are padding) per row. Useful for
// computing other rows via pointer arithmetic. WARNING: this must
// NOT be used to determine xsize.
HWY_INLINE intptr_t PixelsPerRow() const {
return static_cast<intptr_t>(bytes_per_row_ / sizeof(T));
}
};
using ImageF = Image<float>;
// A bundle of 3 same-sized images. To fill an existing Image3 using
// single-channel producers, we also need access to each const Image*. Const
// prevents breaking the same-size invariant, while still allowing pixels to be
// changed via MutableRow.
template <typename ComponentType>
class Image3 {
public:
using T = ComponentType;
using ImageT = Image<T>;
static constexpr size_t kNumPlanes = 3;
Image3() : planes_{ImageT(), ImageT(), ImageT()} {}
Image3(const size_t xsize, const size_t ysize)
: planes_{ImageT(xsize, ysize), ImageT(xsize, ysize),
ImageT(xsize, ysize)} {}
Image3(Image3&& other) noexcept {
for (size_t i = 0; i < kNumPlanes; i++) {
planes_[i] = std::move(other.planes_[i]);
}
}
Image3(ImageT&& plane0, ImageT&& plane1, ImageT&& plane2) {
if (!SameSize(plane0, plane1) || !SameSize(plane0, plane2)) {
HWY_ABORT(
"Not same size: %d x %d, %d x %d, %d x %d\n",
static_cast<int>(plane0.xsize()), static_cast<int>(plane0.ysize()),
static_cast<int>(plane1.xsize()), static_cast<int>(plane1.ysize()),
static_cast<int>(plane2.xsize()), static_cast<int>(plane2.ysize()));
}
planes_[0] = std::move(plane0);
planes_[1] = std::move(plane1);
planes_[2] = std::move(plane2);
}
// Copy construction/assignment is forbidden to avoid inadvertent copies,
// which can be very expensive. Use CopyImageTo instead.
Image3(const Image3& other) = delete;
Image3& operator=(const Image3& other) = delete;
Image3& operator=(Image3&& other) noexcept {
for (size_t i = 0; i < kNumPlanes; i++) {
planes_[i] = std::move(other.planes_[i]);
}
return *this;
}
HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) const {
return static_cast<const T*>(VoidPlaneRow(c, y));
}
HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) {
return static_cast<const T*>(VoidPlaneRow(c, y));
}
HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) const {
return static_cast<T*>(VoidPlaneRow(c, y));
}
HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) {
return static_cast<T*>(VoidPlaneRow(c, y));
}
HWY_INLINE const ImageT& Plane(size_t idx) const { return planes_[idx]; }
void Swap(Image3& other) {
for (size_t c = 0; c < 3; ++c) {
other.planes_[c].Swap(planes_[c]);
}
}
void ShrinkTo(const size_t xsize, const size_t ysize) {
for (ImageT& plane : planes_) {
plane.ShrinkTo(xsize, ysize);
}
}
// Sizes of all three images are guaranteed to be equal.
HWY_INLINE size_t xsize() const { return planes_[0].xsize(); }
HWY_INLINE size_t ysize() const { return planes_[0].ysize(); }
// Returns offset [bytes] from one row to the next row of the same plane.
// WARNING: this must NOT be used to determine xsize, nor for copying rows -
// the valid xsize may be much less.
HWY_INLINE size_t bytes_per_row() const { return planes_[0].bytes_per_row(); }
// Returns number of pixels (some of which are padding) per row. Useful for
// computing other rows via pointer arithmetic. WARNING: this must NOT be used
// to determine xsize.
HWY_INLINE intptr_t PixelsPerRow() const { return planes_[0].PixelsPerRow(); }
private:
// Returns pointer to the start of a row.
HWY_INLINE void* VoidPlaneRow(const size_t c, const size_t y) const {
#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN
if (c >= kNumPlanes || y >= ysize()) {
HWY_ABORT("PlaneRow(%d, %d) >= %d\n", static_cast<int>(c),
static_cast<int>(y), static_cast<int>(ysize()));
}
#endif
// Use the first plane's stride because the compiler might not realize they
// are all equal. Thus we only need a single multiplication for all planes.
const size_t row_offset = y * planes_[0].bytes_per_row();
const void* row = planes_[c].bytes() + row_offset;
return static_cast<const T * HWY_RESTRICT>(
HWY_ASSUME_ALIGNED(row, HWY_ALIGNMENT));
}
private:
ImageT planes_[kNumPlanes];
};
using Image3F = Image3<float>;
// Rectangular region in image(s). Factoring this out of Image instead of
// shifting the pointer by x0/y0 allows this to apply to multiple images with
// different resolutions. Can compare size via SameSize(rect1, rect2).
class Rect {
public:
// Most windows are xsize_max * ysize_max, except those on the borders where
// begin + size_max > end.
constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize_max,
size_t ysize_max, size_t xend, size_t yend)
: x0_(xbegin),
y0_(ybegin),
xsize_(ClampedSize(xbegin, xsize_max, xend)),
ysize_(ClampedSize(ybegin, ysize_max, yend)) {}
// Construct with origin and known size (typically from another Rect).
constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize, size_t ysize)
: x0_(xbegin), y0_(ybegin), xsize_(xsize), ysize_(ysize) {}
// Construct a rect that covers a whole image.
template <typename Image>
explicit Rect(const Image& image)
: Rect(0, 0, image.xsize(), image.ysize()) {}
Rect() : Rect(0, 0, 0, 0) {}
Rect(const Rect&) = default;
Rect& operator=(const Rect&) = default;
Rect Subrect(size_t xbegin, size_t ybegin, size_t xsize_max,
size_t ysize_max) {
return Rect(x0_ + xbegin, y0_ + ybegin, xsize_max, ysize_max, x0_ + xsize_,
y0_ + ysize_);
}
template <typename T>
const T* ConstRow(const Image<T>* image, size_t y) const {
return image->ConstRow(y + y0_) + x0_;
}
template <typename T>
T* MutableRow(const Image<T>* image, size_t y) const {
return image->MutableRow(y + y0_) + x0_;
}
template <typename T>
const T* ConstPlaneRow(const Image3<T>& image, size_t c, size_t y) const {
return image.ConstPlaneRow(c, y + y0_) + x0_;
}
template <typename T>
T* MutablePlaneRow(Image3<T>* image, const size_t c, size_t y) const {
return image->MutablePlaneRow(c, y + y0_) + x0_;
}
// Returns true if this Rect fully resides in the given image. ImageT could be
// Image<T> or Image3<T>; however if ImageT is Rect, results are nonsensical.
template <class ImageT>
bool IsInside(const ImageT& image) const {
return (x0_ + xsize_ <= image.xsize()) && (y0_ + ysize_ <= image.ysize());
}
size_t x0() const { return x0_; }
size_t y0() const { return y0_; }
size_t xsize() const { return xsize_; }
size_t ysize() const { return ysize_; }
private:
// Returns size_max, or whatever is left in [begin, end).
static constexpr size_t ClampedSize(size_t begin, size_t size_max,
size_t end) {
return (begin + size_max <= end) ? size_max
: (end > begin ? end - begin : 0);
}
size_t x0_;
size_t y0_;
size_t xsize_;
size_t ysize_;
};
// Works for any image-like input type(s).
template <class Image1, class Image2>
HWY_MAYBE_UNUSED bool SameSize(const Image1& image1, const Image2& image2) {
return image1.xsize() == image2.xsize() && image1.ysize() == image2.ysize();
}
// Mirrors out of bounds coordinates and returns valid coordinates unchanged.
// We assume the radius (distance outside the image) is small compared to the
// image size, otherwise this might not terminate.
// The mirror is outside the last column (border pixel is also replicated).
static HWY_INLINE HWY_MAYBE_UNUSED size_t Mirror(int64_t x,
const int64_t xsize) {
HWY_DASSERT(xsize != 0);
// TODO(janwas): replace with branchless version
while (x < 0 || x >= xsize) {
if (x < 0) {
x = -x - 1;
} else {
x = 2 * xsize - 1 - x;
}
}
return static_cast<size_t>(x);
}
// Wrap modes for ensuring X/Y coordinates are in the valid range [0, size):
// Mirrors (repeating the edge pixel once). Useful for convolutions.
struct WrapMirror {
HWY_INLINE size_t operator()(const int64_t coord, const size_t size) const {
return Mirror(coord, static_cast<int64_t>(size));
}
};
// Returns the same coordinate, for when we know "coord" is already valid (e.g.
// interior of an image).
struct WrapUnchanged {
HWY_INLINE size_t operator()(const int64_t coord, size_t /*size*/) const {
return static_cast<size_t>(coord);
}
};
// Similar to Wrap* but for row pointers (reduces Row() multiplications).
class WrapRowMirror {
public:
template <class View>
WrapRowMirror(const View& image, size_t ysize)
: first_row_(image.ConstRow(0)), last_row_(image.ConstRow(ysize - 1)) {}
const float* operator()(const float* const HWY_RESTRICT row,
const int64_t stride) const {
if (row < first_row_) {
const int64_t num_before = first_row_ - row;
// Mirrored; one row before => row 0, two before = row 1, ...
return first_row_ + num_before - stride;
}
if (row > last_row_) {
const int64_t num_after = row - last_row_;
// Mirrored; one row after => last row, two after = last - 1, ...
return last_row_ - num_after + stride;
}
return row;
}
private:
const float* const HWY_RESTRICT first_row_;
const float* const HWY_RESTRICT last_row_;
};
struct WrapRowUnchanged {
HWY_INLINE const float* operator()(const float* const HWY_RESTRICT row,
int64_t /*stride*/) const {
return row;
}
};
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,451 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard (still compiled once per target)
#if defined(HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_) == \
defined(HWY_TARGET_TOGGLE)
#ifdef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
#undef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
#else
#define HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_
#endif
#include <stddef.h>
#include "third_party/highway/hwy/cache_control.h"
#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h"
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
template <typename TA, typename TB>
TA AddScalar(TA a, TB b) {
return ConvertScalarTo<TA>(ConvertScalarTo<float>(a) +
ConvertScalarTo<float>(b));
}
template <size_t kOuter, size_t kInner, typename T, bool kAdd>
HWY_NOINLINE void MatVecAddImpl(const T* HWY_RESTRICT mat,
const T* HWY_RESTRICT vec,
const T* HWY_RESTRICT add, T* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
(void)add;
// Process multiple rows at a time so that we write multiples of a cache line
// to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
// parallelization potential.
constexpr size_t kChunkSize = 64 / sizeof(T);
const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize);
const ScalableTag<T> d;
const size_t N = Lanes(d);
// Required for Stream loop, otherwise we might have partial vectors.
HWY_DASSERT(kChunkSize >= N);
pool.Run(0, num_chunks,
[&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
// MSVC workaround: duplicate to ensure constexpr.
constexpr size_t kChunkSize = 64 / sizeof(T);
// Software write-combining to avoid cache pollution from out.
// Although `out` may be used later, keeping it out of the cache
// now and avoiding RFOs is a consistent 5% overall win.
HWY_ALIGN T buf[kChunkSize];
// Only handle entire chunks here because the Stream is not masked.
// Remaining rows are handled after the pool.Run.
const size_t begin = static_cast<size_t>(chunk * kChunkSize);
for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
auto sum0 = Zero(d);
auto sum1 = Zero(d);
// 4x unrolling barely helps SKX but likely helps Arm V2.
auto sum2 = Zero(d);
auto sum3 = Zero(d);
const T* HWY_RESTRICT row = &mat[(begin + idx_row) * kInner];
size_t i = 0;
// No clear win from prefetching from the next 1..3 rows.
// clflush &row[i] is slow, clflushopt less so but not helping.
HWY_UNROLL(1)
for (; i + 4 * N <= kInner; i += 4 * N) {
const auto a0 = LoadU(d, row + i + 0 * N);
const auto v0 = LoadU(d, vec + i + 0 * N);
sum0 = MulAdd(a0, v0, sum0);
const auto a1 = LoadU(d, row + i + 1 * N);
const auto v1 = LoadU(d, vec + i + 1 * N);
sum1 = MulAdd(a1, v1, sum1);
const auto a2 = LoadU(d, row + i + 2 * N);
const auto v2 = LoadU(d, vec + i + 2 * N);
sum2 = MulAdd(a2, v2, sum2);
const auto a3 = LoadU(d, row + i + 3 * N);
const auto v3 = LoadU(d, vec + i + 3 * N);
sum3 = MulAdd(a3, v3, sum3);
}
// Last entire vectors
for (; i + N <= kInner; i += N) {
const auto a0 = LoadU(d, row + i);
const auto v0 = LoadU(d, vec + i);
sum0 = MulAdd(a0, v0, sum0);
}
const size_t remainder = kInner - i;
if (remainder != 0) {
const auto a0 = LoadN(d, row + i, remainder);
const auto v0 = LoadN(d, vec + i, remainder);
sum1 = MulAdd(a0, v0, sum1);
}
// Reduction tree: sum of all accumulators, then their lanes
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum1);
sum0 = Add(sum0, sum2);
buf[idx_row] = ReduceSum(d, sum0);
HWY_IF_CONSTEXPR(kAdd) {
buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]);
}
} // idx_row
HWY_UNROLL(4) // 1..4 iterations
for (size_t i = 0; i != kChunkSize; i += N) {
Stream(Load(d, buf + i), d, out + begin + i);
}
});
hwy::FlushStream();
// Handle remainder rows which are not a multiple of the chunk size.
for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
auto sum0 = Zero(d);
const T* HWY_RESTRICT row = &mat[r * kInner];
size_t i = 0;
HWY_UNROLL(1)
for (; i + N <= kInner; i += N) {
const auto a0 = LoadU(d, row + i);
const auto v0 = LoadU(d, vec + i);
sum0 = MulAdd(a0, v0, sum0);
}
const size_t remainder = kInner - i;
if (remainder != 0) {
const auto a0 = LoadN(d, row + i, remainder);
const auto v0 = LoadN(d, vec + i, remainder);
sum0 = MulAdd(a0, v0, sum0);
}
out[r] = ReduceSum(d, sum0);
HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); }
} // r
}
// Multiplies mat with vec, adds add and puts the result in out.
//
// mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at
// index i * kInner + j.
//
// vec is a (kInner,)-shaped array.
//
// add is a (kOuter,)-shaped array.
//
// out is a (kOuter,)-shaped array that will set to mat @ vec + add.
template <size_t kOuter, size_t kInner, typename T>
HWY_NOINLINE void MatVecAdd(const T* HWY_RESTRICT mat,
const T* HWY_RESTRICT vec,
const T* HWY_RESTRICT add, T* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
MatVecAddImpl<kOuter, kInner, T, true>(mat, vec, add, out, pool);
}
// Multiplies mat with vec and puts the result in out.
//
// mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at
// index i * kInner + j.
//
// vec is a (kInner,)-shaped array.
//
// out is a (kOuter,)-shaped array that will set to mat @ vec.
template <size_t kOuter, size_t kInner, typename T>
HWY_NOINLINE void MatVec(const T* HWY_RESTRICT mat, const T* HWY_RESTRICT vec,
T* HWY_RESTRICT out, hwy::ThreadPool& pool) {
MatVecAddImpl<kOuter, kInner, T, false>(mat, vec, /*add=*/nullptr, out, pool);
}
// This target lacks too many ops required in our implementation, use
// HWY_EMU128 instead.
#if HWY_TARGET != HWY_SCALAR
// Specialization for bf16 matrix, which halves memory bandwidth requirements.
template <size_t kOuter, size_t kInner, bool kAdd>
HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat,
const float* HWY_RESTRICT vec,
const float* HWY_RESTRICT add,
float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
// Process multiple rows at a time so that we write multiples of a cache line
// to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
// parallelization potential.
constexpr size_t kChunkSize = 64 / sizeof(float);
const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize);
const ScalableTag<float> d;
const Repartition<hwy::bfloat16_t, decltype(d)> d16;
// In the remainder loop, we only process a single f32 vector, so load half
// vectors of bf16 to avoid overrun.
const Half<decltype(d16)> d16h;
using V = Vec<decltype(d)>;
using V16 = Vec<decltype(d16)>;
using V16H = Vec<decltype(d16h)>;
const size_t N = Lanes(d);
// Required for Stream loop, otherwise we might have partial vectors.
HWY_DASSERT(kChunkSize >= N);
pool.Run(0, num_chunks,
[&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
// MSVC workaround: duplicate to ensure constexpr.
constexpr size_t kChunkSize = 64 / sizeof(float);
// Software write-combining to avoid cache pollution from out.
// Although `out` may be used later, keeping it out of the cache
// now and avoiding RFOs is a consistent 5% overall win.
HWY_ALIGN float buf[kChunkSize];
// Only handle entire chunks here because the Stream is not masked.
// Remaining rows are handled after the pool.Run.
const size_t begin = static_cast<size_t>(chunk * kChunkSize);
for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
auto sum0 = Zero(d);
auto sum1 = Zero(d);
// 4x unrolling barely helps SKX but likely helps Arm V2.
auto sum2 = Zero(d);
auto sum3 = Zero(d);
const hwy::bfloat16_t* HWY_RESTRICT row =
&mat[(begin + idx_row) * kInner];
size_t i = 0;
// No clear win from prefetching from the next 1..3 rows.
// clflush &row[i] is slow, clflushopt less so but not helping.
HWY_UNROLL(1)
for (; i + 4 * N <= kInner; i += 4 * N) {
const V16 b0 = LoadU(d16, row + i + 0 * N);
const V a0 = PromoteLowerTo(d, b0);
const V a1 = PromoteUpperTo(d, b0);
const V16 b1 = LoadU(d16, row + i + 2 * N);
const V a2 = PromoteLowerTo(d, b1);
const V a3 = PromoteUpperTo(d, b1);
const V v0 = LoadU(d, vec + i + 0 * N);
sum0 = MulAdd(a0, v0, sum0);
const V v1 = LoadU(d, vec + i + 1 * N);
sum1 = MulAdd(a1, v1, sum1);
const V v2 = LoadU(d, vec + i + 2 * N);
sum2 = MulAdd(a2, v2, sum2);
const V v3 = LoadU(d, vec + i + 3 * N);
sum3 = MulAdd(a3, v3, sum3);
}
// Last entire vectors
for (; i + N <= kInner; i += N) {
const V16H b0 = LoadU(d16h, row + i);
const V a0 = PromoteTo(d, b0);
const V v0 = LoadU(d, vec + i);
sum0 = MulAdd(a0, v0, sum0);
}
const size_t remainder = kInner - i;
if (remainder != 0) {
const V16H b0 = LoadN(d16h, row + i, remainder);
const V a0 = PromoteTo(d, b0);
const V v0 = LoadN(d, vec + i, remainder);
sum1 = MulAdd(a0, v0, sum1);
}
// Reduction tree: sum of all accumulators, then their lanes
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum1);
sum0 = Add(sum0, sum2);
buf[idx_row] = ReduceSum(d, sum0);
HWY_IF_CONSTEXPR(kAdd) {
buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]);
}
} // idx_row
HWY_UNROLL(4) // 1..4 iterations
for (size_t i = 0; i != kChunkSize; i += N) {
Stream(Load(d, buf + i), d, out + begin + i);
}
});
hwy::FlushStream();
// Handle remainder rows which are not a multiple of the chunk size.
for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
auto sum0 = Zero(d);
const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner];
size_t i = 0;
HWY_UNROLL(1)
for (; i + N <= kInner; i += N) {
const V16H b0 = LoadU(d16h, row + i);
const V a0 = PromoteTo(d, b0);
const V v0 = LoadU(d, vec + i);
sum0 = MulAdd(a0, v0, sum0);
}
const size_t remainder = kInner - i;
if (remainder != 0) {
const V16H b0 = LoadN(d16h, row + i, remainder);
const V a0 = PromoteTo(d, b0);
const V v0 = LoadN(d, vec + i, remainder);
sum0 = MulAdd(a0, v0, sum0);
}
out[r] = ReduceSum(d, sum0);
HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); }
} // r
}
template <size_t kOuter, size_t kInner>
HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat,
const float* HWY_RESTRICT vec,
const float* HWY_RESTRICT add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool);
}
template <size_t kOuter, size_t kInner>
HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat,
const float* HWY_RESTRICT vec, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
MatVecAddImpl<kOuter, kInner, false>(mat, vec, /*add=*/nullptr, out, pool);
}
// Both mat and vec are bf16.
template <size_t kOuter, size_t kInner, bool kAdd>
HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat,
const hwy::bfloat16_t* HWY_RESTRICT vec,
const hwy::bfloat16_t* HWY_RESTRICT add,
float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
// Process multiple rows at a time so that we write multiples of a cache line
// to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
// parallelization potential.
constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t);
const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize);
const ScalableTag<float> df;
const Repartition<hwy::bfloat16_t, decltype(df)> d16;
using V16 = Vec<decltype(d16)>;
const size_t N = Lanes(d16);
// Required for Stream loop, otherwise we might have partial vectors.
HWY_DASSERT(kChunkSize >= N);
pool.Run(0, num_chunks,
[&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
// MSVC workaround: duplicate to ensure constexpr.
constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t);
// Software write-combining to avoid cache pollution from out.
// Although `out` may be used later, keeping it out of the cache
// now and avoiding RFOs is a consistent 5% overall win.
HWY_ALIGN float buf[kChunkSize];
// Only handle entire chunks here because the Stream is not masked.
// Remaining rows are handled after the pool.Run.
const size_t begin = static_cast<size_t>(chunk * kChunkSize);
for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) {
auto sum0 = Zero(df);
auto sum1 = Zero(df);
auto sum2 = Zero(df);
auto sum3 = Zero(df);
const hwy::bfloat16_t* HWY_RESTRICT row =
&mat[(begin + idx_row) * kInner];
size_t i = 0;
// No clear win from prefetching from the next 1..3 rows.
// clflush &row[i] is slow, clflushopt less so but not helping.
HWY_UNROLL(1)
for (; i + 2 * N <= kInner; i += 2 * N) {
const V16 b0 = LoadU(d16, row + i + 0 * N);
const V16 b1 = LoadU(d16, row + i + 1 * N);
const V16 v0 = LoadU(d16, vec + i + 0 * N);
const V16 v1 = LoadU(d16, vec + i + 1 * N);
sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
sum2 = ReorderWidenMulAccumulate(df, b1, v1, sum2, sum3);
}
// Last entire vector
for (; i + N <= kInner; i += N) {
const V16 b0 = LoadU(d16, row + i);
const V16 v0 = LoadU(d16, vec + i);
sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
}
const size_t remainder = kInner - i;
if (remainder != 0) {
const V16 b0 = LoadN(d16, row + i, remainder);
const V16 v0 = LoadN(d16, vec + i, remainder);
sum2 = ReorderWidenMulAccumulate(df, b0, v0, sum2, sum3);
}
// Reduction tree: sum of all accumulators, then their lanes
sum0 = Add(sum0, sum1);
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum2);
buf[idx_row] = ReduceSum(df, sum0);
HWY_IF_CONSTEXPR(kAdd) {
buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]);
}
} // idx_row
HWY_UNROLL(4) // 1..4 iterations
for (size_t i = 0; i != kChunkSize; i += N / 2) {
Stream(Load(df, buf + i), df, out + begin + i);
}
});
hwy::FlushStream();
// Handle remainder rows which are not a multiple of the chunk size.
for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
auto sum0 = Zero(df);
auto sum1 = Zero(df);
const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner];
size_t i = 0;
HWY_UNROLL(1)
for (; i + N <= kInner; i += N) {
const V16 b0 = LoadU(d16, row + i);
const V16 v0 = LoadU(d16, vec + i);
sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
}
const size_t remainder = kInner - i;
if (remainder != 0) {
const V16 b0 = LoadN(d16, row + i, remainder);
const V16 v0 = LoadN(d16, vec + i, remainder);
sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1);
}
out[r] = ReduceSum(df, Add(sum0, sum1));
HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); }
} // r
}
template <size_t kOuter, size_t kInner>
HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat,
const hwy::bfloat16_t* HWY_RESTRICT vec,
const hwy::bfloat16_t* HWY_RESTRICT add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
MatVecAddImpl<kOuter, kInner, true>(mat, vec, add, out, pool);
}
template <size_t kOuter, size_t kInner>
HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat,
const hwy::bfloat16_t* HWY_RESTRICT vec,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
MatVecAddImpl<kOuter, kInner, false>(mat, vec, /*add=*/nullptr, out, pool);
}
#endif // HWY_TARGET != HWY_SCALAR
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_

View File

@@ -0,0 +1,384 @@
/*
* Original implementation written in 2019
* by David Blackman and Sebastiano Vigna (vigna@acm.org)
* Available at https://prng.di.unimi.it/ with creative commons license:
* To the extent possible under law, the author has dedicated all copyright
* and related and neighboring rights to this software to the public domain
* worldwide. This software is distributed without any warranty.
* See <http://creativecommons.org/publicdomain/zero/1.0/>.
*
* This implementation is a Vector port of the original implementation
* written by Marco Barbone (m.barbone19@imperial.ac.uk).
* I take no credit for the original implementation.
* The code is provided as is and the original license applies.
*/
#if defined(HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_) == \
defined(HWY_TARGET_TOGGLE) // NOLINT
#ifdef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
#undef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
#else
#define HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
#endif
#include <array>
#include <cstdint>
#include <limits>
#include "third_party/highway/hwy/aligned_allocator.h"
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE(); // required if not using HWY_ATTR
namespace hwy {
namespace HWY_NAMESPACE { // required: unique per target
namespace internal {
namespace {
#if HWY_HAVE_FLOAT64
// C++ < 17 does not support hexfloat
#if __cpp_hex_float > 201603L
constexpr double kMulConst = 0x1.0p-53;
#else
constexpr double kMulConst =
0.00000000000000011102230246251565404236316680908203125;
#endif // __cpp_hex_float
#endif // HWY_HAVE_FLOAT64
constexpr std::uint64_t kJump[] = {0x180ec6d33cfd0aba, 0xd5a61266f0c9392c,
0xa9582618e03fc9aa, 0x39abdc4529b1661c};
constexpr std::uint64_t kLongJump[] = {0x76e15d3efefdcbbf, 0xc5004e441c522fb3,
0x77710069854ee241, 0x39109bb02acbe635};
} // namespace
class SplitMix64 {
public:
constexpr explicit SplitMix64(const std::uint64_t state) noexcept
: state_(state) {}
HWY_CXX14_CONSTEXPR std::uint64_t operator()() {
std::uint64_t z = (state_ += 0x9e3779b97f4a7c15);
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
return z ^ (z >> 31);
}
private:
std::uint64_t state_;
};
class Xoshiro {
public:
HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed) noexcept
: state_{} {
SplitMix64 splitMix64{seed};
for (auto &element : state_) {
element = splitMix64();
}
}
HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed,
const std::uint64_t thread_id) noexcept
: Xoshiro(seed) {
for (auto i = UINT64_C(0); i < thread_id; ++i) {
Jump();
}
}
HWY_CXX14_CONSTEXPR std::uint64_t operator()() noexcept { return Next(); }
#if HWY_HAVE_FLOAT64
HWY_CXX14_CONSTEXPR double Uniform() noexcept {
return static_cast<double>(Next() >> 11) * kMulConst;
}
#endif
HWY_CXX14_CONSTEXPR std::array<std::uint64_t, 4> GetState() const {
return {state_[0], state_[1], state_[2], state_[3]};
}
HWY_CXX17_CONSTEXPR void SetState(
std::array<std::uint64_t, 4> state) noexcept {
state_[0] = state[0];
state_[1] = state[1];
state_[2] = state[2];
state_[3] = state[3];
}
static constexpr std::uint64_t StateSize() noexcept { return 4; }
/* This is the jump function for the generator. It is equivalent to 2^128
* calls to next(); it can be used to generate 2^128 non-overlapping
* subsequences for parallel computations. */
HWY_CXX14_CONSTEXPR void Jump() noexcept { Jump(kJump); }
/* This is the long-jump function for the generator. It is equivalent to 2^192
* calls to next(); it can be used to generate 2^64 starting points, from each
* of which jump() will generate 2^64 non-overlapping subsequences for
* parallel distributed computations. */
HWY_CXX14_CONSTEXPR void LongJump() noexcept { Jump(kLongJump); }
private:
std::uint64_t state_[4];
static constexpr std::uint64_t Rotl(const std::uint64_t x, int k) noexcept {
return (x << k) | (x >> (64 - k));
}
HWY_CXX14_CONSTEXPR std::uint64_t Next() noexcept {
const std::uint64_t result = Rotl(state_[0] + state_[3], 23) + state_[0];
const std::uint64_t t = state_[1] << 17;
state_[2] ^= state_[0];
state_[3] ^= state_[1];
state_[1] ^= state_[2];
state_[0] ^= state_[3];
state_[2] ^= t;
state_[3] = Rotl(state_[3], 45);
return result;
}
HWY_CXX14_CONSTEXPR void Jump(const std::uint64_t (&jumpArray)[4]) noexcept {
std::uint64_t s0 = 0;
std::uint64_t s1 = 0;
std::uint64_t s2 = 0;
std::uint64_t s3 = 0;
for (const std::uint64_t i : jumpArray)
for (std::uint_fast8_t b = 0; b < 64; b++) {
if (i & std::uint64_t{1UL} << b) {
s0 ^= state_[0];
s1 ^= state_[1];
s2 ^= state_[2];
s3 ^= state_[3];
}
Next();
}
state_[0] = s0;
state_[1] = s1;
state_[2] = s2;
state_[3] = s3;
}
};
} // namespace internal
class VectorXoshiro {
private:
using VU64 = Vec<ScalableTag<std::uint64_t>>;
using StateType = AlignedNDArray<std::uint64_t, 2>;
#if HWY_HAVE_FLOAT64
using VF64 = Vec<ScalableTag<double>>;
#endif
public:
explicit VectorXoshiro(const std::uint64_t seed,
const std::uint64_t threadNumber = 0)
: state_{{internal::Xoshiro::StateSize(),
Lanes(ScalableTag<std::uint64_t>{})}},
streams{state_.shape().back()} {
internal::Xoshiro xoshiro{seed};
for (std::uint64_t i = 0; i < threadNumber; ++i) {
xoshiro.LongJump();
}
for (size_t i = 0UL; i < streams; ++i) {
const auto state = xoshiro.GetState();
for (size_t j = 0UL; j < internal::Xoshiro::StateSize(); ++j) {
state_[{j}][i] = state[j];
}
xoshiro.Jump();
}
}
HWY_INLINE VU64 operator()() noexcept { return Next(); }
AlignedVector<std::uint64_t> operator()(const std::size_t n) {
AlignedVector<std::uint64_t> result(n);
const ScalableTag<std::uint64_t> tag{};
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
auto s3 = Load(tag, state_[{3}].data());
for (std::uint64_t i = 0; i < n; i += Lanes(tag)) {
const auto next = Update(s0, s1, s2, s3);
Store(next, tag, result.data() + i);
}
Store(s0, tag, state_[{0}].data());
Store(s1, tag, state_[{1}].data());
Store(s2, tag, state_[{2}].data());
Store(s3, tag, state_[{3}].data());
return result;
}
template <std::uint64_t N>
std::array<std::uint64_t, N> operator()() noexcept {
alignas(HWY_ALIGNMENT) std::array<std::uint64_t, N> result;
const ScalableTag<std::uint64_t> tag{};
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
auto s3 = Load(tag, state_[{3}].data());
for (std::uint64_t i = 0; i < N; i += Lanes(tag)) {
const auto next = Update(s0, s1, s2, s3);
Store(next, tag, result.data() + i);
}
Store(s0, tag, state_[{0}].data());
Store(s1, tag, state_[{1}].data());
Store(s2, tag, state_[{2}].data());
Store(s3, tag, state_[{3}].data());
return result;
}
std::uint64_t StateSize() const noexcept {
return streams * internal::Xoshiro::StateSize();
}
const StateType &GetState() const { return state_; }
#if HWY_HAVE_FLOAT64
HWY_INLINE VF64 Uniform() noexcept {
const ScalableTag<double> real_tag{};
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
const auto bits = ShiftRight<11>(Next());
const auto real = ConvertTo(real_tag, bits);
return Mul(real, MUL_VALUE);
}
AlignedVector<double> Uniform(const std::size_t n) {
AlignedVector<double> result(n);
const ScalableTag<std::uint64_t> tag{};
const ScalableTag<double> real_tag{};
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
auto s3 = Load(tag, state_[{3}].data());
for (std::uint64_t i = 0; i < n; i += Lanes(real_tag)) {
const auto next = Update(s0, s1, s2, s3);
const auto bits = ShiftRight<11>(next);
const auto real = ConvertTo(real_tag, bits);
const auto uniform = Mul(real, MUL_VALUE);
Store(uniform, real_tag, result.data() + i);
}
Store(s0, tag, state_[{0}].data());
Store(s1, tag, state_[{1}].data());
Store(s2, tag, state_[{2}].data());
Store(s3, tag, state_[{3}].data());
return result;
}
template <std::uint64_t N>
std::array<double, N> Uniform() noexcept {
alignas(HWY_ALIGNMENT) std::array<double, N> result;
const ScalableTag<std::uint64_t> tag{};
const ScalableTag<double> real_tag{};
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
auto s3 = Load(tag, state_[{3}].data());
for (std::uint64_t i = 0; i < N; i += Lanes(real_tag)) {
const auto next = Update(s0, s1, s2, s3);
const auto bits = ShiftRight<11>(next);
const auto real = ConvertTo(real_tag, bits);
const auto uniform = Mul(real, MUL_VALUE);
Store(uniform, real_tag, result.data() + i);
}
Store(s0, tag, state_[{0}].data());
Store(s1, tag, state_[{1}].data());
Store(s2, tag, state_[{2}].data());
Store(s3, tag, state_[{3}].data());
return result;
}
#endif
private:
StateType state_;
const std::uint64_t streams;
HWY_INLINE static VU64 Update(VU64 &s0, VU64 &s1, VU64 &s2,
VU64 &s3) noexcept {
const auto result = Add(RotateRight<41>(Add(s0, s3)), s0);
const auto t = ShiftLeft<17>(s1);
s2 = Xor(s2, s0);
s3 = Xor(s3, s1);
s1 = Xor(s1, s2);
s0 = Xor(s0, s3);
s2 = Xor(s2, t);
s3 = RotateRight<19>(s3);
return result;
}
HWY_INLINE VU64 Next() noexcept {
const ScalableTag<std::uint64_t> tag{};
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
auto s3 = Load(tag, state_[{3}].data());
auto result = Update(s0, s1, s2, s3);
Store(s0, tag, state_[{0}].data());
Store(s1, tag, state_[{1}].data());
Store(s2, tag, state_[{2}].data());
Store(s3, tag, state_[{3}].data());
return result;
}
};
template <std::uint64_t size = 1024>
class CachedXoshiro {
public:
using result_type = std::uint64_t;
static constexpr result_type(min)() {
return (std::numeric_limits<result_type>::min)();
}
static constexpr result_type(max)() {
return (std::numeric_limits<result_type>::max)();
}
explicit CachedXoshiro(const result_type seed,
const result_type threadNumber = 0)
: generator_{seed, threadNumber},
cache_{generator_.operator()<size>()},
index_{0} {}
result_type operator()() noexcept {
if (HWY_UNLIKELY(index_ == size)) {
cache_ = std::move(generator_.operator()<size>());
index_ = 0;
}
return cache_[index_++];
}
private:
VectorXoshiro generator_;
alignas(HWY_ALIGNMENT) std::array<result_type, size> cache_;
std::size_t index_;
static_assert((size & (size - 1)) == 0 && size != 0,
"only power of 2 are supported");
};
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_

View File

@@ -0,0 +1,265 @@
package(
default_applicable_licenses = ["//:license"],
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Unused on Bazel builds, where this is not defined/known; Copybara replaces
# usages with an empty list.
COMPAT = [
"//buildenv/target:non_prod", # includes mobile/vendor.
]
cc_library(
name = "intel",
# hdrs = select({
# "//third_party/bazel_platforms/cpu:x86_64": [
# "avx512-16bit-common.h",
# "avx512-16bit-qsort.hpp",
# "avx512-32bit-qsort.hpp",
# "avx512-64bit-common.h",
# "avx512-64bit-qsort.hpp",
# "avx512-common-qsort.h",
# ],
# "//conditions:default": [],
# }),
compatible_with = [],
)
cc_library(
name = "vxsort",
srcs = [
# "vxsort/isa_detection.cpp",
# "vxsort/isa_detection_msvc.cpp",
# "vxsort/isa_detection_sane.cpp",
# "vxsort/machine_traits.avx2.cpp",
# "vxsort/smallsort/avx2_load_mask_tables.cpp",
# "vxsort/smallsort/bitonic_sort.AVX2.double.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX2.float.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX512.double.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX512.float.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.cpp",
# "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.cpp",
# "vxsort/vxsort_stats.cpp",
],
hdrs = [
# "vxsort/alignment.h",
# "vxsort/defs.h",
# "vxsort/isa_detection.h",
# "vxsort/machine_traits.avx2.h",
# "vxsort/machine_traits.avx512.h",
# "vxsort/machine_traits.h",
# "vxsort/packer.h",
# "vxsort/smallsort/bitonic_sort.AVX2.double.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX2.float.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX512.double.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX512.float.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.h",
# "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.h",
# "vxsort/smallsort/bitonic_sort.h",
# "vxsort/vxsort.h",
# "vxsort/vxsort_stats.h",
],
compatible_with = [],
textual_hdrs = [
# "vxsort/vxsort_targets_disable.h",
# "vxsort/vxsort_targets_enable_avx2.h",
# "vxsort/vxsort_targets_enable_avx512.h",
],
)
VQSORT_SRCS = [
"vqsort.cc",
# Split into separate files to reduce MSVC build time.
"vqsort_128a.cc",
"vqsort_128d.cc",
"vqsort_f16a.cc",
"vqsort_f16d.cc",
"vqsort_f32a.cc",
"vqsort_f32d.cc",
"vqsort_f64a.cc",
"vqsort_f64d.cc",
"vqsort_i16a.cc",
"vqsort_i16d.cc",
"vqsort_i32a.cc",
"vqsort_i32d.cc",
"vqsort_i64a.cc",
"vqsort_i64d.cc",
"vqsort_kv64a.cc",
"vqsort_kv64d.cc",
"vqsort_kv128a.cc",
"vqsort_kv128d.cc",
"vqsort_u16a.cc",
"vqsort_u16d.cc",
"vqsort_u32a.cc",
"vqsort_u32d.cc",
"vqsort_u64a.cc",
"vqsort_u64d.cc",
]
VQSORT_TEXTUAL_HDRS = [
"shared-inl.h",
"sorting_networks-inl.h",
"traits-inl.h",
"traits128-inl.h",
"vqsort-inl.h",
# Placeholder for internal instrumentation. Do not remove.
]
cc_library(
name = "vqsort",
srcs = VQSORT_SRCS,
hdrs = [
"order.h", # part of public interface, included by vqsort.h
"vqsort.h", # public interface
],
compatible_with = [],
local_defines = ["hwy_contrib_EXPORTS"],
textual_hdrs = VQSORT_TEXTUAL_HDRS,
deps = [
":intel", # required if HAVE_INTEL
":vxsort", # required if HAVE_VXSORT
"//:algo",
"//:hwy",
],
)
# -----------------------------------------------------------------------------
# Internal-only targets
# Same as vqsort, but add HWY_COMPILE_ALL_ATTAINABLE to ensure we cover all
# targets. Do not enable this in the main vqsort because it increases
# compile times.
cc_library(
name = "vqsort_for_test",
srcs = VQSORT_SRCS,
hdrs = [
"order.h", # part of public interface, included by vqsort.h
"vqsort.h", # public interface
],
compatible_with = [],
local_defines = [
"hwy_contrib_EXPORTS",
# Build for all targets because sort_test will dynamic-dispatch to all.
"HWY_COMPILE_ALL_ATTAINABLE",
],
textual_hdrs = VQSORT_TEXTUAL_HDRS,
deps = [
"//:algo",
"//:hwy",
],
)
cc_library(
name = "helpers",
testonly = 1,
textual_hdrs = [
"algo-inl.h",
"result-inl.h",
],
deps = [
":vqsort",
"//:nanobenchmark",
# Required for HAVE_PDQSORT, but that is unused and this is
# unavailable to Bazel builds, hence commented out.
# "//third_party/boost/allowed",
# Avoid ips4o and thus TBB to work around hwloc build failure.
],
)
cc_binary(
name = "print_network",
testonly = 1,
srcs = ["print_network.cc"],
deps = [
":helpers",
":vqsort",
"//:hwy",
],
)
TEST_MAIN = select({
"//:compiler_msvc": [],
"//conditions:default": ["@com_google_googletest//:gtest_main"],
})
cc_test(
name = "sort_unit_test",
size = "small",
srcs = ["sort_unit_test.cc"],
# Do not enable fully_static_link (pthread crash on bazel)
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":helpers",
":vqsort_for_test",
"//:hwy",
"//:hwy_test_util",
] + TEST_MAIN,
)
cc_test(
name = "sort_test",
size = "medium",
timeout = "long",
srcs = ["sort_test.cc"],
# Do not enable fully_static_link (pthread crash on bazel)
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":helpers",
":vqsort_for_test",
"//:hwy",
"//:hwy_test_util",
"//:thread_pool",
"//:topology",
] + TEST_MAIN,
)
cc_test(
name = "bench_sort",
size = "medium",
srcs = ["bench_sort.cc"],
# Do not enable fully_static_link (pthread crash on bazel)
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":helpers",
":vqsort",
"//:hwy",
"//:hwy_test_util",
"//:nanobenchmark",
"//:thread_pool",
] + TEST_MAIN,
)
cc_binary(
name = "bench_parallel",
testonly = 1,
srcs = ["bench_parallel.cc"],
# Do not enable fully_static_link (pthread crash on bazel)
local_defines = ["HWY_IS_TEST"],
deps = [
":helpers",
":vqsort",
"//:hwy",
"//:hwy_test_util",
"//:nanobenchmark",
] + TEST_MAIN,
)

View File

@@ -0,0 +1,620 @@
// Copyright 2021 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Normal include guard for target-independent parts
#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <algorithm> // std::sort
#include <functional> // std::less, std::greater
#include <vector>
#include "third_party/highway/hwy/contrib/sort/vqsort.h"
#include "third_party/highway/hwy/highway.h"
#include "third_party/highway/hwy/print.h"
// Third-party algorithms
#define HAVE_AVX2SORT 0
#define HAVE_IPS4O 0
// When enabling, consider changing max_threads (required for Table 1a)
#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1)
#define HAVE_PDQSORT 0
#define HAVE_SORT512 0
#define HAVE_VXSORT 0
#if HWY_ARCH_X86
#define HAVE_INTEL 0
#else
#define HAVE_INTEL 0
#endif
#if HAVE_PARALLEL_IPS4O
#include <thread> // NOLINT
#endif
#if HAVE_AVX2SORT
HWY_PUSH_ATTRIBUTES("avx2,avx")
#include "avx2sort.h" //NOLINT
HWY_POP_ATTRIBUTES
#endif
#if HAVE_IPS4O || HAVE_PARALLEL_IPS4O
#include "third_party/ips4o/include/ips4o.hpp"
#include "third_party/ips4o/include/ips4o/thread_pool.hpp"
#endif
#if HAVE_PDQSORT
#include "third_party/boost/allowed/sort/sort.hpp"
#endif
#if HAVE_SORT512
#include "sort512.h" //NOLINT
#endif
// vxsort is difficult to compile for multiple targets because it also uses
// .cpp files, and we'd also have to #undef its include guards. Instead, compile
// only for AVX2 or AVX3 depending on this macro.
#define VXSORT_AVX3 1
#if HAVE_VXSORT
// inlined from vxsort_targets_enable_avx512 (must close before end of header)
#ifdef __GNUC__
#ifdef __clang__
#if VXSORT_AVX3
#pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \
apply_to = any(function))
#else
#pragma clang attribute push(__attribute__((target("avx2"))), \
apply_to = any(function))
#endif // VXSORT_AVX3
#else
#pragma GCC push_options
#if VXSORT_AVX3
#pragma GCC target("avx512f,avx512dq")
#else
#pragma GCC target("avx2")
#endif // VXSORT_AVX3
#endif
#endif
#if VXSORT_AVX3
#include "vxsort/machine_traits.avx512.h"
#else
#include "vxsort/machine_traits.avx2.h"
#endif // VXSORT_AVX3
#include "vxsort/vxsort.h"
#ifdef __GNUC__
#ifdef __clang__
#pragma clang attribute pop
#else
#pragma GCC pop_options
#endif
#endif
#endif // HAVE_VXSORT
namespace hwy {
enum class Dist { kUniform8, kUniform16, kUniform32 };
static inline std::vector<Dist> AllDist() {
// Also include lower-entropy distributions to test MaybePartitionTwoValue.
return {Dist::kUniform8, /*Dist::kUniform16,*/ Dist::kUniform32};
}
static inline const char* DistName(Dist dist) {
switch (dist) {
case Dist::kUniform8:
return "uniform8";
case Dist::kUniform16:
return "uniform16";
case Dist::kUniform32:
return "uniform32";
}
return "unreachable";
}
template <typename T>
class InputStats {
public:
void Notify(T value) {
min_ = HWY_MIN(min_, value);
max_ = HWY_MAX(max_, value);
// Converting to integer would truncate floats, multiplying to save digits
// risks overflow especially when casting, so instead take the sum of the
// bit representations as the checksum.
uint64_t bits = 0;
static_assert(sizeof(T) <= 8, "Expected a built-in type");
CopyBytes<sizeof(T)>(&value, &bits); // not same size
sum_ += bits;
count_ += 1;
}
bool operator==(const InputStats& other) const {
char type_name[100];
detail::TypeName(hwy::detail::MakeTypeInfo<T>(), 1, type_name);
if (count_ != other.count_) {
HWY_ABORT("Sort %s: count %d vs %d\n", type_name,
static_cast<int>(count_), static_cast<int>(other.count_));
}
if (min_ != other.min_ || max_ != other.max_) {
HWY_ABORT("Sort %s: minmax %f/%f vs %f/%f\n", type_name,
static_cast<double>(min_), static_cast<double>(max_),
static_cast<double>(other.min_),
static_cast<double>(other.max_));
}
// Sum helps detect duplicated/lost values
if (sum_ != other.sum_) {
HWY_ABORT("Sort %s: Sum mismatch %g %g; min %g max %g\n", type_name,
static_cast<double>(sum_), static_cast<double>(other.sum_),
static_cast<double>(min_), static_cast<double>(max_));
}
return true;
}
private:
T min_ = hwy::HighestValue<T>();
T max_ = hwy::LowestValue<T>();
uint64_t sum_ = 0;
size_t count_ = 0;
};
enum class Algo {
#if HAVE_INTEL
kIntel,
#endif
#if HAVE_AVX2SORT
kSEA,
#endif
#if HAVE_IPS4O
kIPS4O,
#endif
#if HAVE_PARALLEL_IPS4O
kParallelIPS4O,
#endif
#if HAVE_PDQSORT
kPDQ,
#endif
#if HAVE_SORT512
kSort512,
#endif
#if HAVE_VXSORT
kVXSort,
#endif
kStdSort,
kStdSelect,
kStdPartialSort,
kVQSort,
kVQPartialSort,
kVQSelect,
kHeapSort,
kHeapPartialSort,
kHeapSelect,
};
static inline bool IsVQ(Algo algo) {
switch (algo) {
case Algo::kVQSort:
case Algo::kVQPartialSort:
case Algo::kVQSelect:
return true;
default:
return false;
}
}
static inline bool IsSelect(Algo algo) {
switch (algo) {
case Algo::kStdSelect:
case Algo::kVQSelect:
case Algo::kHeapSelect:
return true;
default:
return false;
}
}
static inline bool IsPartialSort(Algo algo) {
switch (algo) {
case Algo::kStdPartialSort:
case Algo::kVQPartialSort:
case Algo::kHeapPartialSort:
return true;
default:
return false;
}
}
static inline Algo ReferenceAlgoFor(Algo algo) {
if (IsPartialSort(algo)) return Algo::kStdPartialSort;
#if HAVE_PDQSORT
return Algo::kPDQ;
#else
return Algo::kStdSort;
#endif
}
static inline const char* AlgoName(Algo algo) {
switch (algo) {
#if HAVE_INTEL
case Algo::kIntel:
return "intel";
#endif
#if HAVE_AVX2SORT
case Algo::kSEA:
return "sea";
#endif
#if HAVE_IPS4O
case Algo::kIPS4O:
return "ips4o";
#endif
#if HAVE_PARALLEL_IPS4O
case Algo::kParallelIPS4O:
return "par_ips4o";
#endif
#if HAVE_PDQSORT
case Algo::kPDQ:
return "pdq";
#endif
#if HAVE_SORT512
case Algo::kSort512:
return "sort512";
#endif
#if HAVE_VXSORT
case Algo::kVXSort:
return "vxsort";
#endif
case Algo::kStdSort:
return "std";
case Algo::kStdPartialSort:
return "std_partial";
case Algo::kStdSelect:
return "std_select";
case Algo::kVQSort:
return "vq";
case Algo::kVQPartialSort:
return "vq_partial";
case Algo::kVQSelect:
return "vq_select";
case Algo::kHeapSort:
return "heap";
case Algo::kHeapPartialSort:
return "heap_partial";
case Algo::kHeapSelect:
return "heap_select";
}
return "unreachable";
}
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
// Per-target
// clang-format off
#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT
#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
#else
#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
#endif
// clang-format on
#include "third_party/highway/hwy/aligned_allocator.h"
#include "third_party/highway/hwy/contrib/sort/traits-inl.h"
#include "third_party/highway/hwy/contrib/sort/traits128-inl.h"
#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h" // HeapSort
HWY_BEFORE_NAMESPACE();
// Requires target pragma set by HWY_BEFORE_NAMESPACE
#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3
// #include "avx512-16bit-qsort.hpp" // requires AVX512-VBMI2
#include "avx512-32bit-qsort.hpp"
#include "avx512-64bit-qsort.hpp"
#endif
namespace hwy {
namespace HWY_NAMESPACE {
#if HAVE_INTEL || HAVE_VXSORT // only supports ascending order
template <typename T>
using OtherOrder = detail::OrderAscending<T>;
#else
template <typename T>
using OtherOrder = detail::OrderDescending<T>;
#endif
class Xorshift128Plus {
static HWY_INLINE uint64_t SplitMix64(uint64_t z) {
z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull;
z = (z ^ (z >> 27)) * 0x94D049BB133111EBull;
return z ^ (z >> 31);
}
public:
// Generates two vectors of 64-bit seeds via SplitMix64 and stores into
// `seeds`. Generating these afresh in each ChoosePivot is too expensive.
template <class DU64>
static void GenerateSeeds(DU64 du64, TFromD<DU64>* HWY_RESTRICT seeds) {
seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull);
for (size_t i = 1; i < 2 * Lanes(du64); ++i) {
seeds[i] = SplitMix64(seeds[i - 1]);
}
}
// Need to pass in the state because vector cannot be class members.
template <class VU64>
static VU64 RandomBits(VU64& state0, VU64& state1) {
VU64 s1 = state0;
VU64 s0 = state1;
const VU64 bits = Add(s1, s0);
state0 = s0;
s1 = Xor(s1, ShiftLeft<23>(s1));
state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0))));
return bits;
}
};
template <class D, class VU64, HWY_IF_NOT_FLOAT_D(D)>
Vec<D> RandomValues(D d, VU64& s0, VU64& s1, const VU64 mask) {
const VU64 bits = Xorshift128Plus::RandomBits(s0, s1);
return BitCast(d, And(bits, mask));
}
// It is important to avoid denormals, which are flushed to zero by SIMD but not
// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD.
template <class DF, class VU64, HWY_IF_FLOAT_D(DF)>
Vec<DF> RandomValues(DF df, VU64& s0, VU64& s1, const VU64 mask) {
using TF = TFromD<DF>;
const RebindToUnsigned<decltype(df)> du;
using VU = Vec<decltype(du)>;
const VU64 bits64 = And(Xorshift128Plus::RandomBits(s0, s1), mask);
#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to smaller types
using TU = MakeUnsigned<TF>;
const VU bits = Set(du, static_cast<TU>(GetLane(bits64) & LimitsMax<TU>()));
#else
const VU bits = BitCast(du, bits64);
#endif
// Avoid NaN/denormal by only generating values in [1, 2), i.e. random
// mantissas with the exponent taken from the representation of 1.0.
const VU k1 = BitCast(du, Set(df, TF{1.0}));
const VU mantissa_mask = Set(du, MantissaMask<TF>());
const VU representation = OrAnd(k1, bits, mantissa_mask);
return BitCast(df, representation);
}
template <class DU64>
Vec<DU64> MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) {
switch (sizeof_t) {
case 2:
return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull
: 0xFFFFFFFFFFFFFFFFull);
case 4:
return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull
: (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull
: 0xFFFFFFFFFFFFFFFFull);
case 8:
return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull
: (dist == Dist::kUniform16) ? 0x000000000000FFFFull
: 0x00000000FFFFFFFFull);
default:
HWY_ABORT("Logic error");
return Zero(du64);
}
}
template <typename T>
InputStats<T> GenerateInput(const Dist dist, T* v, size_t num_lanes) {
SortTag<uint64_t> du64;
using VU64 = Vec<decltype(du64)>;
const size_t N64 = Lanes(du64);
auto seeds = hwy::AllocateAligned<uint64_t>(2 * N64);
Xorshift128Plus::GenerateSeeds(du64, seeds.get());
VU64 s0 = Load(du64, seeds.get());
VU64 s1 = Load(du64, seeds.get() + N64);
#if HWY_TARGET == HWY_SCALAR
const Sisd<T> d;
#else
const Repartition<T, decltype(du64)> d;
#endif
using V = Vec<decltype(d)>;
const size_t N = Lanes(d);
const VU64 mask = MaskForDist(du64, dist, sizeof(T));
auto buf = hwy::AllocateAligned<T>(N);
size_t i = 0;
for (; i + N <= num_lanes; i += N) {
const V values = RandomValues(d, s0, s1, mask);
StoreU(values, d, v + i);
}
if (i < num_lanes) {
const V values = RandomValues(d, s0, s1, mask);
StoreU(values, d, buf.get());
CopyBytes(buf.get(), v + i, (num_lanes - i) * sizeof(T));
}
InputStats<T> input_stats;
for (size_t i = 0; i < num_lanes; ++i) {
input_stats.Notify(v[i]);
}
return input_stats;
}
struct SharedState {
#if HAVE_PARALLEL_IPS4O
const unsigned max_threads = hwy::LimitsMax<unsigned>(); // 16 for Table 1a
ips4o::StdThreadPool pool{static_cast<int>(
HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))};
#endif
};
// Adapters from Run's num_keys to vqsort-inl.h num_lanes.
template <typename KeyType, class Order>
void CallHeapSort(KeyType* keys, const size_t num_keys, Order) {
const detail::MakeTraits<KeyType, Order> st;
using LaneType = typename decltype(st)::LaneType;
return detail::HeapSort(st, reinterpret_cast<LaneType*>(keys),
num_keys * st.LanesPerKey());
}
template <typename KeyType, class Order>
void CallHeapPartialSort(KeyType* keys, const size_t num_keys,
const size_t k_keys, Order) {
const detail::MakeTraits<KeyType, Order> st;
using LaneType = typename decltype(st)::LaneType;
detail::HeapPartialSort(st, reinterpret_cast<LaneType*>(keys),
num_keys * st.LanesPerKey(),
k_keys * st.LanesPerKey());
}
template <typename KeyType, class Order>
void CallHeapSelect(KeyType* keys, const size_t num_keys, const size_t k_keys,
Order) {
const detail::MakeTraits<KeyType, Order> st;
using LaneType = typename decltype(st)::LaneType;
detail::HeapSelect(st, reinterpret_cast<LaneType*>(keys),
num_keys * st.LanesPerKey(), k_keys * st.LanesPerKey());
}
template <typename KeyType, class Order>
void Run(Algo algo, KeyType* inout, size_t num_keys, SharedState& shared,
size_t /*thread*/, size_t k_keys, Order) {
const std::less<KeyType> less;
const std::greater<KeyType> greater;
constexpr bool kAscending = Order::IsAscending();
#if !HAVE_PARALLEL_IPS4O
(void)shared;
#endif
switch (algo) {
#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3
case Algo::kIntel:
return avx512_qsort<KeyType>(inout, static_cast<int64_t>(num_keys));
#endif
#if HAVE_AVX2SORT
case Algo::kSEA:
return avx2::quicksort(inout, static_cast<int>(num_keys));
#endif
#if HAVE_IPS4O
case Algo::kIPS4O:
if (kAscending) {
return ips4o::sort(inout, inout + num_keys, less);
} else {
return ips4o::sort(inout, inout + num_keys, greater);
}
#endif
#if HAVE_PARALLEL_IPS4O
case Algo::kParallelIPS4O:
if (kAscending) {
return ips4o::parallel::sort(inout, inout + num_keys, less,
shared.pool);
} else {
return ips4o::parallel::sort(inout, inout + num_keys, greater,
shared.pool);
}
#endif
#if HAVE_SORT512
case Algo::kSort512:
HWY_ABORT("not supported");
// return Sort512::Sort(inout, num_keys);
#endif
#if HAVE_PDQSORT
case Algo::kPDQ:
if (kAscending) {
return boost::sort::pdqsort_branchless(inout, inout + num_keys, less);
} else {
return boost::sort::pdqsort_branchless(inout, inout + num_keys,
greater);
}
#endif
#if HAVE_VXSORT
case Algo::kVXSort: {
#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \
(!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2)
HWY_WARN("Do not call for target %s\n", hwy::TargetName(HWY_TARGET));
return;
#else
#if VXSORT_AVX3
vxsort::vxsort<KeyType, vxsort::AVX512> vx;
#else
vxsort::vxsort<KeyType, vxsort::AVX2> vx;
#endif
if (kAscending) {
return vx.sort(inout, inout + num_keys - 1);
} else {
HWY_WARN("Skipping VX - does not support descending order\n");
return;
}
#endif // enabled for this target
}
#endif // HAVE_VXSORT
case Algo::kStdSort:
if (kAscending) {
return std::sort(inout, inout + num_keys, less);
} else {
return std::sort(inout, inout + num_keys, greater);
}
case Algo::kStdPartialSort:
if (kAscending) {
return std::partial_sort(inout, inout + k_keys, inout + num_keys, less);
} else {
return std::partial_sort(inout, inout + k_keys, inout + num_keys,
greater);
}
case Algo::kStdSelect:
if (kAscending) {
return std::nth_element(inout, inout + k_keys, inout + num_keys, less);
} else {
return std::nth_element(inout, inout + k_keys, inout + num_keys,
greater);
}
case Algo::kVQSort:
return VQSort(inout, num_keys, Order());
case Algo::kVQPartialSort:
return VQPartialSort(inout, num_keys, k_keys, Order());
case Algo::kVQSelect:
return VQSelect(inout, num_keys, k_keys, Order());
case Algo::kHeapSort:
return CallHeapSort(inout, num_keys, Order());
case Algo::kHeapPartialSort:
return CallHeapPartialSort(inout, num_keys, k_keys, Order());
case Algo::kHeapSelect:
return CallHeapSelect(inout, num_keys, k_keys, Order());
default:
HWY_ABORT("Not implemented");
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE

View File

@@ -0,0 +1,34 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Tag arguments that determine the sort order. Used by both vqsort.h and the
// VQSortStatic in vqsort-inl.h. Moved to a separate header so that the latter
// can be used without pulling in the dllimport statements in vqsort.h.
#ifndef HIGHWAY_HWY_CONTRIB_SORT_ORDER_H_
#define HIGHWAY_HWY_CONTRIB_SORT_ORDER_H_
namespace hwy {
struct SortAscending {
static constexpr bool IsAscending() { return true; }
};
struct SortDescending {
static constexpr bool IsAscending() { return false; }
};
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_SORT_ORDER_H_

View File

@@ -0,0 +1,291 @@
// Copyright 2021 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "third_party/highway/hwy/contrib/sort/algo-inl.h"
// Normal include guard for non-SIMD parts
#ifndef HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_
#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_
#include <stdint.h>
#include <stdio.h>
#include <time.h>
#include <algorithm> // std::sort
#include <string>
#include <vector>
#include "third_party/highway/hwy/aligned_allocator.h"
#include "third_party/highway/hwy/base.h"
#include "third_party/highway/hwy/contrib/sort/order.h"
#include "third_party/highway/hwy/per_target.h" // DispatchedTarget
#include "third_party/highway/hwy/targets.h" // TargetName
namespace hwy {
// Returns trimmed mean (we don't want to run an out-of-L3-cache sort often
// enough for the mode to be reliable).
static inline double SummarizeMeasurements(std::vector<double>& seconds) {
std::sort(seconds.begin(), seconds.end());
double sum = 0;
int count = 0;
const size_t num = seconds.size();
for (size_t i = num / 4; i < num / 2; ++i) {
sum += seconds[i];
count += 1;
}
return sum / count;
}
struct SortResult {
SortResult() {}
SortResult(const Algo algo, Dist dist, size_t num_keys, size_t num_threads,
double sec, size_t sizeof_key, const char* key_name)
: target(DispatchedTarget()),
algo(algo),
dist(dist),
num_keys(num_keys),
num_threads(num_threads),
sec(sec),
sizeof_key(sizeof_key),
key_name(key_name) {}
void Print() const {
const double bytes = static_cast<double>(num_keys) *
static_cast<double>(num_threads) *
static_cast<double>(sizeof_key);
printf("%10s: %12s: %7s: %9s: %05g %4.0f MB/s (%2zu threads)\n",
hwy::TargetName(target), AlgoName(algo), key_name.c_str(),
DistName(dist), static_cast<double>(num_keys), bytes * 1E-6 / sec,
num_threads);
}
int64_t target;
Algo algo;
Dist dist;
size_t num_keys = 0;
size_t num_threads = 0;
double sec = 0.0;
size_t sizeof_key = 0;
std::string key_name;
};
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_
// Per-target
#if defined(HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE) == \
defined(HWY_TARGET_TOGGLE)
#ifdef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE
#undef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE
#else
#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE
#endif
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
// Copies the input, and compares results to that of a reference algorithm.
template <class Traits>
class ReferenceSortVerifier {
using LaneType = typename Traits::LaneType;
using KeyType = typename Traits::KeyType;
using Order = typename Traits::Order;
static constexpr bool kAscending = Order::IsAscending();
static constexpr size_t kLPK = Traits().LanesPerKey();
public:
ReferenceSortVerifier(const LaneType* in_lanes, size_t num_lanes) {
num_lanes_ = num_lanes;
num_keys_ = num_lanes / kLPK;
in_lanes_ = hwy::AllocateAligned<LaneType>(num_lanes);
HWY_ASSERT(in_lanes_);
CopyBytes(in_lanes, in_lanes_.get(), num_lanes * sizeof(LaneType));
}
// For full sorts, k_keys == num_keys.
void operator()(Algo algo, const LaneType* out_lanes, size_t k_keys) {
SharedState shared;
const Traits st;
const CappedTag<LaneType, kLPK> d;
HWY_ASSERT(hwy::IsAligned(in_lanes_.get(), sizeof(KeyType)));
KeyType* in_keys = HWY_RCAST_ALIGNED(KeyType*, in_lanes_.get());
char caption[10];
const char* algo_type = IsPartialSort(algo) ? "PartialSort" : "Sort";
HWY_ASSERT(k_keys <= num_keys_);
Run(ReferenceAlgoFor(algo), in_keys, num_keys_, shared, /*thread=*/0,
k_keys, Order());
if (IsSelect(algo)) {
// Print lanes centered around k_keys.
if (VQSORT_PRINT >= 3) {
const size_t begin_lane = k_keys < 3 ? 0 : (k_keys - 3) * kLPK;
const size_t end_lane = HWY_MIN(num_lanes_, (k_keys + 3) * kLPK);
fprintf(stderr, "\nExpected:\n");
for (size_t i = begin_lane; i < end_lane; i += kLPK) {
snprintf(caption, sizeof(caption), "%4zu ", i / kLPK);
Print(d, caption, st.SetKey(d, &in_lanes_[i]));
}
fprintf(stderr, "\n\nActual:\n");
for (size_t i = begin_lane; i < end_lane; i += kLPK) {
snprintf(caption, sizeof(caption), "%4zu ", i / kLPK);
Print(d, caption, st.SetKey(d, &out_lanes[i]));
}
fprintf(stderr, "\n\n");
}
// At k_keys: should be equivalent, i.e. neither a < b nor b < a.
// SortOrderVerifier will also check the ordering of the rest of the keys.
const size_t k = k_keys * kLPK;
if (st.Compare1(&in_lanes_[k], &out_lanes[k]) ||
st.Compare1(&out_lanes[k], &in_lanes_[k])) {
Print(d, "Expected", st.SetKey(d, &in_lanes_[k]));
Print(d, " Actual", st.SetKey(d, &out_lanes[k]));
HWY_ABORT("Select %s asc=%d: mismatch at k_keys=%zu, num_keys=%zu\n",
st.KeyString(), kAscending, k_keys, num_keys_);
}
} else {
if (VQSORT_PRINT >= 3) {
const size_t lanes_to_print = HWY_MIN(40, k_keys * kLPK);
fprintf(stderr, "\nExpected:\n");
for (size_t i = 0; i < lanes_to_print; i += kLPK) {
snprintf(caption, sizeof(caption), "%4zu ", i / kLPK);
Print(d, caption, st.SetKey(d, &in_lanes_[i]));
}
fprintf(stderr, "\n\nActual:\n");
for (size_t i = 0; i < lanes_to_print; i += kLPK) {
snprintf(caption, sizeof(caption), "%4zu ", i / kLPK);
Print(d, caption, st.SetKey(d, &out_lanes[i]));
}
fprintf(stderr, "\n\n");
}
// Full or partial sort: all elements up to k_keys are equivalent to the
// reference sort. SortOrderVerifier also checks the output's ordering.
for (size_t i = 0; i < k_keys * kLPK; i += kLPK) {
// All up to k_keys should be equivalent, i.e. neither a < b nor b < a.
if (st.Compare1(&in_lanes_[i], &out_lanes[i]) ||
st.Compare1(&out_lanes[i], &in_lanes_[i])) {
Print(d, "Expected", st.SetKey(d, &in_lanes_[i]));
Print(d, " Actual", st.SetKey(d, &out_lanes[i]));
HWY_ABORT("%s %s asc=%d: mismatch at %zu, k_keys=%zu, num_keys=%zu\n",
algo_type, st.KeyString(), kAscending, i / kLPK, k_keys,
num_keys_);
}
}
}
}
private:
hwy::AlignedFreeUniquePtr<LaneType[]> in_lanes_;
size_t num_lanes_;
size_t num_keys_;
};
// Faster than ReferenceSortVerifier, for use in bench_sort. Only verifies
// order, without running a slow reference sorter. This means it can't verify
// Select places the correct key at `k_keys`, nor that input and output keys are
// the same.
template <class Traits>
class SortOrderVerifier {
using LaneType = typename Traits::LaneType;
using Order = typename Traits::Order;
static constexpr bool kAscending = Order::IsAscending();
static constexpr size_t kLPK = Traits().LanesPerKey();
public:
void operator()(Algo algo, const InputStats<LaneType>& input_stats,
const LaneType* output, size_t num_keys, size_t k_keys) {
if (IsSelect(algo)) {
CheckSelectOrder(input_stats, output, num_keys, k_keys);
} else {
CheckSortedOrder(algo, input_stats, output, num_keys, k_keys);
}
}
private:
// For full or partial sorts: ensures keys are in sorted order.
void CheckSortedOrder(const Algo algo,
const InputStats<LaneType>& input_stats,
const LaneType* output, const size_t num_keys,
const size_t k_keys) {
const Traits st;
const CappedTag<LaneType, kLPK> d;
const size_t num_lanes = num_keys * kLPK;
const size_t k = k_keys * kLPK;
const char* algo_type = IsPartialSort(algo) ? "PartialSort" : "Sort";
InputStats<LaneType> output_stats;
// Even for partial sorts, loop over all keys to verify none disappeared.
for (size_t i = 0; i < num_lanes - kLPK; i += kLPK) {
output_stats.Notify(output[i]);
if (kLPK == 2) output_stats.Notify(output[i + 1]);
// Only check the first k_keys (== num_keys for a full sort).
// Reverse order instead of checking !Compare1 so we accept equal keys.
if (i < k - kLPK && st.Compare1(output + i + kLPK, output + i)) {
Print(d, " cur", st.SetKey(d, &output[i]));
Print(d, "next", st.SetKey(d, &output[i + kLPK]));
HWY_ABORT(
"%s %s asc=%d: wrong order at %zu, k_keys=%zu, num_keys=%zu\n",
algo_type, st.KeyString(), kAscending, i / kLPK, k_keys, num_keys);
}
}
output_stats.Notify(output[num_lanes - kLPK]);
if (kLPK == 2) output_stats.Notify(output[num_lanes - kLPK + 1]);
HWY_ASSERT(input_stats == output_stats);
}
// Ensures keys below index k_keys are less, and all above are greater.
void CheckSelectOrder(const InputStats<LaneType>& input_stats,
const LaneType* output, const size_t num_keys,
const size_t k_keys) {
const Traits st;
const CappedTag<LaneType, kLPK> d;
const size_t num_lanes = num_keys * kLPK;
const size_t k = k_keys * kLPK;
InputStats<LaneType> output_stats;
for (size_t i = 0; i < num_lanes - kLPK; i += kLPK) {
output_stats.Notify(output[i]);
if (kLPK == 2) output_stats.Notify(output[i + 1]);
// Reverse order instead of checking !Compare1 so we accept equal keys.
if (i < k ? st.Compare1(output + k, output + i)
: st.Compare1(output + i, output + k)) {
Print(d, "cur", st.SetKey(d, &output[i]));
Print(d, "kth", st.SetKey(d, &output[k]));
HWY_ABORT(
"Select %s asc=%d: wrong order at %zu, k_keys=%zu, num_keys=%zu\n",
st.KeyString(), kAscending, i / kLPK, k_keys, num_keys);
}
}
output_stats.Notify(output[num_lanes - kLPK]);
if (kLPK == 2) output_stats.Notify(output[num_lanes - kLPK + 1]);
HWY_ASSERT(input_stats == output_stats);
}
};
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE

View File

@@ -0,0 +1,181 @@
// Copyright 2021 Google LLC
// Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: BSD-3-Clause
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Definitions shared between vqsort-inl and sorting_networks-inl.
// Normal include guard for target-independent parts
#ifndef HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_
#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_
#include "third_party/highway/hwy/base.h"
namespace hwy {
// Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028
static HWY_INLINE uint64_t RandomBits(uint64_t* HWY_RESTRICT state) {
const uint64_t a = state[0];
const uint64_t b = state[1];
const uint64_t w = state[2] + 1;
const uint64_t next = a ^ w;
state[0] = (b + (b << 3)) ^ (b >> 11);
const uint64_t rot = (b << 24) | (b >> 40);
state[1] = rot + next;
state[2] = w;
return next;
}
// Internal constants - these are to avoid magic numbers/literals and cannot be
// changed without also changing the associated code.
struct SortConstants {
// SortingNetwork reshapes its input into a matrix. This is the maximum number
// of *lanes* per vector. Must be at least 8 because SortSamples assumes the
// sorting network can handle 128 bytes with 8 rows, so 16 bytes per vector,
// which means 8 lanes for 16-bit types.
#if HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD
static constexpr size_t kMaxCols = 8; // avoid build timeout/stack overflow
#else
static constexpr size_t kMaxCols = 16; // enough for u32 in 512-bit vector
#endif
// 16 rows is a compromise between using the 32 AVX-512/SVE/RVV registers,
// fitting within 16 AVX2 registers with only a few spills, keeping BaseCase
// code size reasonable, and minimizing the extra logN factor for larger
// networks (for which only loose upper bounds on size are known).
static constexpr size_t kMaxRows = 16;
// Template argument ensures there is no actual division instruction.
template <size_t kLPK>
static constexpr HWY_INLINE size_t BaseCaseNumLanes(size_t N) {
// We use 8, 8x2, 8x4, and 16x{4..} networks, in units of keys. For N/kLPK
// < 4, we cannot use the 16-row networks.
return (((N / kLPK) >= 4) ? kMaxRows : 8) * HWY_MIN(N, kMaxCols);
}
// Unrolling is important (pipelining and amortizing branch mispredictions);
// 2x is sufficient to reach full memory bandwidth on SKX in Partition, but
// somewhat slower for sorting than 4x.
//
// To change, must also update left + 3 * N etc. in the loop.
static constexpr size_t kPartitionUnroll = 4;
// Chunk := group of keys loaded for sampling a pivot. Matches the typical
// cache line size of 64 bytes to get maximum benefit per L2 miss. Sort()
// ensures vectors are no larger than that, so this can be independent of the
// vector size and thus constexpr.
static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t) {
return 64 / sizeof_t;
}
template <typename T>
static constexpr HWY_INLINE size_t SampleLanes() {
return 2 * LanesPerChunk(sizeof(T)); // Stored samples
}
static constexpr HWY_INLINE size_t PartitionBufNum(size_t N) {
// The main loop reads kPartitionUnroll vectors, and first loads from
// both left and right beforehand, so it requires 2 * kPartitionUnroll
// vectors. To handle amounts between that and BaseCaseNumLanes(), we
// partition up 3 * kPartitionUnroll + 1 vectors into a two-part buffer.
return 2 * (3 * kPartitionUnroll + 1) * N;
}
// Max across the three buffer usages.
template <typename T, size_t kLPK>
static constexpr HWY_INLINE size_t BufNum(size_t N) {
// BaseCase may write one padding vector, and SortSamples uses the space
// after samples as the buffer.
return HWY_MAX(SampleLanes<T>() + BaseCaseNumLanes<kLPK>(N) + N,
PartitionBufNum(N));
}
// Translates vector_size to lanes and returns size in bytes.
template <typename T, size_t kLPK>
static constexpr HWY_INLINE size_t BufBytes(size_t vector_size) {
return BufNum<T, kLPK>(vector_size / sizeof(T)) * sizeof(T);
}
// Returns max for any type.
template <size_t kLPK>
static constexpr HWY_INLINE size_t MaxBufBytes(size_t vector_size) {
// If 2 lanes per key, it's a 128-bit key with u64 lanes.
return kLPK == 2 ? BufBytes<uint64_t, 2>(vector_size)
: HWY_MAX((BufBytes<uint16_t, 1>(vector_size)),
HWY_MAX((BufBytes<uint32_t, 1>(vector_size)),
(BufBytes<uint64_t, 1>(vector_size))));
}
};
static_assert(SortConstants::MaxBufBytes<1>(64) <= 1664, "Unexpectedly high");
static_assert(SortConstants::MaxBufBytes<2>(64) <= 1664, "Unexpectedly high");
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_
// Per-target
// clang-format off
#if defined(HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT
// clang-format on
#ifdef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE
#undef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE
#else
#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE
#endif
#include "third_party/highway/hwy/highway.h"
// vqsort isn't available on HWY_SCALAR, and builds time out on MSVC opt and
// Armv7 debug, and Armv8 GCC 11 asan hits an internal compiler error likely
// due to https://gcc.gnu.org/bugzilla/show_bug.cgi?id=97696. Armv8 Clang
// hwasan/msan/tsan/asan also fail to build SVE (b/335157772). RVV currently
// has a compiler issue.
#undef VQSORT_ENABLED
#undef VQSORT_COMPILER_COMPATIBLE
#if (HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD) || \
(HWY_ARCH_ARM_V7 && HWY_IS_DEBUG_BUILD) || \
(HWY_ARCH_ARM_A64 && HWY_COMPILER_GCC_ACTUAL && HWY_IS_ASAN) || \
(HWY_ARCH_RISCV)
#define VQSORT_COMPILER_COMPATIBLE 0
#else
#define VQSORT_COMPILER_COMPATIBLE 1
#endif
#if (HWY_TARGET == HWY_SCALAR) || !VQSORT_COMPILER_COMPATIBLE
#define VQSORT_ENABLED 0
#else
#define VQSORT_ENABLED 1
#endif
namespace hwy {
namespace HWY_NAMESPACE {
// Default tag / vector width selector.
#if HWY_TARGET == HWY_RVV
// Use LMUL = 1/2; for SEW=64 this ends up emulated via VSETVLI.
template <typename T>
using SortTag = ScalableTag<T, -1>;
#else
template <typename T>
using SortTag = ScalableTag<T>;
#endif
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE

View File

@@ -0,0 +1,902 @@
// Copyright 2021 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Per-target
#if defined(HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE) == \
defined(HWY_TARGET_TOGGLE)
#ifdef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE
#undef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE
#else
#define HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE
#endif
#include "third_party/highway/hwy/contrib/sort/shared-inl.h" // SortConstants
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
namespace detail {
#if VQSORT_ENABLED
using Constants = hwy::SortConstants;
// ------------------------------ SharedTraits
// Code shared between all traits. It's unclear whether these can profitably be
// specialized for Lane vs Block, or optimized like SortPairsDistance1 using
// Compare/DupOdd.
template <class Base>
struct SharedTraits : public Base {
using SharedTraitsForSortingNetwork =
SharedTraits<typename Base::TraitsForSortingNetwork>;
// Conditionally swaps lane 0 with 2, 1 with 3 etc.
template <class D>
HWY_INLINE Vec<D> SortPairsDistance2(D d, Vec<D> v) const {
const Base* base = static_cast<const Base*>(this);
Vec<D> swapped = base->SwapAdjacentPairs(d, v);
base->Sort2(d, v, swapped);
return base->OddEvenPairs(d, swapped, v);
}
// Swaps with the vector formed by reversing contiguous groups of 8 keys.
template <class D>
HWY_INLINE Vec<D> SortPairsReverse8(D d, Vec<D> v) const {
const Base* base = static_cast<const Base*>(this);
Vec<D> swapped = base->ReverseKeys8(d, v);
base->Sort2(d, v, swapped);
return base->OddEvenQuads(d, swapped, v);
}
// Swaps with the vector formed by reversing contiguous groups of 8 keys.
template <class D>
HWY_INLINE Vec<D> SortPairsReverse16(D d, Vec<D> v) const {
const Base* base = static_cast<const Base*>(this);
static_assert(Constants::kMaxCols <= 16, "Need actual Reverse16");
Vec<D> swapped = base->ReverseKeys(d, v);
base->Sort2(d, v, swapped);
return ConcatUpperLower(d, swapped, v); // 8 = half of the vector
}
};
// ------------------------------ Sorting network
// Sorting networks for independent columns in 2, 4 and 8 vectors from
// https://bertdobbelaere.github.io/sorting_networks.html.
template <class D, class Traits, class V = Vec<D>>
HWY_INLINE void Sort2(D d, Traits st, V& v0, V& v1) {
st.Sort2(d, v0, v1);
}
template <class D, class Traits, class V = Vec<D>>
HWY_INLINE void Sort4(D d, Traits st, V& v0, V& v1, V& v2, V& v3) {
st.Sort2(d, v0, v2);
st.Sort2(d, v1, v3);
st.Sort2(d, v0, v1);
st.Sort2(d, v2, v3);
st.Sort2(d, v1, v2);
}
template <class D, class Traits, class V = Vec<D>>
HWY_INLINE void Sort8(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5,
V& v6, V& v7) {
st.Sort2(d, v0, v2);
st.Sort2(d, v1, v3);
st.Sort2(d, v4, v6);
st.Sort2(d, v5, v7);
st.Sort2(d, v0, v4);
st.Sort2(d, v1, v5);
st.Sort2(d, v2, v6);
st.Sort2(d, v3, v7);
st.Sort2(d, v0, v1);
st.Sort2(d, v2, v3);
st.Sort2(d, v4, v5);
st.Sort2(d, v6, v7);
st.Sort2(d, v2, v4);
st.Sort2(d, v3, v5);
st.Sort2(d, v1, v4);
st.Sort2(d, v3, v6);
st.Sort2(d, v1, v2);
st.Sort2(d, v3, v4);
st.Sort2(d, v5, v6);
}
// (Green's irregular) sorting network for independent columns in 16 vectors.
template <class D, class Traits, class V = Vec<D>>
HWY_INLINE void Sort16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5,
V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd,
V& ve, V& vf) {
st.Sort2(d, v0, v1);
st.Sort2(d, v2, v3);
st.Sort2(d, v4, v5);
st.Sort2(d, v6, v7);
st.Sort2(d, v8, v9);
st.Sort2(d, va, vb);
st.Sort2(d, vc, vd);
st.Sort2(d, ve, vf);
st.Sort2(d, v0, v2);
st.Sort2(d, v1, v3);
st.Sort2(d, v4, v6);
st.Sort2(d, v5, v7);
st.Sort2(d, v8, va);
st.Sort2(d, v9, vb);
st.Sort2(d, vc, ve);
st.Sort2(d, vd, vf);
st.Sort2(d, v0, v4);
st.Sort2(d, v1, v5);
st.Sort2(d, v2, v6);
st.Sort2(d, v3, v7);
st.Sort2(d, v8, vc);
st.Sort2(d, v9, vd);
st.Sort2(d, va, ve);
st.Sort2(d, vb, vf);
st.Sort2(d, v0, v8);
st.Sort2(d, v1, v9);
st.Sort2(d, v2, va);
st.Sort2(d, v3, vb);
st.Sort2(d, v4, vc);
st.Sort2(d, v5, vd);
st.Sort2(d, v6, ve);
st.Sort2(d, v7, vf);
st.Sort2(d, v5, va);
st.Sort2(d, v6, v9);
st.Sort2(d, v3, vc);
st.Sort2(d, v7, vb);
st.Sort2(d, vd, ve);
st.Sort2(d, v4, v8);
st.Sort2(d, v1, v2);
st.Sort2(d, v1, v4);
st.Sort2(d, v7, vd);
st.Sort2(d, v2, v8);
st.Sort2(d, vb, ve);
st.Sort2(d, v2, v4);
st.Sort2(d, v5, v6);
st.Sort2(d, v9, va);
st.Sort2(d, vb, vd);
st.Sort2(d, v3, v8);
st.Sort2(d, v7, vc);
st.Sort2(d, v3, v5);
st.Sort2(d, v6, v8);
st.Sort2(d, v7, v9);
st.Sort2(d, va, vc);
st.Sort2(d, v3, v4);
st.Sort2(d, v5, v6);
st.Sort2(d, v7, v8);
st.Sort2(d, v9, va);
st.Sort2(d, vb, vc);
st.Sort2(d, v6, v7);
st.Sort2(d, v8, v9);
}
// ------------------------------ Merging networks
// Blacher's hybrid bitonic/odd-even networks, generated by print_network.cc.
// For acceptable performance, these must be inlined, otherwise vectors are
// loaded from the stack. The kKeysPerVector allows calling from generic code
// but skipping the functions when vectors have too few lanes for
// st.SortPairsDistance1 to compile. `if constexpr` in the caller would also
// work, but is not available in C++11. We write out the (unused) argument types
// rather than `...` because GCC 9 (but not 10) fails to compile with `...`.
template <size_t kKeysPerVector, class D, class Traits, class V,
HWY_IF_LANES_LE(kKeysPerVector, 1)>
HWY_INLINE void Merge8x2(D, Traits, V, V, V, V, V, V, V, V) {}
template <size_t kKeysPerVector, class D, class Traits, class V,
HWY_IF_LANES_LE(kKeysPerVector, 2)>
HWY_INLINE void Merge8x4(D, Traits, V, V, V, V, V, V, V, V) {}
template <size_t kKeysPerVector, class D, class Traits, class V,
HWY_IF_LANES_LE(kKeysPerVector, 1)>
HWY_INLINE void Merge16x2(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V,
V, V) {}
template <size_t kKeysPerVector, class D, class Traits, class V,
HWY_IF_LANES_LE(kKeysPerVector, 2)>
HWY_INLINE void Merge16x4(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V,
V, V) {}
template <size_t kKeysPerVector, class D, class Traits, class V,
HWY_IF_LANES_LE(kKeysPerVector, 4)>
HWY_INLINE void Merge16x8(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V,
V, V) {}
template <size_t kKeysPerVector, class D, class Traits, class V,
HWY_IF_LANES_LE(kKeysPerVector, 8)>
HWY_INLINE void Merge16x16(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V,
V, V) {}
template <size_t kKeysPerVector, class D, class Traits, class V = Vec<D>,
HWY_IF_LANES_GT(kKeysPerVector, 1)>
HWY_INLINE void Merge8x2(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4,
V& v5, V& v6, V& v7) {
v7 = st.ReverseKeys2(d, v7);
v6 = st.ReverseKeys2(d, v6);
v5 = st.ReverseKeys2(d, v5);
v4 = st.ReverseKeys2(d, v4);
st.Sort2(d, v0, v7);
st.Sort2(d, v1, v6);
st.Sort2(d, v2, v5);
st.Sort2(d, v3, v4);
v3 = st.ReverseKeys2(d, v3);
v2 = st.ReverseKeys2(d, v2);
v7 = st.ReverseKeys2(d, v7);
v6 = st.ReverseKeys2(d, v6);
st.Sort2(d, v0, v3);
st.Sort2(d, v1, v2);
st.Sort2(d, v4, v7);
st.Sort2(d, v5, v6);
v1 = st.ReverseKeys2(d, v1);
v3 = st.ReverseKeys2(d, v3);
v5 = st.ReverseKeys2(d, v5);
v7 = st.ReverseKeys2(d, v7);
st.Sort2(d, v0, v1);
st.Sort2(d, v2, v3);
st.Sort2(d, v4, v5);
st.Sort2(d, v6, v7);
v0 = st.SortPairsDistance1(d, v0);
v1 = st.SortPairsDistance1(d, v1);
v2 = st.SortPairsDistance1(d, v2);
v3 = st.SortPairsDistance1(d, v3);
v4 = st.SortPairsDistance1(d, v4);
v5 = st.SortPairsDistance1(d, v5);
v6 = st.SortPairsDistance1(d, v6);
v7 = st.SortPairsDistance1(d, v7);
}
template <size_t kKeysPerVector, class D, class Traits, class V = Vec<D>,
HWY_IF_LANES_GT(kKeysPerVector, 2)>
HWY_INLINE void Merge8x4(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4,
V& v5, V& v6, V& v7) {
v7 = st.ReverseKeys4(d, v7);
v6 = st.ReverseKeys4(d, v6);
v5 = st.ReverseKeys4(d, v5);
v4 = st.ReverseKeys4(d, v4);
st.Sort2(d, v0, v7);
st.Sort2(d, v1, v6);
st.Sort2(d, v2, v5);
st.Sort2(d, v3, v4);
v3 = st.ReverseKeys4(d, v3);
v2 = st.ReverseKeys4(d, v2);
v7 = st.ReverseKeys4(d, v7);
v6 = st.ReverseKeys4(d, v6);
st.Sort2(d, v0, v3);
st.Sort2(d, v1, v2);
st.Sort2(d, v4, v7);
st.Sort2(d, v5, v6);
v1 = st.ReverseKeys4(d, v1);
v3 = st.ReverseKeys4(d, v3);
v5 = st.ReverseKeys4(d, v5);
v7 = st.ReverseKeys4(d, v7);
st.Sort2(d, v0, v1);
st.Sort2(d, v2, v3);
st.Sort2(d, v4, v5);
st.Sort2(d, v6, v7);
v0 = st.SortPairsReverse4(d, v0);
v1 = st.SortPairsReverse4(d, v1);
v2 = st.SortPairsReverse4(d, v2);
v3 = st.SortPairsReverse4(d, v3);
v4 = st.SortPairsReverse4(d, v4);
v5 = st.SortPairsReverse4(d, v5);
v6 = st.SortPairsReverse4(d, v6);
v7 = st.SortPairsReverse4(d, v7);
v0 = st.SortPairsDistance1(d, v0);
v1 = st.SortPairsDistance1(d, v1);
v2 = st.SortPairsDistance1(d, v2);
v3 = st.SortPairsDistance1(d, v3);
v4 = st.SortPairsDistance1(d, v4);
v5 = st.SortPairsDistance1(d, v5);
v6 = st.SortPairsDistance1(d, v6);
v7 = st.SortPairsDistance1(d, v7);
}
// Only used by the now-deprecated SortingNetwork().
template <size_t kKeysPerVector, class D, class Traits, class V = Vec<D>,
HWY_IF_LANES_GT(kKeysPerVector, 1)>
HWY_INLINE void Merge16x2(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4,
V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb,
V& vc, V& vd, V& ve, V& vf) {
vf = st.ReverseKeys2(d, vf);
ve = st.ReverseKeys2(d, ve);
vd = st.ReverseKeys2(d, vd);
vc = st.ReverseKeys2(d, vc);
vb = st.ReverseKeys2(d, vb);
va = st.ReverseKeys2(d, va);
v9 = st.ReverseKeys2(d, v9);
v8 = st.ReverseKeys2(d, v8);
st.Sort2(d, v0, vf);
st.Sort2(d, v1, ve);
st.Sort2(d, v2, vd);
st.Sort2(d, v3, vc);
st.Sort2(d, v4, vb);
st.Sort2(d, v5, va);
st.Sort2(d, v6, v9);
st.Sort2(d, v7, v8);
v7 = st.ReverseKeys2(d, v7);
v6 = st.ReverseKeys2(d, v6);
v5 = st.ReverseKeys2(d, v5);
v4 = st.ReverseKeys2(d, v4);
vf = st.ReverseKeys2(d, vf);
ve = st.ReverseKeys2(d, ve);
vd = st.ReverseKeys2(d, vd);
vc = st.ReverseKeys2(d, vc);
st.Sort2(d, v0, v7);
st.Sort2(d, v1, v6);
st.Sort2(d, v2, v5);
st.Sort2(d, v3, v4);
st.Sort2(d, v8, vf);
st.Sort2(d, v9, ve);
st.Sort2(d, va, vd);
st.Sort2(d, vb, vc);
v3 = st.ReverseKeys2(d, v3);
v2 = st.ReverseKeys2(d, v2);
v7 = st.ReverseKeys2(d, v7);
v6 = st.ReverseKeys2(d, v6);
vb = st.ReverseKeys2(d, vb);
va = st.ReverseKeys2(d, va);
vf = st.ReverseKeys2(d, vf);
ve = st.ReverseKeys2(d, ve);
st.Sort2(d, v0, v3);
st.Sort2(d, v1, v2);
st.Sort2(d, v4, v7);
st.Sort2(d, v5, v6);
st.Sort2(d, v8, vb);
st.Sort2(d, v9, va);
st.Sort2(d, vc, vf);
st.Sort2(d, vd, ve);
v1 = st.ReverseKeys2(d, v1);
v3 = st.ReverseKeys2(d, v3);
v5 = st.ReverseKeys2(d, v5);
v7 = st.ReverseKeys2(d, v7);
v9 = st.ReverseKeys2(d, v9);
vb = st.ReverseKeys2(d, vb);
vd = st.ReverseKeys2(d, vd);
vf = st.ReverseKeys2(d, vf);
st.Sort2(d, v0, v1);
st.Sort2(d, v2, v3);
st.Sort2(d, v4, v5);
st.Sort2(d, v6, v7);
st.Sort2(d, v8, v9);
st.Sort2(d, va, vb);
st.Sort2(d, vc, vd);
st.Sort2(d, ve, vf);
v0 = st.SortPairsDistance1(d, v0);
v1 = st.SortPairsDistance1(d, v1);
v2 = st.SortPairsDistance1(d, v2);
v3 = st.SortPairsDistance1(d, v3);
v4 = st.SortPairsDistance1(d, v4);
v5 = st.SortPairsDistance1(d, v5);
v6 = st.SortPairsDistance1(d, v6);
v7 = st.SortPairsDistance1(d, v7);
v8 = st.SortPairsDistance1(d, v8);
v9 = st.SortPairsDistance1(d, v9);
va = st.SortPairsDistance1(d, va);
vb = st.SortPairsDistance1(d, vb);
vc = st.SortPairsDistance1(d, vc);
vd = st.SortPairsDistance1(d, vd);
ve = st.SortPairsDistance1(d, ve);
vf = st.SortPairsDistance1(d, vf);
}
template <size_t kKeysPerVector, class D, class Traits, class V = Vec<D>,
HWY_IF_LANES_GT(kKeysPerVector, 2)>
HWY_INLINE void Merge16x4(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4,
V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb,
V& vc, V& vd, V& ve, V& vf) {
vf = st.ReverseKeys4(d, vf);
ve = st.ReverseKeys4(d, ve);
vd = st.ReverseKeys4(d, vd);
vc = st.ReverseKeys4(d, vc);
vb = st.ReverseKeys4(d, vb);
va = st.ReverseKeys4(d, va);
v9 = st.ReverseKeys4(d, v9);
v8 = st.ReverseKeys4(d, v8);
st.Sort2(d, v0, vf);
st.Sort2(d, v1, ve);
st.Sort2(d, v2, vd);
st.Sort2(d, v3, vc);
st.Sort2(d, v4, vb);
st.Sort2(d, v5, va);
st.Sort2(d, v6, v9);
st.Sort2(d, v7, v8);
v7 = st.ReverseKeys4(d, v7);
v6 = st.ReverseKeys4(d, v6);
v5 = st.ReverseKeys4(d, v5);
v4 = st.ReverseKeys4(d, v4);
vf = st.ReverseKeys4(d, vf);
ve = st.ReverseKeys4(d, ve);
vd = st.ReverseKeys4(d, vd);
vc = st.ReverseKeys4(d, vc);
st.Sort2(d, v0, v7);
st.Sort2(d, v1, v6);
st.Sort2(d, v2, v5);
st.Sort2(d, v3, v4);
st.Sort2(d, v8, vf);
st.Sort2(d, v9, ve);
st.Sort2(d, va, vd);
st.Sort2(d, vb, vc);
v3 = st.ReverseKeys4(d, v3);
v2 = st.ReverseKeys4(d, v2);
v7 = st.ReverseKeys4(d, v7);
v6 = st.ReverseKeys4(d, v6);
vb = st.ReverseKeys4(d, vb);
va = st.ReverseKeys4(d, va);
vf = st.ReverseKeys4(d, vf);
ve = st.ReverseKeys4(d, ve);
st.Sort2(d, v0, v3);
st.Sort2(d, v1, v2);
st.Sort2(d, v4, v7);
st.Sort2(d, v5, v6);
st.Sort2(d, v8, vb);
st.Sort2(d, v9, va);
st.Sort2(d, vc, vf);
st.Sort2(d, vd, ve);
v1 = st.ReverseKeys4(d, v1);
v3 = st.ReverseKeys4(d, v3);
v5 = st.ReverseKeys4(d, v5);
v7 = st.ReverseKeys4(d, v7);
v9 = st.ReverseKeys4(d, v9);
vb = st.ReverseKeys4(d, vb);
vd = st.ReverseKeys4(d, vd);
vf = st.ReverseKeys4(d, vf);
st.Sort2(d, v0, v1);
st.Sort2(d, v2, v3);
st.Sort2(d, v4, v5);
st.Sort2(d, v6, v7);
st.Sort2(d, v8, v9);
st.Sort2(d, va, vb);
st.Sort2(d, vc, vd);
st.Sort2(d, ve, vf);
v0 = st.SortPairsReverse4(d, v0);
v1 = st.SortPairsReverse4(d, v1);
v2 = st.SortPairsReverse4(d, v2);
v3 = st.SortPairsReverse4(d, v3);
v4 = st.SortPairsReverse4(d, v4);
v5 = st.SortPairsReverse4(d, v5);
v6 = st.SortPairsReverse4(d, v6);
v7 = st.SortPairsReverse4(d, v7);
v8 = st.SortPairsReverse4(d, v8);
v9 = st.SortPairsReverse4(d, v9);
va = st.SortPairsReverse4(d, va);
vb = st.SortPairsReverse4(d, vb);
vc = st.SortPairsReverse4(d, vc);
vd = st.SortPairsReverse4(d, vd);
ve = st.SortPairsReverse4(d, ve);
vf = st.SortPairsReverse4(d, vf);
v0 = st.SortPairsDistance1(d, v0);
v1 = st.SortPairsDistance1(d, v1);
v2 = st.SortPairsDistance1(d, v2);
v3 = st.SortPairsDistance1(d, v3);
v4 = st.SortPairsDistance1(d, v4);
v5 = st.SortPairsDistance1(d, v5);
v6 = st.SortPairsDistance1(d, v6);
v7 = st.SortPairsDistance1(d, v7);
v8 = st.SortPairsDistance1(d, v8);
v9 = st.SortPairsDistance1(d, v9);
va = st.SortPairsDistance1(d, va);
vb = st.SortPairsDistance1(d, vb);
vc = st.SortPairsDistance1(d, vc);
vd = st.SortPairsDistance1(d, vd);
ve = st.SortPairsDistance1(d, ve);
vf = st.SortPairsDistance1(d, vf);
}
template <size_t kKeysPerVector, class D, class Traits, class V = Vec<D>,
HWY_IF_LANES_GT(kKeysPerVector, 4)>
HWY_INLINE void Merge16x8(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4,
V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb,
V& vc, V& vd, V& ve, V& vf) {
vf = st.ReverseKeys8(d, vf);
ve = st.ReverseKeys8(d, ve);
vd = st.ReverseKeys8(d, vd);
vc = st.ReverseKeys8(d, vc);
vb = st.ReverseKeys8(d, vb);
va = st.ReverseKeys8(d, va);
v9 = st.ReverseKeys8(d, v9);
v8 = st.ReverseKeys8(d, v8);
st.Sort2(d, v0, vf);
st.Sort2(d, v1, ve);
st.Sort2(d, v2, vd);
st.Sort2(d, v3, vc);
st.Sort2(d, v4, vb);
st.Sort2(d, v5, va);
st.Sort2(d, v6, v9);
st.Sort2(d, v7, v8);
v7 = st.ReverseKeys8(d, v7);
v6 = st.ReverseKeys8(d, v6);
v5 = st.ReverseKeys8(d, v5);
v4 = st.ReverseKeys8(d, v4);
vf = st.ReverseKeys8(d, vf);
ve = st.ReverseKeys8(d, ve);
vd = st.ReverseKeys8(d, vd);
vc = st.ReverseKeys8(d, vc);
st.Sort2(d, v0, v7);
st.Sort2(d, v1, v6);
st.Sort2(d, v2, v5);
st.Sort2(d, v3, v4);
st.Sort2(d, v8, vf);
st.Sort2(d, v9, ve);
st.Sort2(d, va, vd);
st.Sort2(d, vb, vc);
v3 = st.ReverseKeys8(d, v3);
v2 = st.ReverseKeys8(d, v2);
v7 = st.ReverseKeys8(d, v7);
v6 = st.ReverseKeys8(d, v6);
vb = st.ReverseKeys8(d, vb);
va = st.ReverseKeys8(d, va);
vf = st.ReverseKeys8(d, vf);
ve = st.ReverseKeys8(d, ve);
st.Sort2(d, v0, v3);
st.Sort2(d, v1, v2);
st.Sort2(d, v4, v7);
st.Sort2(d, v5, v6);
st.Sort2(d, v8, vb);
st.Sort2(d, v9, va);
st.Sort2(d, vc, vf);
st.Sort2(d, vd, ve);
v1 = st.ReverseKeys8(d, v1);
v3 = st.ReverseKeys8(d, v3);
v5 = st.ReverseKeys8(d, v5);
v7 = st.ReverseKeys8(d, v7);
v9 = st.ReverseKeys8(d, v9);
vb = st.ReverseKeys8(d, vb);
vd = st.ReverseKeys8(d, vd);
vf = st.ReverseKeys8(d, vf);
st.Sort2(d, v0, v1);
st.Sort2(d, v2, v3);
st.Sort2(d, v4, v5);
st.Sort2(d, v6, v7);
st.Sort2(d, v8, v9);
st.Sort2(d, va, vb);
st.Sort2(d, vc, vd);
st.Sort2(d, ve, vf);
v0 = st.SortPairsReverse8(d, v0);
v1 = st.SortPairsReverse8(d, v1);
v2 = st.SortPairsReverse8(d, v2);
v3 = st.SortPairsReverse8(d, v3);
v4 = st.SortPairsReverse8(d, v4);
v5 = st.SortPairsReverse8(d, v5);
v6 = st.SortPairsReverse8(d, v6);
v7 = st.SortPairsReverse8(d, v7);
v8 = st.SortPairsReverse8(d, v8);
v9 = st.SortPairsReverse8(d, v9);
va = st.SortPairsReverse8(d, va);
vb = st.SortPairsReverse8(d, vb);
vc = st.SortPairsReverse8(d, vc);
vd = st.SortPairsReverse8(d, vd);
ve = st.SortPairsReverse8(d, ve);
vf = st.SortPairsReverse8(d, vf);
v0 = st.SortPairsDistance2(d, v0);
v1 = st.SortPairsDistance2(d, v1);
v2 = st.SortPairsDistance2(d, v2);
v3 = st.SortPairsDistance2(d, v3);
v4 = st.SortPairsDistance2(d, v4);
v5 = st.SortPairsDistance2(d, v5);
v6 = st.SortPairsDistance2(d, v6);
v7 = st.SortPairsDistance2(d, v7);
v8 = st.SortPairsDistance2(d, v8);
v9 = st.SortPairsDistance2(d, v9);
va = st.SortPairsDistance2(d, va);
vb = st.SortPairsDistance2(d, vb);
vc = st.SortPairsDistance2(d, vc);
vd = st.SortPairsDistance2(d, vd);
ve = st.SortPairsDistance2(d, ve);
vf = st.SortPairsDistance2(d, vf);
v0 = st.SortPairsDistance1(d, v0);
v1 = st.SortPairsDistance1(d, v1);
v2 = st.SortPairsDistance1(d, v2);
v3 = st.SortPairsDistance1(d, v3);
v4 = st.SortPairsDistance1(d, v4);
v5 = st.SortPairsDistance1(d, v5);
v6 = st.SortPairsDistance1(d, v6);
v7 = st.SortPairsDistance1(d, v7);
v8 = st.SortPairsDistance1(d, v8);
v9 = st.SortPairsDistance1(d, v9);
va = st.SortPairsDistance1(d, va);
vb = st.SortPairsDistance1(d, vb);
vc = st.SortPairsDistance1(d, vc);
vd = st.SortPairsDistance1(d, vd);
ve = st.SortPairsDistance1(d, ve);
vf = st.SortPairsDistance1(d, vf);
}
// Unused on MSVC, see below
#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD
template <size_t kKeysPerVector, class D, class Traits, class V = Vec<D>,
HWY_IF_LANES_GT(kKeysPerVector, 8)>
HWY_INLINE void Merge16x16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4,
V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb,
V& vc, V& vd, V& ve, V& vf) {
vf = st.ReverseKeys16(d, vf);
ve = st.ReverseKeys16(d, ve);
vd = st.ReverseKeys16(d, vd);
vc = st.ReverseKeys16(d, vc);
vb = st.ReverseKeys16(d, vb);
va = st.ReverseKeys16(d, va);
v9 = st.ReverseKeys16(d, v9);
v8 = st.ReverseKeys16(d, v8);
st.Sort2(d, v0, vf);
st.Sort2(d, v1, ve);
st.Sort2(d, v2, vd);
st.Sort2(d, v3, vc);
st.Sort2(d, v4, vb);
st.Sort2(d, v5, va);
st.Sort2(d, v6, v9);
st.Sort2(d, v7, v8);
v7 = st.ReverseKeys16(d, v7);
v6 = st.ReverseKeys16(d, v6);
v5 = st.ReverseKeys16(d, v5);
v4 = st.ReverseKeys16(d, v4);
vf = st.ReverseKeys16(d, vf);
ve = st.ReverseKeys16(d, ve);
vd = st.ReverseKeys16(d, vd);
vc = st.ReverseKeys16(d, vc);
st.Sort2(d, v0, v7);
st.Sort2(d, v1, v6);
st.Sort2(d, v2, v5);
st.Sort2(d, v3, v4);
st.Sort2(d, v8, vf);
st.Sort2(d, v9, ve);
st.Sort2(d, va, vd);
st.Sort2(d, vb, vc);
v3 = st.ReverseKeys16(d, v3);
v2 = st.ReverseKeys16(d, v2);
v7 = st.ReverseKeys16(d, v7);
v6 = st.ReverseKeys16(d, v6);
vb = st.ReverseKeys16(d, vb);
va = st.ReverseKeys16(d, va);
vf = st.ReverseKeys16(d, vf);
ve = st.ReverseKeys16(d, ve);
st.Sort2(d, v0, v3);
st.Sort2(d, v1, v2);
st.Sort2(d, v4, v7);
st.Sort2(d, v5, v6);
st.Sort2(d, v8, vb);
st.Sort2(d, v9, va);
st.Sort2(d, vc, vf);
st.Sort2(d, vd, ve);
v1 = st.ReverseKeys16(d, v1);
v3 = st.ReverseKeys16(d, v3);
v5 = st.ReverseKeys16(d, v5);
v7 = st.ReverseKeys16(d, v7);
v9 = st.ReverseKeys16(d, v9);
vb = st.ReverseKeys16(d, vb);
vd = st.ReverseKeys16(d, vd);
vf = st.ReverseKeys16(d, vf);
st.Sort2(d, v0, v1);
st.Sort2(d, v2, v3);
st.Sort2(d, v4, v5);
st.Sort2(d, v6, v7);
st.Sort2(d, v8, v9);
st.Sort2(d, va, vb);
st.Sort2(d, vc, vd);
st.Sort2(d, ve, vf);
v0 = st.SortPairsReverse16(d, v0);
v1 = st.SortPairsReverse16(d, v1);
v2 = st.SortPairsReverse16(d, v2);
v3 = st.SortPairsReverse16(d, v3);
v4 = st.SortPairsReverse16(d, v4);
v5 = st.SortPairsReverse16(d, v5);
v6 = st.SortPairsReverse16(d, v6);
v7 = st.SortPairsReverse16(d, v7);
v8 = st.SortPairsReverse16(d, v8);
v9 = st.SortPairsReverse16(d, v9);
va = st.SortPairsReverse16(d, va);
vb = st.SortPairsReverse16(d, vb);
vc = st.SortPairsReverse16(d, vc);
vd = st.SortPairsReverse16(d, vd);
ve = st.SortPairsReverse16(d, ve);
vf = st.SortPairsReverse16(d, vf);
v0 = st.SortPairsDistance4(d, v0);
v1 = st.SortPairsDistance4(d, v1);
v2 = st.SortPairsDistance4(d, v2);
v3 = st.SortPairsDistance4(d, v3);
v4 = st.SortPairsDistance4(d, v4);
v5 = st.SortPairsDistance4(d, v5);
v6 = st.SortPairsDistance4(d, v6);
v7 = st.SortPairsDistance4(d, v7);
v8 = st.SortPairsDistance4(d, v8);
v9 = st.SortPairsDistance4(d, v9);
va = st.SortPairsDistance4(d, va);
vb = st.SortPairsDistance4(d, vb);
vc = st.SortPairsDistance4(d, vc);
vd = st.SortPairsDistance4(d, vd);
ve = st.SortPairsDistance4(d, ve);
vf = st.SortPairsDistance4(d, vf);
v0 = st.SortPairsDistance2(d, v0);
v1 = st.SortPairsDistance2(d, v1);
v2 = st.SortPairsDistance2(d, v2);
v3 = st.SortPairsDistance2(d, v3);
v4 = st.SortPairsDistance2(d, v4);
v5 = st.SortPairsDistance2(d, v5);
v6 = st.SortPairsDistance2(d, v6);
v7 = st.SortPairsDistance2(d, v7);
v8 = st.SortPairsDistance2(d, v8);
v9 = st.SortPairsDistance2(d, v9);
va = st.SortPairsDistance2(d, va);
vb = st.SortPairsDistance2(d, vb);
vc = st.SortPairsDistance2(d, vc);
vd = st.SortPairsDistance2(d, vd);
ve = st.SortPairsDistance2(d, ve);
vf = st.SortPairsDistance2(d, vf);
v0 = st.SortPairsDistance1(d, v0);
v1 = st.SortPairsDistance1(d, v1);
v2 = st.SortPairsDistance1(d, v2);
v3 = st.SortPairsDistance1(d, v3);
v4 = st.SortPairsDistance1(d, v4);
v5 = st.SortPairsDistance1(d, v5);
v6 = st.SortPairsDistance1(d, v6);
v7 = st.SortPairsDistance1(d, v7);
v8 = st.SortPairsDistance1(d, v8);
v9 = st.SortPairsDistance1(d, v9);
va = st.SortPairsDistance1(d, va);
vb = st.SortPairsDistance1(d, vb);
vc = st.SortPairsDistance1(d, vc);
vd = st.SortPairsDistance1(d, vd);
ve = st.SortPairsDistance1(d, ve);
vf = st.SortPairsDistance1(d, vf);
}
#endif // !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD
// Reshapes `buf` into a matrix, sorts columns independently, and then merges
// into a sorted 1D array without transposing.
//
// DEPRECATED, use BaseCase() instead.
template <class Traits, class V>
HWY_INLINE void SortingNetwork(Traits st, size_t cols, V& v0, V& v1, V& v2,
V& v3, V& v4, V& v5, V& v6, V& v7, V& v8, V& v9,
V& va, V& vb, V& vc, V& vd, V& ve, V& vf) {
// traits*-inl assume 'full' vectors (but still capped to kMaxCols).
const CappedTag<typename Traits::LaneType, Constants::kMaxCols> d;
HWY_DASSERT(cols <= Constants::kMaxCols);
// The network width depends on the number of keys, not lanes.
constexpr size_t kLanesPerKey = st.LanesPerKey();
const size_t keys = cols / kLanesPerKey;
constexpr size_t kMaxKeys = MaxLanes(d) / kLanesPerKey;
Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf);
// Checking MaxLanes avoids generating HWY_ASSERT code for the unreachable
// code paths: if MaxLanes < 2, then keys <= cols < 2.
if (HWY_LIKELY(keys >= 2 && kMaxKeys >= 2)) {
Merge16x2<kMaxKeys>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
vc, vd, ve, vf);
if (HWY_LIKELY(keys >= 4 && kMaxKeys >= 4)) {
Merge16x4<kMaxKeys>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
vc, vd, ve, vf);
if (HWY_LIKELY(keys >= 8 && kMaxKeys >= 8)) {
Merge16x8<kMaxKeys>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va,
vb, vc, vd, ve, vf);
// Avoids build timeout. Must match #if condition in kMaxCols.
#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD
if (HWY_LIKELY(keys >= 16 && kMaxKeys >= 16)) {
Merge16x16<kMaxKeys>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9,
va, vb, vc, vd, ve, vf);
static_assert(Constants::kMaxCols <= 16, "Add more branches");
}
#endif
}
}
}
}
// As above, but loads from/stores to `buf`. This ensures full vectors are
// aligned, and enables loads/stores without bounds checks.
//
// DEPRECATED, use BaseCase() instead.
template <class Traits, typename T>
HWY_NOINLINE void SortingNetwork(Traits st, T* HWY_RESTRICT buf, size_t cols) {
// traits*-inl assume 'full' vectors (but still capped to kMaxCols).
// However, for smaller arrays and sub-maximal `cols` we have overlapping
// loads where only the lowest `cols` are valid, and we skip Merge16 etc.
const CappedTag<T, Constants::kMaxCols> d;
using V = decltype(Zero(d));
HWY_DASSERT(cols <= Constants::kMaxCols);
// These are aligned iff cols == Lanes(d). We prefer unaligned/non-constexpr
// offsets to duplicating this code for every value of cols.
static_assert(Constants::kMaxRows == 16, "Update loads/stores/args");
V v0 = LoadU(d, buf + 0x0 * cols);
V v1 = LoadU(d, buf + 0x1 * cols);
V v2 = LoadU(d, buf + 0x2 * cols);
V v3 = LoadU(d, buf + 0x3 * cols);
V v4 = LoadU(d, buf + 0x4 * cols);
V v5 = LoadU(d, buf + 0x5 * cols);
V v6 = LoadU(d, buf + 0x6 * cols);
V v7 = LoadU(d, buf + 0x7 * cols);
V v8 = LoadU(d, buf + 0x8 * cols);
V v9 = LoadU(d, buf + 0x9 * cols);
V va = LoadU(d, buf + 0xa * cols);
V vb = LoadU(d, buf + 0xb * cols);
V vc = LoadU(d, buf + 0xc * cols);
V vd = LoadU(d, buf + 0xd * cols);
V ve = LoadU(d, buf + 0xe * cols);
V vf = LoadU(d, buf + 0xf * cols);
SortingNetwork(st, cols, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc,
vd, ve, vf);
StoreU(v0, d, buf + 0x0 * cols);
StoreU(v1, d, buf + 0x1 * cols);
StoreU(v2, d, buf + 0x2 * cols);
StoreU(v3, d, buf + 0x3 * cols);
StoreU(v4, d, buf + 0x4 * cols);
StoreU(v5, d, buf + 0x5 * cols);
StoreU(v6, d, buf + 0x6 * cols);
StoreU(v7, d, buf + 0x7 * cols);
StoreU(v8, d, buf + 0x8 * cols);
StoreU(v9, d, buf + 0x9 * cols);
StoreU(va, d, buf + 0xa * cols);
StoreU(vb, d, buf + 0xb * cols);
StoreU(vc, d, buf + 0xc * cols);
StoreU(vd, d, buf + 0xd * cols);
StoreU(ve, d, buf + 0xe * cols);
StoreU(vf, d, buf + 0xf * cols);
}
#else
template <class Base>
struct SharedTraits : public Base {};
#endif // VQSORT_ENABLED
} // namespace detail
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE

View File

@@ -0,0 +1,618 @@
// Copyright 2021 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Per-target
#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE) == \
defined(HWY_TARGET_TOGGLE)
#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
#else
#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
#endif
#include <stddef.h>
#include <stdint.h>
#include "third_party/highway/hwy/contrib/sort/order.h" // SortDescending
#include "third_party/highway/hwy/contrib/sort/shared-inl.h" // SortConstants
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
namespace detail {
// Base class of both KeyLane variants
template <typename LaneTypeArg, typename KeyTypeArg>
struct KeyLaneBase {
static constexpr bool Is128() { return false; }
constexpr size_t LanesPerKey() const { return 1; }
// What type bench_sort should allocate for generating inputs.
using LaneType = LaneTypeArg;
// What type to pass to VQSort.
using KeyType = KeyTypeArg;
const char* KeyString() const {
return IsSame<KeyTypeArg, float16_t>() ? "f16"
: IsSame<KeyTypeArg, float>() ? "f32"
: IsSame<KeyTypeArg, double>() ? "f64"
: IsSame<KeyTypeArg, int16_t>() ? "i16"
: IsSame<KeyTypeArg, int32_t>() ? "i32"
: IsSame<KeyTypeArg, int64_t>() ? "i64"
: IsSame<KeyTypeArg, uint16_t>() ? "u32"
: IsSame<KeyTypeArg, uint32_t>() ? "u32"
: IsSame<KeyTypeArg, uint64_t>() ? "u64"
: IsSame<KeyTypeArg, hwy::K32V32>() ? "k+v=64"
: "?";
}
};
// Wrapper functions so we can specialize for floats - infinity trumps
// HighestValue (the normal value with the largest magnitude). Must be outside
// Order* classes to enable SFINAE.
template <class D, HWY_IF_FLOAT_OR_SPECIAL_D(D)>
Vec<D> LargestSortValue(D d) {
return Inf(d);
}
template <class D, HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
Vec<D> LargestSortValue(D d) {
return Set(d, hwy::HighestValue<TFromD<D>>());
}
template <class D, HWY_IF_FLOAT_OR_SPECIAL_D(D)>
Vec<D> SmallestSortValue(D d) {
return Neg(Inf(d));
}
template <class D, HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
Vec<D> SmallestSortValue(D d) {
return Set(d, hwy::LowestValue<TFromD<D>>());
}
// Returns the next distinct larger value unless already +inf.
template <class D, HWY_IF_FLOAT_OR_SPECIAL_D(D)>
Vec<D> LargerSortValue(D d, Vec<D> v) {
HWY_DASSERT(AllFalse(d, IsNaN(v))); // we replaced all NaN with LastValue.
using T = TFromD<decltype(d)>;
const RebindToUnsigned<D> du;
using VU = Vec<decltype(du)>;
using TU = TFromD<decltype(du)>;
const VU vu = BitCast(du, Abs(v));
// The direction depends on the original sign. Integer comparison is cheaper
// than float comparison and treats -0 as 0 (so we return +epsilon).
const Mask<decltype(du)> was_pos = Le(BitCast(du, v), SignBit(du));
// If positive, add 1, else -1.
const VU add = IfThenElse(was_pos, Set(du, 1u), Set(du, LimitsMax<TU>()));
// Prev/next integer is the prev/next value, even if mantissa under/overflows.
v = BitCast(d, Add(vu, add));
// But we may have overflowed into inf or NaN; replace with inf if positive,
// but the largest (later negated!) value if the input was -inf.
const Mask<D> was_pos_f = RebindMask(d, was_pos);
v = IfThenElse(IsFinite(v), v,
IfThenElse(was_pos_f, Inf(d), Set(d, HighestValue<T>())));
// Restore the original sign - not via CopySignToAbs because we used a mask.
return IfThenElse(was_pos_f, v, Neg(v));
}
// Returns the next distinct smaller value unless already -inf.
template <class D, HWY_IF_FLOAT_OR_SPECIAL_D(D)>
Vec<D> SmallerSortValue(D d, Vec<D> v) {
HWY_DASSERT(AllFalse(d, IsNaN(v))); // we replaced all NaN with LastValue.
using T = TFromD<decltype(d)>;
const RebindToUnsigned<D> du;
using VU = Vec<decltype(du)>;
using TU = TFromD<decltype(du)>;
const VU vu = BitCast(du, Abs(v));
// The direction depends on the original sign. Float comparison because we
// want to treat 0 as -0 so we return -epsilon.
const Mask<D> was_pos = Gt(v, Zero(d));
// If positive, add -1, else 1.
const VU add =
IfThenElse(RebindMask(du, was_pos), Set(du, LimitsMax<TU>()), Set(du, 1));
// Prev/next integer is the prev/next value, even if mantissa under/overflows.
v = BitCast(d, Add(vu, add));
// But we may have overflowed into inf or NaN; replace with +inf (which will
// later be negated) if negative, but the largest value if the input was +inf.
v = IfThenElse(IsFinite(v), v,
IfThenElse(was_pos, Set(d, HighestValue<T>()), Inf(d)));
// Restore the original sign - not via CopySignToAbs because we used a mask.
return IfThenElse(was_pos, v, Neg(v));
}
template <class D, HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
Vec<D> LargerSortValue(D d, Vec<D> v) {
return Add(v, Set(d, TFromD<D>{1}));
}
template <class D, HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
Vec<D> SmallerSortValue(D d, Vec<D> v) {
return Sub(v, Set(d, TFromD<D>{1}));
}
// Highway does not provide a lane type for 128-bit keys, so we use uint64_t
// along with an abstraction layer for single-lane vs. lane-pair, which is
// independent of the order.
template <typename LaneType, typename KeyType>
struct KeyLane : public KeyLaneBase<LaneType, KeyType> {
// For HeapSort
HWY_INLINE void Swap(LaneType* a, LaneType* b) const {
const LaneType temp = *a;
*a = *b;
*b = temp;
}
template <class V, class M>
HWY_INLINE V CompressKeys(V keys, M mask) const {
return CompressNot(keys, mask);
}
// Broadcasts one key into a vector
template <class D>
HWY_INLINE Vec<D> SetKey(D d, const LaneType* key) const {
return Set(d, *key);
}
template <class D>
HWY_INLINE Mask<D> EqualKeys(D /*tag*/, Vec<D> a, Vec<D> b) const {
return Eq(a, b);
}
template <class D>
HWY_INLINE Mask<D> NotEqualKeys(D /*tag*/, Vec<D> a, Vec<D> b) const {
return Ne(a, b);
}
// For keys=lanes, any difference counts.
template <class D>
HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec<D> diff) const {
// Must avoid floating-point comparisons (for -0)
const RebindToUnsigned<D> du;
return AllTrue(du, Eq(BitCast(du, diff), Zero(du)));
}
HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const {
return *a == *b;
}
template <class D>
HWY_INLINE Vec<D> ReverseKeys(D d, Vec<D> v) const {
return Reverse(d, v);
}
template <class D>
HWY_INLINE Vec<D> ReverseKeys2(D d, Vec<D> v) const {
return Reverse2(d, v);
}
template <class D>
HWY_INLINE Vec<D> ReverseKeys4(D d, Vec<D> v) const {
return Reverse4(d, v);
}
template <class D>
HWY_INLINE Vec<D> ReverseKeys8(D d, Vec<D> v) const {
return Reverse8(d, v);
}
template <class D>
HWY_INLINE Vec<D> ReverseKeys16(D d, Vec<D> v) const {
static_assert(SortConstants::kMaxCols <= 16, "Assumes u32x16 = 512 bit");
return ReverseKeys(d, v);
}
template <class V>
HWY_INLINE V OddEvenKeys(const V odd, const V even) const {
return OddEven(odd, even);
}
template <class D, HWY_IF_T_SIZE_D(D, 2)>
HWY_INLINE Vec<D> SwapAdjacentPairs(D d, const Vec<D> v) const {
const Repartition<uint32_t, D> du32;
return BitCast(d, Shuffle2301(BitCast(du32, v)));
}
template <class D, HWY_IF_T_SIZE_D(D, 4)>
HWY_INLINE Vec<D> SwapAdjacentPairs(D /* tag */, const Vec<D> v) const {
return Shuffle1032(v);
}
template <class D, HWY_IF_T_SIZE_D(D, 8)>
HWY_INLINE Vec<D> SwapAdjacentPairs(D /* tag */, const Vec<D> v) const {
return SwapAdjacentBlocks(v);
}
template <class D, HWY_IF_NOT_T_SIZE_D(D, 8)>
HWY_INLINE Vec<D> SwapAdjacentQuads(D d, const Vec<D> v) const {
#if HWY_HAVE_FLOAT64 // in case D is float32
const RepartitionToWide<D> dw;
#else
const RepartitionToWide<RebindToUnsigned<D>> dw;
#endif
return BitCast(d, SwapAdjacentPairs(dw, BitCast(dw, v)));
}
template <class D, HWY_IF_T_SIZE_D(D, 8)>
HWY_INLINE Vec<D> SwapAdjacentQuads(D d, const Vec<D> v) const {
// Assumes max vector size = 512
return ConcatLowerUpper(d, v, v);
}
template <class D, HWY_IF_NOT_T_SIZE_D(D, 8)>
HWY_INLINE Vec<D> OddEvenPairs(D d, const Vec<D> odd,
const Vec<D> even) const {
#if HWY_HAVE_FLOAT64 // in case D is float32
const RepartitionToWide<D> dw;
#else
const RepartitionToWide<RebindToUnsigned<D>> dw;
#endif
return BitCast(d, OddEven(BitCast(dw, odd), BitCast(dw, even)));
}
template <class D, HWY_IF_T_SIZE_D(D, 8)>
HWY_INLINE Vec<D> OddEvenPairs(D /* tag */, Vec<D> odd, Vec<D> even) const {
return OddEvenBlocks(odd, even);
}
template <class D, HWY_IF_NOT_T_SIZE_D(D, 8)>
HWY_INLINE Vec<D> OddEvenQuads(D d, Vec<D> odd, Vec<D> even) const {
#if HWY_HAVE_FLOAT64 // in case D is float32
const RepartitionToWide<D> dw;
#else
const RepartitionToWide<RebindToUnsigned<D>> dw;
#endif
return BitCast(d, OddEvenPairs(dw, BitCast(dw, odd), BitCast(dw, even)));
}
template <class D, HWY_IF_T_SIZE_D(D, 8)>
HWY_INLINE Vec<D> OddEvenQuads(D d, Vec<D> odd, Vec<D> even) const {
return ConcatUpperLower(d, odd, even);
}
};
// Anything order-related depends on the key traits *and* the order (see
// FirstOfLanes). We cannot implement just one Compare function because Lt128
// only compiles if the lane type is u64. Thus we need either overloaded
// functions with a tag type, class specializations, or separate classes.
// We avoid overloaded functions because we want all functions to be callable
// from a SortTraits without per-function wrappers. Specializing would work, but
// we are anyway going to specialize at a higher level.
template <typename T>
struct OrderAscending : public KeyLane<T, T> {
// False indicates the entire key (i.e. lane) should be compared. KV stands
// for key-value.
static constexpr bool IsKV() { return false; }
using Order = SortAscending;
using OrderForSortingNetwork = OrderAscending<T>;
HWY_INLINE bool Compare1(const T* a, const T* b) const { return *a < *b; }
template <class D>
HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const {
return Lt(a, b);
}
// Two halves of Sort2, used in ScanMinMax.
template <class D>
HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const {
return Min(a, b);
}
template <class D>
HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const {
return Max(a, b);
}
template <class D>
HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
T* HWY_RESTRICT /* buf */) const {
return MinOfLanes(d, v);
}
template <class D>
HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
T* HWY_RESTRICT /* buf */) const {
return MaxOfLanes(d, v);
}
template <class D>
HWY_INLINE Vec<D> FirstValue(D d) const {
return SmallestSortValue(d);
}
template <class D>
HWY_INLINE Vec<D> LastValue(D d) const {
return LargestSortValue(d);
}
template <class D>
HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const {
return SmallerSortValue(d, v);
}
};
template <typename T>
struct OrderDescending : public KeyLane<T, T> {
// False indicates the entire key (i.e. lane) should be compared. KV stands
// for key-value.
static constexpr bool IsKV() { return false; }
using Order = SortDescending;
using OrderForSortingNetwork = OrderDescending<T>;
HWY_INLINE bool Compare1(const T* a, const T* b) const { return *b < *a; }
template <class D>
HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const {
return Lt(b, a);
}
template <class D>
HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const {
return Max(a, b);
}
template <class D>
HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const {
return Min(a, b);
}
template <class D>
HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
T* HWY_RESTRICT /* buf */) const {
return MaxOfLanes(d, v);
}
template <class D>
HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
T* HWY_RESTRICT /* buf */) const {
return MinOfLanes(d, v);
}
template <class D>
HWY_INLINE Vec<D> FirstValue(D d) const {
return LargestSortValue(d);
}
template <class D>
HWY_INLINE Vec<D> LastValue(D d) const {
return SmallestSortValue(d);
}
template <class D>
HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const {
return LargerSortValue(d, v);
}
};
struct KeyValue64 : public KeyLane<uint64_t, hwy::K32V32> {
// True indicates only part of the key (i.e. lane) should be compared. KV
// stands for key-value.
static constexpr bool IsKV() { return true; }
template <class D>
HWY_INLINE Mask<D> EqualKeys(D /*tag*/, Vec<D> a, Vec<D> b) const {
return Eq(ShiftRight<32>(a), ShiftRight<32>(b));
}
template <class D>
HWY_INLINE Mask<D> NotEqualKeys(D /*tag*/, Vec<D> a, Vec<D> b) const {
return Ne(ShiftRight<32>(a), ShiftRight<32>(b));
}
HWY_INLINE bool Equal1(const uint64_t* a, const uint64_t* b) const {
return (*a >> 32) == (*b >> 32);
}
// Only count differences in the actual key, not the value.
template <class D>
HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec<D> diff) const {
// Must avoid floating-point comparisons (for -0)
const RebindToUnsigned<D> du;
const Vec<decltype(du)> zero = Zero(du);
const Vec<decltype(du)> keys = ShiftRight<32>(diff); // clear values
return AllTrue(du, Eq(BitCast(du, keys), zero));
}
};
struct OrderAscendingKV64 : public KeyValue64 {
using Order = SortAscending;
using OrderForSortingNetwork = OrderAscending<LaneType>;
HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const {
return (*a >> 32) < (*b >> 32);
}
template <class D>
HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const {
return Lt(ShiftRight<32>(a), ShiftRight<32>(b));
}
// Not required to be stable (preserving the order of equivalent keys), so
// we can include the value in the comparison.
template <class D>
HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const {
return Min(a, b);
}
template <class D>
HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const {
return Max(a, b);
}
template <class D>
HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
uint64_t* HWY_RESTRICT /* buf */) const {
return MinOfLanes(d, v);
}
template <class D>
HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
uint64_t* HWY_RESTRICT /* buf */) const {
return MaxOfLanes(d, v);
}
// Same as for regular lanes.
template <class D>
HWY_INLINE Vec<D> FirstValue(D d) const {
return Set(d, hwy::LowestValue<TFromD<D>>());
}
template <class D>
HWY_INLINE Vec<D> LastValue(D d) const {
return Set(d, hwy::HighestValue<TFromD<D>>());
}
template <class D>
HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const {
return Sub(v, Set(d, uint64_t{1} << 32));
}
};
struct OrderDescendingKV64 : public KeyValue64 {
using Order = SortDescending;
using OrderForSortingNetwork = OrderDescending<LaneType>;
HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const {
return (*b >> 32) < (*a >> 32);
}
template <class D>
HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const {
return Lt(ShiftRight<32>(b), ShiftRight<32>(a));
}
// Not required to be stable (preserving the order of equivalent keys), so
// we can include the value in the comparison.
template <class D>
HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const {
return Max(a, b);
}
template <class D>
HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const {
return Min(a, b);
}
template <class D>
HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
uint64_t* HWY_RESTRICT /* buf */) const {
return MaxOfLanes(d, v);
}
template <class D>
HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
uint64_t* HWY_RESTRICT /* buf */) const {
return MinOfLanes(d, v);
}
template <class D>
HWY_INLINE Vec<D> FirstValue(D d) const {
return Set(d, hwy::HighestValue<TFromD<D>>());
}
template <class D>
HWY_INLINE Vec<D> LastValue(D d) const {
return Set(d, hwy::LowestValue<TFromD<D>>());
}
template <class D>
HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const {
return Add(v, Set(d, uint64_t{1} << 32));
}
};
// Shared code that depends on Order.
template <class Base>
struct TraitsLane : public Base {
using TraitsForSortingNetwork =
TraitsLane<typename Base::OrderForSortingNetwork>;
// For each lane i: replaces a[i] with the first and b[i] with the second
// according to Base.
// Corresponds to a conditional swap, which is one "node" of a sorting
// network. Min/Max are cheaper than compare + blend at least for integers.
template <class D>
HWY_INLINE void Sort2(D d, Vec<D>& a, Vec<D>& b) const {
const Base* base = static_cast<const Base*>(this);
const Vec<D> a_copy = a;
// Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4
// instructions. We can reduce it to a compare + 2 IfThenElse.
#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3
if (sizeof(TFromD<D>) == 8) {
const Mask<D> cmp = base->Compare(d, a, b);
a = IfThenElse(cmp, a, b);
b = IfThenElse(cmp, b, a_copy);
return;
}
#endif
a = base->First(d, a, b);
b = base->Last(d, a_copy, b);
}
// Conditionally swaps even-numbered lanes with their odd-numbered neighbor.
template <class D, HWY_IF_T_SIZE_D(D, 8)>
HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const {
const Base* base = static_cast<const Base*>(this);
Vec<D> swapped = base->ReverseKeys2(d, v);
// Further to the above optimization, Sort2+OddEvenKeys compile to four
// instructions; we can save one by combining two blends.
#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3
const Vec<D> cmp = VecFromMask(d, base->Compare(d, v, swapped));
return IfVecThenElse(DupOdd(cmp), swapped, v);
#else
Sort2(d, v, swapped);
return base->OddEvenKeys(swapped, v);
#endif
}
// (See above - we use Sort2 for non-64-bit types.)
template <class D, HWY_IF_NOT_T_SIZE_D(D, 8)>
HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const {
const Base* base = static_cast<const Base*>(this);
Vec<D> swapped = base->ReverseKeys2(d, v);
Sort2(d, v, swapped);
return base->OddEvenKeys(swapped, v);
}
// Swaps with the vector formed by reversing contiguous groups of 4 keys.
template <class D>
HWY_INLINE Vec<D> SortPairsReverse4(D d, Vec<D> v) const {
const Base* base = static_cast<const Base*>(this);
Vec<D> swapped = base->ReverseKeys4(d, v);
Sort2(d, v, swapped);
return base->OddEvenPairs(d, swapped, v);
}
// Conditionally swaps lane 0 with 4, 1 with 5 etc.
template <class D>
HWY_INLINE Vec<D> SortPairsDistance4(D d, Vec<D> v) const {
const Base* base = static_cast<const Base*>(this);
Vec<D> swapped = base->SwapAdjacentQuads(d, v);
// Only used in Merge16, so this will not be used on AVX2 (which only has 4
// u64 lanes), so skip the above optimization for 64-bit AVX2.
Sort2(d, v, swapped);
return base->OddEvenQuads(d, swapped, v);
}
};
} // namespace detail
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE

View File

@@ -0,0 +1,549 @@
// Copyright 2021 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Per-target
#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE) == \
defined(HWY_TARGET_TOGGLE)
#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
#else
#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
#endif
#include <stddef.h>
#include <stdint.h>
#include "third_party/highway/hwy/contrib/sort/order.h" // SortDescending
#include "third_party/highway/hwy/contrib/sort/shared-inl.h"
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
namespace detail {
// Also used by HeapSort, so do not require VQSORT_ENABLED.
#if HWY_TARGET != HWY_SCALAR || HWY_IDE
// Highway does not provide a lane type for 128-bit keys, so we use uint64_t
// along with an abstraction layer for single-lane vs. lane-pair, which is
// independent of the order.
struct KeyAny128 {
static constexpr bool Is128() { return true; }
constexpr size_t LanesPerKey() const { return 2; }
// What type bench_sort should allocate for generating inputs.
using LaneType = uint64_t;
// KeyType and KeyString are defined by derived classes.
HWY_INLINE void Swap(LaneType* a, LaneType* b) const {
const FixedTag<LaneType, 2> d;
const auto temp = LoadU(d, a);
StoreU(LoadU(d, b), d, a);
StoreU(temp, d, b);
}
template <class V, class M>
HWY_INLINE V CompressKeys(V keys, M mask) const {
return CompressBlocksNot(keys, mask);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> SetKey(D d, const TFromD<D>* key) const {
return LoadDup128(d, key);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> ReverseKeys(D d, Vec<D> v) const {
return ReverseBlocks(d, v);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> ReverseKeys2(D /* tag */, const Vec<D> v) const {
HWY_DASSERT(Lanes(D()) >= 4); // at least 2 keys
return SwapAdjacentBlocks(v);
}
// Only called for 4 keys because we do not support >512-bit vectors.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> ReverseKeys4(D d, const Vec<D> v) const {
HWY_DASSERT(Lanes(D()) == 8); // exactly 4 keys: the 512-bit limit
return ReverseKeys(d, v);
}
// Only called for 4 keys because we do not support >512-bit vectors.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> OddEvenPairs(D d, const Vec<D> odd,
const Vec<D> even) const {
HWY_DASSERT(Lanes(D()) == 8); // exactly 4 keys: the 512-bit limit
return ConcatUpperLower(d, odd, even);
}
template <class V>
HWY_INLINE V OddEvenKeys(const V odd, const V even) const {
return OddEvenBlocks(odd, even);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> ReverseKeys8(D, Vec<D>) const {
HWY_ASSERT(0); // not supported: would require 1024-bit vectors
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> ReverseKeys16(D, Vec<D>) const {
HWY_ASSERT(0); // not supported: would require 2048-bit vectors
}
// This is only called for 8/16 col networks (not supported).
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> SwapAdjacentPairs(D, Vec<D>) const {
HWY_ASSERT(0);
}
// This is only called for 16 col networks (not supported).
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> SwapAdjacentQuads(D, Vec<D>) const {
HWY_ASSERT(0);
}
// This is only called for 8 col networks (not supported).
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> OddEvenQuads(D, Vec<D>, Vec<D>) const {
HWY_ASSERT(0);
}
};
// Base class shared between OrderAscending128, OrderDescending128.
struct Key128 : public KeyAny128 {
// False indicates the entire key should be compared. KV means key-value.
static constexpr bool IsKV() { return false; }
// What type to pass to VQSort.
using KeyType = hwy::uint128_t;
const char* KeyString() const { return "U128"; }
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Mask<D> EqualKeys(D d, Vec<D> a, Vec<D> b) const {
return Eq128(d, a, b);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Mask<D> NotEqualKeys(D d, Vec<D> a, Vec<D> b) const {
return Ne128(d, a, b);
}
// For keys=entire 128 bits, any difference counts.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec<D> diff) const {
// Must avoid floating-point comparisons (for -0)
const RebindToUnsigned<D> du;
return AllTrue(du, Eq(BitCast(du, diff), Zero(du)));
}
HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const {
return a[0] == b[0] && a[1] == b[1];
}
// Returns vector with only the top half of each block valid. This allows
// fusing the "replicate upper to lower half" step with a subsequent permute.
template <class Order, class D>
HWY_INLINE HWY_MAYBE_UNUSED Vec<D> CompareTop(D d, Vec<D> a, Vec<D> b) const {
const Mask<D> eqHL = Eq(a, b);
const Vec<D> ltHL = VecFromMask(d, Order().CompareLanes(a, b));
#if HWY_TARGET <= HWY_AVX2 // slightly faster
const Vec<D> ltLX = ShiftLeftLanes<1>(ltHL);
return OrAnd(ltHL, VecFromMask(d, eqHL), ltLX);
#else
return IfThenElse(eqHL, DupEven(ltHL), ltHL);
#endif
}
};
// Anything order-related depends on the key traits *and* the order (see
// FirstOfLanes). We cannot implement just one Compare function because Lt128
// only compiles if the lane type is u64. Thus we need either overloaded
// functions with a tag type, class specializations, or separate classes.
// We avoid overloaded functions because we want all functions to be callable
// from a SortTraits without per-function wrappers. Specializing would work, but
// we are anyway going to specialize at a higher level.
struct OrderAscending128 : public Key128 {
using Order = SortAscending;
using OrderForSortingNetwork = OrderAscending128;
HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const {
return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1];
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
return Lt128(d, a, b);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
return Min128(d, a, b);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
return Max128(d, a, b);
}
// FirstOfLanes/LastOfLanes are implemented in Traits128.
// Same as for regular lanes because 128-bit keys are u64.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> FirstValue(D d) const {
return Set(d, hwy::LowestValue<TFromD<D> >());
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> LastValue(D d) const {
return Set(d, hwy::HighestValue<TFromD<D> >());
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const {
const Vec<D> k0 = Zero(d);
const Vec<D> k1 = OddEven(k0, Set(d, uint64_t{1}));
const Mask<D> borrow = Eq(v, k0); // don't-care, lo == 0
// lo == 0? 1 : 0, 0
const Vec<D> adjust = ShiftLeftLanes<1>(IfThenElseZero(borrow, k1));
return Sub(Sub(v, k1), adjust);
}
// 'Private', used by base class Key128::CompareTop.
template <class V>
HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
return Lt(a, b);
}
};
struct OrderDescending128 : public Key128 {
using Order = SortDescending;
using OrderForSortingNetwork = OrderDescending128;
HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const {
return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1];
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
return Lt128(d, b, a);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
return Max128(d, a, b);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
return Min128(d, a, b);
}
// FirstOfLanes/LastOfLanes are implemented in Traits128.
// Same as for regular lanes because 128-bit keys are u64.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> FirstValue(D d) const {
return Set(d, hwy::HighestValue<TFromD<D> >());
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> LastValue(D d) const {
return Set(d, hwy::LowestValue<TFromD<D> >());
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const {
const Vec<D> k1 = OddEven(Zero(d), Set(d, uint64_t{1}));
const Vec<D> added = Add(v, k1);
const Mask<D> overflowed = Lt(added, v); // false, overflowed
// overflowed? 1 : 0, 0
const Vec<D> adjust = ShiftLeftLanes<1>(IfThenElseZero(overflowed, k1));
return Add(added, adjust);
}
// 'Private', used by base class Key128::CompareTop.
template <class V>
HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
return Lt(b, a);
}
};
// Base class shared between OrderAscendingKV128, OrderDescendingKV128.
struct KeyValue128 : public KeyAny128 {
// True indicates only part of the key (the more significant lane) should be
// compared. KV stands for key-value.
static constexpr bool IsKV() { return true; }
// What type to pass to VQSort.
using KeyType = K64V64;
const char* KeyString() const { return "k+v=128"; }
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Mask<D> EqualKeys(D d, Vec<D> a, Vec<D> b) const {
return Eq128Upper(d, a, b);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Mask<D> NotEqualKeys(D d, Vec<D> a, Vec<D> b) const {
return Ne128Upper(d, a, b);
}
HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const {
return a[1] == b[1];
}
// Only count differences in the actual key, not the value.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec<D> diff) const {
// Must avoid floating-point comparisons (for -0)
const RebindToUnsigned<D> du;
const Vec<decltype(du)> zero = Zero(du);
const Vec<decltype(du)> keys = OddEven(diff, zero); // clear values
return AllTrue(du, Eq(BitCast(du, keys), zero));
}
// Returns vector with only the top half of each block valid. This allows
// fusing the "replicate upper to lower half" step with a subsequent permute.
template <class Order, class D>
HWY_INLINE HWY_MAYBE_UNUSED Vec<D> CompareTop(D d, Vec<D> a, Vec<D> b) const {
// Only the upper lane of each block is a key, and only that lane is
// required to be valid, so comparing all lanes is sufficient.
return VecFromMask(d, Order().CompareLanes(a, b));
}
};
struct OrderAscendingKV128 : public KeyValue128 {
using Order = SortAscending;
using OrderForSortingNetwork = OrderAscending128;
HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const {
return a[1] < b[1];
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
return Lt128Upper(d, a, b);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
return Min128Upper(d, a, b);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
return Max128Upper(d, a, b);
}
// FirstOfLanes/LastOfLanes are implemented in Traits128.
// Same as for regular lanes because 128-bit keys are u64.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> FirstValue(D d) const {
return Set(d, hwy::LowestValue<TFromD<D> >());
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> LastValue(D d) const {
return Set(d, hwy::HighestValue<TFromD<D> >());
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const {
const Vec<D> k1 = OddEven(Set(d, uint64_t{1}), Zero(d));
return Sub(v, k1);
}
// 'Private', used by base class KeyValue128::CompareTop.
template <class V>
HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
return Lt(a, b);
}
};
struct OrderDescendingKV128 : public KeyValue128 {
using Order = SortDescending;
using OrderForSortingNetwork = OrderDescending128;
HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const {
return b[1] < a[1];
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
return Lt128Upper(d, b, a);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
return Max128Upper(d, a, b);
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
return Min128Upper(d, a, b);
}
// FirstOfLanes/LastOfLanes are implemented in Traits128.
// Same as for regular lanes because 128-bit keys are u64.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> FirstValue(D d) const {
return Set(d, hwy::HighestValue<TFromD<D> >());
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> LastValue(D d) const {
return Set(d, hwy::LowestValue<TFromD<D> >());
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const {
const Vec<D> k1 = OddEven(Set(d, uint64_t{1}), Zero(d));
return Add(v, k1);
}
// 'Private', used by base class KeyValue128::CompareTop.
template <class V>
HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
return Lt(b, a);
}
};
// We want to swap 2 u128, i.e. 4 u64 lanes, based on the 0 or FF..FF mask in
// the most-significant of those lanes (the result of CompareTop), so
// replicate it 4x. Only called for >= 256-bit vectors.
#if HWY_TARGET <= HWY_AVX3
template <class V, HWY_IF_V_SIZE_V(V, 64)>
HWY_INLINE V ReplicateTop4x(V v) {
return V{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))};
}
#endif // HWY_TARGET <= HWY_AVX3
#if HWY_TARGET <= HWY_AVX2
template <class V, HWY_IF_V_SIZE_V(V, 32)>
HWY_INLINE V ReplicateTop4x(V v) {
return V{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))};
}
#else // HWY_TARGET > HWY_AVX2
template <class V>
HWY_INLINE V ReplicateTop4x(V v) {
#if HWY_TARGET == HWY_SVE_256
return svdup_lane_u64(v, 3);
#else
const ScalableTag<uint64_t> d;
HWY_DASSERT(Lanes(d) == 4 || Lanes(d) == 8); // for table below
HWY_ALIGN static constexpr uint64_t kIndices[8] = {3, 3, 3, 3, 7, 7, 7, 7};
return TableLookupLanes(v, SetTableIndices(d, kIndices));
#endif
}
#endif // HWY_TARGET <= HWY_AVX2
// Shared code that depends on Order.
template <class Base>
struct Traits128 : public Base {
using TraitsForSortingNetwork =
Traits128<typename Base::OrderForSortingNetwork>;
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
TFromD<D>* HWY_RESTRICT buf) const {
const Base* base = static_cast<const Base*>(this);
const size_t N = Lanes(d);
Store(v, d, buf);
v = base->SetKey(d, buf + 0); // result must be broadcasted
for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) {
v = base->First(d, v, base->SetKey(d, buf + i));
}
return v;
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
TFromD<D>* HWY_RESTRICT buf) const {
const Base* base = static_cast<const Base*>(this);
const size_t N = Lanes(d);
Store(v, d, buf);
v = base->SetKey(d, buf + 0); // result must be broadcasted
for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) {
v = base->Last(d, v, base->SetKey(d, buf + i));
}
return v;
}
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE void Sort2(D d, Vec<D>& a, Vec<D>& b) const {
const Base* base = static_cast<const Base*>(this);
const Vec<D> a_copy = a;
const auto lt = base->Compare(d, a, b);
a = IfThenElse(lt, a, b);
b = IfThenElse(lt, b, a_copy);
}
// Conditionally swaps even-numbered keys with their odd-numbered neighbor.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const {
HWY_DASSERT(Lanes(d) >= 4); // required by ReplicateTop4x
const Base* base = static_cast<const Base*>(this);
Vec<D> swapped = base->ReverseKeys2(d, v);
const Vec<D> cmpHx = base->template CompareTop<Base>(d, v, swapped);
return IfVecThenElse(ReplicateTop4x(cmpHx), swapped, v);
}
// Swaps with the vector formed by reversing contiguous groups of four 128-bit
// keys, which implies 512-bit vectors (we do not support more than that).
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> SortPairsReverse4(D d, Vec<D> v) const {
HWY_DASSERT(Lanes(d) == 8); // For TableLookupLanes below
const Base* base = static_cast<const Base*>(this);
Vec<D> swapped = base->ReverseKeys4(d, v);
const Vec<D> cmpHx = base->template CompareTop<Base>(d, v, swapped);
// Similar to ReplicateTop4x, we want to gang together 2 comparison results
// (4 lanes). They are not contiguous, so use permute to replicate 4x.
HWY_ALIGN uint64_t kIndices[8] = {7, 7, 5, 5, 5, 5, 7, 7};
const Vec<D> select = TableLookupLanes(cmpHx, SetTableIndices(d, kIndices));
return IfVecThenElse(select, swapped, v);
}
// Conditionally swaps lane 0 with 4, 1 with 5 etc.
template <class D, HWY_IF_U64_D(D)>
HWY_INLINE Vec<D> SortPairsDistance4(D, Vec<D>) const {
// Only used by Merge16, which would require 2048 bit vectors (unsupported).
HWY_ASSERT(0);
}
};
#endif // HWY_TARGET != HWY_SCALAR
} // namespace detail
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,303 @@
// Copyright 2022 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Interface to vectorized quicksort with dynamic dispatch. For static dispatch
// without any DLLEXPORT, avoid including this header and instead define
// VQSORT_ONLY_STATIC, then call VQSortStatic* in vqsort-inl.h.
//
// Blog post: https://tinyurl.com/vqsort-blog
// Paper with measurements: https://arxiv.org/abs/2205.05982
//
// To ensure the overhead of using wide vectors (e.g. AVX2 or AVX-512) is
// worthwhile, we recommend using this code for sorting arrays whose size is at
// least 100 KiB. See the README for details.
#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_
#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_
// IWYU pragma: begin_exports
#include <stddef.h>
#include "third_party/highway/hwy/base.h"
#include "third_party/highway/hwy/contrib/sort/order.h" // SortAscending
// IWYU pragma: end_exports
namespace hwy {
// Vectorized Quicksort: sorts keys[0, n). Does not preserve the ordering of
// equivalent keys (defined as: neither greater nor less than another).
// Dispatches to the best available instruction set. Does not allocate memory.
// Uses about 1.2 KiB stack plus an internal 3-word TLS cache for random state.
HWY_CONTRIB_DLLEXPORT void VQSort(uint16_t* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(uint16_t* HWY_RESTRICT keys, size_t n,
SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSort(uint32_t* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(uint32_t* HWY_RESTRICT keys, size_t n,
SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSort(uint64_t* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(uint64_t* HWY_RESTRICT keys, size_t n,
SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSort(int16_t* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(int16_t* HWY_RESTRICT keys, size_t n,
SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSort(int32_t* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(int32_t* HWY_RESTRICT keys, size_t n,
SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n,
SortDescending);
// These two must only be called if hwy::HaveFloat16() is true.
HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n,
SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n,
SortDescending);
// These two must only be called if hwy::HaveFloat64() is true.
HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n,
SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSort(K32V32* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(K32V32* HWY_RESTRICT keys, size_t n,
SortDescending);
// 128-bit types: `n` is still in units of the 128-bit keys.
HWY_CONTRIB_DLLEXPORT void VQSort(uint128_t* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(uint128_t* HWY_RESTRICT keys, size_t n,
SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSort(K64V64* HWY_RESTRICT keys, size_t n,
SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSort(K64V64* HWY_RESTRICT keys, size_t n,
SortDescending);
// Vectorized partial Quicksort:
// Rearranges elements such that the range [0, k) contains the sorted first k
// elements in the range [0, n). Does not preserve the ordering of equivalent
// keys (defined as: neither greater nor less than another).
// Dispatches to the best available instruction set. Does not allocate memory.
// Uses about 1.2 KiB stack plus an internal 3-word TLS cache for random state.
HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint32_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint32_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint64_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint64_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(int16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(int16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(int32_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(int32_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
// These two must only be called if hwy::HaveFloat16() is true.
HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
// These two must only be called if hwy::HaveFloat64() is true.
HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(K32V32* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(K32V32* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
// 128-bit types: `n` and `k` are still in units of the 128-bit keys.
HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint128_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint128_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(K64V64* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQPartialSort(K64V64* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
// Vectorized Quickselect:
// rearranges elements in [0, n) such that:
// The element pointed at by kth is changed to whatever element would occur in
// that position if [0, n) were sorted. All of the elements before this new kth
// element are less than or equal to the elements after the new kth element.
HWY_CONTRIB_DLLEXPORT void VQSelect(uint16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(uint16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSelect(uint32_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(uint32_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSelect(uint64_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(uint64_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSelect(int16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(int16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSelect(int32_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(int32_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
// These two must only be called if hwy::HaveFloat16() is true.
HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
// These two must only be called if hwy::HaveFloat64() is true.
HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSelect(K32V32* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(K32V32* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
// 128-bit types: `n` and `k` are still in units of the 128-bit keys.
HWY_CONTRIB_DLLEXPORT void VQSelect(uint128_t* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(uint128_t* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
HWY_CONTRIB_DLLEXPORT void VQSelect(K64V64* HWY_RESTRICT keys, size_t n,
size_t k, SortAscending);
HWY_CONTRIB_DLLEXPORT void VQSelect(K64V64* HWY_RESTRICT keys, size_t n,
size_t k, SortDescending);
// User-level caching is no longer required, so this class is no longer
// beneficial. We recommend using the simpler VQSort() interface instead, and
// retain this class only for compatibility. It now just calls VQSort.
class HWY_CONTRIB_DLLEXPORT Sorter {
public:
Sorter();
~Sorter() { Delete(); }
// Move-only
Sorter(const Sorter&) = delete;
Sorter& operator=(const Sorter&) = delete;
Sorter(Sorter&& /*other*/) {}
Sorter& operator=(Sorter&& /*other*/) { return *this; }
void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortDescending) const;
void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortDescending) const;
void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortDescending) const;
void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortDescending) const;
void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortDescending) const;
void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortDescending) const;
// These two must only be called if hwy::HaveFloat16() is true.
void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortDescending) const;
void operator()(float* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(float* HWY_RESTRICT keys, size_t n, SortDescending) const;
// These two must only be called if hwy::HaveFloat64() is true.
void operator()(double* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(double* HWY_RESTRICT keys, size_t n, SortDescending) const;
void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortDescending) const;
void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortDescending) const;
void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortAscending) const;
void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortDescending) const;
// Unused
static void Fill24Bytes(const void*, size_t, void*);
static bool HaveFloat64(); // Can also use hwy::HaveFloat64 directly.
private:
void Delete();
template <typename T>
T* Get() const {
return unused_;
}
#if HWY_COMPILER_CLANG
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wunused-private-field")
#endif
void* unused_ = nullptr;
#if HWY_COMPILER_CLANG
HWY_DIAGNOSTICS(pop)
#endif
};
// Used by vqsort-inl.h unless VQSORT_ONLY_STATIC.
HWY_CONTRIB_DLLEXPORT bool Fill16BytesSecure(void* bytes);
// Unused, only provided for binary compatibility.
HWY_CONTRIB_DLLEXPORT uint64_t* GetGeneratorState();
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_

View File

@@ -0,0 +1,247 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_FUTEX_H_
#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_FUTEX_H_
// Keyed event (futex): kernel queue of blocked threads, identified by the
// address of an atomic u32 called `current` within the same process (do NOT
// use with shared-memory mappings).
//
// Futex equivalents: https://outerproduct.net/futex-dictionary.html; we
// support Linux/Emscripten/Apple/Windows and C++20 std::atomic::wait, plus a
// NanoSleep fallback.
#include <time.h>
#include <atomic>
#include <climits> // INT_MAX
#include "third_party/highway/hwy/base.h"
#if HWY_ARCH_WASM
#include <emscripten/threading.h>
#include <math.h> // INFINITY
#elif HWY_OS_LINUX
#include <errno.h> // IWYU pragma: keep
#include <linux/futex.h> // FUTEX_*
#include <pthread.h>
#include <sys/syscall.h> // SYS_*
#include <unistd.h>
// Android may not declare these:
#ifndef SYS_futex
#ifdef SYS_futex_time64 // 32-bit with 64-bit time_t
#define SYS_futex SYS_futex_time64
#else
#define SYS_futex __NR_futex
#endif // SYS_futex_time64
#endif // SYS_futex
#ifndef FUTEX_WAIT_PRIVATE
#define FUTEX_WAIT_PRIVATE (FUTEX_WAIT | 128)
#endif
#ifndef FUTEX_WAKE_PRIVATE
#define FUTEX_WAKE_PRIVATE (FUTEX_WAKE | 128)
#endif
#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX)
// These are private APIs, so add an opt-out.
extern "C" {
int __ulock_wait(uint32_t op, void* address, uint64_t val, uint32_t max_us);
int __ulock_wake(uint32_t op, void* address, uint64_t zero);
} // extern "C"
#define UL_COMPARE_AND_WAIT 1
#define ULF_WAKE_ALL 0x00000100
#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX)
// WakeByAddressAll requires Windows 8, so add an opt-out.
#ifndef NOMINMAX
#define NOMINMAX
#endif // NOMINMAX
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif // WIN32_LEAN_AND_MEAN
#include <windows.h>
#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL
#pragma comment(lib, "synchronization.lib")
#endif
#elif HWY_CXX_LANG < 202002L // NOT C++20, which has native support
#define HWY_FUTEX_SLEEP
#endif
namespace hwy {
// Attempts to pause for the specified nanoseconds, though the resolution is
// closer to 0.1 microseconds. Returns false if no wait happened. Thread-safe.
static inline bool NanoSleep(uint64_t ns) {
#if HWY_OS_WIN
static thread_local HANDLE hTimer = nullptr;
if (HWY_UNLIKELY(hTimer == nullptr)) {
// Must be manual reset: auto-reset would immediately signal after the next
// SetWaitableTimer.
hTimer = CreateWaitableTimer(nullptr, TRUE, nullptr);
if (hTimer == nullptr) return false;
}
// Negative means relative, in units of 100 ns.
LARGE_INTEGER time;
time.QuadPart = -static_cast<LONGLONG>(ns / 100);
const LONG period = 0; // signal once
if (!SetWaitableTimer(hTimer, &time, period, nullptr, nullptr, FALSE)) {
return false;
}
(void)WaitForSingleObject(hTimer, INFINITE);
return true;
#else
timespec duration;
duration.tv_sec = static_cast<time_t>(ns / 1000000000);
duration.tv_nsec = static_cast<decltype(duration.tv_nsec)>(ns % 1000000000);
timespec remainder;
// Repeat if interrupted by a signal. Note that the remainder may be rounded
// up, which could cause an infinite loop if continually interrupted. Using
// clock_nanosleep would work, but we'd have to get the current time. We
// assume durations are short, and instead just cap the number of retries.
for (int rep = 0; rep < 3; ++rep) {
if (nanosleep(&duration, &remainder) == 0 || errno != EINTR) break;
duration = remainder;
}
return true;
#endif
}
// Waits until `current != prev` and returns the new value. May return
// immediately if `current` already changed, or after blocking and waking.
static inline uint32_t BlockUntilDifferent(
const uint32_t prev, const std::atomic<uint32_t>& current) {
const auto acq = std::memory_order_acquire;
#if HWY_ARCH_WASM
// It is always safe to cast to void.
volatile void* address =
const_cast<volatile void*>(static_cast<const volatile void*>(&current));
const double max_ms = INFINITY;
for (;;) {
const uint32_t next = current.load(acq);
if (next != prev) return next;
const int ret = emscripten_futex_wait(address, prev, max_ms);
HWY_DASSERT(ret >= 0);
(void)ret;
}
#elif HWY_OS_LINUX
// Safe to cast because std::atomic is a standard layout type.
const uint32_t* address = reinterpret_cast<const uint32_t*>(&current);
// _PRIVATE requires this only be used in the same process, and avoids
// virtual->physical lookups and atomic reference counting.
const int op = FUTEX_WAIT_PRIVATE;
for (;;) {
const uint32_t next = current.load(acq);
if (next != prev) return next;
// timeout=null may prevent interrupts via signal. No lvalue because
// the timespec type is only standardized since C++17 or C11.
const auto ret = syscall(SYS_futex, address, op, prev, nullptr, nullptr, 0);
if (ret == -1) {
HWY_DASSERT(errno == EAGAIN); // otherwise an actual error
}
}
#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX)
// It is always safe to cast to void.
volatile void* address =
const_cast<volatile void*>(static_cast<const volatile void*>(&current));
// API is not const-correct, but only loads from the pointer.
PVOID pprev = const_cast<void*>(static_cast<const void*>(&prev));
const DWORD max_ms = INFINITE;
for (;;) {
const uint32_t next = current.load(acq);
if (next != prev) return next;
const BOOL ok = WaitOnAddress(address, pprev, sizeof(prev), max_ms);
HWY_DASSERT(ok);
(void)ok;
}
#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX)
// It is always safe to cast to void.
void* address = const_cast<void*>(static_cast<const void*>(&current));
for (;;) {
const uint32_t next = current.load(acq);
if (next != prev) return next;
__ulock_wait(UL_COMPARE_AND_WAIT, address, prev, 0);
}
#elif defined(HWY_FUTEX_SLEEP)
for (;;) {
const uint32_t next = current.load(acq);
if (next != prev) return next;
NanoSleep(2000);
}
#elif HWY_CXX_LANG >= 202002L
current.wait(prev, acq); // No spurious wakeup.
const uint32_t next = current.load(acq);
HWY_DASSERT(next != prev);
return next;
#else
#error "Logic error, should have reached HWY_FUTEX_SLEEP"
#endif // HWY_OS_*
} // BlockUntilDifferent
// Wakes all threads, if any, that are waiting because they called
// `BlockUntilDifferent` with the same `current`.
static inline void WakeAll(std::atomic<uint32_t>& current) {
#if HWY_ARCH_WASM
// It is always safe to cast to void.
volatile void* address = static_cast<volatile void*>(&current);
const int max_to_wake = INT_MAX; // actually signed
const int ret = emscripten_futex_wake(address, max_to_wake);
HWY_DASSERT(ret >= 0);
(void)ret;
#elif HWY_OS_LINUX
// Safe to cast because std::atomic is a standard layout type.
uint32_t* address = reinterpret_cast<uint32_t*>(&current);
const int max_to_wake = INT_MAX; // actually signed
const auto ret = syscall(SYS_futex, address, FUTEX_WAKE_PRIVATE, max_to_wake,
nullptr, nullptr, 0);
HWY_DASSERT(ret >= 0); // number woken
(void)ret;
#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX)
// It is always safe to cast to void.
void* address = static_cast<void*>(&current);
WakeByAddressAll(address);
#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX)
// It is always safe to cast to void.
void* address = static_cast<void*>(&current);
__ulock_wake(UL_COMPARE_AND_WAIT | ULF_WAKE_ALL, address, 0);
#elif defined(HWY_FUTEX_SLEEP)
// NanoSleep loop does not require wakeup.
(void)current;
#elif HWY_CXX_LANG >= 202002L
current.notify_all();
#else
#error "Logic error, should have reached HWY_FUTEX_SLEEP"
#endif
} // WakeAll
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_FUTEX_H_

View File

@@ -0,0 +1,328 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_
#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_
// Relatively power-efficient spin lock for low-latency synchronization.
#include <stdint.h>
#include <atomic>
#include "third_party/highway/hwy/base.h"
#include "third_party/highway/hwy/cache_control.h" // Pause
#ifndef HWY_ENABLE_MONITORX // allow override
// Clang 3.9 suffices for mwaitx, but the target pragma requires 9.0.
#if HWY_ARCH_X86 && ((HWY_COMPILER_CLANG >= 900) || \
(HWY_COMPILER_GCC_ACTUAL >= 502) || defined(__MWAITX__))
#define HWY_ENABLE_MONITORX 1
#else
#define HWY_ENABLE_MONITORX 0
#endif
#endif // HWY_ENABLE_MONITORX
#ifndef HWY_ENABLE_UMONITOR // allow override
#if HWY_ARCH_X86 && ((HWY_COMPILER_CLANG >= 900) || \
(HWY_COMPILER_GCC_ACTUAL >= 901) || defined(__WAITPKG__))
#define HWY_ENABLE_UMONITOR 1
#else
#define HWY_ENABLE_UMONITOR 0
#endif
#endif // HWY_ENABLE_UMONITOR
// Inline assembly is preferred because it allows inlining of `UntilDifferent`
// etc, but we also support intrinsics for MSVC.
#ifndef HWY_ENABLE_SPIN_ASM // allow override
#if (HWY_COMPILER_CLANG || HWY_COMPILER_GCC) && HWY_ARCH_X86_64
#define HWY_ENABLE_SPIN_ASM 1
#else
#define HWY_ENABLE_SPIN_ASM 0
#endif
#endif // HWY_ENABLE_SPIN_ASM
#if HWY_ENABLE_MONITORX || HWY_ENABLE_UMONITOR
#if HWY_ENABLE_SPIN_ASM
#define HWY_INLINE_SPIN HWY_INLINE // can inline functions with inline assembly
#else
// Intrinsics require attributes, which prevent inlining.
#define HWY_INLINE_SPIN
#include <x86intrin.h>
#endif // HWY_ENABLE_SPIN_ASM
#include "third_party/highway/hwy/x86_cpuid.h"
#endif // HWY_ENABLE_MONITORX || HWY_ENABLE_UMONITOR
namespace hwy {
// Returned by `UntilDifferent` in a single register.
struct SpinResult {
// We also use u32 because that is all that futex.h supports.
uint32_t current;
// Number of retries before returning, useful for checking that the
// monitor/wait did not just return immediately.
uint32_t reps;
};
// User-space monitor/wait are supported on Zen2+ AMD and SPR+ Intel. Spin waits
// are rarely called from SIMD code, hence we do not integrate this into
// `HWY_TARGET` and its runtime dispatch mechanism. Returned by `Type()`, also
// used by callers to set the `disabled` argument for `DetectSpin`.
enum class SpinType : uint8_t {
kMonitorX = 1, // AMD
kUMonitor, // Intel
kPause,
kSentinel // for iterating over all enumerators. Must be last.
};
// For printing which is in use.
static inline const char* ToString(SpinType type) {
switch (type) {
case SpinType::kMonitorX:
return "MonitorX_C1";
case SpinType::kUMonitor:
return "UMonitor_C0.2";
case SpinType::kPause:
return "Pause";
case SpinType::kSentinel:
return nullptr;
default:
HWY_UNREACHABLE;
}
}
// Indirect function calls turn out to be too expensive because this is called
// multiple times per ThreadPool barrier. We will instead inline the spin and
// barrier using policy classes. This one is always available; use it as a
// reference for the interface. Note that Pause varies across CPUs: it can be
// a no-op, or wait 140 cycles.
struct SpinPause {
SpinType Type() const { return SpinType::kPause; }
// Spins until `watched != prev` and returns the new value, similar to
// `BlockUntilDifferent` in `futex.h`.
HWY_INLINE SpinResult UntilDifferent(
const uint32_t prev, const std::atomic<uint32_t>& watched) const {
for (uint32_t reps = 0;; ++reps) {
const uint32_t current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
hwy::Pause();
}
}
// Returns number of retries until `watched == expected`.
HWY_INLINE size_t UntilEqual(const uint32_t expected,
const std::atomic<uint32_t>& watched) const {
for (size_t reps = 0;; ++reps) {
const uint32_t current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
hwy::Pause();
}
}
};
#if HWY_ENABLE_MONITORX || HWY_IDE
#if !HWY_ENABLE_SPIN_ASM
HWY_PUSH_ATTRIBUTES("mwaitx")
#endif
// AMD's user-mode monitor/wait (Zen2+).
class SpinMonitorX {
public:
SpinType Type() const { return SpinType::kMonitorX; }
HWY_INLINE_SPIN SpinResult UntilDifferent(
const uint32_t prev, const std::atomic<uint32_t>& watched) const {
for (uint32_t reps = 0;; ++reps) {
uint32_t current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
Monitor(&watched);
// Double-checked 'lock' to avoid missed events:
current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
Wait();
}
}
HWY_INLINE_SPIN size_t UntilEqual(
const uint32_t expected, const std::atomic<uint32_t>& watched) const {
for (size_t reps = 0;; ++reps) {
uint32_t current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
Monitor(&watched);
// Double-checked 'lock' to avoid missed events:
current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
Wait();
}
}
private:
static HWY_INLINE void Monitor(const void* addr) {
// No extensions/hints currently defined.
#if HWY_ENABLE_SPIN_ASM
asm volatile("monitorx" ::"a"(addr), "c"(0), "d"(0));
#else
_mm_monitorx(const_cast<void*>(addr), 0, 0);
#endif
}
static HWY_INLINE void Wait() {
#if HWY_ENABLE_SPIN_ASM
// EBX=0 cycles means no timeout/infinite.
asm volatile("mwaitx" ::"a"(kHints), "b"(0), "c"(kExtensions));
#else
_mm_mwaitx(kExtensions, kHints, /*cycles=*/0);
#endif
}
// 0xF would be C0. Its wakeup latency is less than 0.1 us shorter, and
// package power is sometimes actually higher than with Pause. The
// difference in spurious wakeups is minor.
static constexpr unsigned kHints = 0x0; // C1: a bit deeper than C0
// No timeout required, we assume the mwaitx does not miss stores, see
// https://www.usenix.org/system/files/usenixsecurity23-zhang-ruiyi.pdf.]
static constexpr unsigned kExtensions = 0;
};
#if !HWY_ENABLE_SPIN_ASM
HWY_POP_ATTRIBUTES
#endif
#endif // HWY_ENABLE_MONITORX
#if HWY_ENABLE_UMONITOR || HWY_IDE
#if !HWY_ENABLE_SPIN_ASM
HWY_PUSH_ATTRIBUTES("waitpkg")
#endif
// Intel's user-mode monitor/wait (SPR+).
class SpinUMonitor {
public:
SpinType Type() const { return SpinType::kUMonitor; }
HWY_INLINE_SPIN SpinResult UntilDifferent(
const uint32_t prev, const std::atomic<uint32_t>& watched) const {
for (uint32_t reps = 0;; ++reps) {
uint32_t current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
Monitor(&watched);
// Double-checked 'lock' to avoid missed events:
current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
Wait();
}
}
HWY_INLINE_SPIN size_t UntilEqual(
const uint32_t expected, const std::atomic<uint32_t>& watched) const {
for (size_t reps = 0;; ++reps) {
uint32_t current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
Monitor(&watched);
// Double-checked 'lock' to avoid missed events:
current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
Wait();
}
}
private:
static HWY_INLINE void Monitor(const void* addr) {
#if HWY_ENABLE_SPIN_ASM
asm volatile("umonitor %%rcx" ::"c"(addr));
#else
_umonitor(const_cast<void*>(addr));
#endif
}
static HWY_INLINE void Wait() {
#if HWY_ENABLE_SPIN_ASM
asm volatile("umwait %%ecx" ::"c"(kControl), "d"(kDeadline >> 32),
"a"(kDeadline & 0xFFFFFFFFu));
#else
_umwait(kControl, kDeadline);
#endif
}
// 1 would be C0.1. C0.2 has 20x fewer spurious wakeups and additional 4%
// package power savings vs Pause on SPR. It comes at the cost of
// 0.4-0.6us higher wake latency, but the total is comparable to Zen4.
static constexpr unsigned kControl = 0; // C0.2 for deeper sleep
static constexpr uint64_t kDeadline = ~uint64_t{0}; // no timeout, see above
};
#if !HWY_ENABLE_SPIN_ASM
HWY_POP_ATTRIBUTES
#endif
#endif // HWY_ENABLE_UMONITOR
// TODO(janwas): add WFE on Arm. May wake at 10 kHz, but still worthwhile.
// Returns the best-available type whose bit in `disabled` is not set. Example:
// to disable kUMonitor, pass `1 << static_cast<int>(SpinType::kUMonitor)`.
// Ignores `disabled` for `kPause` if it is the only supported and enabled type.
// Somewhat expensive, typically called during initialization.
static inline SpinType DetectSpin(int disabled = 0) {
const auto HWY_MAYBE_UNUSED enabled = [disabled](SpinType type) {
return (disabled & (1 << static_cast<int>(type))) == 0;
};
#if HWY_ENABLE_MONITORX
if (enabled(SpinType::kMonitorX) && x86::IsAMD()) {
uint32_t abcd[4];
x86::Cpuid(0x80000001U, 0, abcd);
if (x86::IsBitSet(abcd[2], 29)) return SpinType::kMonitorX;
}
#endif // HWY_ENABLE_MONITORX
#if HWY_ENABLE_UMONITOR
if (enabled(SpinType::kUMonitor) && x86::MaxLevel() >= 7) {
uint32_t abcd[4];
x86::Cpuid(7, 0, abcd);
if (x86::IsBitSet(abcd[2], 5)) return SpinType::kUMonitor;
}
#endif // HWY_ENABLE_UMONITOR
if (!enabled(SpinType::kPause)) {
HWY_WARN("Ignoring attempt to disable Pause, it is the only option left.");
}
return SpinType::kPause;
}
// Calls `func(spin)` for the given `spin_type`.
template <class Func>
HWY_INLINE void CallWithSpin(SpinType spin_type, Func&& func) {
switch (spin_type) {
#if HWY_ENABLE_MONITORX
case SpinType::kMonitorX:
func(SpinMonitorX());
break;
#endif
#if HWY_ENABLE_UMONITOR
case SpinType::kUMonitor:
func(SpinUMonitor());
break;
#endif
case SpinType::kPause:
default:
func(SpinPause());
break;
}
}
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,141 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_TOPOLOGY_H_
#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_TOPOLOGY_H_
// OS-specific functions for processor topology and thread affinity.
#include <stddef.h>
#include <vector>
#include "third_party/highway/hwy/base.h"
#include "third_party/highway/hwy/bit_set.h"
namespace hwy {
// Returns false if std::thread should not be used.
HWY_CONTRIB_DLLEXPORT bool HaveThreadingSupport();
// Upper bound on logical processors, including hyperthreads.
static constexpr size_t kMaxLogicalProcessors = 1024; // matches glibc
// Set used by Get/SetThreadAffinity.
using LogicalProcessorSet = BitSet4096<kMaxLogicalProcessors>;
// Returns false, or sets `lps` to all logical processors which are online and
// available to the current thread.
HWY_CONTRIB_DLLEXPORT bool GetThreadAffinity(LogicalProcessorSet& lps);
// Ensures the current thread can only run on the logical processors in `lps`.
// Returns false if not supported (in particular on Apple), or if the
// intersection between `lps` and `GetThreadAffinity` is the empty set.
HWY_CONTRIB_DLLEXPORT bool SetThreadAffinity(const LogicalProcessorSet& lps);
// Returns false, or ensures the current thread will only run on `lp`, which
// must not exceed `TotalLogicalProcessors`. Note that this merely calls
// `SetThreadAffinity`, see the comment there.
static inline bool PinThreadToLogicalProcessor(size_t lp) {
LogicalProcessorSet lps;
lps.Set(lp);
return SetThreadAffinity(lps);
}
// Returns 1 if unknown, otherwise the total number of logical processors
// provided by the hardware clamped to `kMaxLogicalProcessors`.
// These processors are not necessarily all usable; you can determine which are
// via GetThreadAffinity().
HWY_CONTRIB_DLLEXPORT size_t TotalLogicalProcessors();
struct Topology {
// Caller must check packages.empty(); if so, do not use any fields.
HWY_CONTRIB_DLLEXPORT Topology();
// Clique of cores with lower latency to each other. On Apple M1 these are
// four cores sharing an L2. On Zen4 these 'CCX' are up to eight cores sharing
// an L3 and a memory controller, or for Zen4c up to 16 and half the L3 size.
struct Cluster {
LogicalProcessorSet lps;
uint64_t private_kib = 0; // 0 if unknown
uint64_t shared_kib = 0; // 0 if unknown
uint64_t reserved1 = 0;
uint64_t reserved2 = 0;
uint64_t reserved3 = 0;
};
struct Core {
LogicalProcessorSet lps;
uint64_t reserved = 0;
};
struct Package {
std::vector<Cluster> clusters;
std::vector<Core> cores;
};
std::vector<Package> packages;
// Several hundred instances, so prefer a compact representation.
#pragma pack(push, 1)
struct LP {
uint16_t cluster = 0; // < packages[package].clusters.size()
uint16_t core = 0; // < packages[package].cores.size()
uint8_t package = 0; // < packages.size()
uint8_t smt = 0; // < packages[package].cores[core].lps.Count()
uint8_t node = 0;
uint8_t reserved = 0;
};
#pragma pack(pop)
std::vector<LP> lps; // size() == TotalLogicalProcessors().
};
#pragma pack(push, 1)
// Cache parameters. Note the overlap with `HWY_ALIGNMENT`, which is intended
// but not guaranteed to be an upper bound for L1/L2 line sizes, and
// `Topology::Cluster::private_kib/shared_kib`, which are intended but not
// guaranteed to be the L2/L3 sizes. Getting the exact parameters, including the
// ways of associativity, can be useful for modeling cache conflicts.
//
// Uses packed fields so the array of `Cache` fits in a typical cache line.
struct Cache {
// Arbitrary upper bound for sanity checking.
static constexpr uint16_t kMaxAssociativity = 128;
// Zero if the level does not exist; *per-core* portion for shared caches.
uint32_t size_kib = 0;
// Also per-core portion, computed as number of lines / associativity.
uint32_t sets = 0;
uint16_t bytes_per_line = 0;
uint16_t associativity = 0; // number of ways
uint16_t cores_sharing = 0; // usually 1 for L1
uint16_t reserved = 0;
};
static_assert(sizeof(Cache) == 16, "Unexpected size");
#pragma pack(pop)
// Returns null if unknown, otherwise pointer to an array of `Cache` instances,
// where entry 0 is reserved, entry 1 describes the L1 data cache, entry 2
// describes the (possibly unified or shared) L2, and entry 3 describes the L3
// if its `size_kib != 0`.
//
// Initializes on-demand, which has some overhead for thread safety, hence
// callers should cache the result.
HWY_CONTRIB_DLLEXPORT const Cache* DataCaches();
} // namespace hwy
#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_TOPOLOGY_H_

View File

@@ -0,0 +1,473 @@
// Copyright 2023 Matthew Kolbe
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_) == \
defined(HWY_TARGET_TOGGLE)
#ifdef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
#undef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
#else
#define HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
#endif
#include <cstdlib> // std::abs
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <class DERIVED, typename IN_T, typename OUT_T>
struct UnrollerUnit {
static constexpr size_t kMaxTSize = HWY_MAX(sizeof(IN_T), sizeof(OUT_T));
using LargerT = SignedFromSize<kMaxTSize>; // only the size matters.
DERIVED* me() { return static_cast<DERIVED*>(this); }
static constexpr size_t MaxUnitLanes() {
return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>);
}
static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); }
using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>;
using IT = hn::Rebind<IN_T, LargerD>;
using OT = hn::Rebind<OUT_T, LargerD>;
IT d_in;
OT d_out;
using Y_VEC = hn::Vec<OT>;
using X_VEC = hn::Vec<IT>;
Y_VEC Func(const ptrdiff_t idx, const X_VEC x, const Y_VEC y) {
return me()->Func(idx, x, y);
}
X_VEC X0Init() { return me()->X0InitImpl(); }
X_VEC X0InitImpl() { return hn::Zero(d_in); }
Y_VEC YInit() { return me()->YInitImpl(); }
Y_VEC YInitImpl() { return hn::Zero(d_out); }
X_VEC Load(const ptrdiff_t idx, const IN_T* from) {
return me()->LoadImpl(idx, from);
}
X_VEC LoadImpl(const ptrdiff_t idx, const IN_T* from) {
return hn::LoadU(d_in, from + idx);
}
// MaskLoad can take in either a positive or negative number for `places`. if
// the number is positive, then it loads the top `places` values, and if it's
// negative, it loads the bottom |places| values. example: places = 3
// | o | o | o | x | x | x | x | x |
// example places = -3
// | x | x | x | x | x | o | o | o |
X_VEC MaskLoad(const ptrdiff_t idx, const IN_T* from,
const ptrdiff_t places) {
return me()->MaskLoadImpl(idx, from, places);
}
X_VEC MaskLoadImpl(const ptrdiff_t idx, const IN_T* from,
const ptrdiff_t places) {
auto mask = hn::FirstN(d_in, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
d_in,
static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
if (places < 0) mask = maskneg;
return hn::MaskedLoad(mask, d_in, from + idx);
}
bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
return me()->StoreAndShortCircuitImpl(idx, to, x);
}
bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
hn::StoreU(x, d_out, to + idx);
return true;
}
ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
ptrdiff_t const places) {
return me()->MaskStoreImpl(idx, to, x, places);
}
ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
const ptrdiff_t places) {
auto mask = hn::FirstN(d_out, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
d_out,
static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
if (places < 0) mask = maskneg;
hn::BlendedStore(x, mask, d_out, to + idx);
return std::abs(places);
}
ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); }
ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) {
// default does nothing
(void)x;
(void)to;
return 0;
}
void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
me()->ReduceImpl(x0, x1, x2, y);
}
void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
// default does nothing
(void)x0;
(void)x1;
(void)x2;
(void)y;
}
};
template <class DERIVED, typename IN0_T, typename IN1_T, typename OUT_T>
struct UnrollerUnit2D {
DERIVED* me() { return static_cast<DERIVED*>(this); }
static constexpr size_t kMaxTSize =
HWY_MAX(sizeof(IN0_T), HWY_MAX(sizeof(IN1_T), sizeof(OUT_T)));
using LargerT = SignedFromSize<kMaxTSize>; // only the size matters.
static constexpr size_t MaxUnitLanes() {
return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>);
}
static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); }
using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>;
using I0T = hn::Rebind<IN0_T, LargerD>;
using I1T = hn::Rebind<IN1_T, LargerD>;
using OT = hn::Rebind<OUT_T, LargerD>;
I0T d_in0;
I1T d_in1;
OT d_out;
using Y_VEC = hn::Vec<OT>;
using X0_VEC = hn::Vec<I0T>;
using X1_VEC = hn::Vec<I1T>;
hn::Vec<OT> Func(const ptrdiff_t idx, const hn::Vec<I0T> x0,
const hn::Vec<I1T> x1, const Y_VEC y) {
return me()->Func(idx, x0, x1, y);
}
X0_VEC X0Init() { return me()->X0InitImpl(); }
X0_VEC X0InitImpl() { return hn::Zero(d_in0); }
X1_VEC X1Init() { return me()->X1InitImpl(); }
X1_VEC X1InitImpl() { return hn::Zero(d_in1); }
Y_VEC YInit() { return me()->YInitImpl(); }
Y_VEC YInitImpl() { return hn::Zero(d_out); }
X0_VEC Load0(const ptrdiff_t idx, const IN0_T* from) {
return me()->Load0Impl(idx, from);
}
X0_VEC Load0Impl(const ptrdiff_t idx, const IN0_T* from) {
return hn::LoadU(d_in0, from + idx);
}
X1_VEC Load1(const ptrdiff_t idx, const IN1_T* from) {
return me()->Load1Impl(idx, from);
}
X1_VEC Load1Impl(const ptrdiff_t idx, const IN1_T* from) {
return hn::LoadU(d_in1, from + idx);
}
// maskload can take in either a positive or negative number for `places`. if
// the number is positive, then it loads the top `places` values, and if it's
// negative, it loads the bottom |places| values. example: places = 3
// | o | o | o | x | x | x | x | x |
// example places = -3
// | x | x | x | x | x | o | o | o |
X0_VEC MaskLoad0(const ptrdiff_t idx, const IN0_T* from,
const ptrdiff_t places) {
return me()->MaskLoad0Impl(idx, from, places);
}
X0_VEC MaskLoad0Impl(const ptrdiff_t idx, const IN0_T* from,
const ptrdiff_t places) {
auto mask = hn::FirstN(d_in0, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
d_in0,
static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
if (places < 0) mask = maskneg;
return hn::MaskedLoad(mask, d_in0, from + idx);
}
hn::Vec<I1T> MaskLoad1(const ptrdiff_t idx, const IN1_T* from,
const ptrdiff_t places) {
return me()->MaskLoad1Impl(idx, from, places);
}
hn::Vec<I1T> MaskLoad1Impl(const ptrdiff_t idx, const IN1_T* from,
const ptrdiff_t places) {
auto mask = hn::FirstN(d_in1, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
d_in1,
static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
if (places < 0) mask = maskneg;
return hn::MaskedLoad(mask, d_in1, from + idx);
}
// store returns a bool that is `false` when
bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
return me()->StoreAndShortCircuitImpl(idx, to, x);
}
bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
hn::StoreU(x, d_out, to + idx);
return true;
}
ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
const ptrdiff_t places) {
return me()->MaskStoreImpl(idx, to, x, places);
}
ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
const ptrdiff_t places) {
auto mask = hn::FirstN(d_out, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
d_out,
static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
if (places < 0) mask = maskneg;
hn::BlendedStore(x, mask, d_out, to + idx);
return std::abs(places);
}
ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); }
ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) {
// default does nothing
(void)x;
(void)to;
return 0;
}
void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
me()->ReduceImpl(x0, x1, x2, y);
}
void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
// default does nothing
(void)x0;
(void)x1;
(void)x2;
(void)y;
}
};
template <class FUNC, typename IN_T, typename OUT_T>
inline void Unroller(FUNC& f, const IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y,
const ptrdiff_t n) {
auto xx = f.X0Init();
auto yy = f.YInit();
ptrdiff_t i = 0;
#if HWY_MEM_OPS_MIGHT_FAULT
constexpr auto lane_sz =
static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes());
if (n < lane_sz) {
const DFromV<decltype(yy)> d;
// this may not fit on the stack for HWY_RVV, but we do not reach this code
// there
HWY_ALIGN IN_T xtmp[static_cast<size_t>(lane_sz)];
HWY_ALIGN OUT_T ytmp[static_cast<size_t>(lane_sz)];
CopyBytes(x, xtmp, static_cast<size_t>(n) * sizeof(IN_T));
xx = f.MaskLoad(0, xtmp, n);
yy = f.Func(0, xx, yy);
Store(Zero(d), d, ytmp);
i += f.MaskStore(0, ytmp, yy, n);
i += f.Reduce(yy, ytmp);
CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T));
return;
}
#endif
const ptrdiff_t actual_lanes =
static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes());
if (n > 4 * actual_lanes) {
auto xx1 = f.X0Init();
auto yy1 = f.YInit();
auto xx2 = f.X0Init();
auto yy2 = f.YInit();
auto xx3 = f.X0Init();
auto yy3 = f.YInit();
while (i + 4 * actual_lanes - 1 < n) {
xx = f.Load(i, x);
i += actual_lanes;
xx1 = f.Load(i, x);
i += actual_lanes;
xx2 = f.Load(i, x);
i += actual_lanes;
xx3 = f.Load(i, x);
i -= 3 * actual_lanes;
yy = f.Func(i, xx, yy);
yy1 = f.Func(i + actual_lanes, xx1, yy1);
yy2 = f.Func(i + 2 * actual_lanes, xx2, yy2);
yy3 = f.Func(i + 3 * actual_lanes, xx3, yy3);
if (!f.StoreAndShortCircuit(i, y, yy)) return;
i += actual_lanes;
if (!f.StoreAndShortCircuit(i, y, yy1)) return;
i += actual_lanes;
if (!f.StoreAndShortCircuit(i, y, yy2)) return;
i += actual_lanes;
if (!f.StoreAndShortCircuit(i, y, yy3)) return;
i += actual_lanes;
}
f.Reduce(yy3, yy2, yy1, &yy);
}
while (i + actual_lanes - 1 < n) {
xx = f.Load(i, x);
yy = f.Func(i, xx, yy);
if (!f.StoreAndShortCircuit(i, y, yy)) return;
i += actual_lanes;
}
if (i != n) {
xx = f.MaskLoad(n - actual_lanes, x, i - n);
yy = f.Func(n - actual_lanes, xx, yy);
f.MaskStore(n - actual_lanes, y, yy, i - n);
}
f.Reduce(yy, y);
}
template <class FUNC, typename IN0_T, typename IN1_T, typename OUT_T>
inline void Unroller(FUNC& HWY_RESTRICT f, IN0_T* HWY_RESTRICT x0,
IN1_T* HWY_RESTRICT x1, OUT_T* HWY_RESTRICT y,
const ptrdiff_t n) {
const ptrdiff_t lane_sz =
static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes());
auto xx00 = f.X0Init();
auto xx10 = f.X1Init();
auto yy = f.YInit();
ptrdiff_t i = 0;
#if HWY_MEM_OPS_MIGHT_FAULT
if (n < lane_sz) {
const DFromV<decltype(yy)> d;
// this may not fit on the stack for HWY_RVV, but we do not reach this code
// there
constexpr auto max_lane_sz =
static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes());
HWY_ALIGN IN0_T xtmp0[static_cast<size_t>(max_lane_sz)];
HWY_ALIGN IN1_T xtmp1[static_cast<size_t>(max_lane_sz)];
HWY_ALIGN OUT_T ytmp[static_cast<size_t>(max_lane_sz)];
CopyBytes(x0, xtmp0, static_cast<size_t>(n) * sizeof(IN0_T));
CopyBytes(x1, xtmp1, static_cast<size_t>(n) * sizeof(IN1_T));
xx00 = f.MaskLoad0(0, xtmp0, n);
xx10 = f.MaskLoad1(0, xtmp1, n);
yy = f.Func(0, xx00, xx10, yy);
Store(Zero(d), d, ytmp);
i += f.MaskStore(0, ytmp, yy, n);
i += f.Reduce(yy, ytmp);
CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T));
return;
}
#endif
if (n > 4 * lane_sz) {
auto xx01 = f.X0Init();
auto xx11 = f.X1Init();
auto yy1 = f.YInit();
auto xx02 = f.X0Init();
auto xx12 = f.X1Init();
auto yy2 = f.YInit();
auto xx03 = f.X0Init();
auto xx13 = f.X1Init();
auto yy3 = f.YInit();
while (i + 4 * lane_sz - 1 < n) {
xx00 = f.Load0(i, x0);
xx10 = f.Load1(i, x1);
i += lane_sz;
xx01 = f.Load0(i, x0);
xx11 = f.Load1(i, x1);
i += lane_sz;
xx02 = f.Load0(i, x0);
xx12 = f.Load1(i, x1);
i += lane_sz;
xx03 = f.Load0(i, x0);
xx13 = f.Load1(i, x1);
i -= 3 * lane_sz;
yy = f.Func(i, xx00, xx10, yy);
yy1 = f.Func(i + lane_sz, xx01, xx11, yy1);
yy2 = f.Func(i + 2 * lane_sz, xx02, xx12, yy2);
yy3 = f.Func(i + 3 * lane_sz, xx03, xx13, yy3);
if (!f.StoreAndShortCircuit(i, y, yy)) return;
i += lane_sz;
if (!f.StoreAndShortCircuit(i, y, yy1)) return;
i += lane_sz;
if (!f.StoreAndShortCircuit(i, y, yy2)) return;
i += lane_sz;
if (!f.StoreAndShortCircuit(i, y, yy3)) return;
i += lane_sz;
}
f.Reduce(yy3, yy2, yy1, &yy);
}
while (i + lane_sz - 1 < n) {
xx00 = f.Load0(i, x0);
xx10 = f.Load1(i, x1);
yy = f.Func(i, xx00, xx10, yy);
if (!f.StoreAndShortCircuit(i, y, yy)) return;
i += lane_sz;
}
if (i != n) {
xx00 = f.MaskLoad0(n - lane_sz, x0, i - n);
xx10 = f.MaskLoad1(n - lane_sz, x1, i - n);
yy = f.Func(n - lane_sz, xx00, xx10, yy);
f.MaskStore(n - lane_sz, y, yy, i - n);
}
f.Reduce(yy, y);
}
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
#endif // HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_

View File

@@ -0,0 +1,395 @@
// Copyright 2020 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_
#define HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_
// Detects compiler and arch from predefined macros. Zero dependencies for
// inclusion by foreach_target.h.
// Add to #if conditions to prevent IDE from graying out code.
#if (defined __CDT_PARSER__) || (defined __INTELLISENSE__) || \
(defined Q_CREATOR_RUN) || (defined __CLANGD__) || \
(defined GROK_ELLIPSIS_BUILD)
#define HWY_IDE 1
#else
#define HWY_IDE 0
#endif
//------------------------------------------------------------------------------
// Compiler
// Actual MSVC, not clang-cl, which defines _MSC_VER but doesn't behave like
// MSVC in other aspects (e.g. HWY_DIAGNOSTICS).
#if defined(_MSC_VER) && !defined(__clang__)
#define HWY_COMPILER_MSVC _MSC_VER
#else
#define HWY_COMPILER_MSVC 0
#endif
#if defined(_MSC_VER) && defined(__clang__)
#define HWY_COMPILER_CLANGCL _MSC_VER
#else
#define HWY_COMPILER_CLANGCL 0
#endif
#ifdef __INTEL_COMPILER
#define HWY_COMPILER_ICC __INTEL_COMPILER
#else
#define HWY_COMPILER_ICC 0
#endif
#ifdef __INTEL_LLVM_COMPILER
#define HWY_COMPILER_ICX __INTEL_LLVM_COMPILER
#else
#define HWY_COMPILER_ICX 0
#endif
// HWY_COMPILER_GCC is a generic macro for all compilers implementing the GNU
// compiler extensions (eg. Clang, Intel...)
#ifdef __GNUC__
#define HWY_COMPILER_GCC (__GNUC__ * 100 + __GNUC_MINOR__)
#else
#define HWY_COMPILER_GCC 0
#endif
// Clang or clang-cl, not GCC.
#ifdef __clang__
// In case of Apple LLVM (whose version number is unrelated to that of LLVM) or
// an invalid version number, deduce it from the presence of warnings.
// Originally based on
// https://github.com/simd-everywhere/simde/blob/47d6e603de9d04ee05cdfbc57cf282a02be1bf2a/simde/simde-detect-clang.h#L59.
// Please send updates below to them as well, thanks!
#if defined(__apple_build_version__) || __clang_major__ >= 999
#if __has_warning("-Woverriding-option")
#define HWY_COMPILER_CLANG 1801
// No new warnings in 17.0, and Apple LLVM 15.3, which should be 1600, already
// has the unsafe_buffer_usage attribute, so we instead check for new builtins.
#elif __has_builtin(__builtin_nondeterministic_value)
#define HWY_COMPILER_CLANG 1700
#elif __has_attribute(nouwtable) // no new warnings in 16.0
#define HWY_COMPILER_CLANG 1600
#elif __has_warning("-Warray-parameter")
#define HWY_COMPILER_CLANG 1500
#elif __has_warning("-Wbitwise-instead-of-logical")
#define HWY_COMPILER_CLANG 1400
#elif __has_warning("-Wreserved-identifier")
#define HWY_COMPILER_CLANG 1300
#elif __has_warning("-Wformat-insufficient-args")
#define HWY_COMPILER_CLANG 1200
#elif __has_warning("-Wimplicit-const-int-float-conversion")
#define HWY_COMPILER_CLANG 1100
#elif __has_warning("-Wmisleading-indentation")
#define HWY_COMPILER_CLANG 1000
#elif defined(__FILE_NAME__)
#define HWY_COMPILER_CLANG 900
#elif __has_warning("-Wextra-semi-stmt") || \
__has_builtin(__builtin_rotateleft32)
#define HWY_COMPILER_CLANG 800
// For reasons unknown, XCode 10.3 (Apple LLVM version 10.0.1) is apparently
// based on Clang 7, but does not support the warning we test.
// See https://en.wikipedia.org/wiki/Xcode#Toolchain_versions and
// https://trac.macports.org/wiki/XcodeVersionInfo.
#elif __has_warning("-Wc++98-compat-extra-semi") || \
(defined(__apple_build_version__) && __apple_build_version__ >= 10010000)
#define HWY_COMPILER_CLANG 700
#else // Anything older than 7.0 is not recommended for Highway.
#define HWY_COMPILER_CLANG 600
#endif // __has_warning chain
#define HWY_COMPILER3_CLANG (HWY_COMPILER_CLANG * 100)
#else // use normal version
#define HWY_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__)
#define HWY_COMPILER3_CLANG \
(__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__)
#endif
#else // Not clang
#define HWY_COMPILER_CLANG 0
#define HWY_COMPILER3_CLANG 0
#endif
#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG && !HWY_COMPILER_ICC && \
!HWY_COMPILER_ICX
#define HWY_COMPILER_GCC_ACTUAL HWY_COMPILER_GCC
#else
#define HWY_COMPILER_GCC_ACTUAL 0
#endif
// More than one may be nonzero, but we want at least one.
#if 0 == (HWY_COMPILER_MSVC + HWY_COMPILER_CLANGCL + HWY_COMPILER_ICC + \
HWY_COMPILER_ICX + HWY_COMPILER_GCC + HWY_COMPILER_CLANG)
#error "Unsupported compiler"
#endif
// We should only detect one of these (only clang/clangcl/icx overlap)
#if 1 < (!!HWY_COMPILER_MSVC + (!!HWY_COMPILER_ICC & !HWY_COMPILER_ICX) + \
!!HWY_COMPILER_GCC_ACTUAL + \
!!(HWY_COMPILER_ICX | HWY_COMPILER_CLANGCL | HWY_COMPILER_CLANG))
#error "Detected multiple compilers"
#endif
//------------------------------------------------------------------------------
// Compiler features and C++ version
#ifdef __has_builtin
#define HWY_HAS_BUILTIN(name) __has_builtin(name)
#else
#define HWY_HAS_BUILTIN(name) 0
#endif
#ifdef __has_attribute
#define HWY_HAS_ATTRIBUTE(name) __has_attribute(name)
#else
#define HWY_HAS_ATTRIBUTE(name) 0
#endif
#ifdef __has_cpp_attribute
#define HWY_HAS_CPP_ATTRIBUTE(name) __has_cpp_attribute(name)
#else
#define HWY_HAS_CPP_ATTRIBUTE(name) 0
#endif
#ifdef __has_feature
#define HWY_HAS_FEATURE(name) __has_feature(name)
#else
#define HWY_HAS_FEATURE(name) 0
#endif
// NOTE: clang ~17 does not correctly handle wrapping __has_include in a macro.
#if HWY_COMPILER_MSVC && defined(_MSVC_LANG) && _MSVC_LANG > __cplusplus
#define HWY_CXX_LANG _MSVC_LANG
#else
#define HWY_CXX_LANG __cplusplus
#endif
#if defined(__cpp_constexpr) && __cpp_constexpr >= 201603L
#define HWY_CXX17_CONSTEXPR constexpr
#else
#define HWY_CXX17_CONSTEXPR
#endif
#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304L
#define HWY_CXX14_CONSTEXPR constexpr
#else
#define HWY_CXX14_CONSTEXPR
#endif
#if HWY_CXX_LANG >= 201703L
#define HWY_IF_CONSTEXPR if constexpr
#else
#define HWY_IF_CONSTEXPR if
#endif
//------------------------------------------------------------------------------
// Architecture
#if defined(__i386__) || defined(_M_IX86)
#define HWY_ARCH_X86_32 1
#else
#define HWY_ARCH_X86_32 0
#endif
#if defined(__x86_64__) || defined(_M_X64)
#define HWY_ARCH_X86_64 1
#else
#define HWY_ARCH_X86_64 0
#endif
#if HWY_ARCH_X86_32 && HWY_ARCH_X86_64
#error "Cannot have both x86-32 and x86-64"
#endif
#if HWY_ARCH_X86_32 || HWY_ARCH_X86_64
#define HWY_ARCH_X86 1
#else
#define HWY_ARCH_X86 0
#endif
#if defined(__powerpc64__) || defined(_M_PPC) || defined(__powerpc__)
#define HWY_ARCH_PPC 1
#else
#define HWY_ARCH_PPC 0
#endif
#if defined(__powerpc64__) || (HWY_ARCH_PPC && defined(__64BIT__))
#define HWY_ARCH_PPC_64 1
#else
#define HWY_ARCH_PPC_64 0
#endif
// aarch32 is currently not supported; please raise an issue if you want it.
#if defined(__ARM_ARCH_ISA_A64) || defined(__aarch64__) || defined(_M_ARM64)
#define HWY_ARCH_ARM_A64 1
#else
#define HWY_ARCH_ARM_A64 0
#endif
#if (defined(__ARM_ARCH) && __ARM_ARCH == 7) || (defined(_M_ARM) && _M_ARM == 7)
#define HWY_ARCH_ARM_V7 1
#else
#define HWY_ARCH_ARM_V7 0
#endif
#if HWY_ARCH_ARM_A64 && HWY_ARCH_ARM_V7
#error "Cannot have both A64 and V7"
#endif
// Any *supported* version of Arm, i.e. 7 or later
#if HWY_ARCH_ARM_A64 || HWY_ARCH_ARM_V7
#define HWY_ARCH_ARM 1
#else
#define HWY_ARCH_ARM 0
#endif
// Older than Armv7 (e.g. armel aka Armv5) => we do not support SIMD.
#if (defined(__arm__) || defined(_M_ARM)) && !HWY_ARCH_ARM
#define HWY_ARCH_ARM_OLD 1
#else
#define HWY_ARCH_ARM_OLD 0
#endif
#if defined(__EMSCRIPTEN__) || defined(__wasm__) || defined(__WASM__)
#define HWY_ARCH_WASM 1
#else
#define HWY_ARCH_WASM 0
#endif
#ifdef __riscv
#define HWY_ARCH_RISCV 1
#else
#define HWY_ARCH_RISCV 0
#endif
// DEPRECATED names; please use HWY_ARCH_RISCV instead.
#define HWY_ARCH_RVV HWY_ARCH_RISCV
#if HWY_ARCH_RISCV && defined(__riscv_xlen)
#if __riscv_xlen == 32
#define HWY_ARCH_RISCV_32 1
#else
#define HWY_ARCH_RISCV_32 0
#endif
#if __riscv_xlen == 64
#define HWY_ARCH_RISCV_64 1
#else
#define HWY_ARCH_RISCV_64 0
#endif
#else // !HWY_ARCH_RISCV || !defined(__riscv_xlen)
#define HWY_ARCH_RISCV_32 0
#define HWY_ARCH_RISCV_64 0
#endif // HWY_ARCH_RISCV && defined(__riscv_xlen)
#if HWY_ARCH_RISCV_32 && HWY_ARCH_RISCV_64
#error "Cannot have both RISCV_32 and RISCV_64"
#endif
#if defined(__s390x__)
#define HWY_ARCH_S390X 1
#else
#define HWY_ARCH_S390X 0
#endif
#if defined(__loongarch64__) || defined(__loongarch64) || \
(defined(__loongarch_grlen) && __loongarch_grlen == 64)
#define HWY_ARCH_LOONGARCH_64 1
#else
#define HWY_ARCH_LOONGARCH_64 0
#endif
#if defined(__loongarch__) && !HWY_ARCH_LOONGARCH_64
#define HWY_ARCH_LOONGARCH_32 1
#else
#define HWY_ARCH_LOONGARCH_32 0
#endif
#if HWY_ARCH_LOONGARCH_64 || HWY_ARCH_LOONGARCH_32
#define HWY_ARCH_LOONGARCH 1
#else
#define HWY_ARCH_LOONGARCH 0
#endif
// It is an error to detect multiple architectures at the same time, but OK to
// detect none of the above.
#if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_ARM_OLD + \
HWY_ARCH_WASM + HWY_ARCH_RISCV + HWY_ARCH_S390X + HWY_ARCH_LOONGARCH) > 1
#error "Must not detect more than one architecture"
#endif
//------------------------------------------------------------------------------
// Operating system
#if defined(_WIN32) || defined(_WIN64)
#define HWY_OS_WIN 1
#else
#define HWY_OS_WIN 0
#endif
#if defined(linux) || defined(__linux__)
#define HWY_OS_LINUX 1
#else
#define HWY_OS_LINUX 0
#endif
// iOS or Mac
#if defined(__APPLE__)
#define HWY_OS_APPLE 1
#else
#define HWY_OS_APPLE 0
#endif
#if defined(__FreeBSD__)
#define HWY_OS_FREEBSD 1
#else
#define HWY_OS_FREEBSD 0
#endif
// It is an error to detect multiple OSes at the same time, but OK to
// detect none of the above.
#if (HWY_OS_WIN + HWY_OS_LINUX + HWY_OS_APPLE + HWY_OS_FREEBSD) > 1
#error "Must not detect more than one OS"
#endif
//------------------------------------------------------------------------------
// Endianness
#if HWY_COMPILER_MSVC
#if HWY_ARCH_PPC && defined(_XBOX_VER) && _XBOX_VER >= 200
// XBox 360 is big-endian
#define HWY_IS_LITTLE_ENDIAN 0
#define HWY_IS_BIG_ENDIAN 1
#else
// All other targets supported by MSVC are little-endian
#define HWY_IS_LITTLE_ENDIAN 1
#define HWY_IS_BIG_ENDIAN 0
#endif // HWY_ARCH_PPC && defined(_XBOX_VER) && _XBOX_VER >= 200
#elif defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) && \
__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
#define HWY_IS_LITTLE_ENDIAN 1
#define HWY_IS_BIG_ENDIAN 0
#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define HWY_IS_LITTLE_ENDIAN 0
#define HWY_IS_BIG_ENDIAN 1
#else
#error "Unable to detect endianness or unsupported byte order"
#endif
#if (HWY_IS_LITTLE_ENDIAN + HWY_IS_BIG_ENDIAN) != 1
#error "Must only detect one byte order"
#endif
#endif // HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_

View File

@@ -0,0 +1,930 @@
// Copyright 2021 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_DETECT_TARGETS_H_
#define HIGHWAY_HWY_DETECT_TARGETS_H_
// Defines targets and chooses which to enable.
#include "third_party/highway/hwy/detect_compiler_arch.h"
//------------------------------------------------------------------------------
// Optional configuration
// See g3doc/quick_reference.md for documentation of these macros.
// Uncomment to override the default baseline determined from predefined macros:
// #define HWY_BASELINE_TARGETS (HWY_SSE4 | HWY_SCALAR)
// Uncomment to override the default blocklist:
// #define HWY_BROKEN_TARGETS HWY_AVX3
// Uncomment to definitely avoid generating those target(s):
// #define HWY_DISABLED_TARGETS HWY_SSE4
// Uncomment to avoid emitting BMI/BMI2/FMA instructions (allows generating
// AVX2 target for VMs which support AVX2 but not the other instruction sets)
// #define HWY_DISABLE_BMI2_FMA
// Uncomment to enable these on MSVC even if the predefined macros are not set.
// #define HWY_WANT_SSE2 1
// #define HWY_WANT_SSSE3 1
// #define HWY_WANT_SSE4 1
//------------------------------------------------------------------------------
// Targets
// Unique bit value for each target. A lower value is "better" (e.g. more lanes)
// than a higher value within the same group/platform - see HWY_STATIC_TARGET.
//
// All values are unconditionally defined so we can test HWY_TARGETS without
// first checking the HWY_ARCH_*.
//
// The C99 preprocessor evaluates #if expressions using intmax_t types. This
// holds at least 64 bits in practice (verified 2022-07-18 via Godbolt on
// 32-bit clang/GCC/MSVC compilers for x86/Arm7/AArch32/RISC-V/WASM). We now
// avoid overflow when computing HWY_TARGETS (subtracting one instead of
// left-shifting 2^62), but still do not use bit 63 because it is the sign bit.
// --------------------------- x86: 15 targets (+ one fallback)
// Bits 0..2 reserved (3 targets)
#define HWY_AVX10_2_512 (1LL << 3) // AVX10.2 with 512-bit vectors
#define HWY_AVX3_SPR (1LL << 4)
#define HWY_AVX10_2 (1LL << 5) // AVX10.2 with 256-bit vectors
// Currently `HWY_AVX3_DL` plus `AVX512BF16` and a special case for
// `CompressStore` (10x as fast, still useful on Zen5). We may later also use
// `VPCONFLICT`. Note that `VP2INTERSECT` is available in Zen5.
#define HWY_AVX3_ZEN4 (1LL << 6) // see HWY_WANT_AVX3_ZEN4 below
// Currently satisfiable by Ice Lake (`VNNI`, `VPCLMULQDQ`, `VPOPCNTDQ`,
// `VBMI`, `VBMI2`, `VAES`, `BITALG`, `GFNI`).
#define HWY_AVX3_DL (1LL << 7)
#define HWY_AVX3 (1LL << 8) // HWY_AVX2 plus AVX-512F/BW/CD/DQ/VL
#define HWY_AVX2 (1LL << 9) // HWY_SSE4 plus BMI2 + F16 + FMA
// Bit 10: reserved
#define HWY_SSE4 (1LL << 11) // SSE4.2 plus AES + CLMUL
#define HWY_SSSE3 (1LL << 12) // S-SSE3
// Bit 13: reserved for SSE3
#define HWY_SSE2 (1LL << 14)
// The highest bit in the HWY_TARGETS mask that a x86 target can have. Used for
// dynamic dispatch. All x86 target bits must be lower or equal to
// (1 << HWY_HIGHEST_TARGET_BIT_X86) and they can only use
// HWY_MAX_DYNAMIC_TARGETS in total.
#define HWY_HIGHEST_TARGET_BIT_X86 14
// --------------------------- Arm: 15 targets (+ one fallback)
// Bits 15..17 reserved (3 targets)
#define HWY_SVE2_128 (1LL << 18) // specialized (e.g. Neoverse V2/N2/N3)
#define HWY_SVE_256 (1LL << 19) // specialized (Neoverse V1)
// Bits 20-22 reserved for later SVE (3 targets)
#define HWY_SVE2 (1LL << 23)
#define HWY_SVE (1LL << 24)
// Bit 25 reserved for NEON
#define HWY_NEON_BF16 (1LL << 26) // fp16/dot/bf16 (e.g. Neoverse V2/N2/N3)
// Bit 27 reserved for NEON
#define HWY_NEON (1LL << 28) // Implies support for AES
#define HWY_NEON_WITHOUT_AES (1LL << 29)
#define HWY_HIGHEST_TARGET_BIT_ARM 29
#define HWY_ALL_NEON (HWY_NEON_WITHOUT_AES | HWY_NEON | HWY_NEON_BF16)
#define HWY_ALL_SVE (HWY_SVE | HWY_SVE2 | HWY_SVE_256 | HWY_SVE2_128)
// --------------------------- RISC-V: 9 targets (+ one fallback)
// Bits 30..36 reserved (7 targets)
#define HWY_RVV (1LL << 37)
// Bit 38 reserved
#define HWY_HIGHEST_TARGET_BIT_RVV 38
// --------------------------- LoongArch: 3 targets (+ one fallback)
// Bits 39 reserved (1 target)
#define HWY_LASX (1LL << 40)
#define HWY_LSX (1LL << 41)
#define HWY_HIGHEST_TARGET_BIT_LOONGARCH 41
// --------------------------- Future expansion: 1 target
// Bits 42 reserved
// --------------------------- IBM Power/ZSeries: 9 targets (+ one fallback)
// Bits 43..46 reserved (4 targets)
#define HWY_PPC10 (1LL << 47) // v3.1
#define HWY_PPC9 (1LL << 48) // v3.0
#define HWY_PPC8 (1LL << 49) // v2.07
#define HWY_Z15 (1LL << 50) // Z15
#define HWY_Z14 (1LL << 51) // Z14
#define HWY_HIGHEST_TARGET_BIT_PPC 51
#define HWY_ALL_PPC (HWY_PPC8 | HWY_PPC9 | HWY_PPC10)
// --------------------------- WebAssembly: 9 targets (+ one fallback)
// Bits 52..57 reserved (6 targets)
#define HWY_WASM_EMU256 (1LL << 58) // Experimental
#define HWY_WASM (1LL << 59)
// Bits 60 reserved
#define HWY_HIGHEST_TARGET_BIT_WASM 60
// --------------------------- Emulation: 2 targets
#define HWY_EMU128 (1LL << 61)
// We do not add/left-shift, so this will not overflow to a negative number.
#define HWY_SCALAR (1LL << 62)
#define HWY_HIGHEST_TARGET_BIT_SCALAR 62
// Do not use bit 63 - would be confusing to have negative numbers.
//------------------------------------------------------------------------------
// Set default blocklists
// Disabled means excluded from enabled at user's request. A separate config
// macro allows disabling without deactivating the blocklist below.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS 0
#endif
// Broken means excluded from enabled due to known compiler issues. We define
// separate HWY_BROKEN_* and then OR them together (more than one might apply).
#ifndef HWY_BROKEN_CLANG6 // allow override
// x86 clang-6: we saw multiple AVX2/3 compile errors and in one case invalid
// SSE4 codegen (possibly only for msan), so disable all those targets.
#if HWY_ARCH_X86 && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700)
#define HWY_BROKEN_CLANG6 (HWY_SSE4 | (HWY_SSE4 - 1))
// This entails a major speed reduction, so warn unless the user explicitly
// opts in to scalar-only.
#if !defined(HWY_COMPILE_ONLY_SCALAR)
#pragma message("x86 Clang <= 6: define HWY_COMPILE_ONLY_SCALAR or upgrade.")
#endif
#else
#define HWY_BROKEN_CLANG6 0
#endif
#endif // HWY_BROKEN_CLANG6
#ifndef HWY_BROKEN_32BIT // allow override
// 32-bit may fail to compile AVX2/3.
#if HWY_ARCH_X86_32
// GCC-13 is ok with AVX2:
#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1300)
#define HWY_BROKEN_32BIT (HWY_AVX3 | (HWY_AVX3 - 1))
#else
#define HWY_BROKEN_32BIT (HWY_AVX2 | (HWY_AVX2 - 1))
#endif
#else
#define HWY_BROKEN_32BIT 0
#endif
#endif // HWY_BROKEN_32BIT
#ifndef HWY_BROKEN_MSVC // allow override
// MSVC AVX3 support is buggy: https://github.com/Mysticial/Flops/issues/16
#if HWY_COMPILER_MSVC != 0
#define HWY_BROKEN_MSVC (HWY_AVX3 | (HWY_AVX3 - 1))
#else
#define HWY_BROKEN_MSVC 0
#endif
#endif // HWY_BROKEN_MSVC
#ifndef HWY_BROKEN_AVX3_DL_ZEN4 // allow override
// AVX3_DL and AVX3_ZEN4 require clang >= 7 (ensured above), gcc >= 8.1 or ICC
// 2021.
#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 801) || \
(HWY_COMPILER_ICC && HWY_COMPILER_ICC < 2021)
#define HWY_BROKEN_AVX3_DL_ZEN4 (HWY_AVX3_DL | HWY_AVX3_ZEN4)
#else
#define HWY_BROKEN_AVX3_DL_ZEN4 0
#endif
#endif // HWY_BROKEN_AVX3_DL_ZEN4
#ifndef HWY_BROKEN_AVX3_SPR // allow override
// AVX3_SPR requires clang >= 14, gcc >= 12, or ICC 2021.
#if (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1400) || \
(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200) || \
(HWY_COMPILER_ICC && HWY_COMPILER_ICC < 2021)
#define HWY_BROKEN_AVX3_SPR (HWY_AVX3_SPR)
#else
#define HWY_BROKEN_AVX3_SPR 0
#endif
#endif // HWY_BROKEN_AVX3_SPR
#ifndef HWY_BROKEN_ARM7_BIG_ENDIAN // allow override
// armv7be has not been tested and is not yet supported.
#if HWY_ARCH_ARM_V7 && HWY_IS_BIG_ENDIAN
#define HWY_BROKEN_ARM7_BIG_ENDIAN HWY_ALL_NEON
#else
#define HWY_BROKEN_ARM7_BIG_ENDIAN 0
#endif
#endif // HWY_BROKEN_ARM7_BIG_ENDIAN
#ifdef __ARM_NEON_FP
#define HWY_HAVE_NEON_FP __ARM_NEON_FP
#else
#define HWY_HAVE_NEON_FP 0
#endif
#ifndef HWY_BROKEN_ARM7_WITHOUT_VFP4 // allow override
// armv7-a without a detected vfpv4 is not supported
// (for example Cortex-A8, Cortex-A9)
// vfpv4 always have neon half-float _and_ FMA.
#if HWY_ARCH_ARM_V7 && (__ARM_ARCH_PROFILE == 'A') && \
!defined(__ARM_VFPV4__) && \
!((HWY_HAVE_NEON_FP & 0x2 /* half-float */) && (__ARM_FEATURE_FMA == 1))
#define HWY_BROKEN_ARM7_WITHOUT_VFP4 HWY_ALL_NEON
#else
#define HWY_BROKEN_ARM7_WITHOUT_VFP4 0
#endif
#endif // HWY_BROKEN_ARM7_WITHOUT_VFP4
#ifndef HWY_BROKEN_NEON_BF16 // allow override
// HWY_NEON_BF16 requires recent compilers.
#if (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1700) || \
(HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 1302)
#define HWY_BROKEN_NEON_BF16 (HWY_NEON_BF16)
#else
#define HWY_BROKEN_NEON_BF16 0
#endif
#endif // HWY_BROKEN_NEON_BF16
// SVE[2] require recent clang or gcc versions.
#ifndef HWY_BROKEN_SVE // allow override
// GCC 10+. Clang 19 still has many test failures for SVE. No Apple CPU (at
// least up to and including M4 and A18) has SVE.
#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2000) || \
(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \
HWY_OS_APPLE
#define HWY_BROKEN_SVE (HWY_SVE | HWY_SVE_256)
#else
#define HWY_BROKEN_SVE 0
#endif
#endif // HWY_BROKEN_SVE
#ifndef HWY_BROKEN_SVE2 // allow override
// Clang 19 still has many test failures for SVE2.
#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2000) || \
(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \
HWY_OS_APPLE
#define HWY_BROKEN_SVE2 (HWY_SVE2 | HWY_SVE2_128)
#else
#define HWY_BROKEN_SVE2 0
#endif
#endif // HWY_BROKEN_SVE2
#ifndef HWY_BROKEN_PPC10 // allow override
#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1100)
// GCC 10 supports the -mcpu=power10 option but does not support the PPC10
// vector intrinsics
#define HWY_BROKEN_PPC10 (HWY_PPC10)
#elif HWY_ARCH_PPC && HWY_IS_BIG_ENDIAN && \
((HWY_COMPILER3_CLANG && HWY_COMPILER3_CLANG < 160001) || \
(HWY_COMPILER_GCC_ACTUAL >= 1200 && HWY_COMPILER_GCC_ACTUAL <= 1203) || \
(HWY_COMPILER_GCC_ACTUAL >= 1300 && HWY_COMPILER_GCC_ACTUAL <= 1301))
// GCC 12.0 through 12.3 and GCC 13.0 through 13.1 have a compiler bug where the
// vsldoi instruction is sometimes incorrectly optimized out (and this causes
// some of the Highway unit tests to fail on big-endian PPC10). Details about
// this compiler bug can be found at
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109069, and this bug will be
// fixed in the upcoming GCC 12.4 and 13.2 releases.
// Clang 16.0.0 and earlier (but not Clang 16.0.1 and later) have a compiler
// bug in the LLVM DAGCombiner that causes a zero-extend followed by an
// element insert into a vector, followed by a vector shuffle to be incorrectly
// optimized on big-endian PPC (and which caused some of the Highway unit tests
// to fail on big-endian PPC10).
// Details about this bug, which has already been fixed in Clang 16.0.1 and
// later, can be found at https://github.com/llvm/llvm-project/issues/61315.
#define HWY_BROKEN_PPC10 (HWY_PPC10)
#else
#define HWY_BROKEN_PPC10 0
#endif
#endif // HWY_BROKEN_PPC10
#ifndef HWY_BROKEN_PPC_32BIT // allow override
// PPC8/PPC9/PPC10 targets may fail to compile on 32-bit PowerPC
#if HWY_ARCH_PPC && !HWY_ARCH_PPC_64
#define HWY_BROKEN_PPC_32BIT (HWY_PPC8 | HWY_PPC9 | HWY_PPC10)
#else
#define HWY_BROKEN_PPC_32BIT 0
#endif
#endif // HWY_BROKEN_PPC_32BIT
#ifndef HWY_BROKEN_RVV // allow override
// HWY_RVV fails to compile with GCC < 13 or Clang < 16.
#if HWY_ARCH_RISCV && \
((HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600) || \
(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1300))
#define HWY_BROKEN_RVV (HWY_RVV)
#else
#define HWY_BROKEN_RVV 0
#endif
#endif // HWY_BROKEN_RVV
#ifndef HWY_BROKEN_LOONGARCH // allow override
// HWY_LSX/HWY_LASX require GCC 14 or Clang 18.
#if HWY_ARCH_LOONGARCH && \
((HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1800) || \
(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400))
#define HWY_BROKEN_LOONGARCH (HWY_LSX | HWY_LASX)
#else
#define HWY_BROKEN_LOONGARCH 0
#endif
#endif // HWY_BROKEN_LOONGARCH
#ifndef HWY_BROKEN_Z14 // allow override
#if HWY_ARCH_S390X
#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1900
// Clang 18 and earlier have bugs with some ZVector intrinsics
#define HWY_BROKEN_Z14 (HWY_Z14 | HWY_Z15)
#elif HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
// Z15 target requires GCC 9 or later
#define HWY_BROKEN_Z14 (HWY_Z15)
#else
#define HWY_BROKEN_Z14 0
#endif
#else // !HWY_ARCH_S390X
#define HWY_BROKEN_Z14 0
#endif // HWY_ARCH_S390X
#endif // HWY_BROKEN_Z14
// Allow the user to override this without any guarantee of success.
#ifndef HWY_BROKEN_TARGETS
#define HWY_BROKEN_TARGETS \
(HWY_BROKEN_CLANG6 | HWY_BROKEN_32BIT | HWY_BROKEN_MSVC | \
HWY_BROKEN_AVX3_DL_ZEN4 | HWY_BROKEN_AVX3_SPR | \
HWY_BROKEN_ARM7_BIG_ENDIAN | HWY_BROKEN_ARM7_WITHOUT_VFP4 | \
HWY_BROKEN_NEON_BF16 | HWY_BROKEN_SVE | HWY_BROKEN_SVE2 | \
HWY_BROKEN_PPC10 | HWY_BROKEN_PPC_32BIT | HWY_BROKEN_RVV | \
HWY_BROKEN_LOONGARCH | HWY_BROKEN_Z14)
#endif // HWY_BROKEN_TARGETS
// Enabled means not disabled nor blocklisted.
#define HWY_ENABLED(targets) \
((targets) & ~((HWY_DISABLED_TARGETS) | (HWY_BROKEN_TARGETS)))
// Opt-out for EMU128 (affected by a GCC bug on multiple arches, fixed in 12.3:
// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=106322). An issue still
// remains with 13.2, see #1683. This is separate from HWY_BROKEN_TARGETS
// because it affects the fallback target, which must always be enabled. If 1,
// we instead choose HWY_SCALAR even without HWY_COMPILE_ONLY_SCALAR being set.
#if !defined(HWY_BROKEN_EMU128) // allow overriding
#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400) || \
defined(HWY_NO_LIBCXX)
#define HWY_BROKEN_EMU128 1
#else
#define HWY_BROKEN_EMU128 0
#endif
#endif // HWY_BROKEN_EMU128
//------------------------------------------------------------------------------
// Detect baseline targets using predefined macros
// Baseline means the targets for which the compiler is allowed to generate
// instructions, implying the target CPU would have to support them. This does
// not take the blocklist into account.
#if defined(HWY_COMPILE_ONLY_SCALAR) || HWY_BROKEN_EMU128
#define HWY_BASELINE_SCALAR HWY_SCALAR
#else
#define HWY_BASELINE_SCALAR HWY_EMU128
#endif
// Also check HWY_ARCH to ensure that simulating unknown platforms ends up with
// HWY_TARGET == HWY_BASELINE_SCALAR.
#if HWY_ARCH_WASM && defined(__wasm_simd128__)
#if defined(HWY_WANT_WASM2)
#define HWY_BASELINE_WASM HWY_WASM_EMU256
#else
#define HWY_BASELINE_WASM HWY_WASM
#endif // HWY_WANT_WASM2
#else
#define HWY_BASELINE_WASM 0
#endif
// GCC or Clang.
#if HWY_ARCH_PPC && HWY_COMPILER_GCC && defined(__ALTIVEC__) && \
defined(__VSX__) && defined(__POWER8_VECTOR__) && \
(defined(__CRYPTO__) || defined(HWY_DISABLE_PPC8_CRYPTO))
#define HWY_BASELINE_PPC8 HWY_PPC8
#else
#define HWY_BASELINE_PPC8 0
#endif
#if HWY_BASELINE_PPC8 != 0 && defined(__POWER9_VECTOR__)
#define HWY_BASELINE_PPC9 HWY_PPC9
#else
#define HWY_BASELINE_PPC9 0
#endif
#if HWY_BASELINE_PPC9 != 0 && \
(defined(_ARCH_PWR10) || defined(__POWER10_VECTOR__))
#define HWY_BASELINE_PPC10 HWY_PPC10
#else
#define HWY_BASELINE_PPC10 0
#endif
#if HWY_ARCH_S390X && defined(__VEC__) && defined(__ARCH__) && __ARCH__ >= 12
#define HWY_BASELINE_Z14 HWY_Z14
#else
#define HWY_BASELINE_Z14 0
#endif
#if HWY_BASELINE_Z14 && __ARCH__ >= 13
#define HWY_BASELINE_Z15 HWY_Z15
#else
#define HWY_BASELINE_Z15 0
#endif
#define HWY_BASELINE_SVE2 0
#define HWY_BASELINE_SVE 0
#define HWY_BASELINE_NEON 0
#if HWY_ARCH_ARM
// Also check compiler version as done for HWY_ATTAINABLE_SVE2 because the
// static target (influenced here) must be one of the attainable targets.
#if defined(__ARM_FEATURE_SVE2) && \
(HWY_COMPILER_CLANG >= 1400 || HWY_COMPILER_GCC_ACTUAL >= 1200)
#undef HWY_BASELINE_SVE2 // was 0, will be re-defined
// If user specified -msve-vector-bits=128, they assert the vector length is
// 128 bits and we should use the HWY_SVE2_128 (more efficient for some ops).
#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 128
#define HWY_BASELINE_SVE2 HWY_SVE2_128
// Otherwise we're not sure what the vector length will be. The baseline must be
// unconditionally valid, so we can only assume HWY_SVE2. However, when running
// on a CPU with 128-bit vectors, user code that supports dynamic dispatch will
// still benefit from HWY_SVE2_128 because we add it to HWY_ATTAINABLE_TARGETS.
#else
#define HWY_BASELINE_SVE2 HWY_SVE2
#endif // __ARM_FEATURE_SVE_BITS
#endif // __ARM_FEATURE_SVE2
#if defined(__ARM_FEATURE_SVE) && \
(HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800)
#undef HWY_BASELINE_SVE // was 0, will be re-defined
// See above. If user-specified vector length matches our optimization, use it.
#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 256
#define HWY_BASELINE_SVE HWY_SVE_256
#else
#define HWY_BASELINE_SVE HWY_SVE
#endif // __ARM_FEATURE_SVE_BITS
#endif // __ARM_FEATURE_SVE
// GCC 4.5.4 only defines __ARM_NEON__; 5.4 defines both.
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#undef HWY_BASELINE_NEON
#if defined(__ARM_FEATURE_AES) && \
defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \
defined(__ARM_FEATURE_DOTPROD) && \
defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
#define HWY_BASELINE_NEON HWY_ALL_NEON
#elif defined(__ARM_FEATURE_AES)
#define HWY_BASELINE_NEON (HWY_NEON_WITHOUT_AES | HWY_NEON)
#else
#define HWY_BASELINE_NEON (HWY_NEON_WITHOUT_AES)
#endif // __ARM_FEATURE*
#endif // __ARM_NEON
#endif // HWY_ARCH_ARM
// Special handling for MSVC because it has fewer predefined macros:
#if HWY_COMPILER_MSVC
#if HWY_ARCH_X86_32
#if _M_IX86_FP >= 2
#define HWY_CHECK_SSE2 1
#else
#define HWY_CHECK_SSE2 0
#endif
#elif HWY_ARCH_X86_64
#define HWY_CHECK_SSE2 1
#else
#define HWY_CHECK_SSE2 0
#endif
// 1) We can only be sure SSSE3/SSE4 are enabled if AVX is:
// https://stackoverflow.com/questions/18563978/.
#if defined(__AVX__)
#define HWY_CHECK_SSSE3 1
#define HWY_CHECK_SSE4 1
#else
#define HWY_CHECK_SSSE3 0
#define HWY_CHECK_SSE4 0
#endif
// 2) Cannot check for PCLMUL/AES and BMI2/FMA/F16C individually; we assume
// PCLMUL/AES are available if SSE4 is, and BMI2/FMA/F16C if AVX2 is.
#define HWY_CHECK_PCLMUL_AES 1
#define HWY_CHECK_BMI2_FMA 1
#define HWY_CHECK_F16C 1
#else // non-MSVC
#if defined(__SSE2__)
#define HWY_CHECK_SSE2 1
#else
#define HWY_CHECK_SSE2 0
#endif
#if defined(__SSSE3__)
#define HWY_CHECK_SSSE3 1
#else
#define HWY_CHECK_SSSE3 0
#endif
#if defined(__SSE4_1__) && defined(__SSE4_2__)
#define HWY_CHECK_SSE4 1
#else
#define HWY_CHECK_SSE4 0
#endif
// If these are disabled, they should not gate the availability of SSE4/AVX2.
#if defined(HWY_DISABLE_PCLMUL_AES) || (defined(__PCLMUL__) && defined(__AES__))
#define HWY_CHECK_PCLMUL_AES 1
#else
#define HWY_CHECK_PCLMUL_AES 0
#endif
#if defined(HWY_DISABLE_BMI2_FMA) || (defined(__BMI2__) && defined(__FMA__))
#define HWY_CHECK_BMI2_FMA 1
#else
#define HWY_CHECK_BMI2_FMA 0
#endif
#if defined(HWY_DISABLE_F16C) || defined(__F16C__)
#define HWY_CHECK_F16C 1
#else
#define HWY_CHECK_F16C 0
#endif
#endif // non-MSVC
#if HWY_ARCH_X86 && \
((defined(HWY_WANT_SSE2) && HWY_WANT_SSE2) || HWY_CHECK_SSE2)
#define HWY_BASELINE_SSE2 HWY_SSE2
#else
#define HWY_BASELINE_SSE2 0
#endif
#if HWY_ARCH_X86 && \
((defined(HWY_WANT_SSSE3) && HWY_WANT_SSSE3) || HWY_CHECK_SSSE3)
#define HWY_BASELINE_SSSE3 HWY_SSSE3
#else
#define HWY_BASELINE_SSSE3 0
#endif
#if HWY_ARCH_X86 && ((defined(HWY_WANT_SSE4) && HWY_WANT_SSE4) || \
(HWY_CHECK_SSE4 && HWY_CHECK_PCLMUL_AES))
#define HWY_BASELINE_SSE4 HWY_SSE4
#else
#define HWY_BASELINE_SSE4 0
#endif
#if HWY_BASELINE_SSE4 != 0 && HWY_CHECK_BMI2_FMA && HWY_CHECK_F16C && \
defined(__AVX2__)
#define HWY_BASELINE_AVX2 HWY_AVX2
#else
#define HWY_BASELINE_AVX2 0
#endif
// Require everything in AVX2 plus AVX-512 flags (also set by MSVC)
#if HWY_BASELINE_AVX2 != 0 && defined(__AVX512F__) && defined(__AVX512BW__) && \
defined(__AVX512DQ__) && defined(__AVX512VL__) && \
((!HWY_COMPILER_GCC_ACTUAL && !HWY_COMPILER_CLANG) || \
HWY_COMPILER_GCC_ACTUAL < 1400 || HWY_COMPILER_CLANG < 1800 || \
defined(__EVEX512__))
#define HWY_BASELINE_AVX3 HWY_AVX3
#else
#define HWY_BASELINE_AVX3 0
#endif
// TODO(janwas): not yet known whether these will be set by MSVC
#if HWY_BASELINE_AVX3 != 0 && defined(__AVX512VNNI__) && defined(__VAES__) && \
defined(__VPCLMULQDQ__) && defined(__AVX512VBMI__) && \
defined(__AVX512VBMI2__) && defined(__AVX512VPOPCNTDQ__) && \
defined(__AVX512BITALG__)
#define HWY_BASELINE_AVX3_DL HWY_AVX3_DL
#else
#define HWY_BASELINE_AVX3_DL 0
#endif
// The ZEN4-optimized AVX3 target is numerically lower than AVX3_DL and is thus
// considered better. Do not enable it unless the user explicitly requests it -
// we do not want to choose the ZEN4 path on Intel because it could be slower.
#if defined(HWY_WANT_AVX3_ZEN4) && HWY_BASELINE_AVX3_DL != 0
#define HWY_BASELINE_AVX3_ZEN4 HWY_AVX3_ZEN4
#else
#define HWY_BASELINE_AVX3_ZEN4 0
#endif
#if HWY_BASELINE_AVX2 != 0 && defined(__AVX10_2__)
#define HWY_BASELINE_AVX10_2 HWY_AVX10_2
#else
#define HWY_BASELINE_AVX10_2 0
#endif
#if HWY_BASELINE_AVX3_DL != 0 && defined(__AVX512BF16__) && \
defined(__AVX512FP16__)
#define HWY_BASELINE_AVX3_SPR HWY_AVX3_SPR
#else
#define HWY_BASELINE_AVX3_SPR 0
#endif
#if HWY_BASELINE_AVX3_SPR != 0 && defined(__AVX10_2_512__)
#define HWY_BASELINE_AVX10_2_512 HWY_AVX10_2_512
#else
#define HWY_BASELINE_AVX10_2_512 0
#endif
// RVV requires intrinsics 0.11 or later, see #1156.
#if HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \
__riscv_v_intrinsic >= 11000
#define HWY_BASELINE_RVV HWY_RVV
#else
#define HWY_BASELINE_RVV 0
#endif
#if HWY_ARCH_LOONGARCH && defined(__loongarch_sx) && defined(__loongarch_asx)
#define HWY_BASELINE_LOONGARCH (HWY_LSX | HWY_LASX)
#elif HWY_ARCH_LOONGARCH && defined(__loongarch_sx)
#define HWY_BASELINE_LOONGARCH (HWY_LSX)
#else
#define HWY_BASELINE_LOONGARCH 0
#endif
// Allow the user to override this without any guarantee of success.
#ifndef HWY_BASELINE_TARGETS
#define HWY_BASELINE_TARGETS \
(HWY_BASELINE_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | \
HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10 | HWY_BASELINE_Z14 | \
HWY_BASELINE_Z15 | HWY_BASELINE_SVE2 | HWY_BASELINE_SVE | \
HWY_BASELINE_NEON | HWY_BASELINE_SSE2 | HWY_BASELINE_SSSE3 | \
HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | HWY_BASELINE_AVX3 | \
HWY_BASELINE_AVX3_DL | HWY_BASELINE_AVX3_ZEN4 | HWY_BASELINE_AVX10_2 | \
HWY_BASELINE_AVX3_SPR | HWY_BASELINE_AVX10_2_512 | HWY_BASELINE_RVV | \
HWY_BASELINE_LOONGARCH)
#endif // HWY_BASELINE_TARGETS
//------------------------------------------------------------------------------
// Choose target for static dispatch
#define HWY_ENABLED_BASELINE HWY_ENABLED(HWY_BASELINE_TARGETS)
#if HWY_ENABLED_BASELINE == 0
#error "At least one baseline target must be defined and enabled"
#endif
// Best baseline, used for static dispatch. This is the least-significant 1-bit
// within HWY_ENABLED_BASELINE and lower bit values imply "better".
#define HWY_STATIC_TARGET (HWY_ENABLED_BASELINE & -HWY_ENABLED_BASELINE)
// Start by assuming static dispatch. If we later use dynamic dispatch, this
// will be defined to other targets during the multiple-inclusion, and finally
// return to the initial value. Defining this outside begin/end_target ensures
// inl headers successfully compile by themselves (required by Bazel).
#define HWY_TARGET HWY_STATIC_TARGET
//------------------------------------------------------------------------------
// Choose targets for dynamic dispatch according to one of four policies
// TODO: remove once HWY_LSX is actually supported
#if HWY_ARCH_LOONGARCH && !defined(HWY_COMPILE_ONLY_SCALAR) && \
!defined(HWY_COMPILE_ONLY_EMU128)
#undef HWY_COMPILE_ONLY_STATIC
#define HWY_COMPILE_ONLY_EMU128
#endif
#if 1 < (defined(HWY_COMPILE_ONLY_SCALAR) + defined(HWY_COMPILE_ONLY_EMU128) + \
defined(HWY_COMPILE_ONLY_STATIC))
#error "Can only define one of HWY_COMPILE_ONLY_{SCALAR|EMU128|STATIC} - bug?"
#endif
// Defining one of HWY_COMPILE_ONLY_* will trump HWY_COMPILE_ALL_ATTAINABLE.
#ifndef HWY_HAVE_ASM_HWCAP // allow override
#ifdef TOOLCHAIN_MISS_ASM_HWCAP_H
#define HWY_HAVE_ASM_HWCAP 0 // CMake failed to find the header
#elif defined(__has_include) // note: wrapper macro fails on Clang ~17
// clang-format off
#if __has_include(<asm/hwcap.h>)
// clang-format on
#define HWY_HAVE_ASM_HWCAP 1 // header present
#else
#define HWY_HAVE_ASM_HWCAP 0 // header not present
#endif // __has_include
#else // compiler lacks __has_include
#define HWY_HAVE_ASM_HWCAP 0
#endif
#endif // HWY_HAVE_ASM_HWCAP
#ifndef HWY_HAVE_AUXV // allow override
#ifdef TOOLCHAIN_MISS_SYS_AUXV_H
#define HWY_HAVE_AUXV 0 // CMake failed to find the header
// glibc 2.16 added auxv, but checking for that requires features.h, and we do
// not want to include system headers here. Instead check for the header
// directly, which has been supported at least since GCC 5.4 and Clang 3.
#elif defined(__has_include) // note: wrapper macro fails on Clang ~17
// clang-format off
#if __has_include(<sys/auxv.h>)
// clang-format on
#define HWY_HAVE_AUXV 1 // header present
#else
#define HWY_HAVE_AUXV 0 // header not present
#endif // __has_include
#else // compiler lacks __has_include
#define HWY_HAVE_AUXV 0
#endif
#endif // HWY_HAVE_AUXV
#ifndef HWY_HAVE_RUNTIME_DISPATCH_RVV // allow override
// The riscv_vector.h in Clang 16-18 requires compiler flags, and 19 still has
// some missing intrinsics, see
// https://github.com/llvm/llvm-project/issues/56592. GCC 13.3 also has an
// #error check, whereas 14.1 fails with "argument type 'vuint16m8_t' requires
// the V ISA extension": https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115325.
#if HWY_ARCH_RISCV && HWY_COMPILER_CLANG >= 1900 && 0
#define HWY_HAVE_RUNTIME_DISPATCH_RVV 1
#else
#define HWY_HAVE_RUNTIME_DISPATCH_RVV 0
#endif
#endif // HWY_HAVE_RUNTIME_DISPATCH_RVV
#ifndef HWY_HAVE_RUNTIME_DISPATCH_APPLE // allow override
#if HWY_ARCH_ARM_A64 && HWY_OS_APPLE && \
(HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1700)
#define HWY_HAVE_RUNTIME_DISPATCH_APPLE 1
#else
#define HWY_HAVE_RUNTIME_DISPATCH_APPLE 0
#endif
#endif // HWY_HAVE_RUNTIME_DISPATCH_APPLE
#ifndef HWY_HAVE_RUNTIME_DISPATCH_LINUX // allow override
#if (HWY_ARCH_ARM || HWY_ARCH_PPC || HWY_ARCH_S390X) && HWY_OS_LINUX && \
(HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1700) && HWY_HAVE_AUXV
#define HWY_HAVE_RUNTIME_DISPATCH_LINUX 1
#else
#define HWY_HAVE_RUNTIME_DISPATCH_LINUX 0
#endif
#endif // HWY_HAVE_RUNTIME_DISPATCH_LINUX
// Allow opting out, and without a guarantee of success, opting-in.
#ifndef HWY_HAVE_RUNTIME_DISPATCH
// Clang, GCC and MSVC allow OS-independent runtime dispatch on x86.
#if HWY_ARCH_X86 || HWY_HAVE_RUNTIME_DISPATCH_RVV || \
HWY_HAVE_RUNTIME_DISPATCH_APPLE || HWY_HAVE_RUNTIME_DISPATCH_LINUX
#define HWY_HAVE_RUNTIME_DISPATCH 1
#else
#define HWY_HAVE_RUNTIME_DISPATCH 0
#endif
#endif // HWY_HAVE_RUNTIME_DISPATCH
#if HWY_ARCH_ARM_A64 && HWY_HAVE_RUNTIME_DISPATCH
#define HWY_ATTAINABLE_NEON HWY_ALL_NEON
#elif HWY_ARCH_ARM // static dispatch, or HWY_ARCH_ARM_V7
#define HWY_ATTAINABLE_NEON (HWY_BASELINE_NEON)
#else
#define HWY_ATTAINABLE_NEON 0
#endif
#if HWY_ARCH_ARM_A64 && \
(HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800) && \
(HWY_HAVE_RUNTIME_DISPATCH || \
(HWY_ENABLED_BASELINE & (HWY_SVE | HWY_SVE_256)))
#define HWY_ATTAINABLE_SVE (HWY_SVE | HWY_SVE_256)
#else
#define HWY_ATTAINABLE_SVE 0
#endif
#if HWY_ARCH_ARM_A64 && \
(HWY_COMPILER_CLANG >= 1400 || HWY_COMPILER_GCC_ACTUAL >= 1200) && \
(HWY_HAVE_RUNTIME_DISPATCH || \
(HWY_ENABLED_BASELINE & (HWY_SVE2 | HWY_SVE2_128)))
#define HWY_ATTAINABLE_SVE2 (HWY_SVE2 | HWY_SVE2_128)
#else
#define HWY_ATTAINABLE_SVE2 0
#endif
#if HWY_ARCH_PPC && defined(__ALTIVEC__) && \
(!HWY_COMPILER_CLANG || HWY_BASELINE_PPC8 != 0)
#if (HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10) && \
!defined(HWY_SKIP_NON_BEST_BASELINE)
// On POWER with -m flags, we get compile errors (#1707) for targets older than
// the baseline specified via -m, so only generate the static target and better.
// Note that some Linux distros actually do set POWER9 as the baseline.
// This works by skipping case 3 below, so case 4 is reached.
#define HWY_SKIP_NON_BEST_BASELINE
#endif
#define HWY_ATTAINABLE_PPC (HWY_PPC8 | HWY_PPC9 | HWY_PPC10)
#else
#define HWY_ATTAINABLE_PPC 0
#endif
#if HWY_ARCH_S390X && HWY_BASELINE_Z14 != 0
#define HWY_ATTAINABLE_S390X (HWY_Z14 | HWY_Z15)
#else
#define HWY_ATTAINABLE_S390X 0
#endif
#if HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH
#define HWY_ATTAINABLE_RISCV HWY_RVV
#else
#define HWY_ATTAINABLE_RISCV HWY_BASELINE_RVV
#endif
#if HWY_ARCH_LOONGARCH && HWY_HAVE_RUNTIME_DISPATCH
#define HWY_ATTAINABLE_LOONGARCH (HWY_LSX | HWY_LASX)
#else
#define HWY_ATTAINABLE_LOONGARCH HWY_BASELINE_LOONGARCH
#endif
#ifndef HWY_ATTAINABLE_TARGETS_X86 // allow override
#if HWY_COMPILER_MSVC && defined(HWY_SLOW_MSVC)
// Fewer targets for faster builds.
#define HWY_ATTAINABLE_TARGETS_X86 \
HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_STATIC_TARGET | HWY_AVX2)
#else // !HWY_COMPILER_MSVC
#define HWY_ATTAINABLE_TARGETS_X86 \
HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | \
HWY_AVX2 | HWY_AVX3 | HWY_AVX3_DL | HWY_AVX3_ZEN4 | \
HWY_AVX3_SPR)
#endif // !HWY_COMPILER_MSVC
#endif // HWY_ATTAINABLE_TARGETS_X86
// Attainable means enabled and the compiler allows intrinsics (even when not
// allowed to auto-vectorize). Used in 3 and 4.
#if HWY_ARCH_X86
#define HWY_ATTAINABLE_TARGETS HWY_ATTAINABLE_TARGETS_X86
#elif HWY_ARCH_ARM
#define HWY_ATTAINABLE_TARGETS \
HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_NEON | HWY_ATTAINABLE_SVE | \
HWY_ATTAINABLE_SVE2)
#elif HWY_ARCH_PPC
#define HWY_ATTAINABLE_TARGETS \
HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_PPC)
#elif HWY_ARCH_S390X
#define HWY_ATTAINABLE_TARGETS \
HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_S390X)
#elif HWY_ARCH_RISCV
#define HWY_ATTAINABLE_TARGETS \
HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_RISCV)
#elif HWY_ARCH_LOONGARCH
#define HWY_ATTAINABLE_TARGETS \
HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_LOONGARCH)
#else
#define HWY_ATTAINABLE_TARGETS (HWY_ENABLED_BASELINE)
#endif // HWY_ARCH_*
// 1) For older compilers: avoid SIMD intrinsics, but still support all ops.
#if defined(HWY_COMPILE_ONLY_EMU128) && !HWY_BROKEN_EMU128
#undef HWY_STATIC_TARGET
#define HWY_STATIC_TARGET HWY_EMU128 // override baseline
#define HWY_TARGETS HWY_EMU128
// 1b) HWY_SCALAR is less capable than HWY_EMU128 (which supports all ops), but
// we currently still support it for backwards compatibility.
#elif defined(HWY_COMPILE_ONLY_SCALAR) || \
(defined(HWY_COMPILE_ONLY_EMU128) && HWY_BROKEN_EMU128)
#undef HWY_STATIC_TARGET
#define HWY_STATIC_TARGET HWY_SCALAR // override baseline
#define HWY_TARGETS HWY_SCALAR
// 2) For forcing static dispatch without code changes (removing HWY_EXPORT)
#elif defined(HWY_COMPILE_ONLY_STATIC)
#define HWY_TARGETS HWY_STATIC_TARGET
// 3) For tests: include all attainable targets (in particular: scalar)
#elif (defined(HWY_COMPILE_ALL_ATTAINABLE) || defined(HWY_IS_TEST)) && \
!defined(HWY_SKIP_NON_BEST_BASELINE)
#define HWY_TARGETS HWY_ATTAINABLE_TARGETS
// 4) Default: attainable WITHOUT non-best baseline. This reduces code size by
// excluding superseded targets, in particular scalar. Note: HWY_STATIC_TARGET
// may be 2^62 (HWY_SCALAR), so we must not left-shift/add it. Subtracting one
// sets all lower bits (better targets), then we also include the static target.
#else
#define HWY_TARGETS \
(HWY_ATTAINABLE_TARGETS & ((HWY_STATIC_TARGET - 1LL) | HWY_STATIC_TARGET))
#endif // target policy
// HWY_ONCE and the multiple-inclusion mechanism rely on HWY_STATIC_TARGET being
// one of the dynamic targets. This also implies HWY_TARGETS != 0 and
// (HWY_TARGETS & HWY_ENABLED_BASELINE) != 0.
#if (HWY_TARGETS & HWY_STATIC_TARGET) == 0
#error "Logic error: best baseline should be included in dynamic targets"
#endif
#endif // HIGHWAY_HWY_DETECT_TARGETS_H_

View File

@@ -0,0 +1,64 @@
// Copyright 2020 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Demo of functions that might be called from multiple SIMD modules (either
// other -inl.h files, or a .cc file between begin/end_target-inl). This is
// optional - all SIMD code can reside in .cc files. However, this allows
// splitting code into different files while still inlining instead of requiring
// calling through function pointers.
// Per-target include guard. This is only required when using dynamic dispatch,
// i.e. including foreach_target.h. For static dispatch, a normal include
// guard would be fine because the header is only compiled once.
#if defined(HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_) == defined(HWY_TARGET_TOGGLE)
#ifdef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_
#undef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_
#else
#define HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_
#endif
// It is fine to #include normal or *-inl headers.
#include "third_party/highway/hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace skeleton {
namespace HWY_NAMESPACE {
// Highway ops reside here; ADL does not find templates nor builtins.
namespace hn = hwy::HWY_NAMESPACE;
// Example of a type-agnostic (caller-specified lane type) and width-agnostic
// (uses best available instruction set) function in a header.
//
// Computes x[i] = mul_array[i] * x_array[i] + add_array[i] for i < size.
template <class D, typename T>
HWY_MAYBE_UNUSED void MulAddLoop(const D d, const T* HWY_RESTRICT mul_array,
const T* HWY_RESTRICT add_array,
const size_t size, T* HWY_RESTRICT x_array) {
for (size_t i = 0; i < size; i += hn::Lanes(d)) {
const auto mul = hn::Load(d, mul_array + i);
const auto add = hn::Load(d, add_array + i);
auto x = hn::Load(d, x_array + i);
x = hn::MulAdd(mul, x, add);
hn::Store(x, d, x_array + i);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace skeleton
HWY_AFTER_NAMESPACE();
#endif // include guard

View File

@@ -0,0 +1,38 @@
// Copyright 2020 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Demo interface to target-specific code in skeleton.cc
// Normal header with include guard and namespace.
#ifndef HIGHWAY_HWY_EXAMPLES_SKELETON_H_
#define HIGHWAY_HWY_EXAMPLES_SKELETON_H_
// Platform-specific definitions used for declaring an interface, independent of
// the SIMD instruction set.
#include "third_party/highway/hwy/base.h" // HWY_RESTRICT
namespace skeleton {
// Computes base-2 logarithm by converting to float. Supports dynamic dispatch.
HWY_DLLEXPORT void CallFloorLog2(const uint8_t* HWY_RESTRICT in, size_t count,
uint8_t* HWY_RESTRICT out);
// Same, but uses HWY_DYNAMIC_POINTER to save a function pointer and call it.
HWY_DLLEXPORT void SavedCallFloorLog2(const uint8_t* HWY_RESTRICT in,
size_t count, uint8_t* HWY_RESTRICT out);
} // namespace skeleton
#endif // HIGHWAY_HWY_EXAMPLES_SKELETON_H_

View File

@@ -0,0 +1,421 @@
// Copyright 2020 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_FOREACH_TARGET_H_
#define HIGHWAY_HWY_FOREACH_TARGET_H_
// Re-includes the translation unit zero or more times to compile for any
// targets except HWY_STATIC_TARGET. Defines unique HWY_TARGET each time so that
// highway.h defines the corresponding macro/namespace.
#include "third_party/highway/hwy/detect_targets.h"
// *_inl.h may include other headers, which requires include guards to prevent
// repeated inclusion. The guards must be reset after compiling each target, so
// the header is again visible. This is done by flipping HWY_TARGET_TOGGLE,
// defining it if undefined and vice versa. This macro is initially undefined
// so that IDEs don't gray out the contents of each header.
#ifdef HWY_TARGET_TOGGLE
#error "This macro must not be defined outside foreach_target.h"
#endif
#ifdef HWY_HIGHWAY_INCLUDED // highway.h include guard
// Trigger fixup at the bottom of this header.
#define HWY_ALREADY_INCLUDED
// The next highway.h must re-include set_macros-inl.h because the first
// highway.h chose the static target instead of what we will set below.
#undef HWY_SET_MACROS_PER_TARGET
#endif
// Disable HWY_EXPORT in user code until we have generated all targets. Note
// that a subsequent highway.h will not override this definition.
#undef HWY_ONCE
#define HWY_ONCE (0 || HWY_IDE)
// Avoid warnings on #include HWY_TARGET_INCLUDE by hiding them from the IDE;
// also skip if only 1 target defined (no re-inclusion will be necessary).
#if !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET)
#if !defined(HWY_TARGET_INCLUDE)
#error ">1 target enabled => define HWY_TARGET_INCLUDE before foreach_target.h"
#endif
// ------------------------------ HWY_ARCH_X86
#if (HWY_TARGETS & HWY_SSE2) && (HWY_STATIC_TARGET != HWY_SSE2)
#undef HWY_TARGET
#define HWY_TARGET HWY_SSE2
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_SSSE3) && (HWY_STATIC_TARGET != HWY_SSSE3)
#undef HWY_TARGET
#define HWY_TARGET HWY_SSSE3
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_SSE4) && (HWY_STATIC_TARGET != HWY_SSE4)
#undef HWY_TARGET
#define HWY_TARGET HWY_SSE4
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_AVX2) && (HWY_STATIC_TARGET != HWY_AVX2)
#undef HWY_TARGET
#define HWY_TARGET HWY_AVX2
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_AVX3) && (HWY_STATIC_TARGET != HWY_AVX3)
#undef HWY_TARGET
#define HWY_TARGET HWY_AVX3
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_AVX3_DL) && (HWY_STATIC_TARGET != HWY_AVX3_DL)
#undef HWY_TARGET
#define HWY_TARGET HWY_AVX3_DL
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_AVX3_ZEN4) && (HWY_STATIC_TARGET != HWY_AVX3_ZEN4)
#undef HWY_TARGET
#define HWY_TARGET HWY_AVX3_ZEN4
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_AVX3_SPR) && (HWY_STATIC_TARGET != HWY_AVX3_SPR)
#undef HWY_TARGET
#define HWY_TARGET HWY_AVX3_SPR
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_AVX10_2) && (HWY_STATIC_TARGET != HWY_AVX10_2)
#undef HWY_TARGET
#define HWY_TARGET HWY_AVX10_2
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_AVX10_2_512) && (HWY_STATIC_TARGET != HWY_AVX10_2_512)
#undef HWY_TARGET
#define HWY_TARGET HWY_AVX10_2_512
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
// ------------------------------ HWY_ARCH_ARM
#if (HWY_TARGETS & HWY_NEON_WITHOUT_AES) && \
(HWY_STATIC_TARGET != HWY_NEON_WITHOUT_AES)
#undef HWY_TARGET
#define HWY_TARGET HWY_NEON_WITHOUT_AES
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_NEON) && (HWY_STATIC_TARGET != HWY_NEON)
#undef HWY_TARGET
#define HWY_TARGET HWY_NEON
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_NEON_BF16) && (HWY_STATIC_TARGET != HWY_NEON_BF16)
#undef HWY_TARGET
#define HWY_TARGET HWY_NEON_BF16
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_SVE) && (HWY_STATIC_TARGET != HWY_SVE)
#undef HWY_TARGET
#define HWY_TARGET HWY_SVE
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_SVE2) && (HWY_STATIC_TARGET != HWY_SVE2)
#undef HWY_TARGET
#define HWY_TARGET HWY_SVE2
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_SVE_256) && (HWY_STATIC_TARGET != HWY_SVE_256)
#undef HWY_TARGET
#define HWY_TARGET HWY_SVE_256
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_SVE2_128) && (HWY_STATIC_TARGET != HWY_SVE2_128)
#undef HWY_TARGET
#define HWY_TARGET HWY_SVE2_128
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
// ------------------------------ HWY_ARCH_WASM
#if (HWY_TARGETS & HWY_WASM_EMU256) && (HWY_STATIC_TARGET != HWY_WASM_EMU256)
#undef HWY_TARGET
#define HWY_TARGET HWY_WASM_EMU256
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_WASM) && (HWY_STATIC_TARGET != HWY_WASM)
#undef HWY_TARGET
#define HWY_TARGET HWY_WASM
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
// ------------------------------ HWY_ARCH_PPC
#if (HWY_TARGETS & HWY_PPC8) && (HWY_STATIC_TARGET != HWY_PPC8)
#undef HWY_TARGET
#define HWY_TARGET HWY_PPC8
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_PPC9) && (HWY_STATIC_TARGET != HWY_PPC9)
#undef HWY_TARGET
#define HWY_TARGET HWY_PPC9
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_PPC10) && (HWY_STATIC_TARGET != HWY_PPC10)
#undef HWY_TARGET
#define HWY_TARGET HWY_PPC10
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
// ------------------------------ HWY_ARCH_S390X
#if (HWY_TARGETS & HWY_Z14) && (HWY_STATIC_TARGET != HWY_Z14)
#undef HWY_TARGET
#define HWY_TARGET HWY_Z14
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_Z15) && (HWY_STATIC_TARGET != HWY_Z15)
#undef HWY_TARGET
#define HWY_TARGET HWY_Z15
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
// ------------------------------ HWY_ARCH_RISCV
#if (HWY_TARGETS & HWY_RVV) && (HWY_STATIC_TARGET != HWY_RVV)
#undef HWY_TARGET
#define HWY_TARGET HWY_RVV
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
// ------------------------------ HWY_ARCH_LOONGARCH
#if (HWY_TARGETS & HWY_LSX) && (HWY_STATIC_TARGET != HWY_LSX)
#undef HWY_TARGET
#define HWY_TARGET HWY_LSX
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_LASX) && (HWY_STATIC_TARGET != HWY_LASX)
#undef HWY_TARGET
#define HWY_TARGET HWY_LASX
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
// ------------------------------ Scalar
#if (HWY_TARGETS & HWY_EMU128) && (HWY_STATIC_TARGET != HWY_EMU128)
#undef HWY_TARGET
#define HWY_TARGET HWY_EMU128
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#if (HWY_TARGETS & HWY_SCALAR) && (HWY_STATIC_TARGET != HWY_SCALAR)
#undef HWY_TARGET
#define HWY_TARGET HWY_SCALAR
#include HWY_TARGET_INCLUDE
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
#endif
#endif // !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET)
// Now that all but the static target have been generated, re-enable HWY_EXPORT.
#undef HWY_ONCE
#define HWY_ONCE 1
// If we re-include once per enabled target, the translation unit's
// implementation would have to be skipped via #if to avoid redefining symbols.
// We instead skip the re-include for HWY_STATIC_TARGET, and generate its
// implementation when resuming compilation of the translation unit.
#undef HWY_TARGET
#define HWY_TARGET HWY_STATIC_TARGET
#ifdef HWY_ALREADY_INCLUDED
// Revert the previous toggle to prevent redefinitions for the static target.
#ifdef HWY_TARGET_TOGGLE
#undef HWY_TARGET_TOGGLE
#else
#define HWY_TARGET_TOGGLE
#endif
// Force re-inclusion of set_macros-inl.h now that HWY_TARGET is restored.
#ifdef HWY_SET_MACROS_PER_TARGET
#undef HWY_SET_MACROS_PER_TARGET
#else
#define HWY_SET_MACROS_PER_TARGET
#endif
#endif
#endif // HIGHWAY_HWY_FOREACH_TARGET_H_

View File

@@ -0,0 +1,642 @@
// Copyright 2020 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Main header required before using vector types.
// IWYU pragma: begin_exports
#include "third_party/highway/hwy/base.h"
#include "third_party/highway/hwy/detect_compiler_arch.h"
#include "third_party/highway/hwy/detect_targets.h"
#include "third_party/highway/hwy/highway_export.h"
#include "third_party/highway/hwy/targets.h"
// IWYU pragma: end_exports
#if HWY_CXX_LANG < 201703L
#define HWY_DISPATCH_MAP 1
#else
#define HWY_DISPATCH_MAP 0
#endif
// This include guard is checked by foreach_target, so avoid the usual _H_
// suffix to prevent copybara from renaming it. NOTE: ops/*-inl.h are included
// after/outside this include guard.
#ifndef HWY_HIGHWAY_INCLUDED
#define HWY_HIGHWAY_INCLUDED
namespace hwy {
//------------------------------------------------------------------------------
// Shorthand for tags (defined in shared-inl.h) used to select overloads.
// Note that ScalableTag<T> is preferred over HWY_FULL, and CappedTag<T, N> over
// HWY_CAPPED(T, N).
// HWY_FULL(T[,LMUL=1]) is a native vector/group. LMUL is the number of
// registers in the group, and is ignored on targets that do not support groups.
#define HWY_FULL1(T) hwy::HWY_NAMESPACE::ScalableTag<T>
#define HWY_FULL2(T, LMUL) \
hwy::HWY_NAMESPACE::ScalableTag<T, hwy::CeilLog2(HWY_MAX(0, LMUL))>
#define HWY_3TH_ARG(arg1, arg2, arg3, ...) arg3
// Workaround for MSVC grouping __VA_ARGS__ into a single argument
#define HWY_FULL_RECOMPOSER(args_with_paren) HWY_3TH_ARG args_with_paren
// Trailing comma avoids -pedantic false alarm
#define HWY_CHOOSE_FULL(...) \
HWY_FULL_RECOMPOSER((__VA_ARGS__, HWY_FULL2, HWY_FULL1, ))
#define HWY_FULL(...) HWY_CHOOSE_FULL(__VA_ARGS__())(__VA_ARGS__)
// Vector of up to MAX_N lanes. It's better to use full vectors where possible.
#define HWY_CAPPED(T, MAX_N) hwy::HWY_NAMESPACE::CappedTag<T, MAX_N>
//------------------------------------------------------------------------------
// Export user functions for static/dynamic dispatch
// Evaluates to 0 inside a translation unit if it is generating anything but the
// static target (the last one if multiple targets are enabled). Used to prevent
// redefinitions of HWY_EXPORT. Unless foreach_target.h is included, we only
// compile once anyway, so this is 1 unless it is or has been included.
#ifndef HWY_ONCE
#define HWY_ONCE 1
#endif
// HWY_STATIC_DISPATCH(FUNC_NAME) is the namespace-qualified FUNC_NAME for
// HWY_STATIC_TARGET (the only defined namespace unless HWY_TARGET_INCLUDE is
// defined), and can be used to deduce the return type of Choose*.
#if HWY_STATIC_TARGET == HWY_SCALAR
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SCALAR::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_EMU128
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_EMU128::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_WASM
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_WASM_EMU256
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM_EMU256::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_Z14
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_Z14::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_Z15
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_Z15::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_PPC8
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC8::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_PPC9
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC9::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_PPC10
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC10::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_LSX
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_LSX::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_LASX
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_LASX::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_RVV
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_RVV::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_NEON_WITHOUT_AES
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON_WITHOUT_AES::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_NEON
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_NEON_BF16
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON_BF16::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_SVE
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_SVE2
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_SVE_256
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE_256::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_SVE2_128
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2_128::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_SSE2
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE2::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_SSSE3
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSSE3::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_SSE4
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE4::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_AVX2
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX2::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_AVX3
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_AVX3_DL
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_DL::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_AVX3_ZEN4
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_ZEN4::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_AVX10_2
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX10_2::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_AVX3_SPR
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_SPR::FUNC_NAME
#elif HWY_STATIC_TARGET == HWY_AVX10_2_512
#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX10_2_512::FUNC_NAME
#endif
// HWY_CHOOSE_*(FUNC_NAME) expands to the function pointer for that target or
// nullptr is that target was not compiled.
#if HWY_TARGETS & HWY_EMU128
#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_EMU128::FUNC_NAME
#elif HWY_TARGETS & HWY_SCALAR
#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_SCALAR::FUNC_NAME
#else
// When HWY_SCALAR/HWY_EMU128 are not present and other targets were disabled at
// runtime, fall back to the baseline with HWY_STATIC_DISPATCH().
#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME)
#endif
#if HWY_TARGETS & HWY_WASM
#define HWY_CHOOSE_WASM(FUNC_NAME) &N_WASM::FUNC_NAME
#else
#define HWY_CHOOSE_WASM(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_WASM_EMU256
#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) &N_WASM_EMU256::FUNC_NAME
#else
#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_Z14
#define HWY_CHOOSE_Z14(FUNC_NAME) &N_Z14::FUNC_NAME
#else
#define HWY_CHOOSE_Z14(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_Z15
#define HWY_CHOOSE_Z15(FUNC_NAME) &N_Z15::FUNC_NAME
#else
#define HWY_CHOOSE_Z15(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_PPC8
#define HWY_CHOOSE_PPC8(FUNC_NAME) &N_PPC8::FUNC_NAME
#else
#define HWY_CHOOSE_PPC8(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_PPC9
#define HWY_CHOOSE_PPC9(FUNC_NAME) &N_PPC9::FUNC_NAME
#else
#define HWY_CHOOSE_PPC9(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_LSX
#define HWY_CHOOSE_LSX(FUNC_NAME) &N_LSX::FUNC_NAME
#else
#define HWY_CHOOSE_LSX(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_LASX
#define HWY_CHOOSE_LASX(FUNC_NAME) &N_LASX::FUNC_NAME
#else
#define HWY_CHOOSE_LASX(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_PPC10
#define HWY_CHOOSE_PPC10(FUNC_NAME) &N_PPC10::FUNC_NAME
#else
#define HWY_CHOOSE_PPC10(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_RVV
#define HWY_CHOOSE_RVV(FUNC_NAME) &N_RVV::FUNC_NAME
#else
#define HWY_CHOOSE_RVV(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_NEON_WITHOUT_AES
#define HWY_CHOOSE_NEON_WITHOUT_AES(FUNC_NAME) &N_NEON_WITHOUT_AES::FUNC_NAME
#else
#define HWY_CHOOSE_NEON_WITHOUT_AES(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_NEON
#define HWY_CHOOSE_NEON(FUNC_NAME) &N_NEON::FUNC_NAME
#else
#define HWY_CHOOSE_NEON(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_NEON_BF16
#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) &N_NEON_BF16::FUNC_NAME
#else
#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_SVE
#define HWY_CHOOSE_SVE(FUNC_NAME) &N_SVE::FUNC_NAME
#else
#define HWY_CHOOSE_SVE(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_SVE2
#define HWY_CHOOSE_SVE2(FUNC_NAME) &N_SVE2::FUNC_NAME
#else
#define HWY_CHOOSE_SVE2(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_SVE_256
#define HWY_CHOOSE_SVE_256(FUNC_NAME) &N_SVE_256::FUNC_NAME
#else
#define HWY_CHOOSE_SVE_256(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_SVE2_128
#define HWY_CHOOSE_SVE2_128(FUNC_NAME) &N_SVE2_128::FUNC_NAME
#else
#define HWY_CHOOSE_SVE2_128(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_SSE2
#define HWY_CHOOSE_SSE2(FUNC_NAME) &N_SSE2::FUNC_NAME
#else
#define HWY_CHOOSE_SSE2(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_SSSE3
#define HWY_CHOOSE_SSSE3(FUNC_NAME) &N_SSSE3::FUNC_NAME
#else
#define HWY_CHOOSE_SSSE3(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_SSE4
#define HWY_CHOOSE_SSE4(FUNC_NAME) &N_SSE4::FUNC_NAME
#else
#define HWY_CHOOSE_SSE4(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_AVX2
#define HWY_CHOOSE_AVX2(FUNC_NAME) &N_AVX2::FUNC_NAME
#else
#define HWY_CHOOSE_AVX2(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_AVX3
#define HWY_CHOOSE_AVX3(FUNC_NAME) &N_AVX3::FUNC_NAME
#else
#define HWY_CHOOSE_AVX3(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_AVX3_DL
#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) &N_AVX3_DL::FUNC_NAME
#else
#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_AVX3_ZEN4
#define HWY_CHOOSE_AVX3_ZEN4(FUNC_NAME) &N_AVX3_ZEN4::FUNC_NAME
#else
#define HWY_CHOOSE_AVX3_ZEN4(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_AVX10_2
#define HWY_CHOOSE_AVX10_2(FUNC_NAME) &N_AVX10_2::FUNC_NAME
#else
#define HWY_CHOOSE_AVX10_2(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_AVX3_SPR
#define HWY_CHOOSE_AVX3_SPR(FUNC_NAME) &N_AVX3_SPR::FUNC_NAME
#else
#define HWY_CHOOSE_AVX3_SPR(FUNC_NAME) nullptr
#endif
#if HWY_TARGETS & HWY_AVX10_2_512
#define HWY_CHOOSE_AVX10_2_512(FUNC_NAME) &N_AVX10_2_512::FUNC_NAME
#else
#define HWY_CHOOSE_AVX10_2_512(FUNC_NAME) nullptr
#endif
// MSVC 2017 workaround: the non-type template parameter to ChooseAndCall
// apparently cannot be an array. Use a function pointer instead, which has the
// disadvantage that we call the static (not best) target on the first call to
// any HWY_DYNAMIC_DISPATCH.
#if (HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1915) || \
(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700)
#define HWY_DISPATCH_WORKAROUND 1
#else
#define HWY_DISPATCH_WORKAROUND 0
#endif
#if HWY_DISPATCH_MAP
struct AllExports {
template <class FuncPtr, class ExportsKey, uint64_t kHash>
static const FuncPtr*& GetRefToExportsPtr() {
static const FuncPtr* s_exports = nullptr;
return s_exports;
}
};
#endif
// Provides a static member function which is what is called during the first
// HWY_DYNAMIC_DISPATCH, where GetIndex is still zero, and instantiations of
// this function are the first entry in the tables created by HWY_EXPORT[_T].
template <typename RetType, typename... Args>
struct FunctionCache {
public:
typedef RetType(FuncType)(Args...);
using FuncPtr = FuncType*;
// A template function that when instantiated has the same signature as the
// function being called. This function initializes the bit array of targets
// supported by the current CPU and then calls the appropriate entry within
// the HWY_EXPORT table. Subsequent calls via HWY_DYNAMIC_DISPATCH to any
// exported functions, even those defined by different translation units,
// will dispatch directly to the best available target.
#if HWY_DISPATCH_MAP
template <class ExportsKey, uint64_t kHash>
static RetType ChooseAndCall(Args... args) {
ChosenTarget& chosen_target = GetChosenTarget();
chosen_target.Update(SupportedTargets());
const FuncPtr* table = AllExports::template GetRefToExportsPtr<
FuncPtr, RemoveCvRef<ExportsKey>, kHash>();
HWY_ASSERT(table);
return (table[chosen_target.GetIndex()])(args...);
}
#if !HWY_DISPATCH_WORKAROUND
template <const FuncPtr* table>
static RetType TableChooseAndCall(Args... args) {
ChosenTarget& chosen_target = GetChosenTarget();
chosen_target.Update(SupportedTargets());
return (table[chosen_target.GetIndex()])(args...);
}
#endif // !HWY_DISPATCH_WORKAROUND
#else // !HWY_DISPATCH_MAP: zero-overhead, but requires C++17
template <const FuncPtr* table>
static RetType ChooseAndCall(Args... args) {
ChosenTarget& chosen_target = GetChosenTarget();
chosen_target.Update(SupportedTargets());
return (table[chosen_target.GetIndex()])(args...);
}
#endif // HWY_DISPATCH_MAP
};
// Used to deduce the template parameters RetType and Args from a function.
template <typename RetType, typename... Args>
FunctionCache<RetType, Args...> DeduceFunctionCache(RetType (*)(Args...)) {
return FunctionCache<RetType, Args...>();
}
#define HWY_DISPATCH_TABLE(FUNC_NAME) \
HWY_CONCAT(FUNC_NAME, HighwayDispatchTable)
// HWY_EXPORT(FUNC_NAME); expands to a static array that is used by
// HWY_DYNAMIC_DISPATCH() to call the appropriate function at runtime.
// After being exported, it can be called from other parts of the same source
// file using HWY_DYNAMIC_DISPATCH(), in particular from a function wrapper
// like in the following example:
//
// #include "third_party/highway/hwy/highway.h"
// HWY_BEFORE_NAMESPACE();
// namespace skeleton {
// namespace HWY_NAMESPACE {
//
// void MyFunction(int a, char b, const char* c) { ... }
//
// // NOLINTNEXTLINE(google-readability-namespace-comments)
// } // namespace HWY_NAMESPACE
// } // namespace skeleton
// HWY_AFTER_NAMESPACE();
//
// namespace skeleton {
// HWY_EXPORT(MyFunction); // Defines the dispatch table in this scope.
//
// void MyFunction(int a, char b, const char* c) {
// return HWY_DYNAMIC_DISPATCH(MyFunction)(a, b, c);
// }
// } // namespace skeleton
//
// For templated code with a single type parameter, instead use HWY_EXPORT_T and
// its HWY_DYNAMIC_DISPATCH_T counterpart:
//
// template <typename T>
// void MyFunctionCaller(T ...) {
// // First argument to both HWY_EXPORT_T and HWY_DYNAMIC_DISPATCH_T is an
// // arbitrary table name; you must provide the same name for each call.
// // It is fine to have multiple HWY_EXPORT_T in a function, but a 64-bit
// // FNV hash collision among *any* table names will trigger HWY_ABORT.
// HWY_EXPORT_T(Table1, MyFunction<T>)
// HWY_DYNAMIC_DISPATCH_T(Table1)(a, b, c);
// }
//
// Note that HWY_EXPORT_T must be invoked inside a template (in the above
// example: `MyFunctionCaller`), so that a separate table will be created for
// each template instantiation. For convenience, we also provide a macro that
// combines both steps and avoids the need to pick a table name:
//
// template <typename T>
// void MyFunctionCaller(T ...) {
// // Table name is automatically chosen. Note that this variant must be
// // called in statement context; it is not a valid expression.
// HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(MyFunction<T>)(a, b, c);
// }
// Simplified version for IDE or the dynamic dispatch case with only one target.
#if HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0)
// We use a table to provide the same compile error conditions as with the
// non-simplified case, but the table only has a single entry.
#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \
HWY_MAYBE_UNUSED static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const \
HWY_DISPATCH_TABLE(TABLE_NAME)[1] = {&HWY_STATIC_DISPATCH(FUNC_NAME)}
// Use the table, not just STATIC_DISPATCH as in DYNAMIC_DISPATCH, because
// TABLE_NAME might not match the function name.
#define HWY_DYNAMIC_POINTER_T(TABLE_NAME) (HWY_DISPATCH_TABLE(TABLE_NAME)[0])
#define HWY_DYNAMIC_DISPATCH_T(TABLE_NAME) \
(*(HWY_DYNAMIC_POINTER_T(TABLE_NAME)))
#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME)
#define HWY_DYNAMIC_POINTER(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME)
#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) HWY_STATIC_DISPATCH(FUNC_NAME)
#else // not simplified: full table
// Pre-C++17 workaround: non-type template arguments must have linkage, which
// means we cannot pass &table as a template argument to ChooseAndCall.
// ChooseAndCall must find a way to access the table in order to dispatch to the
// chosen target:
// 0) Skipping this by dispatching to the static target would be surprising to
// users and may have serious performance implications.
// 1) An extra function parameter would be unacceptable because it changes the
// user-visible function signature.
// 2) Declaring a table, then defining a pointer to it would work, but requires
// an additional DECLARE step outside the function so that the pointer has
// linkage, which breaks existing code.
// 3) We instead associate the function with the table using an instance of an
// unnamed struct and the hash of the table name as the key. Because
// ChooseAndCall has the type information, it can then cast to the function
// pointer type. However, we cannot simply pass the name as a template
// argument to ChooseAndCall because this requires char*, which hits the same
// linkage problem. We instead hash the table name, which assumes the
// function names do not have collisions.
#if HWY_DISPATCH_MAP
static constexpr uint64_t FNV(const char* name) {
return *name ? static_cast<uint64_t>(static_cast<uint8_t>(*name)) ^
(0x100000001b3ULL * FNV(name + 1))
: 0xcbf29ce484222325ULL;
}
template <uint64_t kHash>
struct AddExport {
template <class ExportsKey, class FuncPtr>
AddExport(ExportsKey /*exports_key*/, const char* table_name,
const FuncPtr* table) {
using FuncCache = decltype(DeduceFunctionCache(hwy::DeclVal<FuncPtr>()));
static_assert(
hwy::IsSame<RemoveCvRef<FuncPtr>, typename FuncCache::FuncPtr>(),
"FuncPtr should be same type as FuncCache::FuncPtr");
const FuncPtr*& exports_ptr = AllExports::template GetRefToExportsPtr<
RemoveCvRef<FuncPtr>, RemoveCvRef<ExportsKey>, kHash>();
if (exports_ptr && exports_ptr != table) {
HWY_ABORT("Hash collision for %s, rename the function\n", table_name);
} else {
exports_ptr = table;
}
}
};
// Dynamic dispatch: defines table of function pointers. This must be invoked
// from inside the function template that calls the template we are exporting.
// TABLE_NAME must match the one passed to HWY_DYNAMIC_DISPATCH_T. This
// argument allows multiple exports within one function.
#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \
static const struct { \
} HWY_CONCAT(TABLE_NAME, HighwayDispatchExportsKey) = {}; \
static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \
TABLE_NAME)[static_cast<size_t>(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \
/* The first entry in the table initializes the global cache and \
* calls the appropriate function. */ \
&decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \
template ChooseAndCall<decltype(HWY_CONCAT( \
TABLE_NAME, HighwayDispatchExportsKey)), \
hwy::FNV(#TABLE_NAME)>, \
HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \
HWY_CHOOSE_FALLBACK(FUNC_NAME), \
}; \
HWY_MAYBE_UNUSED static hwy::AddExport<hwy::FNV(#TABLE_NAME)> HWY_CONCAT( \
HighwayAddTable, __LINE__)( \
HWY_CONCAT(TABLE_NAME, HighwayDispatchExportsKey), #TABLE_NAME, \
HWY_DISPATCH_TABLE(TABLE_NAME))
// For non-template functions. Not necessarily invoked within a function, hence
// we derive the string and variable names from FUNC_NAME, not HWY_FUNCTION.
#if HWY_DISPATCH_WORKAROUND
#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME)
#else
#define HWY_EXPORT(FUNC_NAME) \
static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \
FUNC_NAME)[static_cast<size_t>(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \
/* The first entry in the table initializes the global cache and \
* calls the appropriate function. */ \
&decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \
template TableChooseAndCall<HWY_DISPATCH_TABLE(FUNC_NAME)>, \
HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \
HWY_CHOOSE_FALLBACK(FUNC_NAME), \
}
#endif // HWY_DISPATCH_WORKAROUND
#else // !HWY_DISPATCH_MAP
// Zero-overhead, but requires C++17 for non-type template arguments without
// linkage, because HWY_EXPORT_T tables are local static variables.
#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \
static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \
TABLE_NAME)[static_cast<size_t>(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \
/* The first entry in the table initializes the global cache and \
* calls the appropriate function. */ \
&decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \
template ChooseAndCall<HWY_DISPATCH_TABLE(TABLE_NAME)>, \
HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \
HWY_CHOOSE_FALLBACK(FUNC_NAME), \
}
#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME)
#endif // HWY_DISPATCH_MAP
// HWY_DISPATCH_MAP only affects how tables are created, not their usage.
// Evaluates to the function pointer for the chosen target.
#define HWY_DYNAMIC_POINTER(FUNC_NAME) \
(HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::GetChosenTarget().GetIndex()])
// Calls the function pointer for the chosen target.
#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) (*(HWY_DYNAMIC_POINTER(FUNC_NAME)))
// Same as DISPATCH, but provide a different arg name to clarify usage.
#define HWY_DYNAMIC_DISPATCH_T(TABLE_NAME) HWY_DYNAMIC_DISPATCH(TABLE_NAME)
#define HWY_DYNAMIC_POINTER_T(TABLE_NAME) HWY_DYNAMIC_POINTER(TABLE_NAME)
#endif // HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0)
// Returns the name of an anonymous dispatch table that is only shared with
// macro invocations coming from the same source line.
#define HWY_DISPATCH_TABLE_T() HWY_CONCAT(HighwayDispatchTableT, __LINE__)
// For templated code, combines export and dispatch using an anonymous table.
#define HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC_NAME) \
HWY_EXPORT_T(HWY_DISPATCH_TABLE_T(), FUNC_NAME); \
HWY_DYNAMIC_DISPATCH_T(HWY_DISPATCH_TABLE_T())
// DEPRECATED names; please use HWY_HAVE_* instead.
#define HWY_CAP_INTEGER64 HWY_HAVE_INTEGER64
#define HWY_CAP_FLOAT16 HWY_HAVE_FLOAT16
#define HWY_CAP_FLOAT64 HWY_HAVE_FLOAT64
} // namespace hwy
#endif // HWY_HIGHWAY_INCLUDED
//------------------------------------------------------------------------------
// NOTE: the following definitions and ops/*.h depend on HWY_TARGET, so we want
// to include them once per target, which is ensured by the toggle check.
// Because ops/*.h are included under it, they do not need their own guard.
#if defined(HWY_HIGHWAY_PER_TARGET) == defined(HWY_TARGET_TOGGLE)
#ifdef HWY_HIGHWAY_PER_TARGET
#undef HWY_HIGHWAY_PER_TARGET
#else
#define HWY_HIGHWAY_PER_TARGET
#endif
// These define ops inside namespace hwy::HWY_NAMESPACE.
#if HWY_TARGET == HWY_SSE2 || HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4
#include "third_party/highway/hwy/ops/x86_128-inl.h"
#elif HWY_TARGET == HWY_AVX2
#include "third_party/highway/hwy/ops/x86_256-inl.h"
#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL || \
HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX10_2 || \
HWY_TARGET == HWY_AVX3_SPR || HWY_TARGET == HWY_AVX10_2_512
#include "third_party/highway/hwy/ops/x86_avx3-inl.h"
#elif HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 || \
(HWY_TARGET & HWY_ALL_PPC)
#include "third_party/highway/hwy/ops/ppc_vsx-inl.h"
#elif HWY_TARGET & HWY_ALL_NEON
#include "third_party/highway/hwy/ops/arm_neon-inl.h"
#elif HWY_TARGET & HWY_ALL_SVE
#include "third_party/highway/hwy/ops/arm_sve-inl.h"
#elif HWY_TARGET == HWY_WASM_EMU256
#include "third_party/highway/hwy/ops/wasm_256-inl.h"
#elif HWY_TARGET == HWY_WASM
#include "third_party/highway/hwy/ops/wasm_128-inl.h"
#elif HWY_TARGET == HWY_RVV
#include "third_party/highway/hwy/ops/rvv-inl.h"
#elif HWY_TARGET == HWY_EMU128
#include "third_party/highway/hwy/ops/emu128-inl.h"
#elif HWY_TARGET == HWY_SCALAR
#include "third_party/highway/hwy/ops/scalar-inl.h"
#elif HWY_TARGET == HWY_LSX || HWY_TARGET == HWY_LASX
#include "third_party/highway/hwy/ops/loongarch_lsx-inl.h"
#else
#pragma message("HWY_TARGET does not match any known target")
#endif // HWY_TARGET
#include "third_party/highway/hwy/ops/generic_ops-inl.h"
#endif // HWY_HIGHWAY_PER_TARGET

View File

@@ -0,0 +1,74 @@
// Pseudo-generated file to handle both cmake & bazel build system.
// Initial generation done using cmake code:
// include(GenerateExportHeader)
// generate_export_header(hwy EXPORT_MACRO_NAME HWY_DLLEXPORT EXPORT_FILE_NAME
// hwy/highway_export.h)
// code reformatted using clang-format --style=Google
#ifndef HWY_DLLEXPORT_H
#define HWY_DLLEXPORT_H
#if !defined(HWY_SHARED_DEFINE)
#define HWY_DLLEXPORT
#define HWY_CONTRIB_DLLEXPORT
#define HWY_TEST_DLLEXPORT
#else // !HWY_SHARED_DEFINE
#ifndef HWY_DLLEXPORT
#if defined(hwy_EXPORTS)
/* We are building this library */
#ifdef _WIN32
#define HWY_DLLEXPORT __declspec(dllexport)
#else
#define HWY_DLLEXPORT __attribute__((visibility("default")))
#endif
#else // defined(hwy_EXPORTS)
/* We are using this library */
#ifdef _WIN32
#define HWY_DLLEXPORT __declspec(dllimport)
#else
#define HWY_DLLEXPORT __attribute__((visibility("default")))
#endif
#endif // defined(hwy_EXPORTS)
#endif // HWY_DLLEXPORT
#ifndef HWY_CONTRIB_DLLEXPORT
#if defined(hwy_contrib_EXPORTS)
/* We are building this library */
#ifdef _WIN32
#define HWY_CONTRIB_DLLEXPORT __declspec(dllexport)
#else
#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default")))
#endif
#else // defined(hwy_contrib_EXPORTS)
/* We are using this library */
#ifdef _WIN32
#define HWY_CONTRIB_DLLEXPORT __declspec(dllimport)
#else
#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default")))
#endif
#endif // defined(hwy_contrib_EXPORTS)
#endif // HWY_CONTRIB_DLLEXPORT
#ifndef HWY_TEST_DLLEXPORT
#if defined(hwy_test_EXPORTS)
/* We are building this library */
#ifdef _WIN32
#define HWY_TEST_DLLEXPORT __declspec(dllexport)
#else
#define HWY_TEST_DLLEXPORT __attribute__((visibility("default")))
#endif
#else // defined(hwy_test_EXPORTS)
/* We are using this library */
#ifdef _WIN32
#define HWY_TEST_DLLEXPORT __declspec(dllimport)
#else
#define HWY_TEST_DLLEXPORT __attribute__((visibility("default")))
#endif
#endif // defined(hwy_test_EXPORTS)
#endif // HWY_TEST_DLLEXPORT
#endif // !HWY_SHARED_DEFINE
#endif /* HWY_DLLEXPORT_H */

View File

@@ -0,0 +1,19 @@
HWY_0 {
global:
extern "C++" {
*hwy::*;
};
local:
# Hide all the std namespace symbols. std namespace is explicitly marked
# as visibility(default) and header-only functions or methods (such as those
# from templates) should be exposed in shared libraries as weak symbols but
# this is only needed when we expose those types in the shared library API
# in any way. We don't use C++ std types in the API and we also don't
# support exceptions in the library.
# See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=36022 for a discussion
# about this.
extern "C++" {
*std::*;
};
};

View File

@@ -0,0 +1,153 @@
// Copyright 2019 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HIGHWAY_HWY_NANOBENCHMARK_H_
#define HIGHWAY_HWY_NANOBENCHMARK_H_
// Benchmarks functions of a single integer argument with realistic branch
// prediction hit rates. Uses a robust estimator to summarize the measurements.
// The precision is about 0.2%.
//
// Examples: see nanobenchmark_test.cc.
//
// Background: Microbenchmarks such as http://github.com/google/benchmark
// can measure elapsed times on the order of a microsecond. Shorter functions
// are typically measured by repeating them thousands of times and dividing
// the total elapsed time by this count. Unfortunately, repetition (especially
// with the same input parameter!) influences the runtime. In time-critical
// code, it is reasonable to expect warm instruction/data caches and TLBs,
// but a perfect record of which branches will be taken is unrealistic.
// Unless the application also repeatedly invokes the measured function with
// the same parameter, the benchmark is measuring something very different -
// a best-case result, almost as if the parameter were made a compile-time
// constant. This may lead to erroneous conclusions about branch-heavy
// algorithms outperforming branch-free alternatives.
//
// Our approach differs in three ways. Adding fences to the timer functions
// reduces variability due to instruction reordering, improving the timer
// resolution to about 40 CPU cycles. However, shorter functions must still
// be invoked repeatedly. For more realistic branch prediction performance,
// we vary the input parameter according to a user-specified distribution.
// Thus, instead of VaryInputs(Measure(Repeat(func))), we change the
// loop nesting to Measure(Repeat(VaryInputs(func))). We also estimate the
// central tendency of the measurement samples with the "half sample mode",
// which is more robust to outliers and skewed data than the mean or median.
#include <stddef.h>
#include <stdint.h>
#include "third_party/highway/hwy/highway_export.h"
#include "third_party/highway/hwy/timer.h" // IWYU pragma: export
namespace hwy {
// Returns 1, but without the compiler knowing what the value is. This prevents
// optimizing out code.
HWY_DLLEXPORT int Unpredictable1();
// Input influencing the function being measured (e.g. number of bytes to copy).
using FuncInput = size_t;
// "Proof of work" returned by Func to ensure the compiler does not elide it.
using FuncOutput = uint64_t;
// Function to measure: either 1) a captureless lambda or function with two
// arguments or 2) a lambda with capture, in which case the first argument
// is reserved for use by MeasureClosure.
using Func = FuncOutput (*)(const void*, FuncInput);
// Internal parameters that determine precision/resolution/measuring time.
struct Params {
// Best-case precision, expressed as a divisor of the timer resolution.
// Larger => more calls to Func and higher precision.
size_t precision_divisor = 1024;
// Ratio between full and subset input distribution sizes. Cannot be less
// than 2; larger values increase measurement time but more faithfully
// model the given input distribution.
size_t subset_ratio = 2;
// Together with the estimated Func duration, determines how many times to
// call Func before checking the sample variability. Larger values increase
// measurement time, memory/cache use and precision.
double seconds_per_eval = 4E-3;
// The minimum number of samples before estimating the central tendency.
size_t min_samples_per_eval = 7;
// The mode is better than median for estimating the central tendency of
// skewed/fat-tailed distributions, but it requires sufficient samples
// relative to the width of half-ranges.
size_t min_mode_samples = 64;
// Maximum permissible variability (= median absolute deviation / center).
double target_rel_mad = 0.002;
// Abort after this many evals without reaching target_rel_mad. This
// prevents infinite loops.
size_t max_evals = 9;
// Whether to print additional statistics to stdout.
bool verbose = true;
};
// Measurement result for each unique input.
struct Result {
FuncInput input;
// Robust estimate (mode or median) of duration.
float ticks;
// Measure of variability (median absolute deviation relative to "ticks").
float variability;
};
// Precisely measures the number of ticks elapsed when calling "func" with the
// given inputs, shuffled to ensure realistic branch prediction hit rates.
//
// "func" returns a 'proof of work' to ensure its computations are not elided.
// "arg" is passed to Func, or reserved for internal use by MeasureClosure.
// "inputs" is an array of "num_inputs" (not necessarily unique) arguments to
// "func". The values should be chosen to maximize coverage of "func". This
// represents a distribution, so a value's frequency should reflect its
// probability in the real application. Order does not matter; for example, a
// uniform distribution over [0, 4) could be represented as {3,0,2,1}.
// Returns how many Result were written to "results": one per unique input, or
// zero if the measurement failed (an error message goes to stderr).
HWY_DLLEXPORT size_t Measure(Func func, const uint8_t* arg,
const FuncInput* inputs, size_t num_inputs,
Result* results, const Params& p = Params());
// Calls operator() of the given closure (lambda function).
template <class Closure>
static FuncOutput CallClosure(const Closure* f, const FuncInput input) {
return (*f)(input);
}
// Same as Measure, except "closure" is typically a lambda function of
// FuncInput -> FuncOutput with a capture list.
template <class Closure>
static inline size_t MeasureClosure(const Closure& closure,
const FuncInput* inputs,
const size_t num_inputs, Result* results,
const Params& p = Params()) {
return Measure(reinterpret_cast<Func>(&CallClosure<Closure>),
reinterpret_cast<const uint8_t*>(&closure), inputs, num_inputs,
results, p);
}
} // namespace hwy
#endif // HIGHWAY_HWY_NANOBENCHMARK_H_

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More