llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-webassembly

Author: Hood Chatham (hoodmane)

<details>
<summary>Changes</summary>

This is a setjmp-like API that catches JavaScript errors. Its signature is 
`externref_t __builtin_wasm_js_catch(int *status);` The first time it returns, 
the return value is `null` and *status is set to 0. If we later call a function 
that raises a JavaScript error, control jumps back to the most recently 
executed call to `__builtin_wasm_js_catch()` which this time sets `*status` to 
1 and returns the caught JS error.

I think this is a generally useful thing to be able to do to handle errors in 
`EM_JS` functions, but it is possible to use JS try/catch for that. However, 
this is necessary in order to catch a SuspendError since it is generated within 
the import wrapper internally to the runtime and there is no way to catch the 
SuspendError in JavaScript.

The implementation is copied from the sjlj implementation with modifications. 
There were a couple of areas where we can directly share code so I factored 
them out of the sjlj implementation. I didn't bother to support js eh because 
JSPI doesn't work with js eh anyways.

Example:
&lt;details&gt;&lt;summary&gt;Details&lt;/summary&gt;
&lt;p&gt;


```C
int main(void) {
    int res1 = -7;
    int res2 = -7;
    // jsfunc1();

    __externref_t caught = __builtin_wasm_js_catch(&amp;res1);
    printf("catch1 res1: %d, res2: %d\n", res1, res2);
    if (res1 == 0) {
      printf("calling jsfunc1\n");
      jsfunc1();
    } else {
      printf("calling handle_error (1)\n");
      handle_error(caught);
    }
    __externref_t caught2 = __builtin_wasm_js_catch(&amp;res2);
    printf("catch2 res1: %d, res2: %d\n", res1, res2);
    if (res2 == 0) {
      printf("calling jsfunc2\n");
      jsfunc2();
    } else {
      printf("calling handle_error (2)\n");
      handle_error(caught2);
    }
    return 0;
}

EM_JS(int, handle_error, (__externref_t err), {
    console.log(err);
    return 8;
});

EM_JS(void, jsfunc1, (), {
    throw new Error("jsfunc1");
})

EM_JS(void, jsfunc2, (), {
    throw new Error("jsfunc2");
})
```

Output:

```
catch1 res1: 0, res2: -7
calling jsfunc1
catch1 res1: 1, res2: -7
calling handle_error (1)
Error: jsfunc1
    at &lt;... JS stack trace&gt;
catch2 res1: 1, res2: 0
calling jsfunc2
catch2 res1: 1, res2: 1
calling handle_error (2)
Error: jsfunc2
    at &lt;... JS stack trace&gt;
```

&lt;/p&gt;
&lt;/details&gt; 


---

Patch is 38.27 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/153767.diff


12 Files Affected:

- (modified) clang/include/clang/Basic/BuiltinsWebAssembly.def (+2) 
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2) 
- (modified) clang/include/clang/Sema/SemaWasm.h (+1) 
- (modified) clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp (+14) 
- (modified) clang/lib/Sema/SemaWasm.cpp (+20) 
- (modified) clang/test/Sema/builtins-wasm.c (+16) 
- (modified) llvm/include/llvm/CodeGen/WasmEHFuncInfo.h (+1-1) 
- (modified) llvm/include/llvm/IR/IntrinsicsWebAssembly.td (+1) 
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp (+15-3) 
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp (+34-4) 
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp 
(+426-16) 
- (added) llvm/test/CodeGen/WebAssembly/lower-js-catch.ll (+87) 


``````````diff
diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def 
b/clang/include/clang/Basic/BuiltinsWebAssembly.def
index d31b72696ff4e..44b935619e01b 100644
--- a/clang/include/clang/Basic/BuiltinsWebAssembly.def
+++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def
@@ -205,6 +205,8 @@ TARGET_BUILTIN(__builtin_wasm_ref_null_func, "i", "nct", 
"reference-types")
 // ref.test to test the type.
 TARGET_BUILTIN(__builtin_wasm_test_function_pointer_signature, "i.", "nct", 
"gc")
 
+TARGET_BUILTIN(__builtin_wasm_js_catch, "ii*", "nctj", "exception-handling")
+
 // Table builtins
 TARGET_BUILTIN(__builtin_wasm_table_set,  "viii", "t", "reference-types")
 TARGET_BUILTIN(__builtin_wasm_table_get,  "iii", "t", "reference-types")
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td 
b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index a7f3d37823075..59b584f02d530 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -13241,6 +13241,8 @@ def 
err_wasm_builtin_test_fp_sig_cannot_include_struct_or_union
     : Error<"not supported with the multivalue ABI for "
             "function pointers with a struct/union as %select{return "
             "value|parameter}0">;
+def err_wasm_builtin_arg_must_be_pointer_to_integer_type
+    : Error<"%ordinal0 argument must be a pointer to an integer">;
 
 // OpenACC diagnostics.
 def warn_acc_routine_unimplemented
diff --git a/clang/include/clang/Sema/SemaWasm.h 
b/clang/include/clang/Sema/SemaWasm.h
index f82590755d183..8f3d456c964ba 100644
--- a/clang/include/clang/Sema/SemaWasm.h
+++ b/clang/include/clang/Sema/SemaWasm.h
@@ -39,6 +39,7 @@ class SemaWasm : public SemaBase {
   bool BuiltinWasmTableCopy(CallExpr *TheCall);
   bool BuiltinWasmTestFunctionPointerSignature(const TargetInfo &TI,
                                                CallExpr *TheCall);
+  bool BuiltinWasmJsCatch(CallExpr *TheCall);
 
   WebAssemblyImportNameAttr *
   mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);
diff --git a/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp 
b/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
index 1a1889a4139d3..1d16818441a69 100644
--- a/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
@@ -270,6 +270,20 @@ Value 
*CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
     Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_test_func);
     return Builder.CreateCall(Callee, Args);
   }
+  case WebAssembly::BI__builtin_wasm_js_catch: {
+    // We don't do any real lowering here, just make a new function also called
+    // __builtin_wasm_js_catch() and give it the same arguments. We'll lower
+    // it in WebAssemblyLowerEmscriptenEHSjLj.cpp
+    auto *Type = llvm::FunctionType::get(
+        llvm::PointerType::get(getLLVMContext(), 10),
+        {llvm::PointerType::get(getLLVMContext(), 0)}, false);
+    auto Attrs =
+        AttributeList().addFnAttribute(getLLVMContext(), "returns_twice");
+    FunctionCallee Callee = CGM.getModule().getOrInsertFunction(
+        "__builtin_wasm_js_catch", Type, Attrs);
+    Value *Status = EmitScalarExpr(E->getArg(0));
+    return Builder.CreateCall(Callee, {Status});
+  }
   case WebAssembly::BI__builtin_wasm_swizzle_i8x16: {
     Value *Src = EmitScalarExpr(E->getArg(0));
     Value *Indices = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Sema/SemaWasm.cpp b/clang/lib/Sema/SemaWasm.cpp
index e7731136720e8..b2804ad1a13d0 100644
--- a/clang/lib/Sema/SemaWasm.cpp
+++ b/clang/lib/Sema/SemaWasm.cpp
@@ -280,6 +280,24 @@ bool 
SemaWasm::BuiltinWasmTestFunctionPointerSignature(const TargetInfo &TI,
   return false;
 }
 
+bool SemaWasm::BuiltinWasmJsCatch(CallExpr *TheCall) {
+  if (SemaRef.checkArgCount(TheCall, 1))
+    return true;
+
+  Expr *PtrArg = TheCall->getArg(0);
+  QualType ArgType = PtrArg->getType();
+
+  // Check that the argument is a pointer
+  const PointerType *PtrTy = ArgType->getAs<PointerType>();
+  if (!PtrTy || !PtrTy->getPointeeType()->isIntegerType()) {
+    return Diag(PtrArg->getBeginLoc(),
+                diag::err_wasm_builtin_arg_must_be_pointer_to_integer_type)
+           << 1 << PtrArg->getSourceRange();
+  }
+  TheCall->setType(getASTContext().getWebAssemblyExternrefType());
+  return false;
+}
+
 bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
                                                    unsigned BuiltinID,
                                                    CallExpr *TheCall) {
@@ -304,6 +322,8 @@ bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const 
TargetInfo &TI,
     return BuiltinWasmTableCopy(TheCall);
   case WebAssembly::BI__builtin_wasm_test_function_pointer_signature:
     return BuiltinWasmTestFunctionPointerSignature(TI, TheCall);
+  case WebAssembly::BI__builtin_wasm_js_catch:
+    return BuiltinWasmJsCatch(TheCall);
   }
 
   return false;
