id-scanner / app.py
gaunernst's picture
add logging and examples
01f1f5c
raw
history blame contribute delete
3.66 kB
import logging
import math
import cv2
import gradio as gr
import numpy as np
import onnxruntime as ort
from PIL import Image, ImageOps
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
MODEL_PATH = "model.onnx"
IMAGE_SIZE = 480
SESSION = ort.InferenceSession(MODEL_PATH)
INPUT_NAME = SESSION.get_inputs()[0].name
def preprocess(img: Image.Image) -> np.ndarray:
resized_img = ImageOps.pad(img, (IMAGE_SIZE, IMAGE_SIZE), centering=(0, 0))
img_chw = np.array(resized_img).transpose(2, 0, 1).astype(np.float32) / 255
img_chw = (img_chw - 0.5) / 0.5
return img_chw
def distance(p1, p2):
return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
# https://stackoverflow.com/a/1222855
# https://www.microsoft.com/en-us/research/wp-content/uploads/2016/11/Digital-Signal-Processing.pdf
def get_aspect_ratio_zhang(keypoints: np.ndarray, img_width: int, img_height: int):
keypoints = keypoints[[3, 2, 0, 1]] # re-arrange keypoint according to Zhang 2006 Figure 6
keypoints = np.concatenate([keypoints, np.ones((4, 1))], axis=1) # convert to homogeneous coordinates
# equation (11) and (12)
k2 = np.cross(keypoints[0], keypoints[3]).dot(keypoints[2]) / np.cross(keypoints[1], keypoints[3]).dot(keypoints[2])
k3 = np.cross(keypoints[0], keypoints[3]).dot(keypoints[1]) / np.cross(keypoints[2], keypoints[3]).dot(keypoints[1])
# equation (14) and (16)
n2 = k2 * keypoints[1] - keypoints[0]
n3 = k3 * keypoints[2] - keypoints[0]
# equation (21)
u0 = img_width / 2
v0 = img_height / 2
f2 = -(n2[0] * n3[0] - (n2[0] * n3[2] + n2[2] + n3[0]) * u0 + n2[2] * n3[2] * u0 * u0) / (n2[2] * n3[2]) + (
n2[1] * n3[1] - (n2[1] * n3[2] + n2[2] * n3[1]) * v0 + n2[2] * n3[2] * v0 * v0
)
f = math.sqrt(f2)
# equation (20)
A = np.array([[f, 0, u0], [0, f, v0], [0, 0, 1]])
A_inv = np.linalg.inv(A)
mid = A_inv.T.dot(A_inv)
wh_ratio2 = n2.dot(mid).dot(n2) / n3.dot(mid).dot(n3)
return math.sqrt(wh_ratio2)
def rectify(img_np: np.ndarray, keypoints: np.ndarray):
img_height, img_width = img_np.shape[:2]
h1 = distance(keypoints[0], keypoints[3])
h2 = distance(keypoints[1], keypoints[2])
h = (h1 + h2) * 0.5
# this may fail if two lines are parallel
try:
wh_ratio = get_aspect_ratio_zhang(keypoints, img_width, img_height)
w = h * wh_ratio
except:
logging.exception("Failed to estimate aspect ratio from perspective")
w1 = distance(keypoints[0], keypoints[1])
w2 = distance(keypoints[3], keypoints[2])
w = (w1 + w2) * 0.5
target_kpts = np.array([[1, 1], [w + 1, 1], [w + 1, h + 1], [1, h + 1]], dtype=np.float32)
transform = cv2.getPerspectiveTransform(keypoints, target_kpts)
cropped = cv2.warpPerspective(img_np, transform, (round(w) + 2, round(h) + 2), flags=cv2.INTER_CUBIC)
return cropped
def predict(img: Image.Image):
img_chw = preprocess(img)
pred_kpts = SESSION.run(None, {INPUT_NAME: img_chw[None]})[0][0]
kpts_xy = pred_kpts[:, :2] * max(img.size) / IMAGE_SIZE
img_np = np.array(img)
cv2.polylines(
img_np,
[kpts_xy.astype(int)],
True,
(0, 255, 0),
thickness=5,
lineType=cv2.LINE_AA,
)
if (pred_kpts[:, 2] >= 0.25).all():
cropped = rectify(np.array(img), kpts_xy)
else:
cropped = None
return cropped, img_np
gr.Interface(
predict,
inputs=[gr.Image(type="pil")],
outputs=["image", "image"],
examples=["estonia_id_card.jpg", "german_bundesdruckerei_passport.webp"],
).launch(server_name="0.0.0.0")