Compare commits

...

8 Commits

Author SHA1 Message Date
DarkSky
76eefcb4f3 fix: codeql 2026-02-03 23:34:33 +08:00
DarkSky
18b8c7831f fix: lint & test 2026-02-03 18:11:43 +08:00
DarkSky
ae26418281 Merge branch 'canary' into darksky/remove-old-client-support 2026-02-03 17:05:35 +08:00
renovate[bot]
de29e8300a chore: bump up @types/uuid version to v11 (#14364)
This PR contains the following updates:

| Package | Change |
[Age](https://docs.renovatebot.com/merge-confidence/) |
[Confidence](https://docs.renovatebot.com/merge-confidence/) |
|---|---|---|---|
| @​types/uuid | [`^10.0.0` →
`^11.0.0`](https://renovatebot.com/diffs/npm/@types%2fuuid/10.0.0/11.0.0)
|
![age](https://developer.mend.io/api/mc/badges/age/npm/@types%2fuuid/11.0.0?slim=true)
|
![confidence](https://developer.mend.io/api/mc/badges/confidence/npm/@types%2fuuid/10.0.0/11.0.0?slim=true)
|

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.

---

- [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check
this box

---

This PR was generated by [Mend Renovate](https://mend.io/renovate/).
View the [repository job
log](https://developer.mend.io/github/toeverything/AFFiNE).

<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiI0Mi45Mi4xIiwidXBkYXRlZEluVmVyIjoiNDIuOTIuMSIsInRhcmdldEJyYW5jaCI6ImNhbmFyeSIsImxhYmVscyI6WyJkZXBlbmRlbmNpZXMiXX0=-->

Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2026-02-02 17:57:23 +00:00
Fahleen Arif
e2b26ffb0c fix: visibility issue of document in print mode (#14367)
fix #14330 

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved print-to-PDF rendering by enforcing a consistent light theme,
ensuring better readability and visual consistency in exported PDF
documents.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-02-03 01:52:14 +08:00
DarkSky
63e602a6f5 fix: rebase type error 2026-02-02 21:42:17 +08:00
renovate[bot]
12f0a9ae62 chore: bump up @types/multer version to v2 (#14353)
This PR contains the following updates:

| Package | Change |
[Age](https://docs.renovatebot.com/merge-confidence/) |
[Confidence](https://docs.renovatebot.com/merge-confidence/) |
|---|---|---|---|
|
[@types/multer](https://redirect.github.com/DefinitelyTyped/DefinitelyTyped/tree/master/types/multer)
([source](https://redirect.github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/multer))
| [`^1` →
`^2.0.0`](https://renovatebot.com/diffs/npm/@types%2fmulter/1.4.12/2.0.0)
|
![age](https://developer.mend.io/api/mc/badges/age/npm/@types%2fmulter/2.0.0?slim=true)
|
![confidence](https://developer.mend.io/api/mc/badges/confidence/npm/@types%2fmulter/1.4.12/2.0.0?slim=true)
|

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.

---

- [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check
this box

---

This PR was generated by [Mend Renovate](https://mend.io/renovate/).
View the [repository job
log](https://developer.mend.io/github/toeverything/AFFiNE).

<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiI0Mi45Mi4xIiwidXBkYXRlZEluVmVyIjoiNDIuOTIuMSIsInRhcmdldEJyYW5jaCI6ImNhbmFyeSIsImxhYmVscyI6WyJkZXBlbmRlbmNpZXMiXX0=-->

Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: DarkSky <25152247+darkskygit@users.noreply.github.com>
2026-02-02 16:21:40 +08:00
DarkSky
d8404e9df8 feat: drop outdated client support 2026-01-17 22:39:20 +08:00
104 changed files with 3344 additions and 1010 deletions

View File

@@ -832,8 +832,8 @@
},
"versionControl.requiredVersion": {
"type": "string",
"description": "Allowed version range of the app that allowed to access the server. Requires 'client/versionControl.enabled' to be true to take effect.\n@default \">=0.20.0\"",
"default": ">=0.20.0"
"description": "Allowed version range of the app that allowed to access the server. Requires 'client/versionControl.enabled' to be true to take effect.\n@default \">=0.25.0\"",
"default": ">=0.25.0"
}
}
},

View File

@@ -35,9 +35,28 @@ export async function printToPdf(
overflow: initial !important;
print-color-adjust: exact;
-webkit-print-color-adjust: exact;
color: #000 !important;
background: #fff !important;
color-scheme: light !important;
}
::-webkit-scrollbar {
display: none;
::-webkit-scrollbar {
display: none;
}
:root, body {
--affine-text-primary: #000 !important;
--affine-text-secondary: #111 !important;
--affine-text-tertiary: #333 !important;
--affine-background-primary: #fff !important;
--affine-background-secondary: #fff !important;
--affine-background-tertiary: #fff !important;
}
body, [data-theme='dark'] {
color: #000 !important;
background: #fff !important;
}
body * {
color: #000 !important;
-webkit-text-fill-color: #000 !important;
}
:root {
--affine-note-shadow-box: none !important;
@@ -95,6 +114,14 @@ export async function printToPdf(
true
) as HTMLDivElement;
// force light theme in print iframe
iframe.contentWindow.document.documentElement.setAttribute(
'data-theme',
'light'
);
iframe.contentWindow.document.body.setAttribute('data-theme', 'light');
importedRoot.setAttribute('data-theme', 'light');
// draw saved canvas image to canvas
const allImportedCanvas = importedRoot.getElementsByTagName('canvas');
for (const importedCanvas of allImportedCanvas) {

View File

@@ -0,0 +1,42 @@
DO $$
DECLARE error_message TEXT;
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
BEGIN
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
EXCEPTION
WHEN OTHERS THEN
error_message := 'pgcrypto extension not found. access_tokens.token will not be hashed automatically.' || E'\n' ||
'Tokens will be lazily migrated on use.';
RAISE WARNING '%', error_message;
END;
END IF;
IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
UPDATE "access_tokens"
SET "token" = encode(digest("token", 'sha256'), 'hex')
WHERE substr("token", 1, 3) = 'ut_';
END IF;
END $$;
-- CreateTable
CREATE TABLE "magic_link_otps" (
"id" VARCHAR NOT NULL,
"email" TEXT NOT NULL,
"otp_hash" VARCHAR NOT NULL,
"token" TEXT NOT NULL,
"client_nonce" TEXT,
"attempts" INTEGER NOT NULL DEFAULT 0,
"expires_at" TIMESTAMPTZ(3) NOT NULL,
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(3) NOT NULL,
CONSTRAINT "magic_link_otps_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "magic_link_otps_email_key" ON "magic_link_otps"("email");
-- CreateIndex
CREATE INDEX "magic_link_otps_expires_at_idx" ON "magic_link_otps"("expires_at");

View File

@@ -152,6 +152,7 @@
"nodemon": "^3.1.11",
"react-email": "4.0.11",
"sinon": "^21.0.1",
"socket.io-client": "^4.8.3",
"supertest": "^7.1.4",
"why-is-node-running": "^3.2.2"
},

View File

@@ -106,6 +106,21 @@ model VerificationToken {
@@map("verification_tokens")
}
model MagicLinkOtp {
id String @id @default(uuid()) @db.VarChar
email String @unique @db.Text
otpHash String @map("otp_hash") @db.VarChar
token String @db.Text
clientNonce String? @map("client_nonce") @db.Text
attempts Int @default(0)
expiresAt DateTime @map("expires_at") @db.Timestamptz(3)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
@@index([expiresAt])
@@map("magic_link_otps")
}
model Workspace {
// NOTE: manually set this column type to identity in migration file
sid Int @unique @default(autoincrement())

View File

@@ -32,6 +32,16 @@ Generated by [AVA](https://avajs.dev).
> Snapshot 4
{
code: 'Bad Request',
message: 'Invalid header',
name: 'BAD_REQUEST',
status: 400,
type: 'BAD_REQUEST',
}
> Snapshot 5
Buffer @Uint8Array [
66616b65 20696d61 6765
]
@@ -56,7 +66,7 @@ Generated by [AVA](https://avajs.dev).
{
code: 'Bad Request',
message: 'Invalid URL',
message: 'Invalid header',
name: 'BAD_REQUEST',
status: 400,
type: 'BAD_REQUEST',
@@ -64,6 +74,16 @@ Generated by [AVA](https://avajs.dev).
> Snapshot 4
{
code: 'Bad Request',
message: 'Invalid URL',
name: 'BAD_REQUEST',
status: 400,
type: 'BAD_REQUEST',
}
> Snapshot 5
{
description: 'Test Description',
favicons: [
@@ -77,7 +97,7 @@ Generated by [AVA](https://avajs.dev).
videos: [],
}
> Snapshot 5
> Snapshot 6
{
charset: 'gbk',
@@ -90,7 +110,7 @@ Generated by [AVA](https://avajs.dev).
videos: [],
}
> Snapshot 6
> Snapshot 7
{
charset: 'shift_jis',
@@ -103,7 +123,7 @@ Generated by [AVA](https://avajs.dev).
videos: [],
}
> Snapshot 7
> Snapshot 8
{
charset: 'big5',
@@ -116,7 +136,7 @@ Generated by [AVA](https://avajs.dev).
videos: [],
}
> Snapshot 8
> Snapshot 9
{
charset: 'euc-kr',

View File

@@ -33,7 +33,7 @@ test('change email', async t => {
const u2Email = 'u2@affine.pro';
const user = await app.signupV1(u1Email);
await sendChangeEmail(app, u1Email, 'affine.pro');
await sendChangeEmail(app, u1Email, '/email-change');
const changeMail = app.mails.last('ChangeEmail');
@@ -53,7 +53,7 @@ test('change email', async t => {
app,
changeEmailToken as string,
u2Email,
'affine.pro'
'/email-change-verify'
);
const verifyMail = app.mails.last('VerifyChangeEmail');
@@ -94,7 +94,7 @@ test('set and change password', async t => {
const u1Email = 'u1@affine.pro';
const u1 = await app.signupV1(u1Email);
await sendSetPasswordEmail(app, u1Email, 'affine.pro');
await sendSetPasswordEmail(app, u1Email, '/password-change');
const setPasswordMail = app.mails.last('ChangePassword');
const link = new URL(setPasswordMail.props.url);
@@ -131,3 +131,29 @@ test('set and change password', async t => {
t.not(user, null, 'failed to get current user');
t.is(user?.email, u1Email, 'failed to get current user');
});
test('should forbid graphql callbackUrl to external origin', async t => {
const { app } = t.context;
const u1Email = 'u1@affine.pro';
await app.signupV1(u1Email);
const res = await app
.POST('/graphql')
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
.send({
query: `
mutation($email: String!, $callbackUrl: String!) {
sendChangeEmail(email: $email, callbackUrl: $callbackUrl)
}
`,
variables: {
email: u1Email,
callbackUrl: 'https://evil.example',
},
})
.expect(200);
t.truthy(res.body.errors?.length);
t.is(res.body.errors[0].extensions?.name, 'ACTION_FORBIDDEN');
});

View File

@@ -5,6 +5,7 @@ import { HttpStatus } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import supertest from 'supertest';
import { parseCookies as safeParseCookies } from '../../base/utils/request';
import { AuthService } from '../../core/auth/service';
@@ -126,6 +127,36 @@ test('should not be able to sign in if forbidden', async t => {
t.pass();
});
test('should forbid magic link with external callbackUrl', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
await app
.POST('/api/auth/sign-in')
.send({
email: u1.email,
callbackUrl: 'https://evil.example/magic-link',
})
.expect(HttpStatus.FORBIDDEN);
t.pass();
});
test('should forbid magic link with untrusted redirect_uri in callbackUrl', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
await app
.POST('/api/auth/sign-in')
.send({
email: u1.email,
callbackUrl: '/magic-link?redirect_uri=https://evil.example',
})
.expect(HttpStatus.FORBIDDEN);
t.pass();
});
test('should be able to sign out', async t => {
const { app } = t.context;
@@ -136,13 +167,82 @@ test('should be able to sign out', async t => {
.send({ email: u1.email, password: u1.password })
.expect(200);
await app.GET('/api/auth/sign-out').expect(200);
await app.POST('/api/auth/sign-out').expect(200);
const session = await currentUser(app);
t.falsy(session);
});
test('should reject sign out when csrf token mismatched', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
await app
.POST('/api/auth/sign-in')
.send({ email: u1.email, password: u1.password })
.expect(200);
await app
.POST('/api/auth/sign-out')
.set('x-affine-csrf-token', 'invalid')
.expect(HttpStatus.FORBIDDEN);
const session = await currentUser(app);
t.is(session?.id, u1.id);
});
test('should sign in desktop app via one-time open-app code', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
await app
.POST('/api/auth/sign-in')
.send({ email: u1.email, password: u1.password })
.expect(200);
const codeRes = await app.POST('/api/auth/open-app/sign-in-code').expect(201);
const code = codeRes.body.code as string;
t.truthy(code);
const exchangeRes = await supertest(app.getHttpServer())
.post('/api/auth/open-app/sign-in')
.send({ code })
.expect(201);
const exchangedCookies = exchangeRes.get('Set-Cookie') ?? [];
t.true(
exchangedCookies.some(c =>
c.startsWith(`${AuthService.sessionCookieName}=`)
)
);
const cookieHeader = exchangedCookies.map(c => c.split(';')[0]).join('; ');
const sessionRes = await supertest(app.getHttpServer())
.get('/api/auth/session')
.set('Cookie', cookieHeader)
.expect(200);
t.is(sessionRes.body.user?.id, u1.id);
// one-time use
await supertest(app.getHttpServer())
.post('/api/auth/open-app/sign-in')
.send({ code })
.expect(400)
.expect({
status: 400,
code: 'Bad Request',
type: 'BAD_REQUEST',
name: 'INVALID_AUTH_STATE',
message:
'Invalid auth state. You might start the auth progress from another device.',
});
});
test('should be able to correct user id cookie', async t => {
const { app } = t.context;
@@ -228,7 +328,7 @@ test('should be able to sign out multiple accounts in one session', async t => {
const u2 = await app.signupV1('u2@affine.pro');
// sign out u2
await app.GET(`/api/auth/sign-out?user_id=${u2.id}`).expect(200);
await app.POST(`/api/auth/sign-out?user_id=${u2.id}`).expect(200);
// list [u1]
let session = await app.GET('/api/auth/session').expect(200);
@@ -241,7 +341,7 @@ test('should be able to sign out multiple accounts in one session', async t => {
.expect(200);
// sign out all account in session
await app.GET('/api/auth/sign-out').expect(200);
await app.POST('/api/auth/sign-out').expect(200);
session = await app.GET('/api/auth/session').expect(200);
t.falsy(session.body.user);
@@ -337,3 +437,56 @@ test('should not be able to sign in if token is invalid', async t => {
t.is(res.body.message, 'An invalid email token provided.');
});
test('should not allow magic link OTP replay', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
await app.POST('/api/auth/sign-in').send({ email: u1.email }).expect(200);
const signInMail = app.mails.last('SignIn');
const url = new URL(signInMail.props.url);
const email = url.searchParams.get('email');
const token = url.searchParams.get('token');
await app.POST('/api/auth/magic-link').send({ email, token }).expect(201);
await app
.POST('/api/auth/magic-link')
.send({ email, token })
.expect(400)
.expect({
status: 400,
code: 'Bad Request',
type: 'INVALID_INPUT',
name: 'INVALID_EMAIL_TOKEN',
message: 'An invalid email token provided.',
});
t.pass();
});
test('should lock magic link OTP after too many attempts', async t => {
const { app } = t.context;
const u1 = await app.createUser('u1@affine.pro');
await app.POST('/api/auth/sign-in').send({ email: u1.email }).expect(200);
const signInMail = app.mails.last('SignIn');
const url = new URL(signInMail.props.url);
const email = url.searchParams.get('email');
const token = url.searchParams.get('token') as string;
const wrongOtp = token === '000000' ? '000001' : '000000';
for (let i = 0; i < 10; i++) {
await app
.POST('/api/auth/magic-link')
.send({ email, token: wrongOtp })
.expect(400);
}
await app.POST('/api/auth/magic-link').send({ email, token }).expect(400);
const session = await currentUser(app);
t.falsy(session);
});

View File

@@ -1,3 +1,5 @@
import { randomUUID } from 'node:crypto';
import { TestingModule } from '@nestjs/testing';
import test from 'ava';
@@ -7,6 +9,8 @@ import { createTestingModule } from './utils';
let cache: Cache;
let module: TestingModule;
const keyPrefix = `test:${randomUUID()}:`;
const key = (name: string) => `${keyPrefix}${name}`;
test.before(async () => {
module = await createTestingModule({
imports: FunctionalityModules,
@@ -19,78 +23,78 @@ test.after.always(async () => {
});
test('should be able to set normal cache', async t => {
t.true(await cache.set('test', 1));
t.is(await cache.get<number>('test'), 1);
t.true(await cache.set(key('test'), 1));
t.is(await cache.get<number>(key('test')), 1);
t.true(await cache.has('test'));
t.true(await cache.delete('test'));
t.is(await cache.get('test'), undefined);
t.true(await cache.has(key('test')));
t.true(await cache.delete(key('test')));
t.is(await cache.get(key('test')), undefined);
t.true(await cache.set('test', { a: 1 }));
t.deepEqual(await cache.get('test'), { a: 1 });
t.true(await cache.set(key('test'), { a: 1 }));
t.deepEqual(await cache.get(key('test')), { a: 1 });
});
test('should be able to set cache with non-exiting flag', async t => {
t.true(await cache.setnx('test-nx', 1));
t.false(await cache.setnx('test-nx', 2));
t.is(await cache.get('test-nx'), 1);
t.true(await cache.setnx(key('test-nx'), 1));
t.false(await cache.setnx(key('test-nx'), 2));
t.is(await cache.get(key('test-nx')), 1);
});
test('should be able to set cache with ttl', async t => {
t.true(await cache.set('test-ttl', 1));
t.is(await cache.get('test-ttl'), 1);
t.true(await cache.set(key('test-ttl'), 1));
t.is(await cache.get(key('test-ttl')), 1);
t.true(await cache.expire('test-ttl', 1 * 1000));
const ttl = await cache.ttl('test-ttl');
t.true(await cache.expire(key('test-ttl'), 1 * 1000));
const ttl = await cache.ttl(key('test-ttl'));
t.true(ttl <= 1 * 1000);
t.true(ttl > 0);
});
test('should be able to incr/decr number cache', async t => {
t.true(await cache.set('test-incr', 1));
t.is(await cache.increase('test-incr'), 2);
t.is(await cache.increase('test-incr'), 3);
t.is(await cache.decrease('test-incr'), 2);
t.is(await cache.decrease('test-incr'), 1);
t.true(await cache.set(key('test-incr'), 1));
t.is(await cache.increase(key('test-incr')), 2);
t.is(await cache.increase(key('test-incr')), 3);
t.is(await cache.decrease(key('test-incr')), 2);
t.is(await cache.decrease(key('test-incr')), 1);
// increase an nonexists number
t.is(await cache.increase('test-incr2'), 1);
t.is(await cache.increase('test-incr2'), 2);
t.is(await cache.increase(key('test-incr2')), 1);
t.is(await cache.increase(key('test-incr2')), 2);
});
test('should be able to manipulate list cache', async t => {
t.is(await cache.pushBack('test-list', 1), 1);
t.is(await cache.pushBack('test-list', 2, 3, 4), 4);
t.is(await cache.len('test-list'), 4);
t.is(await cache.pushBack(key('test-list'), 1), 1);
t.is(await cache.pushBack(key('test-list'), 2, 3, 4), 4);
t.is(await cache.len(key('test-list')), 4);
t.deepEqual(await cache.list('test-list', 1, -1), [2, 3, 4]);
t.deepEqual(await cache.list(key('test-list'), 1, -1), [2, 3, 4]);
t.deepEqual(await cache.popFront('test-list', 2), [1, 2]);
t.deepEqual(await cache.popBack('test-list', 1), [4]);
t.deepEqual(await cache.popFront(key('test-list'), 2), [1, 2]);
t.deepEqual(await cache.popBack(key('test-list'), 1), [4]);
t.is(await cache.pushBack('test-list2', { a: 1 }), 1);
t.deepEqual(await cache.popFront('test-list2', 1), [{ a: 1 }]);
t.is(await cache.pushBack(key('test-list2'), { a: 1 }), 1);
t.deepEqual(await cache.popFront(key('test-list2'), 1), [{ a: 1 }]);
});
test('should be able to manipulate map cache', async t => {
t.is(await cache.mapSet('test-map', 'a', 1), true);
t.is(await cache.mapSet('test-map', 'b', 2), true);
t.is(await cache.mapLen('test-map'), 2);
t.is(await cache.mapSet(key('test-map'), 'a', 1), true);
t.is(await cache.mapSet(key('test-map'), 'b', 2), true);
t.is(await cache.mapLen(key('test-map')), 2);
t.is(await cache.mapGet('test-map', 'a'), 1);
t.is(await cache.mapGet('test-map', 'b'), 2);
t.is(await cache.mapGet(key('test-map'), 'a'), 1);
t.is(await cache.mapGet(key('test-map'), 'b'), 2);
t.is(await cache.mapIncrease('test-map', 'a'), 2);
t.is(await cache.mapIncrease('test-map', 'a'), 3);
t.is(await cache.mapDecrease('test-map', 'b', 3), -1);
t.is(await cache.mapIncrease(key('test-map'), 'a'), 2);
t.is(await cache.mapIncrease(key('test-map'), 'a'), 3);
t.is(await cache.mapDecrease(key('test-map'), 'b', 3), -1);
const keys = await cache.mapKeys('test-map');
const keys = await cache.mapKeys(key('test-map'));
t.deepEqual(keys, ['a', 'b']);
const randomKey = await cache.mapRandomKey('test-map');
const randomKey = await cache.mapRandomKey(key('test-map'));
t.truthy(randomKey);
t.true(keys.includes(randomKey!));
t.is(await cache.mapDelete('test-map', 'a'), true);
t.is(await cache.mapGet('test-map', 'a'), undefined);
t.is(await cache.mapDelete(key('test-map'), 'a'), true);
t.is(await cache.mapGet(key('test-map'), 'a'), undefined);
});

View File

@@ -922,7 +922,6 @@ test('should be able to manage context', async t => {
const { id: fileId } = await addContextFile(
app,
contextId,
'fileId1',
'sample.pdf',
buffer
);

View File

@@ -41,6 +41,7 @@ interface TestingAppMetadata {
export class TestingApp extends NestApplication {
private sessionCookie: string | null = null;
private currentUserCookie: string | null = null;
private csrfCookie: string | null = null;
private readonly userCookies: Set<string> = new Set();
create = createFactory(this.get(PrismaClient, { strict: false }));
@@ -65,12 +66,23 @@ export class TestingApp extends NestApplication {
method: 'options' | 'get' | 'post' | 'put' | 'delete' | 'patch',
path: string
): supertest.Test {
return supertest(this.getHttpServer())
const cookies = [
`${AuthService.sessionCookieName}=${this.sessionCookie ?? ''}`,
`${AuthService.userCookieName}=${this.currentUserCookie ?? ''}`,
];
if (this.csrfCookie) {
cookies.push(`${AuthService.csrfCookieName}=${this.csrfCookie}`);
}
const req = supertest(this.getHttpServer())
[method](path)
.set('Cookie', [
`${AuthService.sessionCookieName}=${this.sessionCookie ?? ''}`,
`${AuthService.userCookieName}=${this.currentUserCookie ?? ''}`,
]);
.set('Cookie', cookies);
if (this.csrfCookie) {
req.set('x-affine-csrf-token', this.csrfCookie);
}
return req;
}
gql = gqlFetcherFactory('', async (_input, init) => {
@@ -123,6 +135,9 @@ export class TestingApp extends NestApplication {
this.sessionCookie = cookies[AuthService.sessionCookieName];
this.currentUserCookie = cookies[AuthService.userCookieName];
if (AuthService.csrfCookieName in cookies) {
this.csrfCookie = cookies[AuthService.csrfCookieName] || null;
}
if (this.currentUserCookie) {
this.userCookies.add(this.currentUserCookie);
}
@@ -180,13 +195,17 @@ export class TestingApp extends NestApplication {
}
async logout(userId?: string) {
const res = await this.GET(
const res = await this.POST(
'/api/auth/sign-out' + (userId ? `?user_id=${userId}` : '')
).expect(200);
const cookies = parseCookies(res);
this.sessionCookie = cookies[AuthService.sessionCookieName];
if (AuthService.csrfCookieName in cookies) {
this.csrfCookie = cookies[AuthService.csrfCookieName] || null;
}
if (!this.sessionCookie) {
this.currentUserCookie = null;
this.csrfCookie = null;
this.userCookies.clear();
} else {
this.currentUserCookie = cookies[AuthService.userCookieName];

View File

@@ -16,9 +16,13 @@ e2e('should get doc markdown success', async t => {
user: owner,
});
const path = `/rpc/workspaces/${workspace.id}/docs/${docSnapshot.id}/markdown`;
const res = await app
.GET(`/rpc/workspaces/${workspace.id}/docs/${docSnapshot.id}/markdown`)
.set('x-access-token', crypto.sign(docSnapshot.id))
.GET(path)
.set(
'x-access-token',
crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect(200)
.expect('Content-Type', 'application/json; charset=utf-8');
@@ -32,9 +36,13 @@ e2e('should get doc markdown return null when doc not exists', async t => {
});
const docId = randomUUID();
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}/markdown`;
const res = await app
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/markdown`)
.set('x-access-token', crypto.sign(docId))
.GET(path)
.set(
'x-access-token',
crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect(404)
.expect('Content-Type', 'application/json; charset=utf-8');

View File

@@ -39,31 +39,7 @@ Generated by [AVA](https://avajs.dev).
},
}
## should not return apple oauth provider when client version is not specified
> Snapshot 1
{
serverConfig: {
oauthProviders: [
'Google',
],
},
}
## should not return apple oauth provider in version < 0.22.0
> Snapshot 1
{
serverConfig: {
oauthProviders: [
'Google',
],
},
}
## should not return apple oauth provider when client version format is not correct
## should return apple oauth provider when client version is not specified
> Snapshot 1
@@ -71,6 +47,7 @@ Generated by [AVA](https://avajs.dev).
serverConfig: {
oauthProviders: [
'Google',
'Apple',
],
},
}

View File

@@ -71,7 +71,7 @@ e2e('should return apple oauth provider in version >= 0.22.0', async t => {
});
e2e(
'should not return apple oauth provider when client version is not specified',
'should return apple oauth provider when client version is not specified',
async t => {
const res = await app.gql({
query: oauthProvidersQuery,
@@ -80,32 +80,3 @@ e2e(
t.snapshot(res);
}
);
e2e('should not return apple oauth provider in version < 0.22.0', async t => {
const res = await app.gql({
query: oauthProvidersQuery,
context: {
headers: {
'x-affine-version': '0.21.0',
},
},
});
t.snapshot(res);
});
e2e(
'should not return apple oauth provider when client version format is not correct',
async t => {
const res = await app.gql({
query: oauthProvidersQuery,
context: {
headers: {
'x-affine-version': 'mock-invalid-version',
},
},
});
t.snapshot(res);
}
);

View File

@@ -228,11 +228,13 @@ async function getBlobUploadPartUrl(
) {
const data = await gql(
`
mutation getBlobUploadPartUrl($workspaceId: String!, $key: String!, $uploadId: String!, $partNumber: Int!) {
getBlobUploadPartUrl(workspaceId: $workspaceId, key: $key, uploadId: $uploadId, partNumber: $partNumber) {
uploadUrl
headers
expiresAt
query getBlobUploadPartUrl($workspaceId: String!, $key: String!, $uploadId: String!, $partNumber: Int!) {
workspace(id: $workspaceId) {
blobUploadPartUrl(key: $key, uploadId: $uploadId, partNumber: $partNumber) {
uploadUrl
headers
expiresAt
}
}
}
`,
@@ -240,7 +242,7 @@ async function getBlobUploadPartUrl(
'getBlobUploadPartUrl'
);
return data.getBlobUploadPartUrl;
return data.workspace.blobUploadPartUrl;
}
async function setupWorkspace() {

View File

@@ -0,0 +1,89 @@
import { getUserQuery } from '@affine/graphql';
import Sinon from 'sinon';
import { ThrottlerStorage } from '../../../base/throttler';
import { app, e2e, Mockers } from '../test';
e2e('user(email) should return null without auth', async t => {
const user = await app.create(Mockers.User);
await app.logout();
const res = await app.gql({
query: getUserQuery,
variables: { email: user.email },
});
t.is(res.user, null);
});
e2e('user(email) should return null outside workspace scope', async t => {
await app.logout();
const me = await app.signup();
const other = await app.create(Mockers.User);
const res = await app.gql({
query: getUserQuery,
variables: { email: other.email },
});
t.is(res.user, null);
// sanity: querying self is always allowed
const self = await app.gql({
query: getUserQuery,
variables: { email: me.email },
});
t.truthy(self.user);
if (!self.user) return;
t.is(self.user.__typename, 'UserType');
if (self.user.__typename === 'UserType') {
t.is(self.user.id, me.id);
}
});
e2e('user(email) should return user within workspace scope', async t => {
await app.logout();
const me = await app.signup();
const other = await app.create(Mockers.User);
const ws = await app.create(Mockers.Workspace, { owner: me });
await app.create(Mockers.WorkspaceUser, {
workspaceId: ws.id,
userId: other.id,
});
const res = await app.gql({
query: getUserQuery,
variables: { email: other.email },
});
t.truthy(res.user);
if (!res.user) return;
t.is(res.user.__typename, 'UserType');
if (res.user.__typename === 'UserType') {
t.is(res.user.id, other.id);
}
});
e2e('user(email) should be rate limited', async t => {
await app.logout();
const me = await app.signup();
const stub = Sinon.stub(app.get(ThrottlerStorage), 'increment').resolves({
timeToExpire: 10,
totalHits: 21,
isBlocked: true,
timeToBlockExpire: 10,
});
await t.throwsAsync(
app.gql({
query: getUserQuery,
variables: { email: me.email },
}),
{ message: /too many requests/i }
);
stub.restore();
});

View File

@@ -17,17 +17,3 @@ Generated by [AVA](https://avajs.dev).
name: 'Free',
storageQuota: 10737418240,
}
## should get feature if extra fields exist in feature config
> Snapshot 1
{
blobLimit: 10485760,
businessBlobLimit: 104857600,
copilotActionLimit: 10,
historyPeriod: 604800000,
memberLimit: 3,
name: 'Free',
storageQuota: 10737418240,
}

View File

@@ -68,7 +68,7 @@ test("should be able to redirect to oauth provider's login page", async t => {
const res = await app
.POST('/api/oauth/preflight')
.send({ provider: 'Google' })
.send({ provider: 'Google', client_nonce: 'test-nonce' })
.expect(HttpStatus.OK);
const { url } = res.body;
@@ -100,7 +100,7 @@ test('should be able to redirect to oauth provider with multiple hosts', async t
const res = await app
.POST('/api/oauth/preflight')
.set('host', 'test.affine.dev')
.send({ provider: 'Google' })
.send({ provider: 'Google', client_nonce: 'test-nonce' })
.expect(HttpStatus.OK);
const { url } = res.body;
@@ -156,12 +156,45 @@ test('should be able to redirect to oauth provider with client_nonce', async t =
t.truthy(state.state);
});
test('should forbid preflight with untrusted redirect_uri', async t => {
const { app } = t.context;
await app
.POST('/api/oauth/preflight')
.send({
provider: 'Google',
redirect_uri: 'https://evil.example',
client_nonce: 'test-nonce',
})
.expect(HttpStatus.FORBIDDEN);
t.pass();
});
test('should throw if client_nonce is missing in preflight', async t => {
const { app } = t.context;
await app
.POST('/api/oauth/preflight')
.send({ provider: 'Google' })
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
code: 'Bad Request',
type: 'BAD_REQUEST',
name: 'MISSING_OAUTH_QUERY_PARAMETER',
message: 'Missing query parameter `client_nonce`.',
data: { name: 'client_nonce' },
});
t.pass();
});
test('should throw if provider is invalid', async t => {
const { app } = t.context;
await app
.POST('/api/oauth/preflight')
.send({ provider: 'Invalid' })
.send({ provider: 'Invalid', client_nonce: 'test-nonce' })
.expect(HttpStatus.BAD_REQUEST)
.expect({
status: 400,
@@ -320,7 +353,7 @@ test('should throw if provider is invalid in callback uri', async t => {
function mockOAuthProvider(
app: TestingApp,
email: string,
clientNonce?: string
clientNonce: string = randomUUID()
) {
const provider = app.get(GoogleOAuthProvider);
const oauth = app.get(OAuthService);
@@ -337,16 +370,18 @@ function mockOAuthProvider(
email,
avatarUrl: 'avatar',
});
return clientNonce;
}
test('should be able to sign up with oauth', async t => {
const { app, db } = t.context;
mockOAuthProvider(app, 'u2@affine.pro');
const clientNonce = mockOAuthProvider(app, 'u2@affine.pro');
await app
.POST('/api/oauth/callback')
.send({ code: '1', state: '1' })
.send({ code: '1', state: '1', client_nonce: clientNonce })
.expect(HttpStatus.OK);
const sessionUser = await currentUser(app);
@@ -427,11 +462,11 @@ test('should throw if client_nonce is invalid', async t => {
test('should not throw if account registered', async t => {
const { app, u1 } = t.context;
mockOAuthProvider(app, u1.email);
const clientNonce = mockOAuthProvider(app, u1.email);
const res = await app
.POST('/api/oauth/callback')
.send({ code: '1', state: '1' })
.send({ code: '1', state: '1', client_nonce: clientNonce })
.expect(HttpStatus.OK);
t.is(res.body.id, u1.id);
@@ -442,9 +477,11 @@ test('should be able to fullfil user with oauth sign in', async t => {
const u3 = await app.createUser('u3@affine.pro');
mockOAuthProvider(app, u3.email);
const clientNonce = mockOAuthProvider(app, u3.email);
await app.POST('/api/oauth/callback').send({ code: '1', state: '1' });
await app
.POST('/api/oauth/callback')
.send({ code: '1', state: '1', client_nonce: clientNonce });
const sessionUser = await currentUser(app);

View File

@@ -1,5 +1,339 @@
import test from 'ava';
import test, { type ExecutionContext } from 'ava';
import { io, type Socket as SocketIOClient } from 'socket.io-client';
import { Doc, encodeStateAsUpdate } from 'yjs';
test('should test through sync gateway', t => {
t.pass();
import { createTestingApp, TestingApp } from '../utils';
type WebsocketResponse<T> =
| { error: { name: string; message: string } }
| { data: T };
const WS_TIMEOUT_MS = 5_000;
function unwrapResponse<T>(t: ExecutionContext, res: WebsocketResponse<T>): T {
if ('data' in res) {
return res.data;
}
t.log(res);
throw new Error(`Websocket error: ${res.error.name}: ${res.error.message}`);
}
async function withTimeout<T>(
promise: Promise<T>,
timeoutMs: number,
label: string
) {
let timer: NodeJS.Timeout | undefined;
const timeout = new Promise<never>((_, reject) => {
timer = setTimeout(() => {
reject(new Error(`Timeout (${timeoutMs}ms): ${label}`));
}, timeoutMs);
});
try {
return await Promise.race([promise, timeout]);
} finally {
if (timer) clearTimeout(timer);
}
}
function createClient(url: string, cookie: string): SocketIOClient {
return io(url, {
transports: ['websocket'],
reconnection: false,
forceNew: true,
extraHeaders: {
cookie,
},
});
}
function waitForConnect(socket: SocketIOClient) {
if (socket.connected) {
return Promise.resolve();
}
return withTimeout(
new Promise<void>((resolve, reject) => {
socket.once('connect', resolve);
socket.once('connect_error', reject);
}),
WS_TIMEOUT_MS,
'socket connect'
);
}
function waitForDisconnect(socket: SocketIOClient) {
if (socket.disconnected) {
return Promise.resolve();
}
return withTimeout(
new Promise<void>(resolve => {
socket.once('disconnect', () => resolve());
}),
WS_TIMEOUT_MS,
'socket disconnect'
);
}
function emitWithAck<T>(socket: SocketIOClient, event: string, data: unknown) {
return withTimeout(
new Promise<WebsocketResponse<T>>(resolve => {
socket.emit(event, data, (res: WebsocketResponse<T>) => resolve(res));
}),
WS_TIMEOUT_MS,
`ack ${event}`
);
}
function waitForEvent<T>(socket: SocketIOClient, event: string) {
return withTimeout(
new Promise<T>(resolve => {
socket.once(event, (payload: T) => resolve(payload));
}),
WS_TIMEOUT_MS,
`event ${event}`
);
}
function expectNoEvent(
socket: SocketIOClient,
event: string,
durationMs = 200
) {
return withTimeout(
new Promise<void>((resolve, reject) => {
let timer: NodeJS.Timeout;
const onEvent = () => {
clearTimeout(timer);
socket.off(event, onEvent);
reject(new Error(`Unexpected event received: ${event}`));
};
timer = setTimeout(() => {
socket.off(event, onEvent);
resolve();
}, durationMs);
socket.on(event, onEvent);
}),
WS_TIMEOUT_MS,
`expect no event ${event}`
);
}
async function login(app: TestingApp) {
const user = await app.createUser('u1@affine.pro');
const res = await app
.POST('/api/auth/sign-in')
.send({ email: user.email, password: user.password })
.expect(200);
const cookies = res.get('Set-Cookie') ?? [];
const cookieHeader = cookies.map(c => c.split(';')[0]).join('; ');
return { user, cookieHeader };
}
function createYjsUpdateBase64() {
const doc = new Doc();
doc.getMap('m').set('k', 'v');
const update = encodeStateAsUpdate(doc);
return Buffer.from(update).toString('base64');
}
let app: TestingApp;
let url: string;
test.before(async () => {
app = await createTestingApp();
url = app.url();
});
test.beforeEach(async () => {
await app.initTestingDB();
});
test.after.always(async () => {
await app.close();
});
test('clientVersion=0.25.0 should only receive space:broadcast-doc-update', async t => {
const { user, cookieHeader } = await login(app);
const spaceId = user.id;
const update = createYjsUpdateBase64();
const sender = createClient(url, cookieHeader);
const receiver = createClient(url, cookieHeader);
try {
await Promise.all([waitForConnect(sender), waitForConnect(receiver)]);
const receiverJoin = unwrapResponse(
t,
await emitWithAck<{ clientId: string; success: boolean }>(
receiver,
'space:join',
{ spaceType: 'userspace', spaceId, clientVersion: '0.25.0' }
)
);
t.true(receiverJoin.success);
const senderJoin = unwrapResponse(
t,
await emitWithAck<{ clientId: string; success: boolean }>(
sender,
'space:join',
{ spaceType: 'userspace', spaceId, clientVersion: '0.26.0' }
)
);
t.true(senderJoin.success);
const onUpdate = waitForEvent<{
spaceType: string;
spaceId: string;
docId: string;
update: string;
}>(receiver, 'space:broadcast-doc-update');
const noUpdates = expectNoEvent(receiver, 'space:broadcast-doc-updates');
const pushRes = await emitWithAck<{ accepted: true; timestamp?: number }>(
sender,
'space:push-doc-update',
{
spaceType: 'userspace',
spaceId,
docId: 'doc-1',
update,
}
);
unwrapResponse(t, pushRes);
const message = await onUpdate;
t.is(message.spaceType, 'userspace');
t.is(message.spaceId, spaceId);
t.is(message.docId, 'doc-1');
t.is(message.update, update);
await noUpdates;
} finally {
sender.disconnect();
receiver.disconnect();
}
});
test('clientVersion>=0.26.0 should only receive space:broadcast-doc-updates', async t => {
const { user, cookieHeader } = await login(app);
const spaceId = user.id;
const update = createYjsUpdateBase64();
const sender = createClient(url, cookieHeader);
const receiver = createClient(url, cookieHeader);
try {
await Promise.all([waitForConnect(sender), waitForConnect(receiver)]);
const receiverJoin = unwrapResponse(
t,
await emitWithAck<{ clientId: string; success: boolean }>(
receiver,
'space:join',
{ spaceType: 'userspace', spaceId, clientVersion: '0.26.0' }
)
);
t.true(receiverJoin.success);
const senderJoin = unwrapResponse(
t,
await emitWithAck<{ clientId: string; success: boolean }>(
sender,
'space:join',
{ spaceType: 'userspace', spaceId, clientVersion: '0.25.0' }
)
);
t.true(senderJoin.success);
const onUpdates = waitForEvent<{
spaceType: string;
spaceId: string;
docId: string;
updates: string[];
}>(receiver, 'space:broadcast-doc-updates');
const noUpdate = expectNoEvent(receiver, 'space:broadcast-doc-update');
const pushRes = await emitWithAck<{ accepted: true; timestamp?: number }>(
sender,
'space:push-doc-update',
{
spaceType: 'userspace',
spaceId,
docId: 'doc-2',
update,
}
);
unwrapResponse(t, pushRes);
const message = await onUpdates;
t.is(message.spaceType, 'userspace');
t.is(message.spaceId, spaceId);
t.is(message.docId, 'doc-2');
t.deepEqual(message.updates, [update]);
await noUpdate;
} finally {
sender.disconnect();
receiver.disconnect();
}
});
test('clientVersion<0.25.0 should be rejected and disconnected', async t => {
const { user, cookieHeader } = await login(app);
const spaceId = user.id;
const socket = createClient(url, cookieHeader);
try {
await waitForConnect(socket);
const res = unwrapResponse(
t,
await emitWithAck<{ clientId: string; success: boolean }>(
socket,
'space:join',
{ spaceType: 'userspace', spaceId, clientVersion: '0.24.4' }
)
);
t.false(res.success);
await waitForDisconnect(socket);
} finally {
socket.disconnect();
}
});
test('space:join-awareness should reject clientVersion<0.25.0', async t => {
const { user, cookieHeader } = await login(app);
const spaceId = user.id;
const socket = createClient(url, cookieHeader);
try {
await waitForConnect(socket);
const res = unwrapResponse(
t,
await emitWithAck<{ clientId: string; success: boolean }>(
socket,
'space:join-awareness',
{
spaceType: 'userspace',
spaceId,
docId: 'doc-awareness',
clientVersion: '0.24.4',
}
)
);
t.false(res.success);
await waitForDisconnect(socket);
} finally {
socket.disconnect();
}
});

View File

@@ -152,9 +152,13 @@ export async function getBlobUploadPartUrl(
) {
const res = await app.gql(
`
mutation getBlobUploadPartUrl($workspaceId: String!, $key: String!, $uploadId: String!, $partNumber: Int!) {
getBlobUploadPartUrl(workspaceId: $workspaceId, key: $key, uploadId: $uploadId, partNumber: $partNumber) {
uploadUrl
query getBlobUploadPartUrl($workspaceId: String!, $key: String!, $uploadId: String!, $partNumber: Int!) {
workspace(id: $workspaceId) {
blobUploadPartUrl(key: $key, uploadId: $uploadId, partNumber: $partNumber) {
uploadUrl
headers
expiresAt
}
}
}
`,
@@ -165,5 +169,5 @@ export async function getBlobUploadPartUrl(
partNumber,
}
);
return res.getBlobUploadPartUrl;
return res.workspace.blobUploadPartUrl;
}

View File

@@ -250,7 +250,6 @@ export async function listContext(
export async function addContextFile(
app: TestingApp,
contextId: string,
blobId: string,
fileName: string,
content: Buffer
): Promise<{ id: string }> {
@@ -269,7 +268,7 @@ export async function addContextFile(
`,
variables: {
content: null,
options: { contextId, blobId },
options: { contextId },
},
})
)

View File

@@ -139,11 +139,11 @@ export async function revokeUser(
): Promise<boolean> {
const res = await app.gql(`
mutation {
revoke(workspaceId: "${workspaceId}", userId: "${userId}")
revokeMember(workspaceId: "${workspaceId}", userId: "${userId}")
}
`);
return res.revoke;
return res.revokeMember;
}
export async function getInviteInfo(

View File

@@ -14,6 +14,7 @@ import {
GlobalExceptionFilter,
JobQueue,
} from '../../base';
import { SocketIoAdapter } from '../../base/websocket';
import { AuthService } from '../../core/auth';
import { Mailer } from '../../core/mail';
import { UserModel } from '../../models';
@@ -61,6 +62,7 @@ export async function createTestingApp(
);
app.use(cookieParser());
app.useWebSocketAdapter(new SocketIoAdapter(app));
if (moduleDef.tapApp) {
moduleDef.tapApp(app);
@@ -89,6 +91,7 @@ export function parseCookies(res: supertest.Response) {
export class TestingApp extends ApplyType<INestApplication>() {
private sessionCookie: string | null = null;
private currentUserCookie: string | null = null;
private csrfCookie: string | null = null;
private readonly userCookies: Set<string> = new Set();
readonly create!: ReturnType<typeof createFactory>;
@@ -103,6 +106,7 @@ export class TestingApp extends ApplyType<INestApplication>() {
await initTestingDB(this);
this.sessionCookie = null;
this.currentUserCookie = null;
this.csrfCookie = null;
this.userCookies.clear();
}
@@ -118,12 +122,23 @@ export class TestingApp extends ApplyType<INestApplication>() {
method: 'options' | 'get' | 'post' | 'put' | 'delete' | 'patch',
path: string
): supertest.Test {
return supertest(this.getHttpServer())
const cookies = [
`${AuthService.sessionCookieName}=${this.sessionCookie ?? ''}`,
`${AuthService.userCookieName}=${this.currentUserCookie ?? ''}`,
];
if (this.csrfCookie) {
cookies.push(`${AuthService.csrfCookieName}=${this.csrfCookie}`);
}
const req = supertest(this.getHttpServer())
[method](path)
.set('Cookie', [
`${AuthService.sessionCookieName}=${this.sessionCookie ?? ''}`,
`${AuthService.userCookieName}=${this.currentUserCookie ?? ''}`,
]);
.set('Cookie', cookies);
if (this.csrfCookie) {
req.set('x-affine-csrf-token', this.csrfCookie);
}
return req;
}
OPTIONS(path: string): supertest.Test {
@@ -147,6 +162,9 @@ export class TestingApp extends ApplyType<INestApplication>() {
this.sessionCookie = cookies[AuthService.sessionCookieName];
this.currentUserCookie = cookies[AuthService.userCookieName];
if (AuthService.csrfCookieName in cookies) {
this.csrfCookie = cookies[AuthService.csrfCookieName] || null;
}
if (this.currentUserCookie) {
this.userCookies.add(this.currentUserCookie);
}
@@ -270,13 +288,17 @@ export class TestingApp extends ApplyType<INestApplication>() {
}
async logout(userId?: string) {
const res = await this.GET(
const res = await this.POST(
'/api/auth/sign-out' + (userId ? `?user_id=${userId}` : '')
).expect(200);
const cookies = parseCookies(res);
this.sessionCookie = cookies[AuthService.sessionCookieName];
if (AuthService.csrfCookieName in cookies) {
this.csrfCookie = cookies[AuthService.csrfCookieName] || null;
}
if (!this.sessionCookie) {
this.currentUserCookie = null;
this.csrfCookie = null;
this.userCookies.clear();
} else {
this.currentUserCookie = cookies[AuthService.userCookieName];

View File

@@ -188,10 +188,10 @@ export async function revokeMember(
const res = await app.gql(
`
mutation {
revoke(workspaceId: "${workspaceId}", userId: "${userId}")
revokeMember(workspaceId: "${workspaceId}", userId: "${userId}")
}
`
);
return res.revoke;
return res.revokeMember;
}

View File

@@ -27,7 +27,7 @@ function checkVersion(enabled = true) {
client: {
versionControl: {
enabled,
requiredVersion: '>=0.20.0',
requiredVersion: '>=0.25.0',
},
},
});
@@ -88,23 +88,23 @@ test('should passthrough is version range is invalid', async t => {
});
test('should pass if client version is allowed', async t => {
let res = await app.GET('/guarded/test').set('x-affine-version', '0.20.0');
let res = await app.GET('/guarded/test').set('x-affine-version', '0.25.0');
t.is(res.status, 200);
res = await app.GET('/guarded/test').set('x-affine-version', '0.21.0');
res = await app.GET('/guarded/test').set('x-affine-version', '0.26.0');
t.is(res.status, 200);
config.override({
client: {
versionControl: {
requiredVersion: '>=0.19.0',
requiredVersion: '>=0.25.0',
},
},
});
res = await app.GET('/guarded/test').set('x-affine-version', '0.19.0');
res = await app.GET('/guarded/test').set('x-affine-version', '0.25.0');
t.is(res.status, 200);
});
@@ -115,7 +115,7 @@ test('should fail if client version is not set or invalid', async t => {
t.is(res.status, 403);
t.is(
res.body.message,
'Unsupported client with version [unset_or_invalid], required version is [>=0.20.0].'
'Unsupported client with version [unset_or_invalid], required version is [>=0.25.0].'
);
res = await app.GET('/guarded/test').set('x-affine-version', 'invalid');
@@ -123,7 +123,7 @@ test('should fail if client version is not set or invalid', async t => {
t.is(res.status, 403);
t.is(
res.body.message,
'Unsupported client with version [invalid], required version is [>=0.20.0].'
'Unsupported client with version [invalid], required version is [>=0.25.0].'
);
});
@@ -131,17 +131,17 @@ test('should tell upgrade if client version is lower than allowed', async t => {
config.override({
client: {
versionControl: {
requiredVersion: '>=0.21.0 <=0.22.0',
requiredVersion: '>=0.26.0 <=0.27.0',
},
},
});
let res = await app.GET('/guarded/test').set('x-affine-version', '0.20.0');
let res = await app.GET('/guarded/test').set('x-affine-version', '0.25.0');
t.is(res.status, 403);
t.is(
res.body.message,
'Unsupported client with version [0.20.0], required version is [>=0.21.0 <=0.22.0].'
'Unsupported client with version [0.25.0], required version is [>=0.26.0 <=0.27.0].'
);
});
@@ -149,17 +149,17 @@ test('should tell downgrade if client version is higher than allowed', async t =
config.override({
client: {
versionControl: {
requiredVersion: '>=0.20.0 <=0.22.0',
requiredVersion: '>=0.25.0 <=0.26.0',
},
},
});
let res = await app.GET('/guarded/test').set('x-affine-version', '0.23.0');
let res = await app.GET('/guarded/test').set('x-affine-version', '0.27.0');
t.is(res.status, 403);
t.is(
res.body.message,
'Unsupported client with version [0.23.0], required version is [>=0.20.0 <=0.22.0].'
'Unsupported client with version [0.27.0], required version is [>=0.25.0 <=0.26.0].'
);
});
@@ -167,25 +167,25 @@ test('should test prerelease version', async t => {
config.override({
client: {
versionControl: {
requiredVersion: '>=0.19.0',
requiredVersion: '>=0.25.0',
},
},
});
let res = await app
.GET('/guarded/test')
.set('x-affine-version', '0.19.0-canary.1');
.set('x-affine-version', '0.25.0-canary.1');
// 0.19.0-canary.1 is lower than 0.19.0 obviously
// 0.25.0-canary.1 is lower than 0.25.0 obviously
t.is(res.status, 403);
res = await app
.GET('/guarded/test')
.set('x-affine-version', '0.20.0-canary.1');
.set('x-affine-version', '0.26.0-canary.1');
t.is(res.status, 200);
res = await app.GET('/guarded/test').set('x-affine-version', '0.20.0-beta.2');
res = await app.GET('/guarded/test').set('x-affine-version', '0.26.0-beta.2');
t.is(res.status, 200);
});

View File

@@ -1,8 +1,14 @@
import type { ExecutionContext, TestFn } from 'ava';
import ava from 'ava';
import { LookupAddress } from 'dns';
import Sinon from 'sinon';
import type { Response } from 'supertest';
import {
__resetDnsLookupForTests,
__setDnsLookupForTests,
type DnsLookup,
} from '../base/utils/ssrf';
import { createTestingApp, TestingApp } from './utils';
type TestContext = {
@@ -11,15 +17,30 @@ type TestContext = {
const test = ava as TestFn<TestContext>;
const LookupAddressStub = (async (_hostname, options) => {
const result = [{ address: '76.76.21.21', family: 4 }] as LookupAddress[];
const isOptions = options && typeof options === 'object';
if (isOptions && 'all' in options && options.all) {
return result;
}
return result[0];
}) as DnsLookup;
test.before(async t => {
// @ts-expect-error test
env.DEPLOYMENT_TYPE = 'selfhosted';
// Avoid relying on real DNS during tests. SSRF protection uses dns.lookup().
__setDnsLookupForTests(LookupAddressStub);
const app = await createTestingApp();
t.context.app = app;
});
test.after.always(async t => {
Sinon.restore();
__resetDnsLookupForTests();
await t.context.app.close();
});
@@ -29,7 +50,8 @@ const assertAndSnapshotRaw = async (
message: string,
options?: {
status?: number;
origin?: string;
origin?: string | null;
referer?: string | null;
method?: 'GET' | 'OPTIONS' | 'POST';
body?: any;
checker?: (res: Response) => any;
@@ -37,16 +59,21 @@ const assertAndSnapshotRaw = async (
) => {
const {
status = 200,
origin = 'http://localhost',
origin = 'http://localhost:3010',
referer,
method = 'GET',
checker = () => {},
} = options || {};
const { app } = t.context;
const res = app[method](route)
.set('Origin', origin)
.send(options?.body)
.expect(status)
.expect(checker);
const req = app[method](route);
if (origin) {
req.set('Origin', origin);
}
if (referer) {
req.set('Referer', referer);
}
const res = req.send(options?.body).expect(status).expect(checker);
await t.notThrowsAsync(res, message);
t.snapshot((await res).body);
};
@@ -76,6 +103,14 @@ test('should proxy image', async t => {
);
}
{
await assertAndSnapshot(
'/api/worker/image-proxy?url=http://example.com/image.png',
'should return 400 if origin and referer are missing',
{ status: 400, origin: null, referer: null }
);
}
{
await assertAndSnapshot(
'/api/worker/image-proxy?url=http://example.com/image.png',
@@ -86,17 +121,13 @@ test('should proxy image', async t => {
{
const fakeBuffer = Buffer.from('fake image');
const fakeResponse = {
ok: true,
const fakeResponse = new Response(fakeBuffer, {
status: 200,
headers: {
get: (header: string) => {
if (header.toLowerCase() === 'content-type') return 'image/png';
if (header.toLowerCase() === 'content-disposition') return 'inline';
return null;
},
'content-type': 'image/png',
'content-disposition': 'inline',
},
arrayBuffer: async () => fakeBuffer,
} as any;
});
const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeResponse);
@@ -132,6 +163,18 @@ test('should preview link', async t => {
{ status: 400, method: 'POST' }
);
await assertAndSnapshot(
'/api/worker/link-preview',
'should return 400 if origin and referer are missing',
{
status: 400,
method: 'POST',
origin: null,
referer: null,
body: { url: 'http://external.com/page' },
}
);
await assertAndSnapshot(
'/api/worker/link-preview',
'should return 400 if provided URL is from the same origin',

View File

@@ -275,6 +275,26 @@ export const USER_FRIENDLY_ERRORS = {
args: { message: 'string' },
message: ({ message }) => `HTTP request error, message: ${message}`,
},
ssrf_blocked_error: {
type: 'invalid_input',
args: { reason: 'string' },
message: ({ reason }) => {
switch (reason) {
case 'unresolvable_hostname':
return 'Failed to resolve hostname';
case 'too_many_redirects':
return 'Too many redirects';
default:
return 'Invalid URL';
}
},
},
response_too_large_error: {
type: 'invalid_input',
args: { limitBytes: 'number', receivedBytes: 'number' },
message: ({ limitBytes, receivedBytes }) =>
`Response too large (${receivedBytes} bytes), limit is ${limitBytes} bytes`,
},
email_service_not_configured: {
type: 'internal_server_error',
message: 'Email service is not configured.',

View File

@@ -54,6 +54,27 @@ export class HttpRequestError extends UserFriendlyError {
super('bad_request', 'http_request_error', message, args);
}
}
@ObjectType()
class SsrfBlockedErrorDataType {
@Field() reason!: string
}
export class SsrfBlockedError extends UserFriendlyError {
constructor(args: SsrfBlockedErrorDataType, message?: string | ((args: SsrfBlockedErrorDataType) => string)) {
super('invalid_input', 'ssrf_blocked_error', message, args);
}
}
@ObjectType()
class ResponseTooLargeErrorDataType {
@Field() limitBytes!: number
@Field() receivedBytes!: number
}
export class ResponseTooLargeError extends UserFriendlyError {
constructor(args: ResponseTooLargeErrorDataType, message?: string | ((args: ResponseTooLargeErrorDataType) => string)) {
super('invalid_input', 'response_too_large_error', message, args);
}
}
export class EmailServiceNotConfigured extends UserFriendlyError {
constructor(message?: string) {
@@ -1131,6 +1152,8 @@ export enum ErrorNames {
BAD_REQUEST,
GRAPHQL_BAD_REQUEST,
HTTP_REQUEST_ERROR,
SSRF_BLOCKED_ERROR,
RESPONSE_TOO_LARGE_ERROR,
EMAIL_SERVICE_NOT_CONFIGURED,
QUERY_TOO_LONG,
VALIDATION_ERROR,
@@ -1274,5 +1297,5 @@ registerEnumType(ErrorNames, {
export const ErrorDataUnionType = createUnionType({
name: 'ErrorDataUnion',
types: () =>
[GraphqlBadRequestDataType, HttpRequestErrorDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const,
[GraphqlBadRequestDataType, HttpRequestErrorDataType, SsrfBlockedErrorDataType, ResponseTooLargeErrorDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const,
});

View File

@@ -1,3 +1,5 @@
import { generateKeyPairSync } from 'node:crypto';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
@@ -7,11 +9,20 @@ const test = ava as TestFn<{
crypto: CryptoHelper;
}>;
const privateKey = `-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgS3IAkshQuSmFWGpe
rGTg2vwaC3LdcvBQlYHHMBYJZMyhRANCAAQXdT/TAh4neNEpd4UqpDIEqWv0XvFo
BRJxGsC5I/fetqObdx1+KEjcm8zFU2xLaUTw9IZCu8OslloOjQv4ur0a
-----END PRIVATE KEY-----`;
function generateTestPrivateKey(): string {
const { privateKey } = generateKeyPairSync('ec', {
namedCurve: 'prime256v1',
});
return privateKey
.export({
type: 'pkcs8',
format: 'pem',
})
.toString();
}
const privateKey = generateTestPrivateKey();
const privateKey2 = generateTestPrivateKey();
test.beforeEach(async t => {
t.context.crypto = new CryptoHelper({
@@ -30,6 +41,21 @@ test('should be able to sign and verify', t => {
t.false(t.context.crypto.verify(`${data},fake-signature`));
});
test('should verify signatures across key rotation', t => {
const data = 'hello world';
const signatureV1 = t.context.crypto.sign(data);
t.true(t.context.crypto.verify(signatureV1));
(t.context.crypto as any).config.crypto.privateKey = privateKey2;
t.context.crypto.onConfigChanged({
updates: { crypto: { privateKey: privateKey2 } },
} as any);
const signatureV2 = t.context.crypto.sign(data);
t.true(t.context.crypto.verify(signatureV1));
t.true(t.context.crypto.verify(signatureV2));
});
test('should same data should get different signature', t => {
const data = 'hello world';
const signature = t.context.crypto.sign(data);
@@ -46,11 +72,12 @@ test('should be able to encrypt and decrypt', t => {
);
const encrypted = t.context.crypto.encrypt(data);
const encrypted2 = t.context.crypto.encrypt(data);
const decrypted = t.context.crypto.decrypt(encrypted);
// we are using a stub to make sure the iv is always 0,
// the encrypted result will always be the same
t.is(encrypted, 'AAAAAAAAAAAAAAAAOXbR/9glITL3BcO3kPd6fGOMasSkPQ==');
// the encrypted result will always be the same for the same key+data
t.is(encrypted2, encrypted);
t.is(decrypted, data);
stub.restore();
@@ -75,6 +102,24 @@ test('should be able to safe compare', t => {
t.false(t.context.crypto.compare('abc', 'def'));
});
test('should sign and parse internal access token', t => {
const token = t.context.crypto.signInternalAccessToken({
method: 'GET',
path: '/rpc/workspaces/123/docs/456',
now: 1700000000000,
nonce: 'nonce-123',
});
const payload = t.context.crypto.parseInternalAccessToken(token);
t.deepEqual(payload, {
v: 1,
ts: 1700000000000,
nonce: 'nonce-123',
m: 'GET',
p: '/rpc/workspaces/123/docs/456',
});
});
test('should be able to hash and verify password', async t => {
const password = 'mySecurePassword';
const hash = await t.context.crypto.encryptPassword(password);

View File

@@ -1,6 +1,7 @@
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import { ActionForbidden } from '../../error';
import { URLHelper } from '../url';
const test = ava as TestFn<{
@@ -85,6 +86,30 @@ test('can create link', t => {
);
});
test('can validate callbackUrl allowlist', t => {
t.true(t.context.url.isAllowedCallbackUrl('/magic-link'));
t.true(
t.context.url.isAllowedCallbackUrl('https://app.affine.local/magic-link')
);
t.false(
t.context.url.isAllowedCallbackUrl('https://evil.example/magic-link')
);
});
test('can validate redirect_uri allowlist', t => {
t.true(t.context.url.isAllowedRedirectUri('/redirect-proxy'));
t.true(t.context.url.isAllowedRedirectUri('https://github.com'));
t.false(t.context.url.isAllowedRedirectUri('javascript:alert(1)'));
t.false(t.context.url.isAllowedRedirectUri('https://evilgithub.com'));
});
test('can create safe link', t => {
t.is(t.context.url.safeLink('/path'), 'https://app.affine.local/path');
t.throws(() => t.context.url.safeLink('https://evil.example/magic-link'), {
instanceOf: ActionForbidden,
});
});
test('can safe redirect', t => {
const res = {
redirect: (to: string) => to,

View File

@@ -76,6 +76,8 @@ export class CryptoHelper implements OnModuleInit {
};
};
private previousPublicKeys: KeyObject[] = [];
AFFiNEProPublicKey: Buffer | null = null;
AFFiNEProLicenseAESKey: Buffer | null = null;
@@ -101,12 +103,23 @@ export class CryptoHelper implements OnModuleInit {
}
private setup() {
const prevPublicKey = this.keyPair?.publicKey;
const privateKey = this.config.crypto.privateKey || generatePrivateKey();
const { priv, pub } = parseKey(privateKey);
const publicKey = pub
.export({ format: 'pem', type: 'spki' })
.toString('utf8');
if (prevPublicKey) {
const prevPem = prevPublicKey
.export({ format: 'pem', type: 'spki' })
.toString('utf8');
if (prevPem !== publicKey) {
this.previousPublicKeys.unshift(prevPublicKey);
this.previousPublicKeys = this.previousPublicKeys.slice(0, 2);
}
}
this.keyPair = {
publicKey: pub,
privateKey: priv,
@@ -143,15 +156,81 @@ export class CryptoHelper implements OnModuleInit {
}
const input = Buffer.from(data, 'utf-8');
const sigBuf = Buffer.from(signature, 'base64');
if (this.keyType === 'ed25519') {
// Ed25519 verifies the message directly
return verify(null, input, this.keyPair.publicKey, sigBuf);
} else {
// ECDSA with SHA-256
const verify = createVerify('sha256');
verify.update(input);
verify.end();
return verify.verify(this.keyPair.publicKey, sigBuf);
const keys = [this.keyPair.publicKey, ...this.previousPublicKeys];
return keys.some(publicKey => {
const keyType = (publicKey.asymmetricKeyType as string) || 'ec';
if (keyType === 'ed25519') {
// Ed25519 verifies the message directly
return verify(null, input, publicKey, sigBuf);
} else {
// ECDSA with SHA-256
const verifier = createVerify('sha256');
verifier.update(input);
verifier.end();
return verifier.verify(publicKey, sigBuf);
}
});
}
signInternalAccessToken(input: {
method: string;
path: string;
now?: number;
nonce?: string;
}) {
const payload = {
v: 1 as const,
ts: input.now ?? Date.now(),
nonce: input.nonce ?? this.randomBytes(16).toString('base64url'),
m: input.method.toUpperCase(),
p: input.path,
};
const data = Buffer.from(JSON.stringify(payload), 'utf8').toString(
'base64url'
);
return this.sign(data);
}
parseInternalAccessToken(signatureWithData: string): {
v: 1;
ts: number;
nonce: string;
m: string;
p: string;
} | null {
const [data, signature] = signatureWithData.split(',');
if (!signature) {
return null;
}
if (!this.verify(signatureWithData)) {
return null;
}
try {
const json = Buffer.from(data, 'base64url').toString('utf8');
const payload = JSON.parse(json) as unknown;
if (!payload || typeof payload !== 'object') {
return null;
}
const val = payload as {
v?: unknown;
ts?: unknown;
nonce?: unknown;
m?: unknown;
p?: unknown;
};
if (
val.v !== 1 ||
typeof val.ts !== 'number' ||
typeof val.nonce !== 'string' ||
typeof val.m !== 'string' ||
typeof val.p !== 'string'
) {
return null;
}
return { v: 1, ts: val.ts, nonce: val.nonce, m: val.m, p: val.p };
} catch {
return null;
}
}

View File

@@ -5,8 +5,31 @@ import type { Response } from 'express';
import { ClsService } from 'nestjs-cls';
import { Config } from '../config';
import { ActionForbidden } from '../error';
import { OnEvent } from '../event';
const ALLOWED_REDIRECT_PROTOCOLS = new Set(['http:', 'https:']);
// Keep in sync with frontend /redirect-proxy allowlist.
const TRUSTED_REDIRECT_DOMAINS = [
'google.com',
'stripe.com',
'github.com',
'twitter.com',
'discord.gg',
'youtube.com',
't.me',
'reddit.com',
'affine.pro',
].map(d => d.toLowerCase());
function normalizeHostname(hostname: string) {
return hostname.toLowerCase().replace(/\.$/, '');
}
function hostnameMatchesDomain(hostname: string, domain: string) {
return hostname === domain || hostname.endsWith(`.${domain}`);
}
@Injectable()
export class URLHelper {
redirectAllowHosts!: string[];
@@ -110,6 +133,13 @@ export class URLHelper {
return this.url(path, query).toString();
}
safeLink(path: string, query: Record<string, any> = {}) {
if (!this.isAllowedCallbackUrl(path)) {
throw new ActionForbidden();
}
return this.link(path, query);
}
safeRedirect(res: Response, to: string) {
try {
const finalTo = new URL(decodeURIComponent(to), this.requestBaseUrl);
@@ -131,6 +161,68 @@ export class URLHelper {
return res.redirect(this.baseUrl);
}
isAllowedCallbackUrl(url: string): boolean {
if (!url) {
return false;
}
// Allow same-app relative paths (e.g. `/magic-link?...`).
if (url.startsWith('/') && !url.startsWith('//')) {
return true;
}
try {
const u = new URL(url);
if (!ALLOWED_REDIRECT_PROTOCOLS.has(u.protocol)) {
return false;
}
if (u.username || u.password) {
return false;
}
return this.allowedOrigins.includes(u.origin);
} catch {
return false;
}
}
isAllowedRedirectUri(redirectUri: string): boolean {
if (!redirectUri) {
return false;
}
// Allow internal navigation (e.g. `/` or `/redirect-proxy?...`).
if (redirectUri.startsWith('/') && !redirectUri.startsWith('//')) {
return true;
}
try {
const u = new URL(redirectUri);
if (!ALLOWED_REDIRECT_PROTOCOLS.has(u.protocol)) {
return false;
}
if (u.username || u.password) {
return false;
}
const hostname = normalizeHostname(u.hostname);
// Allow server known hosts.
for (const origin of this.allowedOrigins) {
const allowedHost = normalizeHostname(new URL(origin).hostname);
if (hostname === allowedHost) {
return true;
}
}
// Allow known trusted domains (for redirect-proxy).
return TRUSTED_REDIRECT_DOMAINS.some(domain =>
hostnameMatchesDomain(hostname, domain)
);
} catch {
return false;
}
}
verify(url: string | URL) {
try {
if (typeof url === 'string') {

View File

@@ -1,6 +1,7 @@
export * from './duration';
export * from './promise';
export * from './request';
export * from './ssrf';
export * from './stream';
export * from './types';
export * from './unit';

View File

@@ -0,0 +1,364 @@
import * as dns from 'node:dns/promises';
import { BlockList, isIP } from 'node:net';
import { Readable } from 'node:stream';
import { ResponseTooLargeError, SsrfBlockedError } from '../error/errors.gen';
import { OneMinute } from './unit';
const DEFAULT_ALLOWED_PROTOCOLS = new Set(['http:', 'https:']);
const BLOCKED_IPS = new BlockList();
const ALLOWED_IPV6 = new BlockList();
export type DnsLookup = typeof dns.lookup;
let dnsLookup: DnsLookup = dns.lookup;
export function __setDnsLookupForTests(lookup: DnsLookup) {
dnsLookup = lookup;
}
export function __resetDnsLookupForTests() {
dnsLookup = dns.lookup;
}
export type SSRFBlockReason =
| 'invalid_url'
| 'disallowed_protocol'
| 'url_has_credentials'
| 'blocked_hostname'
| 'unresolvable_hostname'
| 'blocked_ip'
| 'too_many_redirects';
type SsrfErrorContext = { url?: string; hostname?: string; address?: string };
function createSsrfBlockedError(
reason: SSRFBlockReason,
context?: SsrfErrorContext
) {
const err = new SsrfBlockedError({ reason });
// For logging/debugging only (not part of UserFriendlyError JSON).
(err as any).context = context;
return err;
}
export interface SSRFProtectionOptions {
allowedProtocols?: ReadonlySet<string>;
/**
* Allow fetching private/reserved IPs when URL.origin is allowlisted.
* Defaults to an empty allowlist (i.e. private IPs are blocked).
*/
allowPrivateOrigins?: ReadonlySet<string>;
}
function stripZoneId(address: string) {
const idx = address.indexOf('%');
return idx === -1 ? address : address.slice(0, idx);
}
// IPv4: RFC1918 + loopback + link-local + CGNAT + special/reserved
for (const [network, prefix] of [
['0.0.0.0', 8],
['10.0.0.0', 8],
['127.0.0.0', 8],
['169.254.0.0', 16],
['172.16.0.0', 12],
['192.168.0.0', 16],
['100.64.0.0', 10], // CGNAT
['192.0.0.0', 24],
['192.0.2.0', 24], // TEST-NET-1
['198.51.100.0', 24], // TEST-NET-2
['203.0.113.0', 24], // TEST-NET-3
['198.18.0.0', 15], // benchmark
['192.88.99.0', 24], // 6to4 relay
['224.0.0.0', 4], // multicast
['240.0.0.0', 4], // reserved (includes broadcast)
] as const) {
BLOCKED_IPS.addSubnet(network, prefix, 'ipv4');
}
// IPv6: block loopback/unspecified/link-local/ULA/multicast/doc; allow only global unicast.
BLOCKED_IPS.addAddress('::', 'ipv6');
BLOCKED_IPS.addAddress('::1', 'ipv6');
BLOCKED_IPS.addSubnet('ff00::', 8, 'ipv6'); // multicast
BLOCKED_IPS.addSubnet('fc00::', 7, 'ipv6'); // unique local
BLOCKED_IPS.addSubnet('fe80::', 10, 'ipv6'); // link-local
BLOCKED_IPS.addSubnet('2001:db8::', 32, 'ipv6'); // documentation
ALLOWED_IPV6.addSubnet('2000::', 3, 'ipv6'); // global unicast
function extractEmbeddedIPv4FromIPv6(address: string): string | null {
if (!address.includes('.')) {
return null;
}
const idx = address.lastIndexOf(':');
if (idx === -1) {
return null;
}
const tail = address.slice(idx + 1);
return isIP(tail) === 4 ? tail : null;
}
function isBlockedIpAddress(address: string): boolean {
const ip = stripZoneId(address);
const family = isIP(ip);
if (family === 4) {
return BLOCKED_IPS.check(ip, 'ipv4');
}
if (family === 6) {
const embeddedV4 = extractEmbeddedIPv4FromIPv6(ip);
if (embeddedV4) {
return isBlockedIpAddress(embeddedV4);
}
if (!ALLOWED_IPV6.check(ip, 'ipv6')) {
return true;
}
return BLOCKED_IPS.check(ip, 'ipv6');
}
return true;
}
async function resolveHostAddresses(hostname: string): Promise<string[]> {
// Normalize common localhost aliases without DNS.
const lowered = hostname.toLowerCase();
if (lowered === 'localhost' || lowered.endsWith('.localhost')) {
return ['127.0.0.1', '::1'];
}
const results = await dnsLookup(hostname, {
all: true,
verbatim: true,
});
return results.map(r => r.address);
}
export async function assertSsrFSafeUrl(
rawUrl: string | URL,
options: SSRFProtectionOptions = {}
): Promise<URL> {
const allowedProtocols =
options.allowedProtocols ?? DEFAULT_ALLOWED_PROTOCOLS;
let url: URL;
try {
url = rawUrl instanceof URL ? rawUrl : new URL(rawUrl);
} catch {
throw createSsrfBlockedError('invalid_url', {
url: typeof rawUrl === 'string' ? rawUrl : undefined,
});
}
if (!allowedProtocols.has(url.protocol)) {
throw createSsrfBlockedError('disallowed_protocol', {
url: url.toString(),
});
}
if (url.username || url.password) {
throw createSsrfBlockedError('url_has_credentials', {
url: url.toString(),
});
}
const hostname = url.hostname;
if (!hostname) {
throw createSsrfBlockedError('blocked_hostname', { url: url.toString() });
}
const allowPrivate =
options.allowPrivateOrigins && options.allowPrivateOrigins.has(url.origin);
// IP literal
if (isIP(hostname)) {
if (isBlockedIpAddress(hostname) && !allowPrivate) {
throw createSsrfBlockedError('blocked_ip', {
url: url.toString(),
address: hostname,
});
}
return url;
}
let addresses: string[];
try {
addresses = await resolveHostAddresses(hostname);
} catch (error) {
throw createSsrfBlockedError('unresolvable_hostname', {
url: url.toString(),
hostname,
});
}
if (addresses.length === 0) {
throw createSsrfBlockedError('unresolvable_hostname', {
url: url.toString(),
hostname,
});
}
for (const address of addresses) {
if (isBlockedIpAddress(address) && !allowPrivate) {
throw createSsrfBlockedError('blocked_ip', {
url: url.toString(),
hostname,
address,
});
}
}
return url;
}
export interface SafeFetchOptions extends SSRFProtectionOptions {
timeoutMs?: number;
maxRedirects?: number;
}
export async function safeFetch(
rawUrl: string | URL,
init: RequestInit = {},
options: SafeFetchOptions = {}
): Promise<Response> {
const timeoutMs = options.timeoutMs ?? 10_000;
const maxRedirects = options.maxRedirects ?? 3;
const timeoutSignal = AbortSignal.timeout(timeoutMs);
const signal = init.signal
? AbortSignal.any([init.signal, timeoutSignal])
: timeoutSignal;
let current = await assertSsrFSafeUrl(rawUrl, options);
let redirects = 0;
// Always handle redirects manually (SSRF-safe on each hop).
let requestInit: RequestInit = {
...init,
redirect: 'manual',
signal,
};
while (true) {
const response = await fetch(current, requestInit);
if (response.status >= 300 && response.status < 400) {
const location = response.headers.get('location');
if (!location) {
return response;
}
// Drain/cancel body before following redirect to avoid leaking resources.
try {
await response.body?.cancel();
} catch {
// ignore
}
if (redirects >= maxRedirects) {
throw createSsrfBlockedError('too_many_redirects', {
url: current.toString(),
});
}
const next = new URL(location, current);
current = await assertSsrFSafeUrl(next, options);
redirects += 1;
// 303 forces GET semantics
if (
response.status === 303 &&
requestInit.method &&
requestInit.method !== 'GET'
) {
requestInit = { ...requestInit, method: 'GET', body: undefined };
}
continue;
}
return response;
}
}
export async function readResponseBufferWithLimit(
response: Response,
limitBytes: number
): Promise<Buffer> {
const rawLen = response.headers.get('content-length');
if (rawLen) {
const len = Number.parseInt(rawLen, 10);
if (Number.isFinite(len) && len > limitBytes) {
try {
await response.body?.cancel();
} catch {
// ignore
}
throw new ResponseTooLargeError({ limitBytes, receivedBytes: len });
}
}
if (!response.body) {
return Buffer.alloc(0);
}
// Convert Web ReadableStream -> Node Readable for consistent limit handling.
const nodeStream = Readable.fromWeb(response.body);
const chunks: Buffer[] = [];
let total = 0;
try {
for await (const chunk of nodeStream) {
const buf = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk);
total += buf.length;
if (total > limitBytes) {
try {
nodeStream.destroy();
} catch {
// ignore
}
throw new ResponseTooLargeError({ limitBytes, receivedBytes: total });
}
chunks.push(buf);
}
} finally {
if (total > limitBytes) {
try {
await response.body?.cancel();
} catch {
// ignore
}
}
}
return Buffer.concat(chunks, total);
}
type FetchBufferResult = { buffer: Buffer; type: string };
const ATTACH_GET_PARAMS = { timeoutMs: OneMinute / 6, maxRedirects: 3 };
export async function fetchBuffer(
url: string,
limit: number,
contentType?: string
): Promise<FetchBufferResult> {
const resp = url.startsWith('data:')
? await fetch(url)
: await safeFetch(url, { method: 'GET' }, ATTACH_GET_PARAMS);
if (!resp.ok) {
throw new Error(
`Failed to fetch attachment: ${resp.status} ${resp.statusText}`
);
}
const type = resp.headers.get('content-type') || 'application/octet-stream';
if (contentType && !type.startsWith(contentType)) {
throw new Error(
`Attachment content-type mismatch: expected ${contentType} but got ${type}`
);
}
const buffer = await readResponseBufferWithLimit(resp, limit);
return { buffer, type: type };
}
export function bufferToArrayBuffer(buffer: Buffer): ArrayBuffer {
const copy = new Uint8Array(buffer.byteLength);
copy.set(buffer);
return copy.buffer;
}

View File

@@ -28,13 +28,6 @@ class GenerateAccessTokenInput {
export class AccessTokenResolver {
constructor(private readonly models: Models) {}
@Query(() => [AccessToken], {
deprecationReason: 'use currentUser.accessTokens',
})
async accessTokens(@CurrentUser() user: CurrentUser): Promise<AccessToken[]> {
return await this.models.accessToken.list(user.id);
}
@Query(() => [RevealedAccessToken], {
deprecationReason: 'use currentUser.revealedAccessTokens',
})

View File

@@ -16,7 +16,6 @@ import type { Request, Response } from 'express';
import {
ActionForbidden,
Cache,
Config,
CryptoHelper,
EmailTokenNotFound,
@@ -53,7 +52,9 @@ interface MagicLinkCredential {
client_nonce?: string;
}
const OTP_CACHE_KEY = (otp: string) => `magic-link-otp:${otp}`;
interface OpenAppSignInCredential {
code: string;
}
@Throttle('strict')
@Controller('/api/auth')
@@ -65,7 +66,6 @@ export class AuthController {
private readonly auth: AuthService,
private readonly models: Models,
private readonly config: Config,
private readonly cache: Cache,
private readonly crypto: CryptoHelper
) {
if (env.dev) {
@@ -111,11 +111,7 @@ export class AuthController {
async signIn(
@Req() req: Request,
@Res() res: Response,
@Body() credential: SignInCredential,
/**
* @deprecated
*/
@Query('redirect_uri') redirectUri?: string
@Body() credential: SignInCredential
) {
validators.assertValidEmail(credential.email);
const canSignIn = await this.auth.canSignIn(credential.email);
@@ -132,11 +128,9 @@ export class AuthController {
);
} else {
await this.sendMagicLink(
req,
res,
credential.email,
credential.callbackUrl,
redirectUri,
credential.client_nonce
);
}
@@ -155,13 +149,25 @@ export class AuthController {
}
async sendMagicLink(
_req: Request,
res: Response,
email: string,
callbackUrl = '/magic-link',
redirectUrl?: string,
clientNonce?: string
) {
if (!this.url.isAllowedCallbackUrl(callbackUrl)) {
throw new ActionForbidden();
}
const callbackUrlObj = this.url.url(callbackUrl);
const redirectUriInCallback =
callbackUrlObj.searchParams.get('redirect_uri');
if (
redirectUriInCallback &&
!this.url.isAllowedRedirectUri(redirectUriInCallback)
) {
throw new ActionForbidden();
}
// send email magic link
const user = await this.models.user.getUserByEmail(email, {
withDisabled: true,
@@ -207,23 +213,9 @@ export class AuthController {
);
const otp = this.crypto.otp();
// TODO(@forehalo): this is a temporary solution, we should not rely on cache to store the otp
const cacheKey = OTP_CACHE_KEY(otp);
await this.cache.set(
cacheKey,
{ token, clientNonce },
{ ttl: ttlInSec * 1000 }
);
await this.models.magicLinkOtp.upsert(email, otp, token, clientNonce);
const magicLink = this.url.link(callbackUrl, {
token: otp,
email,
...(redirectUrl
? {
redirect_uri: redirectUrl,
}
: {}),
});
const magicLink = this.url.link(callbackUrl, { token: otp, email });
if (env.dev) {
// make it easier to test in dev mode
this.logger.debug(`Magic link: ${magicLink}`);
@@ -237,8 +229,9 @@ export class AuthController {
}
@Public()
@Get('/sign-out')
@Post('/sign-out')
async signOut(
@Req() req: Request,
@Res() res: Response,
@Session() session: Session | undefined,
@Query('user_id') userId: string | undefined
@@ -248,12 +241,63 @@ export class AuthController {
return;
}
const csrfCookie = req.cookies?.[AuthService.csrfCookieName] as
| string
| undefined;
const csrfHeader = req.get('x-affine-csrf-token');
if (!csrfCookie || !csrfHeader || csrfCookie !== csrfHeader) {
throw new ActionForbidden();
}
await this.auth.signOut(session.sessionId, userId);
await this.auth.refreshCookies(res, session.sessionId);
res.status(HttpStatus.OK).send({});
}
@Public()
@UseNamedGuard('version')
@Post('/open-app/sign-in-code')
async openAppSignInCode(@CurrentUser() user?: CurrentUser) {
if (!user) {
throw new ActionForbidden();
}
// short-lived one-time code for handing off the authenticated session
const code = await this.models.verificationToken.create(
TokenType.OpenAppSignIn,
user.id,
5 * 60
);
return { code };
}
@Public()
@UseNamedGuard('version')
@Post('/open-app/sign-in')
async openAppSignIn(
@Req() req: Request,
@Res() res: Response,
@Body() credential: OpenAppSignInCredential
) {
if (!credential?.code) {
throw new InvalidAuthState();
}
const tokenRecord = await this.models.verificationToken.get(
TokenType.OpenAppSignIn,
credential.code
);
if (!tokenRecord?.credential) {
throw new InvalidAuthState();
}
await this.auth.setCookies(req, res, tokenRecord.credential);
res.send({ id: tokenRecord.credential });
}
@Public()
@UseNamedGuard('version')
@Post('/magic-link')
@@ -269,23 +313,20 @@ export class AuthController {
validators.assertValidEmail(email);
const cacheKey = OTP_CACHE_KEY(otp);
const cachedToken = await this.cache.get<{
token: string;
clientNonce: string;
}>(cacheKey);
let token: string | undefined;
if (cachedToken && typeof cachedToken === 'object') {
token = cachedToken.token;
if (cachedToken.clientNonce && cachedToken.clientNonce !== clientNonce) {
const consumed = await this.models.magicLinkOtp.consume(
email,
otp,
clientNonce
);
if (!consumed.ok) {
if (consumed.reason === 'nonce_mismatch') {
throw new InvalidAuthState();
}
}
if (!token) {
throw new InvalidEmailToken();
}
const token = consumed.token;
const tokenRecord = await this.models.verificationToken.verify(
TokenType.SignIn,
token,

View File

@@ -12,6 +12,7 @@ import { Socket } from 'socket.io';
import {
AccessDenied,
AuthenticationRequired,
Cache,
Config,
CryptoHelper,
getRequestResponseFromContext,
@@ -23,6 +24,8 @@ import { Session, TokenSession } from './session';
const PUBLIC_ENTRYPOINT_SYMBOL = Symbol('public');
const INTERNAL_ENTRYPOINT_SYMBOL = Symbol('internal');
const INTERNAL_ACCESS_TOKEN_TTL_MS = 5 * 60 * 1000;
const INTERNAL_ACCESS_TOKEN_CLOCK_SKEW_MS = 30 * 1000;
@Injectable()
export class AuthGuard implements CanActivate, OnModuleInit {
@@ -30,6 +33,7 @@ export class AuthGuard implements CanActivate, OnModuleInit {
constructor(
private readonly crypto: CryptoHelper,
private readonly cache: Cache,
private readonly ref: ModuleRef,
private readonly reflector: Reflector
) {}
@@ -48,10 +52,28 @@ export class AuthGuard implements CanActivate, OnModuleInit {
[clazz, handler]
);
if (isInternal) {
// check access token: data,signature
const accessToken = req.get('x-access-token');
if (accessToken && this.crypto.verify(accessToken)) {
return true;
if (accessToken) {
const payload = this.crypto.parseInternalAccessToken(accessToken);
if (payload) {
const now = Date.now();
const method = req.method.toUpperCase();
const path = req.path;
const timestampInRange =
payload.ts <= now + INTERNAL_ACCESS_TOKEN_CLOCK_SKEW_MS &&
now - payload.ts <= INTERNAL_ACCESS_TOKEN_TTL_MS;
if (timestampInRange && payload.m === method && payload.p === path) {
const nonceKey = `rpc:nonce:${payload.nonce}`;
const ok = await this.cache.setnx(nonceKey, 1, {
ttl: INTERNAL_ACCESS_TOKEN_TTL_MS,
});
if (ok) {
return true;
}
}
}
}
throw new AccessDenied('Invalid internal request');
}

View File

@@ -159,7 +159,7 @@ export class AuthResolver {
user.id
);
const url = this.url.link(callbackUrl, { userId: user.id, token });
const url = this.url.safeLink(callbackUrl, { userId: user.id, token });
return await this.auth.sendChangePasswordEmail(user.email, url);
}
@@ -200,7 +200,7 @@ export class AuthResolver {
user.id
);
const url = this.url.link(callbackUrl, { token });
const url = this.url.safeLink(callbackUrl, { token });
return await this.auth.sendChangeEmail(user.email, url);
}
@@ -244,7 +244,10 @@ export class AuthResolver {
user.id
);
const url = this.url.link(callbackUrl, { token: verifyEmailToken, email });
const url = this.url.safeLink(callbackUrl, {
token: verifyEmailToken,
email,
});
return await this.auth.sendVerifyChangeEmail(email, url);
}
@@ -258,7 +261,7 @@ export class AuthResolver {
user.id
);
const url = this.url.link(callbackUrl, { token });
const url = this.url.safeLink(callbackUrl, { token });
return await this.auth.sendVerifyEmail(user.email, url);
}
@@ -302,6 +305,6 @@ export class AuthResolver {
userId
);
return this.url.link(callbackUrl, { userId, token });
return this.url.safeLink(callbackUrl, { userId, token });
}
}

View File

@@ -1,3 +1,5 @@
import { randomUUID } from 'node:crypto';
import { Injectable, OnApplicationBootstrap } from '@nestjs/common';
import type { CookieOptions, Request, Response } from 'express';
import { assign, pick } from 'lodash-es';
@@ -39,6 +41,7 @@ export class AuthService implements OnApplicationBootstrap {
};
static readonly sessionCookieName = 'affine_session';
static readonly userCookieName = 'affine_user_id';
static readonly csrfCookieName = 'affine_csrf_token';
constructor(
private readonly config: Config,
@@ -171,6 +174,11 @@ export class AuthService implements OnApplicationBootstrap {
expires: newExpiresAt,
...this.cookieOptions,
});
res.cookie(AuthService.csrfCookieName, randomUUID(), {
expires: newExpiresAt,
...this.cookieOptions,
httpOnly: false,
});
return true;
}
@@ -207,6 +215,12 @@ export class AuthService implements OnApplicationBootstrap {
expires: userSession.expiresAt ?? void 0,
});
res.cookie(AuthService.csrfCookieName, randomUUID(), {
...this.cookieOptions,
httpOnly: false,
expires: userSession.expiresAt ?? void 0,
});
this.setUserCookie(res, userId);
}
@@ -227,6 +241,7 @@ export class AuthService implements OnApplicationBootstrap {
private clearCookies(res: Response<any, Record<string, any>>) {
res.clearCookie(AuthService.sessionCookieName);
res.clearCookie(AuthService.userCookieName);
res.clearCookie(AuthService.csrfCookieName);
}
setUserCookie(res: Response, userId: string) {

View File

@@ -240,18 +240,6 @@ export class AppConfigResolver {
return this.validateConfigInternal(updates);
}
@Mutation(() => [AppConfigValidateResult], {
description: 'validate app configuration',
deprecationReason: 'use Query.validateAppConfig',
name: 'validateAppConfig',
})
async validateAppConfigMutation(
@Args('updates', { type: () => [UpdateAppConfigInput] })
updates: UpdateAppConfigInput[]
): Promise<AppConfigValidateResult[]> {
return this.validateConfigInternal(updates);
}
private validateConfigInternal(
updates: UpdateAppConfigInput[]
): AppConfigValidateResult[] {

View File

@@ -77,14 +77,124 @@ test('should forbid access to rpc api with invalid access token', async t => {
t.pass();
});
test('should forbid replayed internal access token', async t => {
const { app } = t.context;
const workspaceId = '123';
const docId = '123';
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}`;
const token = t.context.crypto.signInternalAccessToken({
method: 'GET',
path,
nonce: `nonce-${randomUUID()}`,
});
await app.GET(path).set('x-access-token', token).expect(404);
await app
.GET(path)
.set('x-access-token', token)
.expect({
status: 403,
code: 'Forbidden',
type: 'NO_PERMISSION',
name: 'ACCESS_DENIED',
message: 'Invalid internal request',
})
.expect(403);
t.pass();
});
test('should forbid internal access token when method mismatched', async t => {
const { app } = t.context;
const workspaceId = '123';
const docId = '123';
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/diff`;
await app
.POST(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect({
status: 403,
code: 'Forbidden',
type: 'NO_PERMISSION',
name: 'ACCESS_DENIED',
message: 'Invalid internal request',
})
.expect(403);
t.pass();
});
test('should forbid internal access token when path mismatched', async t => {
const { app } = t.context;
const workspaceId = '123';
const docId = '123';
const wrongPath = `/rpc/workspaces/${workspaceId}/docs/${docId}`;
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/content`;
await app
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({
method: 'GET',
path: wrongPath,
})
)
.expect({
status: 403,
code: 'Forbidden',
type: 'NO_PERMISSION',
name: 'ACCESS_DENIED',
message: 'Invalid internal request',
})
.expect(403);
t.pass();
});
test('should forbid internal access token when expired', async t => {
const { app } = t.context;
const workspaceId = '123';
const docId = '123';
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}`;
await app
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({
method: 'GET',
path,
now: Date.now() - 10 * 60 * 1000,
nonce: `nonce-${randomUUID()}`,
})
)
.expect({
status: 403,
code: 'Forbidden',
type: 'NO_PERMISSION',
name: 'ACCESS_DENIED',
message: 'Invalid internal request',
})
.expect(403);
t.pass();
});
test('should 404 when doc not found', async t => {
const { app } = t.context;
const workspaceId = '123';
const docId = '123';
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}`;
await app
.GET(`/rpc/workspaces/${workspaceId}/docs/${docId}`)
.set('x-access-token', t.context.crypto.sign(docId))
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect({
status: 404,
code: 'Not Found',
@@ -111,9 +221,13 @@ test('should return doc when found', async t => {
},
]);
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}`;
const res = await app
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}`)
.set('x-access-token', t.context.crypto.sign(docId))
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.set('x-cloud-trace-context', 'test-trace-id/span-id')
.expect(200)
.expect('x-request-id', 'test-trace-id')
@@ -129,9 +243,13 @@ test('should 404 when doc diff not found', async t => {
const workspaceId = '123';
const docId = '123';
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/diff`;
await app
.POST(`/rpc/workspaces/${workspaceId}/docs/${docId}/diff`)
.set('x-access-token', t.context.crypto.sign(docId))
.POST(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'POST', path })
)
.expect({
status: 404,
code: 'Not Found',
@@ -148,9 +266,13 @@ test('should 404 when doc content not found', async t => {
const workspaceId = '123';
const docId = '123';
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/content`;
await app
.GET(`/rpc/workspaces/${workspaceId}/docs/${docId}/content`)
.set('x-access-token', t.context.crypto.sign(docId))
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect({
status: 404,
code: 'Not Found',
@@ -172,9 +294,13 @@ test('should get doc content in json format', async t => {
});
const docId = randomUUID();
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}/content`;
await app
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/content`)
.set('x-access-token', t.context.crypto.sign(docId))
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect('Content-Type', 'application/json; charset=utf-8')
.expect({
title: 'test title',
@@ -183,8 +309,11 @@ test('should get doc content in json format', async t => {
.expect(200);
await app
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/content?full=false`)
.set('x-access-token', t.context.crypto.sign(docId))
.GET(`${path}?full=false`)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect('Content-Type', 'application/json; charset=utf-8')
.expect({
title: 'test title',
@@ -204,9 +333,13 @@ test('should get full doc content in json format', async t => {
});
const docId = randomUUID();
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}/content`;
await app
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/content?full=true`)
.set('x-access-token', t.context.crypto.sign(docId))
.GET(`${path}?full=true`)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect('Content-Type', 'application/json; charset=utf-8')
.expect({
title: 'test title',
@@ -220,9 +353,13 @@ test('should 404 when workspace content not found', async t => {
const { app } = t.context;
const workspaceId = '123';
const path = `/rpc/workspaces/${workspaceId}/content`;
await app
.GET(`/rpc/workspaces/${workspaceId}/content`)
.set('x-access-token', t.context.crypto.sign(workspaceId))
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect({
status: 404,
code: 'Not Found',
@@ -244,9 +381,13 @@ test('should get workspace content in json format', async t => {
});
const workspaceId = randomUUID();
const path = `/rpc/workspaces/${workspaceId}/content`;
await app
.GET(`/rpc/workspaces/${workspaceId}/content`)
.set('x-access-token', t.context.crypto.sign(workspaceId))
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect(200)
.expect({
name: 'test name',
@@ -265,9 +406,13 @@ test('should get doc markdown in json format', async t => {
});
const docId = randomUUID();
const path = `/rpc/workspaces/${workspace.id}/docs/${docId}/markdown`;
await app
.GET(`/rpc/workspaces/${workspace.id}/docs/${docId}/markdown`)
.set('x-access-token', t.context.crypto.sign(docId))
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect('Content-Type', 'application/json; charset=utf-8')
.expect(200)
.expect({
@@ -282,9 +427,13 @@ test('should 404 when doc markdown not found', async t => {
const workspaceId = '123';
const docId = '123';
const path = `/rpc/workspaces/${workspaceId}/docs/${docId}/markdown`;
await app
.GET(`/rpc/workspaces/${workspaceId}/docs/${docId}/markdown`)
.set('x-access-token', t.context.crypto.sign(docId))
.GET(path)
.set(
'x-access-token',
t.context.crypto.signInternalAccessToken({ method: 'GET', path })
)
.expect({
status: 404,
code: 'Not Found',

View File

@@ -257,12 +257,13 @@ export class RpcDocReader extends DatabaseDocReader {
super(cache, models, blobStorage, workspace);
}
private async fetch(
accessToken: string,
url: string,
method: 'GET' | 'POST',
body?: Uint8Array
) {
private async fetch(url: string, method: 'GET' | 'POST', body?: Uint8Array) {
const { pathname } = new URL(url);
const accessToken = this.crypto.signInternalAccessToken({
method,
path: pathname,
});
const headers: Record<string, string> = {
'x-access-token': accessToken,
'x-cloud-trace-context': getOrGenRequestId('rpc'),
@@ -293,9 +294,8 @@ export class RpcDocReader extends DatabaseDocReader {
docId: string
): Promise<DocRecord | null> {
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}`;
const accessToken = this.crypto.sign(docId);
try {
const res = await this.fetch(accessToken, url, 'GET');
const res = await this.fetch(url, 'GET');
if (!res) {
return null;
}
@@ -330,9 +330,8 @@ export class RpcDocReader extends DatabaseDocReader {
aiEditable: boolean
): Promise<DocMarkdown | null> {
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}/markdown?aiEditable=${aiEditable}`;
const accessToken = this.crypto.sign(docId);
try {
const res = await this.fetch(accessToken, url, 'GET');
const res = await this.fetch(url, 'GET');
if (!res) {
return null;
}
@@ -358,9 +357,8 @@ export class RpcDocReader extends DatabaseDocReader {
stateVector?: Uint8Array
): Promise<DocDiff | null> {
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}/diff`;
const accessToken = this.crypto.sign(docId);
try {
const res = await this.fetch(accessToken, url, 'POST', stateVector);
const res = await this.fetch(url, 'POST', stateVector);
if (!res) {
return null;
}
@@ -399,9 +397,8 @@ export class RpcDocReader extends DatabaseDocReader {
fullContent = false
): Promise<PageDocContent | null> {
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/docs/${docId}/content?full=${fullContent}`;
const accessToken = this.crypto.sign(docId);
try {
const res = await this.fetch(accessToken, url, 'GET');
const res = await this.fetch(url, 'GET');
if (!res) {
return null;
}
@@ -427,9 +424,8 @@ export class RpcDocReader extends DatabaseDocReader {
workspaceId: string
): Promise<WorkspaceDocInfo | null> {
const url = `${this.config.docService.endpoint}/rpc/workspaces/${workspaceId}/content`;
const accessToken = this.crypto.sign(workspaceId);
try {
const res = await this.fetch(accessToken, url, 'GET');
const res = await this.fetch(url, 'GET');
if (!res) {
return null;
}

View File

@@ -130,7 +130,7 @@ export abstract class DocStorageAdapter extends Connection {
snapshot: DocRecord | null,
finalUpdate: DocUpdate
) {
this.logger.log(
this.logger.verbose(
`Squashing updates, spaceId: ${spaceId}, docId: ${docId}, updates: ${updates.length}`
);
@@ -152,7 +152,7 @@ export abstract class DocStorageAdapter extends Connection {
// always mark updates as merged unless throws
const count = await this.markUpdatesMerged(spaceId, docId, updates);
this.logger.log(
this.logger.verbose(
`Marked ${count} updates as merged, spaceId: ${spaceId}, docId: ${docId}, timestamp: ${timestamp}`
);

View File

@@ -0,0 +1,90 @@
import ava, { TestFn } from 'ava';
import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
import { buildAppModule } from '../../../app.module';
import { Models } from '../../../models';
const test = ava as TestFn<{
app: TestingApp;
models: Models;
allowlistedAdminToken: string;
nonAllowlistedAdminToken: string;
userToken: string;
}>;
test.before(async t => {
const app = await createTestingApp({
imports: [buildAppModule(globalThis.env)],
});
t.context.app = app;
t.context.models = app.get(Models);
});
test.beforeEach(async t => {
await t.context.app.initTestingDB();
const allowlistedAdmin = await t.context.models.user.create({
email: 'admin@affine.pro',
password: '1',
emailVerifiedAt: new Date(),
});
await t.context.models.userFeature.add(
allowlistedAdmin.id,
'administrator',
'test'
);
const allowlistedAdminToken = await t.context.models.accessToken.create({
userId: allowlistedAdmin.id,
name: 'test',
});
t.context.allowlistedAdminToken = allowlistedAdminToken.token;
const nonAllowlistedAdmin = await t.context.models.user.create({
email: 'admin2@affine.pro',
password: '1',
emailVerifiedAt: new Date(),
});
await t.context.models.userFeature.add(
nonAllowlistedAdmin.id,
'administrator',
'test'
);
const nonAllowlistedAdminToken = await t.context.models.accessToken.create({
userId: nonAllowlistedAdmin.id,
name: 'test',
});
t.context.nonAllowlistedAdminToken = nonAllowlistedAdminToken.token;
const user = await t.context.models.user.create({
email: 'user@affine.pro',
password: '1',
emailVerifiedAt: new Date(),
});
const userToken = await t.context.models.accessToken.create({
userId: user.id,
name: 'test',
});
t.context.userToken = userToken.token;
});
test.after.always(async t => {
await t.context.app.close();
});
test('should return 404 for non-admin user', async t => {
await t.context.app
.GET('/api/queue')
.set('Authorization', `Bearer ${t.context.userToken}`)
.expect(404);
t.pass();
});
test('should allow allowlisted admin', async t => {
await t.context.app
.GET('/api/queue')
.set('Authorization', `Bearer ${t.context.allowlistedAdminToken}`)
.expect(200)
.expect('Content-Type', /text\/html/);
t.pass();
});

View File

@@ -53,12 +53,21 @@ class QueueDashboardService implements OnModuleInit {
): Promise<void> => {
try {
const session = await this.authGuard.signIn(req, res);
const userId = session?.user?.id;
const user = session?.user;
const userId = user?.id;
const email = user?.email?.toLowerCase();
const isAdmin = userId ? await this.feature.isAdmin(userId) : false;
if (!isAdmin) {
res.status(404).end();
return;
}
if (req.method === 'GET' && (req.path === '/' || req.path === '')) {
this.logger.log(
`QueueDash accessed by ${userId} (${email ?? 'n/a'})`
);
}
} catch (error) {
this.logger.warn('QueueDash auth failed', error as Error);
res.status(404).end();

View File

@@ -9,6 +9,7 @@ import {
WebSocketServer,
} from '@nestjs/websockets';
import { ClsInterceptor } from 'nestjs-cls';
import semver from 'semver';
import { type Server, Socket } from 'socket.io';
import {
@@ -49,10 +50,10 @@ type EventResponse<Data = any> = Data extends never
data: Data;
};
// sync-019: legacy 0.19.x clients (broadcast-doc-updates/push-doc-updates).
// Remove after 2026-06-30 once metrics show 0 usage for 30 days.
// 020+: receives space:broadcast-doc-updates (batch) and sends space:push-doc-update.
type RoomType = 'sync' | `${string}:awareness` | 'sync-019';
// sync: shared room for space membership checks and non-protocol broadcasts.
// sync-025: legacy 0.25 doc sync protocol (space:broadcast-doc-update).
// sync-026: current doc sync protocol (space:broadcast-doc-updates).
type RoomType = 'sync' | 'sync-025' | 'sync-026' | `${string}:awareness`;
function Room(
spaceId: string,
@@ -61,6 +62,25 @@ function Room(
return `${spaceId}:${type}`;
}
const MIN_WS_CLIENT_VERSION = new semver.Range('>=0.25.0', {
includePrerelease: true,
});
const DOC_UPDATES_PROTOCOL_026 = new semver.Range('>=0.26.0-0', {
includePrerelease: true,
});
type SyncProtocolRoomType = Extract<RoomType, 'sync-025' | 'sync-026'>;
function isSupportedWsClientVersion(clientVersion: string): boolean {
return Boolean(
semver.valid(clientVersion) && MIN_WS_CLIENT_VERSION.test(clientVersion)
);
}
function getSyncProtocolRoomType(clientVersion: string): SyncProtocolRoomType {
return DOC_UPDATES_PROTOCOL_026.test(clientVersion) ? 'sync-026' : 'sync-025';
}
enum SpaceType {
Workspace = 'workspace',
Userspace = 'userspace',
@@ -90,16 +110,6 @@ interface LeaveSpaceAwarenessMessage {
docId: string;
}
/**
* @deprecated
*/
interface PushDocUpdatesMessage {
spaceType: SpaceType;
spaceId: string;
docId: string;
updates: string[];
}
interface PushDocUpdateMessage {
spaceType: SpaceType;
spaceId: string;
@@ -117,6 +127,15 @@ interface BroadcastDocUpdatesMessage {
compressed?: boolean;
}
interface BroadcastDocUpdateMessage {
spaceType: SpaceType;
spaceId: string;
docId: string;
update: string;
timestamp: number;
editor: string;
}
interface LoadDocMessage {
spaceType: SpaceType;
spaceId: string;
@@ -225,6 +244,11 @@ export class SpaceSyncGateway
}
}
private rejectJoin(client: Socket) {
// Give socket.io a chance to flush the ack packet before disconnecting.
setImmediate(() => client.disconnect());
}
handleConnection() {
this.connectionCount++;
this.logger.debug(`New connection, total: ${this.connectionCount}`);
@@ -252,23 +276,21 @@ export class SpaceSyncGateway
return;
}
const room025 = `${spaceType}:${Room(spaceId, 'sync-025')}`;
const encodedUpdates = this.encodeUpdates(updates);
this.server
.to(Room(spaceId, 'sync-019'))
.emit('space:broadcast-doc-updates', {
spaceType,
for (const update of encodedUpdates) {
const payload: BroadcastDocUpdateMessage = {
spaceType: spaceType as SpaceType,
spaceId,
docId,
updates: encodedUpdates,
update,
timestamp,
editor,
});
metrics.socketio
.counter('sync_019_broadcast')
.add(encodedUpdates.length, { event: 'doc_updates_pushed' });
editor: editor ?? '',
};
this.server.to(room025).emit('space:broadcast-doc-update', payload);
}
const room = `${spaceType}:${Room(spaceId)}`;
const room026 = `${spaceType}:${Room(spaceId, 'sync-026')}`;
const payload = this.buildBroadcastPayload(
spaceType as SpaceType,
spaceId,
@@ -277,7 +299,7 @@ export class SpaceSyncGateway
timestamp,
editor
);
this.server.to(room).emit('space:broadcast-doc-updates', payload);
this.server.to(room026).emit('space:broadcast-doc-updates', payload);
metrics.socketio
.counter('doc_updates_broadcast')
.add(payload.updates.length, {
@@ -314,16 +336,34 @@ export class SpaceSyncGateway
@MessageBody()
{ spaceType, spaceId, clientVersion }: JoinSpaceMessage
): Promise<EventResponse<{ clientId: string; success: boolean }>> {
if (
![SpaceType.Userspace, SpaceType.Workspace].includes(spaceType) ||
/^0.1/.test(clientVersion)
) {
if (![SpaceType.Userspace, SpaceType.Workspace].includes(spaceType)) {
this.rejectJoin(client);
return { data: { clientId: client.id, success: false } };
} else {
if (spaceType === SpaceType.Workspace) {
this.event.emit('workspace.embedding', { workspaceId: spaceId });
}
await this.selectAdapter(client, spaceType).join(user.id, spaceId);
}
if (!isSupportedWsClientVersion(clientVersion)) {
this.rejectJoin(client);
return { data: { clientId: client.id, success: false } };
}
if (spaceType === SpaceType.Workspace) {
this.event.emit('workspace.embedding', { workspaceId: spaceId });
}
const adapter = this.selectAdapter(client, spaceType);
await adapter.join(user.id, spaceId);
const protocolRoomType = getSyncProtocolRoomType(clientVersion);
const protocolRoom = adapter.room(spaceId, protocolRoomType);
const otherProtocolRoom = adapter.room(
spaceId,
protocolRoomType === 'sync-025' ? 'sync-026' : 'sync-025'
);
if (client.rooms.has(otherProtocolRoom)) {
await client.leave(otherProtocolRoom);
}
if (!client.rooms.has(protocolRoom)) {
await client.join(protocolRoom);
}
return { data: { clientId: client.id, success: true } };
@@ -380,68 +420,8 @@ export class SpaceSyncGateway
}
/**
* @deprecated use [space:push-doc-update] instead, client should always merge updates on their own
*
* only 0.19.x client will send this event
* client should always merge updates on their own
*/
@SubscribeMessage('space:push-doc-updates')
async onReceiveDocUpdates(
@ConnectedSocket() client: Socket,
@CurrentUser() user: CurrentUser,
@MessageBody()
message: PushDocUpdatesMessage
): Promise<EventResponse<{ accepted: true; timestamp?: number }>> {
const { spaceType, spaceId, docId, updates } = message;
const adapter = this.selectAdapter(client, spaceType);
const id = new DocID(docId, spaceId);
// TODO(@forehalo): enable after frontend supporting doc revert
// await this.ac.user(user.id).doc(spaceId, id.guid).assert('Doc.Update');
const timestamp = await adapter.push(
spaceId,
id.guid,
updates.map(update => Buffer.from(update, 'base64')),
user.id
);
metrics.socketio
.counter('sync_019_event')
.add(1, { event: 'push-doc-updates' });
// broadcast to 0.19.x clients
client.to(Room(spaceId, 'sync-019')).emit('space:broadcast-doc-updates', {
...message,
timestamp,
editor: user.id,
});
// broadcast to new clients
const decodedUpdates = updates.map(update => Buffer.from(update, 'base64'));
const payload = this.buildBroadcastPayload(
spaceType,
spaceId,
docId,
decodedUpdates,
timestamp,
user.id
);
client
.to(adapter.room(spaceId))
.emit('space:broadcast-doc-updates', payload);
metrics.socketio
.counter('doc_updates_broadcast')
.add(payload.updates.length, {
mode: payload.compressed ? 'compressed' : 'batch',
});
return {
data: {
accepted: true,
timestamp,
},
};
}
@SubscribeMessage('space:push-doc-update')
async onReceiveDocUpdate(
@ConnectedSocket() client: Socket,
@@ -461,16 +441,6 @@ export class SpaceSyncGateway
user.id
);
// broadcast to 0.19.x clients
client.to(Room(spaceId, 'sync-019')).emit('space:broadcast-doc-updates', {
spaceType,
spaceId,
docId,
updates: [update],
timestamp,
editor: user.id,
});
const payload = this.buildBroadcastPayload(
spaceType,
spaceId,
@@ -480,7 +450,7 @@ export class SpaceSyncGateway
user.id
);
client
.to(adapter.room(spaceId))
.to(adapter.room(spaceId, 'sync-026'))
.emit('space:broadcast-doc-updates', payload);
metrics.socketio
.counter('doc_updates_broadcast')
@@ -488,6 +458,17 @@ export class SpaceSyncGateway
mode: payload.compressed ? 'compressed' : 'batch',
});
client
.to(adapter.room(spaceId, 'sync-025'))
.emit('space:broadcast-doc-update', {
spaceType,
spaceId,
docId,
update,
timestamp,
editor: user.id,
} satisfies BroadcastDocUpdateMessage);
return {
data: {
accepted: true,
@@ -516,8 +497,18 @@ export class SpaceSyncGateway
@ConnectedSocket() client: Socket,
@CurrentUser() user: CurrentUser,
@MessageBody()
{ spaceType, spaceId, docId }: JoinSpaceAwarenessMessage
{ spaceType, spaceId, docId, clientVersion }: JoinSpaceAwarenessMessage
) {
if (![SpaceType.Userspace, SpaceType.Workspace].includes(spaceType)) {
this.rejectJoin(client);
return { data: { clientId: client.id, success: false } };
}
if (!isSupportedWsClientVersion(clientVersion)) {
this.rejectJoin(client);
return { data: { clientId: client.id, success: false } };
}
await this.selectAdapter(client, spaceType).join(
user.id,
spaceId,
@@ -555,13 +546,6 @@ export class SpaceSyncGateway
.to(adapter.room(spaceId, roomType))
.emit('space:collect-awareness', { spaceType, spaceId, docId });
// TODO(@forehalo): remove backward compatibility
if (spaceType === SpaceType.Workspace) {
client
.to(adapter.room(spaceId, roomType))
.emit('new-client-awareness-init');
}
return { data: { clientId: client.id } };
}

View File

@@ -66,21 +66,27 @@ export class UserResolver {
): Promise<typeof UserOrLimitedUser | null> {
validators.assertValidEmail(email);
// TODO(@forehalo): need to limit a user can only get another user witch is in the same workspace
// NOTE: prevent user enumeration. Only allow querying users within the same workspace scope.
if (!currentUser) {
return null;
}
const user = await this.models.user.getUserByEmail(email);
// return empty response when user not exists
if (!user) return null;
if (currentUser) {
if (user.id === currentUser.id) {
return sessionUser(user);
}
// only return limited info when not logged in
return {
email: user.email,
hasPassword: !!user.password,
};
const allowed = await this.models.workspaceUser.hasSharedWorkspace(
currentUser.id,
user.id
);
if (!allowed) return null;
return sessionUser(user);
}
@Throttle('strict')

View File

@@ -26,6 +26,6 @@ defineModuleConfig('client', {
},
'versionControl.requiredVersion': {
desc: "Allowed version range of the app that allowed to access the server. Requires 'client/versionControl.enabled' to be true to take effect.",
default: '>=0.20.0',
default: '>=0.25.0',
},
});

View File

@@ -7,7 +7,6 @@ import {
Mutation,
ObjectType,
Parent,
Query,
registerEnumType,
ResolveField,
Resolver,
@@ -33,7 +32,7 @@ import {
MULTIPART_PART_SIZE,
MULTIPART_THRESHOLD,
} from '../../storage/constants';
import { WorkspaceBlobSizes, WorkspaceType } from '../types';
import { WorkspaceType } from '../types';
enum BlobUploadMethod {
GRAPHQL = 'GRAPHQL',
@@ -169,14 +168,6 @@ export class WorkspaceBlobResolver {
return this.getUploadPart(user, workspace.id, key, uploadId, partNumber);
}
@Query(() => WorkspaceBlobSizes, {
deprecationReason: 'use `user.quotaUsage` instead',
})
async collectAllBlobSizes(@CurrentUser() user: CurrentUser) {
const size = await this.quota.getUserStorageUsage(user.id);
return { size };
}
@Mutation(() => String)
async setBlob(
@CurrentUser() user: CurrentUser,
@@ -412,19 +403,6 @@ export class WorkspaceBlobResolver {
return key;
}
@Mutation(() => BlobUploadPart, {
deprecationReason: 'use WorkspaceType.blobUploadPartUrl',
})
async getBlobUploadPartUrl(
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('key') key: string,
@Args('uploadId') uploadId: string,
@Args('partNumber', { type: () => Int }) partNumber: number
): Promise<BlobUploadPart> {
return this.getUploadPart(user, workspaceId, key, uploadId, partNumber);
}
@Mutation(() => Boolean)
async abortBlobUpload(
@CurrentUser() user: CurrentUser,

View File

@@ -238,20 +238,6 @@ export class WorkspaceDocResolver {
return this.models.doc.findPublics(workspace.id);
}
@ResolveField(() => DocType, {
description: 'Get public page of a workspace by page id.',
complexity: 2,
nullable: true,
deprecationReason: 'use [WorkspaceType.doc] instead',
})
async publicPage(
@CurrentUser() me: CurrentUser,
@Parent() workspace: WorkspaceType,
@Args('pageId') pageId: string
) {
return this.doc(me, workspace, pageId);
}
@ResolveField(() => PaginatedDocType)
async docs(
@Parent() workspace: WorkspaceType,
@@ -314,24 +300,6 @@ export class WorkspaceDocResolver {
};
}
@Mutation(() => DocType, {
deprecationReason: 'use publishDoc instead',
})
async publishPage(
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('pageId') pageId: string,
@Args({
name: 'mode',
type: () => PublicDocMode,
nullable: true,
defaultValue: PublicDocMode.Page,
})
mode: PublicDocMode
) {
return this.publishDoc(user, workspaceId, pageId, mode);
}
@Mutation(() => DocType)
async publishDoc(
@CurrentUser() user: CurrentUser,
@@ -364,17 +332,6 @@ export class WorkspaceDocResolver {
return doc;
}
@Mutation(() => DocType, {
deprecationReason: 'use revokePublicDoc instead',
})
async revokePublicPage(
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('docId') docId: string
) {
return this.revokePublicDoc(user, workspaceId, docId);
}
@Mutation(() => DocType)
async revokePublicDoc(
@CurrentUser() user: CurrentUser,

View File

@@ -234,25 +234,6 @@ export class WorkspaceMemberResolver {
return results;
}
/**
* @deprecated
*/
@Mutation(() => [InviteResult], {
deprecationReason: 'use [inviteMembers] instead',
})
async inviteBatch(
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args({ name: 'emails', type: () => [String] }) emails: string[],
@Args('sendInviteMail', {
nullable: true,
deprecationReason: 'never used',
})
_sendInviteMail: boolean = false
) {
return this.inviteMembers(user, workspaceId, emails);
}
@ResolveField(() => InviteLink, {
description: 'invite link for workspace',
nullable: true,
@@ -456,20 +437,6 @@ export class WorkspaceMemberResolver {
return { workspace, user: owner, invitee, status };
}
/**
* @deprecated
*/
@Mutation(() => Boolean, {
deprecationReason: 'use [revokeMember] instead',
})
async revoke(
@CurrentUser() me: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('userId') userId: string
) {
return this.revokeMember(me, workspaceId, userId);
}
@Mutation(() => Boolean)
async revokeMember(
@CurrentUser() me: CurrentUser,

View File

@@ -156,40 +156,6 @@ export class WorkspaceResolver {
};
}
@Query(() => Boolean, {
description: 'Get is owner of workspace',
complexity: 2,
deprecationReason: 'use WorkspaceType[role] instead',
})
async isOwner(
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string
) {
const role = await this.models.workspaceUser.getActive(
workspaceId,
user.id
);
return role?.type === WorkspaceRole.Owner;
}
@Query(() => Boolean, {
description: 'Get is admin of workspace',
complexity: 2,
deprecationReason: 'use WorkspaceType[role] instead',
})
async isAdmin(
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string
) {
const role = await this.models.workspaceUser.getActive(
workspaceId,
user.id
);
return role?.type === WorkspaceRole.Admin;
}
@Query(() => [WorkspaceType], {
description: 'Get all accessible workspaces for current user',
complexity: 2,

View File

@@ -1,3 +1,4 @@
import { PrismaClient } from '@prisma/client';
import test from 'ava';
import { createModule } from '../../__tests__/create-module';
@@ -23,8 +24,16 @@ test('should create access token', async t => {
t.is(token.userId, user.id);
t.is(token.name, 'test');
t.truthy(token.token);
t.true(token.token.startsWith('ut_'));
t.truthy(token.createdAt);
t.is(token.expiresAt, null);
const row = await module.get(PrismaClient).accessToken.findUnique({
where: { id: token.id },
});
t.truthy(row);
t.regex(row!.token, /^[0-9a-f]{64}$/);
t.not(row!.token, token.token);
});
test('should create access token with expiration', async t => {
@@ -50,6 +59,22 @@ test('should list access tokens without token value', async t => {
t.is(listed[0].token, undefined);
});
test('should not reveal access token value after creation', async t => {
const user = await module.create(Mockers.User);
const token = await models.accessToken.create({
userId: user.id,
name: 'test',
});
const listed = await models.accessToken.list(user.id, true);
const found = listed.find(item => item.id === token.id);
t.truthy(found);
t.is(found!.token, '[REDACTED]');
t.not(found!.token, token.token);
});
test('should be able to revoke access token', async t => {
const user = await module.create(Mockers.User);
const token = await module.create(Mockers.AccessToken, { userId: user.id });
@@ -62,7 +87,10 @@ test('should be able to revoke access token', async t => {
test('should be able to get access token by token value', async t => {
const user = await module.create(Mockers.User);
const token = await module.create(Mockers.AccessToken, { userId: user.id });
const token = await models.accessToken.create({
userId: user.id,
name: 'test',
});
const found = await models.accessToken.getByToken(token.token);
t.is(found?.id, token.id);
@@ -72,8 +100,9 @@ test('should be able to get access token by token value', async t => {
test('should not get expired access token', async t => {
const user = await module.create(Mockers.User);
const token = await module.create(Mockers.AccessToken, {
const token = await models.accessToken.create({
userId: user.id,
name: 'test',
expiresAt: Due.before('1s'),
});

View File

@@ -3,43 +3,53 @@ import { Injectable } from '@nestjs/common';
import { CryptoHelper } from '../base';
import { BaseModel } from './base';
const REDACTED_TOKEN = '[REDACTED]';
export interface CreateAccessTokenInput {
userId: string;
name: string;
expiresAt?: Date | null;
}
type UserAccessToken = {
id: string;
name: string;
createdAt: Date;
expiresAt: Date | null;
};
@Injectable()
export class AccessTokenModel extends BaseModel {
constructor(private readonly crypto: CryptoHelper) {
super();
}
async list(userId: string, revealed?: false): Promise<UserAccessToken[]>;
async list(
userId: string,
revealed: true
): Promise<(UserAccessToken & { token: string })[]>;
async list(userId: string, revealed: boolean = false) {
return await this.db.accessToken.findMany({
select: {
id: true,
name: true,
createdAt: true,
expiresAt: true,
token: revealed,
},
where: {
userId,
},
const tokens = await this.db.accessToken.findMany({
select: { id: true, name: true, createdAt: true, expiresAt: true },
where: { userId },
});
if (!revealed) return tokens;
return tokens.map(row => ({ ...row, token: REDACTED_TOKEN }));
}
async create(input: CreateAccessTokenInput) {
let token = 'ut_' + this.crypto.randomBytes(40).toString('hex');
token = token.substring(0, 40);
const token = `ut_${this.crypto.randomBytes(32).toString('base64url')}`;
const tokenHash = this.crypto.sha256(token).toString('hex');
return await this.db.accessToken.create({
data: {
token,
...input,
},
const created = await this.db.accessToken.create({
data: { token: tokenHash, ...input },
});
// NOTE: we only return the plaintext token once, at creation time.
return { ...created, token };
}
async revoke(id: string, userId: string) {
@@ -52,20 +62,27 @@ export class AccessTokenModel extends BaseModel {
}
async getByToken(token: string) {
return await this.db.accessToken.findUnique({
where: {
token,
OR: [
{
expiresAt: null,
},
{
expiresAt: {
gt: new Date(),
},
},
],
},
const tokenHash = this.crypto.sha256(token).toString('hex');
const condition = [{ expiresAt: null }, { expiresAt: { gt: new Date() } }];
const found = await this.db.accessToken.findUnique({
where: { token: tokenHash, OR: condition },
});
if (found) return found;
// Compatibility: lazy-migrate old plaintext tokens in DB.
const legacy = await this.db.accessToken.findUnique({
where: { token, OR: condition },
});
if (!legacy) return null;
await this.db.accessToken.update({
where: { id: legacy.id },
data: { token: tokenHash },
});
return { ...legacy, token: tokenHash };
}
}

View File

@@ -131,7 +131,7 @@ export class DocModel extends BaseModel {
},
});
if (count > 0) {
this.logger.log(
this.logger.verbose(
`Deleted ${count} updates for workspace ${workspaceId} doc ${docId}`
);
}
@@ -159,7 +159,7 @@ export class DocModel extends BaseModel {
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
const result: { updatedAt: Date }[] = await this.db.$queryRaw`
INSERT INTO "snapshots" ("workspace_id", "guid", "blob", "size", "created_at", "updated_at", "created_by", "updated_by")
VALUES (${spaceId}, ${docId}, ${blob}, ${size}, DEFAULT, ${updatedAt}, ${editorId}, ${editorId})
VALUES (${spaceId}, ${docId}, ${blob}, ${size}, ${updatedAt}, ${updatedAt}, ${editorId}, ${editorId})
ON CONFLICT ("workspace_id", "guid")
DO UPDATE SET "blob" = ${blob}, "size" = ${size}, "updated_at" = ${updatedAt}, "updated_by" = ${editorId}
WHERE "snapshots"."workspace_id" = ${spaceId} AND "snapshots"."guid" = ${docId} AND "snapshots"."updated_at" <= ${updatedAt}

View File

@@ -24,6 +24,7 @@ import { DocModel } from './doc';
import { DocUserModel } from './doc-user';
import { FeatureModel } from './feature';
import { HistoryModel } from './history';
import { MagicLinkOtpModel } from './magic-link-otp';
import { NotificationModel } from './notification';
import { MODELS_SYMBOL } from './provider';
import { SessionModel } from './session';
@@ -41,6 +42,7 @@ const MODELS = {
user: UserModel,
session: SessionModel,
verificationToken: VerificationTokenModel,
magicLinkOtp: MagicLinkOtpModel,
feature: FeatureModel,
workspace: WorkspaceModel,
userFeature: UserFeatureModel,
@@ -133,6 +135,7 @@ export * from './doc';
export * from './doc-user';
export * from './feature';
export * from './history';
export * from './magic-link-otp';
export * from './notification';
export * from './session';
export * from './user';

View File

@@ -0,0 +1,86 @@
import { Injectable } from '@nestjs/common';
import { Transactional } from '@nestjs-cls/transactional';
import { CryptoHelper } from '../base';
import { BaseModel } from './base';
const MAX_OTP_ATTEMPTS = 10;
const OTP_TTL_IN_SEC = 30 * 60;
export type ConsumeMagicLinkOtpResult =
| { ok: true; token: string }
| { ok: false; reason: 'not_found' | 'expired' | 'invalid_otp' | 'locked' }
| { ok: false; reason: 'nonce_mismatch' };
@Injectable()
export class MagicLinkOtpModel extends BaseModel {
constructor(private readonly crypto: CryptoHelper) {
super();
}
private hash(otp: string) {
return this.crypto.sha256(otp).toString('hex');
}
async upsert(
email: string,
otp: string,
token: string,
clientNonce?: string
) {
const otpHash = this.hash(otp);
const expiresAt = new Date(Date.now() + OTP_TTL_IN_SEC * 1000);
await this.db.magicLinkOtp.upsert({
where: { email },
create: { email, otpHash, token, clientNonce, expiresAt, attempts: 0 },
update: { otpHash, token, clientNonce, expiresAt, attempts: 0 },
});
}
@Transactional()
async consume(
email: string,
otp: string,
clientNonce?: string
): Promise<ConsumeMagicLinkOtpResult> {
const now = new Date();
const otpHash = this.hash(otp);
const record = await this.db.magicLinkOtp.findUnique({ where: { email } });
if (!record) {
return { ok: false, reason: 'not_found' };
}
if (record.expiresAt <= now) {
await this.db.magicLinkOtp.delete({ where: { email } });
return { ok: false, reason: 'expired' };
}
if (record.clientNonce && record.clientNonce !== clientNonce) {
return { ok: false, reason: 'nonce_mismatch' };
}
if (record.attempts >= MAX_OTP_ATTEMPTS) {
await this.db.magicLinkOtp.delete({ where: { email } });
return { ok: false, reason: 'locked' };
}
const matches = this.crypto.compare(record.otpHash, otpHash);
if (!matches) {
const attempts = record.attempts + 1;
if (attempts >= MAX_OTP_ATTEMPTS) {
await this.db.magicLinkOtp.delete({ where: { email } });
return { ok: false, reason: 'locked' };
}
await this.db.magicLinkOtp.update({
where: { email },
data: { attempts },
});
return { ok: false, reason: 'invalid_otp' };
}
await this.db.magicLinkOtp.delete({ where: { email } });
return { ok: true, token: record.token };
}
}

View File

@@ -14,6 +14,7 @@ export enum TokenType {
ChangeEmail,
ChangePassword,
Challenge,
OpenAppSignIn,
}
@Injectable()

View File

@@ -302,6 +302,29 @@ export class WorkspaceUserModel extends BaseModel {
});
}
async hasSharedWorkspace(userId: string, otherUserId: string) {
if (userId === otherUserId) {
return true;
}
const shared = await this.db.workspaceUserRole.findFirst({
select: { id: true },
where: {
userId,
status: WorkspaceMemberStatus.Accepted,
workspace: {
permissions: {
some: {
userId: otherUserId,
},
},
},
},
});
return !!shared;
}
async paginate(workspaceId: string, pagination: PaginationInput) {
return await Promise.all([
this.db.workspaceUserRole.findMany({

View File

@@ -105,10 +105,6 @@ class RemoveContextDocInput {
class AddContextFileInput {
@Field(() => String)
contextId!: string;
// @TODO(@darkskygit): remove this after client lower then 0.22 has been disconnected
@Field(() => String, { nullable: true, deprecationReason: 'Never used' })
blobId!: string | undefined;
}
@InputType()

View File

@@ -1672,42 +1672,12 @@ const imageActions: Prompt[] = [
},
],
},
// TODO(@darkskygit): deprecated, remove it after <0.22 version is outdated
{
name: 'debug:action:fal-remove-bg',
action: 'Remove background',
model: 'imageutils/rembg',
messages: [],
},
{
name: 'debug:action:fal-face-to-sticker',
action: 'Convert to sticker',
model: 'face-to-sticker',
messages: [],
},
{
name: 'debug:action:fal-teed',
action: 'fal-teed',
model: 'workflowutils/teed',
messages: [{ role: 'user', content: '{{content}}' }],
},
{
name: 'debug:action:fal-sd15',
action: 'image',
model: 'lcm-sd15-i2i',
messages: [],
},
{
name: 'debug:action:fal-upscaler',
action: 'Clearer',
model: 'clarity-upscaler',
messages: [
{
role: 'user',
content: 'best quality, 8K resolution, highres, clarity, {{content}}',
},
],
},
];
const modelActions: Prompt[] = [

View File

@@ -24,7 +24,9 @@ import {
CopilotPromptInvalid,
CopilotProviderNotSupported,
CopilotProviderSideError,
fetchBuffer,
metrics,
OneMB,
UserFriendlyError,
} from '../../../base';
import { CopilotProvider } from './provider';
@@ -673,14 +675,12 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
for (const [idx, entry] of attachments.entries()) {
const url = typeof entry === 'string' ? entry : entry.attachment;
const resp = await fetch(url);
if (resp.ok) {
const type = resp.headers.get('content-type');
if (type && type.startsWith('image/')) {
const buffer = new Uint8Array(await resp.arrayBuffer());
const file = new File([buffer], `${idx}.png`, { type });
form.append('image[]', file);
}
try {
const { buffer, type } = await fetchBuffer(url, 10 * OneMB, 'image/');
const file = new File([buffer], `${idx}.png`, { type });
form.append('image[]', file);
} catch {
continue;
}
}

View File

@@ -12,11 +12,22 @@ import {
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
import z, { ZodType } from 'zod';
import {
bufferToArrayBuffer,
fetchBuffer,
OneMinute,
ResponseTooLargeError,
safeFetch,
SsrfBlockedError,
} from '../../../base';
import { CustomAITools } from '../tools';
import { PromptMessage, StreamObject } from './types';
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
const ATTACHMENT_MAX_BYTES = 20 * 1024 * 1024;
const ATTACH_HEAD_PARAMS = { timeoutMs: OneMinute / 12, maxRedirects: 3 };
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
const FORMAT_INFER_MAP: Record<string, string> = {
pdf: 'application/pdf',
@@ -42,6 +53,11 @@ const FORMAT_INFER_MAP: Record<string, string> = {
flv: 'video/flv',
};
async function fetchArrayBuffer(url: string): Promise<ArrayBuffer> {
const { buffer } = await fetchBuffer(url, ATTACHMENT_MAX_BYTES);
return bufferToArrayBuffer(buffer);
}
export async function inferMimeType(url: string) {
if (url.startsWith('data:')) {
return url.split(';')[0].split(':')[1];
@@ -53,12 +69,15 @@ export async function inferMimeType(url: string) {
if (ext) {
return ext;
}
const mimeType = await fetch(url, {
method: 'HEAD',
redirect: 'follow',
}).then(res => res.headers.get('Content-Type'));
if (mimeType) {
return mimeType;
try {
const mimeType = await safeFetch(
url,
{ method: 'HEAD' },
ATTACH_HEAD_PARAMS
).then(res => res.headers.get('content-type'));
if (mimeType) return mimeType;
} catch {
// ignore and fallback to default
}
}
return 'application/octet-stream';
@@ -106,7 +125,16 @@ export async function chatToGPTMessage(
if (SIMPLE_IMAGE_URL_REGEX.test(attachment)) {
const data =
attachment.startsWith('data:') || useBase64Attachment
? await fetch(attachment).then(r => r.arrayBuffer())
? await fetchArrayBuffer(attachment).catch(error => {
// Avoid leaking internal details for blocked URLs.
if (
error instanceof SsrfBlockedError ||
error instanceof ResponseTooLargeError
) {
throw new Error('Attachment URL is not allowed');
}
throw error;
})
: new URL(attachment);
if (mediaType.startsWith('image/')) {
contents.push({ type: 'image', image: data, mediaType });

View File

@@ -7,7 +7,9 @@ import {
BlobQuotaExceeded,
CallMetric,
Config,
fetchBuffer,
type FileUpload,
OneMB,
OnEvent,
readBuffer,
type StorageProvider,
@@ -16,6 +18,8 @@ import {
} from '../../base';
import { QuotaService } from '../../core/quota';
const REMOTE_BLOB_MAX_BYTES = 20 * OneMB;
@Injectable()
export class CopilotStorage {
public provider!: StorageProvider;
@@ -88,9 +92,8 @@ export class CopilotStorage {
@CallMetric('ai', 'blob_proxy_remote_url')
async handleRemoteLink(userId: string, workspaceId: string, link: string) {
const response = await fetch(link);
const buffer = new Uint8Array(await response.arrayBuffer());
const { buffer } = await fetchBuffer(link, REMOTE_BLOB_MAX_BYTES, 'image/');
const filename = createHash('sha256').update(buffer).digest('base64url');
return this.put(userId, workspaceId, filename, Buffer.from(buffer));
return this.put(userId, workspaceId, filename, buffer);
}
}

View File

@@ -13,6 +13,7 @@ import { ConnectedAccount } from '@prisma/client';
import type { Request, Response } from 'express';
import {
ActionForbidden,
Config,
InvalidAuthState,
InvalidOauthCallbackState,
@@ -57,6 +58,9 @@ export class OAuthController {
if (!unknownProviderName) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
if (!clientNonce) {
throw new MissingOauthQueryParameter({ name: 'client_nonce' });
}
const providerName = OAuthProviderName[unknownProviderName];
const provider = this.providerFactory.get(providerName);
@@ -67,6 +71,10 @@ export class OAuthController {
const pkce = provider.requiresPkce ? this.oauth.createPkcePair() : null;
if (redirectUri && !this.url.isAllowedRedirectUri(redirectUri)) {
throw new ActionForbidden();
}
const state = await this.oauth.saveOAuthState({
provider: providerName,
redirectUri,
@@ -173,16 +181,6 @@ export class OAuthController {
);
}
// TODO(@fengmk2): clientNonce should be required after the client version >= 0.21.0
if (
state.clientNonce &&
state.clientNonce !== clientNonce &&
// apple sign in with nonce stored in id token
state.provider !== OAuthProviderName.Apple
) {
throw new InvalidAuthState();
}
if (!state.provider) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
@@ -193,6 +191,13 @@ export class OAuthController {
throw new UnknownOauthProvider({ name: state.provider ?? 'unknown' });
}
if (
state.provider !== OAuthProviderName.Apple &&
(!clientNonce || !state.clientNonce || state.clientNonce !== clientNonce)
) {
throw new InvalidAuthState();
}
let tokens: Tokens;
try {
tokens = await provider.getToken(code, state);
@@ -221,7 +226,7 @@ export class OAuthController {
state.provider === OAuthProviderName.Apple &&
(!state.client || state.client === 'web')
) {
return res.redirect(this.url.link(state.redirectUri ?? '/'));
return this.url.safeRedirect(res, state.redirectUri ?? '/');
}
res.send({

View File

@@ -1,38 +1,17 @@
import {
Context,
registerEnumType,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import type { Request } from 'express';
import semver from 'semver';
import { registerEnumType, ResolveField, Resolver } from '@nestjs/graphql';
import { getClientVersionFromRequest } from '../../base';
import { ServerConfigType } from '../../core/config/types';
import { OAuthProviderName } from './config';
import { OAuthProviderFactory } from './factory';
registerEnumType(OAuthProviderName, { name: 'OAuthProviderType' });
const APPLE_OAUTH_PROVIDER_MIN_VERSION = new semver.Range('>=0.22.0', {
includePrerelease: true,
});
@Resolver(() => ServerConfigType)
export class OAuthResolver {
constructor(private readonly factory: OAuthProviderFactory) {}
@ResolveField(() => [OAuthProviderName])
oauthProviders(@Context() ctx: { req: Request }) {
// Apple oauth provider is not supported in client version < 0.22.0
const providers = this.factory.providers;
if (providers.includes(OAuthProviderName.Apple)) {
const version = getClientVersionFromRequest(ctx.req);
if (!version || !APPLE_OAUTH_PROVIDER_MIN_VERSION.test(version)) {
return providers.filter(p => p !== OAuthProviderName.Apple);
}
}
return providers;
oauthProviders() {
return this.factory.providers;
}
}

View File

@@ -7,10 +7,23 @@ import {
Req,
Res,
} from '@nestjs/common';
import type { Request, Response } from 'express';
import type {
Request as ExpressRequest,
Response as ExpressResponse,
} from 'express';
import { HTMLRewriter } from 'htmlrewriter';
import { BadRequest, Cache, URLHelper, UseNamedGuard } from '../../base';
import {
BadRequest,
Cache,
readResponseBufferWithLimit,
ResponseTooLargeError,
safeFetch,
SsrfBlockedError,
type SSRFBlockReason,
URLHelper,
UseNamedGuard,
} from '../../base';
import { Public } from '../../core/auth';
import { WorkerService } from './service';
import type { LinkPreviewRequest, LinkPreviewResponse } from './types';
@@ -28,6 +41,25 @@ import { decodeWithCharset } from './utils/encoding';
// cache for 30 minutes
const CACHE_TTL = 1000 * 60 * 30;
const MAX_REDIRECTS = 3;
const FETCH_TIMEOUT_MS = 10_000;
const IMAGE_PROXY_MAX_BYTES = 10 * 1024 * 1024;
const LINK_PREVIEW_MAX_BYTES = 2 * 1024 * 1024;
function toBadRequestReason(reason: SSRFBlockReason) {
switch (reason) {
case 'disallowed_protocol':
case 'url_has_credentials':
case 'blocked_hostname':
case 'blocked_ip':
case 'invalid_url':
return 'Invalid URL';
case 'unresolvable_hostname':
return 'Failed to resolve hostname';
case 'too_many_redirects':
return 'Too many redirects';
}
}
@Public()
@UseNamedGuard('selfhost')
@@ -45,14 +77,33 @@ export class WorkerController {
return this.service.allowedOrigins;
}
@Options('/image-proxy')
imageProxyOption(
@Req() request: ExpressRequest,
@Res() resp: ExpressResponse
) {
const origin = request.headers.origin;
return resp
.status(204)
.header({
...getCorsHeaders(origin),
'Access-Control-Allow-Methods': 'GET, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type',
})
.send();
}
@Get('/image-proxy')
async imageProxy(@Req() req: Request, @Res() resp: Response) {
const origin = req.headers.origin ?? '';
async imageProxy(@Req() req: ExpressRequest, @Res() resp: ExpressResponse) {
const origin = req.headers.origin;
const referer = req.headers.referer;
if (
(origin && !isOriginAllowed(origin, this.allowedOrigin)) ||
(referer && !isRefererAllowed(referer, this.allowedOrigin))
) {
const originAllowed = origin
? isOriginAllowed(origin, this.allowedOrigin)
: false;
const refererAllowed = referer
? isRefererAllowed(referer, this.allowedOrigin)
: false;
if (!originAllowed && !refererAllowed) {
this.logger.error('Invalid Origin', 'ERROR', { origin, referer });
throw new BadRequest('Invalid header');
}
@@ -79,24 +130,66 @@ export class WorkerController {
return resp
.status(200)
.header({
'Access-Control-Allow-Origin': origin,
Vary: 'Origin',
...getCorsHeaders(origin),
...(origin ? { Vary: 'Origin' } : {}),
'Access-Control-Allow-Methods': 'GET',
'Content-Type': 'image/*',
})
.send(buffer);
}
const response = await fetch(
new Request(targetURL.toString(), {
method: 'GET',
headers: cloneHeader(req.headers),
})
);
let response: Response;
try {
response = await safeFetch(
targetURL.toString(),
{ method: 'GET', headers: cloneHeader(req.headers) },
{ timeoutMs: FETCH_TIMEOUT_MS, maxRedirects: MAX_REDIRECTS }
);
} catch (error) {
if (error instanceof SsrfBlockedError) {
const reason = error.data?.reason as SSRFBlockReason | undefined;
this.logger.warn('Blocked image proxy target', {
url: imageURL,
reason,
context: (error as any).context,
});
throw new BadRequest(toBadRequestReason(reason ?? 'invalid_url'));
}
if (error instanceof ResponseTooLargeError) {
this.logger.warn('Image proxy response too large', {
url: imageURL,
limitBytes: error.data?.limitBytes,
receivedBytes: error.data?.receivedBytes,
});
throw new BadRequest('Response too large');
}
this.logger.error('Failed to fetch image', {
origin,
url: imageURL,
error,
});
throw new BadRequest('Failed to fetch image');
}
if (response.ok) {
const contentType = response.headers.get('Content-Type');
if (contentType?.startsWith('image/')) {
const buffer = Buffer.from(await response.arrayBuffer());
let buffer: Buffer;
try {
buffer = await readResponseBufferWithLimit(
response,
IMAGE_PROXY_MAX_BYTES
);
} catch (error) {
if (error instanceof ResponseTooLargeError) {
this.logger.warn('Image proxy response too large', {
url: imageURL,
limitBytes: error.data?.limitBytes,
receivedBytes: error.data?.receivedBytes,
});
throw new BadRequest('Response too large');
}
throw error;
}
await this.cache.set(cachedUrl, buffer.toString('base64'), {
ttl: CACHE_TTL,
});
@@ -104,8 +197,8 @@ export class WorkerController {
return resp
.status(200)
.header({
'Access-Control-Allow-Origin': origin ?? 'null',
Vary: 'Origin',
...getCorsHeaders(origin),
...(origin ? { Vary: 'Origin' } : {}),
'Access-Control-Allow-Methods': 'GET',
'Content-Type': contentType,
'Content-Disposition': contentDisposition,
@@ -124,17 +217,20 @@ export class WorkerController {
this.logger.error('Failed to fetch image', {
origin,
url: imageURL,
status: resp.status,
status: response.status,
});
throw new BadRequest('Failed to fetch image');
}
}
@Options('/link-preview')
linkPreviewOption(@Req() request: Request, @Res() resp: Response) {
linkPreviewOption(
@Req() request: ExpressRequest,
@Res() resp: ExpressResponse
) {
const origin = request.headers.origin;
return resp
.status(200)
.status(204)
.header({
...getCorsHeaders(origin),
'Access-Control-Allow-Methods': 'POST, OPTIONS',
@@ -145,15 +241,18 @@ export class WorkerController {
@Post('/link-preview')
async linkPreview(
@Req() request: Request,
@Res() resp: Response
): Promise<Response> {
@Req() request: ExpressRequest,
@Res() resp: ExpressResponse
): Promise<ExpressResponse> {
const origin = request.headers.origin;
const referer = request.headers.referer;
if (
(origin && !isOriginAllowed(origin, this.allowedOrigin)) ||
(referer && !isRefererAllowed(referer, this.allowedOrigin))
) {
const originAllowed = origin
? isOriginAllowed(origin, this.allowedOrigin)
: false;
const refererAllowed = referer
? isRefererAllowed(referer, this.allowedOrigin)
: false;
if (!originAllowed && !refererAllowed) {
this.logger.error('Invalid Origin', { origin, referer });
throw new BadRequest('Invalid header');
}
@@ -183,9 +282,13 @@ export class WorkerController {
.send(cachedResponse);
}
const response = await fetch(targetURL, {
headers: cloneHeader(request.headers),
});
const method: 'GET' | 'HEAD' = requestBody?.head ? 'HEAD' : 'GET';
const response = await safeFetch(
targetURL.toString(),
{ method, headers: cloneHeader(request.headers) },
{ timeoutMs: FETCH_TIMEOUT_MS, maxRedirects: MAX_REDIRECTS }
);
this.logger.debug('Fetched URL', {
origin,
url: targetURL,
@@ -211,7 +314,12 @@ export class WorkerController {
};
if (response.body) {
const resp = await decodeWithCharset(response, res);
const body = await readResponseBufferWithLimit(
response,
LINK_PREVIEW_MAX_BYTES
);
const limitedResponse = new Response(body, response);
const resp = await decodeWithCharset(limitedResponse, res);
const rewriter = new HTMLRewriter()
.on('meta', {
@@ -287,7 +395,11 @@ export class WorkerController {
{
// head default path of favicon
const faviconUrl = new URL('/favicon.ico?v=2', response.url);
const faviconResponse = await fetch(faviconUrl, { method: 'HEAD' });
const faviconResponse = await safeFetch(
faviconUrl.toString(),
{ method: 'HEAD' },
{ timeoutMs: FETCH_TIMEOUT_MS, maxRedirects: MAX_REDIRECTS }
);
if (faviconResponse.ok) {
appendUrl(faviconUrl.toString(), res.favicons);
}
@@ -311,6 +423,25 @@ export class WorkerController {
})
.send(json);
} catch (error) {
if (error instanceof SsrfBlockedError) {
const reason = error.data?.reason as SSRFBlockReason | undefined;
this.logger.warn('Blocked link preview target', {
origin,
url: requestBody?.url,
reason,
context: (error as any).context,
});
throw new BadRequest(toBadRequestReason(reason ?? 'invalid_url'));
}
if (error instanceof ResponseTooLargeError) {
this.logger.warn('Link preview response too large', {
origin,
url: requestBody?.url,
limitBytes: error.data?.limitBytes,
receivedBytes: error.data?.receivedBytes,
});
throw new BadRequest('Response too large');
}
this.logger.error('Error fetching URL', {
origin,
url: targetURL,

View File

@@ -27,7 +27,6 @@ input AddContextDocInput {
}
input AddContextFileInput {
blobId: String
contextId: String!
}
@@ -798,7 +797,7 @@ type EditorType {
name: String!
}
union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CalendarProviderRequestErrorDataType | CopilotContextFileNotSupportedDataType | CopilotDocNotFoundDataType | CopilotFailedToAddWorkspaceFileEmbeddingDataType | CopilotFailedToGenerateEmbeddingDataType | CopilotFailedToMatchContextDataType | CopilotFailedToMatchGlobalContextDataType | CopilotFailedToModifyContextDataType | CopilotInvalidContextDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderNotSupportedDataType | CopilotProviderSideErrorDataType | DocActionDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | DocUpdateBlockedDataType | ExpectToGrantDocUserRolesDataType | ExpectToRevokeDocUserRolesDataType | ExpectToUpdateDocUserRoleDataType | GraphqlBadRequestDataType | HttpRequestErrorDataType | InvalidAppConfigDataType | InvalidAppConfigInputDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType | InvalidIndexerInputDataType | InvalidLicenseToActivateDataType | InvalidLicenseUpdateParamsDataType | InvalidOauthCallbackCodeDataType | InvalidOauthResponseDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | InvalidSearchProviderRequestDataType | MemberNotFoundInSpaceDataType | MentionUserDocAccessDeniedDataType | MissingOauthQueryParameterDataType | NoCopilotProviderAvailableDataType | NoMoreSeatDataType | NotInSpaceDataType | QueryTooLongDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SpaceShouldHaveOnlyOneOwnerDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | UnsupportedClientVersionDataType | UnsupportedSubscriptionPlanDataType | ValidationErrorDataType | VersionRejectedDataType | WorkspacePermissionNotFoundDataType | WrongSignInCredentialsDataType
union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CalendarProviderRequestErrorDataType | CopilotContextFileNotSupportedDataType | CopilotDocNotFoundDataType | CopilotFailedToAddWorkspaceFileEmbeddingDataType | CopilotFailedToGenerateEmbeddingDataType | CopilotFailedToMatchContextDataType | CopilotFailedToMatchGlobalContextDataType | CopilotFailedToModifyContextDataType | CopilotInvalidContextDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderNotSupportedDataType | CopilotProviderSideErrorDataType | DocActionDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | DocUpdateBlockedDataType | ExpectToGrantDocUserRolesDataType | ExpectToRevokeDocUserRolesDataType | ExpectToUpdateDocUserRoleDataType | GraphqlBadRequestDataType | HttpRequestErrorDataType | InvalidAppConfigDataType | InvalidAppConfigInputDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType | InvalidIndexerInputDataType | InvalidLicenseToActivateDataType | InvalidLicenseUpdateParamsDataType | InvalidOauthCallbackCodeDataType | InvalidOauthResponseDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | InvalidSearchProviderRequestDataType | MemberNotFoundInSpaceDataType | MentionUserDocAccessDeniedDataType | MissingOauthQueryParameterDataType | NoCopilotProviderAvailableDataType | NoMoreSeatDataType | NotInSpaceDataType | QueryTooLongDataType | ResponseTooLargeErrorDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SpaceShouldHaveOnlyOneOwnerDataType | SsrfBlockedErrorDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | UnsupportedClientVersionDataType | UnsupportedSubscriptionPlanDataType | ValidationErrorDataType | VersionRejectedDataType | WorkspacePermissionNotFoundDataType | WrongSignInCredentialsDataType
enum ErrorNames {
ACCESS_DENIED
@@ -912,6 +911,7 @@ enum ErrorNames {
PASSWORD_REQUIRED
QUERY_TOO_LONG
REPLY_NOT_FOUND
RESPONSE_TOO_LARGE_ERROR
RUNTIME_CONFIG_NOT_FOUND
SAME_EMAIL_PROVIDED
SAME_SUBSCRIPTION_RECURRING
@@ -921,6 +921,7 @@ enum ErrorNames {
SPACE_NOT_FOUND
SPACE_OWNER_NOT_FOUND
SPACE_SHOULD_HAVE_ONLY_ONE_OWNER
SSRF_BLOCKED_ERROR
STORAGE_QUOTA_EXCEEDED
SUBSCRIPTION_ALREADY_EXISTS
SUBSCRIPTION_EXPIRED
@@ -1453,14 +1454,12 @@ type Mutation {
forkCopilotSession(options: ForkChatSessionInput!): String!
generateLicenseKey(sessionId: String!): String!
generateUserAccessToken(input: GenerateAccessTokenInput!): RevealedAccessToken!
getBlobUploadPartUrl(key: String!, partNumber: Int!, uploadId: String!, workspaceId: String!): BlobUploadPart! @deprecated(reason: "use WorkspaceType.blobUploadPartUrl")
grantDocUserRoles(input: GrantDocUserRolesInput!): Boolean!
grantMember(permission: Permission!, userId: String!, workspaceId: String!): Boolean!
"""import users"""
importUsers(input: ImportUsersInput!): [UserImportResultType!]!
installLicense(license: Upload!, workspaceId: String!): License!
inviteBatch(emails: [String!]!, sendInviteMail: Boolean @deprecated(reason: "never used"), workspaceId: String!): [InviteResult!]! @deprecated(reason: "use [inviteMembers] instead")
inviteMembers(emails: [String!]!, workspaceId: String!): [InviteResult!]!
leaveWorkspace(sendLeaveMail: Boolean @deprecated(reason: "no used anymore"), workspaceId: String!, workspaceName: String @deprecated(reason: "no longer used")): Boolean!
linkCalendarAccount(input: LinkCalendarAccountInput!): String!
@@ -1468,7 +1467,6 @@ type Mutation {
"""mention user in a doc"""
mentionUser(input: MentionInput!): ID!
publishDoc(docId: String!, mode: PublicDocMode = Page, workspaceId: String!): DocType!
publishPage(mode: PublicDocMode = Page, pageId: String!, workspaceId: String!): DocType! @deprecated(reason: "use publishDoc instead")
"""queue workspace doc embedding"""
queueWorkspaceEmbedding(docId: [String!]!, workspaceId: String!): Boolean!
@@ -1510,12 +1508,10 @@ type Mutation {
resolveComment(input: CommentResolveInput!): Boolean!
resumeSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, workspaceId: String): SubscriptionType!
retryAudioTranscription(jobId: String!, workspaceId: String!): TranscriptionResultType
revoke(userId: String!, workspaceId: String!): Boolean! @deprecated(reason: "use [revokeMember] instead")
revokeDocUserRoles(input: RevokeDocUserRoleInput!): Boolean!
revokeInviteLink(workspaceId: String!): Boolean!
revokeMember(userId: String!, workspaceId: String!): Boolean!
revokePublicDoc(docId: String!, workspaceId: String!): DocType!
revokePublicPage(docId: String!, workspaceId: String!): DocType! @deprecated(reason: "use revokePublicDoc instead")
revokeUserAccessToken(id: String!): Boolean!
sendChangeEmail(callbackUrl: String!, email: String): Boolean!
sendChangePasswordEmail(callbackUrl: String!, email: String @deprecated(reason: "fetched from signed in user")): Boolean!
@@ -1574,9 +1570,6 @@ type Mutation {
"""Upload a comment attachment and return the access url"""
uploadCommentAttachment(attachment: Upload!, docId: String!, workspaceId: String!): String!
"""validate app configuration"""
validateAppConfig(updates: [UpdateAppConfigInput!]!): [AppConfigValidateResult!]! @deprecated(reason: "use Query.validateAppConfig")
verifyEmail(token: String!): Boolean!
}
@@ -1754,8 +1747,6 @@ type PublicUserType {
}
type Query {
accessTokens: [AccessToken!]! @deprecated(reason: "use currentUser.accessTokens")
"""Get workspace detail for admin"""
adminWorkspace(id: String!): AdminWorkspace
@@ -1770,7 +1761,6 @@ type Query {
"""Apply updates to a doc using LLM and return the merged markdown."""
applyDocUpdates(docId: String!, op: String!, updates: String!, workspaceId: String!): String! @deprecated(reason: "use Mutation.applyDocUpdates")
collectAllBlobSizes: WorkspaceBlobSizes! @deprecated(reason: "use `user.quotaUsage` instead")
"""Get current user"""
currentUser: UserType
@@ -1779,12 +1769,6 @@ type Query {
"""get workspace invitation info"""
getInviteInfo(inviteId: String!): InvitationType!
"""Get is admin of workspace"""
isAdmin(workspaceId: String!): Boolean! @deprecated(reason: "use WorkspaceType[role] instead")
"""Get is owner of workspace"""
isOwner(workspaceId: String!): Boolean! @deprecated(reason: "use WorkspaceType[role] instead")
"""List all copilot prompts"""
listCopilotPrompts: [CopilotPromptType!]!
prices: [SubscriptionPrice!]!
@@ -1918,6 +1902,11 @@ input ReplyUpdateInput {
id: ID!
}
type ResponseTooLargeErrorDataType {
limitBytes: Int!
receivedBytes: Int!
}
type RevealedAccessToken {
createdAt: DateTime!
expiresAt: DateTime
@@ -2104,6 +2093,10 @@ type SpaceShouldHaveOnlyOneOwnerDataType {
spaceId: String!
}
type SsrfBlockedErrorDataType {
reason: String!
}
type StreamObject {
args: JSON
result: JSON
@@ -2405,10 +2398,6 @@ type VersionRejectedDataType {
version: String!
}
type WorkspaceBlobSizes {
size: SafeInt!
}
input WorkspaceCalendarItemInput {
colorOverride: String
sortOrder: Int
@@ -2591,9 +2580,6 @@ type WorkspaceType {
"""Get public docs of a workspace"""
publicDocs: [DocType!]!
"""Get public page of a workspace by page id."""
publicPage(pageId: String!): DocType @deprecated(reason: "use [WorkspaceType.doc] instead")
"""quota of workspace"""
quota: WorkspaceQuotaType!

View File

@@ -63,7 +63,6 @@ export interface AddContextDocInput {
}
export interface AddContextFileInput {
blobId?: InputMaybe<Scalars['String']['input']>;
contextId: Scalars['String']['input'];
}
@@ -978,12 +977,14 @@ export type ErrorDataUnion =
| NoMoreSeatDataType
| NotInSpaceDataType
| QueryTooLongDataType
| ResponseTooLargeErrorDataType
| RuntimeConfigNotFoundDataType
| SameSubscriptionRecurringDataType
| SpaceAccessDeniedDataType
| SpaceNotFoundDataType
| SpaceOwnerNotFoundDataType
| SpaceShouldHaveOnlyOneOwnerDataType
| SsrfBlockedErrorDataType
| SubscriptionAlreadyExistsDataType
| SubscriptionNotExistsDataType
| SubscriptionPlanNotFoundDataType
@@ -1107,6 +1108,7 @@ export enum ErrorNames {
PASSWORD_REQUIRED = 'PASSWORD_REQUIRED',
QUERY_TOO_LONG = 'QUERY_TOO_LONG',
REPLY_NOT_FOUND = 'REPLY_NOT_FOUND',
RESPONSE_TOO_LARGE_ERROR = 'RESPONSE_TOO_LARGE_ERROR',
RUNTIME_CONFIG_NOT_FOUND = 'RUNTIME_CONFIG_NOT_FOUND',
SAME_EMAIL_PROVIDED = 'SAME_EMAIL_PROVIDED',
SAME_SUBSCRIPTION_RECURRING = 'SAME_SUBSCRIPTION_RECURRING',
@@ -1116,6 +1118,7 @@ export enum ErrorNames {
SPACE_NOT_FOUND = 'SPACE_NOT_FOUND',
SPACE_OWNER_NOT_FOUND = 'SPACE_OWNER_NOT_FOUND',
SPACE_SHOULD_HAVE_ONLY_ONE_OWNER = 'SPACE_SHOULD_HAVE_ONLY_ONE_OWNER',
SSRF_BLOCKED_ERROR = 'SSRF_BLOCKED_ERROR',
STORAGE_QUOTA_EXCEEDED = 'STORAGE_QUOTA_EXCEEDED',
SUBSCRIPTION_ALREADY_EXISTS = 'SUBSCRIPTION_ALREADY_EXISTS',
SUBSCRIPTION_EXPIRED = 'SUBSCRIPTION_EXPIRED',
@@ -1622,23 +1625,17 @@ export interface Mutation {
forkCopilotSession: Scalars['String']['output'];
generateLicenseKey: Scalars['String']['output'];
generateUserAccessToken: RevealedAccessToken;
/** @deprecated use WorkspaceType.blobUploadPartUrl */
getBlobUploadPartUrl: BlobUploadPart;
grantDocUserRoles: Scalars['Boolean']['output'];
grantMember: Scalars['Boolean']['output'];
/** import users */
importUsers: Array<UserImportResultType>;
installLicense: License;
/** @deprecated use [inviteMembers] instead */
inviteBatch: Array<InviteResult>;
inviteMembers: Array<InviteResult>;
leaveWorkspace: Scalars['Boolean']['output'];
linkCalendarAccount: Scalars['String']['output'];
/** mention user in a doc */
mentionUser: Scalars['ID']['output'];
publishDoc: DocType;
/** @deprecated use publishDoc instead */
publishPage: DocType;
/** queue workspace doc embedding */
queueWorkspaceEmbedding: Scalars['Boolean']['output'];
/** mark all notifications as read */
@@ -1668,14 +1665,10 @@ export interface Mutation {
resolveComment: Scalars['Boolean']['output'];
resumeSubscription: SubscriptionType;
retryAudioTranscription: Maybe<TranscriptionResultType>;
/** @deprecated use [revokeMember] instead */
revoke: Scalars['Boolean']['output'];
revokeDocUserRoles: Scalars['Boolean']['output'];
revokeInviteLink: Scalars['Boolean']['output'];
revokeMember: Scalars['Boolean']['output'];
revokePublicDoc: DocType;
/** @deprecated use revokePublicDoc instead */
revokePublicPage: DocType;
revokeUserAccessToken: Scalars['Boolean']['output'];
sendChangeEmail: Scalars['Boolean']['output'];
sendChangePasswordEmail: Scalars['Boolean']['output'];
@@ -1720,11 +1713,6 @@ export interface Mutation {
uploadAvatar: UserType;
/** Upload a comment attachment and return the access url */
uploadCommentAttachment: Scalars['String']['output'];
/**
* validate app configuration
* @deprecated use Query.validateAppConfig
*/
validateAppConfig: Array<AppConfigValidateResult>;
verifyEmail: Scalars['Boolean']['output'];
}
@@ -1925,13 +1913,6 @@ export interface MutationGenerateUserAccessTokenArgs {
input: GenerateAccessTokenInput;
}
export interface MutationGetBlobUploadPartUrlArgs {
key: Scalars['String']['input'];
partNumber: Scalars['Int']['input'];
uploadId: Scalars['String']['input'];
workspaceId: Scalars['String']['input'];
}
export interface MutationGrantDocUserRolesArgs {
input: GrantDocUserRolesInput;
}
@@ -1951,12 +1932,6 @@ export interface MutationInstallLicenseArgs {
workspaceId: Scalars['String']['input'];
}
export interface MutationInviteBatchArgs {
emails: Array<Scalars['String']['input']>;
sendInviteMail?: InputMaybe<Scalars['Boolean']['input']>;
workspaceId: Scalars['String']['input'];
}
export interface MutationInviteMembersArgs {
emails: Array<Scalars['String']['input']>;
workspaceId: Scalars['String']['input'];
@@ -1982,12 +1957,6 @@ export interface MutationPublishDocArgs {
workspaceId: Scalars['String']['input'];
}
export interface MutationPublishPageArgs {
mode?: InputMaybe<PublicDocMode>;
pageId: Scalars['String']['input'];
workspaceId: Scalars['String']['input'];
}
export interface MutationQueueWorkspaceEmbeddingArgs {
docId: Array<Scalars['String']['input']>;
workspaceId: Scalars['String']['input'];
@@ -2052,11 +2021,6 @@ export interface MutationRetryAudioTranscriptionArgs {
workspaceId: Scalars['String']['input'];
}
export interface MutationRevokeArgs {
userId: Scalars['String']['input'];
workspaceId: Scalars['String']['input'];
}
export interface MutationRevokeDocUserRolesArgs {
input: RevokeDocUserRoleInput;
}
@@ -2075,11 +2039,6 @@ export interface MutationRevokePublicDocArgs {
workspaceId: Scalars['String']['input'];
}
export interface MutationRevokePublicPageArgs {
docId: Scalars['String']['input'];
workspaceId: Scalars['String']['input'];
}
export interface MutationRevokeUserAccessTokenArgs {
id: Scalars['String']['input'];
}
@@ -2212,10 +2171,6 @@ export interface MutationUploadCommentAttachmentArgs {
workspaceId: Scalars['String']['input'];
}
export interface MutationValidateAppConfigArgs {
updates: Array<UpdateAppConfigInput>;
}
export interface MutationVerifyEmailArgs {
token: Scalars['String']['input'];
}
@@ -2401,8 +2356,6 @@ export interface PublicUserType {
export interface Query {
__typename?: 'Query';
/** @deprecated use currentUser.accessTokens */
accessTokens: Array<AccessToken>;
/** Get workspace detail for admin */
adminWorkspace: Maybe<AdminWorkspace>;
/** List workspaces for admin */
@@ -2416,23 +2369,11 @@ export interface Query {
* @deprecated use Mutation.applyDocUpdates
*/
applyDocUpdates: Scalars['String']['output'];
/** @deprecated use `user.quotaUsage` instead */
collectAllBlobSizes: WorkspaceBlobSizes;
/** Get current user */
currentUser: Maybe<UserType>;
error: ErrorDataUnion;
/** get workspace invitation info */
getInviteInfo: InvitationType;
/**
* Get is admin of workspace
* @deprecated use WorkspaceType[role] instead
*/
isAdmin: Scalars['Boolean']['output'];
/**
* Get is owner of workspace
* @deprecated use WorkspaceType[role] instead
*/
isOwner: Scalars['Boolean']['output'];
/** List all copilot prompts */
listCopilotPrompts: Array<CopilotPromptType>;
prices: Array<SubscriptionPrice>;
@@ -2494,14 +2435,6 @@ export interface QueryGetInviteInfoArgs {
inviteId: Scalars['String']['input'];
}
export interface QueryIsAdminArgs {
workspaceId: Scalars['String']['input'];
}
export interface QueryIsOwnerArgs {
workspaceId: Scalars['String']['input'];
}
export interface QueryPublicUserByIdArgs {
id: Scalars['String']['input'];
}
@@ -2630,6 +2563,12 @@ export interface ReplyUpdateInput {
id: Scalars['ID']['input'];
}
export interface ResponseTooLargeErrorDataType {
__typename?: 'ResponseTooLargeErrorDataType';
limitBytes: Scalars['Int']['output'];
receivedBytes: Scalars['Int']['output'];
}
export interface RevealedAccessToken {
__typename?: 'RevealedAccessToken';
createdAt: Scalars['DateTime']['output'];
@@ -2812,6 +2751,11 @@ export interface SpaceShouldHaveOnlyOneOwnerDataType {
spaceId: Scalars['String']['output'];
}
export interface SsrfBlockedErrorDataType {
__typename?: 'SsrfBlockedErrorDataType';
reason: Scalars['String']['output'];
}
export interface StreamObject {
__typename?: 'StreamObject';
args: Maybe<Scalars['JSON']['output']>;
@@ -3126,11 +3070,6 @@ export interface VersionRejectedDataType {
version: Scalars['String']['output'];
}
export interface WorkspaceBlobSizes {
__typename?: 'WorkspaceBlobSizes';
size: Scalars['SafeInt']['output'];
}
export interface WorkspaceCalendarItemInput {
colorOverride?: InputMaybe<Scalars['String']['input']>;
sortOrder?: InputMaybe<Scalars['Int']['input']>;
@@ -3308,11 +3247,6 @@ export interface WorkspaceType {
public: Scalars['Boolean']['output'];
/** Get public docs of a workspace */
publicDocs: Array<DocType>;
/**
* Get public page of a workspace by page id.
* @deprecated use [WorkspaceType.doc] instead
*/
publicPage: Maybe<DocType>;
/** quota of workspace */
quota: WorkspaceQuotaType;
/** Get recently updated docs of a workspace */
@@ -3378,10 +3312,6 @@ export interface WorkspaceTypePageMetaArgs {
pageId: Scalars['String']['input'];
}
export interface WorkspaceTypePublicPageArgs {
pageId: Scalars['String']['input'];
}
export interface WorkspaceTypeRecentlyUpdatedDocsArgs {
pagination: PaginationInput;
}

View File

@@ -0,0 +1,66 @@
import { describe, expect, test } from 'vitest';
import { isAllowedRedirectTarget } from '../redirect-allowlist';
describe('redirect allowlist', () => {
test('allows same hostname', () => {
expect(
isAllowedRedirectTarget('https://self.example.com/path', {
currentHostname: 'self.example.com',
})
).toBe(true);
});
test('allows trusted domains and subdomains', () => {
expect(
isAllowedRedirectTarget('https://github.com/toeverything/AFFiNE', {
currentHostname: 'self.example.com',
})
).toBe(true);
expect(
isAllowedRedirectTarget('https://sub.github.com/foo', {
currentHostname: 'self.example.com',
})
).toBe(true);
});
test('blocks look-alike domains', () => {
expect(
isAllowedRedirectTarget('https://evilgithub.com', {
currentHostname: 'self.example.com',
})
).toBe(false);
});
test('blocks disallowed protocols', () => {
expect(
isAllowedRedirectTarget('javascript:alert(1)', {
currentHostname: 'self.example.com',
})
).toBe(false);
});
test('handles port and trailing dot', () => {
expect(
isAllowedRedirectTarget('https://github.com:8443', {
currentHostname: 'self.example.com',
})
).toBe(true);
expect(
isAllowedRedirectTarget('https://affine.pro./', {
currentHostname: 'self.example.com',
})
).toBe(true);
});
test('blocks punycode homograph', () => {
// "а" is Cyrillic small a (U+0430), different from Latin "a"
expect(
isAllowedRedirectTarget('https://аffine.pro', {
currentHostname: 'self.example.com',
})
).toBe(false);
});
});

View File

@@ -4,6 +4,7 @@ export * from './exhaustmap-with-trailing';
export * from './fractional-indexing';
export * from './merge-updates';
export * from './object-pool';
export * from './redirect-allowlist';
export * from './stable-hash';
export * from './throw-if-aborted';
export * from './yjs-observable';

View File

@@ -0,0 +1,50 @@
export const TRUSTED_REDIRECT_DOMAINS = [
'google.com',
'stripe.com',
'github.com',
'twitter.com',
'discord.gg',
'youtube.com',
't.me',
'reddit.com',
'affine.pro',
].map(d => d.toLowerCase());
export const ALLOWED_REDIRECT_PROTOCOLS = new Set(['http:', 'https:']);
function normalizeHostname(hostname: string) {
return hostname.toLowerCase().replace(/\.$/, '');
}
function hostnameMatchesDomain(hostname: string, domain: string) {
return hostname === domain || hostname.endsWith(`.${domain}`);
}
export function isAllowedRedirectTarget(
redirectUri: string,
options: {
currentHostname: string;
}
) {
const currentHostname = normalizeHostname(options.currentHostname);
try {
const target = new URL(redirectUri);
if (!ALLOWED_REDIRECT_PROTOCOLS.has(target.protocol)) {
return false;
}
const hostname = normalizeHostname(target.hostname);
if (hostname === currentHostname) {
return true;
}
return TRUSTED_REDIRECT_DOMAINS.some(domain =>
hostnameMatchesDomain(hostname, domain)
);
} catch {
return false;
}
}

View File

@@ -26,7 +26,7 @@
"lodash-es": "^4.17.21",
"nanoid": "^5.1.6",
"rxjs": "^7.8.2",
"uuid": "^11.1.0",
"uuid": "^13.0.0",
"y-protocols": "^1.0.6",
"yjs": "^13.6.27"
},
@@ -36,7 +36,7 @@
"@blocksuite/affine": "workspace:*",
"fake-indexeddb": "^6.0.0",
"idb": "^8.0.0",
"socket.io-client": "^4.8.1",
"socket.io-client": "^4.8.3",
"vitest": "^3.2.4"
},
"peerDependencies": {
@@ -44,6 +44,6 @@
"@affine/graphql": "workspace:*",
"@blocksuite/affine": "workspace:*",
"idb": "^8.0.0",
"socket.io-client": "^4.7.5"
"socket.io-client": "^4.8.3"
}
}

View File

@@ -11,15 +11,40 @@
* @param init Request initialization options
* @returns Promise with the fetch Response
*/
const CSRF_COOKIE_NAME = 'affine_csrf_token';
function getCookieValue(name: string) {
if (typeof document === 'undefined') {
return null;
}
const cookies = document.cookie ? document.cookie.split('; ') : [];
for (const cookie of cookies) {
const idx = cookie.indexOf('=');
const key = idx === -1 ? cookie : cookie.slice(0, idx);
if (key === name) {
return idx === -1 ? '' : cookie.slice(idx + 1);
}
}
return null;
}
export const affineFetch = (
input: RequestInfo | URL,
init?: RequestInit
): Promise<Response> => {
const method = init?.method?.toUpperCase() ?? 'GET';
const csrfToken =
method !== 'GET' && method !== 'HEAD'
? getCookieValue(CSRF_COOKIE_NAME)
: null;
return fetch(input, {
...init,
headers: {
...init?.headers,
'x-affine-version': BUILD_CONFIG.appVersion,
...(csrfToken ? { 'x-affine-csrf-token': csrfToken } : {}),
},
});
};

View File

@@ -64,7 +64,7 @@ export function UserDropdown({ isCollapsed }: UserDropdownProps) {
const relative = useRevalidateCurrentUser();
const handleLogout = useCallback(() => {
affineFetch('/api/auth/sign-out')
affineFetch('/api/auth/sign-out', { method: 'POST' })
.then(() => {
toast.success('Logged out successfully');
return relative();

View File

@@ -28,7 +28,9 @@ object AuthInitializer {
.get(server.host + CookieStore.AFFINE_SESSION)
val userIdCookieStr = AFFiNEApp.context().dataStore
.get(server.host + CookieStore.AFFINE_USER_ID)
if (sessionCookieStr.isEmpty() || userIdCookieStr.isEmpty()) {
val csrfCookieStr = AFFiNEApp.context().dataStore
.get(server.host + CookieStore.AFFINE_CSRF_TOKEN)
if (sessionCookieStr.isEmpty() || userIdCookieStr.isEmpty() || csrfCookieStr.isEmpty()) {
Timber.i("[init] user has not signed in yet.")
return@launch
}
@@ -38,6 +40,8 @@ object AuthInitializer {
?: error("Parse session cookie fail:[ cookie = $sessionCookieStr ]"),
Cookie.parse(server, userIdCookieStr)
?: error("Parse user id cookie fail:[ cookie = $userIdCookieStr ]"),
Cookie.parse(server, csrfCookieStr)
?: error("Parse csrf token cookie fail:[ cookie = $csrfCookieStr ]"),
)
CookieStore.saveCookies(server.host, cookies)
FileTree.get()?.checkAndUploadOldLogs(server)
@@ -49,4 +53,4 @@ object AuthInitializer {
})
}
}
}

View File

@@ -43,9 +43,15 @@ class AuthPlugin : Plugin() {
launch(Dispatchers.IO) {
try {
val endpoint = call.getStringEnsure("endpoint")
val csrfToken = CookieStore.getCookie(endpoint.toHttpUrl(), CookieStore.AFFINE_CSRF_TOKEN)
val request = Request.Builder()
.url("$endpoint/api/auth/sign-out")
.get()
.post("".toRequestBody("application/json".toMediaTypeOrNull()))
.apply {
if (csrfToken != null) {
addHeader("x-affine-csrf-token", csrfToken)
}
}
.build()
OkHttp.client.newCall(request).executeAsync().use { response ->
if (response.code >= 400) {

View File

@@ -54,6 +54,7 @@ object CookieStore {
const val AFFINE_SESSION = "affine_session"
const val AFFINE_USER_ID = "affine_user_id"
const val AFFINE_CSRF_TOKEN = "affine_csrf_token"
private val _cookies = ConcurrentHashMap<String, List<Cookie>>()
@@ -68,6 +69,9 @@ object CookieStore {
AFFiNEApp.context().dataStore.set(host + AFFINE_USER_ID, it.toString())
Firebase.crashlytics.setUserId(it.value)
}
cookies.find { it.name == AFFINE_CSRF_TOKEN }?.let {
AFFiNEApp.context().dataStore.set(host + AFFINE_CSRF_TOKEN, it.toString())
}
}
}
@@ -77,4 +81,4 @@ object CookieStore {
.let { _cookies[it] }
?.find { cookie -> cookie.name == name }
?.value
}
}

View File

@@ -26,7 +26,7 @@
"react": "^19.2.1",
"react-dom": "^19.2.1",
"react-router-dom": "^6.30.3",
"uuid": "^11.1.0",
"uuid": "^13.0.0",
"webm-muxer": "^5.0.3"
},
"devDependencies": {

View File

@@ -53,7 +53,7 @@
"@sentry/react": "^9.47.1",
"@toeverything/infra": "workspace:*",
"@types/set-cookie-parser": "^2.4.10",
"@types/uuid": "^10.0.0",
"@types/uuid": "^11.0.0",
"@vitejs/plugin-react-swc": "^3.7.2",
"app-builder-lib": "^26.1.0",
"builder-util-runtime": "^9.5.0",
@@ -73,7 +73,7 @@
"semver": "^7.7.3",
"tree-kill": "^1.2.2",
"ts-node": "^10.9.2",
"uuid": "^11.1.0",
"uuid": "^13.0.0",
"vitest": "^3.2.4",
"zod": "^3.25.76"
},

View File

@@ -88,7 +88,9 @@ async function handleAffineUrl(url: string) {
if (
!method ||
(method !== 'magic-link' && method !== 'oauth') ||
(method !== 'magic-link' &&
method !== 'oauth' &&
method !== 'open-app-signin') ||
!payload
) {
logger.error('Invalid authentication url', url);

View File

@@ -2,35 +2,7 @@ import { app } from 'electron';
import { anotherHost, mainHost } from './constants';
import { openExternalSafely } from './security/open-external';
const extractRedirectTarget = (rawUrl: string) => {
try {
const parsed = new URL(rawUrl);
const redirectUri = parsed.searchParams.get('redirect_uri');
if (redirectUri) {
return redirectUri;
}
if (parsed.hash) {
const hash = parsed.hash.startsWith('#')
? parsed.hash.slice(1)
: parsed.hash;
const queryIndex = hash.indexOf('?');
if (queryIndex !== -1) {
const hashParams = new URLSearchParams(hash.slice(queryIndex + 1));
const hashRedirect = hashParams.get('redirect_uri');
if (hashRedirect) {
return hashRedirect;
}
}
}
return null;
} catch {
return null;
}
};
import { validateRedirectProxyUrl } from './security/redirect-proxy';
app.on('web-contents-created', (_, contents) => {
const isInternalUrl = (url: string) => {
@@ -80,17 +52,18 @@ app.on('web-contents-created', (_, contents) => {
console.error('[security] Failed to open external URL:', error);
});
} else if (url.includes('/redirect-proxy')) {
const redirectTarget = extractRedirectTarget(url);
if (redirectTarget) {
openExternalSafely(redirectTarget).catch(error => {
console.error('[security] Failed to open external URL:', error);
});
} else {
const result = validateRedirectProxyUrl(url);
if (!result.allow) {
console.warn(
'[security] Blocked redirect proxy with missing redirect target:',
url
`[security] Blocked redirect proxy: ${result.reason}`,
result.redirectTarget ?? url
);
return { action: 'deny' };
}
openExternalSafely(result.redirectTarget).catch(error => {
console.error('[security] Failed to open external URL:', error);
});
}
// Prevent creating new window in application
return { action: 'deny' };

View File

@@ -0,0 +1,84 @@
import { isAllowedRedirectTarget } from '@toeverything/infra/utils';
import { buildType, isDev } from '../config';
const API_BASE_BY_BUILD_TYPE: Record<typeof buildType, string> = {
stable: 'https://app.affine.pro',
beta: 'https://insider.affine.pro',
internal: 'https://insider.affine.pro',
canary: 'https://affine.fail',
};
function resolveCurrentHostnameForRedirectAllowlist() {
const devServerBase = process.env.DEV_SERVER_URL;
const base =
isDev && devServerBase
? devServerBase
: (API_BASE_BY_BUILD_TYPE[buildType] ?? API_BASE_BY_BUILD_TYPE.stable);
try {
return new URL(base).hostname;
} catch {
return 'app.affine.pro';
}
}
export function extractRedirectTarget(rawUrl: string) {
try {
const parsed = new URL(rawUrl);
const redirectUri = parsed.searchParams.get('redirect_uri');
if (redirectUri) {
return redirectUri;
}
if (parsed.hash) {
const hash = parsed.hash.startsWith('#')
? parsed.hash.slice(1)
: parsed.hash;
const queryIndex = hash.indexOf('?');
if (queryIndex !== -1) {
const hashParams = new URLSearchParams(hash.slice(queryIndex + 1));
const hashRedirect = hashParams.get('redirect_uri');
if (hashRedirect) {
return hashRedirect;
}
}
}
return null;
} catch {
return null;
}
}
export type RedirectProxyValidationResult =
| {
allow: true;
redirectTarget: string;
}
| {
allow: false;
reason: 'missing_redirect_target' | 'untrusted_redirect_target';
redirectTarget?: string;
};
export function validateRedirectProxyUrl(
rawUrl: string
): RedirectProxyValidationResult {
const redirectTarget = extractRedirectTarget(rawUrl);
if (!redirectTarget) {
return { allow: false, reason: 'missing_redirect_target' };
}
const currentHostname = resolveCurrentHostnameForRedirectAllowlist();
if (!isAllowedRedirectTarget(redirectTarget, { currentHostname })) {
return {
allow: false,
reason: 'untrusted_redirect_target',
redirectTarget,
};
}
return { allow: true, redirectTarget };
}

View File

@@ -0,0 +1,117 @@
import * as dns from 'node:dns/promises';
import { BlockList, isIP } from 'node:net';
const ALLOWED_PROTOCOLS = new Set(['http:', 'https:']);
const BLOCKED_IPS = new BlockList();
const ALLOWED_IPV6 = new BlockList();
function stripZoneId(address: string) {
const idx = address.indexOf('%');
return idx === -1 ? address : address.slice(0, idx);
}
// Use Node's built-in BlockList (Electron 39 ships with Node 22.x).
for (const [network, prefix] of [
['0.0.0.0', 8],
['10.0.0.0', 8],
['127.0.0.0', 8],
['169.254.0.0', 16],
['172.16.0.0', 12],
['192.168.0.0', 16],
['100.64.0.0', 10], // CGNAT
['224.0.0.0', 4], // multicast
['240.0.0.0', 4], // reserved (includes broadcast)
] as const) {
BLOCKED_IPS.addSubnet(network, prefix, 'ipv4');
}
BLOCKED_IPS.addAddress('::', 'ipv6');
BLOCKED_IPS.addAddress('::1', 'ipv6');
BLOCKED_IPS.addSubnet('ff00::', 8, 'ipv6'); // multicast
BLOCKED_IPS.addSubnet('fc00::', 7, 'ipv6'); // unique local
BLOCKED_IPS.addSubnet('fe80::', 10, 'ipv6'); // link-local
ALLOWED_IPV6.addSubnet('2000::', 3, 'ipv6'); // global unicast
function extractEmbeddedIPv4FromIPv6(address: string): string | null {
if (!address.includes('.')) {
return null;
}
const idx = address.lastIndexOf(':');
if (idx === -1) {
return null;
}
const tail = address.slice(idx + 1);
return isIP(tail) === 4 ? tail : null;
}
function isBlockedIpAddress(address: string): boolean {
const ip = stripZoneId(address);
const family = isIP(ip);
if (family === 4) {
return BLOCKED_IPS.check(ip, 'ipv4');
}
if (family === 6) {
const embeddedV4 = extractEmbeddedIPv4FromIPv6(ip);
if (embeddedV4) {
return isBlockedIpAddress(embeddedV4);
}
if (!ALLOWED_IPV6.check(ip, 'ipv6')) {
return true;
}
return BLOCKED_IPS.check(ip, 'ipv6');
}
return true;
}
async function resolveHostAddresses(hostname: string): Promise<string[]> {
const lowered = hostname.toLowerCase();
if (lowered === 'localhost' || lowered.endsWith('.localhost')) {
return ['127.0.0.1', '::1'];
}
const results = await dns.lookup(hostname, { all: true, verbatim: true });
return results.map(r => r.address);
}
export async function resolveAndValidateUrlForPreview(
rawUrl: string
): Promise<{ url: URL; address: string }> {
let url: URL;
try {
url = new URL(rawUrl);
} catch {
throw new Error('Invalid URL');
}
if (!ALLOWED_PROTOCOLS.has(url.protocol)) {
throw new Error('Disallowed URL protocol');
}
if (url.username || url.password) {
throw new Error('URL must not include credentials');
}
if (!url.hostname) {
throw new Error('Missing hostname');
}
if (isIP(url.hostname)) {
if (isBlockedIpAddress(url.hostname)) {
throw new Error('Blocked IP address');
}
return { url, address: url.hostname };
}
const addresses = await resolveHostAddresses(url.hostname);
if (!addresses.length) {
throw new Error('Unresolvable hostname');
}
for (const addr of addresses) {
if (isBlockedIpAddress(addr)) {
throw new Error('Blocked IP address');
}
}
return { url, address: addresses[0] };
}

View File

@@ -6,6 +6,7 @@ import { isMacOS } from '../../shared/utils';
import { persistentConfig } from '../config-storage/persist';
import { logger } from '../logger';
import { openExternalSafely } from '../security/open-external';
import { resolveAndValidateUrlForPreview } from '../security/url-safety';
import type { WorkbenchViewMeta } from '../shared-state-schema';
import { MenubarStateKey, MenubarStateSchema } from '../shared-state-schema';
import { globalStateStorage } from '../shared-storage/storage';
@@ -37,6 +38,13 @@ import { getOrCreateCustomThemeWindow } from '../windows-manager/custom-theme-wi
import { getChallengeResponse } from './challenge';
import { uiSubjects } from './subject';
const EMPTY_OBJECT = Object.freeze({
title: undefined,
description: undefined,
icon: undefined,
image: undefined,
});
const TraySettingsState = {
$: globalStateStorage.watch<MenubarStateSchema>(MenubarStateKey).pipe(
map(v => MenubarStateSchema.parse(v ?? {})),
@@ -127,6 +135,13 @@ export const uiHandlers = {
}
},
getBookmarkDataByLink: async (_, link: string) => {
try {
// Basic validation up-front to prevent SSRF (including redirects).
await resolveAndValidateUrlForPreview(link);
} catch {
return EMPTY_OBJECT;
}
if (
(link.startsWith('https://x.com/') ||
link.startsWith('https://www.x.com/') ||
@@ -135,8 +150,9 @@ export const uiHandlers = {
link.includes('/status/')
) {
// use api.fxtwitter.com
link =
'https://api.fxtwitter.com/status/' + /\/status\/(.*)/.exec(link)?.[1];
const statusId = /\/status\/(\d+)/.exec(link)?.[1];
if (!statusId) return EMPTY_OBJECT;
link = `https://api.fxtwitter.com/status/${statusId}`;
try {
const { tweet } = (await fetch(link).then(res => res.json())) as any;
return {
@@ -161,7 +177,20 @@ export const uiHandlers = {
'User-Agent':
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36',
},
followRedirects: 'follow',
followRedirects: 'manual',
handleRedirects: (_baseUrl: string, forwardedUrl: string) => {
try {
// Only allow http(s) redirects and re-validate before following.
const u = new URL(forwardedUrl);
return u.protocol === 'http:' || u.protocol === 'https:';
} catch {
return false;
}
},
resolveDNSHost: async (url: string) => {
const { address } = await resolveAndValidateUrlForPreview(url);
return address;
},
}).catch(() => {
return {
title: '',

View File

@@ -1,5 +1,5 @@
export interface AuthenticationRequest {
method: 'magic-link' | 'oauth';
method: 'magic-link' | 'oauth' | 'open-app-signin';
payload: Record<string, any>;
server?: string;
}

View File

@@ -0,0 +1,109 @@
import { describe, expect, it, vi } from 'vitest';
describe('redirect proxy allowlist', () => {
it('blocks missing redirect_uri', async () => {
vi.resetModules();
process.env.BUILD_TYPE = 'stable';
process.env.NODE_ENV = 'production';
delete process.env.DEV_SERVER_URL;
const { validateRedirectProxyUrl } =
await import('../../src/main/security/redirect-proxy');
expect(validateRedirectProxyUrl('assets://./redirect-proxy')).toEqual({
allow: false,
reason: 'missing_redirect_target',
});
});
it('blocks untrusted redirect_uri', async () => {
vi.resetModules();
process.env.BUILD_TYPE = 'stable';
process.env.NODE_ENV = 'production';
delete process.env.DEV_SERVER_URL;
const { validateRedirectProxyUrl } =
await import('../../src/main/security/redirect-proxy');
expect(
validateRedirectProxyUrl(
'assets://./redirect-proxy?redirect_uri=https%3A%2F%2Fevil.com%2F'
)
).toEqual({
allow: false,
reason: 'untrusted_redirect_target',
redirectTarget: 'https://evil.com/',
});
});
it('allows trusted redirect_uri', async () => {
vi.resetModules();
process.env.BUILD_TYPE = 'stable';
process.env.NODE_ENV = 'production';
delete process.env.DEV_SERVER_URL;
const { validateRedirectProxyUrl } =
await import('../../src/main/security/redirect-proxy');
expect(
validateRedirectProxyUrl(
'assets://./redirect-proxy?redirect_uri=https%3A%2F%2Fgithub.com%2Ftoeverything%2FAFFiNE'
)
).toEqual({
allow: true,
redirectTarget: 'https://github.com/toeverything/AFFiNE',
});
});
it('allows current hostname (canary)', async () => {
vi.resetModules();
process.env.BUILD_TYPE = 'canary';
process.env.NODE_ENV = 'production';
delete process.env.DEV_SERVER_URL;
const { validateRedirectProxyUrl } =
await import('../../src/main/security/redirect-proxy');
expect(
validateRedirectProxyUrl(
'assets://./redirect-proxy?redirect_uri=https%3A%2F%2Faffine.fail%2Fpricing'
)
).toEqual({
allow: true,
redirectTarget: 'https://affine.fail/pricing',
});
});
it('allows current hostname from DEV_SERVER_URL in development', async () => {
vi.resetModules();
process.env.BUILD_TYPE = 'stable';
process.env.NODE_ENV = 'development';
process.env.DEV_SERVER_URL = 'http://localhost:8080';
const { validateRedirectProxyUrl } =
await import('../../src/main/security/redirect-proxy');
expect(
validateRedirectProxyUrl(
'assets://./redirect-proxy?redirect_uri=http%3A%2F%2Flocalhost%3A1234%2Fauth'
)
).toEqual({
allow: true,
redirectTarget: 'http://localhost:1234/auth',
});
});
it('blocks redirect_uri in hash when untrusted', async () => {
vi.resetModules();
process.env.BUILD_TYPE = 'stable';
process.env.NODE_ENV = 'production';
delete process.env.DEV_SERVER_URL;
const { validateRedirectProxyUrl } =
await import('../../src/main/security/redirect-proxy');
expect(
validateRedirectProxyUrl(
'assets://./redirect-proxy#/foo?redirect_uri=https%3A%2F%2Fevil.com%2F'
)
).toEqual({
allow: false,
reason: 'untrusted_redirect_target',
redirectTarget: 'https://evil.com/',
});
});
});

View File

@@ -27,6 +27,7 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
} else {
call.reject("Failed to sign in")
}
return
}
guard let token = try self.tokenFromCookie(endpoint) else {
@@ -57,6 +58,7 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
} else {
call.reject("Failed to sign in")
}
return
}
guard let token = try self.tokenFromCookie(endpoint) else {
@@ -91,6 +93,7 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
} else {
call.reject("Failed to sign in")
}
return
}
guard let token = try self.tokenFromCookie(endpoint) else {
@@ -109,20 +112,24 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
Task {
do {
let endpoint = try call.getStringEnsure("endpoint")
let csrfToken = try self.csrfTokenFromCookie(endpoint)
let (data, response) = try await self.fetch(endpoint, method: "GET", action: "/api/auth/sign-out", headers: [:], body: nil)
let (data, response) = try await self.fetch(endpoint, method: "POST", action: "/api/auth/sign-out", headers: [
"x-affine-csrf-token": csrfToken,
], body: nil)
if response.statusCode >= 400 {
if let textBody = String(data: data, encoding: .utf8) {
call.reject(textBody)
} else {
call.reject("Failed to sign in")
call.reject("Failed to sign out")
}
return
}
call.resolve(["ok": true])
} catch {
call.reject("Failed to sign in, \(error)", nil, error)
call.reject("Failed to sign out, \(error)", nil, error)
}
}
}
@@ -141,6 +148,16 @@ public class AuthPlugin: CAPPlugin, CAPBridgedPlugin {
}
}
private func csrfTokenFromCookie(_ endpoint: String) throws -> String? {
guard let endpointUrl = URL(string: endpoint) else {
throw AuthError.invalidEndpoint
}
return HTTPCookieStorage.shared.cookies(for: endpointUrl)?.first(where: {
$0.name == "affine_csrf_token"
})?.value
}
private func fetch(_ endpoint: String, method: String, action: String, headers: [String: String?], body: Encodable?) async throws -> (Data, HTTPURLResponse) {
guard let targetUrl = URL(string: "\(endpoint)\(action)") else {
throw AuthError.invalidEndpoint

View File

@@ -125,16 +125,6 @@ export const OnboardingPage = ({
return null;
}
// deprecated
// TODO(@forehalo): remove
if (callbackUrl?.startsWith('/open-app/signin-redirect')) {
const url = new URL(callbackUrl, window.location.origin);
url.searchParams.set('next', 'onboarding');
console.log('redirect to', url.toString());
window.location.assign(url.toString());
return null;
}
if (question) {
return (
<ScrollableLayout

View File

@@ -91,7 +91,7 @@
"semver": "^7.7.3",
"ses": "^1.14.0",
"shiki": "^3.19.0",
"socket.io-client": "^4.8.1",
"socket.io-client": "^4.8.3",
"swr": "^2.3.7",
"tinykeys": "patch:tinykeys@npm%3A2.1.0#~/.yarn/patches/tinykeys-npm-2.1.0-819feeaed0.patch",
"y-protocols": "^1.0.6",

View File

@@ -14,6 +14,23 @@ import { z } from 'zod';
import { supportedClient } from './common';
const supportedProvider = z.nativeEnum(OAuthProviderType);
const CSRF_COOKIE_NAME = 'affine_csrf_token';
function getCookieValue(name: string) {
if (typeof document === 'undefined') {
return null;
}
const cookies = document.cookie ? document.cookie.split('; ') : [];
for (const cookie of cookies) {
const idx = cookie.indexOf('=');
const key = idx === -1 ? cookie : cookie.slice(0, idx);
if (key === name) {
return idx === -1 ? '' : cookie.slice(idx + 1);
}
}
return null;
}
const oauthParameters = z.object({
provider: supportedProvider,
@@ -36,7 +53,11 @@ export const loader: LoaderFunction = async ({ request }) => {
// sign out first, web only
if (client === 'web') {
await fetch('/api/auth/sign-out');
const csrfToken = getCookieValue(CSRF_COOKIE_NAME);
await fetch('/api/auth/sign-out', {
method: 'POST',
headers: csrfToken ? { 'x-affine-csrf-token': csrfToken } : undefined,
});
}
const paramsParseResult = oauthParameters.safeParse({

View File

@@ -1,15 +1,13 @@
import { useNavigateHelper } from '@affine/core/components/hooks/use-navigate-helper';
import { GraphQLService } from '@affine/core/modules/cloud';
import { AuthService } from '@affine/core/modules/cloud';
import { OpenInAppPage } from '@affine/core/modules/open-in-app/views/open-in-app-page';
import {
appSchemaUrl,
appSchemes,
channelToScheme,
} from '@affine/core/utils/channel';
import type { GetCurrentUserQuery } from '@affine/graphql';
import { getCurrentUserQuery } from '@affine/graphql';
import { useService } from '@toeverything/infra';
import { useCallback, useEffect, useState } from 'react';
import { useCallback, useEffect, useRef, useState } from 'react';
import { useParams, useSearchParams } from 'react-router-dom';
import { AppContainer } from '../../components/app-container';
@@ -49,38 +47,43 @@ const OpenUrl = () => {
/**
* @deprecated
*/
const OpenOAuthJwt = () => {
const [currentUser, setCurrentUser] = useState<
GetCurrentUserQuery['currentUser'] | null
>(null);
const OpenAppSignInRedirect = () => {
const authService = useService(AuthService);
const [params] = useSearchParams();
const graphqlService = useService(GraphQLService);
const triggeredRef = useRef(false);
const [urlToOpen, setUrlToOpen] = useState<string | null>(null);
const maybeScheme = appSchemes.safeParse(params.get('scheme'));
const scheme = maybeScheme.success
? maybeScheme.data
: channelToScheme[BUILD_CONFIG.appBuildType];
const next = params.get('next') || '';
const next = params.get('next') || undefined;
useEffect(() => {
graphqlService
.gql({
query: getCurrentUserQuery,
})
.then(res => {
setCurrentUser(res?.currentUser || null);
if (triggeredRef.current) {
return;
}
triggeredRef.current = true;
authService
.createOpenAppSignInCode()
.then(code => {
const authParams = new URLSearchParams();
authParams.set('method', 'open-app-signin');
authParams.set(
'payload',
JSON.stringify(next ? { code, next } : { code })
);
authParams.set('server', location.origin);
setUrlToOpen(`${scheme}://authentication?${authParams.toString()}`);
})
.catch(console.error);
}, [graphqlService]);
}, [authService, next, scheme]);
if (!currentUser || !currentUser?.token?.sessionToken) {
if (!urlToOpen) {
return <AppContainer fallback />;
}
const urlToOpen = `${scheme}://signin-redirect?token=${
currentUser.token.sessionToken
}&next=${next}`;
return <OpenInAppPage urlToOpen={urlToOpen} />;
};
@@ -91,7 +94,7 @@ export const Component = () => {
if (action === 'url') {
return <OpenUrl />;
} else if (action === 'signin-redirect') {
return <OpenOAuthJwt />;
return <OpenAppSignInRedirect />;
}
return null;
};

View File

@@ -1,21 +1,8 @@
import { DebugLogger } from '@affine/debug';
import { escapeRegExp } from 'lodash-es';
import { isAllowedRedirectTarget } from '@toeverything/infra';
import { type LoaderFunction, Navigate, useLoaderData } from 'react-router-dom';
const trustedDomain = [
'google.com',
'stripe.com',
'github.com',
'twitter.com',
'discord.gg',
'youtube.com',
't.me',
'reddit.com',
'affine.pro',
];
const logger = new DebugLogger('redirect_proxy');
const ALLOWED_PROTOCOLS = new Set(['http:', 'https:']);
/**
* /redirect-proxy page
@@ -31,26 +18,13 @@ export const loader: LoaderFunction = async ({ request }) => {
return { allow: false };
}
try {
const target = new URL(redirectUri);
if (!ALLOWED_PROTOCOLS.has(target.protocol)) {
logger.warn('Blocked redirect with disallowed protocol', target.protocol);
return { allow: false };
}
if (
target.hostname === window.location.hostname ||
trustedDomain.some(domain =>
new RegExp(`(^|\\.)${escapeRegExp(domain)}$`).test(target.hostname)
)
) {
location.href = redirectUri;
return { allow: true };
}
} catch (e) {
logger.error('Failed to parse redirect uri', e);
return { allow: false };
if (
isAllowedRedirectTarget(redirectUri, {
currentHostname: window.location.hostname,
})
) {
location.href = redirectUri;
return { allow: true };
}
logger.warn('Blocked redirect to untrusted domain', redirectUri);

View File

@@ -4,6 +4,24 @@ import { AuthProvider } from '../provider/auth';
import { ServerScope } from '../scopes/server';
import { FetchService } from '../services/fetch';
const CSRF_COOKIE_NAME = 'affine_csrf_token';
function getCookieValue(name: string) {
if (typeof document === 'undefined') {
return null;
}
const cookies = document.cookie ? document.cookie.split('; ') : [];
for (const cookie of cookies) {
const idx = cookie.indexOf('=');
const key = idx === -1 ? cookie : cookie.slice(0, idx);
if (key === name) {
return idx === -1 ? '' : cookie.slice(idx + 1);
}
}
return null;
}
export function configureDefaultAuthProvider(framework: Framework) {
framework.scope(ServerScope).override(AuthProvider, resolver => {
const fetchService = resolver.get(FetchService);
@@ -62,7 +80,11 @@ export function configureDefaultAuthProvider(framework: Framework) {
});
},
async signOut() {
await fetchService.fetch('/api/auth/sign-out');
const csrfToken = getCookieValue(CSRF_COOKIE_NAME);
await fetchService.fetch('/api/auth/sign-out', {
method: 'POST',
headers: csrfToken ? { 'x-affine-csrf-token': csrfToken } : undefined,
});
},
};
});

View File

@@ -165,6 +165,32 @@ export class AuthService extends Service {
}
}
async createOpenAppSignInCode() {
const res = await this.fetchService.fetch(
'/api/auth/open-app/sign-in-code',
{
method: 'POST',
}
);
const body = (await res.json()) as { code?: string };
if (!body.code) {
throw new Error('Missing open-app sign-in code');
}
return body.code;
}
async signInOpenAppSignInCode(code: string) {
await this.fetchService.fetch('/api/auth/open-app/sign-in', {
method: 'POST',
body: JSON.stringify({ code }),
headers: { 'content-type': 'application/json' },
});
this.session.revalidate();
}
async signInPassword(credential: {
email: string;
password: string;

View File

@@ -146,6 +146,14 @@ export class DesktopApiService extends Service {
await authService.signInOauth(code, state, provider);
break;
}
case 'open-app-signin': {
const code = (payload as { code?: unknown }).code;
if (typeof code !== 'string' || !code) {
throw new Error('Invalid open-app sign-in payload');
}
await authService.signInOpenAppSignInCode(code);
break;
}
}
})().catch(e => {
notify.error({

View File

@@ -9,16 +9,16 @@
"es-CL": 99,
"es": 98,
"fa": 98,
"fr": 100,
"fr": 99,
"hi": 2,
"it-IT": 100,
"it-IT": 99,
"it": 1,
"ja": 98,
"ko": 99,
"nb-NO": 48,
"pl": 100,
"pt-BR": 98,
"ru": 100,
"ru": 99,
"sv-SE": 98,
"uk": 98,
"ur": 2,

Some files were not shown because too many files have changed in this diff Show More