diff --git a/clang/test/Sema/builtins-wasm.c b/clang/test/Sema/builtins-wasm.c
index 9075e9eaa5230..87cfb72170ebe 100644
--- a/clang/test/Sema/builtins-wasm.c
+++ b/clang/test/Sema/builtins-wasm.c
@@ -87,3 +87,19 @@ void test_function_pointer_signature() {
   (void)__builtin_wasm_test_function_pointer_signature((F4)0);
 #endif
 }
+
+void test_js_catch() {
+  // Test argument count validation
+  __builtin_wasm_js_catch(); // expected-error {{too few arguments to function 
call, expected 1, have 0}}
+  __builtin_wasm_js_catch(0, 0); // expected-error {{too many arguments to 
function call, expected 1, have 2}}
+
+  // // Test argument type validation - should require pointer to int
+  __builtin_wasm_js_catch((void*)0); // expected-error {{1st argument must be 
a pointer to an integer}}
+  __builtin_wasm_js_catch(((int)0));   // expected-error {{1st argument must 
be a pointer to an integer}}
+
+  int res;
+  __externref_t exception = __builtin_wasm_js_catch(&res);
+
+  // Test return type
+  _Static_assert(EXPR_HAS_TYPE(__builtin_wasm_js_catch(&res), __externref_t), 
"");
+}
diff --git a/llvm/include/llvm/CodeGen/WasmEHFuncInfo.h 
b/llvm/include/llvm/CodeGen/WasmEHFuncInfo.h
index ab6b897e9f999..e8cc8ca79fc1b 100644
--- a/llvm/include/llvm/CodeGen/WasmEHFuncInfo.h
+++ b/llvm/include/llvm/CodeGen/WasmEHFuncInfo.h
@@ -24,7 +24,7 @@ class Function;
 class MachineBasicBlock;
 
 namespace WebAssembly {
-enum Tag { CPP_EXCEPTION = 0, C_LONGJMP = 1 };
+enum Tag { CPP_EXCEPTION = 0, C_LONGJMP = 1, JS_EXCEPTION = 2 };
 }
 
 using BBOrMBB = PointerUnion<const BasicBlock *, MachineBasicBlock *>;
diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td 
b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
index c1e4b97e96bc8..50b80d6c6fb69 100644
--- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
+++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
@@ -147,6 +147,7 @@ def int_wasm_get_ehselector :
 def int_wasm_catch :
   DefaultAttrsIntrinsic<[llvm_ptr_ty], [llvm_i32_ty],
                         [IntrHasSideEffects, ImmArg<ArgIndex<0>>]>;
