beanz created this revision.
beanz added reviewers: aaron.ballman, kuhar, bogner, rnk, python3kgae.
Herald added subscribers: Anastasia, jeroen.dobbelaere, mgorny.
Herald added a project: All.
beanz requested review of this revision.
Herald added a project: clang.

HLSL vector types are ext_vector types, but they are also exposed via a
template syntax `vector<T, #>`. This is morally equavalent to the code:

  c++
  template <typename T, int Size>
  using vector = T __attribute__((ext_vector_type(Size)))

The problem is that templates aren't supported before HLSL 2021, and
type aliases still aren't supported in HLSL.

To resolve this (and other issues where HLSL can't represent its own
types), we rely on an external AST & Sema source being registered for
HLSL code.

This patch adds the HLSLExternalSemaSource and registers the vector
type alias.

Depends on D127802 <https://reviews.llvm.org/D127802>


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D128012

Files:
  clang/include/clang/Sema/HLSLExternalSemaSource.h
  clang/lib/Frontend/FrontendAction.cpp
  clang/lib/Headers/hlsl/hlsl_basic_types.h
  clang/lib/Sema/CMakeLists.txt
  clang/lib/Sema/HLSLExternalSemaSource.cpp
  clang/test/AST/HLSL/vector-alias.hlsl

Index: clang/test/AST/HLSL/vector-alias.hlsl
===================================================================
--- /dev/null
+++ clang/test/AST/HLSL/vector-alias.hlsl
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s 
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'element __attribute__((ext_vector_type(element_count)))'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'element __attribute__((ext_vector_type(element_count)))' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+  // Verify that the alias is generated inside the hlsl namespace.
+  hlsl::vector<float, 2> Vec2 = {1.0, 2.0};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:24:3, col:43>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:42> col:26 Vec2 'hlsl::vector<float, 2>':'float __attribute__((ext_vector_type(2)))' cinit
+
+  // Verify that you don't need to specify the namespace.
+  vector<int, 2> Vec2a = {1, 2};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:30:3, col:32>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:31> col:18 Vec2a 'vector<int, 2>':'int __attribute__((ext_vector_type(2)))' cinit
+
+  // Build a bigger vector.
+  vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:36:3, col:48>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:47> col:21 used Vec4 'vector<double, 4>':'double __attribute__((ext_vector_type(4)))' cinit
+
+  // Verify that swizzles still work.
+  vector<double, 3> Vec3 = Vec4.xyz;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:42:3, col:36>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:33> col:21 Vec3 'vector<double, 3>':'double __attribute__((ext_vector_type(3)))' cinit
+
+  // Verify that the implicit arguments generate the correct type.
+  vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:48:3, col:42>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:41> col:12 ImpVec4 'vector<>':'float __attribute__((ext_vector_type(4)))' cinit
+  return 1;
+}
Index: clang/lib/Sema/HLSLExternalSemaSource.cpp
===================================================================
--- /dev/null
+++ clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -0,0 +1,98 @@
+//===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Sema/HLSLExternalSemaSource.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/Basic/AttrKinds.h"
+#include "clang/Sema/Sema.h"
+
+using namespace clang;
+
+char HLSLExternalSemaSource::ID;
+
+HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
+
+void HLSLExternalSemaSource::InitializeSema(Sema &S) {
+  SemaPtr = &S;
+  ASTContext &AST = SemaPtr->getASTContext();
+  IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier);
+  HLSLNamespace =
+      NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), false,
+                            SourceLocation(), SourceLocation(), &HLSL, nullptr);
+  HLSLNamespace->setImplicit(true);
+  AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
+  defineHLSLVectorAlias();
+
+  // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
+  // built in types inside a namespace, but we are planning to change that in
+  // the near future. In order to be source compatible older versions of HLSL
+  // will need to implicitly use the hlsl namespace. For now in clang everything
+  // will get added to the namespace, and we can remove the using directive for
+  // future language versions to match HLSL's evolution.
+  auto *UsingDecl = UsingDirectiveDecl::Create(
+      AST, AST.getTranslationUnitDecl(), SourceLocation(), SourceLocation(),
+      NestedNameSpecifierLoc(), SourceLocation(), HLSLNamespace,
+      AST.getTranslationUnitDecl());
+
+  AST.getTranslationUnitDecl()->addDecl(UsingDecl);
+}
+
+void HLSLExternalSemaSource::defineHLSLVectorAlias() {
+  ASTContext &AST = SemaPtr->getASTContext();
+
+  llvm::SmallVector<NamedDecl *> TemplateArgs;
+
+  auto *TypeParam = TemplateTypeParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+  TypeParam->setDefaultArgument(AST.getTrivialTypeSourceInfo(AST.FloatTy));
+
+  TemplateArgs.emplace_back(TypeParam);
+
+  auto *SizeParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+      &AST.Idents.get("element_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  Expr *LiteralExpr =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getIntWidth(AST.IntTy), 4),
+                             AST.IntTy, SourceLocation());
+  SizeParam->setDefaultArgument(LiteralExpr);
+  TemplateArgs.emplace_back(SizeParam);
+
+  auto *ParamList =
+      TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+                                    TemplateArgs, SourceLocation(), nullptr);
+
+  IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier);
+
+  QualType AliasType = AST.getDependentSizedExtVectorType(
+      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), SizeParam, false,
+          DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      SourceLocation());
+
+  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                       SourceLocation(), &II,
+                                       AST.getTrivialTypeSourceInfo(AliasType));
+  Record->setImplicit(true);
+
+  auto *Template =
+      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                    Record->getIdentifier(), ParamList, Record);
+
+  Record->setDescribedAliasTemplate(Template);
+  Template->setImplicit(true);
+  Template->setLexicalDeclContext(Record->getDeclContext());
+  HLSLNamespace->addDecl(Template);
+}
Index: clang/lib/Sema/CMakeLists.txt
===================================================================
--- clang/lib/Sema/CMakeLists.txt
+++ clang/lib/Sema/CMakeLists.txt
@@ -15,6 +15,7 @@
   CodeCompleteConsumer.cpp
   DeclSpec.cpp
   DelayedDiagnostic.cpp
+  HLSLExternalSemaSource.cpp
   IdentifierResolver.cpp
   JumpDiagnostics.cpp
   MultiplexExternalSemaSource.cpp
Index: clang/lib/Headers/hlsl/hlsl_basic_types.h
===================================================================
--- clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -27,38 +27,38 @@
 // built-in vector data types:
 
 #ifdef __HLSL_ENABLE_16_BIT
-typedef int16_t int16_t2 __attribute__((ext_vector_type(2)));
-typedef int16_t int16_t3 __attribute__((ext_vector_type(3)));
-typedef int16_t int16_t4 __attribute__((ext_vector_type(4)));
-typedef uint16_t uint16_t2 __attribute__((ext_vector_type(2)));
-typedef uint16_t uint16_t3 __attribute__((ext_vector_type(3)));
-typedef uint16_t uint16_t4 __attribute__((ext_vector_type(4)));
+typedef vector<int16_t, 2> int16_t2;
+typedef vector<int16_t, 3> int16_t3;
+typedef vector<int16_t, 4> int16_t4;
+typedef vector<uint16_t, 2> uint16_t2;
+typedef vector<uint16_t, 3> uint16_t3;
+typedef vector<uint16_t, 4> uint16_t4;
 #endif
 
-typedef int int2 __attribute__((ext_vector_type(2)));
-typedef int int3 __attribute__((ext_vector_type(3)));
-typedef int int4 __attribute__((ext_vector_type(4)));
-typedef uint uint2 __attribute__((ext_vector_type(2)));
-typedef uint uint3 __attribute__((ext_vector_type(3)));
-typedef uint uint4 __attribute__((ext_vector_type(4)));
-typedef int64_t int64_t2 __attribute__((ext_vector_type(2)));
-typedef int64_t int64_t3 __attribute__((ext_vector_type(3)));
-typedef int64_t int64_t4 __attribute__((ext_vector_type(4)));
-typedef uint64_t uint64_t2 __attribute__((ext_vector_type(2)));
-typedef uint64_t uint64_t3 __attribute__((ext_vector_type(3)));
-typedef uint64_t uint64_t4 __attribute__((ext_vector_type(4)));
+typedef vector<int, 2> int2;
+typedef vector<int, 3> int3;
+typedef vector<int, 4> int4;
+typedef vector<uint, 2> uint2;
+typedef vector<uint, 3> uint3;
+typedef vector<uint, 4> uint4;
+typedef vector<int64_t, 2> int64_t2;
+typedef vector<int64_t, 3> int64_t3;
+typedef vector<int64_t, 4> int64_t4;
+typedef vector<uint64_t, 2> uint64_t2;
+typedef vector<uint64_t, 3> uint64_t3;
+typedef vector<uint64_t, 4> uint64_t4;
 
 #ifdef __HLSL_ENABLE_16_BIT
