The following implements masked load-lane discovery for SLP.  The
challenge here is that a masked load has a full-width mask with
group-size number of elements when this becomes a masked load-lanes
instruction one mask element gates all group members.  We already
have some discovery hints in place, namely STMT_VINFO_SLP_VECT_ONLY
to guard non-uniform masks, but we need to choose a way for SLP
discovery to handle possible masked load-lanes SLP trees.

I have this time chosen to handle load-lanes discovery where we
have performed permute optimization already and conveniently got
the graph with predecessor edges built.  This is because unlike
non-masked loads masked loads with a load_permutation are never
produced by SLP discovery (because load permutation handling doesn't
handle un-permuting the mask) and thus the load-permutation lowering
which handles non-masked load-lanes discovery doesn't trigger.

With this SLP discovery for a possible masked load-lanes, thus
a masked load with uniform mask, produces a splat of a single-lane
sub-graph as the mask SLP operand.  This is a representation that
shouldn't pessimize the mask load case and allows the masked load-lanes
transform to simply elide this splat.

This fixes the aarch64-sve.exp mask_struct_load*.c testcases with
--param vect-force-slp=1

Re-bootstrap & regtest running on x86_64-unknown-linux-gnu, the
observed CI FAILs are gone.

        PR tree-optimization/116575
        * tree-vect-slp.cc (vect_get_and_check_slp_defs): Handle
        gaps, aka NULL scalar stmt.
        (vect_build_slp_tree_2): Allow gaps in the middle of a
        grouped mask load.  When the mask of a grouped mask load
        is uniform do single-lane discovery for the mask and
        insert a splat VEC_PERM_EXPR node.
        (vect_optimize_slp_pass::decide_masked_load_lanes): New
        function.
        (vect_optimize_slp_pass::run): Call it.
---
 gcc/tree-vect-slp.cc | 141 ++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 138 insertions(+), 3 deletions(-)

