DavidSpickett updated this revision to Diff 557007.
DavidSpickett added a comment.

Rebase


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D159502/new/

https://reviews.llvm.org/D159502

Files:
  lldb/include/lldb/Utility/RegisterValue.h
  lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.cpp
  lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.h
  lldb/source/Plugins/Process/Utility/LinuxPTraceDefines_arm64sve.h
  lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.cpp
  lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.h
  lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.cpp
  lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.h
  lldb/source/Plugins/Process/elf-core/RegisterContextPOSIXCore_arm64.cpp
  lldb/source/Plugins/Process/elf-core/RegisterUtilities.h

Index: lldb/source/Plugins/Process/elf-core/RegisterUtilities.h
===================================================================
--- lldb/source/Plugins/Process/elf-core/RegisterUtilities.h
+++ lldb/source/Plugins/Process/elf-core/RegisterUtilities.h
@@ -119,6 +119,10 @@
     {llvm::Triple::Linux, llvm::Triple::aarch64, llvm::ELF::NT_ARM_SVE},
 };
 
+constexpr RegsetDesc AARCH64_ZA_Desc[] = {
+    {llvm::Triple::Linux, llvm::Triple::aarch64, llvm::ELF::NT_ARM_ZA},
+};
+
 constexpr RegsetDesc AARCH64_PAC_Desc[] = {
     {llvm::Triple::Linux, llvm::Triple::aarch64, llvm::ELF::NT_ARM_PAC_MASK},
 };
Index: lldb/source/Plugins/Process/elf-core/RegisterContextPOSIXCore_arm64.cpp
===================================================================
--- lldb/source/Plugins/Process/elf-core/RegisterContextPOSIXCore_arm64.cpp
+++ lldb/source/Plugins/Process/elf-core/RegisterContextPOSIXCore_arm64.cpp
@@ -113,7 +113,7 @@
     m_sve_state = SVEState::Disabled;
 
   if (m_sve_state != SVEState::Disabled)
-    m_register_info_up->ConfigureVectorLength(
+    m_register_info_up->ConfigureVectorLengthSVE(
         sve::vq_from_vl(m_sve_vector_length));
 }
 
Index: lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.h
===================================================================
--- lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.h
+++ lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.h
@@ -30,6 +30,7 @@
     eRegsetMaskPAuth = 4,
     eRegsetMaskMTE = 8,
     eRegsetMaskTLS = 16,
+    eRegsetMaskZA = 32,
     eRegsetMaskDynamic = ~1,
   };
 
@@ -106,7 +107,11 @@
 
   void AddRegSetTLS(bool has_tpidr2);
 
-  uint32_t ConfigureVectorLength(uint32_t sve_vq);
+  void AddRegSetSME();
+
+  uint32_t ConfigureVectorLengthSVE(uint32_t sve_vq);
+
+  void ConfigureVectorLengthZA(uint32_t za_vq);
 
   bool VectorSizeIsValid(uint32_t vq) {
     // coverity[unsigned_compare]
@@ -117,6 +122,7 @@
 
   bool IsSVEEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskSVE); }
   bool IsSSVEEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskSSVE); }
+  bool IsZAEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskZA); }
   bool IsPAuthEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskPAuth); }
   bool IsMTEEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskMTE); }
   bool IsTLSEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskTLS); }
@@ -128,6 +134,7 @@
   bool IsPAuthReg(unsigned reg) const;
   bool IsMTEReg(unsigned reg) const;
   bool IsTLSReg(unsigned reg) const;
+  bool IsSMEReg(unsigned reg) const;
 
   uint32_t GetRegNumSVEZ0() const;
   uint32_t GetRegNumSVEFFR() const;
@@ -137,6 +144,7 @@
   uint32_t GetPAuthOffset() const;
   uint32_t GetMTEOffset() const;
   uint32_t GetTLSOffset() const;
+  uint32_t GetSMEOffset() const;
 
 private:
   typedef std::map<uint32_t, std::vector<lldb_private::RegisterInfo>>
@@ -145,7 +153,10 @@
   per_vq_register_infos m_per_vq_reg_infos;
 
   uint32_t m_vector_reg_vq = eVectorQuadwordAArch64;
+  uint32_t m_za_reg_vq = eVectorQuadwordAArch64;
 
