Hi,

MULT, fused/chained multiply with add, and similar operations
with MINUS and NEG should all be handled in a consistent way.

To do that, we can pull out a common "mult" core from the partial
implementations found around aarch64_rtx_costs.

This patch performs that refactoring. One additional change we make
is to aarch64_strip_shift_or_extend, which becomes aarch64_strip_extend.
This allows us to catch the shift in our multiply code, and cost it
as appropriate. In order to maintain the precondition on
aarch64_is_extend_from_extract we must add an extra check before calling
it.

With this patch, we refactor the PLUS, MINUS and NEG costing cases.

Tested in series on aarch64-none-elf with no regressions.

OK for stage 1?

Thanks,
James

---
2014-03-27  James Greenhalgh  <james.greenha...@arm.com>
            Philipp Tomsich  <philipp.toms...@theobroma-systems.com>

        * config/aarch64/aarch64.c (aarch64_strip_shift_or_extend): Rename
        to...
        (aarch64_strip_extend): ...this, don't strip shifts, check RTX is
        well formed.
        (aarch64_rtx_mult_cost): New.
        (aarch64_rtx_costs): Use it, refactor as appropriate.
diff --git a/gcc/config/aarch64/aarch64.c b/gcc/config/aarch64/aarch64.c
index af947ca..11dc788 100644
--- a/gcc/config/aarch64/aarch64.c
+++ b/gcc/config/aarch64/aarch64.c
@@ -499,7 +499,7 @@ aarch64_is_long_call_p (rtx sym)
    represent an expression that matches an extend operation.  The
    operands represent the paramters from
 
-   (extract (mult (reg) (mult_imm)) (extract_imm) (const_int 0)).  */
+   (extract:MODE (mult (reg) (MULT_IMM)) (EXTRACT_IMM) (const_int 0)).  */
 bool
 aarch64_is_extend_from_extract (enum machine_mode mode, rtx mult_imm,
 				rtx extract_imm)
@@ -4543,18 +4543,19 @@ aarch64_strip_shift (rtx x)
   return x;
 }
 
-/* Helper function for rtx cost calculation.  Strip a shift or extend
+/* Helper function for rtx cost calculation.  Strip an extend
    expression from X.  Returns the inner operand if successful, or the
    original expression on failure.  We deal with a number of possible
    canonicalization variations here.  */
 static rtx
-aarch64_strip_shift_or_extend (rtx x)
+aarch64_strip_extend (rtx x)
 {
   rtx op = x;
 
   /* Zero and sign extraction of a widened value.  */
   if ((GET_CODE (op) == ZERO_EXTRACT || GET_CODE (op) == SIGN_EXTRACT)
       && XEXP (op, 2) == const0_rtx
+      && GET_CODE (XEXP (op, 0)) == MULT
       && aarch64_is_extend_from_extract (GET_MODE (op), XEXP (XEXP (op, 0), 1),
 					 XEXP (op, 1)))
     return XEXP (XEXP (op, 0), 0);
@@ -4583,7 +4584,122 @@ aarch64_strip_shift_or_extend (rtx x)
   if (op != x)
     return op;
 
-  return aarch64_strip_shift (x);
+  return x;
+}
+
+/* Helper function for rtx cost calculation.  Calculate the cost of
+   a MULT, which may be part of a multiply-accumulate rtx.  Return
+   the calculated cost of the expression, recursing manually in to
+   operands where needed.  */
+
+static int
+aarch64_rtx_mult_cost (rtx x, int code, int outer, bool speed)
+{
+  rtx op0, op1;
+  const struct cpu_cost_table *extra_cost
+    = aarch64_tune_params->insn_extra_cost;
+  int cost = 0;
+  bool maybe_fma = (outer == PLUS || outer == MINUS);
+  enum machine_mode mode = GET_MODE (x);
+
+  gcc_checking_assert (code == MULT);
+
+  op0 = XEXP (x, 0);
+  op1 = XEXP (x, 1);
+
+  if (VECTOR_MODE_P (mode))
+    mode = GET_MODE_INNER (mode);
+
+  /* Integer multiply/fma.  */
+  if (GET_MODE_CLASS (mode) == MODE_INT)
+    {
+      /* The multiply will be canonicalized as a shift, cost it as such.  */
+      if (CONST_INT_P (op1)
+	  && exact_log2 (INTVAL (op1)) > 0)
+	{
+	  if (speed)
+	    {
+	      if (maybe_fma)
+		/* ADD (shifted register).  */
+		cost += extra_cost->alu.arith_shift;
+	      else
+		/* LSL (immediate).  */
+		cost += extra_cost->alu.shift;
+	    }
+
+	  cost += rtx_cost (op0, GET_CODE (op0), 0, speed);
+
+	  return cost;
+	}
+
+      /* Integer multiplies or FMAs have zero/sign extending variants.  */
+      if ((GET_CODE (op0) == ZERO_EXTEND
+	   && GET_CODE (op1) == ZERO_EXTEND)
+	  || (GET_CODE (op0) == SIGN_EXTEND
+	      && GET_CODE (op1) == SIGN_EXTEND))
+	{
+	  cost += rtx_cost (XEXP (op0, 0), MULT, 0, speed)
+		  + rtx_cost (XEXP (op1, 0), MULT, 1, speed);
+
+	  if (speed)
+	    {
+	      if (maybe_fma)
+		/* MADD/SMADDL/UMADDL.  */
+		cost += extra_cost->mult[0].extend_add;
+	      else
+		/* MUL/SMULL/UMULL.  */
+		cost += extra_cost->mult[0].extend;
+	    }
+
+	  return cost;
+	}
+
+      /* This is either an integer multiply or an FMA.  In both cases
+	 we want to recurse and cost the operands.  */
+      cost += rtx_cost (op0, MULT, 0, speed)
+	      + rtx_cost (op1, MULT, 1, speed);
+
+      if (speed)
+	{
+	  if (maybe_fma)
+	    /* MADD.  */
+	    cost += extra_cost->mult[mode == DImode].add;
+	  else
+	    /* MUL.  */
+	    cost += extra_cost->mult[mode == DImode].simple;
+	}
+
+      return cost;
+    }
+  else
+    {
+      if (speed)
+	{
+	  /* Floating-point FMA can also support negations of the
+	     operands.  */
+	  if (GET_CODE (op0) == NEG)
+	    {
+	      maybe_fma = true;
+	      op0 = XEXP (op0, 0);
+	    }
+	  if (GET_CODE (op1) == NEG)
+	    {
+	      maybe_fma = true;
+	      op1 = XEXP (op1, 0);
+	    }
+
+	  if (maybe_fma)
+	    /* FMADD/FNMADD/FNMSUB/FMSUB.  */
+	    cost += extra_cost->fp[mode == DFmode].fma;
+	  else
+	    /* FMUL.  */
+	    cost += extra_cost->fp[mode == DFmode].mult;
+	}
+
+      cost += rtx_cost (op0, MULT, 0, speed)
+	      + rtx_cost (op1, MULT, 1, speed);
+      return cost;
+    }
 }
 
 static int
