添加chat gemini、chatfire端点、 图片生成 gemini、chatfire 更轻松的AI配置
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
models "github.com/drama-generator/backend/domain/models"
|
||||
@@ -23,6 +24,24 @@ type ImageGenerationService struct {
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
// truncateImageURL 截断图片 URL,避免 base64 格式的 URL 占满日志
|
||||
func truncateImageURL(url string) string {
|
||||
if url == "" {
|
||||
return ""
|
||||
}
|
||||
// 如果是 data URI 格式(base64),只显示前缀
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
if len(url) > 50 {
|
||||
return url[:50] + "...[base64 data]"
|
||||
}
|
||||
}
|
||||
// 普通 URL 如果过长也截断
|
||||
if len(url) > 100 {
|
||||
return url[:100] + "..."
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService {
|
||||
return &ImageGenerationService{
|
||||
db: db,
|
||||
@@ -246,17 +265,23 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
||||
now := time.Now()
|
||||
|
||||
// 下载图片到本地存储(仅用于缓存,不更新数据库)
|
||||
if s.localStorage != nil && result.ImageURL != "" {
|
||||
// 仅下载 HTTP/HTTPS URL,跳过 data URI
|
||||
if s.localStorage != nil && result.ImageURL != "" &&
|
||||
(strings.HasPrefix(result.ImageURL, "http://") || strings.HasPrefix(result.ImageURL, "https://")) {
|
||||
_, err := s.localStorage.DownloadFromURL(result.ImageURL, "images")
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if len(errStr) > 200 {
|
||||
errStr = errStr[:200] + "..."
|
||||
}
|
||||
s.log.Warnw("Failed to download image to local storage",
|
||||
"error", err,
|
||||
"error", errStr,
|
||||
"id", imageGenID,
|
||||
"original_url", result.ImageURL)
|
||||
"original_url", truncateImageURL(result.ImageURL))
|
||||
} else {
|
||||
s.log.Infow("Image downloaded to local storage for caching",
|
||||
"id", imageGenID,
|
||||
"original_url", result.ImageURL)
|
||||
"original_url", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,7 +316,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
||||
} else {
|
||||
s.log.Infow("Storyboard updated with composed image",
|
||||
"storyboard_id", *imageGen.StoryboardID,
|
||||
"composed_image", result.ImageURL)
|
||||
"composed_image", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -306,7 +331,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
||||
} else {
|
||||
s.log.Infow("Scene updated with generated image",
|
||||
"scene_id", *imageGen.SceneID,
|
||||
"image_url", result.ImageURL)
|
||||
"image_url", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,7 +342,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
||||
} else {
|
||||
s.log.Infow("Character updated with generated image",
|
||||
"character_id", *imageGen.CharacterID,
|
||||
"image_url", result.ImageURL)
|
||||
"image_url", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -356,13 +381,33 @@ func (s *ImageGenerationService) getImageClient(provider string) (image.ImageCli
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
switch provider {
|
||||
// 使用配置中的 provider,如果没有则使用传入的 provider
|
||||
actualProvider := config.Provider
|
||||
if actualProvider == "" {
|
||||
actualProvider = provider
|
||||
}
|
||||
|
||||
// 根据 provider 自动设置默认端点
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch actualProvider {
|
||||
case "openai", "dalle":
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "stable_diffusion", "sd":
|
||||
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "chatfire":
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "volcengine", "volces", "doubao":
|
||||
endpoint = "/images/generations"
|
||||
queryEndpoint = ""
|
||||
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
case "gemini", "google":
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
default:
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,13 +439,33 @@ func (s *ImageGenerationService) getImageClientWithModel(provider string, modelN
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
switch provider {
|
||||
// 使用配置中的 provider,如果没有则使用传入的 provider
|
||||
actualProvider := config.Provider
|
||||
if actualProvider == "" {
|
||||
actualProvider = provider
|
||||
}
|
||||
|
||||
// 根据 provider 自动设置默认端点
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch actualProvider {
|
||||
case "openai", "dalle":
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "stable_diffusion", "sd":
|
||||
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "chatfire":
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "volcengine", "volces", "doubao":
|
||||
endpoint = "/images/generations"
|
||||
queryEndpoint = ""
|
||||
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
case "gemini", "google":
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
default:
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -716,21 +781,11 @@ func (s *ImageGenerationService) extractBackgroundsFromScript(scriptContent stri
|
||||
|
||||
请严格按照JSON格式输出,确保所有字段都使用中文。`, scriptContent)
|
||||
|
||||
messages := []ai.ChatMessage{
|
||||
{Role: "user", Content: prompt},
|
||||
}
|
||||
|
||||
resp, err := client.ChatCompletion(messages, ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
|
||||
response, err := client.GenerateText(prompt, "", ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to extract backgrounds with AI", "error", err)
|
||||
return nil, fmt.Errorf("AI提取场景失败: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("AI未返回有效响应")
|
||||
}
|
||||
|
||||
response := resp.Choices[0].Message.Content
|
||||
s.log.Infow("AI backgrounds extraction response", "length", len(response))
|
||||
|
||||
// 解析JSON响应
|
||||
|
||||
Reference in New Issue
Block a user