https://github.com/mmha created https://github.com/llvm/llvm-project/pull/159162
This patch enables calling virtual functions of virtual base classes of a derived class. >From a25ee7f9df6eb7c5d6b237eb5b072c182b4f0681 Mon Sep 17 00:00:00 2001 From: Morris Hafner <mhaf...@nvidia.com> Date: Mon, 8 Sep 2025 16:38:50 +0200 Subject: [PATCH] [CIR] Add virtual base support to getAddressOfBaseClass This patch enables calling virtual functions of virtual base classes of a derived class. --- clang/lib/CIR/CodeGen/CIRGenClass.cpp | 42 +++++++++----- clang/test/CIR/CodeGen/vbase.cpp | 82 ++++++++++++++++++++++++++- 2 files changed, 110 insertions(+), 14 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp index 0a8dc2b62fe21..08fdfad899a60 100644 --- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp @@ -865,28 +865,37 @@ Address CIRGenFunction::getAddressOfBaseClass( bool nullCheckValue, SourceLocation loc) { assert(!path.empty() && "Base path should not be empty!"); + CastExpr::path_const_iterator start = path.begin(); + const CXXRecordDecl *vBase = nullptr; + if ((*path.begin())->isVirtual()) { - // The implementation here is actually complete, but let's flag this - // as an error until the rest of the virtual base class support is in place. - cgm.errorNYI(loc, "getAddrOfBaseClass: virtual base"); - return Address::invalid(); + vBase = (*start)->getType()->castAsCXXRecordDecl(); + ++start; } // Compute the static offset of the ultimate destination within its // allocating subobject (the virtual base, if there is one, or else // the "complete" object that we see). - CharUnits nonVirtualOffset = - cgm.computeNonVirtualBaseClassOffset(derived, path); + CharUnits nonVirtualOffset = cgm.computeNonVirtualBaseClassOffset( + vBase ? vBase : derived, {start, path.end()}); + + // If there's a virtual step, we can sometimes "devirtualize" it. + // For now, that's limited to when the derived type is final. + // TODO: "devirtualize" this for accesses to known-complete objects. + if (vBase && derived->hasAttr<FinalAttr>()) { + const ASTRecordLayout &layout = getContext().getASTRecordLayout(derived); + CharUnits vBaseOffset = layout.getVBaseClassOffset(vBase); + nonVirtualOffset += vBaseOffset; + vBase = nullptr; // we no longer have a virtual step + } // Get the base pointer type. mlir::Type baseValueTy = convertType((path.end()[-1])->getType()); assert(!cir::MissingFeatures::addressSpace()); - // The if statement here is redundant now, but it will be needed when we add - // support for virtual base classes. // If there is no virtual base, use cir.base_class_addr. It takes care of // the adjustment and the null pointer check. - if (nonVirtualOffset.isZero()) { + if (nonVirtualOffset.isZero() && !vBase) { assert(!cir::MissingFeatures::sanitizers()); return builder.createBaseClassAddr(getLoc(loc), value, baseValueTy, 0, /*assumeNotNull=*/true); @@ -894,10 +903,17 @@ Address CIRGenFunction::getAddressOfBaseClass( assert(!cir::MissingFeatures::sanitizers()); - // Apply the offset - value = builder.createBaseClassAddr(getLoc(loc), value, baseValueTy, - nonVirtualOffset.getQuantity(), - /*assumeNotNull=*/true); + // Compute the virtual offset. + mlir::Value virtualOffset = nullptr; + if (vBase) { + virtualOffset = cgm.getCXXABI().getVirtualBaseClassOffset( + getLoc(loc), *this, value, derived, vBase); + } + + // Apply both offsets. + value = applyNonVirtualAndVirtualOffset( + getLoc(loc), *this, value, nonVirtualOffset, virtualOffset, derived, + vBase, baseValueTy, not nullCheckValue); // Cast to the destination type. value = value.withElementType(builder, baseValueTy); diff --git a/clang/test/CIR/CodeGen/vbase.cpp b/clang/test/CIR/CodeGen/vbase.cpp index 91396518a40b0..4d57f8ea74e0c 100644 --- a/clang/test/CIR/CodeGen/vbase.cpp +++ b/clang/test/CIR/CodeGen/vbase.cpp @@ -13,19 +13,29 @@ class Base { class Derived : public virtual Base {}; -// This is just here to force the record types to be emitted. void f() { Derived d; + d.f(); +} + +class DerivedFinal final : public virtual Base {}; + +void g() { + DerivedFinal df; + df.f(); } // CIR: !rec_Base = !cir.record<class "Base" {!cir.vptr}> // CIR: !rec_Derived = !cir.record<class "Derived" {!rec_Base}> +// CIR: !rec_DerivedFinal = !cir.record<class "DerivedFinal" {!rec_Base}> // LLVM: %class.Derived = type { %class.Base } // LLVM: %class.Base = type { ptr } +// LLVM: %class.DerivedFinal = type { %class.Base } // OGCG: %class.Derived = type { %class.Base } // OGCG: %class.Base = type { ptr } +// OGCG: %class.DerivedFinal = type { %class.Base } // Test the constructor handling for a class with a virtual base. struct A { @@ -47,6 +57,76 @@ void ppp() { B b; } // OGCG: @_ZTV1B = linkonce_odr unnamed_addr constant { [3 x ptr] } { [3 x ptr] [ptr inttoptr (i64 12 to ptr), ptr null, ptr @_ZTI1B] }, comdat, align 8 +// CIR: cir.func {{.*}}@_Z1fv() { +// CIR: %[[D:.+]] = cir.alloca !rec_Derived, !cir.ptr<!rec_Derived>, ["d", init] +// CIR: cir.call @_ZN7DerivedC1Ev(%[[D]]) nothrow : (!cir.ptr<!rec_Derived>) -> () +// CIR: %[[VPTR_PTR:.+]] = cir.vtable.get_vptr %[[D]] : !cir.ptr<!rec_Derived> -> !cir.ptr<!cir.vptr> +// CIR: %[[VPTR:.+]] = cir.load {{.*}} %[[VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR: %[[VPTR_I8:.+]] = cir.cast(bitcast, %[[VPTR]] : !cir.vptr), !cir.ptr<!u8i> +// CIR: %[[NEG32:.+]] = cir.const #cir.int<-32> : !s64i +// CIR: %[[ADJ_VPTR_I8:.+]] = cir.ptr_stride(%[[VPTR_I8]] : !cir.ptr<!u8i>, %[[NEG32]] : !s64i), !cir.ptr<!u8i> +// CIR: %[[OFFSET_PTR:.+]] = cir.cast(bitcast, %[[ADJ_VPTR_I8]] : !cir.ptr<!u8i>), !cir.ptr<!s64i> +// CIR: %[[OFFSET:.+]] = cir.load {{.*}} %[[OFFSET_PTR]] : !cir.ptr<!s64i>, !s64i +// CIR: %[[D_I8:.+]] = cir.cast(bitcast, %[[D]] : !cir.ptr<!rec_Derived>), !cir.ptr<!u8i> +// CIR: %[[ADJ_THIS_I8:.+]] = cir.ptr_stride(%[[D_I8]] : !cir.ptr<!u8i>, %[[OFFSET]] : !s64i), !cir.ptr<!u8i> +// CIR: %[[ADJ_THIS_D:.+]] = cir.cast(bitcast, %[[ADJ_THIS_I8]] : !cir.ptr<!u8i>), !cir.ptr<!rec_Derived> +// CIR: %[[BASE_THIS:.+]] = cir.cast(bitcast, %[[ADJ_THIS_D]] : !cir.ptr<!rec_Derived>), !cir.ptr<!rec_Base> +// CIR: %[[BASE_VPTR_PTR:.+]] = cir.vtable.get_vptr %[[BASE_THIS]] : !cir.ptr<!rec_Base> -> !cir.ptr<!cir.vptr> +// CIR: %[[BASE_VPTR:.+]] = cir.load {{.*}} %[[BASE_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR: %[[SLOT_PTR:.+]] = cir.vtable.get_virtual_fn_addr %[[BASE_VPTR]][0] : !cir.vptr -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>> +// CIR: %[[FN:.+]] = cir.load {{.*}} %[[SLOT_PTR]] : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>, !cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>> +// CIR: cir.call %[[FN]](%[[BASE_THIS]]) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>, !cir.ptr<!rec_Base>) -> () +// CIR: cir.return + +// CIR: cir.func {{.*}}@_Z1gv() { +// CIR: %[[DF:.+]] = cir.alloca !rec_DerivedFinal, !cir.ptr<!rec_DerivedFinal>, ["df", init] +// CIR: cir.call @_ZN12DerivedFinalC1Ev(%[[DF]]) nothrow : (!cir.ptr<!rec_DerivedFinal>) -> () +// CIR: %[[BASE_THIS_2:.+]] = cir.base_class_addr %[[DF]] : !cir.ptr<!rec_DerivedFinal> nonnull [0] -> !cir.ptr<!rec_Base> +// CIR: %[[BASE_VPTR_PTR_2:.+]] = cir.vtable.get_vptr %[[BASE_THIS_2]] : !cir.ptr<!rec_Base> -> !cir.ptr<!cir.vptr> +// CIR: %[[BASE_VPTR_2:.+]] = cir.load {{.*}} %[[BASE_VPTR_PTR_2]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR: %[[SLOT_PTR_2:.+]] = cir.vtable.get_virtual_fn_addr %[[BASE_VPTR_2]][0] : !cir.vptr -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>> +// CIR: %[[FN_2:.+]] = cir.load {{.*}} %[[SLOT_PTR_2]] : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>, !cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>> +// CIR: cir.call %[[FN_2]](%[[BASE_THIS_2]]) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>, !cir.ptr<!rec_Base>) -> () +// CIR: cir.return + +// LLVM: define {{.*}}void @_Z1fv() +// LLVM: %[[D:.+]] = alloca {{.*}} +// LLVM: call void @_ZN7DerivedC1Ev(ptr %[[D]]) +// LLVM: %[[VPTR_ADDR:.+]] = load ptr, ptr %[[D]] +// LLVM: %[[NEG32_PTR:.+]] = getelementptr i8, ptr %[[VPTR_ADDR]], i64 -32 +// LLVM: %[[OFF:.+]] = load i64, ptr %[[NEG32_PTR]] +// LLVM: %[[ADJ_THIS:.+]] = getelementptr i8, ptr %[[D]], i64 %[[OFF]] +// LLVM: %[[VFN_TAB:.+]] = load ptr, ptr %[[ADJ_THIS]] +// LLVM: %[[SLOT0:.+]] = getelementptr inbounds ptr, ptr %[[VFN_TAB]], i32 0 +// LLVM: %[[VFN:.+]] = load ptr, ptr %[[SLOT0]] +// LLVM: call void %[[VFN]](ptr %[[ADJ_THIS]]) +// LLVM: ret void + +// LLVM: define {{.*}}void @_Z1gv() +// LLVM: %[[DF:.+]] = alloca {{.*}} +// LLVM: call void @_ZN12DerivedFinalC1Ev(ptr %[[DF]]) +// LLVM: %[[VPTR2:.+]] = load ptr, ptr %[[DF]] +// LLVM: %[[SLOT0_2:.+]] = getelementptr inbounds ptr, ptr %[[VPTR2]], i32 0 +// LLVM: %[[VFN2:.+]] = load ptr, ptr %[[SLOT0_2]] +// LLVM: call void %[[VFN2]](ptr %[[DF]]) +// LLVM: ret void + +// OGCG: define {{.*}}void @_Z1fv() +// OGCG: %[[D:.+]] = alloca {{.*}} +// OGCG: call void @_ZN7DerivedC1Ev(ptr {{.*}} %[[D]]) +// OGCG: %[[VTABLE:.+]] = load ptr, ptr %[[D]] +// OGCG: %[[NEG32_PTR:.+]] = getelementptr i8, ptr %[[VTABLE]], i64 -32 +// OGCG: %[[OFF:.+]] = load i64, ptr %[[NEG32_PTR]] +// OGCG: %[[ADJ_THIS:.+]] = getelementptr inbounds i8, ptr %[[D]], i64 %[[OFF]] +// OGCG: call void @_ZN4Base1fEv(ptr {{.*}} %[[ADJ_THIS]]) +// OGCG: ret void + +// OGCG: define {{.*}}void @_Z1gv() +// OGCG: %[[DF:.+]] = alloca {{.*}} +// OGCG: call void @_ZN12DerivedFinalC1Ev(ptr {{.*}} %[[DF]]) +// OGCG: call void @_ZN4Base1fEv(ptr {{.*}} %[[DF]]) +// OGCG: ret void + // Constructor for B // CIR: cir.func comdat linkonce_odr @_ZN1BC1Ev(%arg0: !cir.ptr<!rec_B> // CIR: %[[THIS_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>, ["this", init] _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits