This commit is contained in:
Connor
2026-01-12 13:17:11 +08:00
parent 95851f8e69
commit 9600fc542c
132 changed files with 35734 additions and 5 deletions

188
pkg/ai/openai_client.go Normal file
View File

@@ -0,0 +1,188 @@
package ai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type OpenAIClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
HTTPClient *http.Client
}
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type ErrorResponse struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
} `json:"error"`
}
func NewOpenAIClient(baseURL, apiKey, model, endpoint string) *OpenAIClient {
if endpoint == "" {
endpoint = "/v1/chat/completions"
}
return &OpenAIClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*ChatCompletionRequest)) (*ChatCompletionResponse, error) {
req := &ChatCompletionRequest{
Model: c.Model,
Messages: messages,
}
for _, option := range options {
option(req)
}
return c.sendChatRequest(req)
}
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
jsonData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := c.BaseURL + c.Endpoint
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
if err := json.Unmarshal(body, &errResp); err != nil {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
return nil, fmt.Errorf("API error: %s", errResp.Error.Message)
}
var chatResp ChatCompletionResponse
if err := json.Unmarshal(body, &chatResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return &chatResp, nil
}
func WithTemperature(temp float64) func(*ChatCompletionRequest) {
return func(req *ChatCompletionRequest) {
req.Temperature = temp
}
}
func WithMaxTokens(tokens int) func(*ChatCompletionRequest) {
return func(req *ChatCompletionRequest) {
req.MaxTokens = tokens
}
}
func WithTopP(topP float64) func(*ChatCompletionRequest) {
return func(req *ChatCompletionRequest) {
req.TopP = topP
}
}
func (c *OpenAIClient) GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) {
messages := []ChatMessage{}
if systemPrompt != "" {
messages = append(messages, ChatMessage{
Role: "system",
Content: systemPrompt,
})
}
messages = append(messages, ChatMessage{
Role: "user",
Content: prompt,
})
resp, err := c.ChatCompletion(messages, options...)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no response from API")
}
return resp.Choices[0].Message.Content, nil
}
func (c *OpenAIClient) TestConnection() error {
messages := []ChatMessage{
{
Role: "user",
Content: "Hello",
},
}
_, err := c.ChatCompletion(messages, WithMaxTokens(10))
return err
}

89
pkg/config/config.go Normal file
View File

@@ -0,0 +1,89 @@
package config
import (
"fmt"
"github.com/spf13/viper"
)
type Config struct {
App AppConfig `mapstructure:"app"`
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Storage StorageConfig `mapstructure:"storage"`
AI AIConfig `mapstructure:"ai"`
}
type AppConfig struct {
Name string `mapstructure:"name"`
Version string `mapstructure:"version"`
Debug bool `mapstructure:"debug"`
}
type ServerConfig struct {
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
CORSOrigins []string `mapstructure:"cors_origins"`
ReadTimeout int `mapstructure:"read_timeout"`
WriteTimeout int `mapstructure:"write_timeout"`
}
type DatabaseConfig struct {
Type string `mapstructure:"type"` // sqlite, mysql
Path string `mapstructure:"path"` // SQLite数据库文件路径
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Database string `mapstructure:"database"`
Charset string `mapstructure:"charset"`
MaxIdle int `mapstructure:"max_idle"`
MaxOpen int `mapstructure:"max_open"`
}
type StorageConfig struct {
Type string `mapstructure:"type"` // local, minio
LocalPath string `mapstructure:"local_path"` // 本地存储路径
BaseURL string `mapstructure:"base_url"` // 访问URL前缀
}
type AIConfig struct {
DefaultTextProvider string `mapstructure:"default_text_provider"`
DefaultImageProvider string `mapstructure:"default_image_provider"`
DefaultVideoProvider string `mapstructure:"default_video_provider"`
}
func LoadConfig() (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("./configs")
viper.AddConfigPath(".")
viper.AutomaticEnv()
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config: %w", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
return &config, nil
}
func (c *DatabaseConfig) DSN() string {
if c.Type == "sqlite" {
return c.Path
}
// MySQL DSN
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
c.User,
c.Password,
c.Host,
c.Port,
c.Database,
c.Charset,
)
}

