package services import ( "errors" "fmt" "github.com/drama-generator/backend/domain/models" "github.com/drama-generator/backend/pkg/ai" "github.com/drama-generator/backend/pkg/logger" "gorm.io/gorm" ) type AIService struct { db *gorm.DB log *logger.Logger } func NewAIService(db *gorm.DB, log *logger.Logger) *AIService { return &AIService{ db: db, log: log, } } type CreateAIConfigRequest struct { ServiceType string `json:"service_type" binding:"required,oneof=text image video"` Name string `json:"name" binding:"required,min=1,max=100"` BaseURL string `json:"base_url" binding:"required,url"` APIKey string `json:"api_key" binding:"required"` Model models.ModelField `json:"model" binding:"required"` Endpoint string `json:"endpoint"` QueryEndpoint string `json:"query_endpoint"` Priority int `json:"priority"` IsDefault bool `json:"is_default"` Settings string `json:"settings"` } type UpdateAIConfigRequest struct { Name string `json:"name" binding:"omitempty,min=1,max=100"` BaseURL string `json:"base_url" binding:"omitempty,url"` APIKey string `json:"api_key"` Model *models.ModelField `json:"model"` Endpoint string `json:"endpoint"` QueryEndpoint string `json:"query_endpoint"` Priority *int `json:"priority"` IsDefault bool `json:"is_default"` IsActive bool `json:"is_active"` Settings string `json:"settings"` } type TestConnectionRequest struct { BaseURL string `json:"base_url" binding:"required,url"` APIKey string `json:"api_key" binding:"required"` Model models.ModelField `json:"model" binding:"required"` Endpoint string `json:"endpoint"` } func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) { config := &models.AIServiceConfig{ ServiceType: req.ServiceType, Name: req.Name, BaseURL: req.BaseURL, APIKey: req.APIKey, Model: req.Model, Endpoint: req.Endpoint, QueryEndpoint: req.QueryEndpoint, Priority: req.Priority, IsDefault: req.IsDefault, IsActive: true, Settings: req.Settings, } if err := s.db.Create(config).Error; err != nil { s.log.Errorw("Failed to create AI config", "error", err) return nil, err } s.log.Infow("AI config created", "config_id", config.ID) return config, nil } func (s *AIService) GetConfig(configID uint) (*models.AIServiceConfig, error) { var config models.AIServiceConfig err := s.db.Where("id = ? ", configID).First(&config).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("config not found") } return nil, err } return &config, nil } func (s *AIService) ListConfigs(serviceType string) ([]models.AIServiceConfig, error) { var configs []models.AIServiceConfig query := s.db if serviceType != "" { query = query.Where("service_type = ?", serviceType) } err := query.Order("priority DESC, created_at DESC").Find(&configs).Error if err != nil { s.log.Errorw("Failed to list AI configs", "error", err) return nil, err } return configs, nil } func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*models.AIServiceConfig, error) { var config models.AIServiceConfig if err := s.db.Where("id = ? ", configID).First(&config).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("config not found") } return nil, err } tx := s.db.Begin() // 不再需要is_default独占逻辑 updates := make(map[string]interface{}) if req.Name != "" { updates["name"] = req.Name } if req.BaseURL != "" { updates["base_url"] = req.BaseURL } if req.APIKey != "" { updates["api_key"] = req.APIKey } if req.Model != nil && len(*req.Model) > 0 { updates["model"] = *req.Model } if req.Priority != nil { updates["priority"] = *req.Priority } if req.Endpoint != "" { updates["endpoint"] = req.Endpoint } // 允许清空query_endpoint,所以不检查是否为空 updates["query_endpoint"] = req.QueryEndpoint if req.Settings != "" { updates["settings"] = req.Settings } updates["is_default"] = req.IsDefault updates["is_active"] = req.IsActive if err := tx.Model(&config).Updates(updates).Error; err != nil { tx.Rollback() s.log.Errorw("Failed to update AI config", "error", err) return nil, err } if err := tx.Commit().Error; err != nil { return nil, err } s.log.Infow("AI config updated", "config_id", configID) return &config, nil } func (s *AIService) DeleteConfig(configID uint) error { result := s.db.Where("id = ? ", configID).Delete(&models.AIServiceConfig{}) if result.Error != nil { s.log.Errorw("Failed to delete AI config", "error", result.Error) return result.Error } if result.RowsAffected == 0 { return errors.New("config not found") } s.log.Infow("AI config deleted", "config_id", configID) return nil } func (s *AIService) TestConnection(req *TestConnectionRequest) error { // 使用第一个模型进行测试 model := "" if len(req.Model) > 0 { model = req.Model[0] } client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint) return client.TestConnection() } func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) { var config models.AIServiceConfig // 按优先级降序获取第一个启用的配置 err := s.db.Where("service_type = ? AND is_active = ?", serviceType, true). Order("priority DESC, created_at DESC"). First(&config).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("no active config found") } return nil, err } return &config, nil } // GetConfigForModel 根据服务类型和模型名称获取优先级最高的启用配置 func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*models.AIServiceConfig, error) { var configs []models.AIServiceConfig err := s.db.Where("service_type = ? AND is_active = ?", serviceType, true). Order("priority DESC, created_at DESC"). Find(&configs).Error if err != nil { return nil, err } // 查找包含指定模型的配置 for _, config := range configs { for _, model := range config.Model { if model == modelName { return &config, nil } } } return nil, errors.New("no config found for model: " + modelName) } func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) { config, err := s.GetDefaultConfig(serviceType) if err != nil { return nil, err } // 使用第一个模型 model := "" if len(config.Model) > 0 { model = config.Model[0] } return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil } func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) { client, err := s.GetAIClient("text") if err != nil { return "", fmt.Errorf("failed to get AI client: %w", err) } return client.GenerateText(prompt, systemPrompt, options...) }