cloud-fan commented on code in PR #50587:
URL: https://github.com/apache/spark/pull/50587#discussion_r2050506132


##########
core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala:
##########
@@ -717,35 +718,50 @@ class BytesToString extends 
org.apache.spark.api.java.function.Function[Array[By
  * collects a list of pickled strings that we pass to Python through a socket.
  */
 private[spark] class PythonAccumulatorV2(
-    @transient private val serverHost: String,
-    private val serverPort: Int,
-    private val secretToken: String)
+    @transient private val serverHost: Option[String],
+    private val serverPort: Option[Int],
+    private val secretToken: Option[String],
+    @transient private val socketPath: Option[String])
   extends CollectionAccumulator[Array[Byte]] with Logging {
 
-  Utils.checkHost(serverHost)
+  // Unix domain socket
+  def this(socketPath: String) = this(None, None, None, Some(socketPath))
+  // TPC socket
+  def this(serverHost: String, serverPort: Int, secretToken: String) = this(
+    Some(serverHost), Some(serverPort), Some(secretToken), None)
+
+  serverHost.foreach(Utils.checkHost)
 
   val bufferSize = SparkEnv.get.conf.get(BUFFER_SIZE)
+  val isUnixDomainSock = 
SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
 
   /**
    * We try to reuse a single Socket to transfer accumulator updates, as they 
are all added
    * by the DAGScheduler's single-threaded RpcEndpoint anyway.
    */
-  @transient private var socket: Socket = _
+  @transient private var socket: SocketChannel = _
 
-  private def openSocket(): Socket = synchronized {
-    if (socket == null || socket.isClosed) {
-      socket = new Socket(serverHost, serverPort)
-      logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, 
serverHost)}" +
-        log" port: ${MDC(PORT, serverPort)}")
+  private def openSocket(): SocketChannel = synchronized {
+    if (socket == null || !socket.isOpen) {
+      if (isUnixDomainSock) {
+        socket = SocketChannel.open(UnixDomainSocketAddress.of(socketPath.get))
+        logInfo(log"Connected to AccumulatorServer at socket: 
${MDC(SOCKET_ADDRESS, serverHost)}")
+      } else {
+        socket = SocketChannel.open(new InetSocketAddress(serverHost.get, 
serverPort.get))
+        logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, 
serverHost)}" +
+          log" port: ${MDC(PORT, serverPort)}")

Review Comment:
   serverPort.get



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to