+  // In normal operation this is const. Only when SVE or SME registers change
+  // size is it either replaced or the content modified.
   const lldb_private::RegisterInfo *m_register_info_p;
   uint32_t m_register_info_count;
 
@@ -164,6 +175,7 @@
   std::vector<uint32_t> pauth_regnum_collection;
   std::vector<uint32_t> m_mte_regnum_collection;
   std::vector<uint32_t> m_tls_regnum_collection;
+  std::vector<uint32_t> m_sme_regnum_collection;
 };
 
 #endif
Index: lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.cpp
===================================================================
--- lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.cpp
+++ lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.cpp
@@ -83,6 +83,11 @@
     // Only present when SME is present
     DEFINE_EXTENSION_REG(tpidr2)};
 
+static lldb_private::RegisterInfo g_register_infos_sme[] =
+    // 16 is a default size we will change later.
+    {{"za", nullptr, 16, 0, lldb::eEncodingVector, lldb::eFormatVectorOfUInt8,
+      KIND_ALL_INVALID, nullptr, nullptr, nullptr}};
+
 // Number of register sets provided by this context.
 enum {
   k_num_gpr_registers = gpr_w28 - gpr_x0 + 1,
@@ -91,6 +96,7 @@
   k_num_mte_register = 1,
   // Number of TLS registers is dynamic so it is not listed here.
   k_num_pauth_register = 2,
+  k_num_sme_register = 1,
   k_num_register_sets_default = 2,
   k_num_register_sets = 3
 };
@@ -197,6 +203,9 @@
 
 // The size of the TLS set is dynamic, so not listed here.
 
+static const lldb_private::RegisterSet g_reg_set_sme_arm64 = {
+    "Scalable Matrix Extension Registers", "sme", k_num_sme_register, nullptr};
+
 RegisterInfoPOSIX_arm64::RegisterInfoPOSIX_arm64(
     const lldb_private::ArchSpec &target_arch, lldb_private::Flags opt_regsets)
     : lldb_private::RegisterInfoAndSetInterface(target_arch),
@@ -241,6 +250,9 @@
       // present.
       AddRegSetTLS(m_opt_regsets.AllSet(eRegsetMaskSSVE));
 
+      if (m_opt_regsets.AnySet(eRegsetMaskSSVE))
+        AddRegSetSME();
+
       m_register_info_count = m_dynamic_reg_infos.size();
       m_register_info_p = m_dynamic_reg_infos.data();
       m_register_set_p = m_dynamic_reg_sets.data();
@@ -344,7 +356,25 @@
   m_dynamic_reg_sets.back().registers = m_tls_regnum_collection.data();
 }
 
