Meinersbur updated this revision to Diff 233704.
Meinersbur added a comment.

- Simplify transformation classes


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D69089/new/

https://reviews.llvm.org/D69089

Files:
  clang/include/clang/AST/StmtTransform.h
  clang/include/clang/AST/TransformClauseKinds.def
  clang/include/clang/Basic/DiagnosticGroups.td
  clang/include/clang/Basic/DiagnosticParseKinds.td
  clang/include/clang/Basic/Transform.h
  clang/include/clang/Basic/TransformKinds.def
  clang/include/clang/Parse/Parser.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/CMakeLists.txt
  clang/lib/AST/StmtTransform.cpp
  clang/lib/Basic/CMakeLists.txt
  clang/lib/Basic/Transform.cpp
  clang/lib/Parse/CMakeLists.txt
  clang/lib/Parse/ParseStmt.cpp
  clang/lib/Parse/ParseTransform.cpp
  clang/lib/Sema/CMakeLists.txt
  clang/lib/Sema/SemaTransform.cpp
  clang/test/Parser/pragma-transform.cpp

Index: clang/test/Parser/pragma-transform.cpp
===================================================================
--- /dev/null
+++ clang/test/Parser/pragma-transform.cpp
@@ -0,0 +1,92 @@
+// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -verify %s
+
+void pragma_transform(int *List, int Length) {
+// FIXME: This does not emit an error
+#pragma clang
+
+/* expected-error@+1 {{expected a transformation name}} */
+#pragma clang transform
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+
+/* expected-error@+1 {{unknown transformation}} */
+#pragma clang transform unknown_transformation
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+
+/* expected-error@+2 {{expected loop after transformation pragma}} */
+#pragma clang transform unroll
+  pragma_transform(List, Length);
+
+/* expected-error@+1 {{unknown clause name}} */
+#pragma clang transform unroll unknown_clause
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+
+/* expected-error@+1 {{expected '(' after 'partial'}} */
+#pragma clang transform unroll partial
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+
+/* expected-error@+1 {{expected expression}} */
+#pragma clang transform unroll partial(
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+
+/* expected-error@+1 {{expected '(' after 'partial'}} */
+#pragma clang transform unroll partial)
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+
+/* expected-error@+2 {{expected ')'}} */
+/* expected-note@+1 {{to match this '('}} */
+#pragma clang transform unroll partial(4
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+
+/* expected-error@+1 {{expected expression}} */
+#pragma clang transform unroll partial()
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+
+/* expected-error@+1 {{use of undeclared identifier 'badvalue'}} */
+#pragma clang transform unroll partial(badvalue)
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+
+  {
+/* expected-error@+2 {{expected statement}} */
+#pragma clang transform unroll
+  }
+}
+
+/* expected-error@+1 {{expected unqualified-id}} */
+#pragma clang transform unroll
+int I;
+
+/* expected-error@+1 {{expected unqualified-id}} */
+#pragma clang transform unroll
+void func();
+
+class C1 {
+/* expected-error@+3 {{this pragma cannot appear in class declaration}} */
+/* expected-error@+2 {{expected member name or ';' after declaration specifiers}} */
+/* expected-error@+1 {{unknown type name 'unroll'}} */
+#pragma clang transform unroll
+};
+
+template<int F>
+void pragma_transform_template_func(int *List, int Length) {
+#pragma clang transform unroll partial(F)
+  for (int i = 0; i < Length; i+=1)
+      List[i] = i;
+}
+
+template<int F>
+class C2 {
+  void pragma_transform_template_class(int *List, int Length) {
+#pragma clang transform unroll partial(F)
+    for (int i = 0; i < Length; i+=1)
+        List[i] = i;
+  }
+};
Index: clang/lib/Sema/SemaTransform.cpp
===================================================================
--- /dev/null
+++ clang/lib/Sema/SemaTransform.cpp
@@ -0,0 +1,49 @@
+//===---- SemaTransform.h ------------------------------------- -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Semantic analysis for code transformations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/AST/StmtTransform.h"
+#include "clang/Basic/Transform.h"
+#include "clang/Sema/Sema.h"
+#include "clang/Sema/SemaDiagnostic.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringMap.h"
+
+using namespace clang;
+
+StmtResult
+Sema::ActOnLoopTransformDirective(Transform::Kind Kind,
+                                  llvm::ArrayRef<TransformClause *> Clauses,
+                                  Stmt *AStmt, SourceRange Loc) {
+  // TOOD: implement
+  return StmtError();
+}
+
+TransformClause *Sema::ActOnFullClause(SourceRange Loc) {
+  // TOOD: implement
+  return nullptr;
+}
+
+TransformClause *Sema::ActOnPartialClause(SourceRange Loc, Expr *Factor) {
+  // TOOD: implement
+  return nullptr;
+}
+
+TransformClause *Sema::ActOnWidthClause(SourceRange Loc, Expr *Width) {
+  // TOOD: implement
+  return nullptr;
+}
+
+TransformClause *Sema::ActOnFactorClause(SourceRange Loc, Expr *Factor) {
+  // TOOD: implement
+  return nullptr;
+}
Index: clang/lib/Sema/CMakeLists.txt
===================================================================
--- clang/lib/Sema/CMakeLists.txt
+++ clang/lib/Sema/CMakeLists.txt
@@ -63,6 +63,7 @@
   SemaTemplateInstantiate.cpp
   SemaTemplateInstantiateDecl.cpp
   SemaTemplateVariadic.cpp
