## 主要更新 - ✨ 更新所有依赖到最新稳定版本 - 📝 添加详细的项目文档和模型推荐 - 🔧 配置 VSCode Cloud Studio 预览功能 - 🐛 修复 PyTorch API 弃用警告 ## 依赖更新 - diffusers: 0.27.2 → 0.35.2 - gradio: 4.21.0 → 5.46.0 - peft: 0.7.1 → 0.18.0 - Pillow: 9.5.0 → 11.3.0 - fastapi: 0.108.0 → 0.116.2 ## 新增文件 - CLAUDE.md - 项目架构和开发指南 - UPGRADE_NOTES.md - 详细的升级说明 - .vscode/preview.yml - 预览配置 - .vscode/LAUNCH_GUIDE.md - 启动指南 - .gitignore - 更新的忽略规则 ## 代码修复 - 修复 iopaint/model/ldm.py 中的 torch.cuda.amp.autocast() 弃用警告 ## 文档更新 - README.md - 添加模型推荐和使用指南 - 完整的项目源码(iopaint/) - Web 前端源码(web_app/) 🤖 Generated with Claude Code
46 lines
1.6 KiB
Python
Executable File
46 lines
1.6 KiB
Python
Executable File
from torch import nn
|
|
from .RNN import SequenceEncoder, Im2Seq, Im2Im
|
|
from .RecMv1_enhance import MobileNetV1Enhance
|
|
|
|
from .RecCTCHead import CTCHead
|
|
|
|
backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance}
|
|
neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
|
|
head_dict = {'CTCHead':CTCHead}
|
|
|
|
|
|
class RecModel(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
assert 'in_channels' in config, 'in_channels must in model config'
|
|
backbone_type = config.backbone.pop('type')
|
|
assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
|
|
self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
|
|
|
|
neck_type = config.neck.pop('type')
|
|
assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
|
|
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
|
|
|
|
head_type = config.head.pop('type')
|
|
assert head_type in head_dict, f'head.type must in {head_dict}'
|
|
self.head = head_dict[head_type](self.neck.out_channels, **config.head)
|
|
|
|
self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
|
|
|
|
def load_3rd_state_dict(self, _3rd_name, _state):
|
|
self.backbone.load_3rd_state_dict(_3rd_name, _state)
|
|
self.neck.load_3rd_state_dict(_3rd_name, _state)
|
|
self.head.load_3rd_state_dict(_3rd_name, _state)
|
|
|
|
def forward(self, x):
|
|
x = self.backbone(x)
|
|
x = self.neck(x)
|
|
x = self.head(x)
|
|
return x
|
|
|
|
def encode(self, x):
|
|
x = self.backbone(x)
|
|
x = self.neck(x)
|
|
x = self.head.ctc_encoder(x)
|
|
return x
|