Skip to content

Commit

Permalink
Working on script
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 8, 2024
1 parent e5fb7c0 commit 37dc412
Showing 1 changed file with 98 additions and 19 deletions.
117 changes: 98 additions & 19 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hashlib
import base64
import asyncio
import aiohttp

from tqdm import tqdm
from io import BytesIO
Expand Down Expand Up @@ -179,31 +180,89 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:

return queue

async def process_pdf(args, pdf_s3_path):

async def process_page(session, pdf_path, page_num, args):
query = await build_page_query(
pdf_path,
page_num,
args.target_longest_image_dim,
args.target_anchor_text_len
)
URL = "http://localhost:30000/v1/chat/completions"

with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Grab this file async
tf.write(get_s3_bytes(pdf_s3, s3_path))
tf.flush()
try:
async with session.post(URL, json=query) as response:
if response.status == 200:
result = await response.json()
return (page_num, result)
else:
logger.warning(f"Request failed with status {response.status} for page {page_num}")
return None
except Exception as e:
logger.error(f"Exception while processing page {page_num}: {e}")
return None

reader = PdfReader(tf.name)
page_data = []

for page_num in range(1, reader.get_num_pages() + 1):
query = await build_page_query(tf.name, page_num, args.target_longest_image_dim, args.target_anchor_text_len)
async def process_pdf(args, pdf_s3_path):
URL = "http://localhost:30000/v1/chat/completions"

# TODO Url.post with the query as json_data
# if the result is a 200 then you can append it to the page_data


with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Switch to aioboto3 or something
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
tf.write(data)
tf.flush()


# TODO build dolma doc and return it, or return None if not possible
reader = PdfReader(tf.name)
num_pages = reader.get_num_pages()

# List to hold the tasks for processing each page
page_tasks = []

async with aiohttp.ClientSession() as session:
for page_num in range(1, num_pages + 1):
# Create a task for each page
task = asyncio.create_task(process_page(session, tf.name, page_num, args))
page_tasks.append(task)

# Gather results from all page processing tasks
page_results = await asyncio.gather(*page_tasks)

# If we failed to build a page, then this document is toast
# TODO Abort earlier, if a page returns a None, then we can stop processing the whole pdf
if any(page is None for page in page_results):
return None

# Build the document text and page spans
document_text = ''
pdf_page_spans = []
current_char_pos = 0

for page_num, result in page_data:
try:
content = result['choices'][0]['message']['content']
except (KeyError, IndexError) as e:
logger.error(f"Failed to extract content for page {page_num}: {e}")
continue

start_pos = current_char_pos
document_text += content
current_char_pos = len(document_text)
pdf_page_spans.append({
'pdf_page_number': page_num,
'start_char': start_pos,
'end_char': current_char_pos
})

if not document_text:
return None # Return None if the document text is empty

# Build the Dolma document
metadata = {
"Source-File": pdf.s3_path,
"pdf-total-pages": pdf.num_pages,
}
"Source-File": pdf_s3_path,
"pdf-total-pages": num_pages,
}

id_ = hashlib.sha1(document_text.encode()).hexdigest()

dolma_doc = {
Expand All @@ -221,8 +280,6 @@ async def process_pdf(args, pdf_s3_path):
return dolma_doc




async def worker(args, queue):
while True:
[work_hash, pdfs] = await queue.get()
Expand Down Expand Up @@ -251,6 +308,27 @@ async def sglang_server_task(args):

await proc.wait()

async def sglang_server_ready(args):
max_attempts = 60
delay_sec = 1
url = 'http://localhost:30000/v1/models'

for attempt in range(1, max_attempts + 1):
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
logger.info("sglang server is ready.")
return
else:
logger.info(f"Attempt {attempt}: Unexpected status code {response.status}")
except Exception as e:
logger.warning(f"Attempt {attempt}: Exception occurred: {e}")

await asyncio.sleep(delay_sec)

raise Exception("sglang server did not become ready after waiting.")


async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
Expand Down Expand Up @@ -289,6 +367,7 @@ async def main():
work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items")

await sglang_server_ready()

# Create worker tasks to process the queue concurrently.
worker_tasks = []
Expand Down

0 comments on commit 37dc412

Please sign in to comment.