mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-02-13 12:55:00 +00:00
feat(core): add optionalModels field in AiPrompt and support the front-end modelId param (#12224)
Close [AI-116](https://linear.app/affine-design/issue/AI-116) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for specifying alternative AI models in chat prompts, enabling users to select from multiple available models. - Expanded AI model options with new additions: 'gpt-4.1', 'o3', and 'claude-3-5-sonnet-20241022'. - **Enhancements** - Users can now optionally choose a specific AI model during chat interactions. - Prompts and chat sessions reflect and support selection of alternative models where applicable. - **Bug Fixes** - Improved handling of prompt configuration defaults for better reliability. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -62,7 +62,7 @@ export interface ChatEvent {
|
||||
}
|
||||
|
||||
type CheckResult = {
|
||||
model: string | undefined;
|
||||
model: string;
|
||||
hasAttachment?: boolean;
|
||||
};
|
||||
|
||||
@@ -94,7 +94,8 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
private async checkRequest(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
messageId?: string
|
||||
messageId?: string,
|
||||
modelId?: string
|
||||
): Promise<CheckResult> {
|
||||
await this.chatSession.checkQuota(userId);
|
||||
const session = await this.chatSession.get(sessionId);
|
||||
@@ -102,7 +103,13 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
const ret: CheckResult = { model: session.model };
|
||||
const ret: CheckResult = {
|
||||
model: session.model,
|
||||
};
|
||||
|
||||
if (modelId && session.optionalModels.includes(modelId)) {
|
||||
ret.model = modelId;
|
||||
}
|
||||
|
||||
if (messageId && typeof messageId === 'string') {
|
||||
const message = await session.getMessageById(messageId);
|
||||
@@ -116,13 +123,16 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
private async chooseTextProvider(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
messageId?: string
|
||||
): Promise<CopilotTextProvider> {
|
||||
messageId?: string,
|
||||
modelId?: string
|
||||
): Promise<{ provider: CopilotTextProvider; model: string }> {
|
||||
const { hasAttachment, model } = await this.checkRequest(
|
||||
userId,
|
||||
sessionId,
|
||||
messageId
|
||||
messageId,
|
||||
modelId
|
||||
);
|
||||
|
||||
let provider = await this.provider.getProviderByCapability(
|
||||
CopilotCapability.TextToText,
|
||||
{ model }
|
||||
@@ -138,7 +148,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
throw new NoCopilotProviderAvailable();
|
||||
}
|
||||
|
||||
return provider;
|
||||
return { provider, model };
|
||||
}
|
||||
|
||||
private async appendSessionMessage(
|
||||
@@ -182,13 +192,17 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
const webSearch = Array.isArray(params.webSearch)
|
||||
? Boolean(params.webSearch[0])
|
||||
: Boolean(params.webSearch);
|
||||
const modelId = Array.isArray(params.modelId)
|
||||
? params.modelId[0]
|
||||
: params.modelId;
|
||||
|
||||
delete params.messageId;
|
||||
delete params.retry;
|
||||
delete params.reasoning;
|
||||
delete params.webSearch;
|
||||
delete params.modelId;
|
||||
|
||||
return { messageId, retry, reasoning, webSearch, params };
|
||||
return { messageId, retry, reasoning, webSearch, modelId, params };
|
||||
}
|
||||
|
||||
private getSignal(req: Request) {
|
||||
@@ -236,13 +250,14 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
const info: any = { sessionId, params };
|
||||
|
||||
try {
|
||||
const { messageId, retry, reasoning, webSearch } =
|
||||
const { messageId, retry, reasoning, webSearch, modelId } =
|
||||
this.prepareParams(params);
|
||||
|
||||
const provider = await this.chooseTextProvider(
|
||||
const { provider, model } = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId
|
||||
messageId,
|
||||
modelId
|
||||
);
|
||||
|
||||
const [latestMessage, session] = await this.appendSessionMessage(
|
||||
@@ -251,8 +266,8 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
retry
|
||||
);
|
||||
|
||||
info.model = session.model;
|
||||
metrics.ai.counter('chat_calls').add(1, { model: session.model });
|
||||
info.model = model;
|
||||
metrics.ai.counter('chat_calls').add(1, { model });
|
||||
|
||||
if (latestMessage) {
|
||||
params = Object.assign({}, params, latestMessage.params, {
|
||||
@@ -264,7 +279,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
const finalMessage = session.finish(params);
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
|
||||
const content = await provider.generateText(finalMessage, session.model, {
|
||||
const content = await provider.generateText(finalMessage, model, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
@@ -302,13 +317,14 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
const info: any = { sessionId, params, throwInStream: false };
|
||||
|
||||
try {
|
||||
const { messageId, retry, reasoning, webSearch } =
|
||||
const { messageId, retry, reasoning, webSearch, modelId } =
|
||||
this.prepareParams(params);
|
||||
|
||||
const provider = await this.chooseTextProvider(
|
||||
const { provider, model } = await this.chooseTextProvider(
|
||||
user.id,
|
||||
sessionId,
|
||||
messageId
|
||||
messageId,
|
||||
modelId
|
||||
);
|
||||
|
||||
const [latestMessage, session] = await this.appendSessionMessage(
|
||||
@@ -317,8 +333,8 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
retry
|
||||
);
|
||||
|
||||
info.model = session.model;
|
||||
metrics.ai.counter('chat_stream_calls').add(1, { model: session.model });
|
||||
info.model = model;
|
||||
metrics.ai.counter('chat_stream_calls').add(1, { model });
|
||||
|
||||
if (latestMessage) {
|
||||
params = Object.assign({}, params, latestMessage.params, {
|
||||
@@ -332,7 +348,7 @@ export class CopilotController implements BeforeApplicationShutdown {
|
||||
info.finalMessage = finalMessage.filter(m => m.role !== 'system');
|
||||
|
||||
const source$ = from(
|
||||
provider.generateTextStream(finalMessage, session.model, {
|
||||
provider.generateTextStream(finalMessage, model, {
|
||||
...session.config.promptConfig,
|
||||
signal: this.getSignal(req),
|
||||
user: user.id,
|
||||
|
||||
@@ -41,6 +41,7 @@ export class ChatPrompt {
|
||||
options.name,
|
||||
options.action || undefined,
|
||||
options.model,
|
||||
options.optionalModels,
|
||||
options.config,
|
||||
options.messages
|
||||
);
|
||||
@@ -50,6 +51,7 @@ export class ChatPrompt {
|
||||
public readonly name: string,
|
||||
public readonly action: string | undefined,
|
||||
public readonly model: string,
|
||||
public readonly optionalModels: string[],
|
||||
public readonly config: PromptConfig | undefined,
|
||||
private readonly messages: PromptMessage[]
|
||||
) {
|
||||
|
||||
@@ -5,8 +5,15 @@ import { PromptConfig, PromptMessage } from '../providers';
|
||||
|
||||
type Prompt = Omit<
|
||||
AiPrompt,
|
||||
'id' | 'createdAt' | 'updatedAt' | 'modified' | 'action' | 'config'
|
||||
| 'id'
|
||||
| 'createdAt'
|
||||
| 'updatedAt'
|
||||
| 'modified'
|
||||
| 'action'
|
||||
| 'config'
|
||||
| 'optionalModels'
|
||||
> & {
|
||||
optionalModels?: string[];
|
||||
action?: string;
|
||||
messages: PromptMessage[];
|
||||
config?: PromptConfig;
|
||||
@@ -1037,7 +1044,13 @@ Finally, please only send us the content of your continuation in Markdown Format
|
||||
const chat: Prompt[] = [
|
||||
{
|
||||
name: 'Chat With AFFiNE AI',
|
||||
model: 'o4-mini',
|
||||
model: 'gpt-4.1',
|
||||
optionalModels: [
|
||||
'o3',
|
||||
'o4-mini',
|
||||
'claude-3-7-sonnet-20250219',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
],
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
@@ -1161,14 +1174,15 @@ export async function refreshPrompts(db: PrismaClient) {
|
||||
create: {
|
||||
name: prompt.name,
|
||||
action: prompt.action,
|
||||
config: prompt.config || undefined,
|
||||
config: prompt.config ?? undefined,
|
||||
model: prompt.model,
|
||||
optionalModels: prompt.optionalModels,
|
||||
messages: {
|
||||
create: prompt.messages.map((message, idx) => ({
|
||||
idx,
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
params: message.params || undefined,
|
||||
params: message.params ?? undefined,
|
||||
})),
|
||||
},
|
||||
},
|
||||
@@ -1177,6 +1191,7 @@ export async function refreshPrompts(db: PrismaClient) {
|
||||
action: prompt.action,
|
||||
config: prompt.config ?? undefined,
|
||||
model: prompt.model,
|
||||
optionalModels: prompt.optionalModels,
|
||||
updatedAt: new Date(),
|
||||
messages: {
|
||||
deleteMany: {},
|
||||
@@ -1184,7 +1199,7 @@ export async function refreshPrompts(db: PrismaClient) {
|
||||
idx,
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
params: message.params || undefined,
|
||||
params: message.params ?? undefined,
|
||||
})),
|
||||
},
|
||||
},
|
||||
|
||||
@@ -64,6 +64,7 @@ export class PromptService implements OnApplicationBootstrap {
|
||||
name: true,
|
||||
action: true,
|
||||
model: true,
|
||||
optionalModels: true,
|
||||
config: true,
|
||||
messages: {
|
||||
select: {
|
||||
|
||||
@@ -34,7 +34,10 @@ export class AnthropicProvider
|
||||
{
|
||||
override readonly type = CopilotProviderType.Anthropic;
|
||||
override readonly capabilities = [CopilotCapability.TextToText];
|
||||
override readonly models = ['claude-3-7-sonnet-20250219'];
|
||||
override readonly models = [
|
||||
'claude-3-7-sonnet-20250219',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
];
|
||||
|
||||
private readonly MAX_STEPS = 20;
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ export class OpenAIProvider
|
||||
'gpt-4.1-2025-04-14',
|
||||
'gpt-4.1-mini',
|
||||
'o1',
|
||||
'o3',
|
||||
'o4-mini',
|
||||
// embeddings
|
||||
'text-embedding-3-large',
|
||||
|
||||
@@ -110,12 +110,12 @@ export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;
|
||||
export interface CopilotTextToTextProvider extends CopilotProvider {
|
||||
generateText(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
model: string,
|
||||
options?: CopilotChatOptions
|
||||
): Promise<string>;
|
||||
generateTextStream(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
model: string,
|
||||
options?: CopilotChatOptions
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
@@ -136,7 +136,7 @@ export interface CopilotTextToImageProvider extends CopilotProvider {
|
||||
): Promise<Array<string>>;
|
||||
generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
model: string,
|
||||
options?: CopilotImageOptions
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
@@ -145,12 +145,12 @@ export interface CopilotImageToTextProvider extends CopilotProvider {
|
||||
generateText(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options?: CopilotChatOptions
|
||||
options: CopilotChatOptions
|
||||
): Promise<string>;
|
||||
generateTextStream(
|
||||
messages: PromptMessage[],
|
||||
model: string,
|
||||
options?: CopilotChatOptions
|
||||
options: CopilotChatOptions
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ export interface CopilotImageToImageProvider extends CopilotProvider {
|
||||
): Promise<Array<string>>;
|
||||
generateImagesStream(
|
||||
messages: PromptMessage[],
|
||||
model?: string,
|
||||
model: string,
|
||||
options?: CopilotImageOptions
|
||||
): AsyncIterable<string>;
|
||||
}
|
||||
|
||||
@@ -45,6 +45,10 @@ export class ChatSession implements AsyncDisposable {
|
||||
return this.state.prompt.model;
|
||||
}
|
||||
|
||||
get optionalModels() {
|
||||
return this.state.prompt.optionalModels;
|
||||
}
|
||||
|
||||
get config() {
|
||||
const {
|
||||
sessionId,
|
||||
|
||||
Reference in New Issue
Block a user