feat(core): support apple sign in (#12424)

This commit is contained in:
liuyi
2025-05-23 15:27:27 +08:00
committed by GitHub
parent a96cd3eb0a
commit 41781902f6
22 changed files with 629 additions and 260 deletions

View File

@@ -1,3 +1,5 @@
import { z } from 'zod';
import { defineModuleConfig, JSONSchema } from '../../base';
export interface OAuthProviderConfig {
@@ -21,6 +23,7 @@ export interface OAuthOIDCProviderConfig extends OAuthProviderConfig {
export enum OAuthProviderName {
Google = 'google',
GitHub = 'github',
Apple = 'apple',
OIDC = 'oidc',
}
declare global {
@@ -29,6 +32,7 @@ declare global {
providers: {
[OAuthProviderName.Google]: ConfigItem<OAuthProviderConfig>;
[OAuthProviderName.GitHub]: ConfigItem<OAuthProviderConfig>;
[OAuthProviderName.Apple]: ConfigItem<OAuthProviderConfig>;
[OAuthProviderName.OIDC]: ConfigItem<OAuthOIDCProviderConfig>;
};
};
@@ -71,5 +75,29 @@ defineModuleConfig('oauth', {
issuer: '',
args: {},
},
schema,
link: 'https://openid.net/specs/openid-connect-core-1_0.html',
shape: z.object({
issuer: z
.string()
.url()
.regex(/^https?:\/\//, 'issuer must be a valid URL')
.or(z.string().length(0)),
args: z.object({
scope: z.string().optional(),
claim_id: z.string().optional(),
claim_email: z.string().optional(),
claim_name: z.string().optional(),
}),
}),
},
'providers.apple': {
desc: 'Apple OAuth provider config',
default: {
clientId: '',
clientSecret: '',
},
schema,
link: 'https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/implementing_sign_in_with_apple_in_your_app',
},
});

View File

@@ -19,6 +19,7 @@ import {
OauthAccountAlreadyConnected,
OauthStateExpired,
UnknownOauthProvider,
URLHelper,
UseNamedGuard,
} from '../../base';
import { AuthService, Public } from '../../core/auth';
@@ -36,7 +37,8 @@ export class OAuthController {
private readonly auth: AuthService,
private readonly oauth: OAuthService,
private readonly models: Models,
private readonly providerFactory: OAuthProviderFactory
private readonly providerFactory: OAuthProviderFactory,
private readonly url: URLHelper
) {}
@Public()
@@ -67,10 +69,14 @@ export class OAuthController {
clientNonce,
});
const stateStr = JSON.stringify({
state,
client,
provider: unknownProviderName,
});
return {
url: provider.getAuthUrl(
JSON.stringify({ state, client, provider: unknownProviderName })
),
url: provider.getAuthUrl(stateStr, clientNonce),
};
}
@@ -85,6 +91,7 @@ export class OAuthController {
@Body('state') stateStr?: string,
@Body('client_nonce') clientNonce?: string
) {
// TODO(@forehalo): refactor and remove deprecated code in 0.23
if (!code) {
throw new MissingOauthQueryParameter({ name: 'code' });
}
@@ -93,6 +100,17 @@ export class OAuthController {
throw new MissingOauthQueryParameter({ name: 'state' });
}
// NOTE(@forehalo): Apple sign in will directly post /callback, with `state` set at #L73
let rawState = null;
if (typeof stateStr === 'string' && stateStr.length > 36) {
try {
rawState = JSON.parse(stateStr);
stateStr = rawState.state;
} catch {
/* noop */
}
}
if (typeof stateStr !== 'string' || !this.oauth.isValidState(stateStr)) {
throw new InvalidOauthCallbackState();
}
@@ -103,8 +121,38 @@ export class OAuthController {
throw new OauthStateExpired();
}
if (
state.provider === OAuthProviderName.Apple &&
rawState &&
state.client &&
state.client !== 'web'
) {
const clientUrl = new URL(`${state.client}://authentication`);
clientUrl.searchParams.set('method', 'oauth');
clientUrl.searchParams.set(
'payload',
JSON.stringify({
state: stateStr,
code,
provider: rawState.provider,
})
);
clientUrl.searchParams.set('server', this.url.origin);
return res.redirect(
this.url.link('/open-app/url?', {
url: clientUrl.toString(),
})
);
}
// TODO(@fengmk2): clientNonce should be required after the client version >= 0.21.0
if (state.clientNonce && state.clientNonce !== clientNonce) {
if (
state.clientNonce &&
state.clientNonce !== clientNonce &&
// apple sign in with nonce stored in id token
state.provider !== OAuthProviderName.Apple
) {
throw new InvalidAuthState();
}
@@ -132,7 +180,8 @@ export class OAuthController {
);
throw err;
}
const externAccount = await provider.getUser(tokens.accessToken);
const externAccount = await provider.getUser(tokens, state);
const user = await this.loginFromOauth(
state.provider,
externAccount,
@@ -140,6 +189,14 @@ export class OAuthController {
);
await this.auth.setCookies(req, res, user.id);
if (
state.provider === OAuthProviderName.Apple &&
(!state.client || state.client === 'web')
) {
return res.redirect(this.url.link(state.redirectUri ?? '/'));
}
res.send({
id: user.id,
redirectUri: state.redirectUri,
@@ -170,7 +227,9 @@ export class OAuthController {
userId: user.id,
provider,
providerAccountId: externalAccount.id,
...tokens,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
return user;
@@ -180,10 +239,11 @@ export class OAuthController {
connectedAccount: ConnectedAccount,
tokens: Tokens
) {
return await this.models.user.updateConnectedAccount(
connectedAccount.id,
tokens
);
return await this.models.user.updateConnectedAccount(connectedAccount.id, {
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
}
/**
@@ -210,7 +270,9 @@ export class OAuthController {
userId: user.id,
provider,
providerAccountId: externalAccount.id,
...tokens,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
}
}

View File

@@ -0,0 +1,131 @@
import { JsonWebKey } from 'node:crypto';
import { Injectable } from '@nestjs/common';
import jwt, { type JwtPayload } from 'jsonwebtoken';
import {
InternalServerError,
InvalidOauthCallbackCode,
URLHelper,
} from '../../../base';
import { OAuthProviderName } from '../config';
import { OAuthProvider, Tokens } from './def';
interface AuthTokenResponse {
access_token: string;
refresh_token: string;
id_token: string;
token_type: string;
expires_in: number;
}
@Injectable()
export class AppleOAuthProvider extends OAuthProvider {
provider = OAuthProviderName.Apple;
constructor(private readonly url: URLHelper) {
super();
}
getAuthUrl(state: string, clientNonce?: string): string {
return `https://appleid.apple.com/auth/authorize?${this.url.stringify({
client_id: this.config.clientId,
redirect_uri: this.url.link('/api/oauth/callback'),
scope: 'name email',
response_type: 'code',
response_mode: 'form_post',
...this.config.args,
state,
nonce: clientNonce,
})}`;
}
async getToken(code: string) {
const response = await fetch('https://appleid.apple.com/auth/token', {
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/api/oauth/callback'),
grant_type: 'authorization_code',
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
});
if (response.ok) {
const appleToken = (await response.json()) as AuthTokenResponse;
return {
accessToken: appleToken.access_token,
refreshToken: appleToken.refresh_token,
expiresAt: new Date(Date.now() + appleToken.expires_in * 1000),
idToken: appleToken.id_token,
};
} else {
const body = await response.text();
if (response.status < 500) {
throw new InvalidOauthCallbackCode({ status: response.status, body });
}
throw new Error(
`Server responded with non-success status ${response.status}, body: ${body}`
);
}
}
async getUser(
tokens: Tokens & { idToken: string },
state: { clientNonce: string }
) {
const keysReq = await fetch('https://appleid.apple.com/auth/keys', {
method: 'GET',
});
const { keys } = (await keysReq.json()) as { keys: JsonWebKey[] };
const payload = await new Promise<JwtPayload>((resolve, reject) => {
jwt.verify(
tokens.idToken,
(header, callback) => {
const key = keys.find(key => key.kid === header.kid);
if (!key) {
callback(
new InternalServerError(
'Cannot find match apple public sign key.'
)
);
} else {
callback(null, {
format: 'jwk',
key,
});
}
},
{
issuer: 'https://appleid.apple.com',
audience: this.config.clientId,
nonce: state.clientNonce,
},
(err, payload) => {
if (err || !payload || typeof payload === 'string') {
reject(err || new InternalServerError('Invalid jwt payload'));
return;
}
resolve(payload);
}
);
});
// see https://developer.apple.com/documentation/signinwithapple/authenticating-users-with-sign-in-with-apple
if (!payload.sub || !payload.email) {
throw new Error('Invalid jwt payload');
}
return {
id: payload.sub,
email: payload.email,
};
}
}

View File

@@ -17,12 +17,19 @@ export interface Tokens {
expiresAt?: Date;
}
export interface AuthOptions {
client_id: string;
redirect_uri: string;
scope: string;
state: string;
}
@Injectable()
export abstract class OAuthProvider {
abstract provider: OAuthProviderName;
abstract getAuthUrl(state: string): string;
abstract getAuthUrl(state: string, clientNonce?: string): string;
abstract getToken(code: string): Promise<Tokens>;
abstract getUser(token: string): Promise<OAuthAccount>;
abstract getUser(tokens: Tokens, state: any): Promise<OAuthAccount>;
protected readonly logger = new Logger(this.constructor.name);
@Inject() private readonly factory!: OAuthProviderFactory;
@@ -33,7 +40,9 @@ export abstract class OAuthProvider {
}
get configured() {
return this.config && this.config.clientId && this.config.clientSecret;
return (
!!this.config && !!this.config.clientId && !!this.config.clientSecret
);
}
@OnEvent('config.init')

View File

@@ -2,7 +2,7 @@ import { Injectable } from '@nestjs/common';
import { InvalidOauthCallbackCode, URLHelper } from '../../../base';
import { OAuthProviderName } from '../config';
import { OAuthProvider } from './def';
import { OAuthProvider, Tokens } from './def';
interface AuthTokenResponse {
access_token: string;
@@ -71,11 +71,11 @@ export class GithubOAuthProvider extends OAuthProvider {
}
}
async getUser(token: string) {
async getUser(tokens: Tokens) {
const response = await fetch('https://api.github.com/user', {
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
Authorization: `Bearer ${tokens.accessToken}`,
},
});

View File

@@ -2,7 +2,7 @@ import { Injectable } from '@nestjs/common';
import { InvalidOauthCallbackCode, URLHelper } from '../../../base';
import { OAuthProviderName } from '../config';
import { OAuthProvider } from './def';
import { OAuthProvider, Tokens } from './def';
interface GoogleOAuthTokenResponse {
access_token: string;
@@ -76,13 +76,13 @@ export class GoogleOAuthProvider extends OAuthProvider {
}
}
async getUser(token: string) {
async getUser(tokens: Tokens) {
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{
method: 'GET',
headers: {
Authorization: `Bearer ${token}`,
Authorization: `Bearer ${tokens.accessToken}`,
},
}
);

View File

@@ -1,3 +1,4 @@
import { AppleOAuthProvider } from './apple';
import { GithubOAuthProvider } from './github';
import { GoogleOAuthProvider } from './google';
import { OIDCProvider } from './oidc';
@@ -6,4 +7,5 @@ export const OAuthProviders = [
GoogleOAuthProvider,
GithubOAuthProvider,
OIDCProvider,
AppleOAuthProvider,
];

View File

@@ -1,12 +1,13 @@
import { Injectable, Logger } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { omit } from 'lodash-es';
import { z } from 'zod';
import { URLHelper } from '../../../base';
import {
OAuthOIDCProviderConfig,
OAuthProviderName,
OIDCArgs,
} from '../config';
InvalidOauthCallbackCode,
InvalidOauthResponse,
URLHelper,
} from '../../../base';
import { OAuthOIDCProviderConfig, OAuthProviderName } from '../config';
import { OAuthAccount, OAuthProvider, Tokens } from './def';
const OIDCTokenSchema = z.object({
@@ -27,8 +28,6 @@ const OIDCUserInfoSchema = z
})
.passthrough();
type OIDCUserInfo = z.infer<typeof OIDCUserInfoSchema>;
const OIDCConfigurationSchema = z.object({
authorization_endpoint: z.string().url(),
token_endpoint: z.string().url(),
@@ -37,173 +36,142 @@ const OIDCConfigurationSchema = z.object({
type OIDCConfiguration = z.infer<typeof OIDCConfigurationSchema>;
const logger = new Logger('OIDCClient');
class OIDCClient {
private static async fetch<T = any>(
url: string,
options: RequestInit,
verifier: z.Schema<T>
): Promise<T> {
const response = await fetch(url, options);
if (!response.ok) {
logger.error('Failed to fetch OIDC configuration', await response.json());
throw new Error(`Failed to configure client`);
}
const data = await response.json();
return verifier.parse(data);
}
static async create(config: OAuthOIDCProviderConfig, url: URLHelper) {
const { args, clientId, clientSecret, issuer } = config;
if (!url.verify(issuer)) {
throw new Error('OIDC Issuer is invalid.');
}
const oidcConfig = await OIDCClient.fetch(
`${issuer}/.well-known/openid-configuration`,
{
method: 'GET',
headers: { Accept: 'application/json' },
},
OIDCConfigurationSchema
);
return new OIDCClient(clientId, clientSecret, args, oidcConfig, url);
}
private constructor(
private readonly clientId: string,
private readonly clientSecret: string,
private readonly args: OIDCArgs | undefined,
private readonly config: OIDCConfiguration,
private readonly url: URLHelper
) {}
authorize(state: string): string {
const args = Object.assign({}, this.args);
if ('claim_id' in args) delete args.claim_id;
if ('claim_email' in args) delete args.claim_email;
if ('claim_name' in args) delete args.claim_name;
return `${this.config.authorization_endpoint}?${this.url.stringify({
client_id: this.clientId,
redirect_uri: this.url.link('/oauth/callback'),
response_type: 'code',
...args,
scope: this.args?.scope || 'openid profile email',
state,
})}`;
}
async token(code: string): Promise<Tokens> {
const token = await OIDCClient.fetch(
this.config.token_endpoint,
{
method: 'POST',
body: this.url.stringify({
code,
client_id: this.clientId,
client_secret: this.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
grant_type: 'authorization_code',
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
},
OIDCTokenSchema
);
return {
accessToken: token.access_token,
refreshToken: token.refresh_token,
expiresAt: new Date(Date.now() + token.expires_in * 1000),
scope: token.scope,
};
}
private mapUserInfo(
user: OIDCUserInfo,
claimsMap: Record<string, string>
): OAuthAccount {
const mappedUser: Partial<OAuthAccount> = {};
for (const [key, value] of Object.entries(claimsMap)) {
const claimValue = user[value];
if (claimValue !== undefined) {
mappedUser[key as keyof OAuthAccount] = claimValue as string;
}
}
return mappedUser as OAuthAccount;
}
async userinfo(token: string) {
const user = await OIDCClient.fetch(
this.config.userinfo_endpoint,
{
method: 'GET',
headers: {
Accept: 'application/json',
Authorization: `Bearer ${token}`,
},
},
OIDCUserInfoSchema
);
const claimsMap = {
id: this.args?.claim_id || 'preferred_username',
email: this.args?.claim_email || 'email',
name: this.args?.claim_name || 'name',
};
const userinfo = this.mapUserInfo(user, claimsMap);
return { id: userinfo.id, email: userinfo.email };
}
}
@Injectable()
export class OIDCProvider extends OAuthProvider {
override provider = OAuthProviderName.OIDC;
private client: OIDCClient | null = null;
#endpoints: OIDCConfiguration | null = null;
constructor(private readonly url: URLHelper) {
super();
}
protected override setup() {
super.setup();
if (this.configured) {
OIDCClient.create(this.config as OAuthOIDCProviderConfig, this.url)
.then(client => {
this.client = client;
})
.catch(e => {
this.logger.error('Failed to create OIDC client', e);
});
} else {
this.client = null;
private get endpoints() {
if (!this.#endpoints) {
throw new Error('OIDC provider is not configured');
}
return this.#endpoints;
}
private checkOIDCClient(
client: OIDCClient | null
): asserts client is OIDCClient {
if (!client) {
throw new Error('OIDC client has not been loaded yet.');
}
override get configured() {
return this.#endpoints !== null;
}
protected override setup() {
const validate = async () => {
this.#endpoints = null;
if (this.configured) {
const config = this.config as OAuthOIDCProviderConfig;
try {
const res = await fetch(
`${config.issuer}/.well-known/openid-configuration`,
{
method: 'GET',
headers: { Accept: 'application/json' },
}
);
if (res.ok) {
this.#endpoints = OIDCConfigurationSchema.parse(await res.json());
super.setup();
} else {
this.logger.error(`Invalid OIDC issuer ${config.issuer}`);
}
} catch (e) {
this.logger.error('Failed to validate OIDC configuration', e);
}
}
};
validate().catch(() => {
/* noop */
});
}
getAuthUrl(state: string): string {
this.checkOIDCClient(this.client);
return this.client.authorize(state);
return `${this.endpoints.authorization_endpoint}?${this.url.stringify({
client_id: this.config.clientId,
redirect_uri: this.url.link('/oauth/callback'),
scope: this.config.args?.scope || 'openid profile email',
response_type: 'code',
...omit(this.config.args, 'claim_id', 'claim_email', 'claim_name'),
state,
})}`;
}
async getToken(code: string): Promise<Tokens> {
this.checkOIDCClient(this.client);
return await this.client.token(code);
const res = await fetch(this.endpoints.token_endpoint, {
method: 'POST',
body: this.url.stringify({
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: this.url.link('/oauth/callback'),
grant_type: 'authorization_code',
}),
headers: {
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
},
});
if (res.ok) {
const data = await res.json();
const tokens = OIDCTokenSchema.parse(data);
return {
accessToken: tokens.access_token,
refreshToken: tokens.refresh_token,
expiresAt: new Date(Date.now() + tokens.expires_in * 1000),
scope: tokens.scope,
};
}
throw new InvalidOauthCallbackCode({
status: res.status,
body: await res.text(),
});
}
async getUser(token: string): Promise<OAuthAccount> {
this.checkOIDCClient(this.client);
return await this.client.userinfo(token);
async getUser(tokens: Tokens): Promise<OAuthAccount> {
const res = await fetch(this.endpoints.userinfo_endpoint, {
method: 'GET',
headers: {
Accept: 'application/json',
Authorization: `Bearer ${tokens.accessToken}`,
},
});
if (res.ok) {
const body = await res.json();
const user = OIDCUserInfoSchema.parse(body);
const args = this.config.args ?? {};
const claimsMap = {
id: args.claim_id || 'preferred_username',
email: args.claim_email || 'email',
name: args.claim_name || 'name',
};
const identities = {
id: user[claimsMap.id] as string,
email: user[claimsMap.email] as string,
};
if (!identities.id || !identities.email) {
throw new InvalidOauthResponse({
reason: `Missing required claims: ${Object.keys(identities)
.filter(key => !identities[key as keyof typeof identities])
.join(', ')}`,
});
}
return identities;
}
throw new InvalidOauthCallbackCode({
status: res.status,
body: await res.text(),
});
}
}