Merge remote-tracking branch 'parent/main' into upstream-20240828

This commit is contained in:
KMY 2024-08-28 20:26:35 +09:00
commit b39054ff3c
136 changed files with 1803 additions and 977 deletions
streaming

View file

@ -8,15 +8,14 @@ import url from 'node:url';
import cors from 'cors';
import dotenv from 'dotenv';
import express from 'express';
import { Redis } from 'ioredis';
import { JSDOM } from 'jsdom';
import pg from 'pg';
import pgConnectionString from 'pg-connection-string';
import { WebSocketServer } from 'ws';
import * as Database from './database.js';
import { AuthenticationError, RequestError, extractStatusAndMessage as extractErrorStatusAndMessage } from './errors.js';
import { logger, httpLogger, initializeLogLevel, attachWebsocketHttpLogger, createWebsocketLogger } from './logging.js';
import { setupMetrics } from './metrics.js';
import * as Redis from './redis.js';
import { isTruthy, normalizeHashtag, firstParam } from './utils.js';
const environment = process.env.NODE_ENV || 'development';
@ -48,23 +47,6 @@ initializeLogLevel(process.env, environment);
* @property {string} deviceId
*/
/**
* @param {RedisConfiguration} config
* @returns {Promise<Redis>}
*/
const createRedisClient = async ({ redisParams, redisUrl }) => {
let client;
if (typeof redisUrl === 'string') {
client = new Redis(redisUrl, redisParams);
} else {
client = new Redis(redisParams);
}
client.on('error', (err) => logger.error({ err }, 'Redis Client Error!'));
return client;
};
/**
* Attempts to safely parse a string as JSON, used when both receiving a message
@ -97,177 +79,6 @@ const parseJSON = (json, req) => {
}
};
/**
* Takes an environment variable that should be an integer, attempts to parse
* it falling back to a default if not set, and handles errors parsing.
* @param {string|undefined} value
* @param {number} defaultValue
* @param {string} variableName
* @returns {number}
*/
const parseIntFromEnv = (value, defaultValue, variableName) => {
if (typeof value === 'string' && value.length > 0) {
const parsedValue = parseInt(value, 10);
if (isNaN(parsedValue)) {
throw new Error(`Invalid ${variableName} environment variable: ${value}`);
}
return parsedValue;
} else {
return defaultValue;
}
};
/**
* @param {NodeJS.ProcessEnv} env the `process.env` value to read configuration from
* @returns {pg.PoolConfig} the configuration for the PostgreSQL connection
*/
const pgConfigFromEnv = (env) => {
/** @type {Record<string, pg.PoolConfig>} */
const pgConfigs = {
development: {
user: env.DB_USER || pg.defaults.user,
password: env.DB_PASS || pg.defaults.password,
database: env.DB_NAME || 'mastodon_development',
host: env.DB_HOST || pg.defaults.host,
port: parseIntFromEnv(env.DB_PORT, pg.defaults.port ?? 5432, 'DB_PORT')
},
production: {
user: env.DB_USER || 'mastodon',
password: env.DB_PASS || '',
database: env.DB_NAME || 'mastodon_production',
host: env.DB_HOST || 'localhost',
port: parseIntFromEnv(env.DB_PORT, 5432, 'DB_PORT')
},
};
/**
* @type {pg.PoolConfig}
*/
let baseConfig = {};
if (env.DATABASE_URL) {
const parsedUrl = pgConnectionString.parse(env.DATABASE_URL);
// The result of dbUrlToConfig from pg-connection-string is not type
// compatible with pg.PoolConfig, since parts of the connection URL may be
// `null` when pg.PoolConfig expects `undefined`, as such we have to
// manually create the baseConfig object from the properties of the
// parsedUrl.
//
// For more information see:
// https://github.com/brianc/node-postgres/issues/2280
//
// FIXME: clean up once brianc/node-postgres#3128 lands
if (typeof parsedUrl.password === 'string') baseConfig.password = parsedUrl.password;
if (typeof parsedUrl.host === 'string') baseConfig.host = parsedUrl.host;
if (typeof parsedUrl.user === 'string') baseConfig.user = parsedUrl.user;
if (typeof parsedUrl.port === 'string') {
const parsedPort = parseInt(parsedUrl.port, 10);
if (isNaN(parsedPort)) {
throw new Error('Invalid port specified in DATABASE_URL environment variable');
}
baseConfig.port = parsedPort;
}
if (typeof parsedUrl.database === 'string') baseConfig.database = parsedUrl.database;
if (typeof parsedUrl.options === 'string') baseConfig.options = parsedUrl.options;
// The pg-connection-string type definition isn't correct, as parsedUrl.ssl
// can absolutely be an Object, this is to work around these incorrect
// types, including the casting of parsedUrl.ssl to Record<string, any>
if (typeof parsedUrl.ssl === 'boolean') {
baseConfig.ssl = parsedUrl.ssl;
} else if (typeof parsedUrl.ssl === 'object' && !Array.isArray(parsedUrl.ssl) && parsedUrl.ssl !== null) {
/** @type {Record<string, any>} */
const sslOptions = parsedUrl.ssl;
baseConfig.ssl = {};
baseConfig.ssl.cert = sslOptions.cert;
baseConfig.ssl.key = sslOptions.key;
baseConfig.ssl.ca = sslOptions.ca;
baseConfig.ssl.rejectUnauthorized = sslOptions.rejectUnauthorized;
}
// Support overriding the database password in the connection URL
if (!baseConfig.password && env.DB_PASS) {
baseConfig.password = env.DB_PASS;
}
} else if (Object.hasOwn(pgConfigs, environment)) {
baseConfig = pgConfigs[environment];
if (env.DB_SSLMODE) {
switch(env.DB_SSLMODE) {
case 'disable':
case '':
baseConfig.ssl = false;
break;
case 'no-verify':
baseConfig.ssl = { rejectUnauthorized: false };
break;
default:
baseConfig.ssl = {};
break;
}
}
} else {
throw new Error('Unable to resolve postgresql database configuration.');
}
return {
...baseConfig,
max: parseIntFromEnv(env.DB_POOL, 10, 'DB_POOL'),
connectionTimeoutMillis: 15000,
// Deliberately set application_name to an empty string to prevent excessive
// CPU usage with PG Bouncer. See:
// - https://github.com/mastodon/mastodon/pull/23958
// - https://github.com/pgbouncer/pgbouncer/issues/349
application_name: '',
};
};
/**
* @typedef RedisConfiguration
* @property {import('ioredis').RedisOptions} redisParams
* @property {string} redisPrefix
* @property {string|undefined} redisUrl
*/
/**
* @param {NodeJS.ProcessEnv} env the `process.env` value to read configuration from
* @returns {RedisConfiguration} configuration for the Redis connection
*/
const redisConfigFromEnv = (env) => {
// ioredis *can* transparently add prefixes for us, but it doesn't *in some cases*,
// which means we can't use it. But this is something that should be looked into.
const redisPrefix = env.REDIS_NAMESPACE ? `${env.REDIS_NAMESPACE}:` : '';
let redisPort = parseIntFromEnv(env.REDIS_PORT, 6379, 'REDIS_PORT');
let redisDatabase = parseIntFromEnv(env.REDIS_DB, 0, 'REDIS_DB');
/** @type {import('ioredis').RedisOptions} */
const redisParams = {
host: env.REDIS_HOST || '127.0.0.1',
port: redisPort,
// Force support for both IPv6 and IPv4, by default ioredis sets this to 4,
// only allowing IPv4 connections:
// https://github.com/redis/ioredis/issues/1576
family: 0,
db: redisDatabase,
password: env.REDIS_PASSWORD || undefined,
};
// redisParams.path takes precedence over host and port.
if (env.REDIS_URL && env.REDIS_URL.startsWith('unix://')) {
redisParams.path = env.REDIS_URL.slice(7);
}
return {
redisParams,
redisPrefix,
redisUrl: typeof env.REDIS_URL === 'string' ? env.REDIS_URL : undefined,
};
};
const PUBLIC_CHANNELS = [
'public',
'public:media',
@ -291,7 +102,12 @@ const CHANNEL_NAMES = [
];
const startServer = async () => {
const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
const pgPool = Database.getPool(Database.configFromEnv(process.env, environment));
const metrics = setupMetrics(CHANNEL_NAMES, pgPool);
const redisConfig = Redis.configFromEnv(process.env);
const redisClient = Redis.createClient(redisConfig, logger);
const server = http.createServer();
const wss = new WebSocketServer({ noServer: true });
@ -383,21 +199,9 @@ const startServer = async () => {
*/
const subs = {};
const redisConfig = redisConfigFromEnv(process.env);
const redisSubscribeClient = await createRedisClient(redisConfig);
const redisClient = await createRedisClient(redisConfig);
const redisSubscribeClient = Redis.createClient(redisConfig, logger);
const { redisPrefix } = redisConfig;
const metrics = setupMetrics(CHANNEL_NAMES, pgPool);
// TODO: migrate all metrics to metrics.X.method() instead of just X.method()
const {
connectedClients,
connectedChannels,
redisSubscriptions,
redisMessagesReceived,
messagesSent,
} = metrics;
// When checking metrics in the browser, the favicon is requested this
// prevents the request from falling through to the API Router, which would
// error for this endpoint:
@ -408,15 +212,7 @@ const startServer = async () => {
res.end('OK');
});
app.get('/metrics', async (req, res) => {
try {
res.set('Content-Type', metrics.register.contentType);
res.end(await metrics.register.metrics());
} catch (ex) {
req.log.error(ex);
res.status(500).end();
}
});
app.get('/metrics', metrics.requestHandler);
/**
* @param {string[]} channels
@ -443,7 +239,7 @@ const startServer = async () => {
* @param {string} message
*/
const onRedisMessage = (channel, message) => {
redisMessagesReceived.inc();
metrics.redisMessagesReceived.inc();
const callbacks = subs[channel];
@ -481,7 +277,7 @@ const startServer = async () => {
if (err) {
logger.error(`Error subscribing to ${channel}`);
} else if (typeof count === 'number') {
redisSubscriptions.set(count);
metrics.redisSubscriptions.set(count);
}
});
}
@ -508,7 +304,7 @@ const startServer = async () => {
if (err) {
logger.error(`Error unsubscribing to ${channel}`);
} else if (typeof count === 'number') {
redisSubscriptions.set(count);
metrics.redisSubscriptions.set(count);
}
});
delete subs[channel];
@ -690,13 +486,13 @@ const startServer = async () => {
unsubscribe(`${redisPrefix}${accessTokenChannelId}`, listener);
unsubscribe(`${redisPrefix}${systemChannelId}`, listener);
connectedChannels.labels({ type: 'eventsource', channel: 'system' }).dec(2);
metrics.connectedChannels.labels({ type: 'eventsource', channel: 'system' }).dec(2);
});
subscribe(`${redisPrefix}${accessTokenChannelId}`, listener);
subscribe(`${redisPrefix}${systemChannelId}`, listener);
connectedChannels.labels({ type: 'eventsource', channel: 'system' }).inc(2);
metrics.connectedChannels.labels({ type: 'eventsource', channel: 'system' }).inc(2);
};
/**
@ -820,7 +616,7 @@ const startServer = async () => {
// TODO: Replace "string"-based delete payloads with object payloads:
const encodedPayload = typeof payload === 'object' ? JSON.stringify(payload) : payload;
messagesSent.labels({ type: destinationType }).inc(1);
metrics.messagesSent.labels({ type: destinationType }).inc(1);
log.debug({ event, payload }, `Transmitting ${event} to ${req.accountId}`);
@ -1087,11 +883,11 @@ const startServer = async () => {
const streamToHttp = (req, res) => {
const channelName = channelNameFromPath(req);
connectedClients.labels({ type: 'eventsource' }).inc();
metrics.connectedClients.labels({ type: 'eventsource' }).inc();
// In theory we'll always have a channel name, but channelNameFromPath can return undefined:
if (typeof channelName === 'string') {
connectedChannels.labels({ type: 'eventsource', channel: channelName }).inc();
metrics.connectedChannels.labels({ type: 'eventsource', channel: channelName }).inc();
}
res.setHeader('Content-Type', 'text/event-stream');
@ -1107,10 +903,10 @@ const startServer = async () => {
// We decrement these counters here instead of in streamHttpEnd as in that
// method we don't have knowledge of the channel names
connectedClients.labels({ type: 'eventsource' }).dec();
metrics.connectedClients.labels({ type: 'eventsource' }).dec();
// In theory we'll always have a channel name, but channelNameFromPath can return undefined:
if (typeof channelName === 'string') {
connectedChannels.labels({ type: 'eventsource', channel: channelName }).dec();
metrics.connectedChannels.labels({ type: 'eventsource', channel: channelName }).dec();
}
clearInterval(heartbeat);
@ -1399,7 +1195,7 @@ const startServer = async () => {
const stopHeartbeat = subscriptionHeartbeat(channelIds);
const listener = streamFrom(channelIds, request, logger, onSend, undefined, 'websocket', options.needsFiltering);
connectedChannels.labels({ type: 'websocket', channel: channelName }).inc();
metrics.connectedChannels.labels({ type: 'websocket', channel: channelName }).inc();
subscriptions[channelIds.join(';')] = {
channelName,
@ -1438,7 +1234,7 @@ const startServer = async () => {
unsubscribe(`${redisPrefix}${channelId}`, subscription.listener);
});
connectedChannels.labels({ type: 'websocket', channel: subscription.channelName }).dec();
metrics.connectedChannels.labels({ type: 'websocket', channel: subscription.channelName }).dec();
subscription.stopHeartbeat();
delete subscriptions[channelIds.join(';')];
@ -1496,7 +1292,7 @@ const startServer = async () => {
},
};
connectedChannels.labels({ type: 'websocket', channel: 'system' }).inc(2);
metrics.connectedChannels.labels({ type: 'websocket', channel: 'system' }).inc(2);
};
/**
@ -1508,7 +1304,7 @@ const startServer = async () => {
// Note: url.parse could throw, which would terminate the connection, so we
// increment the connected clients metric straight away when we establish
// the connection, without waiting:
connectedClients.labels({ type: 'websocket' }).inc();
metrics.connectedClients.labels({ type: 'websocket' }).inc();
// Setup connection keep-alive state:
ws.isAlive = true;
@ -1534,7 +1330,7 @@ const startServer = async () => {
});
// Decrement the metrics for connected clients:
connectedClients.labels({ type: 'websocket' }).dec();
metrics.connectedClients.labels({ type: 'websocket' }).dec();
// We need to unassign the session object as to ensure it correctly gets
// garbage collected, without doing this we could accidentally hold on to
@ -1630,15 +1426,23 @@ const startServer = async () => {
* @param {function(string): void} [onSuccess]
*/
const attachServerWithConfig = (server, onSuccess) => {
if (process.env.SOCKET || process.env.PORT && isNaN(+process.env.PORT)) {
server.listen(process.env.SOCKET || process.env.PORT, () => {
if (process.env.SOCKET) {
server.listen(process.env.SOCKET, () => {
if (onSuccess) {
fs.chmodSync(server.address(), 0o666);
onSuccess(server.address());
}
});
} else {
server.listen(+(process.env.PORT || 4000), process.env.BIND || '127.0.0.1', () => {
const port = +(process.env.PORT || 4000);
let bind = process.env.BIND ?? '127.0.0.1';
// Web uses the URI syntax for BIND, which means IPv6 addresses may
// be wrapped in square brackets:
if (bind.startsWith('[') && bind.endsWith(']')) {
bind = bind.slice(1, -1);
}
server.listen(port, bind, () => {
if (onSuccess) {
onSuccess(`${server.address().address}:${server.address().port}`);
}