+def int_wasm_catch_js : DefaultAttrsIntrinsic<[llvm_externref_ty], []>;
 
 // WebAssembly EH must maintain the landingpads in the order assigned to them
 // by WasmEHPrepare pass to generate landingpad table in EHStreamer. This is
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp 
b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
index db832bc91ddb5..06133c7ee7a74 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
@@ -247,7 +247,12 @@ MCSymbol 
*WebAssemblyAsmPrinter::getOrCreateWasmSymbol(StringRef Name) {
 
   SmallVector<wasm::ValType, 4> Returns;
   SmallVector<wasm::ValType, 4> Params;
-  if (Name == "__cpp_exception" || Name == "__c_longjmp") {
+  if (Name == "__js_exception") {
+    WasmSym->setType(wasm::WASM_SYMBOL_TYPE_TAG);
+    WasmSym->setImportModule("env");
+    WasmSym->setImportName("__js_exception");
+    Params.push_back(wasm::ValType::EXTERNREF);
+  } else if (Name == "__cpp_exception" || Name == "__c_longjmp") {
     WasmSym->setType(wasm::WASM_SYMBOL_TYPE_TAG);
     // In static linking we define tag symbols in WasmException::endModule().
     // But we may have multiple objects to be linked together, each of which
@@ -321,9 +326,16 @@ void WebAssemblyAsmPrinter::emitDecls(const Module &M) {
     // Emit .globaltype, .tagtype, or .tabletype declarations for extern
     // declarations, i.e. those that have only been declared (but not defined)
     // in the current module
-    auto Sym = static_cast<MCSymbolWasm *>(It.getValue().Symbol);
-    if (Sym && !Sym->isDefined())
+    auto *Sym = static_cast<MCSymbolWasm *>(It.getValue().Symbol);
+    if (Sym && !Sym->isDefined()) {
       emitSymbolType(Sym);
+      if (Sym->hasImportModule()) {
+        getTargetStreamer()->emitImportModule(Sym, Sym->getImportModule());
+      }
+      if (Sym->hasImportName()) {
+        getTargetStreamer()->emitImportName(Sym, Sym->getImportName());
+      }
+    }
   }
 
   DenseSet<MCSymbol *> InvokeSymbols;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp 
b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
index fc852d0a12e14..250dc0762b275 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
@@ -110,13 +110,25 @@ void WebAssemblyDAGToDAGISel::PreprocessISelDAG() {
 }
 
 static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) {
-  assert(Tag == WebAssembly::CPP_EXCEPTION || WebAssembly::C_LONGJMP);
+  assert(Tag == WebAssembly::CPP_EXCEPTION || Tag == WebAssembly::C_LONGJMP ||
+         Tag == WebAssembly::JS_EXCEPTION);
   auto &MF = DAG->getMachineFunction();
   const auto &TLI = DAG->getTargetLoweringInfo();
   MVT PtrVT = TLI.getPointerTy(DAG->getDataLayout());
-  const char *SymName = Tag == WebAssembly::CPP_EXCEPTION
-                            ? MF.createExternalSymbolName("__cpp_exception")
-                            : MF.createExternalSymbolName("__c_longjmp");
+  const char *SymName;
+  switch (Tag) {
+  case WebAssembly::CPP_EXCEPTION:
+    SymName = MF.createExternalSymbolName("__cpp_exception");
+    break;
+  case WebAssembly::C_LONGJMP:
+    SymName = MF.createExternalSymbolName("__c_longjmp");
+    break;
+  case WebAssembly::JS_EXCEPTION:
+    SymName = MF.createExternalSymbolName("__js_exception");
+    break;
+  default:
+    llvm_unreachable("Should not happen");
+  };
   return DAG->getTargetExternalSymbol(SymName, PtrVT);
 }
 
@@ -334,6 +346,24 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
       ReplaceNode(Node, Catch);
       return;
     }
+    case Intrinsic::wasm_catch_js: {
+      unsigned CatchOpcode = WebAssembly::WasmUseLegacyEH
+                                 ? WebAssembly::CATCH_LEGACY
+                                 : WebAssembly::CATCH;
+      SDValue SymNode = getTagSymNode(WebAssembly::JS_EXCEPTION, CurDAG);
+      MachineSDNode *Catch =
+          CurDAG->getMachineNode(CatchOpcode, DL,
+                                 {
+                                     MVT::externref, // exception pointer
+                                     MVT::Other      // outchain type
+                                 },
+                                 {
+                                     SymNode,             // exception symbol
+                                     Node->getOperand(0), // inchain
+                                 });
+      ReplaceNode(Node, Catch);
+      return;
+    }
     }
     break;
   }
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp 
b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
index c3990d12f2c28..36ed9746b7c59 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
@@ -259,6 +259,75 @@
 ///     __wasm_longjmp(%env, %val)
 ///   catchret to %setjmp.dispatch
 ///
