cloud-fan commented on code in PR #50587: URL: https://github.com/apache/spark/pull/50587#discussion_r2050505986
########## core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala: ########## @@ -717,35 +718,50 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By * collects a list of pickled strings that we pass to Python through a socket. */ private[spark] class PythonAccumulatorV2( - @transient private val serverHost: String, - private val serverPort: Int, - private val secretToken: String) + @transient private val serverHost: Option[String], + private val serverPort: Option[Int], + private val secretToken: Option[String], + @transient private val socketPath: Option[String]) extends CollectionAccumulator[Array[Byte]] with Logging { - Utils.checkHost(serverHost) + // Unix domain socket + def this(socketPath: String) = this(None, None, None, Some(socketPath)) + // TPC socket + def this(serverHost: String, serverPort: Int, secretToken: String) = this( + Some(serverHost), Some(serverPort), Some(secretToken), None) + + serverHost.foreach(Utils.checkHost) val bufferSize = SparkEnv.get.conf.get(BUFFER_SIZE) + val isUnixDomainSock = SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED) /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded RpcEndpoint anyway. */ - @transient private var socket: Socket = _ + @transient private var socket: SocketChannel = _ - private def openSocket(): Socket = synchronized { - if (socket == null || socket.isClosed) { - socket = new Socket(serverHost, serverPort) - logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, serverHost)}" + - log" port: ${MDC(PORT, serverPort)}") + private def openSocket(): SocketChannel = synchronized { + if (socket == null || !socket.isOpen) { + if (isUnixDomainSock) { + socket = SocketChannel.open(UnixDomainSocketAddress.of(socketPath.get)) + logInfo(log"Connected to AccumulatorServer at socket: ${MDC(SOCKET_ADDRESS, serverHost)}") Review Comment: serverHost.get ########## core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala: ########## @@ -717,35 +718,50 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By * collects a list of pickled strings that we pass to Python through a socket. */ private[spark] class PythonAccumulatorV2( - @transient private val serverHost: String, - private val serverPort: Int, - private val secretToken: String) + @transient private val serverHost: Option[String], + private val serverPort: Option[Int], + private val secretToken: Option[String], + @transient private val socketPath: Option[String]) extends CollectionAccumulator[Array[Byte]] with Logging { - Utils.checkHost(serverHost) + // Unix domain socket + def this(socketPath: String) = this(None, None, None, Some(socketPath)) + // TPC socket + def this(serverHost: String, serverPort: Int, secretToken: String) = this( + Some(serverHost), Some(serverPort), Some(secretToken), None) + + serverHost.foreach(Utils.checkHost) val bufferSize = SparkEnv.get.conf.get(BUFFER_SIZE) + val isUnixDomainSock = SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED) /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded RpcEndpoint anyway. */ - @transient private var socket: Socket = _ + @transient private var socket: SocketChannel = _ - private def openSocket(): Socket = synchronized { - if (socket == null || socket.isClosed) { - socket = new Socket(serverHost, serverPort) - logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, serverHost)}" + - log" port: ${MDC(PORT, serverPort)}") + private def openSocket(): SocketChannel = synchronized { + if (socket == null || !socket.isOpen) { + if (isUnixDomainSock) { + socket = SocketChannel.open(UnixDomainSocketAddress.of(socketPath.get)) + logInfo(log"Connected to AccumulatorServer at socket: ${MDC(SOCKET_ADDRESS, serverHost)}") + } else { + socket = SocketChannel.open(new InetSocketAddress(serverHost.get, serverPort.get)) + logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, serverHost)}" + Review Comment: serverHost.get -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org