fix: choose provider correctly (#7081)

fix no provider error in caption generate action
This commit is contained in:
darkskygit
2024-05-27 09:57:39 +00:00
parent 50dcce891b
commit 5ba9e2e9b1
4 changed files with 57 additions and 23 deletions

View File

@@ -34,7 +34,11 @@ import { Config } from '../../fundamentals';
import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotCapability } from './types';
import {
CopilotCapability,
CopilotImageToTextProvider,
CopilotTextToTextProvider,
} from './types';
export interface ChatEvent {
type: 'attachment' | 'message' | 'error';
@@ -71,7 +75,7 @@ export class CopilotController {
const ret: CheckResult = { model: session.model };
if (messageId) {
if (messageId && typeof messageId === 'string') {
const message = await session.getMessageById(messageId);
ret.hasAttachment =
Array.isArray(message.attachments) && !!message.attachments.length;
@@ -80,6 +84,34 @@ export class CopilotController {
return ret;
}
private async chooseTextProvider(
userId: string,
sessionId: string,
messageId?: string
): Promise<CopilotTextToTextProvider | CopilotImageToTextProvider> {
const { hasAttachment, model } = await this.checkRequest(
userId,
sessionId,
messageId
);
let provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
// fallback to image to text if text to text is not available
if (!provider && hasAttachment) {
provider = await this.provider.getProviderByCapability(
CopilotCapability.ImageToText,
model
);
}
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
return provider;
}
private async appendSessionMessage(
sessionId: string,
messageId?: string
@@ -139,18 +171,15 @@ export class CopilotController {
@Param('sessionId') sessionId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const provider = await this.chooseTextProvider(
user.id,
sessionId,
messageId
);
const session = await this.appendSessionMessage(sessionId, messageId);
try {
@@ -187,18 +216,15 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const provider = await this.chooseTextProvider(
user.id,
sessionId,
messageId
);
const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;