添加chat gemini、chatfire端点、 图片生成 gemini、chatfire 更轻松的AI配置
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -63,3 +63,5 @@ configs/config.yaml
|
|||||||
# Docker publish documentation (optional)
|
# Docker publish documentation (optional)
|
||||||
DOCKER_PUBLISH.md
|
DOCKER_PUBLISH.md
|
||||||
build.sh
|
build.sh
|
||||||
|
/data/storage/
|
||||||
|
/web/package-lock.json
|
||||||
|
|||||||
13
Dockerfile
13
Dockerfile
@@ -3,6 +3,9 @@
|
|||||||
# ==================== 阶段1: 构建前端 ====================
|
# ==================== 阶段1: 构建前端 ====================
|
||||||
FROM node:20-alpine AS frontend-builder
|
FROM node:20-alpine AS frontend-builder
|
||||||
|
|
||||||
|
# 配置 npm 镜像源(国内加速)
|
||||||
|
RUN npm config set registry https://registry.npmmirror.com
|
||||||
|
|
||||||
WORKDIR /app/web
|
WORKDIR /app/web
|
||||||
|
|
||||||
# 复制前端依赖文件
|
# 复制前端依赖文件
|
||||||
@@ -59,15 +62,12 @@ RUN apk add --no-cache \
|
|||||||
tzdata \
|
tzdata \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
sqlite-libs \
|
sqlite-libs \
|
||||||
|
wget \
|
||||||
&& rm -rf /var/cache/apk/*
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
# 设置时区
|
# 设置时区
|
||||||
ENV TZ=Asia/Shanghai
|
ENV TZ=Asia/Shanghai
|
||||||
|
|
||||||
# 创建非 root 用户
|
|
||||||
RUN addgroup -g 1000 app && \
|
|
||||||
adduser -D -u 1000 -G app app
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# 从构建阶段复制可执行文件
|
# 从构建阶段复制可执行文件
|
||||||
@@ -83,10 +83,7 @@ RUN cp ./configs/config.example.yaml ./configs/config.yaml
|
|||||||
# 复制数据库迁移文件
|
# 复制数据库迁移文件
|
||||||
COPY migrations ./migrations/
|
COPY migrations ./migrations/
|
||||||
|
|
||||||
# 切换到非 root 用户
|
# 创建数据目录(root 用户运行,无需权限设置)
|
||||||
USER app
|
|
||||||
|
|
||||||
# 创建数据目录(在 app 用户下创建,确保权限正确)
|
|
||||||
RUN mkdir -p /app/data/storage
|
RUN mkdir -p /app/data/storage
|
||||||
|
|
||||||
# 暴露端口
|
# 暴露端口
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
|
|||||||
type CreateAIConfigRequest struct {
|
type CreateAIConfigRequest struct {
|
||||||
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
|
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
|
||||||
Name string `json:"name" binding:"required,min=1,max=100"`
|
Name string `json:"name" binding:"required,min=1,max=100"`
|
||||||
|
Provider string `json:"provider" binding:"required"`
|
||||||
BaseURL string `json:"base_url" binding:"required,url"`
|
BaseURL string `json:"base_url" binding:"required,url"`
|
||||||
APIKey string `json:"api_key" binding:"required"`
|
APIKey string `json:"api_key" binding:"required"`
|
||||||
Model models.ModelField `json:"model" binding:"required"`
|
Model models.ModelField `json:"model" binding:"required"`
|
||||||
@@ -37,6 +38,7 @@ type CreateAIConfigRequest struct {
|
|||||||
|
|
||||||
type UpdateAIConfigRequest struct {
|
type UpdateAIConfigRequest struct {
|
||||||
Name string `json:"name" binding:"omitempty,min=1,max=100"`
|
Name string `json:"name" binding:"omitempty,min=1,max=100"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
BaseURL string `json:"base_url" binding:"omitempty,url"`
|
BaseURL string `json:"base_url" binding:"omitempty,url"`
|
||||||
APIKey string `json:"api_key"`
|
APIKey string `json:"api_key"`
|
||||||
Model *models.ModelField `json:"model"`
|
Model *models.ModelField `json:"model"`
|
||||||
@@ -52,18 +54,53 @@ type TestConnectionRequest struct {
|
|||||||
BaseURL string `json:"base_url" binding:"required,url"`
|
BaseURL string `json:"base_url" binding:"required,url"`
|
||||||
APIKey string `json:"api_key" binding:"required"`
|
APIKey string `json:"api_key" binding:"required"`
|
||||||
Model models.ModelField `json:"model" binding:"required"`
|
Model models.ModelField `json:"model" binding:"required"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
|
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
|
||||||
|
// 根据 provider 和 service_type 自动设置 endpoint
|
||||||
|
endpoint := req.Endpoint
|
||||||
|
queryEndpoint := req.QueryEndpoint
|
||||||
|
|
||||||
|
if endpoint == "" {
|
||||||
|
switch req.Provider {
|
||||||
|
case "gemini", "google":
|
||||||
|
if req.ServiceType == "text" {
|
||||||
|
endpoint = "/v1beta/models/{model}:generateContent"
|
||||||
|
} else if req.ServiceType == "image" {
|
||||||
|
endpoint = "/v1beta/models/{model}:generateContent"
|
||||||
|
}
|
||||||
|
case "openai", "chatfire":
|
||||||
|
if req.ServiceType == "text" {
|
||||||
|
endpoint = "/chat/completions"
|
||||||
|
} else if req.ServiceType == "image" {
|
||||||
|
endpoint = "/images/generations"
|
||||||
|
} else if req.ServiceType == "video" {
|
||||||
|
endpoint = "/video/generations"
|
||||||
|
if queryEndpoint == "" {
|
||||||
|
queryEndpoint = "/v1/video/task/{taskId}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// 默认使用 OpenAI 格式
|
||||||
|
if req.ServiceType == "text" {
|
||||||
|
endpoint = "/chat/completions"
|
||||||
|
} else if req.ServiceType == "image" {
|
||||||
|
endpoint = "/images/generations"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
config := &models.AIServiceConfig{
|
config := &models.AIServiceConfig{
|
||||||
ServiceType: req.ServiceType,
|
ServiceType: req.ServiceType,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
|
Provider: req.Provider,
|
||||||
BaseURL: req.BaseURL,
|
BaseURL: req.BaseURL,
|
||||||
APIKey: req.APIKey,
|
APIKey: req.APIKey,
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Endpoint: req.Endpoint,
|
Endpoint: endpoint,
|
||||||
QueryEndpoint: req.QueryEndpoint,
|
QueryEndpoint: queryEndpoint,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
IsDefault: req.IsDefault,
|
IsDefault: req.IsDefault,
|
||||||
IsActive: true,
|
IsActive: true,
|
||||||
@@ -75,7 +112,7 @@ func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceC
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.log.Infow("AI config created", "config_id", config.ID)
|
s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint)
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,6 +162,9 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
|
|||||||
if req.Name != "" {
|
if req.Name != "" {
|
||||||
updates["name"] = req.Name
|
updates["name"] = req.Name
|
||||||
}
|
}
|
||||||
|
if req.Provider != "" {
|
||||||
|
updates["provider"] = req.Provider
|
||||||
|
}
|
||||||
if req.BaseURL != "" {
|
if req.BaseURL != "" {
|
||||||
updates["base_url"] = req.BaseURL
|
updates["base_url"] = req.BaseURL
|
||||||
}
|
}
|
||||||
@@ -137,9 +177,30 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
|
|||||||
if req.Priority != nil {
|
if req.Priority != nil {
|
||||||
updates["priority"] = *req.Priority
|
updates["priority"] = *req.Priority
|
||||||
}
|
}
|
||||||
if req.Endpoint != "" {
|
|
||||||
|
// 如果提供了 provider,根据 provider 和 service_type 自动设置 endpoint
|
||||||
|
if req.Provider != "" && req.Endpoint == "" {
|
||||||
|
provider := req.Provider
|
||||||
|
serviceType := config.ServiceType
|
||||||
|
|
||||||
|
switch provider {
|
||||||
|
case "gemini", "google":
|
||||||
|
if serviceType == "text" || serviceType == "image" {
|
||||||
|
updates["endpoint"] = "/v1beta/models/{model}:generateContent"
|
||||||
|
}
|
||||||
|
case "openai", "chatfire":
|
||||||
|
if serviceType == "text" {
|
||||||
|
updates["endpoint"] = "/chat/completions"
|
||||||
|
} else if serviceType == "image" {
|
||||||
|
updates["endpoint"] = "/images/generations"
|
||||||
|
} else if serviceType == "video" {
|
||||||
|
updates["endpoint"] = "/video/generations"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if req.Endpoint != "" {
|
||||||
updates["endpoint"] = req.Endpoint
|
updates["endpoint"] = req.Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
// 允许清空query_endpoint,所以不检查是否为空
|
// 允许清空query_endpoint,所以不检查是否为空
|
||||||
updates["query_endpoint"] = req.QueryEndpoint
|
updates["query_endpoint"] = req.QueryEndpoint
|
||||||
if req.Settings != "" {
|
if req.Settings != "" {
|
||||||
@@ -179,13 +240,51 @@ func (s *AIService) DeleteConfig(configID uint) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
|
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
|
||||||
|
s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model))
|
||||||
|
|
||||||
// 使用第一个模型进行测试
|
// 使用第一个模型进行测试
|
||||||
model := ""
|
model := ""
|
||||||
if len(req.Model) > 0 {
|
if len(req.Model) > 0 {
|
||||||
model = req.Model[0]
|
model = req.Model[0]
|
||||||
}
|
}
|
||||||
client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint)
|
s.log.Infow("Using model for test", "model", model, "provider", req.Provider)
|
||||||
return client.TestConnection()
|
|
||||||
|
// 根据 provider 参数选择客户端
|
||||||
|
var client ai.AIClient
|
||||||
|
var endpoint string
|
||||||
|
|
||||||
|
switch req.Provider {
|
||||||
|
case "gemini", "google":
|
||||||
|
// Gemini
|
||||||
|
s.log.Infow("Using Gemini client", "baseURL", req.BaseURL)
|
||||||
|
endpoint = "/v1beta/models/{model}:generateContent"
|
||||||
|
client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||||
|
case "openai", "chatfire":
|
||||||
|
// OpenAI 格式(包括 chatfire 等)
|
||||||
|
s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider)
|
||||||
|
endpoint = req.Endpoint
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = "/chat/completions"
|
||||||
|
}
|
||||||
|
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||||
|
default:
|
||||||
|
// 默认使用 OpenAI 格式
|
||||||
|
s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL)
|
||||||
|
endpoint = req.Endpoint
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = "/chat/completions"
|
||||||
|
}
|
||||||
|
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Infow("Calling TestConnection on client", "endpoint", endpoint)
|
||||||
|
err := client.TestConnection()
|
||||||
|
if err != nil {
|
||||||
|
s.log.Errorw("TestConnection failed", "error", err)
|
||||||
|
} else {
|
||||||
|
s.log.Infow("TestConnection succeeded")
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
|
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
|
||||||
@@ -228,7 +327,7 @@ func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*mo
|
|||||||
return nil, errors.New("no config found for model: " + modelName)
|
return nil, errors.New("no config found for model: " + modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
|
func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) {
|
||||||
config, err := s.GetDefaultConfig(serviceType)
|
config, err := s.GetDefaultConfig(serviceType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -240,7 +339,25 @@ func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
|
|||||||
model = config.Model[0]
|
model = config.Model[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil
|
// 使用数据库配置中的 endpoint,如果为空则根据 provider 设置默认值
|
||||||
|
endpoint := config.Endpoint
|
||||||
|
if endpoint == "" {
|
||||||
|
switch config.Provider {
|
||||||
|
case "gemini", "google":
|
||||||
|
endpoint = "/v1beta/models/{model}:generateContent"
|
||||||
|
default:
|
||||||
|
endpoint = "/chat/completions"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据 provider 创建对应的客户端
|
||||||
|
switch config.Provider {
|
||||||
|
case "gemini", "google":
|
||||||
|
return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
|
default:
|
||||||
|
// openai, chatfire 等其他厂商都使用 OpenAI 格式
|
||||||
|
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {
|
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
models "github.com/drama-generator/backend/domain/models"
|
models "github.com/drama-generator/backend/domain/models"
|
||||||
@@ -23,6 +24,24 @@ type ImageGenerationService struct {
|
|||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncateImageURL 截断图片 URL,避免 base64 格式的 URL 占满日志
|
||||||
|
func truncateImageURL(url string) string {
|
||||||
|
if url == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// 如果是 data URI 格式(base64),只显示前缀
|
||||||
|
if strings.HasPrefix(url, "data:") {
|
||||||
|
if len(url) > 50 {
|
||||||
|
return url[:50] + "...[base64 data]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 普通 URL 如果过长也截断
|
||||||
|
if len(url) > 100 {
|
||||||
|
return url[:100] + "..."
|
||||||
|
}
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
|
||||||
func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService {
|
func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService {
|
||||||
return &ImageGenerationService{
|
return &ImageGenerationService{
|
||||||
db: db,
|
db: db,
|
||||||
@@ -246,17 +265,23 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// 下载图片到本地存储(仅用于缓存,不更新数据库)
|
// 下载图片到本地存储(仅用于缓存,不更新数据库)
|
||||||
if s.localStorage != nil && result.ImageURL != "" {
|
// 仅下载 HTTP/HTTPS URL,跳过 data URI
|
||||||
|
if s.localStorage != nil && result.ImageURL != "" &&
|
||||||
|
(strings.HasPrefix(result.ImageURL, "http://") || strings.HasPrefix(result.ImageURL, "https://")) {
|
||||||
_, err := s.localStorage.DownloadFromURL(result.ImageURL, "images")
|
_, err := s.localStorage.DownloadFromURL(result.ImageURL, "images")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if len(errStr) > 200 {
|
||||||
|
errStr = errStr[:200] + "..."
|
||||||
|
}
|
||||||
s.log.Warnw("Failed to download image to local storage",
|
s.log.Warnw("Failed to download image to local storage",
|
||||||
"error", err,
|
"error", errStr,
|
||||||
"id", imageGenID,
|
"id", imageGenID,
|
||||||
"original_url", result.ImageURL)
|
"original_url", truncateImageURL(result.ImageURL))
|
||||||
} else {
|
} else {
|
||||||
s.log.Infow("Image downloaded to local storage for caching",
|
s.log.Infow("Image downloaded to local storage for caching",
|
||||||
"id", imageGenID,
|
"id", imageGenID,
|
||||||
"original_url", result.ImageURL)
|
"original_url", truncateImageURL(result.ImageURL))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,7 +316,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
|||||||
} else {
|
} else {
|
||||||
s.log.Infow("Storyboard updated with composed image",
|
s.log.Infow("Storyboard updated with composed image",
|
||||||
"storyboard_id", *imageGen.StoryboardID,
|
"storyboard_id", *imageGen.StoryboardID,
|
||||||
"composed_image", result.ImageURL)
|
"composed_image", truncateImageURL(result.ImageURL))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,7 +331,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
|||||||
} else {
|
} else {
|
||||||
s.log.Infow("Scene updated with generated image",
|
s.log.Infow("Scene updated with generated image",
|
||||||
"scene_id", *imageGen.SceneID,
|
"scene_id", *imageGen.SceneID,
|
||||||
"image_url", result.ImageURL)
|
"image_url", truncateImageURL(result.ImageURL))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,7 +342,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
|||||||
} else {
|
} else {
|
||||||
s.log.Infow("Character updated with generated image",
|
s.log.Infow("Character updated with generated image",
|
||||||
"character_id", *imageGen.CharacterID,
|
"character_id", *imageGen.CharacterID,
|
||||||
"image_url", result.ImageURL)
|
"image_url", truncateImageURL(result.ImageURL))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -356,13 +381,33 @@ func (s *ImageGenerationService) getImageClient(provider string) (image.ImageCli
|
|||||||
model = config.Model[0]
|
model = config.Model[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
switch provider {
|
// 使用配置中的 provider,如果没有则使用传入的 provider
|
||||||
|
actualProvider := config.Provider
|
||||||
|
if actualProvider == "" {
|
||||||
|
actualProvider = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据 provider 自动设置默认端点
|
||||||
|
var endpoint string
|
||||||
|
var queryEndpoint string
|
||||||
|
|
||||||
|
switch actualProvider {
|
||||||
case "openai", "dalle":
|
case "openai", "dalle":
|
||||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
endpoint = "/images/generations"
|
||||||
case "stable_diffusion", "sd":
|
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
|
case "chatfire":
|
||||||
|
endpoint = "/images/generations"
|
||||||
|
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
|
case "volcengine", "volces", "doubao":
|
||||||
|
endpoint = "/images/generations"
|
||||||
|
queryEndpoint = ""
|
||||||
|
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||||
|
case "gemini", "google":
|
||||||
|
endpoint = "/v1beta/models/{model}:generateContent"
|
||||||
|
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
default:
|
default:
|
||||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
endpoint = "/images/generations"
|
||||||
|
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -394,13 +439,33 @@ func (s *ImageGenerationService) getImageClientWithModel(provider string, modelN
|
|||||||
model = config.Model[0]
|
model = config.Model[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
switch provider {
|
// 使用配置中的 provider,如果没有则使用传入的 provider
|
||||||
|
actualProvider := config.Provider
|
||||||
|
if actualProvider == "" {
|
||||||
|
actualProvider = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据 provider 自动设置默认端点
|
||||||
|
var endpoint string
|
||||||
|
var queryEndpoint string
|
||||||
|
|
||||||
|
switch actualProvider {
|
||||||
case "openai", "dalle":
|
case "openai", "dalle":
|
||||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
endpoint = "/images/generations"
|
||||||
case "stable_diffusion", "sd":
|
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
|
case "chatfire":
|
||||||
|
endpoint = "/images/generations"
|
||||||
|
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
|
case "volcengine", "volces", "doubao":
|
||||||
|
endpoint = "/images/generations"
|
||||||
|
queryEndpoint = ""
|
||||||
|
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||||
|
case "gemini", "google":
|
||||||
|
endpoint = "/v1beta/models/{model}:generateContent"
|
||||||
|
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
default:
|
default:
|
||||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
endpoint = "/images/generations"
|
||||||
|
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -716,21 +781,11 @@ func (s *ImageGenerationService) extractBackgroundsFromScript(scriptContent stri
|
|||||||
|
|
||||||
请严格按照JSON格式输出,确保所有字段都使用中文。`, scriptContent)
|
请严格按照JSON格式输出,确保所有字段都使用中文。`, scriptContent)
|
||||||
|
|
||||||
messages := []ai.ChatMessage{
|
response, err := client.GenerateText(prompt, "", ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
|
||||||
{Role: "user", Content: prompt},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.ChatCompletion(messages, ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Errorw("Failed to extract backgrounds with AI", "error", err)
|
s.log.Errorw("Failed to extract backgrounds with AI", "error", err)
|
||||||
return nil, fmt.Errorf("AI提取场景失败: %w", err)
|
return nil, fmt.Errorf("AI提取场景失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(resp.Choices) == 0 {
|
|
||||||
return nil, fmt.Errorf("AI未返回有效响应")
|
|
||||||
}
|
|
||||||
|
|
||||||
response := resp.Choices[0].Message.Content
|
|
||||||
s.log.Infow("AI backgrounds extraction response", "length", len(response))
|
s.log.Infow("AI backgrounds extraction response", "length", len(response))
|
||||||
|
|
||||||
// 解析JSON响应
|
// 解析JSON响应
|
||||||
|
|||||||
@@ -185,17 +185,17 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ
|
|||||||
count = 5
|
count = 5
|
||||||
}
|
}
|
||||||
|
|
||||||
systemPrompt := `你是一个专业的角色设计师,擅长创作立体丰富的剧中角色。
|
systemPrompt := `你是一个专业的角色分析师,擅长从剧本中提取和分析角色信息。
|
||||||
|
|
||||||
你的任务是根据提供的剧本大纲,创作符合故事需求的角色设定。
|
你的任务是根据提供的剧本内容,提取并整理剧中出现的所有角色的详细设定。
|
||||||
|
|
||||||
要求:
|
要求:
|
||||||
1. 角色必须服务于大纲中的故事情节和冲突
|
1. 仔细阅读剧本,识别所有出现的角色
|
||||||
2. 角色性格鲜明,有辨识度,符合故事类型
|
2. 根据剧本中的对话、行为和描述,总结角色的性格特点
|
||||||
3. 每个角色都有清晰的动机和目标,与大纲中的矛盾冲突相关
|
3. 提取角色在剧本中的关键信息:背景、动机、目标、关系等
|
||||||
4. 角色之间有合理的关系和联系
|
4. 角色之间的关系必须基于剧本中的实际描述
|
||||||
5. 外貌描述必须极其详细,便于AI绘画生成角色形象
|
5. 外貌描述必须极其详细,如果剧本中有描述则使用,如果没有则根据角色设定合理推断,便于AI绘画生成角色形象
|
||||||
6. 根据大纲的关键场景,合理设置角色数量(通常3-6个主要角色)
|
6. 优先提取主要角色和重要配角,次要角色可以简略
|
||||||
|
|
||||||
请严格按照以下 JSON 格式输出,不要添加任何其他文字:
|
请严格按照以下 JSON 格式输出,不要添加任何其他文字:
|
||||||
|
|
||||||
@@ -213,21 +213,21 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ
|
|||||||
}
|
}
|
||||||
|
|
||||||
注意:
|
注意:
|
||||||
- 角色数量根据故事复杂度确定,不要过多
|
- 必须基于剧本内容提取角色,不要凭空创作
|
||||||
- 每个角色都要与大纲中的故事线有明确关联
|
- 优先提取主要角色和重要配角,数量根据剧本实际情况确定
|
||||||
- description、personality、appearance、voice_style都必须详细描述,字数要充足
|
- description、personality、appearance、voice_style都必须详细描述,字数要充足
|
||||||
- appearance外貌描述是重中之重,必须极其详细具体,要能让AI准确生成角色形象
|
- appearance外貌描述是重中之重,必须极其详细具体,要能让AI准确生成角色形象
|
||||||
- 避免模糊描述,多用具体的视觉特征和细节`
|
- 如果剧本中角色信息不完整,可以根据角色设定合理补充,但要符合剧本整体风格`
|
||||||
|
|
||||||
outlineText := req.Outline
|
outlineText := req.Outline
|
||||||
if outlineText == "" {
|
if outlineText == "" {
|
||||||
outlineText = fmt.Sprintf("剧名:%s\n简介:%s\n类型:%s", drama.Title, drama.Description, drama.Genre)
|
outlineText = fmt.Sprintf("剧名:%s\n简介:%s\n类型:%s", drama.Title, drama.Description, drama.Genre)
|
||||||
}
|
}
|
||||||
|
|
||||||
userPrompt := fmt.Sprintf(`剧本大纲:
|
userPrompt := fmt.Sprintf(`剧本内容:
|
||||||
%s
|
%s
|
||||||
|
|
||||||
请创作 %d 个角色的详细设定。`, outlineText, count)
|
请从剧本中提取并整理最多 %d 个主要角色的详细设定。`, outlineText, count)
|
||||||
|
|
||||||
temperature := req.Temperature
|
temperature := req.Temperature
|
||||||
if temperature == 0 {
|
if temperature == 0 {
|
||||||
|
|||||||
@@ -414,24 +414,33 @@ func (s *VideoGenerationService) getVideoClient(provider string, modelName strin
|
|||||||
// 使用配置中的信息创建客户端
|
// 使用配置中的信息创建客户端
|
||||||
baseURL := config.BaseURL
|
baseURL := config.BaseURL
|
||||||
apiKey := config.APIKey
|
apiKey := config.APIKey
|
||||||
endpoint := config.Endpoint
|
|
||||||
queryEndpoint := config.QueryEndpoint
|
|
||||||
model := modelName
|
model := modelName
|
||||||
if model == "" && len(config.Model) > 0 {
|
if model == "" && len(config.Model) > 0 {
|
||||||
model = config.Model[0]
|
model = config.Model[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据 provider 自动设置默认端点
|
||||||
|
var endpoint string
|
||||||
|
var queryEndpoint string
|
||||||
|
|
||||||
switch provider {
|
switch provider {
|
||||||
case "doubao":
|
case "chatfire":
|
||||||
|
endpoint = "/video/generations"
|
||||||
|
queryEndpoint = "/v1/video/task/{taskId}"
|
||||||
|
return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
|
||||||
|
case "doubao", "volcengine", "volces":
|
||||||
|
endpoint = "/contents/generations/tasks"
|
||||||
|
queryEndpoint = "/generations/tasks/{taskId}"
|
||||||
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
|
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
|
||||||
|
case "openai":
|
||||||
|
// OpenAI Sora 使用 /v1/videos 端点
|
||||||
|
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
|
||||||
case "runway":
|
case "runway":
|
||||||
return video.NewRunwayClient(baseURL, apiKey, model), nil
|
return video.NewRunwayClient(baseURL, apiKey, model), nil
|
||||||
case "pika":
|
case "pika":
|
||||||
return video.NewPikaClient(baseURL, apiKey, model), nil
|
return video.NewPikaClient(baseURL, apiKey, model), nil
|
||||||
case "minimax":
|
case "minimax":
|
||||||
return video.NewMinimaxClient(baseURL, apiKey, model), nil
|
return video.NewMinimaxClient(baseURL, apiKey, model), nil
|
||||||
case "openai":
|
|
||||||
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported video provider: %s", provider)
|
return nil, fmt.Errorf("unsupported video provider: %s", provider)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -297,6 +297,10 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient,
|
|||||||
model = config.Model[0]
|
model = config.Model[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据 provider 自动设置默认端点
|
||||||
|
var endpoint string
|
||||||
|
var queryEndpoint string
|
||||||
|
|
||||||
switch provider {
|
switch provider {
|
||||||
case "runway":
|
case "runway":
|
||||||
return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil
|
return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil
|
||||||
@@ -306,10 +310,18 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient,
|
|||||||
return video.NewOpenAISoraClient(config.BaseURL, config.APIKey, model), nil
|
return video.NewOpenAISoraClient(config.BaseURL, config.APIKey, model), nil
|
||||||
case "minimax":
|
case "minimax":
|
||||||
return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil
|
return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil
|
||||||
|
case "chatfire":
|
||||||
|
endpoint = "/video/generations"
|
||||||
|
queryEndpoint = "/v1/video/task/{taskId}"
|
||||||
|
return video.NewChatfireClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||||
case "doubao", "volces", "ark":
|
case "doubao", "volces", "ark":
|
||||||
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil
|
endpoint = "/contents/generations/tasks"
|
||||||
|
queryEndpoint = "/generations/tasks/{taskId}"
|
||||||
|
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||||
default:
|
default:
|
||||||
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil
|
endpoint = "/contents/generations/tasks"
|
||||||
|
queryEndpoint = "/generations/tasks/{taskId}"
|
||||||
|
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "5678:5678"
|
- "5678:5678"
|
||||||
volumes:
|
volumes:
|
||||||
# 持久化数据目录(使用命名卷)
|
# 持久化数据目录(使用命名卷,容器内以 root 运行)
|
||||||
- huobao-data:/app/data
|
- huobao-data:/app/data
|
||||||
# 挂载配置文件(可选,如需自定义配置请取消注释)
|
# 挂载配置文件(可选,如需自定义配置请取消注释)
|
||||||
# - ./configs/config.yaml:/app/configs/config.yaml:ro
|
# - ./configs/config.yaml:/app/configs/config.yaml:ro
|
||||||
|
# 注意:如果使用本地目录挂载,需要确保目录权限正确
|
||||||
|
# 例如:- ./data:/app/data (需要 chmod 777 ./data)
|
||||||
environment:
|
environment:
|
||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
type AIServiceConfig struct {
|
type AIServiceConfig struct {
|
||||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
ServiceType string `gorm:"type:varchar(50);not null" json:"service_type"` // text, image, video
|
ServiceType string `gorm:"type:varchar(50);not null" json:"service_type"` // text, image, video
|
||||||
|
Provider string `gorm:"type:varchar(50)" json:"provider"` // openai, gemini, volcengine, etc.
|
||||||
Name string `gorm:"type:varchar(100);not null" json:"name"`
|
Name string `gorm:"type:varchar(100);not null" json:"name"`
|
||||||
BaseURL string `gorm:"type:varchar(255);not null" json:"base_url"`
|
BaseURL string `gorm:"type:varchar(255);not null" json:"base_url"`
|
||||||
APIKey string `gorm:"type:varchar(255);not null" json:"api_key"`
|
APIKey string `gorm:"type:varchar(255);not null" json:"api_key"`
|
||||||
|
|||||||
103
infrastructure/database/custom_logger.go
Normal file
103
infrastructure/database/custom_logger.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CustomLogger 自定义 GORM logger,截断过长的 SQL 参数(如 base64 数据)
|
||||||
|
type CustomLogger struct {
|
||||||
|
logger.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCustomLogger 创建自定义 logger
|
||||||
|
func NewCustomLogger() logger.Interface {
|
||||||
|
return &CustomLogger{
|
||||||
|
Interface: logger.Default.LogMode(logger.Silent),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace 重写 Trace 方法,禁用 SQL 日志输出
|
||||||
|
func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||||
|
// 不输出任何 SQL 日志
|
||||||
|
// 如果需要调试,可以临时取消注释下面的代码
|
||||||
|
/*
|
||||||
|
sql, rows := fc()
|
||||||
|
sql = truncateLongValues(sql)
|
||||||
|
elapsed := time.Since(begin)
|
||||||
|
if err != nil {
|
||||||
|
l.Interface.Error(ctx, "SQL error: %v [%v] %s", err, elapsed, sql)
|
||||||
|
} else {
|
||||||
|
l.Interface.Info(ctx, "[%.3fms] [rows:%d] %s", float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateLongValues 截断 SQL 中的长字符串值
|
||||||
|
func truncateLongValues(sql string) string {
|
||||||
|
// 查找 base64 格式的数据 (data:image/...;base64,...)
|
||||||
|
if strings.Contains(sql, "data:image/") && strings.Contains(sql, ";base64,") {
|
||||||
|
parts := strings.Split(sql, "\"")
|
||||||
|
for i, part := range parts {
|
||||||
|
if strings.HasPrefix(part, "data:image/") && strings.Contains(part, ";base64,") {
|
||||||
|
if len(part) > 100 {
|
||||||
|
// 保留前50字符,添加截断标记
|
||||||
|
parts[i] = part[:50] + "...[base64 data truncated]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sql = strings.Join(parts, "\"")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 截断其他过长的值
|
||||||
|
if len(sql) > 5000 {
|
||||||
|
// 查找 VALUES 或 SET 后的内容
|
||||||
|
if idx := strings.Index(sql, " VALUES "); idx > 0 && len(sql) > idx+5000 {
|
||||||
|
sql = sql[:idx+5000] + "...[truncated]"
|
||||||
|
} else if idx := strings.Index(sql, " SET "); idx > 0 && len(sql) > idx+3000 {
|
||||||
|
sql = sql[:idx+3000] + "...[truncated]"
|
||||||
|
} else if len(sql) > 5000 {
|
||||||
|
sql = sql[:5000] + "...[truncated]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info 实现 Info 方法
|
||||||
|
func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||||
|
l.Interface.Info(ctx, msg, data...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn 实现 Warn 方法
|
||||||
|
func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||||
|
l.Interface.Warn(ctx, msg, data...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error 实现 Error 方法
|
||||||
|
func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||||
|
// 检查并截断 data 中的长字符串
|
||||||
|
truncatedData := make([]interface{}, len(data))
|
||||||
|
for i, d := range data {
|
||||||
|
if str, ok := d.(string); ok && len(str) > 200 {
|
||||||
|
if strings.HasPrefix(str, "data:image/") {
|
||||||
|
truncatedData[i] = str[:50] + "...[base64 data]"
|
||||||
|
} else {
|
||||||
|
truncatedData[i] = str[:200] + "..."
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
truncatedData[i] = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.Interface.Error(ctx, msg, truncatedData...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogMode 实现 LogMode 方法
|
||||||
|
func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface {
|
||||||
|
newLogger := *l
|
||||||
|
newLogger.Interface = l.Interface.LogMode(level)
|
||||||
|
return &newLogger
|
||||||
|
}
|
||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
|
func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
|
||||||
@@ -25,7 +24,7 @@ func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
gormConfig := &gorm.Config{
|
gormConfig := &gorm.Config{
|
||||||
Logger: logger.Default.LogMode(logger.Info),
|
Logger: NewCustomLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
var db *gorm.DB
|
var db *gorm.DB
|
||||||
|
|||||||
@@ -445,12 +445,14 @@ CREATE INDEX IF NOT EXISTS idx_asset_collection_relations_collection_id ON asset
|
|||||||
CREATE TABLE IF NOT EXISTS ai_service_configs (
|
CREATE TABLE IF NOT EXISTS ai_service_configs (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
service_type TEXT NOT NULL, -- text, image, video
|
service_type TEXT NOT NULL, -- text, image, video
|
||||||
|
provider TEXT, -- openai, gemini, volcengine, etc.
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
base_url TEXT NOT NULL,
|
base_url TEXT NOT NULL,
|
||||||
api_key TEXT NOT NULL,
|
api_key TEXT NOT NULL,
|
||||||
model TEXT,
|
model TEXT,
|
||||||
endpoint TEXT,
|
endpoint TEXT,
|
||||||
query_endpoint TEXT,
|
query_endpoint TEXT,
|
||||||
|
priority INTEGER NOT NULL DEFAULT 0,
|
||||||
is_default INTEGER NOT NULL DEFAULT 0,
|
is_default INTEGER NOT NULL DEFAULT 0,
|
||||||
is_active INTEGER NOT NULL DEFAULT 1,
|
is_active INTEGER NOT NULL DEFAULT 1,
|
||||||
settings TEXT, -- JSON存储
|
settings TEXT, -- JSON存储
|
||||||
@@ -489,7 +491,8 @@ INSERT OR IGNORE INTO ai_service_providers (name, display_name, service_type, de
|
|||||||
('openai-dalle', 'OpenAI DALL-E', 'image', 'https://api.openai.com/v1', 'OpenAI DALL-E图片生成'),
|
('openai-dalle', 'OpenAI DALL-E', 'image', 'https://api.openai.com/v1', 'OpenAI DALL-E图片生成'),
|
||||||
('openai-sora', 'OpenAI Sora', 'video', 'https://api.openai.com/v1', 'OpenAI Sora视频生成'),
|
('openai-sora', 'OpenAI Sora', 'video', 'https://api.openai.com/v1', 'OpenAI Sora视频生成'),
|
||||||
('midjourney', 'Midjourney', 'image', '', 'Midjourney图片生成'),
|
('midjourney', 'Midjourney', 'image', '', 'Midjourney图片生成'),
|
||||||
('stable-diffusion', 'Stable Diffusion', 'image', '', 'Stable Diffusion图片生成'),
|
('doubao-image', '豆包(火山引擎)', 'image', 'https://ark.cn-beijing.volces.com', '火山引擎豆包图片生成'),
|
||||||
|
('gemini-image', 'Google Gemini', 'image', 'https://generativelanguage.googleapis.com', 'Google Gemini原生图片生成(base64)'),
|
||||||
('runway', 'Runway', 'video', '', 'Runway视频生成'),
|
('runway', 'Runway', 'video', '', 'Runway视频生成'),
|
||||||
('pika', 'Pika Labs', 'video', '', 'Pika视频生成'),
|
('pika', 'Pika Labs', 'video', '', 'Pika视频生成'),
|
||||||
('doubao', '豆包(火山引擎)', 'video', 'https://ark.cn-beijing.volces.com', '火山引擎豆包视频生成'),
|
('doubao', '豆包(火山引擎)', 'video', 'https://ark.cn-beijing.volces.com', '火山引擎豆包视频生成'),
|
||||||
|
|||||||
7
pkg/ai/client.go
Normal file
7
pkg/ai/client.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
// AIClient 定义文本生成客户端接口
|
||||||
|
type AIClient interface {
|
||||||
|
GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error)
|
||||||
|
TestConnection() error
|
||||||
|
}
|
||||||
195
pkg/ai/gemini_client.go
Normal file
195
pkg/ai/gemini_client.go
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GeminiClient struct {
|
||||||
|
BaseURL string
|
||||||
|
APIKey string
|
||||||
|
Model string
|
||||||
|
Endpoint string
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiTextRequest struct {
|
||||||
|
Contents []GeminiContent `json:"contents"`
|
||||||
|
SystemInstruction *GeminiInstruction `json:"systemInstruction,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiContent struct {
|
||||||
|
Parts []GeminiPart `json:"parts"`
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiPart struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiInstruction struct {
|
||||||
|
Parts []GeminiPart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiTextResponse struct {
|
||||||
|
Candidates []struct {
|
||||||
|
Content struct {
|
||||||
|
Parts []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"parts"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
} `json:"content"`
|
||||||
|
FinishReason string `json:"finishReason"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
SafetyRatings []struct {
|
||||||
|
Category string `json:"category"`
|
||||||
|
Probability string `json:"probability"`
|
||||||
|
} `json:"safetyRatings"`
|
||||||
|
} `json:"candidates"`
|
||||||
|
UsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
|
} `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGeminiClient(baseURL, apiKey, model, endpoint string) *GeminiClient {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
|
}
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = "/v1beta/models/{model}:generateContent"
|
||||||
|
}
|
||||||
|
if model == "" {
|
||||||
|
model = "gemini-3-pro"
|
||||||
|
}
|
||||||
|
return &GeminiClient{
|
||||||
|
BaseURL: baseURL,
|
||||||
|
APIKey: apiKey,
|
||||||
|
Model: model,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
HTTPClient: &http.Client{
|
||||||
|
Timeout: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GeminiClient) GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) {
|
||||||
|
model := c.Model
|
||||||
|
|
||||||
|
// 构建请求体
|
||||||
|
reqBody := GeminiTextRequest{
|
||||||
|
Contents: []GeminiContent{
|
||||||
|
{
|
||||||
|
Parts: []GeminiPart{{Text: prompt}},
|
||||||
|
Role: "user",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 systemInstruction 字段处理系统提示
|
||||||
|
if systemPrompt != "" {
|
||||||
|
reqBody.SystemInstruction = &GeminiInstruction{
|
||||||
|
Parts: []GeminiPart{{Text: systemPrompt}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Gemini: Failed to marshal request: %v\n", err)
|
||||||
|
return "", fmt.Errorf("marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 替换端点中的 {model} 占位符
|
||||||
|
endpoint := c.BaseURL + c.Endpoint
|
||||||
|
endpoint = strings.ReplaceAll(endpoint, "{model}", model)
|
||||||
|
url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey)
|
||||||
|
|
||||||
|
// 打印请求信息(隐藏 API Key)
|
||||||
|
safeURL := strings.Replace(url, c.APIKey, "***", 1)
|
||||||
|
fmt.Printf("Gemini: Sending request to: %s\n", safeURL)
|
||||||
|
requestPreview := string(jsonData)
|
||||||
|
if len(jsonData) > 300 {
|
||||||
|
requestPreview = string(jsonData[:300]) + "..."
|
||||||
|
}
|
||||||
|
fmt.Printf("Gemini: Request body: %s\n", requestPreview)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Gemini: Failed to create request: %v\n", err)
|
||||||
|
return "", fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
fmt.Printf("Gemini: Executing HTTP request...\n")
|
||||||
|
resp, err := c.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Gemini: HTTP request failed: %v\n", err)
|
||||||
|
return "", fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
fmt.Printf("Gemini: Received response with status: %d\n", resp.StatusCode)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Gemini: Failed to read response body: %v\n", err)
|
||||||
|
return "", fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
fmt.Printf("Gemini: API error (status %d): %s\n", resp.StatusCode, string(body))
|
||||||
|
return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 打印响应体用于调试
|
||||||
|
bodyPreview := string(body)
|
||||||
|
if len(body) > 500 {
|
||||||
|
bodyPreview = string(body[:500]) + "..."
|
||||||
|
}
|
||||||
|
fmt.Printf("Gemini: Response body: %s\n", bodyPreview)
|
||||||
|
|
||||||
|
var result GeminiTextResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
errorPreview := string(body)
|
||||||
|
if len(body) > 200 {
|
||||||
|
errorPreview = string(body[:200])
|
||||||
|
}
|
||||||
|
fmt.Printf("Gemini: Failed to parse response: %v\n", err)
|
||||||
|
return "", fmt.Errorf("parse response: %w, body preview: %s", err, errorPreview)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Gemini: Successfully parsed response, candidates count: %d\n", len(result.Candidates))
|
||||||
|
|
||||||
|
if len(result.Candidates) == 0 {
|
||||||
|
fmt.Printf("Gemini: No candidates in response\n")
|
||||||
|
return "", fmt.Errorf("no candidates in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Candidates[0].Content.Parts) == 0 {
|
||||||
|
fmt.Printf("Gemini: No parts in first candidate\n")
|
||||||
|
return "", fmt.Errorf("no parts in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
responseText := result.Candidates[0].Content.Parts[0].Text
|
||||||
|
fmt.Printf("Gemini: Generated text: %s\n", responseText)
|
||||||
|
|
||||||
|
return responseText, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GeminiClient) TestConnection() error {
|
||||||
|
fmt.Printf("Gemini: TestConnection called with BaseURL=%s, Model=%s, Endpoint=%s\n", c.BaseURL, c.Model, c.Endpoint)
|
||||||
|
_, err := c.GenerateText("Hello", "")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Gemini: TestConnection failed: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Gemini: TestConnection succeeded\n")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -91,30 +91,48 @@ func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*C
|
|||||||
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Printf("OpenAI: Failed to marshal request: %v\n", err)
|
||||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
url := c.BaseURL + c.Endpoint
|
url := c.BaseURL + c.Endpoint
|
||||||
|
|
||||||
|
// 打印请求信息
|
||||||
|
fmt.Printf("OpenAI: Sending request to: %s\n", url)
|
||||||
|
fmt.Printf("OpenAI: BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model)
|
||||||
|
requestPreview := string(jsonData)
|
||||||
|
if len(jsonData) > 300 {
|
||||||
|
requestPreview = string(jsonData[:300]) + "..."
|
||||||
|
}
|
||||||
|
fmt.Printf("OpenAI: Request body: %s\n", requestPreview)
|
||||||
|
|
||||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Printf("OpenAI: Failed to create request: %v\n", err)
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+c.APIKey)
|
httpReq.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
|
||||||
|
fmt.Printf("OpenAI: Executing HTTP request...\n")
|
||||||
resp, err := c.HTTPClient.Do(httpReq)
|
resp, err := c.HTTPClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Printf("OpenAI: HTTP request failed: %v\n", err)
|
||||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
fmt.Printf("OpenAI: Received response with status: %d\n", resp.StatusCode)
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Printf("OpenAI: Failed to read response body: %v\n", err)
|
||||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
fmt.Printf("OpenAI: API error (status %d): %s\n", resp.StatusCode, string(body))
|
||||||
var errResp ErrorResponse
|
var errResp ErrorResponse
|
||||||
if err := json.Unmarshal(body, &errResp); err != nil {
|
if err := json.Unmarshal(body, &errResp); err != nil {
|
||||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||||
@@ -122,11 +140,25 @@ func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatComplet
|
|||||||
return nil, fmt.Errorf("API error: %s", errResp.Error.Message)
|
return nil, fmt.Errorf("API error: %s", errResp.Error.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 打印响应体用于调试
|
||||||
|
bodyPreview := string(body)
|
||||||
|
if len(body) > 500 {
|
||||||
|
bodyPreview = string(body[:500]) + "..."
|
||||||
|
}
|
||||||
|
fmt.Printf("OpenAI: Response body: %s\n", bodyPreview)
|
||||||
|
|
||||||
var chatResp ChatCompletionResponse
|
var chatResp ChatCompletionResponse
|
||||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
errorPreview := string(body)
|
||||||
|
if len(body) > 200 {
|
||||||
|
errorPreview = string(body[:200])
|
||||||
|
}
|
||||||
|
fmt.Printf("OpenAI: Failed to parse response: %v\n", err)
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal response: %w, body preview: %s", err, errorPreview)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Printf("OpenAI: Successfully parsed response, choices count: %d\n", len(chatResp.Choices))
|
||||||
|
|
||||||
return &chatResp, nil
|
return &chatResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,6 +208,8 @@ func (c *OpenAIClient) GenerateText(prompt string, systemPrompt string, options
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) TestConnection() error {
|
func (c *OpenAIClient) TestConnection() error {
|
||||||
|
fmt.Printf("OpenAI: TestConnection called with BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model)
|
||||||
|
|
||||||
messages := []ChatMessage{
|
messages := []ChatMessage{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
@@ -184,5 +218,10 @@ func (c *OpenAIClient) TestConnection() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := c.ChatCompletion(messages, WithMaxTokens(10))
|
_, err := c.ChatCompletion(messages, WithMaxTokens(10))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("OpenAI: TestConnection failed: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("OpenAI: TestConnection succeeded\n")
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
277
pkg/image/gemini_image_client.go
Normal file
277
pkg/image/gemini_image_client.go
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
package image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GeminiImageClient struct {
|
||||||
|
BaseURL string
|
||||||
|
APIKey string
|
||||||
|
Model string
|
||||||
|
Endpoint string
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiImageRequest struct {
|
||||||
|
Contents []struct {
|
||||||
|
Parts []GeminiPart `json:"parts"`
|
||||||
|
} `json:"contents"`
|
||||||
|
GenerationConfig struct {
|
||||||
|
ResponseModalities []string `json:"responseModalities"`
|
||||||
|
} `json:"generationConfig"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiPart struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiInlineData struct {
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Data string `json:"data"` // base64 编码的图片数据
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiImageResponse struct {
|
||||||
|
Candidates []struct {
|
||||||
|
Content struct {
|
||||||
|
Parts []struct {
|
||||||
|
InlineData struct {
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
} `json:"inlineData,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
} `json:"parts"`
|
||||||
|
} `json:"content"`
|
||||||
|
} `json:"candidates"`
|
||||||
|
UsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
|
} `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// downloadImageToBase64 下载图片 URL 并转换为 base64
|
||||||
|
func downloadImageToBase64(imageURL string) (string, string, error) {
|
||||||
|
resp, err := http.Get(imageURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("download image: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", "", fmt.Errorf("download image failed with status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
imageData, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("read image data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据 Content-Type 确定 mimeType
|
||||||
|
mimeType := resp.Header.Get("Content-Type")
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "image/jpeg"
|
||||||
|
}
|
||||||
|
|
||||||
|
base64Data := base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
return base64Data, mimeType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGeminiImageClient(baseURL, apiKey, model, endpoint string) *GeminiImageClient {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
|
}
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = "/v1beta/models/{model}:generateContent"
|
||||||
|
}
|
||||||
|
if model == "" {
|
||||||
|
model = "gemini-3-pro-image-preview"
|
||||||
|
}
|
||||||
|
return &GeminiImageClient{
|
||||||
|
BaseURL: baseURL,
|
||||||
|
APIKey: apiKey,
|
||||||
|
Model: model,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
HTTPClient: &http.Client{
|
||||||
|
Timeout: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GeminiImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
||||||
|
options := &ImageOptions{
|
||||||
|
Size: "1024x1024",
|
||||||
|
Quality: "standard",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
model := c.Model
|
||||||
|
if options.Model != "" {
|
||||||
|
model = options.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
promptText := prompt
|
||||||
|
if options.NegativePrompt != "" {
|
||||||
|
promptText += fmt.Sprintf("\n\nNegative prompt: %s", options.NegativePrompt)
|
||||||
|
}
|
||||||
|
if options.Size != "" {
|
||||||
|
promptText += fmt.Sprintf("\n\nImage size: %s", options.Size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建请求的 parts,支持参考图
|
||||||
|
parts := []GeminiPart{}
|
||||||
|
|
||||||
|
// 如果有参考图,先添加参考图
|
||||||
|
if len(options.ReferenceImages) > 0 {
|
||||||
|
for _, refImg := range options.ReferenceImages {
|
||||||
|
var base64Data string
|
||||||
|
var mimeType string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 检查是否是 HTTP/HTTPS URL
|
||||||
|
if strings.HasPrefix(refImg, "http://") || strings.HasPrefix(refImg, "https://") {
|
||||||
|
// 下载图片并转换为 base64
|
||||||
|
base64Data, mimeType, err = downloadImageToBase64(refImg)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(refImg, "data:") {
|
||||||
|
// 如果是 data URI 格式,需要解析
|
||||||
|
// 格式: data:image/jpeg;base64,xxxxx
|
||||||
|
mimeType = "image/jpeg"
|
||||||
|
parts := []byte(refImg)
|
||||||
|
for i := 0; i < len(parts); i++ {
|
||||||
|
if parts[i] == ',' {
|
||||||
|
base64Data = refImg[i+1:]
|
||||||
|
// 提取 mime type
|
||||||
|
if i > 11 {
|
||||||
|
mimeTypeEnd := i
|
||||||
|
for j := 5; j < i; j++ {
|
||||||
|
if parts[j] == ';' {
|
||||||
|
mimeTypeEnd = j
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mimeType = refImg[5:mimeTypeEnd]
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 假设已经是 base64 编码
|
||||||
|
base64Data = refImg
|
||||||
|
mimeType = "image/jpeg"
|
||||||
|
}
|
||||||
|
|
||||||
|
if base64Data != "" {
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: base64Data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加文本提示词
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
Text: promptText,
|
||||||
|
})
|
||||||
|
|
||||||
|
reqBody := GeminiImageRequest{
|
||||||
|
Contents: []struct {
|
||||||
|
Parts []GeminiPart `json:"parts"`
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Parts: parts,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
GenerationConfig: struct {
|
||||||
|
ResponseModalities []string `json:"responseModalities"`
|
||||||
|
}{
|
||||||
|
ResponseModalities: []string{"IMAGE"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := c.BaseURL + c.Endpoint
|
||||||
|
endpoint = replaceModelPlaceholder(endpoint, model)
|
||||||
|
url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyStr := string(body)
|
||||||
|
if len(bodyStr) > 1000 {
|
||||||
|
bodyStr = fmt.Sprintf("%s ... %s", bodyStr[:500], bodyStr[len(bodyStr)-500:])
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, bodyStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result GeminiImageResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
|
||||||
|
return nil, fmt.Errorf("no image generated in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
base64Data := result.Candidates[0].Content.Parts[0].InlineData.Data
|
||||||
|
if base64Data == "" {
|
||||||
|
return nil, fmt.Errorf("no base64 image data in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
dataURI := fmt.Sprintf("data:image/jpeg;base64,%s", base64Data)
|
||||||
|
|
||||||
|
return &ImageResult{
|
||||||
|
Status: "completed",
|
||||||
|
ImageURL: dataURI,
|
||||||
|
Completed: true,
|
||||||
|
Width: 1024,
|
||||||
|
Height: 1024,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GeminiImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
||||||
|
return nil, fmt.Errorf("not supported for Gemini (synchronous generation)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func replaceModelPlaceholder(endpoint, model string) string {
|
||||||
|
result := endpoint
|
||||||
|
if bytes.Contains([]byte(result), []byte("{model}")) {
|
||||||
|
result = string(bytes.ReplaceAll([]byte(result), []byte("{model}"), []byte(model)))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -1,14 +1,5 @@
|
|||||||
package image
|
package image
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ImageClient interface {
|
type ImageClient interface {
|
||||||
GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error)
|
GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error)
|
||||||
GetTaskStatus(taskID string) (*ImageResult, error)
|
GetTaskStatus(taskID string) (*ImageResult, error)
|
||||||
@@ -100,285 +91,3 @@ func WithReferenceImages(images []string) ImageOption {
|
|||||||
o.ReferenceImages = images
|
o.ReferenceImages = images
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIImageClient struct {
|
|
||||||
BaseURL string
|
|
||||||
APIKey string
|
|
||||||
Model string
|
|
||||||
HTTPClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
type DALLERequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Quality string `json:"quality,omitempty"`
|
|
||||||
N int `json:"n"`
|
|
||||||
Image []string `json:"image,omitempty"` // 参考图片URL列表
|
|
||||||
}
|
|
||||||
|
|
||||||
type DALLEResponse struct {
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Data []struct {
|
|
||||||
URL string `json:"url"`
|
|
||||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOpenAIImageClient(baseURL, apiKey, model string) *OpenAIImageClient {
|
|
||||||
return &OpenAIImageClient{
|
|
||||||
BaseURL: baseURL,
|
|
||||||
APIKey: apiKey,
|
|
||||||
Model: model,
|
|
||||||
HTTPClient: &http.Client{
|
|
||||||
Timeout: 10 * time.Minute,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
|
||||||
options := &ImageOptions{
|
|
||||||
Size: "1920x1920",
|
|
||||||
Quality: "standard",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(options)
|
|
||||||
}
|
|
||||||
|
|
||||||
model := c.Model
|
|
||||||
if options.Model != "" {
|
|
||||||
model = options.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
reqBody := DALLERequest{
|
|
||||||
Model: model,
|
|
||||||
Prompt: prompt,
|
|
||||||
Size: options.Size,
|
|
||||||
Quality: options.Quality,
|
|
||||||
N: 1,
|
|
||||||
Image: options.ReferenceImages,
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(reqBody)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshal request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint := c.BaseURL + "/v1/images/generations"
|
|
||||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
|
||||||
|
|
||||||
resp, err := c.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("send request: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 打印原始响应以便调试
|
|
||||||
fmt.Printf("OpenAI API Response: %s\n", string(body))
|
|
||||||
|
|
||||||
var result DALLEResponse
|
|
||||||
if err := json.Unmarshal(body, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(result.Data) == 0 {
|
|
||||||
return nil, fmt.Errorf("no image generated, response: %s", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ImageResult{
|
|
||||||
Status: "completed",
|
|
||||||
ImageURL: result.Data[0].URL,
|
|
||||||
Completed: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
|
||||||
return nil, fmt.Errorf("not supported for OpenAI/DALL-E")
|
|
||||||
}
|
|
||||||
|
|
||||||
type StableDiffusionClient struct {
|
|
||||||
BaseURL string
|
|
||||||
APIKey string
|
|
||||||
Model string
|
|
||||||
HTTPClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
type SDRequest struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Width int `json:"width,omitempty"`
|
|
||||||
Height int `json:"height,omitempty"`
|
|
||||||
Steps int `json:"steps,omitempty"`
|
|
||||||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
|
||||||
Seed int64 `json:"seed,omitempty"`
|
|
||||||
Samples int `json:"samples"`
|
|
||||||
Image []string `json:"image,omitempty"` // 参考图片URL列表
|
|
||||||
}
|
|
||||||
|
|
||||||
type SDResponse struct {
|
|
||||||
Status string `json:"status"`
|
|
||||||
TaskID string `json:"task_id,omitempty"`
|
|
||||||
Output []struct {
|
|
||||||
URL string `json:"url"`
|
|
||||||
} `json:"output,omitempty"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStableDiffusionClient(baseURL, apiKey, model string) *StableDiffusionClient {
|
|
||||||
return &StableDiffusionClient{
|
|
||||||
BaseURL: baseURL,
|
|
||||||
APIKey: apiKey,
|
|
||||||
Model: model,
|
|
||||||
HTTPClient: &http.Client{
|
|
||||||
Timeout: 10 * time.Minute,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *StableDiffusionClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
|
||||||
options := &ImageOptions{
|
|
||||||
Width: 1024,
|
|
||||||
Height: 1024,
|
|
||||||
Steps: 30,
|
|
||||||
CfgScale: 7.5,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(options)
|
|
||||||
}
|
|
||||||
|
|
||||||
model := c.Model
|
|
||||||
if options.Model != "" {
|
|
||||||
model = options.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
reqBody := SDRequest{
|
|
||||||
Prompt: prompt,
|
|
||||||
NegativePrompt: options.NegativePrompt,
|
|
||||||
Model: model,
|
|
||||||
Width: options.Width,
|
|
||||||
Height: options.Height,
|
|
||||||
Steps: options.Steps,
|
|
||||||
CfgScale: options.CfgScale,
|
|
||||||
Seed: options.Seed,
|
|
||||||
Samples: 1,
|
|
||||||
Image: options.ReferenceImages,
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(reqBody)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshal request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint := c.BaseURL + "/v1/images/generations"
|
|
||||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
|
||||||
|
|
||||||
resp, err := c.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("send request: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
var result SDResponse
|
|
||||||
if err := json.Unmarshal(body, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Error != "" {
|
|
||||||
return nil, fmt.Errorf("SD error: %s", result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Status == "processing" {
|
|
||||||
return &ImageResult{
|
|
||||||
TaskID: result.TaskID,
|
|
||||||
Status: "processing",
|
|
||||||
Completed: false,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(result.Output) == 0 {
|
|
||||||
return nil, fmt.Errorf("no image generated")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ImageResult{
|
|
||||||
Status: "completed",
|
|
||||||
ImageURL: result.Output[0].URL,
|
|
||||||
Width: options.Width,
|
|
||||||
Height: options.Height,
|
|
||||||
Completed: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *StableDiffusionClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
|
||||||
endpoint := c.BaseURL + "/v1/images/status/" + taskID
|
|
||||||
req, err := http.NewRequest("GET", endpoint, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
|
||||||
|
|
||||||
resp, err := c.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("send request: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result SDResponse
|
|
||||||
if err := json.Unmarshal(body, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
imageResult := &ImageResult{
|
|
||||||
TaskID: taskID,
|
|
||||||
Status: result.Status,
|
|
||||||
Completed: result.Status == "completed",
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Error != "" {
|
|
||||||
imageResult.Error = result.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(result.Output) > 0 {
|
|
||||||
imageResult.ImageURL = result.Output[0].URL
|
|
||||||
}
|
|
||||||
|
|
||||||
return imageResult, nil
|
|
||||||
}
|
|
||||||
|
|||||||
128
pkg/image/openai_image_client.go
Normal file
128
pkg/image/openai_image_client.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIImageClient struct {
|
||||||
|
BaseURL string
|
||||||
|
APIKey string
|
||||||
|
Model string
|
||||||
|
Endpoint string
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type DALLERequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
N int `json:"n"`
|
||||||
|
Image []string `json:"image,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DALLEResponse struct {
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenAIImageClient(baseURL, apiKey, model, endpoint string) *OpenAIImageClient {
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = "/v1/images/generations"
|
||||||
|
}
|
||||||
|
return &OpenAIImageClient{
|
||||||
|
BaseURL: baseURL,
|
||||||
|
APIKey: apiKey,
|
||||||
|
Model: model,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
HTTPClient: &http.Client{
|
||||||
|
Timeout: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
||||||
|
options := &ImageOptions{
|
||||||
|
Size: "1920x1920",
|
||||||
|
Quality: "standard",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
model := c.Model
|
||||||
|
if options.Model != "" {
|
||||||
|
model = options.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := DALLERequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Size: options.Size,
|
||||||
|
Quality: options.Quality,
|
||||||
|
N: 1,
|
||||||
|
Image: options.ReferenceImages,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := c.BaseURL + c.Endpoint
|
||||||
|
fmt.Printf("[OpenAI Image] Request URL: %s\n", url)
|
||||||
|
fmt.Printf("[OpenAI Image] Request Body: %s\n", string(jsonData))
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
|
||||||
|
resp, err := c.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("OpenAI API Response: %s\n", string(body))
|
||||||
|
|
||||||
|
var result DALLEResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Data) == 0 {
|
||||||
|
return nil, fmt.Errorf("no image generated, response: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ImageResult{
|
||||||
|
Status: "completed",
|
||||||
|
ImageURL: result.Data[0].URL,
|
||||||
|
Completed: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
||||||
|
return nil, fmt.Errorf("not supported for OpenAI/DALL-E")
|
||||||
|
}
|
||||||
158
pkg/image/volcengine_image_client.go
Normal file
158
pkg/image/volcengine_image_client.go
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
package image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VolcEngineImageClient struct {
|
||||||
|
BaseURL string
|
||||||
|
APIKey string
|
||||||
|
Model string
|
||||||
|
Endpoint string
|
||||||
|
QueryEndpoint string
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type VolcEngineImageRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Image []string `json:"image,omitempty"`
|
||||||
|
SequentialImageGeneration string `json:"sequential_image_generation,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Watermark bool `json:"watermark,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VolcEngineImageResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Size string `json:"size"`
|
||||||
|
} `json:"data"`
|
||||||
|
Usage struct {
|
||||||
|
GeneratedImages int `json:"generated_images"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
Error interface{} `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewVolcEngineImageClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *VolcEngineImageClient {
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = "/api/v3/images/generations"
|
||||||
|
}
|
||||||
|
if queryEndpoint == "" {
|
||||||
|
queryEndpoint = endpoint
|
||||||
|
}
|
||||||
|
return &VolcEngineImageClient{
|
||||||
|
BaseURL: baseURL,
|
||||||
|
APIKey: apiKey,
|
||||||
|
Model: model,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
QueryEndpoint: queryEndpoint,
|
||||||
|
HTTPClient: &http.Client{
|
||||||
|
Timeout: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *VolcEngineImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
||||||
|
options := &ImageOptions{
|
||||||
|
Size: "1024x1024",
|
||||||
|
Quality: "standard",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
model := c.Model
|
||||||
|
if options.Model != "" {
|
||||||
|
model = options.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
promptText := prompt
|
||||||
|
if options.NegativePrompt != "" {
|
||||||
|
promptText += fmt.Sprintf(". Negative: %s", options.NegativePrompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := options.Size
|
||||||
|
if size == "" {
|
||||||
|
if model == "doubao-seedream-4-5-251128" {
|
||||||
|
size = "2K"
|
||||||
|
} else {
|
||||||
|
size = "1K"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := VolcEngineImageRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: promptText,
|
||||||
|
Image: options.ReferenceImages,
|
||||||
|
SequentialImageGeneration: "disabled",
|
||||||
|
Size: size,
|
||||||
|
Watermark: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := c.BaseURL + c.Endpoint
|
||||||
|
fmt.Printf("[VolcEngine Image] Request URL: %s\n", url)
|
||||||
|
fmt.Printf("[VolcEngine Image] Request Body: %s\n", string(jsonData))
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
|
||||||
|
resp, err := c.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("VolcEngine Image API Response: %s\n", string(body))
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||||
|
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result VolcEngineImageResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, fmt.Errorf("volcengine error: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Data) == 0 {
|
||||||
|
return nil, fmt.Errorf("no image generated")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ImageResult{
|
||||||
|
Status: "completed",
|
||||||
|
ImageURL: result.Data[0].URL,
|
||||||
|
Completed: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *VolcEngineImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
||||||
|
return nil, fmt.Errorf("not supported for VolcEngine Seedream (synchronous generation)")
|
||||||
|
}
|
||||||
184
pkg/video/chatfire_client.go
Normal file
184
pkg/video/chatfire_client.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
package video
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatfireClient Chatfire 视频生成客户端
|
||||||
|
type ChatfireClient struct {
|
||||||
|
BaseURL string
|
||||||
|
APIKey string
|
||||||
|
Model string
|
||||||
|
Endpoint string
|
||||||
|
QueryEndpoint string
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatfireRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
ImageURL string `json:"image_url,omitempty"`
|
||||||
|
Duration int `json:"duration,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatfireResponse struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatfireTaskResponse struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
VideoURL string `json:"video_url,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *ChatfireClient {
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = "/video/generations"
|
||||||
|
}
|
||||||
|
if queryEndpoint == "" {
|
||||||
|
queryEndpoint = "/v1/video/task/{taskId}"
|
||||||
|
}
|
||||||
|
return &ChatfireClient{
|
||||||
|
BaseURL: baseURL,
|
||||||
|
APIKey: apiKey,
|
||||||
|
Model: model,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
QueryEndpoint: queryEndpoint,
|
||||||
|
HTTPClient: &http.Client{
|
||||||
|
Timeout: 300 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatfireClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||||
|
options := &VideoOptions{
|
||||||
|
Duration: 5,
|
||||||
|
AspectRatio: "16:9",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
model := c.Model
|
||||||
|
if options.Model != "" {
|
||||||
|
model = options.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := ChatfireRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
ImageURL: imageURL,
|
||||||
|
Duration: options.Duration,
|
||||||
|
Size: options.AspectRatio,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := c.BaseURL + c.Endpoint
|
||||||
|
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
|
||||||
|
resp, err := c.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||||
|
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result ChatfireResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error != "" {
|
||||||
|
return nil, fmt.Errorf("chatfire error: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
videoResult := &VideoResult{
|
||||||
|
TaskID: result.TaskID,
|
||||||
|
Status: result.Status,
|
||||||
|
Completed: result.Status == "completed" || result.Status == "succeeded",
|
||||||
|
Duration: options.Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
return videoResult, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatfireClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||||
|
queryPath := c.QueryEndpoint
|
||||||
|
if strings.Contains(queryPath, "{taskId}") {
|
||||||
|
queryPath = strings.ReplaceAll(queryPath, "{taskId}", taskID)
|
||||||
|
} else if strings.Contains(queryPath, "{task_id}") {
|
||||||
|
queryPath = strings.ReplaceAll(queryPath, "{task_id}", taskID)
|
||||||
|
} else {
|
||||||
|
queryPath = queryPath + "/" + taskID
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := c.BaseURL + queryPath
|
||||||
|
req, err := http.NewRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
|
||||||
|
resp, err := c.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result ChatfireTaskResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
videoResult := &VideoResult{
|
||||||
|
TaskID: result.TaskID,
|
||||||
|
Status: result.Status,
|
||||||
|
Completed: result.Status == "completed" || result.Status == "succeeded",
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error != "" {
|
||||||
|
videoResult.Error = result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.VideoURL != "" {
|
||||||
|
videoResult.VideoURL = result.VideoURL
|
||||||
|
videoResult.Completed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return videoResult, nil
|
||||||
|
}
|
||||||
@@ -83,7 +83,7 @@ func (c *OpenAISoraClient) GenerateVideo(imageURL, prompt string, opts ...VideoO
|
|||||||
|
|
||||||
writer.Close()
|
writer.Close()
|
||||||
|
|
||||||
endpoint := c.BaseURL + "/v1/videos"
|
endpoint := c.BaseURL + "/videos"
|
||||||
req, err := http.NewRequest("POST", endpoint, body)
|
req, err := http.NewRequest("POST", endpoint, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create request: %w", err)
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
</template>
|
</template>
|
||||||
</el-page-header>
|
</el-page-header>
|
||||||
|
|
||||||
<el-tabs v-model="activeTab" @tab-change="loadConfigs">
|
<el-tabs v-model="activeTab" @tab-change="handleTabChange">
|
||||||
<el-tab-pane label="文本生成" name="text">
|
<el-tab-pane label="文本生成" name="text">
|
||||||
<ConfigList
|
<ConfigList
|
||||||
:configs="configs"
|
:configs="configs"
|
||||||
@@ -114,7 +114,11 @@
|
|||||||
|
|
||||||
<el-form-item label="Base URL" prop="base_url">
|
<el-form-item label="Base URL" prop="base_url">
|
||||||
<el-input v-model="form.base_url" placeholder="https://api.openai.com" />
|
<el-input v-model="form.base_url" placeholder="https://api.openai.com" />
|
||||||
<div class="form-tip">API 服务的基础地址</div>
|
<div class="form-tip">
|
||||||
|
API 服务的基础地址,如 Chatfire: https://api.chatfire.site/v1,Gemini: https://generativelanguage.googleapis.com(无需 /v1)
|
||||||
|
<br>
|
||||||
|
完整调用路径: {{ fullEndpointExample }}
|
||||||
|
</div>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
|
||||||
<el-form-item label="API Key" prop="api_key">
|
<el-form-item label="API Key" prop="api_key">
|
||||||
@@ -127,16 +131,6 @@
|
|||||||
<div class="form-tip">您的 API 密钥</div>
|
<div class="form-tip">您的 API 密钥</div>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
|
||||||
<el-form-item label="端点路径" prop="endpoint">
|
|
||||||
<el-input v-model="form.endpoint" placeholder="/v1/chat/completions" />
|
|
||||||
<div class="form-tip">API 端点路径,默认为 /v1/chat/completions</div>
|
|
||||||
</el-form-item>
|
|
||||||
|
|
||||||
<el-form-item v-if="form.service_type === 'video'" label="查询端点" prop="query_endpoint">
|
|
||||||
<el-input v-model="form.query_endpoint" placeholder="/v1/video/task/{taskId}" />
|
|
||||||
<div class="form-tip">异步任务查询端点(仅视频服务需要),支持 {taskId} 占位符</div>
|
|
||||||
</el-form-item>
|
|
||||||
|
|
||||||
<el-form-item v-if="isEdit" label="启用状态">
|
<el-form-item v-if="isEdit" label="启用状态">
|
||||||
<el-switch v-model="form.is_active" />
|
<el-switch v-model="form.is_active" />
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
@@ -181,8 +175,6 @@ const form = reactive<CreateAIConfigRequest & { is_active?: boolean, provider?:
|
|||||||
base_url: '',
|
base_url: '',
|
||||||
api_key: '',
|
api_key: '',
|
||||||
model: [], // 改为数组支持多选
|
model: [], // 改为数组支持多选
|
||||||
endpoint: '/v1/chat/completions',
|
|
||||||
query_endpoint: '', // 异步查询端点
|
|
||||||
priority: 0, // 默认优先级为0
|
priority: 0, // 默认优先级为0
|
||||||
is_active: true
|
is_active: true
|
||||||
})
|
})
|
||||||
@@ -197,10 +189,54 @@ interface ProviderConfig {
|
|||||||
|
|
||||||
const providerConfigs: Record<AIServiceType, ProviderConfig[]> = {
|
const providerConfigs: Record<AIServiceType, ProviderConfig[]> = {
|
||||||
text: [
|
text: [
|
||||||
{ id: 'openai', name: 'OpenAI', models: ['gpt-5.2', 'gemini-3-pro-preview'], disabled: true }
|
{ id: 'openai', name: 'OpenAI', models: ['gpt-5.2', 'gemini-3-pro-preview'] },
|
||||||
|
{
|
||||||
|
id: 'chatfire',
|
||||||
|
name: 'Chatfire',
|
||||||
|
models: [
|
||||||
|
'gpt-4o',
|
||||||
|
'claude-sonnet-4-5-20250929',
|
||||||
|
'doubao-seed-1-8-251228',
|
||||||
|
'kimi-k2-thinking',
|
||||||
|
'gemini-3-pro',
|
||||||
|
'gemini-2.5-pro',
|
||||||
|
'gemini-3-pro-preview'
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'gemini',
|
||||||
|
name: 'Google Gemini',
|
||||||
|
models: [
|
||||||
|
'gemini-2.5-pro',
|
||||||
|
'gemini-3-pro-preview'
|
||||||
|
]
|
||||||
|
}
|
||||||
],
|
],
|
||||||
image: [
|
image: [
|
||||||
{ id: 'openai', name: 'OpenAI', models: ['nano-banana-pro', 'doubao-seedream-4-5-251128'] }
|
{
|
||||||
|
id: 'volcengine',
|
||||||
|
name: '火山引擎',
|
||||||
|
models: [
|
||||||
|
'doubao-seedream-4-5-251128',
|
||||||
|
'doubao-seedream-4-0-250828',
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'chatfire',
|
||||||
|
name: 'Chatfire',
|
||||||
|
models: [
|
||||||
|
'doubao-seedream-4-5-251128',
|
||||||
|
'nano-banana-pro',
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'gemini',
|
||||||
|
name: 'Google Gemini',
|
||||||
|
models: [
|
||||||
|
'gemini-3-pro-image-preview',
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{ id: 'openai', name: 'OpenAI', models: ['dall-e-3', 'dall-e-2'] }
|
||||||
],
|
],
|
||||||
video: [
|
video: [
|
||||||
{
|
{
|
||||||
@@ -214,6 +250,19 @@ const providerConfigs: Record<AIServiceType, ProviderConfig[]> = {
|
|||||||
'doubao-seedance-1-0-pro-fast-251015'
|
'doubao-seedance-1-0-pro-fast-251015'
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: 'chatfire',
|
||||||
|
name: 'Chatfire',
|
||||||
|
models: [
|
||||||
|
'doubao-seedance-1-5-pro-251215',
|
||||||
|
'doubao-seedance-1-0-lite-i2v-250428',
|
||||||
|
'doubao-seedance-1-0-lite-t2v-250428',
|
||||||
|
'doubao-seedance-1-0-pro-250528',
|
||||||
|
'doubao-seedance-1-0-pro-fast-251015',
|
||||||
|
'sora',
|
||||||
|
'sora-pro'
|
||||||
|
]
|
||||||
|
},
|
||||||
{ id: 'openai', name: 'OpenAI', models: ['sora-2', 'sora-2-pro'] },
|
{ id: 'openai', name: 'OpenAI', models: ['sora-2', 'sora-2-pro'] },
|
||||||
// { id: 'minimax', name: 'MiniMax', models: ['MiniMax-Hailuo-2.3', 'MiniMax-Hailuo-2.3-Fast', 'MiniMax-Hailuo-02'] }
|
// { id: 'minimax', name: 'MiniMax', models: ['MiniMax-Hailuo-2.3', 'MiniMax-Hailuo-2.3-Fast', 'MiniMax-Hailuo-02'] }
|
||||||
]
|
]
|
||||||
@@ -231,6 +280,41 @@ const availableModels = computed(() => {
|
|||||||
return provider?.models || []
|
return provider?.models || []
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 完整端点示例
|
||||||
|
const fullEndpointExample = computed(() => {
|
||||||
|
const baseUrl = form.base_url || 'https://api.example.com'
|
||||||
|
const provider = form.provider
|
||||||
|
const serviceType = form.service_type
|
||||||
|
|
||||||
|
let endpoint = ''
|
||||||
|
|
||||||
|
if (serviceType === 'text') {
|
||||||
|
if (provider === 'gemini' || provider === 'google') {
|
||||||
|
endpoint = '/v1beta/models/{model}:generateContent'
|
||||||
|
} else {
|
||||||
|
endpoint = '/chat/completions'
|
||||||
|
}
|
||||||
|
} else if (serviceType === 'image') {
|
||||||
|
if (provider === 'gemini' || provider === 'google') {
|
||||||
|
endpoint = '/v1beta/models/{model}:generateContent'
|
||||||
|
} else {
|
||||||
|
endpoint = '/images/generations'
|
||||||
|
}
|
||||||
|
} else if (serviceType === 'video') {
|
||||||
|
if (provider === 'chatfire') {
|
||||||
|
endpoint = '/video/generations'
|
||||||
|
} else if (provider === 'doubao' || provider === 'volcengine' || provider === 'volces') {
|
||||||
|
endpoint = '/contents/generations/tasks'
|
||||||
|
} else if (provider === 'openai') {
|
||||||
|
endpoint = '/videos'
|
||||||
|
} else {
|
||||||
|
endpoint = '/video/generations'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return baseUrl + endpoint
|
||||||
|
})
|
||||||
|
|
||||||
const rules: FormRules = {
|
const rules: FormRules = {
|
||||||
name: [
|
name: [
|
||||||
{ required: true, message: '请输入配置名称', trigger: 'blur' }
|
{ required: true, message: '请输入配置名称', trigger: 'blur' }
|
||||||
@@ -274,17 +358,39 @@ const loadConfigs = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 生成随机配置名称
|
||||||
|
const generateConfigName = (provider: string, serviceType: AIServiceType): string => {
|
||||||
|
const providerNames: Record<string, string> = {
|
||||||
|
'chatfire': 'ChatFire',
|
||||||
|
'openai': 'OpenAI',
|
||||||
|
'gemini': 'Gemini',
|
||||||
|
'google': 'Google'
|
||||||
|
}
|
||||||
|
|
||||||
|
const serviceNames: Record<AIServiceType, string> = {
|
||||||
|
'text': '文本',
|
||||||
|
'image': '图片',
|
||||||
|
'video': '视频'
|
||||||
|
}
|
||||||
|
|
||||||
|
const randomNum = Math.floor(Math.random() * 10000).toString().padStart(4, '0')
|
||||||
|
const providerName = providerNames[provider] || provider
|
||||||
|
const serviceName = serviceNames[serviceType] || serviceType
|
||||||
|
|
||||||
|
return `${providerName}-${serviceName}-${randomNum}`
|
||||||
|
}
|
||||||
|
|
||||||
const showCreateDialog = () => {
|
const showCreateDialog = () => {
|
||||||
isEdit.value = false
|
isEdit.value = false
|
||||||
editingId.value = undefined
|
editingId.value = undefined
|
||||||
resetForm()
|
resetForm()
|
||||||
form.service_type = activeTab.value
|
form.service_type = activeTab.value
|
||||||
// 根据服务类型设置默认端点路径
|
// 默认选择 chatfire
|
||||||
form.endpoint = getDefaultEndpoint(activeTab.value)
|
form.provider = 'chatfire'
|
||||||
// 文本生成默认选择openai
|
// 设置默认 base_url
|
||||||
if (activeTab.value === 'text') {
|
form.base_url = 'https://api.chatfire.site/v1'
|
||||||
form.provider = 'openai'
|
// 自动生成随机配置名称
|
||||||
}
|
form.name = generateConfigName('chatfire', activeTab.value)
|
||||||
dialogVisible.value = true
|
dialogVisible.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,26 +398,13 @@ const handleEdit = (config: AIServiceConfig) => {
|
|||||||
isEdit.value = true
|
isEdit.value = true
|
||||||
editingId.value = config.id
|
editingId.value = config.id
|
||||||
|
|
||||||
// 根据模型名称推断厂商
|
|
||||||
const inferProvider = (model: string, serviceType: AIServiceType): string => {
|
|
||||||
const providers = providerConfigs[serviceType]
|
|
||||||
for (const provider of providers) {
|
|
||||||
if (provider.models.includes(model)) {
|
|
||||||
return provider.id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return providers[0]?.id || ''
|
|
||||||
}
|
|
||||||
|
|
||||||
Object.assign(form, {
|
Object.assign(form, {
|
||||||
service_type: config.service_type,
|
service_type: config.service_type,
|
||||||
provider: inferProvider(Array.isArray(config.model) ? config.model[0] : config.model, config.service_type),
|
provider: config.provider || 'chatfire', // 直接使用配置中的 provider,默认为 chatfire
|
||||||
name: config.name,
|
name: config.name,
|
||||||
base_url: config.base_url,
|
base_url: config.base_url,
|
||||||
api_key: config.api_key,
|
api_key: config.api_key,
|
||||||
model: Array.isArray(config.model) ? config.model : [config.model], // 统一转换为数组
|
model: Array.isArray(config.model) ? config.model : [config.model], // 统一转换为数组
|
||||||
endpoint: config.endpoint,
|
|
||||||
query_endpoint: config.query_endpoint || '',
|
|
||||||
priority: config.priority || 0,
|
priority: config.priority || 0,
|
||||||
is_active: config.is_active
|
is_active: config.is_active
|
||||||
})
|
})
|
||||||
@@ -359,7 +452,7 @@ const testConnection = async () => {
|
|||||||
base_url: form.base_url,
|
base_url: form.base_url,
|
||||||
api_key: form.api_key,
|
api_key: form.api_key,
|
||||||
model: form.model,
|
model: form.model,
|
||||||
endpoint: form.endpoint
|
provider: form.provider
|
||||||
})
|
})
|
||||||
ElMessage.success('连接测试成功!')
|
ElMessage.success('连接测试成功!')
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
@@ -376,7 +469,7 @@ const handleTest = async (config: AIServiceConfig) => {
|
|||||||
base_url: config.base_url,
|
base_url: config.base_url,
|
||||||
api_key: config.api_key,
|
api_key: config.api_key,
|
||||||
model: config.model,
|
model: config.model,
|
||||||
endpoint: config.endpoint
|
provider: config.provider
|
||||||
})
|
})
|
||||||
ElMessage.success('连接测试成功!')
|
ElMessage.success('连接测试成功!')
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
@@ -397,11 +490,10 @@ const handleSubmit = async () => {
|
|||||||
if (isEdit.value && editingId.value) {
|
if (isEdit.value && editingId.value) {
|
||||||
const updateData: UpdateAIConfigRequest = {
|
const updateData: UpdateAIConfigRequest = {
|
||||||
name: form.name,
|
name: form.name,
|
||||||
|
provider: form.provider,
|
||||||
base_url: form.base_url,
|
base_url: form.base_url,
|
||||||
api_key: form.api_key,
|
api_key: form.api_key,
|
||||||
model: form.model,
|
model: form.model,
|
||||||
endpoint: form.endpoint,
|
|
||||||
query_endpoint: form.query_endpoint,
|
|
||||||
priority: form.priority,
|
priority: form.priority,
|
||||||
is_active: form.is_active
|
is_active: form.is_active
|
||||||
}
|
}
|
||||||
@@ -422,16 +514,36 @@ const handleSubmit = async () => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleTabChange = (tabName: string | number) => {
|
||||||
|
// 标签页切换时重新加载对应服务类型的配置
|
||||||
|
activeTab.value = tabName as AIServiceType
|
||||||
|
loadConfigs()
|
||||||
|
}
|
||||||
|
|
||||||
const handleProviderChange = () => {
|
const handleProviderChange = () => {
|
||||||
// 切换厂商时清空已选模型
|
// 切换厂商时清空已选模型
|
||||||
form.model = []
|
form.model = []
|
||||||
|
|
||||||
|
// 根据厂商自动设置默认 base_url
|
||||||
|
if (form.provider === 'gemini' || form.provider === 'google') {
|
||||||
|
form.base_url = 'https://api.chatfire.site'
|
||||||
|
} else {
|
||||||
|
// openai, chatfire 等其他厂商
|
||||||
|
form.base_url = 'https://api.chatfire.site/v1'
|
||||||
|
}
|
||||||
|
|
||||||
|
// 仅在新建配置时自动更新名称
|
||||||
|
if (!isEdit.value) {
|
||||||
|
form.name = generateConfigName(form.provider, form.service_type)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 根据服务类型获取默认端点路径
|
// getDefaultEndpoint 已移除,端点由后端根据 provider 自动设置
|
||||||
|
// 保留该函数定义以避免编译错误
|
||||||
const getDefaultEndpoint = (serviceType: AIServiceType): string => {
|
const getDefaultEndpoint = (serviceType: AIServiceType): string => {
|
||||||
switch (serviceType) {
|
switch (serviceType) {
|
||||||
case 'text':
|
case 'text':
|
||||||
return '/v1/chat/completions'
|
return ''
|
||||||
case 'image':
|
case 'image':
|
||||||
return '/v1/images/generations'
|
return '/v1/images/generations'
|
||||||
case 'video':
|
case 'video':
|
||||||
@@ -450,8 +562,6 @@ const resetForm = () => {
|
|||||||
base_url: '',
|
base_url: '',
|
||||||
api_key: '',
|
api_key: '',
|
||||||
model: [], // 改为空数组
|
model: [], // 改为空数组
|
||||||
endpoint: getDefaultEndpoint(serviceType),
|
|
||||||
query_endpoint: '',
|
|
||||||
priority: 0,
|
priority: 0,
|
||||||
is_active: true
|
is_active: true
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user