Files
to-live-photo/Sources/LivePhotoCore/AIEnhancer/RealESRGANProcessor.swift
empty 5aba93e967 feat: M2-M4 完成,添加 AI 增强、设计系统、App Store 准备
新增功能:
- AI 超分辨率模块 (Real-ESRGAN Core ML)
- Soft UI 设计系统 (DesignSystem.swift)
- 设置页、隐私政策页、引导页
- 最近作品管理器

App Store 准备:
- 完善截图 (iPhone 6.7"/6.5", iPad 12.9")
- App Store 元数据文档
- 修复应用图标 alpha 通道
- 更新显示名称为 Live Photo Studio

工程配置:
- 配置 Git LFS 跟踪 mlmodel 文件
- 添加 Claude skill 开发指南
- 更新 .gitignore 规则

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-16 10:24:31 +08:00

214 lines
8.2 KiB
Swift

//
// RealESRGANProcessor.swift
// LivePhotoCore
//
// Core ML inference logic for Real-ESRGAN model.
// This model requires fixed 512x512 input and outputs 2048x2048.
//
import Accelerate
import CoreML
import CoreVideo
import Foundation
import os
/// Real-ESRGAN Core ML model processor
/// Note: This model has fixed input size of 512x512
actor RealESRGANProcessor {
private var model: MLModel?
private let logger = Logger(subsystem: "LivePhotoCore", category: "RealESRGANProcessor")
/// Fixed input size required by the model (512x512)
static let inputSize: Int = 512
/// Scale factor (4x for Real-ESRGAN x4plus)
static let scaleFactor: Int = 4
/// Output size (inputSize * scaleFactor = 2048)
static let outputSize: Int = inputSize * scaleFactor // 2048
init() {}
/// Load Core ML model from bundle
func loadModel() async throws {
guard model == nil else {
logger.debug("Model already loaded")
return
}
logger.info("Loading Real-ESRGAN Core ML model...")
// Try to find model in bundle
let modelName = "RealESRGAN_x4plus"
var modelURL: URL?
// Try SPM bundle first
#if SWIFT_PACKAGE
if let url = Bundle.module.url(forResource: modelName, withExtension: "mlmodelc") {
modelURL = url
} else if let url = Bundle.module.url(forResource: modelName, withExtension: "mlpackage") {
modelURL = url
}
#endif
// Try main bundle
if modelURL == nil {
if let url = Bundle.main.url(forResource: modelName, withExtension: "mlmodelc") {
modelURL = url
} else if let url = Bundle.main.url(forResource: modelName, withExtension: "mlpackage") {
modelURL = url
}
}
guard let url = modelURL else {
logger.error("Model file not found: \(modelName)")
throw AIEnhanceError.modelNotFound
}
logger.info("Found model at: \(url.path)")
// Configure model for optimal performance
let config = MLModelConfiguration()
config.computeUnits = .all // Use Neural Engine when available
do {
model = try await MLModel.load(contentsOf: url, configuration: config)
logger.info("Model loaded successfully")
} catch {
logger.error("Failed to load model: \(error.localizedDescription)")
throw AIEnhanceError.modelLoadFailed(error.localizedDescription)
}
}
/// Unload model from memory
func unloadModel() {
model = nil
logger.info("Model unloaded from memory")
}
/// Process a 512x512 image through the model
/// - Parameter pixelBuffer: Input image as CVPixelBuffer (must be 512x512, BGRA format)
/// - Returns: Enhanced image as RGBA data array (2048x2048)
func processImage(_ pixelBuffer: CVPixelBuffer) async throws -> [UInt8] {
guard let model else {
throw AIEnhanceError.modelNotFound
}
// Verify input size
let width = CVPixelBufferGetWidth(pixelBuffer)
let height = CVPixelBufferGetHeight(pixelBuffer)
guard width == Self.inputSize, height == Self.inputSize else {
throw AIEnhanceError.inferenceError(
"Invalid input size \(width)x\(height), expected \(Self.inputSize)x\(Self.inputSize)"
)
}
// Check for cancellation
try Task.checkCancellation()
logger.info("Running inference on \(width)x\(height) image...")
// Run inference synchronously (MLModel prediction is thread-safe)
let output: [UInt8] = try await withCheckedThrowingContinuation { continuation in
DispatchQueue.global(qos: .userInitiated).async {
do {
// Create input feature from pixel buffer
let inputFeature = try MLFeatureValue(pixelBuffer: pixelBuffer)
let inputProvider = try MLDictionaryFeatureProvider(
dictionary: ["input": inputFeature]
)
// Run inference synchronously
let prediction = try model.prediction(from: inputProvider)
// Extract output from model
// The model outputs to "activation_out" as either MultiArray or Image
let rgbaData: [UInt8]
if let outputValue = prediction.featureValue(for: "activation_out") {
if let multiArray = outputValue.multiArrayValue {
// Output is MLMultiArray with shape [C, H, W]
self.logger.info("Output is MultiArray: \(multiArray.shape)")
rgbaData = try self.multiArrayToRGBA(multiArray)
} else if let outputBuffer = outputValue.imageBufferValue {
// Output is CVPixelBuffer (image)
let outWidth = CVPixelBufferGetWidth(outputBuffer)
let outHeight = CVPixelBufferGetHeight(outputBuffer)
self.logger.info("Output is Image: \(outWidth)x\(outHeight)")
rgbaData = try ImageFormatConverter.pixelBufferToRGBAData(outputBuffer)
} else {
continuation.resume(throwing: AIEnhanceError.inferenceError(
"Cannot extract data from model output"
))
return
}
} else {
continuation.resume(throwing: AIEnhanceError.inferenceError(
"Model output 'activation_out' not found"
))
return
}
continuation.resume(returning: rgbaData)
} catch let error as AIEnhanceError {
continuation.resume(throwing: error)
} catch {
continuation.resume(throwing: AIEnhanceError.inferenceError(error.localizedDescription))
}
}
}
return output
}
/// Convert MLMultiArray [C, H, W] to RGBA byte array
/// - Parameter multiArray: Output from model with shape [3, H, W] (RGB channels)
/// - Returns: RGBA byte array with shape [H * W * 4]
private func multiArrayToRGBA(_ multiArray: MLMultiArray) throws -> [UInt8] {
let shape = multiArray.shape.map { $0.intValue }
// Expect shape [3, H, W] for RGB
guard shape.count == 3, shape[0] == 3 else {
throw AIEnhanceError.inferenceError(
"Unexpected output shape: \(shape), expected [3, H, W]"
)
}
let channels = shape[0]
let height = shape[1]
let width = shape[2]
logger.info("Converting MultiArray \(channels)x\(height)x\(width) to RGBA")
// Output array: RGBA format
var rgbaData = [UInt8](repeating: 255, count: width * height * 4)
// Get pointer to MultiArray data
let dataPointer = multiArray.dataPointer.assumingMemoryBound(to: Float32.self)
let channelStride = height * width
// Convert CHW (channel-first) to RGBA (interleaved)
// Model output is typically in range [0, 1] or [-1, 1], need to scale to [0, 255]
for y in 0..<height {
for x in 0..<width {
let pixelIndex = y * width + x
let rgbaIndex = pixelIndex * 4
// Read RGB values from CHW layout
let r = dataPointer[0 * channelStride + pixelIndex]
let g = dataPointer[1 * channelStride + pixelIndex]
let b = dataPointer[2 * channelStride + pixelIndex]
// Clamp and convert to 0-255
// Model typically outputs values in [0, 1] range
rgbaData[rgbaIndex + 0] = UInt8(clamping: Int(max(0, min(1, r)) * 255))
rgbaData[rgbaIndex + 1] = UInt8(clamping: Int(max(0, min(1, g)) * 255))
rgbaData[rgbaIndex + 2] = UInt8(clamping: Int(max(0, min(1, b)) * 255))
rgbaData[rgbaIndex + 3] = 255 // Alpha
}
}
return rgbaData
}
}