feat(server): passkey pre-refactor (#15060)

#### PR Dependency Tree


* **PR #15060** 👈

This tree was auto-generated by
[Charcoal](https://github.com/danerwilliams/charcoal)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* OpenApp native sign-in and native session exchange (JWT) for mobile &
desktop.
  * Centralized short-lived auth challenge store for one-time tokens.
* Encrypted per-endpoint token storage and native token handlers
(Android, iOS, Electron).

* **Improvements**
* Richer auth-method reporting (password, magic link, OAuth, passkey)
and improved sign-in flows.
* Hardened magic-link, OAuth, and session issuance; JWT-backed sessions
and websocket JWT support.
* UX tweaks: form-based password submit, OTP autocomplete, adjusted
captcha flow.

* **Bug Fixes**
  * Expanded tests and auth-state resets to avoid cross-test leakage.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
DarkSky
2026-06-01 17:11:15 +08:00
committed by GitHub
parent 5b9d51b41b
commit ce9841df9d
74 changed files with 3719 additions and 939 deletions
@@ -2,6 +2,7 @@ import { randomBytes } from 'node:crypto';
import type { TestFn } from 'ava';
import ava from 'ava';
import supertest from 'supertest';
import {
changeEmail,
@@ -33,6 +34,10 @@ test('change email', async t => {
const u2Email = 'u2@affine.pro';
const user = await app.signupV1(u1Email);
const signedIn = await currentUser(app);
const jwt = signedIn?.token.token;
t.truthy(jwt);
await sendChangeEmail(app, u1Email, '/email-change');
const changeMail = app.mails.last('ChangeEmail');
@@ -77,7 +82,16 @@ test('change email', async t => {
t.is(changedMail.to, u2Email);
t.is(changedMail.props.to, u2Email);
await app.logout();
const revokedCookieSession = await currentUser(app);
t.is(revokedCookieSession, null);
const revokedJwtSession = await supertest(app.getHttpServer())
.get('/api/auth/session')
.set('Authorization', `Bearer ${jwt}`)
.expect(200);
t.falsy(revokedJwtSession.body.user);
app.clearAuth();
await app.login({
...user,
email: u2Email,
@@ -0,0 +1,116 @@
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { SessionCache } from '../../base';
import { AuthChallengeStore, AuthModule } from '../../core/auth';
import { createTestingApp, TestingApp } from '../utils';
const test = ava as TestFn<{
app: TestingApp;
challenges: AuthChallengeStore;
}>;
test.before(async t => {
const app = await createTestingApp({
imports: [AuthModule],
});
t.context.app = app;
t.context.challenges = app.get(AuthChallengeStore);
});
test.beforeEach(() => {
Sinon.restore();
});
test.after.always(async t => {
await t.context.app.close();
});
test('should create and get challenge payload without consuming it', async t => {
const token = await t.context.challenges.create(
'oauth_state',
{ provider: 'Google' },
30_000
);
t.deepEqual(await t.context.challenges.get('oauth_state', token), {
provider: 'Google',
});
t.deepEqual(await t.context.challenges.get('oauth_state', token), {
provider: 'Google',
});
});
test('should consume challenge payload once', async t => {
const token = await t.context.challenges.create(
'open_app_sign_in',
{ userId: 'u1' },
30_000
);
t.deepEqual(await t.context.challenges.consume('open_app_sign_in', token), {
userId: 'u1',
});
t.is(await t.context.challenges.consume('open_app_sign_in', token), null);
});
test('should isolate challenges by purpose', async t => {
const token = await t.context.challenges.create(
'open_app_sign_in',
{ userId: 'u1' },
30_000
);
t.is(await t.context.challenges.get('oauth_state', token), null);
t.is(await t.context.challenges.consume('oauth_state', token), null);
t.deepEqual(await t.context.challenges.consume('open_app_sign_in', token), {
userId: 'u1',
});
});
test('should return null for expired challenge', async t => {
const token = await t.context.challenges.create(
'open_app_sign_in',
{ userId: 'u1' },
1
);
await new Promise(resolve => setTimeout(resolve, 10));
t.is(await t.context.challenges.get('open_app_sign_in', token), null);
t.is(await t.context.challenges.consume('open_app_sign_in', token), null);
});
test('should reject invalid challenge ttl', async t => {
await t.throwsAsync(
t.context.challenges.create('open_app_sign_in', { userId: 'u1' }, 0),
{ message: /Invalid auth state/ }
);
});
test('should reject challenge creation when cache write fails', async t => {
Sinon.stub(t.context.app.get(SessionCache), 'set').resolves(false);
await t.throwsAsync(
t.context.challenges.create('open_app_sign_in', { userId: 'u1' }, 30_000),
{ message: /Invalid auth state/ }
);
});
test('should atomically allow one concurrent consume', async t => {
const token = await t.context.challenges.create(
'open_app_sign_in',
{ userId: 'u1' },
30_000
);
const results = await Promise.all(
Array.from({ length: 8 }, () =>
t.context.challenges.consume('open_app_sign_in', token)
)
);
t.is(results.filter(Boolean).length, 1);
t.deepEqual(results.find(Boolean), { userId: 'u1' });
});
@@ -3,11 +3,17 @@ import { IncomingMessage } from 'node:http';
import { HttpStatus } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import ava, { ExecutionContext, TestFn } from 'ava';
import Sinon from 'sinon';
import supertest from 'supertest';
import { parseCookies as safeParseCookies } from '../../base/utils/request';
import { ConfigFactory } from '../../base';
import {
getRequestCookie,
getRequestHeader,
parseCookies as safeParseCookies,
} from '../../base/utils/request';
import { MagicLinkAuthService } from '../../core/auth/magic-link';
import { AuthService } from '../../core/auth/service';
import {
createTestingApp,
@@ -18,7 +24,9 @@ import {
const test = ava as TestFn<{
auth: AuthService;
magicLink: MagicLinkAuthService;
db: PrismaClient;
config: ConfigFactory;
app: TestingApp;
}>;
@@ -26,13 +34,18 @@ test.before(async t => {
const app = await createTestingApp();
t.context.auth = app.get(AuthService);
t.context.magicLink = app.get(MagicLinkAuthService);
t.context.db = app.get(PrismaClient);
t.context.config = app.get(ConfigFactory);
t.context.app = app;
});
test.beforeEach(async t => {
Sinon.reset();
await t.context.app.initTestingDB();
t.context.config.override({
auth: { allowSignup: true, requireEmailDomainVerification: false },
});
});
test.after.always(async t => {
@@ -44,15 +57,102 @@ test('should be able to sign in with credential', async t => {
const u1 = await app.createUser('u1@affine.pro');
await app
const res = await app
.POST('/api/auth/sign-in')
.send({ email: u1.email, password: u1.password })
.expect(200);
t.is(res.body.id, u1.id);
t.falsy(res.body.token);
t.falsy(res.body.expiresAt);
const session = await currentUser(app);
t.is(session?.id, u1.id);
});
async function exchangeSession(app: TestingApp, code: string) {
return await supertest(app.getHttpServer())
.post('/api/auth/native/exchange')
.set('x-affine-client-kind', 'native')
.send({ code })
.expect(201);
}
function assertClearsNativeAuthCookies(
t: ExecutionContext,
res: supertest.Response
) {
const setCookies = res.get('Set-Cookie') ?? [];
for (const name of [
AuthService.sessionCookieName,
AuthService.userCookieName,
AuthService.csrfCookieName,
]) {
t.true(
setCookies.some(
cookie =>
cookie.startsWith(`${name}=;`) &&
/Expires=Thu, 01 Jan 1970/i.test(cookie)
)
);
}
}
test('should issue exchange code only for native credential sign in', async t => {
const { app } = t.context;
const u1 = await app.createUser('native@affine.pro');
const res = await app
.POST('/api/auth/sign-in')
.set('x-affine-client-kind', 'native')
.send({ email: u1.email, password: u1.password })
.expect(200);
t.is(res.body.id, u1.id);
t.truthy(res.body.exchangeCode);
assertClearsNativeAuthCookies(t, res);
const exchangeRes = await exchangeSession(app, res.body.exchangeCode);
t.truthy(exchangeRes.body.token);
t.truthy(exchangeRes.body.expiresAt);
});
test('should not issue jwt for browser-origin credential sign in', async t => {
const { app } = t.context;
const u1 = await app.createUser('browser@affine.pro');
const res = await app
.POST('/api/auth/sign-in')
.set('origin', 'https://app.affine.pro')
.set('x-affine-client-kind', 'native')
.send({ email: u1.email, password: u1.password })
.expect(200);
t.is(res.body.id, u1.id);
t.falsy(res.body.token);
t.falsy(res.body.expiresAt);
t.falsy(res.body.exchangeCode);
});
test('should write legacy auth cookies when signing in with credential', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
const res = await app
.POST('/api/auth/sign-in')
.send({ email: u1.email, password: u1.password })
.expect(200);
const cookies = parseCookies(res);
t.truthy(cookies[AuthService.sessionCookieName]);
t.truthy(cookies[AuthService.userCookieName]);
t.truthy(cookies[AuthService.csrfCookieName]);
});
test('should record sign in client version when header is provided', async t => {
const { app, db } = t.context;
@@ -81,6 +181,126 @@ test('should record sign in client version when header is provided', async t =>
t.is(userSession2?.signInClientVersion, '0.25.1');
});
test('should return method-oriented preflight for registered password users', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
const res = await app
.POST('/api/auth/preflight')
.send({ email: u1.email })
.expect(201);
t.true(res.body.registered);
t.deepEqual(res.body.methods.password, { available: true });
t.deepEqual(res.body.methods.magicLink, { available: true });
t.deepEqual(res.body.methods.passkey, {
available: false,
discoverable: false,
});
t.false('hasPassword' in res.body);
});
test('should return method-oriented preflight for unknown users', async t => {
const { app } = t.context;
const res = await app
.POST('/api/auth/preflight')
.send({ email: 'unknown@affine.pro' })
.expect(201);
t.false(res.body.registered);
t.deepEqual(res.body.methods.password, { available: false });
t.deepEqual(res.body.methods.magicLink, { available: true });
t.deepEqual(res.body.methods.passkey, {
available: false,
discoverable: false,
});
t.false('hasPassword' in res.body);
});
test('should return password unavailable for registered users without password', async t => {
const { app } = t.context;
const u1 = await app.createUser('passwordless@affine.pro', {
password: null,
});
const res = await app
.POST('/api/auth/preflight')
.send({ email: u1.email })
.expect(201);
t.true(res.body.registered);
t.deepEqual(res.body.methods.password, { available: false });
t.false('hasPassword' in res.body);
});
test('should return methods unavailable for disabled users', async t => {
const { app } = t.context;
const u1 = await app.createUser('disabled@affine.pro', {
disabled: true,
});
const res = await app
.POST('/api/auth/preflight')
.send({ email: u1.email })
.expect(201);
t.false(res.body.registered);
t.deepEqual(res.body.methods.password, { available: false });
t.deepEqual(res.body.methods.magicLink, { available: false });
});
test('should return magic link unavailable for unknown users when signup is disabled', async t => {
const { app, config } = t.context;
config.override({
auth: {
allowSignup: false,
},
});
const res = await app
.POST('/api/auth/preflight')
.send({ email: 'unknown@affine.pro' })
.expect(201);
t.false(res.body.registered);
t.deepEqual(res.body.methods.magicLink, { available: false });
});
test('should return magic link unavailable when domain verification rejects signup email', async t => {
const { app, config } = t.context;
config.override({
auth: {
requireEmailDomainVerification: true,
},
});
const res = await app
.POST('/api/auth/preflight')
.send({ email: 'unknown+alias@affine.pro' })
.expect(201);
t.false(res.body.registered);
t.deepEqual(res.body.methods.magicLink, { available: false });
});
test('should return bound auth methods for current account', async t => {
const { app } = t.context;
await app.signupV1('bound-methods@affine.pro');
const res = await app.GET('/api/auth/methods').expect(200);
t.deepEqual(res.body.password, { bound: true });
t.deepEqual(res.body.oauth, { bound: false, providers: [] });
t.deepEqual(res.body.passkey, { bound: false, count: 0 });
});
test('should be able to sign in with email', async t => {
const { app } = t.context;
@@ -100,7 +320,19 @@ test('should be able to sign in with email', async t => {
const email = url.searchParams.get('email');
const token = url.searchParams.get('token');
await app.POST('/api/auth/magic-link').send({ email, token }).expect(201);
const signInRes = await app
.POST('/api/auth/magic-link')
.send({ email, token })
.expect(201);
t.is(signInRes.body.id, u1.id);
t.falsy(signInRes.body.token);
t.falsy(signInRes.body.expiresAt);
const cookies = parseCookies(signInRes);
t.truthy(cookies[AuthService.sessionCookieName]);
t.truthy(cookies[AuthService.userCookieName]);
t.truthy(cookies[AuthService.csrfCookieName]);
const session = await currentUser(app);
t.is(session?.id, u1.id);
@@ -140,6 +372,17 @@ test('should not be able to sign in if email is invalid', async t => {
t.is(res.body.message, 'An invalid email provided: ');
});
test('should not create magic-link state if email is invalid', async t => {
const { app, magicLink } = t.context;
await t.throwsAsync(magicLink.send('invalid-email'), {
message: 'An invalid email provided: invalid-email',
});
t.is(app.mails.count('SignIn'), 0);
t.is(app.mails.count('SignUp'), 0);
});
test('should not be able to sign in if forbidden', async t => {
const { app, auth } = t.context;
@@ -202,7 +445,7 @@ test('should be able to sign out', async t => {
t.falsy(session);
});
test('should be able to sign out when csrf header is missing (compat)', async t => {
test('should reject cookie sign out when csrf header is missing', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
@@ -220,16 +463,134 @@ test('should be able to sign out when csrf header is missing (compat)', async t
await supertest(app.getHttpServer())
.post('/api/auth/sign-out')
.set('Cookie', cookieHeader)
.expect(200);
.expect(HttpStatus.FORBIDDEN);
const sessionRes = await supertest(app.getHttpServer())
.get('/api/auth/session')
.set('Cookie', cookieHeader)
.expect(200);
t.is(sessionRes.body.user.id, u1.id);
});
test('should be able to sign out with jwt without csrf', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
const signInRes = await supertest(app.getHttpServer())
.post('/api/auth/sign-in')
.set('x-affine-client-kind', 'native')
.send({ email: u1.email, password: u1.password })
.expect(200);
const token = (await exchangeSession(app, signInRes.body.exchangeCode)).body
.token;
await supertest(app.getHttpServer())
.post('/api/auth/sign-out')
.set('Authorization', `Bearer ${token}`)
.expect(200);
const sessionRes = await supertest(app.getHttpServer())
.get('/api/auth/session')
.set('Authorization', `Bearer ${token}`)
.expect(200);
t.falsy(sessionRes.body.user);
});
test('should ignore user_id query when signing out with jwt', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
const u2 = await app.createUser('u2@affine.pro');
const u1SignIn = await app
.POST('/api/auth/sign-in')
.set('x-affine-client-kind', 'native')
.send({ email: u1.email, password: u1.password })
.expect(200);
const u1Token = (await exchangeSession(app, u1SignIn.body.exchangeCode)).body
.token;
await app
.POST('/api/auth/sign-in')
.send({ email: u2.email, password: u2.password })
.expect(200);
await supertest(app.getHttpServer())
.post(`/api/auth/sign-out?user_id=${u2.id}`)
.set('Authorization', `Bearer ${u1Token}`)
.expect(200);
const u1Session = await supertest(app.getHttpServer())
.get('/api/auth/session')
.set('Authorization', `Bearer ${u1Token}`)
.expect(200);
t.falsy(u1Session.body.user);
const cookieSession = await app.GET('/api/auth/session').expect(200);
t.is(cookieSession.body.user.id, u2.id);
});
test('should reuse jwt session when signing in another account without cookies', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
const u2 = await app.createUser('u2@affine.pro');
const u1SignIn = await supertest(app.getHttpServer())
.post('/api/auth/sign-in')
.set('x-affine-client-kind', 'native')
.send({ email: u1.email, password: u1.password })
.expect(200);
const u1Token = (await exchangeSession(app, u1SignIn.body.exchangeCode)).body
.token;
const u2SignIn = await supertest(app.getHttpServer())
.post('/api/auth/sign-in')
.set('Authorization', `Bearer ${u1Token}`)
.send({ email: u2.email, password: u2.password })
.expect(200);
const u1Session = await t.context.db.userSession.findFirstOrThrow({
where: { userId: u1.id },
});
const u2Session = await t.context.db.userSession.findFirstOrThrow({
where: { userId: u2.id },
});
t.is(u2SignIn.body.id, u2.id);
t.is(u2Session.sessionId, u1Session.sessionId);
});
test('should not reuse legacy bearer session id when signing in another account without cookies', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
const u2 = await app.createUser('u2@affine.pro');
await supertest(app.getHttpServer())
.post('/api/auth/sign-in')
.send({ email: u1.email, password: u1.password })
.expect(200);
const u1Session = await t.context.db.userSession.findFirstOrThrow({
where: { userId: u1.id },
});
await supertest(app.getHttpServer())
.post('/api/auth/sign-in')
.set('Authorization', `Bearer ${u1Session.sessionId}`)
.send({ email: u2.email, password: u2.password })
.expect(200);
const u2Session = await t.context.db.userSession.findFirstOrThrow({
where: { userId: u2.id },
});
t.not(u2Session.sessionId, u1Session.sessionId);
});
test('should be able to sign out when duplicated csrf cookies exist', async t => {
const { app } = t.context;
@@ -264,23 +625,6 @@ test('should be able to sign out when duplicated csrf cookies exist', async t =>
t.falsy(sessionRes.body.user);
});
test('should be able to sign out via GET /api/auth/sign-out (deprecated)', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
await app
.POST('/api/auth/sign-in')
.send({ email: u1.email, password: u1.password })
.expect(200);
const res = await app.GET('/api/auth/sign-out').expect(200);
t.is(res.headers.deprecation, 'true');
const session = await currentUser(app);
t.falsy(session);
});
test('should reject sign out when csrf token mismatched', async t => {
const { app } = t.context;
@@ -317,20 +661,20 @@ test('should sign in desktop app via one-time open-app code', async t => {
const exchangeRes = await supertest(app.getHttpServer())
.post('/api/auth/open-app/sign-in')
.set('x-affine-client-kind', 'native')
.send({ code })
.expect(201);
const exchangedCookies = exchangeRes.get('Set-Cookie') ?? [];
t.true(
exchangedCookies.some(c =>
c.startsWith(`${AuthService.sessionCookieName}=`)
)
);
t.is(exchangeRes.body.id, u1.id);
t.truthy(exchangeRes.body.exchangeCode);
assertClearsNativeAuthCookies(t, exchangeRes);
const tokenRes = await exchangeSession(app, exchangeRes.body.exchangeCode);
t.truthy(tokenRes.body.token);
t.truthy(tokenRes.body.expiresAt);
const cookieHeader = exchangedCookies.map(c => c.split(';')[0]).join('; ');
const sessionRes = await supertest(app.getHttpServer())
.get('/api/auth/session')
.set('Cookie', cookieHeader)
.set('Authorization', `Bearer ${tokenRes.body.token}`)
.expect(200);
t.is(sessionRes.body.user?.id, u1.id);
@@ -379,6 +723,35 @@ test('should not throw on parse of a bad cookie', async t => {
t.is(req.cookies?.[badCookieKey], badCookieVal);
});
test('should only read string request cookies', t => {
const req = {
headers: {},
cookies: {
empty: '',
list: ['session'],
object: { value: 'session' },
session: 'valid_session',
},
} as unknown as IncomingMessage & { cookies?: Record<string, unknown> };
t.is(getRequestCookie(req, 'session'), 'valid_session');
t.is(getRequestCookie(req, 'empty'), undefined);
t.is(getRequestCookie(req, 'list'), undefined);
t.is(getRequestCookie(req, 'object'), undefined);
});
test('should only read string request headers', t => {
const req = {
headers: {
'x-list': ['value'],
'x-string': 'value',
},
} as unknown as IncomingMessage;
t.is(getRequestHeader(req, 'x-string'), 'value');
t.is(getRequestHeader(req, 'x-list'), undefined);
});
// multiple accounts session tests
test('should be able to sign in another account in one session', async t => {
const { app } = t.context;
@@ -400,15 +773,6 @@ test('should be able to sign in another account in one session', async t => {
.send({ email: u2.email, password: u2.password })
.expect(200);
// list [u1, u2]
const sessions = await app.GET('/api/auth/sessions').expect(200);
t.is(sessions.body.users.length, 2);
t.like(
sessions.body.users.map((u: any) => u.id),
[u1.id, u2.id]
);
// default to latest signed in user: u2
let session = await app.GET('/api/auth/session').expect(200);
@@ -0,0 +1,56 @@
import ava from 'ava';
import { verifyEmailDomainRecords } from '../../core/auth/email-domain';
const test = ava;
test('should verify email domain records', async t => {
const ok = await verifyEmailDomainRecords(
'user@example.com',
{
resolveMx: async () => [{ exchange: 'mx.example.com', priority: 10 }],
resolveTxt: async domain =>
domain === '_dmarc.example.com'
? [['v=DMARC1; p=none']]
: [['v=spf1 include:_spf.example.com ~all']],
},
100
);
t.true(ok);
});
test('should verify split txt record chunks', async t => {
const ok = await verifyEmailDomainRecords(
'user@example.com',
{
resolveMx: async () => [{ exchange: 'mx.example.com', priority: 10 }],
resolveTxt: async domain =>
domain === '_dmarc.example.com'
? [['v=DM', 'ARC1; p=none']]
: [['v=spf', '1 include:_spf.example.com ~all']],
},
100
);
t.true(ok);
});
test('should fail closed when email domain lookup times out', async t => {
const ok = await verifyEmailDomainRecords(
'user@example.com',
{
resolveMx: async () =>
new Promise(resolve =>
setTimeout(
() => resolve([{ exchange: 'mx.example.com', priority: 10 }]),
50
)
),
resolveTxt: async () => [['v=spf1 include:_spf.example.com ~all']],
},
1
);
t.false(ok);
});
@@ -5,7 +5,13 @@ import Sinon from 'sinon';
import request from 'supertest';
import { CANARY_CLIENT_VERSION_MAX_AGE_DAYS, ConfigFactory } from '../../base';
import { AuthModule, CurrentUser, Public, Session } from '../../core/auth';
import {
AuthModule,
CurrentUser,
JwtSessionService,
Public,
Session,
} from '../../core/auth';
import { AuthService } from '../../core/auth/service';
import { Models } from '../../models';
import { createTestingApp, TestingApp } from '../utils';
@@ -37,6 +43,7 @@ const test = ava as TestFn<{
app: TestingApp;
server: any;
auth: AuthService;
jwtSession: JwtSessionService;
models: Models;
db: PrismaClient;
config: ConfigFactory;
@@ -53,6 +60,7 @@ test.before(async t => {
t.context.app = app;
t.context.server = app.getHttpServer();
t.context.auth = app.get(AuthService);
t.context.jwtSession = app.get(JwtSessionService);
t.context.models = app.get(Models);
t.context.db = app.get(PrismaClient);
t.context.config = app.get(ConfigFactory);
@@ -110,7 +118,7 @@ test('should not be able to visit private api if not signed in', async t => {
t.assert(true);
});
test('should be able to visit private api if signed in', async t => {
test('should be able to visit private api with cookie session', async t => {
const res = await request(t.context.server)
.get('/private')
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
@@ -119,16 +127,115 @@ test('should be able to visit private api if signed in', async t => {
t.is(res.body.user.id, t.context.u1.id);
});
test('should be able to visit private api with access token', async t => {
const models = t.context.app.get(Models);
const token = await models.accessToken.create({
test('should be able to visit private api with legacy bearer session id', async t => {
const res = await request(t.context.server)
.get('/private')
.set('Authorization', `Bearer ${t.context.sessionId}`)
.expect(HttpStatus.OK);
t.is(res.body.user.id, t.context.u1.id);
});
test('should be able to visit private api with personal access token', async t => {
const accessToken = await t.context.models.accessToken.create({
userId: t.context.u1.id,
name: 'test',
});
const res = await request(t.context.server)
.get('/private')
.set('Authorization', `Bearer ${token.token}`)
.set('Authorization', `Bearer ${accessToken.token}`)
.expect(HttpStatus.OK);
t.is(res.body.user.id, t.context.u1.id);
});
test('should be able to visit private api with jwt session', async t => {
const jwt = t.context.jwtSession.sign(t.context.u1.id, t.context.sessionId);
const res = await request(t.context.server)
.get('/private')
.set('Authorization', `Bearer ${jwt.token}`)
.expect(HttpStatus.OK);
t.is(res.body.user.id, t.context.u1.id);
});
test('should prefer bearer jwt over cookie session', async t => {
const u2 = await t.context.auth.signUp('u2@affine.pro', '1');
const u2Session = await t.context.auth.createUserSession(u2.id);
const jwt = t.context.jwtSession.sign(u2.id, u2Session.sessionId);
const res = await request(t.context.server)
.get('/private')
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.set('Authorization', `Bearer ${jwt.token}`)
.expect(HttpStatus.OK);
t.is(res.body.user.id, u2.id);
});
test('should reject jwt after its user session is deleted', async t => {
const jwt = t.context.jwtSession.sign(t.context.u1.id, t.context.sessionId);
await t.context.auth.signOut(t.context.sessionId, t.context.u1.id);
await request(t.context.server)
.get('/private')
.set('Authorization', `Bearer ${jwt.token}`)
.expect(HttpStatus.UNAUTHORIZED);
t.pass();
});
test('should enforce client version for jwt and bearer session id auth', async t => {
t.context.config.override({
client: {
versionControl: {
enabled: true,
requiredVersion: '>=0.25.0',
},
},
});
const cases = [
{
name: 'jwt',
token: async () => {
const session = await t.context.auth.createUserSession(t.context.u1.id);
return t.context.jwtSession.sign(t.context.u1.id, session.sessionId)
.token;
},
},
{
name: 'bearer session id',
token: async () => {
const session = await t.context.auth.createUserSession(t.context.u1.id);
return session.sessionId;
},
},
];
for (const testCase of cases) {
const res = await request(t.context.server)
.get('/private')
.set('Authorization', `Bearer ${await testCase.token()}`)
.set('x-affine-version', '0.24.0')
.expect(HttpStatus.FORBIDDEN);
t.is(
res.body.message,
'Unsupported client with version [0.24.0], required version is [>=0.25.0].',
testCase.name
);
}
});
test('should fall back to cookie session on public api when jwt is invalid', async t => {
const res = await request(t.context.server)
.get('/public')
.set('Cookie', `${AuthService.sessionCookieName}=${t.context.sessionId}`)
.set('Authorization', 'Bearer invalid.jwt.token')
.expect(HttpStatus.OK);
t.is(res.body.user.id, t.context.u1.id);
@@ -0,0 +1,182 @@
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import jwt from 'jsonwebtoken';
import { CryptoHelper } from '../../base/helpers';
import {
AuthModule,
AuthService,
type CurrentUser,
JwtSessionService,
} from '../../core/auth';
import { Models } from '../../models';
import { createTestingApp, TestingApp } from '../utils';
const test = ava as TestFn<{
app: TestingApp;
auth: AuthService;
jwtSession: JwtSessionService;
crypto: CryptoHelper;
models: Models;
db: PrismaClient;
user: CurrentUser;
sessionId: string;
}>;
test.before(async t => {
const app = await createTestingApp({
imports: [AuthModule],
});
t.context.app = app;
t.context.auth = app.get(AuthService);
t.context.jwtSession = app.get(JwtSessionService);
t.context.crypto = app.get(CryptoHelper);
t.context.models = app.get(Models);
t.context.db = app.get(PrismaClient);
});
test.beforeEach(async t => {
await t.context.app.initTestingDB();
t.context.user = await t.context.auth.signUp('u1@affine.pro', '1');
const session = await t.context.auth.createUserSession(t.context.user.id);
t.context.sessionId = session.sessionId;
});
test.after.always(async t => {
await t.context.app.close();
});
function currentJwtKey(crypto: CryptoHelper) {
return Buffer.concat([
Buffer.from('affine:user-session-jwt:v1:'),
crypto.keyPair.sha256.privateKey,
]);
}
test('should sign and verify a user session jwt', async t => {
const signed = t.context.jwtSession.sign(
t.context.user.id,
t.context.sessionId
);
const session = await t.context.jwtSession.verify(signed.token);
t.is(session.user.id, t.context.user.id);
t.is(session.sessionId, t.context.sessionId);
t.true(signed.expiresAt.getTime() > Date.now());
});
test('should reject invalid jwt cases', async t => {
const cases: Array<{ name: string; token: string }> = [
{
name: 'expired token',
token: jwt.sign(
{ sid: t.context.sessionId, typ: 'user_session' },
currentJwtKey(t.context.crypto),
{
algorithm: 'HS256',
audience: 'affine-client',
expiresIn: -1,
issuer: 'affine',
subject: t.context.user.id,
}
),
},
{
name: 'wrong signature',
token: jwt.sign(
{ sid: t.context.sessionId, typ: 'user_session' },
'wrong-key',
{
algorithm: 'HS256',
audience: 'affine-client',
expiresIn: 60,
issuer: 'affine',
subject: t.context.user.id,
}
),
},
{
name: 'wrong issuer',
token: jwt.sign(
{ sid: t.context.sessionId, typ: 'user_session' },
currentJwtKey(t.context.crypto),
{
algorithm: 'HS256',
audience: 'affine-client',
expiresIn: 60,
issuer: 'other-issuer',
subject: t.context.user.id,
}
),
},
{
name: 'wrong audience',
token: jwt.sign(
{ sid: t.context.sessionId, typ: 'user_session' },
currentJwtKey(t.context.crypto),
{
algorithm: 'HS256',
audience: 'other-audience',
expiresIn: 60,
issuer: 'affine',
subject: t.context.user.id,
}
),
},
{
name: 'wrong type',
token: jwt.sign(
{ sid: t.context.sessionId, typ: 'personal_access_token' },
currentJwtKey(t.context.crypto),
{
algorithm: 'HS256',
audience: 'affine-client',
expiresIn: 60,
issuer: 'affine',
subject: t.context.user.id,
}
),
},
];
for (const testCase of cases) {
await t.throwsAsync(() => t.context.jwtSession.verify(testCase.token), {
message: 'You must sign in first to access this resource.',
});
}
});
test('should reject jwt when its user session is missing or expired', async t => {
const signed = t.context.jwtSession.sign(
t.context.user.id,
t.context.sessionId
);
await t.context.auth.signOut(t.context.sessionId, t.context.user.id);
await t.throwsAsync(() => t.context.jwtSession.verify(signed.token), {
message: 'You must sign in first to access this resource.',
});
const refreshed = await t.context.auth.createUserSession(t.context.user.id);
const expired = t.context.jwtSession.sign(
t.context.user.id,
refreshed.sessionId
);
await t.context.db.userSession.updateMany({
where: {
userId: t.context.user.id,
sessionId: refreshed.sessionId,
},
data: {
expiresAt: new Date(Date.now() - 1000),
},
});
await t.throwsAsync(() => t.context.jwtSession.verify(expired.token), {
message: 'You must sign in first to access this resource.',
});
});
@@ -0,0 +1,64 @@
import ava, { TestFn } from 'ava';
import { AuthMethodsService, AuthModule } from '../../core/auth';
import { Models } from '../../models';
import { createTestingApp, TestingApp } from '../utils';
const test = ava as TestFn<{
app: TestingApp;
authMethods: AuthMethodsService;
models: Models;
}>;
test.before(async t => {
const app = await createTestingApp({
imports: [AuthModule],
});
t.context.app = app;
t.context.authMethods = app.get(AuthMethodsService);
t.context.models = app.get(Models);
});
test.beforeEach(async t => {
await t.context.app.initTestingDB();
});
test.after.always(async t => {
await t.context.app.close();
});
test('should return login preflight methods without top-level has fields', async t => {
const user = await t.context.app.createUser('methods@affine.pro');
const preflight = await t.context.authMethods.loginPreflight(user.email);
t.true(preflight.registered);
t.deepEqual(preflight.methods.password, { available: true });
t.deepEqual(preflight.methods.magicLink, { available: true });
t.deepEqual(preflight.methods.passkey, {
available: false,
discoverable: false,
});
t.false('hasPassword' in preflight);
});
test('should return bound account methods for settings', async t => {
const user = await t.context.app.createUser('bound-methods@affine.pro');
await t.context.models.user.createConnectedAccount({
userId: user.id,
provider: 'Google',
providerAccountId: 'google-account',
accessToken: 'access-token',
});
const methods = await t.context.authMethods.boundMethods(user.id);
t.deepEqual(methods.password, { bound: true });
t.deepEqual(methods.oauth, {
bound: true,
providers: ['Google'],
});
t.deepEqual(methods.passkey, { bound: false, count: 0 });
});
@@ -50,6 +50,13 @@ test('should be able to set cache with ttl', async t => {
t.true(ttl > 0);
});
test('should reject invalid ttl options', async t => {
t.false(await cache.set(key('test-invalid-ttl'), 1, { ttl: 0 }));
t.is(await cache.get(key('test-invalid-ttl')), undefined);
t.false(await cache.setnx(key('test-invalid-ttl-nx'), 1, { ttl: 0 }));
t.is(await cache.get(key('test-invalid-ttl-nx')), undefined);
});
test('should be able to incr/decr number cache', async t => {
t.true(await cache.set(key('test-incr'), 1));
t.is(await cache.increase(key('test-incr')), 2);
@@ -63,6 +63,14 @@ export class TestingApp extends NestApplication {
await this.close();
}
clearAuth() {
this.resetRateLimit();
this.sessionCookie = null;
this.currentUserCookie = null;
this.csrfCookie = null;
this.userCookies.clear();
}
request(
method: 'options' | 'get' | 'post' | 'put' | 'delete' | 'patch',
path: string
@@ -7,7 +7,7 @@ import { app, e2e, Mockers } from '../test';
e2e('user(email) should return null without auth', async t => {
const user = await app.create(Mockers.User);
await app.logout();
app.clearAuth();
const res = await app.gql({
query: getUserQuery,
@@ -18,7 +18,7 @@ e2e('user(email) should return null without auth', async t => {
});
e2e('user(email) should return null outside workspace scope', async t => {
await app.logout();
app.clearAuth();
const me = await app.signup();
const other = await app.create(Mockers.User);
@@ -43,7 +43,7 @@ e2e('user(email) should return null outside workspace scope', async t => {
});
e2e('user(email) should return user within workspace scope', async t => {
await app.logout();
app.clearAuth();
const me = await app.signup();
const other = await app.create(Mockers.User);
const ws = await app.create(Mockers.Workspace, { owner: me });
@@ -67,7 +67,7 @@ e2e('user(email) should return user within workspace scope', async t => {
});
e2e('user(email) should be rate limited', async t => {
await app.logout();
app.clearAuth();
const me = await app.signup();
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
@@ -7,6 +7,7 @@ import Sinon from 'sinon';
import { AppModule } from '../../app.module';
import { ConfigFactory, InvalidOauthResponse, URLHelper } from '../../base';
import { SessionCache } from '../../base/cache';
import { ConfigModule } from '../../base/config';
import { CurrentUser } from '../../core/auth';
import { AuthService } from '../../core/auth/service';
@@ -23,6 +24,7 @@ import { createTestingApp, currentUser, TestingApp } from '../utils';
const test = ava as TestFn<{
auth: AuthService;
oauth: OAuthService;
cache: SessionCache;
models: Models;
u1: CurrentUser;
db: PrismaClient;
@@ -62,6 +64,7 @@ test.before(async t => {
t.context.auth = app.get(AuthService);
t.context.oauth = app.get(OAuthService);
t.context.cache = app.get(SessionCache);
t.context.models = app.get(Models);
t.context.db = app.get(PrismaClient);
t.context.app = app;
@@ -244,6 +247,7 @@ test('should forbid preflight with untrusted redirect_uri', async t => {
test('should throw if client_nonce is missing in preflight', async t => {
const { app } = t.context;
app.clearAuth();
await app
.POST('/api/oauth/preflight')
@@ -293,6 +297,19 @@ test('should be able to save oauth state', async t => {
t.is(state!.provider, OAuthProviderName.Google);
});
test('should save oauth state with three hour ttl', async t => {
const { cache, oauth } = t.context;
const id = await oauth.saveOAuthState({
provider: OAuthProviderName.Google,
});
const ttl = await cache.ttl(`auth_challenge:oauth_state:${id}`);
t.true(ttl > 2 * 3600);
t.true(ttl <= 3 * 3600);
});
test('should be able to get registered oauth providers', async t => {
const { oauth } = t.context;
@@ -550,12 +567,40 @@ test('should be able to sign up with oauth', async t => {
const clientNonce = mockOAuthProvider(app, 'u2@affine.pro');
await app
const res = await app
.POST('/api/oauth/callback')
.set('x-affine-client-kind', 'native')
.send({ code: '1', state: '1', client_nonce: clientNonce })
.expect(HttpStatus.OK);
const sessionUser = await currentUser(app);
t.truthy(res.body.exchangeCode);
const tokenRes = await app
.POST('/api/auth/native/exchange')
.set('x-affine-client-kind', 'native')
.send({ code: res.body.exchangeCode })
.expect(201);
t.truthy(tokenRes.body.token);
t.truthy(tokenRes.body.expiresAt);
const setCookies = res.get('Set-Cookie') ?? [];
for (const name of [
AuthService.sessionCookieName,
AuthService.userCookieName,
AuthService.csrfCookieName,
]) {
t.true(
setCookies.some(
cookie =>
cookie.startsWith(`${name}=;`) &&
/Expires=Thu, 01 Jan 1970/i.test(cookie)
)
);
}
const sessionUserRes = await app
.GET('/api/auth/session')
.set('Authorization', `Bearer ${tokenRes.body.token}`)
.expect(200);
const sessionUser = sessionUserRes.body.user;
t.truthy(sessionUser);
t.is(sessionUser!.email, 'u2@affine.pro');
@@ -60,14 +60,17 @@ async function withTimeout<T>(
}
}
function createClient(url: string, cookie: string): SocketIOClient {
function createClient(
url: string,
cookie?: string,
auth?: Record<string, unknown>
): SocketIOClient {
return io(url, {
transports: ['websocket'],
reconnection: false,
forceNew: true,
extraHeaders: {
cookie,
},
...(cookie ? { extraHeaders: { cookie } } : {}),
...(auth ? { auth } : {}),
});
}
@@ -146,14 +149,24 @@ function expectNoEvent(
async function login(app: TestingApp) {
const user = await app.createUser();
const res = await app
const cookieRes = await app
.POST('/api/auth/sign-in')
.send({ email: user.email, password: user.password })
.expect(200);
const nativeRes = await app
.POST('/api/auth/sign-in')
.set('x-affine-client-kind', 'native')
.send({ email: user.email, password: user.password })
.expect(200);
const tokenRes = await app
.POST('/api/auth/native/exchange')
.set('x-affine-client-kind', 'native')
.send({ code: nativeRes.body.exchangeCode })
.expect(201);
const cookies = res.get('Set-Cookie') ?? [];
const cookies = cookieRes.get('Set-Cookie') ?? [];
const cookieHeader = cookies.map(c => c.split(';')[0]).join('; ');
return { user, cookieHeader };
return { user, cookieHeader, token: tokenRes.body.token as string };
}
function createYjsUpdateBase64() {
@@ -217,6 +230,52 @@ test.after.always(async () => {
await app.close();
});
test('should reject websocket legacy session token auth', async t => {
const { cookieHeader } = await login(app);
const sessionCookie = cookieHeader
.split('; ')
.find(cookie => cookie.startsWith('affine_session='));
const token = sessionCookie?.split('=')[1];
t.truthy(token);
const socket = createClient(url, undefined, { token });
try {
await t.throwsAsync(() => waitForConnect(socket));
} finally {
socket.disconnect();
}
});
test('should connect websocket with jwt auth', async t => {
const { token } = await login(app);
const socket = createClient(url, undefined, { token, tokenType: 'jwt' });
try {
await waitForConnect(socket);
t.true(socket.connected);
} finally {
socket.disconnect();
}
});
test('should reject websocket jwt auth after session deletion', async t => {
const { token } = await login(app);
await app
.POST('/api/auth/sign-out')
.set('Authorization', `Bearer ${token}`)
.expect(200);
const socket = createClient(url, undefined, { token, tokenType: 'jwt' });
try {
await t.throwsAsync(() => waitForConnect(socket));
} finally {
socket.disconnect();
}
});
test('clientVersion=0.25.0 should only receive space:broadcast-doc-update', async t => {
const { user, cookieHeader } = await login(app);
const spaceId = user.id;
@@ -110,6 +110,10 @@ export class TestingApp extends ApplyType<INestApplication>() {
async initTestingDB() {
await initTestingDB(this);
this.clearAuth();
}
clearAuth() {
this.sessionCookie = null;
this.currentUserCookie = null;
this.csrfCookie = null;
+23 -2
View File
@@ -7,6 +7,14 @@ export interface CacheSetOptions {
ttl?: number;
}
const GET_AND_DELETE_LUA = `
local value = redis.call("GET", KEYS[1])
if value then
redis.call("DEL", KEYS[1])
end
return value
`;
export function isValidCacheTtl(ttl: unknown): ttl is number {
return typeof ttl === 'number' && Number.isSafeInteger(ttl) && ttl > 0;
}
@@ -32,12 +40,15 @@ export class CacheProvider {
value: T,
opts: CacheSetOptions = {}
): Promise<boolean> {
if (opts.ttl) {
if (isValidCacheTtl(opts.ttl)) {
return this.redis
.set(key, JSON.stringify(value), 'PX', opts.ttl)
.then(() => true)
.catch(() => false);
}
if (opts.ttl !== undefined) {
return false;
}
return this.redis
.set(key, JSON.stringify(value))
@@ -58,12 +69,15 @@ export class CacheProvider {
value: T,
opts: CacheSetOptions = {}
): Promise<boolean> {
if (opts.ttl) {
if (isValidCacheTtl(opts.ttl)) {
return this.redis
.set(key, JSON.stringify(value), 'PX', opts.ttl, 'NX')
.then(v => !!v)
.catch(() => false);
}
if (opts.ttl !== undefined) {
return false;
}
return this.redis
.set(key, JSON.stringify(value), 'NX')
@@ -78,6 +92,13 @@ export class CacheProvider {
.catch(() => false);
}
async getAndDelete<T = unknown>(key: string): Promise<T | undefined> {
return this.redis
.eval(GET_AND_DELETE_LUA, 1, key)
.then(v => (typeof v === 'string' ? JSON.parse(v) : undefined))
.catch(() => undefined);
}
async has(key: string): Promise<boolean> {
return this.redis
.exists(key)
@@ -7,12 +7,16 @@ import { GqlArgumentsHost } from '@nestjs/graphql';
import type { Request, Response } from 'express';
import { ClsServiceManager } from 'nestjs-cls';
import type { Socket } from 'socket.io';
import { z } from 'zod';
type RequestResponse = {
req: Request;
res?: Response;
};
const RequestCookieValueSchema = z.string().min(1);
const RequestHeaderValueSchema = z.string().min(1);
export function getRequestResponseFromHost(
host: ArgumentsHost
): RequestResponse {
@@ -68,9 +72,7 @@ export function getRequestResponseFromContext(
export function parseCookies(
req: IncomingMessage & { cookies?: Record<string, string> }
) {
if (req.cookies) {
return;
}
if (req.cookies) return;
const cookieStr = req.headers.cookie ?? '';
req.cookies = cookieStr.split(';').reduce(
@@ -103,6 +105,25 @@ export function parseCookies(
);
}
export function getRequestCookie(
req: IncomingMessage & { cookies?: Record<string, unknown> },
name: string
) {
parseCookies(req as IncomingMessage & { cookies?: Record<string, string> });
const value = req.cookies?.[name];
const parsed = RequestCookieValueSchema.safeParse(value);
return parsed.success ? parsed.data : undefined;
}
export function getRequestHeader(req: IncomingMessage, name: string) {
const value = req.headers[name.toLowerCase()];
const parsed = RequestHeaderValueSchema.safeParse(value);
return parsed.success ? parsed.data : undefined;
}
/**
* Request type
*
@@ -0,0 +1,54 @@
import { randomUUID } from 'node:crypto';
import { Injectable } from '@nestjs/common';
import { InvalidAuthState, SessionCache } from '../../base';
import { isValidCacheTtl } from '../../base/cache/provider';
export type AuthChallengePurpose =
| 'oauth_state'
| 'open_app_sign_in'
| 'native_session_exchange'
| 'captcha'
| 'passkey_registration'
| 'passkey_authentication';
@Injectable()
export class AuthChallengeStore {
constructor(private readonly cache: SessionCache) {}
async create<T>(
purpose: AuthChallengePurpose,
payload: T | ((token: string) => T),
ttlMs: number
): Promise<string> {
if (!isValidCacheTtl(ttlMs)) {
throw new InvalidAuthState();
}
const token = randomUUID();
const value =
typeof payload === 'function'
? (payload as (token: string) => T)(token)
: payload;
const stored = await this.cache.set(this.key(purpose, token), value, {
ttl: ttlMs,
});
if (!stored) {
throw new InvalidAuthState();
}
return token;
}
async get<T>(purpose: AuthChallengePurpose, token: string) {
return (await this.cache.get<T>(this.key(purpose, token))) ?? null;
}
async consume<T>(purpose: AuthChallengePurpose, token: string) {
return (await this.cache.getAndDelete<T>(this.key(purpose, token))) ?? null;
}
private key(purpose: AuthChallengePurpose, token: string) {
return `auth_challenge:${purpose}:${token}`;
}
}
@@ -1,4 +1,4 @@
import { resolveMx, resolveTxt, setServers } from 'node:dns/promises';
import { setServers } from 'node:dns/promises';
import {
Body,
@@ -6,7 +6,6 @@ import {
Get,
Header,
HttpStatus,
Logger,
Post,
Query,
Req,
@@ -16,27 +15,33 @@ import type { Request, Response } from 'express';
import {
ActionForbidden,
Config,
CryptoHelper,
EmailTokenNotFound,
getRequestCookie,
InvalidAuthState,
InvalidEmail,
InvalidEmailToken,
SignUpForbidden,
Throttle,
URLHelper,
UseNamedGuard,
WrongSignInCredentials,
} from '../../base';
import { Models, TokenType } from '../../models';
import { Models } from '../../models';
import { validators } from '../utils/validators';
import { Public } from './guard';
import { AuthService } from './service';
import { MagicLinkAuthService } from './magic-link';
import { AuthMethodsService } from './methods';
import { SessionExchangeService } from './native-exchange';
import { OpenAppAuthService } from './open-app';
import { AuthService, sessionUser } from './service';
import { CurrentUser, Session } from './session';
import { SessionIssuer } from './session-issuer';
interface PreflightResponse {
registered: boolean;
hasPassword: boolean;
methods: {
password: { available: boolean };
magicLink: { available: boolean };
oauth: { available: boolean; providers: string[] };
passkey: { available: boolean; discoverable: boolean };
};
}
interface SignInCredential {
@@ -56,17 +61,25 @@ interface OpenAppSignInCredential {
code: string;
}
interface NativeSessionExchangeCredential {
code: string;
}
type SignInResponse = CurrentUser & {
exchangeCode?: string;
};
@Throttle('strict')
@Controller('/api/auth')
export class AuthController {
private readonly logger = new Logger(AuthController.name);
constructor(
private readonly url: URLHelper,
private readonly auth: AuthService,
private readonly models: Models,
private readonly config: Config,
private readonly crypto: CryptoHelper
private readonly sessionIssuer: SessionIssuer,
private readonly magicLink: MagicLinkAuthService,
private readonly openApp: OpenAppAuthService,
private readonly authMethods: AuthMethodsService,
private readonly sessionExchange: SessionExchangeService,
private readonly models: Models
) {
if (env.dev) {
// set DNS servers in dev mode
@@ -89,19 +102,13 @@ export class AuthController {
}
validators.assertValidEmail(params.email);
const user = await this.models.user.getUserByEmail(params.email);
return this.authMethods.loginPreflight(params.email);
}
if (!user) {
return {
registered: false,
hasPassword: false,
};
}
return {
registered: user.registered,
hasPassword: !!user.password,
};
@UseNamedGuard('version')
@Get('/methods')
async boundMethods(@CurrentUser() user: CurrentUser) {
return this.authMethods.boundMethods(user.id);
}
@Public()
@@ -142,10 +149,17 @@ export class AuthController {
email: string,
password: string
) {
const user = await this.auth.signIn(email, password);
const identity = await this.auth.verifyPassword(email, password);
await this.auth.setCookies(req, res, user.id);
res.status(HttpStatus.OK).send(user);
const { exchangeCode } = await this.sessionIssuer.issue(req, res, identity);
const user = await this.models.user.get(identity.userId);
if (!user) {
throw new WrongSignInCredentials({ email });
}
res.status(HttpStatus.OK).send({
...sessionUser(user),
exchangeCode,
} satisfies SignInResponse);
}
async sendMagicLink(
@@ -154,105 +168,10 @@ export class AuthController {
callbackUrl = '/magic-link',
clientNonce?: string
) {
if (!this.url.isAllowedCallbackUrl(callbackUrl)) {
throw new ActionForbidden();
}
const callbackUrlObj = this.url.url(callbackUrl);
const redirectUriInCallback =
callbackUrlObj.searchParams.get('redirect_uri');
if (
redirectUriInCallback &&
!this.url.isAllowedRedirectUri(redirectUriInCallback)
) {
throw new ActionForbidden();
}
// send email magic link
const user = await this.models.user.getUserByEmail(email, {
withDisabled: true,
});
if (!user) {
if (!this.config.auth.allowSignup) {
throw new SignUpForbidden();
}
if (this.config.auth.requireEmailDomainVerification) {
// verify domain has MX, SPF, DMARC records
const [name, domain, ...rest] = email.split('@');
if (rest.length || !domain) {
throw new InvalidEmail({ email });
}
const [mx, spf, dmarc] = await Promise.allSettled([
resolveMx(domain).then(t => t.map(mx => mx.exchange).filter(Boolean)),
resolveTxt(domain).then(t =>
t.map(([k]) => k).filter(txt => txt.includes('v=spf1'))
),
resolveTxt('_dmarc.' + domain).then(t =>
t.map(([k]) => k).filter(txt => txt.includes('v=DMARC1'))
),
]).then(t => t.filter(t => t.status === 'fulfilled').map(t => t.value));
if (!mx?.length || !spf?.length || !dmarc?.length) {
throw new InvalidEmail({ email });
}
// filter out alias emails
if (name.includes('+')) {
throw new InvalidEmail({ email });
}
}
} else if (user.disabled) {
throw new WrongSignInCredentials({ email });
}
const ttlInSec = 30 * 60;
const token = await this.models.verificationToken.create(
TokenType.SignIn,
email,
ttlInSec
);
const otp = this.crypto.otp();
await this.models.magicLinkOtp.upsert(email, otp, token, clientNonce);
const magicLink = this.url.link(callbackUrl, { token: otp, email });
if (env.dev) {
// make it easier to test in dev mode
this.logger.debug(`Magic link: ${magicLink}`);
}
await this.auth.sendSignInEmail(email, magicLink, otp, !user);
res.status(HttpStatus.OK).send({
email: email,
});
const payload = await this.magicLink.send(email, callbackUrl, clientNonce);
res.status(HttpStatus.OK).send(payload);
}
@Public()
/**
* @deprecated Kept for 0.25 clients that still call GET `/api/auth/sign-out`.
* Use POST `/api/auth/sign-out` instead.
*/
@Get('/sign-out')
async signOutDeprecated(
@Res() res: Response,
@Session() session: Session | undefined,
@Query('user_id') userId: string | undefined
) {
res.setHeader('Deprecation', 'true');
if (!session) {
res.status(HttpStatus.OK).send({});
return;
}
await this.auth.signOut(session.sessionId, userId);
await this.auth.refreshCookies(res, session.sessionId);
res.status(HttpStatus.OK).send({});
}
@Public()
@Post('/sign-out')
async signOut(
@Req() req: Request,
@@ -265,14 +184,15 @@ export class AuthController {
return;
}
const csrfCookie = req.cookies?.[AuthService.csrfCookieName] as
| string
| undefined;
if (req.authType === 'jwt') {
await this.auth.signOut(session.sessionId, session.user.id);
res.status(HttpStatus.OK).send({});
return;
}
const csrfCookie = getRequestCookie(req, AuthService.csrfCookieName);
const csrfHeader = req.get('x-affine-csrf-token');
if (
csrfHeader && // optional for backward compatibility, drop after 0.25.0 outdated
(!csrfCookie || csrfCookie !== csrfHeader)
) {
if (!csrfHeader || !csrfCookie || csrfCookie !== csrfHeader) {
throw new ActionForbidden();
}
@@ -286,17 +206,8 @@ export class AuthController {
@UseNamedGuard('version')
@Post('/open-app/sign-in-code')
async openAppSignInCode(@CurrentUser() user?: CurrentUser) {
if (!user) {
throw new ActionForbidden();
}
// short-lived one-time code for handing off the authenticated session
const code = await this.models.verificationToken.create(
TokenType.OpenAppSignIn,
user.id,
5 * 60
);
if (!user) throw new ActionForbidden();
const code = await this.openApp.createSignInCode(user);
return { code };
}
@@ -308,21 +219,21 @@ export class AuthController {
@Res() res: Response,
@Body() credential: OpenAppSignInCredential
) {
if (!credential?.code) {
throw new InvalidAuthState();
}
if (!credential?.code) throw new InvalidAuthState();
const identity = await this.openApp.verifySignInCode(credential.code);
const { exchangeCode } = await this.sessionIssuer.issue(req, res, identity);
res.send({ id: identity.userId, exchangeCode });
}
const tokenRecord = await this.models.verificationToken.get(
TokenType.OpenAppSignIn,
credential.code
);
if (!tokenRecord?.credential) {
throw new InvalidAuthState();
}
await this.auth.setCookies(req, res, tokenRecord.credential);
res.send({ id: tokenRecord.credential });
@Public()
@UseNamedGuard('version')
@Post('/native/exchange')
async exchangeSession(
@Req() req: Request,
@Body() credential: NativeSessionExchangeCredential
) {
if (!credential?.code) throw new InvalidAuthState();
return await this.sessionExchange.exchange(req, credential.code);
}
@Public()
@@ -334,42 +245,11 @@ export class AuthController {
@Body()
{ email, token: otp, client_nonce: clientNonce }: MagicLinkCredential
) {
if (!otp || !email) {
throw new EmailTokenNotFound();
}
if (!otp || !email) throw new EmailTokenNotFound();
validators.assertValidEmail(email);
const consumed = await this.models.magicLinkOtp.consume(
email,
otp,
clientNonce
);
if (!consumed.ok) {
if (consumed.reason === 'nonce_mismatch') {
throw new InvalidAuthState();
}
throw new InvalidEmailToken();
}
const token = consumed.token;
const tokenRecord = await this.models.verificationToken.verify(
TokenType.SignIn,
token,
{
credential: email,
}
);
if (!tokenRecord) {
throw new InvalidEmailToken();
}
const user = await this.models.user.fulfill(email);
await this.auth.setCookies(req, res, user.id);
res.send({ id: user.id });
const identity = await this.magicLink.verify(email, otp, clientNonce);
const { exchangeCode } = await this.sessionIssuer.issue(req, res, identity);
res.send({ id: identity.userId, exchangeCode });
}
@UseNamedGuard('version')
@@ -377,24 +257,6 @@ export class AuthController {
@Public()
@Get('/session')
async currentSessionUser(@CurrentUser() user?: CurrentUser) {
return {
user,
};
}
@Throttle('default', { limit: 1200 })
@Public()
@Get('/sessions')
async currentSessionUsers(@Req() req: Request) {
const token = req.cookies[AuthService.sessionCookieName];
if (!token) {
return {
users: [],
};
}
return {
users: await this.auth.getUserList(token),
};
return { user };
}
}
@@ -0,0 +1,77 @@
import { resolveMx, resolveTxt } from 'node:dns/promises';
const EMAIL_DOMAIN_DNS_TIMEOUT_MS = 2_000;
type DomainLookups = {
resolveMx: typeof resolveMx;
resolveTxt: typeof resolveTxt;
};
const defaultLookups: DomainLookups = {
resolveMx,
resolveTxt,
};
function joinTxtRecords(records: string[][]) {
return records.map(record => record.join(''));
}
async function withTimeout<T>(promise: Promise<T>, timeoutMs: number) {
let timeout: ReturnType<typeof setTimeout> | undefined;
const timeoutPromise = new Promise<never>((_, reject) => {
timeout = setTimeout(
() => reject(new Error('DNS lookup timed out')),
timeoutMs
);
});
try {
return await Promise.race([promise, timeoutPromise]);
} finally {
if (timeout) {
clearTimeout(timeout);
}
}
}
export async function verifyEmailDomainRecords(
email: string,
lookups: DomainLookups = defaultLookups,
timeoutMs = EMAIL_DOMAIN_DNS_TIMEOUT_MS
) {
const [name, domain, ...rest] = email.split('@');
if (rest.length || !domain || name.includes('+')) {
return false;
}
const [mx, spf, dmarc] = await Promise.allSettled([
withTimeout(
lookups
.resolveMx(domain)
.then(records => records.map(mx => mx.exchange).filter(Boolean)),
timeoutMs
),
withTimeout(
lookups
.resolveTxt(domain)
.then(records =>
joinTxtRecords(records).filter(txt => txt.includes('v=spf1'))
),
timeoutMs
),
withTimeout(
lookups
.resolveTxt('_dmarc.' + domain)
.then(records =>
joinTxtRecords(records).filter(txt => txt.includes('v=DMARC1'))
),
timeoutMs
),
]).then(results =>
results
.filter(result => result.status === 'fulfilled')
.map(result => result.value)
);
return !!mx?.length && !!spf?.length && !!dmarc?.length;
}
+161 -41
View File
@@ -23,6 +23,12 @@ import {
UnsupportedClientVersion,
} from '../../base';
import { WEBSOCKET_OPTIONS } from '../../base/websocket';
import {
extractTokenFromHeader,
getSessionOptionsFromRequest,
SessionIdSchema,
} from './input';
import { isLikelyJwt, JwtSessionService } from './jwt-session';
import { AuthService } from './service';
import { Session, TokenSession } from './session';
@@ -31,9 +37,16 @@ const INTERNAL_ENTRYPOINT_SYMBOL = Symbol('internal');
const INTERNAL_ACCESS_TOKEN_TTL_MS = 5 * 60 * 1000;
const INTERNAL_ACCESS_TOKEN_CLOCK_SKEW_MS = 30 * 1000;
type AuthenticatedRequestSession =
| { type: 'jwt'; session: Session }
| { type: 'cookie_session'; session: Session }
| { type: 'legacy_bearer_session'; session: Session }
| { type: 'access_token'; token: TokenSession };
@Injectable()
export class AuthGuard implements CanActivate, OnModuleInit {
private auth!: AuthService;
private jwtSession!: JwtSessionService;
private readonly cachedVersionRange = new Map<string, semver.Range | null>();
private static readonly HARD_REQUIRED_VERSION = '>=0.25.0';
private static readonly CANARY_REQUIRED_VERSION = 'canary (within 2 months)';
@@ -48,6 +61,7 @@ export class AuthGuard implements CanActivate, OnModuleInit {
onModuleInit() {
this.auth = this.ref.get(AuthService, { strict: false });
this.jwtSession = this.ref.get(JwtSessionService, { strict: false });
}
async canActivate(context: ExecutionContext) {
@@ -110,12 +124,102 @@ export class AuthGuard implements CanActivate, OnModuleInit {
res?: Response,
isPublic = false
): Promise<Session | TokenSession | null> {
const userSession = await this.signInWithCookie(req, res, isPublic);
if (userSession) {
return userSession;
const result = await this.resolveRequestSession(req, res, isPublic);
return result?.type === 'access_token'
? result.token
: (result?.session ?? null);
}
private async resolveRequestSession(
req: Request,
res?: Response,
isPublic = false
): Promise<AuthenticatedRequestSession | null> {
const bearer = req.headers.authorization
? extractTokenFromHeader(req.headers.authorization)
: undefined;
let ignoredInvalidPublicJwt = false;
if (bearer && isLikelyJwt(bearer)) {
try {
const session = await this.signInWithJwt(req, bearer, res, isPublic);
return session ? { type: 'jwt', session } : null;
} catch (err) {
if (!isPublic) throw err;
ignoredInvalidPublicJwt = true;
}
}
return await this.signInWithAccessToken(req);
if (bearer && !ignoredInvalidPublicJwt) {
// Legacy auth compatibility: old clients may still send opaque session ids as bearer tokens.
const legacyBearerSession = await this.signInWithSessionId(
req,
bearer,
res,
isPublic
);
if (legacyBearerSession) {
return { type: 'legacy_bearer_session', session: legacyBearerSession };
}
const token = await this.signInWithAccessToken(req);
return token ? { type: 'access_token', token } : null;
}
const session = await this.signInWithCookie(req, res, isPublic);
return session ? { type: 'cookie_session', session } : null;
}
async signInWithJwt(
req: Request,
token: string,
res?: Response,
isPublic = false
): Promise<Session | null> {
if (req.session && req.authType === 'jwt') return req.session;
const session = await this.jwtSession.verify(token);
const versionAllowed = await this.checkUserSessionClientVersion(
req,
session,
res,
isPublic
);
if (!versionAllowed) return null;
req.session = session;
req.authType = 'jwt';
return req.session;
}
async signInWithSessionId(
req: Request,
sessionId: string,
res?: Response,
isPublic = false
): Promise<Session | null> {
if (req.session && req.session.sessionId === sessionId) return req.session;
const parsedSessionId = SessionIdSchema.safeParse(sessionId);
if (!parsedSessionId.success) return null;
const { userId } = getSessionOptionsFromRequest(req);
const userSession = await this.auth.getUserSession(
parsedSessionId.data,
userId
);
if (!userSession) return null;
req.session = { ...userSession.session, user: userSession.user };
const versionAllowed = await this.checkUserSessionClientVersion(
req,
req.session,
res,
isPublic
);
if (!versionAllowed) {
req.session = undefined;
return null;
}
req.authType = 'session';
return req.session;
}
async signInWithCookie(
@@ -123,37 +227,24 @@ export class AuthGuard implements CanActivate, OnModuleInit {
res?: Response,
isPublic = false
): Promise<Session | null> {
if (req.session) {
return req.session;
}
if (req.session) return req.session;
// TODO(@forehalo): a cache for user session
const userSession = await this.auth.getUserSessionFromRequest(req, res);
if (userSession) {
const headerClientVersion = getClientVersionFromRequest(req);
if (this.config.client.versionControl.enabled) {
const clientVersion =
headerClientVersion ??
userSession.session.refreshClientVersion ??
userSession.session.signInClientVersion;
req.session = { ...userSession.session, user: userSession.user };
const versionCheckResult = this.checkClientVersion(clientVersion);
if (!versionCheckResult.ok) {
await this.auth.signOut(userSession.session.sessionId);
if (res) {
await this.auth.refreshCookies(res, userSession.session.sessionId);
}
if (isPublic) {
return null;
}
throw new UnsupportedClientVersion({
clientVersion: clientVersion ?? 'unset_or_invalid',
requiredVersion: versionCheckResult.requiredVersion,
});
}
const versionAllowed = await this.checkUserSessionClientVersion(
req,
req.session,
res,
isPublic
);
if (!versionAllowed) {
req.session = undefined;
return null;
}
if (res) {
@@ -165,10 +256,7 @@ export class AuthGuard implements CanActivate, OnModuleInit {
);
}
req.session = {
...userSession.session,
user: userSession.user,
};
req.authType = 'session';
return req.session;
}
@@ -176,6 +264,42 @@ export class AuthGuard implements CanActivate, OnModuleInit {
return null;
}
private async checkUserSessionClientVersion(
req: Request,
session: Session,
res?: Response,
isPublic = false
) {
if (!this.config.client.versionControl.enabled) {
return true;
}
const headerClientVersion = getClientVersionFromRequest(req);
const clientVersion =
headerClientVersion ??
session.refreshClientVersion ??
session.signInClientVersion;
const versionCheckResult = this.checkClientVersion(clientVersion);
if (versionCheckResult.ok) {
return true;
}
await this.auth.signOut(session.sessionId);
if (res) {
await this.auth.refreshCookies(res, session.sessionId);
}
if (isPublic) {
return false;
}
throw new UnsupportedClientVersion({
clientVersion: clientVersion ?? 'unset_or_invalid',
requiredVersion: versionCheckResult.requiredVersion,
});
}
async signInWithAccessToken(req: Request): Promise<TokenSession | null> {
if (req.token) {
return req.token;
@@ -184,10 +308,8 @@ export class AuthGuard implements CanActivate, OnModuleInit {
const tokenSession = await this.auth.getTokenSessionFromRequest(req);
if (tokenSession) {
req.token = {
...tokenSession.token,
user: tokenSession.user,
};
req.token = { ...tokenSession.token, user: tokenSession.user };
req.authType = 'access_token';
return req.token;
}
@@ -280,11 +402,9 @@ export const AuthWebsocketOptionsProvider: FactoryProvider = {
// compatibility with websocket request
parseCookies(upgradeReq);
upgradeReq.cookies = {
[AuthService.sessionCookieName]: handshake.auth.token,
[AuthService.userCookieName]: handshake.auth.userId,
...upgradeReq.cookies,
};
if (handshake.auth.tokenType === 'jwt') {
upgradeReq.headers.authorization = `Bearer ${handshake.auth.token}`;
}
const session = await (async () => {
try {
@@ -0,0 +1,12 @@
export type AuthMethod =
| 'password'
| 'magic_link'
| 'oauth'
| 'open_app'
| 'passkey';
export interface VerifiedIdentity {
userId: string;
method: AuthMethod;
clientVersion?: string;
}
+34 -2
View File
@@ -6,11 +6,18 @@ import { FeatureModule } from '../features';
import { MailModule } from '../mail';
import { QuotaModule } from '../quota';
import { UserModule } from '../user';
import { AuthChallengeStore } from './challenge-store';
import { AuthController } from './controller';
import { AuthGuard, AuthWebsocketOptionsProvider } from './guard';
import { AuthCronJob } from './job';
import { JwtSessionService } from './jwt-session';
import { MagicLinkAuthService } from './magic-link';
import { AuthMethodsService } from './methods';
import { SessionExchangeService } from './native-exchange';
import { OpenAppAuthService } from './open-app';
import { AuthResolver } from './resolver';
import { AuthService } from './service';
import { SessionIssuer } from './session-issuer';
@Module({
imports: [FeatureModule, UserModule, QuotaModule, MailModule],
@@ -18,15 +25,40 @@ import { AuthService } from './service';
AuthService,
AuthResolver,
AuthGuard,
JwtSessionService,
SessionIssuer,
AuthChallengeStore,
MagicLinkAuthService,
OpenAppAuthService,
AuthMethodsService,
SessionExchangeService,
AuthCronJob,
AuthWebsocketOptionsProvider,
],
exports: [AuthService, AuthGuard, AuthWebsocketOptionsProvider],
exports: [
AuthService,
AuthGuard,
JwtSessionService,
SessionIssuer,
AuthChallengeStore,
MagicLinkAuthService,
OpenAppAuthService,
AuthMethodsService,
SessionExchangeService,
AuthWebsocketOptionsProvider,
],
controllers: [AuthController],
})
export class AuthModule {}
export { AuthChallengeStore } from './challenge-store';
export * from './guard';
export * from './identity';
export * from './input';
export { MagicLinkAuthService } from './magic-link';
export * from './methods';
export { SessionExchangeService };
export { OpenAppAuthService } from './open-app';
export { ClientTokenType } from './resolver';
export { AuthService };
export { AuthService, JwtSessionService, SessionIssuer };
export * from './session';
@@ -0,0 +1,86 @@
import type { Request } from 'express';
import { z } from 'zod';
import { getRequestCookie, getRequestHeader } from '../../base';
export const CLIENT_KIND_HEADER = 'x-affine-client-kind';
export const SESSION_COOKIE_NAME = 'affine_session';
export const USER_COOKIE_NAME = 'affine_user_id';
export const CSRF_COOKIE_NAME = 'affine_csrf_token';
const NativeClientOriginSchema = z
.enum(['capacitor://localhost', 'ionic://localhost', 'https://localhost'])
.optional();
const NativeClientHeadersSchema = z.object({
clientKind: z.literal('native'),
origin: NativeClientOriginSchema,
});
export const BearerHeaderSchema = z
.string()
.regex(/^Bearer\s+\S+$/i)
.transform(value => value.replace(/^Bearer\s+/i, ''));
export function extractTokenFromHeader(authorization: string) {
const parsed = BearerHeaderSchema.safeParse(authorization);
return parsed.success ? parsed.data : undefined;
}
export const SessionIdSchema = z.string().uuid();
export const UserIdSchema = z.union([
z.string().uuid(),
z.string().regex(/^[A-Za-z0-9_-]{1,128}$/),
]);
export const OAuthCallbackBodySchema = z.object({
code: z.string().min(1),
state: z.string().min(1),
client_nonce: z
.string()
.min(1)
.nullish()
.transform(value => value ?? undefined),
});
export const OAuthPreflightBodySchema = z.object({
provider: z.string().min(1),
redirect_uri: z
.string()
.min(1)
.nullish()
.transform(value => value ?? undefined),
client: z
.string()
.min(1)
.nullish()
.transform(value => value ?? undefined),
client_nonce: z.string().min(1),
});
export const OAuthStateEnvelopeSchema = z.object({
state: z.string().min(1),
provider: z.string().min(1).optional(),
});
export function getSessionOptionsFromRequest(req: Request) {
const sessionId = SessionIdSchema.safeParse(
getRequestCookie(req, SESSION_COOKIE_NAME)
);
const userId = UserIdSchema.safeParse(
getRequestCookie(req, USER_COOKIE_NAME)
);
return {
sessionId: sessionId.success ? sessionId.data : undefined,
userId: userId.success ? userId.data : undefined,
};
}
export function isNativeClientRequest(req: Request) {
return NativeClientHeadersSchema.safeParse({
clientKind: getRequestHeader(req, CLIENT_KIND_HEADER),
origin: getRequestHeader(req, 'origin'),
}).success;
}
@@ -0,0 +1,92 @@
import { Injectable } from '@nestjs/common';
import jwt, { type JwtPayload } from 'jsonwebtoken';
import { AuthenticationRequired, CryptoHelper } from '../../base';
import { Models } from '../../models';
import { sessionUser } from './service';
import type { CurrentUser, Session } from './session';
const JWT_SESSION_TYPE = 'user_session';
const JWT_SESSION_ISSUER = 'affine';
const JWT_SESSION_AUDIENCE = 'affine-client';
const JWT_SESSION_TTL = 15 * 60;
export interface SignedJwtSession {
token: string;
expiresAt: Date;
}
interface UserSessionJwtPayload extends JwtPayload {
sub: string;
sid: string;
typ: typeof JWT_SESSION_TYPE;
}
function isUserSessionJwtPayload(
payload: string | JwtPayload
): payload is UserSessionJwtPayload {
return (
typeof payload !== 'string' &&
typeof payload.sub === 'string' &&
typeof payload.sid === 'string' &&
payload.typ === JWT_SESSION_TYPE
);
}
@Injectable()
export class JwtSessionService {
constructor(
private readonly crypto: CryptoHelper,
private readonly models: Models
) {}
private get currentKey() {
return Buffer.concat([
Buffer.from('affine:user-session-jwt:v1:'),
this.crypto.keyPair.sha256.privateKey,
]);
}
sign(userId: string, sessionId: string): SignedJwtSession {
const expiresAt = new Date(Date.now() + JWT_SESSION_TTL * 1000);
const token = jwt.sign(
{ sid: sessionId, typ: JWT_SESSION_TYPE },
this.currentKey,
{
algorithm: 'HS256',
audience: JWT_SESSION_AUDIENCE,
expiresIn: JWT_SESSION_TTL,
issuer: JWT_SESSION_ISSUER,
subject: userId,
}
);
return { token, expiresAt };
}
async verify(token: string): Promise<Session> {
let payload: string | JwtPayload;
try {
payload = jwt.verify(token, this.currentKey, {
algorithms: ['HS256'],
audience: JWT_SESSION_AUDIENCE,
issuer: JWT_SESSION_ISSUER,
});
} catch {
throw new AuthenticationRequired();
}
if (!isUserSessionJwtPayload(payload)) throw new AuthenticationRequired();
const userSession = await this.models.session
.findUserSessionsBySessionId(payload.sid)
.then(sessions => sessions.find(s => s.userId === payload.sub));
if (!userSession) throw new AuthenticationRequired();
const user = await this.models.user.get(payload.sub);
if (!user) throw new AuthenticationRequired();
return { ...userSession, user: sessionUser(user) as CurrentUser };
}
}
export function isLikelyJwt(token: string) {
return token.split('.').length === 3;
}
@@ -0,0 +1,128 @@
import { Injectable, Logger } from '@nestjs/common';
import {
ActionForbidden,
Config,
CryptoHelper,
InvalidAuthState,
InvalidEmail,
InvalidEmailToken,
SignUpForbidden,
URLHelper,
WrongSignInCredentials,
} from '../../base';
import { Models, TokenType } from '../../models';
import { validators } from '../utils/validators';
import { verifyEmailDomainRecords } from './email-domain';
import type { VerifiedIdentity } from './identity';
import { AuthService } from './service';
@Injectable()
export class MagicLinkAuthService {
private readonly logger = new Logger(MagicLinkAuthService.name);
constructor(
private readonly url: URLHelper,
private readonly auth: AuthService,
private readonly models: Models,
private readonly config: Config,
private readonly crypto: CryptoHelper
) {}
async send(email: string, callbackUrl = '/magic-link', clientNonce?: string) {
validators.assertValidEmail(email);
if (!this.url.isAllowedCallbackUrl(callbackUrl)) {
throw new ActionForbidden();
}
const callbackUrlObj = this.url.url(callbackUrl);
const redirectUriInCallback =
callbackUrlObj.searchParams.get('redirect_uri');
if (
redirectUriInCallback &&
!this.url.isAllowedRedirectUri(redirectUriInCallback)
) {
throw new ActionForbidden();
}
const user = await this.models.user.getUserByEmail(email, {
withDisabled: true,
});
if (!user) {
await this.assertSignupAllowed(email);
} else if (user.disabled) {
throw new WrongSignInCredentials({ email });
}
const ttlInSec = 30 * 60;
const token = await this.models.verificationToken.create(
TokenType.SignIn,
email,
ttlInSec
);
const otp = this.crypto.otp();
await this.models.magicLinkOtp.upsert(email, otp, token, clientNonce);
const magicLink = this.url.link(callbackUrl, { token: otp, email });
if (env.dev) {
this.logger.debug(`Magic link: ${magicLink}`);
}
await this.auth.sendSignInEmail(email, magicLink, otp, !user);
return { email };
}
async verify(
email: string,
otp: string,
clientNonce?: string
): Promise<VerifiedIdentity> {
validators.assertValidEmail(email);
const consumed = await this.models.magicLinkOtp.consume(
email,
otp,
clientNonce
);
if (!consumed.ok) {
if (consumed.reason === 'nonce_mismatch') {
throw new InvalidAuthState();
}
throw new InvalidEmailToken();
}
const tokenRecord = await this.models.verificationToken.verify(
TokenType.SignIn,
consumed.token,
{
credential: email,
}
);
if (!tokenRecord) {
throw new InvalidEmailToken();
}
const user = await this.models.user.fulfill(email);
return { userId: user.id, method: 'magic_link' };
}
private async assertSignupAllowed(email: string) {
if (!this.config.auth.allowSignup) {
throw new SignUpForbidden();
}
if (!this.config.auth.requireEmailDomainVerification) {
return;
}
if (!(await verifyEmailDomainRecords(email))) {
throw new InvalidEmail({ email });
}
}
}
@@ -0,0 +1,131 @@
import { Injectable } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { PrismaClient } from '@prisma/client';
import { Config } from '../../base';
import { Models, type User } from '../../models';
import { verifyEmailDomainRecords } from './email-domain';
export const AUTH_OAUTH_PROVIDER_READER = Symbol('AUTH_OAUTH_PROVIDER_READER');
interface OAuthProviderReader {
providers: string[];
}
export interface LoginAuthMethods {
password: { available: boolean };
magicLink: { available: boolean };
oauth: { available: boolean; providers: string[] };
passkey: { available: boolean; discoverable: boolean };
}
export interface BoundAuthMethods {
password: { bound: boolean };
oauth: { bound: boolean; providers: string[] };
passkey: { bound: boolean; count: number };
}
@Injectable()
export class AuthMethodsService {
constructor(
private readonly config: Config,
private readonly models: Models,
private readonly db: PrismaClient,
private readonly ref: ModuleRef
) {}
async loginPreflight(email: string) {
const [user, userWithDisabled] = await Promise.all([
this.models.user.getUserByEmail(email),
this.models.user.getUserByEmail(email, {
withDisabled: true,
}),
]);
const disabledUser =
userWithDisabled?.disabled && !user ? userWithDisabled : null;
const providers = this.oauthProviders();
return {
registered: !!user?.registered,
methods: {
password: {
available:
!!user?.password &&
!user.disabled &&
(await this.canPasswordSignIn(email)),
},
magicLink: {
available: await this.canMagicLinkSignIn(email, user, disabledUser),
},
oauth: {
available: providers.length > 0,
providers,
},
passkey: {
available: false,
discoverable: false,
},
} satisfies LoginAuthMethods,
};
}
async boundMethods(userId: string): Promise<BoundAuthMethods> {
const [user, connectedAccounts] = await Promise.all([
this.models.user.get(userId),
this.db.connectedAccount.findMany({
select: { provider: true },
where: { userId },
}),
]);
const providers = Array.from(
new Set(connectedAccounts.map(account => account.provider))
);
return {
password: { bound: !!user?.password },
oauth: { bound: providers.length > 0, providers },
passkey: { bound: false, count: 0 },
};
}
private async canPasswordSignIn(_email: string) {
return true;
}
private async canMagicLinkSignIn(
email: string,
user: User | null,
disabledUser: User | null
) {
if (disabledUser) {
return false;
}
if (user) {
return !user.disabled;
}
if (!this.config.auth.allowSignup) {
return false;
}
return this.emailDomainAllowed(email);
}
private async emailDomainAllowed(email: string) {
if (!this.config.auth.requireEmailDomainVerification) {
return true;
}
return verifyEmailDomainRecords(email);
}
private oauthProviders() {
try {
const reader = this.ref.get<OAuthProviderReader>(
AUTH_OAUTH_PROVIDER_READER,
{ strict: false }
);
return reader.providers;
} catch {
return [];
}
}
}
@@ -0,0 +1,59 @@
import { Injectable } from '@nestjs/common';
import type { Request } from 'express';
import { ActionForbidden, InvalidAuthState } from '../../base';
import { AuthChallengeStore } from './challenge-store';
import { isNativeClientRequest } from './input';
import { JwtSessionService } from './jwt-session';
import { AuthService } from './service';
interface SessionExchangePayload {
userId: string;
sessionId: string;
}
@Injectable()
export class SessionExchangeService {
constructor(
private readonly auth: AuthService,
private readonly challenges: AuthChallengeStore,
private readonly jwtSession: JwtSessionService
) {}
async createCode(req: Request, userId: string, sessionId: string) {
if (!isNativeClientRequest(req)) {
return;
}
return this.challenges.create<SessionExchangePayload>(
'native_session_exchange',
{ userId, sessionId },
60 * 1000
);
}
async exchange(req: Request, code: string) {
if (!isNativeClientRequest(req)) {
throw new ActionForbidden();
}
const payload = await this.challenges.consume<SessionExchangePayload>(
'native_session_exchange',
code
);
if (!payload?.userId || !payload.sessionId) {
throw new InvalidAuthState();
}
const session = await this.auth.getUserSession(
payload.sessionId,
payload.userId
);
if (!session) {
throw new InvalidAuthState();
}
return this.jwtSession.sign(payload.userId, payload.sessionId);
}
}
@@ -0,0 +1,32 @@
import { Injectable } from '@nestjs/common';
import { InvalidAuthState } from '../../base';
import { AuthChallengeStore } from './challenge-store';
import type { VerifiedIdentity } from './identity';
import type { CurrentUser } from './session';
@Injectable()
export class OpenAppAuthService {
constructor(private readonly challenges: AuthChallengeStore) {}
async createSignInCode(user: CurrentUser) {
return this.challenges.create(
'open_app_sign_in',
{ userId: user.id },
5 * 60 * 1000
);
}
async verifySignInCode(code: string): Promise<VerifiedIdentity> {
const payload = await this.challenges.consume<{ userId?: string }>(
'open_app_sign_in',
code
);
if (!payload?.userId) {
throw new InvalidAuthState();
}
return { userId: payload.userId, method: 'open_app' };
}
}
@@ -63,7 +63,7 @@ export class AuthResolver {
@ResolveField(() => ClientTokenType, {
name: 'token',
deprecationReason: 'use [/api/auth/sign-in?native=true] instead',
deprecationReason: 'use native session exchange instead',
})
async clientToken(
@CurrentUser() currentUser: CurrentUser,
@@ -4,14 +4,18 @@ import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
import type { CookieOptions, Request, Response } from 'express';
import { assign, pick } from 'lodash-es';
import {
Config,
getClientVersionFromRequest,
SignUpForbidden,
} from '../../base';
import { Config, SignUpForbidden } from '../../base';
import { Models, type User, type UserSession } from '../../models';
import { Mailer } from '../mail/mailer';
import { createDevUsers } from './dev';
import type { VerifiedIdentity } from './identity';
import {
CSRF_COOKIE_NAME,
extractTokenFromHeader,
getSessionOptionsFromRequest,
SESSION_COOKIE_NAME,
USER_COOKIE_NAME,
} from './input';
import type { CurrentUser } from './session';
export function sessionUser(
@@ -27,20 +31,12 @@ export function sessionUser(
});
}
function extractTokenFromHeader(authorization: string) {
if (!/^Bearer\s/i.test(authorization)) {
return;
}
return authorization.substring(7);
}
@Injectable()
export class AuthService implements OnApplicationBootstrap {
readonly cookieOptions: CookieOptions;
static readonly sessionCookieName = 'affine_session';
static readonly userCookieName = 'affine_user_id';
static readonly csrfCookieName = 'affine_csrf_token';
static readonly sessionCookieName = SESSION_COOKIE_NAME;
static readonly userCookieName = USER_COOKIE_NAME;
static readonly csrfCookieName = CSRF_COOKIE_NAME;
constructor(
private readonly config: Config,
@@ -90,6 +86,14 @@ export class AuthService implements OnApplicationBootstrap {
return this.models.user.signIn(email, password).then(sessionUser);
}
async verifyPassword(
email: string,
password: string
): Promise<VerifiedIdentity> {
const user = await this.models.user.signIn(email, password);
return { userId: user.id, method: 'password' };
}
async signOut(sessionId: string, userId?: string) {
// sign out all users in the session
if (!userId) {
@@ -104,10 +108,7 @@ export class AuthService implements OnApplicationBootstrap {
userId?: string
): Promise<{ user: CurrentUser; session: UserSession } | null> {
const sessions = await this.getUserSessions(sessionId);
if (!sessions.length) {
return null;
}
if (!sessions.length) return null;
let userSession: UserSession | undefined;
@@ -201,55 +202,6 @@ export class AuthService implements OnApplicationBootstrap {
return await this.models.session.deleteUserSessions(userId);
}
getSessionOptionsFromRequest(req: Request) {
let sessionId: string | undefined =
req.cookies[AuthService.sessionCookieName];
if (!sessionId && req.headers.authorization) {
sessionId = extractTokenFromHeader(req.headers.authorization);
}
const userId: string | undefined =
req.cookies[AuthService.userCookieName] ||
req.headers[AuthService.userCookieName.replaceAll('_', '-')];
return {
sessionId,
userId,
};
}
async setCookies(
req: Request,
res: Response,
userId: string,
clientVersion?: string
) {
const { sessionId } = this.getSessionOptionsFromRequest(req);
const signInClientVersion =
clientVersion ?? getClientVersionFromRequest(req);
const userSession = await this.createUserSession(
userId,
sessionId,
undefined,
signInClientVersion
);
res.cookie(AuthService.sessionCookieName, userSession.sessionId, {
...this.cookieOptions,
expires: userSession.expiresAt ?? void 0,
});
res.cookie(AuthService.csrfCookieName, randomUUID(), {
...this.cookieOptions,
httpOnly: false,
expires: userSession.expiresAt ?? void 0,
});
this.setUserCookie(res, userId);
}
async refreshCookies(res: Response, sessionId?: string) {
if (sessionId) {
const users = await this.getUserList(sessionId);
@@ -264,7 +216,7 @@ export class AuthService implements OnApplicationBootstrap {
this.clearCookies(res);
}
private clearCookies(res: Response<any, Record<string, any>>) {
clearCookies(res: Response<any, Record<string, any>>) {
res.clearCookie(AuthService.sessionCookieName);
res.clearCookie(AuthService.userCookieName);
res.clearCookie(AuthService.csrfCookieName);
@@ -281,12 +233,8 @@ export class AuthService implements OnApplicationBootstrap {
}
async getUserSessionFromRequest(req: Request, res?: Response) {
const { sessionId, userId } = this.getSessionOptionsFromRequest(req);
if (!sessionId) {
return null;
}
const { sessionId, userId } = getSessionOptionsFromRequest(req);
if (!sessionId) return null;
const session = await this.getUserSession(sessionId, userId);
if (res) {
@@ -0,0 +1,73 @@
import { randomUUID } from 'node:crypto';
import { Injectable } from '@nestjs/common';
import type { Request, Response } from 'express';
import { getClientVersionFromRequest, getRequestCookie } from '../../base';
import type { VerifiedIdentity } from './identity';
import { isNativeClientRequest } from './input';
import { SessionExchangeService } from './native-exchange';
import { AuthService } from './service';
export type IssuedSession = {
userId: string;
sessionId: string;
exchangeCode?: string;
};
@Injectable()
export class SessionIssuer {
constructor(
private readonly auth: AuthService,
private readonly sessionExchange: SessionExchangeService
) {}
async issue(
req: Request,
res: Response,
identity: VerifiedIdentity
): Promise<IssuedSession> {
const nativeClient = isNativeClientRequest(req);
const sessionId =
req.authType === 'jwt'
? req.session?.sessionId
: getRequestCookie(req, AuthService.sessionCookieName);
const signInClientVersion =
identity.clientVersion ?? getClientVersionFromRequest(req);
const userSession = await this.auth.createUserSession(
identity.userId,
sessionId,
undefined,
signInClientVersion
);
if (nativeClient) {
this.auth.clearCookies(res);
} else {
res.cookie(AuthService.sessionCookieName, userSession.sessionId, {
...this.auth.cookieOptions,
expires: userSession.expiresAt ?? void 0,
});
res.cookie(AuthService.csrfCookieName, randomUUID(), {
...this.auth.cookieOptions,
httpOnly: false,
expires: userSession.expiresAt ?? void 0,
});
this.auth.setUserCookie(res, identity.userId);
}
const exchangeCode = await this.sessionExchange.createCode(
req,
identity.userId,
userSession.sessionId
);
return {
userId: identity.userId,
sessionId: userSession.sessionId,
exchangeCode,
};
}
}
@@ -10,7 +10,7 @@ import {
UseNamedGuard,
} from '../../base';
import { Models } from '../../models';
import { AuthService, Public } from '../auth';
import { Public, SessionIssuer } from '../auth';
import { ServerService } from '../config';
import { validators } from '../utils/validators';
@@ -26,7 +26,7 @@ export class CustomSetupController {
constructor(
private readonly config: Config,
private readonly models: Models,
private readonly auth: AuthService,
private readonly sessionIssuer: SessionIssuer,
private readonly mutex: Mutex,
private readonly server: ServerService
) {}
@@ -72,7 +72,10 @@ export class CustomSetupController {
'selfhost setup'
);
await this.auth.setCookies(req, res, user.id);
await this.sessionIssuer.issue(req, res, {
userId: user.id,
method: 'password',
});
res.send({ id: user.id, email: user.email, name: user.name });
} catch (e) {
await this.models.user.delete(user.id);
+1
View File
@@ -2,6 +2,7 @@ declare namespace Express {
interface Request {
session?: import('./core/auth/session').Session;
token?: import('./core/auth/session').TokenSession;
authType?: 'jwt' | 'session' | 'access_token';
}
}
@@ -13,8 +13,6 @@ export enum TokenType {
VerifyEmail,
ChangeEmail,
ChangePassword,
Challenge,
OpenAppSignIn,
}
@Injectable()
@@ -12,7 +12,7 @@ import {
OnEvent,
} from '../../base';
import { ServerFeature, ServerService } from '../../core';
import { Models, TokenType } from '../../models';
import { AuthChallengeStore } from '../../core/auth';
import { verifyChallengeResponse } from '../../native';
import { CaptchaConfig } from './types';
@@ -28,7 +28,7 @@ export class CaptchaService {
constructor(
private readonly config: Config,
private readonly models: Models,
private readonly challenges: AuthChallengeStore,
private readonly server: ServerService
) {
this.captcha = config.captcha.config;
@@ -93,10 +93,10 @@ export class CaptchaService {
async getChallengeToken() {
const resource = randomUUID();
const challenge = await this.models.verificationToken.create(
TokenType.Challenge,
const challenge = await this.challenges.create(
'captcha',
resource,
5 * 60
5 * 60 * 1000
);
return {
@@ -117,9 +117,7 @@ export class CaptchaService {
const challenge = credential.challenge;
let resource: string | null = null;
if (typeof challenge === 'string' && challenge) {
resource = await this.models.verificationToken
.get(TokenType.Challenge, challenge)
.then(token => token?.credential || null);
resource = await this.challenges.consume<string>('captcha', challenge);
}
if (resource) {
@@ -3,68 +3,63 @@ import {
Controller,
HttpCode,
HttpStatus,
Logger,
Post,
type RawBodyRequest,
Req,
Res,
} from '@nestjs/common';
import { ConnectedAccount } from '@prisma/client';
import type { Request, Response } from 'express';
import {
ActionForbidden,
Config,
getClientVersionFromRequest,
InvalidAuthState,
InvalidOauthCallbackState,
MissingOauthQueryParameter,
OauthAccountAlreadyConnected,
OauthStateExpired,
SignUpForbidden,
UnknownOauthProvider,
URLHelper,
UseNamedGuard,
} from '../../base';
import { AuthService, Public } from '../../core/auth';
import { Models } from '../../models';
import {
OAuthCallbackBodySchema,
OAuthPreflightBodySchema,
Public,
SessionIssuer,
} from '../../core/auth';
import { OAuthProviderName } from './config';
import { OAuthProviderFactory } from './factory';
import { OAuthAccount, Tokens } from './providers/def';
import { OAuthService } from './service';
@Controller('/api/oauth')
export class OAuthController {
private readonly logger = new Logger(OAuthController.name);
constructor(
private readonly auth: AuthService,
private readonly sessionIssuer: SessionIssuer,
private readonly oauth: OAuthService,
private readonly models: Models,
private readonly providerFactory: OAuthProviderFactory,
private readonly url: URLHelper,
private readonly config: Config
private readonly url: URLHelper
) {}
@Public()
@UseNamedGuard('version')
@Post('/preflight')
@HttpCode(HttpStatus.OK)
async preflight(
@Req() req: Request,
@Body('provider') unknownProviderName?: keyof typeof OAuthProviderName,
@Body('redirect_uri') redirectUri?: string,
@Body('client') client?: string,
@Body('client_nonce') clientNonce?: string
) {
if (!unknownProviderName) {
async preflight(@Req() req: Request, @Body() body?: unknown) {
const input = OAuthPreflightBodySchema.safeParse(body);
if (!input.success) {
const fields = new Set(input.error.issues.map(issue => issue.path[0]));
if (fields.has('client_nonce')) {
throw new MissingOauthQueryParameter({ name: 'client_nonce' });
}
throw new MissingOauthQueryParameter({ name: 'provider' });
}
if (!clientNonce) {
throw new MissingOauthQueryParameter({ name: 'client_nonce' });
}
const providerName = OAuthProviderName[unknownProviderName];
const {
provider: unknownProviderName,
redirect_uri: redirectUri,
client,
client_nonce: clientNonce,
} = input.data;
const providerName =
OAuthProviderName[unknownProviderName as keyof typeof OAuthProviderName];
const provider = this.providerFactory.get(providerName);
if (!provider) {
@@ -123,57 +118,38 @@ export class OAuthController {
async callback(
@Req() req: RawBodyRequest<Request>,
@Res() res: Response,
@Body('code') code?: string,
@Body('state') stateStr?: string,
@Body('client_nonce') clientNonce?: string
@Body() body?: unknown
) {
// TODO(@forehalo): refactor and remove deprecated code in 0.23
if (!code) {
throw new MissingOauthQueryParameter({ name: 'code' });
}
if (!stateStr) {
const input = OAuthCallbackBodySchema.safeParse(body);
if (!input.success) {
const fields = new Set(input.error.issues.map(issue => issue.path[0]));
if (fields.has('code')) {
throw new MissingOauthQueryParameter({ name: 'code' });
}
if (fields.has('state')) {
throw new MissingOauthQueryParameter({ name: 'state' });
}
throw new MissingOauthQueryParameter({ name: 'state' });
}
// NOTE(@forehalo): Apple sign in will directly post /callback, with `state` set at #L73
let rawState = null;
if (typeof stateStr === 'string' && stateStr.length > 36) {
try {
rawState = JSON.parse(stateStr);
stateStr = rawState.state;
} catch {
/* noop */
}
}
const { code, state: stateStr, client_nonce: clientNonce } = input.data;
if (typeof stateStr !== 'string' || !this.oauth.isValidState(stateStr)) {
throw new InvalidOauthCallbackState();
}
const verified = await this.oauth.verifyCallback({
code,
stateStr,
clientNonce,
rawBody: req.rawBody,
});
const state = await this.oauth.getOAuthState(stateStr);
if (!state) {
throw new OauthStateExpired();
}
if (!state.token) {
state.token = stateStr;
}
if (
state.provider === OAuthProviderName.Apple &&
rawState &&
state.client &&
state.client !== 'web'
) {
const clientUrl = new URL(`${state.client}://authentication`);
if (verified.type === 'handoff') {
const clientUrl = new URL(`${verified.state.client}://authentication`);
clientUrl.searchParams.set('method', 'oauth');
clientUrl.searchParams.set(
'payload',
JSON.stringify({
state: stateStr,
state: verified.stateToken,
code,
provider: rawState.provider,
provider: verified.provider,
})
);
clientUrl.searchParams.set('server', this.url.requestOrigin);
@@ -185,46 +161,8 @@ export class OAuthController {
);
}
if (!state.provider) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
const provider = this.providerFactory.get(state.provider);
if (!provider) {
throw new UnknownOauthProvider({ name: state.provider ?? 'unknown' });
}
if (
state.provider !== OAuthProviderName.Apple &&
(!clientNonce || !state.clientNonce || state.clientNonce !== clientNonce)
) {
throw new InvalidAuthState();
}
let tokens: Tokens;
try {
tokens = await provider.getToken(code, state);
} catch (err) {
let rayBodyString = '';
if (req.rawBody) {
// only log the first 4096 bytes of the raw body
rayBodyString = req.rawBody.subarray(0, 4096).toString('utf-8');
}
this.logger.warn(
`Error getting oauth token for ${state.provider}, callback code: ${code}, stateStr: ${stateStr}, rawBody: ${rayBodyString}, error: ${err}`
);
throw err;
}
const externAccount = await provider.getUser(tokens, state);
const user = await this.getOrCreateUserFromOauth(
state.provider,
externAccount,
tokens
);
await this.auth.setCookies(req, res, user.id, state.clientVersion);
const { identity, state } = verified;
const { exchangeCode } = await this.sessionIssuer.issue(req, res, identity);
if (
state.provider === OAuthProviderName.Apple &&
@@ -234,96 +172,9 @@ export class OAuthController {
}
res.send({
id: user.id,
id: identity.userId,
exchangeCode,
redirectUri: state.redirectUri,
});
}
private async getOrCreateUserFromOauth(
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedAccount = await this.models.user.getConnectedAccount(
provider,
externalAccount.id
);
if (connectedAccount) {
// already connected
await this.updateConnectedAccount(connectedAccount, tokens);
if (
!connectedAccount.user.emailVerifiedAt &&
// external email may change, check if it matches exists email
externalAccount.email.toLowerCase() ===
connectedAccount.user.email.toLowerCase()
) {
await this.auth.setEmailVerified(connectedAccount.userId);
}
return connectedAccount.user;
}
if (!this.config.auth.allowSignupForOauth) {
throw new SignUpForbidden();
}
const user = await this.models.user.fulfill(externalAccount.email, {
name: externalAccount.name,
avatarUrl: externalAccount.avatarUrl,
});
await this.models.user.createConnectedAccount({
userId: user.id,
provider,
providerAccountId: externalAccount.id,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
return user;
}
private async updateConnectedAccount(
connectedAccount: ConnectedAccount,
tokens: Tokens
) {
return await this.models.user.updateConnectedAccount(connectedAccount.id, {
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
}
/**
* we currently don't support connect oauth account to existing user
* keep it incase we need it in the future
*/
// @ts-expect-error allow unused
private async _connectAccount(
user: { id: string },
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedAccount = await this.models.user.getConnectedAccount(
provider,
externalAccount.id
);
if (connectedAccount) {
if (connectedAccount.userId !== user.id) {
throw new OauthAccountAlreadyConnected();
}
} else {
await this.models.user.createConnectedAccount({
userId: user.id,
provider,
providerAccountId: externalAccount.id,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
}
}
}
@@ -3,7 +3,7 @@ import './config';
import { Module } from '@nestjs/common';
import { ServerConfigModule } from '../../core';
import { AuthModule } from '../../core/auth';
import { AUTH_OAUTH_PROVIDER_READER, AuthModule } from '../../core/auth';
import { UserModule } from '../../core/user';
import { OAuthController } from './controller';
import { OAuthProviderFactory } from './factory';
@@ -15,6 +15,7 @@ import { OAuthService } from './service';
imports: [AuthModule, UserModule, ServerConfigModule],
providers: [
OAuthProviderFactory,
{ provide: AUTH_OAUTH_PROVIDER_READER, useExisting: OAuthProviderFactory },
OAuthService,
OAuthResolver,
...OAuthProviders,
@@ -1,18 +1,57 @@
import { createHash, randomBytes, randomUUID } from 'node:crypto';
import { createHash, randomBytes } from 'node:crypto';
import { Injectable } from '@nestjs/common';
import { Injectable, Logger } from '@nestjs/common';
import { ConnectedAccount } from '@prisma/client';
import { SessionCache } from '../../base';
import {
Config,
InvalidAuthState,
InvalidOauthCallbackState,
MissingOauthQueryParameter,
OauthStateExpired,
SignUpForbidden,
UnknownOauthProvider,
} from '../../base';
import {
AuthChallengeStore,
AuthService,
OAuthStateEnvelopeSchema,
type VerifiedIdentity,
} from '../../core/auth';
import { Models } from '../../models';
import { OAuthProviderName } from './config';
import { OAuthProviderFactory } from './factory';
import { OAuthAccount, Tokens } from './providers/def';
import { OAuthPkceChallenge, OAuthState } from './types';
const OAUTH_STATE_KEY = 'OAUTH_STATE';
type HandoffResult = {
type: 'handoff';
code: string;
provider: unknown;
state: OAuthState;
stateToken: string;
};
type IdentityResult = {
type: 'identity';
identity: VerifiedIdentity;
state: OAuthState;
};
type VerifyCallbackResult = HandoffResult | IdentityResult;
const OAUTH_STATE_TTL_MS = 3600 * 3 * 1000;
@Injectable()
export class OAuthService {
private readonly logger = new Logger(OAuthService.name);
constructor(
private readonly providerFactory: OAuthProviderFactory,
private readonly cache: SessionCache
private readonly challenges: AuthChallengeStore,
private readonly auth: AuthService,
private readonly models: Models,
private readonly config: Config
) {}
isValidState(stateStr: string) {
@@ -20,23 +59,191 @@ export class OAuthService {
}
async saveOAuthState(state: OAuthState) {
const token = randomUUID();
const payload: OAuthState = { ...state, token };
await this.cache.set(`${OAUTH_STATE_KEY}:${token}`, payload, {
ttl: 3600 * 3 * 1000 /* 3 hours */,
});
return token;
return this.challenges.create<OAuthState>(
'oauth_state',
token => ({ ...state, token }),
OAUTH_STATE_TTL_MS
);
}
async getOAuthState(token: string) {
return this.cache.get<OAuthState>(`${OAUTH_STATE_KEY}:${token}`);
return this.challenges.get<OAuthState>('oauth_state', token);
}
availableOAuthProviders() {
return this.providerFactory.providers;
}
async verifyCallback(input: {
code: string;
stateStr: string;
clientNonce?: string;
rawBody?: Buffer;
}): Promise<VerifyCallbackResult> {
let stateStr = input.stateStr;
let rawState: { state: string; provider?: string } | null = null;
if (typeof stateStr === 'string' && stateStr.length > 36) {
try {
const parsed = OAuthStateEnvelopeSchema.safeParse(JSON.parse(stateStr));
if (parsed.success) {
rawState = parsed.data;
stateStr = rawState.state;
}
} catch {} // noop
}
if (typeof stateStr !== 'string' || !this.isValidState(stateStr)) {
throw new InvalidOauthCallbackState();
}
const state = await this.getOAuthState(stateStr);
if (!state) throw new OauthStateExpired();
if (!state.token) state.token = stateStr;
if (
state.provider === OAuthProviderName.Apple &&
rawState &&
state.client &&
state.client !== 'web'
) {
return {
type: 'handoff',
code: input.code,
provider: rawState.provider,
state,
stateToken: stateStr,
};
}
if (!state.provider) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
const provider = this.providerFactory.get(state.provider);
if (!provider) {
throw new UnknownOauthProvider({ name: state.provider ?? 'unknown' });
}
if (
state.provider !== OAuthProviderName.Apple &&
(!input.clientNonce ||
!state.clientNonce ||
state.clientNonce !== input.clientNonce)
) {
throw new InvalidAuthState();
}
return {
type: 'identity',
identity: await this.verifyCallbackIdentity(
input.code,
state,
stateStr,
input.rawBody
),
state,
};
}
async verifyCallbackIdentity(
code: string,
state: OAuthState,
stateStr: string,
rawBody?: Buffer
): Promise<VerifiedIdentity> {
if (!state.provider) {
throw new UnknownOauthProvider({ name: 'unknown' });
}
const provider = this.providerFactory.get(state.provider);
if (!provider) {
throw new UnknownOauthProvider({ name: state.provider });
}
let tokens: Tokens;
try {
tokens = await provider.getToken(code, state);
} catch (err) {
const rawBodyString = rawBody
? rawBody.subarray(0, 4096).toString('utf-8')
: '';
this.logger.warn(
`Error getting oauth token for ${state.provider}, callback code: ${code}, stateStr: ${stateStr}, rawBody: ${rawBodyString}, error: ${err}`
);
throw err;
}
const externalAccount = await provider.getUser(tokens, state);
const user = await this.getOrCreateUserFromOauth(
state.provider,
externalAccount,
tokens
);
return {
userId: user.id,
method: 'oauth',
clientVersion: state.clientVersion,
};
}
private async getOrCreateUserFromOauth(
provider: OAuthProviderName,
externalAccount: OAuthAccount,
tokens: Tokens
) {
const connectedAccount = await this.models.user.getConnectedAccount(
provider,
externalAccount.id
);
if (connectedAccount) {
await this.updateConnectedAccount(connectedAccount, tokens);
if (
!connectedAccount.user.emailVerifiedAt &&
externalAccount.email.toLowerCase() ===
connectedAccount.user.email.toLowerCase()
) {
await this.auth.setEmailVerified(connectedAccount.userId);
}
return connectedAccount.user;
}
if (!this.config.auth.allowSignupForOauth) {
throw new SignUpForbidden();
}
const user = await this.models.user.fulfill(externalAccount.email, {
name: externalAccount.name,
avatarUrl: externalAccount.avatarUrl,
});
await this.models.user.createConnectedAccount({
userId: user.id,
provider,
providerAccountId: externalAccount.id,
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
return user;
}
private async updateConnectedAccount(
connectedAccount: ConnectedAccount,
tokens: Tokens
) {
return await this.models.user.updateConnectedAccount(connectedAccount.id, {
accessToken: tokens.accessToken,
refreshToken: tokens.refreshToken,
expiresAt: tokens.expiresAt,
});
}
createPkcePair(): OAuthPkceChallenge {
const codeVerifier = this.randomBase64Url(96);
const hash = createHash('sha256').update(codeVerifier).digest();
+1 -1
View File
@@ -2679,7 +2679,7 @@ type UserType {
"""Get user settings"""
settings: UserSettingsType!
subscriptions: [SubscriptionType!]!
token: tokenType! @deprecated(reason: "use [/api/auth/sign-in?native=true] instead")
token: tokenType! @deprecated(reason: "use native session exchange instead")
}
type ValidationErrorDataType {
+1 -1
View File
@@ -1887,7 +1887,7 @@ export const getCurrentUserQuery = {
}
}
}`,
deprecations: ["'token' is deprecated: use [/api/auth/sign-in?native=true] instead"],
deprecations: ["'token' is deprecated: use native session exchange instead"],
};
export const getDocCreatedByUpdatedByListQuery = {
+1 -1
View File
@@ -3409,7 +3409,7 @@ export interface UserType {
/** Get user settings */
settings: UserSettingsType;
subscriptions: Array<SubscriptionType>;
/** @deprecated use [/api/auth/sign-in?native=true] instead */
/** @deprecated use native session exchange instead */
token: TokenType;
}
@@ -1,9 +1,6 @@
package app.affine.pro
import android.webkit.WebView
import app.affine.pro.service.CookieStore
import app.affine.pro.utils.dataStore
import app.affine.pro.utils.get
import app.affine.pro.utils.getCurrentServerBaseUrl
import app.affine.pro.utils.logger.FileTree
import com.getcapacitor.Bridge
@@ -11,7 +8,6 @@ import com.getcapacitor.WebViewListener
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.MainScope
import kotlinx.coroutines.launch
import okhttp3.Cookie
import okhttp3.HttpUrl.Companion.toHttpUrl
import timber.log.Timber
@@ -23,30 +19,11 @@ object AuthInitializer {
bridge.removeWebViewListener(this)
MainScope().launch(Dispatchers.IO) {
try {
val server = bridge.getCurrentServerBaseUrl().toHttpUrl()
val sessionCookieStr = AFFiNEApp.context().dataStore
.get(server.host + CookieStore.AFFINE_SESSION)
val userIdCookieStr = AFFiNEApp.context().dataStore
.get(server.host + CookieStore.AFFINE_USER_ID)
val csrfCookieStr = AFFiNEApp.context().dataStore
.get(server.host + CookieStore.AFFINE_CSRF_TOKEN)
if (sessionCookieStr.isEmpty() || userIdCookieStr.isEmpty() || csrfCookieStr.isEmpty()) {
Timber.i("[init] user has not signed in yet.")
return@launch
}
Timber.i("[init] user already signed in.")
val cookies = listOf(
Cookie.parse(server, sessionCookieStr)
?: error("Parse session cookie fail:[ cookie = $sessionCookieStr ]"),
Cookie.parse(server, userIdCookieStr)
?: error("Parse user id cookie fail:[ cookie = $userIdCookieStr ]"),
Cookie.parse(server, csrfCookieStr)
?: error("Parse csrf token cookie fail:[ cookie = $csrfCookieStr ]"),
FileTree.get()?.checkAndUploadOldLogs(
bridge.getCurrentServerBaseUrl().toHttpUrl()
)
CookieStore.saveCookies(server.host, cookies)
FileTree.get()?.checkAndUploadOldLogs(server)
} catch (e: Exception) {
Timber.w(e, "[init] load persistent cookies fail.")
Timber.w(e, "[init] auth initializer fail.")
}
}
}
@@ -1,8 +1,16 @@
package app.affine.pro.plugin
import android.annotation.SuppressLint
import android.security.keystore.KeyGenParameterSpec
import android.security.keystore.KeyProperties
import android.util.Base64
import app.affine.pro.AFFiNEApp
import app.affine.pro.service.AuthHttp
import app.affine.pro.service.CookieStore
import app.affine.pro.service.OkHttp
import app.affine.pro.utils.dataStore
import app.affine.pro.utils.del
import app.affine.pro.utils.get
import app.affine.pro.utils.set
import com.getcapacitor.JSObject
import com.getcapacitor.Plugin
import com.getcapacitor.PluginCall
@@ -10,6 +18,7 @@ import com.getcapacitor.PluginMethod
import com.getcapacitor.annotation.CapacitorPlugin
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import okhttp3.HttpUrl
import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.MediaType.Companion.toMediaTypeOrNull
import okhttp3.Request
@@ -17,10 +26,90 @@ import okhttp3.RequestBody.Companion.toRequestBody
import okhttp3.coroutines.executeAsync
import org.json.JSONObject
import timber.log.Timber
import java.security.KeyStore
import javax.crypto.Cipher
import javax.crypto.KeyGenerator
import javax.crypto.SecretKey
import javax.crypto.spec.GCMParameterSpec
@OptIn(ExperimentalCoroutinesApi::class)
@CapacitorPlugin(name = "Auth")
class AuthPlugin : Plugin() {
private fun canonicalEndpoint(endpoint: String): String = try {
val url = endpoint.toHttpUrl()
val port = if (url.port == HttpUrl.defaultPort(url.scheme)) "" else ":${url.port}"
"${url.scheme}://${url.host}$port"
} catch (_: Exception) {
endpoint
}
private fun tokenKey(endpoint: String) = "auth-token:${canonicalEndpoint(endpoint)}"
private fun legacyTokenKey(endpoint: String) = "auth-token:$endpoint"
private val tokenCipher = TokenCipher()
@PluginMethod
fun readEndpointToken(call: PluginCall) {
launch(Dispatchers.IO) {
try {
val endpoint = call.getStringEnsure("endpoint")
val key = tokenKey(endpoint)
val legacyKey = legacyTokenKey(endpoint)
val store = AFFiNEApp.context().dataStore
val storedKey = key.takeIf { store.get(it).isNotEmpty() }
?: legacyKey.takeIf { it != key && store.get(it).isNotEmpty() }
val storedToken = storedKey?.let { store.get(it) }?.takeIf { it.isNotEmpty() }
val token = storedToken?.let {
tokenCipher.decrypt(it) ?: tokenCipher.legacyPlaintext(it)
}
if (
storedToken != null &&
token != null &&
(storedKey != key || !tokenCipher.isEncrypted(storedToken))
) {
store.set(key, tokenCipher.encrypt(token))
storedKey?.let {
if (it != key) {
store.del(it)
}
}
}
call.resolve(JSObject().put("token", token))
} catch (e: Exception) {
call.reject("Failed to read endpoint token.", null, e)
}
}
}
@PluginMethod
fun writeEndpointToken(call: PluginCall) {
launch(Dispatchers.IO) {
try {
val endpoint = call.getStringEnsure("endpoint")
val token = call.getStringEnsure("token")
AFFiNEApp.context().dataStore.set(
tokenKey(endpoint),
tokenCipher.encrypt(token)
)
call.resolve(JSObject().put("ok", true))
} catch (e: Exception) {
call.reject("Failed to write endpoint token.", null, e)
}
}
}
@PluginMethod
fun deleteEndpointToken(call: PluginCall) {
launch(Dispatchers.IO) {
try {
val endpoint = call.getStringEnsure("endpoint")
AFFiNEApp.context().dataStore.del(tokenKey(endpoint))
AFFiNEApp.context().dataStore.del(legacyTokenKey(endpoint))
call.resolve(JSObject().put("ok", true))
} catch (e: Exception) {
call.reject("Failed to delete endpoint token.", null, e)
}
}
}
@PluginMethod
fun signInMagicLink(call: PluginCall) {
@@ -32,6 +121,11 @@ class AuthPlugin : Plugin() {
processSignIn(call, SignInMethod.Oauth)
}
@PluginMethod
fun signInOpenApp(call: PluginCall) {
processSignIn(call, SignInMethod.OpenApp)
}
@SuppressLint("BuildListAdds")
@PluginMethod
fun signInPassword(call: PluginCall) {
@@ -43,21 +137,22 @@ class AuthPlugin : Plugin() {
launch(Dispatchers.IO) {
try {
val endpoint = call.getStringEnsure("endpoint")
val csrfToken = CookieStore.getCookie(endpoint.toHttpUrl(), CookieStore.AFFINE_CSRF_TOKEN)
val token = call.getString("token")
val request = Request.Builder()
.url("$endpoint/api/auth/sign-out")
.post("".toRequestBody("application/json".toMediaTypeOrNull()))
.apply {
if (csrfToken != null) {
addHeader("x-affine-csrf-token", csrfToken)
if (token != null) {
addHeader("Authorization", "Bearer $token")
}
}
.build()
OkHttp.client.newCall(request).executeAsync().use { response ->
AuthHttp.client.newCall(request).executeAsync().use { response ->
if (response.code >= 400) {
call.reject(response.body.string())
return@launch
}
CookieStore.clearAuthCookies(endpoint.toHttpUrl().host)
Timber.i("Sign out success.")
call.resolve(JSObject().put("ok", true))
}
@@ -69,7 +164,7 @@ class AuthPlugin : Plugin() {
}
private enum class SignInMethod {
Password, Oauth, MagicLink
Password, Oauth, MagicLink, OpenApp
}
private fun processSignIn(call: PluginCall, method: SignInMethod) {
@@ -92,6 +187,7 @@ class AuthPlugin : Plugin() {
val requestBuilder = Request.Builder()
.url("$endpoint/api/auth/sign-in")
.addHeader("x-affine-client-kind", "native")
.post(body)
if (verifyToken != null) {
requestBuilder.addHeader("x-captcha-token", verifyToken)
@@ -117,6 +213,7 @@ class AuthPlugin : Plugin() {
Request.Builder()
.url("$endpoint/api/oauth/callback")
.addHeader("x-affine-client-kind", "native")
.post(body)
.build()
}
@@ -136,19 +233,41 @@ class AuthPlugin : Plugin() {
Request.Builder()
.url("$endpoint/api/auth/magic-link")
.addHeader("x-affine-client-kind", "native")
.post(body)
.build()
}
SignInMethod.OpenApp -> {
val code = call.getStringEnsure("code")
val body = JSONObject()
.apply { put("code", code) }
.toString()
.toRequestBody("application/json".toMediaTypeOrNull())
Request.Builder()
.url("$endpoint/api/auth/open-app/sign-in")
.addHeader("x-affine-client-kind", "native")
.post(body)
.build()
}
}
OkHttp.client.newCall(request).executeAsync().use { response ->
AuthHttp.client.newCall(request).executeAsync().use { response ->
if (response.code >= 400) {
call.reject(response.body.string())
return@launch
}
CookieStore.getCookie(endpoint.toHttpUrl(), CookieStore.AFFINE_SESSION)?.let {
val exchangeCode = JSONObject(response.body.string()).optString("exchangeCode").takeIf { it.isNotEmpty() }
if (exchangeCode == null) {
Timber.w("$method sign in fail, exchange code not found.")
call.reject("$method sign in fail, exchange code not found")
return@launch
}
val token = exchangeSession(endpoint, exchangeCode)
token.takeIf { it.isNotEmpty() }?.let {
CookieStore.clearAuthCookies(endpoint.toHttpUrl().host)
Timber.i("$method sign in success.")
Timber.d("Update session [$it]")
call.resolve(JSObject().put("token", it))
} ?: run {
Timber.w("$method sign in fail, token not found.")
@@ -161,4 +280,88 @@ class AuthPlugin : Plugin() {
}
}
}
private suspend fun exchangeSession(endpoint: String, code: String): String {
val body = JSONObject()
.apply { put("code", code) }
.toString()
.toRequestBody("application/json".toMediaTypeOrNull())
val request = Request.Builder()
.url("$endpoint/api/auth/native/exchange")
.addHeader("x-affine-client-kind", "native")
.post(body)
.build()
AuthHttp.client.newCall(request).executeAsync().use { response ->
if (response.code >= 400) {
throw Exception(response.body.string())
}
return JSONObject(response.body.string()).optString("token")
}
}
}
private class TokenCipher {
private val alias = "affine-native-auth-token"
private val transformation = "AES/GCM/NoPadding"
fun encrypt(plaintext: String): String {
val cipher = Cipher.getInstance(transformation)
cipher.init(Cipher.ENCRYPT_MODE, secretKey())
val ciphertext = cipher.doFinal(plaintext.toByteArray(Charsets.UTF_8))
return listOf(
"v1",
Base64.encodeToString(cipher.iv, Base64.NO_WRAP),
Base64.encodeToString(ciphertext, Base64.NO_WRAP),
).joinToString(":")
}
fun decrypt(encoded: String): String? {
val parts = encoded.split(":")
if (parts.size != 3 || parts[0] != "v1") {
return null
}
return try {
val iv = Base64.decode(parts[1], Base64.NO_WRAP)
val ciphertext = Base64.decode(parts[2], Base64.NO_WRAP)
val cipher = Cipher.getInstance(transformation)
cipher.init(
Cipher.DECRYPT_MODE,
secretKey(),
GCMParameterSpec(128, iv)
)
String(cipher.doFinal(ciphertext), Charsets.UTF_8)
} catch (e: Exception) {
Timber.w(e, "Failed to decrypt auth token.")
null
}
}
fun isEncrypted(value: String) = value.startsWith("v1:")
fun legacyPlaintext(value: String) =
value.takeIf { !isEncrypted(it) && it.isNotBlank() }
private fun secretKey(): SecretKey {
val keyStore = KeyStore.getInstance("AndroidKeyStore").apply { load(null) }
(keyStore.getEntry(alias, null) as? KeyStore.SecretKeyEntry)?.let {
return it.secretKey
}
val keyGenerator = KeyGenerator.getInstance(
KeyProperties.KEY_ALGORITHM_AES,
"AndroidKeyStore"
)
val spec = KeyGenParameterSpec.Builder(
alias,
KeyProperties.PURPOSE_ENCRYPT or KeyProperties.PURPOSE_DECRYPT
)
.setBlockModes(KeyProperties.BLOCK_MODE_GCM)
.setEncryptionPaddings(KeyProperties.ENCRYPTION_PADDING_NONE)
.setRandomizedEncryptionRequired(true)
.build()
keyGenerator.init(spec)
return keyGenerator.generateKey()
}
}
@@ -3,6 +3,7 @@ package app.affine.pro.service
import app.affine.pro.AFFiNEApp
import app.affine.pro.CapacitorConfig
import app.affine.pro.utils.dataStore
import app.affine.pro.utils.del
import app.affine.pro.utils.set
import com.google.firebase.crashlytics.ktx.crashlytics
import com.google.firebase.ktx.Firebase
@@ -50,6 +51,20 @@ object OkHttp {
}
object AuthHttp {
val client = OkHttpClient.Builder()
.cookieJar(CookieJar.NO_COOKIES)
.addInterceptor {
it.proceed(
it.request()
.newBuilder()
.addHeader("x-affine-version", CapacitorConfig.getAffineVersion())
.build()
)
}
.build()
}
object CookieStore {
const val AFFINE_SESSION = "affine_session"
@@ -61,9 +76,6 @@ object CookieStore {
fun saveCookies(host: String, cookies: List<Cookie>) {
_cookies[host] = cookies
MainScope().launch(Dispatchers.IO) {
cookies.find { it.name == AFFINE_SESSION }?.let {
AFFiNEApp.context().dataStore.set(host + AFFINE_SESSION, it.toString())
}
cookies.find { it.name == AFFINE_USER_ID }?.let {
Timber.d("Update user id [${it.value}]")
AFFiNEApp.context().dataStore.set(host + AFFINE_USER_ID, it.toString())
@@ -77,6 +89,18 @@ object CookieStore {
fun getCookies(host: String) = _cookies[host] ?: emptyList()
fun clearAuthCookies(host: String) {
val cookies = _cookies[host] ?: emptyList()
_cookies[host] = cookies.filter {
it.name != AFFINE_SESSION && it.name != AFFINE_USER_ID && it.name != AFFINE_CSRF_TOKEN
}
MainScope().launch(Dispatchers.IO) {
AFFiNEApp.context().dataStore.del(host + AFFINE_USER_ID)
AFFiNEApp.context().dataStore.del(host + AFFINE_CSRF_TOKEN)
Firebase.crashlytics.setUserId("")
}
}
fun getCookie(url: HttpUrl, name: String) = url.host
.let { _cookies[it] }
?.find { cookie -> cookie.name == name }
@@ -17,6 +17,12 @@ suspend fun DataStore<Preferences>.set(key: String, value: String) {
}
}
suspend fun DataStore<Preferences>.del(key: String) {
edit {
it.remove(stringPreferencesKey(key))
}
}
suspend fun DataStore<Preferences>.get(key: String) = data.map {
it[stringPreferencesKey(key)] ?: ""
}.first()
}.first()
+32 -3
View File
@@ -57,7 +57,11 @@ import { Auth } from './plugins/auth';
import { HashCash } from './plugins/hashcash';
import { NbStoreNativeDBApis } from './plugins/nbstore';
import { Preview } from './plugins/preview';
import { writeEndpointToken } from './proxy';
import {
deleteEndpointToken,
readEndpointToken,
writeEndpointToken,
} from './proxy';
const storeManagerClient = createStoreManagerClient();
setTelemetryTransport(storeManagerClient.telemetry);
@@ -206,10 +210,20 @@ framework.scope(ServerScope).override(AuthProvider, resolver => {
});
await writeEndpointToken(endpoint, token);
},
async signOut() {
await Auth.signOut({
async signInOpenAppSignInCode(code) {
const { token } = await Auth.signInOpenApp({
endpoint,
code,
});
await writeEndpointToken(endpoint, token);
},
async signOut() {
const token = await readEndpointToken(endpoint);
try {
await Auth.signOut({ endpoint, token });
} finally {
await deleteEndpointToken(endpoint);
}
},
};
});
@@ -442,5 +456,20 @@ function createStoreManagerClient() {
},
[nativeDBApiChannelClient]
);
const { port1: authTokenChannelServer, port2: authTokenChannelClient } =
new MessageChannel();
authTokenChannelServer.addEventListener('message', event => {
const { id, endpoint } = event.data as { id?: string; endpoint?: string };
if (!id || !endpoint) return;
readEndpointToken(endpoint)
.then(token => authTokenChannelServer.postMessage({ id, token }))
.catch(() => authTokenChannelServer.postMessage({ id, token: null }));
});
authTokenChannelServer.start();
worker.postMessage(
{ type: 'native-auth-token-channel', port: authTokenChannelClient },
[authTokenChannelClient]
);
return new StoreManagerClient(new OpClient(worker));
}
@@ -18,19 +18,28 @@ import {
import { type MessageCommunicapable, OpConsumer } from '@toeverything/infra/op';
import { AsyncCall } from 'async-call-rpc';
import { readEndpointToken } from './proxy';
let authTokenPort: MessagePort | undefined;
const pendingTokenRequests = new Map<string, (token: string | null) => void>();
configureSocketAuthMethod((endpoint, cb) => {
readEndpointToken(endpoint)
.then(token => {
cb({ token });
})
.catch(e => {
console.error(e);
});
.then(token => cb(token ? { token, tokenType: 'jwt' } : {}))
.catch(() => cb({}));
});
globalThis.addEventListener('message', e => {
if (e.data.type === 'native-auth-token-channel') {
authTokenPort = e.ports[0] as MessagePort;
authTokenPort.addEventListener('message', e => {
const { id, token } = e.data as { id?: string; token?: string | null };
if (!id) return;
pendingTokenRequests.get(id)?.(token ?? null);
pendingTokenRequests.delete(id);
});
authTokenPort.start();
return;
}
if (e.data.type === 'native-db-api-channel') {
const port = e.ports[0] as MessagePort;
const rpc = AsyncCall<NativeDBApis>(
@@ -57,6 +66,25 @@ globalThis.addEventListener('message', e => {
}
});
function readEndpointToken(endpoint: string) {
if (!authTokenPort) {
return Promise.resolve(null);
}
const id = `${Date.now()}:${Math.random()}`;
return new Promise<string | null>(resolve => {
const timeout = setTimeout(() => {
pendingTokenRequests.delete(id);
resolve(null);
}, 5000);
pendingTokenRequests.set(id, token => {
clearTimeout(timeout);
resolve(token);
});
authTokenPort?.postMessage({ id, endpoint });
});
}
const consumer = new OpConsumer<WorkerManagerOps>(
globalThis as MessageCommunicapable
);
@@ -18,5 +18,17 @@ export interface AuthPlugin {
verifyToken?: string;
challenge?: string;
}): Promise<{ token: string }>;
signOut(options: { endpoint: string }): Promise<void>;
signInOpenApp(options: {
endpoint: string;
code: string;
}): Promise<{ token: string }>;
signOut(options: { endpoint: string; token?: string | null }): Promise<void>;
readEndpointToken(options: {
endpoint: string;
}): Promise<{ token?: string | null }>;
writeEndpointToken(options: {
endpoint: string;
token: string;
}): Promise<void>;
deleteEndpointToken(options: { endpoint: string }): Promise<void>;
}
+53 -24
View File
@@ -1,4 +1,19 @@
import { openDB } from 'idb';
import { Auth } from './plugins/auth';
function authEndpointForUrl(url: string | URL) {
try {
const parsed = new URL(url, globalThis.location.origin);
return parsed.protocol === 'http:' || parsed.protocol === 'https:'
? parsed.origin
: null;
} catch {
return null;
}
}
function canonicalEndpoint(endpoint: string) {
return authEndpointForUrl(endpoint) ?? endpoint;
}
/**
* the below code includes the custom fetch and xmlhttprequest implementation for ios webview.
@@ -8,9 +23,11 @@ const rawFetch = globalThis.fetch;
globalThis.fetch = async (input: RequestInfo | URL, init?: RequestInit) => {
const request = new Request(input, init);
const origin = new URL(request.url, globalThis.location.origin).origin;
const origin = authEndpointForUrl(request.url);
const token = await readEndpointToken(origin);
const token = origin
? await readEndpointToken(origin).catch(() => null)
: null;
if (token) {
request.headers.set('Authorization', `Bearer ${token}`);
}
@@ -19,11 +36,30 @@ globalThis.fetch = async (input: RequestInfo | URL, init?: RequestInit) => {
};
const rawXMLHttpRequest = globalThis.XMLHttpRequest;
const xhrRequestUrls = new WeakMap<XMLHttpRequest, string>();
globalThis.XMLHttpRequest = class extends rawXMLHttpRequest {
override send(body?: Document | XMLHttpRequestBodyInit | null): void {
const origin = new URL(this.responseURL, globalThis.location.origin).origin;
override open(
method: string,
url: string | URL,
async: boolean = true,
username?: string | null,
password?: string | null
): void {
xhrRequestUrls.set(this, url.toString());
return super.open(
method,
url,
async,
username ?? undefined,
password ?? undefined
);
}
readEndpointToken(origin).then(
override send(body?: Document | XMLHttpRequestBodyInit | null): void {
const requestUrl = xhrRequestUrls.get(this);
const origin = authEndpointForUrl(requestUrl ?? globalThis.location.href);
(origin ? readEndpointToken(origin) : Promise.resolve(null)).then(
token => {
if (token) {
this.setRequestHeader('Authorization', `Bearer ${token}`);
@@ -31,7 +67,7 @@ globalThis.XMLHttpRequest = class extends rawXMLHttpRequest {
return super.send(body);
},
() => {
throw new Error('Failed to read token');
return super.send(body);
}
);
}
@@ -40,26 +76,19 @@ globalThis.XMLHttpRequest = class extends rawXMLHttpRequest {
export async function readEndpointToken(
endpoint: string
): Promise<string | null> {
const idb = await openDB('affine-token', 1, {
upgrade(db) {
if (!db.objectStoreNames.contains('tokens')) {
db.createObjectStore('tokens', { keyPath: 'endpoint' });
}
},
const { token } = await Auth.readEndpointToken({
endpoint: canonicalEndpoint(endpoint),
});
const token = await idb.get('tokens', endpoint);
return token ? token.token : null;
return token ?? null;
}
export async function writeEndpointToken(endpoint: string, token: string) {
const db = await openDB('affine-token', 1, {
upgrade(db) {
if (!db.objectStoreNames.contains('tokens')) {
db.createObjectStore('tokens', { keyPath: 'endpoint' });
}
},
await Auth.writeEndpointToken({
endpoint: canonicalEndpoint(endpoint),
token,
});
await db.put('tokens', { endpoint, token });
}
export async function deleteEndpointToken(endpoint: string) {
await Auth.deleteEndpointToken({ endpoint: canonicalEndpoint(endpoint) });
}
@@ -2,7 +2,12 @@ import { configureElectronStateStorageImpls } from '@affine/core/desktop/storage
import { configureCommonModules } from '@affine/core/modules';
import { configureAppTabsHeaderModule } from '@affine/core/modules/app-tabs-header';
import { configureDesktopBackupModule } from '@affine/core/modules/backup';
import { ValidatorProvider } from '@affine/core/modules/cloud';
import {
AuthProvider,
ServerScope,
ServerService,
ValidatorProvider,
} from '@affine/core/modules/cloud';
import {
configureDesktopApiModule,
DesktopApiService,
@@ -63,6 +68,39 @@ export function setupModules() {
},
};
});
framework.scope(ServerScope).override(AuthProvider, p => {
const apis = p.get(DesktopApiService).api;
const serverService = p.get(ServerService);
const endpoint = serverService.server.baseUrl;
return {
async signInMagicLink(email, token, clientNonce) {
await apis.handler.auth.signInMagicLink(
endpoint,
email,
token,
clientNonce
);
},
async signInOauth(code, state, _provider, clientNonce) {
return await apis.handler.auth.signInOauth(
endpoint,
code,
state,
clientNonce
);
},
async signInPassword(credential) {
await apis.handler.auth.signInPassword(endpoint, credential);
},
async signInOpenAppSignInCode(code) {
await apis.handler.auth.signInOpenAppSignInCode(endpoint, code);
},
async signOut() {
await apis.handler.auth.signOut(endpoint);
},
};
});
const frameworkProvider = framework.provider();
@@ -2,7 +2,10 @@ import '@affine/core/bootstrap/electron';
import { apis } from '@affine/electron-api';
import { broadcastChannelStorages } from '@affine/nbstore/broadcast-channel';
import { cloudStorages } from '@affine/nbstore/cloud';
import {
cloudStorages,
configureSocketAuthMethod,
} from '@affine/nbstore/cloud';
import { bindNativeDBApis, sqliteStorages } from '@affine/nbstore/sqlite';
import {
bindNativeDBV1Apis,
@@ -18,6 +21,15 @@ import { OpConsumer } from '@toeverything/infra/op';
bindNativeDBApis(apis!.nbstore);
// oxlint-disable-next-line no-non-null-assertion
bindNativeDBV1Apis(apis!.db);
configureSocketAuthMethod((endpoint, cb) => {
// oxlint-disable-next-line no-non-null-assertion
apis!.auth
.readEndpointToken(endpoint)
.then(({ token }: { token?: string | null }) => {
cb(token ? { token, tokenType: 'jwt' } : {});
})
.catch(() => cb({}));
});
const storeManager = new StoreManagerConsumer([
...sqliteStorages,
@@ -0,0 +1,177 @@
import { net, session } from 'electron';
import { logger } from '../logger';
import type { NamespaceHandlers } from '../type';
import {
deleteNativeAuthToken,
getNativeAuthToken,
setNativeAuthToken,
} from './native-token';
interface SignInResponse {
exchangeCode?: string;
redirectUri?: string;
}
interface ExchangeResponse {
token?: string;
}
const authCookieNames = [
'affine_session',
'affine_user_id',
'affine_csrf_token',
];
function authUrl(endpoint: string, path: string) {
return new URL(path, endpoint).toString();
}
async function readJson<T>(response: Response): Promise<T> {
const text = await response.text();
if (!response.ok) {
throw new Error(text || response.statusText);
}
return text ? JSON.parse(text) : ({} as T);
}
async function fetchAuth(endpoint: string, path: string, body?: unknown) {
return await net.fetch(authUrl(endpoint, path), {
method: 'POST',
headers: {
'content-type': 'application/json',
'x-affine-client-kind': 'native',
'x-affine-version': BUILD_CONFIG.appVersion,
},
body: body === undefined ? undefined : JSON.stringify(body),
});
}
async function clearAuthCookies(endpoint: string) {
await Promise.all(
authCookieNames.map(name =>
session.defaultSession.cookies
.remove(endpoint, name)
.catch(error =>
logger.debug(
'failed to clear native auth cookie',
endpoint,
name,
error
)
)
)
);
}
async function exchangeSession(endpoint: string, response: SignInResponse) {
if (!response.exchangeCode) {
throw new Error('Missing native auth exchange code.');
}
const exchangeResponse = await fetchAuth(
endpoint,
'/api/auth/native/exchange',
{ code: response.exchangeCode }
);
const body = await readJson<ExchangeResponse>(exchangeResponse);
if (!body.token) {
throw new Error('Missing native auth token.');
}
setNativeAuthToken(endpoint, body.token);
await clearAuthCookies(endpoint);
}
export const authHandlers = {
signInMagicLink: async (
_,
endpoint: string,
email: string,
token: string,
clientNonce?: string
) => {
const response = await fetchAuth(endpoint, '/api/auth/magic-link', {
email,
token,
client_nonce: clientNonce,
});
await exchangeSession(endpoint, await readJson(response));
},
signInOauth: async (
_,
endpoint: string,
code: string,
state: string,
clientNonce?: string
) => {
const response = await fetchAuth(endpoint, '/api/oauth/callback', {
code,
state,
client_nonce: clientNonce,
});
const body = await readJson<SignInResponse>(response);
await exchangeSession(endpoint, body);
return { redirectUri: body.redirectUri };
},
signInPassword: async (
_,
endpoint: string,
credential: {
email: string;
password: string;
verifyToken?: string;
challenge?: string;
}
) => {
const response = await net.fetch(authUrl(endpoint, '/api/auth/sign-in'), {
method: 'POST',
headers: {
'content-type': 'application/json',
'x-affine-client-kind': 'native',
'x-affine-version': BUILD_CONFIG.appVersion,
...(credential.verifyToken
? { 'x-captcha-token': credential.verifyToken }
: {}),
...(credential.challenge
? { 'x-captcha-challenge': credential.challenge }
: {}),
},
body: JSON.stringify({
email: credential.email,
password: credential.password,
}),
});
await exchangeSession(endpoint, await readJson(response));
},
signInOpenAppSignInCode: async (_e, endpoint: string, code: string) => {
const response = await fetchAuth(endpoint, '/api/auth/open-app/sign-in', {
code,
});
await exchangeSession(endpoint, await readJson(response));
},
signOut: async (_e, endpoint: string) => {
const token = getNativeAuthToken(endpoint);
if (token) {
await net.fetch(authUrl(endpoint, '/api/auth/sign-out'), {
method: 'POST',
headers: {
Authorization: `Bearer ${token}`,
'x-affine-version': BUILD_CONFIG.appVersion,
},
});
}
deleteNativeAuthToken(endpoint);
await clearAuthCookies(endpoint);
},
readEndpointToken: async (_e, endpoint: string) => {
return { token: getNativeAuthToken(endpoint) };
},
} satisfies NamespaceHandlers;
@@ -0,0 +1,83 @@
import fs from 'node:fs';
import path from 'node:path';
import { app, safeStorage } from 'electron';
import { logger } from '../logger';
const FILEPATH = path.join(app.getPath('userData'), 'native-auth-tokens.json');
type TokenRecord = {
token: string;
};
function normalizeEndpoint(endpoint: string) {
return new URL(endpoint).origin;
}
function readStore(): Record<string, string> {
if (!fs.existsSync(FILEPATH)) return {};
try {
return JSON.parse(fs.readFileSync(FILEPATH, 'utf-8'));
} catch (error) {
logger.error('failed to read native auth token store', error);
return {};
}
}
function writeStore(store: Record<string, string>) {
fs.writeFileSync(FILEPATH, JSON.stringify(store, null, 2));
}
function encryptToken(record: TokenRecord) {
if (!safeStorage.isEncryptionAvailable()) {
throw new Error('Secure native auth token storage is not available.');
}
return safeStorage.encryptString(JSON.stringify(record)).toString('base64');
}
function decryptToken(value: string): TokenRecord | null {
if (!safeStorage.isEncryptionAvailable()) {
return null;
}
try {
return JSON.parse(safeStorage.decryptString(Buffer.from(value, 'base64')));
} catch (error) {
logger.error('failed to decrypt native auth token', error);
return null;
}
}
export function setNativeAuthToken(endpoint: string, token: string) {
const store = readStore();
store[normalizeEndpoint(endpoint)] = encryptToken({ token });
writeStore(store);
}
export function deleteNativeAuthToken(endpoint: string) {
const store = readStore();
delete store[normalizeEndpoint(endpoint)];
writeStore(store);
}
export function getNativeAuthToken(endpoint: string) {
const encrypted = readStore()[normalizeEndpoint(endpoint)];
if (!encrypted) return null;
return decryptToken(encrypted)?.token ?? null;
}
export function getAuthTokenForUrl(url: string) {
try {
const parsed = new URL(url);
if (parsed.protocol === 'ws:') {
parsed.protocol = 'http:';
} else if (parsed.protocol === 'wss:') {
parsed.protocol = 'https:';
}
return getNativeAuthToken(parsed.origin);
} catch {
return null;
}
}
@@ -2,6 +2,7 @@ import { I18n } from '@affine/i18n';
import { ipcMain } from 'electron';
import { AFFINE_API_CHANNEL_NAME } from '../shared/type';
import { authHandlers } from './auth/handlers';
import { byokStorageHandlers } from './byok-storage/handlers';
import { clipboardHandlers } from './clipboard';
import { configStorageHandlers } from './config-storage';
@@ -44,6 +45,7 @@ export const allHandlers = {
popup: popupHandlers,
i18n: i18nHandlers,
byokStorage: byokStorageHandlers,
auth: authHandlers,
};
export const registerHandlers = () => {
@@ -2,7 +2,6 @@ import path, { join } from 'node:path';
import { pathToFileURL } from 'node:url';
import { app, net, protocol, session } from 'electron';
import cookieParser from 'set-cookie-parser';
import { anotherHost, mainHost } from '../shared/internal-origin';
import {
@@ -12,6 +11,7 @@ import {
resolvePathInBase,
resourcesPath,
} from '../shared/utils';
import { getAuthTokenForUrl } from './auth/native-token';
import { buildType, isDev } from './config';
import { logger } from './logger';
@@ -64,7 +64,27 @@ function buildTargetUrl(base: string, urlObject: URL) {
return new URL(`${urlObject.pathname}${urlObject.search}`, base).toString();
}
function proxyRequest(
async function buildAuthorizedRequest(request: Request, targetUrl: string) {
const clonedRequest = request.clone();
const headers = new Headers(clonedRequest.headers);
const token = getAuthTokenForUrl(targetUrl);
if (token) {
headers.set('Authorization', `Bearer ${token}`);
}
return new Request(targetUrl, {
body:
clonedRequest.method === 'GET' || clonedRequest.method === 'HEAD'
? undefined
: clonedRequest.body,
headers,
method: clonedRequest.method,
redirect: clonedRequest.redirect,
signal: clonedRequest.signal,
});
}
async function proxyRequest(
request: Request,
urlObject: URL,
base: string,
@@ -72,12 +92,13 @@ function proxyRequest(
) {
const { bypassCustomProtocolHandlers = true } = options;
const targetUrl = buildTargetUrl(base, urlObject);
const authorizedRequest = await buildAuthorizedRequest(request, targetUrl);
const proxiedRequest = bypassCustomProtocolHandlers
? Object.assign(request.clone(), {
? Object.assign(authorizedRequest, {
bypassCustomProtocolHandlers: true,
})
: request;
return net.fetch(targetUrl, proxiedRequest);
: authorizedRequest;
return net.fetch(proxiedRequest);
}
async function handleFileRequest(request: Request) {
@@ -218,41 +239,6 @@ export function registerProtocol() {
const { responseHeaders, url } = responseDetails;
(async () => {
if (responseHeaders) {
const originalCookie =
responseHeaders['set-cookie'] || responseHeaders['Set-Cookie'];
if (originalCookie) {
// save the cookies, to support third party cookies
for (const cookies of originalCookie) {
const parsedCookies = cookieParser.parse(cookies);
for (const parsedCookie of parsedCookies) {
if (!parsedCookie.value) {
await session.defaultSession.cookies.remove(
responseDetails.url,
parsedCookie.name
);
} else {
await session.defaultSession.cookies.set({
url: responseDetails.url,
domain: parsedCookie.domain,
expirationDate: parsedCookie.expires?.getTime(),
httpOnly: parsedCookie.httpOnly,
secure: parsedCookie.secure,
value: parsedCookie.value,
name: parsedCookie.name,
path: parsedCookie.path,
sameSite: parsedCookie.sameSite?.toLowerCase() as
| 'unspecified'
| 'no_restriction'
| 'lax'
| 'strict'
| undefined,
});
}
}
}
}
const { protocol, hostname } = new URL(url);
// Adjust CORS for assets responses and allow blob redirects on affine domains
@@ -284,23 +270,17 @@ export function registerProtocol() {
const url = new URL(details.url);
(async () => {
// session cookies are set to assets:// on production
// if sending request to the cloud, attach the session cookie (to affine cloud server)
if (
url.protocol === 'http:' ||
url.protocol === 'https:' ||
url.protocol === 'ws:' ||
url.protocol === 'wss:'
) {
const cookies = await session.defaultSession.cookies.get({
url: details.url,
});
const cookieString = cookies
.map(c => `${c.name}=${c.value}`)
.join('; ');
delete details.requestHeaders['cookie'];
details.requestHeaders['Cookie'] = cookieString;
const token = getAuthTokenForUrl(details.url);
if (token) {
delete details.requestHeaders.authorization;
details.requestHeaders.Authorization = `Bearer ${token}`;
}
}
const hostname = url.hostname;
@@ -1,5 +1,6 @@
import Capacitor
import Foundation
import Security
public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
public let identifier = "AuthPlugin"
@@ -7,10 +8,70 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
public let pluginMethods: [CAPPluginMethod] = [
CAPPluginMethod(name: "signInMagicLink", returnType: CAPPluginReturnPromise),
CAPPluginMethod(name: "signInOauth", returnType: CAPPluginReturnPromise),
CAPPluginMethod(name: "signInOpenApp", returnType: CAPPluginReturnPromise),
CAPPluginMethod(name: "signInPassword", returnType: CAPPluginReturnPromise),
CAPPluginMethod(name: "signOut", returnType: CAPPluginReturnPromise),
CAPPluginMethod(name: "readEndpointToken", returnType: CAPPluginReturnPromise),
CAPPluginMethod(name: "writeEndpointToken", returnType: CAPPluginReturnPromise),
CAPPluginMethod(name: "deleteEndpointToken", returnType: CAPPluginReturnPromise),
]
private let tokenService = "app.affine.pro.auth-token"
private let authCookieNames = Set(["affine_session", "affine_user_id", "affine_csrf_token"])
private func canonicalEndpoint(_ endpoint: String) -> String {
guard let url = URL(string: endpoint), let scheme = url.scheme, let host = url.host else {
return endpoint
}
let normalizedScheme = scheme.lowercased()
let normalizedHost = host.lowercased()
let defaultPort: Int?
if normalizedScheme == "http" {
defaultPort = 80
} else if normalizedScheme == "https" {
defaultPort = 443
} else {
defaultPort = nil
}
let port = url.port.flatMap { $0 == defaultPort ? nil : ":\($0)" } ?? ""
return "\(normalizedScheme)://\(normalizedHost)\(port)"
}
@objc public func readEndpointToken(_ call: CAPPluginCall) {
do {
let endpoint = try call.getStringEnsure("endpoint")
if let token = try self.readToken(endpoint) {
call.resolve(["token": token])
} else {
call.resolve(["token": NSNull()])
}
} catch {
call.reject("Failed to read endpoint token, \(error)", nil, error)
}
}
@objc public func writeEndpointToken(_ call: CAPPluginCall) {
do {
let endpoint = try call.getStringEnsure("endpoint")
let token = try call.getStringEnsure("token")
try self.writeToken(endpoint, token)
call.resolve(["ok": true])
} catch {
call.reject("Failed to write endpoint token, \(error)", nil, error)
}
}
@objc public func deleteEndpointToken(_ call: CAPPluginCall) {
do {
let endpoint = try call.getStringEnsure("endpoint")
try self.deleteToken(endpoint)
call.resolve(["ok": true])
} catch {
call.reject("Failed to delete endpoint token, \(error)", nil, error)
}
}
@objc public func signInMagicLink(_ call: CAPPluginCall) {
Task {
do {
@@ -19,7 +80,11 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
let token = try call.getStringEnsure("token")
let clientNonce = call.getString("clientNonce")
let (data, response) = try await self.fetch(endpoint, method: "POST", action: "/api/auth/magic-link", headers: [:], body: ["email": email, "token": token, "client_nonce": clientNonce])
let (data, response) = try await self.fetch(
endpoint, method: "POST", action: "/api/auth/magic-link",
headers: [
"x-affine-client-kind": "native"
], body: ["email": email, "token": token, "client_nonce": clientNonce])
if response.statusCode >= 400 {
if let textBody = String(data: data, encoding: .utf8) {
@@ -30,12 +95,7 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
return
}
guard let token = try self.tokenFromCookie(endpoint) else {
call.reject("token not found")
return
}
call.resolve(["token": token])
call.resolve(["token": try await self.exchangeSession(endpoint, data)])
} catch {
call.reject("Failed to sign in, \(error)", nil, error)
}
@@ -50,7 +110,11 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
let state = try call.getStringEnsure("state")
let clientNonce = call.getString("clientNonce")
let (data, response) = try await self.fetch(endpoint, method: "POST", action: "/api/oauth/callback", headers: [:], body: ["code": code, "state": state, "client_nonce": clientNonce])
let (data, response) = try await self.fetch(
endpoint, method: "POST", action: "/api/oauth/callback",
headers: [
"x-affine-client-kind": "native"
], body: ["code": code, "state": state, "client_nonce": clientNonce])
if response.statusCode >= 400 {
if let textBody = String(data: data, encoding: .utf8) {
@@ -61,12 +125,7 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
return
}
guard let token = try self.tokenFromCookie(endpoint) else {
call.reject("token not found")
return
}
call.resolve(["token": token])
call.resolve(["token": try await self.exchangeSession(endpoint, data)])
} catch {
call.reject("Failed to sign in, \(error)", nil, error)
}
@@ -82,10 +141,13 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
let verifyToken = call.getString("verifyToken")
let challenge = call.getString("challenge")
let (data, response) = try await self.fetch(endpoint, method: "POST", action: "/api/auth/sign-in", headers: [
"x-captcha-token": verifyToken,
"x-captcha-challenge": challenge,
], body: ["email": email, "password": password])
let (data, response) = try await self.fetch(
endpoint, method: "POST", action: "/api/auth/sign-in",
headers: [
"x-affine-client-kind": "native",
"x-captcha-token": verifyToken,
"x-captcha-challenge": challenge,
], body: ["email": email, "password": password])
if response.statusCode >= 400 {
if let textBody = String(data: data, encoding: .utf8) {
@@ -96,12 +158,35 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
return
}
guard let token = try self.tokenFromCookie(endpoint) else {
call.reject("token not found")
call.resolve(["token": try await self.exchangeSession(endpoint, data)])
} catch {
call.reject("Failed to sign in, \(error)", nil, error)
}
}
}
@objc public func signInOpenApp(_ call: CAPPluginCall) {
Task {
do {
let endpoint = try call.getStringEnsure("endpoint")
let code = try call.getStringEnsure("code")
let (data, response) = try await self.fetch(
endpoint, method: "POST", action: "/api/auth/open-app/sign-in",
headers: [
"x-affine-client-kind": "native"
], body: ["code": code])
if response.statusCode >= 400 {
if let textBody = String(data: data, encoding: .utf8) {
call.reject(textBody)
} else {
call.reject("Failed to sign in")
}
return
}
call.resolve(["token": token])
call.resolve(["token": try await self.exchangeSession(endpoint, data)])
} catch {
call.reject("Failed to sign in, \(error)", nil, error)
}
@@ -112,11 +197,13 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
Task {
do {
let endpoint = try call.getStringEnsure("endpoint")
let csrfToken = try self.csrfTokenFromCookie(endpoint)
let token = call.getString("token")
let (data, response) = try await self.fetch(endpoint, method: "POST", action: "/api/auth/sign-out", headers: [
"x-affine-csrf-token": csrfToken,
], body: nil)
let (data, response) = try await self.fetch(
endpoint, method: "POST", action: "/api/auth/sign-out",
headers: [
"Authorization": token.map { "Bearer \($0)" }
], body: nil)
if response.statusCode >= 400 {
if let textBody = String(data: data, encoding: .utf8) {
@@ -127,6 +214,7 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
return
}
self.clearAuthCookies(endpoint)
call.resolve(["ok": true])
} catch {
call.reject("Failed to sign out, \(error)", nil, error)
@@ -134,38 +222,147 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
}
}
private func tokenFromCookie(_ endpoint: String) throws -> String? {
guard let endpointUrl = URL(string: endpoint) else {
throw AuthError.invalidEndpoint
private func tokenFromResponse(_ data: Data) throws -> String {
guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any],
let token = json["token"] as? String
else {
throw AuthError.tokenNotFound
}
if let cookie = HTTPCookieStorage.shared.cookies(for: endpointUrl)?.first(where: {
$0.name == "affine_session"
}) {
return cookie.value
} else {
return nil
return token
}
private func exchangeCodeFromResponse(_ data: Data) throws -> String {
guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any],
let code = json["exchangeCode"] as? String
else {
throw AuthError.exchangeCodeNotFound
}
return code
}
private func exchangeSession(_ endpoint: String, _ signInData: Data) async throws -> String {
let code = try exchangeCodeFromResponse(signInData)
let (data, response) = try await self.fetch(
endpoint, method: "POST", action: "/api/auth/native/exchange",
headers: [
"x-affine-client-kind": "native"
], body: ["code": code])
if response.statusCode >= 400 {
throw AuthError.exchangeFailed
}
let token = try tokenFromResponse(data)
self.clearAuthCookies(endpoint)
return token
}
private func clearAuthCookies(_ endpoint: String) {
guard let url = URL(string: endpoint), let host = url.host else {
return
}
let normalizedHost = host.lowercased()
HTTPCookieStorage.shared.cookies?.forEach { cookie in
let domain = cookie.domain.lowercased().trimmingCharacters(in: CharacterSet(charactersIn: "."))
let domainMatches = normalizedHost == domain || normalizedHost.hasSuffix(".\(domain)")
if domainMatches && authCookieNames.contains(cookie.name) {
HTTPCookieStorage.shared.deleteCookie(cookie)
}
}
}
private func csrfTokenFromCookie(_ endpoint: String) throws -> String? {
guard let endpointUrl = URL(string: endpoint) else {
throw AuthError.invalidEndpoint
}
return HTTPCookieStorage.shared.cookies(for: endpointUrl)?.first(where: {
$0.name == "affine_csrf_token"
})?.value
private func tokenQuery(_ endpoint: String) -> [String: Any] {
[
kSecClass as String: kSecClassGenericPassword,
kSecAttrService as String: tokenService,
kSecAttrAccount as String: canonicalEndpoint(endpoint),
]
}
private func fetch(_ endpoint: String, method: String, action: String, headers: [String: String?], body: Encodable?) async throws -> (Data, HTTPURLResponse) {
private func legacyTokenQuery(_ endpoint: String) -> [String: Any] {
[
kSecClass as String: kSecClassGenericPassword,
kSecAttrService as String: tokenService,
kSecAttrAccount as String: endpoint,
]
}
private func readToken(_ endpoint: String) throws -> String? {
var query = tokenQuery(endpoint)
query[kSecReturnData as String] = true
query[kSecMatchLimit as String] = kSecMatchLimitOne
var item: CFTypeRef?
let status = SecItemCopyMatching(query as CFDictionary, &item)
if status == errSecItemNotFound {
guard canonicalEndpoint(endpoint) != endpoint else {
return nil
}
var legacyQuery = legacyTokenQuery(endpoint)
legacyQuery[kSecReturnData as String] = true
legacyQuery[kSecMatchLimit as String] = kSecMatchLimitOne
let legacyStatus = SecItemCopyMatching(legacyQuery as CFDictionary, &item)
if legacyStatus == errSecItemNotFound {
return nil
}
guard legacyStatus == errSecSuccess, let data = item as? Data else {
throw AuthError.internalError
}
let token = String(data: data, encoding: .utf8)
if let token = token {
try writeToken(endpoint, token)
let deleteStatus = SecItemDelete(legacyTokenQuery(endpoint) as CFDictionary)
guard deleteStatus == errSecSuccess || deleteStatus == errSecItemNotFound else {
throw AuthError.internalError
}
}
return token
}
guard status == errSecSuccess, let data = item as? Data else {
throw AuthError.internalError
}
return String(data: data, encoding: .utf8)
}
private func writeToken(_ endpoint: String, _ token: String) throws {
try deleteToken(endpoint)
var query = tokenQuery(endpoint)
query[kSecValueData as String] = Data(token.utf8)
query[kSecAttrAccessible as String] = kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly
let status = SecItemAdd(query as CFDictionary, nil)
guard status == errSecSuccess else {
throw AuthError.internalError
}
}
private func deleteToken(_ endpoint: String) throws {
let status = SecItemDelete(tokenQuery(endpoint) as CFDictionary)
guard status == errSecSuccess || status == errSecItemNotFound else {
throw AuthError.internalError
}
if canonicalEndpoint(endpoint) != endpoint {
let legacyStatus = SecItemDelete(legacyTokenQuery(endpoint) as CFDictionary)
guard legacyStatus == errSecSuccess || legacyStatus == errSecItemNotFound else {
throw AuthError.internalError
}
}
}
private func fetch(
_ endpoint: String, method: String, action: String, headers: [String: String?], body: Encodable?
) async throws -> (Data, HTTPURLResponse) {
guard let targetUrl = URL(string: "\(endpoint)\(action)") else {
throw AuthError.invalidEndpoint
}
var request = URLRequest(url: targetUrl)
request.httpMethod = method
request.httpShouldHandleCookies = true
request.httpShouldHandleCookies = false
for (key, value) in headers {
request.setValue(value, forHTTPHeaderField: key)
}
@@ -174,7 +371,7 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
request.httpBody = try JSONEncoder().encode(body!)
}
request.setValue(AppConfigManager.getAffineVersion(), forHTTPHeaderField: "x-affine-version")
request.timeoutInterval = 10 // time out 10s
request.timeoutInterval = 10 // time out 10s
let (data, response) = try await URLSession.shared.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
@@ -185,5 +382,5 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
}
enum AuthError: Error {
case invalidEndpoint, internalError
case invalidEndpoint, internalError, tokenNotFound, exchangeCodeNotFound, exchangeFailed
}
@@ -57,7 +57,7 @@ public class GetCurrentUserQuery: GraphQLQuery {
public var emailVerified: Bool { __data["emailVerified"] }
/// User avatar url
public var avatarUrl: String? { __data["avatarUrl"] }
@available(*, deprecated, message: "use [/api/auth/sign-in?native=true] instead")
@available(*, deprecated, message: "use native session exchange instead")
public var token: Token { __data["token"] }
/// CurrentUser.Token
+35 -13
View File
@@ -74,7 +74,11 @@ import { Hashcash } from './plugins/hashcash';
import { NbStoreNativeDBApis } from './plugins/nbstore';
import { PayWall } from './plugins/paywall';
import { Preview } from './plugins/preview';
import { writeEndpointToken } from './proxy';
import {
deleteEndpointToken,
readEndpointToken,
writeEndpointToken,
} from './proxy';
import { enableNavigationGesture$ } from './web-navigation-control';
const storeManagerClient = createStoreManagerClient();
@@ -204,10 +208,20 @@ framework.scope(ServerScope).override(AuthProvider, resolver => {
});
await writeEndpointToken(endpoint, token);
},
async signOut() {
await Auth.signOut({
async signInOpenAppSignInCode(code) {
const { token } = await Auth.signInOpenApp({
endpoint,
code,
});
await writeEndpointToken(endpoint, token);
},
async signOut() {
const token = await readEndpointToken(endpoint);
try {
await Auth.signOut({ endpoint, token });
} finally {
await deleteEndpointToken(endpoint);
}
},
};
});
@@ -541,13 +555,9 @@ function createStoreManagerClient() {
AsyncCall<typeof NbStoreNativeDBApis>(NbStoreNativeDBApis, {
channel: {
on(listener) {
const f = (e: MessageEvent<any>) => {
listener(e.data);
};
const f = (e: MessageEvent<any>) => listener(e.data);
nativeDBApiChannelServer.addEventListener('message', f);
return () => {
nativeDBApiChannelServer.removeEventListener('message', f);
};
return () => nativeDBApiChannelServer.removeEventListener('message', f);
},
send(data) {
nativeDBApiChannelServer.postMessage(data);
@@ -557,11 +567,23 @@ function createStoreManagerClient() {
});
nativeDBApiChannelServer.start();
worker.postMessage(
{
type: 'native-db-api-channel',
port: nativeDBApiChannelClient,
},
{ type: 'native-db-api-channel', port: nativeDBApiChannelClient },
[nativeDBApiChannelClient]
);
const { port1: authTokenChannelServer, port2: authTokenChannelClient } =
new MessageChannel();
authTokenChannelServer.addEventListener('message', event => {
const { id, endpoint } = event.data as { id?: string; endpoint?: string };
if (!id || !endpoint) return;
readEndpointToken(endpoint)
.then(token => authTokenChannelServer.postMessage({ id, token }))
.catch(() => authTokenChannelServer.postMessage({ id, token: null }));
});
authTokenChannelServer.start();
worker.postMessage(
{ type: 'native-auth-token-channel', port: authTokenChannelClient },
[authTokenChannelClient]
);
return new StoreManagerClient(new OpClient(worker));
}
@@ -18,19 +18,28 @@ import {
import { type MessageCommunicapable, OpConsumer } from '@toeverything/infra/op';
import { AsyncCall } from 'async-call-rpc';
import { readEndpointToken } from './proxy';
let authTokenPort: MessagePort | undefined;
const pendingTokenRequests = new Map<string, (token: string | null) => void>();
configureSocketAuthMethod((endpoint, cb) => {
readEndpointToken(endpoint)
.then(token => {
cb({ token });
})
.catch(e => {
console.error(e);
});
.then(token => cb(token ? { token, tokenType: 'jwt' } : {}))
.catch(() => cb({}));
});
globalThis.addEventListener('message', e => {
if (e.data.type === 'native-auth-token-channel') {
authTokenPort = e.ports[0] as MessagePort;
authTokenPort.addEventListener('message', e => {
const { id, token } = e.data as { id?: string; token?: string | null };
if (!id) return;
pendingTokenRequests.get(id)?.(token ?? null);
pendingTokenRequests.delete(id);
});
authTokenPort.start();
return;
}
if (e.data.type === 'native-db-api-channel') {
const port = e.ports[0] as MessagePort;
const rpc = AsyncCall<NativeDBApis>(
@@ -57,6 +66,25 @@ globalThis.addEventListener('message', e => {
}
});
function readEndpointToken(endpoint: string) {
if (!authTokenPort) {
return Promise.resolve(null);
}
const id = `${Date.now()}:${Math.random()}`;
return new Promise<string | null>(resolve => {
const timeout = setTimeout(() => {
pendingTokenRequests.delete(id);
resolve(null);
}, 5000);
pendingTokenRequests.set(id, token => {
clearTimeout(timeout);
resolve(token);
});
authTokenPort?.postMessage({ id, endpoint });
});
}
const consumer = new OpConsumer<WorkerManagerOps>(
globalThis as MessageCommunicapable
);
@@ -18,5 +18,17 @@ export interface AuthPlugin {
verifyToken?: string;
challenge?: string;
}): Promise<{ token: string }>;
signOut(options: { endpoint: string }): Promise<void>;
signInOpenApp(options: {
endpoint: string;
code: string;
}): Promise<{ token: string }>;
signOut(options: { endpoint: string; token?: string | null }): Promise<void>;
readEndpointToken(options: {
endpoint: string;
}): Promise<{ token?: string | null }>;
writeEndpointToken(options: {
endpoint: string;
token: string;
}): Promise<void>;
deleteEndpointToken(options: { endpoint: string }): Promise<void>;
}
+53 -24
View File
@@ -1,4 +1,19 @@
import { openDB } from 'idb';
import { Auth } from './plugins/auth';
function authEndpointForUrl(url: string | URL) {
try {
const parsed = new URL(url, globalThis.location.origin);
return parsed.protocol === 'http:' || parsed.protocol === 'https:'
? parsed.origin
: null;
} catch {
return null;
}
}
function canonicalEndpoint(endpoint: string) {
return authEndpointForUrl(endpoint) ?? endpoint;
}
/**
* the below code includes the custom fetch and xmlhttprequest implementation for ios webview.
@@ -8,9 +23,11 @@ const rawFetch = globalThis.fetch;
globalThis.fetch = async (input: RequestInfo | URL, init?: RequestInit) => {
const request = new Request(input, init);
const origin = new URL(request.url, globalThis.location.origin).origin;
const origin = authEndpointForUrl(request.url);
const token = await readEndpointToken(origin);
const token = origin
? await readEndpointToken(origin).catch(() => null)
: null;
if (token) {
request.headers.set('Authorization', `Bearer ${token}`);
}
@@ -19,11 +36,30 @@ globalThis.fetch = async (input: RequestInfo | URL, init?: RequestInit) => {
};
const rawXMLHttpRequest = globalThis.XMLHttpRequest;
const xhrRequestUrls = new WeakMap<XMLHttpRequest, string>();
globalThis.XMLHttpRequest = class extends rawXMLHttpRequest {
override send(body?: Document | XMLHttpRequestBodyInit | null): void {
const origin = new URL(this.responseURL, globalThis.location.origin).origin;
override open(
method: string,
url: string | URL,
async: boolean = true,
username?: string | null,
password?: string | null
): void {
xhrRequestUrls.set(this, url.toString());
return super.open(
method,
url,
async,
username ?? undefined,
password ?? undefined
);
}
readEndpointToken(origin).then(
override send(body?: Document | XMLHttpRequestBodyInit | null): void {
const requestUrl = xhrRequestUrls.get(this);
const origin = authEndpointForUrl(requestUrl ?? globalThis.location.href);
(origin ? readEndpointToken(origin) : Promise.resolve(null)).then(
token => {
if (token) {
this.setRequestHeader('Authorization', `Bearer ${token}`);
@@ -31,7 +67,7 @@ globalThis.XMLHttpRequest = class extends rawXMLHttpRequest {
return super.send(body);
},
() => {
throw new Error('Failed to read token');
return super.send(body);
}
);
}
@@ -40,26 +76,19 @@ globalThis.XMLHttpRequest = class extends rawXMLHttpRequest {
export async function readEndpointToken(
endpoint: string
): Promise<string | null> {
const idb = await openDB('affine-token', 1, {
upgrade(db) {
if (!db.objectStoreNames.contains('tokens')) {
db.createObjectStore('tokens', { keyPath: 'endpoint' });
}
},
const { token } = await Auth.readEndpointToken({
endpoint: canonicalEndpoint(endpoint),
});
const token = await idb.get('tokens', endpoint);
return token ? token.token : null;
return token ?? null;
}
export async function writeEndpointToken(endpoint: string, token: string) {
const db = await openDB('affine-token', 1, {
upgrade(db) {
if (!db.objectStoreNames.contains('tokens')) {
db.createObjectStore('tokens', { keyPath: 'endpoint' });
}
},
await Auth.writeEndpointToken({
endpoint: canonicalEndpoint(endpoint),
token,
});
await db.put('tokens', { endpoint, token });
}
export async function deleteEndpointToken(endpoint: string) {
await Auth.deleteEndpointToken({ endpoint: canonicalEndpoint(endpoint) });
}
@@ -41,7 +41,6 @@ describe('AuthService oauthPreflight', () => {
framework.service(NbstoreService, {
realtime: { subscribe: () => of() },
} as any);
framework.service(AuthService, [
FetchService,
AuthStore,
@@ -16,6 +16,7 @@ export const Captcha = () => {
const handleTurnstileSuccess = useCallback(
(token: string) => {
captchaService.challenge$.next(undefined);
captchaService.verifyToken$.next(token);
},
[captchaService]
@@ -86,7 +86,6 @@ export const SignInWithEmailStep = ({
setIsSending(true);
try {
setResendCountDown(60);
captchaService.revalidate();
await authService.sendEmailMagicLink(
email,
verifyToken,
@@ -100,6 +99,7 @@ export const SignInWithEmailStep = ({
title: 'Failed to sign in',
message: t[`error.${error.name}`](error.data),
});
captchaService.revalidate();
}
setIsSending(false);
}, [
@@ -182,6 +182,7 @@ export const SignInWithEmailStep = ({
errorHint={otpError}
onEnter={onContinue}
type="text"
autoComplete="one-time-code"
required={true}
maxLength={6}
/>
@@ -85,7 +85,6 @@ export const SignInWithPasswordStep = ({
setIsLoading(true);
try {
captchaService.revalidate();
await authService.signInPassword({
email,
password,
@@ -111,6 +110,7 @@ export const SignInWithPasswordStep = ({
: t[`error.${error.name}`](error.data),
});
}
captchaService.revalidate();
} finally {
setIsLoading(false);
}
@@ -138,28 +138,50 @@ export const SignInWithPasswordStep = ({
/>
<AuthContent>
<AuthInput
label={t['com.affine.settings.email']()}
disabled={true}
value={email}
/>
<AuthInput
autoFocus
data-testid="password-input"
label={t['com.affine.auth.password']()}
value={password}
type="password"
onChange={(value: string) => {
setPassword(value);
if (passwordError) {
setPasswordError(false);
setPasswordErrorHint(t['com.affine.auth.password.error']());
}
<form
onSubmit={event => {
event.preventDefault();
onSignIn();
}}
error={passwordError}
errorHint={passwordErrorHint}
onEnter={onSignIn}
/>
>
<AuthInput
label={t['com.affine.settings.email']()}
readOnly={true}
value={email}
type="email"
name="username"
autoComplete="username"
/>
<AuthInput
autoFocus
data-testid="password-input"
label={t['com.affine.auth.password']()}
value={password}
type="password"
name="password"
autoComplete="current-password"
onChange={(value: string) => {
setPassword(value);
if (passwordError) {
setPasswordError(false);
setPasswordErrorHint(t['com.affine.auth.password.error']());
}
}}
error={passwordError}
errorHint={passwordErrorHint}
onEnter={onSignIn}
/>
{!verifyToken && needCaptcha && <Captcha />}
<Button
data-testid="sign-in-button"
variant="primary"
size="extraLarge"
style={{ width: '100%' }}
disabled={isLoading || (!verifyToken && needCaptcha)}
>
{t['com.affine.auth.sign.in']()}
</Button>
</form>
{!isSelfhosted && (
<div className={styles.passwordButtonRow}>
<a
@@ -171,17 +193,6 @@ export const SignInWithPasswordStep = ({
</a>
</div>
)}
{!verifyToken && needCaptcha && <Captcha />}
<Button
data-testid="sign-in-button"
variant="primary"
size="extraLarge"
style={{ width: '100%' }}
disabled={isLoading || (!verifyToken && needCaptcha)}
onClick={onSignIn}
>
{t['com.affine.auth.sign.in']()}
</Button>
</AuthContent>
<AuthFooter>
<Back changeState={changeState} />
@@ -90,7 +90,9 @@ export const SignInStep = ({
setIsMutating(true);
try {
const { hasPassword } = await authService.checkUserByEmail(email);
const { methods } = await authService.checkUserByEmail(email);
const hasPassword = methods.password.available;
const canUseMagicLink = methods.magicLink.available;
if (hasPassword) {
changeState(prev => ({
@@ -99,13 +101,18 @@ export const SignInStep = ({
step: 'signInWithPassword',
hasPassword: true,
}));
} else {
} else if (canUseMagicLink) {
changeState(prev => ({
...prev,
email,
step: 'signInWithEmail',
hasPassword: false,
}));
} else {
notify.error({
title: 'Failed to sign in',
message: 'This email is not available for sign in.',
});
}
} catch (err: any) {
console.error(err);
@@ -151,31 +158,41 @@ export const SignInStep = ({
<AuthContent>
<OAuth redirectUrl={state.redirectUrl} />
<AuthInput
className={style.authInput}
label={t['com.affine.settings.email']()}
placeholder={t['com.affine.auth.sign.email.placeholder']()}
onChange={setEmail}
error={!isValidEmail}
errorHint={
isValidEmail ? '' : t['com.affine.auth.sign.email.error']()
}
onEnter={onContinue}
/>
<Button
className={style.signInButton}
style={{ width: '100%' }}
size="extraLarge"
data-testid="continue-login-button"
block
loading={isMutating}
suffix={<ArrowRightBigIcon />}
suffixStyle={{ width: 20, height: 20, color: cssVar('blue') }}
onClick={onContinue}
<form
onSubmit={event => {
event.preventDefault();
onContinue();
}}
>
{t['com.affine.auth.sign.email.continue']()}
</Button>
<AuthInput
className={style.authInput}
label={t['com.affine.settings.email']()}
placeholder={t['com.affine.auth.sign.email.placeholder']()}
onChange={setEmail}
error={!isValidEmail}
errorHint={
isValidEmail ? '' : t['com.affine.auth.sign.email.error']()
}
onEnter={onContinue}
type="email"
name="username"
autoComplete="username"
/>
<Button
className={style.signInButton}
style={{ width: '100%' }}
size="extraLarge"
data-testid="continue-login-button"
block
loading={isMutating}
disabled={isMutating}
suffix={<ArrowRightBigIcon />}
suffixStyle={{ width: 20, height: 20, color: cssVar('blue') }}
>
{t['com.affine.auth.sign.email.continue']()}
</Button>
</form>
{!isSelfhosted && (
<>
@@ -25,6 +25,7 @@ import { useEffect, useState } from 'react';
export const ChangePasswordDialog = ({
close,
hasPassword: hasPasswordProp,
server: serverBaseUrl,
}: DialogComponentProps<GLOBAL_DIALOG_SCHEMA['change-password']>) => {
const t = useI18n();
@@ -44,7 +45,8 @@ export const ChangePasswordDialog = ({
const authService = server.scope.get(AuthService);
const account = useLiveData(authService.session.account$);
const email = account?.email;
const hasPassword = account?.info?.hasPassword;
const hasPassword =
hasPasswordProp ?? account?.info?.authMethods?.password.bound ?? false;
const [hasSentEmail, setHasSentEmail] = useState(false);
const [loading, setLoading] = useState(false);
const passwordLimits = useLiveData(
@@ -201,13 +201,19 @@ export const AccountSetting = ({
const onPasswordButtonClick = useCallback(() => {
globalDialogService.open('change-password', {
hasPassword: account?.info?.authMethods?.password.bound,
server: serverService.server.baseUrl,
});
}, [globalDialogService, serverService.server.baseUrl]);
}, [
account?.info?.authMethods?.password.bound,
globalDialogService,
serverService.server.baseUrl,
]);
if (!account) {
return null;
}
const hasPassword = account.info?.authMethods?.password.bound;
return (
<>
@@ -233,7 +239,7 @@ export const AccountSetting = ({
desc={t['com.affine.settings.password.message']()}
>
<Button onClick={onPasswordButtonClick}>
{account.info?.hasPassword
{hasPassword
? t['com.affine.settings.password.action.change']()
: t['com.affine.settings.password.action.set']()}
</Button>
@@ -79,6 +79,13 @@ export function configureDefaultAuthProvider(framework: Framework) {
},
});
},
async signInOpenAppSignInCode(code: string) {
await fetchService.fetch('/api/auth/open-app/sign-in', {
method: 'POST',
body: JSON.stringify({ code }),
headers: { 'content-type': 'application/json' },
});
},
async signOut() {
const csrfToken = getCookieValue(CSRF_COOKIE_NAME);
await fetchService.fetch('/api/auth/sign-out', {
@@ -21,6 +21,8 @@ export interface AuthProvider {
challenge?: string;
}): Promise<void>;
signInOpenAppSignInCode(code: string): Promise<void>;
signOut(): Promise<void>;
}
@@ -222,11 +222,7 @@ export class AuthService extends Service {
}
async signInOpenAppSignInCode(code: string) {
await this.fetchService.fetch('/api/auth/open-app/sign-in', {
method: 'POST',
body: JSON.stringify({ code }),
headers: { 'content-type': 'application/json' },
});
await this.store.signInOpenAppSignInCode(code);
this.session.revalidate();
}
@@ -33,7 +33,7 @@ export class CaptchaService extends Service {
revalidate = effect(
exhaustMap(() => {
return fromPromise(async signal => {
if (!this.needCaptcha$.value) {
if (!this.needCaptcha$.value || !this.validatorProvider) {
return {};
}
const res = await this.fetchService.fetch('/api/auth/challenge', {
@@ -46,17 +46,14 @@ export class CaptchaService extends Service {
if (!data || !data.challenge || !data.resource) {
throw new Error('Invalid challenge');
}
if (this.validatorProvider) {
const token = await this.validatorProvider.validate(
data.challenge,
data.resource
);
return {
token,
challenge: data.challenge,
};
}
return { challenge: data.challenge, token: undefined };
const token = await this.validatorProvider.validate(
data.challenge,
data.resource
);
return {
token,
challenge: data.challenge,
};
}).pipe(
tap(({ challenge, token }) => {
this.verifyToken$.next(token);
@@ -19,6 +19,11 @@ export interface AccountProfile {
email: string;
name: string;
hasPassword: boolean;
authMethods?: {
password: { bound: boolean };
oauth: { bound: boolean; providers: string[] };
passkey: { bound: boolean; count: number };
};
avatarUrl: string | null;
emailVerified: string | null;
features?: string[];
@@ -61,16 +66,20 @@ export class AuthStore extends Store {
}
async fetchSession() {
const { user } = await this.nbstoreService.realtime.request(
'user.profile.get',
{},
{ timeoutMs: 10000 }
);
const { user } = await this.fetchService
.fetch('/api/auth/session')
.then(res => res.json());
const authMethods = user
? await this.fetchService
.fetch('/api/auth/methods')
.then(res => (res.ok ? res.json() : undefined))
: undefined;
return {
user: user
? {
...user,
hasPassword: Boolean(user.hasPassword),
authMethods,
emailVerified: user.emailVerified ? 'true' : null,
}
: null,
@@ -103,6 +112,10 @@ export class AuthStore extends Store {
await this.authProvider.signInPassword(credential);
}
async signInOpenAppSignInCode(code: string) {
await this.authProvider.signInOpenAppSignInCode(code);
}
async signOut() {
await this.authProvider.signOut();
await this.nbstoreService.realtime.configure({
@@ -155,8 +168,12 @@ export class AuthStore extends Store {
const data = (await res.json()) as {
registered: boolean;
hasPassword: boolean;
magicLink: boolean;
methods: {
password: { available: boolean };
magicLink: { available: boolean };
oauth: { available: boolean; providers: string[] };
passkey: { available: boolean; discoverable: boolean };
};
};
return data;
@@ -30,7 +30,10 @@ export type GLOBAL_DIALOG_SCHEMA = {
snapshotUrl: string;
}) => void;
'sign-in': (props: { server?: string; step?: string }) => void;
'change-password': (props: { server?: string }) => void;
'change-password': (props: {
server?: string;
hasPassword?: boolean;
}) => void;
'verify-email': (props: { server?: string; changeEmail?: boolean }) => void;
'enable-cloud': (props: {
workspaceId: string;
@@ -12,7 +12,7 @@
"fr": 97,
"hi": 1,
"it": 94,
"ja": 93,
"ja": 92,
"kk": 100,
"ko": 93,
"nb-NO": 46,
@@ -20,9 +20,9 @@
"pt-BR": 92,
"ru": 95,
"sv-SE": 93,
"uk": 93,
"tr": 100,
"uk": 92,
"ur": 100,
"zh-Hans": 100,
"zh-Hant": 93,
"tr": 100
"zh-Hant": 93
}