feat: M2-M4 完成,添加 AI 增强、设计系统、App Store 准备
新增功能: - AI 超分辨率模块 (Real-ESRGAN Core ML) - Soft UI 设计系统 (DesignSystem.swift) - 设置页、隐私政策页、引导页 - 最近作品管理器 App Store 准备: - 完善截图 (iPhone 6.7"/6.5", iPad 12.9") - App Store 元数据文档 - 修复应用图标 alpha 通道 - 更新显示名称为 Live Photo Studio 工程配置: - 配置 Git LFS 跟踪 mlmodel 文件 - 添加 Claude skill 开发指南 - 更新 .gitignore 规则 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
213
Sources/LivePhotoCore/AIEnhancer/RealESRGANProcessor.swift
Normal file
213
Sources/LivePhotoCore/AIEnhancer/RealESRGANProcessor.swift
Normal file
@@ -0,0 +1,213 @@
|
||||
//
|
||||
// RealESRGANProcessor.swift
|
||||
// LivePhotoCore
|
||||
//
|
||||
// Core ML inference logic for Real-ESRGAN model.
|
||||
// This model requires fixed 512x512 input and outputs 2048x2048.
|
||||
//
|
||||
|
||||
import Accelerate
|
||||
import CoreML
|
||||
import CoreVideo
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
/// Real-ESRGAN Core ML model processor
|
||||
/// Note: This model has fixed input size of 512x512
|
||||
actor RealESRGANProcessor {
|
||||
private var model: MLModel?
|
||||
private let logger = Logger(subsystem: "LivePhotoCore", category: "RealESRGANProcessor")
|
||||
|
||||
/// Fixed input size required by the model (512x512)
|
||||
static let inputSize: Int = 512
|
||||
|
||||
/// Scale factor (4x for Real-ESRGAN x4plus)
|
||||
static let scaleFactor: Int = 4
|
||||
|
||||
/// Output size (inputSize * scaleFactor = 2048)
|
||||
static let outputSize: Int = inputSize * scaleFactor // 2048
|
||||
|
||||
init() {}
|
||||
|
||||
/// Load Core ML model from bundle
|
||||
func loadModel() async throws {
|
||||
guard model == nil else {
|
||||
logger.debug("Model already loaded")
|
||||
return
|
||||
}
|
||||
|
||||
logger.info("Loading Real-ESRGAN Core ML model...")
|
||||
|
||||
// Try to find model in bundle
|
||||
let modelName = "RealESRGAN_x4plus"
|
||||
var modelURL: URL?
|
||||
|
||||
// Try SPM bundle first
|
||||
#if SWIFT_PACKAGE
|
||||
if let url = Bundle.module.url(forResource: modelName, withExtension: "mlmodelc") {
|
||||
modelURL = url
|
||||
} else if let url = Bundle.module.url(forResource: modelName, withExtension: "mlpackage") {
|
||||
modelURL = url
|
||||
}
|
||||
#endif
|
||||
|
||||
// Try main bundle
|
||||
if modelURL == nil {
|
||||
if let url = Bundle.main.url(forResource: modelName, withExtension: "mlmodelc") {
|
||||
modelURL = url
|
||||
} else if let url = Bundle.main.url(forResource: modelName, withExtension: "mlpackage") {
|
||||
modelURL = url
|
||||
}
|
||||
}
|
||||
|
||||
guard let url = modelURL else {
|
||||
logger.error("Model file not found: \(modelName)")
|
||||
throw AIEnhanceError.modelNotFound
|
||||
}
|
||||
|
||||
logger.info("Found model at: \(url.path)")
|
||||
|
||||
// Configure model for optimal performance
|
||||
let config = MLModelConfiguration()
|
||||
config.computeUnits = .all // Use Neural Engine when available
|
||||
|
||||
do {
|
||||
model = try await MLModel.load(contentsOf: url, configuration: config)
|
||||
logger.info("Model loaded successfully")
|
||||
} catch {
|
||||
logger.error("Failed to load model: \(error.localizedDescription)")
|
||||
throw AIEnhanceError.modelLoadFailed(error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
/// Unload model from memory
|
||||
func unloadModel() {
|
||||
model = nil
|
||||
logger.info("Model unloaded from memory")
|
||||
}
|
||||
|
||||
/// Process a 512x512 image through the model
|
||||
/// - Parameter pixelBuffer: Input image as CVPixelBuffer (must be 512x512, BGRA format)
|
||||
/// - Returns: Enhanced image as RGBA data array (2048x2048)
|
||||
func processImage(_ pixelBuffer: CVPixelBuffer) async throws -> [UInt8] {
|
||||
guard let model else {
|
||||
throw AIEnhanceError.modelNotFound
|
||||
}
|
||||
|
||||
// Verify input size
|
||||
let width = CVPixelBufferGetWidth(pixelBuffer)
|
||||
let height = CVPixelBufferGetHeight(pixelBuffer)
|
||||
guard width == Self.inputSize, height == Self.inputSize else {
|
||||
throw AIEnhanceError.inferenceError(
|
||||
"Invalid input size \(width)x\(height), expected \(Self.inputSize)x\(Self.inputSize)"
|
||||
)
|
||||
}
|
||||
|
||||
// Check for cancellation
|
||||
try Task.checkCancellation()
|
||||
|
||||
logger.info("Running inference on \(width)x\(height) image...")
|
||||
|
||||
// Run inference synchronously (MLModel prediction is thread-safe)
|
||||
let output: [UInt8] = try await withCheckedThrowingContinuation { continuation in
|
||||
DispatchQueue.global(qos: .userInitiated).async {
|
||||
do {
|
||||
// Create input feature from pixel buffer
|
||||
let inputFeature = try MLFeatureValue(pixelBuffer: pixelBuffer)
|
||||
let inputProvider = try MLDictionaryFeatureProvider(
|
||||
dictionary: ["input": inputFeature]
|
||||
)
|
||||
|
||||
// Run inference synchronously
|
||||
let prediction = try model.prediction(from: inputProvider)
|
||||
|
||||
// Extract output from model
|
||||
// The model outputs to "activation_out" as either MultiArray or Image
|
||||
let rgbaData: [UInt8]
|
||||
|
||||
if let outputValue = prediction.featureValue(for: "activation_out") {
|
||||
if let multiArray = outputValue.multiArrayValue {
|
||||
// Output is MLMultiArray with shape [C, H, W]
|
||||
self.logger.info("Output is MultiArray: \(multiArray.shape)")
|
||||
rgbaData = try self.multiArrayToRGBA(multiArray)
|
||||
} else if let outputBuffer = outputValue.imageBufferValue {
|
||||
// Output is CVPixelBuffer (image)
|
||||
let outWidth = CVPixelBufferGetWidth(outputBuffer)
|
||||
let outHeight = CVPixelBufferGetHeight(outputBuffer)
|
||||
self.logger.info("Output is Image: \(outWidth)x\(outHeight)")
|
||||
rgbaData = try ImageFormatConverter.pixelBufferToRGBAData(outputBuffer)
|
||||
} else {
|
||||
continuation.resume(throwing: AIEnhanceError.inferenceError(
|
||||
"Cannot extract data from model output"
|
||||
))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
continuation.resume(throwing: AIEnhanceError.inferenceError(
|
||||
"Model output 'activation_out' not found"
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
continuation.resume(returning: rgbaData)
|
||||
} catch let error as AIEnhanceError {
|
||||
continuation.resume(throwing: error)
|
||||
} catch {
|
||||
continuation.resume(throwing: AIEnhanceError.inferenceError(error.localizedDescription))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
/// Convert MLMultiArray [C, H, W] to RGBA byte array
|
||||
/// - Parameter multiArray: Output from model with shape [3, H, W] (RGB channels)
|
||||
/// - Returns: RGBA byte array with shape [H * W * 4]
|
||||
private func multiArrayToRGBA(_ multiArray: MLMultiArray) throws -> [UInt8] {
|
||||
let shape = multiArray.shape.map { $0.intValue }
|
||||
|
||||
// Expect shape [3, H, W] for RGB
|
||||
guard shape.count == 3, shape[0] == 3 else {
|
||||
throw AIEnhanceError.inferenceError(
|
||||
"Unexpected output shape: \(shape), expected [3, H, W]"
|
||||
)
|
||||
}
|
||||
|
||||
let channels = shape[0]
|
||||
let height = shape[1]
|
||||
let width = shape[2]
|
||||
|
||||
logger.info("Converting MultiArray \(channels)x\(height)x\(width) to RGBA")
|
||||
|
||||
// Output array: RGBA format
|
||||
var rgbaData = [UInt8](repeating: 255, count: width * height * 4)
|
||||
|
||||
// Get pointer to MultiArray data
|
||||
let dataPointer = multiArray.dataPointer.assumingMemoryBound(to: Float32.self)
|
||||
let channelStride = height * width
|
||||
|
||||
// Convert CHW (channel-first) to RGBA (interleaved)
|
||||
// Model output is typically in range [0, 1] or [-1, 1], need to scale to [0, 255]
|
||||
for y in 0..<height {
|
||||
for x in 0..<width {
|
||||
let pixelIndex = y * width + x
|
||||
let rgbaIndex = pixelIndex * 4
|
||||
|
||||
// Read RGB values from CHW layout
|
||||
let r = dataPointer[0 * channelStride + pixelIndex]
|
||||
let g = dataPointer[1 * channelStride + pixelIndex]
|
||||
let b = dataPointer[2 * channelStride + pixelIndex]
|
||||
|
||||
// Clamp and convert to 0-255
|
||||
// Model typically outputs values in [0, 1] range
|
||||
rgbaData[rgbaIndex + 0] = UInt8(clamping: Int(max(0, min(1, r)) * 255))
|
||||
rgbaData[rgbaIndex + 1] = UInt8(clamping: Int(max(0, min(1, g)) * 255))
|
||||
rgbaData[rgbaIndex + 2] = UInt8(clamping: Int(max(0, min(1, b)) * 255))
|
||||
rgbaData[rgbaIndex + 3] = 255 // Alpha
|
||||
}
|
||||
}
|
||||
|
||||
return rgbaData
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user