Files
huobao-drama/application/services/ai_service.go
2026-01-14 12:40:45 +08:00

399 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"errors"
"fmt"
"github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/ai"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type AIService struct {
db *gorm.DB
log *logger.Logger
}
func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
return &AIService{
db: db,
log: log,
}
}
type CreateAIConfigRequest struct {
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
Name string `json:"name" binding:"required,min=1,max=100"`
Provider string `json:"provider" binding:"required"`
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
Endpoint string `json:"endpoint"`
QueryEndpoint string `json:"query_endpoint"`
Priority int `json:"priority"`
IsDefault bool `json:"is_default"`
Settings string `json:"settings"`
}
type UpdateAIConfigRequest struct {
Name string `json:"name" binding:"omitempty,min=1,max=100"`
Provider string `json:"provider"`
BaseURL string `json:"base_url" binding:"omitempty,url"`
APIKey string `json:"api_key"`
Model *models.ModelField `json:"model"`
Endpoint string `json:"endpoint"`
QueryEndpoint string `json:"query_endpoint"`
Priority *int `json:"priority"`
IsDefault bool `json:"is_default"`
IsActive bool `json:"is_active"`
Settings string `json:"settings"`
}
type TestConnectionRequest struct {
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
Provider string `json:"provider"`
Endpoint string `json:"endpoint"`
}
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
// 根据 provider 和 service_type 自动设置 endpoint
endpoint := req.Endpoint
queryEndpoint := req.QueryEndpoint
if endpoint == "" {
switch req.Provider {
case "gemini", "google":
if req.ServiceType == "text" {
endpoint = "/v1beta/models/{model}:generateContent"
} else if req.ServiceType == "image" {
endpoint = "/v1beta/models/{model}:generateContent"
}
case "openai":
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
} else if req.ServiceType == "video" {
endpoint = "/videos"
if queryEndpoint == "" {
queryEndpoint = "/videos/{taskId}"
}
}
case "chatfire":
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
} else if req.ServiceType == "video" {
endpoint = "/video/generations"
if queryEndpoint == "" {
queryEndpoint = "/video/task/{taskId}"
}
}
case "doubao", "volcengine", "volces":
if req.ServiceType == "video" {
endpoint = "/contents/generations/tasks"
if queryEndpoint == "" {
queryEndpoint = "/generations/tasks/{taskId}"
}
}
default:
// 默认使用 OpenAI 格式
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
}
}
}
config := &models.AIServiceConfig{
ServiceType: req.ServiceType,
Name: req.Name,
Provider: req.Provider,
BaseURL: req.BaseURL,
APIKey: req.APIKey,
Model: req.Model,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
Priority: req.Priority,
IsDefault: req.IsDefault,
IsActive: true,
Settings: req.Settings,
}
if err := s.db.Create(config).Error; err != nil {
s.log.Errorw("Failed to create AI config", "error", err)
return nil, err
}
s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint)
return config, nil
}
func (s *AIService) GetConfig(configID uint) (*models.AIServiceConfig, error) {
var config models.AIServiceConfig
err := s.db.Where("id = ? ", configID).First(&config).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("config not found")
}
return nil, err
}
return &config, nil
}
func (s *AIService) ListConfigs(serviceType string) ([]models.AIServiceConfig, error) {
var configs []models.AIServiceConfig
query := s.db
if serviceType != "" {
query = query.Where("service_type = ?", serviceType)
}
err := query.Order("priority DESC, created_at DESC").Find(&configs).Error
if err != nil {
s.log.Errorw("Failed to list AI configs", "error", err)
return nil, err
}
return configs, nil
}
func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*models.AIServiceConfig, error) {
var config models.AIServiceConfig
if err := s.db.Where("id = ? ", configID).First(&config).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("config not found")
}
return nil, err
}
tx := s.db.Begin()
// 不再需要is_default独占逻辑
updates := make(map[string]interface{})
if req.Name != "" {
updates["name"] = req.Name
}
if req.Provider != "" {
updates["provider"] = req.Provider
}
if req.BaseURL != "" {
updates["base_url"] = req.BaseURL
}
if req.APIKey != "" {
updates["api_key"] = req.APIKey
}
if req.Model != nil && len(*req.Model) > 0 {
updates["model"] = *req.Model
}
if req.Priority != nil {
updates["priority"] = *req.Priority
}
// 如果提供了 provider根据 provider 和 service_type 自动设置 endpoint
if req.Provider != "" && req.Endpoint == "" {
provider := req.Provider
serviceType := config.ServiceType
switch provider {
case "gemini", "google":
if serviceType == "text" || serviceType == "image" {
updates["endpoint"] = "/v1beta/models/{model}:generateContent"
}
case "openai":
if serviceType == "text" {
updates["endpoint"] = "/chat/completions"
} else if serviceType == "image" {
updates["endpoint"] = "/images/generations"
} else if serviceType == "video" {
updates["endpoint"] = "/videos"
updates["query_endpoint"] = "/videos/{taskId}"
}
case "chatfire":
if serviceType == "text" {
updates["endpoint"] = "/chat/completions"
} else if serviceType == "image" {
updates["endpoint"] = "/images/generations"
} else if serviceType == "video" {
updates["endpoint"] = "/video/generations"
updates["query_endpoint"] = "/video/task/{taskId}"
}
}
} else if req.Endpoint != "" {
updates["endpoint"] = req.Endpoint
}
// 允许清空query_endpoint所以不检查是否为空
updates["query_endpoint"] = req.QueryEndpoint
if req.Settings != "" {
updates["settings"] = req.Settings
}
updates["is_default"] = req.IsDefault
updates["is_active"] = req.IsActive
if err := tx.Model(&config).Updates(updates).Error; err != nil {
tx.Rollback()
s.log.Errorw("Failed to update AI config", "error", err)
return nil, err
}
if err := tx.Commit().Error; err != nil {
return nil, err
}
s.log.Infow("AI config updated", "config_id", configID)
return &config, nil
}
func (s *AIService) DeleteConfig(configID uint) error {
result := s.db.Where("id = ? ", configID).Delete(&models.AIServiceConfig{})
if result.Error != nil {
s.log.Errorw("Failed to delete AI config", "error", result.Error)
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("config not found")
}
s.log.Infow("AI config deleted", "config_id", configID)
return nil
}
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model))
// 使用第一个模型进行测试
model := ""
if len(req.Model) > 0 {
model = req.Model[0]
}
s.log.Infow("Using model for test", "model", model, "provider", req.Provider)
// 根据 provider 参数选择客户端
var client ai.AIClient
var endpoint string
switch req.Provider {
case "gemini", "google":
// Gemini
s.log.Infow("Using Gemini client", "baseURL", req.BaseURL)
endpoint = "/v1beta/models/{model}:generateContent"
client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint)
case "openai", "chatfire":
// OpenAI 格式(包括 chatfire 等)
s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
default:
// 默认使用 OpenAI 格式
s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
}
s.log.Infow("Calling TestConnection on client", "endpoint", endpoint)
err := client.TestConnection()
if err != nil {
s.log.Errorw("TestConnection failed", "error", err)
} else {
s.log.Infow("TestConnection succeeded")
}
return err
}
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
var config models.AIServiceConfig
// 按优先级降序获取第一个配置
err := s.db.Where("service_type = ?", serviceType).
Order("priority DESC, created_at DESC").
First(&config).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("no config found")
}
return nil, err
}
return &config, nil
}
// GetConfigForModel 根据服务类型和模型名称获取优先级最高的配置
func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*models.AIServiceConfig, error) {
var configs []models.AIServiceConfig
err := s.db.Where("service_type = ?", serviceType).
Order("priority DESC, created_at DESC").
Find(&configs).Error
if err != nil {
return nil, err
}
// 查找包含指定模型的配置
for _, config := range configs {
for _, model := range config.Model {
if model == modelName {
return &config, nil
}
}
}
return nil, errors.New("no config found for model: " + modelName)
}
func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) {
config, err := s.GetDefaultConfig(serviceType)
if err != nil {
return nil, err
}
// 使用第一个模型
model := ""
if len(config.Model) > 0 {
model = config.Model[0]
}
// 使用数据库配置中的 endpoint如果为空则根据 provider 设置默认值
endpoint := config.Endpoint
if endpoint == "" {
switch config.Provider {
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
default:
endpoint = "/chat/completions"
}
}
// 根据 provider 创建对应的客户端
switch config.Provider {
case "gemini", "google":
return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
// openai, chatfire 等其他厂商都使用 OpenAI 格式
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {
client, err := s.GetAIClient("text")
if err != nil {
return "", fmt.Errorf("failed to get AI client: %w", err)
}
return client.GenerateText(prompt, systemPrompt, options...)
}