import apache_beam as beam
from apache_beam.ml.inference.tensorrt_inference import
TensorRTEngineHandlerNumPy
from apache_beam.ml.inference.base import RunInference
#!/usr/bin/env python3
"""
Apache Beam pipeline for processing PDFs with Triton server and saving
results to BigQuery.
This pipeline combines functionality from test_triton_document.py,
create_bigquery_tables.py,
and save_to_bigquery.py into a single workflow.
"""
import os
import sys
import json
import uuid
import argparse
import logging
import tempfile
import datetime
import requests
import numpy as np
import cv2
from PIL import Image
import fitz # PyMuPDF
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional, Iterator
# Apache Beam imports
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions,
SetupOptions
from apache_beam.ml.inference.base import RemoteModelHandler,
PredictionResult
from apache_beam.ml.inference.utils import _convert_to_result
from apache_beam.ml.inference.base import RunInference
from apache_beam.io.gcp.bigquery import WriteToBigQuery
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.gcsio import GcsIO
# Google Cloud imports
from google.cloud import storage
from google.cloud import bigquery
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# DocLayNet classes
CLASS_ID_TO_NAME = {
0: 'Caption',
1: 'Footnote',
2: 'Formula',
3: 'List-item',
4: 'Page-footer',
5: 'Page-header',
6: 'Picture',
7: 'Section-header',
8: 'Table',
9: 'Text',
10: 'Title'
}
class DownloadPDFFromGCS(beam.DoFn):
"""Download a PDF from Google Cloud Storage."""
def __init__(self, temp_dir=None):
self.temp_dir = temp_dir or tempfile.gettempdir()
def process(self, gcs_uri):
try:
# Parse GCS URI
if not gcs_uri.startswith("gs://"):
raise ValueError(f"Invalid GCS URI: {gcs_uri}")
# Remove gs:// prefix and split into bucket and blob path
path_parts = gcs_uri[5:].split("/", 1)
bucket_name = path_parts[0]
blob_path = path_parts[1]
# Get filename from blob path
filename = os.path.basename(blob_path)
local_path = os.path.join(self.temp_dir, filename)
# Create temp directory if it doesn't exist
os.makedirs(self.temp_dir, exist_ok=True)
try:
# Download using Beam's GcsIO
with FileSystems.open(gcs_uri, 'rb') as gcs_file:
with open(local_path, 'wb') as local_file:
local_file.write(gcs_file.read())
logger.info(f"Downloaded {gcs_uri} to {local_path}")
# Return a dictionary with the local path and original URI
yield {
'local_path': local_path,
'gcs_uri': gcs_uri,
'filename': filename
}
except Exception as e:
logger.error(f"Error reading from GCS: {str(e)}")
# Try alternative download method
logger.info(f"Trying alternative download method for
{gcs_uri}")
# For testing with local files
if os.path.exists(gcs_uri.replace("gs://", "")):
local_path = gcs_uri.replace("gs://", "")
logger.info(f"Using local file: {local_path}")
yield {
'local_path': local_path,
'gcs_uri': gcs_uri,
'filename': os.path.basename(local_path)
}
else:
# Try using gsutil command
import subprocess
try:
subprocess.run(["gsutil", "cp", gcs_uri,
local_path], check=True)
logger.info(f"Downloaded {gcs_uri} to {local_path}
using gsutil")
yield {
'local_path': local_path,
'gcs_uri': gcs_uri,
'filename': filename
}
except Exception as e2:
logger.error(f"Failed to download using gsutil:
{str(e2)}")
except Exception as e:
logger.error(f"Error downloading {gcs_uri}: {str(e)}")
class LoadPDFPages(beam.DoFn):
"""Load PDF pages as images."""
def __init__(self, dpi=200):
self.dpi = dpi
def process(self, element):
doc = None
try:
# Make sure we have all required fields
if not isinstance(element, dict):
logger.error(f"Expected dictionary, got {type(element)}")
return
if 'local_path' not in element:
logger.error("Missing 'local_path' in element")
return
local_path = element['local_path']
gcs_uri = element.get('gcs_uri', '')
# Extract filename from local_path if not provided
filename = element.get('filename', os.path.basename(local_path))
logger.info(f"Loading PDF: {local_path}, filename: {filename}")
# Check if file exists and is accessible
if not os.path.exists(local_path):
logger.error(f"File not found: {local_path}")
return
if not os.access(local_path, os.R_OK):
logger.error(f"File not readable: {local_path}")
return
# Open the PDF
try:
doc = fitz.open(local_path)
if doc.is_closed:
logger.error(f"Failed to open PDF: {local_path}")
return
except Exception as e:
logger.error(f"Error opening PDF {local_path}: {str(e)}")
return
# Process each page
page_count = len(doc)
logger.info(f"Processing {page_count} pages from {local_path}")
for i in range(page_count):
try:
if doc.is_closed:
logger.error(f"Document was closed unexpectedly
while processing page {i}")
break
page = doc[i]
if page is None:
logger.error(f"Failed to get page {i} from
document")
continue
# Use a higher resolution for better quality
scale = self.dpi / 72.0
mat = fitz.Matrix(scale, scale)
try:
pix = page.get_pixmap(matrix=mat, alpha=False)
except Exception as e:
logger.error(f"Error getting pixmap for page {i}:
{str(e)}")
continue
# Check pixmap dimensions
if pix.height <= 0 or pix.width <= 0 or pix.n <= 0:
logger.error(f"Invalid pixmap dimensions:
{pix.width}x{pix.height}x{pix.n}")
continue
# Convert to numpy array
try:
arr = np.frombuffer(pix.samples,
dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
except Exception as e:
logger.error(f"Error converting pixmap to numpy
array: {str(e)}")
continue
# Convert BGR to RGB if needed
if pix.n == 3: # RGB
try:
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
except Exception as e:
logger.error(f"Error converting BGR to RGB:
{str(e)}")
continue
# Store original size for later use
original_size = (arr.shape[0], arr.shape[1])
# Create page info
page_info = {
'page_num': i,
'image': arr,
'original_size': original_size,
'local_path': local_path,
'gcs_uri': gcs_uri,
'filename': filename
}
# Use document ID and page number as key
doc_id = os.path.splitext(filename)[0]
key = f"{doc_id}_{i}"
yield (key, page_info)
except Exception as e:
import traceback
logger.error(f"Error processing page {i}: {str(e)}")
logger.error(traceback.format_exc())
logger.info(f"Loaded {len(doc)} pages from {local_path}")
except Exception as e:
import traceback
logger.error(f"Error loading PDF: {str(e)}")
logger.error(traceback.format_exc())
finally:
# Make sure to close the document only if it was successfully
opened
if doc is not None:
try:
if not doc.is_closed:
doc.close()
except Exception as e:
logger.debug(f"Error closing document: {str(e)}")
class PreprocessImage(beam.DoFn):
"""Preprocess image for Triton server."""
def __init__(self, size=1024):
self.size = size
def letterbox(self, img, new_shape=1024, color=(114,114,114)):
"""Resize and pad image to target size."""
h, w = img.shape[:2]
r = min(new_shape / h, new_shape / w)
nh, nw = int(round(h * r)), int(round(w * r))
pad_h, pad_w = new_shape - nh, new_shape - nw
top = pad_h // 2
bottom = pad_h - top
left = pad_w // 2
right = pad_w - left
img = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_LINEAR)
img = cv2.copyMakeBorder(img, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=color)
return img, r, left, top
def process(self, element):
try:
if not isinstance(element, tuple) or len(element) != 2:
logger.error(f"Expected (key, value) tuple, got
{type(element)}")
return
key, page_info = element
if not isinstance(page_info, dict):
logger.error(f"Expected dictionary for page_info, got
{type(page_info)}")
return
if 'image' not in page_info:
logger.error("Missing 'image' in page_info")
return
# Create a new dictionary to avoid modifying the input
new_page_info = dict(page_info)
# Apply letterbox resize
img = new_page_info['image']
lb, r, left, top = self.letterbox(img, new_shape=self.size)
# Convert to float32 and normalize to [0,1]
x = lb.astype(np.float32) / 255.0
# Convert to CHW format
x = np.transpose(x, (2, 0, 1))
# Add batch dimension
batched_img = np.expand_dims(x, axis=0)
# Update page info
new_page_info['preprocessed_image'] = batched_img
new_page_info['letterbox_info'] = (r, left, top)
yield (key, new_page_info)
except Exception as e:
import traceback
logger.error(f"Error preprocessing image: {str(e)}")
logger.error(traceback.format_exc())
class ExtractBoxes(beam.DoFn):
"""Extract bounding boxes from Triton response."""
def __init__(self, conf_th=0.25, iou_th=0.7, model_size=1024):
self.conf_th = conf_th
self.iou_th = iou_th
self.model_size = model_size
def _nms(self, boxes, scores, iou_th=0.7):
"""Non-Maximum Suppression"""
if len(boxes) == 0:
return []
boxes = boxes.astype(np.float32)
x1, y1, x2, y2 = boxes.T
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
inter = w * h
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-9)
inds = np.where(iou <= iou_th)[0]
order = order[inds + 1]
return keep
def process(self, page_info):
try:
triton_response = page_info['triton_response']
original_size = page_info['original_size']
r, left, top = page_info['letterbox_info']
if "outputs" not in triton_response or not
triton_response["outputs"]:
logger.error("Invalid response from Triton server")
return []
out_meta = triton_response["outputs"][0]
shape = out_meta["shape"]
data = np.array(out_meta["data"],
dtype=np.float32).reshape(shape)
logger.info(f"Output shape: {shape}")
# For YOLO output [B, C, P] where C is channels (box coords +
objectness + classes)
B, C, P = shape
# Assuming 4 box coordinates + class probabilities (no
objectness)
has_objectness = False
num_classes = C - 5 if has_objectness else C - 4
# Extract data
xywh = data[:, 0:4, :]
if has_objectness:
obj = data[:, 4:5, :]
cls = data[:, 5:5 + num_classes, :]
else:
obj = None
cls = data[:, 4:4 + num_classes, :]
# Process batch item (we only have one)
b = 0
h, w = original_size
xywh_b = xywh[b].T # (P,4)
if obj is not None:
obj_b = obj[b].T.squeeze(1) # (P,)
else:
obj_b = np.ones((P,), dtype=np.float32)
cls_b = cls[b].T # (P,nc)
# Get scores and labels
scores_all = (obj_b[:, None] * cls_b) if obj is not None else
cls_b
labels = scores_all.argmax(axis=1)
scores = scores_all.max(axis=1)
# Filter by confidence threshold
keep = scores >= self.conf_th
if not np.any(keep):
logger.info(f"No detections above threshold {self.conf_th}")
return []
xywh_k = xywh_b[keep]
scores_k = scores[keep]
labels_k = labels[keep]
# xywh -> xyxy in model space
cx, cy, ww, hh = xywh_k.T
xyxy_model = np.stack([cx - ww / 2, cy - hh / 2, cx + ww / 2,
cy + hh / 2], axis=1)
# Apply NMS per class
final_boxes = []
final_scores = []
final_labels = []
for c in np.unique(labels_k):
idxs = np.where(labels_k == c)[0]
if idxs.size == 0:
continue
keep_idx = self._nms(xyxy_model[idxs], scores_k[idxs],
iou_th=self.iou_th)
final_boxes.append(xyxy_model[idxs][keep_idx])
final_scores.append(scores_k[idxs][keep_idx])
final_labels.append(np.full(len(keep_idx), c, dtype=int))
if not final_boxes:
logger.info("No detections after NMS")
return []
xyxy_model = np.vstack(final_boxes)
scores_k = np.concatenate(final_scores)
labels_k = np.concatenate(final_labels)
# Map boxes from model space to original image space
xyxy_orig = xyxy_model.copy()
# Remove padding
xyxy_orig[:, [0, 2]] -= left
xyxy_orig[:, [1, 3]] -= top
# Scale back to original size
xyxy_orig /= r
# Clip to image boundaries
xyxy_orig[:, 0::2] = np.clip(xyxy_orig[:, 0::2], 0, w - 1)
xyxy_orig[:, 1::2] = np.clip(xyxy_orig[:, 1::2], 0, h - 1)
# Format as requested: x_min, y_min, x_max, y_max, class,
probability
boxes = []
for (x1, y1, x2, y2), label, score in zip(xyxy_orig, labels_k,
scores_k):
class_name = CLASS_ID_TO_NAME.get(int(label))
box_info = {
"page": page_info['page_num'],
"x_min": float(x1),
"y_min": float(y1),
"x_max": float(x2),
"y_max": float(y2),
"class": int(label),
"class_name": class_name,
"probability": float(score),
"filename": page_info['filename'],
"local_path": page_info['local_path'],
"gcs_uri": page_info['gcs_uri']
}
boxes.append(box_info)
logger.info(f"Extracted {len(boxes)} boxes from page
{page_info['page_num']}")
return boxes
except Exception as e:
logger.error(f"Error extracting boxes: {str(e)}")
return []
class PrepareForBigQuery(beam.DoFn):
"""Prepare data for BigQuery insertion."""
def process(self, box_info):
try:
# Generate UUIDs for primary keys
v_note_id = str(uuid.uuid4())
page_ocr_id = str(uuid.uuid4())
class_prediction_id = str(uuid.uuid4())
# Create timestamp
processing_time = datetime.datetime.now().strftime("%Y-%m-%d
%H:%M:%S")
# Create ocr_results row
ocr_results_row = {
"v_note_id": v_note_id,
"filename": box_info['filename'],
"file_path": box_info['gcs_uri'],
"processing_time": processing_time,
"file_type": "pdf"
}
# Create page_ocr row
page_ocr_row = {
"page_ocr_id": page_ocr_id,
"v_note_id": v_note_id,
"page_number": box_info['page']
}
# Create class_prediction row
class_prediction_row = {
"class_prediction_id": class_prediction_id,
"page_ocr_id": page_ocr_id,
"xmin": box_info['x_min'],
"ymin": box_info['y_min'],
"xmax": box_info['x_max'],
"ymax": box_info['y_max'],
"class": box_info['class_name'] if box_info['class_name']
else str(box_info['class']),
"confidence": box_info['probability']
}
# Return all three rows with table names
return [
('ocr_results', ocr_results_row),
('page_ocr', page_ocr_row),
('class_prediction', class_prediction_row)
]
except Exception as e:
logger.error(f"Error preparing for BigQuery: {str(e)}")
return []
model_handler = TensorRTEngineHandlerNumPy(
min_batch_size=1,
max_batch_size=1,
engine_path="gs://temp/yolov11l-doclaynet.engine",
)
with beam.Pipeline(options=options) as pipeline:
# Create PCollection from input URIs
pdf_uris = (
pipeline
| "Create URIs" >> beam.Create(["tmp.pdf"])
)
# Download PDFs
local_pdfs = (
pdf_uris
| "Download PDFs" >> beam.ParDo(DownloadPDFFromGCS())
)
# Load PDF pages
pdf_pages = (
local_pdfs
| "Load PDF Pages" >> beam.ParDo(LoadPDFPages())
#| "Flatten Pages" >> beam.FlatMap(lambda x: x)
)
# Preprocess images
preprocessed_pages = (
pdf_pages
| "Preprocess Images" >> beam.ParDo(PreprocessImage())
)
inference_results = (
preprocessed_pages
| "Run Inference" >> RunInference(model_handler=model_handler)
)
On Tue, 16 Sept 2025 at 21:23, XQ Hu <[email protected]> wrote:
> Can you share your commands and outputs?
>
> On Tue, Sep 16, 2025 at 9:02 PM Sai Shashank <[email protected]>
> wrote:
>
>> Okay I have changed the docker image but to now to RUN the python
>> command but it is still halting without are error or warnings or errors
>>
>> On Tue, 16 Sept 2025 at 17:38, XQ Hu via dev <[email protected]> wrote:
>>
>>> The CMD is not necessary as it will be overridden by the ENTRYPOINT just
>>> like your comment.
>>>
>>> If you ssh to your Docker container like `docker run --rm -it
>>> --entrypoint=/bin/bash $CUSTOM_CONTAINER_IMAGE`, can you run python and
>>> some Beam pipelines with a direct runner in the container? This can help
>>> test the environment works fine.
>>>
>>> I have one old Dockerfile that used to work with the old Beam:
>>> https://github.com/google/dataflow-ml-starter/blob/main/tensor_rt.Dockerfile
>>> .
>>>
>>> On Tue, Sep 16, 2025 at 4:56 PM Sai Shashank <[email protected]>
>>> wrote:
>>>
>>>>
>>>>
>>>> ---------- Forwarded message ---------
>>>> From: Sai Shashank <[email protected]>
>>>> Date: Tue, Sep 16, 2025 at 4:27 PM
>>>> Subject: TensorRT inference not starting
>>>> To: <[email protected]>
>>>>
>>>>
>>>> Hey Everyone,
>>>> I was trying to use tensorRT within the apache
>>>> beam on dataflow but somehow , dataflow didn't start like it did not even
>>>> give me Worker logs. Below is the docker file that , use to create a
>>>> custom image, at first I thought it is the version mismatched but usually
>>>> it gives me a harness error .
>>>>
>>>> ARG BUILD_IMAGE=nvcr.io/nvidia/tensorrt:25.08-py3
>>>> FROM ${BUILD_IMAGE}
>>>> ENV PATH="/usr/src/tensorrt/bin:${PATH}"
>>>>
>>>> WORKDIR /workspace
>>>>
>>>> RUN apt-get update -y && apt-get install -y python3-venv
>>>> RUN pip install --no-cache-dir apache-beam[gcp]==2.67.0
>>>>
>>>> COPY --from=apache/beam_python3.10_sdk:2.67.0 /opt/apache/beam
>>>> /opt/apache/beam
>>>>
>>>> # Install additional dependencies
>>>> RUN pip install --upgrade pip \
>>>> && pip install torch \
>>>> && pip install torchvision \
>>>> && pip install pillow>=8.0.0 \
>>>> && pip install transformers>=4.18.0 \
>>>> && pip install cuda-python \
>>>> && pip install opencv-python==4.7.0.72 \
>>>> && pip install PyMuPDF==1.22.5 \
>>>> && pip install requests==2.31.0
>>>>
>>>> # Set the default command to run the inference script
>>>> # This will be overridden by the Apache Beam boot script
>>>> CMD ["python", "/workspace/inference.py"]
>>>>
>>>> # Use the Apache Beam boot script as the entrypoint
>>>> ENTRYPOINT ["/opt/apache/beam/boot"]
>>>>
>>>>
>>>>