Many thanks to Uros for reviewing/approving all of the previous pieces.
This patch adds support for converting 128-bit TImode shifts and rotates
to SSE equivalents using V1TImode during the TImode STV pass.
Previously, only logical shifts by multiples of 8 were handled
(from my patch earlier this month).

As an example of the benefits, the following rotate by 32-bits:

unsigned __int128 a, b;
void rot32() { a = (b >> 32) | (b << 96); }

when compiled on x86_64 with -O2 previously generated:

        movq    b(%rip), %rax
        movq    b+8(%rip), %rdx
        movq    %rax, %rcx
        shrdq   $32, %rdx, %rax
        shrdq   $32, %rcx, %rdx
        movq    %rax, a(%rip)
        movq    %rdx, a+8(%rip)
        ret

with this patch, now generates:

        movdqa  b(%rip), %xmm0
        pshufd  $57, %xmm0, %xmm0
        movaps  %xmm0, a(%rip)
        ret

[which uses a V4SI permutation for those that don't read SSE].
This should help 128-bit cryptography codes, that interleave XORs
with rotations (but that don't use additions or subtractions).

This patch has been tested on x86_64-pc-linux-gnu with make bootstrap
and make -k check, both with and without --target_board=unix{-m32},
with no new failures.  Ok for mainline?


2022-08-15  Roger Sayle  <ro...@nextmovesoftware.com>

gcc/ChangeLog
        * config/i386/i386-features.cc
        (timode_scalar_chain::compute_convert_gain): Provide costs for
        shifts and rotates.  Provide gains for comparisons against 0/-1.
        (timode_scalar_chain::convert_insn): Handle ASHIFTRT, ROTATERT
        and ROTATE just like existing ASHIFT and LSHIFTRT cases.
        (timode_scalar_to_vector_candidate_p): Handle all shifts and
        rotates by integer constants between 0 and 127.

gcc/testsuite/ChangeLog
        * gcc.target/i386/sse4_1-stv-9.c: New test case.


Thanks in advance,
Roger
--

diff --git a/gcc/config/i386/i386-features.cc b/gcc/config/i386/i386-features.cc
index effc2f2..8ab65c8 100644
--- a/gcc/config/i386/i386-features.cc
+++ b/gcc/config/i386/i386-features.cc
@@ -1209,6 +1209,8 @@ timode_scalar_chain::compute_convert_gain ()
       rtx def_set = single_set (insn);
       rtx src = SET_SRC (def_set);
       rtx dst = SET_DEST (def_set);
+      HOST_WIDE_INT op1val;
+      int scost, vcost;
       int igain = 0;
 
       switch (GET_CODE (src))
@@ -1245,9 +1247,157 @@ timode_scalar_chain::compute_convert_gain ()
 
        case ASHIFT:
        case LSHIFTRT:
-         /* For logical shifts by constant multiples of 8. */
-         igain = optimize_insn_for_size_p () ? COSTS_N_BYTES (4)
-                                             : COSTS_N_INSNS (1);
+         /* See ix86_expand_v1ti_shift.  */
+         op1val = XINT (src, 1);
+         if (optimize_insn_for_size_p ())
+           {
+             if (op1val == 64 || op1val == 65)
+               scost = COSTS_N_BYTES (5);
+             else if (op1val >= 66)
+               scost = COSTS_N_BYTES (6);
+             else if (op1val == 1)
+               scost = COSTS_N_BYTES (8);
+             else
+               scost = COSTS_N_BYTES (9);
+
+             if ((op1val & 7) == 0)
+               vcost = COSTS_N_BYTES (5);
+             else if (op1val > 64)
+               vcost = COSTS_N_BYTES (10);
+             else
+               vcost = TARGET_AVX ? COSTS_N_BYTES (19) : COSTS_N_BYTES (23);
+           }
+         else
+           {
+             scost = COSTS_N_INSNS (2);
+             if ((op1val & 7) == 0)
+               vcost = COSTS_N_INSNS (1);
+             else if (op1val > 64)
+               vcost = COSTS_N_INSNS (2);
+             else
+               vcost = TARGET_AVX ? COSTS_N_INSNS (4) : COSTS_N_INSNS (5);
+           }
+         igain = scost - vcost;
+         break;
+
+       case ASHIFTRT:
+         /* See ix86_expand_v1ti_ashiftrt.  */
+         op1val = XINT (src, 1);
+         if (optimize_insn_for_size_p ())
+           {
+             if (op1val == 64 || op1val == 127)
+               scost = COSTS_N_BYTES (7);
+             else if (op1val == 1)
+               scost = COSTS_N_BYTES (8);
+             else if (op1val == 65)
+               scost = COSTS_N_BYTES (10);
+             else if (op1val >= 66)
+               scost = COSTS_N_BYTES (11);
+             else
+               scost = COSTS_N_BYTES (9);
+
+             if (op1val == 127)
+               vcost = COSTS_N_BYTES (10);
+             else if (op1val == 64)
+               vcost = COSTS_N_BYTES (14);
+             else if (op1val == 96)
+               vcost = COSTS_N_BYTES (18);
+             else if (op1val >= 111)
+               vcost = COSTS_N_BYTES (15);
+              else if (TARGET_AVX2 && op1val == 32)
+               vcost = COSTS_N_BYTES (16);
+             else if (TARGET_SSE4_1 && op1val == 32)
+               vcost = COSTS_N_BYTES (20);
+             else if (op1val >= 96)
+               vcost = COSTS_N_BYTES (23);
+             else if ((op1val & 7) == 0)
+               vcost = COSTS_N_BYTES (28);
+              else if (TARGET_AVX2 && op1val < 32)
+               vcost = COSTS_N_BYTES (30);
+             else if (op1val == 1 || op1val >= 64)
+               vcost = COSTS_N_BYTES (42);
+             else
+               vcost = COSTS_N_BYTES (47);
+           }
+         else
+           {
+             if (op1val >= 65 && op1val <= 126)
+               scost = COSTS_N_INSNS (3);
+             else
+               scost = COSTS_N_INSNS (2);
+
+             if (op1val == 127)
+               vcost = COSTS_N_INSNS (2);
+             else if (op1val == 64)
+               vcost = COSTS_N_INSNS (3);
+             else if (op1val == 96)
+               vcost = COSTS_N_INSNS (4);
+             else if (op1val >= 111)
+               vcost = COSTS_N_INSNS (3);
+              else if (TARGET_AVX2 && op1val == 32)
+               vcost = COSTS_N_INSNS (3);
+             else if (TARGET_SSE4_1 && op1val == 32)
+               vcost = COSTS_N_INSNS (4);
+             else if (op1val >= 96)
+               vcost = COSTS_N_INSNS (5);
+             else if ((op1val & 7) == 0)
+               vcost = COSTS_N_INSNS (6);
+              else if (TARGET_AVX2 && op1val < 32)
+               vcost = COSTS_N_INSNS (6);
+             else if (op1val == 1 || op1val >= 64)
+               vcost = COSTS_N_INSNS (9);
+             else
+               vcost = COSTS_N_INSNS (10);
+           }
+         igain = scost - vcost;
+         break;
+
+       case ROTATE:
+       case ROTATERT:
+         /* See ix86_expand_v1ti_rotate.  */
+         op1val = XINT (src, 1);
+         if (optimize_insn_for_size_p ())
+           {
+             scost = COSTS_N_BYTES (13);
+             if ((op1val & 31) == 0)
+               vcost = COSTS_N_BYTES (5);
+             else if ((op1val & 7) == 0)
+               vcost = TARGET_AVX ? COSTS_N_BYTES (13) : COSTS_N_BYTES (18);
+              else if (op1val > 32 && op1val < 96)
+               vcost = COSTS_N_BYTES (24);
+             else
+               vcost = COSTS_N_BYTES (19);
+           }
+         else
+           {
+             scost = COSTS_N_INSNS (3);
+             if ((op1val & 31) == 0)
+               vcost = COSTS_N_INSNS (1);
+             else if ((op1val & 7) == 0)
+               vcost = TARGET_AVX ? COSTS_N_INSNS (3) : COSTS_N_INSNS (4);
+              else if (op1val > 32 && op1val < 96)
+               vcost = COSTS_N_INSNS (5);
+             else
+               vcost = COSTS_N_INSNS (1);
+           }
+         igain = scost - vcost;
+         break;
+
+       case COMPARE:
+         if (XEXP (src, 1) == const0_rtx)
+           {
+             if (GET_CODE (XEXP (src, 0)) == AND)
+               /* and;and;or (9 bytes) vs. ptest (5 bytes).  */
+               igain = optimize_insn_for_size_p() ? COSTS_N_BYTES (4)
+                                                  : COSTS_N_INSNS (2);
+             /* or (3 bytes) vs. ptest (5 bytes).  */
+             else if (optimize_insn_for_size_p ())
+               igain = -COSTS_N_BYTES (2);
+           }
+         else if (XEXP (src, 1) == const1_rtx)
+           /* and;cmp -1 (7 bytes) vs. pcmpeqd;pxor;ptest (13 bytes).  */
+           igain = optimize_insn_for_size_p() ? -COSTS_N_BYTES (6)
+                                              : -COSTS_N_INSNS (1);
          break;
 
        default:
@@ -1503,6 +1653,9 @@ timode_scalar_chain::convert_insn (rtx_insn *insn)
 
     case ASHIFT:
     case LSHIFTRT:
+    case ASHIFTRT:
+    case ROTATERT:
+    case ROTATE:
       convert_op (&XEXP (src, 0), insn);
       PUT_MODE (src, V1TImode);
       break;
@@ -1861,11 +2014,13 @@ timode_scalar_to_vector_candidate_p (rtx_insn *insn)
 
     case ASHIFT:
     case LSHIFTRT:
-      /* Handle logical shifts by integer constants between 0 and 120
-        that are multiples of 8.  */
+    case ASHIFTRT:
+    case ROTATERT:
+    case ROTATE:
+      /* Handle shifts/rotates by integer constants between 0 and 127.  */
       return REG_P (XEXP (src, 0))
             && CONST_INT_P (XEXP (src, 1))
-            && (INTVAL (XEXP (src, 1)) & ~0x78) == 0;
+            && (INTVAL (XEXP (src, 1)) & ~0x7f) == 0;
 
     default:
       return false;
diff --git a/gcc/testsuite/gcc.target/i386/sse4_1-stv-9.c 
b/gcc/testsuite/gcc.target/i386/sse4_1-stv-9.c
new file mode 100644
index 0000000..ee5af3c
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/sse4_1-stv-9.c
@@ -0,0 +1,12 @@
+/* { dg-do compile { target int128 } } */
+/* { dg-options "-O2 -msse4.1 -mstv -mno-stackrealign" } */
+
+unsigned __int128 a, b;
+void rot1()  { a = (b >> 1) | (b << 127); }
+void rot4()  { a = (b >> 4) | (b << 124); }
+void rot8()  { a = (b >> 8) | (b << 120); }
+void rot32() { a = (b >> 32) | (b << 96); }
+void rot64() { a = (b >> 64) | (b << 64); }
+
+/* { dg-final { scan-assembler-not "shrdq" } } */
+/* { dg-final { scan-assembler "pshufd" } } */

Reply via email to