diff --git a/Package.swift b/Package.swift index a937160..336e978 100644 --- a/Package.swift +++ b/Package.swift @@ -18,9 +18,8 @@ let package = Package( name: "LivePhotoCore", dependencies: [], resources: [ - .copy("Resources/metadata.mov"), - // AI 超分辨率模型(Real-ESRGAN x4plus) - .process("Resources/RealESRGAN_x4plus.mlmodel") + .copy("Resources/metadata.mov") + // AI 模型已移至 On-Demand Resources,按需下载 ] ), .testTarget( diff --git a/Sources/LivePhotoCore/AIEnhancer/ODRManager.swift b/Sources/LivePhotoCore/AIEnhancer/ODRManager.swift new file mode 100644 index 0000000..3fd0209 --- /dev/null +++ b/Sources/LivePhotoCore/AIEnhancer/ODRManager.swift @@ -0,0 +1,201 @@ +// +// ODRManager.swift +// LivePhotoCore +// +// On-Demand Resources manager for AI model download. +// + +import Foundation +import os + +// MARK: - Download State + +/// Model download state +public enum ModelDownloadState: Sendable, Equatable { + case notDownloaded + case downloading(progress: Double) + case downloaded + case failed(String) + + public static func == (lhs: ModelDownloadState, rhs: ModelDownloadState) -> Bool { + switch (lhs, rhs) { + case (.notDownloaded, .notDownloaded): return true + case (.downloaded, .downloaded): return true + case let (.downloading(p1), .downloading(p2)): return p1 == p2 + case let (.failed(e1), .failed(e2)): return e1 == e2 + default: return false + } + } +} + +// MARK: - ODR Manager + +/// On-Demand Resources manager for AI model +public actor ODRManager { + public static let shared = ODRManager() + + private static let modelTag = "ai-model" + private static let modelName = "RealESRGAN_x4plus" + + private var resourceRequest: NSBundleResourceRequest? + private var cachedModelURL: URL? + private let logger = Logger(subsystem: "LivePhotoCore", category: "ODRManager") + + private init() {} + + // MARK: - Public API + + /// Check if model is available locally (either in ODR cache or bundle) + public func isModelAvailable() async -> Bool { + // First check if we have a cached URL + if let url = cachedModelURL, FileManager.default.fileExists(atPath: url.path) { + return true + } + + // Check bundle (development/fallback) + if getBundleModelURL() != nil { + return true + } + + // Check ODR conditionally (only available in app context) + return await checkODRAvailability() + } + + /// Get current download state + public func getDownloadState() async -> ModelDownloadState { + if await isModelAvailable() { + return .downloaded + } + + if resourceRequest != nil { + return .downloading(progress: 0) + } + + return .notDownloaded + } + + /// Download model with progress callback + /// - Parameter progress: Progress callback (0.0 to 1.0) + public func downloadModel(progress: @escaping @Sendable (Double) -> Void) async throws { + // Check if already available + if await isModelAvailable() { + logger.info("Model already available, skipping download") + progress(1.0) + return + } + + logger.info("Starting ODR download for model: \(Self.modelTag)") + + // Create resource request + let request = NSBundleResourceRequest(tags: [Self.modelTag]) + self.resourceRequest = request + + // Set up progress observation + let observation = request.progress.observe(\.fractionCompleted) { progressObj, _ in + Task { @MainActor in + progress(progressObj.fractionCompleted) + } + } + + defer { + observation.invalidate() + } + + do { + // Begin accessing resources + try await request.beginAccessingResources() + + logger.info("ODR download completed successfully") + + // Find and cache the model URL + if let url = findModelInBundle(request.bundle) { + cachedModelURL = url + logger.info("Model cached at: \(url.path)") + } + + progress(1.0) + } catch { + logger.error("ODR download failed: \(error.localizedDescription)") + self.resourceRequest = nil + throw AIEnhanceError.modelLoadFailed("Download failed: \(error.localizedDescription)") + } + } + + /// Get model URL (after download or from bundle) + public func getModelURL() -> URL? { + // Return cached URL if available + if let url = cachedModelURL { + return url + } + + // Check bundle fallback + if let url = getBundleModelURL() { + return url + } + + // Try to find in ODR bundle + if let request = resourceRequest, let url = findModelInBundle(request.bundle) { + cachedModelURL = url + return url + } + + return nil + } + + /// Release ODR resources when not in use + public func releaseResources() { + resourceRequest?.endAccessingResources() + resourceRequest = nil + cachedModelURL = nil + logger.info("ODR resources released") + } + + // MARK: - Private Helpers + + private func checkODRAvailability() async -> Bool { + // Use conditionallyBeginAccessingResources to check without triggering download + let request = NSBundleResourceRequest(tags: [Self.modelTag]) + + return await withCheckedContinuation { continuation in + request.conditionallyBeginAccessingResources { available in + if available { + // Model is already downloaded via ODR + self.logger.debug("ODR model is available locally") + } + continuation.resume(returning: available) + } + } + } + + private func getBundleModelURL() -> URL? { + // Try main bundle first + if let url = Bundle.main.url(forResource: Self.modelName, withExtension: "mlmodelc") { + return url + } + if let url = Bundle.main.url(forResource: Self.modelName, withExtension: "mlpackage") { + return url + } + + // Try SPM bundle (development) + #if SWIFT_PACKAGE + if let url = Bundle.module.url(forResource: Self.modelName, withExtension: "mlmodelc") { + return url + } + if let url = Bundle.module.url(forResource: Self.modelName, withExtension: "mlpackage") { + return url + } + #endif + + return nil + } + + private func findModelInBundle(_ bundle: Bundle) -> URL? { + if let url = bundle.url(forResource: Self.modelName, withExtension: "mlmodelc") { + return url + } + if let url = bundle.url(forResource: Self.modelName, withExtension: "mlpackage") { + return url + } + return nil + } +} diff --git a/to-live-photo/to-live-photo/Views/EditorView.swift b/to-live-photo/to-live-photo/Views/EditorView.swift index b96ef3e..36900e1 100644 --- a/to-live-photo/to-live-photo/Views/EditorView.swift +++ b/to-live-photo/to-live-photo/Views/EditorView.swift @@ -37,6 +37,9 @@ struct EditorView: View { // AI 超分辨率 @State private var aiEnhanceEnabled: Bool = false + @State private var aiModelNeedsDownload: Bool = false + @State private var aiModelDownloading: Bool = false + @State private var aiModelDownloadProgress: Double = 0 // 视频诊断 @State private var videoDiagnosis: VideoDiagnosis? @@ -370,10 +373,45 @@ struct EditorView: View { } } .tint(.purple) - .disabled(!AIEnhancer.isAvailable()) + .disabled(!AIEnhancer.isAvailable() || aiModelDownloading) + .onChange(of: aiEnhanceEnabled) { _, newValue in + if newValue { + checkAndDownloadModel() + } + } - if aiEnhanceEnabled { + // 模型下载进度 + if aiModelDownloading { + VStack(alignment: .leading, spacing: 8) { + HStack(spacing: 8) { + ProgressView() + .scaleEffect(0.8) + Text("正在下载 AI 模型...") + .font(.caption) + .foregroundStyle(.secondary) + } + + ProgressView(value: aiModelDownloadProgress) + .tint(.purple) + + Text(String(format: "%.0f%%", aiModelDownloadProgress * 100)) + .font(.caption2) + .foregroundStyle(.secondary) + } + .padding(.leading, 4) + } + + if aiEnhanceEnabled && !aiModelDownloading { VStack(alignment: .leading, spacing: 6) { + if aiModelNeedsDownload { + HStack(spacing: 4) { + Image(systemName: "arrow.down.circle") + .foregroundStyle(.orange) + .font(.caption) + Text("首次使用需下载 AI 模型(约 64MB)") + .font(.caption) + } + } HStack(spacing: 4) { Image(systemName: "sparkles") .foregroundStyle(.purple) @@ -415,6 +453,10 @@ struct EditorView: View { .padding(16) .background(Color.purple.opacity(0.1)) .clipShape(RoundedRectangle(cornerRadius: 12)) + .task { + // 检查模型是否需要下载 + aiModelNeedsDownload = await AIEnhancer.needsDownload() + } } // MARK: - 兼容模式开关 @@ -681,6 +723,46 @@ struct EditorView: View { return CropRect(x: cropX, y: cropY, width: cropWidth, height: cropHeight) } + private func checkAndDownloadModel() { + guard aiEnhanceEnabled else { return } + + Task { + // 检查是否需要下载 + let needsDownload = await AIEnhancer.needsDownload() + + await MainActor.run { + aiModelNeedsDownload = needsDownload + } + + if needsDownload { + await MainActor.run { + aiModelDownloading = true + aiModelDownloadProgress = 0 + } + + do { + try await AIEnhancer.downloadModel { progress in + Task { @MainActor in + aiModelDownloadProgress = progress + } + } + + await MainActor.run { + aiModelDownloading = false + aiModelNeedsDownload = false + } + } catch { + await MainActor.run { + aiModelDownloading = false + // 下载失败时禁用 AI 增强 + aiEnhanceEnabled = false + } + print("Failed to download AI model: \(error)") + } + } + } + } + private func startProcessing() { Analytics.shared.log(.editorGenerateClick, parameters: [ "trimStart": trimStart,