HyukjinKwon commented on code in PR #50466:
URL: https://github.com/apache/spark/pull/50466#discussion_r2025806593


##########
python/pyspark/util.py:
##########
@@ -731,34 +731,59 @@ def __del__(self) -> None:
     return iter(PyLocalIterable(sock_info, serializer))
 
 
-def local_connect_and_auth(port: Optional[Union[str, int]], auth_secret: str) 
-> Tuple:
+def local_connect_and_auth(
+    conn_info: Optional[Union[str, int]], auth_secret: Optional[str]
+) -> Tuple:
     """
     Connect to local host, authenticate with it, and return a (sockfile,sock) 
for that connection.
     Handles IPV4 & IPV6, does some error handling.
 
     Parameters
     ----------
     port : str or int, optional
-    auth_secret : str
+    auth_secret : str, optional
 
     Returns
     -------
     tuple
         with (sockfile, sock)
     """
+    is_unix_domain_socket = str(conn_info) and auth_secret is None
+    if is_unix_domain_socket:
+        sock_path = conn_info
+        assert isinstance(sock_path, str)
+        sock = None
+        try:
+            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+            sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 
15)))
+            sock.connect(sock_path)
+            sockfile = sock.makefile("rwb", 
int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+            return (sockfile, sock)
+        except socket.error as e:
+            if sock is not None:
+                sock.close()
+            raise PySparkRuntimeError(
+                errorClass="CANNOT_OPEN_SOCKET",
+                messageParameters={
+                    "errors": "tried to connect to %s, but an error occurred: 
%s"
+                    % (sock_path, str(e)),
+                },
+            )
+
     sock = None
     errors = []
     # Support for both IPv4 and IPv6.
     addr = "127.0.0.1"
     if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
         addr = "::1"
-    for res in socket.getaddrinfo(addr, port, socket.AF_UNSPEC, 
socket.SOCK_STREAM):
+    for res in socket.getaddrinfo(addr, conn_info, socket.AF_UNSPEC, 
socket.SOCK_STREAM):
         af, socktype, proto, _, sa = res
         try:
             sock = socket.socket(af, socktype, proto)
             sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 
15)))
             sock.connect(sa)
             sockfile = sock.makefile("rwb", 
int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+            assert str(auth_secret)

Review Comment:
   For making liters happy ... :-)



##########
python/pyspark/util.py:
##########
@@ -731,34 +731,59 @@ def __del__(self) -> None:
     return iter(PyLocalIterable(sock_info, serializer))
 
 
-def local_connect_and_auth(port: Optional[Union[str, int]], auth_secret: str) 
-> Tuple:
+def local_connect_and_auth(
+    conn_info: Optional[Union[str, int]], auth_secret: Optional[str]
+) -> Tuple:
     """
     Connect to local host, authenticate with it, and return a (sockfile,sock) 
for that connection.
     Handles IPV4 & IPV6, does some error handling.
 
     Parameters
     ----------
     port : str or int, optional
-    auth_secret : str
+    auth_secret : str, optional
 
     Returns
     -------
     tuple
         with (sockfile, sock)
     """
+    is_unix_domain_socket = str(conn_info) and auth_secret is None
+    if is_unix_domain_socket:
+        sock_path = conn_info
+        assert isinstance(sock_path, str)
+        sock = None
+        try:
+            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+            sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 
15)))
+            sock.connect(sock_path)
+            sockfile = sock.makefile("rwb", 
int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+            return (sockfile, sock)
+        except socket.error as e:
+            if sock is not None:
+                sock.close()
+            raise PySparkRuntimeError(
+                errorClass="CANNOT_OPEN_SOCKET",
+                messageParameters={
+                    "errors": "tried to connect to %s, but an error occurred: 
%s"
+                    % (sock_path, str(e)),
+                },
+            )
+
     sock = None
     errors = []
     # Support for both IPv4 and IPv6.
     addr = "127.0.0.1"
     if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
         addr = "::1"
-    for res in socket.getaddrinfo(addr, port, socket.AF_UNSPEC, 
socket.SOCK_STREAM):
+    for res in socket.getaddrinfo(addr, conn_info, socket.AF_UNSPEC, 
socket.SOCK_STREAM):
         af, socktype, proto, _, sa = res
         try:
             sock = socket.socket(af, socktype, proto)
             sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 
15)))
             sock.connect(sa)
             sockfile = sock.makefile("rwb", 
int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+            assert str(auth_secret)

Review Comment:
   For making linters happy ... :-)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to