From 68549f47dfb2296deaf27ce403490af4e513de95 Mon Sep 17 00:00:00 2001
From: Kyrylo Tkachov <ktkachov@nvidia.com>
Date: Tue, 15 Oct 2024 06:32:31 -0700
Subject: [PATCH 1/3] PR 117048: simplify-rtx: Simplify (X << C1) [+,^] (X >>
 C2) into ROTATE

simplify-rtx can transform (X << C1) | (X >> C2) into ROTATE (X, C1) when
C1 + C2 == mode-width.  But the transformation is also valid for PLUS and XOR.
Indeed GIMPLE can also do the fold.  Let's teach RTL to do it too.

The motivating testcase for this is in AArch64 intrinsics:

uint64x2_t G2(uint64x2_t a, uint64x2_t b) {
    uint64x2_t c = veorq_u64(a, b);
    return veorq_u64(vaddq_u64(c, c), vshrq_n_u64(c, 63));
}

which I was hoping to fold to a single XAR (a ROTATE+XOR instruction) but
GCC was failing to detect the rotate operation for two reasons:
1) The combination of the two arms of the expression is done under XOR rather
than IOR that simplify-rtx currently supports.
2) The ASHIFT operation is actually a (PLUS X X) operation and thus is not
detected as the LHS of the two arms we require.

The patch fixes both issues.  The analysis of the two arms of the rotation
expression is factored out into a common helper simplify_rotate which is
then used in the PLUS, XOR, IOR cases in simplify_binary_operation_1.

The check-assembly testcase for this is added in the following patch because
it needs some extra AArch64 backend work, but I've added self-tests in this
patch to validate the transformation.

Bootstrapped and tested on aarch64-none-linux-gnu

Signed-off-by: Kyrylo Tkachov <ktachov@nvidia.com>

	PR target/117048
	* simplify-rtx.cc (extract_ashift_operands_p): Define.
	(simplify_rotate_op): Likewise.
	(simplify_context::simplify_binary_operation_1): Use the above in
	the PLUS, IOR, XOR cases.
	(test_vector_rotate): Define.
	(test_vector_ops): Use the above.
---
 gcc/simplify-rtx.cc | 204 +++++++++++++++++++++++++++++++++-----------
 1 file changed, 156 insertions(+), 48 deletions(-)

diff --git a/gcc/simplify-rtx.cc b/gcc/simplify-rtx.cc
index 4d024ec523b..089e03c2a7a 100644
--- a/gcc/simplify-rtx.cc
+++ b/gcc/simplify-rtx.cc
@@ -2820,6 +2820,104 @@ reverse_rotate_by_imm_p (machine_mode mode, unsigned int left, rtx op1)
   return false;
 }
 
+/* Analyse argument X to see if it represents an (ASHIFT X Y) operation
+   and return the expression to be shifted in SHIFT_OPND and the shift amount
+   in SHIFT_AMNT.  This is primarily used to group handling of ASHIFT (X, CST)
+   and (PLUS (X, X)) in one place.  If the expression is not equivalent to an
+   ASHIFT then return FALSE and set SHIFT_OPND and SHIFT_AMNT to NULL.  */
+
+static bool
+extract_ashift_operands_p (rtx x, rtx *shift_opnd, rtx *shift_amnt)
+{
+  if (GET_CODE (x) == ASHIFT)
+    {
+      *shift_opnd = XEXP (x, 0);
+      *shift_amnt = XEXP (x, 1);
+      return true;
+    }
+  if (GET_CODE (x) == PLUS && rtx_equal_p (XEXP (x, 0), XEXP (x, 1)))
+    {
+      *shift_opnd = XEXP (x, 0);
+      *shift_amnt = CONST1_RTX (GET_MODE (x));
+      return true;
+    }
+  *shift_opnd = NULL_RTX;
+  *shift_amnt = NULL_RTX;
+  return false;
+}
+
+/* OP0 and OP1 are combined under an operation of mode MODE that can
+   potentially result in a ROTATE expression.  Analyze the OP0 and OP1
+   and return the resulting ROTATE expression if so.  Return NULL otherwise.
+   This is used in detecting the patterns (X << C1) [+,|,^] (X >> C2) where
+   C1 + C2 == GET_MODE_UNIT_PRECISION (mode).
+   (X << C1) and (C >> C2) would be OP0 and OP1.  */
+
+static rtx
+simplify_rotate_op (rtx op0, rtx op1, machine_mode mode)
+{
+  /* Convert (ior (ashift A CX) (lshiftrt A CY)) where CX+CY equals the
+     mode size to (rotate A CX).  */
+
+  rtx opleft = simplify_rtx (op0);
+  rtx opright = simplify_rtx (op1);
+  rtx ashift_opnd, ashift_amnt;
+  /* In some cases the ASHIFT is not a direct ASHIFT.  Look deeper and extract
+     the relevant operands here.  */
+  bool ashift_op_p
+    = extract_ashift_operands_p (op1, &ashift_opnd, &ashift_amnt);
+
+  if (ashift_op_p
+     || GET_CODE (op1) == SUBREG)
+    {
+      opleft = op1;
+      opright = op0;
+    }
+  else
+    {
+      opright = op1;
+      opleft = op0;
+      ashift_op_p
+	= extract_ashift_operands_p (opleft, &ashift_opnd, &ashift_amnt);
+    }
+
+  if (ashift_op_p && GET_CODE (opright) == LSHIFTRT
+      && rtx_equal_p (ashift_opnd, XEXP (opright, 0)))
+    {
+      rtx leftcst = unwrap_const_vec_duplicate (ashift_amnt);
+      rtx rightcst = unwrap_const_vec_duplicate (XEXP (opright, 1));
+
+      if (CONST_INT_P (leftcst) && CONST_INT_P (rightcst)
+	  && (INTVAL (leftcst) + INTVAL (rightcst)
+	      == GET_MODE_UNIT_PRECISION (mode)))
+	return gen_rtx_ROTATE (mode, XEXP (opright, 0), ashift_amnt);
+    }
+
+  /* Same, but for ashift that has been "simplified" to a wider mode
+     by simplify_shift_const.  */
+  scalar_int_mode int_mode, inner_mode;
+
+  if (GET_CODE (opleft) == SUBREG
+      && is_a <scalar_int_mode> (mode, &int_mode)
+      && is_a <scalar_int_mode> (GET_MODE (SUBREG_REG (opleft)),
+				 &inner_mode)
+      && GET_CODE (SUBREG_REG (opleft)) == ASHIFT
+      && GET_CODE (opright) == LSHIFTRT
+      && GET_CODE (XEXP (opright, 0)) == SUBREG
+      && known_eq (SUBREG_BYTE (opleft), SUBREG_BYTE (XEXP (opright, 0)))
+      && GET_MODE_SIZE (int_mode) < GET_MODE_SIZE (inner_mode)
+      && rtx_equal_p (XEXP (SUBREG_REG (opleft), 0),
+		      SUBREG_REG (XEXP (opright, 0)))
+      && CONST_INT_P (XEXP (SUBREG_REG (opleft), 1))
+      && CONST_INT_P (XEXP (opright, 1))
+      && (INTVAL (XEXP (SUBREG_REG (opleft), 1))
+	    + INTVAL (XEXP (opright, 1))
+	 == GET_MODE_PRECISION (int_mode)))
+	return gen_rtx_ROTATE (int_mode, XEXP (opright, 0),
+			       XEXP (SUBREG_REG (opleft), 1));
+  return NULL_RTX;
+}
+
 /* Subroutine of simplify_binary_operation.  Simplify a binary operation
    CODE with result mode MODE, operating on OP0 and OP1.  If OP0 and/or
    OP1 are constant pool references, TRUEOP0 and TRUEOP1 represent the
@@ -2831,7 +2929,7 @@ simplify_context::simplify_binary_operation_1 (rtx_code code,
 					       rtx op0, rtx op1,
 					       rtx trueop0, rtx trueop1)
 {
-  rtx tem, reversed, opleft, opright, elt0, elt1;
+  rtx tem, reversed, elt0, elt1;
   HOST_WIDE_INT val;
   scalar_int_mode int_mode, inner_mode;
   poly_int64 offset;
@@ -3030,6 +3128,11 @@ simplify_context::simplify_binary_operation_1 (rtx_code code,
 	return
 	  simplify_gen_unary (NEG, mode, reversed, mode);
 
+      /* Convert (plus (ashift A CX) (lshiftrt A CY)) where CX+CY equals the
+	 mode size to (rotate A CX).  */
+      if ((tem = simplify_rotate_op (op0, op1, mode)))
+	return tem;
+
       /* If one of the operands is a PLUS or a MINUS, see if we can
 	 simplify this by the associative law.
 	 Don't use the associative law for floating point.
@@ -3462,53 +3565,10 @@ simplify_context::simplify_binary_operation_1 (rtx_code code,
 	return op1;
 
       /* Convert (ior (ashift A CX) (lshiftrt A CY)) where CX+CY equals the
-         mode size to (rotate A CX).  */
-
-      if (GET_CODE (op1) == ASHIFT
-          || GET_CODE (op1) == SUBREG)
-        {
-	  opleft = op1;
-	  opright = op0;
-	}
-      else
-        {
-	  opright = op1;
-	  opleft = op0;
-	}
-
-      if (GET_CODE (opleft) == ASHIFT && GET_CODE (opright) == LSHIFTRT
-	  && rtx_equal_p (XEXP (opleft, 0), XEXP (opright, 0)))
-	{
-	  rtx leftcst = unwrap_const_vec_duplicate (XEXP (opleft, 1));
-	  rtx rightcst = unwrap_const_vec_duplicate (XEXP (opright, 1));
-
-	  if (CONST_INT_P (leftcst) && CONST_INT_P (rightcst)
-	      && (INTVAL (leftcst) + INTVAL (rightcst)
-		  == GET_MODE_UNIT_PRECISION (mode)))
-	    return gen_rtx_ROTATE (mode, XEXP (opright, 0), XEXP (opleft, 1));
-	}
-
-      /* Same, but for ashift that has been "simplified" to a wider mode
-        by simplify_shift_const.  */
-
-      if (GET_CODE (opleft) == SUBREG
-	  && is_a <scalar_int_mode> (mode, &int_mode)
-	  && is_a <scalar_int_mode> (GET_MODE (SUBREG_REG (opleft)),
-				     &inner_mode)
-          && GET_CODE (SUBREG_REG (opleft)) == ASHIFT
-          && GET_CODE (opright) == LSHIFTRT
-          && GET_CODE (XEXP (opright, 0)) == SUBREG
-	  && known_eq (SUBREG_BYTE (opleft), SUBREG_BYTE (XEXP (opright, 0)))
-	  && GET_MODE_SIZE (int_mode) < GET_MODE_SIZE (inner_mode)
-          && rtx_equal_p (XEXP (SUBREG_REG (opleft), 0),
-                          SUBREG_REG (XEXP (opright, 0)))
-          && CONST_INT_P (XEXP (SUBREG_REG (opleft), 1))
-          && CONST_INT_P (XEXP (opright, 1))
-	  && (INTVAL (XEXP (SUBREG_REG (opleft), 1))
-	      + INTVAL (XEXP (opright, 1))
-	      == GET_MODE_PRECISION (int_mode)))
-	return gen_rtx_ROTATE (int_mode, XEXP (opright, 0),
-			       XEXP (SUBREG_REG (opleft), 1));
+	 mode size to (rotate A CX).  */
+      tem = simplify_rotate_op (op0, op1, mode);
+      if (tem)
+	return tem;
 
       /* If OP0 is (ashiftrt (plus ...) C), it might actually be
          a (sign_extend (plus ...)).  Then check if OP1 is a CONST_INT and
@@ -3838,6 +3898,12 @@ simplify_context::simplify_binary_operation_1 (rtx_code code,
 	    return tem;
 	}
 
+      /* Convert (xor (ashift A CX) (lshiftrt A CY)) where CX+CY equals the
+	 mode size to (rotate A CX).  */
+      tem = simplify_rotate_op (op0, op1, mode);
+      if (tem)
+	return tem;
+
       /* Convert (xor (and (not A) B) A) into A | B.  */
       if (GET_CODE (op0) == AND
 	  && GET_CODE (XEXP (op0, 0)) == NOT
@@ -8654,6 +8720,46 @@ test_vec_merge (machine_mode mode)
 		 simplify_rtx (nvm));
 }
 
