添加chat gemini、chatfire端点、 图片生成 gemini、chatfire 更轻松的AI配置

This commit is contained in:
Connor
2026-01-14 02:25:41 +08:00
parent 4d38357ff6
commit 23b45efae9
22 changed files with 1512 additions and 405 deletions

View File

@@ -25,6 +25,7 @@ func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
type CreateAIConfigRequest struct {
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
Name string `json:"name" binding:"required,min=1,max=100"`
Provider string `json:"provider" binding:"required"`
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
@@ -37,6 +38,7 @@ type CreateAIConfigRequest struct {
type UpdateAIConfigRequest struct {
Name string `json:"name" binding:"omitempty,min=1,max=100"`
Provider string `json:"provider"`
BaseURL string `json:"base_url" binding:"omitempty,url"`
APIKey string `json:"api_key"`
Model *models.ModelField `json:"model"`
@@ -52,18 +54,53 @@ type TestConnectionRequest struct {
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
Provider string `json:"provider"`
Endpoint string `json:"endpoint"`
}
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
// 根据 provider 和 service_type 自动设置 endpoint
endpoint := req.Endpoint
queryEndpoint := req.QueryEndpoint
if endpoint == "" {
switch req.Provider {
case "gemini", "google":
if req.ServiceType == "text" {
endpoint = "/v1beta/models/{model}:generateContent"
} else if req.ServiceType == "image" {
endpoint = "/v1beta/models/{model}:generateContent"
}
case "openai", "chatfire":
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
} else if req.ServiceType == "video" {
endpoint = "/video/generations"
if queryEndpoint == "" {
queryEndpoint = "/v1/video/task/{taskId}"
}
}
default:
// 默认使用 OpenAI 格式
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
}
}
}
config := &models.AIServiceConfig{
ServiceType: req.ServiceType,
Name: req.Name,
Provider: req.Provider,
BaseURL: req.BaseURL,
APIKey: req.APIKey,
Model: req.Model,
Endpoint: req.Endpoint,
QueryEndpoint: req.QueryEndpoint,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
Priority: req.Priority,
IsDefault: req.IsDefault,
IsActive: true,
@@ -75,7 +112,7 @@ func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceC
return nil, err
}
s.log.Infow("AI config created", "config_id", config.ID)
s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint)
return config, nil
}
@@ -125,6 +162,9 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
if req.Name != "" {
updates["name"] = req.Name
}
if req.Provider != "" {
updates["provider"] = req.Provider
}
if req.BaseURL != "" {
updates["base_url"] = req.BaseURL
}
@@ -137,9 +177,30 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
if req.Priority != nil {
updates["priority"] = *req.Priority
}
if req.Endpoint != "" {
// 如果提供了 provider根据 provider 和 service_type 自动设置 endpoint
if req.Provider != "" && req.Endpoint == "" {
provider := req.Provider
serviceType := config.ServiceType
switch provider {
case "gemini", "google":
if serviceType == "text" || serviceType == "image" {
updates["endpoint"] = "/v1beta/models/{model}:generateContent"
}
case "openai", "chatfire":
if serviceType == "text" {
updates["endpoint"] = "/chat/completions"
} else if serviceType == "image" {
updates["endpoint"] = "/images/generations"
} else if serviceType == "video" {
updates["endpoint"] = "/video/generations"
}
}
} else if req.Endpoint != "" {
updates["endpoint"] = req.Endpoint
}
// 允许清空query_endpoint所以不检查是否为空
updates["query_endpoint"] = req.QueryEndpoint
if req.Settings != "" {
@@ -179,13 +240,51 @@ func (s *AIService) DeleteConfig(configID uint) error {
}
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model))
// 使用第一个模型进行测试
model := ""
if len(req.Model) > 0 {
model = req.Model[0]
}
client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint)
return client.TestConnection()
s.log.Infow("Using model for test", "model", model, "provider", req.Provider)
// 根据 provider 参数选择客户端
var client ai.AIClient
var endpoint string
switch req.Provider {
case "gemini", "google":
// Gemini
s.log.Infow("Using Gemini client", "baseURL", req.BaseURL)
endpoint = "/v1beta/models/{model}:generateContent"
client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint)
case "openai", "chatfire":
// OpenAI 格式(包括 chatfire 等)
s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
default:
// 默认使用 OpenAI 格式
s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
}
s.log.Infow("Calling TestConnection on client", "endpoint", endpoint)
err := client.TestConnection()
if err != nil {
s.log.Errorw("TestConnection failed", "error", err)
} else {
s.log.Infow("TestConnection succeeded")
}
return err
}
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
@@ -228,7 +327,7 @@ func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*mo
return nil, errors.New("no config found for model: " + modelName)
}
func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) {
config, err := s.GetDefaultConfig(serviceType)
if err != nil {
return nil, err
@@ -240,7 +339,25 @@ func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
model = config.Model[0]
}
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil
// 使用数据库配置中的 endpoint如果为空则根据 provider 设置默认值
endpoint := config.Endpoint
if endpoint == "" {
switch config.Provider {
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
default:
endpoint = "/chat/completions"
}
}
// 根据 provider 创建对应的客户端
switch config.Provider {
case "gemini", "google":
return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
// openai, chatfire 等其他厂商都使用 OpenAI 格式
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {