Skip to content

Commit

Permalink
feat(NODE-5939): Implement 6.x: cache the AWS credentials provider in…
Browse files Browse the repository at this point in the history
… the MONGODB-AWS auth logic (#3991)

Co-authored-by: Durran Jordan <durran@gmail.com>
  • Loading branch information
alenakhineika and durran committed Feb 21, 2024
1 parent 38742c2 commit e0a37e5
Show file tree
Hide file tree
Showing 15 changed files with 229 additions and 102 deletions.
4 changes: 4 additions & 0 deletions src/cmap/auth/auth_provider.ts
Expand Up @@ -34,6 +34,10 @@ export class AuthContext {
}
}

/**
* Provider used during authentication.
* @internal
*/
export abstract class AuthProvider {
/**
* Prepare the handshake document before the initial handshake.
Expand Down
115 changes: 60 additions & 55 deletions src/cmap/auth/mongodb_aws.ts
Expand Up @@ -4,7 +4,7 @@ import { promisify } from 'util';

import type { Binary, BSONSerializeOptions } from '../../bson';
import * as BSON from '../../bson';
import { aws4, getAwsCredentialProvider } from '../../deps';
import { aws4, type AWSCredentials, getAwsCredentialProvider } from '../../deps';
import {
MongoAWSError,
MongoCompatibilityError,
Expand Down Expand Up @@ -57,12 +57,42 @@ interface AWSSaslContinuePayload {
}

export class MongoDBAWS extends AuthProvider {
static credentialProvider: ReturnType<typeof getAwsCredentialProvider> | null = null;
static credentialProvider: ReturnType<typeof getAwsCredentialProvider>;
provider?: () => Promise<AWSCredentials>;
randomBytesAsync: (size: number) => Promise<Buffer>;

constructor() {
super();
this.randomBytesAsync = promisify(crypto.randomBytes);
MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();

let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
AWS_REGION = AWS_REGION.toLowerCase();

/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
const awsRegionSettingsExist =
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;

/**
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
*
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
* That is not our bug to fix here. We leave that up to the SDK.
*/
const useRegionalSts =
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));

if ('fromNodeProviderChain' in MongoDBAWS.credentialProvider) {
this.provider =
awsRegionSettingsExist && useRegionalSts
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
clientConfig: { region: AWS_REGION }
})
: MongoDBAWS.credentialProvider.fromNodeProviderChain();
}
}

override async auth(authContext: AuthContext): Promise<void> {
Expand All @@ -83,7 +113,7 @@ export class MongoDBAWS extends AuthProvider {
}

if (!authContext.credentials.username) {
authContext.credentials = await makeTempCredentials(authContext.credentials);
authContext.credentials = await makeTempCredentials(authContext.credentials, this.provider);
}

const { credentials } = authContext;
Expand Down Expand Up @@ -181,7 +211,10 @@ interface AWSTempCredentials {
Expiration?: Date;
}

async function makeTempCredentials(credentials: MongoCredentials): Promise<MongoCredentials> {
async function makeTempCredentials(
credentials: MongoCredentials,
provider?: () => Promise<AWSCredentials>
): Promise<MongoCredentials> {
function makeMongoCredentialsFromAWSTemp(creds: AWSTempCredentials) {
if (!creds.AccessKeyId || !creds.SecretAccessKey || !creds.Token) {
throw new MongoMissingCredentialsError('Could not obtain temporary MONGODB-AWS credentials');
Expand All @@ -198,11 +231,31 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
});
}

MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();

// Check if the AWS credential provider from the SDK is present. If not,
// use the old method.
if ('kModuleError' in MongoDBAWS.credentialProvider) {
if (provider && !('kModuleError' in MongoDBAWS.credentialProvider)) {
/*
* Creates a credential provider that will attempt to find credentials from the
* following sources (listed in order of precedence):
*
* - Environment variables exposed via process.env
* - SSO credentials from token cache
* - Web identity token credentials
* - Shared credentials and config ini files
* - The EC2/ECS Instance Metadata Service
*/
try {
const creds = await provider();
return makeMongoCredentialsFromAWSTemp({
AccessKeyId: creds.accessKeyId,
SecretAccessKey: creds.secretAccessKey,
Token: creds.sessionToken,
Expiration: creds.expiration
});
} catch (error) {
throw new MongoAWSError(error.message);
}
} else {
// If the environment variable AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
// is set then drivers MUST assume that it was set by an AWS ECS agent
if (process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI) {
Expand Down Expand Up @@ -232,54 +285,6 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
});

return makeMongoCredentialsFromAWSTemp(creds);
} else {
let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
AWS_REGION = AWS_REGION.toLowerCase();

/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
const awsRegionSettingsExist =
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;

/**
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
*
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
* That is not our bug to fix here. We leave that up to the SDK.
*/
const useRegionalSts =
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));

const provider =
awsRegionSettingsExist && useRegionalSts
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
clientConfig: { region: AWS_REGION }
})
: MongoDBAWS.credentialProvider.fromNodeProviderChain();