-uint32_t RegisterInfoPOSIX_arm64::ConfigureVectorLength(uint32_t sve_vq) {
+void RegisterInfoPOSIX_arm64::AddRegSetSME() {
+  uint32_t sme_regnum = m_dynamic_reg_infos.size();
+  for (uint32_t i = 0; i < k_num_sme_register; i++) {
+    m_sme_regnum_collection.push_back(sme_regnum + i);
+    m_dynamic_reg_infos.push_back(g_register_infos_sme[i]);
+    m_dynamic_reg_infos[sme_regnum + i].byte_offset =
+        m_dynamic_reg_infos[sme_regnum + i - 1].byte_offset +
+        m_dynamic_reg_infos[sme_regnum + i - 1].byte_size;
+    m_dynamic_reg_infos[sme_regnum + i].kinds[lldb::eRegisterKindLLDB] =
+        sme_regnum + i;
+  }
+
+  m_per_regset_regnum_range[m_register_set_count] =
+      std::make_pair(sme_regnum, m_dynamic_reg_infos.size());
+  m_dynamic_reg_sets.push_back(g_reg_set_sme_arm64);
+  m_dynamic_reg_sets.back().registers = m_sme_regnum_collection.data();
+}
+
+uint32_t RegisterInfoPOSIX_arm64::ConfigureVectorLengthSVE(uint32_t sve_vq) {
   // sve_vq contains SVE Quad vector length in context of AArch64 SVE.
   // SVE register infos if enabled cannot be disabled by selecting sve_vq = 0.
   // Also if an invalid or previously set vector length is passed to this
@@ -408,6 +438,20 @@
   return m_vector_reg_vq;
 }
 
+void RegisterInfoPOSIX_arm64::ConfigureVectorLengthZA(uint32_t za_vq) {
+  if (!VectorSizeIsValid(za_vq) || m_za_reg_vq == za_vq)
+    return;
+
+  m_za_reg_vq = za_vq;
+
+  // For SVE changes, we replace m_register_info_p completely. ZA is in a
+  // dynamic set and is just 1 register so we make an exception to const here.
+  lldb_private::RegisterInfo *non_const_reginfo =
+      const_cast<lldb_private::RegisterInfo *>(m_register_info_p);
+  non_const_reginfo[m_sme_regnum_collection[0]].byte_size =
+      (za_vq * 16) * (za_vq * 16);
+}
+
 bool RegisterInfoPOSIX_arm64::IsSVEReg(unsigned reg) const {
   if (m_vector_reg_vq > eVectorQuadwordAArch64)
     return (sve_vg <= reg && reg <= sve_ffr);
@@ -439,6 +483,10 @@
   return llvm::is_contained(m_tls_regnum_collection, reg);
 }
 
+bool RegisterInfoPOSIX_arm64::IsSMEReg(unsigned reg) const {
+  return llvm::is_contained(m_sme_regnum_collection, reg);
+}
+
 uint32_t RegisterInfoPOSIX_arm64::GetRegNumSVEZ0() const { return sve_z0; }
 
 uint32_t RegisterInfoPOSIX_arm64::GetRegNumSVEFFR() const { return sve_ffr; }
@@ -460,3 +508,7 @@
 uint32_t RegisterInfoPOSIX_arm64::GetTLSOffset() const {
   return m_register_info_p[m_tls_regnum_collection[0]].byte_offset;
 }
+
+uint32_t RegisterInfoPOSIX_arm64::GetSMEOffset() const {
+  return m_register_info_p[m_sme_regnum_collection[0]].byte_offset;
+}
Index: lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.h
===================================================================
--- lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.h
+++ lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.h
@@ -54,6 +54,7 @@
   size_t GetFPUSize() { return sizeof(RegisterInfoPOSIX_arm64::FPU); }
 
   bool IsSVE(unsigned reg) const;
+  bool IsSME(unsigned reg) const;
   bool IsPAuth(unsigned reg) const;
   bool IsTLS(unsigned reg) const;
 
Index: lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.cpp
===================================================================
--- lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.cpp
+++ lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.cpp
@@ -43,6 +43,10 @@
   return m_register_info_up->IsSVEReg(reg);
 }
 
+bool RegisterContextPOSIX_arm64::IsSME(unsigned reg) const {
+  return m_register_info_up->IsSMEReg(reg);
+}
+
 bool RegisterContextPOSIX_arm64::IsPAuth(unsigned reg) const {
   return m_register_info_up->IsPAuthReg(reg);
 }
Index: lldb/source/Plugins/Process/Utility/LinuxPTraceDefines_arm64sve.h
===================================================================
--- lldb/source/Plugins/Process/Utility/LinuxPTraceDefines_arm64sve.h
+++ lldb/source/Plugins/Process/Utility/LinuxPTraceDefines_arm64sve.h
@@ -152,6 +152,8 @@
   uint16_t reserved;
 };
 
+using user_za_header = user_sve_header;
+
 /* Definitions for user_sve_header.flags: */
 const uint16_t ptrace_regs_mask = 1 << 0;
 const uint16_t ptrace_regs_fpsimd = 0;
Index: lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.h
===================================================================
--- lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.h
+++ lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.h
@@ -85,6 +85,8 @@
   bool m_mte_ctrl_is_valid;
 
   bool m_sve_header_is_valid;
+  bool m_za_buffer_is_valid;
+  bool m_za_header_is_valid;
   bool m_pac_mask_is_valid;
   bool m_tls_is_valid;
   size_t m_tls_size;
@@ -98,6 +100,9 @@
   struct sve::user_sve_header m_sve_header;
   std::vector<uint8_t> m_sve_ptrace_payload;
 