+/* Test that vector rotate formation works at RTL level.  Try various
+   combinations of (REG << C) [|,^,+] (REG >> (<bitwidth> - C)).  */
+
+static void
+test_vector_rotate (rtx reg)
+{
+  machine_mode mode = GET_MODE (reg);
+  unsigned bitwidth = GET_MODE_UNIT_SIZE (mode) * BITS_PER_UNIT;
+  rtx plus_rtx = gen_rtx_PLUS (mode, reg, reg);
+  rtx lshftrt_amnt = GEN_INT (bitwidth - 1);
+  lshftrt_amnt = gen_const_vec_duplicate (mode, lshftrt_amnt);
+  rtx lshiftrt_rtx = gen_rtx_LSHIFTRT (mode, reg, lshftrt_amnt);
+  rtx rotate_rtx = gen_rtx_ROTATE (mode, reg, CONST1_RTX (mode));
+  /* Test explicitly the case where ASHIFT (x, 1) is a PLUS (x, x).  */
+  ASSERT_RTX_EQ (rotate_rtx,
+	     simplify_rtx (gen_rtx_IOR (mode, plus_rtx, lshiftrt_rtx)));
+  ASSERT_RTX_EQ (rotate_rtx,
+	     simplify_rtx (gen_rtx_XOR (mode, plus_rtx, lshiftrt_rtx)));
+  ASSERT_RTX_EQ (rotate_rtx,
+	     simplify_rtx (gen_rtx_PLUS (mode, plus_rtx, lshiftrt_rtx)));
+
+  /* Don't go through every possible rotate amount to save execution time.
+     Multiple of BITS_PER_UNIT amounts could conceivably be simplified to
+     other bswap operations sometimes. Go through just the odd amounts.  */
+  for (unsigned i = 3; i < bitwidth - 2; i += 2)
+    {
+      rtx rot_amnt = gen_const_vec_duplicate (mode, GEN_INT (i));
+      rtx ashift_rtx = gen_rtx_ASHIFT (mode, reg, rot_amnt);
+      lshftrt_amnt = gen_const_vec_duplicate (mode, GEN_INT (bitwidth - i));
+      lshiftrt_rtx = gen_rtx_LSHIFTRT (mode, reg, lshftrt_amnt);
+      rotate_rtx = gen_rtx_ROTATE (mode, reg, rot_amnt);
+      ASSERT_RTX_EQ (rotate_rtx,
+		 simplify_rtx (gen_rtx_IOR (mode, ashift_rtx, lshiftrt_rtx)));
+      ASSERT_RTX_EQ (rotate_rtx,
+		 simplify_rtx (gen_rtx_XOR (mode, ashift_rtx, lshiftrt_rtx)));
+      ASSERT_RTX_EQ (rotate_rtx,
+		 simplify_rtx (gen_rtx_PLUS (mode, ashift_rtx, lshiftrt_rtx)));
+    }
+}
+
 /* Test subregs of integer vector constant X, trying elements in
    the range [ELT_BIAS, ELT_BIAS + constant_lower_bound (NELTS)),
    where NELTS is the number of elements in X.  Subregs involving
@@ -8825,11 +8931,13 @@ test_vector_ops ()
 	{
 	  rtx scalar_reg = make_test_reg (GET_MODE_INNER (mode));
 	  test_vector_ops_duplicate (mode, scalar_reg);
+	  rtx vector_reg = make_test_reg (mode);
 	  if (GET_MODE_CLASS (mode) == MODE_VECTOR_INT
 	      && maybe_gt (GET_MODE_NUNITS (mode), 2))
 	    {
 	      test_vector_ops_series (mode, scalar_reg);
 	      test_vector_subregs (mode);
+	      test_vector_rotate (vector_reg);
 	    }
 	  test_vec_merge (mode);
 	}
-- 
2.44.0

