Compare commits

..

2 Commits

Author SHA1 Message Date
DarkSky
ad6470db82 feat: improve downgrade check 2026-03-22 22:09:15 +08:00
DarkSky
adf8955e3f feat: improve subscription sync 2026-03-20 05:28:05 +08:00
86 changed files with 2560 additions and 3404 deletions

View File

@@ -269,13 +269,10 @@ jobs:
- name: Run playground build
run: yarn workspace @blocksuite/playground build
- name: Run integration browser tests
timeout-minutes: 10
run: yarn workspace @blocksuite/integration-test test:unit
- name: Run cross-platform playwright tests
timeout-minutes: 10
run: yarn workspace @affine-test/blocksuite test "cross-platform/" --forbid-only
- name: Run playwright tests
run: |
yarn workspace @blocksuite/integration-test test:unit
yarn workspace @affine-test/blocksuite test "cross-platform/" --forbid-only
- name: Upload test results
if: always()

33
Cargo.lock generated
View File

@@ -92,9 +92,6 @@ dependencies = [
"napi-derive",
"objc2",
"objc2-foundation",
"ogg",
"opus-codec",
"rand 0.9.2",
"rubato",
"screencapturekit",
"symphonia",
@@ -624,8 +621,6 @@ dependencies = [
"cexpr",
"clang-sys",
"itertools 0.13.0",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
@@ -1088,15 +1083,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
[[package]]
name = "cmake"
version = "0.1.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d"
dependencies = [
"cc",
]
[[package]]
name = "cobs"
version = "0.3.0"
@@ -4008,15 +3994,6 @@ dependencies = [
"cc",
]
[[package]]
name = "ogg"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdab8dcd8d4052eaacaf8fb07a3ccd9a6e26efadb42878a413c68fc4af1dee2b"
dependencies = [
"byteorder",
]
[[package]]
name = "once_cell"
version = "1.21.4"
@@ -4041,16 +4018,6 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "opus-codec"
version = "0.1.2"
source = "git+https://github.com/toeverything/opus-codec?rev=c2afef2#c2afef20773c3afb06395a26a4f054ca90ba9078"
dependencies = [
"bindgen",
"cmake",
"pkg-config",
]
[[package]]
name = "ordered-float"
version = "5.1.0"

View File

@@ -76,7 +76,6 @@ resolver = "3"
notify = { version = "8", features = ["serde"] }
objc2 = "0.6"
objc2-foundation = "0.3"
ogg = "0.9"
once_cell = "1"
ordered-float = "5"
parking_lot = "0.12"

View File

@@ -183,32 +183,6 @@ function createTextFootnoteDefinition(content: string): string {
});
}
function parseFootnoteDefLine(line: string): {
identifier: string;
content: string;
} | null {
if (!line.startsWith('[^')) return null;
const closeBracketIndex = line.indexOf(']:', 2);
if (closeBracketIndex <= 2) return null;
const identifier = line.slice(2, closeBracketIndex);
if (!identifier || identifier.includes(']')) return null;
let contentStart = closeBracketIndex + 2;
while (
contentStart < line.length &&
(line[contentStart] === ' ' || line[contentStart] === '\t')
) {
contentStart += 1;
}
return {
identifier,
content: line.slice(contentStart),
};
}
function extractObsidianFootnotes(markdown: string): {
content: string;
footnotes: string[];
@@ -219,14 +193,14 @@ function extractObsidianFootnotes(markdown: string): {
for (let index = 0; index < lines.length; index += 1) {
const line = lines[index];
const definition = parseFootnoteDefLine(line);
if (!definition) {
const match = line.match(/^\[\^([^\]]+)\]:\s*(.*)$/);
if (!match) {
output.push(line);
continue;
}
const { identifier } = definition;
const contentLines = [definition.content];
const identifier = match[1];
const contentLines = [match[2]];
while (index + 1 < lines.length) {
const nextLine = lines[index + 1];
@@ -418,119 +392,49 @@ function parseObsidianAttach(value: string): ObsidianAttachmentEmbed | null {
}
}
function parseWikiLinkAt(
source: string,
startIdx: number,
embedded: boolean
): {
raw: string;
rawTarget: string;
rawAlias?: string;
endIdx: number;
} | null {
const opener = embedded ? '![[' : '[[';
if (!source.startsWith(opener, startIdx)) return null;
const contentStart = startIdx + opener.length;
const closeIndex = source.indexOf(']]', contentStart);
if (closeIndex === -1) return null;
const inner = source.slice(contentStart, closeIndex);
const separatorIdx = inner.indexOf('|');
const rawTarget = separatorIdx === -1 ? inner : inner.slice(0, separatorIdx);
const rawAlias =
separatorIdx === -1 ? undefined : inner.slice(separatorIdx + 1);
if (
rawTarget.length === 0 ||
rawTarget.includes(']') ||
rawTarget.includes('|') ||
rawAlias?.includes(']')
) {
return null;
}
return {
raw: source.slice(startIdx, closeIndex + 2),
rawTarget,
rawAlias,
endIdx: closeIndex + 2,
};
}
function replaceWikiLinks(
source: string,
embedded: boolean,
replacer: (match: {
raw: string;
rawTarget: string;
rawAlias?: string;
}) => string
): string {
const opener = embedded ? '![[' : '[[';
let cursor = 0;
let output = '';
while (cursor < source.length) {
const matchStart = source.indexOf(opener, cursor);
if (matchStart === -1) {
output += source.slice(cursor);
break;
}
output += source.slice(cursor, matchStart);
const match = parseWikiLinkAt(source, matchStart, embedded);
if (!match) {
output += source.slice(matchStart, matchStart + opener.length);
cursor = matchStart + opener.length;
continue;
}
output += replacer(match);
cursor = match.endIdx;
}
return output;
}
function preprocessObsidianEmbeds(
markdown: string,
filePath: string,
pageLookupMap: ReadonlyMap<string, string>,
pathBlobIdMap: ReadonlyMap<string, string>
): string {
return replaceWikiLinks(markdown, true, ({ raw, rawTarget, rawAlias }) => {
const targetPageId = resolvePageIdFromLookup(
pageLookupMap,
rawTarget,
filePath
);
if (targetPageId) {
return `[[${rawTarget}${rawAlias ? `|${rawAlias}` : ''}]]`;
return markdown.replace(
/!\[\[([^\]|]+)(?:\|([^\]]+))?\]\]/g,
(match, rawTarget: string, rawAlias?: string) => {
const targetPageId = resolvePageIdFromLookup(
pageLookupMap,
rawTarget,
filePath
);
if (targetPageId) {
return `[[${rawTarget}${rawAlias ? `|${rawAlias}` : ''}]]`;
}
const { path } = parseObsidianTarget(rawTarget);
if (!path) {
return match;
}
const assetPath = getImageFullPath(filePath, path);
const encodedPath = encodeMarkdownPath(assetPath);
if (isImageAssetPath(path)) {
const alt = getEmbedLabel(rawAlias, path, false);
return `![${escapeMarkdownLabel(alt)}](${encodedPath})`;
}
const label = getEmbedLabel(rawAlias, path, true);
const blobId = pathBlobIdMap.get(assetPath);
if (!blobId) return `[${escapeMarkdownLabel(label)}](${encodedPath})`;
const extension = path.split('.').at(-1)?.toLowerCase() ?? '';
return createObsidianAttach({
blobId,
fileName: basename(path),
fileType: extMimeMap.get(extension) ?? '',
});
}
const { path } = parseObsidianTarget(rawTarget);
if (!path) return raw;
const assetPath = getImageFullPath(filePath, path);
const encodedPath = encodeMarkdownPath(assetPath);
if (isImageAssetPath(path)) {
const alt = getEmbedLabel(rawAlias, path, false);
return `![${escapeMarkdownLabel(alt)}](${encodedPath})`;
}
const label = getEmbedLabel(rawAlias, path, true);
const blobId = pathBlobIdMap.get(assetPath);
if (!blobId) return `[${escapeMarkdownLabel(label)}](${encodedPath})`;
const extension = path.split('.').at(-1)?.toLowerCase() ?? '';
return createObsidianAttach({
blobId,
fileName: basename(path),
fileType: extMimeMap.get(extension) ?? '',
});
});
);
}
function preprocessObsidianMarkdown(
@@ -617,31 +521,21 @@ export const obsidianWikilinkToDeltaMatcher = MarkdownASTToDeltaExtension({
}
const nodeContent = textNode.value;
const wikilinkRegex = /\[\[([^\]|]+)(?:\|([^\]]+))?\]\]/g;
const deltas: DeltaInsert<AffineTextAttributes>[] = [];
let cursor = 0;
while (cursor < nodeContent.length) {
const matchStart = nodeContent.indexOf('[[', cursor);
if (matchStart === -1) {
deltas.push({ insert: nodeContent.substring(cursor) });
break;
}
let lastProcessedIndex = 0;
let linkMatch;
if (matchStart > cursor) {
while ((linkMatch = wikilinkRegex.exec(nodeContent)) !== null) {
if (linkMatch.index > lastProcessedIndex) {
deltas.push({
insert: nodeContent.substring(cursor, matchStart),
insert: nodeContent.substring(lastProcessedIndex, linkMatch.index),
});
}
const linkMatch = parseWikiLinkAt(nodeContent, matchStart, false);
if (!linkMatch) {
deltas.push({ insert: '[[' });
cursor = matchStart + 2;
continue;
}
const targetPageName = linkMatch.rawTarget.trim();
const alias = linkMatch.rawAlias?.trim();
const targetPageName = linkMatch[1].trim();
const alias = linkMatch[2]?.trim();
const currentFilePath = context.configs.get(FULL_FILE_PATH_KEY);
const targetPageId = resolvePageIdFromLookup(
{ get: key => context.configs.get(`obsidian:pageId:${key}`) },
@@ -666,10 +560,14 @@ export const obsidianWikilinkToDeltaMatcher = MarkdownASTToDeltaExtension({
},
});
} else {
deltas.push({ insert: linkMatch.raw });
deltas.push({ insert: linkMatch[0] });
}
cursor = linkMatch.endIdx;
lastProcessedIndex = wikilinkRegex.lastIndex;
}
if (lastProcessedIndex < nodeContent.length) {
deltas.push({ insert: nodeContent.substring(lastProcessedIndex) });
}
return deltas;

View File

@@ -2,10 +2,12 @@ import {
acceptInviteByInviteIdMutation,
approveWorkspaceTeamMemberMutation,
createInviteLinkMutation,
deleteBlobMutation,
getInviteInfoQuery,
getMembersByWorkspaceIdQuery,
inviteByEmailsMutation,
leaveWorkspaceMutation,
releaseDeletedBlobsMutation,
revokeMemberPermissionMutation,
WorkspaceInviteLinkExpireTime,
WorkspaceMemberStatus,
@@ -13,6 +15,11 @@ import {
import { faker } from '@faker-js/faker';
import { Models } from '../../../models';
import { FeatureConfigs } from '../../../models/common/feature';
import {
SubscriptionPlan,
SubscriptionRecurring,
} from '../../../plugins/payment/types';
import { Mockers } from '../../mocks';
import { app, e2e } from '../test';
@@ -81,6 +88,175 @@ e2e('should invite a user', async t => {
t.is(getInviteInfo2.status, WorkspaceMemberStatus.Accepted);
});
e2e('should re-check seat when accepting an email invitation', async t => {
const { owner, workspace } = await createWorkspace();
const member = await app.create(Mockers.User);
await app.create(Mockers.TeamWorkspace, {
id: workspace.id,
quantity: 4,
});
await app.create(Mockers.WorkspaceUser, {
workspaceId: workspace.id,
userId: (await app.create(Mockers.User)).id,
});
await app.create(Mockers.WorkspaceUser, {
workspaceId: workspace.id,
userId: (await app.create(Mockers.User)).id,
});
await app.login(owner);
const invite = await app.gql({
query: inviteByEmailsMutation,
variables: {
emails: [member.email],
workspaceId: workspace.id,
},
});
await app.eventBus.emitAsync('workspace.members.allocateSeats', {
workspaceId: workspace.id,
quantity: 4,
});
await app.models.workspaceFeature.remove(workspace.id, 'team_plan_v1');
await app.login(member);
await t.throwsAsync(
app.gql({
query: acceptInviteByInviteIdMutation,
variables: {
workspaceId: workspace.id,
inviteId: invite.inviteMembers[0].inviteId!,
},
})
);
const { getInviteInfo } = await app.gql({
query: getInviteInfoQuery,
variables: {
inviteId: invite.inviteMembers[0].inviteId!,
},
});
t.is(getInviteInfo.status, WorkspaceMemberStatus.Pending);
});
e2e.serial(
'should block accepting pending invitations in readonly mode and recover after blob cleanup',
async t => {
const { owner, workspace } = await createWorkspace();
const member = await app.create(Mockers.User);
const freeStorageQuota = FeatureConfigs.free_plan_v1.configs.storageQuota;
const lifetimeStorageQuota =
FeatureConfigs.lifetime_pro_plan_v1.configs.storageQuota;
FeatureConfigs.free_plan_v1.configs.storageQuota = 1;
FeatureConfigs.lifetime_pro_plan_v1.configs.storageQuota = 2;
t.teardown(() => {
FeatureConfigs.free_plan_v1.configs.storageQuota = freeStorageQuota;
FeatureConfigs.lifetime_pro_plan_v1.configs.storageQuota =
lifetimeStorageQuota;
});
await app.models.userFeature.switchQuota(
owner.id,
'lifetime_pro_plan_v1',
'test setup'
);
await app.login(owner);
const invite = await app.gql({
query: inviteByEmailsMutation,
variables: {
emails: [member.email],
workspaceId: workspace.id,
},
});
await app.models.blob.upsert({
workspaceId: workspace.id,
key: 'overflow-blob',
mime: 'application/octet-stream',
size: 2,
status: 'completed',
uploadId: null,
});
await app.eventBus.emitAsync('user.subscription.canceled', {
userId: owner.id,
plan: SubscriptionPlan.Pro,
recurring: SubscriptionRecurring.Lifetime,
});
t.true(
await app.models.workspaceFeature.has(
workspace.id,
'quota_exceeded_readonly_workspace_v1'
)
);
await app.login(member);
await t.throwsAsync(
app.gql({
query: acceptInviteByInviteIdMutation,
variables: {
workspaceId: workspace.id,
inviteId: invite.inviteMembers[0].inviteId!,
},
})
);
const { getInviteInfo: pendingInvite } = await app.gql({
query: getInviteInfoQuery,
variables: {
inviteId: invite.inviteMembers[0].inviteId!,
},
});
t.is(pendingInvite.status, WorkspaceMemberStatus.Pending);
await app.login(owner);
await app.gql({
query: deleteBlobMutation,
variables: {
workspaceId: workspace.id,
key: 'overflow-blob',
permanently: false,
},
});
await app.gql({
query: releaseDeletedBlobsMutation,
variables: {
workspaceId: workspace.id,
},
});
t.false(
await app.models.workspaceFeature.has(
workspace.id,
'quota_exceeded_readonly_workspace_v1'
)
);
await app.login(member);
await app.gql({
query: acceptInviteByInviteIdMutation,
variables: {
workspaceId: workspace.id,
inviteId: invite.inviteMembers[0].inviteId!,
},
});
const { getInviteInfo: acceptedInvite } = await app.gql({
query: getInviteInfoQuery,
variables: {
inviteId: invite.inviteMembers[0].inviteId!,
},
});
t.is(acceptedInvite.status, WorkspaceMemberStatus.Accepted);
}
);
e2e('should leave a workspace', async t => {
const { owner, workspace } = await createWorkspace();
const u2 = await app.create(Mockers.User);

View File

@@ -1,10 +1,18 @@
import {
getInviteInfoQuery,
inviteByEmailsMutation,
publishPageMutation,
revokeMemberPermissionMutation,
revokePublicPageMutation,
WorkspaceMemberStatus,
} from '@affine/graphql';
import { QuotaService } from '../../../core/quota/service';
import { WorkspaceRole } from '../../../models';
import {
SubscriptionPlan,
SubscriptionRecurring,
} from '../../../plugins/payment/types';
import { Mockers } from '../../mocks';
import { app, e2e } from '../test';
@@ -54,6 +62,42 @@ const getInvitationInfo = async (inviteId: string) => {
return result.getInviteInfo;
};
const publishDoc = async (workspaceId: string, docId: string) => {
const { publishDoc } = await app.gql({
query: publishPageMutation,
variables: {
workspaceId,
pageId: docId,
},
});
return publishDoc;
};
const revokePublicDoc = async (workspaceId: string, docId: string) => {
const { revokePublicDoc } = await app.gql({
query: revokePublicPageMutation,
variables: {
workspaceId,
pageId: docId,
},
});
return revokePublicDoc;
};
const revokeMember = async (workspaceId: string, userId: string) => {
const { revokeMember } = await app.gql({
query: revokeMemberPermissionMutation,
variables: {
workspaceId,
userId,
},
});
return revokeMember;
};
e2e('should set new invited users to AllocatingSeat', async t => {
const { owner, workspace } = await createTeamWorkspace();
await app.login(owner);
@@ -165,3 +209,136 @@ e2e('should set all rests to NeedMoreSeat', async t => {
WorkspaceMemberStatus.NeedMoreSeat
);
});
e2e(
'should cleanup non-accepted members when team workspace is downgraded',
async t => {
const { workspace } = await createTeamWorkspace();
const pending = await app.create(Mockers.User);
await app.create(Mockers.WorkspaceUser, {
userId: pending.id,
workspaceId: workspace.id,
status: WorkspaceMemberStatus.Pending,
});
const allocating = await app.create(Mockers.User);
await app.create(Mockers.WorkspaceUser, {
userId: allocating.id,
workspaceId: workspace.id,
status: WorkspaceMemberStatus.AllocatingSeat,
source: 'Email',
});
const underReview = await app.create(Mockers.User);
await app.create(Mockers.WorkspaceUser, {
userId: underReview.id,
workspaceId: workspace.id,
status: WorkspaceMemberStatus.UnderReview,
});
await app.eventBus.emitAsync('workspace.subscription.canceled', {
workspaceId: workspace.id,
plan: SubscriptionPlan.Team,
recurring: SubscriptionRecurring.Monthly,
});
const [members] = await app.models.workspaceUser.paginate(workspace.id, {
first: 20,
offset: 0,
});
t.deepEqual(
members.map(member => member.status),
[
WorkspaceMemberStatus.Accepted,
WorkspaceMemberStatus.Accepted,
WorkspaceMemberStatus.Accepted,
]
);
t.false(await app.models.workspace.isTeamWorkspace(workspace.id));
}
);
e2e(
'should demote accepted admins and keep workspace writable when downgrade stays within owner quota',
async t => {
const { workspace, owner, admin } = await createTeamWorkspace();
await app.eventBus.emitAsync('workspace.subscription.canceled', {
workspaceId: workspace.id,
plan: SubscriptionPlan.Team,
recurring: SubscriptionRecurring.Monthly,
});
t.false(await app.models.workspace.isTeamWorkspace(workspace.id));
t.false(
await app.models.workspaceFeature.has(
workspace.id,
'quota_exceeded_readonly_workspace_v1'
)
);
t.is(
(await app.models.workspaceUser.get(workspace.id, admin.id))?.type,
WorkspaceRole.Collaborator
);
await app.login(owner);
await t.notThrowsAsync(publishDoc(workspace.id, 'doc-1'));
}
);
e2e(
'should enter readonly mode on over-quota team downgrade and recover through cleanup actions',
async t => {
const { workspace, owner, admin } = await createTeamWorkspace(20);
const extraMembers = await Promise.all(
Array.from({ length: 8 }).map(async () => {
const member = await app.create(Mockers.User);
await app.create(Mockers.WorkspaceUser, {
workspaceId: workspace.id,
userId: member.id,
});
return member;
})
);
await app.login(owner);
await publishDoc(workspace.id, 'published-doc');
await app.eventBus.emitAsync('workspace.subscription.canceled', {
workspaceId: workspace.id,
plan: SubscriptionPlan.Team,
recurring: SubscriptionRecurring.Monthly,
});
t.false(await app.models.workspace.isTeamWorkspace(workspace.id));
t.true(
await app.models.workspaceFeature.has(
workspace.id,
'quota_exceeded_readonly_workspace_v1'
)
);
t.is(
(await app.models.workspaceUser.get(workspace.id, admin.id))?.type,
WorkspaceRole.Collaborator
);
await t.throwsAsync(publishDoc(workspace.id, 'blocked-doc'));
await t.notThrowsAsync(revokePublicDoc(workspace.id, 'published-doc'));
const quota = await app
.get(QuotaService)
.getWorkspaceQuotaWithUsage(workspace.id);
for (const member of extraMembers.slice(0, quota.overcapacityMemberCount)) {
await revokeMember(workspace.id, member.id);
}
t.false(
await app.models.workspaceFeature.has(
workspace.id,
'quota_exceeded_readonly_workspace_v1'
)
);
}
);

View File

@@ -65,6 +65,28 @@ test('should transfer workespace owner', async t => {
const owner2 = await models.workspaceUser.getOwner(workspace.id);
t.is(owner2.id, user2.id);
const oldOwnerRole = await models.workspaceUser.get(workspace.id, user.id);
t.is(oldOwnerRole?.type, WorkspaceRole.Collaborator);
});
test('should keep old owner as admin when transferring a team workspace', async t => {
const [user, user2] = await module.create(Mockers.User, 2);
const workspace = await module.create(Mockers.Workspace, {
owner: { id: user.id },
});
await module.create(Mockers.TeamWorkspace, {
id: workspace.id,
quantity: 10,
});
await module.create(Mockers.WorkspaceUser, {
workspaceId: workspace.id,
userId: user2.id,
});
await models.workspaceUser.setOwner(workspace.id, user2.id);
const oldOwnerRole = await models.workspaceUser.get(workspace.id, user.id);
t.is(oldOwnerRole?.type, WorkspaceRole.Admin);
});
test('should throw if transfer owner to non-active member', async t => {

View File

@@ -14,6 +14,7 @@ import { ServerFeature } from '../../core/config/types';
import { Models } from '../../models';
import { OAuthProviderName } from '../../plugins/oauth/config';
import { OAuthProviderFactory } from '../../plugins/oauth/factory';
import { GithubOAuthProvider } from '../../plugins/oauth/providers/github';
import { GoogleOAuthProvider } from '../../plugins/oauth/providers/google';
import { OIDCProvider } from '../../plugins/oauth/providers/oidc';
import { OAuthService } from '../../plugins/oauth/service';
@@ -38,6 +39,10 @@ test.before(async t => {
clientId: 'google-client-id',
clientSecret: 'google-client-secret',
},
github: {
clientId: 'github-client-id',
clientSecret: 'github-client-secret',
},
oidc: {
clientId: '',
clientSecret: '',
@@ -293,7 +298,7 @@ test('should be able to get registered oauth providers', async t => {
const providers = oauth.availableOAuthProviders();
t.deepEqual(providers, [OAuthProviderName.Google]);
t.deepEqual(providers, [OAuthProviderName.Google, OAuthProviderName.GitHub]);
});
test('should throw if code is missing in callback uri', async t => {
@@ -441,6 +446,24 @@ function mockOAuthProvider(
return clientNonce;
}
function mockGithubOAuthProvider(
app: TestingApp,
clientNonce: string = randomUUID()
) {
const provider = app.get(GithubOAuthProvider);
const oauth = app.get(OAuthService);
Sinon.stub(oauth, 'isValidState').resolves(true);
Sinon.stub(oauth, 'getOAuthState').resolves({
provider: OAuthProviderName.GitHub,
clientNonce,
});
Sinon.stub(provider, 'getToken').resolves({ accessToken: '1' });
return { provider, clientNonce };
}
function mockOidcProvider(
provider: OIDCProvider,
{
@@ -645,6 +668,76 @@ test('should be able to fullfil user with oauth sign in', async t => {
t.is(account!.user.id, u3.id);
});
test('github oauth should resolve private email from emails api', async t => {
const { app, db } = t.context;
const email = 'github-private@affine.pro';
const { clientNonce, provider } = mockGithubOAuthProvider(app);
const fetchJson = Sinon.stub(provider as any, 'fetchJson');
fetchJson.onFirstCall().resolves({
login: 'github-user',
email: null,
avatar_url: 'avatar',
name: 'DarkSky',
});
fetchJson.onSecondCall().resolves([
{ email: 'unverified@affine.pro', primary: true, verified: false },
{ email, primary: false, verified: true },
]);
await app
.POST('/api/oauth/callback')
.send({ code: '1', state: '1', client_nonce: clientNonce })
.expect(HttpStatus.OK);
const sessionUser = await currentUser(app);
t.truthy(sessionUser);
t.is(sessionUser!.email, email);
const user = await db.user.findFirst({
select: {
email: true,
connectedAccounts: true,
},
where: {
email,
},
});
t.truthy(user);
t.is(user!.connectedAccounts[0].provider, OAuthProviderName.GitHub);
t.is(user!.connectedAccounts[0].providerAccountId, 'github-user');
});
test('github oauth should reject responses without a verified email', async t => {
const { app } = t.context;
const provider = app.get(GithubOAuthProvider);
const fetchJson = Sinon.stub(provider as any, 'fetchJson');
fetchJson.onFirstCall().resolves({
login: 'github-user',
email: null,
avatar_url: 'avatar',
name: 'DarkSky',
});
fetchJson
.onSecondCall()
.resolves([
{ email: 'private@affine.pro', primary: true, verified: false },
]);
const error = await t.throwsAsync(
provider.getUser(
{ accessToken: 'token' },
{ token: 'state', provider: OAuthProviderName.GitHub }
)
);
t.true(error instanceof InvalidOauthResponse);
});
test('oidc should accept email from id token when userinfo email is missing', async t => {
const { app } = t.context;

View File

@@ -11,6 +11,7 @@ import { ConfigFactory, ConfigModule } from '../../base/config';
import { CurrentUser } from '../../core/auth';
import { AuthService } from '../../core/auth/service';
import { EarlyAccessType, FeatureService } from '../../core/features';
import { SubscriptionCronJobs } from '../../plugins/payment/cron';
import { SubscriptionService } from '../../plugins/payment/service';
import { StripeFactory } from '../../plugins/payment/stripe';
import {
@@ -871,6 +872,34 @@ test('should be able to cancel subscription', async t => {
t.truthy(subInDB.canceledAt);
});
test('should reconcile canceled stripe subscriptions and revoke local entitlement', async t => {
const { app, db, event, service, stripe, u1 } = t.context;
const cron = app.get(SubscriptionCronJobs);
await service.saveStripeSubscription(sub);
event.emit.resetHistory();
stripe.subscriptions.retrieve.resolves({
...sub,
status: SubscriptionStatus.Canceled,
} as any);
await cron.reconcileStripeSubscriptions();
const subInDB = await db.subscription.findFirst({
where: { targetId: u1.id, stripeSubscriptionId: sub.id },
});
t.is(subInDB, null);
t.true(
event.emit.calledWith('user.subscription.canceled', {
userId: u1.id,
plan: SubscriptionPlan.Pro,
recurring: SubscriptionRecurring.Monthly,
})
);
});
test('should be able to resume subscription', async t => {
const { service, db, u1, stripe } = t.context;

View File

@@ -111,20 +111,3 @@ test('delete', async t => {
await t.throwsAsync(() => fs.access(join(config.path, provider.bucket, key)));
});
test('rejects unsafe object keys', async t => {
const provider = createProvider();
await t.throwsAsync(() => provider.put('../escape', Buffer.from('nope')));
await t.throwsAsync(() => provider.get('nested/../escape'));
await t.throwsAsync(() => provider.head('./escape'));
t.throws(() => provider.delete('nested//escape'));
});
test('rejects unsafe list prefixes', async t => {
const provider = createProvider();
await t.throwsAsync(() => provider.list('../escape'));
await t.throwsAsync(() => provider.list('nested/../../escape'));
await t.throwsAsync(() => provider.list('/absolute'));
});

View File

@@ -25,47 +25,9 @@ import {
} from './provider';
import { autoMetadata, toBuffer } from './utils';
function normalizeStorageKey(key: string): string {
const normalized = key.replaceAll('\\', '/');
const segments = normalized.split('/');
if (
!normalized ||
normalized.startsWith('/') ||
segments.some(segment => !segment || segment === '.' || segment === '..')
) {
throw new Error(`Invalid storage key: ${key}`);
}
return segments.join('/');
}
function normalizeStoragePrefix(prefix: string): string {
const normalized = prefix.replaceAll('\\', '/');
if (!normalized) {
return normalized;
}
if (normalized.startsWith('/')) {
throw new Error(`Invalid storage prefix: ${prefix}`);
}
const segments = normalized.split('/');
const lastSegment = segments.pop();
if (
lastSegment === undefined ||
segments.some(segment => !segment || segment === '.' || segment === '..') ||
lastSegment === '.' ||
lastSegment === '..'
) {
throw new Error(`Invalid storage prefix: ${prefix}`);
}
if (lastSegment === '') {
return `${segments.join('/')}/`;
}
return [...segments, lastSegment].join('/');
function escapeKey(key: string): string {
// avoid '../' and './' in key
return key.replace(/\.?\.[/\\]/g, '%');
}
export interface FsStorageConfig {
@@ -95,7 +57,7 @@ export class FsStorageProvider implements StorageProvider {
body: BlobInputType,
metadata: PutObjectMetadata = {}
): Promise<void> {
key = normalizeStorageKey(key);
key = escapeKey(key);
const blob = await toBuffer(body);
// write object
@@ -106,7 +68,6 @@ export class FsStorageProvider implements StorageProvider {
}
async head(key: string) {
key = normalizeStorageKey(key);
const metadata = this.readMetadata(key);
if (!metadata) {
this.logger.verbose(`Object \`${key}\` not found`);
@@ -119,7 +80,7 @@ export class FsStorageProvider implements StorageProvider {
body?: Readable;
metadata?: GetObjectMetadata;
}> {
key = normalizeStorageKey(key);
key = escapeKey(key);
try {
const metadata = this.readMetadata(key);
@@ -144,7 +105,7 @@ export class FsStorageProvider implements StorageProvider {
// read dir recursively and filter out '.metadata.json' files
let dir = this.path;
if (prefix) {
prefix = normalizeStoragePrefix(prefix);
prefix = escapeKey(prefix);
const parts = prefix.split(/[/\\]/);
// for prefix `a/b/c`, move `a/b` to dir and `c` to key prefix
if (parts.length > 1) {
@@ -191,7 +152,7 @@ export class FsStorageProvider implements StorageProvider {
}
delete(key: string): Promise<void> {
key = normalizeStorageKey(key);
key = escapeKey(key);
try {
rmSync(this.join(key), { force: true });

View File

@@ -1,3 +1,5 @@
import { randomUUID } from 'node:crypto';
import test from 'ava';
import { createTestingModule, TestingModule } from '../../../__tests__/utils';
@@ -10,11 +12,13 @@ import {
} from '../../../models';
import { DocAccessController } from '../doc';
import { PermissionModule } from '../index';
import { WorkspacePolicyService } from '../policy';
import { DocRole, mapDocRoleToPermissions } from '../types';
let module: TestingModule;
let models: Models;
let ac: DocAccessController;
let policy: WorkspacePolicyService;
let user: User;
let ws: Workspace;
@@ -22,11 +26,12 @@ test.before(async () => {
module = await createTestingModule({ imports: [PermissionModule] });
models = module.get<Models>(Models);
ac = module.get(DocAccessController);
policy = module.get(WorkspacePolicyService);
});
test.beforeEach(async () => {
await module.initTestingDB();
user = await models.user.create({ email: 'u1@affine.pro' });
user = await models.user.create({ email: `${randomUUID()}@affine.pro` });
ws = await models.workspace.create(user.id);
});
@@ -45,7 +50,7 @@ test('should get null role', async t => {
});
test('should return null if workspace role is not accepted', async t => {
const u2 = await models.user.create({ email: 'u2@affine.pro' });
const u2 = await models.user.create({ email: `${randomUUID()}@affine.pro` });
await models.workspaceUser.set(ws.id, u2.id, WorkspaceRole.Collaborator, {
status: WorkspaceMemberStatus.UnderReview,
});
@@ -162,7 +167,7 @@ test('should assert action', async t => {
)
);
const u2 = await models.user.create({ email: 'u2@affine.pro' });
const u2 = await models.user.create({ email: `${randomUUID()}@affine.pro` });
await t.throwsAsync(
ac.assert(
@@ -184,3 +189,37 @@ test('should assert action', async t => {
)
);
});
test('should apply readonly doc restrictions while keeping cleanup actions', async t => {
for (let index = 0; index < 10; index++) {
const member = await models.user.create({
email: `${randomUUID()}@affine.pro`,
});
await models.workspaceUser.set(
ws.id,
member.id,
WorkspaceRole.Collaborator,
{
status: WorkspaceMemberStatus.Accepted,
}
);
}
await policy.reconcileWorkspaceQuotaState(ws.id);
const { permissions } = await ac.role({
workspaceId: ws.id,
docId: 'doc1',
userId: user.id,
});
t.false(permissions['Doc.Update']);
t.false(permissions['Doc.Publish']);
t.false(permissions['Doc.Duplicate']);
t.false(permissions['Doc.Comments.Create']);
t.false(permissions['Doc.Comments.Update']);
t.false(permissions['Doc.Comments.Resolve']);
t.true(permissions['Doc.Read']);
t.true(permissions['Doc.Delete']);
t.true(permissions['Doc.Trash']);
t.true(permissions['Doc.TransferOwner']);
});

View File

@@ -0,0 +1,198 @@
import { randomUUID } from 'node:crypto';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
import {
createTestingModule,
type TestingModule,
} from '../../../__tests__/utils';
import { SpaceAccessDenied } from '../../../base';
import {
Models,
User,
Workspace,
WorkspaceMemberStatus,
WorkspaceRole,
} from '../../../models';
import { QuotaService } from '../../quota/service';
import { PermissionModule } from '../index';
import { WorkspacePolicyService } from '../policy';
interface Context {
module: TestingModule;
models: Models;
policy: WorkspacePolicyService;
}
const test = ava as TestFn<Context>;
const READONLY_FEATURE = 'quota_exceeded_readonly_workspace_v1' as const;
type WorkspaceQuotaSnapshot = Awaited<
ReturnType<QuotaService['getWorkspaceQuotaWithUsage']>
> & {
ownerQuota?: string;
};
async function addAcceptedMembers(
models: Models,
workspaceId: string,
count: number
) {
for (let index = 0; index < count; index++) {
const member = await models.user.create({
email: `${randomUUID()}@affine.pro`,
});
await models.workspaceUser.set(
workspaceId,
member.id,
WorkspaceRole.Collaborator,
{
status: WorkspaceMemberStatus.Accepted,
}
);
}
}
let owner: User;
let workspace: Workspace;
test.before(async t => {
const module = await createTestingModule({ imports: [PermissionModule] });
t.context.module = module;
t.context.models = module.get(Models);
t.context.policy = module.get(WorkspacePolicyService);
});
test.beforeEach(async t => {
Sinon.restore();
await t.context.module.initTestingDB();
owner = await t.context.models.user.create({
email: `${randomUUID()}@affine.pro`,
});
workspace = await t.context.models.workspace.create(owner.id);
});
test.after.always(async t => {
await t.context.module.close();
});
test('should keep owned workspace writable when quota is within limit', async t => {
const state = await t.context.policy.reconcileWorkspaceQuotaState(
workspace.id
);
t.false(state.isReadonly);
t.deepEqual(state.readonlyReasons, []);
t.false(
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
);
});
test('should enter readonly mode when fallback owner member quota overflows', async t => {
await addAcceptedMembers(t.context.models, workspace.id, 10);
const state = await t.context.policy.reconcileWorkspaceQuotaState(
workspace.id
);
t.true(state.isReadonly);
t.true(state.canRecoverByRemovingMembers);
t.false(state.canRecoverByDeletingBlobs);
t.deepEqual(state.readonlyReasons, ['member_overflow']);
t.true(
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
);
await t.throwsAsync(t.context.policy.assertCanInviteMembers(workspace.id), {
instanceOf: SpaceAccessDenied,
});
});
test('should enter readonly mode when fallback owner storage quota overflows', async t => {
const quota = Sinon.stub(
Reflect.get(t.context.policy, 'quota') as QuotaService,
'getWorkspaceQuotaWithUsage'
);
quota.resolves({
name: 'Free',
blobLimit: 1,
storageQuota: 1,
usedStorageQuota: 2,
historyPeriod: 1,
memberLimit: 3,
memberCount: 1,
overcapacityMemberCount: 0,
usedSize: 2,
ownerQuota: owner.id,
} satisfies WorkspaceQuotaSnapshot);
const state = await t.context.policy.reconcileWorkspaceQuotaState(
workspace.id
);
t.true(state.isReadonly);
t.false(state.canRecoverByRemovingMembers);
t.true(state.canRecoverByDeletingBlobs);
t.deepEqual(state.readonlyReasons, ['storage_overflow']);
t.true(
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
);
});
test('should leave readonly mode after workspace usage recovers', async t => {
const quota = Sinon.stub(
Reflect.get(t.context.policy, 'quota') as QuotaService,
'getWorkspaceQuotaWithUsage'
);
quota.onFirstCall().resolves({
name: 'Free',
blobLimit: 1,
storageQuota: 1,
usedStorageQuota: 2,
historyPeriod: 1,
memberLimit: 3,
memberCount: 1,
overcapacityMemberCount: 0,
usedSize: 2,
ownerQuota: owner.id,
} satisfies WorkspaceQuotaSnapshot);
quota.onSecondCall().resolves({
name: 'Free',
blobLimit: 1,
storageQuota: 1,
usedStorageQuota: 0,
historyPeriod: 1,
memberLimit: 3,
memberCount: 1,
overcapacityMemberCount: 0,
usedSize: 0,
ownerQuota: owner.id,
} satisfies WorkspaceQuotaSnapshot);
quota.onThirdCall().resolves({
name: 'Free',
blobLimit: 1,
storageQuota: 1,
usedStorageQuota: 0,
historyPeriod: 1,
memberLimit: 3,
memberCount: 1,
overcapacityMemberCount: 0,
usedSize: 0,
ownerQuota: owner.id,
} satisfies WorkspaceQuotaSnapshot);
await t.context.policy.reconcileWorkspaceQuotaState(workspace.id);
t.true(
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
);
const recovered = await t.context.policy.reconcileWorkspaceQuotaState(
workspace.id
);
t.false(recovered.isReadonly);
t.deepEqual(recovered.readonlyReasons, []);
t.false(
await t.context.models.workspaceFeature.has(workspace.id, READONLY_FEATURE)
);
await t.notThrowsAsync(t.context.policy.assertCanInviteMembers(workspace.id));
});

View File

@@ -1,3 +1,5 @@
import { randomUUID } from 'node:crypto';
import test from 'ava';
import { createTestingModule, TestingModule } from '../../../__tests__/utils';
@@ -9,24 +11,27 @@ import {
WorkspaceRole,
} from '../../../models';
import { PermissionModule } from '../index';
import { WorkspacePolicyService } from '../policy';
import { mapWorkspaceRoleToPermissions } from '../types';
import { WorkspaceAccessController } from '../workspace';
let module: TestingModule;
let models: Models;
let ac: WorkspaceAccessController;
let policy: WorkspacePolicyService;
let user: User;
let ws: Workspace;
test.before(async () => {
module = await createTestingModule({ imports: [PermissionModule] });
models = module.get<Models>(Models);
ac = new WorkspaceAccessController(models);
ac = module.get(WorkspaceAccessController);
policy = module.get(WorkspacePolicyService);
});
test.beforeEach(async () => {
await module.initTestingDB();
user = await models.user.create({ email: 'u1@affine.pro' });
user = await models.user.create({ email: `${randomUUID()}@affine.pro` });
ws = await models.workspace.create(user.id);
});
@@ -44,7 +49,7 @@ test('should get null role', async t => {
});
test('should return null if role is not accepted', async t => {
const u2 = await models.user.create({ email: 'u2@affine.pro' });
const u2 = await models.user.create({ email: `${randomUUID()}@affine.pro` });
await models.workspaceUser.set(ws.id, u2.id, WorkspaceRole.Collaborator, {
status: WorkspaceMemberStatus.UnderReview,
});
@@ -183,3 +188,38 @@ test('should assert action', async t => {
)
);
});
test('should apply readonly workspace restrictions while keeping cleanup actions', async t => {
for (let index = 0; index < 10; index++) {
const member = await models.user.create({
email: `${randomUUID()}@affine.pro`,
});
await models.workspaceUser.set(
ws.id,
member.id,
WorkspaceRole.Collaborator,
{
status: WorkspaceMemberStatus.Accepted,
}
);
}
await policy.reconcileWorkspaceQuotaState(ws.id);
const { permissions } = await ac.role({
workspaceId: ws.id,
userId: user.id,
});
t.false(permissions['Workspace.CreateDoc']);
t.false(permissions['Workspace.Settings.Update']);
t.false(permissions['Workspace.Properties.Create']);
t.false(permissions['Workspace.Properties.Update']);
t.false(permissions['Workspace.Properties.Delete']);
t.false(permissions['Workspace.Blobs.Write']);
t.true(permissions['Workspace.Read']);
t.true(permissions['Workspace.Sync']);
t.true(permissions['Workspace.Users.Manage']);
t.true(permissions['Workspace.Blobs.List']);
t.true(permissions['Workspace.TransferOwner']);
t.true(permissions['Workspace.Payment.Manage']);
});

View File

@@ -3,6 +3,7 @@ import { Injectable } from '@nestjs/common';
import { DocActionDenied } from '../../base';
import { Models } from '../../models';
import { AccessController, getAccessController } from './controller';
import { WorkspacePolicyService } from './policy';
import type { Resource } from './resource';
import {
DocAction,
@@ -15,13 +16,19 @@ import { WorkspaceAccessController } from './workspace';
@Injectable()
export class DocAccessController extends AccessController<'doc'> {
protected readonly type = 'doc';
constructor(private readonly models: Models) {
constructor(
private readonly models: Models,
private readonly policy: WorkspacePolicyService
) {
super();
}
async role(resource: Resource<'doc'>) {
const role = await this.getRole(resource);
const permissions = mapDocRoleToPermissions(role);
const permissions = await this.policy.applyDocPermissions(
resource.workspaceId,
mapDocRoleToPermissions(role)
);
const sharingAllowed = await this.models.workspace.allowSharing(
resource.workspaceId
);

View File

@@ -1,22 +1,29 @@
import { Module } from '@nestjs/common';
import { QuotaService } from '../quota/service';
import { StorageModule } from '../storage';
import { AccessControllerBuilder } from './builder';
import { DocAccessController } from './doc';
import { EventsListener } from './event';
import { WorkspacePolicyService } from './policy';
import { WorkspaceAccessController } from './workspace';
@Module({
imports: [StorageModule],
providers: [
QuotaService,
WorkspaceAccessController,
DocAccessController,
AccessControllerBuilder,
EventsListener,
WorkspacePolicyService,
],
exports: [AccessControllerBuilder],
exports: [AccessControllerBuilder, WorkspacePolicyService],
})
export class PermissionModule {}
export { AccessControllerBuilder as AccessController } from './builder';
export { WorkspacePolicyService } from './policy';
export {
DOC_ACTIONS,
type DocAction,

View File

@@ -0,0 +1,328 @@
import { Injectable } from '@nestjs/common';
import { DocActionDenied, OnEvent, SpaceAccessDenied } from '../../base';
import { Models, WorkspaceRole } from '../../models';
import { QuotaService } from '../quota/service';
import { getAccessController } from './controller';
import type { Resource } from './resource';
import {
type DocAction,
type DocActionPermissions,
mapWorkspaceRoleToPermissions,
type WorkspaceAction,
type WorkspaceActionPermissions,
} from './types';
export type WorkspaceReadonlyReason = 'member_overflow' | 'storage_overflow';
type WorkspaceQuotaSnapshot = Awaited<
ReturnType<QuotaService['getWorkspaceQuotaWithUsage']>
> & {
ownerQuota?: string;
};
export type WorkspaceState = {
isTeamWorkspace: boolean;
isReadonly: boolean;
readonlyReasons: WorkspaceReadonlyReason[];
canRecoverByRemovingMembers: boolean;
canRecoverByDeletingBlobs: boolean;
usesFallbackOwnerQuota: boolean;
};
const READONLY_WORKSPACE_ACTIONS: WorkspaceAction[] = [
'Workspace.CreateDoc',
'Workspace.Settings.Update',
'Workspace.Properties.Create',
'Workspace.Properties.Update',
'Workspace.Properties.Delete',
'Workspace.Blobs.Write',
];
const READONLY_DOC_ACTIONS: DocAction[] = [
'Doc.Update',
'Doc.Duplicate',
'Doc.Publish',
'Doc.Comments.Create',
'Doc.Comments.Update',
'Doc.Comments.Resolve',
];
const READONLY_WORKSPACE_FEATURE =
'quota_exceeded_readonly_workspace_v1' as const;
type WorkspaceRoleChecker = {
getRole(resource: Resource<'ws'>): Promise<WorkspaceRole | null>;
docRoles(
resource: Resource<'ws'>,
docIds: string[]
): Promise<Array<{ role: unknown; permissions: Record<DocAction, boolean> }>>;
};
declare global {
interface Events {
'workspace.blobs.updated': {
workspaceId: string;
};
}
}
@Injectable()
export class WorkspacePolicyService {
constructor(
private readonly models: Models,
private readonly quota: QuotaService
) {}
async getWorkspaceState(workspaceId: string): Promise<WorkspaceState> {
const [isTeamWorkspace, isUnlimitedWorkspace, quota] = await Promise.all([
this.models.workspace.isTeamWorkspace(workspaceId),
this.models.workspaceFeature.has(workspaceId, 'unlimited_workspace'),
this.quota.getWorkspaceQuotaWithUsage(workspaceId),
]);
const quotaSnapshot = quota as WorkspaceQuotaSnapshot;
const readonlyReasons: WorkspaceReadonlyReason[] = [];
const usesFallbackOwnerQuota =
!!quotaSnapshot.ownerQuota && !isUnlimitedWorkspace;
if (usesFallbackOwnerQuota && quotaSnapshot.overcapacityMemberCount > 0) {
readonlyReasons.push('member_overflow');
}
if (
usesFallbackOwnerQuota &&
quotaSnapshot.usedStorageQuota > quotaSnapshot.storageQuota
) {
readonlyReasons.push('storage_overflow');
}
return {
isTeamWorkspace,
isReadonly: readonlyReasons.length > 0,
readonlyReasons,
canRecoverByRemovingMembers: readonlyReasons.includes('member_overflow'),
canRecoverByDeletingBlobs: readonlyReasons.includes('storage_overflow'),
usesFallbackOwnerQuota,
};
}
async reconcileOwnedWorkspaces(userId: string) {
const workspaces = await this.models.workspaceUser.getUserActiveRoles(
userId,
{ role: WorkspaceRole.Owner }
);
await Promise.all(
workspaces.map(({ workspaceId }) =>
this.reconcileWorkspaceQuotaState(workspaceId)
)
);
}
async reconcileWorkspaceQuotaState(workspaceId: string) {
const [state, isReadonlyFeatureEnabled] = await Promise.all([
this.getWorkspaceState(workspaceId),
this.models.workspaceFeature.has(workspaceId, READONLY_WORKSPACE_FEATURE),
]);
if (state.isReadonly && !isReadonlyFeatureEnabled) {
await this.models.workspaceFeature.add(
workspaceId,
READONLY_WORKSPACE_FEATURE,
`workspace recovery mode: ${state.readonlyReasons.join(',')}`
);
} else if (!state.isReadonly && isReadonlyFeatureEnabled) {
await this.models.workspaceFeature.remove(
workspaceId,
READONLY_WORKSPACE_FEATURE
);
}
return state;
}
async isWorkspaceReadonly(workspaceId: string) {
const hasReadonlyFeature = await this.models.workspaceFeature.has(
workspaceId,
READONLY_WORKSPACE_FEATURE
);
if (!hasReadonlyFeature) {
return false;
}
const state = await this.getWorkspaceState(workspaceId);
if (!state.isReadonly) {
await this.models.workspaceFeature.remove(
workspaceId,
READONLY_WORKSPACE_FEATURE
);
return false;
}
return true;
}
async applyWorkspacePermissions(
workspaceId: string,
permissions: WorkspaceActionPermissions
) {
if (!(await this.isWorkspaceReadonly(workspaceId))) {
return permissions;
}
const next = { ...permissions };
READONLY_WORKSPACE_ACTIONS.forEach(action => {
next[action] = false;
});
return next;
}
async applyDocPermissions(
workspaceId: string,
permissions: DocActionPermissions
) {
if (!(await this.isWorkspaceReadonly(workspaceId))) {
return permissions;
}
const next = { ...permissions };
READONLY_DOC_ACTIONS.forEach(action => {
next[action] = false;
});
return next;
}
async assertWorkspaceActionAllowed(
workspaceId: string,
action: WorkspaceAction
) {
if (
READONLY_WORKSPACE_ACTIONS.includes(action) &&
(await this.isWorkspaceReadonly(workspaceId))
) {
throw new SpaceAccessDenied({ spaceId: workspaceId });
}
}
async assertDocActionAllowed(
workspaceId: string,
docId: string,
action: DocAction
) {
if (
READONLY_DOC_ACTIONS.includes(action) &&
(await this.isWorkspaceReadonly(workspaceId))
) {
throw new DocActionDenied({
action,
docId,
spaceId: workspaceId,
});
}
}
async assertWorkspaceRoleAction(
userId: string,
workspaceId: string,
action: WorkspaceAction
) {
const checker = getAccessController(
'ws'
) as unknown as WorkspaceRoleChecker;
const role = await checker.getRole({ userId, workspaceId });
const permissions = mapWorkspaceRoleToPermissions(role);
if (!permissions[action]) {
throw new SpaceAccessDenied({ spaceId: workspaceId });
}
}
async assertDocRoleAction(
userId: string,
workspaceId: string,
docId: string,
action: DocAction
) {
const checker = getAccessController(
'ws'
) as unknown as WorkspaceRoleChecker;
const [role] = await checker.docRoles({ userId, workspaceId }, [docId]);
if (!role?.permissions[action]) {
throw new DocActionDenied({
action,
docId,
spaceId: workspaceId,
});
}
}
async assertCanUploadBlob(workspaceId: string) {
await this.assertWorkspaceActionAllowed(
workspaceId,
'Workspace.Blobs.Write'
);
}
async assertCanDeleteBlob(userId: string, workspaceId: string) {
await this.assertWorkspaceRoleAction(
userId,
workspaceId,
'Workspace.Blobs.Write'
);
}
async assertCanInviteMembers(workspaceId: string) {
if (await this.isWorkspaceReadonly(workspaceId)) {
throw new SpaceAccessDenied({ spaceId: workspaceId });
}
}
async assertCanRevokeMember(
userId: string,
workspaceId: string,
role: WorkspaceRole
) {
await this.assertWorkspaceRoleAction(
userId,
workspaceId,
role === WorkspaceRole.Admin
? 'Workspace.Administrators.Manage'
: 'Workspace.Users.Manage'
);
}
async assertCanPublishDoc(workspaceId: string, docId: string) {
await this.assertDocActionAllowed(workspaceId, docId, 'Doc.Publish');
}
async assertCanUnpublishDoc(
userId: string,
workspaceId: string,
docId: string
) {
await this.assertDocRoleAction(userId, workspaceId, docId, 'Doc.Publish');
}
@OnEvent('workspace.members.updated')
async onWorkspaceMembersUpdated({
workspaceId,
}: Events['workspace.members.updated']) {
await this.reconcileWorkspaceQuotaState(workspaceId);
}
@OnEvent('workspace.owner.changed')
async onWorkspaceOwnerChanged({
workspaceId,
}: Events['workspace.owner.changed']) {
await this.reconcileWorkspaceQuotaState(workspaceId);
}
@OnEvent('workspace.blobs.updated')
async onWorkspaceBlobsUpdated({
workspaceId,
}: Events['workspace.blobs.updated']) {
await this.reconcileWorkspaceQuotaState(workspaceId);
}
}

View File

@@ -3,6 +3,7 @@ import { Injectable } from '@nestjs/common';
import { SpaceAccessDenied } from '../../base';
import { DocRole, Models } from '../../models';
import { AccessController } from './controller';
import { WorkspacePolicyService } from './policy';
import type { Resource } from './resource';
import {
fixupDocRole,
@@ -17,7 +18,10 @@ import {
export class WorkspaceAccessController extends AccessController<'ws'> {
protected readonly type = 'ws';
constructor(private readonly models: Models) {
constructor(
private readonly models: Models,
private readonly policy: WorkspacePolicyService
) {
super();
}
@@ -37,7 +41,10 @@ export class WorkspaceAccessController extends AccessController<'ws'> {
return {
role,
permissions: mapWorkspaceRoleToPermissions(role),
permissions: await this.policy.applyWorkspacePermissions(
resource.workspaceId,
mapWorkspaceRoleToPermissions(role)
),
};
}

View File

@@ -1,6 +1,5 @@
import { Module } from '@nestjs/common';
import { PermissionModule } from '../permission';
import { StorageModule } from '../storage';
import { QuotaResolver } from './resolver';
import { QuotaService } from './service';
@@ -12,7 +11,7 @@ import { QuotaService } from './service';
* - quota statistics
*/
@Module({
imports: [StorageModule, PermissionModule],
imports: [StorageModule],
providers: [QuotaService, QuotaResolver],
exports: [QuotaService],
})

View File

@@ -20,7 +20,10 @@ type UserQuotaWithUsage = Omit<UserQuotaType, 'humanReadable'>;
type WorkspaceQuota = Omit<BaseWorkspaceQuota, 'seatQuota'> & {
ownerQuota?: string;
};
type WorkspaceQuotaWithUsage = Omit<WorkspaceQuotaType, 'humanReadable'>;
export type WorkspaceQuotaWithUsage = Omit<
WorkspaceQuotaType,
'humanReadable'
> & { ownerQuota?: string };
@Injectable()
export class QuotaService {

View File

@@ -26,6 +26,9 @@ declare global {
workspaceId: string;
key: string;
};
'workspace.blobs.updated': {
workspaceId: string;
};
}
}
@@ -255,6 +258,9 @@ export class WorkspaceBlobStorage {
await this.provider.delete(`${workspaceId}/${key}`);
}
await this.models.blob.delete(workspaceId, key, permanently);
if (!permanently) {
await this.event.emitAsync('workspace.blobs.updated', { workspaceId });
}
}
async release(workspaceId: string) {
@@ -270,6 +276,8 @@ export class WorkspaceBlobStorage {
this.logger.log(
`released ${deletedBlobs.length} blobs for workspace ${workspaceId}`
);
await this.event.emitAsync('workspace.blobs.updated', { workspaceId });
}
async totalSize(workspaceId: string) {

View File

@@ -624,6 +624,7 @@ export class SpaceSyncGateway
const { spaceType, spaceId, docId, update } = message;
const adapter = this.selectAdapter(client, spaceType);
// Quota recovery mode is intentionally not applied to sync in this phase.
// TODO(@forehalo): enable after frontend supporting doc revert
// await this.ac.user(user.id).doc(spaceId, docId).assert('Doc.Update');
const timestamp = await adapter.push(

View File

@@ -25,7 +25,7 @@ import {
} from '../../../base';
import { Models } from '../../../models';
import { CurrentUser } from '../../auth';
import { AccessController } from '../../permission';
import { AccessController, WorkspacePolicyService } from '../../permission';
import { QuotaService } from '../../quota';
import { WorkspaceBlobStorage } from '../../storage';
import {
@@ -126,6 +126,7 @@ export class WorkspaceBlobResolver {
logger = new Logger(WorkspaceBlobResolver.name);
constructor(
private readonly ac: AccessController,
private readonly policy: WorkspacePolicyService,
private readonly quota: QuotaService,
private readonly storage: WorkspaceBlobStorage,
private readonly models: Models
@@ -466,10 +467,7 @@ export class WorkspaceBlobResolver {
return false;
}
await this.ac
.user(user.id)
.workspace(workspaceId)
.assert('Workspace.Blobs.Write');
await this.policy.assertCanDeleteBlob(user.id, workspaceId);
await this.storage.delete(workspaceId, key, permanently);
@@ -481,10 +479,7 @@ export class WorkspaceBlobResolver {
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string
) {
await this.ac
.user(user.id)
.workspace(workspaceId)
.assert('Workspace.Blobs.Write');
await this.policy.assertCanDeleteBlob(user.id, workspaceId);
await this.storage.release(workspaceId);

View File

@@ -38,6 +38,7 @@ import {
DOC_ACTIONS,
DocAction,
DocRole,
WorkspacePolicyService,
} from '../../permission';
import { PublicUserType, WorkspaceUserType } from '../../user';
import { WorkspaceType } from '../types';
@@ -295,6 +296,7 @@ export class WorkspaceDocResolver {
*/
private readonly prisma: PrismaClient,
private readonly ac: AccessController,
private readonly policy: WorkspacePolicyService,
private readonly models: Models,
private readonly cache: Cache
) {}
@@ -437,7 +439,7 @@ export class WorkspaceDocResolver {
throw new ExpectToRevokePublicDoc('Expect doc not to be workspace');
}
await this.ac.user(user.id).doc(workspaceId, docId).assert('Doc.Publish');
await this.policy.assertCanUnpublishDoc(user.id, workspaceId, docId);
const doc = await this.models.doc.unpublish(workspaceId, docId);

View File

@@ -36,7 +36,11 @@ import {
} from '../../../base';
import { Models } from '../../../models';
import { CurrentUser, Public } from '../../auth';
import { AccessController, WorkspaceRole } from '../../permission';
import {
AccessController,
WorkspacePolicyService,
WorkspaceRole,
} from '../../permission';
import { QuotaService } from '../../quota';
import { UserType } from '../../user';
import { validators } from '../../utils/validators';
@@ -64,6 +68,7 @@ export class WorkspaceMemberResolver {
private readonly ac: AccessController,
private readonly models: Models,
private readonly mutex: RequestMutex,
private readonly policy: WorkspacePolicyService,
private readonly workspaceService: WorkspaceService,
private readonly quota: QuotaService
) {}
@@ -304,10 +309,11 @@ export class WorkspaceMemberResolver {
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string
) {
await this.ac
.user(user.id)
.workspace(workspaceId)
.assert('Workspace.Users.Manage');
await this.policy.assertWorkspaceRoleAction(
user.id,
workspaceId,
'Workspace.Users.Manage'
);
const cacheId = `workspace:inviteLink:${workspaceId}`;
return await this.cache.delete(cacheId);
@@ -359,6 +365,7 @@ export class WorkspaceMemberResolver {
role.id,
me.id
);
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
}
return true;
} else {
@@ -453,14 +460,7 @@ export class WorkspaceMemberResolver {
throw new MemberNotFoundInSpace({ spaceId: workspaceId });
}
await this.ac
.user(me.id)
.workspace(workspaceId)
.assert(
role.type === WorkspaceRole.Admin
? 'Workspace.Administrators.Manage'
: 'Workspace.Users.Manage'
);
await this.policy.assertCanRevokeMember(me.id, workspaceId, role.type);
await this.models.workspaceUser.delete(workspaceId, userId);
@@ -480,6 +480,7 @@ export class WorkspaceMemberResolver {
this.event.emit('workspace.members.updated', {
workspaceId,
});
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
return true;
}
@@ -580,11 +581,20 @@ export class WorkspaceMemberResolver {
this.event.emit('workspace.members.updated', {
workspaceId,
});
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
return true;
}
private async acceptInvitationByEmail(role: WorkspaceUserRole) {
await this.policy.assertCanInviteMembers(role.workspaceId);
const hasSeat = await this.quota.tryCheckSeat(role.workspaceId, true);
if (!hasSeat) {
throw new NoMoreSeat({ spaceId: role.workspaceId });
}
await this.models.workspaceUser.setStatus(
role.workspaceId,
role.userId,
@@ -596,6 +606,7 @@ export class WorkspaceMemberResolver {
(await this.models.workspaceUser.getOwner(role.workspaceId)).id,
role.id
);
await this.policy.reconcileWorkspaceQuotaState(role.workspaceId);
}
private async acceptInvitationByLink(
@@ -603,6 +614,8 @@ export class WorkspaceMemberResolver {
workspaceId: string,
inviterId: string
) {
await this.policy.assertCanInviteMembers(workspaceId);
let inviter = await this.models.user.getPublicUser(inviterId);
if (!inviter) {
inviter = await this.models.workspaceUser.getOwner(workspaceId);

View File

@@ -1,12 +0,0 @@
import { ModuleRef } from '@nestjs/core';
import { PrismaClient } from '@prisma/client';
import { IndexerService } from '../../plugins/indexer';
export class RebuildManticoreMixedScriptIndexes1763800000000 {
static async up(_db: PrismaClient, ref: ModuleRef) {
await ref.get(IndexerService, { strict: false }).rebuildManticoreIndexes();
}
static async down(_db: PrismaClient) {}
}

View File

@@ -3,4 +3,3 @@ export * from './1703756315970-unamed-account';
export * from './1721299086340-refresh-unnamed-user';
export * from './1745211351719-create-indexer-tables';
export * from './1751966744168-correct-session-update-time';
export * from './1763800000000-rebuild-manticore-mixed-script-indexes';

View File

@@ -53,6 +53,7 @@ export enum Feature {
// workspace
UnlimitedWorkspace = 'unlimited_workspace',
TeamPlan = 'team_plan_v1',
QuotaExceededReadonlyWorkspace = 'quota_exceeded_readonly_workspace_v1',
}
// TODO(@forehalo): may merge `FeatureShapes` and `FeatureConfigs`?
@@ -66,6 +67,7 @@ export const FeaturesShapes = {
pro_plan_v1: UserPlanQuotaConfig,
lifetime_pro_plan_v1: UserPlanQuotaConfig,
team_plan_v1: WorkspaceQuotaConfig,
quota_exceeded_readonly_workspace_v1: EMPTY_CONFIG,
} satisfies Record<Feature, z.ZodObject<any>>;
export type UserFeatureName = keyof Pick<
@@ -80,7 +82,9 @@ export type UserFeatureName = keyof Pick<
>;
export type WorkspaceFeatureName = keyof Pick<
typeof FeaturesShapes,
'unlimited_workspace' | 'team_plan_v1'
| 'unlimited_workspace'
| 'team_plan_v1'
| 'quota_exceeded_readonly_workspace_v1'
>;
export type FeatureName = UserFeatureName | WorkspaceFeatureName;
@@ -162,6 +166,7 @@ export const FeatureConfigs: {
team_plan_v1: TeamFeature,
early_access: WhitelistFeature,
unlimited_workspace: EmptyFeature,
quota_exceeded_readonly_workspace_v1: EmptyFeature,
unlimited_copilot: EmptyFeature,
ai_early_access: EmptyFeature,
administrator: EmptyFeature,

View File

@@ -36,7 +36,8 @@ export class WorkspaceUserModel extends BaseModel {
/**
* Set or update the [Owner] of a workspace.
* The old [Owner] will be changed to [Admin] if there is already an [Owner].
* The old [Owner] will be changed to [Admin] for team workspace and
* [Collaborator] for owned workspace if there is already an [Owner].
*/
@Transactional()
async setOwner(workspaceId: string, userId: string) {
@@ -63,12 +64,18 @@ export class WorkspaceUserModel extends BaseModel {
throw new NewOwnerIsNotActiveMember();
}
const fallbackRole = (await this.models.workspace.isTeamWorkspace(
workspaceId
))
? WorkspaceRole.Admin
: WorkspaceRole.Collaborator;
await this.db.workspaceUserRole.update({
where: {
id: oldOwner.id,
},
data: {
type: WorkspaceRole.Admin,
type: fallbackRole,
},
});
await this.db.workspaceUserRole.update({
@@ -201,6 +208,25 @@ export class WorkspaceUserModel extends BaseModel {
});
}
async deleteNonAccepted(workspaceId: string) {
return await this.db.workspaceUserRole.deleteMany({
where: { workspaceId, status: { not: WorkspaceMemberStatus.Accepted } },
});
}
async demoteAcceptedAdmins(workspaceId: string) {
return await this.db.workspaceUserRole.updateMany({
where: {
workspaceId,
status: WorkspaceMemberStatus.Accepted,
type: WorkspaceRole.Admin,
},
data: {
type: WorkspaceRole.Collaborator,
},
});
}
async get(workspaceId: string, userId: string) {
return await this.db.workspaceUserRole.findUnique({
where: {

View File

@@ -37,7 +37,10 @@ import {
UserFriendlyError,
} from '../../../base';
import { CurrentUser } from '../../../core/auth';
import { AccessController } from '../../../core/permission';
import {
AccessController,
WorkspacePolicyService,
} from '../../../core/permission';
import {
ContextBlob,
ContextCategories,
@@ -408,6 +411,7 @@ export class CopilotContextRootResolver {
export class CopilotContextResolver {
constructor(
private readonly ac: AccessController,
private readonly policy: WorkspacePolicyService,
private readonly models: Models,
private readonly mutex: RequestMutex,
private readonly context: CopilotContextService,
@@ -667,6 +671,7 @@ export class CopilotContextResolver {
const blobId = createHash('sha256').update(buffer).digest('base64url');
const { filename, mimetype } = content;
await this.policy.assertCanUploadBlob(session.workspaceId);
await this.storage.put(user.id, session.workspaceId, blobId, buffer);
const file = await session.addFile(
blobId,

View File

@@ -37,7 +37,11 @@ import {
import { CurrentUser } from '../../core/auth';
import { Admin } from '../../core/common';
import { DocReader } from '../../core/doc';
import { AccessController, DocAction } from '../../core/permission';
import {
AccessController,
DocAction,
WorkspacePolicyService,
} from '../../core/permission';
import { UserType } from '../../core/user';
import type { ListSessionOptions, UpdateChatSession } from '../../models';
import { processImage } from '../../native';
@@ -378,6 +382,7 @@ export class CopilotResolver {
constructor(
private readonly ac: AccessController,
private readonly mutex: RequestMutex,
private readonly policy: WorkspacePolicyService,
private readonly prompt: PromptService,
private readonly chatSession: ChatSessionService,
private readonly storage: CopilotStorage,
@@ -778,6 +783,10 @@ export class CopilotResolver {
delete options.blob;
delete options.blobs;
if (blobs.length) {
await this.policy.assertCanUploadBlob(workspaceId);
}
for (const blob of blobs) {
const uploaded = await this.storage.handleUpload(user.id, blob);
const detectedMime =

View File

@@ -24,7 +24,10 @@ import {
UserFriendlyError,
} from '../../../base';
import { CurrentUser } from '../../../core/auth';
import { AccessController } from '../../../core/permission';
import {
AccessController,
WorkspacePolicyService,
} from '../../../core/permission';
import { WorkspaceType } from '../../../core/workspaces';
import { COPILOT_LOCKER } from '../resolver';
import { MAX_EMBEDDABLE_SIZE } from '../utils';
@@ -72,6 +75,7 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
constructor(
private readonly ac: AccessController,
private readonly mutex: Mutex,
private readonly policy: WorkspacePolicyService,
private readonly copilotWorkspace: CopilotWorkspaceService
) {}
@@ -215,10 +219,11 @@ export class CopilotWorkspaceEmbeddingConfigResolver {
@Args('fileId', { type: () => String })
fileId: string
): Promise<boolean> {
await this.ac
.user(user.id)
.workspace(workspaceId)
.assert('Workspace.Settings.Update');
await this.policy.assertWorkspaceRoleAction(
user.id,
workspaceId,
'Workspace.Settings.Update'
);
return await this.copilotWorkspace.removeFile(workspaceId, fileId);
}

View File

@@ -4,75 +4,6 @@ The actual snapshot is saved in `manticoresearch.spec.ts.snap`.
Generated by [AVA](https://avajs.dev).
## should search doc title match chinese word segmentation
> Snapshot 1
[
{
_id: '5373363211628325828',
_source: {
doc_id: 'doc-chinese',
workspace_id: 'workspace-test-doc-title-chinese',
},
fields: {
doc_id: [
'doc-chinese',
],
title: [
'AFFiNE 是一个基于云端的笔记应用',
],
},
highlights: undefined,
},
]
## should search block content match korean ngram
> Snapshot 1
[
{
_id: '1227635764506850985',
_source: {
doc_id: 'doc-korean',
workspace_id: 'workspace-test-block-content-korean',
},
fields: {
block_id: [
'block-korean',
],
content: [
'다람쥐 헌 쳇바퀴에 타고파',
],
},
highlights: undefined,
},
]
## should search block content match japanese kana ngram
> Snapshot 1
[
{
_id: '381498385699454292',
_source: {
doc_id: 'doc-japanese',
workspace_id: 'workspace-test-block-content-japanese',
},
fields: {
block_id: [
'block-japanese',
],
content: [
'いろはにほへと ちりぬるを',
],
},
highlights: undefined,
},
]
## should write document work
> Snapshot 1
@@ -958,7 +889,7 @@ Generated by [AVA](https://avajs.dev).
> Snapshot 1
{
equals: {
term: {
workspace_id: 'workspaceId1',
},
}
@@ -966,7 +897,7 @@ Generated by [AVA](https://avajs.dev).
> Snapshot 2
{
equals: {
term: {
workspace_id: 'workspaceId1',
},
}

View File

@@ -33,8 +33,8 @@ const user = await module.create(Mockers.User);
const workspace = await module.create(Mockers.Workspace);
test.before(async () => {
await searchProvider.recreateTable(SearchTable.block, blockSQL);
await searchProvider.recreateTable(SearchTable.doc, docSQL);
await searchProvider.createTable(SearchTable.block, blockSQL);
await searchProvider.createTable(SearchTable.doc, docSQL);
await searchProvider.write(
SearchTable.block,
@@ -163,135 +163,6 @@ test('should provider is manticoresearch', t => {
t.is(searchProvider.type, SearchProviderType.Manticoresearch);
});
test('should search doc title match chinese word segmentation', async t => {
const workspaceId = 'workspace-test-doc-title-chinese';
const docId = 'doc-chinese';
const title = 'AFFiNE 是一个基于云端的笔记应用';
await searchProvider.write(
SearchTable.doc,
[
{
workspace_id: workspaceId,
doc_id: docId,
title,
},
],
{
refresh: true,
}
);
const result = await searchProvider.search(SearchTable.doc, {
_source: ['workspace_id', 'doc_id'],
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ match: { title: '笔记' } },
],
},
},
fields: ['doc_id', 'title'],
sort: ['_score'],
});
t.true(result.total >= 1);
t.snapshot(
result.nodes
.filter(node => node._source.doc_id === docId)
.map(node => omit(node, ['_score']))
);
});
test('should search block content match korean ngram', async t => {
const workspaceId = 'workspace-test-block-content-korean';
const docId = 'doc-korean';
const blockId = 'block-korean';
const content = '다람쥐 헌 쳇바퀴에 타고파';
await searchProvider.write(
SearchTable.block,
[
{
workspace_id: workspaceId,
doc_id: docId,
block_id: blockId,
content,
flavour: 'affine:paragraph',
},
],
{
refresh: true,
}
);
const result = await searchProvider.search(SearchTable.block, {
_source: ['workspace_id', 'doc_id'],
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ match: { content: '쥐' } },
],
},
},
fields: ['block_id', 'content'],
sort: ['_score'],
});
t.true(result.total >= 1);
t.snapshot(
result.nodes
.filter(node => node.fields.block_id?.[0] === blockId)
.map(node => omit(node, ['_score']))
);
});
test('should search block content match japanese kana ngram', async t => {
const workspaceId = 'workspace-test-block-content-japanese';
const docId = 'doc-japanese';
const blockId = 'block-japanese';
const content = 'いろはにほへと ちりぬるを';
await searchProvider.write(
SearchTable.block,
[
{
workspace_id: workspaceId,
doc_id: docId,
block_id: blockId,
content,
flavour: 'affine:paragraph',
},
],
{
refresh: true,
}
);
const result = await searchProvider.search(SearchTable.block, {
_source: ['workspace_id', 'doc_id'],
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ match: { content: 'へ' } },
],
},
},
fields: ['block_id', 'content'],
sort: ['_score'],
});
t.true(result.total >= 1);
t.snapshot(
result.nodes
.filter(node => node.fields.block_id?.[0] === blockId)
.map(node => omit(node, ['_score']))
);
});
// #region write
test('should write document work', async t => {
@@ -318,7 +189,7 @@ test('should write document work', async t => {
let result = await searchProvider.search(SearchTable.block, {
_source: ['workspace_id', 'doc_id'],
query: { term: { doc_id: { value: docId } } },
query: { match: { doc_id: docId } },
fields: [
'flavour',
'flavour_indexed',
@@ -361,7 +232,7 @@ test('should write document work', async t => {
result = await searchProvider.search(SearchTable.block, {
_source: ['workspace_id', 'doc_id'],
query: { term: { doc_id: { value: docId } } },
query: { match: { doc_id: docId } },
fields: ['flavour', 'block_id', 'content', 'ref_doc_id'],
sort: ['_score'],
});
@@ -392,7 +263,7 @@ test('should write document work', async t => {
result = await searchProvider.search(SearchTable.block, {
_source: ['workspace_id', 'doc_id'],
query: { term: { doc_id: { value: docId } } },
query: { match: { doc_id: docId } },
fields: ['flavour', 'block_id', 'content', 'ref_doc_id'],
sort: ['_score'],
});
@@ -448,8 +319,8 @@ test('should handle ref_doc_id as string[]', async t => {
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ term: { doc_id: { value: docId } } },
{ match: { workspace_id: workspaceId } },
{ match: { doc_id: docId } },
],
},
},
@@ -500,8 +371,8 @@ test('should handle ref_doc_id as string[]', async t => {
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ term: { doc_id: { value: docId } } },
{ match: { workspace_id: workspaceId } },
{ match: { doc_id: docId } },
],
},
},
@@ -545,8 +416,8 @@ test('should handle content as string[]', async t => {
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ term: { doc_id: { value: docId } } },
{ match: { workspace_id: workspaceId } },
{ match: { doc_id: docId } },
],
},
},
@@ -584,8 +455,8 @@ test('should handle content as string[]', async t => {
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ term: { doc_id: { value: docId } } },
{ match: { workspace_id: workspaceId } },
{ match: { doc_id: docId } },
],
},
},
@@ -626,8 +497,8 @@ test('should handle blob as string[]', async t => {
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ term: { doc_id: { value: docId } } },
{ match: { workspace_id: workspaceId } },
{ match: { doc_id: docId } },
],
},
},
@@ -663,8 +534,8 @@ test('should handle blob as string[]', async t => {
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ term: { doc_id: { value: docId } } },
{ match: { workspace_id: workspaceId } },
{ match: { doc_id: docId } },
],
},
},
@@ -700,8 +571,8 @@ test('should handle blob as string[]', async t => {
query: {
bool: {
must: [
{ term: { workspace_id: { value: workspaceId } } },
{ term: { doc_id: { value: docId } } },
{ match: { workspace_id: workspaceId } },
{ match: { doc_id: docId } },
],
},
},
@@ -811,10 +682,8 @@ test('should search query all and get next cursor work', async t => {
'id',
],
query: {
term: {
workspace_id: {
value: workspaceId,
},
match: {
workspace_id: workspaceId,
},
},
fields: ['flavour', 'workspace_id', 'doc_id', 'block_id'],
@@ -839,10 +708,8 @@ test('should search query all and get next cursor work', async t => {
'id',
],
query: {
term: {
workspace_id: {
value: workspaceId,
},
match: {
workspace_id: workspaceId,
},
},
fields: ['flavour', 'workspace_id', 'doc_id', 'block_id'],
@@ -867,10 +734,8 @@ test('should search query all and get next cursor work', async t => {
'id',
],
query: {
term: {
workspace_id: {
value: workspaceId,
},
match: {
workspace_id: workspaceId,
},
},
fields: ['flavour', 'workspace_id', 'doc_id', 'block_id'],
@@ -915,20 +780,16 @@ test('should filter by workspace_id work', async t => {
bool: {
must: [
{
term: {
workspace_id: {
value: workspaceId,
},
match: {
workspace_id: workspaceId,
},
},
{
bool: {
must: [
{
term: {
doc_id: {
value: docId,
},
match: {
doc_id: docId,
},
},
],

View File

@@ -8,12 +8,11 @@ import { createModule } from '../../../__tests__/create-module';
import { Mockers } from '../../../__tests__/mocks';
import { ConfigModule } from '../../../base/config';
import { ServerConfigModule } from '../../../core/config';
import { Models } from '../../../models';
import { SearchProviderFactory } from '../factory';
import { IndexerModule, IndexerService } from '../index';
import { ManticoresearchProvider } from '../providers';
import { UpsertDoc } from '../service';
import { blockSQL, docSQL, SearchTable } from '../tables';
import { SearchTable } from '../tables';
import {
AggregateInput,
SearchInput,
@@ -36,7 +35,6 @@ const module = await createModule({
const indexerService = module.get(IndexerService);
const searchProviderFactory = module.get(SearchProviderFactory);
const manticoresearch = module.get(ManticoresearchProvider);
const models = module.get(Models);
const user = await module.create(Mockers.User);
const workspace = await module.create(Mockers.Workspace, {
snapshot: true,
@@ -52,8 +50,7 @@ test.after.always(async () => {
});
test.before(async () => {
await manticoresearch.recreateTable(SearchTable.block, blockSQL);
await manticoresearch.recreateTable(SearchTable.doc, docSQL);
await indexerService.createTables();
});
test.afterEach.always(async () => {
@@ -2314,29 +2311,3 @@ test('should search docs by keyword work', async t => {
});
// #endregion
test('should rebuild manticore indexes and requeue workspaces', async t => {
const workspace1 = await module.create(Mockers.Workspace, {
indexed: true,
});
const workspace2 = await module.create(Mockers.Workspace, {
indexed: true,
});
const queueCount = module.queue.count('indexer.indexWorkspace');
await indexerService.rebuildManticoreIndexes();
const queuedWorkspaceIds = new Set(
module.queue.add
.getCalls()
.filter(call => call.args[0] === 'indexer.indexWorkspace')
.slice(queueCount)
.map(call => call.args[1].workspaceId)
);
t.true(queuedWorkspaceIds.has(workspace1.id));
t.true(queuedWorkspaceIds.has(workspace2.id));
t.is((await models.workspace.get(workspace1.id))?.indexed, false);
t.is((await models.workspace.get(workspace2.id))?.indexed, false);
});

View File

@@ -38,17 +38,6 @@ const SupportIndexedAttributes = [
'parent_block_id',
];
const SupportExactTermFields = new Set([
'workspace_id',
'doc_id',
'block_id',
'flavour',
'parent_flavour',
'parent_block_id',
'created_by_user_id',
'updated_by_user_id',
]);
const ConvertEmptyStringToNullValueFields = new Set([
'ref_doc_id',
'ref',
@@ -66,20 +55,23 @@ export class ManticoresearchProvider extends ElasticsearchProvider {
table: SearchTable,
mapping: string
): Promise<void> {
const text = await this.#executeSQL(mapping);
const url = `${this.config.provider.endpoint}/cli`;
const response = await fetch(url, {
method: 'POST',
body: mapping,
headers: {
'Content-Type': 'text/plain',
},
});
// manticoresearch cli response is not json, so we need to handle it manually
const text = (await response.text()).trim();
if (!response.ok) {
this.logger.error(`failed to create table ${table}, response: ${text}`);
throw new InternalServerError();
}
this.logger.log(`created table ${table}, response: ${text}`);
}
async dropTable(table: SearchTable): Promise<void> {
const text = await this.#executeSQL(`DROP TABLE IF EXISTS ${table}`);
this.logger.log(`dropped table ${table}, response: ${text}`);
}
async recreateTable(table: SearchTable, mapping: string): Promise<void> {
await this.dropTable(table);
await this.createTable(table, mapping);
}
override async write(
table: SearchTable,
documents: Record<string, unknown>[],
@@ -260,12 +252,6 @@ export class ManticoresearchProvider extends ElasticsearchProvider {
// 1750389254 => new Date(1750389254 * 1000)
return new Date(value * 1000);
}
if (value && typeof value === 'string') {
const timestamp = Date.parse(value);
if (!Number.isNaN(timestamp)) {
return new Date(timestamp);
}
}
return value;
}
@@ -316,10 +302,8 @@ export class ManticoresearchProvider extends ElasticsearchProvider {
// workspace_id: 'workspaceId1'
// }
// }
let termField = options?.termMappingField ?? 'term';
let field = Object.keys(query.term)[0];
let termField =
options?.termMappingField ??
(SupportExactTermFields.has(field) ? 'equals' : 'term');
let value = query.term[field];
if (typeof value === 'object' && 'value' in value) {
if ('boost' in value) {
@@ -448,28 +432,4 @@ export class ManticoresearchProvider extends ElasticsearchProvider {
}
return value;
}
async #executeSQL(sql: string) {
const url = `${this.config.provider.endpoint}/cli`;
const headers: Record<string, string> = {
'Content-Type': 'text/plain',
};
if (this.config.provider.apiKey) {
headers.Authorization = `ApiKey ${this.config.provider.apiKey}`;
} else if (this.config.provider.password) {
headers.Authorization = `Basic ${Buffer.from(`${this.config.provider.username}:${this.config.provider.password}`).toString('base64')}`;
}
const response = await fetch(url, {
method: 'POST',
body: sql,
headers,
});
const text = (await response.text()).trim();
if (!response.ok) {
this.logger.error(`failed to execute SQL "${sql}", response: ${text}`);
throw new InternalServerError();
}
return text;
}
}

View File

@@ -14,7 +14,6 @@ import {
AggregateQueryDSL,
BaseQueryDSL,
HighlightDSL,
ManticoresearchProvider,
OperationOptions,
SearchNode,
SearchProvider,
@@ -131,63 +130,6 @@ export class IndexerService {
}
}
async rebuildManticoreIndexes() {
let searchProvider: SearchProvider | undefined;
try {
searchProvider = this.factory.get();
} catch (err) {
if (err instanceof SearchProviderNotFound) {
this.logger.debug('No search provider found, skip rebuilding tables');
return;
}
throw err;
}
if (!(searchProvider instanceof ManticoresearchProvider)) {
this.logger.debug(
`Search provider ${searchProvider.type} does not need manticore rebuild`
);
return;
}
const mappings = SearchTableMappingStrings[searchProvider.type];
for (const table of Object.keys(mappings) as SearchTable[]) {
await searchProvider.recreateTable(table, mappings[table]);
}
let lastWorkspaceSid = 0;
while (true) {
const workspaces = await this.models.workspace.list(
{ sid: { gt: lastWorkspaceSid } },
{ id: true, sid: true },
100
);
if (!workspaces.length) {
break;
}
for (const workspace of workspaces) {
await this.models.workspace.update(
workspace.id,
{ indexed: false },
false
);
await this.queue.add(
'indexer.indexWorkspace',
{
workspaceId: workspace.id,
},
{
jobId: `indexWorkspace/${workspace.id}`,
priority: 100,
}
);
}
lastWorkspaceSid = workspaces[workspaces.length - 1].sid;
}
}
async write<T extends SearchTable>(
table: T,
documents: UpsertTypeByTable<T>[],

View File

@@ -150,8 +150,6 @@ CREATE TABLE IF NOT EXISTS block (
updated_at timestamp
)
morphology = 'jieba_chinese, lemmatize_en_all, lemmatize_de_all, lemmatize_ru_all, libstemmer_ar, libstemmer_ca, stem_cz, libstemmer_da, libstemmer_nl, libstemmer_fi, libstemmer_fr, libstemmer_el, libstemmer_hi, libstemmer_hu, libstemmer_id, libstemmer_ga, libstemmer_it, libstemmer_lt, libstemmer_ne, libstemmer_no, libstemmer_pt, libstemmer_ro, libstemmer_es, libstemmer_sv, libstemmer_ta, libstemmer_tr'
charset_table = 'non_cjk, chinese'
ngram_len = '1'
ngram_chars = 'U+1100..U+11FF, U+3130..U+318F, U+A960..U+A97F, U+AC00..U+D7AF, U+D7B0..U+D7FF, U+3040..U+30FF, U+0E00..U+0E7F'
charset_table = 'non_cjk, cjk'
index_field_lengths = '1'
`;

View File

@@ -109,8 +109,6 @@ CREATE TABLE IF NOT EXISTS doc (
updated_at timestamp
)
morphology = 'jieba_chinese, lemmatize_en_all, lemmatize_de_all, lemmatize_ru_all, libstemmer_ar, libstemmer_ca, stem_cz, libstemmer_da, libstemmer_nl, libstemmer_fi, libstemmer_fr, libstemmer_el, libstemmer_hi, libstemmer_hu, libstemmer_id, libstemmer_ga, libstemmer_it, libstemmer_lt, libstemmer_ne, libstemmer_no, libstemmer_pt, libstemmer_ro, libstemmer_es, libstemmer_sv, libstemmer_ta, libstemmer_tr'
charset_table = 'non_cjk, chinese'
ngram_len = '1'
ngram_chars = 'U+1100..U+11FF, U+3130..U+318F, U+A960..U+A97F, U+AC00..U+D7AF, U+D7B0..U+D7FF, U+3040..U+30FF, U+0E00..U+0E7F'
charset_table = 'non_cjk, cjk'
index_field_lengths = '1'
`;

View File

@@ -16,6 +16,7 @@ import {
UserFriendlyError,
WorkspaceLicenseAlreadyExists,
} from '../../base';
import { WorkspacePolicyService } from '../../core/permission';
import { Models } from '../../models';
import {
SubscriptionPlan,
@@ -59,7 +60,8 @@ export class LicenseService {
private readonly db: PrismaClient,
private readonly event: EventBus,
private readonly models: Models,
private readonly crypto: CryptoHelper
private readonly crypto: CryptoHelper,
private readonly policy: WorkspacePolicyService
) {}
@OnEvent('workspace.subscription.activated')
@@ -83,6 +85,7 @@ export class LicenseService {
workspaceId,
quantity,
});
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
break;
default:
break;
@@ -96,7 +99,10 @@ export class LicenseService {
}: Events['workspace.subscription.canceled']) {
switch (plan) {
case SubscriptionPlan.SelfHostedTeam:
await this.models.workspaceUser.deleteNonAccepted(workspaceId);
await this.models.workspaceUser.demoteAcceptedAdmins(workspaceId);
await this.models.workspaceFeature.remove(workspaceId, 'team_plan_v1');
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
break;
default:
break;

View File

@@ -1,6 +1,6 @@
import { Injectable } from '@nestjs/common';
import { URLHelper } from '../../../base';
import { InvalidOauthResponse, URLHelper } from '../../../base';
import { OAuthProviderName } from '../config';
import type { OAuthState } from '../types';
import { OAuthAccount, OAuthProvider, Tokens } from './def';
@@ -13,11 +13,17 @@ interface AuthTokenResponse {
export interface UserInfo {
login: string;
email: string;
email: string | null;
avatar_url: string;
name: string;
}
interface UserEmailInfo {
email: string;
primary: boolean;
verified: boolean;
}
@Injectable()
export class GithubOAuthProvider extends OAuthProvider {
provider = OAuthProviderName.GitHub;
@@ -30,7 +36,7 @@ export class GithubOAuthProvider extends OAuthProvider {
return `https://github.com/login/oauth/authorize?${this.url.stringify({
client_id: this.config.clientId,
redirect_uri: this.url.link('/oauth/callback'),
scope: 'user',
scope: 'read:user user:email',
...this.config.args,
state,
})}`;
@@ -56,16 +62,36 @@ export class GithubOAuthProvider extends OAuthProvider {
async getUser(tokens: Tokens, _state: OAuthState): Promise<OAuthAccount> {
const user = await this.fetchJson<UserInfo>('https://api.github.com/user', {
method: 'GET',
headers: {
Authorization: `Bearer ${tokens.accessToken}`,
},
headers: { Authorization: `Bearer ${tokens.accessToken}` },
});
const email = user.email ?? (await this.getVerifiedEmail(tokens));
if (!email) {
throw new InvalidOauthResponse({
reason: 'GitHub account did not have a verified email address.',
});
}
return {
id: user.login,
avatarUrl: user.avatar_url,
email: user.email,
email,
name: user.name,
};
}
private async getVerifiedEmail(tokens: Tokens) {
const emails = await this.fetchJson<UserEmailInfo[]>(
'https://api.github.com/user/emails',
{
method: 'GET',
headers: { Authorization: `Bearer ${tokens.accessToken}` },
}
);
return (
emails.find(email => email.primary && email.verified)?.email ??
emails.find(email => email.verified)?.email
);
}
}

View File

@@ -1,4 +1,4 @@
import { Injectable } from '@nestjs/common';
import { Injectable, Logger } from '@nestjs/common';
import { Cron, CronExpression } from '@nestjs/schedule';
import { PrismaClient, Provider } from '@prisma/client';
@@ -18,6 +18,7 @@ declare global {
'nightly.cleanExpiredOnetimeSubscriptions': {};
'nightly.notifyAboutToExpireWorkspaceSubscriptions': {};
'nightly.reconcileRevenueCatSubscriptions': {};
'nightly.reconcileStripeSubscriptions': {};
'nightly.reconcileStripeRefunds': {};
'nightly.revenuecat.syncUser': { userId: string };
}
@@ -25,6 +26,8 @@ declare global {
@Injectable()
export class SubscriptionCronJobs {
private readonly logger = new Logger(SubscriptionCronJobs.name);
constructor(
private readonly db: PrismaClient,
private readonly event: EventBus,
@@ -61,6 +64,12 @@ export class SubscriptionCronJobs {
{ jobId: 'nightly-payment-reconcile-revenuecat-subscriptions' }
);
await this.queue.add(
'nightly.reconcileStripeSubscriptions',
{},
{ jobId: 'nightly-payment-reconcile-stripe-subscriptions' }
);
await this.queue.add(
'nightly.reconcileStripeRefunds',
{},
@@ -202,6 +211,48 @@ export class SubscriptionCronJobs {
await this.rcHandler.syncAppUser(payload.userId);
}
@OnJob('nightly.reconcileStripeSubscriptions')
async reconcileStripeSubscriptions() {
const stripe = this.stripeFactory.stripe;
const subs = await this.db.subscription.findMany({
where: {
provider: Provider.stripe,
stripeSubscriptionId: { not: null },
status: {
in: [
SubscriptionStatus.Active,
SubscriptionStatus.Trialing,
SubscriptionStatus.PastDue,
],
},
},
select: { stripeSubscriptionId: true },
});
const subscriptionIds = Array.from(
new Set(
subs
.map(sub => sub.stripeSubscriptionId)
.filter((id): id is string => !!id)
)
);
for (const subscriptionId of subscriptionIds) {
try {
const subscription = await stripe.subscriptions.retrieve(
subscriptionId,
{ expand: ['customer'] }
);
await this.subscription.saveStripeSubscription(subscription);
} catch (e) {
this.logger.error(
`Failed to reconcile stripe subscription ${subscriptionId}`,
e
);
}
}
}
@OnJob('nightly.reconcileStripeRefunds')
async reconcileStripeRefunds() {
const stripe = this.stripeFactory.stripe;

View File

@@ -1,6 +1,7 @@
import { Injectable } from '@nestjs/common';
import { EventBus, OnEvent } from '../../base';
import { WorkspacePolicyService } from '../../core/permission';
import { WorkspaceService } from '../../core/workspaces';
import { Models } from '../../models';
import { SubscriptionPlan, SubscriptionRecurring } from './types';
@@ -10,7 +11,8 @@ export class PaymentEventHandlers {
constructor(
private readonly workspace: WorkspaceService,
private readonly models: Models,
private readonly event: EventBus
private readonly event: EventBus,
private readonly policy: WorkspacePolicyService
) {}
@OnEvent('workspace.subscription.activated')
@@ -40,6 +42,7 @@ export class PaymentEventHandlers {
// we only send emails when the team workspace is activated
await this.workspace.sendTeamWorkspaceUpgradedEmail(workspaceId);
}
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
break;
}
default:
@@ -54,7 +57,10 @@ export class PaymentEventHandlers {
}: Events['workspace.subscription.canceled']) {
switch (plan) {
case SubscriptionPlan.Team:
await this.models.workspaceUser.deleteNonAccepted(workspaceId);
await this.models.workspaceUser.demoteAcceptedAdmins(workspaceId);
await this.models.workspaceFeature.remove(workspaceId, 'team_plan_v1');
await this.policy.reconcileWorkspaceQuotaState(workspaceId);
break;
default:
break;
@@ -81,6 +87,7 @@ export class PaymentEventHandlers {
recurring === 'lifetime' ? 'lifetime_pro_plan_v1' : 'pro_plan_v1',
'subscription activated'
);
await this.policy.reconcileOwnedWorkspaces(userId);
break;
default:
break;
@@ -105,6 +112,7 @@ export class PaymentEventHandlers {
'free_plan_v1',
'lifetime subscription canceled'
);
await this.policy.reconcileOwnedWorkspaces(userId);
break;
}
@@ -121,6 +129,7 @@ export class PaymentEventHandlers {
'free_plan_v1',
'subscription canceled'
);
await this.policy.reconcileOwnedWorkspaces(userId);
}
break;
}

View File

@@ -433,9 +433,7 @@ export const NbStoreNativeDBApis: NativeDBApis = {
id: string,
docId: string
): Promise<DocIndexedClock | null> {
return NbStore.getDocIndexedClock({ id, docId }).then(clock =>
clock ? { ...clock, timestamp: new Date(clock.timestamp) } : null
);
return NbStore.getDocIndexedClock({ id, docId });
},
setDocIndexedClock: function (
id: string,

View File

@@ -13,19 +13,6 @@ import type { FrameworkProvider } from '@toeverything/infra';
import { getCurrentWorkspace, isAiEnabled } from './utils';
const logger = new DebugLogger('electron-renderer:recording');
const RECORDING_PROCESS_RETRY_MS = 1000;
const NATIVE_RECORDING_MIME_TYPE = 'audio/ogg';
type ProcessingRecordingStatus = {
id: number;
status: 'processing';
appName?: string;
blockCreationStatus?: undefined;
filepath: string;
startTime: number;
};
type WorkspaceHandle = NonNullable<ReturnType<typeof getCurrentWorkspace>>;
async function readRecordingFile(filepath: string) {
if (apis?.recording?.readRecordingFile) {
@@ -58,217 +45,118 @@ async function saveRecordingBlob(blobEngine: BlobEngine, filepath: string) {
logger.debug('Saving recording', filepath);
const opusBuffer = await readRecordingFile(filepath);
const blob = new Blob([opusBuffer], {
type: NATIVE_RECORDING_MIME_TYPE,
type: 'audio/mp4',
});
const blobId = await blobEngine.set(blob);
logger.debug('Recording saved', blobId);
return { blob, blobId };
}
function shouldProcessRecording(
status: unknown
): status is ProcessingRecordingStatus {
return (
!!status &&
typeof status === 'object' &&
'status' in status &&
status.status === 'processing' &&
'filepath' in status &&
typeof status.filepath === 'string' &&
!('blockCreationStatus' in status && status.blockCreationStatus)
);
}
async function createRecordingDoc(
frameworkProvider: FrameworkProvider,
workspace: WorkspaceHandle['workspace'],
status: ProcessingRecordingStatus
) {
const docsService = workspace.scope.get(DocsService);
const aiEnabled = isAiEnabled(frameworkProvider);
const recordingFilepath = status.filepath;
const timestamp = i18nTime(status.startTime, {
absolute: {
accuracy: 'minute',
noYear: true,
},
});
await new Promise<void>((resolve, reject) => {
const docProps: DocProps = {
onStoreLoad: (doc, { noteId }) => {
void (async () => {
// it takes a while to save the blob, so we show the attachment first
const { blobId, blob } = await saveRecordingBlob(
doc.workspace.blobSync,
recordingFilepath
);
// name + timestamp(readable) + extension
const attachmentName =
(status.appName ?? 'System Audio') + ' ' + timestamp + '.opus';
const attachmentId = doc.addBlock(
'affine:attachment',
{
name: attachmentName,
type: NATIVE_RECORDING_MIME_TYPE,
size: blob.size,
sourceId: blobId,
embed: true,
},
noteId
);
const model = doc.getBlock(attachmentId)
?.model as AttachmentBlockModel;
if (!aiEnabled) {
return;
}
using currentWorkspace = getCurrentWorkspace(frameworkProvider);
if (!currentWorkspace) {
return;
}
const { workspace } = currentWorkspace;
using audioAttachment = workspace.scope
.get(AudioAttachmentService)
.get(model);
audioAttachment?.obj
.transcribe()
.then(() => {
track.doc.editor.audioBlock.transcribeRecording({
type: 'Meeting record',
method: 'success',
option: 'Auto transcribing',
});
})
.catch(err => {
logger.error('Failed to transcribe recording', err);
});
})().then(resolve, reject);
},
};
const page = docsService.createDoc({
docProps,
title:
'Recording ' + (status.appName ?? 'System Audio') + ' ' + timestamp,
primaryMode: 'page',
});
workspace.scope.get(WorkbenchService).workbench.openDoc(page.id);
});
}
export function setupRecordingEvents(frameworkProvider: FrameworkProvider) {
let pendingStatus: ProcessingRecordingStatus | null = null;
let retryTimer: ReturnType<typeof setTimeout> | null = null;
let processingStatusId: number | null = null;
const clearRetry = () => {
if (retryTimer !== null) {
clearTimeout(retryTimer);
retryTimer = null;
}
};
const clearPending = (id?: number) => {
if (id === undefined || pendingStatus?.id === id) {
pendingStatus = null;
clearRetry();
}
if (id === undefined || processingStatusId === id) {
processingStatusId = null;
}
};
const scheduleRetry = () => {
if (!pendingStatus || retryTimer !== null) {
return;
}
retryTimer = setTimeout(() => {
retryTimer = null;
void processPendingStatus().catch(console.error);
}, RECORDING_PROCESS_RETRY_MS);
};
const processPendingStatus = async () => {
const status = pendingStatus;
if (!status || processingStatusId === status.id) {
return;
}
let isActiveTab = false;
try {
isActiveTab = !!(await apis?.ui.isActiveTab());
} catch (error) {
logger.error('Failed to probe active recording tab', error);
scheduleRetry();
return;
}
if (!isActiveTab) {
scheduleRetry();
return;
}
using currentWorkspace = getCurrentWorkspace(frameworkProvider);
if (!currentWorkspace) {
// Workspace can lag behind the post-recording status update for a short
// time; keep retrying instead of permanently failing the import.
scheduleRetry();
return;
}
processingStatusId = status.id;
try {
await createRecordingDoc(
frameworkProvider,
currentWorkspace.workspace,
status
);
await apis?.recording.setRecordingBlockCreationStatus(
status.id,
'success'
);
clearPending(status.id);
} catch (error) {
logger.error('Failed to create recording block', error);
try {
await apis?.recording.setRecordingBlockCreationStatus(
status.id,
'failed',
error instanceof Error ? error.message : undefined
);
} finally {
clearPending(status.id);
}
} finally {
if (pendingStatus?.id === status.id) {
processingStatusId = null;
scheduleRetry();
}
}
};
events?.recording.onRecordingStatusChanged(status => {
if (shouldProcessRecording(status)) {
pendingStatus = status;
clearRetry();
void processPendingStatus().catch(console.error);
return;
}
(async () => {
if ((await apis?.ui.isActiveTab()) && status?.status === 'ready') {
using currentWorkspace = getCurrentWorkspace(frameworkProvider);
if (!currentWorkspace) {
// maybe the workspace is not ready yet, eg. for shared workspace view
await apis?.recording.handleBlockCreationFailed(status.id);
return;
}
const { workspace } = currentWorkspace;
const docsService = workspace.scope.get(DocsService);
const aiEnabled = isAiEnabled(frameworkProvider);
if (!status) {
clearPending();
return;
}
const timestamp = i18nTime(status.startTime, {
absolute: {
accuracy: 'minute',
noYear: true,
},
});
if (pendingStatus?.id === status.id) {
clearPending(status.id);
}
const docProps: DocProps = {
onStoreLoad: (doc, { noteId }) => {
(async () => {
if (status.filepath) {
// it takes a while to save the blob, so we show the attachment first
const { blobId, blob } = await saveRecordingBlob(
doc.workspace.blobSync,
status.filepath
);
// name + timestamp(readable) + extension
const attachmentName =
(status.appName ?? 'System Audio') +
' ' +
timestamp +
'.opus';
// add size and sourceId to the attachment later
const attachmentId = doc.addBlock(
'affine:attachment',
{
name: attachmentName,
type: 'audio/opus',
size: blob.size,
sourceId: blobId,
embed: true,
},
noteId
);
const model = doc.getBlock(attachmentId)
?.model as AttachmentBlockModel;
if (!aiEnabled) {
return;
}
using currentWorkspace = getCurrentWorkspace(frameworkProvider);
if (!currentWorkspace) {
return;
}
const { workspace } = currentWorkspace;
using audioAttachment = workspace.scope
.get(AudioAttachmentService)
.get(model);
audioAttachment?.obj
.transcribe()
.then(() => {
track.doc.editor.audioBlock.transcribeRecording({
type: 'Meeting record',
method: 'success',
option: 'Auto transcribing',
});
})
.catch(err => {
logger.error('Failed to transcribe recording', err);
});
} else {
throw new Error('No attachment model found');
}
})()
.then(async () => {
await apis?.recording.handleBlockCreationSuccess(status.id);
})
.catch(error => {
logger.error('Failed to transcribe recording', error);
return apis?.recording.handleBlockCreationFailed(
status.id,
error
);
})
.catch(error => {
console.error('unknown error', error);
});
},
};
const page = docsService.createDoc({
docProps,
title:
'Recording ' + (status.appName ?? 'System Audio') + ' ' + timestamp,
primaryMode: 'page',
});
workspace.scope.get(WorkbenchService).workbench.openDoc(page.id);
}
})().catch(console.error);
});
}

View File

@@ -1,17 +1,28 @@
import { Button } from '@affine/component';
import { useAsyncCallback } from '@affine/core/components/hooks/affine-async-hooks';
import { appIconMap } from '@affine/core/utils';
import {
createStreamEncoder,
encodeRawBufferToOpus,
type OpusStreamEncoder,
} from '@affine/core/utils/opus-encoding';
import { apis, events } from '@affine/electron-api';
import { useI18n } from '@affine/i18n';
import track from '@affine/track';
import { useEffect, useMemo, useRef, useState } from 'react';
import { useEffect, useMemo, useState } from 'react';
import * as styles from './styles.css';
type Status = {
id: number;
status: 'new' | 'recording' | 'processing' | 'ready';
blockCreationStatus?: 'success' | 'failed';
status:
| 'new'
| 'recording'
| 'paused'
| 'stopped'
| 'ready'
| 'create-block-success'
| 'create-block-failed';
appName?: string;
appGroupId?: number;
icon?: Buffer;
@@ -47,7 +58,6 @@ const appIcon = appIconMap[BUILD_CONFIG.appBuildType];
export function Recording() {
const status = useRecordingStatus();
const trackedNewRecordingIdsRef = useRef<Set<number>>(new Set());
const t = useI18n();
const textElement = useMemo(() => {
@@ -56,19 +66,14 @@ export function Recording() {
}
if (status.status === 'new') {
return t['com.affine.recording.new']();
} else if (
status.status === 'ready' &&
status.blockCreationStatus === 'success'
) {
} else if (status.status === 'create-block-success') {
return t['com.affine.recording.success.prompt']();
} else if (
status.status === 'ready' &&
status.blockCreationStatus === 'failed'
) {
} else if (status.status === 'create-block-failed') {
return t['com.affine.recording.failed.prompt']();
} else if (
status.status === 'recording' ||
status.status === 'processing'
status.status === 'ready' ||
status.status === 'stopped'
) {
if (status.appName) {
return t['com.affine.recording.recording']({
@@ -100,16 +105,106 @@ export function Recording() {
await apis?.recording?.stopRecording(status.id);
}, [status]);
useEffect(() => {
if (!status || status.status !== 'new') return;
if (trackedNewRecordingIdsRef.current.has(status.id)) return;
const handleProcessStoppedRecording = useAsyncCallback(
async (currentStreamEncoder?: OpusStreamEncoder) => {
let id: number | undefined;
try {
const result = await apis?.recording?.getCurrentRecording();
trackedNewRecordingIdsRef.current.add(status.id);
track.popup.$.recordingBar.toggleRecordingBar({
type: 'Meeting record',
appName: status.appName || 'System Audio',
if (!result) {
return;
}
id = result.id;
const { filepath, sampleRate, numberOfChannels } = result;
if (!filepath || !sampleRate || !numberOfChannels) {
return;
}
const [buffer] = await Promise.all([
currentStreamEncoder
? currentStreamEncoder.finish()
: encodeRawBufferToOpus({
filepath,
sampleRate,
numberOfChannels,
}),
new Promise<void>(resolve => {
setTimeout(() => {
resolve();
}, 500); // wait at least 500ms for better user experience
}),
]);
await apis?.recording.readyRecording(result.id, buffer);
} catch (error) {
console.error('Failed to stop recording', error);
await apis?.popup?.dismissCurrentRecording();
if (id) {
await apis?.recording.removeRecording(id);
}
}
},
[]
);
useEffect(() => {
let removed = false;
let currentStreamEncoder: OpusStreamEncoder | undefined;
apis?.recording
.getCurrentRecording()
.then(status => {
if (status) {
return handleRecordingStatusChanged(status);
}
return;
})
.catch(console.error);
const handleRecordingStatusChanged = async (status: Status) => {
if (removed) {
return;
}
if (status?.status === 'new') {
track.popup.$.recordingBar.toggleRecordingBar({
type: 'Meeting record',
appName: status.appName || 'System Audio',
});
}
if (
status?.status === 'recording' &&
status.sampleRate &&
status.numberOfChannels &&
(!currentStreamEncoder || currentStreamEncoder.id !== status.id)
) {
currentStreamEncoder?.close();
currentStreamEncoder = createStreamEncoder(status.id, {
sampleRate: status.sampleRate,
numberOfChannels: status.numberOfChannels,
});
currentStreamEncoder.poll().catch(console.error);
}
if (status?.status === 'stopped') {
handleProcessStoppedRecording(currentStreamEncoder);
currentStreamEncoder = undefined;
}
};
// allow processing stopped event in tray menu as well:
const unsubscribe = events?.recording.onRecordingStatusChanged(status => {
if (status) {
handleRecordingStatusChanged(status).catch(console.error);
}
});
}, [status]);
return () => {
removed = true;
unsubscribe?.();
currentStreamEncoder?.close();
};
}, [handleProcessStoppedRecording]);
const handleStartRecording = useAsyncCallback(async () => {
if (!status) {
@@ -154,10 +249,7 @@ export function Recording() {
{t['com.affine.recording.stop']()}
</Button>
);
} else if (
status.status === 'processing' ||
(status.status === 'ready' && !status.blockCreationStatus)
) {
} else if (status.status === 'stopped' || status.status === 'ready') {
return (
<Button
variant="error"
@@ -166,19 +258,13 @@ export function Recording() {
disabled
/>
);
} else if (
status.status === 'ready' &&
status.blockCreationStatus === 'success'
) {
} else if (status.status === 'create-block-success') {
return (
<Button variant="primary" onClick={handleDismiss}>
{t['com.affine.recording.success.button']()}
</Button>
);
} else if (
status.status === 'ready' &&
status.blockCreationStatus === 'failed'
) {
} else if (status.status === 'create-block-failed') {
return (
<>
<Button variant="plain" onClick={handleDismiss}>

View File

@@ -5,7 +5,6 @@ import { parseUniversalId } from '@affine/nbstore';
import fs from 'fs-extra';
import { nanoid } from 'nanoid';
import { isPathInsideBase } from '../../shared/utils';
import { logger } from '../logger';
import { mainRPC } from '../main-rpc';
import { getDocStoragePool } from '../nbstore';
@@ -39,6 +38,31 @@ export interface SelectDBFileLocationResult {
canceled?: boolean;
}
// provide a backdoor to set dialog path for testing in playwright
export interface FakeDialogResult {
canceled?: boolean;
filePath?: string;
filePaths?: string[];
}
// result will be used in the next call to showOpenDialog
// if it is being read once, it will be reset to undefined
let fakeDialogResult: FakeDialogResult | undefined = undefined;
function getFakedResult() {
const result = fakeDialogResult;
fakeDialogResult = undefined;
return result;
}
export function setFakeDialogResult(result: FakeDialogResult | undefined) {
fakeDialogResult = result;
// for convenience, we will fill filePaths with filePath if it is not set
if (result?.filePaths === undefined && result?.filePath !== undefined) {
result.filePaths = [result.filePath];
}
}
const extension = 'affine';
function getDefaultDBFileName(name: string, id: string) {
@@ -63,33 +87,12 @@ async function isSameFilePath(sourcePath: string, targetPath: string) {
return true;
}
const [resolvedSourcePath, resolvedTargetPath] = await Promise.all([
const [sourceRealPath, targetRealPath] = await Promise.all([
resolveExistingPath(sourcePath),
resolveExistingPath(targetPath),
]);
return !!resolvedSourcePath && resolvedSourcePath === resolvedTargetPath;
}
async function normalizeImportDBPath(selectedPath: string) {
if (!(await fs.pathExists(selectedPath))) {
return null;
}
const [normalizedPath, workspacesBasePath] = await Promise.all([
resolveExistingPath(selectedPath),
resolveExistingPath(await getWorkspacesBasePath()),
]);
const resolvedSelectedPath = normalizedPath ?? resolve(selectedPath);
const resolvedWorkspacesBasePath =
workspacesBasePath ?? resolve(await getWorkspacesBasePath());
if (isPathInsideBase(resolvedWorkspacesBasePath, resolvedSelectedPath)) {
logger.warn('loadDBFile: db file in app data dir');
return null;
}
return resolvedSelectedPath;
return !!sourceRealPath && sourceRealPath === targetRealPath;
}
/**
@@ -110,26 +113,29 @@ export async function saveDBFileAs(
await pool.connect(universalId, dbPath);
await pool.checkpoint(universalId); // make sure all changes (WAL) are written to db
const fakedResult = getFakedResult();
if (!dbPath) {
return {
error: 'DB_FILE_PATH_INVALID',
};
}
const ret = await mainRPC.showSaveDialog({
properties: ['showOverwriteConfirmation'],
title: 'Save Workspace',
showsTagField: false,
buttonLabel: 'Save',
filters: [
{
extensions: [extension],
name: '',
},
],
defaultPath: getDefaultDBFileName(name, id),
message: 'Save Workspace as a SQLite Database file',
});
const ret =
fakedResult ??
(await mainRPC.showSaveDialog({
properties: ['showOverwriteConfirmation'],
title: 'Save Workspace',
showsTagField: false,
buttonLabel: 'Save',
filters: [
{
extensions: [extension],
name: '',
},
],
defaultPath: getDefaultDBFileName(name, id),
message: 'Save Workspace as a SQLite Database file',
}));
const filePath = ret.filePath;
if (ret.canceled || !filePath) {
@@ -154,9 +160,11 @@ export async function saveDBFileAs(
}
}
logger.log('saved', filePath);
mainRPC.showItemInFolder(filePath).catch(err => {
console.error(err);
});
if (!fakedResult) {
mainRPC.showItemInFolder(filePath).catch(err => {
console.error(err);
});
}
return { filePath };
} catch (err) {
logger.error('saveDBFileAs', err);
@@ -168,13 +176,15 @@ export async function saveDBFileAs(
export async function selectDBFileLocation(): Promise<SelectDBFileLocationResult> {
try {
const ret = await mainRPC.showOpenDialog({
properties: ['openDirectory'],
title: 'Set Workspace Storage Location',
buttonLabel: 'Select',
defaultPath: await mainRPC.getPath('documents'),
message: "Select a location to store the workspace's database file",
});
const ret =
getFakedResult() ??
(await mainRPC.showOpenDialog({
properties: ['openDirectory'],
title: 'Set Workspace Storage Location',
buttonLabel: 'Select',
defaultPath: await mainRPC.getPath('documents'),
message: "Select a location to store the workspace's database file",
}));
const dir = ret.filePaths?.[0];
if (ret.canceled || !dir) {
return {
@@ -204,29 +214,39 @@ export async function selectDBFileLocation(): Promise<SelectDBFileLocationResult
* update the local workspace id list and then connect to it.
*
*/
export async function loadDBFile(): Promise<LoadDBFileResult> {
export async function loadDBFile(
dbFilePath?: string
): Promise<LoadDBFileResult> {
try {
const ret = await mainRPC.showOpenDialog({
properties: ['openFile'],
title: 'Load Workspace',
buttonLabel: 'Load',
filters: [
{
name: 'SQLite Database',
// do we want to support other file format?
extensions: ['db', 'affine'],
},
],
message: 'Load Workspace from a AFFiNE file',
});
const selectedPath = ret.filePaths?.[0];
if (ret.canceled || !selectedPath) {
const provided =
getFakedResult() ??
(dbFilePath
? { filePath: dbFilePath, filePaths: [dbFilePath], canceled: false }
: undefined);
const ret =
provided ??
(await mainRPC.showOpenDialog({
properties: ['openFile'],
title: 'Load Workspace',
buttonLabel: 'Load',
filters: [
{
name: 'SQLite Database',
// do we want to support other file format?
extensions: ['db', 'affine'],
},
],
message: 'Load Workspace from a AFFiNE file',
}));
const originalPath = ret.filePaths?.[0];
if (ret.canceled || !originalPath) {
logger.info('loadDBFile canceled');
return { canceled: true };
}
const originalPath = await normalizeImportDBPath(selectedPath);
if (!originalPath) {
// the imported file should not be in app data dir
if (originalPath.startsWith(await getWorkspacesBasePath())) {
logger.warn('loadDBFile: db file in app data dir');
return { error: 'DB_FILE_PATH_INVALID' };
}
@@ -279,26 +299,22 @@ async function cpV1DBFile(
}
const connection = new SqliteConnection(originalPath);
try {
if (!(await connection.validateImportSchema())) {
return { error: 'DB_FILE_INVALID' };
}
const internalFilePath = await getWorkspaceDBPath('workspace', workspaceId);
await fs.ensureDir(parse(internalFilePath).dir);
await connection.vacuumInto(internalFilePath);
logger.info(`loadDBFile, vacuum: ${originalPath} -> ${internalFilePath}`);
await storeWorkspaceMeta(workspaceId, {
id: workspaceId,
mainDBPath: internalFilePath,
});
return {
workspaceId,
};
} finally {
await connection.close();
if (!(await connection.validateImportSchema())) {
return { error: 'DB_FILE_INVALID' };
}
const internalFilePath = await getWorkspaceDBPath('workspace', workspaceId);
await fs.ensureDir(parse(internalFilePath).dir);
await connection.vacuumInto(internalFilePath);
logger.info(`loadDBFile, vacuum: ${originalPath} -> ${internalFilePath}`);
await storeWorkspaceMeta(workspaceId, {
id: workspaceId,
mainDBPath: internalFilePath,
});
return {
workspaceId,
};
}

View File

@@ -1,8 +1,13 @@
import { loadDBFile, saveDBFileAs, selectDBFileLocation } from './dialog';
import {
loadDBFile,
saveDBFileAs,
selectDBFileLocation,
setFakeDialogResult,
} from './dialog';
export const dialogHandlers = {
loadDBFile: async () => {
return loadDBFile();
loadDBFile: async (dbFilePath?: string) => {
return loadDBFile(dbFilePath);
},
saveDBFileAs: async (universalId: string, name: string) => {
return saveDBFileAs(universalId, name);
@@ -10,4 +15,9 @@ export const dialogHandlers = {
selectDBFileLocation: async () => {
return selectDBFileLocation();
},
setFakeDialogResult: async (
result: Parameters<typeof setFakeDialogResult>[0]
) => {
return setFakeDialogResult(result);
},
};

View File

@@ -1,18 +1,13 @@
import path from 'node:path';
import { DocStorage, ValidationResult } from '@affine/native';
import { DocStorage } from '@affine/native';
import {
parseUniversalId,
universalId as generateUniversalId,
} from '@affine/nbstore';
import fs from 'fs-extra';
import { nanoid } from 'nanoid';
import { applyUpdate, Doc as YDoc } from 'yjs';
import {
normalizeWorkspaceIdForPath,
resolveExistingPathInBase,
} from '../../shared/utils';
import { logger } from '../logger';
import { getDocStoragePool } from '../nbstore';
import { ensureSQLiteDisconnected } from '../nbstore/v1/ensure-db';
@@ -23,7 +18,6 @@ import {
getSpaceBasePath,
getSpaceDBPath,
getWorkspaceBasePathV1,
getWorkspaceDBPath,
getWorkspaceMeta,
} from './meta';
@@ -64,7 +58,7 @@ export async function trashWorkspace(universalId: string) {
const dbPath = await getSpaceDBPath(peer, type, id);
const basePath = await getDeletedWorkspacesBasePath();
const movedPath = path.join(basePath, normalizeWorkspaceIdForPath(id));
const movedPath = path.join(basePath, `${id}`);
try {
const storage = new DocStorage(dbPath);
if (await storage.validate()) {
@@ -264,88 +258,12 @@ export async function getDeletedWorkspaces() {
};
}
async function importLegacyWorkspaceDb(
originalPath: string,
workspaceId: string
) {
const { SqliteConnection } = await import('@affine/native');
const validationResult = await SqliteConnection.validate(originalPath);
if (validationResult !== ValidationResult.Valid) {
return {};
}
const connection = new SqliteConnection(originalPath);
if (!(await connection.validateImportSchema())) {
return {};
}
const internalFilePath = await getWorkspaceDBPath('workspace', workspaceId);
await fs.ensureDir(path.parse(internalFilePath).dir);
await connection.vacuumInto(internalFilePath);
logger.info(
`recoverBackupWorkspace, vacuum: ${originalPath} -> ${internalFilePath}`
);
await storeWorkspaceMeta(workspaceId, {
id: workspaceId,
mainDBPath: internalFilePath,
});
return {
workspaceId,
};
}
async function importWorkspaceDb(originalPath: string) {
const workspaceId = nanoid(10);
let storage = new DocStorage(originalPath);
if (!(await storage.validate())) {
return await importLegacyWorkspaceDb(originalPath, workspaceId);
}
if (!(await storage.validateImportSchema())) {
return {};
}
const internalFilePath = await getSpaceDBPath(
'local',
'workspace',
workspaceId
);
await fs.ensureDir(path.parse(internalFilePath).dir);
await storage.vacuumInto(internalFilePath);
logger.info(
`recoverBackupWorkspace, vacuum: ${originalPath} -> ${internalFilePath}`
);
storage = new DocStorage(internalFilePath);
await storage.setSpaceId(workspaceId);
return {
workspaceId,
};
}
export async function deleteBackupWorkspace(id: string) {
const basePath = await getDeletedWorkspacesBasePath();
const workspacePath = path.join(basePath, normalizeWorkspaceIdForPath(id));
const workspacePath = path.join(basePath, id);
await fs.rmdir(workspacePath, { recursive: true });
logger.info(
'deleteBackupWorkspace',
`Deleted backup workspace: ${workspacePath}`
);
}
export async function recoverBackupWorkspace(id: string) {
const basePath = await getDeletedWorkspacesBasePath();
const workspacePath = path.join(basePath, normalizeWorkspaceIdForPath(id));
const dbPath = await resolveExistingPathInBase(
basePath,
path.join(workspacePath, 'storage.db'),
{ label: 'backup workspace filepath' }
);
return await importWorkspaceDb(dbPath);
}

View File

@@ -4,7 +4,6 @@ import {
deleteWorkspace,
getDeletedWorkspaces,
listLocalWorkspaceIds,
recoverBackupWorkspace,
trashWorkspace,
} from './handlers';
@@ -20,6 +19,5 @@ export const workspaceHandlers = {
return getDeletedWorkspaces();
},
deleteBackupWorkspace: async (id: string) => deleteBackupWorkspace(id),
recoverBackupWorkspace: async (id: string) => recoverBackupWorkspace(id),
listLocalWorkspaceIds: async () => listLocalWorkspaceIds(),
};

View File

@@ -2,7 +2,7 @@ import path from 'node:path';
import { type SpaceType } from '@affine/nbstore';
import { normalizeWorkspaceIdForPath } from '../../shared/utils';
import { isWindows } from '../../shared/utils';
import { mainRPC } from '../main-rpc';
import type { WorkspaceMeta } from '../type';
@@ -24,11 +24,10 @@ export async function getWorkspaceBasePathV1(
spaceType: SpaceType,
workspaceId: string
) {
const safeWorkspaceId = normalizeWorkspaceIdForPath(workspaceId);
return path.join(
await getAppDataPath(),
spaceType === 'userspace' ? 'userspaces' : 'workspaces',
safeWorkspaceId
isWindows() ? workspaceId.replace(':', '_') : workspaceId
);
}
@@ -53,11 +52,10 @@ export async function getSpaceDBPath(
spaceType: SpaceType,
id: string
) {
const safeId = normalizeWorkspaceIdForPath(id);
return path.join(
await getSpaceBasePath(spaceType),
escapeFilename(peer),
safeId,
id,
'storage.db'
);
}

View File

@@ -5,46 +5,24 @@ import { app, net, protocol, session } from 'electron';
import cookieParser from 'set-cookie-parser';
import { anotherHost, mainHost } from '../shared/internal-origin';
import {
isPathInsideBase,
isWindows,
resolveExistingPathInBase,
resolvePathInBase,
resourcesPath,
} from '../shared/utils';
import { isWindows, resourcesPath } from '../shared/utils';
import { buildType, isDev } from './config';
import { logger } from './logger';
const webStaticDir = join(resourcesPath, 'web-static');
const devServerBase = process.env.DEV_SERVER_URL;
const localWhiteListDirs = [
path.resolve(app.getPath('sessionData')),
path.resolve(app.getPath('temp')),
path.resolve(app.getPath('sessionData')).toLowerCase(),
path.resolve(app.getPath('temp')).toLowerCase(),
];
function isPathInWhiteList(filepath: string) {
const lowerFilePath = filepath.toLowerCase();
return localWhiteListDirs.some(whitelistDir =>
isPathInsideBase(whitelistDir, filepath, {
caseInsensitive: isWindows(),
})
lowerFilePath.startsWith(whitelistDir)
);
}
async function resolveWhitelistedLocalPath(filepath: string) {
for (const whitelistDir of localWhiteListDirs) {
try {
return await resolveExistingPathInBase(whitelistDir, filepath, {
caseInsensitive: isWindows(),
label: 'filepath',
});
} catch {
continue;
}
}
throw new Error('Invalid filepath');
}
const apiBaseByBuildType: Record<typeof buildType, string> = {
stable: 'https://app.affine.pro',
beta: 'https://insider.affine.pro',
@@ -116,14 +94,15 @@ async function handleFileRequest(request: Request) {
// for relative path, load the file in resources
if (!isAbsolutePath) {
if (urlObject.pathname.split('/').at(-1)?.includes('.')) {
const decodedPath = decodeURIComponent(urlObject.pathname).replace(
/^\/+/,
''
);
filepath = resolvePathInBase(webStaticDir, decodedPath, {
caseInsensitive: isWindows(),
label: 'filepath',
});
// Sanitize pathname to prevent path traversal attacks
const decodedPath = decodeURIComponent(urlObject.pathname);
const normalizedPath = join(webStaticDir, decodedPath).normalize();
if (!normalizedPath.startsWith(webStaticDir)) {
// Attempted path traversal - reject by using empty path
filepath = join(webStaticDir, '');
} else {
filepath = normalizedPath;
}
} else {
// else, fallback to load the index.html instead
filepath = join(webStaticDir, 'index.html');
@@ -134,10 +113,10 @@ async function handleFileRequest(request: Request) {
if (isWindows()) {
filepath = path.resolve(filepath.replace(/^\//, ''));
}
// security check if the filepath is within app.getPath('sessionData')
if (urlObject.host !== 'local-file' || !isPathInWhiteList(filepath)) {
throw new Error('Invalid filepath');
}
filepath = await resolveWhitelistedLocalPath(filepath);
}
return net.fetch(pathToFileURL(filepath).toString(), clonedRequest);
}

View File

@@ -1,10 +1,11 @@
/* oxlint-disable no-var-requires */
import { execSync } from 'node:child_process';
import { createHash } from 'node:crypto';
import fsp from 'node:fs/promises';
import path from 'node:path';
// Should not load @affine/native for unsupported platforms
import type * as NativeModuleType from '@affine/native';
import type { ShareableContent as ShareableContentType } from '@affine/native';
import { app, systemPreferences } from 'electron';
import fs from 'fs-extra';
import { debounce } from 'lodash-es';
@@ -19,12 +20,7 @@ import {
} from 'rxjs';
import { filter, map, shareReplay } from 'rxjs/operators';
import {
isMacOS,
isWindows,
resolveExistingPathInBase,
shallowEqual,
} from '../../shared/utils';
import { isMacOS, isWindows, shallowEqual } from '../../shared/utils';
import { beforeAppQuit } from '../cleanup';
import { logger } from '../logger';
import {
@@ -36,7 +32,12 @@ import { getMainWindow } from '../windows-manager';
import { popupManager } from '../windows-manager/popup';
import { isAppNameAllowed } from './allow-list';
import { recordingStateMachine } from './state-machine';
import type { AppGroupInfo, RecordingStatus, TappableAppInfo } from './types';
import type {
AppGroupInfo,
Recording,
RecordingStatus,
TappableAppInfo,
} from './types';
export const MeetingsSettingsState = {
$: globalStateStorage.watch<MeetingSettingsSchema>(MeetingSettingsKey).pipe(
@@ -55,12 +56,7 @@ export const MeetingsSettingsState = {
},
};
type Subscriber = {
unsubscribe: () => void;
};
const subscribers: Subscriber[] = [];
let appStateSubscribers: Subscriber[] = [];
// recordings are saved in the app data directory
// may need a way to clean up old recordings
@@ -69,29 +65,10 @@ export const SAVED_RECORDINGS_DIR = path.join(
'recordings'
);
type NativeModule = typeof NativeModuleType;
type ShareableContentType = InstanceType<NativeModule['ShareableContent']>;
type ShareableContentStatic = NativeModule['ShareableContent'];
let shareableContent: ShareableContentType | null = null;
function getNativeModule(): NativeModule {
return require('@affine/native') as NativeModule;
}
function cleanup() {
const nativeId = recordingStateMachine.status?.nativeId;
if (nativeId) cleanupAbandonedNativeRecording(nativeId);
recordingStatus$.next(null);
shareableContent = null;
appStateSubscribers.forEach(subscriber => {
try {
subscriber.unsubscribe();
} catch {
// ignore unsubscribe error
}
});
appStateSubscribers = [];
subscribers.forEach(subscriber => {
try {
subscriber.unsubscribe();
@@ -99,9 +76,6 @@ function cleanup() {
// ignore unsubscribe error
}
});
subscribers.length = 0;
applications$.next([]);
appGroups$.next([]);
}
beforeAppQuit(() => {
@@ -113,21 +87,18 @@ export const appGroups$ = new BehaviorSubject<AppGroupInfo[]>([]);
export const updateApplicationsPing$ = new Subject<number>();
// There should be only one active recording at a time; state is managed by the state machine
export const recordingStatus$ = recordingStateMachine.status$;
// recording id -> recording
// recordings will be saved in memory before consumed and created as an audio block to user's doc
const recordings = new Map<number, Recording>();
function isRecordingSettled(
status: RecordingStatus | null | undefined
): status is RecordingStatus & {
status: 'ready';
blockCreationStatus: 'success' | 'failed';
} {
return status?.status === 'ready' && status.blockCreationStatus !== undefined;
}
// there should be only one active recording at a time
// We'll now use recordingStateMachine.status$ instead of our own BehaviorSubject
export const recordingStatus$ = recordingStateMachine.status$;
function createAppGroup(processGroupId: number): AppGroupInfo | undefined {
// MUST require dynamically to avoid loading @affine/native for unsupported platforms
const SC: ShareableContentStatic = getNativeModule().ShareableContent;
const SC: typeof ShareableContentType =
require('@affine/native').ShareableContent;
const groupProcess = SC?.applicationWithProcessId(processGroupId);
if (!groupProcess) {
return;
@@ -203,13 +174,9 @@ function setupNewRunningAppGroup() {
});
const debounceStartRecording = debounce((appGroup: AppGroupInfo) => {
const currentGroup = appGroups$.value.find(
group => group.processGroupId === appGroup.processGroupId
);
if (currentGroup?.isRunning) {
startRecording(currentGroup).catch(err => {
logger.error('failed to start recording', err);
});
// check if the app is running again
if (appGroup.isRunning) {
startRecording(appGroup);
}
}, 1000);
@@ -233,7 +200,8 @@ function setupNewRunningAppGroup() {
if (
!recordingStatus ||
recordingStatus.status === 'new' ||
isRecordingSettled(recordingStatus)
recordingStatus.status === 'create-block-success' ||
recordingStatus.status === 'create-block-failed'
) {
if (MeetingsSettingsState.value.recordingMode === 'prompt') {
newRecording(currentGroup);
@@ -258,7 +226,7 @@ function setupNewRunningAppGroup() {
removeRecording(recordingStatus.id);
}
// if the watched app stops while we are recording it,
// if the recording is stopped and we are recording it,
// we should stop the recording
if (
recordingStatus?.status === 'recording' &&
@@ -274,28 +242,100 @@ function setupNewRunningAppGroup() {
);
}
function getSanitizedAppId(bundleIdentifier?: string) {
if (!bundleIdentifier) {
return 'unknown';
}
return isWindows()
? createHash('sha256')
.update(bundleIdentifier)
.digest('hex')
.substring(0, 8)
: bundleIdentifier;
}
export function createRecording(status: RecordingStatus) {
let recording = recordings.get(status.id);
if (recording) {
return recording;
}
const appId = getSanitizedAppId(status.appGroup?.bundleIdentifier);
const bufferedFilePath = path.join(
SAVED_RECORDINGS_DIR,
`${appId}-${status.id}-${status.startTime}.raw`
);
fs.ensureDirSync(SAVED_RECORDINGS_DIR);
const file = fs.createWriteStream(bufferedFilePath);
function tapAudioSamples(err: Error | null, samples: Float32Array) {
const recordingStatus = recordingStatus$.getValue();
if (
!recordingStatus ||
recordingStatus.id !== status.id ||
recordingStatus.status === 'paused'
) {
return;
}
if (err) {
logger.error('failed to get audio samples', err);
} else {
// Writing raw Float32Array samples directly to file
// For stereo audio, samples are interleaved [L,R,L,R,...]
file.write(Buffer.from(samples.buffer));
}
}
// MUST require dynamically to avoid loading @affine/native for unsupported platforms
const SC: typeof ShareableContentType =
require('@affine/native').ShareableContent;
const stream = status.app
? SC.tapAudio(status.app.processId, tapAudioSamples)
: SC.tapGlobalAudio(null, tapAudioSamples);
recording = {
id: status.id,
startTime: status.startTime,
app: status.app,
appGroup: status.appGroup,
file,
session: stream,
};
recordings.set(status.id, recording);
return recording;
}
export async function getRecording(id: number) {
const recording = recordingStateMachine.status;
if (!recording || recording.id !== id) {
const recording = recordings.get(id);
if (!recording) {
logger.error(`Recording ${id} not found`);
return;
}
const rawFilePath = String(recording.file.path);
return {
id,
appGroup: recording.appGroup,
app: recording.app,
startTime: recording.startTime,
filepath: recording.filepath,
sampleRate: recording.sampleRate,
numberOfChannels: recording.numberOfChannels,
filepath: rawFilePath,
sampleRate: recording.session.sampleRate,
numberOfChannels: recording.session.channels,
};
}
// recording popup status
// new: waiting for user confirmation
// recording: native recording is ongoing
// processing: native stop or renderer import/transcription is ongoing
// ready + blockCreationStatus: post-processing finished
// new: recording is started, popup is shown
// recording: recording is started, popup is shown
// stopped: recording is stopped, popup showing processing status
// create-block-success: recording is ready, show "open app" button
// create-block-failed: recording is failed, show "failed to save" button
// null: hide popup
function setupRecordingListeners() {
subscribers.push(
@@ -310,21 +350,36 @@ function setupRecordingListeners() {
});
}
if (isRecordingSettled(status)) {
if (status?.status === 'recording') {
let recording = recordings.get(status.id);
// create a recording if not exists
if (!recording) {
recording = createRecording(status);
}
} else if (status?.status === 'stopped') {
const recording = recordings.get(status.id);
if (recording) {
recording.session.stop();
}
} else if (
status?.status === 'create-block-success' ||
status?.status === 'create-block-failed'
) {
// show the popup for 10s
setTimeout(
() => {
const currentStatus = recordingStatus$.value;
// check again if current status is still ready
if (
isRecordingSettled(currentStatus) &&
currentStatus.id === status.id
(recordingStatus$.value?.status === 'create-block-success' ||
recordingStatus$.value?.status === 'create-block-failed') &&
recordingStatus$.value.id === status.id
) {
popup.hide().catch(err => {
logger.error('failed to hide recording popup', err);
});
}
},
status.blockCreationStatus === 'failed' ? 30_000 : 10_000
status?.status === 'create-block-failed' ? 30_000 : 10_000
);
} else if (!status) {
// status is removed, we should hide the popup
@@ -345,7 +400,9 @@ function getAllApps(): TappableAppInfo[] {
}
// MUST require dynamically to avoid loading @affine/native for unsupported platforms
const { ShareableContent } = getNativeModule();
const { ShareableContent } = require('@affine/native') as {
ShareableContent: typeof ShareableContentType;
};
const apps = ShareableContent.applications().map(app => {
try {
@@ -376,8 +433,12 @@ function getAllApps(): TappableAppInfo[] {
return filteredApps;
}
type Subscriber = {
unsubscribe: () => void;
};
function setupMediaListeners() {
const ShareableContent = getNativeModule().ShareableContent;
const ShareableContent = require('@affine/native').ShareableContent;
applications$.next(getAllApps());
subscribers.push(
interval(3000).subscribe(() => {
@@ -393,6 +454,8 @@ function setupMediaListeners() {
})
);
let appStateSubscribers: Subscriber[] = [];
subscribers.push(
applications$.subscribe(apps => {
appStateSubscribers.forEach(subscriber => {
@@ -421,6 +484,15 @@ function setupMediaListeners() {
});
appStateSubscribers = _appStateSubscribers;
return () => {
_appStateSubscribers.forEach(subscriber => {
try {
subscriber.unsubscribe();
} catch {
// ignore unsubscribe error
}
});
};
})
);
}
@@ -430,7 +502,7 @@ function askForScreenRecordingPermission() {
return false;
}
try {
const ShareableContent = getNativeModule().ShareableContent;
const ShareableContent = require('@affine/native').ShareableContent;
// this will trigger the permission prompt
new ShareableContent();
return true;
@@ -447,7 +519,7 @@ export function setupRecordingFeature() {
}
try {
const ShareableContent = getNativeModule().ShareableContent;
const ShareableContent = require('@affine/native').ShareableContent;
if (!shareableContent) {
shareableContent = new ShareableContent();
setupMediaListeners();
@@ -465,6 +537,7 @@ export function setupRecordingFeature() {
}
export function disableRecordingFeature() {
recordingStatus$.next(null);
cleanup();
}
@@ -485,175 +558,222 @@ export function newRecording(
});
}
export async function startRecording(
export function startRecording(
appGroup?: AppGroupInfo | number
): Promise<RecordingStatus | null> {
const previousState = recordingStateMachine.status;
const state = recordingStateMachine.dispatch({
type: 'START_RECORDING',
appGroup: normalizeAppGroupInfo(appGroup),
});
): RecordingStatus | null {
const state = recordingStateMachine.dispatch(
{
type: 'START_RECORDING',
appGroup: normalizeAppGroupInfo(appGroup),
},
false
);
if (!state || state.status !== 'recording' || state === previousState) {
return state;
if (state?.status === 'recording') {
createRecording(state);
}
let nativeId: string | undefined;
recordingStateMachine.status$.next(state);
try {
fs.ensureDirSync(SAVED_RECORDINGS_DIR);
return state;
}
const meta = getNativeModule().startRecording({
appProcessId: state.app?.processId,
outputDir: SAVED_RECORDINGS_DIR,
format: 'opus',
id: String(state.id),
});
nativeId = meta.id;
export function pauseRecording(id: number) {
return recordingStateMachine.dispatch({ type: 'PAUSE_RECORDING', id });
}
const filepath = await assertRecordingFilepath(meta.filepath);
const nextState = recordingStateMachine.dispatch({
type: 'ATTACH_NATIVE_RECORDING',
id: state.id,
nativeId: meta.id,
startTime: meta.startedAt ?? state.startTime,
filepath,
sampleRate: meta.sampleRate,
numberOfChannels: meta.channels,
});
if (!nextState || nextState.nativeId !== meta.id) {
throw new Error('Failed to attach native recording metadata');
}
return nextState;
} catch (error) {
if (nativeId) {
cleanupAbandonedNativeRecording(nativeId);
}
logger.error('failed to start recording', error);
return setRecordingBlockCreationStatus(
state.id,
'failed',
error instanceof Error ? error.message : undefined
);
}
export function resumeRecording(id: number) {
return recordingStateMachine.dispatch({ type: 'RESUME_RECORDING', id });
}
export async function stopRecording(id: number) {
const recording = recordingStateMachine.status;
if (!recording || recording.id !== id) {
const recording = recordings.get(id);
if (!recording) {
logger.error(`stopRecording: Recording ${id} not found`);
return;
}
if (!recording.nativeId) {
logger.error(`stopRecording: Recording ${id} missing native id`);
if (!recording.file.path) {
logger.error(`Recording ${id} has no file path`);
return;
}
const processingState = recordingStateMachine.dispatch({
type: 'STOP_RECORDING',
id,
});
if (
!processingState ||
processingState.id !== id ||
processingState.status !== 'processing'
) {
return serializeRecordingStatus(processingState ?? recording);
const { file, session: stream } = recording;
// First stop the audio stream to prevent more data coming in
try {
stream.stop();
} catch (err) {
logger.error('Failed to stop audio stream', err);
}
// End the file with a timeout
file.end();
try {
const artifact = getNativeModule().stopRecording(recording.nativeId);
const filepath = await assertRecordingFilepath(artifact.filepath);
const readyStatus = recordingStateMachine.dispatch({
type: 'ATTACH_RECORDING_ARTIFACT',
await Promise.race([
new Promise<void>((resolve, reject) => {
file.on('finish', () => {
// check if the file is empty
const stats = fs.statSync(file.path);
if (stats.size === 0) {
reject(new Error('Recording is empty'));
return;
}
resolve();
});
file.on('error', err => {
reject(err);
});
}),
new Promise<never>((_, reject) =>
setTimeout(() => reject(new Error('File writing timeout')), 10000)
),
]);
const recordingStatus = recordingStateMachine.dispatch({
type: 'STOP_RECORDING',
id,
filepath,
sampleRate: artifact.sampleRate,
numberOfChannels: artifact.channels,
});
if (!readyStatus) {
logger.error('No recording status to save');
return;
}
getMainWindow()
.then(mainWindow => {
if (mainWindow) {
mainWindow.show();
}
})
.catch(err => {
logger.error('failed to bring up the window', err);
});
return serializeRecordingStatus(readyStatus);
} catch (error: unknown) {
logger.error('Failed to stop recording', error);
const recordingStatus = await setRecordingBlockCreationStatus(
id,
'failed',
error instanceof Error ? error.message : undefined
);
if (!recordingStatus) {
logger.error('No recording status to stop');
return;
}
return serializeRecordingStatus(recordingStatus);
} catch (error: unknown) {
logger.error('Failed to stop recording', error);
const recordingStatus = recordingStateMachine.dispatch({
type: 'CREATE_BLOCK_FAILED',
id,
error: error instanceof Error ? error : undefined,
});
if (!recordingStatus) {
logger.error('No recording status to stop');
return;
}
return serializeRecordingStatus(recordingStatus);
} finally {
// Clean up the file stream if it's still open
if (!file.closed) {
file.destroy();
}
}
}
async function assertRecordingFilepath(filepath: string) {
return await resolveExistingPathInBase(SAVED_RECORDINGS_DIR, filepath, {
caseInsensitive: isWindows(),
label: 'recording filepath',
});
export async function getRawAudioBuffers(
id: number,
cursor?: number
): Promise<{
buffer: Buffer;
nextCursor: number;
}> {
const recording = recordings.get(id);
if (!recording) {
throw new Error(`getRawAudioBuffers: Recording ${id} not found`);
}
const start = cursor ?? 0;
const file = await fsp.open(recording.file.path, 'r');
const stats = await file.stat();
const buffer = Buffer.alloc(stats.size - start);
const result = await file.read(buffer, 0, buffer.length, start);
await file.close();
return {
buffer,
nextCursor: start + result.bytesRead,
};
}
function assertRecordingFilepath(filepath: string) {
const normalizedPath = path.normalize(filepath);
const normalizedBase = path.normalize(SAVED_RECORDINGS_DIR + path.sep);
if (!normalizedPath.toLowerCase().startsWith(normalizedBase.toLowerCase())) {
throw new Error('Invalid recording filepath');
}
return normalizedPath;
}
export async function readRecordingFile(filepath: string) {
const normalizedPath = await assertRecordingFilepath(filepath);
const normalizedPath = assertRecordingFilepath(filepath);
return fsp.readFile(normalizedPath);
}
function cleanupAbandonedNativeRecording(nativeId: string) {
try {
const artifact = getNativeModule().stopRecording(nativeId);
void assertRecordingFilepath(artifact.filepath)
.then(filepath => {
fs.removeSync(filepath);
})
.catch(error => {
logger.error('failed to validate abandoned recording filepath', error);
});
} catch (error) {
logger.error('failed to cleanup abandoned native recording', error);
export async function readyRecording(id: number, buffer: Buffer) {
logger.info('readyRecording', id);
const recordingStatus = recordingStatus$.value;
const recording = recordings.get(id);
if (!recordingStatus || recordingStatus.id !== id || !recording) {
logger.error(`readyRecording: Recording ${id} not found`);
return;
}
const rawFilePath = String(recording.file.path);
const filepath = rawFilePath.replace('.raw', '.opus');
if (!filepath) {
logger.error(`readyRecording: Recording ${id} has no filepath`);
return;
}
await fs.writeFile(filepath, buffer);
// can safely remove the raw file now
logger.info('remove raw file', rawFilePath);
if (rawFilePath) {
try {
await fs.unlink(rawFilePath);
} catch (err) {
logger.error('failed to remove raw file', err);
}
}
// Update the status through the state machine
recordingStateMachine.dispatch({
type: 'SAVE_RECORDING',
id,
filepath,
});
// bring up the window
getMainWindow()
.then(mainWindow => {
if (mainWindow) {
mainWindow.show();
}
})
.catch(err => {
logger.error('failed to bring up the window', err);
});
}
export async function setRecordingBlockCreationStatus(
id: number,
status: 'success' | 'failed',
errorMessage?: string
) {
return recordingStateMachine.dispatch({
type: 'SET_BLOCK_CREATION_STATUS',
export async function handleBlockCreationSuccess(id: number) {
recordingStateMachine.dispatch({
type: 'CREATE_BLOCK_SUCCESS',
id,
status,
errorMessage,
});
}
export async function handleBlockCreationFailed(id: number, error?: Error) {
recordingStateMachine.dispatch({
type: 'CREATE_BLOCK_FAILED',
id,
error,
});
}
export function removeRecording(id: number) {
recordings.delete(id);
recordingStateMachine.dispatch({ type: 'REMOVE_RECORDING', id });
}
export interface SerializedRecordingStatus {
id: number;
status: RecordingStatus['status'];
blockCreationStatus?: RecordingStatus['blockCreationStatus'];
appName?: string;
// if there is no app group, it means the recording is for system audio
appGroupId?: number;
@@ -667,17 +787,18 @@ export interface SerializedRecordingStatus {
export function serializeRecordingStatus(
status: RecordingStatus
): SerializedRecordingStatus | null {
const recording = recordings.get(status.id);
return {
id: status.id,
status: status.status,
blockCreationStatus: status.blockCreationStatus,
appName: status.appGroup?.name,
appGroupId: status.appGroup?.processGroupId,
icon: status.appGroup?.icon,
startTime: status.startTime,
filepath: status.filepath,
sampleRate: status.sampleRate,
numberOfChannels: status.numberOfChannels,
filepath:
status.filepath ?? (recording ? String(recording.file.path) : undefined),
sampleRate: recording?.session.sampleRate,
numberOfChannels: recording?.session.channels,
};
}

View File

@@ -2,9 +2,11 @@
// Should not load @affine/native for unsupported platforms
import path from 'node:path';
import { shell } from 'electron';
import { isMacOS, resolvePathInBase } from '../../shared/utils';
import { isMacOS } from '../../shared/utils';
import { openExternalSafely } from '../security/open-external';
import type { NamespaceHandlers } from '../type';
import {
@@ -12,14 +14,18 @@ import {
checkMeetingPermissions,
checkRecordingAvailable,
disableRecordingFeature,
getRawAudioBuffers,
getRecording,
handleBlockCreationFailed,
handleBlockCreationSuccess,
pauseRecording,
readRecordingFile,
readyRecording,
recordingStatus$,
removeRecording,
SAVED_RECORDINGS_DIR,
type SerializedRecordingStatus,
serializeRecordingStatus,
setRecordingBlockCreationStatus,
setupRecordingFeature,
startRecording,
stopRecording,
@@ -39,19 +45,27 @@ export const recordingHandlers = {
startRecording: async (_, appGroup?: AppGroupInfo | number) => {
return startRecording(appGroup);
},
pauseRecording: async (_, id: number) => {
return pauseRecording(id);
},
stopRecording: async (_, id: number) => {
return stopRecording(id);
},
getRawAudioBuffers: async (_, id: number, cursor?: number) => {
return getRawAudioBuffers(id, cursor);
},
readRecordingFile: async (_, filepath: string) => {
return readRecordingFile(filepath);
},
setRecordingBlockCreationStatus: async (
_,
id: number,
status: 'success' | 'failed',
errorMessage?: string
) => {
return setRecordingBlockCreationStatus(id, status, errorMessage);
// save the encoded recording buffer to the file system
readyRecording: async (_, id: number, buffer: Uint8Array) => {
return readyRecording(id, Buffer.from(buffer));
},
handleBlockCreationSuccess: async (_, id: number) => {
return handleBlockCreationSuccess(id);
},
handleBlockCreationFailed: async (_, id: number, error?: Error) => {
return handleBlockCreationFailed(id, error);
},
removeRecording: async (_, id: number) => {
return removeRecording(id);
@@ -86,10 +100,15 @@ export const recordingHandlers = {
return false;
},
showSavedRecordings: async (_, subpath?: string) => {
const directory = resolvePathInBase(SAVED_RECORDINGS_DIR, subpath ?? '', {
label: 'directory',
});
return shell.showItemInFolder(directory);
const normalizedDir = path.normalize(
path.join(SAVED_RECORDINGS_DIR, subpath ?? '')
);
const normalizedBase = path.normalize(SAVED_RECORDINGS_DIR);
if (!normalizedDir.startsWith(normalizedBase)) {
throw new Error('Invalid directory');
}
return shell.showItemInFolder(normalizedDir);
},
} satisfies NamespaceHandlers;

View File

@@ -13,31 +13,25 @@ export type RecordingEvent =
type: 'START_RECORDING';
appGroup?: AppGroupInfo;
}
| {
type: 'ATTACH_NATIVE_RECORDING';
id: number;
nativeId: string;
startTime: number;
filepath: string;
sampleRate: number;
numberOfChannels: number;
}
| { type: 'PAUSE_RECORDING'; id: number }
| { type: 'RESUME_RECORDING'; id: number }
| {
type: 'STOP_RECORDING';
id: number;
}
| {
type: 'ATTACH_RECORDING_ARTIFACT';
type: 'SAVE_RECORDING';
id: number;
filepath: string;
sampleRate?: number;
numberOfChannels?: number;
}
| {
type: 'SET_BLOCK_CREATION_STATUS';
type: 'CREATE_BLOCK_FAILED';
id: number;
error?: Error;
}
| {
type: 'CREATE_BLOCK_SUCCESS';
id: number;
status: 'success' | 'failed';
errorMessage?: string;
}
| { type: 'REMOVE_RECORDING'; id: number };
@@ -80,26 +74,23 @@ export class RecordingStateMachine {
case 'START_RECORDING':
newStatus = this.handleStartRecording(event.appGroup);
break;
case 'ATTACH_NATIVE_RECORDING':
newStatus = this.handleAttachNativeRecording(event);
case 'PAUSE_RECORDING':
newStatus = this.handlePauseRecording();
break;
case 'RESUME_RECORDING':
newStatus = this.handleResumeRecording();
break;
case 'STOP_RECORDING':
newStatus = this.handleStopRecording(event.id);
break;
case 'ATTACH_RECORDING_ARTIFACT':
newStatus = this.handleAttachRecordingArtifact(
event.id,
event.filepath,
event.sampleRate,
event.numberOfChannels
);
case 'SAVE_RECORDING':
newStatus = this.handleSaveRecording(event.id, event.filepath);
break;
case 'SET_BLOCK_CREATION_STATUS':
newStatus = this.handleSetBlockCreationStatus(
event.id,
event.status,
event.errorMessage
);
case 'CREATE_BLOCK_SUCCESS':
newStatus = this.handleCreateBlockSuccess(event.id);
break;
case 'CREATE_BLOCK_FAILED':
newStatus = this.handleCreateBlockFailed(event.id, event.error);
break;
case 'REMOVE_RECORDING':
this.handleRemoveRecording(event.id);
@@ -142,7 +133,7 @@ export class RecordingStateMachine {
const currentStatus = this.recordingStatus$.value;
if (
currentStatus?.status === 'recording' ||
currentStatus?.status === 'processing'
currentStatus?.status === 'stopped'
) {
logger.error(
'Cannot start a new recording if there is already a recording'
@@ -169,31 +160,46 @@ export class RecordingStateMachine {
}
/**
* Attach native recording metadata to the current recording
* Handle the PAUSE_RECORDING event
*/
private handleAttachNativeRecording(
event: Extract<RecordingEvent, { type: 'ATTACH_NATIVE_RECORDING' }>
): RecordingStatus | null {
private handlePauseRecording(): RecordingStatus | null {
const currentStatus = this.recordingStatus$.value;
if (!currentStatus || currentStatus.id !== event.id) {
logger.error(`Recording ${event.id} not found for native attachment`);
return currentStatus;
if (!currentStatus) {
logger.error('No active recording to pause');
return null;
}
if (currentStatus.status !== 'recording') {
logger.error(
`Cannot attach native metadata when recording is in ${currentStatus.status} state`
);
logger.error(`Cannot pause recording in ${currentStatus.status} state`);
return currentStatus;
}
return {
...currentStatus,
nativeId: event.nativeId,
startTime: event.startTime,
filepath: event.filepath,
sampleRate: event.sampleRate,
numberOfChannels: event.numberOfChannels,
status: 'paused',
};
}
/**
* Handle the RESUME_RECORDING event
*/
private handleResumeRecording(): RecordingStatus | null {
const currentStatus = this.recordingStatus$.value;
if (!currentStatus) {
logger.error('No active recording to resume');
return null;
}
if (currentStatus.status !== 'paused') {
logger.error(`Cannot resume recording in ${currentStatus.status} state`);
return currentStatus;
}
return {
...currentStatus,
status: 'recording',
};
}
@@ -208,25 +214,26 @@ export class RecordingStateMachine {
return currentStatus;
}
if (currentStatus.status !== 'recording') {
if (
currentStatus.status !== 'recording' &&
currentStatus.status !== 'paused'
) {
logger.error(`Cannot stop recording in ${currentStatus.status} state`);
return currentStatus;
}
return {
...currentStatus,
status: 'processing',
status: 'stopped',
};
}
/**
* Attach the encoded artifact once native stop completes
* Handle the SAVE_RECORDING event
*/
private handleAttachRecordingArtifact(
private handleSaveRecording(
id: number,
filepath: string,
sampleRate?: number,
numberOfChannels?: number
filepath: string
): RecordingStatus | null {
const currentStatus = this.recordingStatus$.value;
@@ -235,54 +242,51 @@ export class RecordingStateMachine {
return currentStatus;
}
if (currentStatus.status !== 'processing') {
logger.error(`Cannot attach artifact in ${currentStatus.status} state`);
return currentStatus;
}
return {
...currentStatus,
status: 'ready',
filepath,
sampleRate: sampleRate ?? currentStatus.sampleRate,
numberOfChannels: numberOfChannels ?? currentStatus.numberOfChannels,
};
}
/**
* Set the renderer-side block creation result
* Handle the CREATE_BLOCK_SUCCESS event
*/
private handleSetBlockCreationStatus(
id: number,
status: 'success' | 'failed',
errorMessage?: string
): RecordingStatus | null {
private handleCreateBlockSuccess(id: number): RecordingStatus | null {
const currentStatus = this.recordingStatus$.value;
if (!currentStatus || currentStatus.id !== id) {
logger.error(`Recording ${id} not found for block creation status`);
logger.error(`Recording ${id} not found for create-block-success`);
return currentStatus;
}
if (currentStatus.status === 'new') {
logger.error(`Cannot settle recording ${id} before it starts`);
return currentStatus;
}
if (
currentStatus.status === 'ready' &&
currentStatus.blockCreationStatus !== undefined
) {
return currentStatus;
}
if (errorMessage) {
logger.error(`Recording ${id} create block failed: ${errorMessage}`);
}
return {
...currentStatus,
status: 'ready',
blockCreationStatus: status,
status: 'create-block-success',
};
}
/**
* Handle the CREATE_BLOCK_FAILED event
*/
private handleCreateBlockFailed(
id: number,
error?: Error
): RecordingStatus | null {
const currentStatus = this.recordingStatus$.value;
if (!currentStatus || currentStatus.id !== id) {
logger.error(`Recording ${id} not found for create-block-failed`);
return currentStatus;
}
if (error) {
logger.error(`Recording ${id} create block failed:`, error);
}
return {
...currentStatus,
status: 'create-block-failed',
};
}

View File

@@ -1,35 +1,88 @@
# Recording State Transitions
The desktop recording flow now has a single linear engine state and a separate post-process result.
This document visualizes the possible state transitions in the recording system.
## Engine states
## States
- `inactive`: no active recording
- `new`: app detected, waiting for user confirmation
- `recording`: native capture is running
- `processing`: native capture has stopped and the artifact is being imported
- `ready`: post-processing has finished
The recording system has the following states:
## Post-process result
- **inactive**: No active recording (null state)
- **new**: A new recording has been detected but not yet started
- **recording**: Audio is being recorded
- **paused**: Recording is temporarily paused
- **stopped**: Recording has been stopped and is processing
- **ready**: Recording is processed and ready for use
`ready` recordings may carry `blockCreationStatus`:
## Transitions
- `success`: the recording block was created successfully
- `failed`: the artifact was saved, but block creation/import failed
## State flow
```text
inactive -> new -> recording -> processing -> ready
^ |
| |
+------ start ---------+
```
┌───────────┐ ┌───────┐
│ │ │ │
│ inactive │◀───────────────│ ready │
│ │ │ │
└─────┬─────┘ └───┬───┘
│ │
│ NEW_RECORDING │
┌───────────┐
│ │ │
│ new │ │
│ │ │
└─────┬─────┘ │
│ │
│ START_RECORDING │
▼ │
┌───────────┐ │
│ │ STOP_RECORDING│
│ recording │─────────────────┐ │
│ │◀────────────┐ │ │
└─────┬─────┘ │ │ │
│ │ │ │
│ PAUSE_RECORDING │ │ │
▼ │ │ │
┌───────────┐ │ │ │
│ │ │ │ │
│ paused │ │ │ │
│ │ │ │ │
└─────┬─────┘ │ │ │
│ │ │ │
│ RESUME_RECORDING │ │ │
└───────────────────┘ │ │
│ │
▼ │
┌───────────┐
│ │
│ stopped │
│ │
└─────┬─────┘
│ SAVE_RECORDING
┌───────────┐
│ │
│ ready │
│ │
└───────────┘
```
- `START_RECORDING` creates or reuses a pending `new` recording.
- `ATTACH_NATIVE_RECORDING` fills in native metadata while staying in `recording`.
- `STOP_RECORDING` moves the flow to `processing`.
- `ATTACH_RECORDING_ARTIFACT` attaches the finalized `.opus` artifact while staying in `processing`.
- `SET_BLOCK_CREATION_STATUS` settles the flow as `ready`.
## Events
Only one recording can be active at a time. A new recording can start only after the previous one has been removed or its `ready` result has been settled.
The following events trigger state transitions:
- `NEW_RECORDING`: Create a new recording when an app starts or is detected
- `START_RECORDING`: Start recording audio
- `PAUSE_RECORDING`: Pause the current recording
- `RESUME_RECORDING`: Resume a paused recording
- `STOP_RECORDING`: Stop the current recording
- `SAVE_RECORDING`: Save and finalize a recording
- `REMOVE_RECORDING`: Delete a recording
## Error Handling
Invalid state transitions are logged and prevented. For example:
- Cannot start a new recording when one is already in progress
- Cannot pause a recording that is not in the 'recording' state
- Cannot resume a recording that is not in the 'paused' state
Each transition function in the state machine validates the current state before allowing a transition.

View File

@@ -1,4 +1,6 @@
import type { ApplicationInfo } from '@affine/native';
import type { WriteStream } from 'node:fs';
import type { ApplicationInfo, AudioCaptureSession } from '@affine/native';
export interface TappableAppInfo {
info: ApplicationInfo;
@@ -18,19 +20,38 @@ export interface AppGroupInfo {
isRunning: boolean;
}
export interface Recording {
id: number;
// the app may not be available if the user choose to record system audio
app?: TappableAppInfo;
appGroup?: AppGroupInfo;
// the buffered file that is being recorded streamed to
file: WriteStream;
session: AudioCaptureSession;
startTime: number;
filepath?: string; // the filepath of the recording (only available when status is ready)
}
export interface RecordingStatus {
id: number; // corresponds to the recording id
// an app group is detected and waiting for user confirmation
// recording: the native recorder is running
// processing: recording has stopped and the artifact is being prepared/imported
// ready: the post-processing result has been settled
status: 'new' | 'recording' | 'processing' | 'ready';
// the status of the recording in a linear state machine
// new: an new app group is listening. note, if there are any active recording, the current recording will not change
// recording: the recording is ongoing
// paused: the recording is paused
// stopped: the recording is stopped (processing audio file for use in the editor)
// ready: the recording is ready to be used
// create-block-success: the recording is successfully created as a block
// create-block-failed: creating block failed
status:
| 'new'
| 'recording'
| 'paused'
| 'stopped'
| 'ready'
| 'create-block-success'
| 'create-block-failed';
app?: TappableAppInfo;
appGroup?: AppGroupInfo;
startTime: number; // 0 means not started yet
filepath?: string; // encoded file path
nativeId?: string;
sampleRate?: number;
numberOfChannels?: number;
blockCreationStatus?: 'success' | 'failed';
}

View File

@@ -160,7 +160,11 @@ class TrayState implements Disposable {
const recordingStatus = recordingStatus$.value;
if (!recordingStatus || recordingStatus.status !== 'recording') {
if (
!recordingStatus ||
(recordingStatus?.status !== 'paused' &&
recordingStatus?.status !== 'recording')
) {
const appMenuItems = runningAppGroups.map(appGroup => ({
label: appGroup.name,
icon: appGroup.icon || undefined,
@@ -168,9 +172,7 @@ class TrayState implements Disposable {
logger.info(
`User action: Start Recording Meeting (${appGroup.name})`
);
startRecording(appGroup).catch(err => {
logger.error('Failed to start recording:', err);
});
startRecording(appGroup);
},
}));
@@ -186,9 +188,7 @@ class TrayState implements Disposable {
logger.info(
'User action: Start Recording Meeting (System audio)'
);
startRecording().catch(err => {
logger.error('Failed to start recording:', err);
});
startRecording();
},
},
...appMenuItems,
@@ -201,7 +201,7 @@ class TrayState implements Disposable {
? `Recording (${recordingStatus.appGroup?.name})`
: 'Recording';
// recording is active
// recording is either started or paused
items.push(
{
label: recordingLabel,

View File

@@ -1,5 +1,4 @@
import { realpath } from 'node:fs/promises';
import { isAbsolute, join, relative, resolve, sep } from 'node:path';
import { join } from 'node:path';
import type { EventBasedChannel } from 'async-call-rpc';
@@ -48,130 +47,6 @@ export class MessageEventChannel implements EventBasedChannel {
export const resourcesPath = join(__dirname, `../resources`);
function normalizeComparedPath(path: string, caseInsensitive: boolean) {
return caseInsensitive ? path.toLowerCase() : path;
}
export function isPathInsideBase(
basePath: string,
targetPath: string,
options: { caseInsensitive?: boolean } = {}
) {
const { caseInsensitive = false } = options;
const normalizedBase = normalizeComparedPath(
resolve(basePath),
caseInsensitive
);
const normalizedTarget = normalizeComparedPath(
resolve(targetPath),
caseInsensitive
);
const rel = relative(normalizedBase, normalizedTarget);
return (
rel === '' ||
(!isAbsolute(rel) && rel !== '..' && !rel.startsWith(`..${sep}`))
);
}
export function resolvePathInBase(
basePath: string,
targetPath: string,
options: { caseInsensitive?: boolean; label?: string } = {}
) {
const resolvedBase = resolve(basePath);
const resolvedTarget = resolve(resolvedBase, targetPath);
if (!isPathInsideBase(resolvedBase, resolvedTarget, options)) {
throw new Error(
options.label ? `Invalid ${options.label}` : 'Invalid path'
);
}
return resolvedTarget;
}
export async function resolveExistingPath(targetPath: string) {
try {
return await realpath(targetPath);
} catch (error) {
const code = (error as NodeJS.ErrnoException).code;
if (code === 'ENOENT' || code === 'ENOTDIR') {
return resolve(targetPath);
}
throw error;
}
}
export async function resolveExistingPathInBase(
basePath: string,
targetPath: string,
options: { caseInsensitive?: boolean; label?: string } = {}
) {
const [resolvedBase, resolvedTarget] = await Promise.all([
resolveExistingPath(basePath),
resolveExistingPath(targetPath),
]);
if (!isPathInsideBase(resolvedBase, resolvedTarget, options)) {
throw new Error(
options.label ? `Invalid ${options.label}` : 'Invalid path'
);
}
return resolvedTarget;
}
export function assertPathComponent(
value: string,
label: string = 'path component'
) {
const hasControlChar = Array.from(value).some(
character => character.charCodeAt(0) < 0x20
);
if (
!value ||
value === '.' ||
value === '..' ||
/[/\\]/.test(value) ||
hasControlChar
) {
throw new Error(`Invalid ${label}`);
}
return value;
}
export function normalizeWorkspaceIdForPath(
value: string,
options: { windows?: boolean; label?: string } = {}
) {
const { windows = isWindows(), label = 'workspace id' } = options;
const safeValue = assertPathComponent(value, label);
if (!windows) {
return safeValue;
}
const windowsReservedChars = new Set(['<', '>', ':', '"', '|', '?', '*']);
let normalized = '';
for (const character of safeValue) {
normalized += windowsReservedChars.has(character) ? '_' : character;
}
while (normalized.endsWith('.') || normalized.endsWith(' ')) {
normalized = normalized.slice(0, -1);
}
if (!normalized || normalized === '.' || normalized === '..') {
throw new Error(`Invalid ${label}`);
}
return normalized;
}
// credit: https://github.com/facebook/fbjs/blob/main/packages/fbjs/src/core/shallowEqual.js
export function shallowEqual<T>(objA: T, objB: T) {
if (Object.is(objA, objB)) {

View File

@@ -21,11 +21,6 @@ const docSetSpaceId = vi.fn();
const sqliteValidate = vi.fn();
const sqliteValidateImportSchema = vi.fn();
const sqliteVacuumInto = vi.fn();
const sqliteClose = vi.fn();
const showOpenDialog = vi.fn();
const showSaveDialog = vi.fn();
const showItemInFolder = vi.fn(async () => undefined);
const getPath = vi.fn();
vi.doMock('nanoid', () => ({
nanoid: () => 'workspace-1',
@@ -75,10 +70,6 @@ vi.doMock('@affine/native', () => {
vacuumInto(path: string) {
return sqliteVacuumInto(this.path, path);
}
close() {
return sqliteClose(this.path);
}
},
};
});
@@ -93,10 +84,7 @@ vi.doMock('@affine/electron/helper/nbstore', () => ({
vi.doMock('@affine/electron/helper/main-rpc', () => ({
mainRPC: {
getPath,
showItemInFolder,
showOpenDialog,
showSaveDialog,
showItemInFolder: vi.fn(),
},
}));
@@ -138,11 +126,12 @@ describe('dialog export', () => {
realpath.mockImplementation(async path => path);
getSpaceDBPath.mockResolvedValue(dbPath);
move.mockResolvedValue(undefined);
showSaveDialog.mockResolvedValue({ canceled: false, filePath: exportPath });
const { saveDBFileAs } =
const { saveDBFileAs, setFakeDialogResult } =
await import('@affine/electron/helper/dialog/dialog');
setFakeDialogResult({ filePath: exportPath });
const result = await saveDBFileAs(id, 'My Space');
expect(result).toEqual({ filePath: exportPath });
@@ -162,11 +151,12 @@ describe('dialog export', () => {
pathExists.mockResolvedValue(false);
getSpaceDBPath.mockResolvedValue(dbPath);
showSaveDialog.mockResolvedValue({ canceled: false, filePath: dbPath });
const { saveDBFileAs } =
const { saveDBFileAs, setFakeDialogResult } =
await import('@affine/electron/helper/dialog/dialog');
setFakeDialogResult({ filePath: dbPath });
const result = await saveDBFileAs(id, 'My Space');
expect(result).toEqual({ error: 'DB_FILE_PATH_INVALID' });
@@ -184,11 +174,12 @@ describe('dialog export', () => {
path === exportPath ? dbPath : path
);
getSpaceDBPath.mockResolvedValue(dbPath);
showSaveDialog.mockResolvedValue({ canceled: false, filePath: exportPath });
const { saveDBFileAs } =
const { saveDBFileAs, setFakeDialogResult } =
await import('@affine/electron/helper/dialog/dialog');
setFakeDialogResult({ filePath: exportPath });
const result = await saveDBFileAs(id, 'My Space');
expect(result).toEqual({ error: 'DB_FILE_PATH_INVALID' });
@@ -202,12 +193,6 @@ describe('dialog import', () => {
const originalPath = '/tmp/import.affine';
const internalPath = '/app/workspaces/local/workspace-1/storage.db';
pathExists.mockResolvedValue(true);
realpath.mockImplementation(async path => path);
showOpenDialog.mockResolvedValue({
canceled: false,
filePaths: [originalPath],
});
getWorkspacesBasePath.mockResolvedValue('/app/workspaces');
getSpaceDBPath.mockResolvedValue(internalPath);
docValidate.mockResolvedValue(true);
@@ -216,9 +201,11 @@ describe('dialog import', () => {
docSetSpaceId.mockResolvedValue(undefined);
ensureDir.mockResolvedValue(undefined);
const { loadDBFile } =
const { loadDBFile, setFakeDialogResult } =
await import('@affine/electron/helper/dialog/dialog');
setFakeDialogResult({ filePath: originalPath });
const result = await loadDBFile();
expect(result).toEqual({ workspaceId: 'workspace-1' });
@@ -232,19 +219,15 @@ describe('dialog import', () => {
test('loadDBFile rejects v2 imports with unexpected schema objects', async () => {
const originalPath = '/tmp/import.affine';
pathExists.mockResolvedValue(true);
realpath.mockImplementation(async path => path);
showOpenDialog.mockResolvedValue({
canceled: false,
filePaths: [originalPath],
});
getWorkspacesBasePath.mockResolvedValue('/app/workspaces');
docValidate.mockResolvedValue(true);
docValidateImportSchema.mockResolvedValue(false);
const { loadDBFile } =
const { loadDBFile, setFakeDialogResult } =
await import('@affine/electron/helper/dialog/dialog');
setFakeDialogResult({ filePath: originalPath });
const result = await loadDBFile();
expect(result).toEqual({ error: 'DB_FILE_INVALID' });
@@ -256,12 +239,6 @@ describe('dialog import', () => {
const originalPath = '/tmp/import-v1.affine';
const internalPath = '/app/workspaces/workspace-1/storage.db';
pathExists.mockResolvedValue(true);
realpath.mockImplementation(async path => path);
showOpenDialog.mockResolvedValue({
canceled: false,
filePaths: [originalPath],
});
getWorkspacesBasePath.mockResolvedValue('/app/workspaces');
getWorkspaceDBPath.mockResolvedValue(internalPath);
docValidate.mockResolvedValue(false);
@@ -270,9 +247,11 @@ describe('dialog import', () => {
sqliteVacuumInto.mockResolvedValue(undefined);
ensureDir.mockResolvedValue(undefined);
const { loadDBFile } =
const { loadDBFile, setFakeDialogResult } =
await import('@affine/electron/helper/dialog/dialog');
setFakeDialogResult({ filePath: originalPath });
const result = await loadDBFile();
expect(result).toEqual({ workspaceId: 'workspace-1' });
@@ -284,57 +263,6 @@ describe('dialog import', () => {
id: 'workspace-1',
mainDBPath: internalPath,
});
expect(sqliteClose).toHaveBeenCalledWith(originalPath);
expect(copy).not.toHaveBeenCalled();
});
test('loadDBFile closes v1 connection when schema validation fails', async () => {
const originalPath = '/tmp/import-v1-invalid.affine';
pathExists.mockResolvedValue(true);
realpath.mockImplementation(async path => path);
showOpenDialog.mockResolvedValue({
canceled: false,
filePaths: [originalPath],
});
getWorkspacesBasePath.mockResolvedValue('/app/workspaces');
docValidate.mockResolvedValue(false);
sqliteValidate.mockResolvedValue('Valid');
sqliteValidateImportSchema.mockResolvedValue(false);
const { loadDBFile } =
await import('@affine/electron/helper/dialog/dialog');
const result = await loadDBFile();
expect(result).toEqual({ error: 'DB_FILE_INVALID' });
expect(sqliteClose).toHaveBeenCalledWith(originalPath);
expect(sqliteVacuumInto).not.toHaveBeenCalled();
});
test('loadDBFile rejects normalized paths inside app data', async () => {
const selectedPath = '/tmp/import.affine';
const normalizedPath = '/app/workspaces/local/existing/storage.db';
pathExists.mockResolvedValue(true);
realpath.mockImplementation(async path => {
if (path === selectedPath) {
return normalizedPath;
}
return path;
});
showOpenDialog.mockResolvedValue({
canceled: false,
filePaths: [selectedPath],
});
getWorkspacesBasePath.mockResolvedValue('/app/workspaces');
const { loadDBFile } =
await import('@affine/electron/helper/dialog/dialog');
const result = await loadDBFile();
expect(result).toEqual({ error: 'DB_FILE_PATH_INVALID' });
expect(docValidate).not.toHaveBeenCalled();
});
});

View File

@@ -1,107 +0,0 @@
import { randomUUID } from 'node:crypto';
import fs from 'node:fs/promises';
import os from 'node:os';
import path from 'node:path';
import { afterEach, describe, expect, test } from 'vitest';
import {
assertPathComponent,
normalizeWorkspaceIdForPath,
resolveExistingPathInBase,
resolvePathInBase,
} from '../../src/shared/utils';
const tmpDir = path.join(os.tmpdir(), `affine-electron-utils-${randomUUID()}`);
afterEach(async () => {
await fs.rm(tmpDir, { recursive: true, force: true });
});
describe('path guards', () => {
test('resolvePathInBase blocks sibling-prefix escapes', () => {
const baseDir = path.join(tmpDir, 'recordings');
expect(() =>
resolvePathInBase(baseDir, '../recordings-evil/file.opus', {
label: 'directory',
})
).toThrow('Invalid directory');
});
test.runIf(process.platform !== 'win32')(
'resolveExistingPathInBase rejects symlink escapes',
async () => {
const baseDir = path.join(tmpDir, 'recordings');
const outsideDir = path.join(tmpDir, 'outside');
const outsideFile = path.join(outsideDir, 'secret.txt');
const linkPath = path.join(baseDir, '1234567890abcdef.blob');
await fs.mkdir(baseDir, { recursive: true });
await fs.mkdir(outsideDir, { recursive: true });
await fs.writeFile(outsideFile, 'secret');
await fs.symlink(outsideFile, linkPath);
await expect(
resolveExistingPathInBase(baseDir, linkPath, {
label: 'recording filepath',
})
).rejects.toThrow('Invalid recording filepath');
}
);
test('resolveExistingPathInBase falls back for missing descendants', async () => {
const baseDir = path.join(tmpDir, 'recordings');
await fs.mkdir(baseDir, { recursive: true });
const missingPath = path.join(
await fs.realpath(baseDir),
'pending',
'recording.opus'
);
await expect(
resolveExistingPathInBase(baseDir, missingPath, {
label: 'recording filepath',
})
).resolves.toBe(path.resolve(missingPath));
});
test.runIf(process.platform !== 'win32')(
'resolveExistingPathInBase preserves non-missing realpath errors',
async () => {
const baseDir = path.join(tmpDir, 'recordings');
const loopPath = path.join(baseDir, 'loop.opus');
await fs.mkdir(baseDir, { recursive: true });
await fs.symlink(path.basename(loopPath), loopPath);
await expect(
resolveExistingPathInBase(baseDir, loopPath, {
label: 'recording filepath',
})
).rejects.toMatchObject({ code: 'ELOOP' });
}
);
test.each(['../../escape', 'nested/id'])(
'assertPathComponent rejects invalid workspace id %s',
input => {
expect(() => assertPathComponent(input, 'workspace id')).toThrow(
'Invalid workspace id'
);
}
);
test.each([
{ input: 'legacy:id*with?reserved.', expected: 'legacy_id_with_reserved' },
{ input: 'safe-workspace', expected: 'safe-workspace' },
])(
'normalizeWorkspaceIdForPath maps $input to $expected on Windows',
({ input, expected }) => {
expect(normalizeWorkspaceIdForPath(input, { windows: true })).toBe(
expected
);
}
);
});

View File

@@ -1,256 +0,0 @@
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
const isActiveTab = vi.fn();
const readRecordingFile = vi.fn();
const setRecordingBlockCreationStatus = vi.fn();
const getCurrentWorkspace = vi.fn();
const isAiEnabled = vi.fn();
const transcribeRecording = vi.fn();
let onRecordingStatusChanged:
| ((
status: {
id: number;
status: 'processing';
appName?: string;
filepath?: string;
startTime: number;
blockCreationStatus?: 'success' | 'failed';
} | null
) => void)
| undefined;
vi.mock('@affine/core/modules/doc', () => ({
DocsService: class DocsService {},
}));
vi.mock('@affine/core/modules/media/services/audio-attachment', () => ({
AudioAttachmentService: class AudioAttachmentService {},
}));
vi.mock('@affine/core/modules/workbench', () => ({
WorkbenchService: class WorkbenchService {},
}));
vi.mock('@affine/debug', () => ({
DebugLogger: class DebugLogger {
debug = vi.fn();
error = vi.fn();
},
}));
vi.mock('@affine/electron-api', () => ({
apis: {
ui: {
isActiveTab,
},
recording: {
readRecordingFile,
setRecordingBlockCreationStatus,
},
},
events: {
recording: {
onRecordingStatusChanged: vi.fn(
(handler: typeof onRecordingStatusChanged) => {
onRecordingStatusChanged = handler;
return () => {
onRecordingStatusChanged = undefined;
};
}
),
},
},
}));
vi.mock('@affine/i18n', () => ({
i18nTime: vi.fn(() => 'Jan 1 09:00'),
}));
vi.mock('@affine/track', () => ({
default: {
doc: {
editor: {
audioBlock: {
transcribeRecording,
},
},
},
},
}));
vi.mock('../../../electron-renderer/src/app/effects/utils', () => ({
getCurrentWorkspace,
isAiEnabled,
}));
function createWorkspaceRef() {
const blobSet = vi.fn(async () => 'blob-1');
const addBlock = vi.fn(() => 'attachment-1');
const getBlock = vi.fn(() => ({ model: { id: 'attachment-1' } }));
const openDoc = vi.fn();
type MockDoc = {
workspace: {
blobSync: {
set: typeof blobSet;
};
};
addBlock: typeof addBlock;
getBlock: typeof getBlock;
};
type MockDocProps = {
onStoreLoad: (doc: MockDoc, meta: { noteId: string }) => void;
};
const createDoc = vi.fn(({ docProps }: { docProps: MockDocProps }) => {
queueMicrotask(() => {
docProps.onStoreLoad(
{
workspace: { blobSync: { set: blobSet } },
addBlock,
getBlock,
},
{ noteId: 'note-1' }
);
});
return { id: 'doc-1' };
});
const scope = {
get(token: { name?: string }) {
switch (token.name) {
case 'DocsService':
return { createDoc };
case 'WorkbenchService':
return { workbench: { openDoc } };
case 'AudioAttachmentService':
return {
get: () => ({
obj: {
transcribe: vi.fn(async () => undefined),
},
[Symbol.dispose]: vi.fn(),
}),
};
default:
throw new Error(`Unexpected token: ${token.name}`);
}
},
};
const dispose = vi.fn();
return {
ref: {
workspace: { scope },
dispose,
[Symbol.dispose]: dispose,
},
createDoc,
openDoc,
blobSet,
addBlock,
getBlock,
};
}
describe('recording effect', () => {
beforeEach(() => {
vi.useFakeTimers();
vi.clearAllMocks();
vi.resetModules();
onRecordingStatusChanged = undefined;
readRecordingFile.mockResolvedValue(new Uint8Array([1, 2, 3]).buffer);
setRecordingBlockCreationStatus.mockResolvedValue(undefined);
isAiEnabled.mockReturnValue(false);
});
afterEach(() => {
vi.runOnlyPendingTimers();
vi.useRealTimers();
});
test('retries processing until the active tab has a workspace', async () => {
const workspace = createWorkspaceRef();
isActiveTab.mockResolvedValueOnce(false).mockResolvedValue(true);
getCurrentWorkspace
.mockReturnValueOnce(undefined)
.mockReturnValue(workspace.ref);
const { setupRecordingEvents } =
await import('../../../electron-renderer/src/app/effects/recording');
setupRecordingEvents({} as never);
onRecordingStatusChanged?.({
id: 7,
status: 'processing',
appName: 'Zoom',
filepath: '/tmp/meeting.opus',
startTime: 1000,
});
await Promise.resolve();
expect(workspace.createDoc).not.toHaveBeenCalled();
expect(setRecordingBlockCreationStatus).not.toHaveBeenCalled();
await vi.advanceTimersByTimeAsync(1000);
expect(workspace.createDoc).not.toHaveBeenCalled();
expect(setRecordingBlockCreationStatus).not.toHaveBeenCalled();
await vi.advanceTimersByTimeAsync(1000);
expect(workspace.createDoc).toHaveBeenCalledTimes(1);
expect(workspace.openDoc).toHaveBeenCalledWith('doc-1');
expect(workspace.blobSet).toHaveBeenCalledTimes(1);
const [savedBlob] = workspace.blobSet.mock.calls[0] ?? [];
expect(savedBlob).toBeInstanceOf(Blob);
expect((savedBlob as Blob).type).toBe('audio/ogg');
expect(workspace.addBlock).toHaveBeenCalledWith(
'affine:attachment',
expect.objectContaining({ type: 'audio/ogg' }),
'note-1'
);
expect(setRecordingBlockCreationStatus).toHaveBeenCalledWith(7, 'success');
expect(setRecordingBlockCreationStatus).not.toHaveBeenCalledWith(
7,
'failed',
expect.anything()
);
});
test('retries when the active-tab probe rejects', async () => {
const workspace = createWorkspaceRef();
isActiveTab
.mockRejectedValueOnce(new Error('probe failed'))
.mockResolvedValue(true);
getCurrentWorkspace.mockReturnValue(workspace.ref);
const { setupRecordingEvents } =
await import('../../../electron-renderer/src/app/effects/recording');
setupRecordingEvents({} as never);
onRecordingStatusChanged?.({
id: 9,
status: 'processing',
appName: 'Meet',
filepath: '/tmp/meeting.opus',
startTime: 1000,
});
await Promise.resolve();
expect(workspace.createDoc).not.toHaveBeenCalled();
expect(setRecordingBlockCreationStatus).not.toHaveBeenCalled();
await vi.advanceTimersByTimeAsync(1000);
expect(workspace.createDoc).toHaveBeenCalledTimes(1);
expect(setRecordingBlockCreationStatus).toHaveBeenCalledWith(9, 'success');
});
});

View File

@@ -1,116 +0,0 @@
import { describe, expect, test, vi } from 'vitest';
vi.mock('../../src/main/logger', () => ({
logger: {
error: vi.fn(),
info: vi.fn(),
},
}));
import { RecordingStateMachine } from '../../src/main/recording/state-machine';
function createAttachedRecording(stateMachine: RecordingStateMachine) {
const pending = stateMachine.dispatch({
type: 'START_RECORDING',
});
stateMachine.dispatch({
type: 'ATTACH_NATIVE_RECORDING',
id: pending!.id,
nativeId: 'native-1',
startTime: 100,
filepath: '/tmp/recording.opus',
sampleRate: 48000,
numberOfChannels: 2,
});
return pending!;
}
describe('RecordingStateMachine', () => {
test('transitions from recording to ready after artifact import and block creation', () => {
const stateMachine = new RecordingStateMachine();
const pending = createAttachedRecording(stateMachine);
expect(pending?.status).toBe('recording');
const processing = stateMachine.dispatch({
type: 'STOP_RECORDING',
id: pending.id,
});
expect(processing?.status).toBe('processing');
const artifactAttached = stateMachine.dispatch({
type: 'ATTACH_RECORDING_ARTIFACT',
id: pending.id,
filepath: '/tmp/recording.opus',
sampleRate: 48000,
numberOfChannels: 2,
});
expect(artifactAttached).toMatchObject({
status: 'processing',
filepath: '/tmp/recording.opus',
});
const ready = stateMachine.dispatch({
type: 'SET_BLOCK_CREATION_STATUS',
id: pending.id,
status: 'success',
});
expect(ready).toMatchObject({
status: 'ready',
blockCreationStatus: 'success',
});
});
test('keeps native audio metadata when stop artifact omits it', () => {
const stateMachine = new RecordingStateMachine();
const pending = createAttachedRecording(stateMachine);
stateMachine.dispatch({ type: 'STOP_RECORDING', id: pending.id });
const artifactAttached = stateMachine.dispatch({
type: 'ATTACH_RECORDING_ARTIFACT',
id: pending.id,
filepath: '/tmp/recording.opus',
});
expect(artifactAttached).toMatchObject({
sampleRate: 48000,
numberOfChannels: 2,
});
});
test.each([
{ status: 'success' as const, errorMessage: undefined },
{ status: 'failed' as const, errorMessage: 'native start failed' },
])(
'settles recordings into ready state with blockCreationStatus=$status',
({ status, errorMessage }) => {
const stateMachine = new RecordingStateMachine();
const pending = stateMachine.dispatch({
type: 'START_RECORDING',
});
expect(pending?.status).toBe('recording');
const settled = stateMachine.dispatch({
type: 'SET_BLOCK_CREATION_STATUS',
id: pending!.id,
status,
errorMessage,
});
expect(settled).toMatchObject({
status: 'ready',
blockCreationStatus: status,
});
const next = stateMachine.dispatch({
type: 'START_RECORDING',
});
expect(next?.id).toBeGreaterThan(pending!.id);
expect(next?.status).toBe('recording');
expect(next?.blockCreationStatus).toBeUndefined();
}
);
});

View File

@@ -131,52 +131,4 @@ describe('workspace db management', () => {
)
).toBe(false);
});
test('rejects unsafe ids when deleting a workspace', async () => {
const { deleteWorkspace } =
await import('@affine/electron/helper/workspace/handlers');
const outsideDir = path.join(tmpDir, 'outside-delete-target');
await fs.ensureDir(outsideDir);
await expect(
deleteWorkspace(
universalId({
peer: 'local',
type: 'workspace',
id: '../../outside-delete-target',
})
)
).rejects.toThrow('Invalid workspace id');
expect(await fs.pathExists(outsideDir)).toBe(true);
});
test('rejects unsafe ids when deleting backup workspaces', async () => {
const { deleteBackupWorkspace } =
await import('@affine/electron/helper/workspace/handlers');
const outsideDir = path.join(tmpDir, 'outside-backup-target');
await fs.ensureDir(outsideDir);
await expect(
deleteBackupWorkspace('../../outside-backup-target')
).rejects.toThrow('Invalid workspace id');
expect(await fs.pathExists(outsideDir)).toBe(true);
});
test('rejects unsafe ids when recovering backup workspaces', async () => {
const { recoverBackupWorkspace } =
await import('@affine/electron/helper/workspace/handlers');
const outsideDir = path.join(tmpDir, 'outside-recover-target');
await fs.ensureDir(outsideDir);
await expect(
recoverBackupWorkspace('../../outside-recover-target')
).rejects.toThrow('Invalid workspace id');
expect(await fs.pathExists(outsideDir)).toBe(true);
});
});

View File

@@ -77,8 +77,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/RevenueCat/purchases-ios-spm.git",
"state" : {
"revision" : "abb0d68c3e7ba97b16ab51c38fcaca16b0e358c8",
"version" : "5.66.0"
"revision" : "2913a336eb37dc06795cdbaa5b5de330b6707669",
"version" : "5.65.0"
}
},
{
@@ -113,8 +113,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-collections",
"state" : {
"revision" : "6675bc0ff86e61436e615df6fc5174e043e57924",
"version" : "1.4.1"
"revision" : "8d9834a6189db730f6264db7556a7ffb751e99ee",
"version" : "1.4.0"
}
},
{

View File

@@ -17,7 +17,7 @@ let package = Package(
],
dependencies: [
.package(path: "../AffineResources"),
.package(url: "https://github.com/RevenueCat/purchases-ios-spm.git", from: "5.66.0"),
.package(url: "https://github.com/RevenueCat/purchases-ios-spm.git", from: "5.60.0"),
],
targets: [
.target(

View File

@@ -17,7 +17,7 @@ let package = Package(
.package(path: "../AffineGraphQL"),
.package(path: "../AffineResources"),
.package(url: "https://github.com/apollographql/apollo-ios.git", from: "1.23.0"),
.package(url: "https://github.com/apple/swift-collections.git", from: "1.4.1"),
.package(url: "https://github.com/apple/swift-collections.git", from: "1.4.0"),
.package(url: "https://github.com/SnapKit/SnapKit.git", from: "5.7.1"),
.package(url: "https://github.com/SwifterSwift/SwifterSwift.git", from: "6.2.0"),
.package(url: "https://github.com/Recouse/EventSource.git", from: "0.1.7"),

View File

@@ -433,9 +433,7 @@ export const NbStoreNativeDBApis: NativeDBApis = {
id: string,
docId: string
): Promise<DocIndexedClock | null> {
return NbStore.getDocIndexedClock({ id, docId }).then(clock =>
clock ? { ...clock, timestamp: new Date(clock.timestamp) } : null
);
return NbStore.getDocIndexedClock({ id, docId });
},
setDocIndexedClock: function (
id: string,

View File

@@ -187,6 +187,7 @@ export const DayPicker = memo(function DayPicker(
{/* Weeks in month */}
{matrix.map((week, i) => {
return (
// eslint-disable-next-line react/no-array-index-key
<div key={i} className={clsx(styles.monthViewRow)}>
{week.map(cell => {
const dateValue = cell.date.format(format);

View File

@@ -126,8 +126,10 @@ export const MonthPicker = memo(function MonthPicker(
const Body = useMemo(() => {
return (
<div className={styles.yearViewBody}>
{/* eslint-disable-next-line react/no-array-index-key */}
{matrix.map((row, i) => {
return (
// eslint-disable-next-line react/no-array-index-key
<div key={i} className={styles.yearViewRow}>
{row.map(month => {
const monthValue = month.format('YYYY-MM');

View File

@@ -81,7 +81,7 @@ const BackupWorkspaceItem = ({ item }: { item: BackupWorkspaceItem }) => {
const handleImport = useAsyncCallback(async () => {
setImporting(true);
track.$.settingsPanel.archivedWorkspaces.recoverArchivedWorkspace();
const workspaceId = await backupService.recoverBackupWorkspace(item.id);
const workspaceId = await backupService.recoverBackupWorkspace(item.dbPath);
if (!workspaceId) {
setImporting(false);
return;
@@ -102,7 +102,7 @@ const BackupWorkspaceItem = ({ item }: { item: BackupWorkspaceItem }) => {
});
setMenuOpen(false);
setImporting(false);
}, [backupService, item.id, jumpToPage, t]);
}, [backupService, item.dbPath, jumpToPage, t]);
const handleDelete = useCallback(
(backupWorkspaceId: string) => {

View File

@@ -47,11 +47,9 @@ export class BackupService extends Service {
)
);
async recoverBackupWorkspace(backupWorkspaceId: string) {
async recoverBackupWorkspace(dbPath: string) {
const result =
await this.desktopApiService.handler.workspace.recoverBackupWorkspace(
backupWorkspaceId
);
await this.desktopApiService.handler.dialog.loadDBFile(dbPath);
if (result.workspaceId) {
_addLocalWorkspace(result.workspaceId);
this.workspacesService.list.revalidate();

View File

@@ -414,3 +414,98 @@ export async function encodeAudioBlobToOpusSlices(
await audioContext.close();
}
}
export const createStreamEncoder = (
recordingId: number,
codecs: {
sampleRate: number;
numberOfChannels: number;
targetBitrate?: number;
}
) => {
const { encoder, encodedChunks } = createOpusEncoder({
sampleRate: codecs.sampleRate,
numberOfChannels: codecs.numberOfChannels,
bitrate: codecs.targetBitrate,
});
const toAudioData = (buffer: Uint8Array) => {
// Each sample in f32 format is 4 bytes
const BYTES_PER_SAMPLE = 4;
return new AudioData({
format: 'f32',
sampleRate: codecs.sampleRate,
numberOfChannels: codecs.numberOfChannels,
numberOfFrames:
buffer.length / BYTES_PER_SAMPLE / codecs.numberOfChannels,
timestamp: 0,
data: toArrayBuffer(buffer),
});
};
let cursor = 0;
let isClosed = false;
const next = async () => {
if (!apis) {
throw new Error('Electron API is not available');
}
if (isClosed) {
return;
}
const { buffer, nextCursor } = await apis.recording.getRawAudioBuffers(
recordingId,
cursor
);
if (isClosed || cursor === nextCursor) {
return;
}
cursor = nextCursor;
logger.debug('Encoding next chunk', cursor, nextCursor);
encoder.encode(toAudioData(buffer));
};
const poll = async () => {
if (isClosed) {
return;
}
logger.debug('Polling next chunk');
await next();
await new Promise(resolve => setTimeout(resolve, 1000));
await poll();
};
const close = () => {
if (isClosed) {
return;
}
isClosed = true;
return encoder.close();
};
return {
id: recordingId,
next,
poll,
flush: () => {
return encoder.flush();
},
close,
finish: async () => {
logger.debug('Finishing encoding');
await next();
close();
const buffer = muxToMp4(encodedChunks, {
sampleRate: codecs.sampleRate,
numberOfChannels: codecs.numberOfChannels,
bitrate: codecs.targetBitrate,
});
return buffer;
},
[Symbol.dispose]: () => {
close();
},
};
};
export type OpusStreamEncoder = ReturnType<typeof createStreamEncoder>;

View File

@@ -40,37 +40,6 @@ export declare function decodeAudio(buf: Uint8Array, destSampleRate?: number | u
/** Decode audio file into a Float32Array */
export declare function decodeAudioSync(buf: Uint8Array, destSampleRate?: number | undefined | null, filename?: string | undefined | null): Float32Array
export interface RecordingArtifact {
id: string
filepath: string
sampleRate: number
channels: number
durationMs: number
size: number
}
export interface RecordingSessionMeta {
id: string
filepath: string
sampleRate: number
channels: number
startedAt: number
}
export interface RecordingStartOptions {
appProcessId?: number
excludeProcessIds?: Array<number>
outputDir: string
format?: string
sampleRate?: number
channels?: number
id?: string
}
export declare function startRecording(opts: RecordingStartOptions): RecordingSessionMeta
export declare function stopRecording(id: string): RecordingArtifact
export interface MermaidRenderOptions {
theme?: string
fontFamily?: string

View File

@@ -579,8 +579,6 @@ module.exports.AudioCaptureSession = nativeBinding.AudioCaptureSession
module.exports.ShareableContent = nativeBinding.ShareableContent
module.exports.decodeAudio = nativeBinding.decodeAudio
module.exports.decodeAudioSync = nativeBinding.decodeAudioSync
module.exports.startRecording = nativeBinding.startRecording
module.exports.stopRecording = nativeBinding.stopRecording
module.exports.mintChallengeResponse = nativeBinding.mintChallengeResponse
module.exports.renderMermaidSvg = nativeBinding.renderMermaidSvg
module.exports.renderTypstSvg = nativeBinding.renderTypstSvg

View File

@@ -12,15 +12,11 @@ harness = false
name = "mix_audio_samples"
[dependencies]
crossbeam-channel = { workspace = true }
napi = { workspace = true, features = ["napi4"] }
napi-derive = { workspace = true, features = ["type-def"] }
ogg = { workspace = true }
opus-codec = { git = "https://github.com/toeverything/opus-codec", rev = "c2afef2" }
rand = { workspace = true }
rubato = { workspace = true }
symphonia = { workspace = true, features = ["all", "opt-simd"] }
thiserror = { workspace = true }
napi = { workspace = true, features = ["napi4"] }
napi-derive = { workspace = true, features = ["type-def"] }
rubato = { workspace = true }
symphonia = { workspace = true, features = ["all", "opt-simd"] }
thiserror = { workspace = true }
[target.'cfg(target_os = "macos")'.dependencies]
block2 = { workspace = true }
@@ -34,9 +30,10 @@ screencapturekit = { workspace = true }
uuid = { workspace = true, features = ["v4"] }
[target.'cfg(target_os = "windows")'.dependencies]
cpal = { workspace = true }
windows = { workspace = true }
windows-core = { workspace = true }
cpal = { workspace = true }
crossbeam-channel = { workspace = true }
windows = { workspace = true }
windows-core = { workspace = true }
[dev-dependencies]
criterion2 = { workspace = true }

View File

@@ -1,31 +0,0 @@
use std::sync::Arc;
use crossbeam_channel::Sender;
use napi::{
bindgen_prelude::Float32Array,
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
};
/// Internal callback abstraction so audio taps can target JS or native
/// pipelines.
#[derive(Clone)]
pub enum AudioCallback {
Js(Arc<ThreadsafeFunction<Float32Array, ()>>),
Channel(Sender<Vec<f32>>),
}
impl AudioCallback {
pub fn call(&self, samples: Vec<f32>) {
match self {
Self::Js(func) => {
// Non-blocking call into JS; errors are ignored to avoid blocking the
// audio thread.
let _ = func.call(Ok(samples.into()), ThreadsafeFunctionCallMode::NonBlocking);
}
Self::Channel(sender) => {
// Drop the chunk if the channel is full to avoid blocking capture.
let _ = sender.try_send(samples);
}
}
}
}

View File

@@ -8,6 +8,4 @@ pub mod windows;
#[cfg(target_os = "windows")]
pub use windows::*;
pub mod audio_callback;
pub mod audio_decoder;
pub mod recording;

View File

@@ -34,7 +34,6 @@ use screencapturekit::shareable_content::SCShareableContent;
use uuid::Uuid;
use crate::{
audio_callback::AudioCallback,
error::CoreAudioError,
pid::{audio_process_list, get_process_property},
tap_audio::{AggregateDeviceManager, AudioCaptureSession},
@@ -652,9 +651,10 @@ impl ShareableContent {
Ok(false)
}
pub(crate) fn tap_audio_with_callback(
#[napi]
pub fn tap_audio(
process_id: u32,
audio_stream_callback: AudioCallback,
audio_stream_callback: ThreadsafeFunction<napi::bindgen_prelude::Float32Array, ()>,
) -> Result<AudioCaptureSession> {
let app = ShareableContent::applications()?
.into_iter()
@@ -668,10 +668,13 @@ impl ShareableContent {
));
}
// Convert ThreadsafeFunction to Arc<ThreadsafeFunction>
let callback_arc = Arc::new(audio_stream_callback);
// Use AggregateDeviceManager instead of AggregateDevice directly
// This provides automatic default device change detection
let mut device_manager = AggregateDeviceManager::new(&app)?;
device_manager.start_capture(audio_stream_callback)?;
device_manager.start_capture(callback_arc)?;
let boxed_manager = Box::new(device_manager);
Ok(AudioCaptureSession::new(boxed_manager))
} else {
@@ -683,16 +686,9 @@ impl ShareableContent {
}
#[napi]
pub fn tap_audio(
process_id: u32,
audio_stream_callback: ThreadsafeFunction<napi::bindgen_prelude::Float32Array, ()>,
) -> Result<AudioCaptureSession> {
ShareableContent::tap_audio_with_callback(process_id, AudioCallback::Js(Arc::new(audio_stream_callback)))
}
pub(crate) fn tap_global_audio_with_callback(
pub fn tap_global_audio(
excluded_processes: Option<Vec<&ApplicationInfo>>,
audio_stream_callback: AudioCallback,
audio_stream_callback: ThreadsafeFunction<napi::bindgen_prelude::Float32Array, ()>,
) -> Result<AudioCaptureSession> {
let excluded_object_ids = excluded_processes
.unwrap_or_default()
@@ -700,21 +696,13 @@ impl ShareableContent {
.map(|app| app.object_id)
.collect::<Vec<_>>();
// Convert ThreadsafeFunction to Arc<ThreadsafeFunction>
let callback_arc = Arc::new(audio_stream_callback);
// Use the new AggregateDeviceManager for automatic device adaptation
let mut device_manager = AggregateDeviceManager::new_global(&excluded_object_ids)?;
device_manager.start_capture(audio_stream_callback)?;
device_manager.start_capture(callback_arc)?;
let boxed_manager = Box::new(device_manager);
Ok(AudioCaptureSession::new(boxed_manager))
}
#[napi]
pub fn tap_global_audio(
excluded_processes: Option<Vec<&ApplicationInfo>>,
audio_stream_callback: ThreadsafeFunction<napi::bindgen_prelude::Float32Array, ()>,
) -> Result<AudioCaptureSession> {
ShareableContent::tap_global_audio_with_callback(
excluded_processes,
AudioCallback::Js(Arc::new(audio_stream_callback)),
)
}
}

View File

@@ -20,13 +20,15 @@ use coreaudio::sys::{
kAudioObjectPropertyElementMain, kAudioObjectPropertyScopeGlobal, kAudioObjectSystemObject, kAudioSubDeviceUIDKey,
kAudioSubTapUIDKey,
};
use napi::bindgen_prelude::Result;
use napi::{
bindgen_prelude::{Float32Array, Result, Status},
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
};
use napi_derive::napi;
use objc2::runtime::AnyObject;
use crate::{
audio_buffer::InputAndOutputAudioBufferList,
audio_callback::AudioCallback,
ca_tap_description::CATapDescription,
cf_types::CFDictionaryBuilder,
device::get_device_uid,
@@ -218,7 +220,7 @@ impl AggregateDevice {
/// Implementation for the AggregateDevice to start processing audio
pub fn start(
&mut self,
audio_stream_callback: AudioCallback,
audio_stream_callback: Arc<ThreadsafeFunction<Float32Array, (), Float32Array, Status, true>>,
// Add original_audio_stats to ensure consistent target rate
original_audio_stats: AudioStats,
) -> Result<AudioTapStream> {
@@ -273,8 +275,8 @@ impl AggregateDevice {
return kAudioHardwareBadStreamError as i32;
};
// Send the processed audio data to the configured sink
audio_stream_callback.call(mixed_samples);
// Send the processed audio data to JavaScript
audio_stream_callback.call(Ok(mixed_samples.into()), ThreadsafeFunctionCallMode::NonBlocking);
kAudioHardwareNoError as i32
},
@@ -525,7 +527,7 @@ pub struct AggregateDeviceManager {
app_id: Option<AudioObjectID>,
excluded_processes: Vec<AudioObjectID>,
active_stream: Option<Arc<std::sync::Mutex<Option<AudioTapStream>>>>,
audio_callback: Option<AudioCallback>,
audio_callback: Option<Arc<ThreadsafeFunction<Float32Array, (), Float32Array, Status, true>>>,
original_audio_stats: Option<AudioStats>,
}
@@ -563,7 +565,10 @@ impl AggregateDeviceManager {
}
/// This sets up the initial stream and listeners.
pub fn start_capture(&mut self, audio_stream_callback: AudioCallback) -> Result<()> {
pub fn start_capture(
&mut self,
audio_stream_callback: Arc<ThreadsafeFunction<Float32Array, (), Float32Array, Status, true>>,
) -> Result<()> {
// Store the callback for potential device switch later
self.audio_callback = Some(audio_stream_callback.clone());

View File

@@ -1,942 +0,0 @@
use std::{
fs,
io::{BufWriter, Write},
path::PathBuf,
sync::{LazyLock, Mutex},
thread::{self, JoinHandle},
time::{SystemTime, UNIX_EPOCH},
};
use crossbeam_channel::{Receiver, Sender, bounded};
use napi::{Error, Status, bindgen_prelude::Result};
use napi_derive::napi;
use ogg::writing::{PacketWriteEndInfo, PacketWriter};
use opus_codec::{Application, Channels, Encoder, FrameSize, SampleRate as OpusSampleRate};
use rubato::Resampler;
#[cfg(any(target_os = "macos", target_os = "windows"))]
use crate::audio_callback::AudioCallback;
#[cfg(target_os = "macos")]
use crate::macos::screen_capture_kit::{ApplicationInfo, ShareableContent};
#[cfg(target_os = "windows")]
use crate::windows::screen_capture_kit::ShareableContent;
const ENCODE_SAMPLE_RATE: OpusSampleRate = OpusSampleRate::Hz48000;
const MAX_PACKET_SIZE: usize = 4096;
const RESAMPLER_INPUT_CHUNK: usize = 1024;
type RecordingResult<T> = std::result::Result<T, RecordingError>;
#[napi(object)]
pub struct RecordingStartOptions {
pub app_process_id: Option<u32>,
pub exclude_process_ids: Option<Vec<u32>>,
pub output_dir: String,
pub format: Option<String>,
pub sample_rate: Option<u32>,
pub channels: Option<u32>,
pub id: Option<String>,
}
#[napi(object)]
pub struct RecordingSessionMeta {
pub id: String,
pub filepath: String,
pub sample_rate: u32,
pub channels: u32,
pub started_at: i64,
}
#[napi(object)]
pub struct RecordingArtifact {
pub id: String,
pub filepath: String,
pub sample_rate: u32,
pub channels: u32,
pub duration_ms: i64,
pub size: i64,
}
#[derive(Debug, thiserror::Error)]
enum RecordingError {
#[error("unsupported platform")]
UnsupportedPlatform,
#[error("invalid output directory")]
InvalidOutputDir,
#[error("invalid channel count {0}")]
InvalidChannels(u32),
#[error("invalid format {0}")]
InvalidFormat(String),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("encoding error: {0}")]
Encoding(String),
#[error("recording not found")]
NotFound,
#[error("empty recording")]
Empty,
#[error("start failure: {0}")]
Start(String),
#[error("join failure")]
Join,
}
impl RecordingError {
fn code(&self) -> &'static str {
match self {
RecordingError::UnsupportedPlatform => "unsupported-platform",
RecordingError::InvalidOutputDir => "invalid-output-dir",
RecordingError::InvalidChannels(_) => "invalid-channels",
RecordingError::InvalidFormat(_) => "invalid-format",
RecordingError::Io(_) => "io-error",
RecordingError::Encoding(_) => "encoding-error",
RecordingError::NotFound => "not-found",
RecordingError::Empty => "empty-recording",
RecordingError::Start(_) => "start-failure",
RecordingError::Join => "join-failure",
}
}
}
impl From<RecordingError> for Error {
fn from(err: RecordingError) -> Self {
Error::new(Status::GenericFailure, format!("{}: {}", err.code(), err))
}
}
struct InterleavedResampler {
resampler: rubato::FastFixedIn<f32>,
channels: usize,
fifo: Vec<Vec<f32>>,
warmed: bool,
}
impl InterleavedResampler {
fn new(from_sr: u32, to_sr: u32, channels: usize) -> RecordingResult<Self> {
let ratio = to_sr as f64 / from_sr as f64;
let resampler = rubato::FastFixedIn::<f32>::new(
ratio,
1.0,
rubato::PolynomialDegree::Linear,
RESAMPLER_INPUT_CHUNK,
channels,
)
.map_err(|e| RecordingError::Encoding(format!("resampler init failed: {e}")))?;
Ok(Self {
resampler,
channels,
fifo: vec![Vec::<f32>::new(); channels],
warmed: false,
})
}
fn append_blocks(&mut self, blocks: Vec<Vec<f32>>, out: &mut Vec<f32>) {
if blocks.is_empty() || blocks.len() != self.channels {
return;
}
if !self.warmed {
self.warmed = true;
return;
}
let out_len = blocks[0].len();
for i in 0..out_len {
for channel in blocks.iter().take(self.channels) {
out.push(channel[i]);
}
}
}
fn feed(&mut self, interleaved: &[f32]) -> RecordingResult<Vec<f32>> {
for frame in interleaved.chunks(self.channels) {
for (idx, sample) in frame.iter().enumerate() {
if let Some(channel_fifo) = self.fifo.get_mut(idx) {
channel_fifo.push(*sample);
}
}
}
let mut out = Vec::new();
while self.fifo.first().map(|q| q.len()).unwrap_or(0) >= RESAMPLER_INPUT_CHUNK {
let mut chunk: Vec<Vec<f32>> = Vec::with_capacity(self.channels);
for channel in &mut self.fifo {
let take: Vec<f32> = channel.drain(..RESAMPLER_INPUT_CHUNK).collect();
chunk.push(take);
}
let blocks = self
.resampler
.process(&chunk, None)
.map_err(|e| RecordingError::Encoding(format!("resampler process failed: {e}")))?;
self.append_blocks(blocks, &mut out);
}
Ok(out)
}
fn finalize(&mut self) -> RecordingResult<Vec<f32>> {
let mut out = Vec::new();
let has_pending = self.fifo.first().map(|q| !q.is_empty()).unwrap_or(false);
if has_pending {
let mut chunk: Vec<Vec<f32>> = Vec::with_capacity(self.channels);
for channel in &mut self.fifo {
chunk.push(std::mem::take(channel));
}
let blocks = self
.resampler
.process_partial(Some(&chunk), None)
.map_err(|e| RecordingError::Encoding(format!("resampler finalize failed: {e}")))?;
self.append_blocks(blocks, &mut out);
}
let delayed = self
.resampler
.process_partial::<Vec<f32>>(None, None)
.map_err(|e| RecordingError::Encoding(format!("resampler drain failed: {e}")))?;
self.append_blocks(delayed, &mut out);
Ok(out)
}
}
fn normalize_channel_count(channels: u32) -> RecordingResult<Channels> {
match channels {
1 => Ok(Channels::Mono),
2 => Ok(Channels::Stereo),
other => Err(RecordingError::InvalidChannels(other)),
}
}
fn convert_interleaved_channels(
samples: &[f32],
source_channels: usize,
target_channels: usize,
) -> RecordingResult<Vec<f32>> {
if source_channels == 0 || target_channels == 0 {
return Err(RecordingError::Encoding("channel count must be positive".into()));
}
if !samples.len().is_multiple_of(source_channels) {
return Err(RecordingError::Encoding("invalid interleaved sample buffer".into()));
}
if source_channels == target_channels {
return Ok(samples.to_vec());
}
let frame_count = samples.len() / source_channels;
let mut converted = Vec::with_capacity(frame_count * target_channels);
match (source_channels, target_channels) {
(1, 2) => {
for &sample in samples {
converted.push(sample);
converted.push(sample);
}
}
(_, 1) => {
for frame in samples.chunks(source_channels) {
let sum: f32 = frame.iter().copied().sum();
converted.push(sum / source_channels as f32);
}
}
(2, 2) => return Ok(samples.to_vec()),
(_, 2) => {
for frame in samples.chunks(source_channels) {
let mono = frame.iter().copied().sum::<f32>() / source_channels as f32;
converted.push(mono);
converted.push(mono);
}
}
_ => {
return Err(RecordingError::Encoding(format!(
"unsupported channel conversion: {source_channels} -> {target_channels}"
)));
}
}
Ok(converted)
}
struct OggOpusWriter {
writer: PacketWriter<'static, BufWriter<fs::File>>,
encoder: Encoder,
frame_samples: usize,
pending: Vec<f32>,
pending_packet: Option<Vec<u8>>,
pending_packet_granule_position: u64,
granule_position: u64,
samples_written: u64,
source_channels: usize,
channels: Channels,
sample_rate: OpusSampleRate,
resampler: Option<InterleavedResampler>,
filepath: PathBuf,
stream_serial: u32,
}
impl OggOpusWriter {
fn new(
filepath: PathBuf,
source_sample_rate: u32,
source_channels: u32,
encoding_channels: u32,
) -> RecordingResult<Self> {
let source_channels =
usize::try_from(source_channels).map_err(|_| RecordingError::InvalidChannels(source_channels))?;
let channels = normalize_channel_count(encoding_channels)?;
let sample_rate = ENCODE_SAMPLE_RATE;
let mut encoder =
Encoder::new(sample_rate, channels, Application::Audio).map_err(|e| RecordingError::Encoding(e.to_string()))?;
let pre_skip = u16::try_from(
encoder
.lookahead()
.map_err(|e| RecordingError::Encoding(e.to_string()))?,
)
.map_err(|_| RecordingError::Encoding("invalid encoder lookahead".into()))?;
let resampler = if source_sample_rate != sample_rate.as_i32() as u32 {
Some(InterleavedResampler::new(
source_sample_rate,
sample_rate.as_i32() as u32,
channels.as_usize(),
)?)
} else {
None
};
if let Some(parent) = filepath.parent() {
fs::create_dir_all(parent)?;
}
let file = fs::File::create(&filepath)?;
let mut writer = PacketWriter::new(BufWriter::new(file));
let stream_serial: u32 = rand::random();
write_opus_headers(&mut writer, stream_serial, channels, sample_rate, pre_skip)?;
let frame_samples = FrameSize::Ms20.samples(sample_rate);
Ok(Self {
writer,
encoder,
frame_samples,
pending: Vec::new(),
pending_packet: None,
pending_packet_granule_position: 0,
granule_position: u64::from(pre_skip),
samples_written: 0,
source_channels,
channels,
sample_rate,
resampler,
filepath,
stream_serial,
})
}
fn push_samples(&mut self, samples: &[f32]) -> RecordingResult<()> {
let normalized = convert_interleaved_channels(samples, self.source_channels, self.channels.as_usize())?;
let mut processed = if let Some(resampler) = &mut self.resampler {
resampler.feed(&normalized)?
} else {
normalized
};
if processed.is_empty() {
return Ok(());
}
self.pending.append(&mut processed);
let frame_len = self.frame_samples * self.channels.as_usize();
while self.pending.len() >= frame_len {
let frame: Vec<f32> = self.pending.drain(..frame_len).collect();
self.encode_frame(frame, self.frame_samples, PacketWriteEndInfo::NormalPacket)?;
}
Ok(())
}
fn encode_frame(&mut self, frame: Vec<f32>, samples_in_frame: usize, end: PacketWriteEndInfo) -> RecordingResult<()> {
let mut out = vec![0u8; MAX_PACKET_SIZE];
let encoded = self
.encoder
.encode_float(&frame, &mut out)
.map_err(|e| RecordingError::Encoding(e.to_string()))?;
self.granule_position += samples_in_frame as u64;
self.samples_written += samples_in_frame as u64;
let packet = out[..encoded].to_vec();
if let Some(previous_packet) = self.pending_packet.replace(packet) {
self
.writer
.write_packet(
previous_packet,
self.stream_serial,
PacketWriteEndInfo::NormalPacket,
self.pending_packet_granule_position,
)
.map_err(|e| RecordingError::Encoding(format!("failed to write packet: {e}")))?;
}
self.pending_packet_granule_position = self.granule_position;
if end == PacketWriteEndInfo::EndStream {
let final_packet = self
.pending_packet
.take()
.ok_or_else(|| RecordingError::Encoding("missing final packet".into()))?;
self
.writer
.write_packet(
final_packet,
self.stream_serial,
PacketWriteEndInfo::EndStream,
self.pending_packet_granule_position,
)
.map_err(|e| RecordingError::Encoding(format!("failed to write packet: {e}")))?;
}
Ok(())
}
fn finish(mut self) -> RecordingResult<RecordingArtifact> {
if let Some(resampler) = &mut self.resampler {
let mut flushed = resampler.finalize()?;
self.pending.append(&mut flushed);
}
let frame_len = self.frame_samples * self.channels.as_usize();
if !self.pending.is_empty() {
let mut frame = self.pending.clone();
let samples_in_frame = frame.len() / self.channels.as_usize();
frame.resize(frame_len, 0.0);
self.encode_frame(frame, samples_in_frame, PacketWriteEndInfo::EndStream)?;
self.pending.clear();
}
if self.samples_written == 0 {
fs::remove_file(&self.filepath).ok();
return Err(RecordingError::Empty);
}
if let Some(final_packet) = self.pending_packet.take() {
self
.writer
.write_packet(
final_packet,
self.stream_serial,
PacketWriteEndInfo::EndStream,
self.pending_packet_granule_position,
)
.map_err(|e| RecordingError::Encoding(format!("failed to finish stream: {e}")))?;
}
self.writer.inner_mut().flush()?;
let size = fs::metadata(&self.filepath)?.len() as i64;
let duration_ms = (self.samples_written * 1000) as i64 / self.sample_rate.as_i32() as i64;
Ok(RecordingArtifact {
id: String::new(),
filepath: self.filepath.to_string_lossy().to_string(),
sample_rate: self.sample_rate.as_i32() as u32,
channels: self.channels.as_usize() as u32,
duration_ms,
size,
})
}
}
fn write_opus_headers(
writer: &mut PacketWriter<'static, BufWriter<fs::File>>,
stream_serial: u32,
channels: Channels,
sample_rate: OpusSampleRate,
pre_skip: u16,
) -> RecordingResult<()> {
let mut opus_head = Vec::with_capacity(19);
opus_head.extend_from_slice(b"OpusHead");
opus_head.push(1); // version
opus_head.push(channels.as_usize() as u8);
opus_head.extend_from_slice(&pre_skip.to_le_bytes());
opus_head.extend_from_slice(&(sample_rate.as_i32() as u32).to_le_bytes());
opus_head.extend_from_slice(&0i16.to_le_bytes()); // output gain
opus_head.push(0); // channel mapping
writer
.write_packet(opus_head, stream_serial, PacketWriteEndInfo::EndPage, 0)
.map_err(|e| RecordingError::Encoding(format!("failed to write OpusHead: {e}")))?;
let vendor = b"AFFiNE Native";
let mut opus_tags = Vec::new();
opus_tags.extend_from_slice(b"OpusTags");
opus_tags.extend_from_slice(&(vendor.len() as u32).to_le_bytes());
opus_tags.extend_from_slice(vendor);
opus_tags.extend_from_slice(&0u32.to_le_bytes()); // user comment list length
writer
.write_packet(opus_tags, stream_serial, PacketWriteEndInfo::EndPage, 0)
.map_err(|e| RecordingError::Encoding(format!("failed to write OpusTags: {e}")))?;
Ok(())
}
enum PlatformCapture {
#[cfg(target_os = "macos")]
Mac(crate::macos::tap_audio::AudioCaptureSession),
#[cfg(target_os = "windows")]
Windows(crate::windows::audio_capture::AudioCaptureSession),
}
impl PlatformCapture {
fn stop(&mut self) -> Result<()> {
match self {
#[cfg(target_os = "macos")]
PlatformCapture::Mac(session) => session.stop(),
#[cfg(target_os = "windows")]
PlatformCapture::Windows(session) => session.stop(),
#[allow(unreachable_patterns)]
_ => Err(RecordingError::UnsupportedPlatform.into()),
}
}
}
enum ControlMessage {
Stop(Sender<RecordingResult<RecordingArtifact>>),
}
struct ActiveRecording {
id: String,
control_tx: Sender<ControlMessage>,
controller: Option<JoinHandle<()>>,
}
static ACTIVE_RECORDING: LazyLock<Mutex<Option<ActiveRecording>>> = LazyLock::new(|| Mutex::new(None));
static START_RECORDING_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
fn now_millis() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as i64)
.unwrap_or(0)
}
fn new_recording_id() -> String {
format!("{}-{:08x}", now_millis(), rand::random::<u32>())
}
fn sanitize_id(id: Option<String>) -> String {
let raw = id.unwrap_or_else(new_recording_id);
let filtered: String = raw
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_')
.collect();
if filtered.is_empty() {
new_recording_id()
} else {
filtered
}
}
fn validate_output_dir(path: &str) -> Result<PathBuf> {
let dir = PathBuf::from(path);
if !dir.is_absolute() {
return Err(RecordingError::InvalidOutputDir.into());
}
fs::create_dir_all(&dir)?;
let normalized = dir.canonicalize().map_err(|_| RecordingError::InvalidOutputDir)?;
Ok(normalized)
}
#[cfg(target_os = "macos")]
fn build_excluded_refs(ids: &[u32]) -> Result<Vec<ApplicationInfo>> {
if ids.is_empty() {
return Ok(Vec::new());
}
let apps = ShareableContent::applications()?;
let mut excluded = Vec::new();
for app in apps {
if ids.contains(&(app.process_id as u32)) {
excluded.push(app);
}
}
Ok(excluded)
}
fn start_capture(opts: &RecordingStartOptions, tx: Sender<Vec<f32>>) -> Result<(PlatformCapture, u32, u32)> {
#[cfg(target_os = "macos")]
{
let callback = AudioCallback::Channel(tx);
let session = if let Some(app_id) = opts.app_process_id {
ShareableContent::tap_audio_with_callback(app_id, callback)?
} else {
let excluded_apps = build_excluded_refs(opts.exclude_process_ids.as_deref().unwrap_or(&[]))?;
let excluded_refs: Vec<&ApplicationInfo> = excluded_apps.iter().collect();
ShareableContent::tap_global_audio_with_callback(Some(excluded_refs), callback)?
};
let sample_rate = session.get_sample_rate()?.round().clamp(1.0, f64::MAX) as u32;
let channels = session.get_channels()?;
Ok((PlatformCapture::Mac(session), sample_rate, channels))
}
#[cfg(target_os = "windows")]
{
let callback = AudioCallback::Channel(tx);
let session =
ShareableContent::tap_audio_with_callback(opts.app_process_id.unwrap_or(0), callback, opts.sample_rate)?;
let sample_rate = session.get_sample_rate().round() as u32;
let channels = session.get_channels();
return Ok((PlatformCapture::Windows(session), sample_rate, channels));
}
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
{
let _ = opts;
let _ = tx;
Err(RecordingError::UnsupportedPlatform.into())
}
}
fn spawn_worker(
id: String,
filepath: PathBuf,
rx: Receiver<Vec<f32>>,
source_sample_rate: u32,
source_channels: u32,
encoding_channels: u32,
) -> JoinHandle<std::result::Result<RecordingArtifact, RecordingError>> {
thread::spawn(move || {
let mut writer = OggOpusWriter::new(filepath.clone(), source_sample_rate, source_channels, encoding_channels)?;
for chunk in rx {
writer.push_samples(&chunk)?;
}
let mut artifact = writer.finish()?;
artifact.id = id;
Ok(artifact)
})
}
fn spawn_recording_controller(
id: String,
filepath: PathBuf,
opts: RecordingStartOptions,
) -> (Receiver<RecordingResult<u32>>, Sender<ControlMessage>, JoinHandle<()>) {
let (started_tx, started_rx) = bounded(1);
let (control_tx, control_rx) = bounded(1);
let controller = thread::spawn(move || {
let (tx, rx) = bounded::<Vec<f32>>(32);
let (mut capture, capture_rate, capture_channels) = match start_capture(&opts, tx.clone()) {
Ok(capture) => capture,
Err(error) => {
let _ = started_tx.send(Err(RecordingError::Start(error.to_string())));
return;
}
};
let encoding_channels = match opts.channels {
Some(channels) => match normalize_channel_count(channels) {
Ok(_) => channels,
Err(error) => {
let _ = started_tx.send(Err(error));
return;
}
},
None => {
if capture_channels == 0 {
let _ = started_tx.send(Err(RecordingError::InvalidChannels(capture_channels)));
return;
}
if capture_channels > 1 { 2 } else { 1 }
}
};
let mut audio_tx = Some(tx);
let mut worker = Some(spawn_worker(
id,
filepath,
rx,
capture_rate,
capture_channels,
encoding_channels,
));
if started_tx.send(Ok(encoding_channels)).is_err() {
let _ = capture.stop();
drop(audio_tx.take());
if let Some(handle) = worker.take() {
let _ = handle.join();
}
return;
}
while let Ok(message) = control_rx.recv() {
match message {
ControlMessage::Stop(reply_tx) => {
let result = match capture.stop() {
Ok(()) => {
drop(audio_tx.take());
match worker.take() {
Some(handle) => match handle.join() {
Ok(result) => result,
Err(_) => Err(RecordingError::Join),
},
None => Err(RecordingError::Join),
}
}
Err(error) => Err(RecordingError::Start(error.to_string())),
};
let _ = reply_tx.send(result);
if worker.is_none() {
break;
}
}
}
}
if let Some(handle) = worker.take() {
let _ = capture.stop();
drop(audio_tx.take());
let _ = handle.join();
}
});
(started_rx, control_tx, controller)
}
fn cleanup_recording_controller(control_tx: &Sender<ControlMessage>, controller: JoinHandle<()>) {
let (reply_tx, reply_rx) = bounded(1);
let _ = control_tx.send(ControlMessage::Stop(reply_tx));
let _ = reply_rx.recv();
let _ = controller.join();
}
fn take_active_recording(id: &str) -> RecordingResult<ActiveRecording> {
let mut active_recording = ACTIVE_RECORDING
.lock()
.map_err(|_| RecordingError::Start("lock poisoned".into()))?;
let recording = active_recording.take().ok_or(RecordingError::NotFound)?;
if recording.id != id {
*active_recording = Some(recording);
return Err(RecordingError::NotFound);
}
Ok(recording)
}
fn join_active_recording(mut recording: ActiveRecording) -> RecordingResult<()> {
if let Some(handle) = recording.controller.take() {
handle.join().map_err(|_| RecordingError::Join)?;
}
Ok(())
}
#[napi]
pub fn start_recording(opts: RecordingStartOptions) -> Result<RecordingSessionMeta> {
if let Some(fmt) = opts.format.as_deref()
&& !fmt.eq_ignore_ascii_case("opus")
{
return Err(RecordingError::InvalidFormat(fmt.to_string()).into());
}
if let Some(channels) = opts.channels {
normalize_channel_count(channels)?;
}
let _start_lock = START_RECORDING_LOCK
.lock()
.map_err(|_| RecordingError::Start("lock poisoned".into()))?;
let output_dir = validate_output_dir(&opts.output_dir)?;
let id = sanitize_id(opts.id.clone());
{
let recording = ACTIVE_RECORDING
.lock()
.map_err(|_| RecordingError::Start("lock poisoned".into()))?;
if recording.is_some() {
return Err(RecordingError::Start("recording already active".into()).into());
}
}
let filepath = output_dir.join(format!("{id}.opus"));
if filepath.exists() {
fs::remove_file(&filepath)?;
}
let (started_rx, control_tx, controller) = spawn_recording_controller(id.clone(), filepath.clone(), opts);
let encoding_channels = started_rx
.recv()
.map_err(|_| RecordingError::Start("failed to start recording controller".into()))??;
let meta = RecordingSessionMeta {
id: id.clone(),
filepath: filepath.to_string_lossy().to_string(),
sample_rate: ENCODE_SAMPLE_RATE.as_i32() as u32,
channels: encoding_channels,
started_at: now_millis(),
};
let mut recording = match ACTIVE_RECORDING.lock() {
Ok(recording) => recording,
Err(_) => {
cleanup_recording_controller(&control_tx, controller);
return Err(RecordingError::Start("lock poisoned".into()).into());
}
};
if recording.is_some() {
cleanup_recording_controller(&control_tx, controller);
return Err(RecordingError::Start("recording already active".into()).into());
}
*recording = Some(ActiveRecording {
id,
control_tx,
controller: Some(controller),
});
Ok(meta)
}
#[napi]
pub fn stop_recording(id: String) -> Result<RecordingArtifact> {
let control_tx = {
let recording = ACTIVE_RECORDING
.lock()
.map_err(|_| RecordingError::Start("lock poisoned".into()))?;
let active = recording.as_ref().ok_or(RecordingError::NotFound)?;
if active.id != id {
return Err(RecordingError::NotFound.into());
}
active.control_tx.clone()
};
let (reply_tx, reply_rx) = bounded(1);
if control_tx.send(ControlMessage::Stop(reply_tx)).is_err() {
if let Ok(recording) = take_active_recording(&id) {
let _ = join_active_recording(recording);
}
return Err(RecordingError::Join.into());
}
let response = match reply_rx.recv() {
Ok(response) => response,
Err(_) => {
if let Ok(recording) = take_active_recording(&id) {
let _ = join_active_recording(recording);
}
return Err(RecordingError::Join.into());
}
};
let artifact = match response {
Ok(artifact) => artifact,
Err(RecordingError::Start(message)) => {
return Err(RecordingError::Start(message).into());
}
Err(error) => {
if let Ok(recording) = take_active_recording(&id) {
let _ = join_active_recording(recording);
}
return Err(error.into());
}
};
let active_recording = take_active_recording(&id)?;
join_active_recording(active_recording)?;
Ok(artifact)
}
#[cfg(test)]
mod tests {
use std::{env, fs::File, path::PathBuf};
use ogg::PacketReader;
use super::{OggOpusWriter, convert_interleaved_channels};
fn temp_recording_path() -> PathBuf {
env::temp_dir().join(format!("affine-recording-test-{}.opus", rand::random::<u64>()))
}
#[test]
fn finish_marks_last_audio_packet_as_end_of_stream() {
let path = temp_recording_path();
let samples = vec![0.0f32; 960 * 2];
let artifact = {
let mut writer = OggOpusWriter::new(path.clone(), 48_000, 2, 2).expect("create writer");
writer.push_samples(&samples).expect("push samples");
writer.finish().expect("finish writer")
};
assert_eq!(artifact.filepath, path.to_string_lossy());
assert!(artifact.size > 0);
assert_eq!(artifact.sample_rate, 48_000);
assert_eq!(artifact.channels, 2);
let mut reader = PacketReader::new(File::open(&path).expect("open opus file"));
let mut packets = Vec::new();
while let Some(packet) = reader.read_packet().expect("read packet") {
packets.push(packet);
}
assert_eq!(packets.len(), 3);
assert_eq!(&packets[0].data[..8], b"OpusHead");
assert_eq!(&packets[1].data[..8], b"OpusTags");
assert!(!packets[2].data.is_empty());
assert!(packets[2].last_in_stream());
std::fs::remove_file(path).ok();
}
#[test]
fn finish_flushes_short_resampled_recordings() {
let path = temp_recording_path();
let samples = vec![0.25f32; 512 * 2];
let artifact = {
let mut writer = OggOpusWriter::new(path.clone(), 44_100, 2, 2).expect("create writer");
writer.push_samples(&samples).expect("push samples");
writer.finish().expect("finish writer")
};
assert!(artifact.size > 0);
assert!(artifact.duration_ms > 0);
let mut reader = PacketReader::new(File::open(&path).expect("open opus file"));
let mut packets = Vec::new();
while let Some(packet) = reader.read_packet().expect("read packet") {
packets.push(packet);
}
assert_eq!(packets.len(), 3);
assert!(packets[2].last_in_stream());
std::fs::remove_file(path).ok();
}
#[test]
fn converts_interleaved_channels_before_encoding() {
assert_eq!(
convert_interleaved_channels(&[1.0, 2.0], 1, 2).expect("mono to stereo"),
vec![1.0, 1.0, 2.0, 2.0]
);
assert_eq!(
convert_interleaved_channels(&[1.0, 3.0, 5.0, 7.0], 2, 1).expect("stereo to mono"),
vec![2.0, 6.0]
);
assert_eq!(
convert_interleaved_channels(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2).expect("surround to stereo"),
vec![3.0, 3.0, 4.0, 4.0]
);
}
}

View File

@@ -13,12 +13,14 @@ use cpal::{
traits::{DeviceTrait, HostTrait, StreamTrait},
};
use crossbeam_channel::unbounded;
use napi::{Error, Status, bindgen_prelude::Result};
use napi::{
Error, Status,
bindgen_prelude::{Float32Array, Result},
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
};
use napi_derive::napi;
use rubato::{FastFixedIn, PolynomialDegree, Resampler};
use crate::audio_callback::AudioCallback;
const RESAMPLER_INPUT_CHUNK: usize = 1024; // samples per channel
const TARGET_FRAME_SIZE: usize = 1024; // frame size returned to JS (in mono samples)
@@ -214,10 +216,7 @@ impl Drop for AudioCaptureSession {
}
}
pub fn start_recording(
audio_buffer_callback: AudioCallback,
target_sample_rate: Option<SampleRate>,
) -> Result<AudioCaptureSession> {
pub fn start_recording(audio_buffer_callback: ThreadsafeFunction<Float32Array, ()>) -> Result<AudioCaptureSession> {
let available_hosts = cpal::available_hosts();
let host_id = available_hosts
.first()
@@ -241,7 +240,7 @@ pub fn start_recording(
let mic_sample_rate = mic_config.sample_rate();
let lb_sample_rate = lb_config.sample_rate();
let target_rate = target_sample_rate.unwrap_or(SampleRate(mic_sample_rate.min(lb_sample_rate).0));
let target_rate = SampleRate(mic_sample_rate.min(lb_sample_rate).0);
let mic_channels = mic_config.channels();
let lb_channels = lb_config.channels();
@@ -333,7 +332,7 @@ pub fn start_recording(
let lb_chunk: Vec<f32> = post_lb.drain(..TARGET_FRAME_SIZE).collect();
let mixed = mix(&mic_chunk, &lb_chunk);
if !mixed.is_empty() {
audio_buffer_callback.call(mixed);
let _ = audio_buffer_callback.call(Ok(mixed.clone().into()), ThreadsafeFunctionCallMode::NonBlocking);
}
}

View File

@@ -10,7 +10,6 @@ use std::{
time::Duration,
};
use cpal::SampleRate;
use napi::{
bindgen_prelude::{Buffer, Error, Result, Status},
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
@@ -28,7 +27,6 @@ use windows::Win32::System::{
};
// Import the function from microphone_listener
use crate::audio_callback::AudioCallback;
use crate::windows::microphone_listener::is_process_actively_using_microphone;
// Type alias to match macOS API
@@ -216,15 +214,6 @@ impl ShareableContent {
}
}
pub(crate) fn tap_audio_with_callback(
_process_id: u32,
audio_stream_callback: AudioCallback,
target_sample_rate: Option<u32>,
) -> Result<AudioCaptureSession> {
let target = target_sample_rate.map(SampleRate);
crate::windows::audio_capture::start_recording(audio_stream_callback, target)
}
#[napi]
pub fn tap_audio(
_process_id: u32, // Currently unused - Windows captures global audio
@@ -232,18 +221,7 @@ impl ShareableContent {
) -> Result<AudioCaptureSession> {
// On Windows with CPAL, we capture global audio (mic + loopback)
// since per-application audio tapping isn't supported the same way as macOS
ShareableContent::tap_audio_with_callback(_process_id, AudioCallback::Js(Arc::new(audio_stream_callback)), None)
}
pub(crate) fn tap_global_audio_with_callback(
_excluded_processes: Option<Vec<&ApplicationInfo>>,
audio_stream_callback: AudioCallback,
target_sample_rate: Option<u32>,
) -> Result<AudioCaptureSession> {
let target = target_sample_rate.map(SampleRate);
// Delegate to audio_capture::start_recording which handles mixing mic +
// loopback
crate::windows::audio_capture::start_recording(audio_stream_callback, target)
crate::windows::audio_capture::start_recording(audio_stream_callback)
}
#[napi]
@@ -251,11 +229,9 @@ impl ShareableContent {
_excluded_processes: Option<Vec<&ApplicationInfo>>,
audio_stream_callback: ThreadsafeFunction<napi::bindgen_prelude::Float32Array, ()>,
) -> Result<AudioCaptureSession> {
ShareableContent::tap_global_audio_with_callback(
_excluded_processes,
AudioCallback::Js(Arc::new(audio_stream_callback)),
None,
)
// Delegate to audio_capture::start_recording which handles mixing mic +
// loopback
crate::windows::audio_capture::start_recording(audio_stream_callback)
}
#[napi]

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "1.94.0"
channel = "1.93.1"
profile = "default"

View File

@@ -14,7 +14,6 @@ import {
import { createLocalWorkspace } from '@affine-test/kit/utils/workspace';
import { expect } from '@playwright/test';
import fs from 'fs-extra';
import type { ElectronApplication } from 'playwright';
declare global {
interface Window {
@@ -22,32 +21,6 @@ declare global {
}
}
async function mockNextSaveDialog(
electronApp: ElectronApplication,
filePath: string
) {
await electronApp.evaluate(({ dialog }, mockedFilePath) => {
const original = dialog.showSaveDialog.bind(dialog);
dialog.showSaveDialog = async () => {
dialog.showSaveDialog = original;
return { canceled: false, filePath: mockedFilePath };
};
}, filePath);
}
async function mockNextOpenDialog(
electronApp: ElectronApplication,
filePath: string
) {
await electronApp.evaluate(({ dialog }, mockedFilePath) => {
const original = dialog.showOpenDialog.bind(dialog);
dialog.showOpenDialog = async () => {
dialog.showOpenDialog = original;
return { canceled: false, filePaths: [mockedFilePath] };
};
}, filePath);
}
test('check workspace has a DB file', async ({ appInfo, workspace }) => {
const w = await workspace.current();
const dbPath = path.join(
@@ -61,7 +34,7 @@ test('check workspace has a DB file', async ({ appInfo, workspace }) => {
expect(await fs.exists(dbPath)).toBe(true);
});
test('export then add', async ({ electronApp, page, appInfo, workspace }) => {
test('export then add', async ({ page, appInfo, workspace }) => {
await clickNewPageButton(page);
const w = await workspace.current();
@@ -87,7 +60,11 @@ test('export then add', async ({ electronApp, page, appInfo, workspace }) => {
const tmpPath = path.join(appInfo.sessionData, w.meta.id + '-tmp.db');
// export db file to tmp folder
await mockNextSaveDialog(electronApp, tmpPath);
await page.evaluate(tmpPath => {
return window.__apis?.dialog.setFakeDialogResult({
filePath: tmpPath,
});
}, tmpPath);
await page.getByTestId('workspace-setting:storage').click();
await page.getByTestId('export-affine-backup').click();
@@ -101,7 +78,11 @@ test('export then add', async ({ electronApp, page, appInfo, workspace }) => {
// in the codebase
await clickSideBarCurrentWorkspaceBanner(page);
await mockNextOpenDialog(electronApp, tmpPath);
await page.evaluate(tmpPath => {
return window.__apis?.dialog.setFakeDialogResult({
filePath: tmpPath,
});
}, tmpPath);
// load the db file
await page.getByTestId('add-workspace').click();