aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKimplul <kimi.h.kuparinen@gmail.com>2025-05-22 21:52:07 +0300
committerKimplul <kimi.h.kuparinen@gmail.com>2025-05-22 21:52:07 +0300
commitd2c56bc3103a8c459c75c779f5c95d9fd6223d89 (patch)
treeb75fe643c78c8f4214567d341cdf9f2502df718e
downloadsat_arith-master.tar.gz
sat_arith-master.zip
-rw-r--r--.gitignore1
-rw-r--r--Makefile8
-rw-r--r--README.md1
-rw-r--r--include/sat_arith/sat_arith.h142
-rw-r--r--tests/add_sat.c128
-rw-r--r--tests/mul_sat.c128
-rw-r--r--tests/sub_sat.c128
7 files changed, 536 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..378eac2
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+build
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..2dd2764
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,8 @@
+check:
+ mkdir -p build
+ $(CC) -O2 -Iinclude -g -Wall -Wextra tests/add_sat.c -o build/add_sat
+ $(CC) -O2 -Iinclude -g -Wall -Wextra tests/sub_sat.c -o build/sub_sat
+ $(CC) -O2 -Iinclude -g -Wall -Wextra tests/mul_sat.c -o build/mul_sat
+ ./build/add_sat
+ ./build/sub_sat
+ ./build/mul_sat
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..cda3a6e
--- /dev/null
+++ b/README.md
@@ -0,0 +1 @@
+# GNU libstdc++v3 `sat_arith.h` ported to C
diff --git a/include/sat_arith/sat_arith.h b/include/sat_arith/sat_arith.h
new file mode 100644
index 0000000..a05e57f
--- /dev/null
+++ b/include/sat_arith/sat_arith.h
@@ -0,0 +1,142 @@
+#ifndef SAT_ARITH_H
+#define SAT_ARITH_H
+
+#include <limits.h>
+
+#define ADD_SAT(T, n, l) \
+static inline T add_sat_##n(T x, T y) \
+{ \
+ T z; \
+ if (!__builtin_add_overflow(x, y, &z)) \
+ return z; \
+ \
+ SAT_ARITH_UNSIGNED( \
+ return l##_MAX, \
+ \
+ if (x < 0) \
+ return l##_MIN; \
+ else \
+ return l##_MAX; \
+ ); \
+}
+
+#define SUB_SAT(T, n, l) \
+static inline T sub_sat_##n(T x, T y) \
+{ \
+ T z; \
+ if (!__builtin_sub_overflow(x, y, &z)) \
+ return z; \
+ \
+ SAT_ARITH_UNSIGNED( \
+ return 0, \
+ \
+ if (x < 0) \
+ return l##_MIN; \
+ else \
+ return l##_MAX; \
+ ); \
+}
+
+#define MUL_SAT(T, n, l) \
+static inline T mul_sat_##n(T x, T y) \
+{ \
+ T z; \
+ if (!__builtin_mul_overflow(x, y, &z)) \
+ return z; \
+ \
+ SAT_ARITH_UNSIGNED( \
+ return l##_MAX, \
+ \
+ if ((x < 0) != (y < 0)) \
+ return l##_MIN; \
+ else \
+ return l##_MAX; \
+ ); \
+}
+
+#define SAT_ARITH_UNSIGNED(x, y) y
+
+ADD_SAT(signed char , char , SCHAR);
+ADD_SAT(signed short , short, SHRT);
+ADD_SAT(signed int , int , INT);
+ADD_SAT(signed long , long , LONG);
+ADD_SAT(signed long long, llong, LLONG);
+
+SUB_SAT(signed char , char , SCHAR);
+SUB_SAT(signed short , short, SHRT);
+SUB_SAT(signed int , int , INT);
+SUB_SAT(signed long , long , LONG);
+SUB_SAT(signed long long, llong, LLONG);
+
+MUL_SAT(signed char , char , SCHAR);
+MUL_SAT(signed short , short, SHRT);
+MUL_SAT(signed int , int , INT);
+MUL_SAT(signed long , long , LONG);
+MUL_SAT(signed long long, llong, LLONG);
+
+#undef SAT_ARITH_UNSIGNED
+#define SAT_ARITH_UNSIGNED(x, y) x
+
+ADD_SAT(unsigned char , uchar , UCHAR);
+ADD_SAT(unsigned short , ushort, USHRT);
+ADD_SAT(unsigned int , uint , UINT);
+ADD_SAT(unsigned long , ulong , ULONG);
+ADD_SAT(unsigned long long, ullong, ULLONG);
+
+SUB_SAT(unsigned char , uchar , UCHAR);
+SUB_SAT(unsigned short , ushort, USHRT);
+SUB_SAT(unsigned int , uint , UINT);
+SUB_SAT(unsigned long , ulong , ULONG);
+SUB_SAT(unsigned long long, ullong, ULLONG);
+
+MUL_SAT(unsigned char , uchar , UCHAR);
+MUL_SAT(unsigned short , ushort, USHRT);
+MUL_SAT(unsigned int , uint , UINT);
+MUL_SAT(unsigned long , ulong , ULONG);
+MUL_SAT(unsigned long long, ullong, ULLONG);
+
+#undef SAT_ARITH_UNSIGNED
+
+#define add_sat(x, y) \
+ _Generic((x), \
+ signed char : add_sat_char,\
+ signed short : add_sat_short,\
+ signed int : add_sat_short,\
+ signed long : add_sat_long,\
+ signed long long : add_sat_llong,\
+ unsigned char : add_sat_uchar,\
+ unsigned short : add_sat_ushort,\
+ unsigned int : add_sat_ushort,\
+ unsigned long : add_sat_ulong,\
+ unsigned long long: add_sat_ullong\
+ )(x, y)
+
+#define sub_sat(x, y) \
+ _Generic((x), \
+ signed char : sub_sat_char,\
+ signed short : sub_sat_short,\
+ signed int : sub_sat_short,\
+ signed long : sub_sat_long,\
+ signed long long : sub_sat_llong,\
+ unsigned char : sub_sat_uchar,\
+ unsigned short : sub_sat_ushort,\
+ unsigned int : sub_sat_ushort,\
+ unsigned long : sub_sat_ulong,\
+ unsigned long long: sub_sat_ullong\
+ )(x, y)
+
+#define mul_sat(x, y) \
+ _Generic((x), \
+ signed char : mul_sat_char,\
+ signed short : mul_sat_short,\
+ signed int : mul_sat_short,\
+ signed long : mul_sat_long,\
+ signed long long : mul_sat_llong,\
+ unsigned char : mul_sat_uchar,\
+ unsigned short : mul_sat_ushort,\
+ unsigned int : mul_sat_ushort,\
+ unsigned long : mul_sat_ulong,\
+ unsigned long long: mul_sat_ullong\
+ )(x, y)
+
+#endif /* SAT_ARITH_H */
diff --git a/tests/add_sat.c b/tests/add_sat.c
new file mode 100644
index 0000000..b751814
--- /dev/null
+++ b/tests/add_sat.c
@@ -0,0 +1,128 @@
+#include <stdio.h>
+#include <stdint.h>
+#include <sat_arith/sat_arith.h>
+
+static int64_t formal_sat_add(int64_t i, int64_t j, int64_t min, int64_t max)
+{
+ int64_t result = i + j;
+ if (result < min)
+ result = min;
+
+ if (result > max)
+ result = max;
+
+ return result;
+}
+
+static int test_i8()
+{
+ for (int64_t i = INT8_MIN; i <= INT8_MAX; ++i)
+ for (int64_t j = INT8_MIN; j <= INT8_MAX; ++j) {
+ int64_t expected = formal_sat_add(i, j, INT8_MIN, INT8_MAX);
+ int8_t got = add_sat((int8_t)i, (int8_t)j);
+ if (expected != got) {
+ printf("%lld + %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static int test_i16()
+{
+ for (int64_t i = INT16_MIN; i <= INT16_MAX; ++i)
+ for (int64_t j = INT16_MIN; j <= INT16_MAX; ++j) {
+ int64_t expected = formal_sat_add(i, j, INT16_MIN, INT16_MAX);
+ int16_t got = add_sat((int16_t)i, (int16_t)j);
+ if (expected != got) {
+ printf("%lld + %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static int test_u16()
+{
+ for (int64_t i = 0; i <= UINT16_MAX; ++i)
+ for (int64_t j = 0; j <= UINT16_MAX; ++j) {
+ int64_t expected = formal_sat_add(i, j, 0, UINT16_MAX);
+ uint16_t got = add_sat((uint16_t)i, (uint16_t)j);
+ if (expected != got) {
+ printf("%lld + %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static int test_u8()
+{
+ for (int64_t i = 0; i <= UINT8_MAX; ++i)
+ for (int64_t j = 0; j <= UINT8_MAX; ++j) {
+ int64_t expected = formal_sat_add(i, j, 0, UINT8_MAX);
+ uint8_t got = add_sat((uint8_t)i, (uint8_t)j);
+ if (expected != got) {
+ printf("%lld + %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+int main()
+{
+ printf("add_sat\n");
+
+ /* these are still relatively easy to exhaustively check */
+ printf("i8...\n");
+ if (test_i8()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ printf("u8...\n");
+ if (test_u8()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ printf("i16...\n");
+ if (test_i16()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ printf("u16...\n");
+ if (test_u16()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ /* might add some random testing for 32/64 bit values, but presumably if
+ * the narrower types work, the wider ones should work as well (assuming
+ * I don't have some any silly typos or something) */
+
+ printf("OK!\n");
+ return 0;
+}
diff --git a/tests/mul_sat.c b/tests/mul_sat.c
new file mode 100644
index 0000000..42bccfb
--- /dev/null
+++ b/tests/mul_sat.c
@@ -0,0 +1,128 @@
+#include <stdio.h>
+#include <stdint.h>
+#include <sat_arith/sat_arith.h>
+
+static int64_t formal_sat_mul(int64_t i, int64_t j, int64_t min, int64_t max)
+{
+ int64_t result = i * j;
+ if (result < min)
+ result = min;
+
+ if (result > max)
+ result = max;
+
+ return result;
+}
+
+static int test_i8()
+{
+ for (int64_t i = INT8_MIN; i <= INT8_MAX; ++i)
+ for (int64_t j = INT8_MIN; j <= INT8_MAX; ++j) {
+ int64_t expected = formal_sat_mul(i, j, INT8_MIN, INT8_MAX);
+ int8_t got = mul_sat((int8_t)i, (int8_t)j);
+ if (expected != got) {
+ printf("%lld * %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static int test_i16()
+{
+ for (int64_t i = INT16_MIN; i <= INT16_MAX; ++i)
+ for (int64_t j = INT16_MIN; j <= INT16_MAX; ++j) {
+ int64_t expected = formal_sat_mul(i, j, INT16_MIN, INT16_MAX);
+ int16_t got = mul_sat((int16_t)i, (int16_t)j);
+ if (expected != got) {
+ printf("%lld * %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static int test_u16()
+{
+ for (int64_t i = 0; i <= UINT16_MAX; ++i)
+ for (int64_t j = 0; j <= UINT16_MAX; ++j) {
+ int64_t expected = formal_sat_mul(i, j, 0, UINT16_MAX);
+ uint16_t got = mul_sat((uint16_t)i, (uint16_t)j);
+ if (expected != got) {
+ printf("%lld * %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static int test_u8()
+{
+ for (int64_t i = 0; i <= UINT8_MAX; ++i)
+ for (int64_t j = 0; j <= UINT8_MAX; ++j) {
+ int64_t expected = formal_sat_mul(i, j, 0, UINT8_MAX);
+ uint8_t got = mul_sat((uint8_t)i, (uint8_t)j);
+ if (expected != got) {
+ printf("%lld * %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+int main()
+{
+ printf("mul_sat\n");
+
+ /* these are still relatively easy to exhaustively check */
+ printf("i8...\n");
+ if (test_i8()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ printf("u8...\n");
+ if (test_u8()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ printf("i16...\n");
+ if (test_i16()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ printf("u16...\n");
+ if (test_u16()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ /* might add some random testing for 32/64 bit values, but presumably if
+ * the narrower types work, the wider ones should work as well (assuming
+ * I don't have some any silly typos or something) */
+
+ printf("OK!\n");
+ return 0;
+}
diff --git a/tests/sub_sat.c b/tests/sub_sat.c
new file mode 100644
index 0000000..226362c
--- /dev/null
+++ b/tests/sub_sat.c
@@ -0,0 +1,128 @@
+#include <stdio.h>
+#include <stdint.h>
+#include <sat_arith/sat_arith.h>
+
+static int64_t formal_sat_sub(int64_t i, int64_t j, int64_t min, int64_t max)
+{
+ int64_t result = i - j;
+ if (result < min)
+ result = min;
+
+ if (result > max)
+ result = max;
+
+ return result;
+}
+
+static int test_i8()
+{
+ for (int64_t i = INT8_MIN; i <= INT8_MAX; ++i)
+ for (int64_t j = INT8_MIN; j <= INT8_MAX; ++j) {
+ int64_t expected = formal_sat_sub(i, j, INT8_MIN, INT8_MAX);
+ int8_t got = sub_sat((int8_t)i, (int8_t)j);
+ if (expected != got) {
+ printf("%lld - %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static int test_i16()
+{
+ for (int64_t i = INT16_MIN; i <= INT16_MAX; ++i)
+ for (int64_t j = INT16_MIN; j <= INT16_MAX; ++j) {
+ int64_t expected = formal_sat_sub(i, j, INT16_MIN, INT16_MAX);
+ int16_t got = sub_sat((int16_t)i, (int16_t)j);
+ if (expected != got) {
+ printf("%lld - %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static int test_u16()
+{
+ for (int64_t i = 0; i <= UINT16_MAX; ++i)
+ for (int64_t j = 0; j <= UINT16_MAX; ++j) {
+ int64_t expected = formal_sat_sub(i, j, 0, UINT16_MAX);
+ uint16_t got = sub_sat((uint16_t)i, (uint16_t)j);
+ if (expected != got) {
+ printf("%lld - %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static int test_u8()
+{
+ for (int64_t i = 0; i <= UINT8_MAX; ++i)
+ for (int64_t j = 0; j <= UINT8_MAX; ++j) {
+ int64_t expected = formal_sat_sub(i, j, 0, UINT8_MAX);
+ uint8_t got = sub_sat((uint8_t)i, (uint8_t)j);
+ if (expected != got) {
+ printf("%lld - %lld should be %lld, but got %lld!\n",
+ (long long)i,
+ (long long)j,
+ (long long)expected,
+ (long long)got);
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+int main()
+{
+ printf("sub_sat\n");
+
+ /* these are still relatively easy to exhaustively check */
+ printf("i8...\n");
+ if (test_i8()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ printf("u8...\n");
+ if (test_u8()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ printf("i16...\n");
+ if (test_i16()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ printf("u16...\n");
+ if (test_u16()) {
+ printf("FAIL!\n");
+ return -1;
+ }
+
+ /* might add some random testing for 32/64 bit values, but presumably if
+ * the narrower types work, the wider ones should work as well (assuming
+ * I don't have some any silly typos or something) */
+
+ printf("OK!\n");
+ return 0;
+}