+  sve::user_za_header m_za_header;
+  std::vector<uint8_t> m_za_ptrace_payload;
+
   bool m_refresh_hwdebug_info;
 
   struct user_pac_mask {
@@ -139,7 +144,18 @@
 
   Status WriteTLS();
 
+  Status ReadZAHeader();
+
+  Status ReadZA();
+
+  Status WriteZA();
+
+  // No WriteZAHeader because writing only the header will disable ZA.
+  // Instead use WriteZA and ensure you have the correct ZA buffer size set
+  // beforehand if you wish to disable it.
+
   bool IsSVE(unsigned reg) const;
+  bool IsSME(unsigned reg) const;
   bool IsPAuth(unsigned reg) const;
   bool IsMTE(unsigned reg) const;
   bool IsTLS(unsigned reg) const;
@@ -150,6 +166,10 @@
 
   void *GetSVEHeader() { return &m_sve_header; }
 
+  void *GetZAHeader() { return &m_za_header; }
+
+  size_t GetZAHeaderSize() { return sizeof(m_za_header); }
+
   void *GetPACMask() { return &m_pac_mask; }
 
   void *GetMTEControl() { return &m_mte_ctrl_reg; }
@@ -166,6 +186,10 @@
 
   unsigned GetSVERegSet();
 
+  void *GetZABuffer() { return m_za_ptrace_payload.data(); };
+
+  size_t GetZABufferSize() { return m_za_ptrace_payload.size(); }
+
   size_t GetMTEControlSize() { return sizeof(m_mte_ctrl_reg); }
 
   size_t GetTLSBufferSize() { return m_tls_size; }
Index: lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.cpp
===================================================================
--- lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.cpp
+++ lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.cpp
@@ -41,6 +41,10 @@
   0x40b /* ARM Scalable Matrix Extension, Streaming SVE mode */
 #endif
 
+#ifndef NT_ARM_ZA
+#define NT_ARM_ZA 0x40c /* ARM Scalable Matrix Extension, Array Storage */
+#endif
+
 #ifndef NT_ARM_PAC_MASK
 #define NT_ARM_PAC_MASK 0x406 /* Pointer authentication code masks */
 #endif
@@ -90,6 +94,16 @@
         opt_regsets.Set(RegisterInfoPOSIX_arm64::eRegsetMaskSSVE);
     }
 
+    sve::user_za_header za_header;
+    ioVec.iov_base = &za_header;
+    ioVec.iov_len = sizeof(za_header);
+    regset = NT_ARM_ZA;
+    if (NativeProcessLinux::PtraceWrapper(PTRACE_GETREGSET,
+                                          native_thread.GetID(), &regset,
+                                          &ioVec, sizeof(za_header))
+            .Success())
+      opt_regsets.Set(RegisterInfoPOSIX_arm64::eRegsetMaskZA);
+
     NativeProcessLinux &process = native_thread.GetProcess();
 
     std::optional<uint64_t> auxv_at_hwcap =
@@ -314,6 +328,31 @@
     offset = reg_info->byte_offset - GetRegisterInfo().GetMTEOffset();
     assert(offset < GetMTEControlSize());
     src = (uint8_t *)GetMTEControl() + offset;
+  } else if (IsSME(reg)) {
+    error = ReadZAHeader();
+    if (error.Fail())
+      return error;
+
+    // If there is only a header and no registers, ZA is inactive. Read as 0
+    // in this case.
+    if (m_za_header.size == sizeof(m_za_header)) {
+      // This will get reconfigured/reset later, so we are safe to use it.
+      // ZA is a square of VL * VL and the ptrace buffer also includes the
+      // header itself.
+      m_za_ptrace_payload.resize(((m_za_header.vl) * (m_za_header.vl)) +
+                                 GetZAHeaderSize());
+      std::fill(m_za_ptrace_payload.begin(), m_za_ptrace_payload.end(), 0);
+    } else {
+      // ZA is active, read the real register.
+      error = ReadZA();
+      if (error.Fail())
+        return error;
+    }
+
+    // ZA is part of the SME set but uses a seperate member buffer for storage.
+    // Therefore its effective byte offset is always 0 even if it isn't 0 within
+    // the SME register set.
+    src = (uint8_t *)GetZABuffer() + GetZAHeaderSize();
   } else
     return Status("failed - register wasn't recognized to be a GPR or an FPR, "
                   "write strategy unknown");
@@ -420,8 +459,12 @@
           SetSVERegVG(vg_value);
 
           error = WriteSVEHeader();
-          if (error.Success())
+          if (error.Success()) {
+            // Changing VG during streaming mode also changes the size of ZA.
+            if (m_sve_state == SVEState::Streaming)
+              m_za_header_is_valid = false;
             ConfigureRegisterContext();
+          }
 
           if (m_sve_header_is_valid && vg_value == GetSVERegVG())
             return error;
