From e7b22f97f960b62e555dfd6f2e3ae43973fcbb3e Mon Sep 17 00:00:00 2001
From: Pronin Alexander 00812787 <pronin.alexander@huawei.com>
Date: Wed, 25 Jan 2023 15:04:07 +0300
Subject: [PATCH 05/18] Match double sized mul pattern

---
 gcc/match.pd                              | 136 +++++++++++++++++++++
 gcc/testsuite/gcc.dg/double_sized_mul-1.c | 141 ++++++++++++++++++++++
 gcc/testsuite/gcc.dg/double_sized_mul-2.c |  62 ++++++++++
 gcc/tree-ssa-math-opts.cc                 |  80 ++++++++++++
 4 files changed, 419 insertions(+)
 create mode 100644 gcc/testsuite/gcc.dg/double_sized_mul-1.c
 create mode 100644 gcc/testsuite/gcc.dg/double_sized_mul-2.c

diff --git a/gcc/match.pd b/gcc/match.pd
index 3cbaf2a5b..61866cb90 100644
--- a/gcc/match.pd
+++ b/gcc/match.pd
@@ -7895,3 +7895,139 @@ and,
 	       == TYPE_UNSIGNED (TREE_TYPE (@3))))
        && single_use (@4)
        && single_use (@5))))
