init
This commit is contained in:
188
pkg/ai/openai_client.go
Normal file
188
pkg/ai/openai_client.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Endpoint string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func NewOpenAIClient(baseURL, apiKey, model, endpoint string) *OpenAIClient {
|
||||
if endpoint == "" {
|
||||
endpoint = "/v1/chat/completions"
|
||||
}
|
||||
|
||||
return &OpenAIClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*ChatCompletionRequest)) (*ChatCompletionResponse, error) {
|
||||
req := &ChatCompletionRequest{
|
||||
Model: c.Model,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(req)
|
||||
}
|
||||
|
||||
return c.sendChatRequest(req)
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := c.BaseURL + c.Endpoint
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp ErrorResponse
|
||||
if err := json.Unmarshal(body, &errResp); err != nil {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil, fmt.Errorf("API error: %s", errResp.Error.Message)
|
||||
}
|
||||
|
||||
var chatResp ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return &chatResp, nil
|
||||
}
|
||||
|
||||
func WithTemperature(temp float64) func(*ChatCompletionRequest) {
|
||||
return func(req *ChatCompletionRequest) {
|
||||
req.Temperature = temp
|
||||
}
|
||||
}
|
||||
|
||||
func WithMaxTokens(tokens int) func(*ChatCompletionRequest) {
|
||||
return func(req *ChatCompletionRequest) {
|
||||
req.MaxTokens = tokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithTopP(topP float64) func(*ChatCompletionRequest) {
|
||||
return func(req *ChatCompletionRequest) {
|
||||
req.TopP = topP
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) {
|
||||
messages := []ChatMessage{}
|
||||
|
||||
if systemPrompt != "" {
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
resp, err := c.ChatCompletion(messages, options...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no response from API")
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) TestConnection() error {
|
||||
messages := []ChatMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := c.ChatCompletion(messages, WithMaxTokens(10))
|
||||
return err
|
||||
}
|
||||
89
pkg/config/config.go
Normal file
89
pkg/config/config.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
App AppConfig `mapstructure:"app"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Storage StorageConfig `mapstructure:"storage"`
|
||||
AI AIConfig `mapstructure:"ai"`
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
Name string `mapstructure:"name"`
|
||||
Version string `mapstructure:"version"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `mapstructure:"port"`
|
||||
Host string `mapstructure:"host"`
|
||||
CORSOrigins []string `mapstructure:"cors_origins"`
|
||||
ReadTimeout int `mapstructure:"read_timeout"`
|
||||
WriteTimeout int `mapstructure:"write_timeout"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Type string `mapstructure:"type"` // sqlite, mysql
|
||||
Path string `mapstructure:"path"` // SQLite数据库文件路径
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
Database string `mapstructure:"database"`
|
||||
Charset string `mapstructure:"charset"`
|
||||
MaxIdle int `mapstructure:"max_idle"`
|
||||
MaxOpen int `mapstructure:"max_open"`
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
Type string `mapstructure:"type"` // local, minio
|
||||
LocalPath string `mapstructure:"local_path"` // 本地存储路径
|
||||
BaseURL string `mapstructure:"base_url"` // 访问URL前缀
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
DefaultTextProvider string `mapstructure:"default_text_provider"`
|
||||
DefaultImageProvider string `mapstructure:"default_image_provider"`
|
||||
DefaultVideoProvider string `mapstructure:"default_video_provider"`
|
||||
}
|
||||
|
||||
func LoadConfig() (*Config, error) {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath("./configs")
|
||||
viper.AddConfigPath(".")
|
||||
|
||||
viper.AutomaticEnv()
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := viper.Unmarshal(&config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (c *DatabaseConfig) DSN() string {
|
||||
if c.Type == "sqlite" {
|
||||
return c.Path
|
||||
}
|
||||
// MySQL DSN
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
|
||||
c.User,
|
||||
c.Password,
|
||||
c.Host,
|
||||
c.Port,
|
||||
c.Database,
|
||||
c.Charset,
|
||||
)
|
||||
}
|
||||
384
pkg/image/image_client.go
Normal file
384
pkg/image/image_client.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ImageClient interface {
|
||||
GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error)
|
||||
GetTaskStatus(taskID string) (*ImageResult, error)
|
||||
}
|
||||
|
||||
type ImageResult struct {
|
||||
TaskID string
|
||||
Status string
|
||||
ImageURL string
|
||||
Width int
|
||||
Height int
|
||||
Error string
|
||||
Completed bool
|
||||
}
|
||||
|
||||
type ImageOptions struct {
|
||||
NegativePrompt string
|
||||
Size string
|
||||
Quality string
|
||||
Style string
|
||||
Steps int
|
||||
CfgScale float64
|
||||
Seed int64
|
||||
Model string
|
||||
Width int
|
||||
Height int
|
||||
ReferenceImages []string // 参考图片URL列表
|
||||
}
|
||||
|
||||
type ImageOption func(*ImageOptions)
|
||||
|
||||
func WithNegativePrompt(prompt string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.NegativePrompt = prompt
|
||||
}
|
||||
}
|
||||
|
||||
func WithSize(size string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Size = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithQuality(quality string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Quality = quality
|
||||
}
|
||||
}
|
||||
|
||||
func WithStyle(style string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Style = style
|
||||
}
|
||||
}
|
||||
|
||||
func WithSteps(steps int) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Steps = steps
|
||||
}
|
||||
}
|
||||
|
||||
func WithCfgScale(scale float64) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.CfgScale = scale
|
||||
}
|
||||
}
|
||||
|
||||
func WithSeed(seed int64) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Seed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithModel(model string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithDimensions(width, height int) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Width = width
|
||||
o.Height = height
|
||||
}
|
||||
}
|
||||
|
||||
func WithReferenceImages(images []string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.ReferenceImages = images
|
||||
}
|
||||
}
|
||||
|
||||
type OpenAIImageClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type DALLERequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
N int `json:"n"`
|
||||
Image []string `json:"image,omitempty"` // 参考图片URL列表
|
||||
}
|
||||
|
||||
type DALLEResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []struct {
|
||||
URL string `json:"url"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func NewOpenAIImageClient(baseURL, apiKey, model string) *OpenAIImageClient {
|
||||
return &OpenAIImageClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
||||
options := &ImageOptions{
|
||||
Size: "1920x1920",
|
||||
Quality: "standard",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
reqBody := DALLERequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Size: options.Size,
|
||||
Quality: options.Quality,
|
||||
N: 1,
|
||||
Image: options.ReferenceImages,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + "/v1/images/generations"
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 打印原始响应以便调试
|
||||
fmt.Printf("OpenAI API Response: %s\n", string(body))
|
||||
|
||||
var result DALLEResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
|
||||
}
|
||||
|
||||
if len(result.Data) == 0 {
|
||||
return nil, fmt.Errorf("no image generated, response: %s", string(body))
|
||||
}
|
||||
|
||||
return &ImageResult{
|
||||
Status: "completed",
|
||||
ImageURL: result.Data[0].URL,
|
||||
Completed: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
||||
return nil, fmt.Errorf("not supported for OpenAI/DALL-E")
|
||||
}
|
||||
|
||||
type StableDiffusionClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type SDRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Samples int `json:"samples"`
|
||||
Image []string `json:"image,omitempty"` // 参考图片URL列表
|
||||
}
|
||||
|
||||
type SDResponse struct {
|
||||
Status string `json:"status"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Output []struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"output,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func NewStableDiffusionClient(baseURL, apiKey, model string) *StableDiffusionClient {
|
||||
return &StableDiffusionClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StableDiffusionClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
||||
options := &ImageOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 30,
|
||||
CfgScale: 7.5,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
reqBody := SDRequest{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: options.NegativePrompt,
|
||||
Model: model,
|
||||
Width: options.Width,
|
||||
Height: options.Height,
|
||||
Steps: options.Steps,
|
||||
CfgScale: options.CfgScale,
|
||||
Seed: options.Seed,
|
||||
Samples: 1,
|
||||
Image: options.ReferenceImages,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + "/v1/images/generations"
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result SDResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("SD error: %s", result.Error)
|
||||
}
|
||||
|
||||
if result.Status == "processing" {
|
||||
return &ImageResult{
|
||||
TaskID: result.TaskID,
|
||||
Status: "processing",
|
||||
Completed: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(result.Output) == 0 {
|
||||
return nil, fmt.Errorf("no image generated")
|
||||
}
|
||||
|
||||
return &ImageResult{
|
||||
Status: "completed",
|
||||
ImageURL: result.Output[0].URL,
|
||||
Width: options.Width,
|
||||
Height: options.Height,
|
||||
Completed: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *StableDiffusionClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
||||
endpoint := c.BaseURL + "/v1/images/status/" + taskID
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var result SDResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
imageResult := &ImageResult{
|
||||
TaskID: taskID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
imageResult.Error = result.Error
|
||||
}
|
||||
|
||||
if len(result.Output) > 0 {
|
||||
imageResult.ImageURL = result.Output[0].URL
|
||||
}
|
||||
|
||||
return imageResult, nil
|
||||
}
|
||||
35
pkg/logger/logger.go
Normal file
35
pkg/logger/logger.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
*zap.SugaredLogger
|
||||
}
|
||||
|
||||
func NewLogger(debug bool) *Logger {
|
||||
var config zap.Config
|
||||
|
||||
if debug {
|
||||
config = zap.NewDevelopmentConfig()
|
||||
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||
// 在开发模式下,禁用时间戳和调用者信息,使输出更简洁
|
||||
config.EncoderConfig.TimeKey = ""
|
||||
config.EncoderConfig.CallerKey = ""
|
||||
} else {
|
||||
config = zap.NewProductionConfig()
|
||||
config.EncoderConfig.TimeKey = "timestamp"
|
||||
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
}
|
||||
|
||||
logger, err := config.Build()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &Logger{
|
||||
SugaredLogger: logger.Sugar(),
|
||||
}
|
||||
}
|
||||
119
pkg/response/response.go
Normal file
119
pkg/response/response.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Error *ErrorInfo `json:"error,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}
|
||||
|
||||
type ErrorInfo struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details interface{} `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
type PaginationData struct {
|
||||
Items interface{} `json:"items"`
|
||||
Pagination Pagination `json:"pagination"`
|
||||
}
|
||||
|
||||
type Pagination struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Total int64 `json:"total"`
|
||||
TotalPages int64 `json:"total_pages"`
|
||||
}
|
||||
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func SuccessWithMessage(c *gin.Context, message string, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Message: message,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func Created(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusCreated, Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func SuccessWithPagination(c *gin.Context, items interface{}, total int64, page int, pageSize int) {
|
||||
totalPages := (total + int64(pageSize) - 1) / int64(pageSize)
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Success: true,
|
||||
Data: PaginationData{
|
||||
Items: items,
|
||||
Pagination: Pagination{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Total: total,
|
||||
TotalPages: totalPages,
|
||||
},
|
||||
},
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func Error(c *gin.Context, statusCode int, errCode string, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Success: false,
|
||||
Error: &ErrorInfo{
|
||||
Code: errCode,
|
||||
Message: message,
|
||||
},
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func ErrorWithDetails(c *gin.Context, statusCode int, errCode string, message string, details interface{}) {
|
||||
c.JSON(statusCode, Response{
|
||||
Success: false,
|
||||
Error: &ErrorInfo{
|
||||
Code: errCode,
|
||||
Message: message,
|
||||
Details: details,
|
||||
},
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
Error(c, http.StatusBadRequest, "BAD_REQUEST", message)
|
||||
}
|
||||
|
||||
func Unauthorized(c *gin.Context, message string) {
|
||||
Error(c, http.StatusUnauthorized, "UNAUTHORIZED", message)
|
||||
}
|
||||
|
||||
func Forbidden(c *gin.Context, message string) {
|
||||
Error(c, http.StatusForbidden, "FORBIDDEN", message)
|
||||
}
|
||||
|
||||
func NotFound(c *gin.Context, message string) {
|
||||
Error(c, http.StatusNotFound, "NOT_FOUND", message)
|
||||
}
|
||||
|
||||
func InternalError(c *gin.Context, message string) {
|
||||
Error(c, http.StatusInternalServerError, "INTERNAL_ERROR", message)
|
||||
}
|
||||
153
pkg/utils/json_parser.go
Normal file
153
pkg/utils/json_parser.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SafeParseAIJSON 安全地解析AI返回的JSON,处理常见的格式问题
|
||||
// 包括:
|
||||
// 1. 移除Markdown代码块标记
|
||||
// 2. 提取JSON对象
|
||||
// 3. 清理多余的空白和换行
|
||||
// 4. 尝试修复截断的JSON
|
||||
// 5. 提供详细的错误信息
|
||||
func SafeParseAIJSON(aiResponse string, v interface{}) error {
|
||||
if aiResponse == "" {
|
||||
return fmt.Errorf("AI返回内容为空")
|
||||
}
|
||||
|
||||
// 1. 移除可能的Markdown代码块标记
|
||||
cleaned := strings.TrimSpace(aiResponse)
|
||||
cleaned = regexp.MustCompile("(?m)^```json\\s*").ReplaceAllString(cleaned, "")
|
||||
cleaned = regexp.MustCompile("(?m)^```\\s*").ReplaceAllString(cleaned, "")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
|
||||
// 2. 提取JSON对象 (查找第一个 { 到最后一个 })
|
||||
jsonRegex := regexp.MustCompile(`(?s)\{.*\}`)
|
||||
jsonMatch := jsonRegex.FindString(cleaned)
|
||||
|
||||
if jsonMatch == "" {
|
||||
return fmt.Errorf("响应中未找到有效的JSON对象,原始响应: %s", truncateString(aiResponse, 200))
|
||||
}
|
||||
|
||||
// 3. 尝试解析JSON
|
||||
err := json.Unmarshal([]byte(jsonMatch), v)
|
||||
if err == nil {
|
||||
return nil // 解析成功
|
||||
}
|
||||
|
||||
// 4. 如果解析失败,尝试修复截断的JSON
|
||||
fixedJSON := attemptJSONRepair(jsonMatch)
|
||||
if fixedJSON != jsonMatch {
|
||||
if err := json.Unmarshal([]byte(fixedJSON), v); err == nil {
|
||||
return nil // 修复后解析成功
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 提供详细的错误上下文
|
||||
if jsonErr, ok := err.(*json.SyntaxError); ok {
|
||||
errorPos := int(jsonErr.Offset)
|
||||
start := maxInt(0, errorPos-100)
|
||||
end := minInt(len(jsonMatch), errorPos+100)
|
||||
|
||||
context := jsonMatch[start:end]
|
||||
marker := strings.Repeat(" ", errorPos-start) + "^"
|
||||
|
||||
return fmt.Errorf(
|
||||
"JSON解析失败: %s\n错误位置附近:\n%s\n%s",
|
||||
jsonErr.Error(),
|
||||
context,
|
||||
marker,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf("JSON解析失败: %w\n原始响应: %s", err, truncateString(jsonMatch, 300))
|
||||
}
|
||||
|
||||
// attemptJSONRepair 尝试修复常见的JSON问题
|
||||
func attemptJSONRepair(jsonStr string) string {
|
||||
// 1. 处理未闭合的字符串
|
||||
// 如果最后一个字符不是 },尝试补全
|
||||
trimmed := strings.TrimSpace(jsonStr)
|
||||
|
||||
// 2. 检查是否有未闭合的引号
|
||||
if strings.Count(trimmed, `"`)%2 != 0 {
|
||||
// 有奇数个引号,尝试补全最后一个引号
|
||||
trimmed += `"`
|
||||
}
|
||||
|
||||
// 3. 统计括号
|
||||
openBraces := strings.Count(trimmed, "{")
|
||||
closeBraces := strings.Count(trimmed, "}")
|
||||
openBrackets := strings.Count(trimmed, "[")
|
||||
closeBrackets := strings.Count(trimmed, "]")
|
||||
|
||||
// 4. 补全未闭合的数组
|
||||
for i := 0; i < openBrackets-closeBrackets; i++ {
|
||||
trimmed += "]"
|
||||
}
|
||||
|
||||
// 5. 补全未闭合的对象
|
||||
for i := 0; i < openBraces-closeBraces; i++ {
|
||||
trimmed += "}"
|
||||
}
|
||||
|
||||
return trimmed
|
||||
}
|
||||
|
||||
// ExtractJSONFromText 从文本中提取JSON对象或数组
|
||||
func ExtractJSONFromText(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
|
||||
// 移除Markdown代码块
|
||||
text = regexp.MustCompile("(?m)^```json\\s*").ReplaceAllString(text, "")
|
||||
text = regexp.MustCompile("(?m)^```\\s*").ReplaceAllString(text, "")
|
||||
text = strings.TrimSpace(text)
|
||||
|
||||
// 查找JSON对象
|
||||
if idx := strings.Index(text, "{"); idx != -1 {
|
||||
if lastIdx := strings.LastIndex(text, "}"); lastIdx != -1 && lastIdx > idx {
|
||||
return text[idx : lastIdx+1]
|
||||
}
|
||||
}
|
||||
|
||||
// 查找JSON数组
|
||||
if idx := strings.Index(text, "["); idx != -1 {
|
||||
if lastIdx := strings.LastIndex(text, "]"); lastIdx != -1 && lastIdx > idx {
|
||||
return text[idx : lastIdx+1]
|
||||
}
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
// ValidateJSON 验证JSON字符串是否有效
|
||||
func ValidateJSON(jsonStr string) error {
|
||||
var js json.RawMessage
|
||||
return json.Unmarshal([]byte(jsonStr), &js)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
28
pkg/utils/random.go
Normal file
28
pkg/utils/random.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func GenerateVerificationCode(length int) string {
|
||||
digits := "0123456789"
|
||||
code := make([]byte, length)
|
||||
for i := range code {
|
||||
code[i] = digits[rand.Intn(len(digits))]
|
||||
}
|
||||
return string(code)
|
||||
}
|
||||
|
||||
func GenerateRandomString(length int) string {
|
||||
chars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
for i := range result {
|
||||
result[i] = chars[rand.Intn(len(chars))]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
192
pkg/video/minimax_client.go
Normal file
192
pkg/video/minimax_client.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MinimaxClient Minimax视频生成客户端
|
||||
type MinimaxClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type MinimaxSubjectReference struct {
|
||||
Type string `json:"type"`
|
||||
Image []string `json:"image"`
|
||||
}
|
||||
|
||||
type MinimaxRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
FirstFrameImage string `json:"first_frame_image,omitempty"`
|
||||
LastFrameImage string `json:"last_frame_image,omitempty"`
|
||||
SubjectReference []MinimaxSubjectReference `json:"subject_reference,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
}
|
||||
|
||||
type MinimaxResponse struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
BaseResp struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
StatusMsg string `json:"status_msg"`
|
||||
} `json:"base_resp"`
|
||||
Video struct {
|
||||
URL string `json:"url"`
|
||||
Duration int `json:"duration"`
|
||||
} `json:"video"`
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func NewMinimaxClient(baseURL, apiKey, model string) *MinimaxClient {
|
||||
return &MinimaxClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateVideo 生成视频(支持首尾帧和主体参考)
|
||||
func (c *MinimaxClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 6,
|
||||
Resolution: "1080P",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
reqBody := MinimaxRequest{
|
||||
Prompt: prompt,
|
||||
Model: model,
|
||||
Duration: options.Duration,
|
||||
}
|
||||
|
||||
// 设置分辨率
|
||||
if options.Resolution != "" {
|
||||
reqBody.Resolution = options.Resolution
|
||||
}
|
||||
|
||||
// 如果有首帧图片(从imageURL或FirstFrameURL)
|
||||
if options.FirstFrameURL != "" {
|
||||
reqBody.FirstFrameImage = options.FirstFrameURL
|
||||
} else if imageURL != "" {
|
||||
reqBody.FirstFrameImage = imageURL
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + "/v1/video_generation"
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result MinimaxResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error.Message != "" {
|
||||
return nil, fmt.Errorf("minimax error: %s", result.Error.Message)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.TaskID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
Duration: result.Video.Duration,
|
||||
}
|
||||
|
||||
if result.Video.URL != "" {
|
||||
videoResult.VideoURL = result.Video.URL
|
||||
videoResult.Completed = true
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *MinimaxClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
endpoint := c.BaseURL + "/v1/video_generation/" + taskID
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var result MinimaxResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.TaskID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
Duration: result.Video.Duration,
|
||||
}
|
||||
|
||||
if result.Error.Message != "" {
|
||||
videoResult.Error = result.Error.Message
|
||||
}
|
||||
|
||||
if result.Video.URL != "" {
|
||||
videoResult.VideoURL = result.Video.URL
|
||||
videoResult.Completed = true
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
178
pkg/video/openai_sora_client.go
Normal file
178
pkg/video/openai_sora_client.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpenAISoraClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type OpenAISoraResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Model string `json:"model"`
|
||||
Status string `json:"status"`
|
||||
Progress int `json:"progress"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt int64 `json:"completed_at"`
|
||||
Size string `json:"size"`
|
||||
Seconds string `json:"seconds"`
|
||||
Quality string `json:"quality"`
|
||||
VideoURL string `json:"video_url"` // 直接的video_url字段
|
||||
Video struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"video"` // 嵌套的video.url字段(兼容)
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func NewOpenAISoraClient(baseURL, apiKey, model string) *OpenAISoraClient {
|
||||
return &OpenAISoraClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAISoraClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 4,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
writer.WriteField("model", model)
|
||||
writer.WriteField("prompt", prompt)
|
||||
|
||||
if imageURL != "" {
|
||||
writer.WriteField("input_reference", imageURL)
|
||||
}
|
||||
|
||||
if options.Duration > 0 {
|
||||
writer.WriteField("seconds", fmt.Sprintf("%d", options.Duration))
|
||||
}
|
||||
|
||||
if options.Resolution != "" {
|
||||
writer.WriteField("size", options.Resolution)
|
||||
}
|
||||
|
||||
writer.Close()
|
||||
|
||||
endpoint := c.BaseURL + "/v1/videos"
|
||||
req, err := http.NewRequest("POST", endpoint, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result OpenAISoraResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error.Message != "" {
|
||||
return nil, fmt.Errorf("openai error: %s", result.Error.Message)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
}
|
||||
|
||||
// 优先使用video_url字段,兼容video.url嵌套结构
|
||||
if result.VideoURL != "" {
|
||||
videoResult.VideoURL = result.VideoURL
|
||||
} else if result.Video.URL != "" {
|
||||
videoResult.VideoURL = result.Video.URL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *OpenAISoraClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
endpoint := c.BaseURL + "/v1/videos/" + taskID
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var result OpenAISoraResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
}
|
||||
|
||||
if result.Error.Message != "" {
|
||||
videoResult.Error = result.Error.Message
|
||||
}
|
||||
|
||||
// 优先使用video_url字段,兼容video.url嵌套结构
|
||||
if result.VideoURL != "" {
|
||||
videoResult.VideoURL = result.VideoURL
|
||||
} else if result.Video.URL != "" {
|
||||
videoResult.VideoURL = result.Video.URL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
427
pkg/video/video_client.go
Normal file
427
pkg/video/video_client.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type VideoClient interface {
|
||||
GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error)
|
||||
GetTaskStatus(taskID string) (*VideoResult, error)
|
||||
}
|
||||
|
||||
type VideoResult struct {
|
||||
TaskID string
|
||||
Status string
|
||||
VideoURL string
|
||||
ThumbnailURL string
|
||||
Duration int
|
||||
Width int
|
||||
Height int
|
||||
Error string
|
||||
Completed bool
|
||||
}
|
||||
|
||||
type VideoOptions struct {
|
||||
Model string
|
||||
Duration int
|
||||
FPS int
|
||||
Resolution string
|
||||
AspectRatio string
|
||||
Style string
|
||||
MotionLevel int
|
||||
CameraMotion string
|
||||
Seed int64
|
||||
FirstFrameURL string
|
||||
LastFrameURL string
|
||||
ReferenceImageURLs []string
|
||||
}
|
||||
|
||||
type VideoOption func(*VideoOptions)
|
||||
|
||||
func WithModel(model string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithDuration(duration int) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Duration = duration
|
||||
}
|
||||
}
|
||||
|
||||
func WithFPS(fps int) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.FPS = fps
|
||||
}
|
||||
}
|
||||
|
||||
func WithResolution(resolution string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Resolution = resolution
|
||||
}
|
||||
}
|
||||
|
||||
func WithAspectRatio(ratio string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.AspectRatio = ratio
|
||||
}
|
||||
}
|
||||
|
||||
func WithStyle(style string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Style = style
|
||||
}
|
||||
}
|
||||
|
||||
func WithMotionLevel(level int) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.MotionLevel = level
|
||||
}
|
||||
}
|
||||
|
||||
func WithCameraMotion(motion string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.CameraMotion = motion
|
||||
}
|
||||
}
|
||||
|
||||
func WithSeed(seed int64) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Seed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithFirstFrame(url string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.FirstFrameURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithLastFrame(url string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.LastFrameURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithReferenceImages(urls []string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.ReferenceImageURLs = urls
|
||||
}
|
||||
}
|
||||
|
||||
type RunwayClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type RunwayRequest struct {
|
||||
Model string `json:"model"`
|
||||
PromptImage string `json:"prompt_image"`
|
||||
PromptText string `json:"prompt_text"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
type RunwayResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Output struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func NewRunwayClient(baseURL, apiKey, model string) *RunwayClient {
|
||||
return &RunwayClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 180 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RunwayClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 5,
|
||||
AspectRatio: "16:9",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
reqBody := RunwayRequest{
|
||||
Model: model,
|
||||
PromptImage: imageURL,
|
||||
PromptText: prompt,
|
||||
Duration: options.Duration,
|
||||
AspectRatio: options.AspectRatio,
|
||||
Seed: options.Seed,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + "/v1/video/generate"
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result RunwayResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("runway error: %s", result.Error)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "succeeded",
|
||||
}
|
||||
|
||||
if result.Output.URL != "" {
|
||||
videoResult.VideoURL = result.Output.URL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *RunwayClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
endpoint := c.BaseURL + "/v1/video/status/" + taskID
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var result RunwayResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "succeeded",
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
videoResult.Error = result.Error
|
||||
}
|
||||
|
||||
if result.Output.URL != "" {
|
||||
videoResult.VideoURL = result.Output.URL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
type PikaClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type PikaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Image string `json:"image"`
|
||||
Prompt string `json:"prompt"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
Motion int `json:"motion,omitempty"`
|
||||
CameraMotion string `json:"camera_motion,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
type PikaResponse struct {
|
||||
JobID string `json:"job_id"`
|
||||
Status string `json:"status"`
|
||||
Result struct {
|
||||
VideoURL string `json:"video_url"`
|
||||
} `json:"result"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func NewPikaClient(baseURL, apiKey, model string) *PikaClient {
|
||||
return &PikaClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 180 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PikaClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 3,
|
||||
AspectRatio: "16:9",
|
||||
MotionLevel: 50,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
reqBody := PikaRequest{
|
||||
Model: model,
|
||||
Image: imageURL,
|
||||
Prompt: prompt,
|
||||
Duration: options.Duration,
|
||||
AspectRatio: options.AspectRatio,
|
||||
Motion: options.MotionLevel,
|
||||
CameraMotion: options.CameraMotion,
|
||||
Seed: options.Seed,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + "/v1/video/generate"
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result PikaResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("pika error: %s", result.Error)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.JobID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
}
|
||||
|
||||
if result.Result.VideoURL != "" {
|
||||
videoResult.VideoURL = result.Result.VideoURL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *PikaClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
endpoint := c.BaseURL + "/v1/video/status/" + taskID
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var result PikaResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.JobID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
videoResult.Error = result.Error
|
||||
}
|
||||
|
||||
if result.Result.VideoURL != "" {
|
||||
videoResult.VideoURL = result.Result.VideoURL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
288
pkg/video/volces_ark_client.go
Normal file
288
pkg/video/volces_ark_client.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// VolcesArkClient 火山引擎ARK视频生成客户端
|
||||
type VolcesArkClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Endpoint string
|
||||
QueryEndpoint string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type VolcesArkContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL map[string]interface{} `json:"image_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
type VolcesArkRequest struct {
|
||||
Model string `json:"model"`
|
||||
Content []VolcesArkContent `json:"content"`
|
||||
GenerateAudio bool `json:"generate_audio,omitempty"`
|
||||
}
|
||||
|
||||
type VolcesArkResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Status string `json:"status"`
|
||||
Content struct {
|
||||
VideoURL string `json:"video_url"`
|
||||
} `json:"content"`
|
||||
Usage struct {
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
Seed int `json:"seed"`
|
||||
Resolution string `json:"resolution"`
|
||||
Ratio string `json:"ratio"`
|
||||
Duration int `json:"duration"`
|
||||
FramesPerSecond int `json:"framespersecond"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
ExecutionExpiresAfter int `json:"execution_expires_after"`
|
||||
GenerateAudio bool `json:"generate_audio"`
|
||||
Error interface{} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *VolcesArkClient {
|
||||
if endpoint == "" {
|
||||
endpoint = "/api/v3/contents/generations/tasks"
|
||||
}
|
||||
if queryEndpoint == "" {
|
||||
queryEndpoint = endpoint
|
||||
}
|
||||
return &VolcesArkClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
QueryEndpoint: queryEndpoint,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateVideo 生成视频(支持首帧、首尾帧、参考图等多种模式)
|
||||
func (c *VolcesArkClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 5,
|
||||
AspectRatio: "adaptive",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
// 构建prompt文本(包含duration和ratio参数)
|
||||
promptText := prompt
|
||||
if options.AspectRatio != "" {
|
||||
promptText += fmt.Sprintf(" --ratio %s", options.AspectRatio)
|
||||
}
|
||||
if options.Duration > 0 {
|
||||
promptText += fmt.Sprintf(" --dur %d", options.Duration)
|
||||
}
|
||||
|
||||
content := []VolcesArkContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: promptText,
|
||||
},
|
||||
}
|
||||
|
||||
// 处理不同的图片模式
|
||||
// 1. 组图模式(多个reference_image)
|
||||
if len(options.ReferenceImageURLs) > 0 {
|
||||
for _, refURL := range options.ReferenceImageURLs {
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": refURL,
|
||||
},
|
||||
Role: "reference_image",
|
||||
})
|
||||
}
|
||||
} else if options.FirstFrameURL != "" && options.LastFrameURL != "" {
|
||||
// 2. 首尾帧模式
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": options.FirstFrameURL,
|
||||
},
|
||||
Role: "first_frame",
|
||||
})
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": options.LastFrameURL,
|
||||
},
|
||||
Role: "last_frame",
|
||||
})
|
||||
} else if imageURL != "" {
|
||||
// 3. 单图模式(默认)
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": imageURL,
|
||||
},
|
||||
// 单图模式不需要role
|
||||
})
|
||||
} else if options.FirstFrameURL != "" {
|
||||
// 4. 只有首帧
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": options.FirstFrameURL,
|
||||
},
|
||||
Role: "first_frame",
|
||||
})
|
||||
}
|
||||
|
||||
// 只有 seedance-1-5-pro 模型支持 generate_audio 参数
|
||||
generateAudio := false
|
||||
if strings.Contains(strings.ToLower(model), "seedance-1-5-pro") {
|
||||
generateAudio = true
|
||||
}
|
||||
|
||||
reqBody := VolcesArkRequest{
|
||||
Model: model,
|
||||
Content: content,
|
||||
GenerateAudio: generateAudio,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + c.Endpoint
|
||||
fmt.Printf("[VolcesARK] Generating video - Endpoint: %s, FullURL: %s, Model: %s\n", c.Endpoint, endpoint, model)
|
||||
fmt.Printf("[VolcesARK] Request body: %s\n", string(jsonData))
|
||||
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[VolcesARK] Response status: %d, body: %s\n", resp.StatusCode, string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result VolcesArkResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[VolcesARK] Video generation initiated - TaskID: %s, Status: %s\n", result.ID, result.Status)
|
||||
|
||||
if result.Error != nil {
|
||||
errorMsg := fmt.Sprintf("%v", result.Error)
|
||||
return nil, fmt.Errorf("volces error: %s", errorMsg)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed" || result.Status == "succeeded",
|
||||
Duration: result.Duration,
|
||||
}
|
||||
|
||||
if result.Content.VideoURL != "" {
|
||||
videoResult.VideoURL = result.Content.VideoURL
|
||||
videoResult.Completed = true
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *VolcesArkClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
// 替换占位符{taskId}或直接拼接
|
||||
queryPath := c.QueryEndpoint
|
||||
if contains := bytes.Contains([]byte(queryPath), []byte("{taskId}")); contains {
|
||||
queryPath = string(bytes.ReplaceAll([]byte(queryPath), []byte("{taskId}"), []byte(taskID)))
|
||||
} else {
|
||||
queryPath = queryPath + "/" + taskID
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + queryPath
|
||||
fmt.Printf("[VolcesARK] Querying task status - TaskID: %s, QueryEndpoint: %s, FullURL: %s\n", taskID, c.QueryEndpoint, endpoint)
|
||||
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[VolcesARK] Response body: %s\n", string(body))
|
||||
|
||||
var result VolcesArkResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[VolcesARK] Parsed result - ID: %s, Status: %s, VideoURL: %s\n", result.ID, result.Status, result.Content.VideoURL)
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed" || result.Status == "succeeded",
|
||||
Duration: result.Duration,
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
videoResult.Error = fmt.Sprintf("%v", result.Error)
|
||||
}
|
||||
|
||||
if result.Content.VideoURL != "" {
|
||||
videoResult.VideoURL = result.Content.VideoURL
|
||||
videoResult.Completed = true
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
Reference in New Issue
Block a user