crepererum commented on code in PR #12086:
URL: https://github.com/apache/datafusion/pull/12086#discussion_r1726597939
##########
datafusion/common-runtime/src/common.rs:
##########
@@ -60,18 +60,67 @@ impl<R: 'static> SpawnedTask<R> {
}
/// Joins the task and unwinds the panic if it happens.
- pub async fn join_unwind(self) -> R {
- self.join().await.unwrap_or_else(|e| {
+ pub async fn join_unwind(self) -> Result<R, JoinError> {
+ self.join().await.map_err(|e| {
// `JoinError` can be caused either by panic or cancellation. We
have to handle panics:
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
+ } else if e.is_cancelled() {
+ log::warn!("SpawnedTask was polled during shutdown");
+ e
} else {
- // Cancellation may be caused by two reasons:
- // 1. Abort is called, but since we consumed `self`, it's not
our case (`JoinHandle` not accessible outside).
- // 2. The runtime is shutting down.
- // So we consider this branch as unreachable.
unreachable!("SpawnedTask was cancelled unexpectedly");
}
})
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use std::{
+ future::{pending, Pending},
+ sync::{Arc, Mutex},
+ };
+
+ use tokio::runtime::Runtime;
+
+ #[tokio::test]
+ async fn runtime_shutdown() {
+ // capture the panic message
+ let panic_msg = Arc::new(Mutex::new(None));
+ let captured_panic_msg = Arc::clone(&panic_msg);
+ std::panic::set_hook(Box::new(move |e| {
+ let mut guard = captured_panic_msg.lock().unwrap();
+ *guard = Some(e.to_string());
+ }));
+
+ for _ in 0..30 {
+ let rt = Runtime::new().unwrap();
+ let join = rt.spawn(async {
+ let task = SpawnedTask::spawn(async {
+ let fut: Pending<()> = pending();
+ fut.await;
+ unreachable!("should never return");
+ });
+ let _ = task.join_unwind().await;
+ });
+
+ // caller shutdown their DF runtime (e.g. timeout, error in
caller, etc)
+ rt.shutdown_background();
+
+ // race condition
+ // poll occurs during shutdown (buffered stream poll calls, etc)
+ let _ = join.await;
Review Comment:
```suggestion
let task = rt.spawn(async {
SpawnedTask::spawn(async {
let fut: Pending<()> = pending();
fut.await;
unreachable!("should never return");
})
}).await;
// caller shutdown their DF runtime (e.g. timeout, error in
caller, etc)
rt.shutdown_background();
// race condition
// poll occurs during shutdown (buffered stream poll calls, etc)
let _ = task.join_unwind().await;
```
I think the current setup is a bit flaky:
- it depends on your panic hook, see my concerns above
- it depends on the fact that your sub-runtime makes progress in `spawn`.
this is also why you need to loop a bunch of times
I think this change also allows you to remove the loop and the panic hook.
##########
datafusion/common-runtime/src/common.rs:
##########
@@ -60,18 +60,67 @@ impl<R: 'static> SpawnedTask<R> {
}
/// Joins the task and unwinds the panic if it happens.
- pub async fn join_unwind(self) -> R {
- self.join().await.unwrap_or_else(|e| {
+ pub async fn join_unwind(self) -> Result<R, JoinError> {
+ self.join().await.map_err(|e| {
// `JoinError` can be caused either by panic or cancellation. We
have to handle panics:
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
+ } else if e.is_cancelled() {
+ log::warn!("SpawnedTask was polled during shutdown");
+ e
} else {
- // Cancellation may be caused by two reasons:
- // 1. Abort is called, but since we consumed `self`, it's not
our case (`JoinHandle` not accessible outside).
- // 2. The runtime is shutting down.
- // So we consider this branch as unreachable.
unreachable!("SpawnedTask was cancelled unexpectedly");
}
})
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use std::{
+ future::{pending, Pending},
+ sync::{Arc, Mutex},
+ };
+
+ use tokio::runtime::Runtime;
+
+ #[tokio::test]
+ async fn runtime_shutdown() {
+ // capture the panic message
+ let panic_msg = Arc::new(Mutex::new(None));
+ let captured_panic_msg = Arc::clone(&panic_msg);
+ std::panic::set_hook(Box::new(move |e| {
Review Comment:
Can we remove this hook? IIRC it is a global process state (similar to env
variables) and is shared with all the other tests that run in the same process
(so all unit tests in `common-runtime`). This will be a pain to debug for
others.
See my comment below on an alternative.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]