diff --git a/scripts/test_extract_info.py b/scripts/test_extract_info.py new file mode 100644 index 0000000..8f31436 --- /dev/null +++ b/scripts/test_extract_info.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +""" +解析器快速自测脚本 + +运行方式: + python scripts/test_extract_info.py +""" + +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + +from processor import extract_info # noqa: E402 + + +def _print_case(name: str, result: dict) -> None: + print(f"\n=== {name} ===") + for key in ["编号", "邮编", "地址", "联系人/单位名", "电话"]: + print(f"{key}: {result.get(key, '')}") + + +def case_layout_multi_column() -> None: + """多栏场景:左侧地址、右侧单位+联系人。""" + ocr_lines = [ + {"text": "518000", "box": [[80, 40], [180, 40], [180, 80], [80, 80]], "source": "main"}, + {"text": "广东省深圳市南山区", "box": [[80, 100], [450, 100], [450, 132], [80, 132]], "source": "main"}, + {"text": "科技园高新南一道18号", "box": [[80, 140], [520, 140], [520, 172], [80, 172]], "source": "main"}, + {"text": "创新大厦3栋1201", "box": [[80, 180], [420, 180], [420, 212], [80, 212]], "source": "main"}, + {"text": "华南建设小组办公室", "box": [[620, 182], [960, 182], [960, 214], [620, 214]], "source": "main"}, + {"text": "张三13800138000", "box": [[620, 222], [960, 222], [960, 254], [620, 254]], "source": "main"}, + {"text": "202602241234567890", "box": [[280, 60], [760, 60], [760, 94], [280, 94]], "source": "number"}, + ] + result = extract_info(ocr_lines) + _print_case("多栏版面", result) + + assert result["邮编"] == "518000" + assert result["电话"] == "13800138000" + assert "广东省深圳市南山区" in result["地址"] + assert "科技园高新南一道18号" in result["地址"] + assert "华南建设小组办公室" in result["联系人/单位名"] + assert result["编号"] == "202602241234567890" + + +def case_layout_single_column() -> None: + """单列场景:邮编后连续地址,电话行包含联系人。""" + ocr_lines = [ + {"text": "200120", "box": [[90, 42], [188, 42], [188, 76], [90, 76]], "source": "main"}, + {"text": "上海市浦东新区世纪大道100号", "box": [[90, 96], [620, 96], [620, 128], [90, 128]], "source": "main"}, + {"text": "A座1201室", "box": [[90, 136], [300, 136], [300, 168], [90, 168]], "source": "main"}, + {"text": "李四021-12345678", "box": [[90, 178], [420, 178], [420, 210], [90, 210]], "source": "main"}, + ] + result = extract_info(ocr_lines) + _print_case("单列版面", result) + + assert result["邮编"] == "200120" + assert result["电话"] == "021-12345678" + assert "上海市浦东新区世纪大道100号" in result["地址"] + assert "A座1201室" in result["地址"] + assert result["联系人/单位名"] == "李四" + + +def case_text_fallback() -> None: + """无坐标回退:纯文本顺序规则。""" + ocr_texts = [ + "518000", + "广东省深圳市南山区科技园", + "高新南一道18号", + "华南建设小组办公室", + "王五 13911112222", + ] + result = extract_info(ocr_texts) + _print_case("纯文本回退", result) + + assert result["邮编"] == "518000" + assert result["电话"] == "13911112222" + assert "广东省深圳市南山区科技园" in result["地址"] + assert "高新南一道18号" in result["地址"] + assert "华南建设小组办公室" in result["联系人/单位名"] or result["联系人/单位名"] == "王五" + + +def main() -> None: + case_layout_multi_column() + case_layout_single_column() + case_text_fallback() + print("\n所有场景断言通过。") + + +if __name__ == "__main__": + main() diff --git a/src/desktop.py b/src/desktop.py index 19a8517..9b1992b 100644 --- a/src/desktop.py +++ b/src/desktop.py @@ -11,6 +11,7 @@ import time import logging import threading import queue +import multiprocessing as mp import subprocess from datetime import datetime from pathlib import Path @@ -24,8 +25,8 @@ from PyQt6.QtWidgets import ( from PyQt6.QtCore import Qt, QTimer, pyqtSignal, QObject, pyqtSlot from PyQt6.QtGui import QImage, QPixmap, QFont, QAction, QKeySequence, QShortcut -from processor import extract_info -from ocr_offline import create_offline_ocr, get_models_base_dir +from ocr_offline import get_models_base_dir +from ocr_worker_process import run_ocr_worker logger = logging.getLogger("post_ocr.desktop") @@ -70,12 +71,12 @@ def setup_logging() -> Path: class OCRService(QObject): """ - OCR 后台服务(运行在标准 Python 线程内)。 + OCR 后台服务(运行在独立子进程中)。 关键点: - - 避免使用 QThread:在 macOS 上,QThread(Dummy-*) 内 import paddleocr 可能卡死 - - PaddleOCR 实例在后台线程内创建并使用,避免跨线程调用导致卡死/死锁 - - 单线程串行处理任务:避免并发推理挤爆内存或引发底层库竞争 + - PaddleOCR 初始化与推理都放到子进程,避免阻塞 UI 主进程 + - 主进程只做任务投递与结果回调 + - 子进程异常或卡住时,可通过重启服务恢复 """ finished = pyqtSignal(int, dict, list) @@ -87,11 +88,20 @@ class OCRService(QObject): def __init__(self, models_base_dir: Path): super().__init__() self._models_base_dir = models_base_dir - self._ocr = None self._busy = False self._stop_event = threading.Event() - self._queue: "queue.Queue[tuple[int, object] | None]" = queue.Queue() - self._thread = threading.Thread(target=self._run, name="OCRThread", daemon=True) + method_default = "fork" if sys.platform == "darwin" else "spawn" + method = os.environ.get("POST_OCR_MP_START_METHOD", method_default).strip() or method_default + try: + self._ctx = mp.get_context(method) + except ValueError: + method = method_default + self._ctx = mp.get_context(method_default) + logger.info("OCR multiprocessing start_method=%s", method) + self._req_q = None + self._resp_q = None + self._proc = None + self._reader_thread = None def _set_busy(self, busy: bool) -> None: if self._busy != busy: @@ -99,118 +109,152 @@ class OCRService(QObject): self.busy_changed.emit(busy) def start(self) -> None: - """启动后台线程并执行 warmup。""" + """启动 OCR 子进程与响应监听线程。""" - self._thread.start() + self._stop_event.clear() + self._req_q = self._ctx.Queue(maxsize=1) + self._resp_q = self._ctx.Queue() + self._proc = self._ctx.Process( + target=run_ocr_worker, + args=(str(self._models_base_dir), self._req_q, self._resp_q), + name="OCRProcess", + daemon=True, + ) + self._proc.start() + self._reader_thread = threading.Thread( + target=self._read_responses, + name="OCRRespReader", + daemon=True, + ) + self._reader_thread.start() def stop(self, timeout_ms: int = 8000) -> bool: - """请求停止后台线程并等待退出(后台线程为 daemon,退出失败也不阻塞进程)。""" + """停止 OCR 子进程与监听线程。""" try: self._stop_event.set() - # 用 sentinel 唤醒阻塞在 queue.get() 的线程 try: - self._queue.put_nowait(None) + if self._req_q is not None: + self._req_q.put_nowait(None) except Exception: pass - self._thread.join(timeout=max(0.0, timeout_ms / 1000.0)) - return not self._thread.is_alive() + if self._reader_thread is not None: + self._reader_thread.join(timeout=max(0.0, timeout_ms / 1000.0)) + + proc_alive = False + if self._proc is not None: + self._proc.join(timeout=max(0.0, timeout_ms / 1000.0)) + if self._proc.is_alive(): + proc_alive = True + self._proc.terminate() + self._proc.join(timeout=1.0) + + self._set_busy(False) + return not proc_alive except Exception: + self._set_busy(False) return False + finally: + self._proc = None + self._reader_thread = None + self._req_q = None + self._resp_q = None - def _ensure_ocr(self) -> None: - if self._ocr is None: - logger.info("OCR ensure_ocr: 开始创建 PaddleOCR(线程=%s)", threading.current_thread().name) - self._ocr = create_offline_ocr(models_base_dir=self._models_base_dir) - logger.info("OCR ensure_ocr: PaddleOCR 创建完成") - self.ready.emit() - - def _warmup(self) -> None: - """提前加载 OCR 模型,避免首次识别时才初始化导致“像卡死”""" - - logger.info("OCR 预热开始(线程=%s)", threading.current_thread().name) - self._ensure_ocr() - logger.info("OCR 预热完成") - - def _run(self) -> None: - try: - self._warmup() - except Exception as e: - logger.exception("OCR 预热失败:%s", str(e)) - self.init_error.emit(str(e)) - return - + def _read_responses(self) -> None: + """读取 OCR 子进程响应并转发为 Qt 信号。""" while not self._stop_event.is_set(): - item = None try: - item = self._queue.get() - except Exception: + if self._resp_q is None: + return + msg = self._resp_q.get(timeout=0.2) + except queue.Empty: continue + except Exception: + if not self._stop_event.is_set(): + self.init_error.emit("OCR 子进程通信失败") + return - if item is None: - # sentinel: stop - break - - job_id, images = item - if self._stop_event.is_set(): - break - self._process_job(job_id, images) + if not isinstance(msg, dict): + continue + msg_type = str(msg.get("type", "")).strip() + if msg_type == "progress": + job_id = msg.get("job_id", "-") + stage = msg.get("stage", "") + extra = [] + if "images" in msg: + extra.append(f"images={msg.get('images')}") + if "texts" in msg: + extra.append(f"texts={msg.get('texts')}") + suffix = f" ({', '.join(extra)})" if extra else "" + logger.info("OCR 子进程进度 job=%s stage=%s%s", job_id, stage, suffix) + continue + if msg_type == "ready": + logger.info("OCR 子进程已就绪 pid=%s", getattr(self._proc, "pid", None)) + self.ready.emit() + continue + if msg_type == "init_error": + self._set_busy(False) + self.init_error.emit(str(msg.get("error", "OCR 初始化失败"))) + continue + if msg_type == "result": + self._set_busy(False) + try: + job_id = int(msg.get("job_id")) + except Exception: + job_id = -1 + record = msg.get("record") if isinstance(msg.get("record"), dict) else {} + texts = msg.get("texts") if isinstance(msg.get("texts"), list) else [] + self.finished.emit(job_id, record, texts) + continue + if msg_type == "error": + self._set_busy(False) + try: + job_id = int(msg.get("job_id")) + except Exception: + job_id = -1 + self.error.emit(job_id, str(msg.get("error", "OCR 处理失败"))) + continue @pyqtSlot(int, object) def process(self, job_id: int, images: object) -> None: - """接收 UI 请求:把任务放进队列,由后台线程串行处理。""" + """接收 UI 请求并投递到 OCR 子进程。""" if self._stop_event.is_set(): self.error.emit(job_id, "OCR 服务正在关闭,请稍后重试。") return - # 忙碌或已有排队任务时,直接拒绝,避免积压导致“看起来一直在识别” - if self._busy or (not self._queue.empty()): + if self._proc is None or (not self._proc.is_alive()): + self.error.emit(job_id, "OCR 服务未就绪,请稍后重试。") + return + if self._busy: self.error.emit(job_id, "OCR 正在进行中,请稍后再试。") return + if not isinstance(images, (list, tuple)) or len(images) == 0: + self.error.emit(job_id, "内部错误:未传入有效图片数据") + return try: - # 注意:这里不做耗时工作,只入队,避免阻塞 UI - self._queue.put_nowait((job_id, images)) - except Exception as e: - self.error.emit(job_id, f"OCR 入队失败:{str(e)}") - - def _process_job(self, job_id: int, images: object) -> None: - self._set_busy(True) - try: - self._ensure_ocr() - if not isinstance(images, (list, tuple)) or len(images) == 0: - raise ValueError("内部错误:未传入有效图片数据") - shapes = [] - for img in images: + for item in images: + img = item + source = "main" + if isinstance(item, dict): + img = item.get("img") + source = str(item.get("source", "main")) try: - shapes.append(getattr(img, "shape", None)) + shapes.append({"source": source, "shape": getattr(img, "shape", None)}) except Exception: - shapes.append(None) - logger.info("OCR job=%s 开始,images=%s", job_id, shapes) + shapes.append({"source": source, "shape": None}) + logger.info("OCR job=%s 投递到子进程,images=%s", job_id, shapes) - ocr_texts: list[str] = [] - for img in images: - if img is None: - continue - result = self._ocr.ocr(img, cls=False) - if result and result[0]: - for line in result[0]: - if line and len(line) >= 2: - ocr_texts.append(line[1][0]) - - record = extract_info(ocr_texts) - logger.info( - "OCR job=%s 完成,lines=%s, record_keys=%s", - job_id, - len(ocr_texts), - list(record.keys()), - ) - self.finished.emit(job_id, record, ocr_texts) - except Exception as e: - logger.exception("OCR job=%s 失败:%s", job_id, str(e)) - self.error.emit(job_id, str(e)) - finally: + self._set_busy(True) + if self._req_q is None: + raise RuntimeError("OCR 请求队列不可用") + self._req_q.put_nowait((int(job_id), list(images))) + except queue.Full: self._set_busy(False) + self.error.emit(job_id, "OCR 队列已满,请稍后再试。") + except Exception as e: + self._set_busy(False) + self.error.emit(job_id, f"OCR 入队失败:{str(e)}") class MainWindow(QMainWindow): @@ -223,17 +267,22 @@ class MainWindow(QMainWindow): # OCR 工作线程(避免 UI 卡死) self._ocr_job_id = 0 + self._ocr_pending_job_id = None self._ocr_start_time_by_job: dict[int, float] = {} self._ocr_ready = False self._ocr_busy = False self._shutting_down = False self._ocr_timeout_prompted = False + self._ocr_restarting = False # 摄像头 self.cap = None self.timer = QTimer() self.timer.timeout.connect(self.update_frame) self._frame_fail_count = 0 + self._last_frame = None + self._last_frame_ts = 0.0 + self._capture_in_progress = False # 状态栏进度(识别中显示) self._progress = QProgressBar() @@ -252,17 +301,44 @@ class MainWindow(QMainWindow): self.init_ui() self.load_cameras() - # 主线程预加载:在 macOS 上,必须在主线程 import paddleocr,否则后台线程会卡死 - self.statusBar().showMessage("正在加载 OCR 模块...") - QApplication.processEvents() - try: - logger.info("主线程预加载:import paddleocr") - import paddleocr # noqa: F401 - logger.info("主线程预加载:paddleocr 导入完成") - except Exception as e: - logger.error("主线程预加载失败:%s", e, exc_info=True) - QMessageBox.critical(self, "启动失败", f"无法加载 OCR 模块:{e}") - raise + # 历史上主线程直接 import paddleocr 偶发卡死。 + # 默认跳过该步骤,避免 UI 被阻塞;如需诊断可打开轻量预检(子进程 + 超时)。 + if os.environ.get("POST_OCR_PRECHECK_IMPORT", "0").strip() == "1": + timeout_sec = 8 + try: + timeout_sec = max( + 2, + int( + os.environ.get("POST_OCR_PRECHECK_TIMEOUT_SEC", "8").strip() + or "8" + ), + ) + except Exception: + timeout_sec = 8 + self.statusBar().showMessage("正在预检 OCR 模块...") + QApplication.processEvents() + try: + logger.info("OCR 预检开始(子进程,timeout=%ss)", timeout_sec) + proc = subprocess.run( + [sys.executable, "-c", "import paddleocr"], + capture_output=True, + text=True, + timeout=timeout_sec, + ) + if proc.returncode == 0: + logger.info("OCR 预检通过") + else: + logger.warning( + "OCR 预检失败(rc=%s):%s", + proc.returncode, + (proc.stderr or "").strip(), + ) + except subprocess.TimeoutExpired: + logger.warning("OCR 预检超时(%ss),跳过预检继续启动。", timeout_sec) + except Exception as e: + logger.warning("OCR 预检异常:%s(忽略并继续)", str(e)) + else: + logger.info("已跳过主线程 OCR 预检(POST_OCR_PRECHECK_IMPORT=0)") # OCR 服务放在 UI 初始化之后启动,避免 ready/busy 信号回调时 btn_capture 尚未创建 self.statusBar().showMessage("正在启动 OCR 服务...") @@ -308,6 +384,8 @@ class MainWindow(QMainWindow): self._ocr_ready = False self._ocr_busy = False self._ocr_timeout_prompted = False + self._ocr_pending_job_id = None + self._ocr_start_time_by_job.clear() try: self._progress.setVisible(False) except Exception: @@ -316,10 +394,13 @@ class MainWindow(QMainWindow): try: svc = getattr(self, "_ocr_service", None) if svc is not None: + try: + self.request_ocr.disconnect(svc.process) + except Exception: + pass ok = svc.stop(timeout_ms=8000 if force else 3000) if (not ok) and force: - # Python 线程无法可靠“强杀”,这里只做提示并继续退出流程。 - logger.warning("OCR 服务停止超时:后台线程可能仍在运行,建议重启应用。") + logger.warning("OCR 服务停止超时:子进程可能仍在退出中,建议重启应用。") except Exception: pass @@ -333,9 +414,15 @@ class MainWindow(QMainWindow): if self._shutting_down: return - self.statusBar().showMessage("正在重启 OCR 服务...") - self._stop_ocr_service(force=True) - self._init_ocr_service() + if self._ocr_restarting: + return + self._ocr_restarting = True + try: + self.statusBar().showMessage("正在重启 OCR 服务...") + self._stop_ocr_service(force=True) + self._init_ocr_service() + finally: + self._ocr_restarting = False def _init_ocr_service(self) -> None: models_dir = get_models_base_dir() @@ -347,7 +434,7 @@ class MainWindow(QMainWindow): self._ocr_service = OCRService(models_base_dir=models_dir) - # 注意:OCRService 内部使用 Python 线程做 warmup 与推理。 + # 注意:OCRService 内部使用独立子进程做 warmup 与推理。 # 这里强制使用 QueuedConnection,确保 UI 回调始终在主线程执行。 self.request_ocr.connect(self._ocr_service.process, Qt.ConnectionType.QueuedConnection) self._ocr_service.ready.connect(self._on_ocr_ready, Qt.ConnectionType.QueuedConnection) @@ -378,6 +465,8 @@ class MainWindow(QMainWindow): try: self._ocr_busy = busy if busy: + # OCR 线程已开始处理,提交阶段不再算“待接收” + self._ocr_pending_job_id = None self._progress.setRange(0, 0) # 不确定进度条 self._progress.setVisible(True) self._ocr_timeout_prompted = False @@ -391,8 +480,27 @@ class MainWindow(QMainWindow): except Exception as e: logger.exception("处理 OCR busy 回调失败:%s", str(e)) + def _guard_ocr_submission(self, job_id: int) -> None: + """ + 兜底保护: + 如果提交后一段时间仍未进入 busy 状态,说明任务可能未被 OCR 线程接收, + 主动恢复按钮,避免界面一直停留在“正在识别...”。 + """ + + if job_id != self._ocr_pending_job_id: + return + if self._ocr_busy: + return + + self._ocr_pending_job_id = None + self._ocr_start_time_by_job.pop(job_id, None) + logger.warning("OCR job=%s 提交后未被接收,已自动恢复 UI 状态", job_id) + self.statusBar().showMessage("识别请求未被处理,请重试一次(已自动恢复)") + if self.btn_capture is not None: + self.btn_capture.setEnabled(self.cap is not None and self._ocr_ready) + def _tick_ocr_watchdog(self) -> None: - """识别进行中:更新耗时,超时则提示是否重启 OCR 服务。""" + """识别进行中:更新耗时,超时自动重启 OCR 服务。""" if not self._ocr_busy: return @@ -402,19 +510,30 @@ class MainWindow(QMainWindow): cost = time.monotonic() - start_t self.statusBar().showMessage(f"正在识别...(已用 {cost:.1f}s)") - # 超时保护:底层推理偶发卡住时,让用户可以自救 - if cost >= 45 and not self._ocr_timeout_prompted: + # 超时保护:底层推理偶发卡住时,自动重启 OCR 服务并恢复可用状态 + timeout_sec = 25 + try: + timeout_sec = max( + 8, int(os.environ.get("POST_OCR_JOB_TIMEOUT_SEC", "25").strip() or "25") + ) + except Exception: + timeout_sec = 25 + if cost >= timeout_sec and not self._ocr_timeout_prompted: self._ocr_timeout_prompted = True - reply = QMessageBox.question( + logger.warning("OCR job=%s 超时 %.1fs,自动重启 OCR 服务", self._ocr_job_id, cost) + self.statusBar().showMessage(f"识别超时({cost:.1f}s),正在自动恢复...") + # 当前任务视为失败并回收,避免界面一直等待结果 + self._ocr_start_time_by_job.pop(self._ocr_job_id, None) + self._restart_ocr_service() + QMessageBox.warning( self, "识别超时", - "识别已超过 45 秒仍未完成,可能卡住。\n\n是否重启 OCR 服务?\n(若仍无响应,建议直接退出并重新打开应用)", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + "本次识别超时,已自动重启 OCR 服务。\n请再次拍照识别。", ) - if reply == QMessageBox.StandardButton.Yes: - self._restart_ocr_service() def _on_ocr_finished_job(self, job_id: int, record: dict, texts: list) -> None: + if self._ocr_pending_job_id == job_id: + self._ocr_pending_job_id = None start_t = self._ocr_start_time_by_job.pop(job_id, None) # 只处理最新一次请求,避免旧结果回写 @@ -428,14 +547,18 @@ class MainWindow(QMainWindow): cost = f"(耗时 {time.monotonic() - start_t:.1f}s)" self.statusBar().showMessage(f"识别完成: {record.get('联系人/单位名', '未知')}{cost}") logger.info("OCR job=%s UI 回写完成 %s", job_id, cost) + self.btn_capture.setEnabled(self.cap is not None and self._ocr_ready and not self._ocr_busy) def _on_ocr_error_job(self, job_id: int, error: str) -> None: + if self._ocr_pending_job_id == job_id: + self._ocr_pending_job_id = None self._ocr_start_time_by_job.pop(job_id, None) if job_id != self._ocr_job_id: return self.statusBar().showMessage("识别失败") QMessageBox.warning(self, "识别失败", error) logger.error("OCR job=%s error: %s", job_id, error) + self.btn_capture.setEnabled(self.cap is not None and self._ocr_ready and not self._ocr_busy) def init_ui(self): central = QWidget() @@ -519,6 +642,7 @@ class MainWindow(QMainWindow): # macOS/Qt 下 Space 经常被控件吞掉(按钮激活/表格选择等),用 ApplicationShortcut 更稳 self._shortcut_capture2 = QShortcut(QKeySequence("Space"), self) self._shortcut_capture2.setContext(Qt.ShortcutContext.ApplicationShortcut) + self._shortcut_capture2.setAutoRepeat(False) self._shortcut_capture2.activated.connect(self.capture_and_recognize) def load_cameras(self): @@ -770,6 +894,13 @@ class MainWindow(QMainWindow): ret, frame = self.cap.read() if ret and frame is not None and frame.size > 0: self._frame_fail_count = 0 + # 缓存原始帧,拍照时直接使用,避免按空格再读摄像头导致主线程阻塞 + try: + self._last_frame = frame.copy() + self._last_frame_ts = time.monotonic() + except Exception: + self._last_frame = frame + self._last_frame_ts = time.monotonic() # 绘制扫描框 h, w = frame.shape[:2] # 框的位置:上方 70%,编号在下方 @@ -812,6 +943,9 @@ class MainWindow(QMainWindow): def capture_and_recognize(self): """拍照并识别""" + if self._capture_in_progress: + self.statusBar().showMessage("正在拍照,请稍候") + return if self.cap is None: self.statusBar().showMessage("请先连接摄像头") return @@ -822,61 +956,126 @@ class MainWindow(QMainWindow): self.statusBar().showMessage("正在识别中,请稍后再按空格") return - ret, frame = self.cap.read() - if not ret: - self.statusBar().showMessage("拍照失败") - return - - # 裁剪两块 ROI(主信息框 + 编号区域),显著减小像素量,提升速度与稳定性 - h, w = frame.shape[:2] - x1, y1 = int(w * 0.06), int(h * 0.08) - x2 = int(w * 0.94) - y2_box = int(h * 0.78) - - roi_images = [] + self._capture_in_progress = True try: - roi_box = frame[y1:y2_box, x1:x2] - if roi_box is not None and roi_box.size > 0: - roi_images.append(roi_box) - except Exception: - pass + # 直接使用预览缓存帧,避免在按键回调中阻塞式 read 摄像头导致卡顿 + frame = None + now = time.monotonic() + if self._last_frame is not None and (now - self._last_frame_ts) <= 1.5: + try: + frame = self._last_frame.copy() + except Exception: + frame = self._last_frame - try: - # 编号一般在底部中间,取较小区域即可 - nx1, nx2 = int(w * 0.30), int(w * 0.70) - ny1, ny2 = int(h * 0.80), int(h * 0.98) - roi_num = frame[ny1:ny2, nx1:nx2] - if roi_num is not None and roi_num.size > 0: - roi_images.append(roi_num) - except Exception: - pass + if frame is None: + self.statusBar().showMessage("尚未拿到稳定画面,请稍后再按空格") + return - if not roi_images: - self.statusBar().showMessage("拍照失败:未截取到有效区域") - return + # 裁剪主信息 ROI 与编号 ROI + h, w = frame.shape[:2] + x1, y1 = int(w * 0.06), int(h * 0.08) + x2 = int(w * 0.94) + y2_box = int(h * 0.78) - # 超大分辨率下适当缩放(提高稳定性与速度) - resized_images = [] - for img in roi_images: + roi_inputs = [] try: - max_w = 1400 - if img.shape[1] > max_w: - scale = max_w / img.shape[1] - img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale))) + roi_box = frame[y1:y2_box, x1:x2] + if roi_box is not None and roi_box.size > 0: + # 主信息区域切成多段,规避大图整块检测偶发卡住 + split_count = 2 + try: + split_count = max( + 1, + int( + os.environ.get("POST_OCR_MAIN_SPLIT", "2").strip() + or "2" + ), + ) + except Exception: + split_count = 2 + split_count = min(split_count, 4) + + if split_count <= 1 or roi_box.shape[0] < 120: + roi_inputs.append({"img": roi_box, "source": "main"}) + else: + h_box = roi_box.shape[0] + step = h_box / float(split_count) + overlap = max(8, int(h_box * 0.06)) + for i in range(split_count): + sy = int(max(0, i * step - (overlap if i > 0 else 0))) + ey = int( + min( + h_box, + (i + 1) * step + + (overlap if i < split_count - 1 else 0), + ) + ) + part = roi_box[sy:ey, :] + if part is not None and part.size > 0: + roi_inputs.append({"img": part, "source": "main"}) except Exception: pass - resized_images.append(img) - logger.info("UI 触发识别:frame=%s, rois=%s", getattr(frame, "shape", None), [getattr(i, "shape", None) for i in resized_images]) + try: + # 编号一般在底部中间,取较小区域即可 + nx1, nx2 = int(w * 0.30), int(w * 0.70) + ny1, ny2 = int(h * 0.80), int(h * 0.98) + roi_num = frame[ny1:ny2, nx1:nx2] + if roi_num is not None and roi_num.size > 0: + roi_inputs.append({"img": roi_num, "source": "number"}) + except Exception: + pass - self.statusBar().showMessage("正在识别...") - self.btn_capture.setEnabled(False) + if not roi_inputs: + self.statusBar().showMessage("拍照失败:未截取到有效区域") + return - # 派发到 OCR 工作线程 - self._ocr_job_id += 1 - job_id = self._ocr_job_id - self._ocr_start_time_by_job[job_id] = time.monotonic() - self.request_ocr.emit(job_id, resized_images) + # 超大分辨率下适当缩放(提高稳定性与速度) + resized_inputs = [] + max_w = 960 + try: + max_w = max( + 600, int(os.environ.get("POST_OCR_MAX_ROI_WIDTH", "960").strip() or "960") + ) + except Exception: + max_w = 960 + + for item in roi_inputs: + img = item.get("img") + source = item.get("source", "main") + try: + if img is not None and img.shape[1] > max_w: + scale = max_w / img.shape[1] + img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale))) + except Exception: + pass + resized_inputs.append({"img": img, "source": source}) + + logger.info( + "UI 触发识别:frame=%s, rois=%s, frame_age=%.3fs", + getattr(frame, "shape", None), + [ + { + "source": item.get("source", "main"), + "shape": getattr(item.get("img"), "shape", None), + } + for item in resized_inputs + ], + max(0.0, now - self._last_frame_ts), + ) + + self.statusBar().showMessage("正在识别...") + self.btn_capture.setEnabled(False) + + # 派发到 OCR 工作线程 + self._ocr_job_id += 1 + job_id = self._ocr_job_id + self._ocr_pending_job_id = job_id + self._ocr_start_time_by_job[job_id] = time.monotonic() + self.request_ocr.emit(job_id, resized_inputs) + QTimer.singleShot(2000, lambda j=job_id: self._guard_ocr_submission(j)) + finally: + self._capture_in_progress = False def update_table(self): """更新表格""" diff --git a/src/main.py b/src/main.py index 184dfb0..339706c 100644 --- a/src/main.py +++ b/src/main.py @@ -40,15 +40,31 @@ def main(): # 2. 提取文字行 ocr_texts = [] + ocr_lines = [] if result and result[0]: for line in result[0]: # line 格式: [box, (text, confidence)] if line and len(line) >= 2: - ocr_texts.append(line[1][0]) + text = str(line[1][0]) + ocr_texts.append(text) + conf = None + try: + conf = float(line[1][1]) + except Exception: + conf = None + ocr_lines.append( + { + "text": text, + "box": line[0], + "conf": conf, + "source": "main", + "roi_index": 0, + } + ) # 3. 结构化解析 if ocr_texts: - record = extract_info(ocr_texts) + record = extract_info(ocr_lines if ocr_lines else ocr_texts) all_records.append(record) else: errors.append( diff --git a/src/ocr_offline.py b/src/ocr_offline.py index b67e653..a5423a0 100644 --- a/src/ocr_offline.py +++ b/src/ocr_offline.py @@ -118,7 +118,21 @@ def create_offline_ocr(models_base_dir: Path | None = None): from paddleocr import PaddleOCR # 构建 PaddleOCR 参数 + # 说明:在部分 macOS/CPU 环境下,oneDNN(MKLDNN) 可能出现卡住,默认关闭以换取稳定性。 kwargs = dict(lang="ch", use_angle_cls=False, show_log=False) + disable_mkldnn = os.environ.get("POST_OCR_DISABLE_MKLDNN", "1").strip() == "1" + if disable_mkldnn: + os.environ["FLAGS_use_mkldnn"] = "0" + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + os.environ["OPENBLAS_NUM_THREADS"] = "1" + kwargs["enable_mkldnn"] = False + try: + kwargs["cpu_threads"] = max( + 1, int(os.environ.get("POST_OCR_CPU_THREADS", "1").strip() or "1") + ) + except Exception: + kwargs["cpu_threads"] = 1 # 如果 models/ 目录存在离线模型,显式指定路径(打包分发场景) models_dir = models_base_dir or get_models_base_dir() @@ -134,6 +148,12 @@ def create_offline_ocr(models_base_dir: Path | None = None): log.info("未找到离线模型,将使用默认路径(可能需要联网下载)") log.info("create_offline_ocr: creating PaddleOCR(lang=ch)") - ocr = PaddleOCR(**kwargs) + try: + ocr = PaddleOCR(**kwargs) + except TypeError: + # 兼容个别 PaddleOCR 版本不支持的参数 + kwargs.pop("enable_mkldnn", None) + kwargs.pop("cpu_threads", None) + ocr = PaddleOCR(**kwargs) log.info("create_offline_ocr: PaddleOCR created") return ocr diff --git a/src/ocr_worker_process.py b/src/ocr_worker_process.py new file mode 100644 index 0000000..01d63a1 --- /dev/null +++ b/src/ocr_worker_process.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +# 必须在所有 paddle/numpy import 之前设置,否则 macOS spawn 子进程推理会死锁 +import os +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +os.environ["VECLIB_MAXIMUM_THREADS"] = "1" +os.environ["FLAGS_use_mkldnn"] = "0" +os.environ["PADDLE_DISABLE_SIGNAL_HANDLER"] = "1" +os.environ["PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK"] = "True" + +from pathlib import Path +from typing import Any + +from ocr_offline import create_offline_ocr +from processor import extract_info + + +def run_ocr_worker(models_base_dir: str, request_q, response_q) -> None: + """ + OCR 子进程主循环: + - 在子进程内初始化 PaddleOCR,避免阻塞主 UI 进程 + - 接收任务并返回结构化结果 + """ + try: + response_q.put({"type": "progress", "stage": "init_start"}) + ocr = create_offline_ocr(models_base_dir=Path(models_base_dir)) + response_q.put({"type": "ready"}) + except Exception as e: + response_q.put({"type": "init_error", "error": str(e)}) + return + + while True: + item = request_q.get() + if item is None: + break + + job_id = -1 + try: + job_id, images = item + if not isinstance(images, (list, tuple)) or len(images) == 0: + raise ValueError("内部错误:未传入有效图片数据") + + response_q.put({"type": "progress", "job_id": int(job_id), "stage": "job_received", "images": len(images)}) + ocr_texts: list[str] = [] + ocr_lines: list[dict[str, Any]] = [] + for roi_index, entry in enumerate(images): + source = "main" + img = entry + if isinstance(entry, dict): + source = str(entry.get("source", "main")) + img = entry.get("img") + elif roi_index > 0: + source = "number" + if img is None: + continue + response_q.put({"type": "progress", "job_id": int(job_id), "stage": f"roi_{roi_index}_start"}) + result = ocr.ocr(img, cls=False) + response_q.put({"type": "progress", "job_id": int(job_id), "stage": f"roi_{roi_index}_done"}) + if result and result[0]: + for line in result[0]: + if line and len(line) >= 2: + text = str(line[1][0]) + ocr_texts.append(text) + conf = None + try: + conf = float(line[1][1]) + except Exception: + conf = None + ocr_lines.append( + { + "text": text, + "box": line[0], + "conf": conf, + "source": source, + "roi_index": roi_index, + } + ) + + record = extract_info(ocr_lines if ocr_lines else ocr_texts) + response_q.put({"type": "progress", "job_id": int(job_id), "stage": "parse_done", "texts": len(ocr_texts)}) + response_q.put( + { + "type": "result", + "job_id": int(job_id), + "record": record, + "texts": ocr_texts, + } + ) + except Exception as e: + response_q.put({"type": "error", "job_id": int(job_id), "error": str(e)}) diff --git a/src/processor.py b/src/processor.py index 599af96..82923a9 100644 --- a/src/processor.py +++ b/src/processor.py @@ -1,9 +1,64 @@ import re +from dataclasses import dataclass +from statistics import median +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + import pandas as pd -from typing import List, Dict, Any from pydantic import BaseModel, Field +ZIP_PATTERN = re.compile(r"(? bool: + return ( + self.x1 is not None + and self.y1 is not None + and self.x2 is not None + and self.y2 is not None + ) + + @property + def cx(self) -> float: + if not self.has_pos: + return float(self.order) + return (self.x1 + self.x2) / 2.0 # type: ignore[operator] + + @property + def cy(self) -> float: + if not self.has_pos: + return float(self.order) + return (self.y1 + self.y2) / 2.0 # type: ignore[operator] + + @property + def height(self) -> float: + if not self.has_pos: + return 0.0 + return max(0.0, float(self.y2) - float(self.y1)) + + @property + def width(self) -> float: + if not self.has_pos: + return 0.0 + return max(0.0, float(self.x2) - float(self.x1)) + + class EnvelopeRecord(BaseModel): 编号: str = "" 邮编: str = "" @@ -13,75 +68,412 @@ class EnvelopeRecord(BaseModel): def clean_text(text: str) -> str: - """清理OCR识别出的杂质字符""" - return text.strip().replace(" ", "") + """清理 OCR 识别文本中的空白和无意义分隔符。""" + if not text: + return "" + text = text.replace("\u3000", " ").strip() + return re.sub(r"\s+", "", text) -def extract_info(ocr_results: List[str]) -> Dict[str, str]: +def _parse_box(raw_box: Any) -> Tuple[Optional[float], Optional[float], Optional[float], Optional[float]]: + if not isinstance(raw_box, (list, tuple)) or len(raw_box) < 4: + return None, None, None, None + + xs: List[float] = [] + ys: List[float] = [] + for p in raw_box: + if not isinstance(p, (list, tuple)) or len(p) < 2: + continue + try: + xs.append(float(p[0])) + ys.append(float(p[1])) + except Exception: + continue + if not xs or not ys: + return None, None, None, None + return min(xs), min(ys), max(xs), max(ys) + + +def _to_ocr_line(item: Any, idx: int) -> Optional[OCRLine]: + if isinstance(item, str): + text = clean_text(item) + if not text: + return None + return OCRLine(text=text, source="main", order=idx) + + if not isinstance(item, dict): + return None + + text = clean_text(str(item.get("text", ""))) + if not text: + return None + + source = str(item.get("source", "main")) + x1, y1, x2, y2 = _parse_box(item.get("box")) + if x1 is None: + # 兼容直接传坐标的输入 + try: + x1 = float(item.get("x1")) + y1 = float(item.get("y1")) + x2 = float(item.get("x2")) + y2 = float(item.get("y2")) + except Exception: + x1, y1, x2, y2 = None, None, None, None + + return OCRLine(text=text, source=source, order=idx, x1=x1, y1=y1, x2=x2, y2=y2) + + +def _normalize_ocr_results(ocr_results: Sequence[Any]) -> List[OCRLine]: + lines: List[OCRLine] = [] + seen = set() + for idx, item in enumerate(ocr_results): + line = _to_ocr_line(item, idx) + if line is None: + continue + if line.has_pos: + key = ( + line.text, + line.source, + round(line.cx, 1), + round(line.cy, 1), + ) + else: + key = (line.text, line.source, line.order) + if key in seen: + continue + seen.add(key) + lines.append(line) + return lines + + +def _first_match(pattern: re.Pattern[str], text: str) -> str: + m = pattern.search(text) + if not m: + return "" + if m.lastindex: + return m.group(1) + return m.group(0) + + +def _find_anchor(lines: Iterable[OCRLine], pattern: re.Pattern[str], prefer_bottom: bool) -> Optional[Tuple[OCRLine, str]]: + candidates: List[Tuple[OCRLine, str]] = [] + for line in lines: + m = pattern.search(line.text) + if not m: + continue + token = m.group(1) if m.lastindex else m.group(0) + candidates.append((line, token)) + + if not candidates: + return None + + if prefer_bottom: + return max(candidates, key=lambda item: (item[0].row_idx, item[0].cy, item[0].cx, item[0].order)) + return min(candidates, key=lambda item: (item[0].row_idx, item[0].cy, item[0].cx, item[0].order)) + + +def _build_rows(lines: List[OCRLine]) -> List[List[OCRLine]]: + positioned = [line for line in lines if line.has_pos] + if not positioned: + return [] + positioned.sort(key=lambda line: (line.cy, line.cx)) + + heights = [line.height for line in positioned if line.height > 1.0] + h_med = median(heights) if heights else 20.0 + y_threshold = max(8.0, h_med * 0.65) + + rows: List[List[OCRLine]] = [] + for line in positioned: + if not rows: + rows.append([line]) + continue + row = rows[-1] + mean_y = sum(item.cy for item in row) / len(row) + if abs(line.cy - mean_y) <= y_threshold: + row.append(line) + else: + rows.append([line]) + + for row_idx, row in enumerate(rows): + row.sort(key=lambda line: (line.cx, line.x1 or 0.0)) + for col_idx, line in enumerate(row): + line.row_idx = row_idx + line.col_idx = col_idx + return rows + + +def _sanitize_address(text: str) -> str: + text = clean_text(text) + text = re.sub(r"^(地址|收件地址|详细地址)[::]?", "", text) + return text + + +def _sanitize_contact(text: str) -> str: + text = clean_text(text) + text = re.sub(r"^(收件人|联系人|单位|收)[::]?", "", text) + return text.strip(",,。;;:") + + +def _join_entries(entries: List[Tuple[int, int, str]]) -> str: + if not entries: + return "" + entries.sort(key=lambda item: (item[0], item[1])) + merged: List[str] = [] + for _, _, text in entries: + if not text: + continue + if merged and merged[-1] == text: + continue + merged.append(text) + return "".join(merged) + + +def _extract_tracking_number(lines: List[OCRLine], zip_code: str, phone: str) -> str: + phone_digits = re.sub(r"\D", "", phone) + candidates: List[Tuple[int, int, str]] = [] + for line in lines: + for match in LONG_NUMBER_PATTERN.finditer(line.text): + number = match.group(1) + if number == zip_code: + continue + if phone and (number == phone or number == phone_digits): + continue + src_score = 2 if line.source == "number" else 1 + candidates.append((src_score, len(number), number)) + if not candidates: + return "" + candidates.sort(reverse=True) + return candidates[0][2] + + +def _extract_with_layout(lines: List[OCRLine], data: Dict[str, str]) -> Tuple[str, str, bool]: + main_lines = [line for line in lines if line.source != "number"] + if len(main_lines) < 2: + return "", "", False + + rows = _build_rows(main_lines) + if not rows: + return "", "", False + + zip_anchor = _find_anchor(main_lines, ZIP_PATTERN, prefer_bottom=False) + phone_anchor = _find_anchor(main_lines, PHONE_PATTERN, prefer_bottom=True) + if zip_anchor and not data["邮编"]: + data["邮编"] = zip_anchor[1] + if phone_anchor and not data["电话"]: + data["电话"] = phone_anchor[1] + + if zip_anchor: + start_row = zip_anchor[0].row_idx + else: + start_row = min(line.row_idx for line in main_lines) + if phone_anchor: + end_row = phone_anchor[0].row_idx + else: + end_row = max(line.row_idx for line in main_lines) + if start_row > end_row: + start_row, end_row = end_row, start_row + + single_column_mode = False + if zip_anchor and phone_anchor: + line_widths = [line.width for line in main_lines if line.width > 0] + width_ref = median(line_widths) if line_widths else 120.0 + single_column_mode = abs(phone_anchor[0].cx - zip_anchor[0].cx) < max(60.0, width_ref * 0.6) + + if zip_anchor and phone_anchor and phone_anchor[0].cx > zip_anchor[0].cx and not single_column_mode: + split_x = (zip_anchor[0].cx + phone_anchor[0].cx) / 2.0 + elif phone_anchor: + split_x = phone_anchor[0].cx - max(40.0, phone_anchor[0].width * 0.6) + elif zip_anchor: + split_x = zip_anchor[0].cx + max(80.0, zip_anchor[0].width * 1.5) + else: + split_x = median([line.cx for line in main_lines]) + + address_entries: List[Tuple[int, int, str]] = [] + contact_entries: List[Tuple[int, int, str]] = [] + + for line in main_lines: + if line.row_idx < start_row or line.row_idx > end_row: + continue + + text = line.text + if zip_anchor and line is zip_anchor[0]: + text = text.replace(zip_anchor[1], "") + if phone_anchor and line is phone_anchor[0]: + text = text.replace(phone_anchor[1], "") + text = clean_text(text) + if not text: + continue + if re.fullmatch(r"\d{6,20}", text): + continue + + if single_column_mode: + if phone_anchor and line is phone_anchor[0]: + contact_entries.append((line.row_idx, line.col_idx, text)) + else: + address_entries.append((line.row_idx, line.col_idx, text)) + continue + + if line.cx <= split_x: + address_entries.append((line.row_idx, line.col_idx, text)) + else: + contact_entries.append((line.row_idx, line.col_idx, text)) + + # 联系人优先取靠近电话的一段,降低把地址误分到联系人的概率 + if phone_anchor and contact_entries: + phone_row = phone_anchor[0].row_idx + min_dist = min(abs(item[0] - phone_row) for item in contact_entries) + contact_entries = [ + item for item in contact_entries if abs(item[0] - phone_row) <= min_dist + 1 + ] + + contact_text = _sanitize_contact(_join_entries(contact_entries)) + address_text = _sanitize_address(_join_entries(address_entries)) + + # 如果联系人仍为空,尝试从“电话所在行去掉电话号码”的残余文本提取 + if not contact_text and phone_anchor: + fallback_contact = clean_text(phone_anchor[0].text.replace(phone_anchor[1], "")) + if fallback_contact and not re.fullmatch(r"\d{2,20}", fallback_contact): + contact_text = _sanitize_contact(fallback_contact) + + # 若仍缺联系人,尝试从靠近电话的地址候选中回退一行 + if not contact_text and phone_anchor and address_entries: + phone_row = phone_anchor[0].row_idx + sorted_candidates = sorted( + address_entries, + key=lambda item: (abs(item[0] - phone_row), -item[0], item[1]), + ) + for row_idx, col_idx, txt in sorted_candidates: + if ADDRESS_HINT_PATTERN.search(txt): + continue + contact_text = _sanitize_contact(txt) + if contact_text: + address_entries = [ + item + for item in address_entries + if not (item[0] == row_idx and item[1] == col_idx and item[2] == txt) + ] + address_text = _sanitize_address(_join_entries(address_entries)) + break + + has_signal = bool(zip_anchor or phone_anchor) + return address_text, contact_text, has_signal + + +def _extract_with_text_order(lines: List[OCRLine], data: Dict[str, str]) -> Tuple[str, str, bool]: + if not lines: + return "", "", False + + zip_idx = -1 + zip_token = "" + for idx, line in enumerate(lines): + m = ZIP_PATTERN.search(line.text) + if m: + zip_idx = idx + zip_token = m.group(1) + break + + phone_idx = -1 + phone_token = "" + for idx in range(len(lines) - 1, -1, -1): + m = PHONE_PATTERN.search(lines[idx].text) + if m: + phone_idx = idx + phone_token = m.group(1) + break + + if zip_idx < 0 or phone_idx < 0 or zip_idx > phone_idx: + return "", "", False + + if not data["邮编"]: + data["邮编"] = zip_token + if not data["电话"]: + data["电话"] = phone_token + + address_parts: List[Tuple[int, str]] = [] + contact_text = "" + for idx in range(zip_idx, phone_idx + 1): + text = lines[idx].text + if idx == zip_idx: + text = text.replace(zip_token, "") + if idx == phone_idx: + text = text.replace(phone_token, "") + text = clean_text(text) + if not text: + continue + if idx == phone_idx: + contact_text = _sanitize_contact(text) + else: + address_parts.append((idx, text)) + + if not contact_text and address_parts: + for idx, text in reversed(address_parts): + if ADDRESS_HINT_PATTERN.search(text): + continue + contact_text = _sanitize_contact(text) + if contact_text: + address_parts = [item for item in address_parts if item[0] != idx] + break + + address_text = _sanitize_address("".join(text for _, text in address_parts)) + return address_text, contact_text, True + + +def extract_info(ocr_results: List[Any]) -> Dict[str, str]: """ - 从OCR结果列表中提取结构化信息。 + 从 OCR 结果中提取结构化信息。 + + 支持两类输入: + 1. 纯文本列表:`List[str]` + 2. 带坐标的行对象列表:`List[{"text": "...", "box": [[x,y],...], "source": "..."}]` """ data = {"编号": "", "邮编": "", "地址": "", "联系人/单位名": "", "电话": ""} + lines = _normalize_ocr_results(ocr_results) + if not lines: + return data - full_content = " ".join(ocr_results) + full_content = " ".join(line.text for line in lines) + data["邮编"] = _first_match(ZIP_PATTERN, full_content) + data["电话"] = _first_match(PHONE_PATTERN, full_content) + data["编号"] = _extract_tracking_number(lines, data["邮编"], data["电话"]) - # 1. 提取邮编 (6位数字) - zip_match = re.search(r"\b(\d{6})\b", full_content) - if zip_match: - data["邮编"] = zip_match.group(1) + # 第一优先级:使用版面坐标进行“邮编-电话锚点 + 连续块”解析 + address_text, contact_text, used_layout = _extract_with_layout(lines, data) + if not used_layout: + # 第二优先级:无坐标时按文本顺序回退 + address_text, contact_text, _ = _extract_with_text_order(lines, data) - # 2. 提取电话 (11位手机号或带区号固话) - phone_match = re.search(r"(1[3-9]\d{9}|0\d{2,3}-\d{7,8})", full_content) - if phone_match: - data["电话"] = phone_match.group(0) + data["地址"] = _sanitize_address(address_text) + data["联系人/单位名"] = _sanitize_contact(contact_text) - # 3. 提取联系人 (通常在电话前面,或者是独立的短行) - # 遍历每一行寻找包含电话的行 - for line in ocr_results: - if data["电话"] and data["电话"] in line: - # 移除电话部分,剩下的可能是姓名 - name_part = line.replace(data["电话"], "").strip() - # 进一步清洗姓名(移除符号) - name_part = re.sub(r"[^\w\u4e00-\u9fa5]", "", name_part) - if name_part: - data["联系人/单位名"] = name_part - break - - # 如果还没找到联系人,尝试找不含数字的短行 + # 最终兜底:联系人和地址任一为空时,补旧规则避免完全丢字段 if not data["联系人/单位名"]: - for line in ocr_results: - clean_line = re.sub(r"[^\w\u4e00-\u9fa5]", "", line) - if 2 <= len(clean_line) <= 10 and not re.search(r"\d", clean_line): - data["联系人/单位名"] = clean_line - break + for line in lines: + text = clean_text(line.text) + if not text: + continue + if data["电话"] and data["电话"] in text: + name_part = _sanitize_contact(text.replace(data["电话"], "")) + if name_part: + data["联系人/单位名"] = name_part + break + if not data["联系人/单位名"]: + for line in lines: + text = clean_text(line.text) + if 2 <= len(text) <= 20 and not re.search(r"\d", text): + data["联系人/单位名"] = _sanitize_contact(text) + break - # 4. 提取地址 - address_match = re.search( - r"([^,,。\s]*(?:省|市|区|县|乡|镇|路|街|村|组|号)[^,,。\s]*)", full_content - ) - if address_match: - data["地址"] = address_match.group(1) - else: - # 兜底:寻找较长的包含地名特征的行 - for line in ocr_results: - if any(k in line for k in ["省", "市", "区", "县", "乡", "镇", "村"]): - data["地址"] = line.strip() - break - - # 5. 提取编号 (长数字串) - # 排除邮编和电话后的最长数字串 - long_numbers = re.findall(r"\b\d{10,20}\b", full_content) - for num in long_numbers: - if num != data["电话"]: - data["编号"] = num - break + if not data["地址"]: + hint_lines = [line.text for line in lines if ADDRESS_HINT_PATTERN.search(line.text)] + if hint_lines: + hint_lines.sort(key=lambda txt: len(clean_text(txt)), reverse=True) + data["地址"] = _sanitize_address(hint_lines[0]) return data def save_to_excel(records: List[Dict[str, Any]], output_path: str): df = pd.DataFrame(records) - # 调整列顺序 cols = ["编号", "邮编", "地址", "联系人/单位名", "电话"] df = df.reindex(columns=cols) df.to_excel(output_path, index=False)