From: Ju-Zhe Zhong <juzhe.zh...@rivai.ai>

gcc/ChangeLog:

        * internal-fn.cc (expand_partial_store_optab_fn): Adapt for 
LEN_MASK_STORE.
        (internal_load_fn_p): Add LEN_MASK_LOAD.
        (internal_store_fn_p): Add LEN_MASK_STORE.
        (internal_fn_mask_index): Add LEN_MASK_{LOAD,STORE}.
        (internal_fn_stored_value_index): Add LEN_MASK_STORE.
        (internal_len_load_store_bias):  Add LEN_MASK_{LOAD,STORE}.
        * optabs-tree.cc (can_vec_mask_load_store_p): Adapt for 
LEN_MASK_{LOAD,STORE}.
        (get_len_load_store_mode): Ditto.
        * optabs-tree.h (can_vec_mask_load_store_p): Ditto.
        (get_len_load_store_mode): Ditto.
        * tree-vect-stmts.cc (check_load_store_for_partial_vectors): Ditto.
        (get_all_ones_mask): New function.
        (vectorizable_store): Apply LEN_MASK_{LOAD,STORE} into vectorizer.
        (vectorizable_load): Ditto.

---
 gcc/internal-fn.cc     |  36 ++++++-
 gcc/optabs-tree.cc     |  60 +++++++++--
 gcc/optabs-tree.h      |   6 +-
 gcc/tree-vect-stmts.cc | 220 +++++++++++++++++++++++++++++------------
 4 files changed, 243 insertions(+), 79 deletions(-)

diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc
index c911ae790cb..b90bd85df2c 100644
--- a/gcc/internal-fn.cc
+++ b/gcc/internal-fn.cc
@@ -2949,7 +2949,7 @@ expand_partial_load_optab_fn (internal_fn, gcall *stmt, 
convert_optab optab)
  * OPTAB.  */
 
 static void
-expand_partial_store_optab_fn (internal_fn, gcall *stmt, convert_optab optab)
+expand_partial_store_optab_fn (internal_fn ifn, gcall *stmt, convert_optab 
optab)
 {
   class expand_operand ops[5];
   tree type, lhs, rhs, maskt, biast;
@@ -2957,7 +2957,7 @@ expand_partial_store_optab_fn (internal_fn, gcall *stmt, 
convert_optab optab)
   insn_code icode;
 
   maskt = gimple_call_arg (stmt, 2);
-  rhs = gimple_call_arg (stmt, 3);
+  rhs = gimple_call_arg (stmt, internal_fn_stored_value_index (ifn));
   type = TREE_TYPE (rhs);
   lhs = expand_call_mem_ref (type, stmt, 0);
 
@@ -4435,6 +4435,7 @@ internal_load_fn_p (internal_fn fn)
     case IFN_GATHER_LOAD:
     case IFN_MASK_GATHER_LOAD:
     case IFN_LEN_LOAD:
+    case IFN_LEN_MASK_LOAD:
       return true;
 
     default:
@@ -4455,6 +4456,7 @@ internal_store_fn_p (internal_fn fn)
     case IFN_SCATTER_STORE:
     case IFN_MASK_SCATTER_STORE:
     case IFN_LEN_STORE:
+    case IFN_LEN_MASK_STORE:
       return true;
 
     default:
@@ -4498,6 +4500,10 @@ internal_fn_mask_index (internal_fn fn)
     case IFN_MASK_SCATTER_STORE:
       return 4;
 
+    case IFN_LEN_MASK_LOAD:
+    case IFN_LEN_MASK_STORE:
+      return 3;
+
     default:
       return (conditional_internal_fn_code (fn) != ERROR_MARK
              || get_unconditional_internal_fn (fn) != IFN_LAST ? 0 : -1);
@@ -4519,6 +4525,9 @@ internal_fn_stored_value_index (internal_fn fn)
     case IFN_LEN_STORE:
       return 3;
 
+    case IFN_LEN_MASK_STORE:
+      return 4;
+
     default:
       return -1;
     }
