anti-prestige-tool/reforger_queue/ocr.py

294 lines
10 KiB
Python

import json
import os
import sys
from .config import AUTO_FONT_CANDIDATES
from .magick import image_size, magick_fonts, ppm_from_magick, try_ppm_from_magick
def orange_digit_mask(width, height, raw):
mask = [[False] * width for _ in range(height)]
for y in range(height):
row_offset = y * width * 3
for x in range(width):
offset = row_offset + x * 3
r, g, b = raw[offset], raw[offset + 1], raw[offset + 2]
mask[y][x] = r >= 115 and 55 <= g <= 205 and b <= 100 and r - g >= 20
return mask
def white_mask(width, height, raw):
mask = [[False] * width for _ in range(height)]
for y in range(height):
row_offset = y * width * 3
for x in range(width):
offset = row_offset + x * 3
r, g, b = raw[offset], raw[offset + 1], raw[offset + 2]
mask[y][x] = r >= 120 and g >= 120 and b >= 120
return mask
def connected_components(mask):
height = len(mask)
width = len(mask[0]) if height else 0
seen = [[False] * width for _ in range(height)]
components = []
for y in range(height):
for x in range(width):
if seen[y][x] or not mask[y][x]:
continue
stack = [(x, y)]
seen[y][x] = True
points = set()
while stack:
cx, cy = stack.pop()
points.add((cx, cy))
for ny in range(cy - 1, cy + 2):
for nx in range(cx - 1, cx + 2):
if nx < 0 or ny < 0 or nx >= width or ny >= height:
continue
if seen[ny][nx] or not mask[ny][nx]:
continue
seen[ny][nx] = True
stack.append((nx, ny))
xs = [point[0] for point in points]
ys = [point[1] for point in points]
bbox = (min(xs), min(ys), max(xs), max(ys))
box_width = bbox[2] - bbox[0] + 1
box_height = bbox[3] - bbox[1] + 1
if len(points) >= 40 and box_width >= 5 and box_height >= 15:
components.append({"points": points, "bbox": bbox, "area": len(points)})
components.sort(key=lambda component: component["bbox"][0])
return components
def normalize(component, width=24, height=36):
x1, y1, x2, y2 = component["bbox"]
source_width = x2 - x1 + 1
source_height = y2 - y1 + 1
points = component["points"]
output = []
for y in range(height):
source_y = y1 + int((y + 0.5) * source_height / height)
row = []
for x in range(width):
source_x = x1 + int((x + 0.5) * source_width / width)
row.append((source_x, source_y) in points)
output.append(row)
return output
def template_distance(left, right):
total = len(left) * len(left[0])
differences = 0
intersection = 0
union = 0
for left_row, right_row in zip(left, right):
for left_value, right_value in zip(left_row, right_row):
differences += left_value != right_value
intersection += left_value and right_value
union += left_value or right_value
hamming = differences / total
iou = intersection / union if union else 0.0
return hamming, iou
def template_font_candidates(font):
if font != "auto":
return [font]
available = magick_fonts()
candidates = [candidate for candidate in AUTO_FONT_CANDIDATES if candidate in available]
return [*candidates, None]
def build_template_args(font, pointsize):
args = ["-background", "black", "-fill", "white"]
if font:
args.extend(["-font", font])
args.extend(["-pointsize", str(pointsize), "label:0123456789"])
return args
def build_templates(font, pointsize):
errors = []
for candidate in template_font_candidates(font):
ppm, error = try_ppm_from_magick(build_template_args(candidate, pointsize))
if ppm is None:
errors.append(f"{candidate or 'ImageMagick default'}: {error}")
continue
width, height, raw = ppm
components = connected_components(white_mask(width, height, raw))
if len(components) < 10:
errors.append(f"{candidate or 'ImageMagick default'}: rendered only {len(components)} digit components")
continue
return {str(index): [normalize(component)] for index, component in enumerate(components[:10])}
if font == "auto":
sys.exit("could not render digit templates with any available ImageMagick font:\n " + "\n ".join(errors))
sys.exit(
f"could not render digit templates with font {font!r}. "
"Install that font, use --font auto, or pass another ImageMagick font name."
)
def validate_template_grid(grid, digit, sample_index, path):
if not isinstance(grid, list) or len(grid) != 36:
sys.exit(f"invalid template grid for digit {digit} sample {sample_index} in {path}: expected 36 rows")
for row in grid:
if not isinstance(row, list) or len(row) != 24 or any(not isinstance(value, bool) for value in row):
sys.exit(
f"invalid template grid for digit {digit} sample {sample_index} in {path}: "
"expected 24 boolean columns per row"
)
def load_template_set(path):
try:
with open(path, "r", encoding="utf-8") as handle:
payload = json.load(handle)
except FileNotFoundError:
sys.exit(f"missing digit template set {path}; regenerate it from the regression dataset")
except json.JSONDecodeError as exc:
sys.exit(f"failed to parse digit template set {path}: {exc}")
digits = payload.get("digits") if isinstance(payload, dict) else None
if not isinstance(digits, dict):
sys.exit(f"invalid digit template set {path}: missing digits object")
templates = {}
for digit in "0123456789":
samples = digits.get(digit)
if not isinstance(samples, list) or not samples:
sys.exit(f"invalid digit template set {path}: missing samples for digit {digit}")
templates[digit] = []
for index, grid in enumerate(samples):
validate_template_grid(grid, digit, index, path)
templates[digit].append(grid)
return templates
def templates_from_dataset(paths, reference_crop, reference_size, scale_mode, cropped):
templates = {str(digit): [] for digit in range(10)}
for path in paths:
expected = os.path.splitext(os.path.basename(path))[0]
if not expected.isdigit():
sys.exit(f"template source filename must be numeric: {path}")
crop = resolve_crop(path, None, reference_crop, reference_size, scale_mode, cropped)
width, height, raw = load_queue_image(path, crop, cropped)
components = connected_components(orange_digit_mask(width, height, raw))
if len(components) != len(expected):
sys.exit(
f"template source {path} produced {len(components)} digit components, "
f"but filename expects {len(expected)}"
)
for digit, component in zip(expected, components):
templates[digit].append(normalize(component))
missing = [digit for digit, samples in templates.items() if not samples]
if missing:
sys.exit(f"template dataset has no samples for digits: {', '.join(missing)}")
return templates
def write_template_set(output, templates, source_paths):
payload = {
"version": 1,
"normalize_size": [24, 36],
"source": "regression digit crops",
"source_files": [os.path.basename(path) for path in source_paths],
"digits": templates,
}
directory = os.path.dirname(output)
if directory:
os.makedirs(directory, exist_ok=True)
temp_path = f"{output}.tmp"
with open(temp_path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, separators=(",", ":"))
handle.write("\n")
os.replace(temp_path, output)
def resolve_templates(template_set, font, pointsize):
if font:
return build_templates(font, pointsize)
return load_template_set(template_set)
def classify(component, templates):
sample = normalize(component)
ranked = []
for digit, digit_templates in templates.items():
for template in digit_templates:
hamming, iou = template_distance(sample, template)
ranked.append((hamming, -iou, digit, iou))
ranked.sort()
hamming, negative_iou, digit, iou = ranked[0]
return digit, hamming, iou
def load_queue_image(image_path, crop, already_cropped=False):
if already_cropped:
return ppm_from_magick([image_path, "+repage"])
x, y, width, height = crop
return ppm_from_magick([image_path, "+repage", "-crop", f"{width}x{height}+{x}+{y}", "+repage"])
def read_queue_number(image_path, crop, template_set, font, pointsize, already_cropped=False):
width, height, raw = load_queue_image(image_path, crop, already_cropped)
components = connected_components(orange_digit_mask(width, height, raw))
templates = resolve_templates(template_set, font, pointsize)
digits = []
details = []
for component in components:
digit, hamming, iou = classify(component, templates)
digits.append(digit)
details.append((digit, component["bbox"], component["area"], hamming, iou))
return "".join(digits), details
def scale_crop(reference_crop, reference_size, target_size, scale_mode):
reference_width, reference_height = reference_size
target_width, target_height = target_size
x, y, width, height = reference_crop
if scale_mode == "width":
x_scale = target_width / reference_width
y_scale = x_scale
elif scale_mode == "independent":
x_scale = target_width / reference_width
y_scale = target_height / reference_height
else:
raise ValueError(f"unsupported scale mode: {scale_mode}")
return (
round(x * x_scale),
round(y * y_scale),
max(1, round(width * x_scale)),
max(1, round(height * y_scale)),
)
def resolve_crop(image_path, explicit_crop, reference_crop, reference_size, scale_mode, already_cropped):
if already_cropped:
return None
if explicit_crop:
return explicit_crop
return scale_crop(reference_crop, reference_size, image_size(image_path), scale_mode)