384
pkg/image/image_client.go Normal file
View File

@@ -0,0 +1,384 @@
package image
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type ImageClient interface {
GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error)
GetTaskStatus(taskID string) (*ImageResult, error)
}
type ImageResult struct {
TaskID string
Status string
ImageURL string
Width int
Height int
Error string
Completed bool
}
type ImageOptions struct {
NegativePrompt string
Size string
Quality string
Style string
Steps int
CfgScale float64
Seed int64
Model string
Width int
Height int
ReferenceImages []string // 参考图片URL列表
}
type ImageOption func(*ImageOptions)
func WithNegativePrompt(prompt string) ImageOption {
return func(o *ImageOptions) {
o.NegativePrompt = prompt
}
}
func WithSize(size string) ImageOption {
return func(o *ImageOptions) {
o.Size = size
}
}
func WithQuality(quality string) ImageOption {
return func(o *ImageOptions) {
o.Quality = quality
}
}
func WithStyle(style string) ImageOption {
return func(o *ImageOptions) {
o.Style = style
}
}
func WithSteps(steps int) ImageOption {
return func(o *ImageOptions) {
o.Steps = steps
}
}
func WithCfgScale(scale float64) ImageOption {
return func(o *ImageOptions) {
o.CfgScale = scale
}
}
func WithSeed(seed int64) ImageOption {
return func(o *ImageOptions) {
o.Seed = seed
}
}
func WithModel(model string) ImageOption {
return func(o *ImageOptions) {
o.Model = model
}
}
func WithDimensions(width, height int) ImageOption {
return func(o *ImageOptions) {
o.Width = width
o.Height = height
}
}
func WithReferenceImages(images []string) ImageOption {
return func(o *ImageOptions) {
o.ReferenceImages = images
}
}
type OpenAIImageClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type DALLERequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
N int `json:"n"`
Image []string `json:"image,omitempty"` // 参考图片URL列表
}
type DALLEResponse struct {
Created int64 `json:"created"`
Data []struct {
URL string `json:"url"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
} `json:"data"`
}
func NewOpenAIImageClient(baseURL, apiKey, model string) *OpenAIImageClient {
return &OpenAIImageClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Size: "1920x1920",
Quality: "standard",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := DALLERequest{
Model: model,
Prompt: prompt,
Size: options.Size,
Quality: options.Quality,
N: 1,
Image: options.ReferenceImages,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/images/generations"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
// 打印原始响应以便调试
fmt.Printf("OpenAI API Response: %s\n", string(body))
var result DALLEResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no image generated, response: %s", string(body))
}
return &ImageResult{
Status: "completed",
ImageURL: result.Data[0].URL,
Completed: true,
}, nil
}
func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
return nil, fmt.Errorf("not supported for OpenAI/DALL-E")
}
type StableDiffusionClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type SDRequest struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
Model string `json:"model,omitempty"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
CfgScale float64 `json:"cfg_scale,omitempty"`
Seed int64 `json:"seed,omitempty"`
Samples int `json:"samples"`
Image []string `json:"image,omitempty"` // 参考图片URL列表
}
type SDResponse struct {
Status string `json:"status"`
TaskID string `json:"task_id,omitempty"`
Output []struct {
URL string `json:"url"`
} `json:"output,omitempty"`
Error string `json:"error,omitempty"`
}
func NewStableDiffusionClient(baseURL, apiKey, model string) *StableDiffusionClient {
return &StableDiffusionClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *StableDiffusionClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Width: 1024,
Height: 1024,
Steps: 30,
CfgScale: 7.5,
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := SDRequest{
Prompt: prompt,
NegativePrompt: options.NegativePrompt,
Model: model,
Width: options.Width,
Height: options.Height,
Steps: options.Steps,
CfgScale: options.CfgScale,
Seed: options.Seed,
Samples: 1,
Image: options.ReferenceImages,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/images/generations"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result SDResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("SD error: %s", result.Error)
}
if result.Status == "processing" {
return &ImageResult{
TaskID: result.TaskID,
Status: "processing",
Completed: false,
}, nil
}
if len(result.Output) == 0 {
return nil, fmt.Errorf("no image generated")
}
return &ImageResult{
Status: "completed",
ImageURL: result.Output[0].URL,
Width: options.Width,
Height: options.Height,
Completed: true,
}, nil
}
func (c *StableDiffusionClient) GetTaskStatus(taskID string) (*ImageResult, error) {
endpoint := c.BaseURL + "/v1/images/status/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result SDResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
imageResult := &ImageResult{
TaskID: taskID,
Status: result.Status,
Completed: result.Status == "completed",
}
if result.Error != "" {
imageResult.Error = result.Error
}
if len(result.Output) > 0 {
imageResult.ImageURL = result.Output[0].URL
}
return imageResult, nil
}

35
pkg/logger/logger.go Normal file
View File

@@ -0,0 +1,35 @@
package logger
import (
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type Logger struct {
*zap.SugaredLogger
}
func NewLogger(debug bool) *Logger {
var config zap.Config
if debug {
config = zap.NewDevelopmentConfig()
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
// 在开发模式下,禁用时间戳和调用者信息,使输出更简洁
config.EncoderConfig.TimeKey = ""
config.EncoderConfig.CallerKey = ""
} else {
config = zap.NewProductionConfig()
config.EncoderConfig.TimeKey = "timestamp"
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
}
logger, err := config.Build()
if err != nil {
panic(err)
}
return &Logger{
SugaredLogger: logger.Sugar(),
}
}

119
pkg/response/response.go Normal file
View File

@@ -0,0 +1,119 @@
package response
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
)
type Response struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Error *ErrorInfo `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Timestamp string `json:"timestamp"`
}
type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
Details interface{} `json:"details,omitempty"`
}
type PaginationData struct {
Items interface{} `json:"items"`
Pagination Pagination `json:"pagination"`
}
type Pagination struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Total int64 `json:"total"`
TotalPages int64 `json:"total_pages"`
}
func Success(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, Response{
Success: true,
Data: data,
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func SuccessWithMessage(c *gin.Context, message string, data interface{}) {
c.JSON(http.StatusOK, Response{
Success: true,
Data: data,
Message: message,
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func Created(c *gin.Context, data interface{}) {
c.JSON(http.StatusCreated, Response{
Success: true,
Data: data,
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func SuccessWithPagination(c *gin.Context, items interface{}, total int64, page int, pageSize int) {
totalPages := (total + int64(pageSize) - 1) / int64(pageSize)
c.JSON(http.StatusOK, Response{
Success: true,
Data: PaginationData{
Items: items,
Pagination: Pagination{
Page: page,
PageSize: pageSize,
Total: total,
TotalPages: totalPages,
},
},
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func Error(c *gin.Context, statusCode int, errCode string, message string) {
c.JSON(statusCode, Response{
Success: false,
Error: &ErrorInfo{
Code: errCode,
Message: message,
},
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func ErrorWithDetails(c *gin.Context, statusCode int, errCode string, message string, details interface{}) {
c.JSON(statusCode, Response{
Success: false,
Error: &ErrorInfo{
Code: errCode,
Message: message,
Details: details,
},
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, "BAD_REQUEST", message)
}
func Unauthorized(c *gin.Context, message string) {
Error(c, http.StatusUnauthorized, "UNAUTHORIZED", message)
}
func Forbidden(c *gin.Context, message string) {
Error(c, http.StatusForbidden, "FORBIDDEN", message)
}
func NotFound(c *gin.Context, message string) {
Error(c, http.StatusNotFound, "NOT_FOUND", message)
}
func InternalError(c *gin.Context, message string) {
Error(c, http.StatusInternalServerError, "INTERNAL_ERROR", message)
}

153
pkg/utils/json_parser.go Normal file
View File

@@ -0,0 +1,153 @@
package utils
import (
"encoding/json"
"fmt"
"regexp"
"strings"
)
// SafeParseAIJSON 安全地解析AI返回的JSON处理常见的格式问题
// 包括:
// 1. 移除Markdown代码块标记
// 2. 提取JSON对象
// 3. 清理多余的空白和换行
// 4. 尝试修复截断的JSON
// 5. 提供详细的错误信息
func SafeParseAIJSON(aiResponse string, v interface{}) error {
if aiResponse == "" {
return fmt.Errorf("AI返回内容为空")
}
// 1. 移除可能的Markdown代码块标记
cleaned := strings.TrimSpace(aiResponse)
cleaned = regexp.MustCompile("(?m)^```json\\s*").ReplaceAllString(cleaned, "")
cleaned = regexp.MustCompile("(?m)^```\\s*").ReplaceAllString(cleaned, "")
cleaned = strings.TrimSpace(cleaned)
// 2. 提取JSON对象 (查找第一个 { 到最后一个 })
jsonRegex := regexp.MustCompile(`(?s)\{.*\}`)
jsonMatch := jsonRegex.FindString(cleaned)
if jsonMatch == "" {
return fmt.Errorf("响应中未找到有效的JSON对象原始响应: %s", truncateString(aiResponse, 200))
}
// 3. 尝试解析JSON
err := json.Unmarshal([]byte(jsonMatch), v)
if err == nil {
return nil // 解析成功
}
// 4. 如果解析失败尝试修复截断的JSON
fixedJSON := attemptJSONRepair(jsonMatch)
if fixedJSON != jsonMatch {
if err := json.Unmarshal([]byte(fixedJSON), v); err == nil {
return nil // 修复后解析成功
}
}
// 5. 提供详细的错误上下文
if jsonErr, ok := err.(*json.SyntaxError); ok {
errorPos := int(jsonErr.Offset)
start := maxInt(0, errorPos-100)
end := minInt(len(jsonMatch), errorPos+100)
context := jsonMatch[start:end]
marker := strings.Repeat(" ", errorPos-start) + "^"
return fmt.Errorf(
"JSON解析失败: %s\n错误位置附近:\n%s\n%s",
jsonErr.Error(),
context,
marker,
)
}
return fmt.Errorf("JSON解析失败: %w\n原始响应: %s", err, truncateString(jsonMatch, 300))
}
// attemptJSONRepair 尝试修复常见的JSON问题
func attemptJSONRepair(jsonStr string) string {
// 1. 处理未闭合的字符串
// 如果最后一个字符不是 },尝试补全
trimmed := strings.TrimSpace(jsonStr)
// 2. 检查是否有未闭合的引号
if strings.Count(trimmed, `"`)%2 != 0 {
// 有奇数个引号,尝试补全最后一个引号
trimmed += `"`
}
// 3. 统计括号
openBraces := strings.Count(trimmed, "{")
closeBraces := strings.Count(trimmed, "}")
openBrackets := strings.Count(trimmed, "[")
closeBrackets := strings.Count(trimmed, "]")
// 4. 补全未闭合的数组
for i := 0; i < openBrackets-closeBrackets; i++ {
trimmed += "]"
}
// 5. 补全未闭合的对象
for i := 0; i < openBraces-closeBraces; i++ {
trimmed += "}"
}
return trimmed
}
// ExtractJSONFromText 从文本中提取JSON对象或数组
func ExtractJSONFromText(text string) string {
text = strings.TrimSpace(text)
// 移除Markdown代码块
text = regexp.MustCompile("(?m)^```json\\s*").ReplaceAllString(text, "")
text = regexp.MustCompile("(?m)^```\\s*").ReplaceAllString(text, "")
text = strings.TrimSpace(text)
// 查找JSON对象
if idx := strings.Index(text, "{"); idx != -1 {
if lastIdx := strings.LastIndex(text, "}"); lastIdx != -1 && lastIdx > idx {
return text[idx : lastIdx+1]
}
}
// 查找JSON数组
if idx := strings.Index(text, "["); idx != -1 {
if lastIdx := strings.LastIndex(text, "]"); lastIdx != -1 && lastIdx > idx {
return text[idx : lastIdx+1]
}
}
return text
}
// ValidateJSON 验证JSON字符串是否有效
func ValidateJSON(jsonStr string) error {
var js json.RawMessage
return json.Unmarshal([]byte(jsonStr), &js)
}
// Helper functions
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}

