+ <>
+
+
{
@@ -80,13 +72,28 @@ const DetailContentImpl = () => {
}}
/>
+ >
+ );
+};
+
+const DetailContentImpl = () => {
+ const { conversationAtom } = useChatAtoms();
+ const conversations = useAtomValue(conversationAtom);
+
+ return (
+
);
};
diff --git a/plugins/copilot/src/UI/header-item.tsx b/plugins/copilot/src/UI/header-item.tsx
index 054573707c..c6f32b7ee1 100644
--- a/plugins/copilot/src/UI/header-item.tsx
+++ b/plugins/copilot/src/UI/header-item.tsx
@@ -19,7 +19,7 @@ export const HeaderItem: PluginUIAdapter['headerItem'] = ({
direction: 'horizontal',
first: 'editor',
second: 'com.affine.copilot',
- splitPercentage: 80,
+ splitPercentage: 70,
};
} else {
return 'editor';
diff --git a/plugins/copilot/src/UI/index.css.ts b/plugins/copilot/src/UI/index.css.ts
new file mode 100644
index 0000000000..53d3ae65c0
--- /dev/null
+++ b/plugins/copilot/src/UI/index.css.ts
@@ -0,0 +1,15 @@
+import { style } from '@vanilla-extract/css';
+
+export const detailContentStyle = style({
+ backgroundColor: 'rgba(0, 0, 0, 0.04)',
+ height: '100%',
+ display: 'flex',
+ flexDirection: 'column',
+
+ overflow: 'auto',
+
+ paddingLeft: '9px',
+ paddingRight: '9px',
+});
+
+export const detailContentActionsStyle = style({});
diff --git a/plugins/copilot/src/core/chat.ts b/plugins/copilot/src/core/chat.ts
index 3142c94c23..555c630ab3 100644
--- a/plugins/copilot/src/core/chat.ts
+++ b/plugins/copilot/src/core/chat.ts
@@ -1,16 +1,17 @@
-import { ConversationChain } from 'langchain/chains';
+import { ConversationChain, LLMChain } from 'langchain/chains';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { BufferMemory } from 'langchain/memory';
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
+ PromptTemplate,
SystemMessagePromptTemplate,
} from 'langchain/prompts';
import { type LLMResult } from 'langchain/schema';
import { IndexedDBChatMessageHistory } from './langchain/message-history';
-import { chatPrompt } from './prompts';
+import { chatPrompt, followupQuestionPrompt } from './prompts';
declare global {
interface WindowEventMap {
@@ -22,13 +23,24 @@ declare global {
export async function createChatAI(
room: string,
openAIApiKey: string
-): Promise
{
+): Promise<{
+ conversationChain: ConversationChain;
+ followupChain: LLMChain;
+ chatHistory: IndexedDBChatMessageHistory;
+}> {
if (!openAIApiKey) {
console.warn('OpenAI API key not set, chat will not work');
}
+ const followup = new ChatOpenAI({
+ streaming: false,
+ modelName: 'gpt-3.5-turbo',
+ temperature: 0.5,
+ openAIApiKey: openAIApiKey,
+ });
+
const chat = new ChatOpenAI({
streaming: true,
- modelName: 'gpt-4',
+ modelName: 'gpt-3.5-turbo',
temperature: 0.5,
openAIApiKey: openAIApiKey,
callbacks: [
@@ -77,13 +89,32 @@ export async function createChatAI(
HumanMessagePromptTemplate.fromTemplate('{input}'),
]);
- return new ConversationChain({
+ const followupPromptTemplate = new PromptTemplate({
+ template: followupQuestionPrompt,
+ inputVariables: ['human_conversation', 'ai_conversation'],
+ });
+
+ const followupChain = new LLMChain({
+ llm: followup,
+ prompt: followupPromptTemplate,
+ memory: undefined,
+ });
+
+ const chatHistory = new IndexedDBChatMessageHistory(room);
+
+ const conversationChain = new ConversationChain({
memory: new BufferMemory({
returnMessages: true,
memoryKey: 'history',
- chatHistory: new IndexedDBChatMessageHistory(room),
+ chatHistory,
}),
prompt: chatPromptTemplate,
llm: chat,
});
+
+ return {
+ conversationChain,
+ followupChain,
+ chatHistory,
+ } as const;
}
diff --git a/plugins/copilot/src/core/components/conversation-list/index.css.ts b/plugins/copilot/src/core/components/conversation-list/index.css.ts
new file mode 100644
index 0000000000..02c837150c
--- /dev/null
+++ b/plugins/copilot/src/core/components/conversation-list/index.css.ts
@@ -0,0 +1,7 @@
+import { style } from '@vanilla-extract/css';
+
+export const conversationListStyle = style({
+ display: 'flex',
+ flexDirection: 'column',
+ gap: '30px',
+});
diff --git a/plugins/copilot/src/core/components/conversation-list/index.tsx b/plugins/copilot/src/core/components/conversation-list/index.tsx
new file mode 100644
index 0000000000..c3bbaa1bb7
--- /dev/null
+++ b/plugins/copilot/src/core/components/conversation-list/index.tsx
@@ -0,0 +1,22 @@
+import type { BaseChatMessage } from 'langchain/schema';
+
+import { Conversation } from '../conversation';
+import { conversationListStyle } from './index.css';
+
+export type ConversationListProps = {
+ conversations: BaseChatMessage[];
+};
+
+export const ConversationList = (props: ConversationListProps) => {
+ return (
+
+ {props.conversations.map((conversation, idx) => (
+
+ ))}
+
+ );
+};
diff --git a/plugins/copilot/src/core/components/conversation.tsx b/plugins/copilot/src/core/components/conversation.tsx
deleted file mode 100644
index 18411604bf..0000000000
--- a/plugins/copilot/src/core/components/conversation.tsx
+++ /dev/null
@@ -1,19 +0,0 @@
-import { marked } from 'marked';
-import { type ReactElement, useMemo } from 'react';
-
-export interface ConversationProps {
- text: string;
-}
-
-export const Conversation = (props: ConversationProps): ReactElement => {
- const html = useMemo(() => marked.parse(props.text), [props.text]);
- return (
-
- );
-};
diff --git a/plugins/copilot/src/core/components/conversation/index.css.ts b/plugins/copilot/src/core/components/conversation/index.css.ts
new file mode 100644
index 0000000000..cb09dbf486
--- /dev/null
+++ b/plugins/copilot/src/core/components/conversation/index.css.ts
@@ -0,0 +1,15 @@
+import { style } from '@vanilla-extract/css';
+
+export const conversationStyle = style({
+ padding: '10px 18px',
+});
+
+export const aiMessageStyle = style({
+ backgroundColor: 'rgba(207, 252, 255, 0.3)',
+ borderRadius: '18px 18px 18px 2px',
+});
+
+export const humanMessageStyle = style({
+ borderRadius: '18px 18px 2px 18px',
+ backgroundColor: 'white',
+});
diff --git a/plugins/copilot/src/core/components/conversation/index.tsx b/plugins/copilot/src/core/components/conversation/index.tsx
new file mode 100644
index 0000000000..17c8d312d2
--- /dev/null
+++ b/plugins/copilot/src/core/components/conversation/index.tsx
@@ -0,0 +1,42 @@
+import { clsx } from 'clsx';
+import type { MessageType } from 'langchain/schema';
+import { marked } from 'marked';
+import { gfmHeadingId } from 'marked-gfm-heading-id';
+// eslint-disable-next-line @typescript-eslint/ban-ts-comment
+// @ts-expect-error
+import { mangle } from 'marked-mangle';
+import { type ReactElement, useMemo } from 'react';
+
+import {
+ aiMessageStyle,
+ conversationStyle,
+ humanMessageStyle,
+} from './index.css';
+
+marked.use(
+ gfmHeadingId({
+ prefix: 'affine-',
+ })
+);
+
+marked.use(mangle());
+
+export interface ConversationProps {
+ type: MessageType;
+ text: string;
+}
+
+export const Conversation = (props: ConversationProps): ReactElement => {
+ const html = useMemo(() => marked.parse(props.text), [props.text]);
+ return (
+
+ );
+};
diff --git a/plugins/copilot/src/core/components/following-up/index.css.ts b/plugins/copilot/src/core/components/following-up/index.css.ts
new file mode 100644
index 0000000000..0205999766
--- /dev/null
+++ b/plugins/copilot/src/core/components/following-up/index.css.ts
@@ -0,0 +1,17 @@
+import { style } from '@vanilla-extract/css';
+
+export const followingUpStyle = style({
+ display: 'flex',
+ flexDirection: 'row',
+ flexWrap: 'wrap',
+ gap: '10px',
+ alignItems: 'flex-start',
+});
+
+export const questionStyle = style({
+ backgroundColor: 'white',
+ borderRadius: '8px',
+ color: '#8E8D91',
+ padding: '2px 10px',
+ cursor: 'pointer',
+});
diff --git a/plugins/copilot/src/core/components/following-up/index.tsx b/plugins/copilot/src/core/components/following-up/index.tsx
new file mode 100644
index 0000000000..10afe03c76
--- /dev/null
+++ b/plugins/copilot/src/core/components/following-up/index.tsx
@@ -0,0 +1,19 @@
+import type { ReactElement } from 'react';
+
+import { followingUpStyle, questionStyle } from './index.css';
+
+export type FollowingUpProps = {
+ questions: string[];
+};
+
+export const FollowingUp = (props: FollowingUpProps): ReactElement => {
+ return (
+
+ {props.questions.map((question, index) => (
+
+ {question}
+
+ ))}
+
+ );
+};
diff --git a/plugins/copilot/src/core/hooks/index.ts b/plugins/copilot/src/core/hooks/index.ts
index 81fbfeb5ca..02772d6e1c 100644
--- a/plugins/copilot/src/core/hooks/index.ts
+++ b/plugins/copilot/src/core/hooks/index.ts
@@ -1,6 +1,9 @@
+import type { IndexedDBChatMessageHistory } from '@affine/copilot/core/langchain/message-history';
import { atom, useAtomValue } from 'jotai';
-import { atomFamily } from 'jotai/utils';
+import { atomWithDefault } from 'jotai/utils';
import { atomWithStorage } from 'jotai/utils';
+import type { WritableAtom } from 'jotai/vanilla';
+import type { LLMChain } from 'langchain/chains';
import { type ConversationChain } from 'langchain/chains';
import { type BufferMemory } from 'langchain/memory';
import {
@@ -8,9 +11,12 @@ import {
type BaseChatMessage,
HumanChatMessage,
} from 'langchain/schema';
+import { z } from 'zod';
import { createChatAI } from '../chat';
+const followupResponseSchema = z.array(z.string());
+
export const openAIApiKeyAtom = atomWithStorage(
'com.affine.copilot.openai.token',
null
@@ -19,12 +25,24 @@ export const openAIApiKeyAtom = atomWithStorage(
export const chatAtom = atom(async get => {
const openAIApiKey = get(openAIApiKeyAtom);
if (!openAIApiKey) {
- return null;
+ throw new Error('OpenAI API key not set, chat will not work');
}
return createChatAI('default-copilot', openAIApiKey);
});
-const conversationAtomFamily = atomFamily((chat: ConversationChain | null) => {
+const conversationWeakMap = new WeakMap<
+ ConversationChain,
+ WritableAtom>
+>();
+
+const getConversationAtom = (chat: ConversationChain) => {
+ if (conversationWeakMap.has(chat)) {
+ return conversationWeakMap.get(chat) as WritableAtom<
+ BaseChatMessage[],
+ [string],
+ Promise
+ >;
+ }
const conversationBaseAtom = atom([]);
conversationBaseAtom.onMount = setAtom => {
if (!chat) {
@@ -52,7 +70,7 @@ const conversationAtomFamily = atomFamily((chat: ConversationChain | null) => {
};
};
- return atom>(
+ const conversationAtom = atom>(
get => get(conversationBaseAtom),
async (get, set, input) => {
if (!chat) {
@@ -73,14 +91,75 @@ const conversationAtomFamily = atomFamily((chat: ConversationChain | null) => {
});
}
);
-});
+ conversationWeakMap.set(chat, conversationAtom);
+ return conversationAtom;
+};
+
+const followingUpWeakMap = new WeakMap<
+ LLMChain,
+ {
+ questionsAtom: ReturnType>>;
+ generateChatAtom: WritableAtom;
+ }
+>();
+
+const getFollowingUpAtoms = (
+ followupLLMChain: LLMChain,
+ chatHistory: IndexedDBChatMessageHistory
+) => {
+ if (followingUpWeakMap.has(followupLLMChain)) {
+ return followingUpWeakMap.get(followupLLMChain) as {
+ questionsAtom: ReturnType>>;
+ generateChatAtom: WritableAtom;
+ };
+ }
+ const baseAtom = atomWithDefault>(async () => {
+ return chatHistory?.getFollowingUp() ?? [];
+ });
+ const setAtom = atom(null, async (get, set) => {
+ if (!followupLLMChain || !chatHistory) {
+ throw new Error('followupLLMChain not set');
+ }
+ const messages = await chatHistory.getMessages();
+ const aiMessage = messages.findLast(
+ message => message._getType() === 'ai'
+ )?.text;
+ const humanMessage = messages.findLast(
+ message => message._getType() === 'human'
+ )?.text;
+ const response = await followupLLMChain.call({
+ ai_conversation: aiMessage,
+ human_conversation: humanMessage,
+ });
+ const followingUp = JSON.parse(response.text);
+ followupResponseSchema.parse(followingUp);
+ set(baseAtom, followingUp);
+ chatHistory.saveFollowingUp(followingUp).catch(() => {
+ console.error('failed to save followup');
+ });
+ });
+ followingUpWeakMap.set(followupLLMChain, {
+ questionsAtom: baseAtom,
+ generateChatAtom: setAtom,
+ });
+ return {
+ questionsAtom: baseAtom,
+ generateChatAtom: setAtom,
+ };
+};
export function useChatAtoms(): {
- conversationAtom: ReturnType;
+ conversationAtom: ReturnType;
+ followingUpAtoms: ReturnType;
} {
const chat = useAtomValue(chatAtom);
- const conversationAtom = conversationAtomFamily(chat);
+ const conversationAtom = getConversationAtom(chat.conversationChain);
+ const followingUpAtoms = getFollowingUpAtoms(
+ chat.followupChain,
+ chat.chatHistory
+ );
return {
conversationAtom,
+ followingUpAtoms,
};
}
diff --git a/plugins/copilot/src/core/langchain/message-history.ts b/plugins/copilot/src/core/langchain/message-history.ts
index cc8141ceb8..072f96afac 100644
--- a/plugins/copilot/src/core/langchain/message-history.ts
+++ b/plugins/copilot/src/core/langchain/message-history.ts
@@ -23,25 +23,42 @@ interface ChatMessageDBV1 extends DBSchema {
};
}
+interface ChatMessageDBV2 extends ChatMessageDBV1 {
+ followingUp: {
+ key: string;
+ value: {
+ /**
+ * ID of the chat
+ */
+ id: string;
+ question: string[];
+ };
+ };
+}
+
export const conversationHistoryDBName = 'affine-copilot-chat';
export class IndexedDBChatMessageHistory extends BaseChatMessageHistory {
public id: string;
private messages: BaseChatMessage[] = [];
- private readonly dbPromise: Promise>;
+ private readonly dbPromise: Promise>;
private readonly initPromise: Promise;
constructor(id: string) {
super();
this.id = id;
this.messages = [];
- this.dbPromise = openDB('affine-copilot-chat', 1, {
+ this.dbPromise = openDB('affine-copilot-chat', 2, {
upgrade(database, oldVersion) {
if (oldVersion === 0) {
database.createObjectStore('chat', {
keyPath: 'id',
});
+ } else if (oldVersion === 1) {
+ database.createObjectStore('followingUp', {
+ keyPath: 'id',
+ });
}
},
});
@@ -70,6 +87,31 @@ export class IndexedDBChatMessageHistory extends BaseChatMessageHistory {
});
}
+ public async saveFollowingUp(question: string[]): Promise {
+ await this.initPromise;
+ const db = await this.dbPromise;
+ const t = db
+ .transaction('followingUp', 'readwrite')
+ .objectStore('followingUp');
+ await t.put({
+ id: this.id,
+ question,
+ });
+ }
+
+ public async getFollowingUp(): Promise {
+ await this.initPromise;
+ const db = await this.dbPromise;
+ const t = db
+ .transaction('followingUp', 'readonly')
+ .objectStore('followingUp');
+ const chat = await t.get(this.id);
+ if (chat != null) {
+ return chat.question;
+ }
+ return [];
+ }
+
protected async addMessage(message: BaseChatMessage): Promise {
await this.initPromise;
this.messages.push(message);
@@ -104,6 +146,6 @@ export class IndexedDBChatMessageHistory extends BaseChatMessageHistory {
}
async getMessages(): Promise {
- return await this.initPromise.then(() => this.messages);
+ return this.initPromise.then(() => this.messages);
}
}
diff --git a/plugins/copilot/src/core/prompts/index.ts b/plugins/copilot/src/core/prompts/index.ts
index efc53a36ad..0e745735c7 100644
--- a/plugins/copilot/src/core/prompts/index.ts
+++ b/plugins/copilot/src/core/prompts/index.ts
@@ -14,8 +14,19 @@ Keep your answers short and impersonal.
The user works in an app called AFFiNE, which has a concept for an editor, a page for a single document, workspace for a collection of documents.
The active document is the markdown file the user is looking at.
Use Markdown formatting in your answers.
-Wrap your answers into triple backticks.
You can only give one reply for each conversation turn.
-You should always generate short suggestions for the next user turns that are relevant to the conversation and not offensive.
-You should reply to the users within 150 characters.
+`;
+
+export const followupQuestionPrompt = `Rules you must follow:
+- You only respond in JSON format
+- Read the following conversation between AI and Human and generate at most 3 follow-up messages or questions the Human can ask
+- Your response MUST be a valid JSON array of strings like this: ["some question", "another question"]
+- Each message in your response should be concise, no more than 15 words
+- You MUST reply in the same written language as the conversation
+- Don't output anything other text
+The conversation is inside triple quotes:
+\`\`\`
+Human: {human_conversation}
+AI: {ai_conversation}
+\`\`\`
`;
diff --git a/yarn.lock b/yarn.lock
index ea9ae325a9..add4ab6e56 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -130,10 +130,13 @@ __metadata:
"@types/react-dom": ^18.2.4
idb: ^7.1.1
jotai: ^2.1.1
- langchain: ^0.0.90
+ langchain: ^0.0.92
marked: ^5.0.4
+ marked-gfm-heading-id: ^3.0.4
+ marked-mangle: ^1.0.1
react: 18.3.0-canary-16d053d59-20230506
react-dom: 18.3.0-canary-16d053d59-20230506
+ zod: ^3.21.4
peerDependencies:
react: "*"
react-dom: "*"
@@ -300,6 +303,7 @@ __metadata:
husky: ^8.0.3
lint-staged: ^13.2.2
madge: ^6.1.0
+ marked-mangle: ^1.0.1
msw: ^1.2.1
nanoid: ^4.0.2
nx: 16.3.2
@@ -16830,6 +16834,13 @@ __metadata:
languageName: node
linkType: hard
+"github-slugger@npm:^2.0.0":
+ version: 2.0.0
+ resolution: "github-slugger@npm:2.0.0"
+ checksum: 250375cde2058f21454872c2c79f72c4637340c30c51ff158ca4ec71cbc478f33d54477d787a662f9207aeb095a2060f155bc01f15329ba8a5fb6698e0fc81f8
+ languageName: node
+ linkType: hard
+
"glob-parent@npm:^5.1.2, glob-parent@npm:~5.1.2":
version: 5.1.2
resolution: "glob-parent@npm:5.1.2"
@@ -19787,9 +19798,9 @@ __metadata:
languageName: node
linkType: hard
-"langchain@npm:^0.0.90":
- version: 0.0.90
- resolution: "langchain@npm:0.0.90"
+"langchain@npm:^0.0.92":
+ version: 0.0.92
+ resolution: "langchain@npm:0.0.92"
dependencies:
"@anthropic-ai/sdk": ^0.4.3
ansi-styles: ^5.0.0
@@ -19835,7 +19846,7 @@ __metadata:
cohere-ai: ^5.0.2
d3-dsv: ^2.0.0
epub2: ^3.0.1
- faiss-node: ^0.2.0
+ faiss-node: ^0.2.1
google-auth-library: ^8.8.0
hnswlib-node: ^1.4.2
html-to-text: ^9.0.5
@@ -19947,7 +19958,7 @@ __metadata:
optional: true
weaviate-ts-client:
optional: true
- checksum: cbd35a7397b30ad265bf4ce992818ee03a409c33dde7a7a1e17c76e7e875768dbc79ff06618e42b3c688081ff8b0ff5eb34b532a2994d078578de67937a6452b
+ checksum: a9fb6cdf48e971f9327bcf5eb2322eb9b9add63413ecacd90bb064feab40d3ba6024086ccade59c0285dd57249406dda9ba7847cdb02e08312e147618165b49e
languageName: node
linkType: hard
@@ -20760,6 +20771,26 @@ __metadata:
languageName: node
linkType: hard
+"marked-gfm-heading-id@npm:^3.0.4":
+ version: 3.0.4
+ resolution: "marked-gfm-heading-id@npm:3.0.4"
+ dependencies:
+ github-slugger: ^2.0.0
+ peerDependencies:
+ marked: ^4 || ^5
+ checksum: 8c65d0fe0f59291d6f4529f49a4af00cad5787a72a59b9397f82d6ab58b1affbf53678f7efdca77fa75c32e80bbdab1c53c2aa5d7ece3e821d2fe35a76fb4930
+ languageName: node
+ linkType: hard
+
+"marked-mangle@npm:^1.0.1":
+ version: 1.0.1
+ resolution: "marked-mangle@npm:1.0.1"
+ peerDependencies:
+ marked: ^4 || ^5
+ checksum: 7ebdd69eb41907a865141e7cce1cd8c5c24dd6955c63aaf5299710aba2935abdfbca1b1ea0ff30cf527ff20a940c3b5a9311fca9abb84f4d294767b0435213fe
+ languageName: node
+ linkType: hard
+
"marked@npm:^4.2.12":
version: 4.3.0
resolution: "marked@npm:4.3.0"