Skip to content

Commit

Permalink
fixes missing OSS code for Issue #36
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Feb 26, 2025
1 parent d4b902c commit bd08fdb
Show file tree
Hide file tree
Showing 6 changed files with 466 additions and 8 deletions.
69 changes: 69 additions & 0 deletions olmocr/eval/dolma_refine/aligners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Type

from sequence_align.pairwise import hirschberg, needleman_wunsch

from .registry import BaseRegistry


class AlignerRegistry(BaseRegistry[Type["BaseAligner"]]):
"""A registry for aligners."""


class BaseAligner:
def __init__(self, *args, **kwargs):
super().__init__()

def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
raise NotImplementedError()


@AlignerRegistry.add("hirschberg")
class HirschbergAligner(BaseAligner):
def __init__(
self,
match_score: float = 1.0,
mismatch_score: float = -1.0,
indel_score: float = -1.0,
gap_token: str = "▓",
):
self.match_score = match_score
self.mismatch_score = mismatch_score
self.indel_score = indel_score
self.gap_token = gap_token
super().__init__()

def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
return hirschberg(
gold,
pred,
match_score=self.match_score,
mismatch_score=self.mismatch_score,
indel_score=self.indel_score,
gap=self.gap_token,
)


@AlignerRegistry.add("needleman-wunsch")
class NeedlemanWunschAligner(BaseAligner):
def __init__(
self,
match_score: float = 1.0,
mismatch_score: float = -1.0,
indel_score: float = -1.0,
gap_token: str = "▓",
):
self.match_score = match_score
self.mismatch_score = mismatch_score
self.indel_score = indel_score
self.gap_token = gap_token
super().__init__()

def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
return needleman_wunsch(
gold,
pred,
match_score=self.match_score,
mismatch_score=self.mismatch_score,
indel_score=self.indel_score,
gap=self.gap_token,
)
237 changes: 237 additions & 0 deletions olmocr/eval/dolma_refine/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import bisect
from typing import Type

import regex as re
from tqdm import tqdm

from .aligners import BaseAligner
from .segmenters import BaseSegmenter, SegmenterRegistry
from .registry import BaseRegistry

class TextMetricRegistry(BaseRegistry[Type["BaseTextMetric"]]):
"""A registry for text metrics."""


class BaseTextMetric:
def __init__(self, *args, **kwargs):
super().__init__()

def compute(self, gold: str, pred: str) -> float:
raise NotImplementedError()

def batch_compute(self, golds: list[str], preds: list[str]) -> list[float]:
it = tqdm(
zip(golds, preds),
total=min(len(golds), len(preds)),
desc=type(self).__name__,
unit="samples",
unit_scale=True,
)
return [self.compute(gold, pred) for gold, pred in it]


class BaseTextAlignMetric(BaseTextMetric):
def __init__(
self,
segmenter: str | BaseSegmenter,
aligner: str | BaseAligner = "hirschberg",
aligner_kwargs: dict = {},
segmenter_kwargs: dict = {},
gap_token: str = "▓",
*args,
**kwargs,
):
if isinstance(segmenter, str):
self.segmenter = SegmenterRegistry.get(segmenter)(segmenter, **segmenter_kwargs)
else:
self.segmenter = segmenter

if isinstance(aligner, str):
self.aligner = AlignerRegistry.get(aligner)(aligner, **aligner_kwargs)
else:
self.aligner = aligner

self.gap_token = gap_token

def segment(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> list[tuple[list[str], list[str]]]:
return [(seq_a_tokens, seq_b_tokens)]

def align(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> tuple[list[str], list[str]]:
return self.aligner.align(seq_a_tokens, seq_b_tokens)

def tokenize(self, text: str) -> list[str]:
return [w for w in re.split(r"(\p{P}+|\s+)", text) if w]

def compute(self, gold: str, pred: str) -> float:
raise NotImplementedError()


@TextMetricRegistry.add("document_edit_similarity")
class DocumentEditSimilarity(BaseTextAlignMetric):
def _score_aligned(self, aligned_gold_tokens: list[str], aligned_pred_tokens: list[str]) -> float:
insertions = deletions = matches = substitutions = 0.0
for gold_symbol, pred_symbol in zip(aligned_gold_tokens, aligned_pred_tokens):
if gold_symbol == self.gap_token:
insertions += 1
elif pred_symbol == self.gap_token:
deletions += 1
elif gold_symbol == pred_symbol:
matches += 1
else:
substitutions += 1

if total := insertions + deletions + matches + substitutions:
return matches / total
return 0.0

def compute(self, gold: str, pred: str) -> float:
gold_tokens = self.tokenize(gold)
pred_tokens = self.tokenize(pred)
aligned_gold_tokens, aligned_pred_tokens = self.align(gold_tokens, pred_tokens)
return self._score_aligned(aligned_gold_tokens, aligned_pred_tokens)


def find_align_gaps(aligned_text: list[str], gap_token: str = "▓", gap_threshold: int = 3) -> list[int]:
consecutive_gaps_counter = 0
above_threshold_locs: list[int] = []

for aligned_pos, symbol in enumerate(aligned_text):
if symbol == gap_token:
consecutive_gaps_counter += 1
else:
consecutive_gaps_counter = 0

if consecutive_gaps_counter >= gap_threshold:
above_threshold_locs.append(aligned_pos)
consecutive_gaps_counter = 0

return above_threshold_locs


def make_unaligned_text(tokens: list[str], gap_token: str = "▓") -> str:
return "".join(symbol for symbol in tokens if symbol != gap_token)


def find_sentences(
tokens: list[str],
sentences: list[str],
gap_token: str = "▓",
):
matches: list[tuple[int, int]] = []

original_text = ""
original: list[int] = []
original_to_aligned: list[int] = []

for i, token in enumerate(tokens):
if token != gap_token:
original_text += token
original.append(len(original_text))
original_to_aligned.append(i)

matches = []
for sentence in sentences:
start_pos = original_text.find(sentence)
if start_pos < 0:
continue

end_pos = start_pos + len(sentence)
start_token = original_to_aligned[bisect.bisect_left(original, start_pos)]
end_token = original_to_aligned[min(bisect.bisect_right(original, end_pos), len(original) - 1)]
matches.append((start_token, end_token))

return matches


def merge_spans(spans: list[tuple[int, int]]) -> list[tuple[int, int]]:
if not spans:
return []

# Sort spans based on start position
sorted_spans = sorted(spans, key=lambda x: x[0])

merged = [sorted_spans[0]]

for current in sorted_spans[1:]:
last = merged[-1]

# If current span overlaps with last merged span, update the end of last span
if current[0] <= last[1]:
merged[-1] = (last[0], max(last[1], current[1]))
else:
merged.append(current)

return merged


def make_sentences_around_gaps(sent_locs: list[tuple[int, int]], gaps_locs: list[int], window: int):
sent_start_only = [start for start, _ in sent_locs]

sentences_with_gaps = []

# collect all sentences that are around the gaps
for gap in gaps_locs:
start_idx = bisect.bisect_left(sent_start_only, gap)
fwd_window = max(0, start_idx - window)
bwd_window = min(len(sent_locs) - 1, start_idx + window)
sentences_with_gaps.append((sent_locs[fwd_window][0], sent_locs[bwd_window][-1]))

# merge overlapping sentences
sentences_with_gaps = merge_spans(sentences_with_gaps)

return sentences_with_gaps


@TextMetricRegistry.add("paragraph_edit_similarity")
class ParagraphEditSimilarity(DocumentEditSimilarity):
def __init__(
self,
segmenter: str | BaseSegmenter,
aligner: str | BaseAligner = "hirschberg",
aligner_kwargs: dict = {},
segmenter_kwargs: dict = {},
gap_token: str = "▓",
gap_threshold: int = 3,
sent_window: int = 1,
*args,
**kwargs,
):
super().__init__(
segmenter=segmenter,
aligner=aligner,
aligner_kwargs=aligner_kwargs,
segmenter_kwargs=segmenter_kwargs,
gap_token=gap_token,
)
self.gap_threshold = gap_threshold
self.sent_window = sent_window

def segment(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> list[tuple[list[str], list[str]]]:

all_spans = []

for seq_tokens in (seq_a_tokens, seq_b_tokens):
text = make_unaligned_text(tokens=seq_tokens, gap_token=self.gap_token)
sentences = self.segmenter.segment(text)

sent_locs = find_sentences(tokens=seq_tokens, sentences=sentences, gap_token=self.gap_token)
gaps_locs = find_align_gaps(aligned_text=seq_tokens, gap_token=self.gap_token, gap_threshold=3)

sentences_with_gaps = make_sentences_around_gaps(
sent_locs=sent_locs, gaps_locs=gaps_locs, window=self.sent_window
)
all_spans.extend(sentences_with_gaps)

return [(seq_a_tokens[start:end], seq_b_tokens[start:end]) for start, end in merge_spans(all_spans)]

def compute(self, gold: str, pred: str) -> float:
gold_tokens = self.tokenize(gold)
pred_tokens = self.tokenize(pred)
aligned_gold_tokens, aligned_pred_tokens = self.align(gold_tokens, pred_tokens)

scores = []
for gold_segment, pred_segment in self.segment(aligned_gold_tokens, aligned_pred_tokens):
score = self._score_aligned(gold_segment, pred_segment)
scores.append(score)

return sum(scores) / len(scores) if scores else 1.0
Loading

0 comments on commit bd08fdb

Please sign in to comment.