DavidSpickett updated this revision to Diff 556559.
DavidSpickett added a comment.

Rebase


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D159505/new/

https://reviews.llvm.org/D159505

Files:
  
lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py
  
lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py
  
lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile
  
lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py
  
lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c
  
lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile
  
lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py
  
lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c

Index: lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c
===================================================================
--- /dev/null
+++ lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c
@@ -0,0 +1,226 @@
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/prctl.h>
+
+// Important details for this program:
+// * Making a syscall will disable streaming mode if it is active.
+// * Changing the vector length will make streaming mode and ZA inactive.
+// * ZA can be active independent of streaming mode.
+// * ZA's size is the streaming vector length squared.
+
+#ifndef PR_SME_SET_VL
+#define PR_SME_SET_VL 63
+#endif
+
+#ifndef PR_SME_GET_VL
+#define PR_SME_GET_VL 64
+#endif
+
+#ifndef PR_SME_VL_LEN_MASK
+#define PR_SME_VL_LEN_MASK 0xffff
+#endif
+
+#define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr")
+#define SMSTART SM_INST(7)
+#define SMSTART_SM SM_INST(3)
+#define SMSTART_ZA SM_INST(5)
+#define SMSTOP SM_INST(6)
+#define SMSTOP_SM SM_INST(2)
+#define SMSTOP_ZA SM_INST(4)
+
+int start_vl = 0;
+int other_vl = 0;
+
+void write_sve_regs() {
+  // We assume the smefa64 feature is present, which allows ffr access
+  // in streaming mode.
+  asm volatile("setffr\n\t");
+  asm volatile("ptrue p0.b\n\t");
+  asm volatile("ptrue p1.h\n\t");
+  asm volatile("ptrue p2.s\n\t");
+  asm volatile("ptrue p3.d\n\t");
+  asm volatile("pfalse p4.b\n\t");
+  asm volatile("ptrue p5.b\n\t");
+  asm volatile("ptrue p6.h\n\t");
+  asm volatile("ptrue p7.s\n\t");
+  asm volatile("ptrue p8.d\n\t");
+  asm volatile("pfalse p9.b\n\t");
+  asm volatile("ptrue p10.b\n\t");
+  asm volatile("ptrue p11.h\n\t");
+  asm volatile("ptrue p12.s\n\t");
+  asm volatile("ptrue p13.d\n\t");
+  asm volatile("pfalse p14.b\n\t");
+  asm volatile("ptrue p15.b\n\t");
+
+  asm volatile("cpy  z0.b, p0/z, #1\n\t");
+  asm volatile("cpy  z1.b, p5/z, #2\n\t");
+  asm volatile("cpy  z2.b, p10/z, #3\n\t");
+  asm volatile("cpy  z3.b, p15/z, #4\n\t");
+  asm volatile("cpy  z4.b, p0/z, #5\n\t");
+  asm volatile("cpy  z5.b, p5/z, #6\n\t");
+  asm volatile("cpy  z6.b, p10/z, #7\n\t");
+  asm volatile("cpy  z7.b, p15/z, #8\n\t");
+  asm volatile("cpy  z8.b, p0/z, #9\n\t");
+  asm volatile("cpy  z9.b, p5/z, #10\n\t");
+  asm volatile("cpy  z10.b, p10/z, #11\n\t");
+  asm volatile("cpy  z11.b, p15/z, #12\n\t");
+  asm volatile("cpy  z12.b, p0/z, #13\n\t");
+  asm volatile("cpy  z13.b, p5/z, #14\n\t");
+  asm volatile("cpy  z14.b, p10/z, #15\n\t");
+  asm volatile("cpy  z15.b, p15/z, #16\n\t");
+  asm volatile("cpy  z16.b, p0/z, #17\n\t");
+  asm volatile("cpy  z17.b, p5/z, #18\n\t");
+  asm volatile("cpy  z18.b, p10/z, #19\n\t");
+  asm volatile("cpy  z19.b, p15/z, #20\n\t");
+  asm volatile("cpy  z20.b, p0/z, #21\n\t");
+  asm volatile("cpy  z21.b, p5/z, #22\n\t");
+  asm volatile("cpy  z22.b, p10/z, #23\n\t");
+  asm volatile("cpy  z23.b, p15/z, #24\n\t");
+  asm volatile("cpy  z24.b, p0/z, #25\n\t");
+  asm volatile("cpy  z25.b, p5/z, #26\n\t");
+  asm volatile("cpy  z26.b, p10/z, #27\n\t");
+  asm volatile("cpy  z27.b, p15/z, #28\n\t");
+  asm volatile("cpy  z28.b, p0/z, #29\n\t");
+  asm volatile("cpy  z29.b, p5/z, #30\n\t");
+  asm volatile("cpy  z30.b, p10/z, #31\n\t");
+  asm volatile("cpy  z31.b, p15/z, #32\n\t");
+}
+
+// Write something different so we will know if we didn't restore them
+// correctly.
+void write_sve_regs_expr() {
+  asm volatile("pfalse p0.b\n\t");
+  asm volatile("wrffr p0.b\n\t");
+  asm volatile("pfalse p1.b\n\t");
+  asm volatile("pfalse p2.b\n\t");
+  asm volatile("pfalse p3.b\n\t");
+  asm volatile("ptrue p4.b\n\t");
+  asm volatile("pfalse p5.b\n\t");
+  asm volatile("pfalse p6.b\n\t");
+  asm volatile("pfalse p7.b\n\t");
+  asm volatile("pfalse p8.b\n\t");
+  asm volatile("ptrue p9.b\n\t");
+  asm volatile("pfalse p10.b\n\t");
+  asm volatile("pfalse p11.b\n\t");
+  asm volatile("pfalse p12.b\n\t");
+  asm volatile("pfalse p13.b\n\t");
+  asm volatile("ptrue p14.b\n\t");
+  asm volatile("pfalse p15.b\n\t");
+
+  asm volatile("cpy  z0.b, p0/z, #2\n\t");
+  asm volatile("cpy  z1.b, p5/z, #3\n\t");
+  asm volatile("cpy  z2.b, p10/z, #4\n\t");
+  asm volatile("cpy  z3.b, p15/z, #5\n\t");
+  asm volatile("cpy  z4.b, p0/z, #6\n\t");
+  asm volatile("cpy  z5.b, p5/z, #7\n\t");
+  asm volatile("cpy  z6.b, p10/z, #8\n\t");
+  asm volatile("cpy  z7.b, p15/z, #9\n\t");
+  asm volatile("cpy  z8.b, p0/z, #10\n\t");
+  asm volatile("cpy  z9.b, p5/z, #11\n\t");
+  asm volatile("cpy  z10.b, p10/z, #12\n\t");
+  asm volatile("cpy  z11.b, p15/z, #13\n\t");
+  asm volatile("cpy  z12.b, p0/z, #14\n\t");
+  asm volatile("cpy  z13.b, p5/z, #15\n\t");
+  asm volatile("cpy  z14.b, p10/z, #16\n\t");
+  asm volatile("cpy  z15.b, p15/z, #17\n\t");
+  asm volatile("cpy  z16.b, p0/z, #18\n\t");
+  asm volatile("cpy  z17.b, p5/z, #19\n\t");
+  asm volatile("cpy  z18.b, p10/z, #20\n\t");
+  asm volatile("cpy  z19.b, p15/z, #21\n\t");
+  asm volatile("cpy  z20.b, p0/z, #22\n\t");
+  asm volatile("cpy  z21.b, p5/z, #23\n\t");
+  asm volatile("cpy  z22.b, p10/z, #24\n\t");
+  asm volatile("cpy  z23.b, p15/z, #25\n\t");
+  asm volatile("cpy  z24.b, p0/z, #26\n\t");
+  asm volatile("cpy  z25.b, p5/z, #27\n\t");
+  asm volatile("cpy  z26.b, p10/z, #28\n\t");
+  asm volatile("cpy  z27.b, p15/z, #29\n\t");
+  asm volatile("cpy  z28.b, p0/z, #30\n\t");
+  asm volatile("cpy  z29.b, p5/z, #31\n\t");
+  asm volatile("cpy  z30.b, p10/z, #32\n\t");
+  asm volatile("cpy  z31.b, p15/z, #33\n\t");
+}
+
+void set_za_register(int svl, int value_offset) {
+#define MAX_VL_BYTES 256
+  uint8_t data[MAX_VL_BYTES];
+
+  // ldr za will actually wrap the selected vector row, by the number of rows
+  // you have. So setting one that didn't exist would actually set one that did.
+  // That's why we need the streaming vector length here.
+  for (int i = 0; i < svl; ++i) {
+    memset(data, i + value_offset, MAX_VL_BYTES);
+    // Each one of these loads a VL sized row of ZA.
+    asm volatile("mov w12, %w0\n\t"
+                 "ldr za[w12, 0], [%1]\n\t" ::"r"(i),
+                 "r"(&data)
+                 : "w12");
+  }
+}
+
+void expr_disable_za() {
+  SMSTOP_ZA;
+  write_sve_regs_expr();
+}
+
+void expr_enable_za() {
+  SMSTART_ZA;
+  set_za_register(start_vl, 2);
+  write_sve_regs_expr();
+}
+
+void expr_start_vl() {
+  prctl(PR_SME_SET_VL, start_vl);
+  SMSTART_ZA;
+  set_za_register(start_vl, 4);
+  write_sve_regs_expr();
+}
+
+void expr_other_vl() {
+  prctl(PR_SME_SET_VL, other_vl);
+  SMSTART_ZA;
+  set_za_register(other_vl, 5);
+  write_sve_regs_expr();
+}
+
+void expr_enable_sm() {
+  SMSTART_SM;
+  write_sve_regs_expr();
+}
+
+void expr_disable_sm() {
+  SMSTOP_SM;
+  write_sve_regs_expr();
+}
+
+int main(int argc, char *argv[]) {
+  // We expect to get:
+  // * whether to enable streaming mode
+  // * whether to enable ZA
+  // * what the starting VL should be
+  // * what the other VL should be
+  if (argc != 5)
+    return 1;
+
+  bool ssve = argv[1][0] == '1';
+  bool za = argv[2][0] == '1';
+  start_vl = atoi(argv[3]);
+  other_vl = atoi(argv[4]);
+
+  prctl(PR_SME_SET_VL, start_vl);
+
+  if (ssve)
+    SMSTART_SM;
+
+  if (za) {
+    SMSTART_ZA;
+    set_za_register(start_vl, 1);
+  }
+
+  write_sve_regs();
+
+  return 0; // Set a break point here.
+}
+
Index: lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py
===================================================================
--- /dev/null
+++ lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py
@@ -0,0 +1,238 @@
+"""
+Test the AArch64 SME ZA register is saved and restored around expressions.
+
+This attempts to cover expressions that change the following:
+* ZA enabled or not.
+* Streaming mode or not.
+* Streaming vector length (increasing and decreasing).
+* Some combintations of the above.
+"""
+
+from enum import IntEnum
+import lldb
+from lldbsuite.test.decorators import *
+from lldbsuite.test.lldbtest import *
+from lldbsuite.test import lldbutil
+
+
+# These enum values match the flag values used in the test program.
+class Mode(IntEnum):
+    SVE = 0
+    SSVE = 1
+
+
+class ZA(IntEnum):
+    Disabled = 0
+    Enabled = 1
+
+
+class AArch64ZATestCase(TestBase):
+    def get_supported_svg(self):
+        # Always build this probe program to start as streaming SVE.
+        # We will read/write "vg" here but since we are in streaming mode "svg"
+        # is really what we are writing ("svg" is a read only pseudo).
+        self.build()
+
+        exe = self.getBuildArtifact("a.out")
+        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
+        # Enter streaming mode, don't enable ZA, start_vl and other_vl don't
+        # matter here.
+        self.runCmd("settings set target.run-args 1 0 0 0")
+
+        stop_line = line_number("main.c", "// Set a break point here.")
+        lldbutil.run_break_set_by_file_and_line(self, "main.c", stop_line,
+                                                num_expected_locations=1)
+
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread info 1",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint"],
+        )
+
+        # Write back the current vg to confirm read/write works at all.
+        current_svg = self.match("register read vg", ["(0x[0-9]+)"])
+        self.assertTrue(current_svg is not None)
+        self.expect("register write vg {}".format(current_svg.group()))
+
+        # Aka 128, 256 and 512 bit.
+        supported_svg = []
+        for svg in [2, 4, 8]:
+            # This could mask other errors but writing vg is tested elsewhere
+            # so we assume the hardware rejected the value.
+            self.runCmd("register write vg {}".format(svg), check=False)
+            if not self.res.GetError():
+                supported_svg.append(svg)
+
+        self.runCmd("breakpoint delete 1")
+        self.runCmd("continue")
+
+        return supported_svg
+
+    def read_vg(self):
+        process = self.dbg.GetSelectedTarget().GetProcess()
+        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
+        sve_registers = registerSets.GetFirstValueByName("Scalable Vector Extension Registers")
+        return sve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned()
+
+    def read_svg(self):
+        process = self.dbg.GetSelectedTarget().GetProcess()
+        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
+        sve_registers = registerSets.GetFirstValueByName("Scalable Matrix Extension Registers")
+        return sve_registers.GetChildMemberWithName("svg").GetValueAsUnsigned()
+
+    def make_za_value(self, vl, generator):
+        # Generate a vector value string "{0x00 0x01....}".
+        rows = []
+        for row in range(vl):
+            byte = "0x{:02x}".format(generator(row))
+            rows.append(" ".join([byte]*vl))
+        return "{" + " ".join(rows) + "}"
+
+    def check_za(self, vl):
+        # We expect an increasing value starting at 1. Row 0=1, row 1 = 2, etc.
+        self.expect("register read za", substrs=[
+            self.make_za_value(vl, lambda row: row+1)])
+
+    def check_za_disabled(self, vl):
+        # When ZA is disabled, lldb will show ZA as all 0s.
+        self.expect("register read za", substrs=[
+            self.make_za_value(vl, lambda row: 0)])
+
+    def za_expr_test_impl(self, sve_mode, za_state, swap_start_vl):
+        if not self.isAArch64SME():
+            self.skipTest("SME must be present.")
+
+        supported_svg = self.get_supported_svg()
+        if len(supported_svg) < 2:
+            self.skipTest("Target must support at least 2 streaming vector lengths.")
+
+        # vg is in units of 8 bytes.
+        start_vl = supported_svg[1] * 8
+        other_vl = supported_svg[2] * 8
+
+        if swap_start_vl:
+            start_vl, other_vl = other_vl, start_vl
+
+        self.line = line_number("main.c", "// Set a break point here.")
+
+        exe = self.getBuildArtifact("a.out")
+        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
+        self.runCmd("settings set target.run-args {} {} {} {}".format(sve_mode,
+                    za_state, start_vl, other_vl))
+
+        lldbutil.run_break_set_by_file_and_line(
+            self, "main.c", self.line, num_expected_locations=1
+        )
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread backtrace",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint 1."],
+        )
+
+        exprs = ["expr_disable_za", "expr_enable_za", "expr_start_vl",
+                 "expr_other_vl", "expr_enable_sm", "expr_disable_sm"]
+
+        # This may be the streaming or non-streaming vg. All that matters is
+        # that it is saved and restored, remaining constant throughout.
+        start_vg = self.read_vg()
+
+        # Check SVE registers to make sure that combination of scaling SVE
+        # and scaling ZA works properly. This is a brittle check, but failures
+        # are likely to be catastrophic when they do happen anyway.
+        sve_reg_names = "ffr {} {}".format(
+            " ".join(["z{}".format(n) for n in range(32)]),
+            " ".join(["p{}".format(n) for n in range(16)]))
+        self.runCmd("register read " + sve_reg_names)
+        sve_values = self.res.GetOutput()
+
+        def check_regs():
+            if za_state == ZA.Enabled:
+                self.check_za(start_vl)
+            else:
+                self.check_za_disabled(start_vl)
+
+            # svg and vg are in units of 8 bytes.
+            self.assertEqual(start_vl, self.read_svg()*8)
+            self.assertEqual(start_vg, self.read_vg())
+
+            self.expect("register read " + sve_reg_names, substrs=[sve_values])
+
+        for expr in exprs:
+            expr_cmd = "expression {}()".format(expr)
+
+            # We do this twice because there were issues in development where
+            # using data stored by a previous WriteAllRegisterValues would crash
+            # the second time around.
+            self.runCmd(expr_cmd)
+            check_regs()
+            self.runCmd(expr_cmd)
+            check_regs()
+
+        # Run them in sequence to make sure there is no state lingering between
+        # them after a restore.
+        for expr in exprs:
+            self.runCmd("expression {}()".format(expr))
+            check_regs()
+
+        for expr in reversed(exprs):
+            self.runCmd("expression {}()".format(expr))
+            check_regs()
+
+    # These tests start with the 1st supported SVL and change to the 2nd
+    # supported SVL as needed.
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_ssve_za_enabled(self):
+        self.za_expr_test_impl(Mode.SSVE, ZA.Enabled, False)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_ssve_za_disabled(self):
+        self.za_expr_test_impl(Mode.SSVE, ZA.Disabled, False)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_sve_za_enabled(self):
+        self.za_expr_test_impl(Mode.SVE, ZA.Enabled, False)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_sve_za_disabled(self):
+        self.za_expr_test_impl(Mode.SVE, ZA.Disabled, False)
+
+    # These tests start in the 2nd supported SVL and change to the 1st supported
+    # SVL as needed.
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_ssve_za_enabled_different_vl(self):
+        self.za_expr_test_impl(Mode.SSVE, ZA.Enabled, True)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_ssve_za_disabled_different_vl(self):
+        self.za_expr_test_impl(Mode.SSVE, ZA.Disabled, True)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_sve_za_enabled_different_vl(self):
+        self.za_expr_test_impl(Mode.SVE, ZA.Enabled, True)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_sve_za_disabled_different_vl(self):
+        self.za_expr_test_impl(Mode.SVE, ZA.Disabled, True)
+
Index: lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile
===================================================================
--- /dev/null
+++ lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile
@@ -0,0 +1,5 @@
+C_SOURCES := main.c
+
+CFLAGS_EXTRAS := -march=armv8-a+sve+sme
+
+include Makefile.rules
Index: lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c
===================================================================
--- /dev/null
+++ lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c
@@ -0,0 +1,103 @@
+#include <pthread.h>
+#include <stdatomic.h>
+#include <stdbool.h>
+#include <stdint.h>
+#include <string.h>
+#include <sys/prctl.h>
+
+// Important notes for this test:
+// * Making a syscall will disable streaming mode.
+// * LLDB writing to vg while in streaming mode will disable ZA
+//   (this is just how ptrace works).
+// * Writing to an inactive ZA produces a SIGILL.
+
+#ifndef PR_SME_SET_VL
+#define PR_SME_SET_VL 63
+#endif
+
+#define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr")
+#define SMSTART_SM SM_INST(3)
+#define SMSTART_ZA SM_INST(5)
+
+void set_za_register(int svl, int value_offset) {
+#define MAX_VL_BYTES 256
+  uint8_t data[MAX_VL_BYTES];
+
+  // ldr za will actually wrap the selected vector row, by the number of rows
+  // you have. So setting one that didn't exist would actually set one that did.
+  // That's why we need the streaming vector length here.
+  for (int i = 0; i < svl; ++i) {
+    memset(data, i + value_offset, MAX_VL_BYTES);
+    // Each one of these loads a VL sized row of ZA.
+    asm volatile("mov w12, %w0\n\t"
+                 "ldr za[w12, 0], [%1]\n\t" ::"r"(i),
+                 "r"(&data)
+                 : "w12");
+  }
+}
+
+// These are used to make sure we only break in each thread once both of the
+// threads have been started. Otherwise when the test does "process continue"
+// it could stop in one thread and wait forever for the other one to start.
+atomic_bool threadX_ready = false;
+atomic_bool threadY_ready = false;
+
+void *threadX_func(void *x_arg) {
+  threadX_ready = true;
+  while (!threadY_ready) {
+  }
+
+  prctl(PR_SME_SET_VL, 8 * 4);
+  SMSTART_SM;
+  SMSTART_ZA;
+  set_za_register(8 * 4, 2);
+  SMSTART_ZA; // Thread X breakpoint 1
+  set_za_register(8 * 2, 2);
+  return NULL; // Thread X breakpoint 2
+}
+
+void *threadY_func(void *y_arg) {
+  threadY_ready = true;
+  while (!threadX_ready) {
+  }
+
+  prctl(PR_SME_SET_VL, 8 * 2);
+  SMSTART_SM;
+  SMSTART_ZA;
+  set_za_register(8 * 2, 3);
+  SMSTART_ZA; // Thread Y breakpoint 1
+  set_za_register(8 * 4, 3);
+  return NULL; // Thread Y breakpoint 2
+}
+
+int main(int argc, char *argv[]) {
+  // Expecting argument to tell us whether to enable ZA on the main thread.
+  if (argc != 2)
+    return 1;
+
+  prctl(PR_SME_SET_VL, 8 * 8);
+  SMSTART_SM;
+
+  if (argv[1][0] == '1') {
+    SMSTART_ZA;
+    set_za_register(8 * 8, 1);
+  }
+  // else we do not enable ZA and lldb will show 0s for it.
+
+  pthread_t x_thread;
+  if (pthread_create(&x_thread, NULL, threadX_func, 0)) // Break in main thread
+    return 1;
+
+  pthread_t y_thread;
+  if (pthread_create(&y_thread, NULL, threadY_func, 0))
+    return 1;
+
+  if (pthread_join(x_thread, NULL))
+    return 2;
+
+  if (pthread_join(y_thread, NULL))
+    return 2;
+
+  return 0;
+}
+
Index: lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py
===================================================================
--- lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py
+++ lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py
@@ -1,11 +1,6 @@
 """
-Test the AArch64 SVE and Streaming SVE (SSVE) registers dynamic resize with
+Test the AArch64 SME Array Storage (ZA) register dynamic resize with
 multiple threads.
-
-This test assumes a minimum supported vector length (VL) of 256 bits
-and will test 512 bits if possible. We refer to "vg" which is the
-register shown in lldb. This is in units of 64 bits. 256 bit VL is
-the same as a vg of 4.
 """
 
 from enum import Enum
@@ -15,21 +10,15 @@
 from lldbsuite.test import lldbutil
 
 
-class Mode(Enum):
-    SVE = 0
-    SSVE = 1
-
-
-class RegisterCommandsTestCase(TestBase):
+class AArch64ZAThreadedTestCase(TestBase):
     def get_supported_vg(self):
-        # Changing VL trashes the register state, so we need to run the program
-        # just to test this. Then run it again for the test.
         exe = self.getBuildArtifact("a.out")
         self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
 
         main_thread_stop_line = line_number("main.c", "// Break in main thread")
         lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line)
 
+        self.runCmd("settings set target.run-args 0")
         self.runCmd("run", RUN_SUCCEEDED)
 
         self.expect(
@@ -38,7 +27,6 @@
             substrs=["stop reason = breakpoint"],
         )
 
-        # Write back the current vg to confirm read/write works at all.
         current_vg = self.match("register read vg", ["(0x[0-9]+)"])
         self.assertTrue(current_vg is not None)
         self.expect("register write vg {}".format(current_vg.group()))
@@ -57,64 +45,36 @@
 
         return supported_vg
 
-    def check_sve_registers(self, vg_test_value):
-        z_reg_size = vg_test_value * 8
-        p_reg_size = int(z_reg_size / 8)
-
-        p_value_bytes = ["0xff", "0x55", "0x11", "0x01", "0x00"]
-
-        for i in range(32):
-            s_reg_value = "s%i = 0x" % (i) + "".join(
-                "{:02x}".format(i + 1) for _ in range(4)
-            )
-
-            d_reg_value = "d%i = 0x" % (i) + "".join(
-                "{:02x}".format(i + 1) for _ in range(8)
-            )
-
-            v_reg_value = "v%i = 0x" % (i) + "".join(
-                "{:02x}".format(i + 1) for _ in range(16)
-            )
-
-            z_reg_value = (
-                "{"
-                + " ".join("0x{:02x}".format(i + 1) for _ in range(z_reg_size))
-                + "}"
-            )
-
-            self.expect("register read -f hex " + "s%i" % (i), substrs=[s_reg_value])
+    def gen_za_value(self, svg, value_generator):
+        svl = svg*8
 