28
pkg/utils/random.go Normal file
View File

@@ -0,0 +1,28 @@
package utils
import (
"math/rand"
"time"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateVerificationCode(length int) string {
digits := "0123456789"
code := make([]byte, length)
for i := range code {
code[i] = digits[rand.Intn(len(digits))]
}
return string(code)
}
func GenerateRandomString(length int) string {
chars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = chars[rand.Intn(len(chars))]
}
return string(result)
}

192
pkg/video/minimax_client.go Normal file
View File

@@ -0,0 +1,192 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// MinimaxClient Minimax视频生成客户端
type MinimaxClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type MinimaxSubjectReference struct {
Type string `json:"type"`
Image []string `json:"image"`
}
type MinimaxRequest struct {
Prompt string `json:"prompt"`
FirstFrameImage string `json:"first_frame_image,omitempty"`
LastFrameImage string `json:"last_frame_image,omitempty"`
SubjectReference []MinimaxSubjectReference `json:"subject_reference,omitempty"`
Model string `json:"model"`
Duration int `json:"duration,omitempty"`
Resolution string `json:"resolution,omitempty"`
}
type MinimaxResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
BaseResp struct {
StatusCode int `json:"status_code"`
StatusMsg string `json:"status_msg"`
} `json:"base_resp"`
Video struct {
URL string `json:"url"`
Duration int `json:"duration"`
} `json:"video"`
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
func NewMinimaxClient(baseURL, apiKey, model string) *MinimaxClient {
return &MinimaxClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 300 * time.Second,
},
}
}
// GenerateVideo 生成视频(支持首尾帧和主体参考)
func (c *MinimaxClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 6,
Resolution: "1080P",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := MinimaxRequest{
Prompt: prompt,
Model: model,
Duration: options.Duration,
}
// 设置分辨率
if options.Resolution != "" {
reqBody.Resolution = options.Resolution
}
// 如果有首帧图片从imageURL或FirstFrameURL
if options.FirstFrameURL != "" {
reqBody.FirstFrameImage = options.FirstFrameURL
} else if imageURL != "" {
reqBody.FirstFrameImage = imageURL
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/video_generation"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result MinimaxResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error.Message != "" {
return nil, fmt.Errorf("minimax error: %s", result.Error.Message)
}
videoResult := &VideoResult{
TaskID: result.TaskID,
Status: result.Status,
Completed: result.Status == "completed",
Duration: result.Video.Duration,
}
if result.Video.URL != "" {
videoResult.VideoURL = result.Video.URL
videoResult.Completed = true
}
return videoResult, nil
}
func (c *MinimaxClient) GetTaskStatus(taskID string) (*VideoResult, error) {
endpoint := c.BaseURL + "/v1/video_generation/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result MinimaxResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.TaskID,
Status: result.Status,
Completed: result.Status == "completed",
Duration: result.Video.Duration,
}
if result.Error.Message != "" {
videoResult.Error = result.Error.Message
}
if result.Video.URL != "" {
videoResult.VideoURL = result.Video.URL
videoResult.Completed = true
}
return videoResult, nil
}

