Merge pull request #2 from itzhan/main

添加claudeCode的count路径
This commit is contained in:
1e0n
2025-10-13 19:03:31 +08:00
committed by GitHub
3 changed files with 94 additions and 6 deletions

4
package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "droid2api", "name": "droid2api",
"version": "1.0.0", "version": "1.3.1",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "droid2api", "name": "droid2api",
"version": "1.0.0", "version": "1.3.1",
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"express": "^4.18.2", "express": "^4.18.2",

View File

@@ -508,9 +508,94 @@ async function handleDirectMessages(req, res) {
} }
} }
// 处理 Anthropic count_tokens 请求
async function handleCountTokens(req, res) {
logInfo('POST /v1/messages/count_tokens');
try {
const anthropicRequest = req.body;
const modelId = anthropicRequest.model;
if (!modelId) {
return res.status(400).json({ error: 'model is required' });
}
const model = getModelById(modelId);
if (!model) {
return res.status(404).json({ error: `Model ${modelId} not found` });
}
// 只允许 anthropic 类型端点
if (model.type !== 'anthropic') {
return res.status(400).json({
error: 'Invalid endpoint type',
message: `/v1/messages/count_tokens 接口只支持 anthropic 类型端点,当前模型 ${modelId}${model.type} 类型`
});
}
const endpoint = getEndpointByType('anthropic');
if (!endpoint) {
return res.status(500).json({ error: 'Endpoint type anthropic not found' });
}
// Get API key
let authHeader;
try {
const clientAuthFromXApiKey = req.headers['x-api-key']
? `Bearer ${req.headers['x-api-key']}`
: null;
authHeader = await getApiKey(req.headers.authorization || clientAuthFromXApiKey);
} catch (error) {
logError('Failed to get API key', error);
return res.status(500).json({
error: 'API key not available',
message: 'Failed to get or refresh API key. Please check server logs.'
});
}
const clientHeaders = req.headers;
const headers = getAnthropicHeaders(authHeader, clientHeaders, false, modelId);
// 构建 count_tokens 端点 URL
const countTokensUrl = endpoint.base_url.replace('/v1/messages', '/v1/messages/count_tokens');
logInfo(`Forwarding to count_tokens endpoint: ${countTokensUrl}`);
logRequest('POST', countTokensUrl, headers, anthropicRequest);
const response = await fetch(countTokensUrl, {
method: 'POST',
headers,
body: JSON.stringify(anthropicRequest)
});
logInfo(`Response status: ${response.status}`);
if (!response.ok) {
const errorText = await response.text();
logError(`Count tokens error: ${response.status}`, new Error(errorText));
return res.status(response.status).json({
error: `Endpoint returned ${response.status}`,
details: errorText
});
}
const data = await response.json();
logResponse(200, null, data);
res.json(data);
} catch (error) {
logError('Error in /v1/messages/count_tokens', error);
res.status(500).json({
error: 'Internal server error',
message: error.message
});
}
}
// 注册路由 // 注册路由
router.post('/v1/chat/completions', handleChatCompletions); router.post('/v1/chat/completions', handleChatCompletions);
router.post('/v1/responses', handleDirectResponses); router.post('/v1/responses', handleDirectResponses);
router.post('/v1/messages', handleDirectMessages); router.post('/v1/messages', handleDirectMessages);
router.post('/v1/messages/count_tokens', handleCountTokens);
export default router; export default router;

View File

@@ -31,7 +31,8 @@ app.get('/', (req, res) => {
'GET /v1/models', 'GET /v1/models',
'POST /v1/chat/completions', 'POST /v1/chat/completions',
'POST /v1/responses', 'POST /v1/responses',
'POST /v1/messages' 'POST /v1/messages',
'POST /v1/messages/count_tokens'
] ]
}); });
}); });
@@ -90,7 +91,8 @@ app.use((req, res, next) => {
'GET /v1/models', 'GET /v1/models',
'POST /v1/chat/completions', 'POST /v1/chat/completions',
'POST /v1/responses', 'POST /v1/responses',
'POST /v1/messages' 'POST /v1/messages',
'POST /v1/messages/count_tokens'
] ]
}); });
}); });
@@ -125,6 +127,7 @@ app.use((err, req, res, next) => {
logInfo(' POST /v1/chat/completions'); logInfo(' POST /v1/chat/completions');
logInfo(' POST /v1/responses'); logInfo(' POST /v1/responses');
logInfo(' POST /v1/messages'); logInfo(' POST /v1/messages');
logInfo(' POST /v1/messages/count_tokens');
}) })
.on('error', (err) => { .on('error', (err) => {
if (err.code === 'EADDRINUSE') { if (err.code === 'EADDRINUSE') {