-            self.expect("register read -f hex " + "d%i" % (i), substrs=[d_reg_value])
+        rows = []
+        for row in range(svl):
+            byte = "0x{:02x}".format(value_generator(row))
+            rows.append(" ".join([byte]*svl))
 
-            self.expect("register read -f hex " + "v%i" % (i), substrs=[v_reg_value])
+        return "{" + " ".join(rows) + "}"
 
-            self.expect("register read " + "z%i" % (i), substrs=[z_reg_value])
+    def check_za_register(self, svg, value_offset):
+        self.expect("register read za", substrs=[
+            self.gen_za_value(svg, lambda r: r+value_offset)])
 
-        for i in range(16):
-            p_regs_value = (
-                "{" + " ".join(p_value_bytes[i % 5] for _ in range(p_reg_size)) + "}"
-            )
-            self.expect("register read " + "p%i" % (i), substrs=[p_regs_value])
+    def check_disabled_za_register(self, svg):
+        self.expect("register read za", substrs=[
+            self.gen_za_value(svg, lambda r: 0)])
 
-        self.expect("register read ffr", substrs=[p_regs_value])
-
-    def run_sve_test(self, mode):
-        if (mode == Mode.SVE) and not self.isAArch64SVE():
-            self.skipTest("SVE registers must be supported.")
-
-        if (mode == Mode.SSVE) and not self.isAArch64SME():
-            self.skipTest("Streaming SVE registers must be supported.")
-
-        cflags = "-march=armv8-a+sve -lpthread"
-        if mode == Mode.SSVE:
-            cflags += " -DUSE_SSVE"
-        self.build(dictionary={"CFLAGS_EXTRAS": cflags})
+    def za_test_impl(self, enable_za):
+        if not self.isAArch64SME():
+            self.skipTest("SME must be present.")
 
         self.build()
         supported_vg = self.get_supported_vg()
 