View File

@@ -0,0 +1,178 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"time"
)
type OpenAISoraClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type OpenAISoraResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model"`
Status string `json:"status"`
Progress int `json:"progress"`
CreatedAt int64 `json:"created_at"`
CompletedAt int64 `json:"completed_at"`
Size string `json:"size"`
Seconds string `json:"seconds"`
Quality string `json:"quality"`
VideoURL string `json:"video_url"` // 直接的video_url字段
Video struct {
URL string `json:"url"`
} `json:"video"` // 嵌套的video.url字段兼容
Error struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error"`
}
func NewOpenAISoraClient(baseURL, apiKey, model string) *OpenAISoraClient {
return &OpenAISoraClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 300 * time.Second,
},
}
}
func (c *OpenAISoraClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 4,
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
writer.WriteField("model", model)
writer.WriteField("prompt", prompt)
if imageURL != "" {
writer.WriteField("input_reference", imageURL)
}
if options.Duration > 0 {
writer.WriteField("seconds", fmt.Sprintf("%d", options.Duration))
}
if options.Resolution != "" {
writer.WriteField("size", options.Resolution)
}
writer.Close()
endpoint := c.BaseURL + "/v1/videos"
req, err := http.NewRequest("POST", endpoint, body)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(respBody))
}
var result OpenAISoraResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error.Message != "" {
return nil, fmt.Errorf("openai error: %s", result.Error.Message)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "completed",
}
// 优先使用video_url字段兼容video.url嵌套结构
if result.VideoURL != "" {
videoResult.VideoURL = result.VideoURL
} else if result.Video.URL != "" {
videoResult.VideoURL = result.Video.URL
}
return videoResult, nil
}
func (c *OpenAISoraClient) GetTaskStatus(taskID string) (*VideoResult, error) {
endpoint := c.BaseURL + "/v1/videos/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result OpenAISoraResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "completed",
}
if result.Error.Message != "" {
videoResult.Error = result.Error.Message
}
// 优先使用video_url字段兼容video.url嵌套结构
if result.VideoURL != "" {
videoResult.VideoURL = result.VideoURL
} else if result.Video.URL != "" {
videoResult.VideoURL = result.Video.URL
}
return videoResult, nil
}

427
pkg/video/video_client.go Normal file
View File

@@ -0,0 +1,427 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type VideoClient interface {
GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error)
GetTaskStatus(taskID string) (*VideoResult, error)
}
type VideoResult struct {
TaskID string
Status string
VideoURL string
ThumbnailURL string
Duration int
Width int
Height int
Error string
Completed bool
}
type VideoOptions struct {
Model string
Duration int
FPS int
Resolution string
AspectRatio string
Style string
MotionLevel int
CameraMotion string
Seed int64
FirstFrameURL string
LastFrameURL string
ReferenceImageURLs []string
}
type VideoOption func(*VideoOptions)
func WithModel(model string) VideoOption {
return func(o *VideoOptions) {
o.Model = model
}
}
func WithDuration(duration int) VideoOption {
return func(o *VideoOptions) {
o.Duration = duration
}
}
func WithFPS(fps int) VideoOption {
return func(o *VideoOptions) {
o.FPS = fps
}
}
func WithResolution(resolution string) VideoOption {
return func(o *VideoOptions) {
o.Resolution = resolution
}
}
func WithAspectRatio(ratio string) VideoOption {
return func(o *VideoOptions) {
o.AspectRatio = ratio
}
}
func WithStyle(style string) VideoOption {
return func(o *VideoOptions) {
o.Style = style
}
}
func WithMotionLevel(level int) VideoOption {
return func(o *VideoOptions) {
o.MotionLevel = level
}
}
func WithCameraMotion(motion string) VideoOption {
return func(o *VideoOptions) {
o.CameraMotion = motion
}
}
func WithSeed(seed int64) VideoOption {
return func(o *VideoOptions) {
o.Seed = seed
}
}
func WithFirstFrame(url string) VideoOption {
return func(o *VideoOptions) {
o.FirstFrameURL = url
}
}
func WithLastFrame(url string) VideoOption {
return func(o *VideoOptions) {
o.LastFrameURL = url
}
}
func WithReferenceImages(urls []string) VideoOption {
return func(o *VideoOptions) {
o.ReferenceImageURLs = urls
}
}
type RunwayClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type RunwayRequest struct {
Model string `json:"model"`
PromptImage string `json:"prompt_image"`
PromptText string `json:"prompt_text"`
Duration int `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
type RunwayResponse struct {
ID string `json:"id"`
Status string `json:"status"`
Output struct {
URL string `json:"url"`
} `json:"output"`
Error string `json:"error,omitempty"`
}
func NewRunwayClient(baseURL, apiKey, model string) *RunwayClient {
return &RunwayClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 180 * time.Second,
},
}
}
func (c *RunwayClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 5,
AspectRatio: "16:9",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := RunwayRequest{
Model: model,
PromptImage: imageURL,
PromptText: prompt,
Duration: options.Duration,
AspectRatio: options.AspectRatio,
Seed: options.Seed,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/video/generate"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result RunwayResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("runway error: %s", result.Error)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "succeeded",
}
if result.Output.URL != "" {
videoResult.VideoURL = result.Output.URL
}
return videoResult, nil
}
func (c *RunwayClient) GetTaskStatus(taskID string) (*VideoResult, error) {
endpoint := c.BaseURL + "/v1/video/status/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result RunwayResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "succeeded",
}
if result.Error != "" {
videoResult.Error = result.Error
}
if result.Output.URL != "" {
videoResult.VideoURL = result.Output.URL
}
return videoResult, nil
}
type PikaClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type PikaRequest struct {
Model string `json:"model"`
Image string `json:"image"`
Prompt string `json:"prompt"`
Duration int `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
Motion int `json:"motion,omitempty"`
CameraMotion string `json:"camera_motion,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
type PikaResponse struct {
JobID string `json:"job_id"`
Status string `json:"status"`
Result struct {
VideoURL string `json:"video_url"`
} `json:"result"`
Error string `json:"error,omitempty"`
}
func NewPikaClient(baseURL, apiKey, model string) *PikaClient {
return &PikaClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 180 * time.Second,
},
}
}
func (c *PikaClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 3,
AspectRatio: "16:9",
MotionLevel: 50,
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := PikaRequest{
Model: model,
Image: imageURL,
Prompt: prompt,
Duration: options.Duration,
AspectRatio: options.AspectRatio,
Motion: options.MotionLevel,
CameraMotion: options.CameraMotion,
Seed: options.Seed,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/video/generate"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result PikaResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("pika error: %s", result.Error)
}
videoResult := &VideoResult{
TaskID: result.JobID,
Status: result.Status,
Completed: result.Status == "completed",
}
if result.Result.VideoURL != "" {
videoResult.VideoURL = result.Result.VideoURL
}
return videoResult, nil
}
func (c *PikaClient) GetTaskStatus(taskID string) (*VideoResult, error) {
endpoint := c.BaseURL + "/v1/video/status/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result PikaResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.JobID,
Status: result.Status,
Completed: result.Status == "completed",
}
if result.Error != "" {
videoResult.Error = result.Error
}
if result.Result.VideoURL != "" {
videoResult.VideoURL = result.Result.VideoURL
}
return videoResult, nil
}

View File

@@ -0,0 +1,288 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// VolcesArkClient 火山引擎ARK视频生成客户端
type VolcesArkClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
QueryEndpoint string
HTTPClient *http.Client
}
type VolcesArkContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL map[string]interface{} `json:"image_url,omitempty"`
Role string `json:"role,omitempty"`
}
type VolcesArkRequest struct {
Model string `json:"model"`
Content []VolcesArkContent `json:"content"`
GenerateAudio bool `json:"generate_audio,omitempty"`
}
type VolcesArkResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Status string `json:"status"`
Content struct {
VideoURL string `json:"video_url"`
} `json:"content"`
Usage struct {
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
Seed int `json:"seed"`
Resolution string `json:"resolution"`
Ratio string `json:"ratio"`
Duration int `json:"duration"`
FramesPerSecond int `json:"framespersecond"`
ServiceTier string `json:"service_tier"`
ExecutionExpiresAfter int `json:"execution_expires_after"`
GenerateAudio bool `json:"generate_audio"`
Error interface{} `json:"error,omitempty"`
}
func NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *VolcesArkClient {
if endpoint == "" {
endpoint = "/api/v3/contents/generations/tasks"
}
if queryEndpoint == "" {
queryEndpoint = endpoint
}
return &VolcesArkClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
HTTPClient: &http.Client{
Timeout: 300 * time.Second,
},
}
}
// GenerateVideo 生成视频(支持首帧、首尾帧、参考图等多种模式)
func (c *VolcesArkClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 5,
AspectRatio: "adaptive",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
// 构建prompt文本包含duration和ratio参数
promptText := prompt
if options.AspectRatio != "" {
promptText += fmt.Sprintf(" --ratio %s", options.AspectRatio)
}
if options.Duration > 0 {
promptText += fmt.Sprintf(" --dur %d", options.Duration)
}
content := []VolcesArkContent{
{
Type: "text",
Text: promptText,
},
}
// 处理不同的图片模式
// 1. 组图模式多个reference_image
if len(options.ReferenceImageURLs) > 0 {
for _, refURL := range options.ReferenceImageURLs {
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": refURL,
},
Role: "reference_image",
})
}
} else if options.FirstFrameURL != "" && options.LastFrameURL != "" {
// 2. 首尾帧模式
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": options.FirstFrameURL,
},
Role: "first_frame",
})
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": options.LastFrameURL,
},
Role: "last_frame",
})
} else if imageURL != "" {
// 3. 单图模式(默认)
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": imageURL,
},
// 单图模式不需要role
})
} else if options.FirstFrameURL != "" {
// 4. 只有首帧
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": options.FirstFrameURL,
},
Role: "first_frame",
})
}
// 只有 seedance-1-5-pro 模型支持 generate_audio 参数
generateAudio := false
if strings.Contains(strings.ToLower(model), "seedance-1-5-pro") {
generateAudio = true
}
reqBody := VolcesArkRequest{
Model: model,
Content: content,
GenerateAudio: generateAudio,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + c.Endpoint
fmt.Printf("[VolcesARK] Generating video - Endpoint: %s, FullURL: %s, Model: %s\n", c.Endpoint, endpoint, model)
fmt.Printf("[VolcesARK] Request body: %s\n", string(jsonData))
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
fmt.Printf("[VolcesARK] Response status: %d, body: %s\n", resp.StatusCode, string(body))
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result VolcesArkResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
fmt.Printf("[VolcesARK] Video generation initiated - TaskID: %s, Status: %s\n", result.ID, result.Status)
if result.Error != nil {
errorMsg := fmt.Sprintf("%v", result.Error)
return nil, fmt.Errorf("volces error: %s", errorMsg)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "completed" || result.Status == "succeeded",
Duration: result.Duration,
}
if result.Content.VideoURL != "" {
videoResult.VideoURL = result.Content.VideoURL
videoResult.Completed = true
}
return videoResult, nil
}
func (c *VolcesArkClient) GetTaskStatus(taskID string) (*VideoResult, error) {
// 替换占位符{taskId}或直接拼接
queryPath := c.QueryEndpoint
if contains := bytes.Contains([]byte(queryPath), []byte("{taskId}")); contains {
queryPath = string(bytes.ReplaceAll([]byte(queryPath), []byte("{taskId}"), []byte(taskID)))
} else {
queryPath = queryPath + "/" + taskID
}
endpoint := c.BaseURL + queryPath
fmt.Printf("[VolcesARK] Querying task status - TaskID: %s, QueryEndpoint: %s, FullURL: %s\n", taskID, c.QueryEndpoint, endpoint)
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
fmt.Printf("[VolcesARK] Response body: %s\n", string(body))
var result VolcesArkResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
fmt.Printf("[VolcesARK] Parsed result - ID: %s, Status: %s, VideoURL: %s\n", result.ID, result.Status, result.Content.VideoURL)
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "completed" || result.Status == "succeeded",
Duration: result.Duration,
}
if result.Error != nil {
videoResult.Error = fmt.Sprintf("%v", result.Error)
}
if result.Content.VideoURL != "" {
videoResult.VideoURL = result.Content.VideoURL
videoResult.Completed = true
}
return videoResult, nil
}