fix: harden url fetch dns pinning

This commit is contained in:
Peter Steinberger
2026-01-26 16:05:20 +00:00
parent 8b68cdd9bc
commit b623557a2e
7 changed files with 429 additions and 198 deletions

View File

@@ -41,6 +41,7 @@ Status: unreleased.
- Security: harden Tailscale Serve auth by validating identity via local tailscaled before trusting headers. - Security: harden Tailscale Serve auth by validating identity via local tailscaled before trusting headers.
- Build: align memory-core peer dependency with lockfile. - Build: align memory-core peer dependency with lockfile.
- Security: add mDNS discovery mode with minimal default to reduce information disclosure. (#1882) Thanks @orlyjamie. - Security: add mDNS discovery mode with minimal default to reduce information disclosure. (#1882) Thanks @orlyjamie.
- Security: harden URL fetches with DNS pinning to reduce rebinding risk. Thanks Chris Zheng.
- Web UI: improve WebChat image paste previews and allow image-only sends. (#1925) Thanks @smartprogrammer93. - Web UI: improve WebChat image paste previews and allow image-only sends. (#1925) Thanks @smartprogrammer93.
- Security: wrap external hook content by default with a per-hook opt-out. (#1827) Thanks @mertcicekci0. - Security: wrap external hook content by default with a per-hook opt-out. (#1827) Thanks @mertcicekci0.
- Gateway: default auth now fail-closed (token/password required; Tailscale Serve identity remains allowed). - Gateway: default auth now fail-closed (token/password required; Tailscale Serve identity remains allowed).

View File

@@ -1,7 +1,13 @@
import { Type } from "@sinclair/typebox"; import { Type } from "@sinclair/typebox";
import type { ClawdbotConfig } from "../../config/config.js"; import type { ClawdbotConfig } from "../../config/config.js";
import { assertPublicHostname, SsrFBlockedError } from "../../infra/net/ssrf.js"; import {
closeDispatcher,
createPinnedDispatcher,
resolvePinnedHostname,
SsrFBlockedError,
} from "../../infra/net/ssrf.js";
import type { Dispatcher } from "undici";
import { stringEnum } from "../schema/typebox.js"; import { stringEnum } from "../schema/typebox.js";
import type { AnyAgentTool } from "./common.js"; import type { AnyAgentTool } from "./common.js";
import { jsonResult, readNumberParam, readStringParam } from "./common.js"; import { jsonResult, readNumberParam, readStringParam } from "./common.js";
@@ -167,7 +173,7 @@ async function fetchWithRedirects(params: {
maxRedirects: number; maxRedirects: number;
timeoutSeconds: number; timeoutSeconds: number;
userAgent: string; userAgent: string;
}): Promise<{ response: Response; finalUrl: string }> { }): Promise<{ response: Response; finalUrl: string; dispatcher: Dispatcher }> {
const signal = withTimeout(undefined, params.timeoutSeconds * 1000); const signal = withTimeout(undefined, params.timeoutSeconds * 1000);
const visited = new Set<string>(); const visited = new Set<string>();
let currentUrl = params.url; let currentUrl = params.url;
@@ -184,39 +190,50 @@ async function fetchWithRedirects(params: {
throw new Error("Invalid URL: must be http or https"); throw new Error("Invalid URL: must be http or https");
} }
await assertPublicHostname(parsedUrl.hostname); const pinned = await resolvePinnedHostname(parsedUrl.hostname);
const dispatcher = createPinnedDispatcher(pinned);
const res = await fetch(parsedUrl.toString(), { let res: Response;
method: "GET", try {
headers: { res = await fetch(parsedUrl.toString(), {
Accept: "*/*", method: "GET",
"User-Agent": params.userAgent, headers: {
"Accept-Language": "en-US,en;q=0.9", Accept: "*/*",
}, "User-Agent": params.userAgent,
signal, "Accept-Language": "en-US,en;q=0.9",
redirect: "manual", },
}); signal,
redirect: "manual",
dispatcher,
} as RequestInit);
} catch (err) {
await closeDispatcher(dispatcher);
throw err;
}
if (isRedirectStatus(res.status)) { if (isRedirectStatus(res.status)) {
const location = res.headers.get("location"); const location = res.headers.get("location");
if (!location) { if (!location) {
await closeDispatcher(dispatcher);
throw new Error(`Redirect missing location header (${res.status})`); throw new Error(`Redirect missing location header (${res.status})`);
} }
redirectCount += 1; redirectCount += 1;
if (redirectCount > params.maxRedirects) { if (redirectCount > params.maxRedirects) {
await closeDispatcher(dispatcher);
throw new Error(`Too many redirects (limit: ${params.maxRedirects})`); throw new Error(`Too many redirects (limit: ${params.maxRedirects})`);
} }
const nextUrl = new URL(location, parsedUrl).toString(); const nextUrl = new URL(location, parsedUrl).toString();
if (visited.has(nextUrl)) { if (visited.has(nextUrl)) {
await closeDispatcher(dispatcher);
throw new Error("Redirect loop detected"); throw new Error("Redirect loop detected");
} }
visited.add(nextUrl); visited.add(nextUrl);
void res.body?.cancel(); void res.body?.cancel();
await closeDispatcher(dispatcher);
currentUrl = nextUrl; currentUrl = nextUrl;
continue; continue;
} }
return { response: res, finalUrl: currentUrl }; return { response: res, finalUrl: currentUrl, dispatcher };
} }
} }
@@ -348,6 +365,7 @@ async function runWebFetch(params: {
const start = Date.now(); const start = Date.now();
let res: Response; let res: Response;
let dispatcher: Dispatcher | null = null;
let finalUrl = params.url; let finalUrl = params.url;
try { try {
const result = await fetchWithRedirects({ const result = await fetchWithRedirects({
@@ -358,6 +376,7 @@ async function runWebFetch(params: {
}); });
res = result.response; res = result.response;
finalUrl = result.finalUrl; finalUrl = result.finalUrl;
dispatcher = result.dispatcher;
} catch (error) { } catch (error) {
if (error instanceof SsrFBlockedError) { if (error instanceof SsrFBlockedError) {
throw error; throw error;
@@ -396,108 +415,112 @@ async function runWebFetch(params: {
throw error; throw error;
} }
if (!res.ok) { try {
if (params.firecrawlEnabled && params.firecrawlApiKey) { if (!res.ok) {
const firecrawl = await fetchFirecrawlContent({ if (params.firecrawlEnabled && params.firecrawlApiKey) {
url: params.url, const firecrawl = await fetchFirecrawlContent({
extractMode: params.extractMode, url: params.url,
apiKey: params.firecrawlApiKey, extractMode: params.extractMode,
baseUrl: params.firecrawlBaseUrl, apiKey: params.firecrawlApiKey,
onlyMainContent: params.firecrawlOnlyMainContent, baseUrl: params.firecrawlBaseUrl,
maxAgeMs: params.firecrawlMaxAgeMs, onlyMainContent: params.firecrawlOnlyMainContent,
proxy: params.firecrawlProxy, maxAgeMs: params.firecrawlMaxAgeMs,
storeInCache: params.firecrawlStoreInCache, proxy: params.firecrawlProxy,
timeoutSeconds: params.firecrawlTimeoutSeconds, storeInCache: params.firecrawlStoreInCache,
}); timeoutSeconds: params.firecrawlTimeoutSeconds,
const truncated = truncateText(firecrawl.text, params.maxChars); });
const payload = { const truncated = truncateText(firecrawl.text, params.maxChars);
url: params.url, const payload = {
finalUrl: firecrawl.finalUrl || finalUrl, url: params.url,
status: firecrawl.status ?? res.status, finalUrl: firecrawl.finalUrl || finalUrl,
contentType: "text/markdown", status: firecrawl.status ?? res.status,
title: firecrawl.title, contentType: "text/markdown",
extractMode: params.extractMode, title: firecrawl.title,
extractor: "firecrawl", extractMode: params.extractMode,
truncated: truncated.truncated, extractor: "firecrawl",
length: truncated.text.length, truncated: truncated.truncated,
fetchedAt: new Date().toISOString(), length: truncated.text.length,
tookMs: Date.now() - start, fetchedAt: new Date().toISOString(),
text: truncated.text, tookMs: Date.now() - start,
warning: firecrawl.warning, text: truncated.text,
}; warning: firecrawl.warning,
writeCache(FETCH_CACHE, cacheKey, payload, params.cacheTtlMs); };
return payload; writeCache(FETCH_CACHE, cacheKey, payload, params.cacheTtlMs);
} return payload;
const rawDetail = await readResponseText(res);
const detail = formatWebFetchErrorDetail({
detail: rawDetail,
contentType: res.headers.get("content-type"),
maxChars: DEFAULT_ERROR_MAX_CHARS,
});
throw new Error(`Web fetch failed (${res.status}): ${detail || res.statusText}`);
}
const contentType = res.headers.get("content-type") ?? "application/octet-stream";
const body = await readResponseText(res);
let title: string | undefined;
let extractor = "raw";
let text = body;
if (contentType.includes("text/html")) {
if (params.readabilityEnabled) {
const readable = await extractReadableContent({
html: body,
url: finalUrl,
extractMode: params.extractMode,
});
if (readable?.text) {
text = readable.text;
title = readable.title;
extractor = "readability";
} else {
const firecrawl = await tryFirecrawlFallback({ ...params, url: finalUrl });
if (firecrawl) {
text = firecrawl.text;
title = firecrawl.title;
extractor = "firecrawl";
} else {
throw new Error(
"Web fetch extraction failed: Readability and Firecrawl returned no content.",
);
}
} }
} else { const rawDetail = await readResponseText(res);
throw new Error( const detail = formatWebFetchErrorDetail({
"Web fetch extraction failed: Readability disabled and Firecrawl unavailable.", detail: rawDetail,
); contentType: res.headers.get("content-type"),
maxChars: DEFAULT_ERROR_MAX_CHARS,
});
throw new Error(`Web fetch failed (${res.status}): ${detail || res.statusText}`);
} }
} else if (contentType.includes("application/json")) {
try {
text = JSON.stringify(JSON.parse(body), null, 2);
extractor = "json";
} catch {
text = body;
extractor = "raw";
}
}
const truncated = truncateText(text, params.maxChars); const contentType = res.headers.get("content-type") ?? "application/octet-stream";
const payload = { const body = await readResponseText(res);
url: params.url,
finalUrl, let title: string | undefined;
status: res.status, let extractor = "raw";
contentType, let text = body;
title, if (contentType.includes("text/html")) {
extractMode: params.extractMode, if (params.readabilityEnabled) {
extractor, const readable = await extractReadableContent({
truncated: truncated.truncated, html: body,
length: truncated.text.length, url: finalUrl,
fetchedAt: new Date().toISOString(), extractMode: params.extractMode,
tookMs: Date.now() - start, });
text: truncated.text, if (readable?.text) {
}; text = readable.text;
writeCache(FETCH_CACHE, cacheKey, payload, params.cacheTtlMs); title = readable.title;
return payload; extractor = "readability";
} else {
const firecrawl = await tryFirecrawlFallback({ ...params, url: finalUrl });
if (firecrawl) {
text = firecrawl.text;
title = firecrawl.title;
extractor = "firecrawl";
} else {
throw new Error(
"Web fetch extraction failed: Readability and Firecrawl returned no content.",
);
}
}
} else {
throw new Error(
"Web fetch extraction failed: Readability disabled and Firecrawl unavailable.",
);
}
} else if (contentType.includes("application/json")) {
try {
text = JSON.stringify(JSON.parse(body), null, 2);
extractor = "json";
} catch {
text = body;
extractor = "raw";
}
}
const truncated = truncateText(text, params.maxChars);
const payload = {
url: params.url,
finalUrl,
status: res.status,
contentType,
title,
extractMode: params.extractMode,
extractor,
truncated: truncated.truncated,
length: truncated.text.length,
fetchedAt: new Date().toISOString(),
tookMs: Date.now() - start,
text: truncated.text,
};
writeCache(FETCH_CACHE, cacheKey, payload, params.cacheTtlMs);
return payload;
} finally {
await closeDispatcher(dispatcher);
}
} }
async function tryFirecrawlFallback(params: { async function tryFirecrawlFallback(params: {

View File

@@ -0,0 +1,63 @@
import { describe, expect, it, vi } from "vitest";
import { createPinnedLookup, resolvePinnedHostname } from "./ssrf.js";
describe("ssrf pinning", () => {
it("pins resolved addresses for the target hostname", async () => {
const lookup = vi.fn(async () => [
{ address: "93.184.216.34", family: 4 },
{ address: "93.184.216.35", family: 4 },
]);
const pinned = await resolvePinnedHostname("Example.com.", lookup);
expect(pinned.hostname).toBe("example.com");
expect(pinned.addresses).toEqual(["93.184.216.34", "93.184.216.35"]);
const first = await new Promise<{ address: string; family?: number }>((resolve, reject) => {
pinned.lookup("example.com", (err, address, family) => {
if (err) reject(err);
else resolve({ address: address as string, family });
});
});
expect(first.address).toBe("93.184.216.34");
expect(first.family).toBe(4);
const all = await new Promise<unknown>((resolve, reject) => {
pinned.lookup("example.com", { all: true }, (err, addresses) => {
if (err) reject(err);
else resolve(addresses);
});
});
expect(Array.isArray(all)).toBe(true);
expect((all as Array<{ address: string }>).map((entry) => entry.address)).toEqual(
pinned.addresses,
);
});
it("rejects private DNS results", async () => {
const lookup = vi.fn(async () => [{ address: "10.0.0.8", family: 4 }]);
await expect(resolvePinnedHostname("example.com", lookup)).rejects.toThrow(/private|internal/i);
});
it("falls back for non-matching hostnames", async () => {
const fallback = vi.fn((host: string, options?: unknown, callback?: unknown) => {
const cb = typeof options === "function" ? options : (callback as () => void);
(cb as (err: null, address: string, family: number) => void)(null, "1.2.3.4", 4);
});
const lookup = createPinnedLookup({
hostname: "example.com",
addresses: ["93.184.216.34"],
fallback,
});
const result = await new Promise<{ address: string }>((resolve, reject) => {
lookup("other.test", (err, address) => {
if (err) reject(err);
else resolve({ address: address as string });
});
});
expect(fallback).toHaveBeenCalledTimes(1);
expect(result.address).toBe("1.2.3.4");
});
});

View File

@@ -1,4 +1,12 @@
import { lookup as dnsLookup } from "node:dns/promises"; import { lookup as dnsLookup } from "node:dns/promises";
import { lookup as dnsLookupCb, type LookupAddress } from "node:dns";
import { Agent, type Dispatcher } from "undici";
type LookupCallback = (
err: NodeJS.ErrnoException | null,
address: string | LookupAddress[],
family?: number,
) => void;
export class SsrFBlockedError extends Error { export class SsrFBlockedError extends Error {
constructor(message: string) { constructor(message: string) {
@@ -101,10 +109,71 @@ export function isBlockedHostname(hostname: string): boolean {
); );
} }
export async function assertPublicHostname( export function createPinnedLookup(params: {
hostname: string;
addresses: string[];
fallback?: typeof dnsLookupCb;
}): typeof dnsLookupCb {
const normalizedHost = normalizeHostname(params.hostname);
const fallback = params.fallback ?? dnsLookupCb;
const fallbackLookup = fallback as unknown as (
hostname: string,
callback: LookupCallback,
) => void;
const fallbackWithOptions = fallback as unknown as (
hostname: string,
options: unknown,
callback: LookupCallback,
) => void;
const records = params.addresses.map((address) => ({
address,
family: address.includes(":") ? 6 : 4,
}));
let index = 0;
return ((host: string, options?: unknown, callback?: unknown) => {
const cb: LookupCallback =
typeof options === "function" ? (options as LookupCallback) : (callback as LookupCallback);
if (!cb) return;
const normalized = normalizeHostname(host);
if (!normalized || normalized !== normalizedHost) {
if (typeof options === "function" || options === undefined) {
return fallbackLookup(host, cb);
}
return fallbackWithOptions(host, options, cb);
}
const opts =
typeof options === "object" && options !== null
? (options as { all?: boolean; family?: number })
: {};
const requestedFamily =
typeof options === "number" ? options : typeof opts.family === "number" ? opts.family : 0;
const candidates =
requestedFamily === 4 || requestedFamily === 6
? records.filter((entry) => entry.family === requestedFamily)
: records;
const usable = candidates.length > 0 ? candidates : records;
if (opts.all) {
cb(null, usable as LookupAddress[]);
return;
}
const chosen = usable[index % usable.length];
index += 1;
cb(null, chosen.address, chosen.family);
}) as typeof dnsLookupCb;
}
export type PinnedHostname = {
hostname: string;
addresses: string[];
lookup: typeof dnsLookupCb;
};
export async function resolvePinnedHostname(
hostname: string, hostname: string,
lookupFn: LookupFn = dnsLookup, lookupFn: LookupFn = dnsLookup,
): Promise<void> { ): Promise<PinnedHostname> {
const normalized = normalizeHostname(hostname); const normalized = normalizeHostname(hostname);
if (!normalized) { if (!normalized) {
throw new Error("Invalid hostname"); throw new Error("Invalid hostname");
@@ -128,4 +197,46 @@ export async function assertPublicHostname(
throw new SsrFBlockedError("Blocked: resolves to private/internal IP address"); throw new SsrFBlockedError("Blocked: resolves to private/internal IP address");
} }
} }
const addresses = Array.from(new Set(results.map((entry) => entry.address)));
if (addresses.length === 0) {
throw new Error(`Unable to resolve hostname: ${hostname}`);
}
return {
hostname: normalized,
addresses,
lookup: createPinnedLookup({ hostname: normalized, addresses }),
};
}
export function createPinnedDispatcher(pinned: PinnedHostname): Dispatcher {
return new Agent({
connect: {
lookup: pinned.lookup,
},
});
}
export async function closeDispatcher(dispatcher?: Dispatcher | null): Promise<void> {
if (!dispatcher) return;
const candidate = dispatcher as { close?: () => Promise<void> | void; destroy?: () => void };
try {
if (typeof candidate.close === "function") {
await candidate.close();
return;
}
if (typeof candidate.destroy === "function") {
candidate.destroy();
}
} catch {
// ignore dispatcher cleanup errors
}
}
export async function assertPublicHostname(
hostname: string,
lookupFn: LookupFn = dnsLookup,
): Promise<void> {
await resolvePinnedHostname(hostname, lookupFn);
} }

View File

@@ -1,5 +1,10 @@
import { logWarn } from "../logger.js"; import { logWarn } from "../logger.js";
import { assertPublicHostname } from "../infra/net/ssrf.js"; import {
closeDispatcher,
createPinnedDispatcher,
resolvePinnedHostname,
} from "../infra/net/ssrf.js";
import type { Dispatcher } from "undici";
type CanvasModule = typeof import("@napi-rs/canvas"); type CanvasModule = typeof import("@napi-rs/canvas");
type PdfJsModule = typeof import("pdfjs-dist/legacy/build/pdf.mjs"); type PdfJsModule = typeof import("pdfjs-dist/legacy/build/pdf.mjs");
@@ -154,50 +159,57 @@ export async function fetchWithGuard(params: {
if (!["http:", "https:"].includes(parsedUrl.protocol)) { if (!["http:", "https:"].includes(parsedUrl.protocol)) {
throw new Error(`Invalid URL protocol: ${parsedUrl.protocol}. Only HTTP/HTTPS allowed.`); throw new Error(`Invalid URL protocol: ${parsedUrl.protocol}. Only HTTP/HTTPS allowed.`);
} }
await assertPublicHostname(parsedUrl.hostname); const pinned = await resolvePinnedHostname(parsedUrl.hostname);
const dispatcher = createPinnedDispatcher(pinned);
const response = await fetch(parsedUrl, { try {
signal: controller.signal, const response = await fetch(parsedUrl, {
headers: { "User-Agent": "Clawdbot-Gateway/1.0" }, signal: controller.signal,
redirect: "manual", headers: { "User-Agent": "Clawdbot-Gateway/1.0" },
}); redirect: "manual",
dispatcher,
} as RequestInit & { dispatcher: Dispatcher });
if (isRedirectStatus(response.status)) { if (isRedirectStatus(response.status)) {
const location = response.headers.get("location"); const location = response.headers.get("location");
if (!location) { if (!location) {
throw new Error(`Redirect missing location header (${response.status})`); throw new Error(`Redirect missing location header (${response.status})`);
}
redirectCount += 1;
if (redirectCount > params.maxRedirects) {
throw new Error(`Too many redirects (limit: ${params.maxRedirects})`);
}
void response.body?.cancel();
currentUrl = new URL(location, parsedUrl).toString();
continue;
} }
redirectCount += 1;
if (redirectCount > params.maxRedirects) { if (!response.ok) {
throw new Error(`Too many redirects (limit: ${params.maxRedirects})`); throw new Error(`Failed to fetch: ${response.status} ${response.statusText}`);
} }
currentUrl = new URL(location, parsedUrl).toString();
continue;
}
if (!response.ok) { const contentLength = response.headers.get("content-length");
throw new Error(`Failed to fetch: ${response.status} ${response.statusText}`); if (contentLength) {
} const size = parseInt(contentLength, 10);
if (size > params.maxBytes) {
const contentLength = response.headers.get("content-length"); throw new Error(`Content too large: ${size} bytes (limit: ${params.maxBytes} bytes)`);
if (contentLength) { }
const size = parseInt(contentLength, 10);
if (size > params.maxBytes) {
throw new Error(`Content too large: ${size} bytes (limit: ${params.maxBytes} bytes)`);
} }
}
const buffer = Buffer.from(await response.arrayBuffer()); const buffer = Buffer.from(await response.arrayBuffer());
if (buffer.byteLength > params.maxBytes) { if (buffer.byteLength > params.maxBytes) {
throw new Error( throw new Error(
`Content too large: ${buffer.byteLength} bytes (limit: ${params.maxBytes} bytes)`, `Content too large: ${buffer.byteLength} bytes (limit: ${params.maxBytes} bytes)`,
); );
} }
const contentType = response.headers.get("content-type") || undefined; const contentType = response.headers.get("content-type") || undefined;
const parsed = parseContentType(contentType); const parsed = parseContentType(contentType);
const mimeType = parsed.mimeType ?? "application/octet-stream"; const mimeType = parsed.mimeType ?? "application/octet-stream";
return { buffer, mimeType, contentType }; return { buffer, mimeType, contentType };
} finally {
await closeDispatcher(dispatcher);
}
} }
} finally { } finally {
clearTimeout(timeoutId); clearTimeout(timeoutId);

View File

@@ -18,6 +18,9 @@ vi.doMock("node:os", () => ({
vi.doMock("node:https", () => ({ vi.doMock("node:https", () => ({
request: (...args: unknown[]) => mockRequest(...args), request: (...args: unknown[]) => mockRequest(...args),
})); }));
vi.doMock("node:dns/promises", () => ({
lookup: async () => [{ address: "93.184.216.34", family: 4 }],
}));
const loadStore = async () => await import("./store.js"); const loadStore = async () => await import("./store.js");

View File

@@ -1,10 +1,12 @@
import crypto from "node:crypto"; import crypto from "node:crypto";
import { createWriteStream } from "node:fs"; import { createWriteStream } from "node:fs";
import fs from "node:fs/promises"; import fs from "node:fs/promises";
import { request } from "node:https"; import { request as httpRequest } from "node:http";
import { request as httpsRequest } from "node:https";
import path from "node:path"; import path from "node:path";
import { pipeline } from "node:stream/promises"; import { pipeline } from "node:stream/promises";
import { resolveConfigDir } from "../utils.js"; import { resolveConfigDir } from "../utils.js";
import { resolvePinnedHostname } from "../infra/net/ssrf.js";
import { detectMime, extensionForMime } from "./mime.js"; import { detectMime, extensionForMime } from "./mime.js";
const resolveMediaDir = () => path.join(resolveConfigDir(), "media"); const resolveMediaDir = () => path.join(resolveConfigDir(), "media");
@@ -88,51 +90,67 @@ async function downloadToFile(
maxRedirects = 5, maxRedirects = 5,
): Promise<{ headerMime?: string; sniffBuffer: Buffer; size: number }> { ): Promise<{ headerMime?: string; sniffBuffer: Buffer; size: number }> {
return await new Promise((resolve, reject) => { return await new Promise((resolve, reject) => {
const req = request(url, { headers }, (res) => { let parsedUrl: URL;
// Follow redirects try {
if (res.statusCode && res.statusCode >= 300 && res.statusCode < 400) { parsedUrl = new URL(url);
const location = res.headers.location; } catch {
if (!location || maxRedirects <= 0) { reject(new Error("Invalid URL"));
reject(new Error(`Redirect loop or missing Location header`)); return;
return; }
} if (!["http:", "https:"].includes(parsedUrl.protocol)) {
const redirectUrl = new URL(location, url).href; reject(new Error(`Invalid URL protocol: ${parsedUrl.protocol}. Only HTTP/HTTPS allowed.`));
resolve(downloadToFile(redirectUrl, dest, headers, maxRedirects - 1)); return;
return; }
} const requestImpl = parsedUrl.protocol === "https:" ? httpsRequest : httpRequest;
if (!res.statusCode || res.statusCode >= 400) { resolvePinnedHostname(parsedUrl.hostname)
reject(new Error(`HTTP ${res.statusCode ?? "?"} downloading media`)); .then((pinned) => {
return; const req = requestImpl(parsedUrl, { headers, lookup: pinned.lookup }, (res) => {
} // Follow redirects
let total = 0; if (res.statusCode && res.statusCode >= 300 && res.statusCode < 400) {
const sniffChunks: Buffer[] = []; const location = res.headers.location;
let sniffLen = 0; if (!location || maxRedirects <= 0) {
const out = createWriteStream(dest); reject(new Error(`Redirect loop or missing Location header`));
res.on("data", (chunk) => { return;
total += chunk.length; }
if (sniffLen < 16384) { const redirectUrl = new URL(location, url).href;
sniffChunks.push(chunk); resolve(downloadToFile(redirectUrl, dest, headers, maxRedirects - 1));
sniffLen += chunk.length; return;
} }
if (total > MAX_BYTES) { if (!res.statusCode || res.statusCode >= 400) {
req.destroy(new Error("Media exceeds 5MB limit")); reject(new Error(`HTTP ${res.statusCode ?? "?"} downloading media`));
} return;
}); }
pipeline(res, out) let total = 0;
.then(() => { const sniffChunks: Buffer[] = [];
const sniffBuffer = Buffer.concat(sniffChunks, Math.min(sniffLen, 16384)); let sniffLen = 0;
const rawHeader = res.headers["content-type"]; const out = createWriteStream(dest);
const headerMime = Array.isArray(rawHeader) ? rawHeader[0] : rawHeader; res.on("data", (chunk) => {
resolve({ total += chunk.length;
headerMime, if (sniffLen < 16384) {
sniffBuffer, sniffChunks.push(chunk);
size: total, sniffLen += chunk.length;
}
if (total > MAX_BYTES) {
req.destroy(new Error("Media exceeds 5MB limit"));
}
}); });
}) pipeline(res, out)
.catch(reject); .then(() => {
}); const sniffBuffer = Buffer.concat(sniffChunks, Math.min(sniffLen, 16384));
req.on("error", reject); const rawHeader = res.headers["content-type"];
req.end(); const headerMime = Array.isArray(rawHeader) ? rawHeader[0] : rawHeader;
resolve({
headerMime,
sniffBuffer,
size: total,
});
})
.catch(reject);
});
req.on("error", reject);
req.end();
})
.catch(reject);
}); });
} }