@@ -4583,13 +4592,32 @@ internal_len_load_store_bias (internal_fn ifn, 
machine_mode mode)
 {
   optab optab = direct_internal_fn_optab (ifn);
   insn_code icode = direct_optab_handler (optab, mode);
+  int bias_opno = 3;
+
+  if (icode == CODE_FOR_nothing)
+    {
+      machine_mode mask_mode
+       = targetm.vectorize.get_mask_mode (mode).require ();
+      if (ifn == IFN_LEN_LOAD)
+       {
+         /* Try LEN_MASK_LOAD.  */
+         optab = direct_internal_fn_optab (IFN_LEN_MASK_LOAD);
+       }
+      else
+       {
+         /* Try LEN_MASK_STORE.  */
+         optab = direct_internal_fn_optab (IFN_LEN_MASK_STORE);
+       }
+      icode = convert_optab_handler (optab, mode, mask_mode);
+      bias_opno = 4;
+    }
 
   if (icode != CODE_FOR_nothing)
     {
       /* For now we only support biases of 0 or -1.  Try both of them.  */
-      if (insn_operand_matches (icode, 3, GEN_INT (0)))
+      if (insn_operand_matches (icode, bias_opno, GEN_INT (0)))
        return 0;
-      if (insn_operand_matches (icode, 3, GEN_INT (-1)))
+      if (insn_operand_matches (icode, bias_opno, GEN_INT (-1)))
        return -1;
     }
 
diff --git a/gcc/optabs-tree.cc b/gcc/optabs-tree.cc
index 77bf745ae40..e90e1b62ebc 100644
--- a/gcc/optabs-tree.cc
+++ b/gcc/optabs-tree.cc
@@ -548,14 +548,29 @@ target_supports_op_p (tree type, enum tree_code code,
 bool
 can_vec_mask_load_store_p (machine_mode mode,
                           machine_mode mask_mode,
-                          bool is_load)
+                          bool is_load,
+                          internal_fn *ifn)
 {
   optab op = is_load ? maskload_optab : maskstore_optab;
+  optab len_op = is_load ? len_maskload_optab : len_maskstore_optab;
   machine_mode vmode;
 
   /* If mode is vector mode, check it directly.  */
   if (VECTOR_MODE_P (mode))
-    return convert_optab_handler (op, mode, mask_mode) != CODE_FOR_nothing;
+    {
+      if (convert_optab_handler (op, mode, mask_mode) != CODE_FOR_nothing)
+       {
+         if (ifn)
+           *ifn = is_load ? IFN_MASK_LOAD : IFN_MASK_STORE;
+         return true;
+       }
+      else if (convert_optab_handler (len_op, mode, mask_mode)
+              != CODE_FOR_nothing)
+       {
+         *ifn = is_load ? IFN_LEN_MASK_LOAD : IFN_LEN_MASK_STORE;
+         return true;
+       }
+    }
 
   /* Otherwise, return true if there is some vector mode with
      the mask load/store supported.  */
@@ -569,7 +584,9 @@ can_vec_mask_load_store_p (machine_mode mode,
   vmode = targetm.vectorize.preferred_simd_mode (smode);
   if (VECTOR_MODE_P (vmode)
       && targetm.vectorize.get_mask_mode (vmode).exists (&mask_mode)
-      && convert_optab_handler (op, vmode, mask_mode) != CODE_FOR_nothing)
+      && (convert_optab_handler (op, vmode, mask_mode) != CODE_FOR_nothing
+         || convert_optab_handler (len_op, vmode, mask_mode)
+              != CODE_FOR_nothing))
     return true;
 
   auto_vector_modes vector_modes;
@@ -577,7 +594,9 @@ can_vec_mask_load_store_p (machine_mode mode,
   for (machine_mode base_mode : vector_modes)
     if (related_vector_mode (base_mode, smode).exists (&vmode)
        && targetm.vectorize.get_mask_mode (vmode).exists (&mask_mode)
-       && convert_optab_handler (op, vmode, mask_mode) != CODE_FOR_nothing)
+       && (convert_optab_handler (op, vmode, mask_mode) != CODE_FOR_nothing
+           || convert_optab_handler (len_op, vmode, mask_mode)
+                != CODE_FOR_nothing))
       return true;
   return false;
 }
@@ -590,21 +609,46 @@ can_vec_mask_load_store_p (machine_mode mode,
    VnQI to wrap the other supportable same size vector modes.  */
 
 opt_machine_mode
-get_len_load_store_mode (machine_mode mode, bool is_load)
+get_len_load_store_mode (machine_mode mode, bool is_load, internal_fn *ifn)
 {
   optab op = is_load ? len_load_optab : len_store_optab;
+  optab masked_op = is_load ? len_maskload_optab : len_maskstore_optab;
   gcc_assert (VECTOR_MODE_P (mode));
+  /* We default ifn to IFN_LEN_{LOAD,STORE} and adjust it to
+     IFN_LEN_MASK_{LOAD,STORE} if according to following checks.  */
+  if (ifn)
+    *ifn = is_load ? IFN_LEN_LOAD : IFN_LEN_STORE;
 
   /* Check if length in lanes supported for this mode directly.  */
   if (direct_optab_handler (op, mode))
     return mode;
 
+  /* Check if length in lanes supported by len_maskload/store.  */
+  machine_mode mask_mode;
+  if (targetm.vectorize.get_mask_mode (mode).exists (&mask_mode)
+      && convert_optab_handler (masked_op, mode, mask_mode) != 
CODE_FOR_nothing)
+    {
+      if (ifn)
+       *ifn = is_load ? IFN_LEN_MASK_LOAD : IFN_LEN_MASK_STORE;
+      return mode;
+    }
+
   /* Check if length in bytes supported for same vector size VnQI.  */
   machine_mode vmode;
   poly_uint64 nunits = GET_MODE_SIZE (mode);
-  if (related_vector_mode (mode, QImode, nunits).exists (&vmode)
-      && direct_optab_handler (op, vmode))
-    return vmode;
+  if (related_vector_mode (mode, QImode, nunits).exists (&vmode))
+    {
+      if (direct_optab_handler (op, vmode))
+       return vmode;
+      if (targetm.vectorize.get_mask_mode (mode).exists (&mask_mode)
+         && convert_optab_handler (masked_op, mode, mask_mode)
+              != CODE_FOR_nothing)
+       {
+         if (ifn)
+           *ifn = is_load ? IFN_LEN_MASK_LOAD : IFN_LEN_MASK_STORE;
+         return vmode;
+       }
+    }
 
   return opt_machine_mode ();
 }
diff --git a/gcc/optabs-tree.h b/gcc/optabs-tree.h
index a3f79b6bd43..e421fc24289 100644
--- a/gcc/optabs-tree.h
+++ b/gcc/optabs-tree.h
@@ -47,7 +47,9 @@ bool expand_vec_cond_expr_p (tree, tree, enum tree_code);
 void init_tree_optimization_optabs (tree);
 bool target_supports_op_p (tree, enum tree_code,
                           enum optab_subtype = optab_default);
-bool can_vec_mask_load_store_p (machine_mode, machine_mode, bool);
-opt_machine_mode get_len_load_store_mode (machine_mode, bool);
+bool can_vec_mask_load_store_p (machine_mode, machine_mode, bool,
+                               internal_fn * = nullptr);
+opt_machine_mode get_len_load_store_mode (machine_mode, bool,
+                                         internal_fn * = nullptr);
 
 #endif
diff --git a/gcc/tree-vect-stmts.cc b/gcc/tree-vect-stmts.cc
index 056a0ecb2be..d53c4e4f2e5 100644
--- a/gcc/tree-vect-stmts.cc
+++ b/gcc/tree-vect-stmts.cc
@@ -1819,16 +1819,8 @@ check_load_store_for_partial_vectors (loop_vec_info 
loop_vinfo, tree vectype,
   poly_uint64 nunits = TYPE_VECTOR_SUBPARTS (vectype);
   poly_uint64 vf = LOOP_VINFO_VECT_FACTOR (loop_vinfo);
   machine_mode mask_mode;
-  bool using_partial_vectors_p = false;
-  if (targetm.vectorize.get_mask_mode (vecmode).exists (&mask_mode)
-      && can_vec_mask_load_store_p (vecmode, mask_mode, is_load))
-    {
-      nvectors = group_memory_nvectors (group_size * vf, nunits);
-      vect_record_loop_mask (loop_vinfo, masks, nvectors, vectype, 
scalar_mask);
-      using_partial_vectors_p = true;
-    }
-
   machine_mode vmode;
+  bool using_partial_vectors_p = false;
   if (get_len_load_store_mode (vecmode, is_load).exists (&vmode))
     {
       nvectors = group_memory_nvectors (group_size * vf, nunits);
@@ -1837,6 +1829,13 @@ check_load_store_for_partial_vectors (loop_vec_info 
loop_vinfo, tree vectype,
       vect_record_loop_len (loop_vinfo, lens, nvectors, vectype, factor);
       using_partial_vectors_p = true;
     }
+  else if (targetm.vectorize.get_mask_mode (vecmode).exists (&mask_mode)
+          && can_vec_mask_load_store_p (vecmode, mask_mode, is_load))
+    {
+      nvectors = group_memory_nvectors (group_size * vf, nunits);
+      vect_record_loop_mask (loop_vinfo, masks, nvectors, vectype, 
scalar_mask);
+      using_partial_vectors_p = true;
+    }
 
   if (!using_partial_vectors_p)
     {
@@ -3175,6 +3174,17 @@ vect_get_loop_variant_data_ptr_increment (
   return bump;
 }
 
+/* Get all-ones vector mask for corresponding vectype.  */
+
+static tree
+get_all_ones_mask (machine_mode vmode)
+{
+  machine_mode maskmode = targetm.vectorize.get_mask_mode (vmode).require ();
+  poly_uint64 nunits = GET_MODE_NUNITS (maskmode);
+  tree masktype = build_truth_vector_type_for_mode (nunits, maskmode);
+  return constant_boolean_node (true, masktype);
+}
+
 /* Return the amount that should be added to a vector pointer to move
    to the next or previous copy of AGGR_TYPE.  DR_INFO is the data reference
    being vectorized and MEMORY_ACCESS_TYPE describes the type of
@@ -8944,30 +8954,58 @@ vectorizable_store (vec_info *vinfo,
                  vec_oprnd = new_temp;
                }
 
-             /* Arguments are ready.  Create the new vector stmt.  */
-             if (final_mask)
-               {
-                 tree ptr = build_int_cst (ref_type, align * BITS_PER_UNIT);
-                 gcall *call
-                   = gimple_build_call_internal (IFN_MASK_STORE, 4,
-                                                 dataref_ptr, ptr,
-                                                 final_mask, vec_oprnd);
-                 gimple_call_set_nothrow (call, true);
-                 vect_finish_stmt_generation (vinfo, stmt_info, call, gsi);
-                 new_stmt = call;
-               }
-             else if (loop_lens)
+             /* Compute IFN when LOOP_LENS or final_mask valid.  */
+             machine_mode vmode = TYPE_MODE (vectype);
+             machine_mode new_vmode = vmode;
+             internal_fn partial_ifn = IFN_LAST;
+             /* Produce 'len' and 'bias' argument.  */
+             tree final_len = NULL_TREE;
+             tree bias = NULL_TREE;
+             if (loop_lens)
                {
-                 machine_mode vmode = TYPE_MODE (vectype);
                  opt_machine_mode new_ovmode
                    = get_len_load_store_mode (vmode, false);
-                 machine_mode new_vmode = new_ovmode.require ();
+                 new_vmode = new_ovmode.require ();
                  unsigned factor
                    = (new_ovmode == vmode) ? 1 : GET_MODE_UNIT_SIZE (vmode);
-                 tree final_len
-                   = vect_get_loop_len (loop_vinfo, gsi, loop_lens,
-                                        vec_num * ncopies, vectype,
-                                        vec_num * j + i, factor);
+                 final_len = vect_get_loop_len (loop_vinfo, gsi, loop_lens,
+                                                vec_num * ncopies, vectype,
+                                                vec_num * j + i, factor);
+               }
+             else if (final_mask)
+               can_vec_mask_load_store_p (vmode, TYPE_MODE (mask_vectype),
+                                          false, &partial_ifn);
+
+             if (partial_ifn == IFN_LEN_MASK_STORE)
+               {
+                 if (!final_len)
+                   {
+                     /* Pass VF value to 'len' argument of
+                      * LEN_MASK_STORE if LOOP_LENS is invalid.  */
+                     tree iv_type = LOOP_VINFO_RGROUP_IV_TYPE (loop_vinfo);
+                     final_len
+                       = build_int_cst (iv_type,
+                                        TYPE_VECTOR_SUBPARTS (vectype));
+                   }
+                 if (!final_mask)
+                   {
+                     /* Pass all ones value to 'mask' argument of
+                      * LEN_MASK_STORE if final_mask is invalid.  */
+                     final_mask = get_all_ones_mask (vmode);
+                   }
+               }
+             if (final_len)
+               {
+                 signed char biasval
+                   = LOOP_VINFO_PARTIAL_LOAD_STORE_BIAS (loop_vinfo);
+
+                 bias = build_int_cst (intQI_type_node, biasval);
+               }
+
+             /* Arguments are ready.  Create the new vector stmt.  */
+             if (final_len)
+               {
+                 gcall *call;
                  tree ptr = build_int_cst (ref_type, align * BITS_PER_UNIT);
                  /* Need conversion if it's wrapped with VnQI.  */
                  if (vmode != new_vmode)
@@ -8987,14 +9025,27 @@ vectorizable_store (vec_info *vinfo,
                      vec_oprnd = var;
                    }
 
-                 signed char biasval =
-                   LOOP_VINFO_PARTIAL_LOAD_STORE_BIAS (loop_vinfo);
-
-                 tree bias = build_int_cst (intQI_type_node, biasval);
+                 if (partial_ifn == IFN_LEN_MASK_STORE)
+                   call = gimple_build_call_internal (IFN_LEN_MASK_STORE, 6,
+                                                      dataref_ptr, ptr,
+                                                      final_len, final_mask,
+                                                      vec_oprnd, bias);
+                 else
+                   call
+                     = gimple_build_call_internal (IFN_LEN_STORE, 5,
+                                                   dataref_ptr, ptr, final_len,
+                                                   vec_oprnd, bias);
+                 gimple_call_set_nothrow (call, true);
+                 vect_finish_stmt_generation (vinfo, stmt_info, call, gsi);
+                 new_stmt = call;
+               }
+             else if (final_mask)
+               {
+                 tree ptr = build_int_cst (ref_type, align * BITS_PER_UNIT);
                  gcall *call
-                   = gimple_build_call_internal (IFN_LEN_STORE, 5, dataref_ptr,
-                                                 ptr, final_len, vec_oprnd,
-                                                 bias);
+                   = gimple_build_call_internal (IFN_MASK_STORE, 4,
+                                                 dataref_ptr, ptr,
+                                                 final_mask, vec_oprnd);
                  gimple_call_set_nothrow (call, true);
                  vect_finish_stmt_generation (vinfo, stmt_info, call, gsi);
                  new_stmt = call;
@@ -10304,45 +10355,72 @@ vectorizable_load (vec_info *vinfo,
                                              align, misalign);
                    align = least_bit_hwi (misalign | align);
 
-                   if (final_mask)
-                     {
-                       tree ptr = build_int_cst (ref_type,
-                                                 align * BITS_PER_UNIT);
-                       gcall *call
-                         = gimple_build_call_internal (IFN_MASK_LOAD, 3,
-                                                       dataref_ptr, ptr,
-                                                       final_mask);
-                       gimple_call_set_nothrow (call, true);
-                       new_stmt = call;
-                       data_ref = NULL_TREE;
-                     }
-                   else if (loop_lens && memory_access_type != VMAT_INVARIANT)
+                   /* Compute IFN when LOOP_LENS or final_mask valid.  */
+                   machine_mode vmode = TYPE_MODE (vectype);
+                   machine_mode new_vmode = vmode;
+                   internal_fn partial_ifn = IFN_LAST;
+                   /* Produce 'len' and 'bias' argument.  */
+                   tree final_len = NULL_TREE;
+                   tree bias = NULL_TREE;
+                   if (loop_lens)
                      {
-                       machine_mode vmode = TYPE_MODE (vectype);
                        opt_machine_mode new_ovmode
-                         = get_len_load_store_mode (vmode, true);
-                       machine_mode new_vmode = new_ovmode.require ();
+                         = get_len_load_store_mode (vmode, false);
+                       new_vmode = new_ovmode.require ();
                        unsigned factor = (new_ovmode == vmode)
                                            ? 1
                                            : GET_MODE_UNIT_SIZE (vmode);
-                       tree final_len
+                       final_len
                          = vect_get_loop_len (loop_vinfo, gsi, loop_lens,
                                               vec_num * ncopies, vectype,
                                               vec_num * j + i, factor);
-                       tree ptr
-                         = build_int_cst (ref_type, align * BITS_PER_UNIT);
-
-                       tree qi_type = unsigned_intQI_type_node;
+                     }
+                   else if (final_mask)
+                     can_vec_mask_load_store_p (vmode,
+                                                TYPE_MODE (mask_vectype),
+                                                false, &partial_ifn);
 
-                       signed char biasval =
-                         LOOP_VINFO_PARTIAL_LOAD_STORE_BIAS (loop_vinfo);
+                   if (partial_ifn == IFN_LEN_MASK_LOAD)
+                     {
+                       if (!final_len)
+                         {
+                           /* Pass VF value to 'len' argument of
+                            * LEN_MASK_LOAD if LOOP_LENS is invalid.  */
+                           tree iv_type
+                             = LOOP_VINFO_RGROUP_IV_TYPE (loop_vinfo);
+                           final_len
+                             = build_int_cst (iv_type,
+                                              TYPE_VECTOR_SUBPARTS (vectype));
+                         }
+                       if (!final_mask)
+                         {
+                           /* Pass all ones value to 'mask' argument of
+                            * LEN_MASK_LOAD if final_mask is invalid.  */
+                           final_mask = get_all_ones_mask (vmode);
+                         }
+                     }
+                   if (final_len)
+                     {
+                       signed char biasval
+                         = LOOP_VINFO_PARTIAL_LOAD_STORE_BIAS (loop_vinfo);
 
-                       tree bias = build_int_cst (intQI_type_node, biasval);
+                       bias = build_int_cst (intQI_type_node, biasval);
+                     }
 
-                       gcall *call
-                         = gimple_build_call_internal (IFN_LEN_LOAD, 4,
-                                                       dataref_ptr, ptr,
-                                                       final_len, bias);
+                   if (final_len && memory_access_type != VMAT_INVARIANT)
+                     {
+                       tree ptr
+                         = build_int_cst (ref_type, align * BITS_PER_UNIT);
+                       gcall *call;
+                       if (partial_ifn == IFN_LEN_MASK_LOAD)
+                         call = gimple_build_call_internal (IFN_LEN_MASK_LOAD,
+                                                            5, dataref_ptr,
+                                                            ptr, final_len,
+                                                            final_mask, bias);
+                       else
+                         call = gimple_build_call_internal (IFN_LEN_LOAD, 4,
+                                                            dataref_ptr, ptr,
+                                                            final_len, bias);
                        gimple_call_set_nothrow (call, true);
                        new_stmt = call;
                        data_ref = NULL_TREE;
@@ -10350,8 +10428,8 @@ vectorizable_load (vec_info *vinfo,
                        /* Need conversion if it's wrapped with VnQI.  */
                        if (vmode != new_vmode)
                          {
-                           tree new_vtype
-                             = build_vector_type_for_mode (qi_type, new_vmode);
+                           tree new_vtype = build_vector_type_for_mode (
+                             unsigned_intQI_type_node, new_vmode);
                            tree var = vect_get_new_ssa_name (new_vtype,
                                                              vect_simple_var);
                            gimple_set_lhs (call, var);
@@ -10363,6 +10441,18 @@ vectorizable_load (vec_info *vinfo,
                                                     VIEW_CONVERT_EXPR, op);
                          }
                      }
+                   else if (final_mask)
+                     {
+                       tree ptr = build_int_cst (ref_type,
+                                                 align * BITS_PER_UNIT);
+                       gcall *call
+                         = gimple_build_call_internal (IFN_MASK_LOAD, 3,
+                                                       dataref_ptr, ptr,
+                                                       final_mask);
+                       gimple_call_set_nothrow (call, true);
+                       new_stmt = call;
+                       data_ref = NULL_TREE;
+                     }
                    else
                      {
                        tree ltype = vectype;
-- 
2.36.1

Reply via email to