Files
huobao-drama/application/services/ai_service.go
Connor 9600fc542c init
2026-01-12 13:17:11 +08:00

254 lines
7.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"errors"
"fmt"
"github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/ai"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type AIService struct {
db *gorm.DB
log *logger.Logger
}
func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
return &AIService{
db: db,
log: log,
}
}
type CreateAIConfigRequest struct {
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
Name string `json:"name" binding:"required,min=1,max=100"`
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
Endpoint string `json:"endpoint"`
QueryEndpoint string `json:"query_endpoint"`
Priority int `json:"priority"`
IsDefault bool `json:"is_default"`
Settings string `json:"settings"`
}
type UpdateAIConfigRequest struct {
Name string `json:"name" binding:"omitempty,min=1,max=100"`
BaseURL string `json:"base_url" binding:"omitempty,url"`
APIKey string `json:"api_key"`
Model *models.ModelField `json:"model"`
Endpoint string `json:"endpoint"`
QueryEndpoint string `json:"query_endpoint"`
Priority *int `json:"priority"`
IsDefault bool `json:"is_default"`
IsActive bool `json:"is_active"`
Settings string `json:"settings"`
}
type TestConnectionRequest struct {
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
Endpoint string `json:"endpoint"`
}
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
config := &models.AIServiceConfig{
ServiceType: req.ServiceType,
Name: req.Name,
BaseURL: req.BaseURL,
APIKey: req.APIKey,
Model: req.Model,
Endpoint: req.Endpoint,
QueryEndpoint: req.QueryEndpoint,
Priority: req.Priority,
IsDefault: req.IsDefault,
IsActive: true,
Settings: req.Settings,
}
if err := s.db.Create(config).Error; err != nil {
s.log.Errorw("Failed to create AI config", "error", err)
return nil, err
}
s.log.Infow("AI config created", "config_id", config.ID)
return config, nil
}
func (s *AIService) GetConfig(configID uint) (*models.AIServiceConfig, error) {
var config models.AIServiceConfig
err := s.db.Where("id = ? ", configID).First(&config).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("config not found")
}
return nil, err
}
return &config, nil
}
func (s *AIService) ListConfigs(serviceType string) ([]models.AIServiceConfig, error) {
var configs []models.AIServiceConfig
query := s.db
if serviceType != "" {
query = query.Where("service_type = ?", serviceType)
}
err := query.Order("priority DESC, created_at DESC").Find(&configs).Error
if err != nil {
s.log.Errorw("Failed to list AI configs", "error", err)
return nil, err
}
return configs, nil
}
func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*models.AIServiceConfig, error) {
var config models.AIServiceConfig
if err := s.db.Where("id = ? ", configID).First(&config).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("config not found")
}
return nil, err
}
tx := s.db.Begin()
// 不再需要is_default独占逻辑
updates := make(map[string]interface{})
if req.Name != "" {
updates["name"] = req.Name
}
if req.BaseURL != "" {
updates["base_url"] = req.BaseURL
}
if req.APIKey != "" {
updates["api_key"] = req.APIKey
}
if req.Model != nil && len(*req.Model) > 0 {
updates["model"] = *req.Model
}
if req.Priority != nil {
updates["priority"] = *req.Priority
}
if req.Endpoint != "" {
updates["endpoint"] = req.Endpoint
}
// 允许清空query_endpoint所以不检查是否为空
updates["query_endpoint"] = req.QueryEndpoint
if req.Settings != "" {
updates["settings"] = req.Settings
}
updates["is_default"] = req.IsDefault
updates["is_active"] = req.IsActive
if err := tx.Model(&config).Updates(updates).Error; err != nil {
tx.Rollback()
s.log.Errorw("Failed to update AI config", "error", err)
return nil, err
}
if err := tx.Commit().Error; err != nil {
return nil, err
}
s.log.Infow("AI config updated", "config_id", configID)
return &config, nil
}
func (s *AIService) DeleteConfig(configID uint) error {
result := s.db.Where("id = ? ", configID).Delete(&models.AIServiceConfig{})
if result.Error != nil {
s.log.Errorw("Failed to delete AI config", "error", result.Error)
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("config not found")
}
s.log.Infow("AI config deleted", "config_id", configID)
return nil
}
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
// 使用第一个模型进行测试
model := ""
if len(req.Model) > 0 {
model = req.Model[0]
}
client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint)
return client.TestConnection()
}
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
var config models.AIServiceConfig
// 按优先级降序获取第一个启用的配置
err := s.db.Where("service_type = ? AND is_active = ?", serviceType, true).
Order("priority DESC, created_at DESC").
First(&config).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("no active config found")
}
return nil, err
}
return &config, nil
}
// GetConfigForModel 根据服务类型和模型名称获取优先级最高的启用配置
func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*models.AIServiceConfig, error) {
var configs []models.AIServiceConfig
err := s.db.Where("service_type = ? AND is_active = ?", serviceType, true).
Order("priority DESC, created_at DESC").
Find(&configs).Error
if err != nil {
return nil, err
}
// 查找包含指定模型的配置
for _, config := range configs {
for _, model := range config.Model {
if model == modelName {
return &config, nil
}
}
}
return nil, errors.New("no config found for model: " + modelName)
}
func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
config, err := s.GetDefaultConfig(serviceType)
if err != nil {
return nil, err
}
// 使用第一个模型
model := ""
if len(config.Model) > 0 {
model = config.Model[0]
}
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil
}
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {
client, err := s.GetAIClient("text")
if err != nil {
return "", fmt.Errorf("failed to get AI client: %w", err)
}
return client.GenerateText(prompt, systemPrompt, options...)
}