mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-12 04:18:54 +00:00
feat: improve copilot (#2758)
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
import { Button, Input } from '@affine/component';
|
||||
import { rootStore } from '@toeverything/plugin-infra/manager';
|
||||
import type { PluginUIAdapter } from '@toeverything/plugin-infra/type';
|
||||
import { Provider, useAtom, useAtomValue, useSetAtom } from 'jotai';
|
||||
import { Provider, useAtomValue, useSetAtom } from 'jotai';
|
||||
import type { ReactElement } from 'react';
|
||||
import { Fragment, StrictMode, useState } from 'react';
|
||||
import { StrictMode, Suspense, useCallback, useState } from 'react';
|
||||
import { createRoot } from 'react-dom/client';
|
||||
|
||||
import { Conversation } from '../core/components/conversation';
|
||||
import { Divider } from '../core/components/divider';
|
||||
import { ConversationList } from '../core/components/conversation-list';
|
||||
import { FollowingUp } from '../core/components/following-up';
|
||||
import { openAIApiKeyAtom, useChatAtoms } from '../core/hooks';
|
||||
import { detailContentActionsStyle, detailContentStyle } from './index.css';
|
||||
|
||||
if (typeof window === 'undefined') {
|
||||
import('@blocksuite/blocks').then(({ FormatQuickBar }) => {
|
||||
@@ -54,25 +55,16 @@ if (typeof window === 'undefined') {
|
||||
});
|
||||
}
|
||||
|
||||
const DetailContentImpl = () => {
|
||||
const Actions = () => {
|
||||
const { conversationAtom, followingUpAtoms } = useChatAtoms();
|
||||
const call = useSetAtom(conversationAtom);
|
||||
const questions = useAtomValue(followingUpAtoms.questionsAtom);
|
||||
const generateFollowingUp = useSetAtom(followingUpAtoms.generateChatAtom);
|
||||
const [input, setInput] = useState('');
|
||||
const { conversationAtom } = useChatAtoms();
|
||||
const [conversations, call] = useAtom(conversationAtom);
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
width: '300px',
|
||||
}}
|
||||
>
|
||||
{conversations.map((message, idx) => {
|
||||
return (
|
||||
<Fragment key={idx}>
|
||||
<Conversation text={message.text} />
|
||||
<Divider />
|
||||
</Fragment>
|
||||
);
|
||||
})}
|
||||
<div>
|
||||
<>
|
||||
<FollowingUp questions={questions} />
|
||||
<div className={detailContentActionsStyle}>
|
||||
<Input
|
||||
value={input}
|
||||
onChange={text => {
|
||||
@@ -80,13 +72,28 @@ const DetailContentImpl = () => {
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
onClick={() => {
|
||||
void call(input);
|
||||
}}
|
||||
onClick={useCallback(async () => {
|
||||
await call(input);
|
||||
await generateFollowingUp();
|
||||
}, [call, generateFollowingUp, input])}
|
||||
>
|
||||
send
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const DetailContentImpl = () => {
|
||||
const { conversationAtom } = useChatAtoms();
|
||||
const conversations = useAtomValue(conversationAtom);
|
||||
|
||||
return (
|
||||
<div className={detailContentStyle}>
|
||||
<ConversationList conversations={conversations} />
|
||||
<Suspense fallback="generating follow-up question">
|
||||
<Actions />
|
||||
</Suspense>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -19,7 +19,7 @@ export const HeaderItem: PluginUIAdapter['headerItem'] = ({
|
||||
direction: 'horizontal',
|
||||
first: 'editor',
|
||||
second: 'com.affine.copilot',
|
||||
splitPercentage: 80,
|
||||
splitPercentage: 70,
|
||||
};
|
||||
} else {
|
||||
return 'editor';
|
||||
|
||||
15
plugins/copilot/src/UI/index.css.ts
Normal file
15
plugins/copilot/src/UI/index.css.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import { style } from '@vanilla-extract/css';
|
||||
|
||||
export const detailContentStyle = style({
|
||||
backgroundColor: 'rgba(0, 0, 0, 0.04)',
|
||||
height: '100%',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
|
||||
overflow: 'auto',
|
||||
|
||||
paddingLeft: '9px',
|
||||
paddingRight: '9px',
|
||||
});
|
||||
|
||||
export const detailContentActionsStyle = style({});
|
||||
@@ -1,16 +1,17 @@
|
||||
import { ConversationChain } from 'langchain/chains';
|
||||
import { ConversationChain, LLMChain } from 'langchain/chains';
|
||||
import { ChatOpenAI } from 'langchain/chat_models/openai';
|
||||
import { BufferMemory } from 'langchain/memory';
|
||||
import {
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
PromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
} from 'langchain/prompts';
|
||||
import { type LLMResult } from 'langchain/schema';
|
||||
|
||||
import { IndexedDBChatMessageHistory } from './langchain/message-history';
|
||||
import { chatPrompt } from './prompts';
|
||||
import { chatPrompt, followupQuestionPrompt } from './prompts';
|
||||
|
||||
declare global {
|
||||
interface WindowEventMap {
|
||||
@@ -22,13 +23,24 @@ declare global {
|
||||
export async function createChatAI(
|
||||
room: string,
|
||||
openAIApiKey: string
|
||||
): Promise<ConversationChain> {
|
||||
): Promise<{
|
||||
conversationChain: ConversationChain;
|
||||
followupChain: LLMChain<string>;
|
||||
chatHistory: IndexedDBChatMessageHistory;
|
||||
}> {
|
||||
if (!openAIApiKey) {
|
||||
console.warn('OpenAI API key not set, chat will not work');
|
||||
}
|
||||
const followup = new ChatOpenAI({
|
||||
streaming: false,
|
||||
modelName: 'gpt-3.5-turbo',
|
||||
temperature: 0.5,
|
||||
openAIApiKey: openAIApiKey,
|
||||
});
|
||||
|
||||
const chat = new ChatOpenAI({
|
||||
streaming: true,
|
||||
modelName: 'gpt-4',
|
||||
modelName: 'gpt-3.5-turbo',
|
||||
temperature: 0.5,
|
||||
openAIApiKey: openAIApiKey,
|
||||
callbacks: [
|
||||
@@ -77,13 +89,32 @@ export async function createChatAI(
|
||||
HumanMessagePromptTemplate.fromTemplate('{input}'),
|
||||
]);
|
||||
|
||||
return new ConversationChain({
|
||||
const followupPromptTemplate = new PromptTemplate({
|
||||
template: followupQuestionPrompt,
|
||||
inputVariables: ['human_conversation', 'ai_conversation'],
|
||||
});
|
||||
|
||||
const followupChain = new LLMChain({
|
||||
llm: followup,
|
||||
prompt: followupPromptTemplate,
|
||||
memory: undefined,
|
||||
});
|
||||
|
||||
const chatHistory = new IndexedDBChatMessageHistory(room);
|
||||
|
||||
const conversationChain = new ConversationChain({
|
||||
memory: new BufferMemory({
|
||||
returnMessages: true,
|
||||
memoryKey: 'history',
|
||||
chatHistory: new IndexedDBChatMessageHistory(room),
|
||||
chatHistory,
|
||||
}),
|
||||
prompt: chatPromptTemplate,
|
||||
llm: chat,
|
||||
});
|
||||
|
||||
return {
|
||||
conversationChain,
|
||||
followupChain,
|
||||
chatHistory,
|
||||
} as const;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
import { style } from '@vanilla-extract/css';
|
||||
|
||||
export const conversationListStyle = style({
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: '30px',
|
||||
});
|
||||
@@ -0,0 +1,22 @@
|
||||
import type { BaseChatMessage } from 'langchain/schema';
|
||||
|
||||
import { Conversation } from '../conversation';
|
||||
import { conversationListStyle } from './index.css';
|
||||
|
||||
export type ConversationListProps = {
|
||||
conversations: BaseChatMessage[];
|
||||
};
|
||||
|
||||
export const ConversationList = (props: ConversationListProps) => {
|
||||
return (
|
||||
<div className={conversationListStyle}>
|
||||
{props.conversations.map((conversation, idx) => (
|
||||
<Conversation
|
||||
type={conversation._getType()}
|
||||
text={conversation.text}
|
||||
key={idx}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,19 +0,0 @@
|
||||
import { marked } from 'marked';
|
||||
import { type ReactElement, useMemo } from 'react';
|
||||
|
||||
export interface ConversationProps {
|
||||
text: string;
|
||||
}
|
||||
|
||||
export const Conversation = (props: ConversationProps): ReactElement => {
|
||||
const html = useMemo(() => marked.parse(props.text), [props.text]);
|
||||
return (
|
||||
<div>
|
||||
<div
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: html,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,15 @@
|
||||
import { style } from '@vanilla-extract/css';
|
||||
|
||||
export const conversationStyle = style({
|
||||
padding: '10px 18px',
|
||||
});
|
||||
|
||||
export const aiMessageStyle = style({
|
||||
backgroundColor: 'rgba(207, 252, 255, 0.3)',
|
||||
borderRadius: '18px 18px 18px 2px',
|
||||
});
|
||||
|
||||
export const humanMessageStyle = style({
|
||||
borderRadius: '18px 18px 2px 18px',
|
||||
backgroundColor: 'white',
|
||||
});
|
||||
42
plugins/copilot/src/core/components/conversation/index.tsx
Normal file
42
plugins/copilot/src/core/components/conversation/index.tsx
Normal file
@@ -0,0 +1,42 @@
|
||||
import { clsx } from 'clsx';
|
||||
import type { MessageType } from 'langchain/schema';
|
||||
import { marked } from 'marked';
|
||||
import { gfmHeadingId } from 'marked-gfm-heading-id';
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-expect-error
|
||||
import { mangle } from 'marked-mangle';
|
||||
import { type ReactElement, useMemo } from 'react';
|
||||
|
||||
import {
|
||||
aiMessageStyle,
|
||||
conversationStyle,
|
||||
humanMessageStyle,
|
||||
} from './index.css';
|
||||
|
||||
marked.use(
|
||||
gfmHeadingId({
|
||||
prefix: 'affine-',
|
||||
})
|
||||
);
|
||||
|
||||
marked.use(mangle());
|
||||
|
||||
export interface ConversationProps {
|
||||
type: MessageType;
|
||||
text: string;
|
||||
}
|
||||
|
||||
export const Conversation = (props: ConversationProps): ReactElement => {
|
||||
const html = useMemo(() => marked.parse(props.text), [props.text]);
|
||||
return (
|
||||
<div
|
||||
className={clsx(conversationStyle, {
|
||||
[aiMessageStyle]: props.type === 'ai',
|
||||
[humanMessageStyle]: props.type === 'human',
|
||||
})}
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: html,
|
||||
}}
|
||||
></div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,17 @@
|
||||
import { style } from '@vanilla-extract/css';
|
||||
|
||||
export const followingUpStyle = style({
|
||||
display: 'flex',
|
||||
flexDirection: 'row',
|
||||
flexWrap: 'wrap',
|
||||
gap: '10px',
|
||||
alignItems: 'flex-start',
|
||||
});
|
||||
|
||||
export const questionStyle = style({
|
||||
backgroundColor: 'white',
|
||||
borderRadius: '8px',
|
||||
color: '#8E8D91',
|
||||
padding: '2px 10px',
|
||||
cursor: 'pointer',
|
||||
});
|
||||
19
plugins/copilot/src/core/components/following-up/index.tsx
Normal file
19
plugins/copilot/src/core/components/following-up/index.tsx
Normal file
@@ -0,0 +1,19 @@
|
||||
import type { ReactElement } from 'react';
|
||||
|
||||
import { followingUpStyle, questionStyle } from './index.css';
|
||||
|
||||
export type FollowingUpProps = {
|
||||
questions: string[];
|
||||
};
|
||||
|
||||
export const FollowingUp = (props: FollowingUpProps): ReactElement => {
|
||||
return (
|
||||
<div className={followingUpStyle}>
|
||||
{props.questions.map((question, index) => (
|
||||
<div className={questionStyle} key={index}>
|
||||
{question}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,6 +1,9 @@
|
||||
import type { IndexedDBChatMessageHistory } from '@affine/copilot/core/langchain/message-history';
|
||||
import { atom, useAtomValue } from 'jotai';
|
||||
import { atomFamily } from 'jotai/utils';
|
||||
import { atomWithDefault } from 'jotai/utils';
|
||||
import { atomWithStorage } from 'jotai/utils';
|
||||
import type { WritableAtom } from 'jotai/vanilla';
|
||||
import type { LLMChain } from 'langchain/chains';
|
||||
import { type ConversationChain } from 'langchain/chains';
|
||||
import { type BufferMemory } from 'langchain/memory';
|
||||
import {
|
||||
@@ -8,9 +11,12 @@ import {
|
||||
type BaseChatMessage,
|
||||
HumanChatMessage,
|
||||
} from 'langchain/schema';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { createChatAI } from '../chat';
|
||||
|
||||
const followupResponseSchema = z.array(z.string());
|
||||
|
||||
export const openAIApiKeyAtom = atomWithStorage<string | null>(
|
||||
'com.affine.copilot.openai.token',
|
||||
null
|
||||
@@ -19,12 +25,24 @@ export const openAIApiKeyAtom = atomWithStorage<string | null>(
|
||||
export const chatAtom = atom(async get => {
|
||||
const openAIApiKey = get(openAIApiKeyAtom);
|
||||
if (!openAIApiKey) {
|
||||
return null;
|
||||
throw new Error('OpenAI API key not set, chat will not work');
|
||||
}
|
||||
return createChatAI('default-copilot', openAIApiKey);
|
||||
});
|
||||
|
||||
const conversationAtomFamily = atomFamily((chat: ConversationChain | null) => {
|
||||
const conversationWeakMap = new WeakMap<
|
||||
ConversationChain,
|
||||
WritableAtom<BaseChatMessage[], [string], Promise<void>>
|
||||
>();
|
||||
|
||||
const getConversationAtom = (chat: ConversationChain) => {
|
||||
if (conversationWeakMap.has(chat)) {
|
||||
return conversationWeakMap.get(chat) as WritableAtom<
|
||||
BaseChatMessage[],
|
||||
[string],
|
||||
Promise<void>
|
||||
>;
|
||||
}
|
||||
const conversationBaseAtom = atom<BaseChatMessage[]>([]);
|
||||
conversationBaseAtom.onMount = setAtom => {
|
||||
if (!chat) {
|
||||
@@ -52,7 +70,7 @@ const conversationAtomFamily = atomFamily((chat: ConversationChain | null) => {
|
||||
};
|
||||
};
|
||||
|
||||
return atom<BaseChatMessage[], [string], Promise<void>>(
|
||||
const conversationAtom = atom<BaseChatMessage[], [string], Promise<void>>(
|
||||
get => get(conversationBaseAtom),
|
||||
async (get, set, input) => {
|
||||
if (!chat) {
|
||||
@@ -73,14 +91,75 @@ const conversationAtomFamily = atomFamily((chat: ConversationChain | null) => {
|
||||
});
|
||||
}
|
||||
);
|
||||
});
|
||||
conversationWeakMap.set(chat, conversationAtom);
|
||||
return conversationAtom;
|
||||
};
|
||||
|
||||
const followingUpWeakMap = new WeakMap<
|
||||
LLMChain<string>,
|
||||
{
|
||||
questionsAtom: ReturnType<typeof atomWithDefault<Promise<string[]>>>;
|
||||
generateChatAtom: WritableAtom<null, [], void>;
|
||||
}
|
||||
>();
|
||||
|
||||
const getFollowingUpAtoms = (
|
||||
followupLLMChain: LLMChain<string>,
|
||||
chatHistory: IndexedDBChatMessageHistory
|
||||
) => {
|
||||
if (followingUpWeakMap.has(followupLLMChain)) {
|
||||
return followingUpWeakMap.get(followupLLMChain) as {
|
||||
questionsAtom: ReturnType<typeof atomWithDefault<Promise<string[]>>>;
|
||||
generateChatAtom: WritableAtom<null, [], void>;
|
||||
};
|
||||
}
|
||||
const baseAtom = atomWithDefault<Promise<string[]>>(async () => {
|
||||
return chatHistory?.getFollowingUp() ?? [];
|
||||
});
|
||||
const setAtom = atom<null, [], void>(null, async (get, set) => {
|
||||
if (!followupLLMChain || !chatHistory) {
|
||||
throw new Error('followupLLMChain not set');
|
||||
}
|
||||
const messages = await chatHistory.getMessages();
|
||||
const aiMessage = messages.findLast(
|
||||
message => message._getType() === 'ai'
|
||||
)?.text;
|
||||
const humanMessage = messages.findLast(
|
||||
message => message._getType() === 'human'
|
||||
)?.text;
|
||||
const response = await followupLLMChain.call({
|
||||
ai_conversation: aiMessage,
|
||||
human_conversation: humanMessage,
|
||||
});
|
||||
const followingUp = JSON.parse(response.text);
|
||||
followupResponseSchema.parse(followingUp);
|
||||
set(baseAtom, followingUp);
|
||||
chatHistory.saveFollowingUp(followingUp).catch(() => {
|
||||
console.error('failed to save followup');
|
||||
});
|
||||
});
|
||||
followingUpWeakMap.set(followupLLMChain, {
|
||||
questionsAtom: baseAtom,
|
||||
generateChatAtom: setAtom,
|
||||
});
|
||||
return {
|
||||
questionsAtom: baseAtom,
|
||||
generateChatAtom: setAtom,
|
||||
};
|
||||
};
|
||||
|
||||
export function useChatAtoms(): {
|
||||
conversationAtom: ReturnType<typeof conversationAtomFamily>;
|
||||
conversationAtom: ReturnType<typeof getConversationAtom>;
|
||||
followingUpAtoms: ReturnType<typeof getFollowingUpAtoms>;
|
||||
} {
|
||||
const chat = useAtomValue(chatAtom);
|
||||
const conversationAtom = conversationAtomFamily(chat);
|
||||
const conversationAtom = getConversationAtom(chat.conversationChain);
|
||||
const followingUpAtoms = getFollowingUpAtoms(
|
||||
chat.followupChain,
|
||||
chat.chatHistory
|
||||
);
|
||||
return {
|
||||
conversationAtom,
|
||||
followingUpAtoms,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -23,25 +23,42 @@ interface ChatMessageDBV1 extends DBSchema {
|
||||
};
|
||||
}
|
||||
|
||||
interface ChatMessageDBV2 extends ChatMessageDBV1 {
|
||||
followingUp: {
|
||||
key: string;
|
||||
value: {
|
||||
/**
|
||||
* ID of the chat
|
||||
*/
|
||||
id: string;
|
||||
question: string[];
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
export const conversationHistoryDBName = 'affine-copilot-chat';
|
||||
|
||||
export class IndexedDBChatMessageHistory extends BaseChatMessageHistory {
|
||||
public id: string;
|
||||
private messages: BaseChatMessage[] = [];
|
||||
|
||||
private readonly dbPromise: Promise<IDBPDatabase<ChatMessageDBV1>>;
|
||||
private readonly dbPromise: Promise<IDBPDatabase<ChatMessageDBV2>>;
|
||||
private readonly initPromise: Promise<void>;
|
||||
|
||||
constructor(id: string) {
|
||||
super();
|
||||
this.id = id;
|
||||
this.messages = [];
|
||||
this.dbPromise = openDB<ChatMessageDBV1>('affine-copilot-chat', 1, {
|
||||
this.dbPromise = openDB<ChatMessageDBV2>('affine-copilot-chat', 2, {
|
||||
upgrade(database, oldVersion) {
|
||||
if (oldVersion === 0) {
|
||||
database.createObjectStore('chat', {
|
||||
keyPath: 'id',
|
||||
});
|
||||
} else if (oldVersion === 1) {
|
||||
database.createObjectStore('followingUp', {
|
||||
keyPath: 'id',
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -70,6 +87,31 @@ export class IndexedDBChatMessageHistory extends BaseChatMessageHistory {
|
||||
});
|
||||
}
|
||||
|
||||
public async saveFollowingUp(question: string[]): Promise<void> {
|
||||
await this.initPromise;
|
||||
const db = await this.dbPromise;
|
||||
const t = db
|
||||
.transaction('followingUp', 'readwrite')
|
||||
.objectStore('followingUp');
|
||||
await t.put({
|
||||
id: this.id,
|
||||
question,
|
||||
});
|
||||
}
|
||||
|
||||
public async getFollowingUp(): Promise<string[]> {
|
||||
await this.initPromise;
|
||||
const db = await this.dbPromise;
|
||||
const t = db
|
||||
.transaction('followingUp', 'readonly')
|
||||
.objectStore('followingUp');
|
||||
const chat = await t.get(this.id);
|
||||
if (chat != null) {
|
||||
return chat.question;
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
protected async addMessage(message: BaseChatMessage): Promise<void> {
|
||||
await this.initPromise;
|
||||
this.messages.push(message);
|
||||
@@ -104,6 +146,6 @@ export class IndexedDBChatMessageHistory extends BaseChatMessageHistory {
|
||||
}
|
||||
|
||||
async getMessages(): Promise<BaseChatMessage[]> {
|
||||
return await this.initPromise.then(() => this.messages);
|
||||
return this.initPromise.then(() => this.messages);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,8 +14,19 @@ Keep your answers short and impersonal.
|
||||
The user works in an app called AFFiNE, which has a concept for an editor, a page for a single document, workspace for a collection of documents.
|
||||
The active document is the markdown file the user is looking at.
|
||||
Use Markdown formatting in your answers.
|
||||
Wrap your answers into triple backticks.
|
||||
You can only give one reply for each conversation turn.
|
||||
You should always generate short suggestions for the next user turns that are relevant to the conversation and not offensive.
|
||||
You should reply to the users within 150 characters.
|
||||
`;
|
||||
|
||||
export const followupQuestionPrompt = `Rules you must follow:
|
||||
- You only respond in JSON format
|
||||
- Read the following conversation between AI and Human and generate at most 3 follow-up messages or questions the Human can ask
|
||||
- Your response MUST be a valid JSON array of strings like this: ["some question", "another question"]
|
||||
- Each message in your response should be concise, no more than 15 words
|
||||
- You MUST reply in the same written language as the conversation
|
||||
- Don't output anything other text
|
||||
The conversation is inside triple quotes:
|
||||
\`\`\`
|
||||
Human: {human_conversation}
|
||||
AI: {ai_conversation}
|
||||
\`\`\`
|
||||
`;
|
||||
|
||||
Reference in New Issue
Block a user