294 lines
10 KiB
Python
294 lines
10 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
|
|
from .config import AUTO_FONT_CANDIDATES
|
|
from .magick import image_size, magick_fonts, ppm_from_magick, try_ppm_from_magick
|
|
|
|
|
|
def orange_digit_mask(width, height, raw):
|
|
mask = [[False] * width for _ in range(height)]
|
|
for y in range(height):
|
|
row_offset = y * width * 3
|
|
for x in range(width):
|
|
offset = row_offset + x * 3
|
|
r, g, b = raw[offset], raw[offset + 1], raw[offset + 2]
|
|
mask[y][x] = r >= 115 and 55 <= g <= 205 and b <= 100 and r - g >= 20
|
|
return mask
|
|
|
|
|
|
def white_mask(width, height, raw):
|
|
mask = [[False] * width for _ in range(height)]
|
|
for y in range(height):
|
|
row_offset = y * width * 3
|
|
for x in range(width):
|
|
offset = row_offset + x * 3
|
|
r, g, b = raw[offset], raw[offset + 1], raw[offset + 2]
|
|
mask[y][x] = r >= 120 and g >= 120 and b >= 120
|
|
return mask
|
|
|
|
|
|
def connected_components(mask):
|
|
height = len(mask)
|
|
width = len(mask[0]) if height else 0
|
|
seen = [[False] * width for _ in range(height)]
|
|
components = []
|
|
|
|
for y in range(height):
|
|
for x in range(width):
|
|
if seen[y][x] or not mask[y][x]:
|
|
continue
|
|
|
|
stack = [(x, y)]
|
|
seen[y][x] = True
|
|
points = set()
|
|
|
|
while stack:
|
|
cx, cy = stack.pop()
|
|
points.add((cx, cy))
|
|
for ny in range(cy - 1, cy + 2):
|
|
for nx in range(cx - 1, cx + 2):
|
|
if nx < 0 or ny < 0 or nx >= width or ny >= height:
|
|
continue
|
|
if seen[ny][nx] or not mask[ny][nx]:
|
|
continue
|
|
seen[ny][nx] = True
|
|
stack.append((nx, ny))
|
|
|
|
xs = [point[0] for point in points]
|
|
ys = [point[1] for point in points]
|
|
bbox = (min(xs), min(ys), max(xs), max(ys))
|
|
box_width = bbox[2] - bbox[0] + 1
|
|
box_height = bbox[3] - bbox[1] + 1
|
|
|
|
if len(points) >= 40 and box_width >= 5 and box_height >= 15:
|
|
components.append({"points": points, "bbox": bbox, "area": len(points)})
|
|
|
|
components.sort(key=lambda component: component["bbox"][0])
|
|
return components
|
|
|
|
|
|
def normalize(component, width=24, height=36):
|
|
x1, y1, x2, y2 = component["bbox"]
|
|
source_width = x2 - x1 + 1
|
|
source_height = y2 - y1 + 1
|
|
points = component["points"]
|
|
output = []
|
|
|
|
for y in range(height):
|
|
source_y = y1 + int((y + 0.5) * source_height / height)
|
|
row = []
|
|
for x in range(width):
|
|
source_x = x1 + int((x + 0.5) * source_width / width)
|
|
row.append((source_x, source_y) in points)
|
|
output.append(row)
|
|
|
|
return output
|
|
|
|
|
|
def template_distance(left, right):
|
|
total = len(left) * len(left[0])
|
|
differences = 0
|
|
intersection = 0
|
|
union = 0
|
|
|
|
for left_row, right_row in zip(left, right):
|
|
for left_value, right_value in zip(left_row, right_row):
|
|
differences += left_value != right_value
|
|
intersection += left_value and right_value
|
|
union += left_value or right_value
|
|
|
|
hamming = differences / total
|
|
iou = intersection / union if union else 0.0
|
|
return hamming, iou
|
|
|
|
|
|
def template_font_candidates(font):
|
|
if font != "auto":
|
|
return [font]
|
|
|
|
available = magick_fonts()
|
|
candidates = [candidate for candidate in AUTO_FONT_CANDIDATES if candidate in available]
|
|
return [*candidates, None]
|
|
|
|
|
|
def build_template_args(font, pointsize):
|
|
args = ["-background", "black", "-fill", "white"]
|
|
if font:
|
|
args.extend(["-font", font])
|
|
args.extend(["-pointsize", str(pointsize), "label:0123456789"])
|
|
return args
|
|
|
|
|
|
def build_templates(font, pointsize):
|
|
errors = []
|
|
for candidate in template_font_candidates(font):
|
|
ppm, error = try_ppm_from_magick(build_template_args(candidate, pointsize))
|
|
if ppm is None:
|
|
errors.append(f"{candidate or 'ImageMagick default'}: {error}")
|
|
continue
|
|
|
|
width, height, raw = ppm
|
|
components = connected_components(white_mask(width, height, raw))
|
|
if len(components) < 10:
|
|
errors.append(f"{candidate or 'ImageMagick default'}: rendered only {len(components)} digit components")
|
|
continue
|
|
return {str(index): [normalize(component)] for index, component in enumerate(components[:10])}
|
|
|
|
if font == "auto":
|
|
sys.exit("could not render digit templates with any available ImageMagick font:\n " + "\n ".join(errors))
|
|
sys.exit(
|
|
f"could not render digit templates with font {font!r}. "
|
|
"Install that font, use --font auto, or pass another ImageMagick font name."
|
|
)
|
|
|
|
|
|
def validate_template_grid(grid, digit, sample_index, path):
|
|
if not isinstance(grid, list) or len(grid) != 36:
|
|
sys.exit(f"invalid template grid for digit {digit} sample {sample_index} in {path}: expected 36 rows")
|
|
for row in grid:
|
|
if not isinstance(row, list) or len(row) != 24 or any(not isinstance(value, bool) for value in row):
|
|
sys.exit(
|
|
f"invalid template grid for digit {digit} sample {sample_index} in {path}: "
|
|
"expected 24 boolean columns per row"
|
|
)
|
|
|
|
|
|
def load_template_set(path):
|
|
try:
|
|
with open(path, "r", encoding="utf-8") as handle:
|
|
payload = json.load(handle)
|
|
except FileNotFoundError:
|
|
sys.exit(f"missing digit template set {path}; regenerate it from the regression dataset")
|
|
except json.JSONDecodeError as exc:
|
|
sys.exit(f"failed to parse digit template set {path}: {exc}")
|
|
|
|
digits = payload.get("digits") if isinstance(payload, dict) else None
|
|
if not isinstance(digits, dict):
|
|
sys.exit(f"invalid digit template set {path}: missing digits object")
|
|
|
|
templates = {}
|
|
for digit in "0123456789":
|
|
samples = digits.get(digit)
|
|
if not isinstance(samples, list) or not samples:
|
|
sys.exit(f"invalid digit template set {path}: missing samples for digit {digit}")
|
|
templates[digit] = []
|
|
for index, grid in enumerate(samples):
|
|
validate_template_grid(grid, digit, index, path)
|
|
templates[digit].append(grid)
|
|
|
|
return templates
|
|
|
|
|
|
def templates_from_dataset(paths, reference_crop, reference_size, scale_mode, cropped):
|
|
templates = {str(digit): [] for digit in range(10)}
|
|
for path in paths:
|
|
expected = os.path.splitext(os.path.basename(path))[0]
|
|
if not expected.isdigit():
|
|
sys.exit(f"template source filename must be numeric: {path}")
|
|
|
|
crop = resolve_crop(path, None, reference_crop, reference_size, scale_mode, cropped)
|
|
width, height, raw = load_queue_image(path, crop, cropped)
|
|
components = connected_components(orange_digit_mask(width, height, raw))
|
|
if len(components) != len(expected):
|
|
sys.exit(
|
|
f"template source {path} produced {len(components)} digit components, "
|
|
f"but filename expects {len(expected)}"
|
|
)
|
|
|
|
for digit, component in zip(expected, components):
|
|
templates[digit].append(normalize(component))
|
|
|
|
missing = [digit for digit, samples in templates.items() if not samples]
|
|
if missing:
|
|
sys.exit(f"template dataset has no samples for digits: {', '.join(missing)}")
|
|
return templates
|
|
|
|
|
|
def write_template_set(output, templates, source_paths):
|
|
payload = {
|
|
"version": 1,
|
|
"normalize_size": [24, 36],
|
|
"source": "regression digit crops",
|
|
"source_files": [os.path.basename(path) for path in source_paths],
|
|
"digits": templates,
|
|
}
|
|
directory = os.path.dirname(output)
|
|
if directory:
|
|
os.makedirs(directory, exist_ok=True)
|
|
temp_path = f"{output}.tmp"
|
|
with open(temp_path, "w", encoding="utf-8") as handle:
|
|
json.dump(payload, handle, separators=(",", ":"))
|
|
handle.write("\n")
|
|
os.replace(temp_path, output)
|
|
|
|
|
|
def resolve_templates(template_set, font, pointsize):
|
|
if font:
|
|
return build_templates(font, pointsize)
|
|
return load_template_set(template_set)
|
|
|
|
|
|
def classify(component, templates):
|
|
sample = normalize(component)
|
|
ranked = []
|
|
for digit, digit_templates in templates.items():
|
|
for template in digit_templates:
|
|
hamming, iou = template_distance(sample, template)
|
|
ranked.append((hamming, -iou, digit, iou))
|
|
ranked.sort()
|
|
hamming, negative_iou, digit, iou = ranked[0]
|
|
return digit, hamming, iou
|
|
|
|
|
|
def load_queue_image(image_path, crop, already_cropped=False):
|
|
if already_cropped:
|
|
return ppm_from_magick([image_path, "+repage"])
|
|
|
|
x, y, width, height = crop
|
|
return ppm_from_magick([image_path, "+repage", "-crop", f"{width}x{height}+{x}+{y}", "+repage"])
|
|
|
|
|
|
def read_queue_number(image_path, crop, template_set, font, pointsize, already_cropped=False):
|
|
width, height, raw = load_queue_image(image_path, crop, already_cropped)
|
|
components = connected_components(orange_digit_mask(width, height, raw))
|
|
templates = resolve_templates(template_set, font, pointsize)
|
|
|
|
digits = []
|
|
details = []
|
|
for component in components:
|
|
digit, hamming, iou = classify(component, templates)
|
|
digits.append(digit)
|
|
details.append((digit, component["bbox"], component["area"], hamming, iou))
|
|
|
|
return "".join(digits), details
|
|
|
|
|
|
def scale_crop(reference_crop, reference_size, target_size, scale_mode):
|
|
reference_width, reference_height = reference_size
|
|
target_width, target_height = target_size
|
|
x, y, width, height = reference_crop
|
|
|
|
if scale_mode == "width":
|
|
x_scale = target_width / reference_width
|
|
y_scale = x_scale
|
|
elif scale_mode == "independent":
|
|
x_scale = target_width / reference_width
|
|
y_scale = target_height / reference_height
|
|
else:
|
|
raise ValueError(f"unsupported scale mode: {scale_mode}")
|
|
|
|
return (
|
|
round(x * x_scale),
|
|
round(y * y_scale),
|
|
max(1, round(width * x_scale)),
|
|
max(1, round(height * y_scale)),
|
|
)
|
|
|
|
|
|
def resolve_crop(image_path, explicit_crop, reference_crop, reference_size, scale_mode, already_cropped):
|
|
if already_cropped:
|
|
return None
|
|
if explicit_crop:
|
|
return explicit_crop
|
|
return scale_crop(reference_crop, reference_size, image_size(image_path), scale_mode)
|