cloud-fan commented on code in PR #48818: URL: https://github.com/apache/spark/pull/48818#discussion_r1925413987
########## sql/api/src/main/scala/org/apache/spark/sql/SparkSession.scala: ########## @@ -776,40 +779,161 @@ abstract class SparkSession extends Serializable with Closeable { * means the connection to the server is usable. */ private[sql] def isUsable: Boolean + + /** + * Execute a block of code with this session set as the active session, and restore the previous + * session on completion. + */ + @DeveloperApi + def withActive[T](block: => T): T = { + // Use the active session thread local directly to make sure we get the session that is actually + // set and not the default session. This to prevent that we promote the default session to the + // active session once we are done. + val old = SparkSession.getActiveSession.orNull + SparkSession.setActiveSession(this) + try block + finally { + SparkSession.setActiveSession(old) + } + } } object SparkSession extends SparkSessionCompanion { type Session = SparkSession - private[this] val companion: SparkSessionCompanion = { - val cls = SparkClassUtils.classForName("org.apache.spark.sql.SparkSession") + // Implementation specific companions + private lazy val CLASSIC_COMPANION = lookupCompanion( + "org.apache.spark.sql.classic.SparkSession") + private lazy val CONNECT_COMPANION = lookupCompanion( + "org.apache.spark.sql.connect.SparkSession") + private def DEFAULT_COMPANION = + Try(CLASSIC_COMPANION).orElse(Try(CONNECT_COMPANION)).getOrElse { + throw new IllegalStateException( + "Cannot find a SparkSession implementation on the Classpath.") + } + + private[this] def lookupCompanion(name: String): SparkSessionCompanion = { + val cls = SparkClassUtils.classForName(name) val mirror = scala.reflect.runtime.currentMirror val module = mirror.classSymbol(cls).companion.asModule mirror.reflectModule(module).instance.asInstanceOf[SparkSessionCompanion] } /** @inheritdoc */ - override def builder(): SparkSessionBuilder = companion.builder() + override def builder(): Builder = new Builder /** @inheritdoc */ - override def setActiveSession(session: SparkSession): Unit = - companion.setActiveSession(session.asInstanceOf[companion.Session]) + override def setActiveSession(session: SparkSession): Unit = super.setActiveSession(session) /** @inheritdoc */ - override def clearActiveSession(): Unit = companion.clearActiveSession() + override def setDefaultSession(session: SparkSession): Unit = super.setDefaultSession(session) /** @inheritdoc */ - override def setDefaultSession(session: SparkSession): Unit = - companion.setDefaultSession(session.asInstanceOf[companion.Session]) + override def getActiveSession: Option[SparkSession] = super.getActiveSession /** @inheritdoc */ - override def clearDefaultSession(): Unit = companion.clearDefaultSession() + override def getDefaultSession: Option[SparkSession] = super.getDefaultSession - /** @inheritdoc */ - override def getActiveSession: Option[SparkSession] = companion.getActiveSession + override protected def tryCastToImplementation(session: SparkSession): Option[SparkSession] = + Some(session) - /** @inheritdoc */ - override def getDefaultSession: Option[SparkSession] = companion.getDefaultSession + class Builder extends SparkSessionBuilder { + import SparkSessionBuilder._ + private val extensionModifications = mutable.Buffer.empty[SparkSessionExtensions => Unit] + private var sc: Option[SparkContext] = None + private var companion: SparkSessionCompanion = DEFAULT_COMPANION + + /** @inheritdoc */ + override def appName(name: String): this.type = super.appName(name) + + /** @inheritdoc */ + override def master(master: String): this.type = super.master(master) + + /** @inheritdoc */ + override def enableHiveSupport(): this.type = super.enableHiveSupport() + + /** @inheritdoc */ + override def config(key: String, value: String): this.type = super.config(key, value) + + /** @inheritdoc */ + override def config(key: String, value: Long): this.type = super.config(key, value) + + /** @inheritdoc */ + override def config(key: String, value: Double): this.type = super.config(key, value) + + /** @inheritdoc */ + override def config(key: String, value: Boolean): this.type = super.config(key, value) + + /** @inheritdoc */ + override def config(map: Map[String, Any]): this.type = super.config(map) + + /** @inheritdoc */ + override def config(map: util.Map[String, Any]): this.type = super.config(map) + + /** @inheritdoc */ + override def config(conf: SparkConf): this.type = super.config(conf) + + /** @inheritdoc */ + override def remote(connectionString: String): this.type = super.remote(connectionString) + + /** @inheritdoc */ + override def withExtensions(f: SparkSessionExtensions => Unit): this.type = synchronized { + extensionModifications += f + this + } + + /** @inheritdoc */ + override private[spark] def sparkContext(sparkContext: SparkContext): this.type = + synchronized { + sc = Option(sparkContext) + this + } + + /** + * Make the builder create a Classic SparkSession. + */ + def classic(): this.type = mode(CONNECT_COMPANION) + + /** + * Make the builder create a Connect SparkSession. + */ + def connect(): this.type = mode(CONNECT_COMPANION) Review Comment: is it the same as setting the Spark Connect config? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org