import apache_beam as beam
from apache_beam.ml.inference.tensorrt_inference import
TensorRTEngineHandlerNumPy
from apache_beam.ml.inference.base import RunInference

#!/usr/bin/env python3
"""
Apache Beam pipeline for processing PDFs with Triton server and saving
results to BigQuery.
This pipeline combines functionality from test_triton_document.py,
create_bigquery_tables.py,
and save_to_bigquery.py into a single workflow.
"""

import os
import sys
import json
import uuid
import argparse
import logging
import tempfile
import datetime
import requests
import numpy as np
import cv2
from PIL import Image
import fitz  # PyMuPDF
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional, Iterator

# Apache Beam imports
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions,
SetupOptions
from apache_beam.ml.inference.base import RemoteModelHandler,
PredictionResult
from apache_beam.ml.inference.utils import _convert_to_result
from apache_beam.ml.inference.base import RunInference
from apache_beam.io.gcp.bigquery import WriteToBigQuery
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.gcsio import GcsIO

# Google Cloud imports
from google.cloud import storage
from google.cloud import bigquery

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# DocLayNet classes
CLASS_ID_TO_NAME = {
    0: 'Caption',
    1: 'Footnote',
    2: 'Formula',
    3: 'List-item',
    4: 'Page-footer',
    5: 'Page-header',
    6: 'Picture',
    7: 'Section-header',
    8: 'Table',
    9: 'Text',
    10: 'Title'
}
class DownloadPDFFromGCS(beam.DoFn):
    """Download a PDF from Google Cloud Storage."""

    def __init__(self, temp_dir=None):
        self.temp_dir = temp_dir or tempfile.gettempdir()

    def process(self, gcs_uri):
        try:
            # Parse GCS URI
            if not gcs_uri.startswith("gs://"):
                raise ValueError(f"Invalid GCS URI: {gcs_uri}")

            # Remove gs:// prefix and split into bucket and blob path
            path_parts = gcs_uri[5:].split("/", 1)
            bucket_name = path_parts[0]
            blob_path = path_parts[1]

            # Get filename from blob path
            filename = os.path.basename(blob_path)
            local_path = os.path.join(self.temp_dir, filename)

            # Create temp directory if it doesn't exist
            os.makedirs(self.temp_dir, exist_ok=True)

            try:
                # Download using Beam's GcsIO
                with FileSystems.open(gcs_uri, 'rb') as gcs_file:
                    with open(local_path, 'wb') as local_file:
                        local_file.write(gcs_file.read())

                logger.info(f"Downloaded {gcs_uri} to {local_path}")

                # Return a dictionary with the local path and original URI
                yield {
                    'local_path': local_path,
                    'gcs_uri': gcs_uri,
                    'filename': filename
                }
            except Exception as e:
                logger.error(f"Error reading from GCS: {str(e)}")
                # Try alternative download method
                logger.info(f"Trying alternative download method for
{gcs_uri}")

                # For testing with local files
                if os.path.exists(gcs_uri.replace("gs://", "")):
                    local_path = gcs_uri.replace("gs://", "")
                    logger.info(f"Using local file: {local_path}")
                    yield {
                        'local_path': local_path,
                        'gcs_uri': gcs_uri,
                        'filename': os.path.basename(local_path)
                    }
                else:
                    # Try using gsutil command
                    import subprocess
                    try:
                        subprocess.run(["gsutil", "cp", gcs_uri,
local_path], check=True)
                        logger.info(f"Downloaded {gcs_uri} to {local_path}
using gsutil")
                        yield {
                            'local_path': local_path,
                            'gcs_uri': gcs_uri,
                            'filename': filename
                        }
                    except Exception as e2:
                        logger.error(f"Failed to download using gsutil:
{str(e2)}")

        except Exception as e:
            logger.error(f"Error downloading {gcs_uri}: {str(e)}")
class LoadPDFPages(beam.DoFn):
    """Load PDF pages as images."""

    def __init__(self, dpi=200):
        self.dpi = dpi

    def process(self, element):
        doc = None
        try:
            # Make sure we have all required fields
            if not isinstance(element, dict):
                logger.error(f"Expected dictionary, got {type(element)}")
                return

            if 'local_path' not in element:
                logger.error("Missing 'local_path' in element")
                return

            local_path = element['local_path']
            gcs_uri = element.get('gcs_uri', '')

            # Extract filename from local_path if not provided
            filename = element.get('filename', os.path.basename(local_path))

            logger.info(f"Loading PDF: {local_path}, filename: {filename}")

            # Check if file exists and is accessible
            if not os.path.exists(local_path):
                logger.error(f"File not found: {local_path}")
                return

            if not os.access(local_path, os.R_OK):
                logger.error(f"File not readable: {local_path}")
                return

            # Open the PDF
            try:
                doc = fitz.open(local_path)
                if doc.is_closed:
                    logger.error(f"Failed to open PDF: {local_path}")
                    return
            except Exception as e:
                logger.error(f"Error opening PDF {local_path}: {str(e)}")
                return

            # Process each page
            page_count = len(doc)
            logger.info(f"Processing {page_count} pages from {local_path}")

            for i in range(page_count):
                try:
                    if doc.is_closed:
                        logger.error(f"Document was closed unexpectedly
while processing page {i}")
                        break

                    page = doc[i]
                    if page is None:
                        logger.error(f"Failed to get page {i} from
document")
                        continue

                    # Use a higher resolution for better quality
                    scale = self.dpi / 72.0
                    mat = fitz.Matrix(scale, scale)

                    try:
                        pix = page.get_pixmap(matrix=mat, alpha=False)
                    except Exception as e:
                        logger.error(f"Error getting pixmap for page {i}:
{str(e)}")
                        continue

                    # Check pixmap dimensions
                    if pix.height <= 0 or pix.width <= 0 or pix.n <= 0:
                        logger.error(f"Invalid pixmap dimensions:
{pix.width}x{pix.height}x{pix.n}")
                        continue

                    # Convert to numpy array
                    try:
                        arr = np.frombuffer(pix.samples,
dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
                    except Exception as e:
                        logger.error(f"Error converting pixmap to numpy
array: {str(e)}")
                        continue

                    # Convert BGR to RGB if needed
                    if pix.n == 3:  # RGB
                        try:
                            arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
                        except Exception as e:
                            logger.error(f"Error converting BGR to RGB:
{str(e)}")
                            continue

                    # Store original size for later use
                    original_size = (arr.shape[0], arr.shape[1])

                    # Create page info
                    page_info = {
                        'page_num': i,
                        'image': arr,
                        'original_size': original_size,
                        'local_path': local_path,
                        'gcs_uri': gcs_uri,
                        'filename': filename
                    }

                    # Use document ID and page number as key
                    doc_id = os.path.splitext(filename)[0]
                    key = f"{doc_id}_{i}"

                    yield (key, page_info)
                except Exception as e:
                    import traceback
                    logger.error(f"Error processing page {i}: {str(e)}")
                    logger.error(traceback.format_exc())

            logger.info(f"Loaded {len(doc)} pages from {local_path}")

        except Exception as e:
            import traceback
            logger.error(f"Error loading PDF: {str(e)}")
            logger.error(traceback.format_exc())
        finally:
            # Make sure to close the document only if it was successfully
opened
            if doc is not None:
                try:
                    if not doc.is_closed:
                        doc.close()
                except Exception as e:
                    logger.debug(f"Error closing document: {str(e)}")

class PreprocessImage(beam.DoFn):
    """Preprocess image for Triton server."""

    def __init__(self, size=1024):
        self.size = size

    def letterbox(self, img, new_shape=1024, color=(114,114,114)):
        """Resize and pad image to target size."""
        h, w = img.shape[:2]
        r = min(new_shape / h, new_shape / w)
        nh, nw = int(round(h * r)), int(round(w * r))
        pad_h, pad_w = new_shape - nh, new_shape - nw
        top = pad_h // 2
        bottom = pad_h - top
        left = pad_w // 2
        right = pad_w - left
        img = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_LINEAR)
        img = cv2.copyMakeBorder(img, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=color)
        return img, r, left, top

    def process(self, element):
        try:
            if not isinstance(element, tuple) or len(element) != 2:
                logger.error(f"Expected (key, value) tuple, got
{type(element)}")
                return

            key, page_info = element

            if not isinstance(page_info, dict):
                logger.error(f"Expected dictionary for page_info, got
{type(page_info)}")
                return

            if 'image' not in page_info:
                logger.error("Missing 'image' in page_info")
                return

            # Create a new dictionary to avoid modifying the input
            new_page_info = dict(page_info)

            # Apply letterbox resize
            img = new_page_info['image']
            lb, r, left, top = self.letterbox(img, new_shape=self.size)

            # Convert to float32 and normalize to [0,1]
            x = lb.astype(np.float32) / 255.0

            # Convert to CHW format
            x = np.transpose(x, (2, 0, 1))

            # Add batch dimension
            batched_img = np.expand_dims(x, axis=0)

            # Update page info
            new_page_info['preprocessed_image'] = batched_img
            new_page_info['letterbox_info'] = (r, left, top)

            yield (key, new_page_info)

        except Exception as e:
            import traceback
            logger.error(f"Error preprocessing image: {str(e)}")
            logger.error(traceback.format_exc())



class ExtractBoxes(beam.DoFn):
    """Extract bounding boxes from Triton response."""

    def __init__(self, conf_th=0.25, iou_th=0.7, model_size=1024):
        self.conf_th = conf_th
        self.iou_th = iou_th
        self.model_size = model_size

    def _nms(self, boxes, scores, iou_th=0.7):
        """Non-Maximum Suppression"""
        if len(boxes) == 0:
            return []

        boxes = boxes.astype(np.float32)
        x1, y1, x2, y2 = boxes.T
        areas = (x2 - x1) * (y2 - y1)
        order = scores.argsort()[::-1]

        keep = []
        while order.size > 0:
            i = order[0]
            keep.append(i)

            xx1 = np.maximum(x1[i], x1[order[1:]])
            yy1 = np.maximum(y1[i], y1[order[1:]])
            xx2 = np.minimum(x2[i], x2[order[1:]])
            yy2 = np.minimum(y2[i], y2[order[1:]])

            w = np.maximum(0.0, xx2 - xx1)
            h = np.maximum(0.0, yy2 - yy1)
            inter = w * h

            iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-9)
            inds = np.where(iou <= iou_th)[0]
            order = order[inds + 1]

        return keep

    def process(self, page_info):
        try:
            triton_response = page_info['triton_response']
            original_size = page_info['original_size']
            r, left, top = page_info['letterbox_info']

            if "outputs" not in triton_response or not
triton_response["outputs"]:
                logger.error("Invalid response from Triton server")
                return []

            out_meta = triton_response["outputs"][0]
            shape = out_meta["shape"]
            data = np.array(out_meta["data"],
dtype=np.float32).reshape(shape)

            logger.info(f"Output shape: {shape}")

            # For YOLO output [B, C, P] where C is channels (box coords +
objectness + classes)
            B, C, P = shape

            # Assuming 4 box coordinates + class probabilities (no
objectness)
            has_objectness = False
            num_classes = C - 5 if has_objectness else C - 4

            # Extract data
            xywh = data[:, 0:4, :]
            if has_objectness:
                obj = data[:, 4:5, :]
                cls = data[:, 5:5 + num_classes, :]
            else:
                obj = None
                cls = data[:, 4:4 + num_classes, :]

            # Process batch item (we only have one)
            b = 0
            h, w = original_size

            xywh_b = xywh[b].T  # (P,4)
            if obj is not None:
                obj_b = obj[b].T.squeeze(1)  # (P,)
            else:
                obj_b = np.ones((P,), dtype=np.float32)
            cls_b = cls[b].T  # (P,nc)

            # Get scores and labels
            scores_all = (obj_b[:, None] * cls_b) if obj is not None else
cls_b
            labels = scores_all.argmax(axis=1)
            scores = scores_all.max(axis=1)

            # Filter by confidence threshold
            keep = scores >= self.conf_th
            if not np.any(keep):
                logger.info(f"No detections above threshold {self.conf_th}")
                return []

            xywh_k = xywh_b[keep]
            scores_k = scores[keep]
            labels_k = labels[keep]

            # xywh -> xyxy in model space
            cx, cy, ww, hh = xywh_k.T
            xyxy_model = np.stack([cx - ww / 2, cy - hh / 2, cx + ww / 2,
cy + hh / 2], axis=1)

            # Apply NMS per class
            final_boxes = []
            final_scores = []
            final_labels = []

            for c in np.unique(labels_k):
                idxs = np.where(labels_k == c)[0]
                if idxs.size == 0:
                    continue
                keep_idx = self._nms(xyxy_model[idxs], scores_k[idxs],
iou_th=self.iou_th)
                final_boxes.append(xyxy_model[idxs][keep_idx])
                final_scores.append(scores_k[idxs][keep_idx])
                final_labels.append(np.full(len(keep_idx), c, dtype=int))

            if not final_boxes:
                logger.info("No detections after NMS")
                return []

            xyxy_model = np.vstack(final_boxes)
            scores_k = np.concatenate(final_scores)
            labels_k = np.concatenate(final_labels)

            # Map boxes from model space to original image space
            xyxy_orig = xyxy_model.copy()

            # Remove padding
            xyxy_orig[:, [0, 2]] -= left
            xyxy_orig[:, [1, 3]] -= top

            # Scale back to original size
            xyxy_orig /= r

            # Clip to image boundaries
            xyxy_orig[:, 0::2] = np.clip(xyxy_orig[:, 0::2], 0, w - 1)
            xyxy_orig[:, 1::2] = np.clip(xyxy_orig[:, 1::2], 0, h - 1)

            # Format as requested: x_min, y_min, x_max, y_max, class,
probability
            boxes = []
            for (x1, y1, x2, y2), label, score in zip(xyxy_orig, labels_k,
scores_k):
                class_name = CLASS_ID_TO_NAME.get(int(label))
                box_info = {
                    "page": page_info['page_num'],
                    "x_min": float(x1),
                    "y_min": float(y1),
                    "x_max": float(x2),
                    "y_max": float(y2),
                    "class": int(label),
                    "class_name": class_name,
                    "probability": float(score),
                    "filename": page_info['filename'],
                    "local_path": page_info['local_path'],
                    "gcs_uri": page_info['gcs_uri']
                }
                boxes.append(box_info)

            logger.info(f"Extracted {len(boxes)} boxes from page
{page_info['page_num']}")

            return boxes

        except Exception as e:
            logger.error(f"Error extracting boxes: {str(e)}")
            return []

class PrepareForBigQuery(beam.DoFn):
    """Prepare data for BigQuery insertion."""

    def process(self, box_info):
        try:
            # Generate UUIDs for primary keys
            v_note_id = str(uuid.uuid4())
            page_ocr_id = str(uuid.uuid4())
            class_prediction_id = str(uuid.uuid4())

            # Create timestamp
            processing_time = datetime.datetime.now().strftime("%Y-%m-%d
%H:%M:%S")

            # Create ocr_results row
            ocr_results_row = {
                "v_note_id": v_note_id,
                "filename": box_info['filename'],
                "file_path": box_info['gcs_uri'],
                "processing_time": processing_time,
                "file_type": "pdf"
            }

            # Create page_ocr row
            page_ocr_row = {
                "page_ocr_id": page_ocr_id,
                "v_note_id": v_note_id,
                "page_number": box_info['page']
            }

            # Create class_prediction row
            class_prediction_row = {
                "class_prediction_id": class_prediction_id,
                "page_ocr_id": page_ocr_id,
                "xmin": box_info['x_min'],
                "ymin": box_info['y_min'],
                "xmax": box_info['x_max'],
                "ymax": box_info['y_max'],
                "class": box_info['class_name'] if box_info['class_name']
else str(box_info['class']),
                "confidence": box_info['probability']
            }

            # Return all three rows with table names
            return [
                ('ocr_results', ocr_results_row),
                ('page_ocr', page_ocr_row),
                ('class_prediction', class_prediction_row)
            ]

        except Exception as e:
            logger.error(f"Error preparing for BigQuery: {str(e)}")
            return []

model_handler = TensorRTEngineHandlerNumPy(
  min_batch_size=1,
  max_batch_size=1,
  engine_path="gs://temp/yolov11l-doclaynet.engine",
)


with beam.Pipeline(options=options) as pipeline:

        # Create PCollection from input URIs
        pdf_uris = (
            pipeline
            | "Create URIs" >> beam.Create(["tmp.pdf"])
        )

        # Download PDFs
        local_pdfs = (
            pdf_uris
            | "Download PDFs" >> beam.ParDo(DownloadPDFFromGCS())
        )

         # Load PDF pages
        pdf_pages = (
            local_pdfs
            | "Load PDF Pages" >> beam.ParDo(LoadPDFPages())
            #| "Flatten Pages" >> beam.FlatMap(lambda x: x)
        )

        # Preprocess images
        preprocessed_pages = (
            pdf_pages
            | "Preprocess Images" >> beam.ParDo(PreprocessImage())
        )
        inference_results = (
            preprocessed_pages
            | "Run Inference" >> RunInference(model_handler=model_handler)
        )

On Tue, 16 Sept 2025 at 21:23, XQ Hu <[email protected]> wrote:

> Can you share your commands and outputs?
>
> On Tue, Sep 16, 2025 at 9:02 PM Sai Shashank <[email protected]>
> wrote:
>
>> Okay I have changed the docker image but  to now to RUN the python
>> command but it is still halting without are error or warnings or errors
>>
>> On Tue, 16 Sept 2025 at 17:38, XQ Hu via dev <[email protected]> wrote:
>>
>>> The CMD is not necessary as it will be overridden by the ENTRYPOINT just
>>> like your comment.
>>>
>>> If you ssh to your Docker container like `docker run --rm -it
>>> --entrypoint=/bin/bash $CUSTOM_CONTAINER_IMAGE`, can you run python and
>>> some Beam pipelines with a direct runner in the container? This can help
>>> test the environment works fine.
>>>
>>> I have one old Dockerfile that used to work with the old Beam:
>>> https://github.com/google/dataflow-ml-starter/blob/main/tensor_rt.Dockerfile
>>> .
>>>
>>> On Tue, Sep 16, 2025 at 4:56 PM Sai Shashank <[email protected]>
>>> wrote:
>>>
>>>>
>>>>
>>>> ---------- Forwarded message ---------
>>>> From: Sai Shashank <[email protected]>
>>>> Date: Tue, Sep 16, 2025 at 4:27 PM
>>>> Subject: TensorRT inference not starting
>>>> To: <[email protected]>
>>>>
>>>>
>>>> Hey Everyone,
>>>>                          I was trying to use tensorRT within the apache
>>>> beam on dataflow but somehow , dataflow didn't start like it did not even
>>>> give me Worker logs. Below is the docker file that , use to create a
>>>> custom  image, at first I thought it is the version mismatched but usually
>>>> it gives me a harness error .
>>>>
>>>> ARG BUILD_IMAGE=nvcr.io/nvidia/tensorrt:25.08-py3
>>>> FROM ${BUILD_IMAGE}
>>>> ENV PATH="/usr/src/tensorrt/bin:${PATH}"
>>>>
>>>> WORKDIR /workspace
>>>>
>>>> RUN apt-get update -y && apt-get install -y python3-venv
>>>> RUN pip install --no-cache-dir apache-beam[gcp]==2.67.0
>>>>
>>>> COPY --from=apache/beam_python3.10_sdk:2.67.0 /opt/apache/beam
>>>> /opt/apache/beam
>>>>
>>>> # Install additional dependencies
>>>> RUN pip install --upgrade pip \
>>>>     && pip install torch \
>>>>     && pip install torchvision \
>>>>     && pip install pillow>=8.0.0 \
>>>>     && pip install transformers>=4.18.0 \
>>>>     && pip install cuda-python \
>>>>     && pip install opencv-python==4.7.0.72 \
>>>>     && pip install PyMuPDF==1.22.5 \
>>>>     && pip install requests==2.31.0
>>>>
>>>> # Set the default command to run the inference script
>>>> # This will be overridden by the Apache Beam boot script
>>>> CMD ["python", "/workspace/inference.py"]
>>>>
>>>> # Use the Apache Beam boot script as the entrypoint
>>>> ENTRYPOINT ["/opt/apache/beam/boot"]
>>>>
>>>>
>>>>

Reply via email to