/*
* Creates a credential provider that will attempt to find credentials from the
* following sources (listed in order of precedence):
*
* - Environment variables exposed via process.env
* - SSO credentials from token cache
* - Web identity token credentials
* - Shared credentials and config ini files
* - The EC2/ECS Instance Metadata Service
*/
try {
const creds = await provider();
return makeMongoCredentialsFromAWSTemp({
AccessKeyId: creds.accessKeyId,
SecretAccessKey: creds.secretAccessKey,
Token: creds.sessionToken,
Expiration: creds.expiration
});
} catch (error) {
throw new MongoAWSError(error.message);
}
}
}

Expand Down
39 changes: 13 additions & 26 deletions src/cmap/connect.ts
Expand Up @@ -16,16 +16,10 @@ import {
MongoRuntimeError,
needsRetryableWriteLabel
} from '../error';
import { type MongoClientAuthProviders } from '../mongo_client_auth_providers';
import { HostAddress, ns, promiseWithResolvers } from '../utils';
import { AuthContext, type AuthProvider } from './auth/auth_provider';
import { GSSAPI } from './auth/gssapi';
import { MongoCR } from './auth/mongocr';
import { MongoDBAWS } from './auth/mongodb_aws';
import { MongoDBOIDC } from './auth/mongodb_oidc';
import { Plain } from './auth/plain';
import { AuthContext } from './auth/auth_provider';
import { AuthMechanism } from './auth/providers';
import { ScramSHA1, ScramSHA256 } from './auth/scram';
import { X509 } from './auth/x509';
import {
type CommandOptions,
Connection,
Expand All @@ -40,18 +34,6 @@ import {
MIN_SUPPORTED_WIRE_VERSION
} from './wire_protocol/constants';

/** @internal */
export const AUTH_PROVIDERS = new Map<AuthMechanism | string, AuthProvider>([
[AuthMechanism.MONGODB_AWS, new MongoDBAWS()],
[AuthMechanism.MONGODB_CR, new MongoCR()],
[AuthMechanism.MONGODB_GSSAPI, new GSSAPI()],
[AuthMechanism.MONGODB_OIDC, new MongoDBOIDC()],
[AuthMechanism.MONGODB_PLAIN, new Plain()],
[AuthMechanism.MONGODB_SCRAM_SHA1, new ScramSHA1()],
[AuthMechanism.MONGODB_SCRAM_SHA256, new ScramSHA256()],
[AuthMechanism.MONGODB_X509, new X509()]
]);

/** @public */
export type Stream = Socket | TLSSocket;

Expand Down Expand Up @@ -111,7 +93,7 @@ export async function performInitialHandshake(
if (credentials) {
if (
!(credentials.mechanism === AuthMechanism.MONGODB_DEFAULT) &&
!AUTH_PROVIDERS.get(credentials.mechanism)
!options.authProviders.getOrCreateProvider(credentials.mechanism)
) {
throw new MongoInvalidArgumentError(`AuthMechanism '${credentials.mechanism}' not supported`);
}
Expand All @@ -120,7 +102,7 @@ export async function performInitialHandshake(
const authContext = new AuthContext(conn, credentials, options);
conn.authContext = authContext;

const handshakeDoc = await prepareHandshakeDocument(authContext);
const handshakeDoc = await prepareHandshakeDocument(authContext, options.authProviders);

// @ts-expect-error: TODO(NODE-5141): The options need to be filtered properly, Connection options differ from Command options
const handshakeOptions: CommandOptions = { ...options };
Expand Down Expand Up @@ -166,7 +148,7 @@ export async function performInitialHandshake(
authContext.response = response;

const resolvedCredentials = credentials.resolveAuthMechanism(response);
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
const provider = options.authProviders.getOrCreateProvider(resolvedCredentials.mechanism);
if (!provider) {
throw new MongoInvalidArgumentError(
`No AuthProvider for ${resolvedCredentials.mechanism} defined.`
Expand All @@ -191,6 +173,10 @@ export async function performInitialHandshake(
conn.established = true;
}

/**
* HandshakeDocument used during authentication.
* @internal
*/
export interface HandshakeDocument extends Document {
/**
* @deprecated Use hello instead
Expand All @@ -210,7 +196,8 @@ export interface HandshakeDocument extends Document {
* This function is only exposed for testing purposes.
*/
export async function prepareHandshakeDocument(
authContext: AuthContext
authContext: AuthContext,
authProviders: MongoClientAuthProviders
): Promise<HandshakeDocument> {
const options = authContext.options;
const compressors = options.compressors ? options.compressors : [];
Expand All @@ -232,7 +219,7 @@ export async function prepareHandshakeDocument(
if (credentials.mechanism === AuthMechanism.MONGODB_DEFAULT && credentials.username) {
handshakeDoc.saslSupportedMechs = `${credentials.source}.${credentials.username}`;

const provider = AUTH_PROVIDERS.get(AuthMechanism.MONGODB_SCRAM_SHA256);
const provider = authProviders.getOrCreateProvider(AuthMechanism.MONGODB_SCRAM_SHA256);
if (!provider) {
// This auth mechanism is always present.
throw new MongoInvalidArgumentError(
Expand All @@ -241,7 +228,7 @@ export async function prepareHandshakeDocument(
}
return provider.prepare(handshakeDoc, authContext);
}
const provider = AUTH_PROVIDERS.get(credentials.mechanism);
const provider = authProviders.getOrCreateProvider(credentials.mechanism);
if (!provider) {
throw new MongoInvalidArgumentError(`No AuthProvider for ${credentials.mechanism} defined.`);
}
Expand Down
3 changes: 3 additions & 0 deletions src/cmap/connection.ts
Expand Up @@ -24,6 +24,7 @@ import {
MongoWriteConcernError
} from '../error';
import type { ServerApi, SupportedNodeConnectionOptions } from '../mongo_client';
import { type MongoClientAuthProviders } from '../mongo_client_auth_providers';
import { MongoLoggableComponent, type MongoLogger, SeverityLevel } from '../mongo_logger';
import { type CancellationToken, TypedEventEmitter } from '../mongo_types';
import type { ReadPreferenceLike } from '../read_preference';
Expand Down Expand Up @@ -109,6 +110,8 @@ export interface ConnectionOptions
/** @internal */
connectionType?: any;
credentials?: MongoCredentials;
/** @internal */
authProviders: MongoClientAuthProviders;
connectTimeoutMS?: number;
tls: boolean;
noDelay?: boolean;
Expand Down
9 changes: 6 additions & 3 deletions src/cmap/connection_pool.ts
Expand Up @@ -28,7 +28,7 @@ import {
import { CancellationToken, TypedEventEmitter } from '../mongo_types';
import type { Server } from '../sdam/server';
import { type Callback, eachAsync, List, makeCounter, TimeoutController } from '../utils';
import { AUTH_PROVIDERS, connect } from './connect';
import { connect } from './connect';
import { Connection, type ConnectionEvents, type ConnectionOptions } from './connection';
import {
ConnectionCheckedInEvent,
Expand Down Expand Up @@ -622,7 +622,9 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
);
}
const resolvedCredentials = credentials.resolveAuthMechanism(connection.hello);
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
const provider = this[kServer].topology.client.s.authProviders.getOrCreateProvider(
resolvedCredentials.mechanism
);
if (!provider) {
return callback(
new MongoMissingCredentialsError(
Expand Down Expand Up @@ -700,7 +702,8 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
id: this[kConnectionCounter].next().value,
generation: this[kGeneration],
cancellationToken: this[kCancellationToken],
mongoLogger: this.mongoLogger
mongoLogger: this.mongoLogger,
authProviders: this[kServer].topology.client.s.authProviders
};

this[kPending]++;
Expand Down
4 changes: 3 additions & 1 deletion src/index.ts
Expand Up @@ -243,7 +243,7 @@ export type {
CSFLEKMSTlsOptions,
StateMachineExecutable
} from './client-side-encryption/state_machine';
export type { AuthContext } from './cmap/auth/auth_provider';
export type { AuthContext, AuthProvider } from './cmap/auth/auth_provider';
export type {
AuthMechanismProperties,
MongoCredentials,
Expand All @@ -268,6 +268,7 @@ export type {
OpResponseOptions,
WriteProtocolMessageType
} from './cmap/commands';
export type { HandshakeDocument } from './cmap/connect';
export type { LEGAL_TCP_SOCKET_OPTIONS, LEGAL_TLS_SOCKET_OPTIONS, Stream } from './cmap/connect';
export type {
CommandOptions,
Expand Down Expand Up @@ -365,6 +366,7 @@ export type {
SupportedTLSSocketOptions,
WithSessionCallback
} from './mongo_client';
export { MongoClientAuthProviders } from './mongo_client_auth_providers';
export type {
Log,
LogComponentSeveritiesClientOptions,
Expand Down
8 changes: 6 additions & 2 deletions src/mongo_client.ts
Expand Up @@ -21,6 +21,7 @@ import { MONGO_CLIENT_EVENTS } from './constants';
import { Db, type DbOptions } from './db';
import type { Encrypter } from './encrypter';
import { MongoInvalidArgumentError } from './error';
import { MongoClientAuthProviders } from './mongo_client_auth_providers';
import {
type LogComponentSeveritiesClientOptions,
type MongoDBLogWritable,
Expand Down Expand Up @@ -297,6 +298,7 @@ export interface MongoClientPrivate {
bsonOptions: BSONSerializeOptions;
namespace: MongoDBNamespace;
hasBeenClosed: boolean;
authProviders: MongoClientAuthProviders;
/**
* We keep a reference to the sessions that are acquired from the pool.
* - used to track and close all sessions in client.close() (which is non-standard behavior)
Expand All @@ -319,6 +321,7 @@ export type MongoClientEvents = Pick<TopologyEvents, (typeof MONGO_CLIENT_EVENTS
};

/** @internal */

const kOptions = Symbol('options');

/**
Expand Down Expand Up @@ -379,6 +382,7 @@ export class MongoClient extends TypedEventEmitter<MongoClientEvents> {
hasBeenClosed: false,
sessionPool: new ServerSessionPool(this),
activeSessions: new Set(),
authProviders: new MongoClientAuthProviders(),

get options() {
return client[kOptions];
Expand Down Expand Up @@ -829,10 +833,10 @@ export interface MongoOptions
proxyUsername?: string;
proxyPassword?: string;
serverMonitoringMode: ServerMonitoringMode;

/** @internal */
connectionType?: typeof Connection;

/** @internal */
authProviders: MongoClientAuthProviders;
/** @internal */
encrypter: Encrypter;
/** @internal */
Expand Down

0 comments on commit e0a37e5

Please sign in to comment.