Files
huobao-drama/pkg/video/chatfire_client.go

185 lines
4.3 KiB
Go

package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// ChatfireClient Chatfire 视频生成客户端
type ChatfireClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
QueryEndpoint string
HTTPClient *http.Client
}
type ChatfireRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
ImageURL string `json:"image_url,omitempty"`
Duration int `json:"duration,omitempty"`
Size string `json:"size,omitempty"`
}
type ChatfireResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
}
type ChatfireTaskResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
VideoURL string `json:"video_url,omitempty"`
Error string `json:"error,omitempty"`
}
func NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *ChatfireClient {
if endpoint == "" {
endpoint = "/video/generations"
}
if queryEndpoint == "" {
queryEndpoint = "/v1/video/task/{taskId}"
}
return &ChatfireClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
HTTPClient: &http.Client{
Timeout: 300 * time.Second,
},
}
}
func (c *ChatfireClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 5,
AspectRatio: "16:9",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := ChatfireRequest{
Model: model,
Prompt: prompt,
ImageURL: imageURL,
Duration: options.Duration,
Size: options.AspectRatio,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + c.Endpoint
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result ChatfireResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("chatfire error: %s", result.Error)
}
videoResult := &VideoResult{
TaskID: result.TaskID,
Status: result.Status,
Completed: result.Status == "completed" || result.Status == "succeeded",
Duration: options.Duration,
}
return videoResult, nil
}
func (c *ChatfireClient) GetTaskStatus(taskID string) (*VideoResult, error) {
queryPath := c.QueryEndpoint
if strings.Contains(queryPath, "{taskId}") {
queryPath = strings.ReplaceAll(queryPath, "{taskId}", taskID)
} else if strings.Contains(queryPath, "{task_id}") {
queryPath = strings.ReplaceAll(queryPath, "{task_id}", taskID)
} else {
queryPath = queryPath + "/" + taskID
}
endpoint := c.BaseURL + queryPath
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result ChatfireTaskResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.TaskID,
Status: result.Status,
Completed: result.Status == "completed" || result.Status == "succeeded",
}
if result.Error != "" {
videoResult.Error = result.Error
}
if result.VideoURL != "" {
videoResult.VideoURL = result.VideoURL
videoResult.Completed = true
}
return videoResult, nil
}