+  SemaTransform.cpp
   SemaType.cpp
   TypeLocBuilder.cpp
 
Index: clang/lib/Parse/ParseTransform.cpp
===================================================================
--- /dev/null
+++ clang/lib/Parse/ParseTransform.cpp
@@ -0,0 +1,145 @@
+//===---- ParseTransform.h -------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Parse #pragma clang transform ...
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/StmtTransform.h"
+#include "clang/Parse/Parser.h"
+#include "clang/Parse/RAIIObjectsForParser.h"
+
+using namespace clang;
+
+Transform::Kind
+Parser::tryParsePragmaTransform(SourceLocation BeginLoc,
+                                ParsedStmtContext StmtCtx,
+                                SmallVectorImpl<TransformClause *> &Clauses) {
+  // ... Tok=<transform> | <...> tok::annot_pragma_transform_end ...
+  if (Tok.isNot(tok::identifier)) {
+    Diag(Tok, diag::err_pragma_transform_expected_directive);
+    return Transform::UnknownKind;
+  }
+  std::string DirectiveStr = PP.getSpelling(Tok);
+  Transform::Kind DirectiveKind =
+      Transform::getTransformDirectiveKind(DirectiveStr);
+  ConsumeToken();
+
+  switch (DirectiveKind) {
+  case Transform::LoopUnrollKind:
+  case Transform::LoopUnrollAndJamKind:
+  case Transform::LoopDistributionKind:
+  case Transform::LoopVectorizationKind:
+  case Transform::LoopInterleavingKind:
+    break;
+  default:
+    Diag(Tok, diag::err_pragma_transform_unknown_directive);
+    return Transform::UnknownKind;
+  }
+
+  while (true) {
+    TransformClauseResult Clause = ParseTransformClause(DirectiveKind);
+    if (Clause.isInvalid())
+      return Transform::UnknownKind;
+    if (!Clause.isUsable())
+      break;
+
+    Clauses.push_back(Clause.get());
+  }
+
+  assert(Tok.is(tok::annot_pragma_transform_end));
+  return DirectiveKind;
+}
+
+StmtResult Parser::ParsePragmaTransform(ParsedStmtContext StmtCtx) {
+  assert(Tok.is(tok::annot_pragma_transform) && "Not a transform directive!");
+
+  // ... Tok=annot_pragma_transform | <trans> <...> annot_pragma_transform_end
+  // ...
+  SourceLocation BeginLoc = ConsumeAnnotationToken();
+
+  ParenBraceBracketBalancer BalancerRAIIObj(*this);
+
+  SmallVector<TransformClause *, 8> DirectiveClauses;
+  Transform::Kind DirectiveKind =
+      tryParsePragmaTransform(BeginLoc, StmtCtx, DirectiveClauses);
+  if (DirectiveKind == Transform::UnknownKind) {
+    SkipUntil(tok::annot_pragma_transform_end);
+    return StmtError();
+  }
+
+  assert(Tok.is(tok::annot_pragma_transform_end));
+  SourceLocation EndLoc = ConsumeAnnotationToken();
+
+  SourceLocation PreStmtLoc = Tok.getLocation();
+  StmtResult AssociatedStmt = ParseStatement();
+  if (AssociatedStmt.isInvalid())
+    return AssociatedStmt;
+  if (!getAssociatedLoop(AssociatedStmt.get()))
+    return StmtError(
+        Diag(PreStmtLoc, diag::err_pragma_transform_expected_loop));
+
+  return Actions.ActOnLoopTransformDirective(DirectiveKind, DirectiveClauses,
+                                             AssociatedStmt.get(),
+                                             {BeginLoc, EndLoc});
+}
+
+Parser::TransformClauseResult
+Parser::ParseTransformClause(Transform::Kind TransformKind) {
+  // No more clauses
+  if (Tok.is(tok::annot_pragma_transform_end))
+    return ClauseEmpty();
+
+  SourceLocation StartLoc = Tok.getLocation();
+  if (Tok.isNot(tok::identifier))
+    return ClauseError(Diag(Tok, diag::err_pragma_transform_expected_clause));
+  std::string ClauseKeyword = PP.getSpelling(Tok);
+  ConsumeToken();
+  TransformClause::Kind Kind =
+      TransformClause::getClauseKind(TransformKind, ClauseKeyword);
+
+  switch (Kind) {
+  case TransformClause::UnknownKind:
+    return ClauseError(Diag(Tok, diag::err_pragma_transform_unknown_clause));
+
+    // Clauses without arguments.
+  case TransformClause::FullKind:
+    return Actions.ActOnFullClause(SourceRange{StartLoc, StartLoc});
+
+    // Clauses with integer argument.
+  case TransformClause::PartialKind:
+  case TransformClause::WidthKind:
+  case TransformClause::FactorKind: {
+    BalancedDelimiterTracker T(*this, tok::l_paren,
+                               tok::annot_pragma_transform_end);
+    if (T.expectAndConsume(diag::err_expected_lparen_after,
+                           ClauseKeyword.data()))
+      return ClauseError();
+
+    ExprResult Expr = ParseConstantExpression();
+    if (Expr.isInvalid())
+      return ClauseError();
+
+    if (T.consumeClose())
+      return ClauseError();
+    SourceLocation EndLoc = T.getCloseLocation();
+    SourceRange Range{StartLoc, EndLoc};
+    switch (Kind) {
+    case TransformClause::PartialKind:
+      return Actions.ActOnPartialClause(Range, Expr.get());
+    case TransformClause::WidthKind:
+      return Actions.ActOnWidthClause(Range, Expr.get());
+    case TransformClause::FactorKind:
+      return Actions.ActOnFactorClause(Range, Expr.get());
+    default:
+      llvm_unreachable("Unhandled clause");
+    }
+  }
+  }
+  llvm_unreachable("Unhandled clause");
+}
Index: clang/lib/Parse/ParseStmt.cpp
===================================================================
--- clang/lib/Parse/ParseStmt.cpp
+++ clang/lib/Parse/ParseStmt.cpp
@@ -14,6 +14,7 @@
 #include "clang/AST/PrettyDeclStackTrace.h"
 #include "clang/Basic/Attributes.h"
 #include "clang/Basic/PrettyStackTrace.h"
