Skip to content

Commit

Permalink
Adding some rotation retry contrl
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 28, 2024
1 parent 7678f31 commit 08d51b7
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 3 deletions.
22 changes: 20 additions & 2 deletions pdelfin/birrpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sqlite3
import orjson
import argparse
import uuid
import base64
import tempfile
import datetime
import posixpath
Expand All @@ -15,6 +15,8 @@

from dataclasses import dataclass
from pypdf import PdfReader
from io import BytesIO
from PIL import Image
from tqdm import tqdm
from functools import partial
from typing import Optional, List, Tuple, Dict, Callable, Any
Expand Down Expand Up @@ -383,8 +385,23 @@ def close(self):
thread.join()


def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int) -> dict:
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)

if image_rotation != 0:
image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img:
rotated_img = img.rotate(-image_rotation, expand=True)

# Save the rotated image to a bytes buffer
buffered = BytesIO()
rotated_img.save(buffered, format="PNG")

# Encode the rotated image back to base64
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')


anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)

return {
Expand Down Expand Up @@ -511,6 +528,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})

# TODO: If the rotation was previously invalid, then apply a rotation


# TODO: Try to provide a smaller prompt hint
else:
Expand Down
86 changes: 85 additions & 1 deletion tests/test_birrpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from unittest.mock import MagicMock, patch
import hashlib
import json
import os
import base64
from PIL import Image

# Adjust the import path to match where your code resides
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager
from pdelfin.birrpipeline import build_dolma_doc, DatabaseManager, build_finetuning_prompt, build_page_query

class TestBuildDolmaDoc(unittest.TestCase):
@patch('pdelfin.birrpipeline.DatabaseManager')
Expand Down Expand Up @@ -121,6 +124,87 @@ def get_s3_bytes_side_effect(s3_client, s3_path, start_index=None, end_index=Non
expected_id = hashlib.sha1(expected_text.encode()).hexdigest()
self.assertEqual(dolma_doc['id'], expected_id)


class TestBuildPageQuery(unittest.TestCase):
def testRotation(self):
# First, generate and save the non-rotated image
query = build_page_query(os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"edgar.pdf"
), "edgar.pdf", 1, 1024, 6000, 0)

# Extract the base64 image from the query
image_content = query["chat_messages"][0]["content"][1]
self.assertEqual(image_content["type"], "image_url")
image_url = image_content["image_url"]["url"]

# Extract base64 string from the data URL
prefix = "data:image/png;base64,"
self.assertTrue(image_url.startswith(prefix))
image_base64 = image_url[len(prefix):]

# Decode the base64 string
image_data = base64.b64decode(image_base64)

# Define the output file path for the non-rotated image
output_image_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image.png")

# Save the non-rotated image to a file
with open(output_image_path, "wb") as image_file:
image_file.write(image_data)

# Now, generate and save the rotated image (90 degrees clockwise)
query_rotated = build_page_query(os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"edgar.pdf"
), "edgar.pdf", 1, 1024, 6000, 90)

# Extract the base64 image from the rotated query
image_content_rotated = query_rotated["chat_messages"][0]["content"][1]
self.assertEqual(image_content_rotated["type"], "image_url")
image_url_rotated = image_content_rotated["image_url"]["url"]

# Extract base64 string from the data URL for the rotated image
self.assertTrue(image_url_rotated.startswith(prefix))
image_base64_rotated = image_url_rotated[len(prefix):]

# Decode the base64 string for the rotated image
image_data_rotated = base64.b64decode(image_base64_rotated)

# Define the output file path for the rotated image
output_image_rotated_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image_rotated90.png")

# Save the rotated image to a file
with open(output_image_rotated_path, "wb") as image_file_rotated:
image_file_rotated.write(image_data_rotated)

# Verification Step: Ensure the rotated image is 90 degrees clockwise rotated

# Open both images using PIL
with Image.open(output_image_path) as original_image:
with Image.open(output_image_rotated_path) as rotated_image:

# Compare pixel by pixel
original_pixels = original_image.load()
rotated_pixels = rotated_image.load()
width, height = original_image.size

self.assertEqual(width, rotated_image.size[1])
self.assertEqual(height, rotated_image.size[0])

for x in range(width):
for y in range(height):

self.assertEqual(
original_pixels[x, y], rotated_pixels[height - 1 - y, x],
f"Pixel mismatch at ({x}, {y})"
)

print("Rotation verification passed: The rotated image is correctly rotated 90 degrees clockwise.")


# Run the test
if __name__ == '__main__':
unittest.main()
Binary file added tests/test_renders/output_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/test_renders/output_image_rotated90.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 08d51b7

Please sign in to comment.