+        self.runCmd("settings set target.run-args {}".format(
+            '1' if enable_za else '0'))
+
         if not (2 in supported_vg and 4 in supported_vg):
-            self.skipTest("Not all required SVE vector lengths are supported.")
+            self.skipTest("Not all required streaming vector lengths are supported.")
 
         main_thread_stop_line = line_number("main.c", "// Break in main thread")
         lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line)
@@ -133,8 +93,6 @@
 
         self.runCmd("run", RUN_SUCCEEDED)
 
-        process = self.dbg.GetSelectedTarget().GetProcess()
-
         self.expect(
             "thread info 1",
             STOPPED_DUE_TO_BREAKPOINT,
@@ -142,12 +100,19 @@
         )
 
         if 8 in supported_vg:
-            self.check_sve_registers(8)
+            if enable_za:
+                self.check_za_register(8, 1)
+            else:
+                self.check_disabled_za_register(8)
         else:
-            self.check_sve_registers(4)
+            if enable_za:
+                self.check_za_register(4, 1)
+            else:
+                self.check_disabled_za_register(4)
 
         self.runCmd("process continue", RUN_SUCCEEDED)
 
+        process = self.dbg.GetSelectedTarget().GetProcess()
         for idx in range(1, process.GetNumThreads()):
             thread = process.GetThreadAtIndex(idx)
             if thread.GetStopReason() != lldb.eStopReasonBreakpoint:
@@ -158,12 +123,12 @@
 
             if stopped_at_line_number == thX_break_line1:
                 self.runCmd("thread select %d" % (idx + 1))
-                self.check_sve_registers(4)
+                self.check_za_register(4, 2)
                 self.runCmd("register write vg 2")
 
             elif stopped_at_line_number == thY_break_line1:
                 self.runCmd("thread select %d" % (idx + 1))
-                self.check_sve_registers(2)
+                self.check_za_register(2, 3)
                 self.runCmd("register write vg 4")
 
         self.runCmd("thread continue 2")
@@ -177,22 +142,25 @@
 
             if stopped_at_line_number == thX_break_line2:
                 self.runCmd("thread select %d" % (idx + 1))
-                self.check_sve_registers(2)
+                self.check_za_register(2, 2)
 
             elif stopped_at_line_number == thY_break_line2:
                 self.runCmd("thread select %d" % (idx + 1))
-                self.check_sve_registers(4)
+                self.check_za_register(4, 3)
 
     @no_debug_info_test
     @skipIf(archs=no_match(["aarch64"]))
     @skipIf(oslist=no_match(["linux"]))
-    def test_sve_registers_dynamic_config(self):
-        """Test AArch64 SVE registers multi-threaded dynamic resize."""
-        self.run_sve_test(Mode.SVE)
+    def test_za_register_dynamic_config_main_enabled(self):
+        """ Test multiple threads resizing ZA, with the main thread's ZA
+            enabled."""
+        self.za_test_impl(True)
 
     @no_debug_info_test
     @skipIf(archs=no_match(["aarch64"]))
     @skipIf(oslist=no_match(["linux"]))