@@ -494,6 +537,20 @@
     ::memcpy(dst, reg_value.GetBytes(), reg_info->byte_size);
 
     return WriteTLS();
+  } else if (IsSME(reg)) {
+    error = ReadZA();
+    if (error.Fail())
+      return error;
+
+    // ZA is part of the SME set but not stored with the other SME registers.
+    // So its byte offset is effectively always 0.
+    dst = (uint8_t *)GetZABuffer() + GetZAHeaderSize();
+    ::memcpy(dst, reg_value.GetBytes(), reg_info->byte_size);
+
+    // While this is writing a header that contains a vector length, the only
+    // way to change that is via the vg register. So here we assume the length
+    // will always be the current length and no reconfigure is needed.
+    return WriteZA();
   }
 
   return Status("Failed to write register value");
@@ -503,8 +560,10 @@
   GPR,
   SVE, // Used for SVE and SSVE.
   FPR, // When there is no SVE, or SVE in FPSIMD mode.
+  // Pointer authentication registers are read only, so not included here.
   MTE,
   TLS,
+  SME, // ZA only, SVCR and SVG are pseudo registers.
 };
 
 static uint8_t *AddRegisterSetType(uint8_t *dst,
@@ -533,6 +592,24 @@
   if (error.Fail())
     return error;
 
+  // Here this means, does the system have ZA, not whether it is active.
+  if (GetRegisterInfo().IsZAEnabled()) {
+    error = ReadZAHeader();
+    if (error.Fail())
+      return error;
+    // Use header size here because the buffer may contain fake data when ZA is
+    // disabled. We do not want to write this fake data (all 0s) because this
+    // would tell the kernel that we want ZA to become active. Which is the
+    // opposite of what we want in the case where it is currently inactive.
+    cached_size += sizeof(RegisterSetType) + m_za_header.size;
+    // For the same reason, we need to force it to be re-read so that it will
+    // always contain the real header.
+    m_za_buffer_is_valid = false;
+    error = ReadZA();
+    if (error.Fail())
+      return error;
+  }
+
   // If SVE is enabled we need not copy FPR separately.
   if (GetRegisterInfo().IsSVEEnabled() || GetRegisterInfo().IsSSVEEnabled()) {
     // Store mode and register data.
@@ -583,6 +660,45 @@
   dst = AddSavedRegisters(dst, RegisterSetType::GPR, GetGPRBuffer(),
                           GetGPRBufferSize());
 
+  // Streaming SVE and the ZA register both use the streaming vector length.
+  // When you change this, the kernel will invalidate parts of the process
+  // state. Therefore we need a specific order of restoration for each mode, if
+  // we also have ZA to restore.
+  //
+  // Streaming mode enabled, ZA enabled:
+  // * Write streaming registers. This sets SVCR.SM and clears SVCR.ZA.
+  // * Write ZA, this set SVCR.ZA. The register data we provide is written to
+  // ZA.
+  // * Result is SVCR.SM and SVCR.ZA set, with the expected data in both
+  //   register sets.
+  //
+  // Streaming mode disabled, ZA enabled:
+  // * Write ZA. This sets SVCR.ZA, and the ZA content. In the majority of cases
+  //   the streaming vector length is changing, so the thread is converted into
+  //   an FPSIMD thread if it is not already one. This also clears SVCR.SM.
+  // * Write SVE registers, which also clears SVCR.SM but most importantly, puts
+  //   us into full SVE mode instead of FPSIMD mode (where the registers are
+  //   actually the 128 bit Neon registers).
+  // * Result is we have SVCR.SM = 0, SVCR.ZA = 1 and the expected register
+  //   state.
+  //
+  // Restoring in different orders leads to things like the SVE registers being
+  // truncated due to the FPSIMD mode and ZA being disabled or filled with 0s
+  // (disabled and 0s looks the same from inside lldb since we fake the value
+  // when it's disabled).
+  //
+  // For more information on this, look up the uses of the relevant NT_ARM_
+  // constants and the functions vec_set_vector_length, sve_set_common and
+  // za_set in the Linux Kernel.
+
+  if ((m_sve_state != SVEState::Streaming) && GetRegisterInfo().IsZAEnabled()) {
+    // Use the header size not the buffer size, as we may be using the buffer
+    // for fake data, which we do not want to write out.
+    assert(m_za_header.size <= GetZABufferSize());
+    dst = AddSavedRegisters(dst, RegisterSetType::SME, GetZABuffer(),
+                            m_za_header.size);
+  }
+
   if (GetRegisterInfo().IsSVEEnabled() || GetRegisterInfo().IsSSVEEnabled()) {
     dst = AddRegisterSetType(dst, RegisterSetType::SVE);
     *(reinterpret_cast<SVEState *>(dst)) = m_sve_state;
@@ -593,6 +709,12 @@
                             GetFPRSize());
   }
 
+  if ((m_sve_state == SVEState::Streaming) && GetRegisterInfo().IsZAEnabled()) {
+    assert(m_za_header.size <= GetZABufferSize());
+    dst = AddSavedRegisters(dst, RegisterSetType::SME, GetZABuffer(),
+                            m_za_header.size);
+  }
+
   if (GetRegisterInfo().IsMTEEnabled()) {
     dst = AddSavedRegisters(dst, RegisterSetType::MTE, GetMTEControl(),
                             GetMTEControlSize());
@@ -685,6 +807,8 @@
         return error;
 
       // SVE header has been written configure SVE vector length if needed.
+      // This could change ZA data too, but that will be restored again later
+      // anyway.
       ConfigureRegisterContext();
 
       // Write header and register data, incrementing src this time.
@@ -707,6 +831,33 @@
           GetTLSBuffer(), &src, GetTLSBufferSize(), m_tls_is_valid,
           std::bind(&NativeRegisterContextLinux_arm64::WriteTLS, this));
       break;
+    case RegisterSetType::SME:
+      // To enable or disable ZA you write the regset with or without register
+      // data. The kernel detects this by looking at the ioVec's length, not the
+      // ZA header size you pass in. Therefore we must write header and register
+      // data (if present) in one go every time. Read the header only first just
+      // to get the size.
+      ::memcpy(GetZAHeader(), src, GetZAHeaderSize());
+      // Read the header and register data. Can't use the buffer size here, it
+      // may be incorrect due to being filled with dummy data previously. Resize
+      // this so WriteZA uses the correct size.
+      m_za_ptrace_payload.resize(m_za_header.size);
+      ::memcpy(GetZABuffer(), src, GetZABufferSize());
+      m_za_buffer_is_valid = true;
+
+      error = WriteZA();
+      if (error.Fail())
+        return error;
+
+      // Update size of ZA, which resizes the ptrace payload potentially
+      // trashing our copy of the data we just wrote.
+      ConfigureRegisterContext();
+
+      // ZA buffer now has proper size, read back the data we wrote above, from
+      // ptrace.
+      error = ReadZA();
+      src += GetZABufferSize();
+      break;
     }
 
     if (error.Fail())
@@ -734,6 +885,10 @@
   return GetRegisterInfo().IsSVEReg(reg);
 }
 
+bool NativeRegisterContextLinux_arm64::IsSME(unsigned reg) const {
+  return GetRegisterInfo().IsSMEReg(reg);
+}
+
 bool NativeRegisterContextLinux_arm64::IsPAuth(unsigned reg) const {
   return GetRegisterInfo().IsPAuthReg(reg);
 }
@@ -887,11 +1042,13 @@
   m_fpu_is_valid = false;
   m_sve_buffer_is_valid = false;
   m_sve_header_is_valid = false;
+  m_za_buffer_is_valid = false;
+  m_za_header_is_valid = false;
   m_pac_mask_is_valid = false;
   m_mte_ctrl_is_valid = false;
   m_tls_is_valid = false;
 
-  // Update SVE registers in case there is change in configuration.
+  // Update SVE and ZA registers in case there is change in configuration.
   ConfigureRegisterContext();
 }
 
@@ -1057,6 +1214,62 @@
   return WriteRegisterSet(&ioVec, GetTLSBufferSize(), NT_ARM_TLS);
 }
 
+Status NativeRegisterContextLinux_arm64::ReadZAHeader() {
+  Status error;
+
+  if (m_za_header_is_valid)
+    return error;
+
+  struct iovec ioVec;
+  ioVec.iov_base = GetZAHeader();
+  ioVec.iov_len = GetZAHeaderSize();
+
+  error = ReadRegisterSet(&ioVec, GetZAHeaderSize(), NT_ARM_ZA);
+
+  if (error.Success())
+    m_za_header_is_valid = true;
+
+  return error;
+}
+
+Status NativeRegisterContextLinux_arm64::ReadZA() {
+  Status error;
+
+  if (m_za_buffer_is_valid)
+    return error;
+
+  struct iovec ioVec;
+  ioVec.iov_base = GetZABuffer();
+  ioVec.iov_len = GetZABufferSize();
+
+  error = ReadRegisterSet(&ioVec, GetZABufferSize(), NT_ARM_ZA);
+
+  if (error.Success())
+    m_za_buffer_is_valid = true;
+
+  return error;
+}
+
+Status NativeRegisterContextLinux_arm64::WriteZA() {
+  // Note that because the ZA ptrace payload contains the header also, this
+  // method will write both. This is done because writing only the header
+  // will disable ZA, even if .size in the header is correct for an enabled ZA.
+  Status error;
+
+  error = ReadZA();
+  if (error.Fail())
+    return error;
+
+  struct iovec ioVec;
+  ioVec.iov_base = GetZABuffer();
+  ioVec.iov_len = GetZABufferSize();
+
+  m_za_buffer_is_valid = false;
+  m_za_header_is_valid = false;
+
+  return WriteRegisterSet(&ioVec, GetZABufferSize(), NT_ARM_ZA);
+}
+
 void NativeRegisterContextLinux_arm64::ConfigureRegisterContext() {
   // ConfigureRegisterContext gets called from InvalidateAllRegisters
   // on every stop and configures SVE vector length and whether we are in
@@ -1097,15 +1310,28 @@
     if (m_sve_state == SVEState::Full || m_sve_state == SVEState::FPSIMD ||
         m_sve_state == SVEState::Streaming) {
       // On every stop we configure SVE vector length by calling
-      // ConfigureVectorLength regardless of current SVEState of this thread.
+      // ConfigureVectorLengthSVE regardless of current SVEState of this thread.
       uint32_t vq = RegisterInfoPOSIX_arm64::eVectorQuadwordAArch64SVE;
       if (sve::vl_valid(m_sve_header.vl))
         vq = sve::vq_from_vl(m_sve_header.vl);
 
-      GetRegisterInfo().ConfigureVectorLength(vq);
+      GetRegisterInfo().ConfigureVectorLengthSVE(vq);
       m_sve_ptrace_payload.resize(sve::PTraceSize(vq, sve::ptrace_regs_sve));
     }
   }
