添加chat gemini、chatfire端点、 图片生成 gemini、chatfire 更轻松的AI配置
This commit is contained in:
@@ -25,6 +25,7 @@ func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
|
||||
type CreateAIConfigRequest struct {
|
||||
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
|
||||
Name string `json:"name" binding:"required,min=1,max=100"`
|
||||
Provider string `json:"provider" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required,url"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
Model models.ModelField `json:"model" binding:"required"`
|
||||
@@ -37,6 +38,7 @@ type CreateAIConfigRequest struct {
|
||||
|
||||
type UpdateAIConfigRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,min=1,max=100"`
|
||||
Provider string `json:"provider"`
|
||||
BaseURL string `json:"base_url" binding:"omitempty,url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Model *models.ModelField `json:"model"`
|
||||
@@ -52,18 +54,53 @@ type TestConnectionRequest struct {
|
||||
BaseURL string `json:"base_url" binding:"required,url"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
Model models.ModelField `json:"model" binding:"required"`
|
||||
Provider string `json:"provider"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
|
||||
// 根据 provider 和 service_type 自动设置 endpoint
|
||||
endpoint := req.Endpoint
|
||||
queryEndpoint := req.QueryEndpoint
|
||||
|
||||
if endpoint == "" {
|
||||
switch req.Provider {
|
||||
case "gemini", "google":
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
}
|
||||
case "openai", "chatfire":
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/chat/completions"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/images/generations"
|
||||
} else if req.ServiceType == "video" {
|
||||
endpoint = "/video/generations"
|
||||
if queryEndpoint == "" {
|
||||
queryEndpoint = "/v1/video/task/{taskId}"
|
||||
}
|
||||
}
|
||||
default:
|
||||
// 默认使用 OpenAI 格式
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/chat/completions"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/images/generations"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config := &models.AIServiceConfig{
|
||||
ServiceType: req.ServiceType,
|
||||
Name: req.Name,
|
||||
Provider: req.Provider,
|
||||
BaseURL: req.BaseURL,
|
||||
APIKey: req.APIKey,
|
||||
Model: req.Model,
|
||||
Endpoint: req.Endpoint,
|
||||
QueryEndpoint: req.QueryEndpoint,
|
||||
Endpoint: endpoint,
|
||||
QueryEndpoint: queryEndpoint,
|
||||
Priority: req.Priority,
|
||||
IsDefault: req.IsDefault,
|
||||
IsActive: true,
|
||||
@@ -75,7 +112,7 @@ func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceC
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.log.Infow("AI config created", "config_id", config.ID)
|
||||
s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -125,6 +162,9 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
|
||||
if req.Name != "" {
|
||||
updates["name"] = req.Name
|
||||
}
|
||||
if req.Provider != "" {
|
||||
updates["provider"] = req.Provider
|
||||
}
|
||||
if req.BaseURL != "" {
|
||||
updates["base_url"] = req.BaseURL
|
||||
}
|
||||
@@ -137,9 +177,30 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
|
||||
if req.Priority != nil {
|
||||
updates["priority"] = *req.Priority
|
||||
}
|
||||
if req.Endpoint != "" {
|
||||
|
||||
// 如果提供了 provider,根据 provider 和 service_type 自动设置 endpoint
|
||||
if req.Provider != "" && req.Endpoint == "" {
|
||||
provider := req.Provider
|
||||
serviceType := config.ServiceType
|
||||
|
||||
switch provider {
|
||||
case "gemini", "google":
|
||||
if serviceType == "text" || serviceType == "image" {
|
||||
updates["endpoint"] = "/v1beta/models/{model}:generateContent"
|
||||
}
|
||||
case "openai", "chatfire":
|
||||
if serviceType == "text" {
|
||||
updates["endpoint"] = "/chat/completions"
|
||||
} else if serviceType == "image" {
|
||||
updates["endpoint"] = "/images/generations"
|
||||
} else if serviceType == "video" {
|
||||
updates["endpoint"] = "/video/generations"
|
||||
}
|
||||
}
|
||||
} else if req.Endpoint != "" {
|
||||
updates["endpoint"] = req.Endpoint
|
||||
}
|
||||
|
||||
// 允许清空query_endpoint,所以不检查是否为空
|
||||
updates["query_endpoint"] = req.QueryEndpoint
|
||||
if req.Settings != "" {
|
||||
@@ -179,13 +240,51 @@ func (s *AIService) DeleteConfig(configID uint) error {
|
||||
}
|
||||
|
||||
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
|
||||
s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model))
|
||||
|
||||
// 使用第一个模型进行测试
|
||||
model := ""
|
||||
if len(req.Model) > 0 {
|
||||
model = req.Model[0]
|
||||
}
|
||||
client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint)
|
||||
return client.TestConnection()
|
||||
s.log.Infow("Using model for test", "model", model, "provider", req.Provider)
|
||||
|
||||
// 根据 provider 参数选择客户端
|
||||
var client ai.AIClient
|
||||
var endpoint string
|
||||
|
||||
switch req.Provider {
|
||||
case "gemini", "google":
|
||||
// Gemini
|
||||
s.log.Infow("Using Gemini client", "baseURL", req.BaseURL)
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||
case "openai", "chatfire":
|
||||
// OpenAI 格式(包括 chatfire 等)
|
||||
s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider)
|
||||
endpoint = req.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = "/chat/completions"
|
||||
}
|
||||
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||
default:
|
||||
// 默认使用 OpenAI 格式
|
||||
s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL)
|
||||
endpoint = req.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = "/chat/completions"
|
||||
}
|
||||
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||
}
|
||||
|
||||
s.log.Infow("Calling TestConnection on client", "endpoint", endpoint)
|
||||
err := client.TestConnection()
|
||||
if err != nil {
|
||||
s.log.Errorw("TestConnection failed", "error", err)
|
||||
} else {
|
||||
s.log.Infow("TestConnection succeeded")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
|
||||
@@ -228,7 +327,7 @@ func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*mo
|
||||
return nil, errors.New("no config found for model: " + modelName)
|
||||
}
|
||||
|
||||
func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
|
||||
func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) {
|
||||
config, err := s.GetDefaultConfig(serviceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -240,7 +339,25 @@ func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil
|
||||
// 使用数据库配置中的 endpoint,如果为空则根据 provider 设置默认值
|
||||
endpoint := config.Endpoint
|
||||
if endpoint == "" {
|
||||
switch config.Provider {
|
||||
case "gemini", "google":
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
default:
|
||||
endpoint = "/chat/completions"
|
||||
}
|
||||
}
|
||||
|
||||
// 根据 provider 创建对应的客户端
|
||||
switch config.Provider {
|
||||
case "gemini", "google":
|
||||
return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
default:
|
||||
// openai, chatfire 等其他厂商都使用 OpenAI 格式
|
||||
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {
|
||||
|
||||
Reference in New Issue
Block a user