-    def test_ssve_registers_dynamic_config(self):
-        """Test AArch64 SSVE registers multi-threaded dynamic resize."""
-        self.run_sve_test(Mode.SSVE)
+    def test_za_register_dynamic_config_main_disabled(self):
+        """ Test multiple threads resizing ZA, with the main thread's ZA
+            disabled."""
+        self.za_test_impl(False)
+
Index: lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile
===================================================================
--- /dev/null
+++ lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile
@@ -0,0 +1,5 @@
+C_SOURCES := main.c
+
+CFLAGS_EXTRAS := -march=armv8-a+sve+sme -lpthread
+
+include Makefile.rules
Index: lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py
===================================================================
--- lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py
+++ lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py
@@ -98,6 +98,12 @@
 
         self.expect("register read ffr", substrs=[p_regs_value])
 
+    def build_for_mode(self, mode):
+        cflags = "-march=armv8-a+sve -lpthread"
+        if mode == Mode.SSVE:
+            cflags += " -DUSE_SSVE"
+        self.build(dictionary={"CFLAGS_EXTRAS": cflags})
+
     def run_sve_test(self, mode):
         if (mode == Mode.SVE) and not self.isAArch64SVE():
             self.skipTest("SVE registers must be supported.")
@@ -105,12 +111,8 @@
         if (mode == Mode.SSVE) and not self.isAArch64SME():
             self.skipTest("Streaming SVE registers must be supported.")
 
-        cflags = "-march=armv8-a+sve -lpthread"
-        if mode == Mode.SSVE:
-            cflags += " -DUSE_SSVE"
-        self.build(dictionary={"CFLAGS_EXTRAS": cflags})
+        self.build_for_mode(mode)
 
-        self.build()
         supported_vg = self.get_supported_vg()
 
         if not (2 in supported_vg and 4 in supported_vg):
@@ -196,3 +198,95 @@
     def test_ssve_registers_dynamic_config(self):
         """Test AArch64 SSVE registers multi-threaded dynamic resize."""
         self.run_sve_test(Mode.SSVE)
+
+    def setup_svg_test(self, mode):
+        # Even when running in SVE mode, we need access to SVG for these tests.
+        if not self.isAArch64SME():
+            self.skipTest("Streaming SVE registers must be present.")
+
+        self.build_for_mode(mode)
+
+        supported_vg = self.get_supported_vg()
+
+        main_thread_stop_line = line_number("main.c", "// Break in main thread")
+        lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line)
+
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread info 1",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint"],
+        )
+
+        target = self.dbg.GetSelectedTarget()
+        process = target.GetProcess()
+
+        return process, supported_vg
+
+    def read_reg(self, process, regset, reg):
+        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
+        sve_registers = registerSets.GetFirstValueByName(regset)
+        return sve_registers.GetChildMemberWithName(reg).GetValueAsUnsigned()
+
+    def read_vg(self, process):
+        return self.read_reg(process, "Scalable Vector Extension Registers", "vg")
+
+    def read_svg(self, process):
+        return self.read_reg(process, "Scalable Matrix Extension Registers", "svg")
+
+    def do_svg_test(self, process, vgs, expected_svgs):
+        for vg, svg in zip(vgs, expected_svgs):
+            self.runCmd("register write vg {}".format(vg))
+            self.assertEqual(svg, self.read_svg(process))
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_svg_sve_mode(self):
+        """ When in SVE mode, svg should remain constant as we change vg. """
+        process, supported_vg = self.setup_svg_test(Mode.SVE)
+        svg = self.read_svg(process)
+        self.do_svg_test(process, supported_vg, [svg]*len(supported_vg))
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_svg_ssve_mode(self):
+        """ When in SSVE mode, changing vg should change svg to the same value. """
+        process, supported_vg = self.setup_svg_test(Mode.SSVE)
+        self.do_svg_test(process, supported_vg, supported_vg)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_sme_not_present(self):
+        """ When there is no SME, we should not show the SME register sets."""
+        if self.isAArch64SME():
+            self.skipTest("Streaming SVE registers must not be present.")
+
+        self.build_for_mode(Mode.SVE)
+
+        exe = self.getBuildArtifact("a.out")
+        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
+
+        # This test may run on a non-sve system, but we'll stop before any
+        # SVE instruction would be run.
+        self.runCmd("b main")
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread info 1",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint"],
+        )
+
+        target = self.dbg.GetSelectedTarget()
+        process = target.GetProcess()
+
+        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
+        sme_registers = registerSets.GetFirstValueByName("Scalable Matrix Extension Registers")
+        self.assertFalse(sme_registers.IsValid())
+
+        za = registerSets.GetFirstValueByName("Scalable Matrix Array Storage Registers")
+        self.assertFalse(za.IsValid())
Index: lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py
===================================================================
--- lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py
+++ lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py
@@ -70,15 +70,14 @@
         self.runCmd("register write ffr " + "'" + p_regs_value + "'")
         self.expect("register read ffr", substrs=[p_regs_value])
 
-    @no_debug_info_test
-    @skipIf(archs=no_match(["aarch64"]))
-    @skipIf(oslist=no_match(["linux"]))
-    def test_aarch64_dynamic_regset_config(self):
-        """Test AArch64 Dynamic Register sets configuration."""
+
+    def setup_register_config_test(self, run_args=None):
         self.build()
         self.line = line_number("main.c", "// Set a break point here.")
 
         exe = self.getBuildArtifact("a.out")
+        if run_args is not None:
+            self.runCmd("settings set target.run-args " + run_args)
         self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
 
         lldbutil.run_break_set_by_file_and_line(
@@ -97,7 +96,16 @@
         thread = process.GetThreadAtIndex(0)
         currentFrame = thread.GetFrameAtIndex(0)
 
-        for registerSet in currentFrame.GetRegisters():
+        return currentFrame.GetRegisters()
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_aarch64_dynamic_regset_config(self):
+        """Test AArch64 Dynamic Register sets configuration."""
+        register_sets = self.setup_register_config_test()
+
+        for registerSet in register_sets:
             if "Scalable Vector Extension Registers" in registerSet.GetName():
                 self.assertTrue(
                     self.isAArch64SVE(),
@@ -120,6 +128,20 @@
                 )
                 self.expect("register read data_mask", substrs=["data_mask = 0x"])
                 self.expect("register read code_mask", substrs=["code_mask = 0x"])
+            if "Scalable Matrix Extension Registers" in registerSet.GetName():
+                self.assertTrue(self.isAArch64SME(),
+                    "LLDB Enabled SME register set when it was disabled by target")
+            if "Scalable Matrix Array Storage Registers" in registerSet.GetName():
+                self.assertTrue(self.isAArch64SME(),
+                    "LLDB Enabled SME array storage register set when it was disabled by target.")
+
+    def make_za_value(self, vl, generator):
+        # Generate a vector value string "{0x00 0x01....}".
+        rows = []
+        for row in range(vl):
+            byte = "0x{:02x}".format(generator(row))
+            rows.append(" ".join([byte]*vl))
+        return "{" + " ".join(rows) + "}"
 
     @no_debug_info_test
     @skipIf(archs=no_match(["aarch64"]))
@@ -130,32 +152,58 @@
         if not self.isAArch64SME():
             self.skipTest("SME must be present.")
 
-        self.build()
-        self.line = line_number("main.c", "// Set a break point here.")
-
-        exe = self.getBuildArtifact("a.out")
-        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
-
-        lldbutil.run_break_set_by_file_and_line(
-            self, "main.c", self.line, num_expected_locations=1
-        )
-        self.runCmd("settings set target.run-args sme")
-        self.runCmd("run", RUN_SUCCEEDED)
-
-        self.expect(
-            "thread backtrace",
-            STOPPED_DUE_TO_BREAKPOINT,
-            substrs=["stop reason = breakpoint 1."],
-        )
-
-        target = self.dbg.GetSelectedTarget()
-        process = target.GetProcess()
-        thread = process.GetThreadAtIndex(0)
-        currentFrame = thread.GetFrameAtIndex(0)
-
-        register_sets = currentFrame.GetRegisters()
+        register_sets = self.setup_register_config_test("sme")
 
         ssve_registers = register_sets.GetFirstValueByName(
             "Scalable Vector Extension Registers")
         self.assertTrue(ssve_registers.IsValid())
         self.sve_regs_read_dynamic(ssve_registers)
+
+        za_register = register_sets.GetFirstValueByName(
+            "Scalable Matrix Array Storage Registers")
+        self.assertTrue(za_register.IsValid())
+        vg = ssve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned()
+        vl = vg * 8
+        # When first enabled it is all 0s.
+        self.expect("register read za", substrs=[self.make_za_value(vl, lambda r: 0)])
+        za_value = self.make_za_value(vl, lambda r:r+1)
+        self.runCmd("register write za '{}'".format(za_value))
+        self.expect("register read za", substrs=[za_value])
+
+        # SVG should match VG because we're in streaming mode.
+        sme_registers = register_sets.GetFirstValueByName(
+            "Scalable Matrix Extension Registers")
+        self.assertTrue(sme_registers.IsValid())
+        svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned()
+        self.assertEqual(vg, svg)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_aarch64_dynamic_regset_config_sme_za_disabled(self):
+        """Test that ZA shows as 0s when disabled and can be enabled by writing
+           to it."""
+        if not self.isAArch64SME():
+            self.skipTest("SME must be present.")
+
+        # No argument, so ZA will be disabled when we break.
+        register_sets = self.setup_register_config_test()
+
+        # vg is the non-streaming vg as we are in non-streaming mode, so we need
+        # to use svg.
+        sme_registers = register_sets.GetFirstValueByName(
+            "Scalable Matrix Extension Registers")
+        self.assertTrue(sme_registers.IsValid())
+        svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned()
+
+        za_register = register_sets.GetFirstValueByName(
+            "Scalable Matrix Array Storage Registers")
+        self.assertTrue(za_register.IsValid())
+        svl = svg * 8
+        # A disabled ZA is shown as all 0s.
+        self.expect("register read za", substrs=[self.make_za_value(svl, lambda r: 0)])
+        za_value = self.make_za_value(svl, lambda r:r+1)
+        # Writing to it enables ZA, so the value should be there when we read
+        # it back.
+        self.runCmd("register write za '{}'".format(za_value))
+        self.expect("register read za", substrs=[za_value])
_______________________________________________
lldb-commits mailing list
lldb-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits

Reply via email to