-typedef half half2 __attribute__((ext_vector_type(2)));
-typedef half half3 __attribute__((ext_vector_type(3)));
-typedef half half4 __attribute__((ext_vector_type(4)));
+typedef vector<half, 2> half2;
+typedef vector<half, 3> half3;
+typedef vector<half, 4> half4;
 #endif
 
-typedef float float2 __attribute__((ext_vector_type(2)));
-typedef float float3 __attribute__((ext_vector_type(3)));
-typedef float float4 __attribute__((ext_vector_type(4)));
-typedef double double2 __attribute__((ext_vector_type(2)));
-typedef double double3 __attribute__((ext_vector_type(3)));
-typedef double double4 __attribute__((ext_vector_type(4)));
+typedef vector<float, 2> float2;
+typedef vector<float, 3> float3;
+typedef vector<float, 4> float4;
+typedef vector<double, 2> double2;
+typedef vector<double, 3> double3;
+typedef vector<double, 4> double4;
 
 #endif //_HLSL_HLSL_BASIC_TYPES_H_
Index: clang/lib/Frontend/FrontendAction.cpp
===================================================================
--- clang/lib/Frontend/FrontendAction.cpp
+++ clang/lib/Frontend/FrontendAction.cpp
@@ -24,6 +24,7 @@
 #include "clang/Lex/Preprocessor.h"
 #include "clang/Lex/PreprocessorOptions.h"
 #include "clang/Parse/ParseAST.h"
+#include "clang/Sema/HLSLExternalSemaSource.h"
 #include "clang/Serialization/ASTDeserializationListener.h"
 #include "clang/Serialization/ASTReader.h"
 #include "clang/Serialization/GlobalModuleIndex.h"
@@ -1014,6 +1015,13 @@
     CI.getASTContext().setExternalSource(Override);
   }
 
+  // Setup HLSL External Sema Source
+  if (CI.getLangOpts().HLSL && CI.hasASTContext()) {
+    IntrusiveRefCntPtr<ExternalASTSource> HLSLSema(
+        new HLSLExternalSemaSource());
+    CI.getASTContext().setExternalSource(HLSLSema);
+  }
+
   FailureCleanup.release();
   return true;
 }
Index: clang/include/clang/Sema/HLSLExternalSemaSource.h
===================================================================
--- /dev/null
+++ clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -0,0 +1,43 @@
+//===--- HLSLExternalSemaSource.h - HLSL Sema Source ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines the HLSLExternalSemaSource interface.
+//
+//===----------------------------------------------------------------------===//
+#ifndef CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H
+#define CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H
+
+#include "clang/Sema/ExternalSemaSource.h"
+
+namespace clang {
+class NamespaceDecl;
+class Sema;
+
+class HLSLExternalSemaSource : public ExternalSemaSource {
+  static char ID;
+
+  Sema *SemaPtr = nullptr;
+  NamespaceDecl *HLSLNamespace;
+
+  void defineHLSLVectorAlias();
+
+public:
+  ~HLSLExternalSemaSource() override;
+
+  /// Initialize the semantic source with the Sema instance
+  /// being used to perform semantic analysis on the abstract syntax
+  /// tree.
+  void InitializeSema(Sema &S) override;
+
+  /// Inform the semantic consumer that Sema is no longer available.
+  void ForgetSema() override { SemaPtr = nullptr; }
+};
+
+} // namespace clang
+
+#endif // CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to