@@ -4795,9 +4911,42 @@ aarch64_rtx_costs (rtx x, int code, int outer ATTRIBUTE_UNUSED,
       return true;
 
     case NEG:
-      op0 = CONST0_RTX (GET_MODE (x));
-      op1 = XEXP (x, 0);
-      goto cost_minus;
+      op0 = XEXP (x, 0);
+
+      if (GET_MODE_CLASS (GET_MODE (x)) == MODE_INT)
+       {
+          if (GET_RTX_CLASS (GET_CODE (op0)) == RTX_COMPARE
+              || GET_RTX_CLASS (GET_CODE (op0)) == RTX_COMM_COMPARE)
+            {
+              /* CSETM.  */
+              *cost += rtx_cost (XEXP (op0, 0), NEG, 0, speed);
+              return true;
+            }
+
+	  /* Cost this as SUB wzr, X.  */
+          op0 = CONST0_RTX (GET_MODE (x));
+          op1 = XEXP (x, 0);
+          goto cost_minus;
+        }
+
+      if (GET_MODE_CLASS (GET_MODE (x)) == MODE_FLOAT)
+        {
+          /* Support (neg(fma...)) as a single instruction only if
+             sign of zeros is unimportant.  This matches the decision
+             making in aarch64.md.  */
+          if (GET_CODE (op0) == FMA && !HONOR_SIGNED_ZEROS (GET_MODE (op0)))
+            {
+	      /* FNMADD.  */
+              *cost = rtx_cost (op0, NEG, 0, speed);
+              return true;
+            }
+	  if (speed)
+	    /* FNEG.  */
+	    *cost += extra_cost->fp[mode == DFmode].neg;
+          return false;
+        }
+
+      return false;
 
     case COMPARE:
       op0 = XEXP (x, 0);
@@ -4822,82 +4971,110 @@ aarch64_rtx_costs (rtx x, int code, int outer ATTRIBUTE_UNUSED,
       goto cost_minus;
 
     case MINUS:
-      op0 = XEXP (x, 0);
-      op1 = XEXP (x, 1);
+      {
+	op0 = XEXP (x, 0);
+	op1 = XEXP (x, 1);
+
+cost_minus:
+	/* Detect valid immediates.  */
+	if ((GET_MODE_CLASS (mode) == MODE_INT
+	     || (GET_MODE_CLASS (mode) == MODE_CC
+		 && GET_MODE_CLASS (GET_MODE (op0)) == MODE_INT))
+	    && CONST_INT_P (op1)
+	    && aarch64_uimm12_shift (INTVAL (op1)))
+	  {
+	    *cost += rtx_cost (op0, MINUS, 0, speed);
 
-    cost_minus:
-      if (GET_MODE_CLASS (GET_MODE (x)) == MODE_INT
-	  || (GET_MODE_CLASS (GET_MODE (x)) == MODE_CC
-	      && GET_MODE_CLASS (GET_MODE (op0)) == MODE_INT))
-	{
-	  if (op0 != const0_rtx)
+	    if (speed)
+	      /* SUB(S) (immediate).  */
+	      *cost += extra_cost->alu.arith;
+	    return true;
+
+	  }
+
+	rtx new_op1 = aarch64_strip_extend (op1);
+
+	/* Cost this as an FMA-alike operation.  */
+	if ((GET_CODE (new_op1) == MULT
+	     || GET_CODE (new_op1) == ASHIFT)
+	    && code != COMPARE)
+	  {
+	    *cost += aarch64_rtx_mult_cost (new_op1, MULT,
+					    (enum rtx_code) code,
+					    speed);
 	    *cost += rtx_cost (op0, MINUS, 0, speed);
+	    return true;
+	  }
 
-	  if (CONST_INT_P (op1))
-	    {
-	      if (!aarch64_uimm12_shift (INTVAL (op1)))
-		*cost += rtx_cost (op1, MINUS, 1, speed);
-	    }
-	  else
-	    {
-	      op1 = aarch64_strip_shift_or_extend (op1);
-	      *cost += rtx_cost (op1, MINUS, 1, speed);
-	    }
-	  return true;
-	}
+	*cost += rtx_cost (new_op1, MINUS, 1, speed);
 
-      return false;
+	if (speed)
+	  {
+	    if (GET_MODE_CLASS (mode) == MODE_INT)
+	      /* SUB(S).  */
+	      *cost += extra_cost->alu.arith;
+	    else if (GET_MODE_CLASS (mode) == MODE_FLOAT)
+	      /* FSUB.  */
+	      *cost += extra_cost->fp[mode == DFmode].addsub;
+	  }
+	return true;
+      }
 
     case PLUS:
-      op0 = XEXP (x, 0);
-      op1 = XEXP (x, 1);
+      {
+	rtx new_op0;
 
-      if (GET_MODE_CLASS (GET_MODE (x)) == MODE_INT)
-	{
-	  if (CONST_INT_P (op1) && aarch64_uimm12_shift (INTVAL (op1)))
-	    {
-	      *cost += rtx_cost (op0, PLUS, 0, speed);
-	    }
-	  else
-	    {
-	      rtx new_op0 = aarch64_strip_shift_or_extend (op0);
+	op0 = XEXP (x, 0);
+	op1 = XEXP (x, 1);
 
-	      if (new_op0 == op0
-		  && GET_CODE (op0) == MULT)
-		{
-		  if ((GET_CODE (XEXP (op0, 0)) == ZERO_EXTEND
-		       && GET_CODE (XEXP (op0, 1)) == ZERO_EXTEND)
-		      || (GET_CODE (XEXP (op0, 0)) == SIGN_EXTEND
-			  && GET_CODE (XEXP (op0, 1)) == SIGN_EXTEND))
-		    {
-		      *cost += (rtx_cost (XEXP (XEXP (op0, 0), 0), MULT, 0,
-					  speed)
-				+ rtx_cost (XEXP (XEXP (op0, 1), 0), MULT, 1,
-					    speed)
-				+ rtx_cost (op1, PLUS, 1, speed));
-		      if (speed)
-			*cost +=
-			  extra_cost->mult[GET_MODE (x) == DImode].extend_add;
-		      return true;
-		    }
-
-		  *cost += (rtx_cost (XEXP (op0, 0), MULT, 0, speed)
-			    + rtx_cost (XEXP (op0, 1), MULT, 1, speed)
-			    + rtx_cost (op1, PLUS, 1, speed));
-
-		  if (speed)
-		    *cost += extra_cost->mult[GET_MODE (x) == DImode].add;
-
-		  return true;
-		}
+	if (GET_RTX_CLASS (GET_CODE (op0)) == RTX_COMPARE
+	    || GET_RTX_CLASS (GET_CODE (op0)) == RTX_COMM_COMPARE)
+	  {
+	    /* CSINC.  */
+	    *cost += rtx_cost (XEXP (op0, 0), PLUS, 0, speed);
+	    *cost += rtx_cost (op1, PLUS, 1, speed);
+	    return true;
+	  }
 
-	      *cost += (rtx_cost (new_op0, PLUS, 0, speed)
-			+ rtx_cost (op1, PLUS, 1, speed));
-	    }
-	  return true;
-	}
+	if (GET_MODE_CLASS (mode) == MODE_INT
+	    && CONST_INT_P (op1)
+	    && aarch64_uimm12_shift (INTVAL (op1)))
+	  {
+	    *cost += rtx_cost (op0, PLUS, 0, speed);
 
-      return false;
+	    if (speed)
+	      /* ADD (immediate).  */
+	      *cost += extra_cost->alu.arith;
+	    return true;
+	  }
+
+	/* Strip any extend, leave shifts behind as we will
+	   cost them through mult_cost.  */
+	new_op0 = aarch64_strip_extend (op0);
+
+	if (GET_CODE (new_op0) == MULT
+	    || GET_CODE (new_op0) == ASHIFT)
+	  {
+	    *cost += aarch64_rtx_mult_cost (new_op0, MULT, PLUS,
+					    speed);
+	    *cost += rtx_cost (op1, PLUS, 1, speed);
+	    return true;
+	  }
+
+	*cost += (rtx_cost (new_op0, PLUS, 0, speed)
+		  + rtx_cost (op1, PLUS, 1, speed));
+
+	if (speed)
+	  {
+	    if (GET_MODE_CLASS (mode) == MODE_INT)
+	      /* ADD.  */
+	      *cost += extra_cost->alu.arith;
+	    else if (GET_MODE_CLASS (mode) == MODE_FLOAT)
+	      /* FADD.  */
+	      *cost += extra_cost->fp[mode == DFmode].addsub;
+	  }
+	return true;
+      }
 
     case IOR:
     case XOR:
@@ -4976,43 +5153,10 @@ aarch64_rtx_costs (rtx x, int code, int outer ATTRIBUTE_UNUSED,
       return true;
 
     case MULT:
-      op0 = XEXP (x, 0);
-      op1 = XEXP (x, 1);
-
-      *cost = COSTS_N_INSNS (1);
-      if (GET_MODE_CLASS (GET_MODE (x)) == MODE_INT)
-	{
-	  if (CONST_INT_P (op1)
-	      && exact_log2 (INTVAL (op1)) > 0)
-	    {
-	      *cost += rtx_cost (op0, ASHIFT, 0, speed);
-	      return true;
-	    }
-
-	  if ((GET_CODE (op0) == ZERO_EXTEND
-	       && GET_CODE (op1) == ZERO_EXTEND)
-	      || (GET_CODE (op0) == SIGN_EXTEND
-		  && GET_CODE (op1) == SIGN_EXTEND))
-	    {
-	      *cost += (rtx_cost (XEXP (op0, 0), MULT, 0, speed)
-			+ rtx_cost (XEXP (op1, 0), MULT, 1, speed));
-	      if (speed)
-		*cost += extra_cost->mult[GET_MODE (x) == DImode].extend;
-	      return true;
-	    }
-
-	  if (speed)
-	    *cost += extra_cost->mult[GET_MODE (x) == DImode].simple;
-	}
-      else if (speed)
-	{
-	  if (GET_MODE (x) == DFmode)
-	    *cost += extra_cost->fp[1].mult;
-	  else if (GET_MODE (x) == SFmode)
-	    *cost += extra_cost->fp[0].mult;
-	}
-
-      return false;  /* All arguments need to be in registers.  */
+      *cost += aarch64_rtx_mult_cost (x, MULT, 0, speed);
+      /* aarch64_rtx_mult_cost always handles recursion to its
+	 operands.  */
+      return true;
 
     case MOD:
     case UMOD:

Reply via email to