mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-07 18:13:43 +00:00
Compare commits
45 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22187f964a | ||
|
|
cf7b026832 | ||
|
|
e6818b4f14 | ||
|
|
aab9925aa1 | ||
|
|
86218d87c2 | ||
|
|
de4084495b | ||
|
|
13a2562282 | ||
|
|
556956ced2 | ||
|
|
bf6c9a5955 | ||
|
|
9ef8829ef1 | ||
|
|
de91027852 | ||
|
|
7235779b02 | ||
|
|
ba356f4412 | ||
|
|
602d932065 | ||
|
|
8dfa601771 | ||
|
|
481a2269f8 | ||
|
|
555f203be6 | ||
|
|
5c1f78afd4 | ||
|
|
d6ad7d566f | ||
|
|
b79d13bcc8 | ||
|
|
a0ce75c902 | ||
|
|
e8285289fe | ||
|
|
cc7740d8d3 | ||
|
|
61870c04d0 | ||
|
|
10df1fb4b7 | ||
|
|
0bc09a9333 | ||
|
|
f0d127fa29 | ||
|
|
fc729d6a32 | ||
|
|
ef7ba273ab | ||
|
|
b8b30e79e5 | ||
|
|
2a6ea3c9c6 | ||
|
|
c62d79ab14 | ||
|
|
27d0fc5108 | ||
|
|
40e381e272 | ||
|
|
15e99c7819 | ||
|
|
3870801ebb | ||
|
|
0957c30e74 | ||
|
|
90e4a9b181 | ||
|
|
1997f24414 | ||
|
|
3f8fe5cfae | ||
|
|
8c4a42f0e6 | ||
|
|
4d484ea814 | ||
|
|
3bbb657a78 | ||
|
|
39acb51d87 | ||
|
|
d72dbe682c |
7
.github/helm/affine/templates/ingress.yaml
vendored
7
.github/helm/affine/templates/ingress.yaml
vendored
@@ -74,4 +74,11 @@ spec:
|
||||
name: affine-web
|
||||
port:
|
||||
number: {{ .Values.web.service.port }}
|
||||
- path: /js/worker.(.+).js
|
||||
pathType: ImplementationSpecific
|
||||
backend:
|
||||
service:
|
||||
name: affine-web
|
||||
port:
|
||||
number: {{ .Values.web.service.port }}
|
||||
{{- end }}
|
||||
|
||||
2
.github/workflows/workers.yml
vendored
2
.github/workflows/workers.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Publish
|
||||
uses: cloudflare/wrangler-action@v3.6.1
|
||||
uses: cloudflare/wrangler-action@v3.7.0
|
||||
with:
|
||||
apiToken: ${{ secrets.CF_API_TOKEN }}
|
||||
accountId: ${{ secrets.CF_ACCOUNT_ID }}
|
||||
|
||||
39
.yarn/patches/yjs-npm-13.6.18-ad0d5f7c43.patch
Normal file
39
.yarn/patches/yjs-npm-13.6.18-ad0d5f7c43.patch
Normal file
@@ -0,0 +1,39 @@
|
||||
diff --git a/dist/yjs.cjs b/dist/yjs.cjs
|
||||
index d2dc06ae11a6eb44f8c8445d4298c0e89c3e4da2..a30ab04fa9f3b77666939caa88335c68c40f194c 100644
|
||||
--- a/dist/yjs.cjs
|
||||
+++ b/dist/yjs.cjs
|
||||
@@ -414,7 +414,7 @@ const equalDeleteSets = (ds1, ds2) => {
|
||||
*/
|
||||
|
||||
|
||||
-const generateNewClientId = random__namespace.uint32;
|
||||
+const generateNewClientId = random__namespace.uint53;
|
||||
|
||||
/**
|
||||
* @typedef {Object} DocOpts
|
||||
diff --git a/dist/yjs.mjs b/dist/yjs.mjs
|
||||
index 20c9e58c32bcb6bc714200a2561fd1f542c49523..14267e5e36d9781ca3810d5b70ff8c051dac779e 100644
|
||||
--- a/dist/yjs.mjs
|
||||
+++ b/dist/yjs.mjs
|
||||
@@ -378,7 +378,7 @@ const equalDeleteSets = (ds1, ds2) => {
|
||||
*/
|
||||
|
||||
|
||||
-const generateNewClientId = random.uint32;
|
||||
+const generateNewClientId = random.uint53;
|
||||
|
||||
/**
|
||||
* @typedef {Object} DocOpts
|
||||
diff --git a/src/utils/Doc.js b/src/utils/Doc.js
|
||||
index 62643617c86e57c64dd9babdb792fa8888357ec0..4df5048ab12af1ae0f1154da67f06dce1fda7b49 100644
|
||||
--- a/src/utils/Doc.js
|
||||
+++ b/src/utils/Doc.js
|
||||
@@ -20,7 +20,7 @@ import * as map from 'lib0/map'
|
||||
import * as array from 'lib0/array'
|
||||
import * as promise from 'lib0/promise'
|
||||
|
||||
-export const generateNewClientId = random.uint32
|
||||
+export const generateNewClientId = random.uint53
|
||||
|
||||
/**
|
||||
* @typedef {Object} DocOpts
|
||||
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -993,14 +993,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "napi"
|
||||
version = "3.0.0-alpha.2"
|
||||
version = "3.0.0-alpha.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "99d38fbf4cbfd7d2785d153f4dcce374d515d3dabd688504dd9093f8135829d0"
|
||||
checksum = "9e1c3a7423adc069939192859f1c5b1e6b576d662a183a70839f5b098dd807ca"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bitflags 2.5.0",
|
||||
"chrono",
|
||||
"ctor",
|
||||
"napi-build",
|
||||
"napi-sys",
|
||||
"once_cell",
|
||||
"serde",
|
||||
@@ -1015,9 +1016,9 @@ checksum = "e1c0f5d67ee408a4685b61f5ab7e58605c8ae3f2b4189f0127d804ff13d5560a"
|
||||
|
||||
[[package]]
|
||||
name = "napi-derive"
|
||||
version = "3.0.0-alpha.1"
|
||||
version = "3.0.0-alpha.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c230c813bfd4d6c7aafead3c075b37f0cf7fecb38be8f4cf5cfcee0b2c273ad0"
|
||||
checksum = "9f728c2fc73c9be638b4fc65de1f15309246a1c2d355cb1508fc26a4a265873f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"convert_case",
|
||||
@@ -1029,9 +1030,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "napi-derive-backend"
|
||||
version = "2.0.0-alpha.1"
|
||||
version = "2.0.0-alpha.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4370cc24c2e58d0f3393527b282eb00f1158b304248f549e1ec81bd2927db5fe"
|
||||
checksum = "665de86dea7d1bf1ea6628cb8544edb5008f73e15b5bf5c69e54211c19988b3b"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"once_cell",
|
||||
@@ -1555,9 +1556,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.117"
|
||||
version = "1.0.120"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3"
|
||||
checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
@@ -2178,9 +2179,9 @@ checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.9.0"
|
||||
version = "1.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ea73390fe27785838dcbf75b91b1d84799e28f1ce71e6f372a5dc2200c80de5"
|
||||
checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"rand",
|
||||
|
||||
@@ -6,8 +6,8 @@ We recommend users to always use the latest major version. Security updates will
|
||||
|
||||
| Version | Supported |
|
||||
| --------------- | ------------------ |
|
||||
| 0.14.x (stable) | :white_check_mark: |
|
||||
| < 0.14.x | :x: |
|
||||
| 0.15.x (stable) | :white_check_mark: |
|
||||
| < 0.15.x | :x: |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"devDependencies": {
|
||||
"nodemon": "^3.1.0",
|
||||
"serve": "^14.2.1",
|
||||
"typedoc": "^0.25.13"
|
||||
"typedoc": "^0.26.0"
|
||||
},
|
||||
"nodemonConfig": {
|
||||
"watch": [
|
||||
|
||||
@@ -59,8 +59,8 @@
|
||||
"@faker-js/faker": "^8.4.1",
|
||||
"@istanbuljs/schema": "^0.1.3",
|
||||
"@magic-works/i18n-codegen": "^0.6.0",
|
||||
"@nx/vite": "19.2.3",
|
||||
"@playwright/test": "^1.44.0",
|
||||
"@nx/vite": "19.4.1",
|
||||
"@playwright/test": "=1.44.1",
|
||||
"@taplo/cli": "^0.7.0",
|
||||
"@testing-library/react": "^16.0.0",
|
||||
"@toeverything/infra": "workspace:*",
|
||||
@@ -75,7 +75,7 @@
|
||||
"@vitest/coverage-istanbul": "1.6.0",
|
||||
"@vitest/ui": "1.6.0",
|
||||
"cross-env": "^7.0.3",
|
||||
"electron": "^30.1.1",
|
||||
"electron": "~30.1.0",
|
||||
"eslint": "^8.57.0",
|
||||
"eslint-config-prettier": "^9.1.0",
|
||||
"eslint-plugin-import-x": "^0.5.0",
|
||||
@@ -95,7 +95,7 @@
|
||||
"nanoid": "^5.0.7",
|
||||
"nx": "^19.0.0",
|
||||
"nyc": "^17.0.0",
|
||||
"oxlint": "0.5.0",
|
||||
"oxlint": "0.5.2",
|
||||
"prettier": "^3.2.5",
|
||||
"semver": "^7.6.0",
|
||||
"serve": "^14.2.1",
|
||||
|
||||
12
packages/backend/native/index.d.ts
vendored
12
packages/backend/native/index.d.ts
vendored
@@ -1,20 +1,20 @@
|
||||
/* auto-generated by NAPI-RS */
|
||||
/* eslint-disable */
|
||||
export class Tokenizer {
|
||||
export declare class Tokenizer {
|
||||
count(content: string, allowedSpecial?: Array<string> | undefined | null): number
|
||||
}
|
||||
|
||||
export function fromModelName(modelName: string): Tokenizer | null
|
||||
export declare function fromModelName(modelName: string): Tokenizer | null
|
||||
|
||||
export function getMime(input: Uint8Array): string
|
||||
export declare function getMime(input: Uint8Array): string
|
||||
|
||||
/**
|
||||
* Merge updates in form like `Y.applyUpdate(doc, update)` way and return the
|
||||
* result binary.
|
||||
*/
|
||||
export function mergeUpdatesInApplyWay(updates: Array<Buffer>): Buffer
|
||||
export declare function mergeUpdatesInApplyWay(updates: Array<Buffer>): Buffer
|
||||
|
||||
export function mintChallengeResponse(resource: string, bits?: number | undefined | null): Promise<string>
|
||||
export declare function mintChallengeResponse(resource: string, bits?: number | undefined | null): Promise<string>
|
||||
|
||||
export function verifyChallengeResponse(response: string, bits: number, resource: string): Promise<boolean>
|
||||
export declare function verifyChallengeResponse(response: string, bits: number, resource: string): Promise<boolean>
|
||||
|
||||
|
||||
@@ -33,12 +33,12 @@
|
||||
"build:debug": "napi build"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@napi-rs/cli": "3.0.0-alpha.55",
|
||||
"@napi-rs/cli": "3.0.0-alpha.56",
|
||||
"lib0": "^0.2.93",
|
||||
"nx": "^19.0.0",
|
||||
"nx-cloud": "^19.0.0",
|
||||
"tiktoken": "^1.0.15",
|
||||
"tinybench": "^2.8.0",
|
||||
"yjs": "^13.6.14"
|
||||
"yjs": "patch:yjs@npm%3A13.6.18#~/.yarn/patches/yjs-npm-13.6.18-ad0d5f7c43.patch"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "user_subscriptions" ALTER COLUMN "stripe_subscription_id" DROP NOT NULL,
|
||||
ALTER COLUMN "end" DROP NOT NULL;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "ai_sessions_metadata" ADD COLUMN "parent_session_id" VARCHAR(36);
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "ai_prompts_metadata" ADD COLUMN "config" JSON;
|
||||
@@ -21,7 +21,7 @@
|
||||
"dependencies": {
|
||||
"@apollo/server": "^4.10.2",
|
||||
"@aws-sdk/client-s3": "^3.552.0",
|
||||
"@fal-ai/serverless-client": "^0.10.2",
|
||||
"@fal-ai/serverless-client": "^0.12.0",
|
||||
"@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.18.0",
|
||||
"@google-cloud/opentelemetry-cloud-trace-exporter": "^2.2.0",
|
||||
"@google-cloud/opentelemetry-resource-util": "^2.2.0",
|
||||
@@ -35,7 +35,7 @@
|
||||
"@nestjs/platform-socket.io": "^10.3.7",
|
||||
"@nestjs/schedule": "^4.0.1",
|
||||
"@nestjs/serve-static": "^4.0.2",
|
||||
"@nestjs/throttler": "5.1.2",
|
||||
"@nestjs/throttler": "5.2.0",
|
||||
"@nestjs/websockets": "^10.3.7",
|
||||
"@node-rs/argon2": "^1.8.0",
|
||||
"@node-rs/crc32": "^1.10.0",
|
||||
@@ -46,11 +46,11 @@
|
||||
"@opentelemetry/exporter-zipkin": "^1.25.0",
|
||||
"@opentelemetry/host-metrics": "^0.35.2",
|
||||
"@opentelemetry/instrumentation": "^0.52.0",
|
||||
"@opentelemetry/instrumentation-graphql": "^0.41.0",
|
||||
"@opentelemetry/instrumentation-graphql": "^0.42.0",
|
||||
"@opentelemetry/instrumentation-http": "^0.52.0",
|
||||
"@opentelemetry/instrumentation-ioredis": "^0.41.0",
|
||||
"@opentelemetry/instrumentation-nestjs-core": "^0.38.0",
|
||||
"@opentelemetry/instrumentation-socket.io": "^0.40.0",
|
||||
"@opentelemetry/instrumentation-ioredis": "^0.42.0",
|
||||
"@opentelemetry/instrumentation-nestjs-core": "^0.39.0",
|
||||
"@opentelemetry/instrumentation-socket.io": "^0.41.0",
|
||||
"@opentelemetry/resources": "^1.25.0",
|
||||
"@opentelemetry/sdk-metrics": "^1.25.0",
|
||||
"@opentelemetry/sdk-node": "^0.52.0",
|
||||
@@ -95,7 +95,7 @@
|
||||
"ts-node": "^10.9.2",
|
||||
"typescript": "^5.4.5",
|
||||
"ws": "^8.16.0",
|
||||
"yjs": "^13.6.14",
|
||||
"yjs": "patch:yjs@npm%3A13.6.18#~/.yarn/patches/yjs-npm-13.6.18-ad0d5f7c43.patch",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -377,14 +377,14 @@ model UserSubscription {
|
||||
plan String @db.VarChar(20)
|
||||
// yearly/monthly
|
||||
recurring String @db.VarChar(20)
|
||||
// subscription.id
|
||||
stripeSubscriptionId String @unique @map("stripe_subscription_id")
|
||||
// subscription.id, null for linefetime payment
|
||||
stripeSubscriptionId String? @unique @map("stripe_subscription_id")
|
||||
// subscription.status, active/past_due/canceled/unpaid...
|
||||
status String @db.VarChar(20)
|
||||
// subscription.current_period_start
|
||||
start DateTime @map("start") @db.Timestamptz(6)
|
||||
// subscription.current_period_end
|
||||
end DateTime @map("end") @db.Timestamptz(6)
|
||||
// subscription.current_period_end, null for lifetime payment
|
||||
end DateTime? @map("end") @db.Timestamptz(6)
|
||||
// subscription.billing_cycle_anchor
|
||||
nextBillAt DateTime? @map("next_bill_at") @db.Timestamptz(6)
|
||||
// subscription.canceled_at
|
||||
@@ -457,6 +457,7 @@ model AiPrompt {
|
||||
// it is only used in the frontend and does not affect the backend
|
||||
action String? @db.VarChar
|
||||
model String @db.VarChar
|
||||
config Json? @db.Json
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
|
||||
|
||||
messages AiPromptMessage[]
|
||||
@@ -481,15 +482,17 @@ model AiSessionMessage {
|
||||
}
|
||||
|
||||
model AiSession {
|
||||
id String @id @default(uuid()) @db.VarChar(36)
|
||||
userId String @map("user_id") @db.VarChar(36)
|
||||
workspaceId String @map("workspace_id") @db.VarChar(36)
|
||||
docId String @map("doc_id") @db.VarChar(36)
|
||||
promptName String @map("prompt_name") @db.VarChar(32)
|
||||
messageCost Int @default(0)
|
||||
tokenCost Int @default(0)
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
|
||||
deletedAt DateTime? @map("deleted_at") @db.Timestamptz(6)
|
||||
id String @id @default(uuid()) @db.VarChar(36)
|
||||
userId String @map("user_id") @db.VarChar(36)
|
||||
workspaceId String @map("workspace_id") @db.VarChar(36)
|
||||
docId String @map("doc_id") @db.VarChar(36)
|
||||
promptName String @map("prompt_name") @db.VarChar(32)
|
||||
// the session id of the parent session if this session is a forked session
|
||||
parentSessionId String? @map("parent_session_id") @db.VarChar(36)
|
||||
messageCost Int @default(0)
|
||||
tokenCost Int @default(0)
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
|
||||
deletedAt DateTime? @map("deleted_at") @db.Timestamptz(6)
|
||||
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade)
|
||||
|
||||
@@ -155,6 +155,25 @@ export const Quotas: Quota[] = [
|
||||
copilotActionLimit: 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
feature: QuotaType.LifetimeProPlanV1,
|
||||
type: FeatureKind.Quota,
|
||||
version: 1,
|
||||
configs: {
|
||||
// quota name
|
||||
name: 'Lifetime Pro',
|
||||
// single blob limit 100MB
|
||||
blobLimit: 100 * OneMB,
|
||||
// total blob limit 1TB
|
||||
storageQuota: 1024 * OneGB,
|
||||
// history period of validity 30 days
|
||||
historyPeriod: 30 * OneDay,
|
||||
// member limit 10
|
||||
memberLimit: 10,
|
||||
// copilot action limit 10
|
||||
copilotActionLimit: 10,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
export function getLatestQuota(type: QuotaType) {
|
||||
@@ -165,6 +184,7 @@ export function getLatestQuota(type: QuotaType) {
|
||||
|
||||
export const FreePlan = getLatestQuota(QuotaType.FreePlanV1);
|
||||
export const ProPlan = getLatestQuota(QuotaType.ProPlanV1);
|
||||
export const LifetimeProPlan = getLatestQuota(QuotaType.LifetimeProPlanV1);
|
||||
|
||||
export const Quota_FreePlanV1_1 = {
|
||||
feature: Quotas[5].feature,
|
||||
|
||||
@@ -3,7 +3,6 @@ import { PrismaClient } from '@prisma/client';
|
||||
|
||||
import type { EventPayload } from '../../fundamentals';
|
||||
import { OnEvent, PrismaTransaction } from '../../fundamentals';
|
||||
import { SubscriptionPlan } from '../../plugins/payment/types';
|
||||
import { FeatureManagementService } from '../features/management';
|
||||
import { FeatureKind } from '../features/types';
|
||||
import { QuotaConfig } from './quota';
|
||||
@@ -152,15 +151,18 @@ export class QuotaService {
|
||||
async onSubscriptionUpdated({
|
||||
userId,
|
||||
plan,
|
||||
recurring,
|
||||
}: EventPayload<'user.subscription.activated'>) {
|
||||
switch (plan) {
|
||||
case SubscriptionPlan.AI:
|
||||
case 'ai':
|
||||
await this.feature.addCopilot(userId, 'subscription activated');
|
||||
break;
|
||||
case SubscriptionPlan.Pro:
|
||||
case 'pro':
|
||||
await this.switchUserQuota(
|
||||
userId,
|
||||
QuotaType.ProPlanV1,
|
||||
recurring === 'lifetime'
|
||||
? QuotaType.LifetimeProPlanV1
|
||||
: QuotaType.ProPlanV1,
|
||||
'subscription activated'
|
||||
);
|
||||
break;
|
||||
@@ -175,16 +177,22 @@ export class QuotaService {
|
||||
plan,
|
||||
}: EventPayload<'user.subscription.canceled'>) {
|
||||
switch (plan) {
|
||||
case SubscriptionPlan.AI:
|
||||
case 'ai':
|
||||
await this.feature.removeCopilot(userId);
|
||||
break;
|
||||
case SubscriptionPlan.Pro:
|
||||
await this.switchUserQuota(
|
||||
userId,
|
||||
QuotaType.FreePlanV1,
|
||||
'subscription canceled'
|
||||
);
|
||||
case 'pro': {
|
||||
// edge case: when user switch from recurring Pro plan to `Lifetime` plan,
|
||||
// a subscription canceled event will be triggered because `Lifetime` plan is not subscription based
|
||||
const quota = await this.getUserQuota(userId);
|
||||
if (quota.feature.name !== QuotaType.LifetimeProPlanV1) {
|
||||
await this.switchUserQuota(
|
||||
userId,
|
||||
QuotaType.FreePlanV1,
|
||||
'subscription canceled'
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import { ByteUnit, OneDay, OneKB } from './constant';
|
||||
export enum QuotaType {
|
||||
FreePlanV1 = 'free_plan_v1',
|
||||
ProPlanV1 = 'pro_plan_v1',
|
||||
LifetimeProPlanV1 = 'lifetime_pro_plan_v1',
|
||||
// only for test, smaller quota
|
||||
RestrictedPlanV1 = 'restricted_plan_v1',
|
||||
}
|
||||
@@ -25,6 +26,7 @@ const quotaPlan = z.object({
|
||||
feature: z.enum([
|
||||
QuotaType.FreePlanV1,
|
||||
QuotaType.ProPlanV1,
|
||||
QuotaType.LifetimeProPlanV1,
|
||||
QuotaType.RestrictedPlanV1,
|
||||
]),
|
||||
configs: z.object({
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
import { QuotaType } from '../../core/quota';
|
||||
import { upsertLatestQuotaVersion } from './utils/user-quotas';
|
||||
|
||||
export class LifetimeProQuota1719917815802 {
|
||||
// do the migration
|
||||
static async up(db: PrismaClient) {
|
||||
await upsertLatestQuotaVersion(db, QuotaType.LifetimeProPlanV1);
|
||||
}
|
||||
|
||||
// revert the migration
|
||||
static async down(_db: PrismaClient) {}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
import { refreshPrompts } from './utils/prompts';
|
||||
|
||||
export class UpdatePrompts1720413813993 {
|
||||
// do the migration
|
||||
static async up(db: PrismaClient) {
|
||||
await refreshPrompts(db);
|
||||
}
|
||||
|
||||
// revert the migration
|
||||
static async down(_db: PrismaClient) {}
|
||||
}
|
||||
@@ -6,10 +6,20 @@ type PromptMessage = {
|
||||
params?: Record<string, string | string[]>;
|
||||
};
|
||||
|
||||
type PromptConfig = {
|
||||
jsonMode?: boolean;
|
||||
frequencyPenalty?: number;
|
||||
presencePenalty?: number;
|
||||
temperature?: number;
|
||||
topP?: number;
|
||||
maxTokens?: number;
|
||||
};
|
||||
|
||||
type Prompt = {
|
||||
name: string;
|
||||
action?: string;
|
||||
model: string;
|
||||
config?: PromptConfig;
|
||||
messages: PromptMessage[];
|
||||
};
|
||||
|
||||
@@ -465,6 +475,7 @@ content: {{content}}`,
|
||||
name: 'workflow:presentation:step1',
|
||||
action: 'workflow:presentation:step1',
|
||||
model: 'gpt-4o',
|
||||
config: { temperature: 0.7 },
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
@@ -481,10 +492,16 @@ content: {{content}}`,
|
||||
name: 'workflow:presentation:step2',
|
||||
action: 'workflow:presentation:step2',
|
||||
model: 'gpt-4o',
|
||||
config: {
|
||||
frequencyPenalty: 0.5,
|
||||
presencePenalty: 0.5,
|
||||
temperature: 0.2,
|
||||
topP: 0.75,
|
||||
},
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: `You are a PPT creator. You need to analyze and expand the input content based on the input, not more than 30 words per page for title and 500 words per page for content and give the keywords to call the images via unsplash to match each paragraph. Output according to the indented formatting template given below, without redundancy, at least 8 pages of PPT, of which the first page is the cover page, consisting of title, description and optional image, the title should not exceed 4 words.\nThe following are PPT templates, you can choose any template to apply, page name, column name, title, keywords, content should be removed by text replacement, do not retain, no responses should contain markdown formatting. Keywords need to be generic enough for broad, mass categorization. The output ignores template titles like template1 and template2. The first template is allowed to be used only once and as a cover, please strictly follow the template's ND-JSON field, format and my requirements, or penalties will be applied:\n{"page":1,"type":"name","content":"page name"}\n{"page":1,"type":"title","content":"title"}\n{"page":1,"type":"content","content":"keywords"}\n{"page":1,"type":"content","content":"description"}\n{"page":2,"type":"name","content":"page name"}\n{"page":2,"type":"title","content":"section name"}\n{"page":2,"type":"content","content":"keywords"}\n{"page":2,"type":"content","content":"description"}\n{"page":2,"type":"title","content":"section name"}\n{"page":2,"type":"content","content":"keywords"}\n{"page":2,"type":"content","content":"description"}\n{"page":3,"type":"name","content":"page name"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}`,
|
||||
content: `You are the creator of the mind map. You need to analyze and expand on the input and output it according to the indentation formatting template given below without redundancy.\nBelow is an example of indentation for a mind map, the title and content needs to be removed by text replacement and not retained. Please strictly adhere to the hierarchical indentation of the template and my requirements, bold, headings and other formatting (e.g. #, **) are not allowed, a maximum of five levels of indentation is allowed, and the last node of each node should make a judgment on whether to make a detailed statement or not based on the topic:\nexmaple:\n- {topic}\n - {Level 1}\n - {Level 2}\n - {Level 3}\n - {Level 4}\n - {Level 1}\n - {Level 2}\n - {Level 3}\n - {Level 1}\n - {Level 2}\n - {Level 3}`,
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
@@ -496,26 +513,6 @@ content: {{content}}`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'workflow:presentation:step4',
|
||||
action: 'workflow:presentation:step4',
|
||||
model: 'gpt-4o',
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content:
|
||||
"You are a ND-JSON text format checking model with very strict formatting requirements, and you need to optimize the input so that it fully conforms to the template's indentation format and output.\nPage names, section names, titles, keywords, and content should be removed via text replacement and not retained. The first template is only allowed to be used once and as a cover, please strictly adhere to the template's hierarchical indentation and my requirement that bold, headings, and other formatting (e.g., #, **, ```) are not allowed or penalties will be applied, no responses should contain markdown formatting.",
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: `You are a PPT creator. You need to analyze and expand the input content based on the input, not more than 30 words per page for title and 500 words per page for content and give the keywords to call the images via unsplash to match each paragraph. Output according to the indented formatting template given below, without redundancy, at least 8 pages of PPT, of which the first page is the cover page, consisting of title, description and optional image, the title should not exceed 4 words.\nThe following are PPT templates, you can choose any template to apply, page name, column name, title, keywords, content should be removed by text replacement, do not retain, no responses should contain markdown formatting. Keywords need to be generic enough for broad, mass categorization. The output ignores template titles like template1 and template2. The first template is allowed to be used only once and as a cover, please strictly follow the template's ND-JSON field, format and my requirements, or penalties will be applied:\n{"page":1,"type":"name","content":"page name"}\n{"page":1,"type":"title","content":"title"}\n{"page":1,"type":"content","content":"keywords"}\n{"page":1,"type":"content","content":"description"}\n{"page":2,"type":"name","content":"page name"}\n{"page":2,"type":"title","content":"section name"}\n{"page":2,"type":"content","content":"keywords"}\n{"page":2,"type":"content","content":"description"}\n{"page":2,"type":"title","content":"section name"}\n{"page":2,"type":"content","content":"keywords"}\n{"page":2,"type":"content","content":"description"}\n{"page":3,"type":"name","content":"page name"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}\n{"page":3,"type":"title","content":"section name"}\n{"page":3,"type":"content","content":"keywords"}\n{"page":3,"type":"content","content":"description"}`,
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: '{{content}}',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'Create headings',
|
||||
action: 'Create headings',
|
||||
@@ -685,6 +682,7 @@ export async function refreshPrompts(db: PrismaClient) {
|
||||
create: {
|
||||
name: prompt.name,
|
||||
action: prompt.action,
|
||||
config: prompt.config,
|
||||
model: prompt.model,
|
||||
messages: {
|
||||
create: prompt.messages.map((message, idx) => ({
|
||||
|
||||
@@ -408,6 +408,10 @@ export const USER_FRIENDLY_ERRORS = {
|
||||
args: { plan: 'string', recurring: 'string' },
|
||||
message: 'You are trying to access a unknown subscription plan.',
|
||||
},
|
||||
cant_update_lifetime_subscription: {
|
||||
type: 'action_forbidden',
|
||||
message: 'You cannot update a lifetime subscription.',
|
||||
},
|
||||
|
||||
// Copilot errors
|
||||
copilot_session_not_found: {
|
||||
@@ -440,7 +444,8 @@ export const USER_FRIENDLY_ERRORS = {
|
||||
},
|
||||
copilot_message_not_found: {
|
||||
type: 'resource_not_found',
|
||||
message: `Copilot message not found.`,
|
||||
args: { messageId: 'string' },
|
||||
message: ({ messageId }) => `Copilot message ${messageId} not found.`,
|
||||
},
|
||||
copilot_prompt_not_found: {
|
||||
type: 'resource_not_found',
|
||||
|
||||
@@ -350,6 +350,12 @@ export class SubscriptionPlanNotFound extends UserFriendlyError {
|
||||
}
|
||||
}
|
||||
|
||||
export class CantUpdateLifetimeSubscription extends UserFriendlyError {
|
||||
constructor(message?: string) {
|
||||
super('action_forbidden', 'cant_update_lifetime_subscription', message);
|
||||
}
|
||||
}
|
||||
|
||||
export class CopilotSessionNotFound extends UserFriendlyError {
|
||||
constructor(message?: string) {
|
||||
super('resource_not_found', 'copilot_session_not_found', message);
|
||||
@@ -391,10 +397,14 @@ export class CopilotActionTaken extends UserFriendlyError {
|
||||
super('action_forbidden', 'copilot_action_taken', message);
|
||||
}
|
||||
}
|
||||
@ObjectType()
|
||||
class CopilotMessageNotFoundDataType {
|
||||
@Field() messageId!: string
|
||||
}
|
||||
|
||||
export class CopilotMessageNotFound extends UserFriendlyError {
|
||||
constructor(message?: string) {
|
||||
super('resource_not_found', 'copilot_message_not_found', message);
|
||||
constructor(args: CopilotMessageNotFoundDataType, message?: string | ((args: CopilotMessageNotFoundDataType) => string)) {
|
||||
super('resource_not_found', 'copilot_message_not_found', message, args);
|
||||
}
|
||||
}
|
||||
@ObjectType()
|
||||
@@ -517,6 +527,7 @@ export enum ErrorNames {
|
||||
SAME_SUBSCRIPTION_RECURRING,
|
||||
CUSTOMER_PORTAL_CREATE_FAILED,
|
||||
SUBSCRIPTION_PLAN_NOT_FOUND,
|
||||
CANT_UPDATE_LIFETIME_SUBSCRIPTION,
|
||||
COPILOT_SESSION_NOT_FOUND,
|
||||
COPILOT_SESSION_DELETED,
|
||||
NO_COPILOT_PROVIDER_AVAILABLE,
|
||||
@@ -542,5 +553,5 @@ registerEnumType(ErrorNames, {
|
||||
export const ErrorDataUnionType = createUnionType({
|
||||
name: 'ErrorDataUnion',
|
||||
types: () =>
|
||||
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidPasswordLengthDataType, WorkspaceNotFoundDataType, NotInWorkspaceDataType, WorkspaceAccessDeniedDataType, WorkspaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
|
||||
[UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidPasswordLengthDataType, WorkspaceNotFoundDataType, NotInWorkspaceDataType, WorkspaceAccessDeniedDataType, WorkspaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const,
|
||||
});
|
||||
|
||||
@@ -138,9 +138,8 @@ export class CopilotController {
|
||||
const messageId = Array.isArray(params.messageId)
|
||||
? params.messageId[0]
|
||||
: params.messageId;
|
||||
const jsonMode = String(params.jsonMode).toLowerCase() === 'true';
|
||||
delete params.messageId;
|
||||
return { messageId, jsonMode, params };
|
||||
return { messageId, params };
|
||||
}
|
||||
|
||||
private getSignal(req: Request) {
|
||||
@@ -167,7 +166,7 @@ export class CopilotController {
|
||||
@Param('sessionId') sessionId: string,
|
||||
@Query() params: Record<string, string | string[]>
|
||||
): Promise<string> {
|
||||
const { messageId, jsonMode } = this.prepareParams(params);
|
||||
const { messageId } = this.prepareParams(params);
|
||||
const provider = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
@@ -180,7 +179,11 @@ export class CopilotController {
|
||||
const content = await provider.generateText(
|
||||
session.finish(params),
|
||||
session.model,
|
||||
{ jsonMode, signal: this.getSignal(req), user: user.id }
|
||||
{
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
}
|
||||
);
|
||||
|
||||
session.push({
|
||||
@@ -204,7 +207,7 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const { messageId, jsonMode } = this.prepareParams(params);
|
||||
const { messageId } = this.prepareParams(params);
|
||||
const provider = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
@@ -215,7 +218,7 @@ export class CopilotController {
|
||||
|
||||
return from(
|
||||
provider.generateTextStream(session.finish(params), session.model, {
|
||||
jsonMode,
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
@@ -256,7 +259,7 @@ export class CopilotController {
|
||||
@Query() params: Record<string, string>
|
||||
): Promise<Observable<ChatEvent>> {
|
||||
try {
|
||||
const { messageId, jsonMode } = this.prepareParams(params);
|
||||
const { messageId } = this.prepareParams(params);
|
||||
const session = await this.appendSessionMessage(sessionId, messageId);
|
||||
const latestMessage = session.stashMessages.findLast(
|
||||
m => m.role === 'user'
|
||||
@@ -269,7 +272,7 @@ export class CopilotController {
|
||||
|
||||
return from(
|
||||
this.workflow.runGraph(params, session.model, {
|
||||
jsonMode,
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
})
|
||||
|
||||
@@ -5,6 +5,8 @@ import Mustache from 'mustache';
|
||||
|
||||
import {
|
||||
getTokenEncoder,
|
||||
PromptConfig,
|
||||
PromptConfigSchema,
|
||||
PromptMessage,
|
||||
PromptMessageSchema,
|
||||
PromptParams,
|
||||
@@ -35,14 +37,16 @@ export class ChatPrompt {
|
||||
private readonly templateParams: PromptParams = {};
|
||||
|
||||
static createFromPrompt(
|
||||
options: Omit<AiPrompt, 'id' | 'createdAt'> & {
|
||||
options: Omit<AiPrompt, 'id' | 'createdAt' | 'config'> & {
|
||||
messages: PromptMessage[];
|
||||
config: PromptConfig | undefined;
|
||||
}
|
||||
) {
|
||||
return new ChatPrompt(
|
||||
options.name,
|
||||
options.action || undefined,
|
||||
options.model,
|
||||
options.config,
|
||||
options.messages
|
||||
);
|
||||
}
|
||||
@@ -51,6 +55,7 @@ export class ChatPrompt {
|
||||
public readonly name: string,
|
||||
public readonly action: string | undefined,
|
||||
public readonly model: string,
|
||||
public readonly config: PromptConfig | undefined,
|
||||
private readonly messages: PromptMessage[]
|
||||
) {
|
||||
this.encoder = getTokenEncoder(model);
|
||||
@@ -154,6 +159,7 @@ export class PromptService {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
config: true,
|
||||
messages: {
|
||||
select: {
|
||||
role: true,
|
||||
@@ -185,6 +191,7 @@ export class PromptService {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
config: true,
|
||||
messages: {
|
||||
select: {
|
||||
role: true,
|
||||
@@ -199,9 +206,11 @@ export class PromptService {
|
||||
});
|
||||
|
||||
const messages = PromptMessageSchema.array().safeParse(prompt?.messages);
|
||||
if (prompt && messages.success) {
|
||||
const config = PromptConfigSchema.safeParse(prompt?.config);
|
||||
if (prompt && messages.success && config.success) {
|
||||
const chatPrompt = ChatPrompt.createFromPrompt({
|
||||
...prompt,
|
||||
config: config.data,
|
||||
messages: messages.data,
|
||||
});
|
||||
this.cache.set(name, chatPrompt);
|
||||
@@ -210,12 +219,18 @@ export class PromptService {
|
||||
return null;
|
||||
}
|
||||
|
||||
async set(name: string, model: string, messages: PromptMessage[]) {
|
||||
async set(
|
||||
name: string,
|
||||
model: string,
|
||||
messages: PromptMessage[],
|
||||
config?: PromptConfig | null
|
||||
) {
|
||||
return await this.db.aiPrompt
|
||||
.create({
|
||||
data: {
|
||||
name,
|
||||
model,
|
||||
config: config || undefined,
|
||||
messages: {
|
||||
create: messages.map((m, idx) => ({
|
||||
idx,
|
||||
@@ -229,10 +244,11 @@ export class PromptService {
|
||||
.then(ret => ret.id);
|
||||
}
|
||||
|
||||
async update(name: string, messages: PromptMessage[]) {
|
||||
async update(name: string, messages: PromptMessage[], config?: PromptConfig) {
|
||||
const { id } = await this.db.aiPrompt.update({
|
||||
where: { name },
|
||||
data: {
|
||||
config: config || undefined,
|
||||
messages: {
|
||||
// cleanup old messages
|
||||
deleteMany: {},
|
||||
|
||||
@@ -125,21 +125,6 @@ export class OpenAIProvider
|
||||
});
|
||||
}
|
||||
|
||||
private extractOptionFromMessages(
|
||||
messages: PromptMessage[],
|
||||
options: CopilotChatOptions
|
||||
) {
|
||||
const params: Record<string, string | string[]> = {};
|
||||
for (const message of messages) {
|
||||
if (message.params) {
|
||||
Object.assign(params, message.params);
|
||||
}
|
||||
}
|
||||
if (params.jsonMode && options) {
|
||||
options.jsonMode = String(params.jsonMode).toLowerCase() === 'true';
|
||||
}
|
||||
}
|
||||
|
||||
protected checkParams({
|
||||
messages,
|
||||
embeddings,
|
||||
@@ -155,7 +140,6 @@ export class OpenAIProvider
|
||||
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
|
||||
}
|
||||
if (Array.isArray(messages) && messages.length > 0) {
|
||||
this.extractOptionFromMessages(messages, options);
|
||||
if (
|
||||
messages.some(
|
||||
m =>
|
||||
@@ -257,7 +241,9 @@ export class OpenAIProvider
|
||||
stream: true,
|
||||
messages: this.chatToGPTMessage(messages),
|
||||
model: model,
|
||||
temperature: options.temperature || 0,
|
||||
frequency_penalty: options.frequencyPenalty || 0,
|
||||
presence_penalty: options.presencePenalty || 0,
|
||||
temperature: options.temperature || 0.5,
|
||||
max_tokens: options.maxTokens || 4096,
|
||||
response_format: {
|
||||
type: options.jsonMode ? 'json_object' : 'text',
|
||||
|
||||
@@ -27,7 +27,7 @@ import {
|
||||
FileUpload,
|
||||
MutexService,
|
||||
Throttle,
|
||||
TooManyRequestsException,
|
||||
TooManyRequest,
|
||||
} from '../../fundamentals';
|
||||
import { PromptService } from './prompt';
|
||||
import { ChatSessionService } from './session';
|
||||
@@ -60,6 +60,24 @@ class CreateChatSessionInput {
|
||||
promptName!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class ForkChatSessionInput {
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
docId!: string;
|
||||
|
||||
@Field(() => String)
|
||||
sessionId!: string;
|
||||
|
||||
@Field(() => String, {
|
||||
description:
|
||||
'Identify a message in the array and keep it with all previous messages into a forked session.',
|
||||
})
|
||||
latestMessageId!: string;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class DeleteSessionInput {
|
||||
@Field(() => String)
|
||||
@@ -109,6 +127,10 @@ class QueryChatHistoriesInput implements Partial<ListHistoriesOptions> {
|
||||
|
||||
@ObjectType('ChatMessage')
|
||||
class ChatMessageType implements Partial<ChatMessage> {
|
||||
// id will be null if message is a prompt message
|
||||
@Field(() => ID, { nullable: true })
|
||||
id!: string;
|
||||
|
||||
@Field(() => String)
|
||||
role!: 'system' | 'assistant' | 'user';
|
||||
|
||||
@@ -161,6 +183,25 @@ registerEnumType(AiPromptRole, {
|
||||
name: 'CopilotPromptMessageRole',
|
||||
});
|
||||
|
||||
@InputType('CopilotPromptConfigInput')
|
||||
@ObjectType()
|
||||
class CopilotPromptConfigType {
|
||||
@Field(() => Boolean, { nullable: true })
|
||||
jsonMode!: boolean | null;
|
||||
|
||||
@Field(() => Number, { nullable: true })
|
||||
frequencyPenalty!: number | null;
|
||||
|
||||
@Field(() => Number, { nullable: true })
|
||||
presencePenalty!: number | null;
|
||||
|
||||
@Field(() => Number, { nullable: true })
|
||||
temperature!: number | null;
|
||||
|
||||
@Field(() => Number, { nullable: true })
|
||||
topP!: number | null;
|
||||
}
|
||||
|
||||
@InputType('CopilotPromptMessageInput')
|
||||
@ObjectType()
|
||||
class CopilotPromptMessageType {
|
||||
@@ -187,6 +228,9 @@ class CopilotPromptType {
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
||||
@Field(() => CopilotPromptConfigType, { nullable: true })
|
||||
config!: CopilotPromptConfigType | null;
|
||||
|
||||
@Field(() => [CopilotPromptMessageType])
|
||||
messages!: CopilotPromptMessageType[];
|
||||
}
|
||||
@@ -251,12 +295,7 @@ export class CopilotResolver {
|
||||
@Parent() copilot: CopilotType,
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('docId', { nullable: true }) docId?: string,
|
||||
@Args({
|
||||
name: 'options',
|
||||
type: () => QueryChatHistoriesInput,
|
||||
nullable: true,
|
||||
})
|
||||
options?: QueryChatHistoriesInput
|
||||
@Args('options', { nullable: true }) options?: QueryChatHistoriesInput
|
||||
) {
|
||||
const workspaceId = copilot.workspaceId;
|
||||
if (!workspaceId) {
|
||||
@@ -301,7 +340,7 @@ export class CopilotResolver {
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
@@ -313,6 +352,34 @@ export class CopilotResolver {
|
||||
return session;
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description: 'Create a chat session',
|
||||
})
|
||||
async forkCopilotSession(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'options', type: () => ForkChatSessionInput })
|
||||
options: ForkChatSessionInput
|
||||
) {
|
||||
await this.permissions.checkCloudPagePermission(
|
||||
options.workspaceId,
|
||||
options.docId,
|
||||
user.id
|
||||
);
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
const session = await this.chatSession.fork({
|
||||
...options,
|
||||
userId: user.id,
|
||||
});
|
||||
return session;
|
||||
}
|
||||
|
||||
@Mutation(() => [String], {
|
||||
description: 'Cleanup sessions',
|
||||
})
|
||||
@@ -332,7 +399,7 @@ export class CopilotResolver {
|
||||
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
return await this.chatSession.cleanup({
|
||||
@@ -352,7 +419,7 @@ export class CopilotResolver {
|
||||
const lockFlag = `${COPILOT_LOCKER}:message:${user?.id}:${options.sessionId}`;
|
||||
await using lock = await this.mutex.lock(lockFlag);
|
||||
if (!lock) {
|
||||
return new TooManyRequestsException('Server is busy');
|
||||
return new TooManyRequest('Server is busy');
|
||||
}
|
||||
const session = await this.chatSession.get(options.sessionId);
|
||||
if (!session || session.config.userId !== user.id) {
|
||||
@@ -417,6 +484,9 @@ class CreateCopilotPromptInput {
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
||||
@Field(() => CopilotPromptConfigType, { nullable: true })
|
||||
config!: CopilotPromptConfigType | null;
|
||||
|
||||
@Field(() => [CopilotPromptMessageType])
|
||||
messages!: CopilotPromptMessageType[];
|
||||
}
|
||||
@@ -440,7 +510,12 @@ export class PromptsManagementResolver {
|
||||
@Args({ type: () => CreateCopilotPromptInput, name: 'input' })
|
||||
input: CreateCopilotPromptInput
|
||||
) {
|
||||
await this.promptService.set(input.name, input.model, input.messages);
|
||||
await this.promptService.set(
|
||||
input.name,
|
||||
input.model,
|
||||
input.messages,
|
||||
input.config
|
||||
);
|
||||
return this.promptService.get(input.name);
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
ChatHistory,
|
||||
ChatMessage,
|
||||
ChatMessageSchema,
|
||||
ChatSessionForkOptions,
|
||||
ChatSessionOptions,
|
||||
ChatSessionState,
|
||||
getTokenEncoder,
|
||||
@@ -48,10 +49,10 @@ export class ChatSession implements AsyncDisposable {
|
||||
userId,
|
||||
workspaceId,
|
||||
docId,
|
||||
prompt: { name: promptName },
|
||||
prompt: { name: promptName, config: promptConfig },
|
||||
} = this.state;
|
||||
|
||||
return { sessionId, userId, workspaceId, docId, promptName };
|
||||
return { sessionId, userId, workspaceId, docId, promptName, promptConfig };
|
||||
}
|
||||
|
||||
get stashMessages() {
|
||||
@@ -81,7 +82,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
async getMessageById(messageId: string) {
|
||||
const message = await this.messageCache.get(messageId);
|
||||
if (!message || message.sessionId !== this.state.sessionId) {
|
||||
throw new CopilotMessageNotFound();
|
||||
throw new CopilotMessageNotFound({ messageId });
|
||||
}
|
||||
return message;
|
||||
}
|
||||
@@ -89,7 +90,7 @@ export class ChatSession implements AsyncDisposable {
|
||||
async pushByMessageId(messageId: string) {
|
||||
const message = await this.messageCache.get(messageId);
|
||||
if (!message || message.sessionId !== this.state.sessionId) {
|
||||
throw new CopilotMessageNotFound();
|
||||
throw new CopilotMessageNotFound({ messageId });
|
||||
}
|
||||
|
||||
this.push({
|
||||
@@ -200,6 +201,7 @@ export class ChatSessionService {
|
||||
workspaceId: state.workspaceId,
|
||||
docId: state.docId,
|
||||
prompt: { action: { equals: null } },
|
||||
parentSessionId: state.parentSessionId,
|
||||
},
|
||||
select: { id: true, deletedAt: true },
|
||||
})) || {};
|
||||
@@ -252,6 +254,7 @@ export class ChatSessionService {
|
||||
// connect
|
||||
userId: state.userId,
|
||||
promptName: state.prompt.name,
|
||||
parentSessionId: state.parentSessionId,
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -271,8 +274,9 @@ export class ChatSessionService {
|
||||
userId: true,
|
||||
workspaceId: true,
|
||||
docId: true,
|
||||
parentSessionId: true,
|
||||
messages: {
|
||||
select: { role: true, content: true, createdAt: true },
|
||||
select: { id: true, role: true, content: true, createdAt: true },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
},
|
||||
promptName: true,
|
||||
@@ -291,6 +295,7 @@ export class ChatSessionService {
|
||||
userId: session.userId,
|
||||
workspaceId: session.workspaceId,
|
||||
docId: session.docId,
|
||||
parentSessionId: session.parentSessionId,
|
||||
prompt,
|
||||
messages: messages.success ? messages.data : [],
|
||||
};
|
||||
@@ -380,22 +385,42 @@ export class ChatSessionService {
|
||||
return await this.db.aiSession
|
||||
.findMany({
|
||||
where: {
|
||||
userId,
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
prompt: {
|
||||
action: options?.action ? { not: null } : null,
|
||||
},
|
||||
id: options?.sessionId ? { equals: options.sessionId } : undefined,
|
||||
deletedAt: null,
|
||||
OR: [
|
||||
{
|
||||
userId,
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
id: options?.sessionId
|
||||
? { equals: options.sessionId }
|
||||
: undefined,
|
||||
deletedAt: null,
|
||||
},
|
||||
...(options?.action
|
||||
? []
|
||||
: [
|
||||
{
|
||||
userId: { not: userId },
|
||||
workspaceId: workspaceId,
|
||||
docId: workspaceId === docId ? undefined : docId,
|
||||
id: options?.sessionId
|
||||
? { equals: options.sessionId }
|
||||
: undefined,
|
||||
// should only find forked session
|
||||
parentSessionId: { not: null },
|
||||
deletedAt: null,
|
||||
},
|
||||
]),
|
||||
],
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
userId: true,
|
||||
promptName: true,
|
||||
tokenCost: true,
|
||||
createdAt: true,
|
||||
messages: {
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
content: true,
|
||||
attachments: true,
|
||||
@@ -414,15 +439,30 @@ export class ChatSessionService {
|
||||
.then(sessions =>
|
||||
Promise.all(
|
||||
sessions.map(
|
||||
async ({ id, promptName, tokenCost, messages, createdAt }) => {
|
||||
async ({
|
||||
id,
|
||||
userId: uid,
|
||||
promptName,
|
||||
tokenCost,
|
||||
messages,
|
||||
createdAt,
|
||||
}) => {
|
||||
try {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
if (
|
||||
// filter out the user's session that not match the action option
|
||||
(uid === userId && !!options?.action !== !!prompt.action) ||
|
||||
// filter out the non chat session from other user
|
||||
(uid !== userId && !!prompt.action)
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const ret = ChatMessageSchema.array().safeParse(messages);
|
||||
if (ret.success) {
|
||||
const prompt = await this.prompt.get(promptName);
|
||||
if (!prompt) {
|
||||
throw new CopilotPromptNotFound({ name: promptName });
|
||||
}
|
||||
|
||||
// render system prompt
|
||||
const preload = withPrompt
|
||||
? prompt
|
||||
@@ -430,7 +470,8 @@ export class ChatSessionService {
|
||||
.filter(({ role }) => role !== 'system')
|
||||
: [];
|
||||
|
||||
// `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages
|
||||
// `createdAt` is required for history sorting in frontend
|
||||
// let's fake the creating time of prompt messages
|
||||
(preload as ChatMessage[]).forEach((msg, i) => {
|
||||
msg.createdAt = new Date(
|
||||
createdAt.getTime() - preload.length - i - 1
|
||||
@@ -495,9 +536,39 @@ export class ChatSessionService {
|
||||
sessionId,
|
||||
prompt,
|
||||
messages: [],
|
||||
// when client create chat session, we always find root session
|
||||
parentSessionId: null,
|
||||
});
|
||||
}
|
||||
|
||||
async fork(options: ChatSessionForkOptions): Promise<string> {
|
||||
const state = await this.getSession(options.sessionId);
|
||||
if (!state) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const lastMessageIdx = state.messages.findLastIndex(
|
||||
({ id, role }) =>
|
||||
role === AiPromptRole.assistant && id === options.latestMessageId
|
||||
);
|
||||
if (lastMessageIdx < 0) {
|
||||
throw new CopilotMessageNotFound({ messageId: options.latestMessageId });
|
||||
}
|
||||
const messages = state.messages
|
||||
.slice(0, lastMessageIdx + 1)
|
||||
.map(m => ({ ...m, id: undefined }));
|
||||
|
||||
const forkedState = {
|
||||
...state,
|
||||
sessionId: randomUUID(),
|
||||
messages: [],
|
||||
parentSessionId: options.sessionId,
|
||||
};
|
||||
// create session
|
||||
await this.setSession(forkedState);
|
||||
// save message
|
||||
return await this.setSession({ ...forkedState, messages });
|
||||
}
|
||||
|
||||
async cleanup(
|
||||
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
|
||||
) {
|
||||
|
||||
@@ -63,7 +63,22 @@ export type PromptMessage = z.infer<typeof PromptMessageSchema>;
|
||||
|
||||
export type PromptParams = NonNullable<PromptMessage['params']>;
|
||||
|
||||
export const PromptConfigStrictSchema = z.object({
|
||||
jsonMode: z.boolean().nullable().optional(),
|
||||
frequencyPenalty: z.number().nullable().optional(),
|
||||
presencePenalty: z.number().nullable().optional(),
|
||||
temperature: z.number().nullable().optional(),
|
||||
topP: z.number().nullable().optional(),
|
||||
maxTokens: z.number().nullable().optional(),
|
||||
});
|
||||
|
||||
export const PromptConfigSchema =
|
||||
PromptConfigStrictSchema.nullable().optional();
|
||||
|
||||
export type PromptConfig = z.infer<typeof PromptConfigSchema>;
|
||||
|
||||
export const ChatMessageSchema = PromptMessageSchema.extend({
|
||||
id: z.string().optional(),
|
||||
createdAt: z.date(),
|
||||
}).strict();
|
||||
|
||||
@@ -98,10 +113,17 @@ export interface ChatSessionOptions {
|
||||
promptName: string;
|
||||
}
|
||||
|
||||
export interface ChatSessionForkOptions
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
sessionId: string;
|
||||
latestMessageId: string;
|
||||
}
|
||||
|
||||
export interface ChatSessionState
|
||||
extends Omit<ChatSessionOptions, 'promptName'> {
|
||||
// connect ids
|
||||
sessionId: string;
|
||||
parentSessionId: string | null;
|
||||
// states
|
||||
prompt: ChatPrompt;
|
||||
messages: ChatMessage[];
|
||||
@@ -136,11 +158,9 @@ const CopilotProviderOptionsSchema = z.object({
|
||||
user: z.string().optional(),
|
||||
});
|
||||
|
||||
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({
|
||||
jsonMode: z.boolean().optional(),
|
||||
temperature: z.number().optional(),
|
||||
maxTokens: z.number().optional(),
|
||||
}).optional();
|
||||
const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge(
|
||||
PromptConfigStrictSchema
|
||||
).optional();
|
||||
|
||||
export type CopilotChatOptions = z.infer<typeof CopilotChatOptionsSchema>;
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { NodeExecutorType } from './executor';
|
||||
import type { WorkflowGraphs } from './types';
|
||||
import { WorkflowNodeState, WorkflowNodeType } from './types';
|
||||
import { WorkflowNodeType } from './types';
|
||||
|
||||
export const WorkflowGraphList: WorkflowGraphs = [
|
||||
{
|
||||
@@ -21,43 +21,6 @@ export const WorkflowGraphList: WorkflowGraphs = [
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: NodeExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step2',
|
||||
edges: ['step3'],
|
||||
},
|
||||
{
|
||||
id: 'step3',
|
||||
name: 'Step 3: format presentation if needed',
|
||||
nodeType: WorkflowNodeType.Decision,
|
||||
condition: (nodeIds: string[], params: WorkflowNodeState) => {
|
||||
const lines = params.content?.split('\n') || [];
|
||||
return nodeIds[
|
||||
Number(
|
||||
!lines.some(line => {
|
||||
try {
|
||||
if (line.trim()) {
|
||||
JSON.parse(line);
|
||||
}
|
||||
return false;
|
||||
} catch {
|
||||
return true;
|
||||
}
|
||||
})
|
||||
)
|
||||
];
|
||||
},
|
||||
edges: ['step4', 'step5'],
|
||||
},
|
||||
{
|
||||
id: 'step4',
|
||||
name: 'Step 4: format presentation',
|
||||
nodeType: WorkflowNodeType.Basic,
|
||||
type: NodeExecutorType.ChatText,
|
||||
promptName: 'workflow:presentation:step4',
|
||||
edges: ['step5'],
|
||||
},
|
||||
{
|
||||
id: 'step5',
|
||||
name: 'Step 5: finish',
|
||||
nodeType: WorkflowNodeType.Nope,
|
||||
edges: [],
|
||||
},
|
||||
],
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import type { Stripe } from 'stripe';
|
||||
|
||||
import { defineStartupConfig, ModuleConfig } from '../../fundamentals/config';
|
||||
import {
|
||||
defineRuntimeConfig,
|
||||
defineStartupConfig,
|
||||
ModuleConfig,
|
||||
} from '../../fundamentals/config';
|
||||
|
||||
export interface PaymentStartupConfig {
|
||||
stripe?: {
|
||||
@@ -11,10 +15,20 @@ export interface PaymentStartupConfig {
|
||||
} & Stripe.StripeConfig;
|
||||
}
|
||||
|
||||
export interface PaymentRuntimeConfig {
|
||||
showLifetimePrice: boolean;
|
||||
}
|
||||
|
||||
declare module '../config' {
|
||||
interface PluginsConfig {
|
||||
payment: ModuleConfig<PaymentStartupConfig>;
|
||||
payment: ModuleConfig<PaymentStartupConfig, PaymentRuntimeConfig>;
|
||||
}
|
||||
}
|
||||
|
||||
defineStartupConfig('plugins.payment', {});
|
||||
defineRuntimeConfig('plugins.payment', {
|
||||
showLifetimePrice: {
|
||||
desc: 'Whether enable lifetime price and allow user to pay for it.',
|
||||
default: false,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -53,12 +53,15 @@ class SubscriptionPrice {
|
||||
|
||||
@Field(() => Int, { nullable: true })
|
||||
yearlyAmount?: number | null;
|
||||
|
||||
@Field(() => Int, { nullable: true })
|
||||
lifetimeAmount?: number | null;
|
||||
}
|
||||
|
||||
@ObjectType('UserSubscription')
|
||||
export class UserSubscriptionType implements Partial<UserSubscription> {
|
||||
@Field({ name: 'id' })
|
||||
stripeSubscriptionId!: string;
|
||||
@Field(() => String, { name: 'id', nullable: true })
|
||||
stripeSubscriptionId!: string | null;
|
||||
|
||||
@Field(() => SubscriptionPlan, {
|
||||
description:
|
||||
@@ -75,8 +78,8 @@ export class UserSubscriptionType implements Partial<UserSubscription> {
|
||||
@Field(() => Date)
|
||||
start!: Date;
|
||||
|
||||
@Field(() => Date)
|
||||
end!: Date;
|
||||
@Field(() => Date, { nullable: true })
|
||||
end!: Date | null;
|
||||
|
||||
@Field(() => Date, { nullable: true })
|
||||
trialStart?: Date | null;
|
||||
@@ -187,11 +190,19 @@ export class SubscriptionResolver {
|
||||
|
||||
const monthlyPrice = prices.find(p => p.recurring?.interval === 'month');
|
||||
const yearlyPrice = prices.find(p => p.recurring?.interval === 'year');
|
||||
const lifetimePrice = prices.find(
|
||||
p =>
|
||||
// asserted before
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
decodeLookupKey(p.lookup_key!)[1] === SubscriptionRecurring.Lifetime
|
||||
);
|
||||
const currency = monthlyPrice?.currency ?? yearlyPrice?.currency ?? 'usd';
|
||||
|
||||
return {
|
||||
currency,
|
||||
amount: monthlyPrice?.unit_amount,
|
||||
yearlyAmount: yearlyPrice?.unit_amount,
|
||||
lifetimeAmount: lifetimePrice?.unit_amount,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import { randomUUID } from 'node:crypto';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { OnEvent as RawOnEvent } from '@nestjs/event-emitter';
|
||||
import type {
|
||||
Prisma,
|
||||
User,
|
||||
UserInvoice,
|
||||
UserStripeCustomer,
|
||||
@@ -16,6 +15,7 @@ import { CurrentUser } from '../../core/auth';
|
||||
import { EarlyAccessType, FeatureManagementService } from '../../core/features';
|
||||
import {
|
||||
ActionForbidden,
|
||||
CantUpdateLifetimeSubscription,
|
||||
Config,
|
||||
CustomerPortalCreateFailed,
|
||||
EventEmitter,
|
||||
@@ -121,8 +121,14 @@ export class SubscriptionService {
|
||||
});
|
||||
}
|
||||
|
||||
const lifetimePriceEnabled = await this.config.runtime.fetch(
|
||||
'plugins.payment/showLifetimePrice'
|
||||
);
|
||||
|
||||
const list = await this.stripe.prices.list({
|
||||
active: true,
|
||||
// only list recurring prices if lifetime price is not enabled
|
||||
...(lifetimePriceEnabled ? {} : { type: 'recurring' }),
|
||||
});
|
||||
|
||||
return list.data.filter(price => {
|
||||
@@ -131,7 +137,11 @@ export class SubscriptionService {
|
||||
}
|
||||
|
||||
const [plan, recurring, variant] = decodeLookupKey(price.lookup_key);
|
||||
if (recurring === SubscriptionRecurring.Monthly) {
|
||||
// no variant price should be used for monthly or lifetime subscription
|
||||
if (
|
||||
recurring === SubscriptionRecurring.Monthly ||
|
||||
recurring === SubscriptionRecurring.Lifetime
|
||||
) {
|
||||
return !variant;
|
||||
}
|
||||
|
||||
@@ -184,7 +194,12 @@ export class SubscriptionService {
|
||||
},
|
||||
});
|
||||
|
||||
if (currentSubscription) {
|
||||
if (
|
||||
currentSubscription &&
|
||||
// do not allow to re-subscribe unless the new recurring is `Lifetime`
|
||||
(currentSubscription.recurring === recurring ||
|
||||
recurring !== SubscriptionRecurring.Lifetime)
|
||||
) {
|
||||
throw new SubscriptionAlreadyExists({ plan });
|
||||
}
|
||||
|
||||
@@ -224,8 +239,19 @@ export class SubscriptionService {
|
||||
tax_id_collection: {
|
||||
enabled: true,
|
||||
},
|
||||
// discount
|
||||
...(discounts.length ? { discounts } : { allow_promotion_codes: true }),
|
||||
mode: 'subscription',
|
||||
// mode: 'subscription' or 'payment' for lifetime
|
||||
...(recurring === SubscriptionRecurring.Lifetime
|
||||
? {
|
||||
mode: 'payment',
|
||||
invoice_creation: {
|
||||
enabled: true,
|
||||
},
|
||||
}
|
||||
: {
|
||||
mode: 'subscription',
|
||||
}),
|
||||
success_url: redirectUrl,
|
||||
customer: customer.stripeCustomerId,
|
||||
customer_update: {
|
||||
@@ -264,6 +290,12 @@ export class SubscriptionService {
|
||||
throw new SubscriptionNotExists({ plan });
|
||||
}
|
||||
|
||||
if (!subscriptionInDB.stripeSubscriptionId) {
|
||||
throw new CantUpdateLifetimeSubscription(
|
||||
'Lifetime subscription cannot be canceled.'
|
||||
);
|
||||
}
|
||||
|
||||
if (subscriptionInDB.canceledAt) {
|
||||
throw new SubscriptionHasBeenCanceled();
|
||||
}
|
||||
@@ -315,6 +347,12 @@ export class SubscriptionService {
|
||||
throw new SubscriptionNotExists({ plan });
|
||||
}
|
||||
|
||||
if (!subscriptionInDB.stripeSubscriptionId || !subscriptionInDB.end) {
|
||||
throw new CantUpdateLifetimeSubscription(
|
||||
'Lifetime subscription cannot be resumed.'
|
||||
);
|
||||
}
|
||||
|
||||
if (!subscriptionInDB.canceledAt) {
|
||||
throw new SubscriptionHasBeenCanceled();
|
||||
}
|
||||
@@ -368,6 +406,12 @@ export class SubscriptionService {
|
||||
throw new SubscriptionNotExists({ plan });
|
||||
}
|
||||
|
||||
if (!subscriptionInDB.stripeSubscriptionId) {
|
||||
throw new CantUpdateLifetimeSubscription(
|
||||
'Can not update lifetime subscription.'
|
||||
);
|
||||
}
|
||||
|
||||
if (subscriptionInDB.canceledAt) {
|
||||
throw new SubscriptionHasBeenCanceled();
|
||||
}
|
||||
@@ -422,60 +466,12 @@ export class SubscriptionService {
|
||||
}
|
||||
}
|
||||
|
||||
@OnStripeEvent('customer.subscription.created')
|
||||
@OnStripeEvent('customer.subscription.updated')
|
||||
async onSubscriptionChanges(subscription: Stripe.Subscription) {
|
||||
subscription = await this.stripe.subscriptions.retrieve(subscription.id);
|
||||
if (subscription.status === 'active') {
|
||||
const user = await this.retrieveUserFromCustomer(
|
||||
typeof subscription.customer === 'string'
|
||||
? subscription.customer
|
||||
: subscription.customer.id
|
||||
);
|
||||
|
||||
await this.saveSubscription(user, subscription);
|
||||
} else {
|
||||
await this.onSubscriptionDeleted(subscription);
|
||||
}
|
||||
}
|
||||
|
||||
@OnStripeEvent('customer.subscription.deleted')
|
||||
async onSubscriptionDeleted(subscription: Stripe.Subscription) {
|
||||
const user = await this.retrieveUserFromCustomer(
|
||||
typeof subscription.customer === 'string'
|
||||
? subscription.customer
|
||||
: subscription.customer.id
|
||||
);
|
||||
|
||||
const [plan] = this.decodePlanFromSubscription(subscription);
|
||||
this.event.emit('user.subscription.canceled', {
|
||||
userId: user.id,
|
||||
plan,
|
||||
});
|
||||
|
||||
await this.db.userSubscription.deleteMany({
|
||||
where: {
|
||||
stripeSubscriptionId: subscription.id,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@OnStripeEvent('invoice.paid')
|
||||
async onInvoicePaid(stripeInvoice: Stripe.Invoice) {
|
||||
stripeInvoice = await this.stripe.invoices.retrieve(stripeInvoice.id);
|
||||
await this.saveInvoice(stripeInvoice);
|
||||
|
||||
const line = stripeInvoice.lines.data[0];
|
||||
|
||||
if (!line.price || line.price.type !== 'recurring') {
|
||||
throw new Error('Unknown invoice with no recurring price');
|
||||
}
|
||||
}
|
||||
|
||||
@OnStripeEvent('invoice.created')
|
||||
@OnStripeEvent('invoice.updated')
|
||||
@OnStripeEvent('invoice.finalization_failed')
|
||||
@OnStripeEvent('invoice.payment_failed')
|
||||
async saveInvoice(stripeInvoice: Stripe.Invoice) {
|
||||
@OnStripeEvent('invoice.payment_succeeded')
|
||||
async saveInvoice(stripeInvoice: Stripe.Invoice, event: string) {
|
||||
stripeInvoice = await this.stripe.invoices.retrieve(stripeInvoice.id);
|
||||
if (!stripeInvoice.customer) {
|
||||
throw new Error('Unexpected invoice with no customer');
|
||||
@@ -487,12 +483,6 @@ export class SubscriptionService {
|
||||
: stripeInvoice.customer.id
|
||||
);
|
||||
|
||||
const invoice = await this.db.userInvoice.findUnique({
|
||||
where: {
|
||||
stripeInvoiceId: stripeInvoice.id,
|
||||
},
|
||||
});
|
||||
|
||||
const data: Partial<UserInvoice> = {
|
||||
currency: stripeInvoice.currency,
|
||||
amount: stripeInvoice.total,
|
||||
@@ -524,39 +514,135 @@ export class SubscriptionService {
|
||||
}
|
||||
}
|
||||
|
||||
// update invoice
|
||||
if (invoice) {
|
||||
await this.db.userInvoice.update({
|
||||
where: {
|
||||
stripeInvoiceId: stripeInvoice.id,
|
||||
// create invoice
|
||||
const price = stripeInvoice.lines.data[0].price;
|
||||
|
||||
if (!price) {
|
||||
throw new Error('Unexpected invoice with no price');
|
||||
}
|
||||
|
||||
if (!price.lookup_key) {
|
||||
throw new Error('Unexpected subscription with no key');
|
||||
}
|
||||
|
||||
const [plan, recurring] = decodeLookupKey(price.lookup_key);
|
||||
|
||||
const invoice = await this.db.userInvoice.upsert({
|
||||
where: {
|
||||
stripeInvoiceId: stripeInvoice.id,
|
||||
},
|
||||
update: data,
|
||||
create: {
|
||||
userId: user.id,
|
||||
stripeInvoiceId: stripeInvoice.id,
|
||||
plan,
|
||||
recurring,
|
||||
reason: stripeInvoice.billing_reason ?? 'contact support',
|
||||
...(data as any),
|
||||
},
|
||||
});
|
||||
|
||||
// handle one time payment, no subscription created by stripe
|
||||
if (
|
||||
event === 'invoice.payment_succeeded' &&
|
||||
recurring === SubscriptionRecurring.Lifetime &&
|
||||
stripeInvoice.status === 'paid'
|
||||
) {
|
||||
await this.saveLifetimeSubscription(user, invoice);
|
||||
}
|
||||
}
|
||||
|
||||
async saveLifetimeSubscription(user: User, invoice: UserInvoice) {
|
||||
// cancel previous non-lifetime subscription
|
||||
const savedSubscription = await this.db.userSubscription.findUnique({
|
||||
where: {
|
||||
userId_plan: {
|
||||
userId: user.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (savedSubscription && savedSubscription.stripeSubscriptionId) {
|
||||
await this.db.userSubscription.update({
|
||||
where: {
|
||||
id: savedSubscription.id,
|
||||
},
|
||||
data: {
|
||||
stripeScheduleId: null,
|
||||
stripeSubscriptionId: null,
|
||||
status: SubscriptionStatus.Active,
|
||||
recurring: SubscriptionRecurring.Lifetime,
|
||||
end: null,
|
||||
},
|
||||
data,
|
||||
});
|
||||
|
||||
await this.stripe.subscriptions.cancel(
|
||||
savedSubscription.stripeSubscriptionId,
|
||||
{
|
||||
prorate: true,
|
||||
}
|
||||
);
|
||||
} else {
|
||||
// create invoice
|
||||
const price = stripeInvoice.lines.data[0].price;
|
||||
|
||||
if (!price || price.type !== 'recurring') {
|
||||
throw new Error('Unexpected invoice with no recurring price');
|
||||
}
|
||||
|
||||
if (!price.lookup_key) {
|
||||
throw new Error('Unexpected subscription with no key');
|
||||
}
|
||||
|
||||
const [plan, recurring] = decodeLookupKey(price.lookup_key);
|
||||
|
||||
await this.db.userInvoice.create({
|
||||
await this.db.userSubscription.create({
|
||||
data: {
|
||||
userId: user.id,
|
||||
stripeInvoiceId: stripeInvoice.id,
|
||||
plan,
|
||||
recurring,
|
||||
reason: stripeInvoice.billing_reason ?? 'contact support',
|
||||
...(data as any),
|
||||
stripeSubscriptionId: null,
|
||||
plan: invoice.plan,
|
||||
recurring: invoice.recurring,
|
||||
end: null,
|
||||
start: new Date(),
|
||||
status: SubscriptionStatus.Active,
|
||||
nextBillAt: null,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
this.event.emit('user.subscription.activated', {
|
||||
userId: user.id,
|
||||
plan: invoice.plan as SubscriptionPlan,
|
||||
recurring: SubscriptionRecurring.Lifetime,
|
||||
});
|
||||
}
|
||||
|
||||
@OnStripeEvent('customer.subscription.created')
|
||||
@OnStripeEvent('customer.subscription.updated')
|
||||
async onSubscriptionChanges(subscription: Stripe.Subscription) {
|
||||
subscription = await this.stripe.subscriptions.retrieve(subscription.id);
|
||||
if (subscription.status === 'active') {
|
||||
const user = await this.retrieveUserFromCustomer(
|
||||
typeof subscription.customer === 'string'
|
||||
? subscription.customer
|
||||
: subscription.customer.id
|
||||
);
|
||||
|
||||
await this.saveSubscription(user, subscription);
|
||||
} else {
|
||||
await this.onSubscriptionDeleted(subscription);
|
||||
}
|
||||
}
|
||||
|
||||
@OnStripeEvent('customer.subscription.deleted')
|
||||
async onSubscriptionDeleted(subscription: Stripe.Subscription) {
|
||||
const user = await this.retrieveUserFromCustomer(
|
||||
typeof subscription.customer === 'string'
|
||||
? subscription.customer
|
||||
: subscription.customer.id
|
||||
);
|
||||
|
||||
const [plan, recurring] = this.decodePlanFromSubscription(subscription);
|
||||
|
||||
this.event.emit('user.subscription.canceled', {
|
||||
userId: user.id,
|
||||
plan,
|
||||
recurring,
|
||||
});
|
||||
|
||||
await this.db.userSubscription.deleteMany({
|
||||
where: {
|
||||
stripeSubscriptionId: subscription.id,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
private async saveSubscription(
|
||||
@@ -576,6 +662,7 @@ export class SubscriptionService {
|
||||
this.event.emit('user.subscription.activated', {
|
||||
userId: user.id,
|
||||
plan,
|
||||
recurring,
|
||||
});
|
||||
|
||||
let nextBillAt: Date | null = null;
|
||||
@@ -600,44 +687,21 @@ export class SubscriptionService {
|
||||
: null,
|
||||
stripeSubscriptionId: subscription.id,
|
||||
plan,
|
||||
recurring,
|
||||
status: subscription.status,
|
||||
stripeScheduleId: subscription.schedule as string | null,
|
||||
};
|
||||
|
||||
const currentSubscription = await this.db.userSubscription.findUnique({
|
||||
return await this.db.userSubscription.upsert({
|
||||
where: {
|
||||
userId_plan: {
|
||||
userId: user.id,
|
||||
plan,
|
||||
},
|
||||
stripeSubscriptionId: subscription.id,
|
||||
},
|
||||
update: commonData,
|
||||
create: {
|
||||
userId: user.id,
|
||||
recurring,
|
||||
...commonData,
|
||||
},
|
||||
});
|
||||
|
||||
if (currentSubscription) {
|
||||
const update: Prisma.UserSubscriptionUpdateInput = {
|
||||
...commonData,
|
||||
};
|
||||
|
||||
// a schedule exists, update the recurring to scheduled one
|
||||
if (update.stripeScheduleId) {
|
||||
delete update.recurring;
|
||||
}
|
||||
|
||||
return await this.db.userSubscription.update({
|
||||
where: {
|
||||
id: currentSubscription.id,
|
||||
},
|
||||
data: update,
|
||||
});
|
||||
} else {
|
||||
return await this.db.userSubscription.create({
|
||||
data: {
|
||||
userId: user.id,
|
||||
...commonData,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private async getOrCreateCustomer(
|
||||
@@ -749,6 +813,16 @@ export class SubscriptionService {
|
||||
recurring: SubscriptionRecurring,
|
||||
variant?: SubscriptionPriceVariant
|
||||
): Promise<string> {
|
||||
if (recurring === SubscriptionRecurring.Lifetime) {
|
||||
const lifetimePriceEnabled = await this.config.runtime.fetch(
|
||||
'plugins.payment/showLifetimePrice'
|
||||
);
|
||||
|
||||
if (!lifetimePriceEnabled) {
|
||||
throw new ActionForbidden();
|
||||
}
|
||||
}
|
||||
|
||||
const prices = await this.stripe.prices.list({
|
||||
lookup_keys: [encodeLookupKey(plan, recurring, variant)],
|
||||
});
|
||||
|
||||
@@ -5,6 +5,7 @@ import type { Payload } from '../../fundamentals/event/def';
|
||||
export enum SubscriptionRecurring {
|
||||
Monthly = 'monthly',
|
||||
Yearly = 'yearly',
|
||||
Lifetime = 'lifetime',
|
||||
}
|
||||
|
||||
export enum SubscriptionPlan {
|
||||
@@ -46,10 +47,12 @@ declare module '../../fundamentals/event/def' {
|
||||
activated: Payload<{
|
||||
userId: User['id'];
|
||||
plan: SubscriptionPlan;
|
||||
recurring: SubscriptionRecurring;
|
||||
}>;
|
||||
canceled: Payload<{
|
||||
userId: User['id'];
|
||||
plan: SubscriptionPlan;
|
||||
recurring: SubscriptionRecurring;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -45,9 +45,16 @@ export class StripeWebhook {
|
||||
setImmediate(() => {
|
||||
// handle duplicated events?
|
||||
// see https://stripe.com/docs/webhooks#handle-duplicate-events
|
||||
this.event.emitAsync(event.type, event.data.object).catch(e => {
|
||||
this.logger.error('Failed to handle Stripe Webhook event.', e);
|
||||
});
|
||||
this.event
|
||||
.emitAsync(
|
||||
event.type,
|
||||
event.data.object,
|
||||
// here to let event listeners know what exactly the event is if a handler can handle multiple events
|
||||
event.type
|
||||
)
|
||||
.catch(e => {
|
||||
this.logger.error('Failed to handle Stripe Webhook event.', e);
|
||||
});
|
||||
});
|
||||
} catch (err: any) {
|
||||
throw new InternalServerError(err.message);
|
||||
|
||||
@@ -11,6 +11,7 @@ type ChatMessage {
|
||||
attachments: [String!]
|
||||
content: String!
|
||||
createdAt: DateTime!
|
||||
id: ID
|
||||
params: JSON
|
||||
role: String!
|
||||
}
|
||||
@@ -39,6 +40,10 @@ type CopilotHistories {
|
||||
tokens: Int!
|
||||
}
|
||||
|
||||
type CopilotMessageNotFoundDataType {
|
||||
messageId: String!
|
||||
}
|
||||
|
||||
enum CopilotModels {
|
||||
DallE3
|
||||
Gpt4Omni
|
||||
@@ -52,6 +57,22 @@ enum CopilotModels {
|
||||
TextModerationStable
|
||||
}
|
||||
|
||||
input CopilotPromptConfigInput {
|
||||
frequencyPenalty: Int
|
||||
jsonMode: Boolean
|
||||
presencePenalty: Int
|
||||
temperature: Int
|
||||
topP: Int
|
||||
}
|
||||
|
||||
type CopilotPromptConfigType {
|
||||
frequencyPenalty: Int
|
||||
jsonMode: Boolean
|
||||
presencePenalty: Int
|
||||
temperature: Int
|
||||
topP: Int
|
||||
}
|
||||
|
||||
input CopilotPromptMessageInput {
|
||||
content: String!
|
||||
params: JSON
|
||||
@@ -76,6 +97,7 @@ type CopilotPromptNotFoundDataType {
|
||||
|
||||
type CopilotPromptType {
|
||||
action: String
|
||||
config: CopilotPromptConfigType
|
||||
messages: [CopilotPromptMessageType!]!
|
||||
model: CopilotModels!
|
||||
name: String!
|
||||
@@ -118,6 +140,7 @@ input CreateCheckoutSessionInput {
|
||||
|
||||
input CreateCopilotPromptInput {
|
||||
action: String
|
||||
config: CopilotPromptConfigInput
|
||||
messages: [CopilotPromptMessageInput!]!
|
||||
model: CopilotModels!
|
||||
name: String!
|
||||
@@ -175,7 +198,7 @@ enum EarlyAccessType {
|
||||
App
|
||||
}
|
||||
|
||||
union ErrorDataUnion = BlobNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInWorkspaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType | WorkspaceAccessDeniedDataType | WorkspaceNotFoundDataType | WorkspaceOwnerNotFoundDataType
|
||||
union ErrorDataUnion = BlobNotFoundDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInWorkspaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType | WorkspaceAccessDeniedDataType | WorkspaceNotFoundDataType | WorkspaceOwnerNotFoundDataType
|
||||
|
||||
enum ErrorNames {
|
||||
ACCESS_DENIED
|
||||
@@ -184,6 +207,7 @@ enum ErrorNames {
|
||||
BLOB_NOT_FOUND
|
||||
BLOB_QUOTA_EXCEEDED
|
||||
CANT_CHANGE_WORKSPACE_OWNER
|
||||
CANT_UPDATE_LIFETIME_SUBSCRIPTION
|
||||
COPILOT_ACTION_TAKEN
|
||||
COPILOT_FAILED_TO_CREATE_MESSAGE
|
||||
COPILOT_FAILED_TO_GENERATE_TEXT
|
||||
@@ -252,6 +276,17 @@ enum FeatureType {
|
||||
UnlimitedWorkspace
|
||||
}
|
||||
|
||||
input ForkChatSessionInput {
|
||||
docId: String!
|
||||
|
||||
"""
|
||||
Identify a message in the array and keep it with all previous messages into a forked session.
|
||||
"""
|
||||
latestMessageId: String!
|
||||
sessionId: String!
|
||||
workspaceId: String!
|
||||
}
|
||||
|
||||
type HumanReadableQuotaType {
|
||||
blobLimit: String!
|
||||
copilotActionLimit: String
|
||||
@@ -399,6 +434,9 @@ type Mutation {
|
||||
"""Delete a user account"""
|
||||
deleteUser(id: String!): DeleteAccount!
|
||||
deleteWorkspace(id: String!): Boolean!
|
||||
|
||||
"""Create a chat session"""
|
||||
forkCopilotSession(options: ForkChatSessionInput!): String!
|
||||
invite(email: String!, permission: Permission!, sendInviteMail: Boolean, workspaceId: String!): String!
|
||||
leaveWorkspace(sendLeaveMail: Boolean, workspaceId: String!, workspaceName: String!): Boolean!
|
||||
publishPage(mode: PublicPageMode = Page, pageId: String!, workspaceId: String!): WorkspacePage!
|
||||
@@ -638,12 +676,14 @@ type SubscriptionPlanNotFoundDataType {
|
||||
type SubscriptionPrice {
|
||||
amount: Int
|
||||
currency: String!
|
||||
lifetimeAmount: Int
|
||||
plan: SubscriptionPlan!
|
||||
type: String!
|
||||
yearlyAmount: Int
|
||||
}
|
||||
|
||||
enum SubscriptionRecurring {
|
||||
Lifetime
|
||||
Monthly
|
||||
Yearly
|
||||
}
|
||||
@@ -714,8 +754,8 @@ type UserQuotaHumanReadable {
|
||||
type UserSubscription {
|
||||
canceledAt: DateTime
|
||||
createdAt: DateTime!
|
||||
end: DateTime!
|
||||
id: String!
|
||||
end: DateTime
|
||||
id: String
|
||||
nextBillAt: DateTime
|
||||
|
||||
"""
|
||||
|
||||
@@ -36,6 +36,7 @@ import {
|
||||
chatWithWorkflow,
|
||||
createCopilotMessage,
|
||||
createCopilotSession,
|
||||
forkCopilotSession,
|
||||
getHistories,
|
||||
MockCopilotTestProvider,
|
||||
sse2array,
|
||||
@@ -96,7 +97,7 @@ test.beforeEach(async t => {
|
||||
]);
|
||||
|
||||
for (const p of prompts) {
|
||||
await prompt.set(p.name, p.model, p.messages);
|
||||
await prompt.set(p.name, p.model, p.messages, p.config);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -164,6 +165,123 @@ test('should create session correctly', async t => {
|
||||
}
|
||||
});
|
||||
|
||||
test('should fork session correctly', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
const assertForkSession = async (
|
||||
token: string,
|
||||
workspaceId: string,
|
||||
sessionId: string,
|
||||
lastMessageId: string,
|
||||
error: string,
|
||||
asserter = async (x: any) => {
|
||||
const forkedSessionId = await x;
|
||||
t.truthy(forkedSessionId, error);
|
||||
return forkedSessionId;
|
||||
}
|
||||
) =>
|
||||
await asserter(
|
||||
forkCopilotSession(
|
||||
app,
|
||||
token,
|
||||
workspaceId,
|
||||
randomUUID(),
|
||||
sessionId,
|
||||
lastMessageId
|
||||
)
|
||||
);
|
||||
|
||||
// prepare session
|
||||
const { id } = await createWorkspace(app, token);
|
||||
const sessionId = await createCopilotSession(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
randomUUID(),
|
||||
promptName
|
||||
);
|
||||
|
||||
let forkedSessionId: string;
|
||||
// should be able to fork session
|
||||
{
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const messageId = await createCopilotMessage(app, token, sessionId);
|
||||
await chatWithText(app, token, sessionId, messageId);
|
||||
}
|
||||
const histories = await getHistories(app, token, { workspaceId: id });
|
||||
const latestMessageId = histories[0].messages.findLast(
|
||||
m => m.role === 'assistant'
|
||||
)?.id;
|
||||
t.truthy(latestMessageId, 'should find last message id');
|
||||
|
||||
// should be able to fork session
|
||||
forkedSessionId = await assertForkSession(
|
||||
token,
|
||||
id,
|
||||
sessionId,
|
||||
latestMessageId!,
|
||||
'should be able to fork session with cloud workspace that user can access'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
const {
|
||||
token: { token: newToken },
|
||||
} = await signUp(app, 'test', 'test@affine.pro', '123456');
|
||||
await assertForkSession(
|
||||
newToken,
|
||||
id,
|
||||
sessionId,
|
||||
randomUUID(),
|
||||
'',
|
||||
async x => {
|
||||
await t.throwsAsync(
|
||||
x,
|
||||
{ instanceOf: Error },
|
||||
'should not able to fork session with cloud workspace that user cannot access'
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
const inviteId = await inviteUser(
|
||||
app,
|
||||
token,
|
||||
id,
|
||||
'test@affine.pro',
|
||||
'Admin'
|
||||
);
|
||||
await acceptInviteById(app, id, inviteId, false);
|
||||
await assertForkSession(
|
||||
newToken,
|
||||
id,
|
||||
sessionId,
|
||||
randomUUID(),
|
||||
'',
|
||||
async x => {
|
||||
await t.throwsAsync(
|
||||
x,
|
||||
{ instanceOf: Error },
|
||||
'should not able to fork a root session from other user'
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
const histories = await getHistories(app, token, { workspaceId: id });
|
||||
const latestMessageId = histories
|
||||
.find(h => h.sessionId === forkedSessionId)
|
||||
?.messages.findLast(m => m.role === 'assistant')?.id;
|
||||
t.truthy(latestMessageId, 'should find latest message id');
|
||||
|
||||
await assertForkSession(
|
||||
newToken,
|
||||
id,
|
||||
forkedSessionId,
|
||||
latestMessageId!,
|
||||
'should able to fork a forked session created by other user'
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to use test provider', async t => {
|
||||
const { app } = t.context;
|
||||
|
||||
@@ -261,7 +379,7 @@ test('should be able to chat with api by workflow', async t => {
|
||||
const ret = await chatWithWorkflow(app, token, sessionId, messageId);
|
||||
t.is(
|
||||
array2sse(sse2array(ret).filter(e => e.event !== 'event')),
|
||||
textToEventStream(['generate text to text stream'], messageId),
|
||||
textToEventStream('generate text to text stream', messageId),
|
||||
'should be able to chat with workflow'
|
||||
);
|
||||
});
|
||||
|
||||
@@ -208,11 +208,13 @@ test('should be able to manage chat session', async t => {
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
|
||||
const params = { word: 'world' };
|
||||
const commonParams = { docId: 'test', workspaceId: 'test' };
|
||||
|
||||
const sessionId = await session.create({
|
||||
docId: 'test',
|
||||
workspaceId: 'test',
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
...commonParams,
|
||||
});
|
||||
t.truthy(sessionId, 'should create session');
|
||||
|
||||
@@ -221,8 +223,6 @@ test('should be able to manage chat session', async t => {
|
||||
t.is(s.config.promptName, 'prompt', 'should have prompt name');
|
||||
t.is(s.model, 'model', 'should have model');
|
||||
|
||||
const params = { word: 'world' };
|
||||
|
||||
s.push({ role: 'user', content: 'hello', createdAt: new Date() });
|
||||
// @ts-expect-error
|
||||
const finalMessages = s.finish(params).map(({ createdAt: _, ...m }) => m);
|
||||
@@ -239,19 +239,112 @@ test('should be able to manage chat session', async t => {
|
||||
const s1 = (await session.get(sessionId))!;
|
||||
t.deepEqual(
|
||||
// @ts-expect-error
|
||||
s1.finish(params).map(({ createdAt: _, ...m }) => m),
|
||||
s1.finish(params).map(({ id: _, createdAt: __, ...m }) => m),
|
||||
finalMessages,
|
||||
'should same as before message'
|
||||
);
|
||||
t.deepEqual(
|
||||
// @ts-expect-error
|
||||
s1.finish({}).map(({ createdAt: _, ...m }) => m),
|
||||
s1.finish({}).map(({ id: _, createdAt: __, ...m }) => m),
|
||||
[
|
||||
{ content: 'hello ', params: {}, role: 'system' },
|
||||
{ content: 'hello', role: 'user' },
|
||||
],
|
||||
'should generate different message with another params'
|
||||
);
|
||||
|
||||
// should get main session after fork if re-create a chat session for same docId and workspaceId
|
||||
{
|
||||
const newSessionId = await session.create({
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
...commonParams,
|
||||
});
|
||||
t.is(newSessionId, sessionId, 'should get same session id');
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to fork chat session', async t => {
|
||||
const { prompt, session } = t.context;
|
||||
|
||||
await prompt.set('prompt', 'model', [
|
||||
{ role: 'system', content: 'hello {{word}}' },
|
||||
]);
|
||||
|
||||
const params = { word: 'world' };
|
||||
const commonParams = { docId: 'test', workspaceId: 'test' };
|
||||
// create session
|
||||
const sessionId = await session.create({
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
...commonParams,
|
||||
});
|
||||
const s = (await session.get(sessionId))!;
|
||||
s.push({ role: 'user', content: 'hello', createdAt: new Date() });
|
||||
s.push({ role: 'assistant', content: 'world', createdAt: new Date() });
|
||||
s.push({ role: 'user', content: 'aaa', createdAt: new Date() });
|
||||
s.push({ role: 'assistant', content: 'bbb', createdAt: new Date() });
|
||||
await s.save();
|
||||
|
||||
// fork session
|
||||
const s1 = (await session.get(sessionId))!;
|
||||
// @ts-expect-error
|
||||
const latestMessageId = s1.finish({}).find(m => m.role === 'assistant')!.id;
|
||||
const forkedSessionId = await session.fork({
|
||||
userId,
|
||||
sessionId,
|
||||
latestMessageId,
|
||||
...commonParams,
|
||||
});
|
||||
t.not(sessionId, forkedSessionId, 'should fork a new session');
|
||||
|
||||
// check forked session messages
|
||||
{
|
||||
const s2 = (await session.get(forkedSessionId))!;
|
||||
|
||||
const finalMessages = s2
|
||||
.finish(params) // @ts-expect-error
|
||||
.map(({ id: _, createdAt: __, ...m }) => m);
|
||||
t.deepEqual(
|
||||
finalMessages,
|
||||
[
|
||||
{ role: 'system', content: 'hello world', params },
|
||||
{ role: 'user', content: 'hello' },
|
||||
{ role: 'assistant', content: 'world' },
|
||||
],
|
||||
'should generate the final message'
|
||||
);
|
||||
}
|
||||
|
||||
// check original session messages
|
||||
{
|
||||
const s3 = (await session.get(sessionId))!;
|
||||
|
||||
const finalMessages = s3
|
||||
.finish(params) // @ts-expect-error
|
||||
.map(({ id: _, createdAt: __, ...m }) => m);
|
||||
t.deepEqual(
|
||||
finalMessages,
|
||||
[
|
||||
{ role: 'system', content: 'hello world', params },
|
||||
{ role: 'user', content: 'hello' },
|
||||
{ role: 'assistant', content: 'world' },
|
||||
{ role: 'user', content: 'aaa' },
|
||||
{ role: 'assistant', content: 'bbb' },
|
||||
],
|
||||
'should generate the final message'
|
||||
);
|
||||
}
|
||||
|
||||
// should get main session after fork if re-create a chat session for same docId and workspaceId
|
||||
{
|
||||
const newSessionId = await session.create({
|
||||
userId,
|
||||
promptName: 'prompt',
|
||||
...commonParams,
|
||||
});
|
||||
t.is(newSessionId, sessionId, 'should get same session id');
|
||||
}
|
||||
});
|
||||
|
||||
test('should be able to process message id', async t => {
|
||||
@@ -583,7 +676,7 @@ test.skip('should be able to preview workflow', async t => {
|
||||
registerCopilotProvider(OpenAIProvider);
|
||||
|
||||
for (const p of prompts) {
|
||||
await prompt.set(p.name, p.model, p.messages);
|
||||
await prompt.set(p.name, p.model, p.messages, p.config);
|
||||
}
|
||||
|
||||
let result = '';
|
||||
@@ -633,7 +726,7 @@ test('should be able to run pre defined workflow', async t => {
|
||||
const { graph, prompts, callCount, input, params, result } = testCase;
|
||||
console.log('running workflow test:', graph.name);
|
||||
for (const p of prompts) {
|
||||
await prompt.set(p.name, p.model, p.messages);
|
||||
await prompt.set(p.name, p.model, p.messages, p.config);
|
||||
}
|
||||
|
||||
for (const [idx, i] of input.entries()) {
|
||||
@@ -680,7 +773,7 @@ test('should be able to run workflow', async t => {
|
||||
const executor = Sinon.spy(executors.text, 'next');
|
||||
|
||||
for (const p of prompts) {
|
||||
await prompt.set(p.name, p.model, p.messages);
|
||||
await prompt.set(p.name, p.model, p.messages, p.config);
|
||||
}
|
||||
|
||||
const graphName = 'presentation';
|
||||
@@ -699,9 +792,7 @@ test('should be able to run workflow', async t => {
|
||||
}
|
||||
t.assert(result, 'generate text to text stream');
|
||||
|
||||
// presentation workflow has condition node, it will always false
|
||||
// so the latest 2 nodes will not be executed
|
||||
const callCount = graph!.graph.length - 2;
|
||||
const callCount = graph!.graph.length;
|
||||
t.is(
|
||||
executor.callCount,
|
||||
callCount,
|
||||
@@ -717,7 +808,7 @@ test('should be able to run workflow', async t => {
|
||||
|
||||
t.is(
|
||||
params.args[1].content,
|
||||
'generate text to text stream',
|
||||
'apple company',
|
||||
'graph params should correct'
|
||||
);
|
||||
t.is(
|
||||
|
||||
@@ -14,7 +14,7 @@ import {
|
||||
FeatureManagementService,
|
||||
} from '../../src/core/features';
|
||||
import { EventEmitter } from '../../src/fundamentals';
|
||||
import { ConfigModule } from '../../src/fundamentals/config';
|
||||
import { Config, ConfigModule } from '../../src/fundamentals/config';
|
||||
import {
|
||||
CouponType,
|
||||
encodeLookupKey,
|
||||
@@ -84,6 +84,7 @@ test.afterEach.always(async t => {
|
||||
|
||||
const PRO_MONTHLY = `${SubscriptionPlan.Pro}_${SubscriptionRecurring.Monthly}`;
|
||||
const PRO_YEARLY = `${SubscriptionPlan.Pro}_${SubscriptionRecurring.Yearly}`;
|
||||
const PRO_LIFETIME = `${SubscriptionPlan.Pro}_${SubscriptionRecurring.Lifetime}`;
|
||||
const PRO_EA_YEARLY = `${SubscriptionPlan.Pro}_${SubscriptionRecurring.Yearly}_${SubscriptionPriceVariant.EA}`;
|
||||
const AI_YEARLY = `${SubscriptionPlan.AI}_${SubscriptionRecurring.Yearly}`;
|
||||
const AI_YEARLY_EA = `${SubscriptionPlan.AI}_${SubscriptionRecurring.Yearly}_${SubscriptionPriceVariant.EA}`;
|
||||
@@ -105,6 +106,11 @@ const PRICES = {
|
||||
currency: 'usd',
|
||||
lookup_key: PRO_YEARLY,
|
||||
},
|
||||
[PRO_LIFETIME]: {
|
||||
unit_amount: 49900,
|
||||
currency: 'usd',
|
||||
lookup_key: PRO_LIFETIME,
|
||||
},
|
||||
[PRO_EA_YEARLY]: {
|
||||
recurring: {
|
||||
interval: 'year',
|
||||
@@ -170,10 +176,9 @@ test('should list normal price for unauthenticated user', async t => {
|
||||
|
||||
const prices = await service.listPrices();
|
||||
|
||||
t.is(prices.length, 3);
|
||||
t.deepEqual(
|
||||
new Set(prices.map(p => p.lookup_key)),
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, AI_YEARLY])
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, PRO_LIFETIME, AI_YEARLY])
|
||||
);
|
||||
});
|
||||
|
||||
@@ -190,10 +195,9 @@ test('should list normal prices for authenticated user', async t => {
|
||||
|
||||
const prices = await service.listPrices(u1);
|
||||
|
||||
t.is(prices.length, 3);
|
||||
t.deepEqual(
|
||||
new Set(prices.map(p => p.lookup_key)),
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, AI_YEARLY])
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, PRO_LIFETIME, AI_YEARLY])
|
||||
);
|
||||
});
|
||||
|
||||
@@ -210,10 +214,9 @@ test('should list early access prices for pro ea user', async t => {
|
||||
|
||||
const prices = await service.listPrices(u1);
|
||||
|
||||
t.is(prices.length, 3);
|
||||
t.deepEqual(
|
||||
new Set(prices.map(p => p.lookup_key)),
|
||||
new Set([PRO_MONTHLY, PRO_EA_YEARLY, AI_YEARLY])
|
||||
new Set([PRO_MONTHLY, PRO_LIFETIME, PRO_EA_YEARLY, AI_YEARLY])
|
||||
);
|
||||
});
|
||||
|
||||
@@ -246,10 +249,9 @@ test('should list normal prices for pro ea user with old subscriptions', async t
|
||||
|
||||
const prices = await service.listPrices(u1);
|
||||
|
||||
t.is(prices.length, 3);
|
||||
t.deepEqual(
|
||||
new Set(prices.map(p => p.lookup_key)),
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, AI_YEARLY])
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, PRO_LIFETIME, AI_YEARLY])
|
||||
);
|
||||
});
|
||||
|
||||
@@ -266,10 +268,9 @@ test('should list early access prices for ai ea user', async t => {
|
||||
|
||||
const prices = await service.listPrices(u1);
|
||||
|
||||
t.is(prices.length, 3);
|
||||
t.deepEqual(
|
||||
new Set(prices.map(p => p.lookup_key)),
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, AI_YEARLY_EA])
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, PRO_LIFETIME, AI_YEARLY_EA])
|
||||
);
|
||||
});
|
||||
|
||||
@@ -286,10 +287,9 @@ test('should list early access prices for pro and ai ea user', async t => {
|
||||
|
||||
const prices = await service.listPrices(u1);
|
||||
|
||||
t.is(prices.length, 3);
|
||||
t.deepEqual(
|
||||
new Set(prices.map(p => p.lookup_key)),
|
||||
new Set([PRO_MONTHLY, PRO_EA_YEARLY, AI_YEARLY_EA])
|
||||
new Set([PRO_MONTHLY, PRO_LIFETIME, PRO_EA_YEARLY, AI_YEARLY_EA])
|
||||
);
|
||||
});
|
||||
|
||||
@@ -322,10 +322,9 @@ test('should list normal prices for ai ea user with old subscriptions', async t
|
||||
|
||||
const prices = await service.listPrices(u1);
|
||||
|
||||
t.is(prices.length, 3);
|
||||
t.deepEqual(
|
||||
new Set(prices.map(p => p.lookup_key)),
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, AI_YEARLY])
|
||||
new Set([PRO_MONTHLY, PRO_YEARLY, PRO_LIFETIME, AI_YEARLY])
|
||||
);
|
||||
});
|
||||
|
||||
@@ -458,6 +457,22 @@ test('should get correct pro plan price for checking out', async t => {
|
||||
coupon: undefined,
|
||||
});
|
||||
}
|
||||
|
||||
// any user, lifetime recurring
|
||||
{
|
||||
feature.isEarlyAccessUser.resolves(false);
|
||||
// @ts-expect-error stub
|
||||
subListStub.resolves({ data: [] });
|
||||
const ret = await getAvailablePrice(
|
||||
customer,
|
||||
SubscriptionPlan.Pro,
|
||||
SubscriptionRecurring.Lifetime
|
||||
);
|
||||
t.deepEqual(ret, {
|
||||
price: PRO_LIFETIME,
|
||||
coupon: undefined,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
test('should get correct ai plan price for checking out', async t => {
|
||||
@@ -639,6 +654,7 @@ test('should be able to create subscription', async t => {
|
||||
emitStub.calledOnceWith('user.subscription.activated', {
|
||||
userId: u1.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -674,6 +690,7 @@ test('should be able to update subscription', async t => {
|
||||
emitStub.calledOnceWith('user.subscription.activated', {
|
||||
userId: u1.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -706,6 +723,7 @@ test('should be able to delete subscription', async t => {
|
||||
emitStub.calledOnceWith('user.subscription.canceled', {
|
||||
userId: u1.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -749,6 +767,7 @@ test('should be able to cancel subscription', async t => {
|
||||
emitStub.calledOnceWith('user.subscription.activated', {
|
||||
userId: u1.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -785,6 +804,7 @@ test('should be able to resume subscription', async t => {
|
||||
emitStub.calledOnceWith('user.subscription.activated', {
|
||||
userId: u1.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -929,3 +949,159 @@ test('should operate with latest subscription status', async t => {
|
||||
t.deepEqual(stub.firstCall.args[1], sub);
|
||||
t.deepEqual(stub.secondCall.args[1], sub);
|
||||
});
|
||||
|
||||
// ============== Lifetime Subscription ===============
|
||||
const invoice: Stripe.Invoice = {
|
||||
id: 'in_xxx',
|
||||
object: 'invoice',
|
||||
amount_paid: 49900,
|
||||
total: 49900,
|
||||
customer: 'cus_1',
|
||||
currency: 'usd',
|
||||
status: 'paid',
|
||||
lines: {
|
||||
data: [
|
||||
{
|
||||
// @ts-expect-error stub
|
||||
price: PRICES[PRO_LIFETIME],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
test('should not be able to checkout for lifetime recurring if not enabled', async t => {
|
||||
const { service, stripe, u1 } = t.context;
|
||||
|
||||
Sinon.stub(stripe.subscriptions, 'list').resolves({ data: [] } as any);
|
||||
await t.throwsAsync(
|
||||
() =>
|
||||
service.createCheckoutSession({
|
||||
user: u1,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Lifetime,
|
||||
redirectUrl: '',
|
||||
idempotencyKey: '',
|
||||
}),
|
||||
{ message: 'You are not allowed to perform this action.' }
|
||||
);
|
||||
});
|
||||
|
||||
test('should be able to checkout for lifetime recurring', async t => {
|
||||
const { service, stripe, u1, app } = t.context;
|
||||
const config = app.get(Config);
|
||||
await config.runtime.set('plugins.payment/showLifetimePrice', true);
|
||||
|
||||
Sinon.stub(stripe.subscriptions, 'list').resolves({ data: [] } as any);
|
||||
Sinon.stub(stripe.prices, 'list').resolves({
|
||||
data: [PRICES[PRO_LIFETIME]],
|
||||
} as any);
|
||||
const sessionStub = Sinon.stub(stripe.checkout.sessions, 'create');
|
||||
|
||||
await service.createCheckoutSession({
|
||||
user: u1,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Lifetime,
|
||||
redirectUrl: '',
|
||||
idempotencyKey: '',
|
||||
});
|
||||
|
||||
t.true(sessionStub.calledOnce);
|
||||
});
|
||||
|
||||
test('should be able to subscribe to lifetime recurring', async t => {
|
||||
// lifetime payment isn't a subscription, so we need to trigger the creation by invoice payment event
|
||||
const { service, stripe, db, u1, event } = t.context;
|
||||
|
||||
const emitStub = Sinon.stub(event, 'emit');
|
||||
Sinon.stub(stripe.invoices, 'retrieve').resolves(invoice as any);
|
||||
await service.saveInvoice(invoice, 'invoice.payment_succeeded');
|
||||
|
||||
const subInDB = await db.userSubscription.findFirst({
|
||||
where: { userId: u1.id },
|
||||
});
|
||||
|
||||
t.true(
|
||||
emitStub.calledOnceWith('user.subscription.activated', {
|
||||
userId: u1.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Lifetime,
|
||||
})
|
||||
);
|
||||
t.is(subInDB?.plan, SubscriptionPlan.Pro);
|
||||
t.is(subInDB?.recurring, SubscriptionRecurring.Lifetime);
|
||||
t.is(subInDB?.status, SubscriptionStatus.Active);
|
||||
t.is(subInDB?.stripeSubscriptionId, null);
|
||||
});
|
||||
|
||||
test('should be able to subscribe to lifetime recurring with old subscription', async t => {
|
||||
const { service, stripe, db, u1, event } = t.context;
|
||||
|
||||
await db.userSubscription.create({
|
||||
data: {
|
||||
userId: u1.id,
|
||||
stripeSubscriptionId: 'sub_1',
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Monthly,
|
||||
status: SubscriptionStatus.Active,
|
||||
start: new Date(),
|
||||
end: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
const emitStub = Sinon.stub(event, 'emit');
|
||||
Sinon.stub(stripe.invoices, 'retrieve').resolves(invoice as any);
|
||||
Sinon.stub(stripe.subscriptions, 'cancel').resolves(sub as any);
|
||||
await service.saveInvoice(invoice, 'invoice.payment_succeeded');
|
||||
|
||||
const subInDB = await db.userSubscription.findFirst({
|
||||
where: { userId: u1.id },
|
||||
});
|
||||
|
||||
t.true(
|
||||
emitStub.calledOnceWith('user.subscription.activated', {
|
||||
userId: u1.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Lifetime,
|
||||
})
|
||||
);
|
||||
t.is(subInDB?.plan, SubscriptionPlan.Pro);
|
||||
t.is(subInDB?.recurring, SubscriptionRecurring.Lifetime);
|
||||
t.is(subInDB?.status, SubscriptionStatus.Active);
|
||||
t.is(subInDB?.stripeSubscriptionId, null);
|
||||
});
|
||||
|
||||
test('should not be able to update lifetime recurring', async t => {
|
||||
const { service, db, u1 } = t.context;
|
||||
|
||||
await db.userSubscription.create({
|
||||
data: {
|
||||
userId: u1.id,
|
||||
plan: SubscriptionPlan.Pro,
|
||||
recurring: SubscriptionRecurring.Lifetime,
|
||||
status: SubscriptionStatus.Active,
|
||||
start: new Date(),
|
||||
end: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
await t.throwsAsync(
|
||||
() => service.cancelSubscription('', u1.id, SubscriptionPlan.Pro),
|
||||
{ message: 'Lifetime subscription cannot be canceled.' }
|
||||
);
|
||||
|
||||
await t.throwsAsync(
|
||||
() =>
|
||||
service.updateSubscriptionRecurring(
|
||||
'',
|
||||
u1.id,
|
||||
SubscriptionPlan.Pro,
|
||||
SubscriptionRecurring.Monthly
|
||||
),
|
||||
{ message: 'Can not update lifetime subscription.' }
|
||||
);
|
||||
|
||||
await t.throwsAsync(
|
||||
() => service.resumeCanceledSubscription('', u1.id, SubscriptionPlan.Pro),
|
||||
{ message: 'Lifetime subscription cannot be resumed.' }
|
||||
);
|
||||
});
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
CopilotTextToEmbeddingProvider,
|
||||
CopilotTextToImageProvider,
|
||||
CopilotTextToTextProvider,
|
||||
PromptConfig,
|
||||
PromptMessage,
|
||||
} from '../../src/plugins/copilot/types';
|
||||
import { NodeExecutorType } from '../../src/plugins/copilot/workflow/executor';
|
||||
@@ -174,6 +175,35 @@ export async function createCopilotSession(
|
||||
return res.body.data.createCopilotSession;
|
||||
}
|
||||
|
||||
export async function forkCopilotSession(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
sessionId: string,
|
||||
latestMessageId: string
|
||||
): Promise<string> {
|
||||
const res = await request(app.getHttpServer())
|
||||
.post(gql)
|
||||
.auth(userToken, { type: 'bearer' })
|
||||
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
||||
.send({
|
||||
query: `
|
||||
mutation forkCopilotSession($options: ForkChatSessionInput!) {
|
||||
forkCopilotSession(options: $options)
|
||||
}
|
||||
`,
|
||||
variables: {
|
||||
options: { workspaceId, docId, sessionId, latestMessageId },
|
||||
},
|
||||
})
|
||||
.expect(200);
|
||||
|
||||
handleGraphQLError(res);
|
||||
|
||||
return res.body.data.forkCopilotSession;
|
||||
}
|
||||
|
||||
export async function createCopilotMessage(
|
||||
app: INestApplication,
|
||||
userToken: string,
|
||||
@@ -286,6 +316,7 @@ export function textToEventStream(
|
||||
}
|
||||
|
||||
type ChatMessage = {
|
||||
id?: string;
|
||||
role: string;
|
||||
content: string;
|
||||
attachments: string[] | null;
|
||||
@@ -333,6 +364,7 @@ export async function getHistories(
|
||||
action
|
||||
createdAt
|
||||
messages {
|
||||
id
|
||||
role
|
||||
content
|
||||
attachments
|
||||
@@ -352,7 +384,12 @@ export async function getHistories(
|
||||
return res.body.data.currentUser?.copilot?.histories || [];
|
||||
}
|
||||
|
||||
type Prompt = { name: string; model: string; messages: PromptMessage[] };
|
||||
type Prompt = {
|
||||
name: string;
|
||||
model: string;
|
||||
messages: PromptMessage[];
|
||||
config?: PromptConfig;
|
||||
};
|
||||
type WorkflowTestCase = {
|
||||
graph: WorkflowGraph;
|
||||
prompts: Prompt[];
|
||||
|
||||
@@ -145,7 +145,14 @@ export function handleGraphQLError(resp: Response) {
|
||||
if (errors) {
|
||||
const cause = errors[0];
|
||||
const stacktrace = cause.extensions?.stacktrace;
|
||||
throw new Error(stacktrace ? stacktrace.join('\n') : cause.message, cause);
|
||||
throw new Error(
|
||||
stacktrace
|
||||
? Array.isArray(stacktrace)
|
||||
? stacktrace.join('\n')
|
||||
: String(stacktrace)
|
||||
: cause.message,
|
||||
cause
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
4
packages/common/env/package.json
vendored
4
packages/common/env/package.json
vendored
@@ -3,8 +3,8 @@
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"devDependencies": {
|
||||
"@blocksuite/global": "0.15.0-canary-202406280242-21d8d15",
|
||||
"@blocksuite/store": "0.15.0-canary-202406280242-21d8d15",
|
||||
"@blocksuite/global": "0.16.0-canary-202407050348-4620c21",
|
||||
"@blocksuite/store": "0.16.0-canary-202407050348-4620c21",
|
||||
"react": "18.3.1",
|
||||
"react-dom": "18.3.1",
|
||||
"vitest": "1.6.0"
|
||||
|
||||
1
packages/common/env/src/global.ts
vendored
1
packages/common/env/src/global.ts
vendored
@@ -24,6 +24,7 @@ export const runtimeFlagsSchema = z.object({
|
||||
enablePayment: z.boolean(),
|
||||
enablePageHistory: z.boolean(),
|
||||
enableExperimentalFeature: z.boolean(),
|
||||
enableInfoModal: z.boolean(),
|
||||
allowLocalWorkspace: z.boolean(),
|
||||
// this is for the electron app
|
||||
serverUrlPrefix: z.string(),
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"name": "@toeverything/infra",
|
||||
"type": "module",
|
||||
"private": true,
|
||||
"sideEffects": false,
|
||||
"exports": {
|
||||
"./blocksuite": "./src/blocksuite/index.ts",
|
||||
"./storage": "./src/storage/index.ts",
|
||||
@@ -13,26 +14,30 @@
|
||||
"@affine/debug": "workspace:*",
|
||||
"@affine/env": "workspace:*",
|
||||
"@affine/templates": "workspace:*",
|
||||
"@blocksuite/blocks": "0.15.0-canary-202406280242-21d8d15",
|
||||
"@blocksuite/global": "0.15.0-canary-202406280242-21d8d15",
|
||||
"@blocksuite/store": "0.15.0-canary-202406280242-21d8d15",
|
||||
"@blocksuite/blocks": "0.16.0-canary-202407050348-4620c21",
|
||||
"@blocksuite/global": "0.16.0-canary-202407050348-4620c21",
|
||||
"@blocksuite/store": "0.16.0-canary-202407050348-4620c21",
|
||||
"@datastructures-js/binary-search-tree": "^5.3.2",
|
||||
"foxact": "^0.2.33",
|
||||
"fuse.js": "^7.0.0",
|
||||
"graphemer": "^1.4.0",
|
||||
"idb": "^8.0.0",
|
||||
"jotai": "^2.8.0",
|
||||
"jotai-effect": "^1.0.0",
|
||||
"lodash-es": "^4.17.21",
|
||||
"nanoid": "^5.0.7",
|
||||
"react": "18.3.1",
|
||||
"yjs": "^13.6.14",
|
||||
"yjs": "patch:yjs@npm%3A13.6.18#~/.yarn/patches/yjs-npm-13.6.18-ad0d5f7c43.patch",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@affine-test/fixtures": "workspace:*",
|
||||
"@affine/templates": "workspace:*",
|
||||
"@blocksuite/block-std": "0.15.0-canary-202406280242-21d8d15",
|
||||
"@blocksuite/presets": "0.15.0-canary-202406280242-21d8d15",
|
||||
"@blocksuite/block-std": "0.16.0-canary-202407050348-4620c21",
|
||||
"@blocksuite/presets": "0.16.0-canary-202407050348-4620c21",
|
||||
"@testing-library/react": "^16.0.0",
|
||||
"async-call-rpc": "^6.4.0",
|
||||
"fake-indexeddb": "^6.0.0",
|
||||
"react": "^18.2.0",
|
||||
"rxjs": "^7.8.1",
|
||||
"vite": "^5.2.8",
|
||||
|
||||
@@ -92,6 +92,7 @@ export function setupEditorFlags(docCollection: DocCollection) {
|
||||
// override this flag in app settings
|
||||
// TODO(@eyhn): need a better way to manage block suite flags
|
||||
docCollection.awarenessStore.setFlag('enable_synced_doc_block', true);
|
||||
docCollection.awarenessStore.setFlag('enable_edgeless_text', true);
|
||||
} catch (err) {
|
||||
logger.error('syncEditorFlags', err);
|
||||
}
|
||||
|
||||
@@ -6,5 +6,6 @@ export * from './error';
|
||||
export { createEvent, OnEvent } from './event';
|
||||
export { Framework } from './framework';
|
||||
export { createIdentifier } from './identifier';
|
||||
export type { FrameworkProvider, ResolveOptions } from './provider';
|
||||
export type { ResolveOptions } from './provider';
|
||||
export { FrameworkProvider } from './provider';
|
||||
export type { GeneralIdentifier } from './types';
|
||||
|
||||
@@ -9,6 +9,12 @@ export const FrameworkStackContext = React.createContext<FrameworkProvider[]>([
|
||||
Framework.EMPTY.provider(),
|
||||
]);
|
||||
|
||||
export function useFramework(): FrameworkProvider {
|
||||
const stack = useContext(FrameworkStackContext);
|
||||
|
||||
return stack[stack.length - 1]; // never null, because the default value
|
||||
}
|
||||
|
||||
export function useService<T extends Service>(
|
||||
identifier: GeneralIdentifier<T>
|
||||
): T {
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
timer,
|
||||
} from 'rxjs';
|
||||
|
||||
import { MANUALLY_STOP } from '../utils';
|
||||
import type { LiveData } from './livedata';
|
||||
|
||||
/**
|
||||
@@ -107,7 +108,8 @@ export function fromPromise<T>(
|
||||
.catch(error => {
|
||||
subscriber.error(error);
|
||||
});
|
||||
return () => abortController.abort('Aborted');
|
||||
|
||||
return () => abortController.abort(MANUALLY_STOP);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,23 @@ export class DocRecordList extends Entity {
|
||||
[]
|
||||
);
|
||||
|
||||
public readonly trashDocs$ = LiveData.from<DocRecord[]>(
|
||||
this.store.watchTrashDocIds().pipe(
|
||||
map(ids =>
|
||||
ids.map(id => {
|
||||
const exists = this.pool.get(id);
|
||||
if (exists) {
|
||||
return exists;
|
||||
}
|
||||
const record = this.framework.createEntity(DocRecord, { id });
|
||||
this.pool.set(id, record);
|
||||
return record;
|
||||
})
|
||||
)
|
||||
),
|
||||
[]
|
||||
);
|
||||
|
||||
public readonly isReady$ = LiveData.from(
|
||||
this.store.watchDocListReady(),
|
||||
false
|
||||
|
||||
@@ -13,7 +13,6 @@ export type DocMode = 'edgeless' | 'page';
|
||||
*/
|
||||
export class DocRecord extends Entity<{ id: string }> {
|
||||
id: string = this.props.id;
|
||||
meta: Partial<DocMeta> | null = null;
|
||||
constructor(private readonly docsStore: DocsStore) {
|
||||
super();
|
||||
}
|
||||
@@ -59,5 +58,6 @@ export class DocRecord extends Entity<{ id: string }> {
|
||||
}
|
||||
|
||||
title$ = this.meta$.map(meta => meta.title ?? '');
|
||||
|
||||
trash$ = this.meta$.map(meta => meta.trash ?? false);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import { Unreachable } from '@affine/env/constant';
|
||||
|
||||
import { Service } from '../../../framework';
|
||||
import { initEmptyPage } from '../../../initialization';
|
||||
import { ObjectPool } from '../../../utils';
|
||||
import type { Doc } from '../entities/doc';
|
||||
import type { DocMode } from '../entities/record';
|
||||
import { DocRecordList } from '../entities/record-list';
|
||||
import { DocScope } from '../scopes/doc';
|
||||
import type { DocsStore } from '../stores/docs';
|
||||
@@ -46,4 +50,22 @@ export class DocsService extends Service {
|
||||
|
||||
return { doc: obj, release };
|
||||
}
|
||||
|
||||
createDoc(
|
||||
options: {
|
||||
mode?: DocMode;
|
||||
title?: string;
|
||||
} = {}
|
||||
) {
|
||||
const doc = this.store.createBlockSuiteDoc();
|
||||
initEmptyPage(doc, options.title);
|
||||
const docRecord = this.list.doc$(doc.id).value;
|
||||
if (!docRecord) {
|
||||
throw new Unreachable();
|
||||
}
|
||||
if (options.mode) {
|
||||
docRecord.setMode(options.mode);
|
||||
}
|
||||
return docRecord;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,10 @@ export class DocsStore extends Store {
|
||||
return this.workspaceService.workspace.docCollection.getDoc(id);
|
||||
}
|
||||
|
||||
createBlockSuiteDoc() {
|
||||
return this.workspaceService.workspace.docCollection.createDoc();
|
||||
}
|
||||
|
||||
watchDocIds() {
|
||||
return new Observable<string[]>(subscriber => {
|
||||
const emit = () => {
|
||||
@@ -37,7 +41,29 @@ export class DocsStore extends Store {
|
||||
return () => {
|
||||
dispose();
|
||||
};
|
||||
}).pipe(distinctUntilChanged((p, c) => isEqual(p, c)));
|
||||
});
|
||||
}
|
||||
|
||||
watchTrashDocIds() {
|
||||
return new Observable<string[]>(subscriber => {
|
||||
const emit = () => {
|
||||
subscriber.next(
|
||||
this.workspaceService.workspace.docCollection.meta.docMetas
|
||||
.map(v => (v.trash ? v.id : null))
|
||||
.filter(Boolean) as string[]
|
||||
);
|
||||
};
|
||||
|
||||
emit();
|
||||
|
||||
const dispose =
|
||||
this.workspaceService.workspace.docCollection.meta.docMetaUpdated.on(
|
||||
emit
|
||||
).dispose;
|
||||
return () => {
|
||||
dispose();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
watchDocMeta(id: string) {
|
||||
|
||||
@@ -3,6 +3,7 @@ import type { Doc as YDoc } from 'yjs';
|
||||
import { Entity } from '../../../framework';
|
||||
import { AwarenessEngine, BlobEngine, DocEngine } from '../../../sync';
|
||||
import { throwIfAborted } from '../../../utils';
|
||||
import { WorkspaceEngineBeforeStart } from '../events';
|
||||
import type { WorkspaceEngineProvider } from '../providers/flavour';
|
||||
import type { WorkspaceService } from '../services/workspace';
|
||||
|
||||
@@ -33,6 +34,7 @@ export class WorkspaceEngine extends Entity<{
|
||||
}
|
||||
|
||||
start() {
|
||||
this.eventBus.emit(WorkspaceEngineBeforeStart, this);
|
||||
this.doc.start();
|
||||
this.awareness.connect(this.workspaceService.workspace.awareness);
|
||||
this.blob.start();
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
import { createEvent } from '../../../framework';
|
||||
import type { WorkspaceEngine } from '../entities/engine';
|
||||
|
||||
export const WorkspaceEngineBeforeStart = createEvent<WorkspaceEngine>(
|
||||
'WorkspaceEngineBeforeStart'
|
||||
);
|
||||
@@ -1,5 +1,6 @@
|
||||
export type { WorkspaceProfileInfo } from './entities/profile';
|
||||
export { Workspace } from './entities/workspace';
|
||||
export { WorkspaceEngineBeforeStart } from './events';
|
||||
export { globalBlockSuiteSchema } from './global-schema';
|
||||
export type { WorkspaceMetadata } from './metadata';
|
||||
export type { WorkspaceOpenOptions } from './open-options';
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
import { createORMClientType } from '../core';
|
||||
import { AFFiNE_DB_SCHEMA } from './schema';
|
||||
|
||||
export const ORMClient = createORMClientType(AFFiNE_DB_SCHEMA);
|
||||
@@ -1,21 +0,0 @@
|
||||
import { ORMClient } from './client';
|
||||
|
||||
// The ORM hooks are used to define the transformers that will be applied on entities when they are loaded from the data providers.
|
||||
// All transformers are doing in memory, none of the data under the hood will be changed.
|
||||
//
|
||||
// for example:
|
||||
// data in providers: { color: 'red' }
|
||||
// hook: { color: 'red' } => { color: '#FF0000' }
|
||||
//
|
||||
// ORMClient.defineHook(
|
||||
// 'demo',
|
||||
// 'deprecate color field and introduce colors filed',
|
||||
// {
|
||||
// deserialize(tag) {
|
||||
// tag.color = stringToHex(tag.color)
|
||||
// return tag;
|
||||
// },
|
||||
// }
|
||||
// );
|
||||
|
||||
export { ORMClient };
|
||||
@@ -1,3 +0,0 @@
|
||||
import './hooks';
|
||||
|
||||
export { ORMClient } from './client';
|
||||
@@ -1,17 +0,0 @@
|
||||
import type { DBSchemaBuilder } from '../core';
|
||||
// import { f } from './core';
|
||||
|
||||
export const AFFiNE_DB_SCHEMA = {
|
||||
// demo: {
|
||||
// id: f.string().primaryKey().optional().default(nanoid),
|
||||
// name: f.string(),
|
||||
// // v1
|
||||
// // color: f.string(),
|
||||
// // v2, without data level breaking change
|
||||
// /**
|
||||
// * @deprecated use [colors]
|
||||
// */
|
||||
// color: f.string().optional(), // <= mark as optional since new created record might only have [colors] field
|
||||
// colors: f.json<string[]>().optional(), // <= mark as optional since old records might only have [color] field
|
||||
// },
|
||||
} as const satisfies DBSchemaBuilder;
|
||||
@@ -1,18 +1,12 @@
|
||||
import { nanoid } from 'nanoid';
|
||||
import {
|
||||
afterEach,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
test as t,
|
||||
type TestAPI,
|
||||
} from 'vitest';
|
||||
import { beforeEach, describe, expect, test as t, type TestAPI } from 'vitest';
|
||||
|
||||
import {
|
||||
createORMClientType,
|
||||
createORMClient,
|
||||
type DBSchemaBuilder,
|
||||
f,
|
||||
MemoryORMAdapter,
|
||||
type ORMClient,
|
||||
Table,
|
||||
} from '../';
|
||||
|
||||
@@ -24,18 +18,12 @@ const TEST_SCHEMA = {
|
||||
},
|
||||
} satisfies DBSchemaBuilder;
|
||||
|
||||
const Client = createORMClientType(TEST_SCHEMA);
|
||||
type Context = {
|
||||
client: InstanceType<typeof Client>;
|
||||
client: ORMClient<typeof TEST_SCHEMA>;
|
||||
};
|
||||
|
||||
beforeEach<Context>(async t => {
|
||||
t.client = new Client(new MemoryORMAdapter());
|
||||
await t.client.connect();
|
||||
});
|
||||
|
||||
afterEach<Context>(async t => {
|
||||
await t.client.disconnect();
|
||||
t.client = createORMClient(TEST_SCHEMA, MemoryORMAdapter);
|
||||
});
|
||||
|
||||
const test = t as TestAPI<Context>;
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
import { nanoid } from 'nanoid';
|
||||
import {
|
||||
afterEach,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
test as t,
|
||||
type TestAPI,
|
||||
} from 'vitest';
|
||||
import { beforeEach, describe, expect, test as t, type TestAPI } from 'vitest';
|
||||
|
||||
import {
|
||||
createORMClientType,
|
||||
createORMClient,
|
||||
type DBSchemaBuilder,
|
||||
type Entity,
|
||||
f,
|
||||
MemoryORMAdapter,
|
||||
type ORMClient,
|
||||
} from '../';
|
||||
|
||||
const TEST_SCHEMA = {
|
||||
@@ -29,30 +23,23 @@ const TEST_SCHEMA = {
|
||||
},
|
||||
} satisfies DBSchemaBuilder;
|
||||
|
||||
const Client = createORMClientType(TEST_SCHEMA);
|
||||
|
||||
// define the hooks
|
||||
Client.defineHook('tags', 'migrate field `color` to field `colors`', {
|
||||
deserialize(data) {
|
||||
if (!data.colors && data.color) {
|
||||
data.colors = [data.color];
|
||||
}
|
||||
|
||||
return data;
|
||||
},
|
||||
});
|
||||
|
||||
type Context = {
|
||||
client: InstanceType<typeof Client>;
|
||||
client: ORMClient<typeof TEST_SCHEMA>;
|
||||
};
|
||||
|
||||
beforeEach<Context>(async t => {
|
||||
t.client = new Client(new MemoryORMAdapter());
|
||||
await t.client.connect();
|
||||
});
|
||||
t.client = createORMClient(TEST_SCHEMA, MemoryORMAdapter);
|
||||
|
||||
afterEach<Context>(async t => {
|
||||
await t.client.disconnect();
|
||||
// define the hooks
|
||||
t.client.defineHook('tags', 'migrate field `color` to field `colors`', {
|
||||
deserialize(data) {
|
||||
if (!data.colors && data.color) {
|
||||
data.colors = [data.color];
|
||||
}
|
||||
|
||||
return data;
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const test = t as TestAPI<Context>;
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
import { nanoid } from 'nanoid';
|
||||
import { describe, expect, test } from 'vitest';
|
||||
|
||||
import { createORMClientType, f, MemoryORMAdapter } from '../';
|
||||
import {
|
||||
createORMClient,
|
||||
type DBSchemaBuilder,
|
||||
f,
|
||||
MemoryORMAdapter,
|
||||
} from '../';
|
||||
|
||||
function createClient<Schema extends DBSchemaBuilder>(schema: Schema) {
|
||||
return createORMClient(schema, MemoryORMAdapter);
|
||||
}
|
||||
|
||||
describe('Schema validations', () => {
|
||||
test('primary key must be set', () => {
|
||||
expect(() =>
|
||||
createORMClientType({
|
||||
createClient({
|
||||
tags: {
|
||||
id: f.string(),
|
||||
name: f.string(),
|
||||
@@ -19,7 +28,7 @@ describe('Schema validations', () => {
|
||||
|
||||
test('primary key must be unique', () => {
|
||||
expect(() =>
|
||||
createORMClientType({
|
||||
createClient({
|
||||
tags: {
|
||||
id: f.string().primaryKey(),
|
||||
name: f.string().primaryKey(),
|
||||
@@ -32,7 +41,7 @@ describe('Schema validations', () => {
|
||||
|
||||
test('primary key should not be optional without default value', () => {
|
||||
expect(() =>
|
||||
createORMClientType({
|
||||
createClient({
|
||||
tags: {
|
||||
id: f.string().primaryKey().optional(),
|
||||
name: f.string(),
|
||||
@@ -45,7 +54,7 @@ describe('Schema validations', () => {
|
||||
|
||||
test('primary key can be optional with default value', async () => {
|
||||
expect(() =>
|
||||
createORMClientType({
|
||||
createClient({
|
||||
tags: {
|
||||
id: f.string().primaryKey().optional().default(nanoid),
|
||||
name: f.string(),
|
||||
@@ -56,20 +65,18 @@ describe('Schema validations', () => {
|
||||
});
|
||||
|
||||
describe('Entity validations', () => {
|
||||
const Client = createORMClientType({
|
||||
tags: {
|
||||
id: f.string().primaryKey().default(nanoid),
|
||||
name: f.string(),
|
||||
color: f.string(),
|
||||
},
|
||||
});
|
||||
|
||||
function createClient() {
|
||||
return new Client(new MemoryORMAdapter());
|
||||
function createTagsClient() {
|
||||
return createClient({
|
||||
tags: {
|
||||
id: f.string().primaryKey().default(nanoid),
|
||||
name: f.string(),
|
||||
color: f.string(),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
test('should not update primary key', () => {
|
||||
const client = createClient();
|
||||
const client = createTagsClient();
|
||||
|
||||
const tag = client.tags.create({
|
||||
name: 'tag',
|
||||
@@ -83,7 +90,7 @@ describe('Entity validations', () => {
|
||||
});
|
||||
|
||||
test('should throw when trying to create entity with missing required field', () => {
|
||||
const client = createClient();
|
||||
const client = createTagsClient();
|
||||
|
||||
// @ts-expect-error test
|
||||
expect(() => client.tags.create({ name: 'test' })).toThrow(
|
||||
@@ -92,7 +99,7 @@ describe('Entity validations', () => {
|
||||
});
|
||||
|
||||
test('should throw when trying to create entity with extra field', () => {
|
||||
const client = createClient();
|
||||
const client = createTagsClient();
|
||||
|
||||
expect(() =>
|
||||
// @ts-expect-error test
|
||||
@@ -101,34 +108,28 @@ describe('Entity validations', () => {
|
||||
});
|
||||
|
||||
test('should throw when trying to create entity with unexpected field type', () => {
|
||||
const client = createClient();
|
||||
const client = createTagsClient();
|
||||
|
||||
expect(() =>
|
||||
// @ts-expect-error test
|
||||
client.tags.create({ name: 'test', color: 123 })
|
||||
).toThrow(
|
||||
// @ts-expect-error test
|
||||
expect(() => client.tags.create({ name: 'test', color: 123 })).toThrow(
|
||||
"[Table(tags)]: Field 'color' type mismatch. Expected type 'string' but got 'number'."
|
||||
);
|
||||
|
||||
expect(() =>
|
||||
// @ts-expect-error test
|
||||
client.tags.create({ name: 'test', color: [123] })
|
||||
).toThrow(
|
||||
// @ts-expect-error test
|
||||
expect(() => client.tags.create({ name: 'test', color: [123] })).toThrow(
|
||||
"[Table(tags)]: Field 'color' type mismatch. Expected type 'string' but got 'json'"
|
||||
);
|
||||
});
|
||||
|
||||
test('should be able to assign `null` to json field', () => {
|
||||
expect(() => {
|
||||
const Client = createORMClientType({
|
||||
const client = createClient({
|
||||
tags: {
|
||||
id: f.string().primaryKey().default(nanoid),
|
||||
info: f.json(),
|
||||
},
|
||||
});
|
||||
|
||||
const client = new Client(new MemoryORMAdapter());
|
||||
|
||||
const tag = client.tags.create({ info: null });
|
||||
|
||||
expect(tag.info).toBe(null);
|
||||
|
||||
@@ -14,9 +14,10 @@ import { DocEngine } from '../../../sync';
|
||||
import { MiniSyncServer } from '../../../sync/doc/__tests__/utils';
|
||||
import { MemoryStorage } from '../../../sync/doc/storage';
|
||||
import {
|
||||
createORMClientType,
|
||||
createORMClient,
|
||||
type DBSchemaBuilder,
|
||||
f,
|
||||
type ORMClient,
|
||||
YjsDBAdapter,
|
||||
} from '../';
|
||||
|
||||
@@ -29,27 +30,14 @@ const TEST_SCHEMA = {
|
||||
},
|
||||
} satisfies DBSchemaBuilder;
|
||||
|
||||
const Client = createORMClientType(TEST_SCHEMA);
|
||||
|
||||
// define the hooks
|
||||
Client.defineHook('tags', 'migrate field `color` to field `colors`', {
|
||||
deserialize(data) {
|
||||
if (!data.colors && data.color) {
|
||||
data.colors = [data.color];
|
||||
}
|
||||
|
||||
return data;
|
||||
},
|
||||
});
|
||||
|
||||
type Context = {
|
||||
server: MiniSyncServer;
|
||||
user1: {
|
||||
client: InstanceType<typeof Client>;
|
||||
client: ORMClient<typeof TEST_SCHEMA>;
|
||||
engine: DocEngine;
|
||||
};
|
||||
user2: {
|
||||
client: InstanceType<typeof Client>;
|
||||
client: ORMClient<typeof TEST_SCHEMA>;
|
||||
engine: DocEngine;
|
||||
};
|
||||
};
|
||||
@@ -60,16 +48,25 @@ function createEngine(server: MiniSyncServer) {
|
||||
|
||||
async function createClient(server: MiniSyncServer, clientId: number) {
|
||||
const engine = createEngine(server);
|
||||
const client = new Client(
|
||||
new YjsDBAdapter({
|
||||
getDoc(guid: string) {
|
||||
const doc = new Doc({ guid });
|
||||
doc.clientID = clientId;
|
||||
engine.addDoc(doc);
|
||||
return doc;
|
||||
},
|
||||
})
|
||||
);
|
||||
const client = createORMClient(TEST_SCHEMA, YjsDBAdapter, {
|
||||
getDoc(guid: string) {
|
||||
const doc = new Doc({ guid });
|
||||
doc.clientID = clientId;
|
||||
engine.addDoc(doc);
|
||||
return doc;
|
||||
},
|
||||
});
|
||||
|
||||
// define the hooks
|
||||
client.defineHook('tags', 'migrate field `color` to field `colors`', {
|
||||
deserialize(data) {
|
||||
if (!data.colors && data.color) {
|
||||
data.colors = [data.color];
|
||||
}
|
||||
|
||||
return data;
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
engine,
|
||||
@@ -85,14 +82,10 @@ beforeEach<Context>(async t => {
|
||||
t.user2 = await createClient(t.server, 2);
|
||||
|
||||
t.user1.engine.start();
|
||||
await t.user1.client.connect();
|
||||
t.user2.engine.start();
|
||||
await t.user2.client.connect();
|
||||
});
|
||||
|
||||
afterEach<Context>(async t => {
|
||||
t.user1.client.disconnect();
|
||||
t.user2.client.disconnect();
|
||||
t.user1.engine.stop();
|
||||
t.user2.engine.stop();
|
||||
});
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
import { nanoid } from 'nanoid';
|
||||
import {
|
||||
afterEach,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
test as t,
|
||||
type TestAPI,
|
||||
} from 'vitest';
|
||||
import { beforeEach, describe, expect, test as t, type TestAPI } from 'vitest';
|
||||
import { Doc } from 'yjs';
|
||||
|
||||
import {
|
||||
createORMClientType,
|
||||
createORMClient,
|
||||
type DBSchemaBuilder,
|
||||
type DocProvider,
|
||||
type Entity,
|
||||
f,
|
||||
type ORMClient,
|
||||
Table,
|
||||
YjsDBAdapter,
|
||||
} from '../';
|
||||
@@ -33,18 +27,12 @@ const docProvider: DocProvider = {
|
||||
},
|
||||
};
|
||||
|
||||
const Client = createORMClientType(TEST_SCHEMA);
|
||||
type Context = {
|
||||
client: InstanceType<typeof Client>;
|
||||
client: ORMClient<typeof TEST_SCHEMA>;
|
||||
};
|
||||
|
||||
beforeEach<Context>(async t => {
|
||||
t.client = new Client(new YjsDBAdapter(docProvider));
|
||||
await t.client.connect();
|
||||
});
|
||||
|
||||
afterEach<Context>(async t => {
|
||||
await t.client.disconnect();
|
||||
t.client = createORMClient(TEST_SCHEMA, YjsDBAdapter, docProvider);
|
||||
});
|
||||
|
||||
const test = t as TestAPI<Context>;
|
||||
@@ -223,15 +211,13 @@ describe('ORM entity CRUD', () => {
|
||||
});
|
||||
|
||||
test('can not use reserved keyword as field name', () => {
|
||||
const Client = createORMClientType({
|
||||
const schema = {
|
||||
tags: {
|
||||
$$KEY: f.string().primaryKey().default(nanoid),
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
expect(() =>
|
||||
new Client(new YjsDBAdapter(docProvider)).connect()
|
||||
).rejects.toThrow(
|
||||
expect(() => createORMClient(schema, YjsDBAdapter, docProvider)).toThrow(
|
||||
"[Table(tags)]: Field '$$KEY' is reserved keyword and can't be used"
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,16 +1,7 @@
|
||||
import type { DBSchemaBuilder } from '../../schema';
|
||||
import type { DBAdapter } from '../types';
|
||||
import { MemoryTableAdapter } from './table';
|
||||
|
||||
export class MemoryORMAdapter implements DBAdapter {
|
||||
connect(_db: DBSchemaBuilder): Promise<void> {
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
disconnect(_db: DBSchemaBuilder): Promise<void> {
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
table(tableName: string) {
|
||||
return new MemoryTableAdapter(tableName);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { DBSchemaBuilder, TableSchemaBuilder } from '../schema';
|
||||
import type { TableSchemaBuilder } from '../schema';
|
||||
|
||||
export interface Key {
|
||||
toString(): string;
|
||||
@@ -21,8 +21,5 @@ export interface TableAdapter<K extends Key = any, T = unknown> {
|
||||
}
|
||||
|
||||
export interface DBAdapter {
|
||||
connect(db: DBSchemaBuilder): Promise<void>;
|
||||
disconnect(db: DBSchemaBuilder): Promise<void>;
|
||||
|
||||
table(tableName: string): TableAdapter;
|
||||
}
|
||||
|
||||
@@ -11,25 +11,16 @@ export interface DocProvider {
|
||||
|
||||
export class YjsDBAdapter implements DBAdapter {
|
||||
tables: Map<string, TableAdapter> = new Map();
|
||||
constructor(private readonly provider: DocProvider) {}
|
||||
|
||||
connect(db: DBSchemaBuilder): Promise<void> {
|
||||
constructor(
|
||||
db: DBSchemaBuilder,
|
||||
private readonly provider: DocProvider
|
||||
) {
|
||||
for (const [tableName, table] of Object.entries(db)) {
|
||||
validators.validateYjsTableSchema(tableName, table);
|
||||
const doc = this.provider.getDoc(tableName);
|
||||
|
||||
this.tables.set(tableName, new YjsTableAdapter(tableName, doc));
|
||||
}
|
||||
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
disconnect(_db: DBSchemaBuilder): Promise<void> {
|
||||
this.tables.forEach(table => {
|
||||
table.dispose();
|
||||
});
|
||||
this.tables.clear();
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
table(tableName: string) {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { type DBAdapter, type Hook } from './adapters';
|
||||
import type { DBSchemaBuilder } from './schema';
|
||||
import { type CreateEntityInput, Table, type TableMap } from './table';
|
||||
import { Table, type TableMap } from './table';
|
||||
import { validators } from './validators';
|
||||
|
||||
export class ORMClient {
|
||||
static hooksMap: Map<string, Hook<any>[]> = new Map();
|
||||
class RawORMClient {
|
||||
hooksMap: Map<string, Hook<any>[]> = new Map();
|
||||
private readonly tables = new Map<string, Table<any>>();
|
||||
constructor(
|
||||
protected readonly db: DBSchemaBuilder,
|
||||
@@ -17,7 +17,7 @@ export class ORMClient {
|
||||
if (!table) {
|
||||
table = new Table(this.adapter, tableName, {
|
||||
schema: tableSchema,
|
||||
hooks: ORMClient.hooksMap.get(tableName),
|
||||
hooks: this.hooksMap.get(tableName),
|
||||
});
|
||||
this.tables.set(tableName, table);
|
||||
}
|
||||
@@ -27,7 +27,7 @@ export class ORMClient {
|
||||
});
|
||||
}
|
||||
|
||||
static defineHook(tableName: string, _desc: string, hook: Hook<any>) {
|
||||
defineHook(tableName: string, _desc: string, hook: Hook<any>) {
|
||||
let hooks = this.hooksMap.get(tableName);
|
||||
if (!hooks) {
|
||||
hooks = [];
|
||||
@@ -36,48 +36,30 @@ export class ORMClient {
|
||||
|
||||
hooks.push(hook);
|
||||
}
|
||||
|
||||
async connect() {
|
||||
await this.adapter.connect(this.db);
|
||||
}
|
||||
|
||||
async disconnect() {
|
||||
await this.adapter.disconnect(this.db);
|
||||
}
|
||||
}
|
||||
|
||||
export function createORMClientType<Schema extends DBSchemaBuilder>(
|
||||
db: Schema
|
||||
): ORMClientWithTablesClass<Schema> {
|
||||
export function createORMClient<
|
||||
const Schema extends DBSchemaBuilder,
|
||||
AdapterConstructor extends new (...args: any[]) => DBAdapter,
|
||||
AdapterConstructorParams extends
|
||||
any[] = ConstructorParameters<AdapterConstructor> extends [
|
||||
DBSchemaBuilder,
|
||||
...infer Args,
|
||||
]
|
||||
? Args
|
||||
: never,
|
||||
>(
|
||||
db: Schema,
|
||||
adapter: AdapterConstructor,
|
||||
...args: AdapterConstructorParams
|
||||
): ORMClient<Schema> {
|
||||
Object.entries(db).forEach(([tableName, schema]) => {
|
||||
validators.validateTableSchema(tableName, schema);
|
||||
});
|
||||
|
||||
class ORMClientWithTables extends ORMClient {
|
||||
constructor(adapter: DBAdapter) {
|
||||
super(db, adapter);
|
||||
}
|
||||
}
|
||||
|
||||
return ORMClientWithTables as {
|
||||
new (
|
||||
...args: ConstructorParameters<typeof ORMClientWithTables>
|
||||
): ORMClient & TableMap<Schema>;
|
||||
|
||||
defineHook<TableName extends keyof Schema>(
|
||||
tableName: TableName,
|
||||
desc: string,
|
||||
hook: Hook<CreateEntityInput<Schema[TableName]>>
|
||||
): void;
|
||||
};
|
||||
return new RawORMClient(db, new adapter(db, ...args)) as TableMap<Schema> &
|
||||
RawORMClient;
|
||||
}
|
||||
|
||||
export type ORMClientWithTablesClass<Schema extends DBSchemaBuilder> = {
|
||||
new (adapter: DBAdapter): TableMap<Schema> & ORMClient;
|
||||
|
||||
defineHook<TableName extends keyof Schema>(
|
||||
tableName: TableName,
|
||||
desc: string,
|
||||
hook: Hook<CreateEntityInput<Schema[TableName]>>
|
||||
): void;
|
||||
};
|
||||
export type ORMClient<Schema extends DBSchemaBuilder> = RawORMClient &
|
||||
TableMap<Schema>;
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
export * from './affine';
|
||||
@@ -4,6 +4,7 @@ import { difference } from 'lodash-es';
|
||||
|
||||
import { LiveData } from '../../livedata';
|
||||
import type { Memento } from '../../storage';
|
||||
import { MANUALLY_STOP } from '../../utils';
|
||||
import { BlobStorageOverCapacity } from './error';
|
||||
|
||||
const logger = new DebugLogger('affine:blob-engine');
|
||||
@@ -70,7 +71,7 @@ export class BlobEngine {
|
||||
}
|
||||
|
||||
stop() {
|
||||
this.abort?.abort();
|
||||
this.abort?.abort(MANUALLY_STOP);
|
||||
this.abort = null;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ export {
|
||||
} from './storage';
|
||||
|
||||
export class DocEngine {
|
||||
readonly clientId: string;
|
||||
localPart: DocEngineLocalPart;
|
||||
remotePart: DocEngineRemotePart | null;
|
||||
|
||||
@@ -80,11 +81,11 @@ export class DocEngine {
|
||||
storage: DocStorage,
|
||||
private readonly server?: DocServer | null
|
||||
) {
|
||||
const clientId = nanoid();
|
||||
this.clientId = nanoid();
|
||||
this.storage = new DocStorageInner(storage);
|
||||
this.localPart = new DocEngineLocalPart(clientId, this.storage);
|
||||
this.localPart = new DocEngineLocalPart(this.clientId, this.storage);
|
||||
this.remotePart = this.server
|
||||
? new DocEngineRemotePart(clientId, this.storage, this.server)
|
||||
? new DocEngineRemotePart(this.clientId, this.storage, this.server)
|
||||
: null;
|
||||
}
|
||||
|
||||
|
||||
24
packages/common/infra/src/sync/doc/old-id.md
Normal file
24
packages/common/infra/src/sync/doc/old-id.md
Normal file
@@ -0,0 +1,24 @@
|
||||
AFFiNE currently has a lot of data stored using the old ID format. Here, we record the usage of IDs to avoid forgetting.
|
||||
|
||||
## Old ID Format
|
||||
|
||||
The format is:
|
||||
|
||||
- `{workspace-id}:space:{nanoid}` Common
|
||||
- `{workspace-id}:space:page:{nanoid}`
|
||||
|
||||
> Note: sometimes the `workspace-id` is not same with current workspace id.
|
||||
|
||||
## Usage
|
||||
|
||||
- Local Storage
|
||||
- indexeddb: Both new and old IDs coexist
|
||||
- sqlite: Both new and old IDs coexist
|
||||
- server-clock: Only new IDs are stored
|
||||
- sync-metadata: Both new and old IDs coexist
|
||||
- Server Storage
|
||||
- Only stores new IDs but accepts writes using old IDs
|
||||
- Protocols
|
||||
- When the client submits an update, both new and old IDs are used.
|
||||
- When the server broadcasts updates sent by other clients, both new and old IDs are used.
|
||||
- When the server responds to `client-pre-sync` (listing all updated docids), only new IDs are used.
|
||||
@@ -4,3 +4,11 @@ export type { BlobStatus, BlobStorage } from './blob/blob';
|
||||
export { BlobEngine, EmptyBlobStorage } from './blob/blob';
|
||||
export { BlobStorageOverCapacity } from './blob/error';
|
||||
export * from './doc';
|
||||
export * from './indexer';
|
||||
export {
|
||||
IndexedDBIndex,
|
||||
IndexedDBIndexStorage,
|
||||
} from './indexer/impl/indexeddb';
|
||||
export { MemoryIndex, MemoryIndexStorage } from './indexer/impl/memory';
|
||||
export * from './job';
|
||||
export { IndexedDBJobQueue } from './job/impl/indexeddb';
|
||||
|
||||
147
packages/common/infra/src/sync/indexer/README.md
Normal file
147
packages/common/infra/src/sync/indexer/README.md
Normal file
@@ -0,0 +1,147 @@
|
||||
# index
|
||||
|
||||
Search engine abstraction layer for AFFiNE.
|
||||
|
||||
## Using
|
||||
|
||||
1. Define schema
|
||||
|
||||
First, we need to define the shape of the data. Currently, there are the following data types.
|
||||
|
||||
- 'Integer'
|
||||
- 'Boolean'
|
||||
- 'FullText': for full-text search, it will be tokenized and stemmed.
|
||||
- 'String': for exact match search, e.g. tags, ids.
|
||||
|
||||
```typescript
|
||||
const schema = defineSchema({
|
||||
title: 'FullText',
|
||||
tag: 'String',
|
||||
size: 'Integer',
|
||||
});
|
||||
```
|
||||
|
||||
> **Array type**
|
||||
> All types can contain one or more values, so each field can store an array.
|
||||
|
||||
2. Pick a backend
|
||||
|
||||
Currently, there are two backends available.
|
||||
|
||||
- `MemoryIndex`: in-memory indexer, useful for testing.
|
||||
- `IndexedDBIndex`: persistent indexer using IndexedDB.
|
||||
|
||||
> **Underlying Data Table**
|
||||
> Some back-end processes need to maintain underlying data tables, including table creation and migration. This operation should be silently executed the first time the indexer is invoked.
|
||||
> Callers do not need to worry about these details.
|
||||
>
|
||||
> This design conforms to the usual conventions of search engine APIs, such as in Elasticsearch: https://www.elastic.co/guide/en/elasticsearch/reference/current/array.html
|
||||
|
||||
3. Write data
|
||||
|
||||
Write data to the indexer. you need to start a write transaction by `await index.write()` first and then complete the batch write through `await writer.commit()`.
|
||||
|
||||
> **Transactional**
|
||||
> Typically, the indexer does not provide transactional guarantees; reliable locking logic needs to be implemented at a higher level.
|
||||
|
||||
```typescript
|
||||
const indexer = new IndexedDBIndex(schema);
|
||||
|
||||
const writer = await index.write();
|
||||
writer.insert(
|
||||
Document.from('id', {
|
||||
title: 'hello world',
|
||||
tag: ['doc', 'page'],
|
||||
size: '100',
|
||||
})
|
||||
);
|
||||
await writer.commit();
|
||||
```
|
||||
|
||||
4. Search data
|
||||
|
||||
To search for content in the indexer, you need to use a specific **query language**. Here are some examples:
|
||||
|
||||
```typescript
|
||||
// match title == 'hello world'
|
||||
{
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'hello world',
|
||||
}
|
||||
|
||||
// match title == 'hello world' && tag == 'doc'
|
||||
{
|
||||
type: 'boolean',
|
||||
occur: 'must',
|
||||
queries: [
|
||||
{
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'hello world',
|
||||
},
|
||||
{
|
||||
type: 'match',
|
||||
field: 'tag',
|
||||
match: 'doc',
|
||||
},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
There are two ways to perform the search, `index.search()` and `index.aggregate()`.
|
||||
|
||||
- **search**: return each matched node and pagination information.
|
||||
- **aggregate**: aggregate all matched results based on a certain field into buckets, and return the count and score of items in each bucket.
|
||||
|
||||
Examples:
|
||||
|
||||
```typescript
|
||||
const result = await index.search({
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'hello world',
|
||||
});
|
||||
// result = {
|
||||
// nodes: [
|
||||
// {
|
||||
// id: '1',
|
||||
// score: 1,
|
||||
// },
|
||||
// ],
|
||||
// pagination: {
|
||||
// count: 1,
|
||||
// hasMore: false,
|
||||
// limit: 10,
|
||||
// skip: 0,
|
||||
// },
|
||||
// }
|
||||
```
|
||||
|
||||
```typescript
|
||||
const result = await index.aggregate(
|
||||
{
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'affine',
|
||||
},
|
||||
'tag'
|
||||
);
|
||||
// result = {
|
||||
// buckets: [
|
||||
// { key: 'motorcycle', count: 2, score: 1 },
|
||||
// { key: 'bike', count: 1, score: 1 },
|
||||
// { key: 'airplane', count: 1, score: 1 },
|
||||
// ],
|
||||
// pagination: {
|
||||
// count: 3,
|
||||
// hasMore: false,
|
||||
// limit: 10,
|
||||
// skip: 0,
|
||||
// },
|
||||
// }
|
||||
```
|
||||
|
||||
More uses:
|
||||
|
||||
[black-box.spec.ts](./__tests__/black-box.spec.ts)
|
||||
@@ -0,0 +1,554 @@
|
||||
/**
|
||||
* @vitest-environment happy-dom
|
||||
*/
|
||||
import 'fake-indexeddb/auto';
|
||||
|
||||
import { map } from 'rxjs';
|
||||
import { beforeEach, describe, expect, test, vitest } from 'vitest';
|
||||
|
||||
import { defineSchema, Document, type Index } from '..';
|
||||
import { IndexedDBIndex } from '../impl/indexeddb';
|
||||
import { MemoryIndex } from '../impl/memory';
|
||||
|
||||
const schema = defineSchema({
|
||||
title: 'FullText',
|
||||
tag: 'String',
|
||||
size: 'Integer',
|
||||
});
|
||||
|
||||
let index: Index<typeof schema> = null!;
|
||||
|
||||
describe.each([
|
||||
{ name: 'memory', backend: MemoryIndex },
|
||||
{ name: 'idb', backend: IndexedDBIndex },
|
||||
])('index tests($name)', ({ backend }) => {
|
||||
async function writeData(
|
||||
data: Record<
|
||||
string,
|
||||
Partial<Record<keyof typeof schema, string | string[]>>
|
||||
>
|
||||
) {
|
||||
const writer = await index.write();
|
||||
for (const [id, item] of Object.entries(data)) {
|
||||
const doc = new Document(id);
|
||||
for (const [key, value] of Object.entries(item)) {
|
||||
if (Array.isArray(value)) {
|
||||
for (const v of value) {
|
||||
doc.insert(key, v);
|
||||
}
|
||||
} else {
|
||||
doc.insert(key, value);
|
||||
}
|
||||
}
|
||||
writer.insert(doc);
|
||||
}
|
||||
await writer.commit();
|
||||
}
|
||||
|
||||
beforeEach(async () => {
|
||||
index = new backend(schema);
|
||||
await index.clear();
|
||||
});
|
||||
|
||||
test('basic', async () => {
|
||||
await writeData({
|
||||
'1': {
|
||||
title: 'hello world',
|
||||
},
|
||||
});
|
||||
|
||||
const result = await index.search({
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'hello world',
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
nodes: [
|
||||
{
|
||||
id: '1',
|
||||
score: expect.anything(),
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
count: 1,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('basic integer', async () => {
|
||||
await writeData({
|
||||
'1': {
|
||||
title: 'hello world',
|
||||
size: '100',
|
||||
},
|
||||
});
|
||||
|
||||
const result = await index.search({
|
||||
type: 'match',
|
||||
field: 'size',
|
||||
match: '100',
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
nodes: [
|
||||
{
|
||||
id: '1',
|
||||
score: expect.anything(),
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
count: 1,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('fuzz', async () => {
|
||||
await writeData({
|
||||
'1': {
|
||||
title: 'hello world',
|
||||
},
|
||||
});
|
||||
const result = await index.search({
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'hell',
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
nodes: [
|
||||
{
|
||||
id: '1',
|
||||
score: expect.anything(),
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
count: 1,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('highlight', async () => {
|
||||
await writeData({
|
||||
'1': {
|
||||
title: 'hello world',
|
||||
size: '100',
|
||||
},
|
||||
});
|
||||
|
||||
const result = await index.search(
|
||||
{
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'hello',
|
||||
},
|
||||
{
|
||||
highlights: [
|
||||
{
|
||||
field: 'title',
|
||||
before: '<b>',
|
||||
end: '</b>',
|
||||
},
|
||||
],
|
||||
}
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
nodes: expect.arrayContaining([
|
||||
{
|
||||
id: '1',
|
||||
score: expect.anything(),
|
||||
highlights: {
|
||||
title: [expect.stringContaining('<b>hello</b>')],
|
||||
},
|
||||
},
|
||||
]),
|
||||
pagination: {
|
||||
count: 1,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('fields', async () => {
|
||||
await writeData({
|
||||
'1': {
|
||||
title: 'hello world',
|
||||
tag: ['car', 'bike'],
|
||||
},
|
||||
});
|
||||
|
||||
const result = await index.search(
|
||||
{
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'hello',
|
||||
},
|
||||
{
|
||||
fields: ['title', 'tag'],
|
||||
}
|
||||
);
|
||||
|
||||
expect(result.nodes[0].fields).toEqual({
|
||||
title: 'hello world',
|
||||
tag: expect.arrayContaining(['bike', 'car']),
|
||||
});
|
||||
});
|
||||
|
||||
test('pagination', async () => {
|
||||
await writeData(
|
||||
Array.from({ length: 100 }).reduce((acc: any, _, i) => {
|
||||
acc['apple' + i] = {
|
||||
tag: ['apple'],
|
||||
};
|
||||
return acc;
|
||||
}, {}) as any
|
||||
);
|
||||
|
||||
const result = await index.search(
|
||||
{
|
||||
type: 'match',
|
||||
field: 'tag',
|
||||
match: 'apple',
|
||||
},
|
||||
{
|
||||
pagination: {
|
||||
skip: 0,
|
||||
limit: 10,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
nodes: expect.arrayContaining(
|
||||
Array.from({ length: 10 }).fill({
|
||||
id: expect.stringContaining('apple'),
|
||||
score: expect.anything(),
|
||||
})
|
||||
),
|
||||
pagination: {
|
||||
count: 100,
|
||||
hasMore: true,
|
||||
limit: 10,
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
|
||||
const result2 = await index.search(
|
||||
{
|
||||
type: 'match',
|
||||
field: 'tag',
|
||||
match: 'apple',
|
||||
},
|
||||
{
|
||||
pagination: {
|
||||
skip: 10,
|
||||
limit: 10,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
expect(result2).toEqual({
|
||||
nodes: expect.arrayContaining(
|
||||
Array.from({ length: 10 }).fill({
|
||||
id: expect.stringContaining('apple'),
|
||||
score: expect.anything(),
|
||||
})
|
||||
),
|
||||
pagination: {
|
||||
count: 100,
|
||||
hasMore: true,
|
||||
limit: 10,
|
||||
skip: 10,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('aggr', async () => {
|
||||
await writeData({
|
||||
'1': {
|
||||
title: 'hello world',
|
||||
tag: ['car', 'bike'],
|
||||
},
|
||||
affine1: {
|
||||
title: 'affine',
|
||||
tag: ['motorcycle', 'bike'],
|
||||
},
|
||||
affine2: {
|
||||
title: 'affine',
|
||||
tag: ['motorcycle', 'airplane'],
|
||||
},
|
||||
});
|
||||
|
||||
const result = await index.aggregate(
|
||||
{
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'affine',
|
||||
},
|
||||
'tag'
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
buckets: expect.arrayContaining([
|
||||
{ key: 'motorcycle', count: 2, score: expect.anything() },
|
||||
{ key: 'bike', count: 1, score: expect.anything() },
|
||||
{ key: 'airplane', count: 1, score: expect.anything() },
|
||||
]),
|
||||
pagination: {
|
||||
count: 3,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('hits', async () => {
|
||||
await writeData(
|
||||
Array.from({ length: 100 }).reduce((acc: any, _, i) => {
|
||||
acc['apple' + i] = {
|
||||
title: 'apple',
|
||||
tag: ['apple', 'fruit'],
|
||||
};
|
||||
return acc;
|
||||
}, {}) as any
|
||||
);
|
||||
const result = await index.aggregate(
|
||||
{
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'apple',
|
||||
},
|
||||
'tag',
|
||||
{
|
||||
hits: {
|
||||
pagination: {
|
||||
skip: 0,
|
||||
limit: 5,
|
||||
},
|
||||
highlights: [
|
||||
{
|
||||
field: 'title',
|
||||
before: '<b>',
|
||||
end: '</b>',
|
||||
},
|
||||
],
|
||||
fields: ['title', 'tag'],
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
buckets: expect.arrayContaining([
|
||||
{
|
||||
key: 'apple',
|
||||
count: 100,
|
||||
score: expect.anything(),
|
||||
hits: {
|
||||
pagination: {
|
||||
count: 100,
|
||||
hasMore: true,
|
||||
limit: 5,
|
||||
skip: 0,
|
||||
},
|
||||
nodes: expect.arrayContaining(
|
||||
Array.from({ length: 5 }).fill({
|
||||
id: expect.stringContaining('apple'),
|
||||
score: expect.anything(),
|
||||
highlights: {
|
||||
title: [expect.stringContaining('<b>apple</b>')],
|
||||
},
|
||||
fields: {
|
||||
title: expect.stringContaining('apple'),
|
||||
tag: expect.arrayContaining(['apple', 'fruit']),
|
||||
},
|
||||
})
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
key: 'fruit',
|
||||
count: 100,
|
||||
score: expect.anything(),
|
||||
hits: {
|
||||
pagination: {
|
||||
count: 100,
|
||||
hasMore: true,
|
||||
limit: 5,
|
||||
skip: 0,
|
||||
},
|
||||
nodes: expect.arrayContaining(
|
||||
Array.from({ length: 5 }).fill({
|
||||
id: expect.stringContaining('apple'),
|
||||
score: expect.anything(),
|
||||
highlights: {
|
||||
title: [expect.stringContaining('<b>apple</b>')],
|
||||
},
|
||||
fields: {
|
||||
title: expect.stringContaining('apple'),
|
||||
tag: expect.arrayContaining(['apple', 'fruit']),
|
||||
},
|
||||
})
|
||||
),
|
||||
},
|
||||
},
|
||||
]),
|
||||
pagination: {
|
||||
count: 2,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('exists', async () => {
|
||||
await writeData({
|
||||
'1': {
|
||||
title: 'hello world',
|
||||
tag: '111',
|
||||
},
|
||||
'2': {
|
||||
tag: '222',
|
||||
},
|
||||
'3': {
|
||||
title: 'hello world',
|
||||
tag: '333',
|
||||
},
|
||||
});
|
||||
|
||||
const result = await index.search({
|
||||
type: 'exists',
|
||||
field: 'title',
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
nodes: expect.arrayContaining([
|
||||
{
|
||||
id: '1',
|
||||
score: expect.anything(),
|
||||
},
|
||||
{
|
||||
id: '3',
|
||||
score: expect.anything(),
|
||||
},
|
||||
]),
|
||||
pagination: {
|
||||
count: 2,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('subscribe', async () => {
|
||||
await writeData({
|
||||
'1': {
|
||||
title: 'hello world',
|
||||
},
|
||||
});
|
||||
|
||||
let value = null as any;
|
||||
index
|
||||
.search$({
|
||||
type: 'match',
|
||||
field: 'title',
|
||||
match: 'hello world',
|
||||
})
|
||||
.pipe(map(v => (value = v)))
|
||||
.subscribe();
|
||||
|
||||
await vitest.waitFor(
|
||||
() => {
|
||||
expect(value).toEqual({
|
||||
nodes: [
|
||||
{
|
||||
id: '1',
|
||||
score: expect.anything(),
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
count: 1,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
},
|
||||
{
|
||||
timeout: 5000,
|
||||
}
|
||||
);
|
||||
|
||||
await writeData({
|
||||
'2': {
|
||||
title: 'hello world',
|
||||
},
|
||||
});
|
||||
|
||||
await vitest.waitFor(
|
||||
() => {
|
||||
expect(value).toEqual({
|
||||
nodes: [
|
||||
{
|
||||
id: '1',
|
||||
score: expect.anything(),
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
score: expect.anything(),
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
count: 2,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
},
|
||||
{
|
||||
timeout: 5000,
|
||||
}
|
||||
);
|
||||
|
||||
const writer = await index.write();
|
||||
writer.delete('1');
|
||||
await writer.commit();
|
||||
|
||||
await vitest.waitFor(
|
||||
() => {
|
||||
expect(value).toEqual({
|
||||
nodes: [
|
||||
{
|
||||
id: '2',
|
||||
score: expect.anything(),
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
count: 1,
|
||||
hasMore: false,
|
||||
limit: expect.anything(),
|
||||
skip: 0,
|
||||
},
|
||||
});
|
||||
},
|
||||
{
|
||||
timeout: 5000,
|
||||
}
|
||||
);
|
||||
});
|
||||
});
|
||||
48
packages/common/infra/src/sync/indexer/document.ts
Normal file
48
packages/common/infra/src/sync/indexer/document.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import type { Schema } from './schema';
|
||||
|
||||
export class Document<S extends Schema = any> {
|
||||
constructor(public readonly id: string) {}
|
||||
|
||||
fields = new Map<keyof S, string[]>();
|
||||
|
||||
public insert<F extends keyof S>(field: F, value: string | string[]) {
|
||||
const values = this.fields.get(field) ?? [];
|
||||
if (Array.isArray(value)) {
|
||||
values.push(...value);
|
||||
} else {
|
||||
values.push(value);
|
||||
}
|
||||
this.fields.set(field, values);
|
||||
}
|
||||
|
||||
get<F extends keyof S>(field: F): string[] | string | undefined {
|
||||
const values = this.fields.get(field);
|
||||
if (values === undefined) {
|
||||
return undefined;
|
||||
} else if (values.length === 1) {
|
||||
return values[0];
|
||||
} else {
|
||||
return values;
|
||||
}
|
||||
}
|
||||
|
||||
static from<S extends Schema>(
|
||||
id: string,
|
||||
map:
|
||||
| Partial<Record<keyof S, string | string[]>>
|
||||
| Map<keyof S, string | string[]>
|
||||
): Document<S> {
|
||||
const doc = new Document(id);
|
||||
|
||||
if (map instanceof Map) {
|
||||
for (const [key, value] of map) {
|
||||
doc.insert(key, value);
|
||||
}
|
||||
} else {
|
||||
for (const key in map) {
|
||||
doc.insert(key, map[key] as string | string[]);
|
||||
}
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
}
|
||||
1
packages/common/infra/src/sync/indexer/field-type.ts
Normal file
1
packages/common/infra/src/sync/indexer/field-type.ts
Normal file
@@ -0,0 +1 @@
|
||||
export type FieldType = 'Integer' | 'FullText' | 'String' | 'Boolean';
|
||||
@@ -0,0 +1,9 @@
|
||||
import { expect, test } from 'vitest';
|
||||
|
||||
import { bm25 } from '../bm25';
|
||||
|
||||
test('bm25', () => {
|
||||
expect(bm25(1, 1, 10, 10, 15)).toEqual(3.2792079793859643);
|
||||
expect(bm25(2, 1, 10, 10, 15) > bm25(1, 1, 10, 10, 15)).toBeTruthy();
|
||||
expect(bm25(1, 1, 10, 10, 15) > bm25(2, 1, 10, 100, 15)).toBeTruthy();
|
||||
});
|
||||
@@ -0,0 +1,32 @@
|
||||
import { expect, test } from 'vitest';
|
||||
|
||||
import { highlighter } from '../highlighter';
|
||||
|
||||
test('highlighter', () => {
|
||||
expect(highlighter('0123456789', '<b>', '</b>', [[3, 5]])).toEqual(
|
||||
'012<b>34</b>56789'
|
||||
);
|
||||
|
||||
expect(
|
||||
highlighter(
|
||||
'012345678901234567890123456789012345678901234567890123456789',
|
||||
'<b>',
|
||||
'</b>',
|
||||
[[59, 60]]
|
||||
)
|
||||
).toEqual('...0123456789012345678901234567890123456789012345678<b>9</b>');
|
||||
|
||||
expect(
|
||||
highlighter(
|
||||
'012345678901234567890123456789012345678901234567890123456789',
|
||||
'<b>',
|
||||
'</b>',
|
||||
[
|
||||
[10, 11],
|
||||
[49, 51],
|
||||
]
|
||||
)
|
||||
).toEqual(
|
||||
'0123456789<b>0</b>12345678901234567890123456789012345678<b>9</b>...'
|
||||
);
|
||||
});
|
||||
@@ -0,0 +1,128 @@
|
||||
import { expect, test } from 'vitest';
|
||||
|
||||
import { GeneralTokenizer } from '../tokenizer';
|
||||
|
||||
test('tokenizer', () => {
|
||||
{
|
||||
const tokens = new GeneralTokenizer().tokenize('hello world,\n AFFiNE');
|
||||
|
||||
expect(tokens).toEqual([
|
||||
{ term: 'hello', start: 0, end: 5 },
|
||||
{ term: 'world', start: 7, end: 12 },
|
||||
{ term: 'affine', start: 15, end: 21 },
|
||||
]);
|
||||
}
|
||||
|
||||
{
|
||||
const tokens = new GeneralTokenizer().tokenize('你好世界,阿芬');
|
||||
|
||||
expect(tokens).toEqual([
|
||||
{
|
||||
end: 2,
|
||||
start: 0,
|
||||
term: '你好',
|
||||
},
|
||||
{
|
||||
end: 3,
|
||||
start: 1,
|
||||
term: '好世',
|
||||
},
|
||||
{
|
||||
end: 4,
|
||||
start: 2,
|
||||
term: '世界',
|
||||
},
|
||||
{
|
||||
end: 7,
|
||||
start: 5,
|
||||
term: '阿芬',
|
||||
},
|
||||
]);
|
||||
}
|
||||
|
||||
{
|
||||
const tokens = new GeneralTokenizer().tokenize('1阿2芬');
|
||||
|
||||
expect(tokens).toEqual([
|
||||
{ term: '1', start: 0, end: 1 },
|
||||
{ term: '阿', start: 1, end: 2 },
|
||||
{ term: '2', start: 2, end: 3 },
|
||||
{ term: '芬', start: 3, end: 4 },
|
||||
]);
|
||||
}
|
||||
|
||||
{
|
||||
const tokens = new GeneralTokenizer().tokenize('안녕하세요 세계');
|
||||
|
||||
expect(tokens).toEqual([
|
||||
{
|
||||
end: 2,
|
||||
start: 0,
|
||||
term: '안녕',
|
||||
},
|
||||
{
|
||||
end: 3,
|
||||
start: 1,
|
||||
term: '녕하',
|
||||
},
|
||||
{
|
||||
end: 4,
|
||||
start: 2,
|
||||
term: '하세',
|
||||
},
|
||||
{
|
||||
end: 5,
|
||||
start: 3,
|
||||
term: '세요',
|
||||
},
|
||||
{
|
||||
end: 8,
|
||||
start: 6,
|
||||
term: '세계',
|
||||
},
|
||||
]);
|
||||
}
|
||||
|
||||
{
|
||||
const tokens = new GeneralTokenizer().tokenize('ハローワールド');
|
||||
|
||||
expect(tokens).toEqual([
|
||||
{ term: 'ハロ', start: 0, end: 2 },
|
||||
{ term: 'ロー', start: 1, end: 3 },
|
||||
{ term: 'ーワ', start: 2, end: 4 },
|
||||
{ term: 'ワー', start: 3, end: 5 },
|
||||
{ term: 'ール', start: 4, end: 6 },
|
||||
{ term: 'ルド', start: 5, end: 7 },
|
||||
]);
|
||||
}
|
||||
|
||||
{
|
||||
const tokens = new GeneralTokenizer().tokenize('はろーわーるど');
|
||||
|
||||
expect(tokens).toEqual([
|
||||
{ term: 'はろ', start: 0, end: 2 },
|
||||
{ term: 'ろー', start: 1, end: 3 },
|
||||
{ term: 'ーわ', start: 2, end: 4 },
|
||||
{ term: 'わー', start: 3, end: 5 },
|
||||
{ term: 'ーる', start: 4, end: 6 },
|
||||
{ term: 'るど', start: 5, end: 7 },
|
||||
]);
|
||||
}
|
||||
|
||||
{
|
||||
const tokens = new GeneralTokenizer().tokenize('👋1️⃣🚪👋🏿');
|
||||
|
||||
expect(tokens).toEqual([
|
||||
{ term: '👋', start: 0, end: 2 },
|
||||
{ term: '1️⃣', start: 2, end: 5 },
|
||||
{ term: '🚪', start: 5, end: 7 },
|
||||
{ term: '👋🏿', start: 7, end: 11 },
|
||||
]);
|
||||
}
|
||||
|
||||
{
|
||||
const tokens = new GeneralTokenizer().tokenize('1️');
|
||||
|
||||
expect(tokens).toEqual([{ term: '1️', start: 0, end: 2 }]);
|
||||
}
|
||||
});
|
||||
@@ -0,0 +1,62 @@
|
||||
/**
|
||||
* Parameters of the BM25+ scoring algorithm. Customizing these is almost never
|
||||
* necessary, and finetuning them requires an understanding of the BM25 scoring
|
||||
* model.
|
||||
*
|
||||
* Some information about BM25 (and BM25+) can be found at these links:
|
||||
*
|
||||
* - https://en.wikipedia.org/wiki/Okapi_BM25
|
||||
* - https://opensourceconnections.com/blog/2015/10/16/bm25-the-next-generation-of-lucene-relevation/
|
||||
*/
|
||||
export type BM25Params = {
|
||||
/** Term frequency saturation point.
|
||||
*
|
||||
* Recommended values are between `1.2` and `2`. Higher values increase the
|
||||
* difference in score between documents with higher and lower term
|
||||
* frequencies. Setting this to `0` or a negative value is invalid. Defaults
|
||||
* to `1.2`
|
||||
*/
|
||||
k: number;
|
||||
|
||||
/**
|
||||
* Length normalization impact.
|
||||
*
|
||||
* Recommended values are around `0.75`. Higher values increase the weight
|
||||
* that field length has on scoring. Setting this to `0` (not recommended)
|
||||
* means that the field length has no effect on scoring. Negative values are
|
||||
* invalid. Defaults to `0.7`.
|
||||
*/
|
||||
b: number;
|
||||
|
||||
/**
|
||||
* BM25+ frequency normalization lower bound (usually called δ).
|
||||
*
|
||||
* Recommended values are between `0.5` and `1`. Increasing this parameter
|
||||
* increases the minimum relevance of one occurrence of a search term
|
||||
* regardless of its (possibly very long) field length. Negative values are
|
||||
* invalid. Defaults to `0.5`.
|
||||
*/
|
||||
d: number;
|
||||
};
|
||||
|
||||
const defaultBM25params: BM25Params = { k: 1.2, b: 0.7, d: 0.5 };
|
||||
|
||||
export const bm25 = (
|
||||
termFreq: number,
|
||||
matchingCount: number,
|
||||
totalCount: number,
|
||||
fieldLength: number,
|
||||
avgFieldLength: number,
|
||||
bm25params: BM25Params = defaultBM25params
|
||||
): number => {
|
||||
const { k, b, d } = bm25params;
|
||||
const invDocFreq = Math.log(
|
||||
1 + (totalCount - matchingCount + 0.5) / (matchingCount + 0.5)
|
||||
);
|
||||
return (
|
||||
invDocFreq *
|
||||
(d +
|
||||
(termFreq * (k + 1)) /
|
||||
(termFreq + k * (1 - b + (b * fieldLength) / avgFieldLength)))
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,465 @@
|
||||
import {
|
||||
type DBSchema,
|
||||
type IDBPDatabase,
|
||||
type IDBPTransaction,
|
||||
openDB,
|
||||
type StoreNames,
|
||||
} from 'idb';
|
||||
|
||||
import {
|
||||
type AggregateOptions,
|
||||
type AggregateResult,
|
||||
Document,
|
||||
type Query,
|
||||
type Schema,
|
||||
type SearchOptions,
|
||||
type SearchResult,
|
||||
} from '../../';
|
||||
import { highlighter } from './highlighter';
|
||||
import {
|
||||
BooleanInvertedIndex,
|
||||
FullTextInvertedIndex,
|
||||
IntegerInvertedIndex,
|
||||
type InvertedIndex,
|
||||
StringInvertedIndex,
|
||||
} from './inverted-index';
|
||||
import { Match } from './match';
|
||||
|
||||
export interface IndexDB extends DBSchema {
|
||||
kvMetadata: {
|
||||
key: string;
|
||||
value: {
|
||||
key: string;
|
||||
value: any;
|
||||
};
|
||||
};
|
||||
records: {
|
||||
key: number;
|
||||
value: {
|
||||
id: string;
|
||||
data: Map<string, string[]>;
|
||||
};
|
||||
indexes: { id: string };
|
||||
};
|
||||
invertedIndex: {
|
||||
key: number;
|
||||
value: {
|
||||
nid: number;
|
||||
pos?: {
|
||||
i: number /* index */;
|
||||
l: number /* length */;
|
||||
rs: [number, number][] /* ranges: [start, end] */;
|
||||
};
|
||||
key: ArrayBuffer;
|
||||
};
|
||||
indexes: { key: ArrayBuffer; nid: number };
|
||||
};
|
||||
}
|
||||
|
||||
export type DataStructRWTransaction = IDBPTransaction<
|
||||
IndexDB,
|
||||
ArrayLike<StoreNames<IndexDB>>,
|
||||
'readwrite'
|
||||
>;
|
||||
|
||||
export type DataStructROTransaction = IDBPTransaction<
|
||||
IndexDB,
|
||||
ArrayLike<StoreNames<IndexDB>>,
|
||||
'readonly' | 'readwrite'
|
||||
>;
|
||||
|
||||
export class DataStruct {
|
||||
private initializePromise: Promise<void> | null = null;
|
||||
database: IDBPDatabase<IndexDB> = null as any;
|
||||
invertedIndex = new Map<string, InvertedIndex>();
|
||||
|
||||
constructor(
|
||||
readonly databaseName: string,
|
||||
schema: Schema
|
||||
) {
|
||||
for (const [key, type] of Object.entries(schema)) {
|
||||
if (type === 'String') {
|
||||
this.invertedIndex.set(key, new StringInvertedIndex(key));
|
||||
} else if (type === 'Integer') {
|
||||
this.invertedIndex.set(key, new IntegerInvertedIndex(key));
|
||||
} else if (type === 'FullText') {
|
||||
this.invertedIndex.set(key, new FullTextInvertedIndex(key));
|
||||
} else if (type === 'Boolean') {
|
||||
this.invertedIndex.set(key, new BooleanInvertedIndex(key));
|
||||
} else {
|
||||
throw new Error(`Field type '${type}' not supported`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async insert(trx: DataStructRWTransaction, document: Document) {
|
||||
const exists = await trx
|
||||
.objectStore('records')
|
||||
.index('id')
|
||||
.get(document.id);
|
||||
|
||||
if (exists) {
|
||||
throw new Error('Document already exists');
|
||||
}
|
||||
|
||||
const nid = await trx.objectStore('records').add({
|
||||
id: document.id,
|
||||
data: new Map(document.fields as Map<string, string[]>),
|
||||
});
|
||||
|
||||
for (const [key, values] of document.fields) {
|
||||
const iidx = this.invertedIndex.get(key as string);
|
||||
if (!iidx) {
|
||||
throw new Error(
|
||||
`Inverted index '${key.toString()}' not found, document not match schema`
|
||||
);
|
||||
}
|
||||
await iidx.insert(trx, nid, values);
|
||||
}
|
||||
}
|
||||
|
||||
async delete(trx: DataStructRWTransaction, id: string) {
|
||||
const nid = await trx.objectStore('records').index('id').getKey(id);
|
||||
|
||||
if (nid) {
|
||||
await trx.objectStore('records').delete(nid);
|
||||
}
|
||||
|
||||
const indexIds = await trx
|
||||
.objectStore('invertedIndex')
|
||||
.index('nid')
|
||||
.getAllKeys(nid);
|
||||
for (const indexId of indexIds) {
|
||||
await trx.objectStore('invertedIndex').delete(indexId);
|
||||
}
|
||||
}
|
||||
|
||||
async batchWrite(
|
||||
trx: DataStructRWTransaction,
|
||||
deletes: string[],
|
||||
inserts: Document[]
|
||||
) {
|
||||
for (const del of deletes) {
|
||||
await this.delete(trx, del);
|
||||
}
|
||||
for (const inst of inserts) {
|
||||
await this.insert(trx, inst);
|
||||
}
|
||||
}
|
||||
|
||||
async matchAll(trx: DataStructROTransaction): Promise<Match> {
|
||||
const allNids = await trx.objectStore('records').getAllKeys();
|
||||
const match = new Match();
|
||||
|
||||
for (const nid of allNids) {
|
||||
match.addScore(nid, 1);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
private async queryRaw(
|
||||
trx: DataStructROTransaction,
|
||||
query: Query<any>
|
||||
): Promise<Match> {
|
||||
if (query.type === 'match') {
|
||||
const iidx = this.invertedIndex.get(query.field as string);
|
||||
if (!iidx) {
|
||||
throw new Error(`Field '${query.field as string}' not found`);
|
||||
}
|
||||
return await iidx.match(trx, query.match);
|
||||
} else if (query.type === 'boolean') {
|
||||
const weights = [];
|
||||
for (const q of query.queries) {
|
||||
weights.push(await this.queryRaw(trx, q));
|
||||
}
|
||||
if (query.occur === 'must') {
|
||||
return weights.reduce((acc, w) => acc.and(w));
|
||||
} else if (query.occur === 'must_not') {
|
||||
const total = weights.reduce((acc, w) => acc.and(w));
|
||||
return (await this.matchAll(trx)).exclude(total);
|
||||
} else if (query.occur === 'should') {
|
||||
return weights.reduce((acc, w) => acc.or(w));
|
||||
}
|
||||
} else if (query.type === 'all') {
|
||||
return await this.matchAll(trx);
|
||||
} else if (query.type === 'boost') {
|
||||
return (await this.queryRaw(trx, query.query)).boost(query.boost);
|
||||
} else if (query.type === 'exists') {
|
||||
const iidx = this.invertedIndex.get(query.field as string);
|
||||
if (!iidx) {
|
||||
throw new Error(`Field '${query.field as string}' not found`);
|
||||
}
|
||||
return await iidx.all(trx);
|
||||
}
|
||||
throw new Error(`Query type '${query.type}' not supported`);
|
||||
}
|
||||
|
||||
private async query(
|
||||
trx: DataStructROTransaction,
|
||||
query: Query<any>
|
||||
): Promise<Match> {
|
||||
const match = await this.queryRaw(trx, query);
|
||||
const filteredMatch = match.asyncFilter(async nid => {
|
||||
const record = await trx.objectStore('records').getKey(nid);
|
||||
return record !== undefined;
|
||||
});
|
||||
return filteredMatch;
|
||||
}
|
||||
|
||||
async clear(trx: DataStructRWTransaction) {
|
||||
await trx.objectStore('records').clear();
|
||||
await trx.objectStore('invertedIndex').clear();
|
||||
await trx.objectStore('kvMetadata').clear();
|
||||
}
|
||||
|
||||
async search(
|
||||
trx: DataStructROTransaction,
|
||||
query: Query<any>,
|
||||
options: SearchOptions<any>
|
||||
): Promise<SearchResult<any, any>> {
|
||||
const pagination = {
|
||||
skip: options.pagination?.skip ?? 0,
|
||||
limit: options.pagination?.limit ?? 100,
|
||||
};
|
||||
|
||||
const match = await this.query(trx, query);
|
||||
|
||||
const nids = match
|
||||
.toArray()
|
||||
.slice(pagination.skip, pagination.skip + pagination.limit);
|
||||
|
||||
const nodes = [];
|
||||
for (const nid of nids) {
|
||||
nodes.push(await this.resultNode(trx, match, nid, options));
|
||||
}
|
||||
|
||||
return {
|
||||
pagination: {
|
||||
count: match.size(),
|
||||
hasMore: match.size() > pagination.limit + pagination.skip,
|
||||
limit: pagination.limit,
|
||||
skip: pagination.skip,
|
||||
},
|
||||
nodes: nodes,
|
||||
};
|
||||
}
|
||||
|
||||
async aggregate(
|
||||
trx: DataStructROTransaction,
|
||||
query: Query<any>,
|
||||
field: string,
|
||||
options: AggregateOptions<any>
|
||||
): Promise<AggregateResult<any, any>> {
|
||||
const pagination = {
|
||||
skip: options.pagination?.skip ?? 0,
|
||||
limit: options.pagination?.limit ?? 100,
|
||||
};
|
||||
|
||||
const hitPagination = options.hits
|
||||
? {
|
||||
skip: options.hits.pagination?.skip ?? 0,
|
||||
limit: options.hits.pagination?.limit ?? 3,
|
||||
}
|
||||
: {
|
||||
skip: 0,
|
||||
limit: 0,
|
||||
};
|
||||
|
||||
const match = await this.query(trx, query);
|
||||
|
||||
const nids = match.toArray();
|
||||
|
||||
const buckets: {
|
||||
key: string;
|
||||
nids: number[];
|
||||
hits: SearchResult<any, any>['nodes'];
|
||||
}[] = [];
|
||||
|
||||
for (const nid of nids) {
|
||||
const values = (await trx.objectStore('records').get(nid))?.data.get(
|
||||
field
|
||||
);
|
||||
for (const value of values ?? []) {
|
||||
let bucket;
|
||||
let bucketIndex = buckets.findIndex(b => b.key === value);
|
||||
if (bucketIndex === -1) {
|
||||
bucket = { key: value, nids: [], hits: [] };
|
||||
buckets.push(bucket);
|
||||
bucketIndex = buckets.length - 1;
|
||||
} else {
|
||||
bucket = buckets[bucketIndex];
|
||||
}
|
||||
|
||||
if (
|
||||
bucketIndex >= pagination.skip &&
|
||||
bucketIndex < pagination.skip + pagination.limit
|
||||
) {
|
||||
bucket.nids.push(nid);
|
||||
if (
|
||||
bucket.nids.length - 1 >= hitPagination.skip &&
|
||||
bucket.nids.length - 1 < hitPagination.skip + hitPagination.limit
|
||||
) {
|
||||
bucket.hits.push(
|
||||
await this.resultNode(trx, match, nid, options.hits ?? {})
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
buckets: buckets
|
||||
.slice(pagination.skip, pagination.skip + pagination.limit)
|
||||
.map(bucket => {
|
||||
const result = {
|
||||
key: bucket.key,
|
||||
score: match.getScore(bucket.nids[0]),
|
||||
count: bucket.nids.length,
|
||||
} as AggregateResult<any, any>['buckets'][number];
|
||||
|
||||
if (options.hits) {
|
||||
(result as any).hits = {
|
||||
pagination: {
|
||||
count: bucket.nids.length,
|
||||
hasMore:
|
||||
bucket.nids.length > hitPagination.limit + hitPagination.skip,
|
||||
limit: hitPagination.limit,
|
||||
skip: hitPagination.skip,
|
||||
},
|
||||
nodes: bucket.hits,
|
||||
} as SearchResult<any, any>;
|
||||
}
|
||||
|
||||
return result;
|
||||
}),
|
||||
pagination: {
|
||||
count: buckets.length,
|
||||
hasMore: buckets.length > pagination.limit + pagination.skip,
|
||||
limit: pagination.limit,
|
||||
skip: pagination.skip,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async getAll(
|
||||
trx: DataStructROTransaction,
|
||||
ids: string[]
|
||||
): Promise<Document[]> {
|
||||
const docs = [];
|
||||
for (const id of ids) {
|
||||
const record = await trx.objectStore('records').index('id').get(id);
|
||||
if (record) {
|
||||
docs.push(Document.from(record.id, record.data));
|
||||
}
|
||||
}
|
||||
|
||||
return docs;
|
||||
}
|
||||
|
||||
async has(trx: DataStructROTransaction, id: string): Promise<boolean> {
|
||||
const nid = await trx.objectStore('records').index('id').getKey(id);
|
||||
return nid !== undefined;
|
||||
}
|
||||
|
||||
async readonly() {
|
||||
await this.ensureInitialized();
|
||||
return this.database.transaction(
|
||||
['records', 'invertedIndex', 'kvMetadata'],
|
||||
'readonly'
|
||||
);
|
||||
}
|
||||
|
||||
async readwrite() {
|
||||
await this.ensureInitialized();
|
||||
return this.database.transaction(
|
||||
['records', 'invertedIndex', 'kvMetadata'],
|
||||
'readwrite'
|
||||
);
|
||||
}
|
||||
|
||||
private async ensureInitialized() {
|
||||
if (this.database) {
|
||||
return;
|
||||
}
|
||||
this.initializePromise ??= this.initialize();
|
||||
await this.initializePromise;
|
||||
}
|
||||
|
||||
private async initialize() {
|
||||
this.database = await openDB<IndexDB>(this.databaseName, 1, {
|
||||
upgrade(database) {
|
||||
database.createObjectStore('kvMetadata', {
|
||||
keyPath: 'key',
|
||||
});
|
||||
const recordsStore = database.createObjectStore('records', {
|
||||
autoIncrement: true,
|
||||
});
|
||||
recordsStore.createIndex('id', 'id', {
|
||||
unique: true,
|
||||
});
|
||||
const invertedIndexStore = database.createObjectStore('invertedIndex', {
|
||||
autoIncrement: true,
|
||||
});
|
||||
invertedIndexStore.createIndex('key', 'key', { unique: false });
|
||||
invertedIndexStore.createIndex('nid', 'nid', { unique: false });
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
private async resultNode(
|
||||
trx: DataStructROTransaction,
|
||||
match: Match,
|
||||
nid: number,
|
||||
options: SearchOptions<any>
|
||||
): Promise<SearchResult<any, any>['nodes'][number]> {
|
||||
const record = await trx.objectStore('records').get(nid);
|
||||
if (!record) {
|
||||
throw new Error(`Record not found for nid ${nid}`);
|
||||
}
|
||||
|
||||
const node = {
|
||||
id: record.id,
|
||||
score: match.getScore(nid),
|
||||
} as any;
|
||||
|
||||
if (options.fields) {
|
||||
const fields = {} as Record<string, string | string[]>;
|
||||
for (const field of options.fields as string[]) {
|
||||
fields[field] = record.data.get(field) ?? [''];
|
||||
if (fields[field].length === 1) {
|
||||
fields[field] = fields[field][0];
|
||||
}
|
||||
}
|
||||
node.fields = fields;
|
||||
}
|
||||
|
||||
if (options.highlights) {
|
||||
const highlights = {} as Record<string, string[]>;
|
||||
for (const { field, before, end } of options.highlights) {
|
||||
const highlightValues = match.getHighlighters(nid, field);
|
||||
if (highlightValues) {
|
||||
const rawValues = record.data.get(field) ?? [];
|
||||
highlights[field] = Array.from(highlightValues)
|
||||
.map(([index, ranges]) => {
|
||||
const raw = rawValues[index];
|
||||
|
||||
if (raw) {
|
||||
return (
|
||||
highlighter(raw, before, end, ranges, {
|
||||
maxPrefix: 20,
|
||||
maxLength: 50,
|
||||
}) ?? ''
|
||||
);
|
||||
}
|
||||
|
||||
return '';
|
||||
})
|
||||
.filter(Boolean);
|
||||
}
|
||||
}
|
||||
node.highlights = highlights;
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
export function highlighter(
|
||||
originText: string,
|
||||
before: string,
|
||||
after: string,
|
||||
matches: [number, number][],
|
||||
{
|
||||
maxLength = 50,
|
||||
maxPrefix = 20,
|
||||
}: { maxLength?: number; maxPrefix?: number } = {}
|
||||
) {
|
||||
const merged = mergeRanges(matches);
|
||||
|
||||
if (merged.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const firstMatch = merged[0][0];
|
||||
const start = Math.max(
|
||||
0,
|
||||
Math.min(firstMatch - maxPrefix, originText.length - maxLength)
|
||||
);
|
||||
const end = Math.min(start + maxLength, originText.length);
|
||||
const text = originText.substring(start, end);
|
||||
|
||||
let result = '';
|
||||
|
||||
let pointer = 0;
|
||||
for (const match of merged) {
|
||||
const matchStart = match[0] - start;
|
||||
const matchEnd = match[1] - start;
|
||||
if (matchStart >= text.length) {
|
||||
break;
|
||||
}
|
||||
result += text.substring(pointer, matchStart);
|
||||
pointer = matchStart;
|
||||
const highlighted = text.substring(matchStart, matchEnd);
|
||||
|
||||
if (highlighted.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
result += `${before}${highlighted}${after}`;
|
||||
pointer = matchEnd;
|
||||
}
|
||||
result += text.substring(pointer);
|
||||
|
||||
if (start > 0) {
|
||||
result = `...${result}`;
|
||||
}
|
||||
|
||||
if (end < originText.length) {
|
||||
result = `${result}...`;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function mergeRanges(intervals: [number, number][]) {
|
||||
if (intervals.length === 0) return [];
|
||||
|
||||
intervals.sort((a, b) => a[0] - b[0]);
|
||||
|
||||
const merged = [intervals[0]];
|
||||
|
||||
for (let i = 1; i < intervals.length; i++) {
|
||||
const last = merged[merged.length - 1];
|
||||
const current = intervals[i];
|
||||
|
||||
if (current[0] <= last[1]) {
|
||||
last[1] = Math.max(last[1], current[1]);
|
||||
} else {
|
||||
merged.push(current);
|
||||
}
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
171
packages/common/infra/src/sync/indexer/impl/indexeddb/index.ts
Normal file
171
packages/common/infra/src/sync/indexer/impl/indexeddb/index.ts
Normal file
@@ -0,0 +1,171 @@
|
||||
import type { Observable } from 'rxjs';
|
||||
import { from, merge, of, Subject, throttleTime } from 'rxjs';
|
||||
|
||||
import { exhaustMapWithTrailing } from '../../../../utils/exhaustmap-with-trailing';
|
||||
import {
|
||||
type AggregateOptions,
|
||||
type AggregateResult,
|
||||
type Document,
|
||||
type Index,
|
||||
type IndexStorage,
|
||||
type IndexWriter,
|
||||
type Query,
|
||||
type Schema,
|
||||
type SearchOptions,
|
||||
type SearchResult,
|
||||
} from '../../';
|
||||
import { DataStruct, type DataStructRWTransaction } from './data-struct';
|
||||
|
||||
export class IndexedDBIndex<S extends Schema> implements Index<S> {
|
||||
data: DataStruct = new DataStruct(this.databaseName, this.schema);
|
||||
broadcast$ = new Subject();
|
||||
|
||||
constructor(
|
||||
private readonly schema: S,
|
||||
private readonly databaseName: string = 'indexer'
|
||||
) {
|
||||
const channel = new BroadcastChannel(this.databaseName + ':indexer');
|
||||
channel.onmessage = () => {
|
||||
this.broadcast$.next(1);
|
||||
};
|
||||
}
|
||||
|
||||
async get(id: string): Promise<Document<S> | null> {
|
||||
return (await this.getAll([id]))[0] ?? null;
|
||||
}
|
||||
|
||||
async getAll(ids: string[]): Promise<Document<S>[]> {
|
||||
const trx = await this.data.readonly();
|
||||
return this.data.getAll(trx, ids);
|
||||
}
|
||||
|
||||
async write(): Promise<IndexWriter<S>> {
|
||||
return new IndexedDBIndexWriter(this.data, await this.data.readwrite());
|
||||
}
|
||||
|
||||
async has(id: string): Promise<boolean> {
|
||||
const trx = await this.data.readonly();
|
||||
return this.data.has(trx, id);
|
||||
}
|
||||
|
||||
async search(
|
||||
query: Query<any>,
|
||||
options: SearchOptions<any> = {}
|
||||
): Promise<SearchResult<any, SearchOptions<any>>> {
|
||||
const trx = await this.data.readonly();
|
||||
return this.data.search(trx, query, options);
|
||||
}
|
||||
|
||||
search$(
|
||||
query: Query<any>,
|
||||
options: SearchOptions<any> = {}
|
||||
): Observable<SearchResult<any, SearchOptions<any>>> {
|
||||
return merge(of(1), this.broadcast$).pipe(
|
||||
throttleTime(500, undefined, { leading: true, trailing: true }),
|
||||
exhaustMapWithTrailing(() => {
|
||||
return from(
|
||||
(async () => {
|
||||
const trx = await this.data.readonly();
|
||||
return this.data.search(trx, query, options);
|
||||
})()
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
async aggregate(
|
||||
query: Query<any>,
|
||||
field: string,
|
||||
options: AggregateOptions<any> = {}
|
||||
): Promise<AggregateResult<any, AggregateOptions<any>>> {
|
||||
const trx = await this.data.readonly();
|
||||
return this.data.aggregate(trx, query, field, options);
|
||||
}
|
||||
|
||||
aggregate$(
|
||||
query: Query<any>,
|
||||
field: string,
|
||||
options: AggregateOptions<any> = {}
|
||||
): Observable<AggregateResult<S, AggregateOptions<any>>> {
|
||||
return merge(of(1), this.broadcast$).pipe(
|
||||
throttleTime(500, undefined, { leading: true, trailing: true }),
|
||||
exhaustMapWithTrailing(() => {
|
||||
return from(
|
||||
(async () => {
|
||||
const trx = await this.data.readonly();
|
||||
return this.data.aggregate(trx, query, field, options);
|
||||
})()
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
async clear(): Promise<void> {
|
||||
const trx = await this.data.readwrite();
|
||||
return this.data.clear(trx);
|
||||
}
|
||||
}
|
||||
|
||||
export class IndexedDBIndexWriter<S extends Schema> implements IndexWriter<S> {
|
||||
inserts: Document[] = [];
|
||||
deletes: string[] = [];
|
||||
channel = new BroadcastChannel(this.data.databaseName + ':indexer');
|
||||
|
||||
constructor(
|
||||
private readonly data: DataStruct,
|
||||
private readonly trx: DataStructRWTransaction
|
||||
) {}
|
||||
|
||||
async get(id: string): Promise<Document<S> | null> {
|
||||
return (await this.getAll([id]))[0] ?? null;
|
||||
}
|
||||
|
||||
async getAll(ids: string[]): Promise<Document<S>[]> {
|
||||
const trx = await this.data.readonly();
|
||||
return this.data.getAll(trx, ids);
|
||||
}
|
||||
|
||||
insert(document: Document): void {
|
||||
this.inserts.push(document);
|
||||
}
|
||||
delete(id: string): void {
|
||||
this.deletes.push(id);
|
||||
}
|
||||
put(document: Document): void {
|
||||
this.delete(document.id);
|
||||
this.insert(document);
|
||||
}
|
||||
|
||||
async commit(): Promise<void> {
|
||||
await this.data.batchWrite(this.trx, this.deletes, this.inserts);
|
||||
this.channel.postMessage(1);
|
||||
}
|
||||
|
||||
rollback(): void {}
|
||||
|
||||
has(id: string): Promise<boolean> {
|
||||
return this.data.has(this.trx, id);
|
||||
}
|
||||
|
||||
async search(
|
||||
query: Query<any>,
|
||||
options: SearchOptions<any> = {}
|
||||
): Promise<SearchResult<any, SearchOptions<any>>> {
|
||||
return this.data.search(this.trx, query, options);
|
||||
}
|
||||
|
||||
async aggregate(
|
||||
query: Query<any>,
|
||||
field: string,
|
||||
options: AggregateOptions<any> = {}
|
||||
): Promise<AggregateResult<any, AggregateOptions<any>>> {
|
||||
return this.data.aggregate(this.trx, query, field, options);
|
||||
}
|
||||
}
|
||||
|
||||
export class IndexedDBIndexStorage implements IndexStorage {
|
||||
constructor(private readonly databaseName: string) {}
|
||||
getIndex<S extends Schema>(name: string, s: S): Index<S> {
|
||||
return new IndexedDBIndex(s, this.databaseName + ':' + name);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,429 @@
|
||||
import { bm25 } from './bm25';
|
||||
import type {
|
||||
DataStructROTransaction,
|
||||
DataStructRWTransaction,
|
||||
} from './data-struct';
|
||||
import { Match } from './match';
|
||||
import { GeneralTokenizer, type Token } from './tokenizer';
|
||||
|
||||
export interface InvertedIndex {
|
||||
fieldKey: string;
|
||||
|
||||
match(trx: DataStructROTransaction, term: string): Promise<Match>;
|
||||
|
||||
all(trx: DataStructROTransaction): Promise<Match>;
|
||||
|
||||
insert(
|
||||
trx: DataStructRWTransaction,
|
||||
id: number,
|
||||
terms: string[]
|
||||
): Promise<void>;
|
||||
}
|
||||
|
||||
export class StringInvertedIndex implements InvertedIndex {
|
||||
constructor(readonly fieldKey: string) {}
|
||||
|
||||
async match(trx: DataStructROTransaction, term: string): Promise<Match> {
|
||||
const objs = await trx
|
||||
.objectStore('invertedIndex')
|
||||
.index('key')
|
||||
.getAll(InvertedIndexKey.forString(this.fieldKey, term).buffer());
|
||||
const match = new Match();
|
||||
for (const obj of objs) {
|
||||
match.addScore(obj.nid, 1);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
async all(trx: DataStructROTransaction): Promise<Match> {
|
||||
const objs = await trx
|
||||
.objectStore('invertedIndex')
|
||||
.index('key')
|
||||
.getAll(
|
||||
IDBKeyRange.bound(
|
||||
InvertedIndexKey.forPrefix(this.fieldKey).buffer(),
|
||||
InvertedIndexKey.forPrefix(this.fieldKey).add1().buffer()
|
||||
)
|
||||
);
|
||||
|
||||
const set = new Set<number>();
|
||||
for (const obj of objs) {
|
||||
set.add(obj.nid);
|
||||
}
|
||||
|
||||
const match = new Match();
|
||||
for (const nid of set) {
|
||||
match.addScore(nid, 1);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
async insert(trx: DataStructRWTransaction, id: number, terms: string[]) {
|
||||
for (const term of terms) {
|
||||
await trx.objectStore('invertedIndex').add({
|
||||
key: InvertedIndexKey.forString(this.fieldKey, term).buffer(),
|
||||
nid: id,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class IntegerInvertedIndex implements InvertedIndex {
|
||||
constructor(readonly fieldKey: string) {}
|
||||
|
||||
async match(trx: DataStructROTransaction, term: string): Promise<Match> {
|
||||
const objs = await trx
|
||||
.objectStore('invertedIndex')
|
||||
.index('key')
|
||||
.getAll(InvertedIndexKey.forInt64(this.fieldKey, BigInt(term)).buffer());
|
||||
const match = new Match();
|
||||
for (const obj of objs) {
|
||||
match.addScore(obj.nid, 1);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
async all(trx: DataStructROTransaction): Promise<Match> {
|
||||
const objs = await trx
|
||||
.objectStore('invertedIndex')
|
||||
.index('key')
|
||||
.getAll(
|
||||
IDBKeyRange.bound(
|
||||
InvertedIndexKey.forPrefix(this.fieldKey).buffer(),
|
||||
InvertedIndexKey.forPrefix(this.fieldKey).add1().buffer()
|
||||
)
|
||||
);
|
||||
|
||||
const set = new Set<number>();
|
||||
for (const obj of objs) {
|
||||
set.add(obj.nid);
|
||||
}
|
||||
|
||||
const match = new Match();
|
||||
for (const nid of set) {
|
||||
match.addScore(nid, 1);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
async insert(trx: DataStructRWTransaction, id: number, terms: string[]) {
|
||||
for (const term of terms) {
|
||||
await trx.objectStore('invertedIndex').add({
|
||||
key: InvertedIndexKey.forInt64(this.fieldKey, BigInt(term)).buffer(),
|
||||
nid: id,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class BooleanInvertedIndex implements InvertedIndex {
|
||||
constructor(readonly fieldKey: string) {}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
async all(trx: DataStructROTransaction): Promise<Match> {
|
||||
const objs = await trx
|
||||
.objectStore('invertedIndex')
|
||||
.index('key')
|
||||
.getAll(
|
||||
IDBKeyRange.bound(
|
||||
InvertedIndexKey.forPrefix(this.fieldKey).buffer(),
|
||||
InvertedIndexKey.forPrefix(this.fieldKey).add1().buffer()
|
||||
)
|
||||
);
|
||||
|
||||
const set = new Set<number>();
|
||||
for (const obj of objs) {
|
||||
set.add(obj.nid);
|
||||
}
|
||||
|
||||
const match = new Match();
|
||||
for (const nid of set) {
|
||||
match.addScore(nid, 1);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
async match(trx: DataStructROTransaction, term: string): Promise<Match> {
|
||||
const objs = await trx
|
||||
.objectStore('invertedIndex')
|
||||
.index('key')
|
||||
.getAll(
|
||||
InvertedIndexKey.forBoolean(this.fieldKey, term === 'true').buffer()
|
||||
);
|
||||
const match = new Match();
|
||||
for (const obj of objs) {
|
||||
match.addScore(obj.nid, 1);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
async insert(trx: DataStructRWTransaction, id: number, terms: string[]) {
|
||||
for (const term of terms) {
|
||||
await trx.objectStore('invertedIndex').add({
|
||||
key: InvertedIndexKey.forBoolean(
|
||||
this.fieldKey,
|
||||
term === 'true'
|
||||
).buffer(),
|
||||
nid: id,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class FullTextInvertedIndex implements InvertedIndex {
|
||||
constructor(readonly fieldKey: string) {}
|
||||
|
||||
async match(trx: DataStructROTransaction, term: string): Promise<Match> {
|
||||
const queryTokens = new GeneralTokenizer().tokenize(term);
|
||||
const matched = new Map<
|
||||
number,
|
||||
{
|
||||
score: number[];
|
||||
positions: Map<number, [number, number][]>;
|
||||
}
|
||||
>();
|
||||
for (const token of queryTokens) {
|
||||
const key = InvertedIndexKey.forString(this.fieldKey, token.term);
|
||||
const objs = await trx
|
||||
.objectStore('invertedIndex')
|
||||
.index('key')
|
||||
.getAll(
|
||||
IDBKeyRange.bound(key.buffer(), key.add1().buffer(), false, true)
|
||||
);
|
||||
const submatched: {
|
||||
nid: number;
|
||||
score: number;
|
||||
position: {
|
||||
index: number;
|
||||
ranges: [number, number][];
|
||||
};
|
||||
}[] = [];
|
||||
for (const obj of objs) {
|
||||
const key = InvertedIndexKey.fromBuffer(obj.key);
|
||||
const originTokenTerm = key.asString();
|
||||
const matchLength = token.term.length;
|
||||
const position = obj.pos ?? {
|
||||
i: 0,
|
||||
l: 0,
|
||||
rs: [],
|
||||
};
|
||||
const termFreq = position.rs.length;
|
||||
const totalCount = objs.length;
|
||||
const avgFieldLength =
|
||||
(
|
||||
await trx
|
||||
.objectStore('kvMetadata')
|
||||
.get(`full-text:avg-field-length:${this.fieldKey}`)
|
||||
)?.value ?? 0;
|
||||
const fieldLength = position.l;
|
||||
const score =
|
||||
bm25(termFreq, 1, totalCount, fieldLength, avgFieldLength) *
|
||||
(matchLength / originTokenTerm.length);
|
||||
const match = {
|
||||
score,
|
||||
positions: new Map(),
|
||||
};
|
||||
const ranges = match.positions.get(position.i) || [];
|
||||
ranges.push(
|
||||
...position.rs.map(([start, _end]) => [start, start + matchLength])
|
||||
);
|
||||
match.positions.set(position.i, ranges);
|
||||
submatched.push({
|
||||
nid: obj.nid,
|
||||
score,
|
||||
position: {
|
||||
index: position.i,
|
||||
ranges: position.rs.map(([start, _end]) => [
|
||||
start,
|
||||
start + matchLength,
|
||||
]),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// normalize score
|
||||
const maxScore = submatched.reduce((acc, s) => Math.max(acc, s.score), 0);
|
||||
const minScore = submatched.reduce((acc, s) => Math.min(acc, s.score), 1);
|
||||
for (const { nid, score, position } of submatched) {
|
||||
const normalizedScore = (score - minScore) / (maxScore - minScore);
|
||||
const match = matched.get(nid) || {
|
||||
score: [] as number[],
|
||||
positions: new Map(),
|
||||
};
|
||||
match.score.push(normalizedScore);
|
||||
const ranges = match.positions.get(position.index) || [];
|
||||
ranges.push(...position.ranges);
|
||||
match.positions.set(position.index, ranges);
|
||||
matched.set(nid, match);
|
||||
}
|
||||
}
|
||||
const match = new Match();
|
||||
for (const [nid, { score, positions }] of matched) {
|
||||
match.addScore(
|
||||
nid,
|
||||
score.reduce((acc, s) => acc + s, 0)
|
||||
);
|
||||
|
||||
for (const [index, ranges] of positions) {
|
||||
match.addHighlighter(nid, this.fieldKey, index, ranges);
|
||||
}
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
async all(trx: DataStructROTransaction): Promise<Match> {
|
||||
const objs = await trx
|
||||
.objectStore('invertedIndex')
|
||||
.index('key')
|
||||
.getAll(
|
||||
IDBKeyRange.bound(
|
||||
InvertedIndexKey.forPrefix(this.fieldKey).buffer(),
|
||||
InvertedIndexKey.forPrefix(this.fieldKey).add1().buffer()
|
||||
)
|
||||
);
|
||||
|
||||
const set = new Set<number>();
|
||||
for (const obj of objs) {
|
||||
set.add(obj.nid);
|
||||
}
|
||||
|
||||
const match = new Match();
|
||||
for (const nid of set) {
|
||||
match.addScore(nid, 1);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
async insert(trx: DataStructRWTransaction, id: number, terms: string[]) {
|
||||
for (let i = 0; i < terms.length; i++) {
|
||||
const tokenMap = new Map<string, Token[]>();
|
||||
const originString = terms[i];
|
||||
|
||||
const tokens = new GeneralTokenizer().tokenize(originString);
|
||||
|
||||
for (const token of tokens) {
|
||||
const tokens = tokenMap.get(token.term) || [];
|
||||
tokens.push(token);
|
||||
tokenMap.set(token.term, tokens);
|
||||
}
|
||||
|
||||
for (const [term, tokens] of tokenMap) {
|
||||
await trx.objectStore('invertedIndex').add({
|
||||
key: InvertedIndexKey.forString(this.fieldKey, term).buffer(),
|
||||
nid: id,
|
||||
pos: {
|
||||
l: originString.length,
|
||||
i: i,
|
||||
rs: tokens.map(token => [token.start, token.end]),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const kvMetadataStore = trx.objectStore('kvMetadata');
|
||||
// update avg-field-length
|
||||
const totalCount =
|
||||
(await kvMetadataStore.get(`full-text:field-count:${this.fieldKey}`))
|
||||
?.value ?? 0;
|
||||
const avgFieldLength =
|
||||
(
|
||||
await kvMetadataStore.get(
|
||||
`full-text:avg-field-length:${this.fieldKey}`
|
||||
)
|
||||
)?.value ?? 0;
|
||||
await kvMetadataStore.put({
|
||||
key: `full-text:field-count:${this.fieldKey}`,
|
||||
value: totalCount + 1,
|
||||
});
|
||||
await kvMetadataStore.put({
|
||||
key: `full-text:avg-field-length:${this.fieldKey}`,
|
||||
value:
|
||||
avgFieldLength +
|
||||
(terms.reduce((acc, term) => acc + term.length, 0) - avgFieldLength) /
|
||||
(totalCount + 1),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class InvertedIndexKey {
|
||||
constructor(
|
||||
readonly field: ArrayBuffer,
|
||||
readonly value: ArrayBuffer,
|
||||
readonly gap: ArrayBuffer = new Uint8Array([58])
|
||||
) {}
|
||||
|
||||
asString() {
|
||||
return new TextDecoder().decode(this.value);
|
||||
}
|
||||
|
||||
asInt64() {
|
||||
return new DataView(this.value).getBigInt64(0, false); /* big-endian */
|
||||
}
|
||||
|
||||
add1() {
|
||||
if (this.value.byteLength > 0) {
|
||||
const bytes = new Uint8Array(this.value.slice(0));
|
||||
let carry = 1;
|
||||
for (let i = bytes.length - 1; i >= 0 && carry > 0; i--) {
|
||||
const sum = bytes[i] + carry;
|
||||
bytes[i] = sum % 256;
|
||||
carry = sum >> 8;
|
||||
}
|
||||
return new InvertedIndexKey(this.field, bytes);
|
||||
} else {
|
||||
return new InvertedIndexKey(
|
||||
this.field,
|
||||
new ArrayBuffer(0),
|
||||
new Uint8Array([59])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
static forPrefix(field: string) {
|
||||
return new InvertedIndexKey(
|
||||
new TextEncoder().encode(field),
|
||||
new ArrayBuffer(0)
|
||||
);
|
||||
}
|
||||
|
||||
static forString(field: string, value: string) {
|
||||
return new InvertedIndexKey(
|
||||
new TextEncoder().encode(field),
|
||||
new TextEncoder().encode(value)
|
||||
);
|
||||
}
|
||||
|
||||
static forBoolean(field: string, value: boolean) {
|
||||
const bytes = new Uint8Array(1);
|
||||
bytes.set([value ? 1 : 0]);
|
||||
return new InvertedIndexKey(new TextEncoder().encode(field), bytes);
|
||||
}
|
||||
|
||||
static forInt64(field: string, value: bigint) {
|
||||
const bytes = new ArrayBuffer(8);
|
||||
new DataView(bytes).setBigInt64(0, value, false); /* big-endian */
|
||||
return new InvertedIndexKey(new TextEncoder().encode(field), bytes);
|
||||
}
|
||||
|
||||
buffer() {
|
||||
const tmp = new Uint8Array(
|
||||
this.field.byteLength + (this.value?.byteLength ?? 0) + 1
|
||||
);
|
||||
tmp.set(new Uint8Array(this.field), 0);
|
||||
tmp.set(new Uint8Array(this.gap), this.field.byteLength);
|
||||
if (this.value.byteLength > 0) {
|
||||
tmp.set(new Uint8Array(this.value), this.field.byteLength + 1);
|
||||
}
|
||||
return tmp.buffer;
|
||||
}
|
||||
|
||||
static fromBuffer(buffer: ArrayBuffer) {
|
||||
const array = new Uint8Array(buffer);
|
||||
const fieldLength = array.indexOf(58);
|
||||
const field = array.slice(0, fieldLength);
|
||||
const value = array.slice(fieldLength + 1);
|
||||
return new InvertedIndexKey(field, value);
|
||||
}
|
||||
}
|
||||
127
packages/common/infra/src/sync/indexer/impl/indexeddb/match.ts
Normal file
127
packages/common/infra/src/sync/indexer/impl/indexeddb/match.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
export class Match {
|
||||
scores = new Map<number, number>();
|
||||
/**
|
||||
* nid -> field -> index(multi value field) -> [start, end][]
|
||||
*/
|
||||
highlighters = new Map<
|
||||
number,
|
||||
Map<string, Map<number, [number, number][]>>
|
||||
>();
|
||||
|
||||
constructor() {}
|
||||
|
||||
size() {
|
||||
return this.scores.size;
|
||||
}
|
||||
|
||||
getScore(id: number) {
|
||||
return this.scores.get(id) ?? 0;
|
||||
}
|
||||
|
||||
addScore(id: number, score: number) {
|
||||
const currentScore = this.scores.get(id) || 0;
|
||||
this.scores.set(id, currentScore + score);
|
||||
}
|
||||
|
||||
getHighlighters(id: number, field: string) {
|
||||
return this.highlighters.get(id)?.get(field);
|
||||
}
|
||||
|
||||
addHighlighter(
|
||||
id: number,
|
||||
field: string,
|
||||
index: number,
|
||||
newRanges: [number, number][]
|
||||
) {
|
||||
const fields =
|
||||
this.highlighters.get(id) ||
|
||||
new Map<string, Map<number, [number, number][]>>();
|
||||
const values = fields.get(field) || new Map<number, [number, number][]>();
|
||||
const ranges = values.get(index) || [];
|
||||
ranges.push(...newRanges);
|
||||
values.set(index, ranges);
|
||||
fields.set(field, values);
|
||||
this.highlighters.set(id, fields);
|
||||
}
|
||||
|
||||
and(other: Match) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
if (other.scores.has(id)) {
|
||||
newWeight.addScore(id, score + (other.scores.get(id) ?? 0));
|
||||
newWeight.copyExtData(this, id);
|
||||
newWeight.copyExtData(other, id);
|
||||
}
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
or(other: Match) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
newWeight.addScore(id, score);
|
||||
newWeight.copyExtData(this, id);
|
||||
}
|
||||
for (const [id, score] of other.scores) {
|
||||
newWeight.addScore(id, score);
|
||||
newWeight.copyExtData(other, id);
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
exclude(other: Match) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
if (!other.scores.has(id)) {
|
||||
newWeight.addScore(id, score);
|
||||
newWeight.copyExtData(this, id);
|
||||
}
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
boost(boost: number) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
newWeight.addScore(id, score * boost);
|
||||
newWeight.copyExtData(this, id);
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
toArray() {
|
||||
return Array.from(this.scores.entries())
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
.map(e => e[0]);
|
||||
}
|
||||
|
||||
filter(predicate: (id: number) => boolean) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
if (predicate(id)) {
|
||||
newWeight.addScore(id, score);
|
||||
newWeight.copyExtData(this, id);
|
||||
}
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
async asyncFilter(predicate: (id: number) => Promise<boolean>) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
if (await predicate(id)) {
|
||||
newWeight.addScore(id, score);
|
||||
newWeight.copyExtData(this, id);
|
||||
}
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
private copyExtData(from: Match, id: number) {
|
||||
for (const [field, values] of from.highlighters.get(id) ?? []) {
|
||||
for (const [index, ranges] of values) {
|
||||
this.addHighlighter(id, field, index, ranges);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
import Graphemer from 'graphemer';
|
||||
|
||||
export interface Tokenizer {
|
||||
tokenize(text: string): Token[];
|
||||
}
|
||||
|
||||
export interface Token {
|
||||
term: string;
|
||||
start: number;
|
||||
end: number;
|
||||
}
|
||||
|
||||
export class SimpleTokenizer implements Tokenizer {
|
||||
tokenize(text: string): Token[] {
|
||||
const tokens: Token[] = [];
|
||||
let start = 0;
|
||||
let end = 0;
|
||||
let inWord = false;
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
const c = text[i];
|
||||
if (c.match(/[\n\r\p{Z}\p{P}]/u)) {
|
||||
if (inWord) {
|
||||
end = i;
|
||||
tokens.push({
|
||||
term: text.substring(start, end).toLowerCase(),
|
||||
start,
|
||||
end,
|
||||
});
|
||||
inWord = false;
|
||||
}
|
||||
} else {
|
||||
if (!inWord) {
|
||||
start = i;
|
||||
end = i;
|
||||
inWord = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (inWord) {
|
||||
tokens.push({
|
||||
term: text.substring(start).toLowerCase(),
|
||||
start,
|
||||
end: text.length,
|
||||
});
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
}
|
||||
|
||||
export class NGramTokenizer implements Tokenizer {
|
||||
constructor(private readonly n: number) {}
|
||||
|
||||
tokenize(text: string): Token[] {
|
||||
const splitted: Token[] = [];
|
||||
for (let i = 0; i < text.length; ) {
|
||||
const nextBreak = Graphemer.nextBreak(text, i);
|
||||
const c = text.substring(i, nextBreak);
|
||||
|
||||
splitted.push({
|
||||
term: c,
|
||||
start: i,
|
||||
end: nextBreak,
|
||||
});
|
||||
|
||||
i = nextBreak;
|
||||
}
|
||||
const tokens: Token[] = [];
|
||||
for (let i = 0; i < splitted.length - this.n + 1; i++) {
|
||||
tokens.push(
|
||||
splitted.slice(i, i + this.n).reduce(
|
||||
(acc, t) => ({
|
||||
term: acc.term + t.term,
|
||||
start: Math.min(acc.start, t.start),
|
||||
end: Math.max(acc.end, t.end),
|
||||
}),
|
||||
{ term: '', start: Infinity, end: -Infinity }
|
||||
)
|
||||
);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
}
|
||||
|
||||
export class GeneralTokenizer implements Tokenizer {
|
||||
constructor() {}
|
||||
|
||||
tokenizeWord(word: string, lang: string): Token[] {
|
||||
if (lang === 'en') {
|
||||
return [{ term: word.toLowerCase(), start: 0, end: word.length }];
|
||||
} else if (lang === 'cjk') {
|
||||
if (word.length < 3) {
|
||||
return [{ term: word, start: 0, end: word.length }];
|
||||
}
|
||||
return new NGramTokenizer(2).tokenize(word);
|
||||
} else if (lang === 'emoji') {
|
||||
return new NGramTokenizer(1).tokenize(word);
|
||||
} else if (lang === '-') {
|
||||
return [];
|
||||
}
|
||||
|
||||
throw new Error('Not implemented');
|
||||
}
|
||||
|
||||
testLang(c: string): string {
|
||||
if (c.match(/[\p{Emoji}]/u)) {
|
||||
return 'emoji';
|
||||
} else if (c.match(/[\p{sc=Han}\p{scx=Hira}\p{scx=Kana}\p{sc=Hang}]/u)) {
|
||||
return 'cjk';
|
||||
} else if (c.match(/[\n\r\p{Z}\p{P}]/u)) {
|
||||
return '-';
|
||||
} else {
|
||||
return 'en';
|
||||
}
|
||||
}
|
||||
|
||||
tokenize(text: string): Token[] {
|
||||
const tokens: Token[] = [];
|
||||
let start = 0;
|
||||
let end = 0;
|
||||
let lang: string | null = null;
|
||||
|
||||
for (let i = 0; i < text.length; ) {
|
||||
const nextBreak = Graphemer.nextBreak(text, i);
|
||||
const c = text.substring(i, nextBreak);
|
||||
|
||||
const l = this.testLang(c);
|
||||
if (lang !== l) {
|
||||
if (lang !== null) {
|
||||
end = i;
|
||||
tokens.push(
|
||||
...this.tokenizeWord(text.substring(start, end), lang).map(
|
||||
token => ({
|
||||
...token,
|
||||
start: token.start + start,
|
||||
end: token.end + start,
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
start = i;
|
||||
end = i;
|
||||
lang = l;
|
||||
}
|
||||
|
||||
i = nextBreak;
|
||||
}
|
||||
if (lang !== null) {
|
||||
tokens.push(
|
||||
...this.tokenizeWord(text.substring(start, text.length), lang).map(
|
||||
token => ({
|
||||
...token,
|
||||
start: token.start + start,
|
||||
end: token.end + start,
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,282 @@
|
||||
import {
|
||||
type AggregateOptions,
|
||||
type AggregateResult,
|
||||
Document,
|
||||
type Query,
|
||||
type Schema,
|
||||
type SearchOptions,
|
||||
type SearchResult,
|
||||
} from '../../';
|
||||
import {
|
||||
BooleanInvertedIndex,
|
||||
FullTextInvertedIndex,
|
||||
IntegerInvertedIndex,
|
||||
type InvertedIndex,
|
||||
StringInvertedIndex,
|
||||
} from './inverted-index';
|
||||
import { Match } from './match';
|
||||
|
||||
type DataRecord = {
|
||||
id: string;
|
||||
data: Map<string, string[]>;
|
||||
deleted: boolean;
|
||||
};
|
||||
|
||||
export class DataStruct {
|
||||
records: DataRecord[] = [];
|
||||
|
||||
idMap = new Map<string, number>();
|
||||
|
||||
invertedIndex = new Map<string, InvertedIndex>();
|
||||
|
||||
constructor(schema: Schema) {
|
||||
for (const [key, type] of Object.entries(schema)) {
|
||||
if (type === 'String') {
|
||||
this.invertedIndex.set(key, new StringInvertedIndex(key));
|
||||
} else if (type === 'Integer') {
|
||||
this.invertedIndex.set(key, new IntegerInvertedIndex(key));
|
||||
} else if (type === 'FullText') {
|
||||
this.invertedIndex.set(key, new FullTextInvertedIndex(key));
|
||||
} else if (type === 'Boolean') {
|
||||
this.invertedIndex.set(key, new BooleanInvertedIndex(key));
|
||||
} else {
|
||||
throw new Error(`Field type '${type}' not supported`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getAll(ids: string[]): Document[] {
|
||||
return ids
|
||||
.map(id => {
|
||||
const nid = this.idMap.get(id);
|
||||
if (nid === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
return Document.from(id, this.records[nid].data);
|
||||
})
|
||||
.filter((v): v is Document => v !== undefined);
|
||||
}
|
||||
|
||||
insert(document: Document) {
|
||||
if (this.idMap.has(document.id)) {
|
||||
throw new Error('Document already exists');
|
||||
}
|
||||
|
||||
this.records.push({
|
||||
id: document.id,
|
||||
data: document.fields as Map<string, string[]>,
|
||||
deleted: false,
|
||||
});
|
||||
|
||||
const nid = this.records.length - 1;
|
||||
this.idMap.set(document.id, nid);
|
||||
for (const [key, values] of document.fields) {
|
||||
for (const value of values) {
|
||||
const iidx = this.invertedIndex.get(key as string);
|
||||
if (!iidx) {
|
||||
throw new Error(
|
||||
`Inverted index '${key.toString()}' not found, document not match schema`
|
||||
);
|
||||
}
|
||||
iidx.insert(nid, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete(id: string) {
|
||||
const nid = this.idMap.get(id);
|
||||
if (nid === undefined) {
|
||||
throw new Error('Document not found');
|
||||
}
|
||||
|
||||
this.records[nid].deleted = true;
|
||||
this.records[nid].data = new Map();
|
||||
}
|
||||
|
||||
matchAll(): Match {
|
||||
const weight = new Match();
|
||||
for (let i = 0; i < this.records.length; i++) {
|
||||
weight.addScore(i, 1);
|
||||
}
|
||||
return weight;
|
||||
}
|
||||
|
||||
clear() {
|
||||
this.records = [];
|
||||
this.idMap.clear();
|
||||
this.invertedIndex.forEach(v => v.clear());
|
||||
}
|
||||
|
||||
private queryRaw(query: Query<any>): Match {
|
||||
if (query.type === 'match') {
|
||||
const iidx = this.invertedIndex.get(query.field as string);
|
||||
if (!iidx) {
|
||||
throw new Error(`Field '${query.field as string}' not found`);
|
||||
}
|
||||
return iidx.match(query.match);
|
||||
} else if (query.type === 'boolean') {
|
||||
const weights = query.queries.map(q => this.queryRaw(q));
|
||||
if (query.occur === 'must') {
|
||||
return weights.reduce((acc, w) => acc.and(w));
|
||||
} else if (query.occur === 'must_not') {
|
||||
const total = weights.reduce((acc, w) => acc.and(w));
|
||||
return this.matchAll().exclude(total);
|
||||
} else if (query.occur === 'should') {
|
||||
return weights.reduce((acc, w) => acc.or(w));
|
||||
}
|
||||
} else if (query.type === 'all') {
|
||||
return this.matchAll();
|
||||
} else if (query.type === 'boost') {
|
||||
return this.queryRaw(query.query).boost(query.boost);
|
||||
} else if (query.type === 'exists') {
|
||||
const iidx = this.invertedIndex.get(query.field as string);
|
||||
if (!iidx) {
|
||||
throw new Error(`Field '${query.field as string}' not found`);
|
||||
}
|
||||
return iidx.all();
|
||||
}
|
||||
throw new Error(`Query type '${query.type}' not supported`);
|
||||
}
|
||||
|
||||
query(query: Query<any>): Match {
|
||||
return this.queryRaw(query).filter(id => !this.records[id].deleted);
|
||||
}
|
||||
|
||||
search(
|
||||
query: Query<any>,
|
||||
options: SearchOptions<any> = {}
|
||||
): SearchResult<any, any> {
|
||||
const pagination = {
|
||||
skip: options.pagination?.skip ?? 0,
|
||||
limit: options.pagination?.limit ?? 100,
|
||||
};
|
||||
|
||||
const match = this.query(query);
|
||||
|
||||
const nids = match
|
||||
.toArray()
|
||||
.slice(pagination.skip, pagination.skip + pagination.limit);
|
||||
|
||||
return {
|
||||
pagination: {
|
||||
count: match.size(),
|
||||
hasMore: match.size() > pagination.limit + pagination.skip,
|
||||
limit: pagination.limit,
|
||||
skip: pagination.skip,
|
||||
},
|
||||
nodes: nids.map(nid => this.resultNode(match, nid, options)),
|
||||
};
|
||||
}
|
||||
|
||||
aggregate(
|
||||
query: Query<any>,
|
||||
field: string,
|
||||
options: AggregateOptions<any> = {}
|
||||
): AggregateResult<any, any> {
|
||||
const pagination = {
|
||||
skip: options.pagination?.skip ?? 0,
|
||||
limit: options.pagination?.limit ?? 100,
|
||||
};
|
||||
|
||||
const match = this.query(query);
|
||||
|
||||
const nids = match.toArray();
|
||||
|
||||
const buckets: { key: string; nids: number[] }[] = [];
|
||||
|
||||
for (const nid of nids) {
|
||||
for (const value of this.records[nid].data.get(field) ?? []) {
|
||||
let bucket = buckets.find(b => b.key === value);
|
||||
if (!bucket) {
|
||||
bucket = { key: value, nids: [] };
|
||||
buckets.push(bucket);
|
||||
}
|
||||
bucket.nids.push(nid);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
buckets: buckets
|
||||
.slice(pagination.skip, pagination.skip + pagination.limit)
|
||||
.map(bucket => {
|
||||
const result = {
|
||||
key: bucket.key,
|
||||
score: match.getScore(bucket.nids[0]),
|
||||
count: bucket.nids.length,
|
||||
} as AggregateResult<any, any>['buckets'][number];
|
||||
|
||||
if (options.hits) {
|
||||
const hitsOptions = options.hits;
|
||||
const pagination = {
|
||||
skip: options.hits.pagination?.skip ?? 0,
|
||||
limit: options.hits.pagination?.limit ?? 3,
|
||||
};
|
||||
|
||||
const hits = bucket.nids.slice(
|
||||
pagination.skip,
|
||||
pagination.skip + pagination.limit
|
||||
);
|
||||
|
||||
(result as any).hits = {
|
||||
pagination: {
|
||||
count: bucket.nids.length,
|
||||
hasMore:
|
||||
bucket.nids.length > pagination.limit + pagination.skip,
|
||||
limit: pagination.limit,
|
||||
skip: pagination.skip,
|
||||
},
|
||||
nodes: hits.map(nid => this.resultNode(match, nid, hitsOptions)),
|
||||
} as SearchResult<any, any>;
|
||||
}
|
||||
|
||||
return result;
|
||||
}),
|
||||
pagination: {
|
||||
count: buckets.length,
|
||||
hasMore: buckets.length > pagination.limit + pagination.skip,
|
||||
limit: pagination.limit,
|
||||
skip: pagination.skip,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
has(id: string): boolean {
|
||||
return this.idMap.has(id);
|
||||
}
|
||||
|
||||
private resultNode(
|
||||
match: Match,
|
||||
nid: number,
|
||||
options: SearchOptions<any>
|
||||
): SearchResult<any, any>['nodes'][number] {
|
||||
const node = {
|
||||
id: this.records[nid].id,
|
||||
score: match.getScore(nid),
|
||||
} as any;
|
||||
|
||||
if (options.fields) {
|
||||
const fields = {} as Record<string, string | string[]>;
|
||||
for (const field of options.fields as string[]) {
|
||||
fields[field] = this.records[nid].data.get(field) ?? [''];
|
||||
if (fields[field].length === 1) {
|
||||
fields[field] = fields[field][0];
|
||||
}
|
||||
}
|
||||
node.fields = fields;
|
||||
}
|
||||
|
||||
if (options.highlights) {
|
||||
const highlights = {} as Record<string, string[]>;
|
||||
for (const { field, before, end } of options.highlights) {
|
||||
highlights[field] = match
|
||||
.getHighlighters(nid, field)
|
||||
.flatMap(highlighter => {
|
||||
return highlighter(before, end);
|
||||
});
|
||||
}
|
||||
node.highlights = highlights;
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
}
|
||||
141
packages/common/infra/src/sync/indexer/impl/memory/index.ts
Normal file
141
packages/common/infra/src/sync/indexer/impl/memory/index.ts
Normal file
@@ -0,0 +1,141 @@
|
||||
import { map, merge, type Observable, of, Subject, throttleTime } from 'rxjs';
|
||||
|
||||
import type {
|
||||
AggregateOptions,
|
||||
AggregateResult,
|
||||
Document,
|
||||
Index,
|
||||
IndexStorage,
|
||||
IndexWriter,
|
||||
Query,
|
||||
Schema,
|
||||
SearchOptions,
|
||||
SearchResult,
|
||||
} from '../../';
|
||||
import { DataStruct } from './data-struct';
|
||||
|
||||
export class MemoryIndex<S extends Schema> implements Index<S> {
|
||||
private readonly data: DataStruct = new DataStruct(this.schema);
|
||||
broadcast$ = new Subject<number>();
|
||||
|
||||
constructor(private readonly schema: Schema) {}
|
||||
|
||||
write(): Promise<IndexWriter<S>> {
|
||||
return Promise.resolve(new MemoryIndexWriter(this.data, this.broadcast$));
|
||||
}
|
||||
|
||||
async get(id: string): Promise<Document<S> | null> {
|
||||
return (await this.getAll([id]))[0] ?? null;
|
||||
}
|
||||
|
||||
getAll(ids: string[]): Promise<Document<S>[]> {
|
||||
return Promise.resolve(this.data.getAll(ids));
|
||||
}
|
||||
|
||||
has(id: string): Promise<boolean> {
|
||||
return Promise.resolve(this.data.has(id));
|
||||
}
|
||||
|
||||
async search(
|
||||
query: Query<any>,
|
||||
options: SearchOptions<any> = {}
|
||||
): Promise<SearchResult<any, any>> {
|
||||
return this.data.search(query, options);
|
||||
}
|
||||
|
||||
search$(
|
||||
query: Query<any>,
|
||||
options: SearchOptions<any> = {}
|
||||
): Observable<SearchResult<any, any>> {
|
||||
return merge(of(1), this.broadcast$).pipe(
|
||||
throttleTime(500, undefined, { leading: false, trailing: true }),
|
||||
map(() => this.data.search(query, options))
|
||||
);
|
||||
}
|
||||
|
||||
async aggregate(
|
||||
query: Query<any>,
|
||||
field: string,
|
||||
options: AggregateOptions<any> = {}
|
||||
): Promise<AggregateResult<any, any>> {
|
||||
return this.data.aggregate(query, field, options);
|
||||
}
|
||||
|
||||
aggregate$(
|
||||
query: Query<any>,
|
||||
field: string,
|
||||
options: AggregateOptions<any> = {}
|
||||
): Observable<AggregateResult<S, AggregateOptions<any>>> {
|
||||
return merge(of(1), this.broadcast$).pipe(
|
||||
throttleTime(500, undefined, { leading: false, trailing: true }),
|
||||
map(() => this.data.aggregate(query, field, options))
|
||||
);
|
||||
}
|
||||
|
||||
clear(): Promise<void> {
|
||||
this.data.clear();
|
||||
return Promise.resolve();
|
||||
}
|
||||
}
|
||||
|
||||
export class MemoryIndexWriter<S extends Schema> implements IndexWriter<S> {
|
||||
inserts: Document[] = [];
|
||||
deletes: string[] = [];
|
||||
|
||||
constructor(
|
||||
private readonly data: DataStruct,
|
||||
private readonly broadcast$: Subject<number>
|
||||
) {}
|
||||
|
||||
async get(id: string): Promise<Document<S> | null> {
|
||||
return (await this.getAll([id]))[0] ?? null;
|
||||
}
|
||||
|
||||
getAll(ids: string[]): Promise<Document<S>[]> {
|
||||
return Promise.resolve(this.data.getAll(ids));
|
||||
}
|
||||
|
||||
insert(document: Document): void {
|
||||
this.inserts.push(document);
|
||||
}
|
||||
delete(id: string): void {
|
||||
this.deletes.push(id);
|
||||
}
|
||||
put(document: Document): void {
|
||||
this.delete(document.id);
|
||||
this.insert(document);
|
||||
}
|
||||
async search(
|
||||
query: Query<any>,
|
||||
options: SearchOptions<any> = {}
|
||||
): Promise<SearchResult<any, any>> {
|
||||
return this.data.search(query, options);
|
||||
}
|
||||
async aggregate(
|
||||
query: Query<any>,
|
||||
field: string,
|
||||
options: AggregateOptions<any> = {}
|
||||
): Promise<AggregateResult<any, any>> {
|
||||
return this.data.aggregate(query, field, options);
|
||||
}
|
||||
commit(): Promise<void> {
|
||||
for (const del of this.deletes) {
|
||||
this.data.delete(del);
|
||||
}
|
||||
for (const inst of this.inserts) {
|
||||
this.data.insert(inst);
|
||||
}
|
||||
this.broadcast$.next(1);
|
||||
return Promise.resolve();
|
||||
}
|
||||
rollback(): void {}
|
||||
has(id: string): Promise<boolean> {
|
||||
return Promise.resolve(this.data.has(id));
|
||||
}
|
||||
}
|
||||
|
||||
export class MemoryIndexStorage implements IndexStorage {
|
||||
getIndex<S extends Schema>(_: string, schema: S): Index<S> {
|
||||
return new MemoryIndex(schema);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
import Fuse from 'fuse.js';
|
||||
|
||||
import { Match } from './match';
|
||||
|
||||
export interface InvertedIndex {
|
||||
fieldKey: string;
|
||||
|
||||
match(term: string): Match;
|
||||
|
||||
all(): Match;
|
||||
|
||||
insert(id: number, term: string): void;
|
||||
|
||||
clear(): void;
|
||||
}
|
||||
|
||||
export class StringInvertedIndex implements InvertedIndex {
|
||||
index: Map<string, number[]> = new Map();
|
||||
|
||||
constructor(readonly fieldKey: string) {}
|
||||
|
||||
match(term: string): Match {
|
||||
const match = new Match();
|
||||
|
||||
for (const id of this.index.get(term) ?? []) {
|
||||
match.addScore(id, 1);
|
||||
}
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
all(): Match {
|
||||
const match = new Match();
|
||||
|
||||
for (const [_term, ids] of this.index) {
|
||||
for (const id of ids) {
|
||||
if (match.getScore(id) === 0) {
|
||||
match.addScore(id, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
insert(id: number, term: string): void {
|
||||
const ids = this.index.get(term) ?? [];
|
||||
ids.push(id);
|
||||
this.index.set(term, ids);
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.index.clear();
|
||||
}
|
||||
}
|
||||
|
||||
export class IntegerInvertedIndex implements InvertedIndex {
|
||||
index: Map<string, number[]> = new Map();
|
||||
|
||||
constructor(readonly fieldKey: string) {}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
match(term: string): Match {
|
||||
const match = new Match();
|
||||
|
||||
for (const id of this.index.get(term) ?? []) {
|
||||
match.addScore(id, 1);
|
||||
}
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
all(): Match {
|
||||
const match = new Match();
|
||||
|
||||
for (const [_term, ids] of this.index) {
|
||||
for (const id of ids) {
|
||||
if (match.getScore(id) === 0) {
|
||||
match.addScore(id, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
insert(id: number, term: string): void {
|
||||
const ids = this.index.get(term) ?? [];
|
||||
ids.push(id);
|
||||
this.index.set(term, ids);
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.index.clear();
|
||||
}
|
||||
}
|
||||
|
||||
export class BooleanInvertedIndex implements InvertedIndex {
|
||||
index: Map<boolean, number[]> = new Map();
|
||||
|
||||
constructor(readonly fieldKey: string) {}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
match(term: string): Match {
|
||||
const match = new Match();
|
||||
|
||||
for (const id of this.index.get(term === 'true') ?? []) {
|
||||
match.addScore(id, 1);
|
||||
}
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
all(): Match {
|
||||
const match = new Match();
|
||||
|
||||
for (const [_term, ids] of this.index) {
|
||||
for (const id of ids) {
|
||||
if (match.getScore(id) === 0) {
|
||||
match.addScore(id, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
insert(id: number, term: string): void {
|
||||
const ids = this.index.get(term === 'true') ?? [];
|
||||
ids.push(id);
|
||||
this.index.set(term === 'true', ids);
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.index.clear();
|
||||
}
|
||||
}
|
||||
|
||||
export class FullTextInvertedIndex implements InvertedIndex {
|
||||
records = [] as { id: number; v: string }[];
|
||||
index = Fuse.createIndex(['v'], [] as { id: number; v: string }[]);
|
||||
|
||||
constructor(readonly fieldKey: string) {}
|
||||
|
||||
match(term: string): Match {
|
||||
const searcher = new Fuse(
|
||||
this.records,
|
||||
{
|
||||
includeScore: true,
|
||||
includeMatches: true,
|
||||
shouldSort: true,
|
||||
keys: ['v'],
|
||||
},
|
||||
this.index
|
||||
);
|
||||
const result = searcher.search(term);
|
||||
|
||||
const match = new Match();
|
||||
|
||||
for (const value of result) {
|
||||
match.addScore(value.item.id, 1 - (value.score ?? 1));
|
||||
|
||||
match.addHighlighter(value.item.id, this.fieldKey, (before, after) => {
|
||||
const matches = value.matches;
|
||||
if (!matches || matches.length === 0) {
|
||||
return [''];
|
||||
}
|
||||
|
||||
const firstMatch = matches[0];
|
||||
|
||||
const text = firstMatch.value;
|
||||
if (!text) {
|
||||
return [''];
|
||||
}
|
||||
|
||||
let result = '';
|
||||
let pointer = 0;
|
||||
for (const match of matches) {
|
||||
for (const [start, end] of match.indices) {
|
||||
result += text.substring(pointer, start);
|
||||
result += `${before}${text.substring(start, end + 1)}${after}`;
|
||||
pointer = end + 1;
|
||||
}
|
||||
}
|
||||
result += text.substring(pointer);
|
||||
|
||||
return [result];
|
||||
});
|
||||
}
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-identical-functions
|
||||
all(): Match {
|
||||
const match = new Match();
|
||||
|
||||
for (const { id } of this.records) {
|
||||
if (match.getScore(id) === 0) {
|
||||
match.addScore(id, 1);
|
||||
}
|
||||
}
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
insert(id: number, term: string): void {
|
||||
this.index.add({ id, v: term });
|
||||
this.records.push({ id, v: term });
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.records = [];
|
||||
this.index = Fuse.createIndex(['v'], [] as { id: number; v: string }[]);
|
||||
}
|
||||
}
|
||||
108
packages/common/infra/src/sync/indexer/impl/memory/match.ts
Normal file
108
packages/common/infra/src/sync/indexer/impl/memory/match.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
export class Match {
|
||||
scores = new Map<number, number>();
|
||||
highlighters = new Map<
|
||||
number,
|
||||
Map<string, ((before: string, after: string) => string[])[]>
|
||||
>();
|
||||
|
||||
constructor() {}
|
||||
|
||||
size() {
|
||||
return this.scores.size;
|
||||
}
|
||||
|
||||
getScore(id: number) {
|
||||
return this.scores.get(id) ?? 0;
|
||||
}
|
||||
|
||||
addScore(id: number, score: number) {
|
||||
const currentScore = this.scores.get(id) || 0;
|
||||
this.scores.set(id, currentScore + score);
|
||||
}
|
||||
|
||||
getHighlighters(id: number, field: string) {
|
||||
return this.highlighters.get(id)?.get(field) ?? [];
|
||||
}
|
||||
|
||||
addHighlighter(
|
||||
id: number,
|
||||
field: string,
|
||||
highlighter: (before: string, after: string) => string[]
|
||||
) {
|
||||
const fields = this.highlighters.get(id) || new Map();
|
||||
const highlighters = fields.get(field) || [];
|
||||
highlighters.push(highlighter);
|
||||
fields.set(field, highlighters);
|
||||
this.highlighters.set(id, fields);
|
||||
}
|
||||
|
||||
and(other: Match) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
if (other.scores.has(id)) {
|
||||
newWeight.addScore(id, score + (other.scores.get(id) ?? 0));
|
||||
newWeight.copyExtData(this, id);
|
||||
newWeight.copyExtData(other, id);
|
||||
}
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
or(other: Match) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
newWeight.addScore(id, score);
|
||||
newWeight.copyExtData(this, id);
|
||||
}
|
||||
for (const [id, score] of other.scores) {
|
||||
newWeight.addScore(id, score);
|
||||
newWeight.copyExtData(other, id);
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
exclude(other: Match) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
if (!other.scores.has(id)) {
|
||||
newWeight.addScore(id, score);
|
||||
newWeight.copyExtData(this, id);
|
||||
}
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
boost(boost: number) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
newWeight.addScore(id, score * boost);
|
||||
newWeight.copyExtData(this, id);
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
toArray() {
|
||||
return Array.from(this.scores.entries())
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
.map(e => e[0]);
|
||||
}
|
||||
|
||||
filter(predicate: (id: number) => boolean) {
|
||||
const newWeight = new Match();
|
||||
for (const [id, score] of this.scores) {
|
||||
if (predicate(id)) {
|
||||
newWeight.addScore(id, score);
|
||||
newWeight.copyExtData(this, id);
|
||||
}
|
||||
}
|
||||
return newWeight;
|
||||
}
|
||||
|
||||
private copyExtData(from: Match, id: number) {
|
||||
for (const [field, highlighters] of from.highlighters.get(id) ?? []) {
|
||||
for (const highlighter of highlighters) {
|
||||
this.addHighlighter(id, field, highlighter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
6
packages/common/infra/src/sync/indexer/index.ts
Normal file
6
packages/common/infra/src/sync/indexer/index.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
export * from './document';
|
||||
export * from './field-type';
|
||||
export * from './indexer';
|
||||
export * from './query';
|
||||
export * from './schema';
|
||||
export * from './searcher';
|
||||
41
packages/common/infra/src/sync/indexer/indexer.ts
Normal file
41
packages/common/infra/src/sync/indexer/indexer.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
import type { Document } from './document';
|
||||
import type { Schema } from './schema';
|
||||
import type { Searcher, Subscriber } from './searcher';
|
||||
|
||||
export interface Index<S extends Schema>
|
||||
extends IndexReader<S>,
|
||||
Searcher<S>,
|
||||
Subscriber<S> {
|
||||
write(): Promise<IndexWriter<S>>;
|
||||
|
||||
clear(): Promise<void>;
|
||||
}
|
||||
|
||||
export interface IndexWriter<S extends Schema>
|
||||
extends IndexReader<S>,
|
||||
Searcher<S> {
|
||||
insert(document: Document<S>): void;
|
||||
|
||||
put(document: Document<S>): void;
|
||||
|
||||
delete(id: string): void;
|
||||
|
||||
// TODO(@eyhn)
|
||||
// deleteByQuery(query: Query<S>): void;
|
||||
|
||||
commit(): Promise<void>;
|
||||
|
||||
rollback(): void;
|
||||
}
|
||||
|
||||
export interface IndexReader<S extends Schema> {
|
||||
get(id: string): Promise<Document<S> | null>;
|
||||
|
||||
getAll(ids: string[]): Promise<Document<S>[]>;
|
||||
|
||||
has(id: string): Promise<boolean>;
|
||||
}
|
||||
|
||||
export interface IndexStorage {
|
||||
getIndex<S extends Schema>(name: string, schema: S): Index<S>;
|
||||
}
|
||||
35
packages/common/infra/src/sync/indexer/query.ts
Normal file
35
packages/common/infra/src/sync/indexer/query.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import type { Schema } from './schema';
|
||||
|
||||
export type MatchQuery<S extends Schema> = {
|
||||
type: 'match';
|
||||
field: keyof S;
|
||||
match: string;
|
||||
};
|
||||
|
||||
export type BoostQuery = {
|
||||
type: 'boost';
|
||||
query: Query<any>;
|
||||
boost: number;
|
||||
};
|
||||
|
||||
export type BooleanQuery<S extends Schema> = {
|
||||
type: 'boolean';
|
||||
occur: 'should' | 'must' | 'must_not';
|
||||
queries: Query<S>[];
|
||||
};
|
||||
|
||||
export type ExistsQuery<S extends Schema> = {
|
||||
type: 'exists';
|
||||
field: keyof S;
|
||||
};
|
||||
|
||||
export type AllQuery = {
|
||||
type: 'all';
|
||||
};
|
||||
|
||||
export type Query<S extends Schema> =
|
||||
| BooleanQuery<S>
|
||||
| MatchQuery<S>
|
||||
| AllQuery
|
||||
| ExistsQuery<S>
|
||||
| BoostQuery;
|
||||
7
packages/common/infra/src/sync/indexer/schema.ts
Normal file
7
packages/common/infra/src/sync/indexer/schema.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import type { FieldType } from './field-type';
|
||||
|
||||
export type Schema = Record<string, FieldType>;
|
||||
|
||||
export function defineSchema<T extends Schema>(schema: T): T {
|
||||
return schema;
|
||||
}
|
||||
83
packages/common/infra/src/sync/indexer/searcher.ts
Normal file
83
packages/common/infra/src/sync/indexer/searcher.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import type { Observable } from 'rxjs';
|
||||
|
||||
import type { Query } from './query';
|
||||
import type { Schema } from './schema';
|
||||
|
||||
type HighlightAbleField<S extends Schema> = {
|
||||
[K in keyof S]: S[K] extends 'FullText' ? K : never;
|
||||
}[keyof S];
|
||||
|
||||
export interface Searcher<S extends Schema = any> {
|
||||
search<const O extends SearchOptions<S>>(
|
||||
query: Query<S>,
|
||||
options?: O
|
||||
): Promise<SearchResult<S, O>>;
|
||||
aggregate<const O extends AggregateOptions<S>>(
|
||||
query: Query<S>,
|
||||
field: keyof S,
|
||||
options?: O
|
||||
): Promise<AggregateResult<S, O>>;
|
||||
}
|
||||
|
||||
export interface Subscriber<S extends Schema = any> {
|
||||
search$<const O extends SearchOptions<S>>(
|
||||
query: Query<S>,
|
||||
options?: O
|
||||
): Observable<SearchResult<S, O>>;
|
||||
aggregate$<const O extends AggregateOptions<S>>(
|
||||
query: Query<S>,
|
||||
field: keyof S,
|
||||
options?: O
|
||||
): Observable<AggregateResult<S, O>>;
|
||||
}
|
||||
|
||||
type ResultPagination = {
|
||||
count: number;
|
||||
limit: number;
|
||||
skip: number;
|
||||
hasMore: boolean;
|
||||
};
|
||||
|
||||
type PaginationOption = {
|
||||
limit?: number;
|
||||
skip?: number;
|
||||
};
|
||||
|
||||
export type SearchOptions<S extends Schema> = {
|
||||
pagination?: PaginationOption;
|
||||
highlights?: {
|
||||
field: HighlightAbleField<S>;
|
||||
before: string;
|
||||
end: string;
|
||||
}[];
|
||||
fields?: (keyof S)[];
|
||||
};
|
||||
|
||||
export type SearchResult<S extends Schema, O extends SearchOptions<S>> = {
|
||||
pagination: ResultPagination;
|
||||
nodes: ({
|
||||
id: string;
|
||||
score: number;
|
||||
} & (O['fields'] extends any[]
|
||||
? { fields: { [key in O['fields'][number]]: string | string[] } }
|
||||
: unknown) &
|
||||
(O['highlights'] extends any[]
|
||||
? { highlights: { [key in O['highlights'][number]['field']]: string[] } }
|
||||
: unknown))[];
|
||||
};
|
||||
|
||||
export interface AggregateOptions<S extends Schema> {
|
||||
pagination?: PaginationOption;
|
||||
hits?: SearchOptions<S>;
|
||||
}
|
||||
|
||||
export type AggregateResult<S extends Schema, O extends AggregateOptions<S>> = {
|
||||
pagination: ResultPagination;
|
||||
buckets: ({
|
||||
key: string;
|
||||
score: number;
|
||||
count: number;
|
||||
} & (O['hits'] extends object
|
||||
? { hits: SearchResult<S, O['hits']> }
|
||||
: unknown))[];
|
||||
};
|
||||
47
packages/common/infra/src/sync/job/README.md
Normal file
47
packages/common/infra/src/sync/job/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# job
|
||||
|
||||
Job system abstraction for AFFiNE. Currently, only `IndexedDBJobQueue` is implemented; more backends will be implemented in the future.
|
||||
|
||||
Run background jobs in browser & distributed environment. `runners` can consume tasks simultaneously without additional communication.
|
||||
|
||||
# Basic Usage
|
||||
|
||||
```ts
|
||||
const queue = new IndexedDBJobQueue('my-queue');
|
||||
|
||||
await queue.enqueue([
|
||||
{
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
{
|
||||
batchKey: '2',
|
||||
payload: { a: 'world' },
|
||||
},
|
||||
]);
|
||||
|
||||
const runner = new JobRunner(queue, job => {
|
||||
console.log(job);
|
||||
});
|
||||
|
||||
runner.start();
|
||||
|
||||
// Output:
|
||||
// { batchKey: '1', payload: { a: 'hello' } }
|
||||
// { batchKey: '2', payload: { a: 'world' } }
|
||||
```
|
||||
|
||||
## `batchKey`
|
||||
|
||||
Each job has a `batchKey`, and jobs with the same `batchKey` are handed over to one `runner` for execution at once.
|
||||
Additionally, if there are ongoing jobs with the same `batchKey`, other `runners` will not take on jobs with this `batchKey`, ensuring exclusive resource locking.
|
||||
|
||||
> In the future, `batchKey` will be used to implement priority.
|
||||
|
||||
## `timeout`
|
||||
|
||||
If the job execution time exceeds 30 seconds, it will be considered a timeout and reassigned to another `runner`.
|
||||
|
||||
## Error Handling
|
||||
|
||||
If an error is thrown during job execution, will log an error, but the job will be considered complete.
|
||||
231
packages/common/infra/src/sync/job/__tests__/black-box.spec.ts
Normal file
231
packages/common/infra/src/sync/job/__tests__/black-box.spec.ts
Normal file
@@ -0,0 +1,231 @@
|
||||
/**
|
||||
* @vitest-environment happy-dom
|
||||
*/
|
||||
import 'fake-indexeddb/auto';
|
||||
|
||||
import { afterEach, beforeEach, describe, expect, test, vitest } from 'vitest';
|
||||
|
||||
import { IndexedDBJobQueue } from '../impl/indexeddb';
|
||||
import type { JobQueue } from '../queue';
|
||||
|
||||
let queue: JobQueue<{
|
||||
a: string;
|
||||
}> = null!;
|
||||
|
||||
describe.each([{ name: 'idb', backend: IndexedDBJobQueue }])(
|
||||
'impl tests($name)',
|
||||
({ backend }) => {
|
||||
beforeEach(async () => {
|
||||
queue = new backend();
|
||||
|
||||
await queue.clear();
|
||||
|
||||
vitest.useFakeTimers({
|
||||
toFake: ['Date'],
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vitest.useRealTimers();
|
||||
});
|
||||
|
||||
test('basic', async () => {
|
||||
await queue.enqueue([
|
||||
{
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
{
|
||||
batchKey: '2',
|
||||
payload: { a: 'world' },
|
||||
},
|
||||
]);
|
||||
const job1 = await queue.accept();
|
||||
const job2 = await queue.accept();
|
||||
|
||||
expect([job1!, job2!]).toEqual([
|
||||
[
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '2',
|
||||
payload: { a: 'world' },
|
||||
},
|
||||
],
|
||||
]);
|
||||
|
||||
const job3 = await queue.accept();
|
||||
expect(job3).toBeNull();
|
||||
|
||||
await queue.return(job1!);
|
||||
await queue.return(job2!);
|
||||
});
|
||||
|
||||
test('batch', async () => {
|
||||
await queue.enqueue([
|
||||
{
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
{
|
||||
batchKey: '1',
|
||||
payload: { a: 'world' },
|
||||
},
|
||||
]);
|
||||
const job1 = await queue.accept();
|
||||
|
||||
expect(job1).toEqual(
|
||||
expect.arrayContaining([
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '1',
|
||||
payload: { a: 'world' },
|
||||
},
|
||||
])
|
||||
);
|
||||
});
|
||||
|
||||
test('timeout', async () => {
|
||||
await queue.enqueue([
|
||||
{
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
]);
|
||||
{
|
||||
const job = await queue.accept();
|
||||
|
||||
expect(job).toEqual([
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
]);
|
||||
}
|
||||
|
||||
{
|
||||
const job = await queue.accept();
|
||||
|
||||
expect(job).toBeNull();
|
||||
}
|
||||
|
||||
vitest.advanceTimersByTime(1000 * 60 * 60);
|
||||
|
||||
{
|
||||
const job = await queue.accept();
|
||||
|
||||
expect(job).toEqual([
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
]);
|
||||
}
|
||||
});
|
||||
|
||||
test('waitForAccept', async () => {
|
||||
const abort = new AbortController();
|
||||
|
||||
let result = null as any;
|
||||
queue.waitForAccept(abort.signal).then(jobs => (result = jobs));
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
expect(result).toBeNull();
|
||||
|
||||
await queue.enqueue([
|
||||
{
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
]);
|
||||
|
||||
await vitest.waitFor(() => {
|
||||
expect(result).toEqual([
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
test('waitForAccept race', async () => {
|
||||
const abort = new AbortController();
|
||||
|
||||
let result1 = null as any;
|
||||
let result2 = null as any;
|
||||
queue.waitForAccept(abort.signal).then(jobs => (result1 = jobs));
|
||||
queue.waitForAccept(abort.signal).then(jobs => (result2 = jobs));
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
expect(result1).toBeNull();
|
||||
expect(result2).toBeNull();
|
||||
|
||||
await queue.enqueue([
|
||||
{
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
]);
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
expect([result1, result2]).toEqual(
|
||||
expect.arrayContaining([
|
||||
[
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
],
|
||||
null,
|
||||
])
|
||||
);
|
||||
|
||||
await queue.enqueue([
|
||||
{
|
||||
batchKey: '2',
|
||||
payload: { a: 'world' },
|
||||
},
|
||||
]);
|
||||
|
||||
await vitest.waitFor(() => {
|
||||
expect([result1, result2]).toEqual(
|
||||
expect.arrayContaining([
|
||||
[
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '1',
|
||||
payload: { a: 'hello' },
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
id: expect.any(String),
|
||||
batchKey: '2',
|
||||
payload: { a: 'world' },
|
||||
},
|
||||
],
|
||||
])
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
);
|
||||
248
packages/common/infra/src/sync/job/impl/indexeddb/index.ts
Normal file
248
packages/common/infra/src/sync/job/impl/indexeddb/index.ts
Normal file
@@ -0,0 +1,248 @@
|
||||
import type { DBSchema, IDBPDatabase } from 'idb';
|
||||
import { openDB } from 'idb';
|
||||
import { merge, Observable, of, throttleTime } from 'rxjs';
|
||||
|
||||
import { fromPromise } from '../../../../livedata';
|
||||
import { throwIfAborted } from '../../../../utils';
|
||||
import { exhaustMapWithTrailing } from '../../../../utils/exhaustmap-with-trailing';
|
||||
import type { Job, JobParams, JobQueue } from '../../';
|
||||
|
||||
interface IndexDB extends DBSchema {
|
||||
jobs: {
|
||||
key: number;
|
||||
value: JobRecord;
|
||||
indexes: {
|
||||
batchKey: string;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
interface JobRecord {
|
||||
batchKey: string;
|
||||
startTime: number | null;
|
||||
payload: any;
|
||||
}
|
||||
|
||||
export class IndexedDBJobQueue<J> implements JobQueue<J> {
|
||||
database: IDBPDatabase<IndexDB> = null as any;
|
||||
broadcast = new BroadcastChannel('idb-job-queue:' + this.databaseName);
|
||||
|
||||
constructor(private readonly databaseName: string = 'jobs') {}
|
||||
|
||||
async enqueue(jobs: JobParams[]): Promise<void> {
|
||||
await this.ensureInitialized();
|
||||
const trx = this.database.transaction(['jobs'], 'readwrite');
|
||||
|
||||
for (const job of jobs) {
|
||||
await trx.objectStore('jobs').add({
|
||||
batchKey: job.batchKey,
|
||||
payload: job.payload,
|
||||
startTime: null,
|
||||
});
|
||||
}
|
||||
|
||||
trx.commit();
|
||||
|
||||
// send broadcast to notify new jobs
|
||||
this.broadcast.postMessage('new-jobs');
|
||||
}
|
||||
|
||||
async accept(): Promise<Job[] | null> {
|
||||
await this.ensureInitialized();
|
||||
const jobs = [];
|
||||
const trx = this.database.transaction(['jobs'], 'readwrite');
|
||||
|
||||
// if no priority jobs
|
||||
|
||||
if (jobs.length === 0) {
|
||||
const batchKeys = trx.objectStore('jobs').index('batchKey').iterate();
|
||||
|
||||
let currentBatchKey: string = null as any;
|
||||
let currentBatchJobs = [];
|
||||
let skipCurrentBatch = false;
|
||||
|
||||
for await (const item of batchKeys) {
|
||||
if (item.value.batchKey !== currentBatchKey) {
|
||||
if (!skipCurrentBatch && currentBatchJobs.length > 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
currentBatchKey = item.value.batchKey;
|
||||
currentBatchJobs = [];
|
||||
skipCurrentBatch = false;
|
||||
}
|
||||
if (skipCurrentBatch) {
|
||||
continue;
|
||||
}
|
||||
if (this.isAcceptable(item.value)) {
|
||||
currentBatchJobs.push({
|
||||
id: item.primaryKey,
|
||||
job: item.value,
|
||||
});
|
||||
} else {
|
||||
skipCurrentBatch = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (skipCurrentBatch === false && currentBatchJobs.length > 0) {
|
||||
jobs.push(...currentBatchJobs);
|
||||
}
|
||||
}
|
||||
|
||||
for (const { id, job } of jobs) {
|
||||
const startTime = Date.now();
|
||||
await trx.objectStore('jobs').put({ ...job, startTime }, id);
|
||||
}
|
||||
|
||||
if (jobs.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return jobs.map(({ id, job }) => ({
|
||||
id: id.toString(),
|
||||
batchKey: job.batchKey,
|
||||
payload: job.payload,
|
||||
}));
|
||||
}
|
||||
|
||||
async waitForAccept(signal: AbortSignal): Promise<Job<J>[]> {
|
||||
const broadcast = new BroadcastChannel(
|
||||
'idb-job-queue:' + this.databaseName
|
||||
);
|
||||
|
||||
try {
|
||||
let deferred = Promise.withResolvers<void>();
|
||||
|
||||
broadcast.onmessage = () => {
|
||||
deferred.resolve();
|
||||
};
|
||||
|
||||
while (throwIfAborted(signal)) {
|
||||
const jobs = await this.accept();
|
||||
if (jobs !== null) {
|
||||
return jobs;
|
||||
}
|
||||
|
||||
await Promise.race([
|
||||
deferred.promise,
|
||||
new Promise(resolve => {
|
||||
setTimeout(resolve, 5000);
|
||||
}),
|
||||
new Promise((_, reject) => {
|
||||
// exit if manually stopped
|
||||
if (signal?.aborted) {
|
||||
reject(signal.reason);
|
||||
}
|
||||
signal?.addEventListener('abort', () => {
|
||||
reject(signal.reason);
|
||||
});
|
||||
}),
|
||||
]);
|
||||
deferred = Promise.withResolvers<void>();
|
||||
}
|
||||
return [];
|
||||
} finally {
|
||||
broadcast.close();
|
||||
}
|
||||
}
|
||||
|
||||
async complete(jobs: Job[]): Promise<void> {
|
||||
await this.ensureInitialized();
|
||||
const trx = this.database.transaction(['jobs'], 'readwrite');
|
||||
|
||||
for (const { id } of jobs) {
|
||||
await trx
|
||||
.objectStore('jobs')
|
||||
.delete(typeof id === 'string' ? parseInt(id) : id);
|
||||
}
|
||||
|
||||
trx.commit();
|
||||
this.broadcast.postMessage('job-completed');
|
||||
}
|
||||
|
||||
async return(jobs: Job[], retry: boolean = false): Promise<void> {
|
||||
await this.ensureInitialized();
|
||||
const trx = this.database.transaction(['jobs'], 'readwrite');
|
||||
|
||||
for (const { id } of jobs) {
|
||||
if (retry) {
|
||||
const nid = typeof id === 'string' ? parseInt(id) : id;
|
||||
const job = await trx.objectStore('jobs').get(nid);
|
||||
if (job) {
|
||||
await trx.objectStore('jobs').put({ ...job, startTime: null }, nid);
|
||||
}
|
||||
} else {
|
||||
await trx
|
||||
.objectStore('jobs')
|
||||
.delete(typeof id === 'string' ? parseInt(id) : id);
|
||||
}
|
||||
}
|
||||
|
||||
trx.commit();
|
||||
|
||||
this.broadcast.postMessage('job-completed');
|
||||
}
|
||||
|
||||
async clear(): Promise<void> {
|
||||
await this.ensureInitialized();
|
||||
const trx = this.database.transaction(['jobs'], 'readwrite');
|
||||
await trx.objectStore('jobs').clear();
|
||||
}
|
||||
|
||||
private async ensureInitialized(): Promise<void> {
|
||||
if (!this.database) {
|
||||
await this.initialize();
|
||||
}
|
||||
}
|
||||
|
||||
private async initialize(): Promise<void> {
|
||||
if (this.database) {
|
||||
return;
|
||||
}
|
||||
this.database = await openDB(this.databaseName, 1, {
|
||||
upgrade(database) {
|
||||
const jobs = database.createObjectStore('jobs', {
|
||||
autoIncrement: true,
|
||||
});
|
||||
jobs.createIndex('batchKey', 'batchKey');
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
TIMEOUT = 1000 * 30 /* 30 seconds */;
|
||||
|
||||
private isTimeout(job: JobRecord) {
|
||||
return job.startTime !== null && job.startTime + this.TIMEOUT < Date.now();
|
||||
}
|
||||
|
||||
private isAcceptable(job: JobRecord) {
|
||||
return job.startTime === null || this.isTimeout(job);
|
||||
}
|
||||
|
||||
get status$() {
|
||||
return merge(
|
||||
of(1),
|
||||
new Observable(subscriber => {
|
||||
const broadcast = new BroadcastChannel(
|
||||
'idb-job-queue:' + this.databaseName
|
||||
);
|
||||
|
||||
broadcast.onmessage = () => {
|
||||
subscriber.next(1);
|
||||
};
|
||||
return () => {
|
||||
broadcast.close();
|
||||
};
|
||||
})
|
||||
).pipe(
|
||||
throttleTime(300, undefined, { leading: true, trailing: true }),
|
||||
exhaustMapWithTrailing(() =>
|
||||
fromPromise(async () => {
|
||||
const trx = this.database.transaction(['jobs'], 'readonly');
|
||||
const remaining = await trx.objectStore('jobs').count();
|
||||
return { remaining };
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
2
packages/common/infra/src/sync/job/index.ts
Normal file
2
packages/common/infra/src/sync/job/index.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export * from './queue';
|
||||
export * from './runner';
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user