+#include "clang/Basic/Transform.h"
 #include "clang/Parse/LoopHint.h"
 #include "clang/Parse/Parser.h"
 #include "clang/Parse/RAIIObjectsForParser.h"
@@ -400,6 +401,10 @@
     ProhibitAttributes(Attrs);
     return ParsePragmaLoopHint(Stmts, StmtCtx, TrailingElseLoc, Attrs);
 
+  case tok::annot_pragma_transform:
+    ProhibitAttributes(Attrs);
+    return ParsePragmaTransform(StmtCtx);
+
   case tok::annot_pragma_dump:
     HandlePragmaDump();
     return StmtEmpty();
Index: clang/lib/Parse/CMakeLists.txt
===================================================================
--- clang/lib/Parse/CMakeLists.txt
+++ clang/lib/Parse/CMakeLists.txt
@@ -20,6 +20,7 @@
   ParseStmtAsm.cpp
   ParseTemplate.cpp
   ParseTentative.cpp
+  ParseTransform.cpp
   Parser.cpp
 
   LINK_LIBS
Index: clang/lib/Basic/Transform.cpp
===================================================================
--- /dev/null
+++ clang/lib/Basic/Transform.cpp
@@ -0,0 +1,62 @@
+//===--- Transform.h - Code transformation classes --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines classes used for code transformations such as
+//  #pragma clang transform ...
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Basic/Transform.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Casting.h"
+
+using namespace clang;
+
+Transform::Kind Transform ::getTransformDirectiveKind(llvm::StringRef Str) {
+  return llvm::StringSwitch<Transform::Kind>(Str)
+      .Case("unroll", LoopUnrollKind)
+      .Case("unrollandjam", LoopUnrollAndJamKind)
+      .Case("vectorize", LoopVectorizationKind)
+      .Case("interleave", LoopInterleavingKind)
+      .Case("distribute", LoopDistributionKind)
+      .Default(UnknownKind);
+}
+
+llvm::StringRef Transform ::getTransformDirectiveKeyword(Kind K) {
+  switch (K) {
+  case UnknownKind:
+    break;
+  case LoopUnrollKind:
+    return "unroll";
+  case LoopUnrollAndJamKind:
+    return "unrollandjam";
+  case LoopVectorizationKind:
+    return "vectorize";
+  case LoopInterleavingKind:
+    return "interleave";
+  case LoopDistributionKind:
+    return "distribute";
+  }
+  llvm_unreachable("Not a known transformation");
+}
+
+int Transform::getLoopPipelineStage() const {
+  switch (getKind()) {
+  case Transform::Kind::LoopUnrollKind:
+    return cast<LoopUnrollTransform>(this)->isFull() ? 0 : 4;
+  case Transform::Kind::LoopDistributionKind:
+    return 1;
+  case Transform::Kind::LoopInterleavingKind:
+  case Transform::Kind::LoopVectorizationKind:
+    return 2;
+  case Transform::Kind::LoopUnrollAndJamKind:
+    return 3;
+  default:
+    return -1;
+  }
+}
Index: clang/lib/Basic/CMakeLists.txt
===================================================================
--- clang/lib/Basic/CMakeLists.txt
+++ clang/lib/Basic/CMakeLists.txt
@@ -87,6 +87,7 @@
   Targets/X86.cpp
   Targets/XCore.cpp
   TokenKinds.cpp
