================
@@ -0,0 +1,364 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Unit tests for CIR implementation of OpenACC's PointertLikeType interface
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Value.h"
+#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/IR/CIRTypes.h"
+#include "clang/CIR/Dialect/OpenACC/CIROpenACCTypeInterfaces.h"
+#include "clang/CIR/Dialect/OpenACC/RegisterOpenACCExtensions.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace cir;
+
+//===----------------------------------------------------------------------===//
+// Test Fixture
+//===----------------------------------------------------------------------===//
+
+class CIROpenACCPointerLikeTest : public ::testing::Test {
+protected:
+  CIROpenACCPointerLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
+    context.loadDialect<cir::CIRDialect>();
+    context.loadDialect<mlir::acc::OpenACCDialect>();
+
+    // Register extension to integrate CIR types with OpenACC.
+    mlir::DialectRegistry registry;
+    cir::acc::registerOpenACCExtensions(registry);
+    context.appendDialectRegistry(registry);
+  }
+
+  MLIRContext context;
+  OpBuilder b;
+  Location loc;
+  llvm::StringMap<unsigned> recordNames;
+
+  mlir::IntegerAttr getAlignOne(mlir::MLIRContext *ctx) {
+    // Note that mlir::IntegerType is used instead of cir::IntType here
+    // because we don't need sign information for this to be useful, so keep
+    // it simple.
+    clang::CharUnits align = clang::CharUnits::One();
+    return mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64),
+                                  align.getQuantity());
+  }
+
+  mlir::StringAttr getUniqueRecordName(const std::string &baseName) {
+    auto it = recordNames.find(baseName);
+    if (it == recordNames.end()) {
+      recordNames[baseName] = 0;
+      return b.getStringAttr(baseName);
+    }
+
+    return b.getStringAttr(baseName + "." +
+                           std::to_string(recordNames[baseName]++));
+  }
+
+  // General handler for types without a specific test
+  void testSingleType(mlir::Type ty,
+                      mlir::acc::VariableTypeCategory expectedTypeCategory) {
+    mlir::Type ptrTy = cir::PointerType::get(ty);
+
+    // cir::PointerType should be castable to acc::PointerLikeType
+    auto pltTy = dyn_cast_if_present<mlir::acc::PointerLikeType>(ptrTy);
+    ASSERT_NE(pltTy, nullptr);
+
+    EXPECT_EQ(pltTy.getElementType(), ty);
+
+    OwningOpRef<cir::AllocaOp> varPtrOp =
+        b.create<cir::AllocaOp>(loc, ptrTy, ty, "", getAlignOne(&context));
+
+    mlir::Value val = varPtrOp.get();
+    mlir::acc::VariableTypeCategory typeCategory = 
pltTy.getPointeeTypeCategory(
+        cast<TypedValue<mlir::acc::PointerLikeType>>(val),
+        mlir::acc::getVarType(varPtrOp.get()));
+
+    EXPECT_EQ(typeCategory, expectedTypeCategory);
+  }
+
+  void testScalarType(mlir::Type ty) {
+    testSingleType(ty, mlir::acc::VariableTypeCategory::scalar);
+  }
+
+  void testNonScalarType(mlir::Type ty) {
+    testSingleType(ty, mlir::acc::VariableTypeCategory::nonscalar);
+  }
+
+  void testUncategorizedType(mlir::Type ty) {
+    testSingleType(ty, mlir::acc::VariableTypeCategory::uncategorized);
+  }
+
+  void testArrayType(mlir::Type ty) {
+    // Build the array pointer type.
+    mlir::Type arrTy = cir::ArrayType::get(ty, 10);
+    mlir::Type ptrTy = cir::PointerType::get(arrTy);
+
+    // Verify that the pointer points to the array type..
+    auto pltTy = dyn_cast_if_present<mlir::acc::PointerLikeType>(ptrTy);
+    ASSERT_NE(pltTy, nullptr);
+    EXPECT_EQ(pltTy.getElementType(), arrTy);
+
+    // Create an alloca for the array
+    OwningOpRef<cir::AllocaOp> varPtrOp =
+        b.create<cir::AllocaOp>(loc, ptrTy, arrTy, "", getAlignOne(&context));
+
+    // Verify that the type category is array.
+    mlir::Value val = varPtrOp.get();
+    mlir::acc::VariableTypeCategory typeCategory = 
pltTy.getPointeeTypeCategory(
+        cast<TypedValue<mlir::acc::PointerLikeType>>(val),
+        mlir::acc::getVarType(varPtrOp.get()));
+    EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::array);
+
+    // Create an array-to-pointer decay cast.
+    mlir::Type ptrToElemTy = cir::PointerType::get(ty);
+    OwningOpRef<cir::CastOp> decayPtr = b.create<cir::CastOp>(
+        loc, ptrToElemTy, cir::CastKind::array_to_ptrdecay, val);
+    mlir::Value decayVal = decayPtr.get();
+
+    // Verify that we still get the expected element type.
+    auto decayPltTy =
+        dyn_cast_if_present<mlir::acc::PointerLikeType>(decayVal.getType());
+    ASSERT_NE(decayPltTy, nullptr);
+    EXPECT_EQ(decayPltTy.getElementType(), ty);
+
+    // Verify that we still identify the type category as an array.
+    mlir::acc::VariableTypeCategory decayTypeCategory =
+        decayPltTy.getPointeeTypeCategory(
+            cast<TypedValue<mlir::acc::PointerLikeType>>(decayVal),
+            mlir::acc::getVarType(decayPtr.get()));
+    EXPECT_EQ(decayTypeCategory, mlir::acc::VariableTypeCategory::array);
+
+    // Create an element access.
+    mlir::Type i32Ty = cir::IntType::get(&context, 32, true);
+    mlir::Value index =
+        b.create<cir::ConstantOp>(loc, cir::IntAttr::get(i32Ty, 2));
+    OwningOpRef<cir::PtrStrideOp> accessPtr =
+        b.create<cir::PtrStrideOp>(loc, ptrToElemTy, decayVal, index);
+    mlir::Value accessVal = accessPtr.get();
+
+    // Verify that we still get the expected element type.
+    auto accessPltTy =
+        dyn_cast_if_present<mlir::acc::PointerLikeType>(accessVal.getType());
+    ASSERT_NE(accessPltTy, nullptr);
+    EXPECT_EQ(accessPltTy.getElementType(), ty);
+
+    // Verify that we still identify the type category as an array.
+    mlir::acc::VariableTypeCategory accessTypeCategory =
+        accessPltTy.getPointeeTypeCategory(
+            cast<TypedValue<mlir::acc::PointerLikeType>>(accessVal),
+            mlir::acc::getVarType(accessPtr.get()));
+    EXPECT_EQ(accessTypeCategory, mlir::acc::VariableTypeCategory::array);
+  }
+
+  // Structures and unions are accessed in the same way, so use a common test.
+  void testRecordType(mlir::Type ty1, mlir::Type ty2,
+                      cir::RecordType::RecordKind kind) {
+    // Build the structure pointer type.
+    cir::RecordType structTy =
+        cir::RecordType::get(&context, getUniqueRecordName("S"), kind);
+    structTy.complete({ty1, ty2}, false, false);
+    mlir::Type ptrTy = cir::PointerType::get(structTy);
+
+    // Verify that the pointer points to the structure type.
+    auto pltTy = dyn_cast_if_present<mlir::acc::PointerLikeType>(ptrTy);
+    ASSERT_NE(pltTy, nullptr);
+    EXPECT_EQ(pltTy.getElementType(), structTy);
+
+    // Create an alloca for the array
+    OwningOpRef<cir::AllocaOp> varPtrOp = b.create<cir::AllocaOp>(
+        loc, ptrTy, structTy, "", getAlignOne(&context));
+
+    // Verify that the type category is composite.
+    mlir::Value val = varPtrOp.get();
+    mlir::acc::VariableTypeCategory typeCategory = 
pltTy.getPointeeTypeCategory(
+        cast<TypedValue<mlir::acc::PointerLikeType>>(val),
+        mlir::acc::getVarType(varPtrOp.get()));
+    EXPECT_EQ(typeCategory, mlir::acc::VariableTypeCategory::composite);
+
+    // Access the first element of the structure.
+    OwningOpRef<cir::GetMemberOp> access1 = b.create<cir::GetMemberOp>(
+        loc, cir::PointerType::get(ty1), val, b.getStringAttr("f1"), 0);
+    mlir::Value accessVal1 = access1.get();
+
+    // Verify that we get the expected element type.
+    auto access1PltTy =
+        dyn_cast_if_present<mlir::acc::PointerLikeType>(accessVal1.getType());
+    ASSERT_NE(access1PltTy, nullptr);
+    EXPECT_EQ(access1PltTy.getElementType(), ty1);
+
+    // Verify that the type category is still composite.
+    mlir::acc::VariableTypeCategory access1TypeCategory =
+        access1PltTy.getPointeeTypeCategory(
+            cast<TypedValue<mlir::acc::PointerLikeType>>(accessVal1),
+            mlir::acc::getVarType(access1.get()));
+    EXPECT_EQ(access1TypeCategory, mlir::acc::VariableTypeCategory::composite);
+
+    // Access the second element of the structure.
+    OwningOpRef<cir::GetMemberOp> access2 = b.create<cir::GetMemberOp>(
+        loc, cir::PointerType::get(ty2), val, b.getStringAttr("f2"), 1);
+    mlir::Value accessVal2 = access2.get();
+
+    // Verify that we get the expected element type.
+    auto access2PltTy =
+        dyn_cast_if_present<mlir::acc::PointerLikeType>(accessVal2.getType());
+    ASSERT_NE(access2PltTy, nullptr);
+    EXPECT_EQ(access2PltTy.getElementType(), ty2);
+
+    // Verify that the type category is still composite.
+    mlir::acc::VariableTypeCategory access2TypeCategory =
+        access2PltTy.getPointeeTypeCategory(
+            cast<TypedValue<mlir::acc::PointerLikeType>>(accessVal2),
+            mlir::acc::getVarType(access2.get()));
+    EXPECT_EQ(access2TypeCategory, mlir::acc::VariableTypeCategory::composite);
+  }
+
+  void testStructType(mlir::Type ty1, mlir::Type ty2) {
+    testRecordType(ty1, ty2, cir::RecordType::RecordKind::Struct);
+  }
+
+  void testUnionType(mlir::Type ty1, mlir::Type ty2) {
+    testRecordType(ty1, ty2, cir::RecordType::RecordKind::Union);
+  }
+
+  // This is testing a case like this:
+  //
+  // struct S {
+  //   int *f1;
+  //   int *f2;
+  // } *p;
+  // int *pMember = p->f2;
----------------
andykaylor wrote:

I talked this over with @razvanlupusoru offline, and if I understood what he 
was telling me correctly this case ins't reachable from OpenACC source code. 
What I'm testing here is the pointer loaded from `p->f2` but if one were to put 
`pMember` in a OpenACC copy clause, we'd get a subsequent load from the 
`cir.alloca` instruction (which is covered by other tests), not the load from 
`p->f2` which is what I'm testing here. My understanding is that if you put 
`p->f2` in the OpenACC copy clause, we'd be using the result of the 
`cir.get_member` operation, which is also covered by other tests.

I added this test because I wanted to cover the behavior of checking arbitrary 
`cir.load` operations in the middle of the CIR, but it sounds like that's not 
required by the `acc` dialect. Does that mean the "correct" result for this 
case is unspecified?

https://github.com/llvm/llvm-project/pull/139768
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to