+/// * Wasm __builtin_wasm_js_catch handling
+/// Because JS exceptions don't have a specific target, we always transfer
+/// control to the most recently __builtin_wasm_js_catch call. We also set the
+/// int* result argument to 1 when we do this transfer. The return value is
+/// initially set to ref.null.extern.
+///
+/// This is mostly similar to wasm setjmp/longjmp.
+///
+/// 1) There is no longjmp equivalent
+/// 2) We don't need functionInvocationId, but we do set up two stack variables
+/// DispatchTarget and DispatchArgument, and the ref.null.extern we use as the
+/// first return value of __builtin_wasm_js_catch(). DispatchTarget stores the
+/// label index of the most recently executed __builtin_wasm_js_catch() or -1 
if
+/// none have been executed. DispatchArgument stores the argument of the most
+/// recently executed  __builtin_wasm_js_catch().
+///
+/// 3) Lower
+///      __externref_t error = __builtin_wasm_js_catch(&result)
+///    into
+///      result = 0
+///      error = ref.null
+///      DispatchTarget = (index of js_catch call)
+///      DispatchArgument = &result
+///
+/// 4) Create a catchpad with a wasm.catch.js() intrinsic, which returns an
+/// externref containing the JavaScript exception. This is directly used as the
+/// error return value of __builtin_wasm_js_catch().
+///
+/// All function calls that can throw will be converted to an invoke that will
+/// unwind to this catchpad in case a longjmp occurs.
+/// If DispatchTarget is still 0 because we haven't executed a
+/// __builtin_wasm_js_catch() yet, then we rethrow the error. Otherwise, we 
jump
+/// to the beginning of the function, which contains a switch to each
+/// post-jscatch BB.
+///
+/// The below is the pseudocode for what we have described
+///
+/// entry:
+///   Initialize *DispatchTarget = 0, DispatchArgument, %nullref
+///
+/// setjmp.dispatch:
+///    switch %label {
+///      label 1: goto post-jscatch BB 1
+///      label 2: goto post-jscatch BB 2
+///      ...
+///      default: goto split entry BB
+///    }
+/// ...
+///
+///    ;; error = __builtin_wasm_js_catch(&result) expands to
+///    *DispatchTarget = CurrentIndex
+///    *DispatchArgument = &result
+///    error = %nullref
+///    result = 0
+///
+/// bb:
+///   invoke void @foo() ;; foo is a function which can raise a js error
+///     to label %next unwind label %catch.dispatch.jserror
+/// ...
+///
+/// catch.dispatch.jserror:
+///   %0 = catchswitch within none [label %catch.jserror] unwind to caller
+///
+/// catch.longjmp:
+///   %jserror = wasm.catch.js()
+///   %label = *DispatchTarget
+///   if (%label == 0)
+///     __builtin_wasm_rethrow(%jserror)
+///   catchret to %setjmp.dispatch
 
///===----------------------------------------------------------------------===//
 
 #include "WebAssembly.h"