+  Transform.cpp
   Version.cpp
   Warnings.cpp
   XRayInstr.cpp
Index: clang/lib/AST/StmtTransform.cpp
===================================================================
--- /dev/null
+++ clang/lib/AST/StmtTransform.cpp
@@ -0,0 +1,69 @@
+//===--- StmtTransform.h - Code transformation AST nodes --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  Transformation directive statement and clauses for the AST.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/StmtTransform.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Stmt.h"
+#include "clang/AST/StmtOpenMP.h"
+
+using namespace clang;
+
+bool TransformClause::isValidForTransform(Transform::Kind TransformKind,
+                                          TransformClause::Kind ClauseKind) {
+  switch (TransformKind) {
+  case clang::Transform::LoopUnrollKind:
+    return ClauseKind == PartialKind || ClauseKind == FullKind;
+  case clang::Transform::LoopUnrollAndJamKind:
+    return ClauseKind == PartialKind;
+  case clang::Transform::LoopVectorizationKind:
+    return ClauseKind == WidthKind;
+  case clang::Transform::LoopInterleavingKind:
+    return ClauseKind == FactorKind;
+  default:
+    return false;
+  }
+}
+
+TransformClause::Kind
+TransformClause ::getClauseKind(Transform::Kind TransformKind,
+                                llvm::StringRef Str) {
+#define TRANSFORM_CLAUSE(Keyword, Name)                                        \
+  if (isValidForTransform(TransformKind, TransformClause::Kind::Name##Kind) && \
+      Str == #Keyword)                                                         \
+    return TransformClause::Kind::Name##Kind;
+#include "clang/AST/TransformClauseKinds.def"
+  return TransformClause::UnknownKind;
+}
+
+llvm::StringRef
+TransformClause ::getClauseKeyword(TransformClause::Kind ClauseKind) {
+  assert(ClauseKind > UnknownKind);
+  assert(ClauseKind <= LastKind);
+  static const char *ClauseKeyword[LastKind] = {
+#define TRANSFORM_CLAUSE(Keyword, Name) #Keyword,
+#include "clang/AST/TransformClauseKinds.def"
+
+  };
+  return ClauseKeyword[ClauseKind - 1];
+}
+
+const Stmt *clang::getAssociatedLoop(const Stmt *S) {
+  switch (S->getStmtClass()) {
+  case Stmt::ForStmtClass:
+  case Stmt::WhileStmtClass:
+  case Stmt::DoStmtClass:
+  case Stmt::CXXForRangeStmtClass:
+    return S;
+  default:
+    return nullptr;
+  }
+}
Index: clang/lib/AST/CMakeLists.txt
===================================================================
--- clang/lib/AST/CMakeLists.txt
+++ clang/lib/AST/CMakeLists.txt
@@ -100,6 +100,7 @@
   StmtOpenMP.cpp
   StmtPrinter.cpp
   StmtProfile.cpp
+  StmtTransform.cpp
   StmtViz.cpp
   TemplateBase.cpp
   TemplateName.cpp
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -29,6 +29,7 @@
 #include "clang/AST/NSAPI.h"
 #include "clang/AST/PrettyPrinter.h"
 #include "clang/AST/StmtCXX.h"
+#include "clang/AST/StmtTransform.h"
 #include "clang/AST/TypeLoc.h"
 #include "clang/AST/TypeOrdering.h"
 #include "clang/Basic/ExpressionTraits.h"
@@ -11749,6 +11750,16 @@
     ConstructorDestructor,
     BuiltinFunction
   };
+
+  StmtResult
+  ActOnLoopTransformDirective(Transform::Kind Kind,
+                              llvm::ArrayRef<TransformClause *> Clauses,
+                              Stmt *AStmt, SourceRange Loc);
+
+  TransformClause *ActOnFullClause(SourceRange Loc);
+  TransformClause *ActOnPartialClause(SourceRange Loc, Expr *Factor);
+  TransformClause *ActOnWidthClause(SourceRange Loc, Expr *Width);
+  TransformClause *ActOnFactorClause(SourceRange Loc, Expr *Factor);
 };
 
 /// RAII object that enters a new expression evaluation context.
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -1647,6 +1647,17 @@
     IsTypeCast
   };
 
+  using TransformClauseResult = ActionResult<TransformClause *>;
+  static TransformClauseResult ClauseError() {
+    return TransformClauseResult(true);
+  }
+  static TransformClauseResult ClauseError(const DiagnosticBuilder &) {
+    return ClauseError();
+  }
+  static TransformClauseResult ClauseEmpty() {
+    return TransformClauseResult(false);
+  }
+
   ExprResult ParseExpression(TypeCastState isTypeCast = NotTypeCast);
   ExprResult ParseConstantExpressionInExprEvalContext(
       TypeCastState isTypeCast = NotTypeCast);
@@ -1983,6 +1994,12 @@
                                  SourceLocation *TrailingElseLoc,
                                  ParsedAttributesWithRange &Attrs);
 
+  Transform::Kind
+  tryParsePragmaTransform(SourceLocation BeginLoc, ParsedStmtContext StmtCtx,
+                          SmallVectorImpl<TransformClause *> &Clauses);
+  StmtResult ParsePragmaTransform(ParsedStmtContext StmtCtx);
+  TransformClauseResult ParseTransformClause(Transform::Kind TransformKind);
+
   /// Describes the behavior that should be taken for an __if_exists
   /// block.
   enum IfExistsBehavior {
Index: clang/include/clang/Basic/TransformKinds.def
===================================================================
--- /dev/null
+++ clang/include/clang/Basic/TransformKinds.def
@@ -0,0 +1,18 @@
+
+#ifndef TRANSFORM_DIRECTIVE
+#  define TRANSFORM_DIRECTIVE(Name)
+#endif
+#ifndef TRANSFORM_DIRECTIVE_LAST
+#  define TRANSFORM_DIRECTIVE_LAST(Name) TRANSFORM_DIRECTIVE(Name)
+#endif
+
+// Loop transformations accessible through "#pragma clang transform".
+TRANSFORM_DIRECTIVE(LoopUnroll)
+TRANSFORM_DIRECTIVE(LoopUnrollAndJam)
+TRANSFORM_DIRECTIVE(LoopDistribution)
+TRANSFORM_DIRECTIVE(LoopVectorization)
+TRANSFORM_DIRECTIVE_LAST(LoopInterleaving)
+
+
+#undef TRANSFORM_DIRECTIVE
+#undef TRANSFORM_DIRECTIVE_LAST
Index: clang/include/clang/Basic/Transform.h
===================================================================
--- /dev/null
+++ clang/include/clang/Basic/Transform.h
@@ -0,0 +1,388 @@
+//===--- Transform.h - Code transformation classes --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines classes used for code transformations such as
+//  #pragma clang transform ...
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_BASIC_TRANSFORM_H
+#define LLVM_CLANG_BASIC_TRANSFORM_H
+
+#include "clang/AST/Stmt.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace clang {
+
+class Transform {
+public:
+  enum Kind {
+    UnknownKind,
+#define TRANSFORM_DIRECTIVE(Name) Name##Kind,
+#define TRANSFORM_DIRECTIVE_LAST(Name)                                         \
+  TRANSFORM_DIRECTIVE(Name)                                                    \
+  LastKind = Name##Kind
+#include "TransformKinds.def"
+  };
+
+  static Kind getTransformDirectiveKind(llvm::StringRef Str);
+  static llvm::StringRef getTransformDirectiveKeyword(Kind K);
+
+private:
+  Kind TransformKind;
+  SourceRange LocRange;
+
+protected:
+  Transform(Kind K, SourceRange LocRange)
+      : TransformKind(K), LocRange(LocRange) {}
+
+public:
+  virtual ~Transform() {}
+
+  Kind getKind() const { return TransformKind; }
+  static bool classof(const Transform *Trans) { return true; }
+
+  /// Source location of the code transformation directive.
+  /// @{
+  SourceRange getRange() const { return LocRange; }
+  SourceLocation getBeginLoc() const { return LocRange.getBegin(); }
+  SourceLocation getEndLoc() const { return LocRange.getEnd(); }
+  void setRange(SourceRange L) { LocRange = L; }
+  void setRange(SourceLocation BeginLoc, SourceLocation EndLoc) {
+    LocRange = SourceRange(BeginLoc, EndLoc);
+  }
+  /// @}
+
+  /// Each transformation defines how many loops it consumes and generates.
+  /// Users of this class can store arrays holding the information regarding the
+  /// loops, such as pointer to the AST node or the loop name. The index in this
+  /// array is its "role".
+  /// @{
+  virtual int getNumInputs() const { return 1; }
+  virtual int getNumFollowups() const { return 0; }
+  /// @}
+
+  /// A meta role may apply to multiple output loops, its attributes are added
+  /// to each of them. A typical example is the 'all' followup which applies to
+  /// all loops emitted by a transformation. The "all" follow-up role is a meta
+  /// output whose' attributes are added to all generated loops.
+  bool isMetaRole(int R) const { return R == 0; }
+
+  /// Used to warn users that the current LLVM pass pipeline cannot apply
+  /// arbitrary transformation orders yet.
+  int getLoopPipelineStage() const;
+};
+
+/// Partially or fully unroll a loop.
+///
+/// A full unroll transforms a loop such as
+///
+///     for (int i = 0; i < 2; i+=1)
+///       Stmt(i);
+///
+/// into
+///
+///     {
+///       Stmt(0);
+///       Stmt(1);
+///     }
+///
+/// Partial unrolling can also be applied when the loop trip count is only known
+/// at runtime. For instance, partial unrolling by a factor of 2 transforms
+///
+///     for (int i = 0; i < N; i+=1)
+///       Stmt(i);
+///
+/// into
+///
+///     int i = 0;
+///     for (; i < N; i+=2) { // unrolled
+///       Stmt(i);
+///       Stmt(i+1);
+///     }
+///     for (; i < N; i+=1) // epilogue/remainder
+///       Stmt(i);
+///
+/// LLVM's LoopUnroll pass uses the name runtime unrolling if N is not a
+/// constant.
+///
+/// When using heuristic unrolling, the optimizer decides itself whether to
+/// unroll fully or partially. Because the front-end does not know what the
+/// optimizer will do, there is no followup loop. Note that this is different to
+/// partial unrolling with an undefined factor, which has always has followup
+/// loops but may not be executed.
+class LoopUnrollTransform final : public Transform {
+private:
+  int64_t Factor;
+
+  LoopUnrollTransform(SourceRange Loc, int64_t Factor)
+      : Transform(LoopUnrollKind, Loc), Factor(Factor) {
+    assert(Factor >= 2);
+  }
+
+public:
+  static bool classof(const LoopUnrollTransform *Trans) { return true; }
+  static bool classof(const Transform *Trans) {
+    return Trans->getKind() == LoopUnrollKind;
+  }
+
+  /// Create an instance of partial unrolling. The unroll factor must be at
+  /// least 2 or -1. When -1, the unroll factor can be chosen by the optimizer.
+  /// An unroll factor of 0 or 1 is not valid.
+  static LoopUnrollTransform *createPartial(SourceRange Loc,
+                                            int64_t Factor = -1) {
+    assert(Factor >= 2 || Factor == -1);
+    LoopUnrollTransform *Instance = new LoopUnrollTransform(Loc, Factor);
+    assert(Instance->isPartial());
+    return Instance;
+  }
+
+  static LoopUnrollTransform *createFull(SourceRange Loc) {
+    LoopUnrollTransform *Instance = new LoopUnrollTransform(Loc, -2);
+    assert(Instance->isFull());
+    return Instance;
+  }
+
+  static LoopUnrollTransform *createHeuristic(SourceRange Loc) {
+    LoopUnrollTransform *Instance = new LoopUnrollTransform(Loc, -3);
+    assert(Instance->isHeuristic());
+    return Instance;
+  }
+
+  bool isPartial() const { return Factor >= 2 || Factor == -1; }
+  bool isFull() const { return Factor == -2; }
+  bool isHeuristic() const { return Factor == -3; }
+
+  enum Input { InputToUnroll };
+  int getNumInputs() const override { return 1; }
+
+  enum Followup {
+    FollowupAll,
+    FollowupUnrolled, // only for partial unrolling
+    FollowupRemainder // only for partial unrolling
+  };
+  int getNumFollowups() const override {
+    if (isPartial())
+      return 3;
+    return 0;
+  }
+
+  int64_t getFactor() const { return Factor; }
+};
+
+/// Apply partial unroll-and-jam to a loop.
+///
+/// That is, with a unroll factor of 2, transform
+///
+///     for (int i = 0; i < N; i+=1)
+///        for (int j = 0; j < M; j+=1)
+///          Stmt(i,j);
+///
+/// into
+///
+///     int i = 0;
+///     for (; i < N; i+=2) {             // inner
+///        for (int j = 0; j < M; j+=1) { // outer
+///          Stmt(i,j);
+///          Stmt(i+1,j);
+///       }
+///     for (; i < N; i+=1)               // remainder/epilogue
+///        for (int j = 0; j < M; j+=1)
+///          Stmt(i,j);
+///
+/// Note that LLVM's LoopUnrollAndJam pass does not support full unroll.
+class LoopUnrollAndJamTransform final : public Transform {
+private:
+  int64_t Factor;
+
+  LoopUnrollAndJamTransform(SourceRange Loc, int64_t Factor)
+      : Transform(LoopUnrollAndJamKind, Loc), Factor(Factor) {}
+
+public:
+  static bool classof(const LoopUnrollAndJamTransform *Trans) { return true; }
+  static bool classof(const Transform *Trans) {
+    return Trans->getKind() == LoopUnrollAndJamKind;
+  }
+
+  /// Create an instance of unroll-and-jam. The unroll factor must be at least 2
+  /// or -1. When -1, the unroll factor can be chosen by the optimizer. An
+  /// unroll factor of 0 or 1 is not valid.
+  static LoopUnrollAndJamTransform *createPartial(SourceRange Loc,
+                                                  int64_t Factor = -1) {
+    assert(Factor >= 2 || Factor == -1);
+    LoopUnrollAndJamTransform *Instance =
+        new LoopUnrollAndJamTransform(Loc, Factor);
+    assert(Instance->isPartial());
+    return Instance;
+  }
+
+  static LoopUnrollAndJamTransform *createHeuristic(SourceRange Loc) {
+    LoopUnrollAndJamTransform *Instance =
+        new LoopUnrollAndJamTransform(Loc, -3);
+    assert(Instance->isHeuristic());
+    return Instance;
+  }
+
+  bool isPartial() const { return Factor >= 2 || Factor == -1; }
+  bool isHeuristic() const { return Factor == -3; }
+
+  enum Input { InputOuter, InputInner };
+  int getNumInputs() const override { return 2; }
+
+  enum Followup { FollowupAll, FollowupOuter, FollowupInner };
+  int getNumFollowups() const override {
+    if (isPartial())
+      return 3;
+    return 0;
+  }
+
+  int64_t getFactor() const { return Factor; }
+};
+
+/// Apply loop distribution (aka fission) to a loop.
+///
+/// For example, transform the loop
+///
+///     for (int i = 0; i < N; i+=1) {
+///       StmtA(i);
+///       StmtB(i);
+///     }
+///
+/// into
+///
+///     for (int i = 0; i < N; i+=1)
+///       StmtA(i);
+///     for (int i = 0; i < N; i+=1)
+///       StmtB(i);
+///
+/// LLVM's LoopDistribute pass does not allow to control how the loop is
+/// distributed. Hence, there are no non-meta followups.
+class LoopDistributionTransform final : public Transform {
+private:
+  LoopDistributionTransform(SourceRange Loc)
+      : Transform(LoopDistributionKind, Loc) {}
+
+public:
+  static bool classof(const LoopDistributionTransform *Trans) { return true; }
+  static bool classof(const Transform *Trans) {
+    return Trans->getKind() == LoopDistributionKind;
+  }
+
+  static LoopDistributionTransform *create(SourceRange Loc) {
+    return new LoopDistributionTransform(Loc);
+  }
+
+  enum Input { InputToDistribute };
+  int getNumInputs() const override { return 1; }
+
+  enum Followup { FollowupAll };
+  int getNumFollowups() const override { return 1; }
+};
+
+/// Vectorize a loop by executing multiple loop iterations at the same time in
+/// vector lanes.
+///
+/// For example, transform
+///
+///     for (int i = 0; i < N; i+=1)
+///       Stmt(i);
+///
+/// into
+///
+///     int i = 0;
+///     for (; i < N; i+=2) // vectorized
+///       Stmt(i:i+1);
+///     for (; i < N; i+=1) // epilogue/remainder
+///       Stmt(i);
+class LoopVectorizationTransform final : public Transform {
+private:
+  int64_t VectorizeWidth;
+
+  LoopVectorizationTransform(SourceRange Loc, int64_t VectorizeWidth)
+      : Transform(LoopVectorizationKind, Loc), VectorizeWidth(VectorizeWidth) {
+    assert(VectorizeWidth >= 2);
+  }
+
+public:
+  static bool classof(const LoopVectorizationTransform *Trans) { return true; }
+  static bool classof(const Transform *Trans) {
+    return Trans->getKind() == LoopVectorizationKind;
+  }
+
+  static LoopVectorizationTransform *create(SourceRange Loc,
+                                            int64_t VectorizeWidth = -1) {
+    assert(VectorizeWidth >= 2 || VectorizeWidth == -1);
+    return new LoopVectorizationTransform(Loc, VectorizeWidth);
+  }
+
+  enum Input { InputToVectorize };
+  int getNumInputs() const override { return 1; }
+
+  enum Followup { FollowupAll, FollowupVectorized, FollowupEpilogue };
+  int getNumFollowups() const override { return 3; }
+
+  int64_t getWidth() const { return VectorizeWidth; }
+};
+
+/// Execute multiple loop iterations at once by duplicating instructions. This
+/// is different from unrolling in that it copies each instruction n times
+/// instead of the entire loop body as loop unrolling does.
+///
+/// For example, transform
+///
+///     for (int i = 0; i < N; i+=1) {
+///       InstA(i);
+///       InstB(i);
+///       InstC(i);
+///     }
+///
+/// into
+///
+///     int i = 0;
+///     for (; i < N; i+=2) { // interleaved
+///       InstA(i);
+///       InstA(i+1);
+///       InstB(i);
+///       InstB(i+1);
+///       InstC(i);
+///       InstC(i+1);
+///     }
+///     for (; i < N; i+=1) // epilogue/remainder
+///       InstA(i);
+///       InstB(i);
+///       InstC(i);
+///     }
+class LoopInterleavingTransform final : public Transform {
+private:
+  int64_t Factor;
+
+  LoopInterleavingTransform(SourceRange Loc, int64_t Factor)
+      : Transform(LoopInterleavingKind, Loc), Factor(Factor) {}
+
+public:
+  static bool classof(const LoopInterleavingTransform *Trans) { return true; }
+  static bool classof(const Transform *Trans) {
+    return Trans->getKind() == LoopInterleavingKind;
+  }
+
+  static LoopInterleavingTransform *create(SourceRange Loc, int64_t Factor) {
+    assert(Factor == -1 || Factor >= 2);
+    return new LoopInterleavingTransform(Loc, Factor);
+  }
+
+  enum Input { InputToVectorize };
+  int getNumInputs() const override { return 1; }
+
+  enum Followup { FollowupAll, FollowupInterleaved, FollowupEpilogue };
+  int getNumFollowups() const override { return 3; }
+
+  int64_t getFactor() const { return Factor; }
+};
+
+} // namespace clang
+#endif /* LLVM_CLANG_BASIC_TRANSFORM_H */
Index: clang/include/clang/Basic/DiagnosticParseKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticParseKinds.td
+++ clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1245,6 +1245,18 @@
   "vectorize_width, interleave, interleave_count, unroll, unroll_count, "
   "pipeline, pipeline_initiation_interval, vectorize_predicate, or distribute">;
 
