添加chat gemini、chatfire端点、 图片生成 gemini、chatfire 更轻松的AI配置

This commit is contained in:
Connor
2026-01-14 02:25:41 +08:00
parent 4d38357ff6
commit 23b45efae9
22 changed files with 1512 additions and 405 deletions

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
models "github.com/drama-generator/backend/domain/models"
@@ -23,6 +24,24 @@ type ImageGenerationService struct {
log *logger.Logger
}
// truncateImageURL 截断图片 URL避免 base64 格式的 URL 占满日志
func truncateImageURL(url string) string {
if url == "" {
return ""
}
// 如果是 data URI 格式base64只显示前缀
if strings.HasPrefix(url, "data:") {
if len(url) > 50 {
return url[:50] + "...[base64 data]"
}
}
// 普通 URL 如果过长也截断
if len(url) > 100 {
return url[:100] + "..."
}
return url
}
func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService {
return &ImageGenerationService{
db: db,
@@ -246,17 +265,23 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
now := time.Now()
// 下载图片到本地存储(仅用于缓存,不更新数据库)
if s.localStorage != nil && result.ImageURL != "" {
// 仅下载 HTTP/HTTPS URL跳过 data URI
if s.localStorage != nil && result.ImageURL != "" &&
(strings.HasPrefix(result.ImageURL, "http://") || strings.HasPrefix(result.ImageURL, "https://")) {
_, err := s.localStorage.DownloadFromURL(result.ImageURL, "images")
if err != nil {
errStr := err.Error()
if len(errStr) > 200 {
errStr = errStr[:200] + "..."
}
s.log.Warnw("Failed to download image to local storage",
"error", err,
"error", errStr,
"id", imageGenID,
"original_url", result.ImageURL)
"original_url", truncateImageURL(result.ImageURL))
} else {
s.log.Infow("Image downloaded to local storage for caching",
"id", imageGenID,
"original_url", result.ImageURL)
"original_url", truncateImageURL(result.ImageURL))
}
}
@@ -291,7 +316,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
} else {
s.log.Infow("Storyboard updated with composed image",
"storyboard_id", *imageGen.StoryboardID,
"composed_image", result.ImageURL)
"composed_image", truncateImageURL(result.ImageURL))
}
}
@@ -306,7 +331,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
} else {
s.log.Infow("Scene updated with generated image",
"scene_id", *imageGen.SceneID,
"image_url", result.ImageURL)
"image_url", truncateImageURL(result.ImageURL))
}
}
@@ -317,7 +342,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
} else {
s.log.Infow("Character updated with generated image",
"character_id", *imageGen.CharacterID,
"image_url", result.ImageURL)
"image_url", truncateImageURL(result.ImageURL))
}
}
}
@@ -356,13 +381,33 @@ func (s *ImageGenerationService) getImageClient(provider string) (image.ImageCli
model = config.Model[0]
}
switch provider {
// 使用配置中的 provider如果没有则使用传入的 provider
actualProvider := config.Provider
if actualProvider == "" {
actualProvider = provider
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch actualProvider {
case "openai", "dalle":
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
case "stable_diffusion", "sd":
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "chatfire":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "volcengine", "volces", "doubao":
endpoint = "/images/generations"
queryEndpoint = ""
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
@@ -394,13 +439,33 @@ func (s *ImageGenerationService) getImageClientWithModel(provider string, modelN
model = config.Model[0]
}
switch provider {
// 使用配置中的 provider如果没有则使用传入的 provider
actualProvider := config.Provider
if actualProvider == "" {
actualProvider = provider
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch actualProvider {
case "openai", "dalle":
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
case "stable_diffusion", "sd":
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "chatfire":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "volcengine", "volces", "doubao":
endpoint = "/images/generations"
queryEndpoint = ""
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
@@ -716,21 +781,11 @@ func (s *ImageGenerationService) extractBackgroundsFromScript(scriptContent stri
请严格按照JSON格式输出确保所有字段都使用中文。`, scriptContent)
messages := []ai.ChatMessage{
{Role: "user", Content: prompt},
}
resp, err := client.ChatCompletion(messages, ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
response, err := client.GenerateText(prompt, "", ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
if err != nil {
s.log.Errorw("Failed to extract backgrounds with AI", "error", err)
return nil, fmt.Errorf("AI提取场景失败: %w", err)
}
if len(resp.Choices) == 0 {
return nil, fmt.Errorf("AI未返回有效响应")
}
response := resp.Choices[0].Message.Content
s.log.Infow("AI backgrounds extraction response", "length", len(response))
// 解析JSON响应