================
@@ -0,0 +1,586 @@
+//===--- HLSLEmitter.cpp - HLSL intrinsic header generator 
----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This tablegen backend generates hlsl_alias_intrinsics_gen.inc (alias
+// overloads) and hlsl_inline_intrinsics_gen.inc (inline/detail overloads) for
+// HLSL intrinsic functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace llvm;
+
+namespace {
+
+/// Minimum shader model version that supports 16-bit types.
+constexpr StringLiteral SM6_2 = "6.2";
+
+//===----------------------------------------------------------------------===//
+// Type name helpers
+//===----------------------------------------------------------------------===//
+
+static std::string getVectorTypeName(StringRef ElemType, unsigned N) {
+  return (ElemType + Twine(N)).str();
+}
+
+static std::string getMatrixTypeName(StringRef ElemType, unsigned Rows,
+                                     unsigned Cols) {
+  return (ElemType + Twine(Rows) + "x" + Twine(Cols)).str();
+}
+
+/// Get the fixed type name string for a VectorType or HLSLType record.
+static std::string getFixedTypeName(const Record *R) {
+  if (R->isSubClassOf("VectorType"))
+    return getVectorTypeName(
+        R->getValueAsDef("ElementType")->getValueAsString("Name"),
+        R->getValueAsInt("Size"));
+  assert(R->isSubClassOf("HLSLType"));
+  return R->getValueAsString("Name").str();
+}
+
+/// For a VectorType, return its ElementType record; for an HLSLType, return
+/// the record itself (it is already a scalar element type).
+static const Record *getElementTypeRecord(const Record *R) {
+  if (R->isSubClassOf("VectorType"))
+    return R->getValueAsDef("ElementType");
+  assert(R->isSubClassOf("HLSLType"));
+  return R;
+}
+
+//===----------------------------------------------------------------------===//
+// Type information
+//===----------------------------------------------------------------------===//
+
+/// Classifies how a type varies across overloads.
+enum TypeKindEnum {
+  TK_Varying = 0,      ///< Type matches the full varying type (e.g. float3).
+  TK_ElemType = 1,     ///< Type is the scalar element type (e.g. float).
+  TK_VaryingShape = 2, ///< Type uses the varying shape with a fixed element.
+  TK_FixedType = 3,    ///< Type is a fixed concrete type (e.g. "half2").
+  TK_Void = 4          ///< Type is void (only valid for return types).
+};
+
+/// Metadata describing how a type (argument or return) varies across 
overloads.
+struct TypeInfo {
+  /// Classification of how this type varies across overloads.
+  TypeKindEnum Kind = TK_Varying;
+
+  /// Fixed type name (e.g. "half2") for types with a concrete type that does
+  /// not vary across overloads. Empty for varying types.
+  std::string FixedType;
+
+  /// Element type name for TK_VaryingShape types (e.g. "bool" for
+  /// VaryingShape<BoolTy>). Empty for other type kinds.
+  StringRef ShapeElemType;
+
+  /// Explicit parameter name (e.g. "eta"). Empty to use the default "p0",
+  /// "p1", ... naming. Only meaningful for argument types.
+  StringRef Name;
+
+  /// Construct a TypeInfo from a TableGen record.
+  static TypeInfo resolve(const Record *Rec) {
+    TypeInfo TI;
+    if (Rec->getName() == "VoidTy") {
+      TI.Kind = TK_Void;
+    } else if (Rec->getName() == "Varying") {
+      TI.Kind = TK_Varying;
+    } else if (Rec->getName() == "VaryingElemType") {
+      TI.Kind = TK_ElemType;
+    } else if (Rec->isSubClassOf("VaryingShape")) {
+      TI.Kind = TK_VaryingShape;
+      TI.ShapeElemType =
+          Rec->getValueAsDef("ElementType")->getValueAsString("Name");
+    } else if (Rec->isSubClassOf("VectorType") ||
+               Rec->isSubClassOf("HLSLType")) {
+      TI.Kind = TK_FixedType;
+      TI.FixedType = getFixedTypeName(Rec);
+    } else {
+      llvm_unreachable("unhandled record for type resolution");
+    }
+    return TI;
+  }
+
+  /// Resolve this type to a concrete type name string.
+  /// \p ElemType is the scalar element type for the current overload.
+  /// \p FormatVarying formats a scalar element type into the shaped type name.
+  std::string
+  toTypeString(StringRef ElemType,
+               function_ref<std::string(StringRef)> FormatVarying) const {
+    switch (Kind) {
+    case TK_Void:
+      return "void";
+    case TK_Varying:
+      return FormatVarying(ElemType);
+    case TK_ElemType:
+      return ElemType.str();
+    case TK_VaryingShape:
+      return FormatVarying(ShapeElemType);
+    case TK_FixedType:
+      assert(!FixedType.empty() && "TK_FixedType requires non-empty 
FixedType");
+      return FixedType;
+    }
+    llvm_unreachable("unhandled TypeKindEnum");
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Availability helpers
+//===----------------------------------------------------------------------===//
+
+static void emitAvailability(raw_ostream &OS, StringRef Version,
+                             bool Use16Bit = false) {
+  if (Use16Bit) {
+    OS << "_HLSL_16BIT_AVAILABILITY(shadermodel, " << Version << ")\n";
+  } else {
+    OS << "_HLSL_AVAILABILITY(shadermodel, " << Version << ")\n";
+  }
+}
+static std::string getVersionString(const Record *SM) {
+  unsigned Major = SM->getValueAsInt("Major");
+  unsigned Minor = SM->getValueAsInt("Minor");
+  if (Major == 0 && Minor == 0)
+    return "";
+  return (Twine(Major) + "." + Twine(Minor)).str();
+}
+
+//===----------------------------------------------------------------------===//
+// Type work item — describes one element type to emit overloads for
+//===----------------------------------------------------------------------===//
+
+/// A single entry in the worklist of types to process for an intrinsic.
+struct TypeWorkItem {
+  /// Element type name (e.g. "half", "float"). Empty for fixed-arg-only
+  /// intrinsics with no type expansion.
+  StringRef ElemType;
+
+  /// Version string for the availability attribute (e.g. "6.2"). Empty if
+  /// no availability annotation is needed.
+  StringRef Availability;
+
+  /// If true, emit _HLSL_16BIT_AVAILABILITY instead of _HLSL_AVAILABILITY.
+  bool Use16BitAvail = false;
+
+  /// If true, wrap overloads in #ifdef __HLSL_ENABLE_16_BIT / #endif.
+  bool NeedsIfdefGuard = false;
+};
+
+/// Fixed canonical ordering for overload types. Types are grouped as:
+///   0: conditionally-16-bit (half)
+///   1-2: 16-bit integers (int16_t, uint16_t) — ifdef-guarded
+///   3+: regular types (bool, int, uint, int64_t, uint64_t, float, double)
+/// Within each group, signed precedes unsigned, smaller precedes larger,
+/// and integer types precede floating-point types.
+static int getTypeSortPriority(const Record *ET) {
+  return StringSwitch<int>(ET->getValueAsString("Name"))
+      .Case("half", 0)
+      .Case("int16_t", 1)
+      .Case("uint16_t", 2)
+      .Case("bool", 3)
+      .Case("int", 4)
+      .Case("uint", 5)
+      .Case("uint32_t", 6)
+      .Case("int64_t", 7)
+      .Case("uint64_t", 8)
+      .Case("float", 9)
+      .Case("double", 10)
+      .Default(11);
+}
+
+//===----------------------------------------------------------------------===//
+// Overload context — shared state across all overloads of one intrinsic
+//===----------------------------------------------------------------------===//
+
+/// Shared state for emitting all overloads of a single HLSL intrinsic.
+struct OverloadContext {
+  /// Output stream to write generated code to.
+  raw_ostream &OS;
+
+  /// Builtin name for _HLSL_BUILTIN_ALIAS (e.g. "__builtin_hlsl_dot").
+  /// Empty for inline/detail intrinsics.
+  StringRef Builtin;
+
+  /// __detail helper function to call (e.g. "refract_impl").
+  /// Empty for alias and inline-body intrinsics.
+  StringRef DetailFunc;
+
+  /// Literal inline function body (e.g. "return p0;").
+  /// Empty for alias and detail intrinsics.
+  StringRef Body;
+
+  /// The HLSL function name to emit (e.g. "dot", "refract").
+  StringRef FuncName;
+
+  /// Metadata describing the return type and its variation behavior.
+  TypeInfo RetType;
+
+  /// Per-argument metadata describing type and variation behavior.
+  SmallVector<TypeInfo, 4> Args;
+
+  /// Whether to emit the function as constexpr.
+  bool IsConstexpr = false;
+
+  /// Whether to emit the __attribute__((convergent)) annotation.
+  bool IsConvergent = false;
+
+  /// Whether any fixed arg has a 16-bit integer type (e.g. int16_t).
+  bool Uses16BitType = false;
+
+  /// Whether any fixed arg has a conditionally-16-bit type (half).
+  bool UsesConditionally16BitType = false;
+
+  explicit OverloadContext(raw_ostream &OS) : OS(OS) {}
+};
+
+/// Emit a complete function declaration or definition with pre-resolved types.
+static void emitDeclaration(const OverloadContext &Ctx, StringRef RetType,
+                            ArrayRef<std::string> ArgTypes) {
+  raw_ostream &OS = Ctx.OS;
+  bool IsDetail = !Ctx.DetailFunc.empty();
+  bool IsInline = !Ctx.Body.empty();
+  bool HasBody = IsDetail || IsInline;
+
+  bool EmitNames = HasBody || llvm::any_of(Ctx.Args, [](const TypeInfo &A) {
+                     return !A.Name.empty();
+                   });
+
+  auto GetParamName = [&](unsigned I) -> std::string {
+    if (!Ctx.Args[I].Name.empty())
+      return Ctx.Args[I].Name.str();
+    return ("p" + Twine(I)).str();
+  };
+
+  if (!HasBody)
+    OS << "_HLSL_BUILTIN_ALIAS(" << Ctx.Builtin << ")\n";
+  if (Ctx.IsConvergent)
+    OS << "__attribute__((convergent)) ";
+  if (HasBody)
+    OS << (Ctx.IsConstexpr ? "constexpr " : "inline ");
+  OS << RetType << " " << Ctx.FuncName << "(";
+
+  for (unsigned I = 0, N = ArgTypes.size(); I < N; ++I) {
+    if (I > 0)
----------------
jurahul wrote:

nit: use ListSeparator instead of manual code for inserting commas

https://github.com/llvm/llvm-project/pull/187610
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to