+// Pragma transform support.
+def err_pragma_transform_expected_directive : Error<
+  "expected a transformation name">;
+def err_pragma_transform_unknown_directive : Error<
+  "unknown transformation">;
+def err_pragma_transform_expected_loop : Error<
+  "expected loop after transformation pragma">;
+def err_pragma_transform_expected_clause : Error<
+  "expected a clause name">;
+def err_pragma_transform_unknown_clause : Error<
+  "unknown clause name">;
+
 def err_pragma_fp_invalid_option : Error<
   "%select{invalid|missing}0 option%select{ %1|}0; expected contract">;
 def err_pragma_fp_invalid_argument : Error<
Index: clang/include/clang/Basic/DiagnosticGroups.td
===================================================================
--- clang/include/clang/Basic/DiagnosticGroups.td
+++ clang/include/clang/Basic/DiagnosticGroups.td
@@ -1128,3 +1128,6 @@
 def CTADMaybeUnsupported : DiagGroup<"ctad-maybe-unsupported">;
 
 def FortifySource : DiagGroup<"fortify-source">;
+
+// Warnings for #pragma clang transform
+def ClangTransform : DiagGroup<"pragma-transform">;
Index: clang/include/clang/AST/TransformClauseKinds.def
===================================================================
--- /dev/null
+++ clang/include/clang/AST/TransformClauseKinds.def
@@ -0,0 +1,16 @@
+
+#ifndef TRANSFORM_CLAUSE
+#  define TRANSFORM_CLAUSE(Keyword, Name)
+#endif
+#ifndef TRANSFORM_CLAUSE_LAST
+#  define TRANSFORM_CLAUSE_LAST(Keyword, Name)  TRANSFORM_CLAUSE(Keyword, Name)
+#endif
+
+TRANSFORM_CLAUSE(full,Full)
+TRANSFORM_CLAUSE(partial,Partial)
+
+TRANSFORM_CLAUSE(width,Width)
+TRANSFORM_CLAUSE_LAST(factor,Factor)
+
+#undef TRANSFORM_CLAUSE
+#undef TRANSFORM_CLAUSE_LAST
Index: clang/include/clang/AST/StmtTransform.h
===================================================================
--- /dev/null
+++ clang/include/clang/AST/StmtTransform.h
@@ -0,0 +1,52 @@
+//===--- StmtTransform.h - Code transformation AST nodes --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  Transformation directive statement and clauses for the AST.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_STMTTRANSFROM_H
+#define LLVM_CLANG_AST_STMTTRANSFROM_H
+
+#include "clang/AST/Stmt.h"
+#include "clang/Basic/Transform.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace clang {
+
+/// Represents a clause of a \p TransformExecutableDirective.
+class TransformClause {
+public:
+  enum Kind {
+    UnknownKind,
+#define TRANSFORM_CLAUSE(Keyword, Name) Name##Kind,
+#define TRANSFORM_CLAUSE_LAST(Keyword, Name) Name##Kind, LastKind = Name##Kind
+#include "clang/AST/TransformClauseKinds.def"
+  };
+
+  static bool isValidForTransform(Transform::Kind TransformKind,
+                                  TransformClause::Kind ClauseKind);
+  static Kind getClauseKind(Transform::Kind TransformKind, llvm::StringRef Str);
+  static llvm::StringRef getClauseKeyword(TransformClause::Kind ClauseKind);
+
+  // TODO: implement
+};
+
+/// Represents
+///
+///   #pragma clang transform
+///
+/// in the AST.
+class TransformExecutableDirective final {
+  // TODO: implement
+};
+
+const Stmt *getAssociatedLoop(const Stmt *S);
+} // namespace clang
+
+#endif /* LLVM_CLANG_AST_STMTTRANSFROM_H */
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to