+
+/* Match multiplication with double sized result.
+
+   Consider the following calculations:
+   arg0 * arg1 = (2^(bit_size/2) * arg0_hi + arg0_lo)
+	       * (2^(bit_size/2) * arg1_hi + arg1_lo)
+   arg0 * arg1 = 2^bit_size * arg0_hi * arg1_hi
+	       + 2^(bit_size/2) * (arg0_hi * arg1_lo + arg0_lo * arg1_hi)
+	       + arg0_lo * arg1_lo
+
+   The products of high and low parts fits in bit_size values, thus they are
+   placed in high and low parts of result respectively.
+
+   The sum of the mixed products may overflow, so we need a detection for that.
+   Also it has a bit_size/2 offset, thus it intersects with both high and low
+   parts of result.  Overflow detection constant is bit_size/2 due to this.
+
+   With this info:
+   arg0 * arg1 = 2^bit_size * arg0_hi * arg1_hi
+	       + 2^(bit_size/2) * middle
+	       + 2^bit_size * possible_middle_overflow
+	       + arg0_lo * arg1_lo
+   arg0 * arg1 = 2^bit_size * (arg0_hi * arg1_hi + possible_middle_overflow)
+	       + 2^(bit_size/2) * (2^(bit_size/2) * middle_hi + middle_lo)
+	       + arg0_lo * arg1_lo
+   arg0 * arg1 = 2^bit_size * (arg0_hi * arg1_hi + middle_hi
+	       +	       possible_middle_overflow)
+	       + 2^(bit_size/2) * middle_lo
+	       + arg0_lo * arg1_lo
+
+   The last sum can produce overflow for the high result part.  With this:
+   arg0 * arg1 = 2^bit_size * (arg0_hi * arg1_hi + possible_middle_overflow
+	       +	       possible_res_lo_overflow + middle_hi)
+	       + res_lo
+	       = res_hi + res_lo
+
+   This formula is quite big to fit into one match pattern with all of the
+   combinations of terms inside it.  There are many helpers for better code
+   readability.
+
+   The simplification basis is res_hi: assuming that res_lo only is not
+   real practical case for such calculations.
+
+   Overflow handling is done via matching complex calculations:
+   the realpart and imagpart are quite handy here.  */
+/* Match low and high parts of the argument.  */
+(match (double_size_mul_arg_lo @0 @1)
+ (bit_and @0 INTEGER_CST@1)
+  (if (wi::to_wide (@1)
+       == wi::mask (TYPE_PRECISION (type) / 2, false, TYPE_PRECISION (type)))))
+(match (double_size_mul_arg_hi @0 @1)
+ (rshift @0 INTEGER_CST@1)
+  (if (wi::to_wide (@1) == TYPE_PRECISION (type) / 2)))
+
+/* Match various argument parts products.  */
+(match (double_size_mul_lolo @0 @1)
+ (mult@4 (double_size_mul_arg_lo @0 @2) (double_size_mul_arg_lo @1 @3))
+  (if (single_use (@4))))
+(match (double_size_mul_hihi @0 @1)
+ (mult@4 (double_size_mul_arg_hi @0 @2) (double_size_mul_arg_hi @1 @3))
+  (if (single_use (@4))))
+(match (double_size_mul_lohi @0 @1)
+ (mult:c@4 (double_size_mul_arg_lo @0 @2) (double_size_mul_arg_hi @1 @3))
+  (if (single_use (@4))))
+
+/* Match complex middle sum.  */
+(match (double_size_mul_middle_complex @0 @1)
+ (IFN_ADD_OVERFLOW@2 (double_size_mul_lohi @0 @1) (double_size_mul_lohi @1 @0))
+  (if (num_imm_uses (@2) == 2)))
+
+/* Match real middle results.  */
+(match (double_size_mul_middle @0 @1)
+ (realpart@2 (double_size_mul_middle_complex @0 @1))
+  (if (num_imm_uses (@2) == 2)))
+(match (double_size_mul_middleres_lo @0 @1)
+ (lshift@3 (double_size_mul_middle @0 @1) INTEGER_CST@2)
+  (if (wi::to_wide (@2) == TYPE_PRECISION (type) / 2
+       && single_use (@3))))
+(match (double_size_mul_middleres_hi @0 @1)
+ (rshift@3 (double_size_mul_middle @0 @1) INTEGER_CST@2)
+  (if (wi::to_wide (@2) == TYPE_PRECISION (type) / 2
+       && single_use (@3))))
+
+/* Match low result part.  */
+/* Number of uses may be < 2 in case when we are interested in
+   high part only.  */
+(match (double_size_mul_res_lo_complex @0 @1)
+ (IFN_ADD_OVERFLOW:c@2
+  (double_size_mul_lolo:c @0 @1) (double_size_mul_middleres_lo @0 @1))
+  (if (num_imm_uses (@2) <= 2)))
+(match (double_size_mul_res_lo @0 @1)
+ (realpart (double_size_mul_res_lo_complex @0 @1)))
+
+/* Match overflow terms.  */
+(match (double_size_mul_overflow_check_lo @0 @1 @5)
+ (convert@4 (ne@3
+  (imagpart@2 (double_size_mul_res_lo_complex@5 @0 @1)) integer_zerop))
+  (if (single_use (@2) && single_use (@3) && single_use (@4))))
+(match (double_size_mul_overflow_check_hi @0 @1)
+ (lshift@6 (convert@5 (ne@4
+  (imagpart@3 (double_size_mul_middle_complex @0 @1)) integer_zerop))
+	   INTEGER_CST@2)
+  (if (wi::to_wide (@2) == TYPE_PRECISION (type) / 2
+       && single_use (@3) && single_use (@4) && single_use (@5)
+       && single_use (@6))))
+
+/* Match all possible permutations for high result part calculations.  */
+(for op1 (double_size_mul_hihi
+	  double_size_mul_overflow_check_hi
+	  double_size_mul_middleres_hi)
+     op2 (double_size_mul_overflow_check_hi
+	  double_size_mul_middleres_hi
+	  double_size_mul_hihi)
+     op3 (double_size_mul_middleres_hi
+	  double_size_mul_hihi
+	  double_size_mul_overflow_check_hi)
+ (match (double_size_mul_candidate @0 @1 @2 @3)
+  (plus:c@2
+   (plus:c@4 (double_size_mul_overflow_check_lo @0 @1 @3) (op1:c @0 @1))
+   (plus:c@5 (op2:c @0 @1) (op3:c @0 @1)))
+    (if (single_use (@4) && single_use (@5))))
+ (match (double_size_mul_candidate @0 @1 @2 @3)
+  (plus:c@2 (double_size_mul_overflow_check_lo @0 @1 @3)
+   (plus:c@4 (op1:c @0 @1)
+    (plus:c@5 (op2:c @0 @1) (op3:c @0 @1))))
+     (if (single_use (@4) && single_use (@5))))
+ (match (double_size_mul_candidate @0 @1 @2 @3)
+  (plus:c@2 (op1:c @0 @1)
+   (plus:c@4 (double_size_mul_overflow_check_lo @0 @1 @3)
+    (plus:c@5 (op2:c @0 @1) (op3:c @0 @1))))
+     (if (single_use (@4) && single_use (@5))))
+ (match (double_size_mul_candidate @0 @1 @2 @3)
+  (plus:c@2 (op1:c @0 @1)
+   (plus:c@4 (op2:c @0 @1)
+    (plus:c@5 (double_size_mul_overflow_check_lo @0 @1 @3) (op3:c @0 @1))))
+     (if (single_use (@4) && single_use (@5)))))
diff --git a/gcc/testsuite/gcc.dg/double_sized_mul-1.c b/gcc/testsuite/gcc.dg/double_sized_mul-1.c
new file mode 100644
index 000000000..4d475cc8a
--- /dev/null
+++ b/gcc/testsuite/gcc.dg/double_sized_mul-1.c
@@ -0,0 +1,141 @@
+/* { dg-do compile } */
+/* fif-conversion-gimple and fuaddsub-overflow-match-all are required for
+   proper overflow detection in some cases.  */
+/* { dg-options "-O2 -fif-conversion-gimple -fuaddsub-overflow-match-all -fdump-tree-widening_mul-stats" } */
+#include <stdint.h>
+
+typedef unsigned __int128 uint128_t;
+
+uint16_t mul16 (uint8_t a, uint8_t b)
+{
+  uint8_t a_lo = a & 0xF;
+  uint8_t b_lo = b & 0xF;
+  uint8_t a_hi = a >> 4;
+  uint8_t b_hi = b >> 4;
+  uint8_t lolo = a_lo * b_lo;
+  uint8_t lohi = a_lo * b_hi;
+  uint8_t hilo = a_hi * b_lo;
+  uint8_t hihi = a_hi * b_hi;
+  uint8_t middle = hilo + lohi;
+  uint8_t middle_hi = middle >> 4;
+  uint8_t middle_lo = middle << 4;
+  uint8_t res_lo = lolo + middle_lo;
+  uint8_t res_hi = hihi + middle_hi;
+  res_hi += (res_lo < middle_lo ? 1 : 0);
+  res_hi += (middle < hilo ? 0x10 : 0);
+  uint16_t res = ((uint16_t) res_hi) << 8;
+  res += res_lo;
+  return res;
+}
+
+uint32_t mul32 (uint16_t a, uint16_t b)
+{
+  uint16_t a_lo = a & 0xFF;
+  uint16_t b_lo = b & 0xFF;
+  uint16_t a_hi = a >> 8;
+  uint16_t b_hi = b >> 8;
+  uint16_t lolo = a_lo * b_lo;
+  uint16_t lohi = a_lo * b_hi;
+  uint16_t hilo = a_hi * b_lo;
+  uint16_t hihi = a_hi * b_hi;
+  uint16_t middle = hilo + lohi;
+  uint16_t middle_hi = middle >> 8;
+  uint16_t middle_lo = middle << 8;
+  uint16_t res_lo = lolo + middle_lo;
+  uint16_t res_hi = hihi + middle_hi;
+  res_hi += (res_lo < middle_lo ? 1 : 0);
+  res_hi += (middle < hilo ? 0x100 : 0);
+  uint32_t res = ((uint32_t) res_hi) << 16;
+  res += res_lo;
+  return res;
+}
+
+uint64_t mul64 (uint32_t a, uint32_t b)
+{
+  uint32_t a_lo = a & 0xFFFF;
+  uint32_t b_lo = b & 0xFFFF;
+  uint32_t a_hi = a >> 16;
+  uint32_t b_hi = b >> 16;
+  uint32_t lolo = a_lo * b_lo;
+  uint32_t lohi = a_lo * b_hi;
+  uint32_t hilo = a_hi * b_lo;
+  uint32_t hihi = a_hi * b_hi;
+  uint32_t middle = hilo + lohi;
+  uint32_t middle_hi = middle >> 16;
+  uint32_t middle_lo = middle << 16;
+  uint32_t res_lo = lolo + middle_lo;
+  uint32_t res_hi = hihi + middle_hi;
+  res_hi += (res_lo < middle_lo ? 1 : 0);
+  res_hi += (middle < hilo ? 0x10000 : 0);
+  uint64_t res = ((uint64_t) res_hi) << 32;
+  res += res_lo;
+  return res;
+}
+
+uint128_t mul128 (uint64_t a, uint64_t b)
+{
+  uint64_t a_lo = a & 0xFFFFFFFF;
+  uint64_t b_lo = b & 0xFFFFFFFF;
+  uint64_t a_hi = a >> 32;
+  uint64_t b_hi = b >> 32;
+  uint64_t lolo = a_lo * b_lo;
+  uint64_t lohi = a_lo * b_hi;
+  uint64_t hilo = a_hi * b_lo;
+  uint64_t hihi = a_hi * b_hi;
+  uint64_t middle = hilo + lohi;
+  uint64_t middle_hi = middle >> 32;
+  uint64_t middle_lo = middle << 32;
+  uint64_t res_lo = lolo + middle_lo;
+  uint64_t res_hi = hihi + middle_hi;
+  res_hi += (res_lo < middle_lo ? 1 : 0);
+  res_hi += (middle < hilo ? 0x100000000 : 0);
+  uint128_t res = ((uint128_t) res_hi) << 64;
+  res += res_lo;
+  return res;
+}
+
+uint64_t mul64_perm (uint32_t a, uint32_t b)
+{
+  uint32_t a_lo = a & 0xFFFF;
+  uint32_t b_lo = b & 0xFFFF;
+  uint32_t a_hi = a >> 16;
+  uint32_t b_hi = b >> 16;
+  uint32_t lolo = a_lo * b_lo;
+  uint32_t lohi = a_lo * b_hi;
+  uint32_t hilo = a_hi * b_lo;
+  uint32_t hihi = a_hi * b_hi;
+  uint32_t middle = hilo + lohi;
+  uint32_t middle_hi = middle >> 16;
+  uint32_t middle_lo = middle << 16;
+  uint32_t res_lo = lolo + middle_lo;
+  uint32_t res_hi = hihi + middle_hi;
+  res_hi = res_lo < middle_lo ? res_hi + 1 : res_hi;
+  res_hi = middle < hilo ? res_hi + 0x10000 : res_hi;
+  uint64_t res = ((uint64_t) res_hi) << 32;
+  res += res_lo;
+  return res;
+}
+
+uint128_t mul128_perm (uint64_t a, uint64_t b)
+{
+  uint64_t a_lo = a & 0xFFFFFFFF;
+  uint64_t b_lo = b & 0xFFFFFFFF;
+  uint64_t a_hi = a >> 32;
+  uint64_t b_hi = b >> 32;
+  uint64_t lolo = a_lo * b_lo;
+  uint64_t lohi = a_lo * b_hi;
+  uint64_t hilo = a_hi * b_lo;
+  uint64_t hihi = a_hi * b_hi;
+  uint64_t middle = hilo + lohi;
+  uint64_t middle_hi = middle >> 32;
+  uint64_t middle_lo = middle << 32;
+  uint64_t res_lo = lolo + middle_lo;
+  uint64_t res_hi = hihi + middle_hi;
+  res_hi = res_lo < middle_lo ? res_hi + 1 : res_hi;
+  res_hi = middle < hilo ? res_hi + 0x100000000 : res_hi;
+  uint128_t res = ((uint128_t) res_hi) << 64;
+  res += res_lo;
+  return res;
+}
+
+/* { dg-final { scan-tree-dump-times "double sized mul optimized: 1" 6 "widening_mul" } } */
diff --git a/gcc/testsuite/gcc.dg/double_sized_mul-2.c b/gcc/testsuite/gcc.dg/double_sized_mul-2.c
new file mode 100644
index 000000000..cc6e5af25
--- /dev/null
+++ b/gcc/testsuite/gcc.dg/double_sized_mul-2.c
@@ -0,0 +1,62 @@
+/* { dg-do compile } */
+/* fif-conversion-gimple is required for proper overflow detection
+   in some cases.  */
+/* { dg-options "-O2 -fif-conversion-gimple -fuaddsub-overflow-match-all -fdump-tree-widening_mul-stats" } */
+#include <stdint.h>
+
+typedef unsigned __int128 uint128_t;
+typedef struct uint256_t
+{
+    uint128_t lo;
+    uint128_t hi;
+} uint256_t;
+
+uint64_t mul64_double_use (uint32_t a, uint32_t b)
+{
+  uint32_t a_lo = a & 0xFFFF;
+  uint32_t b_lo = b & 0xFFFF;
+  uint32_t a_hi = a >> 16;
+  uint32_t b_hi = b >> 16;
+  uint32_t lolo = a_lo * b_lo;
+  uint32_t lohi = a_lo * b_hi;
+  uint32_t hilo = a_hi * b_lo;
+  uint32_t hihi = a_hi * b_hi;
+  uint32_t middle = hilo + lohi;
+  uint32_t middle_hi = middle >> 16;
+  uint32_t middle_lo = middle << 16;
+  uint32_t res_lo = lolo + middle_lo;
+  uint32_t res_hi = hihi + middle_hi;
+  res_hi += (res_lo < middle_lo ? 1 : 0);
+  res_hi += (middle < hilo ? 0x10000 : 0);
+  uint64_t res = ((uint64_t) res_hi) << 32;
+  res += res_lo;
+  return res + lolo;
+}
+
+uint256_t mul256 (uint128_t a, uint128_t b)
+{
+  uint128_t a_lo = a & 0xFFFFFFFFFFFFFFFF;
+  uint128_t b_lo = b & 0xFFFFFFFFFFFFFFFF;
+  uint128_t a_hi = a >> 64;
+  uint128_t b_hi = b >> 64;
+  uint128_t lolo = a_lo * b_lo;
+  uint128_t lohi = a_lo * b_hi;
+  uint128_t hilo = a_hi * b_lo;
+  uint128_t hihi = a_hi * b_hi;
+  uint128_t middle = hilo + lohi;
+  uint128_t middle_hi = middle >> 64;
+  uint128_t middle_lo = middle << 64;
+  uint128_t res_lo = lolo + middle_lo;
+  uint128_t res_hi = hihi + middle_hi;
+  res_hi += (res_lo < middle_lo ? 1 : 0);
+  /* Constant is to big warning WA */
+  uint128_t overflow_tmp = (middle < hilo ? 1 : 0);
+  overflow_tmp <<= 64;
+  res_hi += overflow_tmp;
+  uint256_t res;
+  res.lo = res_lo;
+  res.hi = res_hi;
+  return res;
+}
+
+/* { dg-final { scan-tree-dump-not "double sized mul optimized" "widening_mul" } } */
diff --git a/gcc/tree-ssa-math-opts.cc b/gcc/tree-ssa-math-opts.cc
index 55d6ee8ae..2c06b8a60 100644
--- a/gcc/tree-ssa-math-opts.cc
+++ b/gcc/tree-ssa-math-opts.cc
@@ -210,6 +210,9 @@ static struct
 
   /* Number of highpart multiplication ops inserted.  */
   int highpart_mults_inserted;
+
+  /* Number of optimized double sized multiplications.  */
+  int double_sized_mul_optimized;
 } widen_mul_stats;
 
 /* The instance of "struct occurrence" representing the highest
@@ -4893,6 +4896,78 @@ optimize_spaceship (gimple *stmt)
 }
 
 
+/* Pattern matcher for double sized multiplication defined in match.pd.  */
+extern bool gimple_double_size_mul_candidate (tree, tree*, tree (*)(tree));
+
+static bool
+convert_double_size_mul (gimple_stmt_iterator *gsi, gimple *stmt)
+{
+  gimple *use_stmt, *complex_res_lo;
+  gimple_stmt_iterator insert_before;
+  imm_use_iterator use_iter;
+  tree match[4]; // arg0, arg1, res_hi, complex_res_lo
+  tree arg0, arg1, widen_mult, new_type, tmp;
+  tree lhs = gimple_assign_lhs (stmt);
+  location_t loc = UNKNOWN_LOCATION;
+  machine_mode mode;
+
+  if (!gimple_double_size_mul_candidate (lhs, match, NULL))
+    return false;
+
+  new_type = build_nonstandard_integer_type (
+	  TYPE_PRECISION (TREE_TYPE (match[0])) * 2, 1);
+  mode = TYPE_MODE (new_type);
+
+  /* Early return if the target multiplication doesn't exist on target.  */
+  if (optab_handler (smul_optab, mode) == CODE_FOR_nothing
+      && !wider_optab_check_p (smul_optab, mode, 1))
+    return false;
+
+  /* Determine the point where the wide multiplication
+     should be inserted.  Complex low res is OK since it is required
+     by both high and low part getters, thus it dominates both of them.  */
+  complex_res_lo = SSA_NAME_DEF_STMT (match[3]);
+  insert_before = gsi_for_stmt (complex_res_lo);
+  gsi_next (&insert_before);
+
+  /* Create the widen multiplication.  */
+  arg0 = build_and_insert_cast (&insert_before, loc, new_type, match[0]);
+  arg1 = build_and_insert_cast (&insert_before, loc, new_type, match[1]);
+  widen_mult = build_and_insert_binop (&insert_before, loc, "widen_mult",
+				       MULT_EXPR, arg0, arg1);
+
+  /* Find the mult low part getter.  */
+  FOR_EACH_IMM_USE_STMT (use_stmt, use_iter, match[3])
+    if (gimple_assign_rhs_code (use_stmt) == REALPART_EXPR)
+      break;
+
+  /* Create high and low (if needed) parts extractors.  */
+  /* Low part.  */
+  if (use_stmt)
+    {
+      loc = gimple_location (use_stmt);
+      tmp = build_and_insert_cast (&insert_before, loc,
+	  	      		   TREE_TYPE (gimple_get_lhs (use_stmt)),
+	  			   widen_mult);
+      gassign *new_stmt = gimple_build_assign (gimple_get_lhs (use_stmt),
+	    				       NOP_EXPR, tmp);
+      gsi_replace (&insert_before, new_stmt, true);
+    }
+
+  /* High part.  */
+  loc = gimple_location (stmt);
+  tmp = build_and_insert_binop (gsi, loc, "widen_mult_hi",
+				RSHIFT_EXPR, widen_mult,
+				build_int_cst (new_type,
+					       TYPE_PRECISION (new_type) / 2));
+  tmp = build_and_insert_cast (gsi, loc, TREE_TYPE (lhs), tmp);
+  gassign *new_stmt = gimple_build_assign (lhs, NOP_EXPR, tmp);
+  gsi_replace (gsi, new_stmt, true);
+
+  widen_mul_stats.double_sized_mul_optimized++;
+  return true;
+}
+
 /* Find integer multiplications where the operands are extended from
    smaller types, and replace the MULT_EXPR with a WIDEN_MULT_EXPR
    or MULT_HIGHPART_EXPR where appropriate.  */
@@ -4987,6 +5062,9 @@ math_opts_dom_walker::after_dom_children (basic_block bb)
 	      break;
 
 	    case PLUS_EXPR:
+	      if (convert_double_size_mul (&gsi, stmt))
+		break;
+	      __attribute__ ((fallthrough));
 	    case MINUS_EXPR:
 	      if (!convert_plusminus_to_widen (&gsi, stmt, code))
 		match_arith_overflow (&gsi, stmt, code, m_cfg_changed_p);
@@ -5091,6 +5169,8 @@ pass_optimize_widening_mul::execute (function *fun)
 			    widen_mul_stats.divmod_calls_inserted);
   statistics_counter_event (fun, "highpart multiplications inserted",
 			    widen_mul_stats.highpart_mults_inserted);
+  statistics_counter_event (fun, "double sized mul optimized",
+			    widen_mul_stats.double_sized_mul_optimized);
 
   return cfg_changed ? TODO_cleanup_cfg : 0;
 }
-- 
2.33.0