diff --git a/mfbt/Casting.h b/mfbt/Casting.h index 3d3fdc88b5ea..d264dc24f44b 100644 --- a/mfbt/Casting.h +++ b/mfbt/Casting.h @@ -249,6 +249,59 @@ inline To ReleaseAssertedCast(const From aFrom) { return static_cast(aFrom); } +/** + * Cast from type From to type To, clamping to minimum and maximum value of the + * destination type if needed. + */ +template +inline To SaturatingCast(const From aFrom) { + static_assert(std::is_arithmetic_v && std::is_arithmetic_v); + // This implementation works up to 64-bits integers. + static_assert(sizeof(From) <= 8 && sizeof(To) <= 8); + constexpr bool fromFloat = std::is_floating_point_v; + constexpr bool toFloat = std::is_floating_point_v; + + // It's not clear what the caller wants here, it could be round, truncate, + // closest value, etc. + static_assert((fromFloat && !toFloat) || (!fromFloat && !toFloat), + "Handle manually depending on desired behaviour"); + + // If the source is floating point and the destination isn't, it can be that + // casting changes the value unexpectedly. Casting to double and clamping to + // the max of the destination type is correct, this also handles infinity. + if constexpr (fromFloat) { + if (aFrom > static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } + if (aFrom < static_cast(std::numeric_limits::lowest())) { + return std::numeric_limits::lowest(); + } + return static_cast(aFrom); + } + // Source and destination are of opposite signedness + if constexpr (std::is_signed_v != std::is_signed_v) { + // Input is negative, output is unsigned, return 0 + if (std::is_signed_v && aFrom < 0) { + return 0; + } + // At this point the input is positive, cast everything to uint64_t for + // simplicity and compare + uint64_t inflated = AssertedCast(aFrom); + if (inflated > static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } + return static_cast(aFrom); + } + // Regular case: clamp to destination type range + if (aFrom > std::numeric_limits::max()) { + return std::numeric_limits::max(); + } + if (aFrom < std::numeric_limits::lowest()) { + return std::numeric_limits::lowest(); + } + return static_cast(aFrom); +} + namespace detail { template diff --git a/mfbt/tests/TestCasting.cpp b/mfbt/tests/TestCasting.cpp index 9b040956c73b..165915971fa5 100644 --- a/mfbt/tests/TestCasting.cpp +++ b/mfbt/tests/TestCasting.cpp @@ -11,9 +11,13 @@ #include #include #include +#include +#include +#include using mozilla::AssertedCast; using mozilla::BitwiseCast; +using mozilla::SaturatingCast; using mozilla::detail::IsInBounds; static const uint8_t floatMantissaBitsPlusOne = 24; @@ -243,6 +247,139 @@ void TestFloatConversion() { !(IsInBounds(-std::numeric_limits::max()))); } +#define ASSERT_EQ(a, b) \ + if ((a) != (b)) { \ + std::cerr << __FILE__ << ":" << __LINE__ << " Actual: " << +(a) << ", " \ + << "expected: " << +(b) << std::endl; \ + MOZ_CRASH(); \ + } + +#ifdef ENABLE_DEBUG_PRINT +# define DEBUG_PRINT(in, out) \ + std::cout << "\tIn: " << +in << ", " << "out: " << +out << std::endl; +#else +# define DEBUG_PRINT(in, out) +#endif + +template +void TestTypePairImpl() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + std::cout << std::fixed; + // Test casting infinities to integer works + if constexpr (std::is_floating_point_v && + !std::is_floating_point_v) { + Out v = SaturatingCast(std::numeric_limits::infinity()); + ASSERT_EQ(v, std::numeric_limits::max()); + v = SaturatingCast(-std::numeric_limits::infinity()); + ASSERT_EQ(v, std::numeric_limits::lowest()); + } + // Saturation of a floating point value that is infinity is infinity + if constexpr (std::is_floating_point_v && std::is_floating_point_v) { + In in = std::numeric_limits::infinity(); + Out v = SaturatingCast(in); + DEBUG_PRINT(in, v); + ASSERT_EQ(v, std::numeric_limits::infinity()); + in = -std::numeric_limits::infinity(); + v = SaturatingCast(in); + DEBUG_PRINT(in, v); + ASSERT_EQ(v, -std::numeric_limits::infinity()); + return; + } else { + if constexpr (sizeof(In) > sizeof(Out) && std::is_integral_v) { + // Test with a value just outside the range of the output type + In in = static_cast(std::numeric_limits::max()) + 1ull; + Out v = SaturatingCast(in); + DEBUG_PRINT(in, v); + ASSERT_EQ(v, std::numeric_limits::max()); + + if (std::is_signed_v) { + // Test with a value just below the range of the output type + Out lowest = std::numeric_limits::lowest(); + in = static_cast(lowest) - 1; + v = SaturatingCast(in); + DEBUG_PRINT(in, v); + if constexpr (std::is_signed_v && !std::is_signed_v) { + ASSERT_EQ(v, 0); + } else { + ASSERT_EQ(v, std::numeric_limits::lowest()); + } + } + } else if constexpr (std::is_integral_v && std::is_integral_v && + sizeof(In) == sizeof(Out) && !std::is_signed_v && + std::is_signed_v) { + // Test that max uintXX_t saturates to max intXX_t + In in = static_cast(std::numeric_limits::max()) + 1; + Out v = SaturatingCast(in); + DEBUG_PRINT(in, v); + ASSERT_EQ(v, std::numeric_limits::max()); + } + + // SaturatingCast of zero is zero + In in = static_cast(0); + Out v = SaturatingCast(in); + DEBUG_PRINT(in, v); + ASSERT_EQ(v, 0); + + if constexpr (sizeof(In) >= sizeof(Out) && std::is_signed_v && + std::is_signed_v) { + // Test with a value within the range of the output type + In in = static_cast(std::numeric_limits::max() / 2); + Out v = SaturatingCast(in); + DEBUG_PRINT(in, v); + ASSERT_EQ(v, in); + + // Test with a negative value within the range of the output type + in = static_cast(std::numeric_limits::lowest() / 2); + v = SaturatingCast(in); + DEBUG_PRINT(in, v); + ASSERT_EQ(v, in); + } + } +} + +template +void TestTypePair() { + constexpr bool fromFloat = std::is_floating_point_v; + constexpr bool toFloat = std::is_floating_point_v; + // Don't test casting to the same type + if constexpr (!std::is_same_v) { + if constexpr ((fromFloat && !toFloat) || (!fromFloat && !toFloat)) { + TestTypePairImpl(); + } + } +} + +template +void for_each_type_pair(std::tuple) { + (TestTypePair(), ...); + (TestTypePair(), ...); + if constexpr (sizeof...(Ts) > 1) { + for_each_type_pair(std::tuple{}); + } +} + +template +void TestSaturatingCastImpl() { + for_each_type_pair(std::tuple{}); +} + +template +void TestFirstToOthers() { + (TestTypePair(), ...); +} + +void TestSaturatingCast() { + // Each integer type against every other + TestSaturatingCastImpl(); + + // Floating point types to every integer type + TestFirstToOthers(); + TestFirstToOthers(); +} + int main() { TestBitwiseCast(); @@ -250,6 +387,7 @@ int main() { TestToBiggerSize(); TestToSmallerSize(); TestFloatConversion(); + TestSaturatingCast(); return 0; }