package services import ( "errors" "fmt" "github.com/drama-generator/backend/domain/models" "github.com/drama-generator/backend/pkg/ai" "github.com/drama-generator/backend/pkg/logger" "gorm.io/gorm" ) type AIService struct { db *gorm.DB log *logger.Logger } func NewAIService(db *gorm.DB, log *logger.Logger) *AIService { return &AIService{ db: db, log: log, } } type CreateAIConfigRequest struct { ServiceType string `json:"service_type" binding:"required,oneof=text image video"` Name string `json:"name" binding:"required,min=1,max=100"` Provider string `json:"provider" binding:"required"` BaseURL string `json:"base_url" binding:"required,url"` APIKey string `json:"api_key" binding:"required"` Model models.ModelField `json:"model" binding:"required"` Endpoint string `json:"endpoint"` QueryEndpoint string `json:"query_endpoint"` Priority int `json:"priority"` IsDefault bool `json:"is_default"` Settings string `json:"settings"` } type UpdateAIConfigRequest struct { Name string `json:"name" binding:"omitempty,min=1,max=100"` Provider string `json:"provider"` BaseURL string `json:"base_url" binding:"omitempty,url"` APIKey string `json:"api_key"` Model *models.ModelField `json:"model"` Endpoint string `json:"endpoint"` QueryEndpoint string `json:"query_endpoint"` Priority *int `json:"priority"` IsDefault bool `json:"is_default"` IsActive bool `json:"is_active"` Settings string `json:"settings"` } type TestConnectionRequest struct { BaseURL string `json:"base_url" binding:"required,url"` APIKey string `json:"api_key" binding:"required"` Model models.ModelField `json:"model" binding:"required"` Provider string `json:"provider"` Endpoint string `json:"endpoint"` } func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) { // 根据 provider 和 service_type 自动设置 endpoint endpoint := req.Endpoint queryEndpoint := req.QueryEndpoint if endpoint == "" { switch req.Provider { case "gemini", "google": if req.ServiceType == "text" { endpoint = "/v1beta/models/{model}:generateContent" } else if req.ServiceType == "image" { endpoint = "/v1beta/models/{model}:generateContent" } case "openai": if req.ServiceType == "text" { endpoint = "/chat/completions" } else if req.ServiceType == "image" { endpoint = "/images/generations" } else if req.ServiceType == "video" { endpoint = "/videos" if queryEndpoint == "" { queryEndpoint = "/videos/{taskId}" } } case "chatfire": if req.ServiceType == "text" { endpoint = "/chat/completions" } else if req.ServiceType == "image" { endpoint = "/images/generations" } else if req.ServiceType == "video" { endpoint = "/video/generations" if queryEndpoint == "" { queryEndpoint = "/video/task/{taskId}" } } case "doubao", "volcengine", "volces": if req.ServiceType == "video" { endpoint = "/contents/generations/tasks" if queryEndpoint == "" { queryEndpoint = "/generations/tasks/{taskId}" } } default: // 默认使用 OpenAI 格式 if req.ServiceType == "text" { endpoint = "/chat/completions" } else if req.ServiceType == "image" { endpoint = "/images/generations" } } } config := &models.AIServiceConfig{ ServiceType: req.ServiceType, Name: req.Name, Provider: req.Provider, BaseURL: req.BaseURL, APIKey: req.APIKey, Model: req.Model, Endpoint: endpoint, QueryEndpoint: queryEndpoint, Priority: req.Priority, IsDefault: req.IsDefault, IsActive: true, Settings: req.Settings, } if err := s.db.Create(config).Error; err != nil { s.log.Errorw("Failed to create AI config", "error", err) return nil, err } s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint) return config, nil } func (s *AIService) GetConfig(configID uint) (*models.AIServiceConfig, error) { var config models.AIServiceConfig err := s.db.Where("id = ? ", configID).First(&config).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("config not found") } return nil, err } return &config, nil } func (s *AIService) ListConfigs(serviceType string) ([]models.AIServiceConfig, error) { var configs []models.AIServiceConfig query := s.db if serviceType != "" { query = query.Where("service_type = ?", serviceType) } err := query.Order("priority DESC, created_at DESC").Find(&configs).Error if err != nil { s.log.Errorw("Failed to list AI configs", "error", err) return nil, err } return configs, nil } func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*models.AIServiceConfig, error) { var config models.AIServiceConfig if err := s.db.Where("id = ? ", configID).First(&config).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("config not found") } return nil, err } tx := s.db.Begin() // 不再需要is_default独占逻辑 updates := make(map[string]interface{}) if req.Name != "" { updates["name"] = req.Name } if req.Provider != "" { updates["provider"] = req.Provider } if req.BaseURL != "" { updates["base_url"] = req.BaseURL } if req.APIKey != "" { updates["api_key"] = req.APIKey } if req.Model != nil && len(*req.Model) > 0 { updates["model"] = *req.Model } if req.Priority != nil { updates["priority"] = *req.Priority } // 如果提供了 provider,根据 provider 和 service_type 自动设置 endpoint if req.Provider != "" && req.Endpoint == "" { provider := req.Provider serviceType := config.ServiceType switch provider { case "gemini", "google": if serviceType == "text" || serviceType == "image" { updates["endpoint"] = "/v1beta/models/{model}:generateContent" } case "openai": if serviceType == "text" { updates["endpoint"] = "/chat/completions" } else if serviceType == "image" { updates["endpoint"] = "/images/generations" } else if serviceType == "video" { updates["endpoint"] = "/videos" updates["query_endpoint"] = "/videos/{taskId}" } case "chatfire": if serviceType == "text" { updates["endpoint"] = "/chat/completions" } else if serviceType == "image" { updates["endpoint"] = "/images/generations" } else if serviceType == "video" { updates["endpoint"] = "/video/generations" updates["query_endpoint"] = "/video/task/{taskId}" } } } else if req.Endpoint != "" { updates["endpoint"] = req.Endpoint } // 允许清空query_endpoint,所以不检查是否为空 updates["query_endpoint"] = req.QueryEndpoint if req.Settings != "" { updates["settings"] = req.Settings } updates["is_default"] = req.IsDefault updates["is_active"] = req.IsActive if err := tx.Model(&config).Updates(updates).Error; err != nil { tx.Rollback() s.log.Errorw("Failed to update AI config", "error", err) return nil, err } if err := tx.Commit().Error; err != nil { return nil, err } s.log.Infow("AI config updated", "config_id", configID) return &config, nil } func (s *AIService) DeleteConfig(configID uint) error { result := s.db.Where("id = ? ", configID).Delete(&models.AIServiceConfig{}) if result.Error != nil { s.log.Errorw("Failed to delete AI config", "error", result.Error) return result.Error } if result.RowsAffected == 0 { return errors.New("config not found") } s.log.Infow("AI config deleted", "config_id", configID) return nil } func (s *AIService) TestConnection(req *TestConnectionRequest) error { s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model)) // 使用第一个模型进行测试 model := "" if len(req.Model) > 0 { model = req.Model[0] } s.log.Infow("Using model for test", "model", model, "provider", req.Provider) // 根据 provider 参数选择客户端 var client ai.AIClient var endpoint string switch req.Provider { case "gemini", "google": // Gemini s.log.Infow("Using Gemini client", "baseURL", req.BaseURL) endpoint = "/v1beta/models/{model}:generateContent" client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint) case "openai", "chatfire": // OpenAI 格式(包括 chatfire 等) s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider) endpoint = req.Endpoint if endpoint == "" { endpoint = "/chat/completions" } client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint) default: // 默认使用 OpenAI 格式 s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL) endpoint = req.Endpoint if endpoint == "" { endpoint = "/chat/completions" } client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint) } s.log.Infow("Calling TestConnection on client", "endpoint", endpoint) err := client.TestConnection() if err != nil { s.log.Errorw("TestConnection failed", "error", err) } else { s.log.Infow("TestConnection succeeded") } return err } func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) { var config models.AIServiceConfig // 按优先级降序获取第一个配置 err := s.db.Where("service_type = ?", serviceType). Order("priority DESC, created_at DESC"). First(&config).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("no config found") } return nil, err } return &config, nil } // GetConfigForModel 根据服务类型和模型名称获取优先级最高的配置 func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*models.AIServiceConfig, error) { var configs []models.AIServiceConfig err := s.db.Where("service_type = ?", serviceType). Order("priority DESC, created_at DESC"). Find(&configs).Error if err != nil { return nil, err } // 查找包含指定模型的配置 for _, config := range configs { for _, model := range config.Model { if model == modelName { return &config, nil } } } return nil, errors.New("no config found for model: " + modelName) } func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) { config, err := s.GetDefaultConfig(serviceType) if err != nil { return nil, err } // 使用第一个模型 model := "" if len(config.Model) > 0 { model = config.Model[0] } // 使用数据库配置中的 endpoint,如果为空则根据 provider 设置默认值 endpoint := config.Endpoint if endpoint == "" { switch config.Provider { case "gemini", "google": endpoint = "/v1beta/models/{model}:generateContent" default: endpoint = "/chat/completions" } } // 根据 provider 创建对应的客户端 switch config.Provider { case "gemini", "google": return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil default: // openai, chatfire 等其他厂商都使用 OpenAI 格式 return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil } } func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) { client, err := s.GetAIClient("text") if err != nil { return "", fmt.Errorf("failed to get AI client: %w", err) } return client.GenerateText(prompt, systemPrompt, options...) }