@@ -307,6 +376,9 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public 
ModulePass {
   Function *WasmSetjmpTestF = nullptr;    // __wasm_setjmp_test() (Emscripten)
   Function *WasmLongjmpF = nullptr;       // __wasm_longjmp() (Emscripten)
   Function *CatchF = nullptr;             // wasm.catch() (intrinsic)
+  Function *CatchJsF = nullptr;           // wasm.catch.js() (intrinsic)
+  Function *NullExternF = nullptr;        // wasm.ref.null.extern() (intrinsic)
+  Function *WasmRethrowF = nullptr;       // wasm.rethrow() (intrinsic)
 
   // type of 'struct __WasmLongjmpArgs' defined in emscripten
   Type *LongjmpArgsTy = nullptr;
@@ -318,8 +390,6 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public 
ModulePass {
   StringMap<Function *> InvokeWrappers;
   // Set of allowed function names for exception handling
   std::set<std::string, std::less<>> EHAllowlistSet;
-  // Functions that contains calls to setjmp
-  SmallPtrSet<Function *, 8> SetjmpUsers;
 
   StringRef getPassName() const override {
     return "WebAssembly Lower Emscripten Exceptions";
@@ -328,6 +398,7 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public 
ModulePass {
   using InstVector = SmallVectorImpl<Instruction *>;
   bool runEHOnFunction(Function &F);
   bool runSjLjOnFunction(Function &F);
+  bool runJsCatchOnFunction(Function &F);
   void handleLongjmpableCallsForEmscriptenSjLj(
       Function &F, Instruction *FunctionInvocationId,
       SmallVectorImpl<PHINode *> &SetjmpRetPHIs);
@@ -335,6 +406,15 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public 
ModulePass {
   handleLongjmpableCallsForWasmSjLj(Function &F,
                                     Instruction *FunctionInvocationId,
                                     SmallVectorImpl<PHINode *> &SetjmpRetPHIs);
+  void handleThrowingCallsForJsCatch(Function &F,
+                                     SmallVectorImpl<PHINode *> 
&JsCatchRetPHIs,
+                                     SmallVectorImpl<CallInst *> &CIs,
+                                     Instruction *DispatchTarget,
+                                     Instruction *DispatchArgument);
+  void convertLongJumpableCallsToInvokes(
+      IRBuilder<> &IRB, Function &F,
+      SmallVectorImpl<CallInst *> &LongjmpableCalls,
+      BasicBlock *CatchDispatchLongjmpBB, CatchSwitchInst *CatchSwitchLongjmp);
   Function *getFindMatchingCatch(Module &M, unsigned NumClauses);
 
   Value *wrapInvoke(CallBase *CI);
@@ -902,6 +982,35 @@ static void nullifySetjmp(Function *F) {
     I->eraseFromParent();
 }
 
+static void nullifyJsCatch(Function *F, Function *NullExternF) {
+  Module &M = *F->getParent();
+  IRBuilder<> IRB(M.getContext());
+  Function *JsCatchF = M.getFunction("__builtin_wasm_js_catch");
+  SmallVector<Instruction *, 1> ToErase;
+
+  for (User *U : make_early_inc_range(JsCatchF->users())) {
+    auto *CB = cast<CallBase>(U);
+    BasicBlock *BB = CB->getParent();
+    if (BB->getParent() != F) // in other function
+      continue;
+    CallInst *CI = nullptr;
+    // __builtin_wasm_js_catch() cannot throw. So if it is an invoke, lower it
+    // to a call
+    if (auto *II = dyn_cast<InvokeInst>(CB))
+      CI = llvm::changeToCall(II);
+    else
+      CI = cast<CallInst>(CB);
+    ToErase.push_back(CI);
+    auto *Op = CI->getOperand(1);
+    IRB.SetInsertPoint(CI->getPrevNode());
+    IRB.CreateStore(IRB.getInt32(0), Op);
+    auto *Null = IRB.CreateCall(NullExternF, {});
+    CI->replaceAllUsesWith(Null);
+  }
+  for (auto *I : ToErase)
+    I->eraseFromParent();
+}
+
 bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
   LLVM_DEBUG(dbgs() << "********** Lower Emscripten EH & SjLj **********\n");
 
@@ -910,6 +1019,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module 
&M) {
 
   Function *SetjmpF = M.getFunction("setjmp");
   Function *LongjmpF = M.getFunction("longjmp");
+  Function *JsCatchF = M.getFunction("__builtin_wasm_js_catch");
 
   // In some platforms _setjmp and _longjmp are used instead. Change these to
   // use setjmp/longjmp instead, because we later detect these functions by
@@ -971,9 +1081,14 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module 
&M) {
     EHTypeIDF = getFunction(EHTypeIDTy, "llvm_eh_typeid_for", &M);
   }
 
-  // Functions that contains calls to setjmp but don't have other longjmpable
+  // Functions that contains calls to setjmp/jscatch
+  SmallPtrSet<Function *, 8> SetjmpUsers;
+  SmallPtrSet<Function *, 8> JsCatchUsers;
+
+  // Functions that contains calls to setjmp/jscatch but don't have longjmpable
   // calls within them.
   SmallPtrSet<Function *, 4> SetjmpUsersToNullify;
+  SmallPtrSet<Function *, 4> JsCatchUsersToNullify;
 
   if ((EnableEmSjLj || EnableWasmSjLj) && SetjmpF) {
     // Precompute setjmp users
@@ -997,9 +1112,40 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module 
&M) {
     }
   }
 
+  if ((EnableEmSjLj || EnableWasmSjLj) && JsCatchF) {
+    for (User *U : JsCatchF->users()) {
+      if (auto *CB = dyn_cast<CallBase>(U)) {
+        auto *UserF = CB->getFunction();
+        if (SetjmpUsers.contains(UserF)) {
+          report_fatal_error("Cannot use JsCatch and setjmp in same function");
+        }
+        // If a function that calls js_catch does not contain any other calls
+        // that can throw JS erros, we don't need to do any transformation on
+        // that function, so can ignore it
+        if (containsLongjmpableCalls(UserF))
+          JsCatchUsers.insert(UserF);
+        else
+          JsCatchUsersToNullify.insert(UserF);
+      } else {
+        std::string S;
+        raw_string_ostream SS(S);
+        SS << *U;
+        report_fatal_error(
+            Twine(
+                ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/153767
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to