feat: 提升OCR稳定性并支持多栏地址解析

This commit is contained in:
empty
2026-02-24 22:45:11 +08:00
parent 1d6ee0a95e
commit 6ce4b7b363
6 changed files with 1026 additions and 216 deletions

View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python3
"""
解析器快速自测脚本
运行方式:
python scripts/test_extract_info.py
"""
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
from processor import extract_info # noqa: E402
def _print_case(name: str, result: dict) -> None:
print(f"\n=== {name} ===")
for key in ["编号", "邮编", "地址", "联系人/单位名", "电话"]:
print(f"{key}: {result.get(key, '')}")
def case_layout_multi_column() -> None:
"""多栏场景:左侧地址、右侧单位+联系人。"""
ocr_lines = [
{"text": "518000", "box": [[80, 40], [180, 40], [180, 80], [80, 80]], "source": "main"},
{"text": "广东省深圳市南山区", "box": [[80, 100], [450, 100], [450, 132], [80, 132]], "source": "main"},
{"text": "科技园高新南一道18号", "box": [[80, 140], [520, 140], [520, 172], [80, 172]], "source": "main"},
{"text": "创新大厦3栋1201", "box": [[80, 180], [420, 180], [420, 212], [80, 212]], "source": "main"},
{"text": "华南建设小组办公室", "box": [[620, 182], [960, 182], [960, 214], [620, 214]], "source": "main"},
{"text": "张三13800138000", "box": [[620, 222], [960, 222], [960, 254], [620, 254]], "source": "main"},
{"text": "202602241234567890", "box": [[280, 60], [760, 60], [760, 94], [280, 94]], "source": "number"},
]
result = extract_info(ocr_lines)
_print_case("多栏版面", result)
assert result["邮编"] == "518000"
assert result["电话"] == "13800138000"
assert "广东省深圳市南山区" in result["地址"]
assert "科技园高新南一道18号" in result["地址"]
assert "华南建设小组办公室" in result["联系人/单位名"]
assert result["编号"] == "202602241234567890"
def case_layout_single_column() -> None:
"""单列场景:邮编后连续地址,电话行包含联系人。"""
ocr_lines = [
{"text": "200120", "box": [[90, 42], [188, 42], [188, 76], [90, 76]], "source": "main"},
{"text": "上海市浦东新区世纪大道100号", "box": [[90, 96], [620, 96], [620, 128], [90, 128]], "source": "main"},
{"text": "A座1201室", "box": [[90, 136], [300, 136], [300, 168], [90, 168]], "source": "main"},
{"text": "李四021-12345678", "box": [[90, 178], [420, 178], [420, 210], [90, 210]], "source": "main"},
]
result = extract_info(ocr_lines)
_print_case("单列版面", result)
assert result["邮编"] == "200120"
assert result["电话"] == "021-12345678"
assert "上海市浦东新区世纪大道100号" in result["地址"]
assert "A座1201室" in result["地址"]
assert result["联系人/单位名"] == "李四"
def case_text_fallback() -> None:
"""无坐标回退:纯文本顺序规则。"""
ocr_texts = [
"518000",
"广东省深圳市南山区科技园",
"高新南一道18号",
"华南建设小组办公室",
"王五 13911112222",
]
result = extract_info(ocr_texts)
_print_case("纯文本回退", result)
assert result["邮编"] == "518000"
assert result["电话"] == "13911112222"
assert "广东省深圳市南山区科技园" in result["地址"]
assert "高新南一道18号" in result["地址"]
assert "华南建设小组办公室" in result["联系人/单位名"] or result["联系人/单位名"] == "王五"
def main() -> None:
case_layout_multi_column()
case_layout_single_column()
case_text_fallback()
print("\n所有场景断言通过。")
if __name__ == "__main__":
main()

View File

@@ -11,6 +11,7 @@ import time
import logging import logging
import threading import threading
import queue import queue
import multiprocessing as mp
import subprocess import subprocess
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -24,8 +25,8 @@ from PyQt6.QtWidgets import (
from PyQt6.QtCore import Qt, QTimer, pyqtSignal, QObject, pyqtSlot from PyQt6.QtCore import Qt, QTimer, pyqtSignal, QObject, pyqtSlot
from PyQt6.QtGui import QImage, QPixmap, QFont, QAction, QKeySequence, QShortcut from PyQt6.QtGui import QImage, QPixmap, QFont, QAction, QKeySequence, QShortcut
from processor import extract_info from ocr_offline import get_models_base_dir
from ocr_offline import create_offline_ocr, get_models_base_dir from ocr_worker_process import run_ocr_worker
logger = logging.getLogger("post_ocr.desktop") logger = logging.getLogger("post_ocr.desktop")
@@ -70,12 +71,12 @@ def setup_logging() -> Path:
class OCRService(QObject): class OCRService(QObject):
""" """
OCR 后台服务(运行在标准 Python 线程内)。 OCR 后台服务(运行在独立子进程中)。
关键点: 关键点:
- 避免使用 QThread在 macOS 上QThread(Dummy-*) 内 import paddleocr 可能卡死 - PaddleOCR 初始化与推理都放到子进程,避免阻塞 UI 主进程
- PaddleOCR 实例在后台线程内创建并使用,避免跨线程调用导致卡死/死锁 - 主进程只做任务投递与结果回调
- 单线程串行处理任务:避免并发推理挤爆内存或引发底层库竞争 - 子进程异常或卡住时,可通过重启服务恢复
""" """
finished = pyqtSignal(int, dict, list) finished = pyqtSignal(int, dict, list)
@@ -87,11 +88,20 @@ class OCRService(QObject):
def __init__(self, models_base_dir: Path): def __init__(self, models_base_dir: Path):
super().__init__() super().__init__()
self._models_base_dir = models_base_dir self._models_base_dir = models_base_dir
self._ocr = None
self._busy = False self._busy = False
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._queue: "queue.Queue[tuple[int, object] | None]" = queue.Queue() method_default = "fork" if sys.platform == "darwin" else "spawn"
self._thread = threading.Thread(target=self._run, name="OCRThread", daemon=True) method = os.environ.get("POST_OCR_MP_START_METHOD", method_default).strip() or method_default
try:
self._ctx = mp.get_context(method)
except ValueError:
method = method_default
self._ctx = mp.get_context(method_default)
logger.info("OCR multiprocessing start_method=%s", method)
self._req_q = None
self._resp_q = None
self._proc = None
self._reader_thread = None
def _set_busy(self, busy: bool) -> None: def _set_busy(self, busy: bool) -> None:
if self._busy != busy: if self._busy != busy:
@@ -99,118 +109,152 @@ class OCRService(QObject):
self.busy_changed.emit(busy) self.busy_changed.emit(busy)
def start(self) -> None: def start(self) -> None:
"""启动后台线程并执行 warmup""" """启动 OCR 子进程与响应监听线程"""
self._thread.start() self._stop_event.clear()
self._req_q = self._ctx.Queue(maxsize=1)
self._resp_q = self._ctx.Queue()
self._proc = self._ctx.Process(
target=run_ocr_worker,
args=(str(self._models_base_dir), self._req_q, self._resp_q),
name="OCRProcess",
daemon=True,
)
self._proc.start()
self._reader_thread = threading.Thread(
target=self._read_responses,
name="OCRRespReader",
daemon=True,
)
self._reader_thread.start()
def stop(self, timeout_ms: int = 8000) -> bool: def stop(self, timeout_ms: int = 8000) -> bool:
"""请求停止后台线程并等待退出(后台线程为 daemon退出失败也不阻塞进程""" """停止 OCR 子进程与监听线程"""
try: try:
self._stop_event.set() self._stop_event.set()
# 用 sentinel 唤醒阻塞在 queue.get() 的线程
try: try:
self._queue.put_nowait(None) if self._req_q is not None:
self._req_q.put_nowait(None)
except Exception: except Exception:
pass pass
self._thread.join(timeout=max(0.0, timeout_ms / 1000.0)) if self._reader_thread is not None:
return not self._thread.is_alive() self._reader_thread.join(timeout=max(0.0, timeout_ms / 1000.0))
proc_alive = False
if self._proc is not None:
self._proc.join(timeout=max(0.0, timeout_ms / 1000.0))
if self._proc.is_alive():
proc_alive = True
self._proc.terminate()
self._proc.join(timeout=1.0)
self._set_busy(False)
return not proc_alive
except Exception: except Exception:
self._set_busy(False)
return False return False
finally:
self._proc = None
self._reader_thread = None
self._req_q = None
self._resp_q = None
def _ensure_ocr(self) -> None: def _read_responses(self) -> None:
if self._ocr is None: """读取 OCR 子进程响应并转发为 Qt 信号。"""
logger.info("OCR ensure_ocr: 开始创建 PaddleOCR线程=%s", threading.current_thread().name)
self._ocr = create_offline_ocr(models_base_dir=self._models_base_dir)
logger.info("OCR ensure_ocr: PaddleOCR 创建完成")
self.ready.emit()
def _warmup(self) -> None:
"""提前加载 OCR 模型,避免首次识别时才初始化导致“像卡死”"""
logger.info("OCR 预热开始(线程=%s", threading.current_thread().name)
self._ensure_ocr()
logger.info("OCR 预热完成")
def _run(self) -> None:
try:
self._warmup()
except Exception as e:
logger.exception("OCR 预热失败:%s", str(e))
self.init_error.emit(str(e))
return
while not self._stop_event.is_set(): while not self._stop_event.is_set():
item = None
try: try:
item = self._queue.get() if self._resp_q is None:
except Exception: return
msg = self._resp_q.get(timeout=0.2)
except queue.Empty:
continue continue
except Exception:
if not self._stop_event.is_set():
self.init_error.emit("OCR 子进程通信失败")
return
if item is None: if not isinstance(msg, dict):
# sentinel: stop continue
break msg_type = str(msg.get("type", "")).strip()
if msg_type == "progress":
job_id, images = item job_id = msg.get("job_id", "-")
if self._stop_event.is_set(): stage = msg.get("stage", "")
break extra = []
self._process_job(job_id, images) if "images" in msg:
extra.append(f"images={msg.get('images')}")
if "texts" in msg:
extra.append(f"texts={msg.get('texts')}")
suffix = f" ({', '.join(extra)})" if extra else ""
logger.info("OCR 子进程进度 job=%s stage=%s%s", job_id, stage, suffix)
continue
if msg_type == "ready":
logger.info("OCR 子进程已就绪 pid=%s", getattr(self._proc, "pid", None))
self.ready.emit()
continue
if msg_type == "init_error":
self._set_busy(False)
self.init_error.emit(str(msg.get("error", "OCR 初始化失败")))
continue
if msg_type == "result":
self._set_busy(False)
try:
job_id = int(msg.get("job_id"))
except Exception:
job_id = -1
record = msg.get("record") if isinstance(msg.get("record"), dict) else {}
texts = msg.get("texts") if isinstance(msg.get("texts"), list) else []
self.finished.emit(job_id, record, texts)
continue
if msg_type == "error":
self._set_busy(False)
try:
job_id = int(msg.get("job_id"))
except Exception:
job_id = -1
self.error.emit(job_id, str(msg.get("error", "OCR 处理失败")))
continue
@pyqtSlot(int, object) @pyqtSlot(int, object)
def process(self, job_id: int, images: object) -> None: def process(self, job_id: int, images: object) -> None:
"""接收 UI 请求:把任务放进队列,由后台线程串行处理""" """接收 UI 请求并投递到 OCR 子进程"""
if self._stop_event.is_set(): if self._stop_event.is_set():
self.error.emit(job_id, "OCR 服务正在关闭,请稍后重试。") self.error.emit(job_id, "OCR 服务正在关闭,请稍后重试。")
return return
# 忙碌或已有排队任务时,直接拒绝,避免积压导致“看起来一直在识别” if self._proc is None or (not self._proc.is_alive()):
if self._busy or (not self._queue.empty()): self.error.emit(job_id, "OCR 服务未就绪,请稍后重试。")
return
if self._busy:
self.error.emit(job_id, "OCR 正在进行中,请稍后再试。") self.error.emit(job_id, "OCR 正在进行中,请稍后再试。")
return return
if not isinstance(images, (list, tuple)) or len(images) == 0:
self.error.emit(job_id, "内部错误:未传入有效图片数据")
return
try: try:
# 注意:这里不做耗时工作,只入队,避免阻塞 UI
self._queue.put_nowait((job_id, images))
except Exception as e:
self.error.emit(job_id, f"OCR 入队失败:{str(e)}")
def _process_job(self, job_id: int, images: object) -> None:
self._set_busy(True)
try:
self._ensure_ocr()
if not isinstance(images, (list, tuple)) or len(images) == 0:
raise ValueError("内部错误:未传入有效图片数据")
shapes = [] shapes = []
for img in images: for item in images:
img = item
source = "main"
if isinstance(item, dict):
img = item.get("img")
source = str(item.get("source", "main"))
try: try:
shapes.append(getattr(img, "shape", None)) shapes.append({"source": source, "shape": getattr(img, "shape", None)})
except Exception: except Exception:
shapes.append(None) shapes.append({"source": source, "shape": None})
logger.info("OCR job=%s 开始images=%s", job_id, shapes) logger.info("OCR job=%s 投递到子进程images=%s", job_id, shapes)
ocr_texts: list[str] = [] self._set_busy(True)
for img in images: if self._req_q is None:
if img is None: raise RuntimeError("OCR 请求队列不可用")
continue self._req_q.put_nowait((int(job_id), list(images)))
result = self._ocr.ocr(img, cls=False) except queue.Full:
if result and result[0]:
for line in result[0]:
if line and len(line) >= 2:
ocr_texts.append(line[1][0])
record = extract_info(ocr_texts)
logger.info(
"OCR job=%s 完成lines=%s, record_keys=%s",
job_id,
len(ocr_texts),
list(record.keys()),
)
self.finished.emit(job_id, record, ocr_texts)
except Exception as e:
logger.exception("OCR job=%s 失败:%s", job_id, str(e))
self.error.emit(job_id, str(e))
finally:
self._set_busy(False) self._set_busy(False)
self.error.emit(job_id, "OCR 队列已满,请稍后再试。")
except Exception as e:
self._set_busy(False)
self.error.emit(job_id, f"OCR 入队失败:{str(e)}")
class MainWindow(QMainWindow): class MainWindow(QMainWindow):
@@ -223,17 +267,22 @@ class MainWindow(QMainWindow):
# OCR 工作线程(避免 UI 卡死) # OCR 工作线程(避免 UI 卡死)
self._ocr_job_id = 0 self._ocr_job_id = 0
self._ocr_pending_job_id = None
self._ocr_start_time_by_job: dict[int, float] = {} self._ocr_start_time_by_job: dict[int, float] = {}
self._ocr_ready = False self._ocr_ready = False
self._ocr_busy = False self._ocr_busy = False
self._shutting_down = False self._shutting_down = False
self._ocr_timeout_prompted = False self._ocr_timeout_prompted = False
self._ocr_restarting = False
# 摄像头 # 摄像头
self.cap = None self.cap = None
self.timer = QTimer() self.timer = QTimer()
self.timer.timeout.connect(self.update_frame) self.timer.timeout.connect(self.update_frame)
self._frame_fail_count = 0 self._frame_fail_count = 0
self._last_frame = None
self._last_frame_ts = 0.0
self._capture_in_progress = False
# 状态栏进度(识别中显示) # 状态栏进度(识别中显示)
self._progress = QProgressBar() self._progress = QProgressBar()
@@ -252,17 +301,44 @@ class MainWindow(QMainWindow):
self.init_ui() self.init_ui()
self.load_cameras() self.load_cameras()
# 主线程预加载:在 macOS 上,必须在主线程 import paddleocr,否则后台线程会卡死 # 历史上主线程直接 import paddleocr 偶发卡死
self.statusBar().showMessage("正在加载 OCR 模块...") # 默认跳过该步骤,避免 UI 被阻塞;如需诊断可打开轻量预检(子进程 + 超时)。
QApplication.processEvents() if os.environ.get("POST_OCR_PRECHECK_IMPORT", "0").strip() == "1":
try: timeout_sec = 8
logger.info("主线程预加载import paddleocr") try:
import paddleocr # noqa: F401 timeout_sec = max(
logger.info("主线程预加载paddleocr 导入完成") 2,
except Exception as e: int(
logger.error("主线程预加载失败:%s", e, exc_info=True) os.environ.get("POST_OCR_PRECHECK_TIMEOUT_SEC", "8").strip()
QMessageBox.critical(self, "启动失败", f"无法加载 OCR 模块:{e}") or "8"
raise ),
)
except Exception:
timeout_sec = 8
self.statusBar().showMessage("正在预检 OCR 模块...")
QApplication.processEvents()
try:
logger.info("OCR 预检开始子进程timeout=%ss", timeout_sec)
proc = subprocess.run(
[sys.executable, "-c", "import paddleocr"],
capture_output=True,
text=True,
timeout=timeout_sec,
)
if proc.returncode == 0:
logger.info("OCR 预检通过")
else:
logger.warning(
"OCR 预检失败rc=%s%s",
proc.returncode,
(proc.stderr or "").strip(),
)
except subprocess.TimeoutExpired:
logger.warning("OCR 预检超时(%ss跳过预检继续启动。", timeout_sec)
except Exception as e:
logger.warning("OCR 预检异常:%s(忽略并继续)", str(e))
else:
logger.info("已跳过主线程 OCR 预检POST_OCR_PRECHECK_IMPORT=0")
# OCR 服务放在 UI 初始化之后启动,避免 ready/busy 信号回调时 btn_capture 尚未创建 # OCR 服务放在 UI 初始化之后启动,避免 ready/busy 信号回调时 btn_capture 尚未创建
self.statusBar().showMessage("正在启动 OCR 服务...") self.statusBar().showMessage("正在启动 OCR 服务...")
@@ -308,6 +384,8 @@ class MainWindow(QMainWindow):
self._ocr_ready = False self._ocr_ready = False
self._ocr_busy = False self._ocr_busy = False
self._ocr_timeout_prompted = False self._ocr_timeout_prompted = False
self._ocr_pending_job_id = None
self._ocr_start_time_by_job.clear()
try: try:
self._progress.setVisible(False) self._progress.setVisible(False)
except Exception: except Exception:
@@ -316,10 +394,13 @@ class MainWindow(QMainWindow):
try: try:
svc = getattr(self, "_ocr_service", None) svc = getattr(self, "_ocr_service", None)
if svc is not None: if svc is not None:
try:
self.request_ocr.disconnect(svc.process)
except Exception:
pass
ok = svc.stop(timeout_ms=8000 if force else 3000) ok = svc.stop(timeout_ms=8000 if force else 3000)
if (not ok) and force: if (not ok) and force:
# Python 线程无法可靠“强杀”,这里只做提示并继续退出流程。 logger.warning("OCR 服务停止超时:子进程可能仍在退出中,建议重启应用。")
logger.warning("OCR 服务停止超时:后台线程可能仍在运行,建议重启应用。")
except Exception: except Exception:
pass pass
@@ -333,9 +414,15 @@ class MainWindow(QMainWindow):
if self._shutting_down: if self._shutting_down:
return return
self.statusBar().showMessage("正在重启 OCR 服务...") if self._ocr_restarting:
self._stop_ocr_service(force=True) return
self._init_ocr_service() self._ocr_restarting = True
try:
self.statusBar().showMessage("正在重启 OCR 服务...")
self._stop_ocr_service(force=True)
self._init_ocr_service()
finally:
self._ocr_restarting = False
def _init_ocr_service(self) -> None: def _init_ocr_service(self) -> None:
models_dir = get_models_base_dir() models_dir = get_models_base_dir()
@@ -347,7 +434,7 @@ class MainWindow(QMainWindow):
self._ocr_service = OCRService(models_base_dir=models_dir) self._ocr_service = OCRService(models_base_dir=models_dir)
# 注意OCRService 内部使用 Python 线程做 warmup 与推理。 # 注意OCRService 内部使用独立子进程做 warmup 与推理。
# 这里强制使用 QueuedConnection确保 UI 回调始终在主线程执行。 # 这里强制使用 QueuedConnection确保 UI 回调始终在主线程执行。
self.request_ocr.connect(self._ocr_service.process, Qt.ConnectionType.QueuedConnection) self.request_ocr.connect(self._ocr_service.process, Qt.ConnectionType.QueuedConnection)
self._ocr_service.ready.connect(self._on_ocr_ready, Qt.ConnectionType.QueuedConnection) self._ocr_service.ready.connect(self._on_ocr_ready, Qt.ConnectionType.QueuedConnection)
@@ -378,6 +465,8 @@ class MainWindow(QMainWindow):
try: try:
self._ocr_busy = busy self._ocr_busy = busy
if busy: if busy:
# OCR 线程已开始处理,提交阶段不再算“待接收”
self._ocr_pending_job_id = None
self._progress.setRange(0, 0) # 不确定进度条 self._progress.setRange(0, 0) # 不确定进度条
self._progress.setVisible(True) self._progress.setVisible(True)
self._ocr_timeout_prompted = False self._ocr_timeout_prompted = False
@@ -391,8 +480,27 @@ class MainWindow(QMainWindow):
except Exception as e: except Exception as e:
logger.exception("处理 OCR busy 回调失败:%s", str(e)) logger.exception("处理 OCR busy 回调失败:%s", str(e))
def _guard_ocr_submission(self, job_id: int) -> None:
"""
兜底保护:
如果提交后一段时间仍未进入 busy 状态,说明任务可能未被 OCR 线程接收,
主动恢复按钮,避免界面一直停留在“正在识别...”。
"""
if job_id != self._ocr_pending_job_id:
return
if self._ocr_busy:
return
self._ocr_pending_job_id = None
self._ocr_start_time_by_job.pop(job_id, None)
logger.warning("OCR job=%s 提交后未被接收,已自动恢复 UI 状态", job_id)
self.statusBar().showMessage("识别请求未被处理,请重试一次(已自动恢复)")
if self.btn_capture is not None:
self.btn_capture.setEnabled(self.cap is not None and self._ocr_ready)
def _tick_ocr_watchdog(self) -> None: def _tick_ocr_watchdog(self) -> None:
"""识别进行中:更新耗时,超时则提示是否重启 OCR 服务。""" """识别进行中:更新耗时,超时自动重启 OCR 服务。"""
if not self._ocr_busy: if not self._ocr_busy:
return return
@@ -402,19 +510,30 @@ class MainWindow(QMainWindow):
cost = time.monotonic() - start_t cost = time.monotonic() - start_t
self.statusBar().showMessage(f"正在识别...(已用 {cost:.1f}s") self.statusBar().showMessage(f"正在识别...(已用 {cost:.1f}s")
# 超时保护:底层推理偶发卡住时,让用户可以自救 # 超时保护:底层推理偶发卡住时,自动重启 OCR 服务并恢复可用状态
if cost >= 45 and not self._ocr_timeout_prompted: timeout_sec = 25
try:
timeout_sec = max(
8, int(os.environ.get("POST_OCR_JOB_TIMEOUT_SEC", "25").strip() or "25")
)
except Exception:
timeout_sec = 25
if cost >= timeout_sec and not self._ocr_timeout_prompted:
self._ocr_timeout_prompted = True self._ocr_timeout_prompted = True
reply = QMessageBox.question( logger.warning("OCR job=%s 超时 %.1fs自动重启 OCR 服务", self._ocr_job_id, cost)
self.statusBar().showMessage(f"识别超时({cost:.1f}s正在自动恢复...")
# 当前任务视为失败并回收,避免界面一直等待结果
self._ocr_start_time_by_job.pop(self._ocr_job_id, None)
self._restart_ocr_service()
QMessageBox.warning(
self, self,
"识别超时", "识别超时",
"识别已超过 45 秒仍未完成,可能卡住。\n\n是否重启 OCR 服务\n(若仍无响应,建议直接退出并重新打开应用)", "本次识别超时,已自动重启 OCR 服务\n请再次拍照识别。",
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
) )
if reply == QMessageBox.StandardButton.Yes:
self._restart_ocr_service()
def _on_ocr_finished_job(self, job_id: int, record: dict, texts: list) -> None: def _on_ocr_finished_job(self, job_id: int, record: dict, texts: list) -> None:
if self._ocr_pending_job_id == job_id:
self._ocr_pending_job_id = None
start_t = self._ocr_start_time_by_job.pop(job_id, None) start_t = self._ocr_start_time_by_job.pop(job_id, None)
# 只处理最新一次请求,避免旧结果回写 # 只处理最新一次请求,避免旧结果回写
@@ -428,14 +547,18 @@ class MainWindow(QMainWindow):
cost = f"(耗时 {time.monotonic() - start_t:.1f}s" cost = f"(耗时 {time.monotonic() - start_t:.1f}s"
self.statusBar().showMessage(f"识别完成: {record.get('联系人/单位名', '未知')}{cost}") self.statusBar().showMessage(f"识别完成: {record.get('联系人/单位名', '未知')}{cost}")
logger.info("OCR job=%s UI 回写完成 %s", job_id, cost) logger.info("OCR job=%s UI 回写完成 %s", job_id, cost)
self.btn_capture.setEnabled(self.cap is not None and self._ocr_ready and not self._ocr_busy)
def _on_ocr_error_job(self, job_id: int, error: str) -> None: def _on_ocr_error_job(self, job_id: int, error: str) -> None:
if self._ocr_pending_job_id == job_id:
self._ocr_pending_job_id = None
self._ocr_start_time_by_job.pop(job_id, None) self._ocr_start_time_by_job.pop(job_id, None)
if job_id != self._ocr_job_id: if job_id != self._ocr_job_id:
return return
self.statusBar().showMessage("识别失败") self.statusBar().showMessage("识别失败")
QMessageBox.warning(self, "识别失败", error) QMessageBox.warning(self, "识别失败", error)
logger.error("OCR job=%s error: %s", job_id, error) logger.error("OCR job=%s error: %s", job_id, error)
self.btn_capture.setEnabled(self.cap is not None and self._ocr_ready and not self._ocr_busy)
def init_ui(self): def init_ui(self):
central = QWidget() central = QWidget()
@@ -519,6 +642,7 @@ class MainWindow(QMainWindow):
# macOS/Qt 下 Space 经常被控件吞掉(按钮激活/表格选择等),用 ApplicationShortcut 更稳 # macOS/Qt 下 Space 经常被控件吞掉(按钮激活/表格选择等),用 ApplicationShortcut 更稳
self._shortcut_capture2 = QShortcut(QKeySequence("Space"), self) self._shortcut_capture2 = QShortcut(QKeySequence("Space"), self)
self._shortcut_capture2.setContext(Qt.ShortcutContext.ApplicationShortcut) self._shortcut_capture2.setContext(Qt.ShortcutContext.ApplicationShortcut)
self._shortcut_capture2.setAutoRepeat(False)
self._shortcut_capture2.activated.connect(self.capture_and_recognize) self._shortcut_capture2.activated.connect(self.capture_and_recognize)
def load_cameras(self): def load_cameras(self):
@@ -770,6 +894,13 @@ class MainWindow(QMainWindow):
ret, frame = self.cap.read() ret, frame = self.cap.read()
if ret and frame is not None and frame.size > 0: if ret and frame is not None and frame.size > 0:
self._frame_fail_count = 0 self._frame_fail_count = 0
# 缓存原始帧,拍照时直接使用,避免按空格再读摄像头导致主线程阻塞
try:
self._last_frame = frame.copy()
self._last_frame_ts = time.monotonic()
except Exception:
self._last_frame = frame
self._last_frame_ts = time.monotonic()
# 绘制扫描框 # 绘制扫描框
h, w = frame.shape[:2] h, w = frame.shape[:2]
# 框的位置:上方 70%,编号在下方 # 框的位置:上方 70%,编号在下方
@@ -812,6 +943,9 @@ class MainWindow(QMainWindow):
def capture_and_recognize(self): def capture_and_recognize(self):
"""拍照并识别""" """拍照并识别"""
if self._capture_in_progress:
self.statusBar().showMessage("正在拍照,请稍候")
return
if self.cap is None: if self.cap is None:
self.statusBar().showMessage("请先连接摄像头") self.statusBar().showMessage("请先连接摄像头")
return return
@@ -822,61 +956,126 @@ class MainWindow(QMainWindow):
self.statusBar().showMessage("正在识别中,请稍后再按空格") self.statusBar().showMessage("正在识别中,请稍后再按空格")
return return
ret, frame = self.cap.read() self._capture_in_progress = True
if not ret:
self.statusBar().showMessage("拍照失败")
return
# 裁剪两块 ROI主信息框 + 编号区域),显著减小像素量,提升速度与稳定性
h, w = frame.shape[:2]
x1, y1 = int(w * 0.06), int(h * 0.08)
x2 = int(w * 0.94)
y2_box = int(h * 0.78)
roi_images = []
try: try:
roi_box = frame[y1:y2_box, x1:x2] # 直接使用预览缓存帧,避免在按键回调中阻塞式 read 摄像头导致卡顿
if roi_box is not None and roi_box.size > 0: frame = None
roi_images.append(roi_box) now = time.monotonic()
except Exception: if self._last_frame is not None and (now - self._last_frame_ts) <= 1.5:
pass try:
frame = self._last_frame.copy()
except Exception:
frame = self._last_frame
try: if frame is None:
# 编号一般在底部中间,取较小区域即可 self.statusBar().showMessage("尚未拿到稳定画面,请稍后再按空格")
nx1, nx2 = int(w * 0.30), int(w * 0.70) return
ny1, ny2 = int(h * 0.80), int(h * 0.98)
roi_num = frame[ny1:ny2, nx1:nx2]
if roi_num is not None and roi_num.size > 0:
roi_images.append(roi_num)
except Exception:
pass
if not roi_images: # 裁剪主信息 ROI 与编号 ROI
self.statusBar().showMessage("拍照失败:未截取到有效区域") h, w = frame.shape[:2]
return x1, y1 = int(w * 0.06), int(h * 0.08)
x2 = int(w * 0.94)
y2_box = int(h * 0.78)
# 超大分辨率下适当缩放(提高稳定性与速度) roi_inputs = []
resized_images = []
for img in roi_images:
try: try:
max_w = 1400 roi_box = frame[y1:y2_box, x1:x2]
if img.shape[1] > max_w: if roi_box is not None and roi_box.size > 0:
scale = max_w / img.shape[1] # 主信息区域切成多段,规避大图整块检测偶发卡住
img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale))) split_count = 2
try:
split_count = max(
1,
int(
os.environ.get("POST_OCR_MAIN_SPLIT", "2").strip()
or "2"
),
)
except Exception:
split_count = 2
split_count = min(split_count, 4)
if split_count <= 1 or roi_box.shape[0] < 120:
roi_inputs.append({"img": roi_box, "source": "main"})
else:
h_box = roi_box.shape[0]
step = h_box / float(split_count)
overlap = max(8, int(h_box * 0.06))
for i in range(split_count):
sy = int(max(0, i * step - (overlap if i > 0 else 0)))
ey = int(
min(
h_box,
(i + 1) * step
+ (overlap if i < split_count - 1 else 0),
)
)
part = roi_box[sy:ey, :]
if part is not None and part.size > 0:
roi_inputs.append({"img": part, "source": "main"})
except Exception: except Exception:
pass pass
resized_images.append(img)
logger.info("UI 触发识别frame=%s, rois=%s", getattr(frame, "shape", None), [getattr(i, "shape", None) for i in resized_images]) try:
# 编号一般在底部中间,取较小区域即可
nx1, nx2 = int(w * 0.30), int(w * 0.70)
ny1, ny2 = int(h * 0.80), int(h * 0.98)
roi_num = frame[ny1:ny2, nx1:nx2]
if roi_num is not None and roi_num.size > 0:
roi_inputs.append({"img": roi_num, "source": "number"})
except Exception:
pass
self.statusBar().showMessage("正在识别...") if not roi_inputs:
self.btn_capture.setEnabled(False) self.statusBar().showMessage("拍照失败:未截取到有效区域")
return
# 派发到 OCR 工作线程 # 超大分辨率下适当缩放(提高稳定性与速度)
self._ocr_job_id += 1 resized_inputs = []
job_id = self._ocr_job_id max_w = 960
self._ocr_start_time_by_job[job_id] = time.monotonic() try:
self.request_ocr.emit(job_id, resized_images) max_w = max(
600, int(os.environ.get("POST_OCR_MAX_ROI_WIDTH", "960").strip() or "960")
)
except Exception:
max_w = 960
for item in roi_inputs:
img = item.get("img")
source = item.get("source", "main")
try:
if img is not None and img.shape[1] > max_w:
scale = max_w / img.shape[1]
img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale)))
except Exception:
pass
resized_inputs.append({"img": img, "source": source})
logger.info(
"UI 触发识别frame=%s, rois=%s, frame_age=%.3fs",
getattr(frame, "shape", None),
[
{
"source": item.get("source", "main"),
"shape": getattr(item.get("img"), "shape", None),
}
for item in resized_inputs
],
max(0.0, now - self._last_frame_ts),
)
self.statusBar().showMessage("正在识别...")
self.btn_capture.setEnabled(False)
# 派发到 OCR 工作线程
self._ocr_job_id += 1
job_id = self._ocr_job_id
self._ocr_pending_job_id = job_id
self._ocr_start_time_by_job[job_id] = time.monotonic()
self.request_ocr.emit(job_id, resized_inputs)
QTimer.singleShot(2000, lambda j=job_id: self._guard_ocr_submission(j))
finally:
self._capture_in_progress = False
def update_table(self): def update_table(self):
"""更新表格""" """更新表格"""

View File

@@ -40,15 +40,31 @@ def main():
# 2. 提取文字行 # 2. 提取文字行
ocr_texts = [] ocr_texts = []
ocr_lines = []
if result and result[0]: if result and result[0]:
for line in result[0]: for line in result[0]:
# line 格式: [box, (text, confidence)] # line 格式: [box, (text, confidence)]
if line and len(line) >= 2: if line and len(line) >= 2:
ocr_texts.append(line[1][0]) text = str(line[1][0])
ocr_texts.append(text)
conf = None
try:
conf = float(line[1][1])
except Exception:
conf = None
ocr_lines.append(
{
"text": text,
"box": line[0],
"conf": conf,
"source": "main",
"roi_index": 0,
}
)
# 3. 结构化解析 # 3. 结构化解析
if ocr_texts: if ocr_texts:
record = extract_info(ocr_texts) record = extract_info(ocr_lines if ocr_lines else ocr_texts)
all_records.append(record) all_records.append(record)
else: else:
errors.append( errors.append(

View File

@@ -118,7 +118,21 @@ def create_offline_ocr(models_base_dir: Path | None = None):
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
# 构建 PaddleOCR 参数 # 构建 PaddleOCR 参数
# 说明:在部分 macOS/CPU 环境下oneDNN(MKLDNN) 可能出现卡住,默认关闭以换取稳定性。
kwargs = dict(lang="ch", use_angle_cls=False, show_log=False) kwargs = dict(lang="ch", use_angle_cls=False, show_log=False)
disable_mkldnn = os.environ.get("POST_OCR_DISABLE_MKLDNN", "1").strip() == "1"
if disable_mkldnn:
os.environ["FLAGS_use_mkldnn"] = "0"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
kwargs["enable_mkldnn"] = False
try:
kwargs["cpu_threads"] = max(
1, int(os.environ.get("POST_OCR_CPU_THREADS", "1").strip() or "1")
)
except Exception:
kwargs["cpu_threads"] = 1
# 如果 models/ 目录存在离线模型,显式指定路径(打包分发场景) # 如果 models/ 目录存在离线模型,显式指定路径(打包分发场景)
models_dir = models_base_dir or get_models_base_dir() models_dir = models_base_dir or get_models_base_dir()
@@ -134,6 +148,12 @@ def create_offline_ocr(models_base_dir: Path | None = None):
log.info("未找到离线模型,将使用默认路径(可能需要联网下载)") log.info("未找到离线模型,将使用默认路径(可能需要联网下载)")
log.info("create_offline_ocr: creating PaddleOCR(lang=ch)") log.info("create_offline_ocr: creating PaddleOCR(lang=ch)")
ocr = PaddleOCR(**kwargs) try:
ocr = PaddleOCR(**kwargs)
except TypeError:
# 兼容个别 PaddleOCR 版本不支持的参数
kwargs.pop("enable_mkldnn", None)
kwargs.pop("cpu_threads", None)
ocr = PaddleOCR(**kwargs)
log.info("create_offline_ocr: PaddleOCR created") log.info("create_offline_ocr: PaddleOCR created")
return ocr return ocr

92
src/ocr_worker_process.py Normal file
View File

@@ -0,0 +1,92 @@
from __future__ import annotations
# 必须在所有 paddle/numpy import 之前设置,否则 macOS spawn 子进程推理会死锁
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["FLAGS_use_mkldnn"] = "0"
os.environ["PADDLE_DISABLE_SIGNAL_HANDLER"] = "1"
os.environ["PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK"] = "True"
from pathlib import Path
from typing import Any
from ocr_offline import create_offline_ocr
from processor import extract_info
def run_ocr_worker(models_base_dir: str, request_q, response_q) -> None:
"""
OCR 子进程主循环:
- 在子进程内初始化 PaddleOCR避免阻塞主 UI 进程
- 接收任务并返回结构化结果
"""
try:
response_q.put({"type": "progress", "stage": "init_start"})
ocr = create_offline_ocr(models_base_dir=Path(models_base_dir))
response_q.put({"type": "ready"})
except Exception as e:
response_q.put({"type": "init_error", "error": str(e)})
return
while True:
item = request_q.get()
if item is None:
break
job_id = -1
try:
job_id, images = item
if not isinstance(images, (list, tuple)) or len(images) == 0:
raise ValueError("内部错误:未传入有效图片数据")
response_q.put({"type": "progress", "job_id": int(job_id), "stage": "job_received", "images": len(images)})
ocr_texts: list[str] = []
ocr_lines: list[dict[str, Any]] = []
for roi_index, entry in enumerate(images):
source = "main"
img = entry
if isinstance(entry, dict):
source = str(entry.get("source", "main"))
img = entry.get("img")
elif roi_index > 0:
source = "number"
if img is None:
continue
response_q.put({"type": "progress", "job_id": int(job_id), "stage": f"roi_{roi_index}_start"})
result = ocr.ocr(img, cls=False)
response_q.put({"type": "progress", "job_id": int(job_id), "stage": f"roi_{roi_index}_done"})
if result and result[0]:
for line in result[0]:
if line and len(line) >= 2:
text = str(line[1][0])
ocr_texts.append(text)
conf = None
try:
conf = float(line[1][1])
except Exception:
conf = None
ocr_lines.append(
{
"text": text,
"box": line[0],
"conf": conf,
"source": source,
"roi_index": roi_index,
}
)
record = extract_info(ocr_lines if ocr_lines else ocr_texts)
response_q.put({"type": "progress", "job_id": int(job_id), "stage": "parse_done", "texts": len(ocr_texts)})
response_q.put(
{
"type": "result",
"job_id": int(job_id),
"record": record,
"texts": ocr_texts,
}
)
except Exception as e:
response_q.put({"type": "error", "job_id": int(job_id), "error": str(e)})

View File

@@ -1,9 +1,64 @@
import re import re
from dataclasses import dataclass
from statistics import median
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
import pandas as pd import pandas as pd
from typing import List, Dict, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
ZIP_PATTERN = re.compile(r"(?<!\d)(\d{6})(?!\d)")
PHONE_PATTERN = re.compile(r"(?<!\d)(1[3-9]\d{9}|0\d{2,3}-?\d{7,8})(?!\d)")
LONG_NUMBER_PATTERN = re.compile(r"(?<!\d)(\d{10,20})(?!\d)")
ADDRESS_HINT_PATTERN = re.compile(r"(省|市|区|县|乡|镇|街|路|村|号|栋|单元|室)")
@dataclass
class OCRLine:
text: str
source: str
order: int
x1: Optional[float] = None
y1: Optional[float] = None
x2: Optional[float] = None
y2: Optional[float] = None
row_idx: int = -1
col_idx: int = -1
@property
def has_pos(self) -> bool:
return (
self.x1 is not None
and self.y1 is not None
and self.x2 is not None
and self.y2 is not None
)
@property
def cx(self) -> float:
if not self.has_pos:
return float(self.order)
return (self.x1 + self.x2) / 2.0 # type: ignore[operator]
@property
def cy(self) -> float:
if not self.has_pos:
return float(self.order)
return (self.y1 + self.y2) / 2.0 # type: ignore[operator]
@property
def height(self) -> float:
if not self.has_pos:
return 0.0
return max(0.0, float(self.y2) - float(self.y1))
@property
def width(self) -> float:
if not self.has_pos:
return 0.0
return max(0.0, float(self.x2) - float(self.x1))
class EnvelopeRecord(BaseModel): class EnvelopeRecord(BaseModel):
编号: str = "" 编号: str = ""
邮编: str = "" 邮编: str = ""
@@ -13,75 +68,412 @@ class EnvelopeRecord(BaseModel):
def clean_text(text: str) -> str: def clean_text(text: str) -> str:
"""清理OCR识别出的杂质字符""" """清理 OCR 识别文本中的空白和无意义分隔符。"""
return text.strip().replace(" ", "") if not text:
return ""
text = text.replace("\u3000", " ").strip()
return re.sub(r"\s+", "", text)
def extract_info(ocr_results: List[str]) -> Dict[str, str]: def _parse_box(raw_box: Any) -> Tuple[Optional[float], Optional[float], Optional[float], Optional[float]]:
if not isinstance(raw_box, (list, tuple)) or len(raw_box) < 4:
return None, None, None, None
xs: List[float] = []
ys: List[float] = []
for p in raw_box:
if not isinstance(p, (list, tuple)) or len(p) < 2:
continue
try:
xs.append(float(p[0]))
ys.append(float(p[1]))
except Exception:
continue
if not xs or not ys:
return None, None, None, None
return min(xs), min(ys), max(xs), max(ys)
def _to_ocr_line(item: Any, idx: int) -> Optional[OCRLine]:
if isinstance(item, str):
text = clean_text(item)
if not text:
return None
return OCRLine(text=text, source="main", order=idx)
if not isinstance(item, dict):
return None
text = clean_text(str(item.get("text", "")))
if not text:
return None
source = str(item.get("source", "main"))
x1, y1, x2, y2 = _parse_box(item.get("box"))
if x1 is None:
# 兼容直接传坐标的输入
try:
x1 = float(item.get("x1"))
y1 = float(item.get("y1"))
x2 = float(item.get("x2"))
y2 = float(item.get("y2"))
except Exception:
x1, y1, x2, y2 = None, None, None, None
return OCRLine(text=text, source=source, order=idx, x1=x1, y1=y1, x2=x2, y2=y2)
def _normalize_ocr_results(ocr_results: Sequence[Any]) -> List[OCRLine]:
lines: List[OCRLine] = []
seen = set()
for idx, item in enumerate(ocr_results):
line = _to_ocr_line(item, idx)
if line is None:
continue
if line.has_pos:
key = (
line.text,
line.source,
round(line.cx, 1),
round(line.cy, 1),
)
else:
key = (line.text, line.source, line.order)
if key in seen:
continue
seen.add(key)
lines.append(line)
return lines
def _first_match(pattern: re.Pattern[str], text: str) -> str:
m = pattern.search(text)
if not m:
return ""
if m.lastindex:
return m.group(1)
return m.group(0)
def _find_anchor(lines: Iterable[OCRLine], pattern: re.Pattern[str], prefer_bottom: bool) -> Optional[Tuple[OCRLine, str]]:
candidates: List[Tuple[OCRLine, str]] = []
for line in lines:
m = pattern.search(line.text)
if not m:
continue
token = m.group(1) if m.lastindex else m.group(0)
candidates.append((line, token))
if not candidates:
return None
if prefer_bottom:
return max(candidates, key=lambda item: (item[0].row_idx, item[0].cy, item[0].cx, item[0].order))
return min(candidates, key=lambda item: (item[0].row_idx, item[0].cy, item[0].cx, item[0].order))
def _build_rows(lines: List[OCRLine]) -> List[List[OCRLine]]:
positioned = [line for line in lines if line.has_pos]
if not positioned:
return []
positioned.sort(key=lambda line: (line.cy, line.cx))
heights = [line.height for line in positioned if line.height > 1.0]
h_med = median(heights) if heights else 20.0
y_threshold = max(8.0, h_med * 0.65)
rows: List[List[OCRLine]] = []
for line in positioned:
if not rows:
rows.append([line])
continue
row = rows[-1]
mean_y = sum(item.cy for item in row) / len(row)
if abs(line.cy - mean_y) <= y_threshold:
row.append(line)
else:
rows.append([line])
for row_idx, row in enumerate(rows):
row.sort(key=lambda line: (line.cx, line.x1 or 0.0))
for col_idx, line in enumerate(row):
line.row_idx = row_idx
line.col_idx = col_idx
return rows
def _sanitize_address(text: str) -> str:
text = clean_text(text)
text = re.sub(r"^(地址|收件地址|详细地址)[:]?", "", text)
return text
def _sanitize_contact(text: str) -> str:
text = clean_text(text)
text = re.sub(r"^(收件人|联系人|单位|收)[:]?", "", text)
return text.strip(",。;;:")
def _join_entries(entries: List[Tuple[int, int, str]]) -> str:
if not entries:
return ""
entries.sort(key=lambda item: (item[0], item[1]))
merged: List[str] = []
for _, _, text in entries:
if not text:
continue
if merged and merged[-1] == text:
continue
merged.append(text)
return "".join(merged)
def _extract_tracking_number(lines: List[OCRLine], zip_code: str, phone: str) -> str:
phone_digits = re.sub(r"\D", "", phone)
candidates: List[Tuple[int, int, str]] = []
for line in lines:
for match in LONG_NUMBER_PATTERN.finditer(line.text):
number = match.group(1)
if number == zip_code:
continue
if phone and (number == phone or number == phone_digits):
continue
src_score = 2 if line.source == "number" else 1
candidates.append((src_score, len(number), number))
if not candidates:
return ""
candidates.sort(reverse=True)
return candidates[0][2]
def _extract_with_layout(lines: List[OCRLine], data: Dict[str, str]) -> Tuple[str, str, bool]:
main_lines = [line for line in lines if line.source != "number"]
if len(main_lines) < 2:
return "", "", False
rows = _build_rows(main_lines)
if not rows:
return "", "", False
zip_anchor = _find_anchor(main_lines, ZIP_PATTERN, prefer_bottom=False)
phone_anchor = _find_anchor(main_lines, PHONE_PATTERN, prefer_bottom=True)
if zip_anchor and not data["邮编"]:
data["邮编"] = zip_anchor[1]
if phone_anchor and not data["电话"]:
data["电话"] = phone_anchor[1]
if zip_anchor:
start_row = zip_anchor[0].row_idx
else:
start_row = min(line.row_idx for line in main_lines)
if phone_anchor:
end_row = phone_anchor[0].row_idx
else:
end_row = max(line.row_idx for line in main_lines)
if start_row > end_row:
start_row, end_row = end_row, start_row
single_column_mode = False
if zip_anchor and phone_anchor:
line_widths = [line.width for line in main_lines if line.width > 0]
width_ref = median(line_widths) if line_widths else 120.0
single_column_mode = abs(phone_anchor[0].cx - zip_anchor[0].cx) < max(60.0, width_ref * 0.6)
if zip_anchor and phone_anchor and phone_anchor[0].cx > zip_anchor[0].cx and not single_column_mode:
split_x = (zip_anchor[0].cx + phone_anchor[0].cx) / 2.0
elif phone_anchor:
split_x = phone_anchor[0].cx - max(40.0, phone_anchor[0].width * 0.6)
elif zip_anchor:
split_x = zip_anchor[0].cx + max(80.0, zip_anchor[0].width * 1.5)
else:
split_x = median([line.cx for line in main_lines])
address_entries: List[Tuple[int, int, str]] = []
contact_entries: List[Tuple[int, int, str]] = []
for line in main_lines:
if line.row_idx < start_row or line.row_idx > end_row:
continue
text = line.text
if zip_anchor and line is zip_anchor[0]:
text = text.replace(zip_anchor[1], "")
if phone_anchor and line is phone_anchor[0]:
text = text.replace(phone_anchor[1], "")
text = clean_text(text)
if not text:
continue
if re.fullmatch(r"\d{6,20}", text):
continue
if single_column_mode:
if phone_anchor and line is phone_anchor[0]:
contact_entries.append((line.row_idx, line.col_idx, text))
else:
address_entries.append((line.row_idx, line.col_idx, text))
continue
if line.cx <= split_x:
address_entries.append((line.row_idx, line.col_idx, text))
else:
contact_entries.append((line.row_idx, line.col_idx, text))
# 联系人优先取靠近电话的一段,降低把地址误分到联系人的概率
if phone_anchor and contact_entries:
phone_row = phone_anchor[0].row_idx
min_dist = min(abs(item[0] - phone_row) for item in contact_entries)
contact_entries = [
item for item in contact_entries if abs(item[0] - phone_row) <= min_dist + 1
]
contact_text = _sanitize_contact(_join_entries(contact_entries))
address_text = _sanitize_address(_join_entries(address_entries))
# 如果联系人仍为空,尝试从“电话所在行去掉电话号码”的残余文本提取
if not contact_text and phone_anchor:
fallback_contact = clean_text(phone_anchor[0].text.replace(phone_anchor[1], ""))
if fallback_contact and not re.fullmatch(r"\d{2,20}", fallback_contact):
contact_text = _sanitize_contact(fallback_contact)
# 若仍缺联系人,尝试从靠近电话的地址候选中回退一行
if not contact_text and phone_anchor and address_entries:
phone_row = phone_anchor[0].row_idx
sorted_candidates = sorted(
address_entries,
key=lambda item: (abs(item[0] - phone_row), -item[0], item[1]),
)
for row_idx, col_idx, txt in sorted_candidates:
if ADDRESS_HINT_PATTERN.search(txt):
continue
contact_text = _sanitize_contact(txt)
if contact_text:
address_entries = [
item
for item in address_entries
if not (item[0] == row_idx and item[1] == col_idx and item[2] == txt)
]
address_text = _sanitize_address(_join_entries(address_entries))
break
has_signal = bool(zip_anchor or phone_anchor)
return address_text, contact_text, has_signal
def _extract_with_text_order(lines: List[OCRLine], data: Dict[str, str]) -> Tuple[str, str, bool]:
if not lines:
return "", "", False
zip_idx = -1
zip_token = ""
for idx, line in enumerate(lines):
m = ZIP_PATTERN.search(line.text)
if m:
zip_idx = idx
zip_token = m.group(1)
break
phone_idx = -1
phone_token = ""
for idx in range(len(lines) - 1, -1, -1):
m = PHONE_PATTERN.search(lines[idx].text)
if m:
phone_idx = idx
phone_token = m.group(1)
break
if zip_idx < 0 or phone_idx < 0 or zip_idx > phone_idx:
return "", "", False
if not data["邮编"]:
data["邮编"] = zip_token
if not data["电话"]:
data["电话"] = phone_token
address_parts: List[Tuple[int, str]] = []
contact_text = ""
for idx in range(zip_idx, phone_idx + 1):
text = lines[idx].text
if idx == zip_idx:
text = text.replace(zip_token, "")
if idx == phone_idx:
text = text.replace(phone_token, "")
text = clean_text(text)
if not text:
continue
if idx == phone_idx:
contact_text = _sanitize_contact(text)
else:
address_parts.append((idx, text))
if not contact_text and address_parts:
for idx, text in reversed(address_parts):
if ADDRESS_HINT_PATTERN.search(text):
continue
contact_text = _sanitize_contact(text)
if contact_text:
address_parts = [item for item in address_parts if item[0] != idx]
break
address_text = _sanitize_address("".join(text for _, text in address_parts))
return address_text, contact_text, True
def extract_info(ocr_results: List[Any]) -> Dict[str, str]:
""" """
从OCR结果列表中提取结构化信息。 OCR 结果中提取结构化信息。
支持两类输入:
1. 纯文本列表:`List[str]`
2. 带坐标的行对象列表:`List[{"text": "...", "box": [[x,y],...], "source": "..."}]`
""" """
data = {"编号": "", "邮编": "", "地址": "", "联系人/单位名": "", "电话": ""} data = {"编号": "", "邮编": "", "地址": "", "联系人/单位名": "", "电话": ""}
lines = _normalize_ocr_results(ocr_results)
if not lines:
return data
full_content = " ".join(ocr_results) full_content = " ".join(line.text for line in lines)
data["邮编"] = _first_match(ZIP_PATTERN, full_content)
data["电话"] = _first_match(PHONE_PATTERN, full_content)
data["编号"] = _extract_tracking_number(lines, data["邮编"], data["电话"])
# 1. 提取邮编 (6位数字) # 第一优先级:使用版面坐标进行“邮编-电话锚点 + 连续块”解析
zip_match = re.search(r"\b(\d{6})\b", full_content) address_text, contact_text, used_layout = _extract_with_layout(lines, data)
if zip_match: if not used_layout:
data["邮编"] = zip_match.group(1) # 第二优先级:无坐标时按文本顺序回退
address_text, contact_text, _ = _extract_with_text_order(lines, data)
# 2. 提取电话 (11位手机号或带区号固话) data["地址"] = _sanitize_address(address_text)
phone_match = re.search(r"(1[3-9]\d{9}|0\d{2,3}-\d{7,8})", full_content) data["联系人/单位名"] = _sanitize_contact(contact_text)
if phone_match:
data["电话"] = phone_match.group(0)
# 3. 提取联系人 (通常在电话前面,或者是独立的短行) # 最终兜底:联系人和地址任一为空时,补旧规则避免完全丢字段
# 遍历每一行寻找包含电话的行
for line in ocr_results:
if data["电话"] and data["电话"] in line:
# 移除电话部分,剩下的可能是姓名
name_part = line.replace(data["电话"], "").strip()
# 进一步清洗姓名(移除符号)
name_part = re.sub(r"[^\w\u4e00-\u9fa5]", "", name_part)
if name_part:
data["联系人/单位名"] = name_part
break
# 如果还没找到联系人,尝试找不含数字的短行
if not data["联系人/单位名"]: if not data["联系人/单位名"]:
for line in ocr_results: for line in lines:
clean_line = re.sub(r"[^\w\u4e00-\u9fa5]", "", line) text = clean_text(line.text)
if 2 <= len(clean_line) <= 10 and not re.search(r"\d", clean_line): if not text:
data["联系人/单位名"] = clean_line continue
break if data["电话"] and data["电话"] in text:
name_part = _sanitize_contact(text.replace(data["电话"], ""))
if name_part:
data["联系人/单位名"] = name_part
break
if not data["联系人/单位名"]:
for line in lines:
text = clean_text(line.text)
if 2 <= len(text) <= 20 and not re.search(r"\d", text):
data["联系人/单位名"] = _sanitize_contact(text)
break
# 4. 提取地址 if not data["地址"]:
address_match = re.search( hint_lines = [line.text for line in lines if ADDRESS_HINT_PATTERN.search(line.text)]
r"([^,,。\s]*(?:省|市|区|县|乡|镇|路|街|村|组|号)[^,,。\s]*)", full_content if hint_lines:
) hint_lines.sort(key=lambda txt: len(clean_text(txt)), reverse=True)
if address_match: data["地址"] = _sanitize_address(hint_lines[0])
data["地址"] = address_match.group(1)
else:
# 兜底:寻找较长的包含地名特征的行
for line in ocr_results:
if any(k in line for k in ["", "", "", "", "", "", ""]):
data["地址"] = line.strip()
break
# 5. 提取编号 (长数字串)
# 排除邮编和电话后的最长数字串
long_numbers = re.findall(r"\b\d{10,20}\b", full_content)
for num in long_numbers:
if num != data["电话"]:
data["编号"] = num
break
return data return data
def save_to_excel(records: List[Dict[str, Any]], output_path: str): def save_to_excel(records: List[Dict[str, Any]], output_path: str):
df = pd.DataFrame(records) df = pd.DataFrame(records)
# 调整列顺序
cols = ["编号", "邮编", "地址", "联系人/单位名", "电话"] cols = ["编号", "邮编", "地址", "联系人/单位名", "电话"]
df = df.reindex(columns=cols) df = df.reindex(columns=cols)
df.to_excel(output_path, index=False) df.to_excel(output_path, index=False)