diff --git a/gcc/tree-vect-slp.cc b/gcc/tree-vect-slp.cc
index 53f5400a961..b192328e3eb 100644
--- a/gcc/tree-vect-slp.cc
+++ b/gcc/tree-vect-slp.cc
@@ -641,6 +641,16 @@ vect_get_and_check_slp_defs (vec_info *vinfo, unsigned 
char swap,
   unsigned int commutative_op = -1U;
   bool first = stmt_num == 0;
 
+  if (!stmt_info)
+    {
+      for (auto oi : *oprnds_info)
+       {
+         oi->def_stmts.quick_push (NULL);
+         oi->ops.quick_push (NULL_TREE);
+       }
+      return 0;
+    }
+
   if (!is_a<gcall *> (stmt_info->stmt)
       && !is_a<gassign *> (stmt_info->stmt)
       && !is_a<gphi *> (stmt_info->stmt))
@@ -2029,9 +2039,11 @@ vect_build_slp_tree_2 (vec_info *vinfo, slp_tree node,
                    has_gaps = true;
              /* We cannot handle permuted masked loads directly, see
                 PR114375.  We cannot handle strided masked loads or masked
-                loads with gaps.  */
+                loads with gaps unless the mask is uniform.  */
              if ((STMT_VINFO_GROUPED_ACCESS (stmt_info)
-                  && (DR_GROUP_GAP (first_stmt_info) != 0 || has_gaps))
+                  && (DR_GROUP_GAP (first_stmt_info) != 0
+                      || (has_gaps
+                          && STMT_VINFO_SLP_VECT_ONLY (first_stmt_info))))
                  || STMT_VINFO_STRIDED_P (stmt_info))
                {
                  load_permutation.release ();
@@ -2054,7 +2066,12 @@ vect_build_slp_tree_2 (vec_info *vinfo, slp_tree node,
                  unsigned i = 0;
                  for (stmt_vec_info si = first_stmt_info;
                       si; si = DR_GROUP_NEXT_ELEMENT (si))
-                   stmts2[i++] = si;
+                   {
+                     if (si != first_stmt_info)
+                       for (unsigned k = 1; k < DR_GROUP_GAP (si); ++k)
+                         stmts2[i++] = NULL;
+                     stmts2[i++] = si;
+                   }
                  bool *matches2 = XALLOCAVEC (bool, dr_group_size);
                  slp_tree unperm_load
                    = vect_build_slp_tree (vinfo, stmts2, dr_group_size,
@@ -2683,6 +2700,46 @@ out:
          continue;
        }
 
+      /* When we have a masked load with uniform mask discover this
+        as a single-lane mask with a splat permute.  This way we can
+        recognize this as a masked load-lane by stripping the splat.  */
+      if (is_a <gcall *> (STMT_VINFO_STMT (stmt_info))
+         && gimple_call_internal_p (STMT_VINFO_STMT (stmt_info),
+                                    IFN_MASK_LOAD)
+         && STMT_VINFO_GROUPED_ACCESS (stmt_info)
+         && ! STMT_VINFO_SLP_VECT_ONLY (DR_GROUP_FIRST_ELEMENT (stmt_info)))
+       {
+         vec<stmt_vec_info> def_stmts2;
+         def_stmts2.create (1);
+         def_stmts2.quick_push (oprnd_info->def_stmts[0]);
+         child = vect_build_slp_tree (vinfo, def_stmts2, 1,
+                                      &this_max_nunits,
+                                      matches, limit,
+                                      &this_tree_size, bst_map);
+         if (child)
+           {
+             slp_tree pnode = vect_create_new_slp_node (1, VEC_PERM_EXPR);
+             SLP_TREE_VECTYPE (pnode) = SLP_TREE_VECTYPE (child);
+             SLP_TREE_LANES (pnode) = group_size;
+             SLP_TREE_SCALAR_STMTS (pnode).create (group_size);
+             SLP_TREE_LANE_PERMUTATION (pnode).create (group_size);
+             for (unsigned k = 0; k < group_size; ++k)
+               {
+                 SLP_TREE_SCALAR_STMTS (pnode)
+                   .quick_push (oprnd_info->def_stmts[0]);
+                 SLP_TREE_LANE_PERMUTATION (pnode)
+                   .quick_push (std::make_pair (0u, 0u));
+               }
+             SLP_TREE_CHILDREN (pnode).quick_push (child);
+             pnode->max_nunits = child->max_nunits;
+             children.safe_push (pnode);
+             oprnd_info->def_stmts = vNULL;
+             continue;
+           }
+         else
+           def_stmts2.release ();
+       }
+
       if ((child = vect_build_slp_tree (vinfo, oprnd_info->def_stmts,
                                        group_size, &this_max_nunits,
                                        matches, limit,
@@ -5462,6 +5519,9 @@ private:
   /* Clean-up.  */
   void remove_redundant_permutations ();
 
+  /* Masked load lanes discovery.  */
+  void decide_masked_load_lanes ();
+
   void dump ();
 
   vec_info *m_vinfo;
@@ -7090,6 +7150,80 @@ vect_optimize_slp_pass::dump ()
     }
 }
 
+/* Masked load lanes discovery.  */
+
+void
+vect_optimize_slp_pass::decide_masked_load_lanes ()
+{
+  for (auto v : m_vertices)
+    {
+      slp_tree node = v.node;
+      if (SLP_TREE_DEF_TYPE (node) != vect_internal_def
+         || SLP_TREE_CODE (node) == VEC_PERM_EXPR)
+       continue;
+      stmt_vec_info stmt_info = SLP_TREE_REPRESENTATIVE (node);
+      if (! STMT_VINFO_GROUPED_ACCESS (stmt_info)
+         /* The mask has to be uniform.  */
+         || STMT_VINFO_SLP_VECT_ONLY (stmt_info)
+         || ! is_a <gcall *> (STMT_VINFO_STMT (stmt_info))
+         || ! gimple_call_internal_p (STMT_VINFO_STMT (stmt_info),
+                                      IFN_MASK_LOAD))
+       continue;
+      stmt_info = DR_GROUP_FIRST_ELEMENT (stmt_info);
+      if (STMT_VINFO_STRIDED_P (stmt_info)
+         || compare_step_with_zero (m_vinfo, stmt_info) <= 0
+         || vect_load_lanes_supported (SLP_TREE_VECTYPE (node),
+                                       DR_GROUP_SIZE (stmt_info),
+                                       true) == IFN_LAST)
+       continue;
+
+      /* Uniform masks need to be suitably represented.  */
+      slp_tree mask = SLP_TREE_CHILDREN (node)[0];
+      if (SLP_TREE_CODE (mask) != VEC_PERM_EXPR
+         || SLP_TREE_CHILDREN (mask).length () != 1)
+       continue;
+      bool match = true;
+      for (auto perm : SLP_TREE_LANE_PERMUTATION (mask))
+       if (perm.first != 0 || perm.second != 0)
+         {
+           match = false;
+           break;
+         }
+      if (!match)
+       continue;
+
+      /* Now see if the consumer side matches.  */
+      for (graph_edge *pred = m_slpg->vertices[node->vertex].pred;
+          pred; pred = pred->pred_next)
+       {
+         slp_tree pred_node = m_vertices[pred->src].node;
+         /* All consumers should be a permute with a single outgoing lane.  */
+         if (SLP_TREE_CODE (pred_node) != VEC_PERM_EXPR
+             || SLP_TREE_LANES (pred_node) != 1)
+           {
+             match = false;
+             break;
+           }
+         gcc_assert (SLP_TREE_CHILDREN (pred_node).length () == 1);
+       }
+      if (!match)
+       continue;
+      /* Now we can mark the nodes as to use load lanes.  */
+      node->ldst_lanes = true;
+      for (graph_edge *pred = m_slpg->vertices[node->vertex].pred;
+          pred; pred = pred->pred_next)
+       m_vertices[pred->src].node->ldst_lanes = true;
+      /* The catch is we have to massage the mask.  We have arranged
+        analyzed uniform masks to be represented by a splat VEC_PERM
+        which we can now simply elide as we cannot easily re-do SLP
+        discovery here.  */
+      slp_tree new_mask = SLP_TREE_CHILDREN (mask)[0];
+      SLP_TREE_REF_COUNT (new_mask)++;
+      SLP_TREE_CHILDREN (node)[0] = new_mask;
+      vect_free_slp_tree (mask);
+    }
+}
+
 /* Main entry point for the SLP graph optimization pass.  */
 
 void
@@ -7110,6 +7244,7 @@ vect_optimize_slp_pass::run ()
     }
   else
     remove_redundant_permutations ();
+  decide_masked_load_lanes ();
   free_graph (m_slpg);
 }
 
-- 
2.43.0

Reply via email to