#36916: Add support for streaming with TaskGroups
-------------------------------------+-------------------------------------
     Reporter:  Thomas Grainger      |                    Owner:  Thomas
                                     |  Grainger
         Type:  New feature          |                   Status:  assigned
    Component:  HTTP handling        |                  Version:  dev
     Severity:  Normal               |               Resolution:
     Keywords:  structured-          |             Triage Stage:  Accepted
  concurrency, taskgroups            |
    Has patch:  0                    |      Needs documentation:  1
  Needs tests:  0                    |  Patch needs improvement:  0
Easy pickings:  0                    |                    UI/UX:  0
-------------------------------------+-------------------------------------
Comment (by Thomas Grainger):

 How about this approach? Working from my phone so posting as a patch
 rather than pushing directly.

 == Summary of changes (from Claude) ==

 **`django/http/response.py`**

 - `_set_streaming_content` detects async context managers via
 `hasattr(__aenter__, __aexit__)`, sets `self.is_acmgr = True` and stores
 the value in `self.__acmgr`
 - `streaming_content` property: `is_acmgr` is an early return at the top
 (unnested from `is_async`). Returns an `@asynccontextmanager` when
 `is_acmgr` is True — callers that check `is_acmgr` use `async with
 response.streaming_content as agen`. Both acmgr and regular async paths
 apply `make_bytes` and `aclose` their underlying iterator in a `finally`
 block
 - `StreamingHttpResponse` gains `__aenter__`/`__aexit__`: for acmgr
 responses, `__aenter__` calls into `self.streaming_content` (no
 duplication of `awrapper` logic needed), stores the context and the
 yielded generator, then `__aexit__` closes the generator and exits the
 context in the right order so `TaskGroup` cleans up correctly
 - `__iter__`, `__aiter__`, and `getvalue` all raise `IsAcmgrException`
 when `is_acmgr` is True
 - `IsAcmgrException` is exported from `django.http`

 **`django/core/handlers/asgi.py`**

 - `send_response` simplified to `async with response as content` — works
 for both regular streaming and acmgr responses. `aclosing` no longer
 needed

 **`django/middleware/gzip.py`**

 - Adds `is_acmgr` branch: captures `response.streaming_content` (an acmgr)
 and wraps it in a new `@asynccontextmanager` that feeds the yielded
 generator into `acompress_sequence`

 **`django/utils/text.py`**

 - `acompress_sequence` wraps its entire body in `try/finally` to `aclose`
 its operand on exit if the method is present

 {{{#!diff
 diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py
 index 9555860..0f55717 100644
 --- a/django/core/handlers/asgi.py
 +++ b/django/core/handlers/asgi.py
 @@ -4,7 +4,7 @@ import sys
  import tempfile
  import traceback
  from collections import defaultdict
 -from contextlib import aclosing, closing
 +from contextlib import closing

  from asgiref.sync import ThreadSensitiveContext, sync_to_async

 @@ -315,11 +315,7 @@ class ASGIHandler(base.BaseHandler):
          )
          # Streaming responses need to be pinned to their iterator.
          if response.streaming:
 -            # - Consume via `__aiter__` and not `streaming_content`
 directly,
 -            #   to allow mapping of a sync iterator.
 -            # - Use aclosing() when consuming aiter. See
 -            #
 
https://github.com/python/cpython/commit/6e8dcdaaa49d4313bf9fab9f9923ca5828fbb10e
 -            async with aclosing(aiter(response)) as content:
 +            async with response as content:
                  async for part in content:
                      for chunk, _ in self.chunk_bytes(part):
                          await send(
 diff --git a/django/http/__init__.py b/django/http/__init__.py
 index 628564e..2df8fa6 100644
 --- a/django/http/__init__.py
 +++ b/django/http/__init__.py
 @@ -8,6 +8,7 @@ from django.http.request import (
  )
  from django.http.response import (
      BadHeaderError,
 +    IsAcmgrException,
      FileResponse,
      Http404,
      HttpResponse,
 @@ -47,6 +48,7 @@ __all__ = [
      "HttpResponseServerError",
      "Http404",
      "BadHeaderError",
 +    "IsAcmgrException",
      "JsonResponse",
      "FileResponse",
  ]
 diff --git a/django/http/response.py b/django/http/response.py
 index 9bf0b14..e3976da 100644
 --- a/django/http/response.py
 +++ b/django/http/response.py
 @@ -7,6 +7,7 @@ import re
  import sys
  import time
  import warnings
 +from contextlib import asynccontextmanager
  from email.header import Header
  from http.client import responses
  from urllib.parse import urlsplit
 @@ -104,6 +105,10 @@ class BadHeaderError(ValueError):
      pass


 +class IsAcmgrException(Exception):
 +    pass
 +
 +
  class HttpResponseBase:
      """
      An HTTP response base class with dictionary-accessed headers.
 @@ -479,14 +484,36 @@ class StreamingHttpResponse(HttpResponseBase):

      @property
      def streaming_content(self):
 +        if self.is_acmgr:
 +            # Pull to lexical scope in case streaming_content is set
 again.
 +            _acmgr = self.__acmgr
 +
 +            @asynccontextmanager
 +            async def acmgr_wrapper():
 +                async with _acmgr as agen:
 +                    try:
 +                        async def awrapper():
 +                            async for part in agen:
 +                                yield self.make_bytes(part)
 +
 +                        yield awrapper()
 +                    finally:
 +                        if hasattr(agen, "aclose"):
 +                            await agen.aclose()
 +
 +            return acmgr_wrapper()
          if self.is_async:
              # pull to lexical scope to capture fixed reference in case
              # streaming_content is set again later.
              _iterator = self._iterator

              async def awrapper():
 -                async for part in _iterator:
 -                    yield self.make_bytes(part)
 +                try:
 +                    async for part in _iterator:
 +                        yield self.make_bytes(part)
 +                finally:
 +                    if hasattr(_iterator, "aclose"):
 +                        await _iterator.aclose()

              return awrapper()
          else:
 @@ -498,6 +525,12 @@ class StreamingHttpResponse(HttpResponseBase):

      def _set_streaming_content(self, value):
          # Ensure we can never iterate on "value" more than once.
 +        if hasattr(value, "__aenter__") and hasattr(value, "__aexit__"):
 +            self.__acmgr = value
 +            self.is_acmgr = True
 +            self.is_async = True
 +            return
 +        self.is_acmgr = False
          try:
              self._iterator = iter(value)
              self.is_async = False
 @@ -507,7 +540,28 @@ class StreamingHttpResponse(HttpResponseBase):
          if hasattr(value, "close"):
              self._resource_closers.append(value.close)

 +    async def __aenter__(self):
 +        if self.is_acmgr:
 +            self.__acmgr_ctx = self.streaming_content
 +            self.__agen = await self.__acmgr_ctx.__aenter__()
 +        else:
 +            self.__agen = aiter(self)
 +        return self.__agen
 +
 +    async def __aexit__(self, *exc_info):
 +        try:
 +            await self.__agen.aclose()
 +        finally:
 +            if self.is_acmgr:
 +                return await self.__acmgr_ctx.__aexit__(*exc_info)
 +        return None
 +
      def __iter__(self):
 +        if self.is_acmgr:
 +            raise IsAcmgrException(
 +                "%s must be consumed via `async with`. Use `async with
 response` "
 +                "and iterate the yielded content." %
 self.__class__.__name__
 +            )
          try:
              return iter(self.streaming_content)
          except TypeError:
 @@ -528,6 +582,11 @@ class StreamingHttpResponse(HttpResponseBase):
              return map(self.make_bytes,
 iter(async_to_sync(to_list)(self._iterator)))

      async def __aiter__(self):
 +        if self.is_acmgr:
 +            raise IsAcmgrException(
 +                "%s must be consumed via `async with`. Use `async with
 response` "
 +                "and iterate the yielded content." %
 self.__class__.__name__
 +            )
          try:
              async for part in self.streaming_content:
                  yield part
 @@ -544,6 +603,11 @@ class StreamingHttpResponse(HttpResponseBase):
                  yield part

      def getvalue(self):
 +        if self.is_acmgr:
 +            raise IsAcmgrException(
 +                "%s must be consumed via `async with`. Use `async with
 response` "
 +                "and iterate the yielded content." %
 self.__class__.__name__
 +            )
          return b"".join(self.streaming_content)


 diff --git a/django/middleware/gzip.py b/django/middleware/gzip.py
 index eb151d7..78b5739 100644
 --- a/django/middleware/gzip.py
 +++ b/django/middleware/gzip.py
 @@ -1,3 +1,5 @@
 +from contextlib import asynccontextmanager
 +
  from django.utils.cache import patch_vary_headers
  from django.utils.deprecation import MiddlewareMixin
  from django.utils.regex_helper import _lazy_re_compile
 @@ -31,7 +33,20 @@ class GZipMiddleware(MiddlewareMixin):
              return response

          if response.streaming:
 -            if response.is_async:
 +            if response.is_acmgr:
 +                original_acmgr = response.streaming_content
 +                max_random_bytes = self.max_random_bytes
 +
 +                @asynccontextmanager
 +                async def compressed_acmgr():
 +                    async with original_acmgr as agen:
 +                        yield acompress_sequence(
 +                            agen,
 +                            max_random_bytes=max_random_bytes,
 +                        )
 +
 +                response.streaming_content = compressed_acmgr()
 +            elif response.is_async:
                  response.streaming_content = acompress_sequence(
                      response.streaming_content,
                      max_random_bytes=self.max_random_bytes,
 diff --git a/django/utils/text.py b/django/utils/text.py
 index d1306f9..55bd6f5 100644
 --- a/django/utils/text.py
 +++ b/django/utils/text.py
 @@ -390,20 +390,24 @@ def compress_sequence(sequence, *,
 max_random_bytes=None):


  async def acompress_sequence(sequence, *, max_random_bytes=None):
 -    buf = StreamingBuffer()
 -    filename = _get_random_filename(max_random_bytes) if max_random_bytes
 else None
 -    with GzipFile(
 -        filename=filename, mode="wb", compresslevel=6, fileobj=buf,
 mtime=0
 -    ) as zfile:
 -        # Output headers...
 +    try:
 +        buf = StreamingBuffer()
 +        filename = _get_random_filename(max_random_bytes) if
 max_random_bytes else None
 +        with GzipFile(
 +            filename=filename, mode="wb", compresslevel=6, fileobj=buf,
 mtime=0
 +        ) as zfile:
 +            # Output headers...
 +            yield buf.read()
 +            async for item in sequence:
 +                zfile.write(item)
 +                zfile.flush()
 +                data = buf.read()
 +                if data:
 +                    yield data
          yield buf.read()
 -        async for item in sequence:
 -            zfile.write(item)
 -            zfile.flush()
 -            data = buf.read()
 -            if data:
 -                yield data
 -    yield buf.read()
 +    finally:
 +        if hasattr(sequence, "aclose"):
 +            await sequence.aclose()
 }}}
-- 
Ticket URL: <https://code.djangoproject.com/ticket/36916#comment:8>
Django <https://code.djangoproject.com/>
The Web framework for perfectionists with deadlines.

-- 
You received this message because you are subscribed to the Google Groups 
"Django updates" group.
To unsubscribe from this group and stop receiving emails from it, send an email 
to [email protected].
To view this discussion visit 
https://groups.google.com/d/msgid/django-updates/0107019d6c1b1b94-a77015d5-81a9-446b-95bd-0b4412079e37-000000%40eu-central-1.amazonses.com.

Reply via email to