+
+  if (!m_za_header_is_valid) {
+    Status error = ReadZAHeader();
+    if (error.Success()) {
+      uint32_t vq = RegisterInfoPOSIX_arm64::eVectorQuadwordAArch64SVE;
+      if (sve::vl_valid(m_za_header.vl))
+        vq = sve::vq_from_vl(m_za_header.vl);
+
+      GetRegisterInfo().ConfigureVectorLengthZA(vq);
+      m_za_ptrace_payload.resize(m_za_header.size);
+      m_za_buffer_is_valid = false;
+    }
+  }
 }
 
 uint32_t NativeRegisterContextLinux_arm64::CalculateFprOffset(
@@ -1135,6 +1361,7 @@
     ExpeditedRegs expType) const {
   std::vector<uint32_t> expedited_reg_nums =
       NativeRegisterContext::GetExpeditedRegisters(expType);
+  // SVE, non-streaming vector length.
   if (m_sve_state == SVEState::FPSIMD || m_sve_state == SVEState::Full)
     expedited_reg_nums.push_back(GetRegisterInfo().GetRegNumSVEVG());
 
Index: lldb/include/lldb/Utility/RegisterValue.h
===================================================================
--- lldb/include/lldb/Utility/RegisterValue.h
+++ lldb/include/lldb/Utility/RegisterValue.h
@@ -33,7 +33,9 @@
     // byte AArch64 SVE.
     kTypicalRegisterByteSize = 256u,
     // Anything else we'll heap allocate storage for it.
-    kMaxRegisterByteSize = kTypicalRegisterByteSize,
+    // 256x256 to support 256 byte AArch64 SME's array storage (ZA) register.
+    // Which is a square of vector length x vector length.
+    kMaxRegisterByteSize = 256u * 256u,
   };
 
   typedef llvm::SmallVector<uint8_t, kTypicalRegisterByteSize> BytesContainer;
_______________________________________________
lldb-commits mailing list
lldb-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits

Reply via email to