254 lines
7.1 KiB
Go
254 lines
7.1 KiB
Go
package services
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
|
||
"github.com/drama-generator/backend/domain/models"
|
||
"github.com/drama-generator/backend/pkg/ai"
|
||
"github.com/drama-generator/backend/pkg/logger"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type AIService struct {
|
||
db *gorm.DB
|
||
log *logger.Logger
|
||
}
|
||
|
||
func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
|
||
return &AIService{
|
||
db: db,
|
||
log: log,
|
||
}
|
||
}
|
||
|
||
type CreateAIConfigRequest struct {
|
||
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
|
||
Name string `json:"name" binding:"required,min=1,max=100"`
|
||
BaseURL string `json:"base_url" binding:"required,url"`
|
||
APIKey string `json:"api_key" binding:"required"`
|
||
Model models.ModelField `json:"model" binding:"required"`
|
||
Endpoint string `json:"endpoint"`
|
||
QueryEndpoint string `json:"query_endpoint"`
|
||
Priority int `json:"priority"`
|
||
IsDefault bool `json:"is_default"`
|
||
Settings string `json:"settings"`
|
||
}
|
||
|
||
type UpdateAIConfigRequest struct {
|
||
Name string `json:"name" binding:"omitempty,min=1,max=100"`
|
||
BaseURL string `json:"base_url" binding:"omitempty,url"`
|
||
APIKey string `json:"api_key"`
|
||
Model *models.ModelField `json:"model"`
|
||
Endpoint string `json:"endpoint"`
|
||
QueryEndpoint string `json:"query_endpoint"`
|
||
Priority *int `json:"priority"`
|
||
IsDefault bool `json:"is_default"`
|
||
IsActive bool `json:"is_active"`
|
||
Settings string `json:"settings"`
|
||
}
|
||
|
||
type TestConnectionRequest struct {
|
||
BaseURL string `json:"base_url" binding:"required,url"`
|
||
APIKey string `json:"api_key" binding:"required"`
|
||
Model models.ModelField `json:"model" binding:"required"`
|
||
Endpoint string `json:"endpoint"`
|
||
}
|
||
|
||
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
|
||
config := &models.AIServiceConfig{
|
||
ServiceType: req.ServiceType,
|
||
Name: req.Name,
|
||
BaseURL: req.BaseURL,
|
||
APIKey: req.APIKey,
|
||
Model: req.Model,
|
||
Endpoint: req.Endpoint,
|
||
QueryEndpoint: req.QueryEndpoint,
|
||
Priority: req.Priority,
|
||
IsDefault: req.IsDefault,
|
||
IsActive: true,
|
||
Settings: req.Settings,
|
||
}
|
||
|
||
if err := s.db.Create(config).Error; err != nil {
|
||
s.log.Errorw("Failed to create AI config", "error", err)
|
||
return nil, err
|
||
}
|
||
|
||
s.log.Infow("AI config created", "config_id", config.ID)
|
||
return config, nil
|
||
}
|
||
|
||
func (s *AIService) GetConfig(configID uint) (*models.AIServiceConfig, error) {
|
||
var config models.AIServiceConfig
|
||
err := s.db.Where("id = ? ", configID).First(&config).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("config not found")
|
||
}
|
||
return nil, err
|
||
}
|
||
return &config, nil
|
||
}
|
||
|
||
func (s *AIService) ListConfigs(serviceType string) ([]models.AIServiceConfig, error) {
|
||
var configs []models.AIServiceConfig
|
||
query := s.db
|
||
|
||
if serviceType != "" {
|
||
query = query.Where("service_type = ?", serviceType)
|
||
}
|
||
|
||
err := query.Order("priority DESC, created_at DESC").Find(&configs).Error
|
||
if err != nil {
|
||
s.log.Errorw("Failed to list AI configs", "error", err)
|
||
return nil, err
|
||
}
|
||
|
||
return configs, nil
|
||
}
|
||
|
||
func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*models.AIServiceConfig, error) {
|
||
var config models.AIServiceConfig
|
||
if err := s.db.Where("id = ? ", configID).First(&config).Error; err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("config not found")
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
tx := s.db.Begin()
|
||
|
||
// 不再需要is_default独占逻辑
|
||
|
||
updates := make(map[string]interface{})
|
||
if req.Name != "" {
|
||
updates["name"] = req.Name
|
||
}
|
||
if req.BaseURL != "" {
|
||
updates["base_url"] = req.BaseURL
|
||
}
|
||
if req.APIKey != "" {
|
||
updates["api_key"] = req.APIKey
|
||
}
|
||
if req.Model != nil && len(*req.Model) > 0 {
|
||
updates["model"] = *req.Model
|
||
}
|
||
if req.Priority != nil {
|
||
updates["priority"] = *req.Priority
|
||
}
|
||
if req.Endpoint != "" {
|
||
updates["endpoint"] = req.Endpoint
|
||
}
|
||
// 允许清空query_endpoint,所以不检查是否为空
|
||
updates["query_endpoint"] = req.QueryEndpoint
|
||
if req.Settings != "" {
|
||
updates["settings"] = req.Settings
|
||
}
|
||
updates["is_default"] = req.IsDefault
|
||
updates["is_active"] = req.IsActive
|
||
|
||
if err := tx.Model(&config).Updates(updates).Error; err != nil {
|
||
tx.Rollback()
|
||
s.log.Errorw("Failed to update AI config", "error", err)
|
||
return nil, err
|
||
}
|
||
|
||
if err := tx.Commit().Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
s.log.Infow("AI config updated", "config_id", configID)
|
||
return &config, nil
|
||
}
|
||
|
||
func (s *AIService) DeleteConfig(configID uint) error {
|
||
result := s.db.Where("id = ? ", configID).Delete(&models.AIServiceConfig{})
|
||
|
||
if result.Error != nil {
|
||
s.log.Errorw("Failed to delete AI config", "error", result.Error)
|
||
return result.Error
|
||
}
|
||
|
||
if result.RowsAffected == 0 {
|
||
return errors.New("config not found")
|
||
}
|
||
|
||
s.log.Infow("AI config deleted", "config_id", configID)
|
||
return nil
|
||
}
|
||
|
||
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
|
||
// 使用第一个模型进行测试
|
||
model := ""
|
||
if len(req.Model) > 0 {
|
||
model = req.Model[0]
|
||
}
|
||
client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint)
|
||
return client.TestConnection()
|
||
}
|
||
|
||
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
|
||
var config models.AIServiceConfig
|
||
// 按优先级降序获取第一个启用的配置
|
||
err := s.db.Where("service_type = ? AND is_active = ?", serviceType, true).
|
||
Order("priority DESC, created_at DESC").
|
||
First(&config).Error
|
||
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("no active config found")
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
return &config, nil
|
||
}
|
||
|
||
// GetConfigForModel 根据服务类型和模型名称获取优先级最高的启用配置
|
||
func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*models.AIServiceConfig, error) {
|
||
var configs []models.AIServiceConfig
|
||
err := s.db.Where("service_type = ? AND is_active = ?", serviceType, true).
|
||
Order("priority DESC, created_at DESC").
|
||
Find(&configs).Error
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 查找包含指定模型的配置
|
||
for _, config := range configs {
|
||
for _, model := range config.Model {
|
||
if model == modelName {
|
||
return &config, nil
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil, errors.New("no config found for model: " + modelName)
|
||
}
|
||
|
||
func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
|
||
config, err := s.GetDefaultConfig(serviceType)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 使用第一个模型
|
||
model := ""
|
||
if len(config.Model) > 0 {
|
||
model = config.Model[0]
|
||
}
|
||
|
||
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil
|
||
}
|
||
|
||
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {
|
||
client, err := s.GetAIClient("text")
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to get AI client: %w", err)
|
||
}
|
||
|
||
return client.GenerateText(prompt, systemPrompt, options...)
|
||
}
|