diff --git a/apps/sim/app/api/copilot/chat/stop/route.test.ts b/apps/sim/app/api/copilot/chat/stop/route.test.ts index 0ac05257bf6..a87f35a2987 100644 --- a/apps/sim/app/api/copilot/chat/stop/route.test.ts +++ b/apps/sim/app/api/copilot/chat/stop/route.test.ts @@ -29,6 +29,16 @@ const { mockSql: vi.fn((strings: TemplateStringsArray, ...values: unknown[]) => ({ strings, values })), })) +vi.mock('@sim/db/schema', () => ({ + copilotChats: { + id: 'copilotChats.id', + userId: 'copilotChats.userId', + workspaceId: 'copilotChats.workspaceId', + messages: 'copilotChats.messages', + conversationId: 'copilotChats.conversationId', + }, +})) + vi.mock('@sim/db', () => ({ db: { select: mockSelect, @@ -140,6 +150,7 @@ describe('copilot chat stop route', () => { workspaceId: 'ws-1', chatId: 'chat-1', type: 'completed', + streamId: 'stream-1', }) }) }) diff --git a/apps/sim/app/api/copilot/chat/stop/route.ts b/apps/sim/app/api/copilot/chat/stop/route.ts index ad7da7386d1..36d3b8ae43d 100644 --- a/apps/sim/app/api/copilot/chat/stop/route.ts +++ b/apps/sim/app/api/copilot/chat/stop/route.ts @@ -111,6 +111,7 @@ export const POST = withRouteHandler((req: NextRequest) => workspaceId: updated.workspaceId, chatId, type: 'completed', + streamId, }) } diff --git a/apps/sim/app/api/copilot/chat/stream/route.ts b/apps/sim/app/api/copilot/chat/stream/route.ts index 04ba6109a56..7cc61cb6447 100644 --- a/apps/sim/app/api/copilot/chat/stream/route.ts +++ b/apps/sim/app/api/copilot/chat/stream/route.ts @@ -248,6 +248,7 @@ async function handleResumeRequestBody({ events: batchEvents, previewSessions, status: run.status, + ...(run.chatId ? { chatId: run.chatId } : {}), }) } diff --git a/apps/sim/app/api/mothership/events/route.ts b/apps/sim/app/api/mothership/events/route.ts index 420f2be3fdb..bb3e1f278c8 100644 --- a/apps/sim/app/api/mothership/events/route.ts +++ b/apps/sim/app/api/mothership/events/route.ts @@ -27,6 +27,7 @@ const mothershipEventsHandler = createWorkspaceSSE({ send('task_status', { chatId: event.chatId, type: event.type, + ...(event.streamId ? { streamId: event.streamId } : {}), timestamp: Date.now(), }) }) diff --git a/apps/sim/app/workspace/[workspaceId]/home/hooks/use-chat.ts b/apps/sim/app/workspace/[workspaceId]/home/hooks/use-chat.ts index d20ce7ac805..6d2c7527a6b 100644 --- a/apps/sim/app/workspace/[workspaceId]/home/hooks/use-chat.ts +++ b/apps/sim/app/workspace/[workspaceId]/home/hooks/use-chat.ts @@ -188,8 +188,16 @@ const RECONNECT_TAIL_ERROR = const MAX_RECONNECT_ATTEMPTS = 10 const RECONNECT_BASE_DELAY_MS = 1000 const RECONNECT_MAX_DELAY_MS = 30_000 +const STREAM_BATCH_FETCH_TIMEOUT_MS = 10_000 +const STREAM_CHAT_ID_RESOLVE_TIMEOUT_MS = 10_000 +const CHAT_HISTORY_RECOVERY_TIMEOUT_MS = 10_000 +const STOP_REQUEST_TIMEOUT_MS = 15_000 const QUEUED_SEND_HANDOFF_STORAGE_KEY = `${STREAM_STORAGE_KEY}:queued-send-handoff` const QUEUED_SEND_HANDOFF_CLAIM_STORAGE_KEY = `${STREAM_STORAGE_KEY}:queued-send-handoff-claim` +const QUEUED_SEND_HANDOFF_TTL_MS = 5 * 60 * 1000 +const QUEUED_SEND_HANDOFF_CLAIM_TTL_MS = 30_000 +const QUEUED_SEND_HANDOFF_RETRY_BASE_MS = 1000 +const QUEUED_SEND_HANDOFF_RETRY_MAX_MS = 30_000 const logger = createLogger('useChat') @@ -208,7 +216,7 @@ type ActiveTurn = { interface QueuedSendHandoffState { id: string - chatId: string + chatId?: string workspaceId: string supersededStreamId: string | null userMessageId: string @@ -216,15 +224,152 @@ interface QueuedSendHandoffState { fileAttachments?: FileAttachmentForApi[] contexts?: ChatContext[] requestedAt: number + resolveAttempts?: number } interface QueuedSendHandoffSeed { id: string - chatId: string + chatId?: string supersededStreamId: string | null userMessageId?: string } +type QueuedChatMessage = QueuedMessage & { + queuedSendHandoff?: QueuedSendHandoffSeed +} + +interface QueuedSendHandoffClaim { + id: string + ownerId: string + claimedAt: number +} + +interface ActiveQueuedSendHandoffRecovery { + id: string + ownerId: string +} + +function createTimeoutSignal(ms: number): AbortSignal | undefined { + if (typeof AbortSignal !== 'undefined' && typeof AbortSignal.timeout === 'function') { + return AbortSignal.timeout(ms) + } + if (typeof AbortController === 'undefined') return undefined + + const controller = new AbortController() + const timeout = setTimeout(() => { + controller.abort(new Error(`Operation timed out after ${ms}ms`)) + }, ms) + controller.signal.addEventListener('abort', () => clearTimeout(timeout), { once: true }) + return controller.signal +} + +function combineAbortSignals(...signals: (AbortSignal | undefined)[]): AbortSignal | undefined { + const activeSignals = signals.filter((signal): signal is AbortSignal => Boolean(signal)) + if (activeSignals.length === 0) return undefined + if (activeSignals.length === 1) return activeSignals[0] + if (typeof AbortSignal !== 'undefined' && typeof AbortSignal.any === 'function') { + return AbortSignal.any(activeSignals) + } + if (typeof AbortController === 'undefined') return activeSignals[0] + + const controller = new AbortController() + const abortFromSource = (source: AbortSignal) => { + cleanup() + controller.abort(source.reason) + } + const listeners = activeSignals.map((signal) => { + const listener = () => abortFromSource(signal) + signal.addEventListener('abort', listener, { once: true }) + return { signal, listener } + }) + function cleanup() { + for (const { signal, listener } of listeners) { + signal.removeEventListener('abort', listener) + } + } + for (const signal of activeSignals) { + if (signal.aborted) { + abortFromSource(signal) + break + } + } + controller.signal.addEventListener('abort', cleanup, { once: true }) + return controller.signal +} + +function createAbortError(signal: AbortSignal): Error { + const error = new Error(signal.reason ? String(signal.reason) : 'Operation aborted') + error.name = 'AbortError' + return error +} + +async function sleepWithAbort(ms: number, signal?: AbortSignal) { + if (!signal) { + await sleep(ms) + return + } + if (signal.aborted) throw createAbortError(signal) + + let cleanup: (() => void) | undefined + await Promise.race([ + sleep(ms), + new Promise((_, reject) => { + const onAbort = () => reject(createAbortError(signal)) + cleanup = () => signal.removeEventListener('abort', onAbort) + signal.addEventListener('abort', onAbort, { once: true }) + }), + ]).finally(() => cleanup?.()) +} + +function isFileAttachmentForApi(value: unknown): value is FileAttachmentForApi { + if (!isRecord(value)) return false + return ( + typeof value.id === 'string' && + typeof value.key === 'string' && + typeof value.filename === 'string' && + typeof value.media_type === 'string' && + typeof value.size === 'number' && + Number.isFinite(value.size) && + (value.path === undefined || typeof value.path === 'string') + ) +} + +function isChatContext(value: unknown): value is ChatContext { + if (!isRecord(value) || typeof value.kind !== 'string' || typeof value.label !== 'string') { + return false + } + + switch (value.kind) { + case 'past_chat': + return typeof value.chatId === 'string' + case 'workflow': + case 'current_workflow': + return typeof value.workflowId === 'string' + case 'blocks': + return Array.isArray(value.blockIds) && value.blockIds.every((id) => typeof id === 'string') + case 'logs': + return value.executionId === undefined || typeof value.executionId === 'string' + case 'workflow_block': + return typeof value.workflowId === 'string' && typeof value.blockId === 'string' + case 'knowledge': + return value.knowledgeId === undefined || typeof value.knowledgeId === 'string' + case 'table': + return typeof value.tableId === 'string' + case 'file': + return typeof value.fileId === 'string' + case 'folder': + return typeof value.folderId === 'string' + case 'templates': + return value.templateId === undefined || typeof value.templateId === 'string' + case 'docs': + return true + case 'slash_command': + return typeof value.command === 'string' + default: + return false + } +} + function readQueuedSendHandoffState(): QueuedSendHandoffState | null { if (typeof window === 'undefined') return null @@ -233,30 +378,46 @@ function readQueuedSendHandoffState(): QueuedSendHandoffState | null { if (!raw) return null const parsed = JSON.parse(raw) as Partial + const chatId = typeof parsed.chatId === 'string' ? parsed.chatId : undefined + const supersededStreamId = + typeof parsed.supersededStreamId === 'string' ? parsed.supersededStreamId : null if ( typeof parsed?.id !== 'string' || - typeof parsed.chatId !== 'string' || typeof parsed.workspaceId !== 'string' || typeof parsed.userMessageId !== 'string' || typeof parsed.message !== 'string' || - typeof parsed.requestedAt !== 'number' + typeof parsed.requestedAt !== 'number' || + (!chatId && !supersededStreamId) ) { return null } + if (Date.now() - parsed.requestedAt > QUEUED_SEND_HANDOFF_TTL_MS) { + window.sessionStorage.removeItem(QUEUED_SEND_HANDOFF_STORAGE_KEY) + if (readQueuedSendHandoffClaim() === parsed.id) { + window.sessionStorage.removeItem(QUEUED_SEND_HANDOFF_CLAIM_STORAGE_KEY) + } + return null + } return { id: parsed.id, - chatId: parsed.chatId, + ...(chatId ? { chatId } : {}), workspaceId: parsed.workspaceId, - supersededStreamId: - typeof parsed.supersededStreamId === 'string' ? parsed.supersededStreamId : null, + supersededStreamId, userMessageId: parsed.userMessageId, message: parsed.message, ...(Array.isArray(parsed.fileAttachments) - ? { fileAttachments: parsed.fileAttachments as FileAttachmentForApi[] } + ? { fileAttachments: parsed.fileAttachments.filter(isFileAttachmentForApi) } + : {}), + ...(Array.isArray(parsed.contexts) + ? { contexts: parsed.contexts.filter(isChatContext) } : {}), - ...(Array.isArray(parsed.contexts) ? { contexts: parsed.contexts as ChatContext[] } : {}), requestedAt: parsed.requestedAt, + ...(typeof parsed.resolveAttempts === 'number' && + Number.isFinite(parsed.resolveAttempts) && + parsed.resolveAttempts > 0 + ? { resolveAttempts: parsed.resolveAttempts } + : {}), } } catch { return null @@ -279,21 +440,73 @@ function clearQueuedSendHandoffState(expectedId?: string) { window.sessionStorage.removeItem(QUEUED_SEND_HANDOFF_STORAGE_KEY) } -function readQueuedSendHandoffClaim(): string | null { +function readQueuedSendHandoffClaimState(): QueuedSendHandoffClaim | null { if (typeof window === 'undefined') return null - return window.sessionStorage.getItem(QUEUED_SEND_HANDOFF_CLAIM_STORAGE_KEY) + const raw = window.sessionStorage.getItem(QUEUED_SEND_HANDOFF_CLAIM_STORAGE_KEY) + if (!raw) return null + + try { + const parsed = JSON.parse(raw) as Partial + if ( + typeof parsed?.id !== 'string' || + typeof parsed.ownerId !== 'string' || + typeof parsed.claimedAt !== 'number' + ) { + window.sessionStorage.removeItem(QUEUED_SEND_HANDOFF_CLAIM_STORAGE_KEY) + return null + } + if (Date.now() - parsed.claimedAt > QUEUED_SEND_HANDOFF_CLAIM_TTL_MS) { + window.sessionStorage.removeItem(QUEUED_SEND_HANDOFF_CLAIM_STORAGE_KEY) + return null + } + return { id: parsed.id, ownerId: parsed.ownerId, claimedAt: parsed.claimedAt } + } catch { + window.sessionStorage.removeItem(QUEUED_SEND_HANDOFF_CLAIM_STORAGE_KEY) + return null + } } -function writeQueuedSendHandoffClaim(id: string) { - if (typeof window === 'undefined') return - window.sessionStorage.setItem(QUEUED_SEND_HANDOFF_CLAIM_STORAGE_KEY, id) +function readQueuedSendHandoffClaim(): string | null { + return readQueuedSendHandoffClaimState()?.id ?? null +} + +function hasQueuedSendHandoffClaimOwner(id: string, ownerId: string): boolean { + const claim = readQueuedSendHandoffClaimState() + return claim?.id === id && claim.ownerId === ownerId +} + +function queuedSendHandoffClaimRetryDelay(id: string): number | null { + const claim = readQueuedSendHandoffClaimState() + if (!claim || claim.id !== id) return null + const elapsed = Date.now() - claim.claimedAt + return Math.max(0, QUEUED_SEND_HANDOFF_CLAIM_TTL_MS - elapsed + 1) +} + +function queuedSendHandoffResolveRetryDelay(resolveAttempts: number): number { + return Math.min( + QUEUED_SEND_HANDOFF_RETRY_MAX_MS, + QUEUED_SEND_HANDOFF_RETRY_BASE_MS * 2 ** Math.max(0, resolveAttempts - 1) + ) +} + +function writeQueuedSendHandoffClaim(id: string): string { + const ownerId = generateId() + if (typeof window === 'undefined') return ownerId + window.sessionStorage.setItem( + QUEUED_SEND_HANDOFF_CLAIM_STORAGE_KEY, + JSON.stringify({ id, ownerId, claimedAt: Date.now() } satisfies QueuedSendHandoffClaim) + ) + return ownerId } -function clearQueuedSendHandoffClaim(expectedId?: string) { +function clearQueuedSendHandoffClaim(expectedId?: string, expectedOwnerId?: string) { if (typeof window === 'undefined') return if (expectedId) { - const current = readQueuedSendHandoffClaim() - if (current && current !== expectedId) { + const current = readQueuedSendHandoffClaimState() + if ( + current && + (current.id !== expectedId || (expectedOwnerId && current.ownerId !== expectedOwnerId)) + ) { return } } @@ -619,6 +832,7 @@ type StreamBatchResponse = { events: StreamBatchEvent[] previewSessions?: FilePreviewSession[] status: string + chatId?: string } function isRecord(value: unknown): value is Record { @@ -703,7 +917,25 @@ function parseStreamBatchResponse(value: unknown): StreamBatchResponse { events, ...(previewSessions ? { previewSessions } : {}), status: typeof value.status === 'string' ? value.status : 'unknown', + ...(typeof value.chatId === 'string' && value.chatId ? { chatId: value.chatId } : {}), + } +} + +function resolveChatIdFromStreamBatch(batch: StreamBatchResponse): string | undefined { + if (batch.chatId) return batch.chatId + + for (const { event } of batch.events) { + const streamChatId = typeof event.stream?.chatId === 'string' ? event.stream.chatId : undefined + if (streamChatId) return streamChatId + if ( + event.type === MothershipStreamV1EventType.session && + event.payload.kind === MothershipStreamV1SessionKind.chat + ) { + return event.payload.chatId + } } + + return undefined } function toRawPersistedContentBlock(block: ContentBlock): Record | null { @@ -890,6 +1122,28 @@ function isTerminalStreamStatus(status: string | null | undefined): boolean { return TERMINAL_STREAM_STATUSES.has(status ?? '') } +function isAlreadyProcessedStreamCursor( + eventCursor: string | undefined, + currentCursor: string +): boolean { + if (!eventCursor) return false + + const eventSequence = Number(eventCursor) + const currentSequence = Number(currentCursor) + return ( + Number.isFinite(eventSequence) && + Number.isFinite(currentSequence) && + eventSequence <= currentSequence + ) +} + +function buildRecoverySubjectKey( + chatId: string | undefined, + selectedChatId: string | undefined +): string { + return `${chatId ?? ''}:${selectedChatId ?? ''}` +} + const sseEncoder = new TextEncoder() function buildReplayStream(events: StreamBatchEvent[]): ReadableStream { return new ReadableStream({ @@ -1026,6 +1280,18 @@ export interface UseChatOptions { initialActiveResourceId?: string | null } +interface ActiveStreamRecovery { + subjectKey: string + controller: AbortController + promise: Promise +} + +type StopGenerationMode = 'normal' | 'queued-handoff' + +interface StopGenerationOptions { + mode?: StopGenerationMode +} + export function getMothershipUseChatOptions( options: Pick = {} ): UseChatOptions { @@ -1062,6 +1328,7 @@ export function useChat( const [isReconnecting, setIsReconnecting] = useState(false) const [error, setError] = useState(null) const [resolvedChatId, setResolvedChatId] = useState(initialChatId) + const [queuedHandoffRecoveryEpoch, setQueuedHandoffRecoveryEpoch] = useState(0) const [resources, setResources] = useState([]) const [activeResourceId, setActiveResourceId] = useState( options?.initialActiveResourceId ?? null @@ -1075,6 +1342,7 @@ export function useChat( const stopPathRef = useRef(options?.stopPath ?? '/api/mothership/chat/stop') stopPathRef.current = options?.stopPath ?? '/api/mothership/chat/stop' const pendingStopPromiseRef = useRef | null>(null) + const pendingStopModeRef = useRef(null) const workflowIdRef = useRef(options?.workflowId) workflowIdRef.current = options?.workflowId const onToolResultRef = useRef(options?.onToolResult) @@ -1208,8 +1476,8 @@ export function useChat( [removePreviewSession, syncPreviewSessionRefs] ) - const [messageQueue, setMessageQueue] = useState([]) - const messageQueueRef = useRef([]) + const [messageQueue, setMessageQueue] = useState([]) + const messageQueueRef = useRef([]) messageQueueRef.current = messageQueue const queuedMessageDispatchIdsRef = useRef>(new Set()) const queueDispatchActionsRef = useRef([]) @@ -1225,7 +1493,12 @@ export function useChat( reader: ReadableStreamDefaultReader, assistantId: string, expectedGen?: number, - options?: { preserveExistingState?: boolean; suppressWorkflowToolStarts?: boolean } + options?: { + preserveExistingState?: boolean + suppressWorkflowToolStarts?: boolean + targetChatId?: string + shouldContinue?: () => boolean + } ) => Promise<{ sawStreamError: boolean; sawComplete: boolean }> >(async () => ({ sawStreamError: false, sawComplete: false })) const attachToExistingStreamRef = useRef< @@ -1235,13 +1508,23 @@ export function useChat( expectedGen: number initialBatch?: StreamBatchResponse | null afterCursor?: string + targetChatId?: string + shouldContinue?: () => boolean }) => Promise<{ error: boolean; aborted: boolean }> >(async () => ({ error: false, aborted: true })) const retryReconnectRef = useRef< - (opts: { streamId: string; assistantId: string; gen: number }) => Promise + (opts: { + streamId: string + assistantId: string + gen: number + targetChatId?: string + shouldContinue?: () => boolean + }) => Promise >(async () => false) - const finalizeRef = useRef<(options?: { error?: boolean }) => void>(() => {}) - const recoveringQueuedSendHandoffIdRef = useRef(null) + const finalizeRef = useRef<(options?: { error?: boolean; targetChatId?: string }) => void>( + () => {} + ) + const recoveringQueuedSendHandoffRef = useRef(null) const resetEphemeralPreviewState = useCallback( (options?: { removeStreamingResource?: boolean }) => { @@ -1327,6 +1610,7 @@ export function useChat( const streamRequestIdRef = useRef(undefined) const locallyTerminalStreamIdRef = useRef(undefined) const lastCursorRef = useRef('0') + const activeStreamReturnRecoveryRef = useRef(null) const sendingRef = useRef(false) const streamGenRef = useRef(0) const streamingContentRef = useRef('') @@ -1354,6 +1638,23 @@ export function useChat( setIsReconnecting(true) }, []) + const cancelActiveStreamRecovery = useCallback(() => { + const recovery = activeStreamReturnRecoveryRef.current + if (!recovery) return + recovery.controller.abort('superseded_recovery') + activeStreamReturnRecoveryRef.current = null + }, []) + + const cancelActiveStreamReader = useCallback(() => { + const reader = streamReaderRef.current + streamReaderRef.current = null + void reader?.cancel().catch((error) => { + logger.warn('Failed to cancel detached stream reader', { + error: toError(error).message, + }) + }) + }, []) + const resetStreamingBuffers = useCallback(() => { streamingContentRef.current = '' streamingBlocksRef.current = [] @@ -1371,7 +1672,9 @@ export function useChat( }, [resetStreamingBuffers]) const resetHomeChatState = useCallback(() => { + cancelActiveStreamRecovery() streamGenRef.current++ + cancelActiveStreamReader() chatIdRef.current = undefined lastCursorRef.current = '0' locallyTerminalStreamIdRef.current = undefined @@ -1387,7 +1690,36 @@ export function useChat( resetEphemeralPreviewState() setMessageQueue([]) clearQueueDispatchState() - }, [clearActiveTurn, clearQueueDispatchState, resetEphemeralPreviewState, setTransportIdle]) + }, [ + cancelActiveStreamRecovery, + cancelActiveStreamReader, + clearActiveTurn, + clearQueueDispatchState, + resetEphemeralPreviewState, + setTransportIdle, + ]) + + const adoptResolvedChatId = useCallback( + (chatId: string, options?: { replaceHomeHistory?: boolean; invalidateList?: boolean }) => { + const selectedChatId = selectedChatIdRef.current + chatIdRef.current = chatId + if (!selectedChatId || selectedChatId === chatId) { + setResolvedChatId(chatId) + } + if ( + options?.replaceHomeHistory && + !selectedChatId && + !workflowIdRef.current && + typeof window !== 'undefined' + ) { + window.history.replaceState(null, '', `/workspace/${workspaceId}/task/${chatId}`) + } + if (options?.invalidateList) { + queryClient.invalidateQueries({ queryKey: taskKeys.list(workspaceId) }) + } + }, + [queryClient, workspaceId] + ) const { data: chatHistory } = useChatHistory(resolvedChatId) const messages = useMemo(() => { @@ -1545,7 +1877,9 @@ export function useChat( const abandonedChatId = streamOwnerId // Detach the current UI from the old stream without cancelling it on the server. // Reopening that chat later will reconnect through the existing chatHistory flow. + cancelActiveStreamRecovery() streamGenRef.current++ + cancelActiveStreamReader() abortControllerRef.current = null clearActiveTurn() setTransportIdle() @@ -1557,6 +1891,8 @@ export function useChat( return } } + cancelActiveStreamRecovery() + cancelActiveStreamReader() chatIdRef.current = initialChatId lastCursorRef.current = '0' locallyTerminalStreamIdRef.current = undefined @@ -1578,6 +1914,8 @@ export function useChat( clearQueueDispatchState, clearActiveTurn, setTransportIdle, + cancelActiveStreamRecovery, + cancelActiveStreamReader, ]) useEffect(() => { @@ -1655,9 +1993,12 @@ export function useChat( if (shouldReconnectActiveStream && activeStreamId) { const gen = ++streamGenRef.current const abortController = new AbortController() + const previousStreamId = streamIdRef.current ?? activeTurnRef.current?.userMessageId + const reconnectAfterCursor = + previousStreamId === activeStreamId ? lastCursorRef.current || '0' : '0' abortControllerRef.current = abortController streamIdRef.current = activeStreamId - lastCursorRef.current = '0' + lastCursorRef.current = reconnectAfterCursor setTransportReconnecting() const assistantId = getLiveAssistantMessageId(activeStreamId) @@ -1668,21 +2009,34 @@ export function useChat( ? (initialSnapshot.events as StreamBatchEvent[]) : [] - const reconnectResult = - snapshotEvents.length > 0 - ? await attachToExistingStreamRef.current({ - streamId: activeStreamId, - assistantId, - expectedGen: gen, - initialBatch: { - success: true, - events: snapshotEvents, - previewSessions: snapshotPreviewSessions, - status: initialSnapshot?.status ?? 'unknown', - }, - afterCursor: String(snapshotEvents[snapshotEvents.length - 1]?.eventId ?? '0'), - }) - : null + let reconnectResult: Awaited> | null = + null + const replaySnapshotEvents = snapshotEvents.filter( + (entry) => !isAlreadyProcessedStreamCursor(String(entry.eventId), reconnectAfterCursor) + ) + if (replaySnapshotEvents.length > 0) { + try { + reconnectResult = await attachToExistingStreamRef.current({ + streamId: activeStreamId, + assistantId, + expectedGen: gen, + initialBatch: { + success: true, + events: replaySnapshotEvents, + previewSessions: snapshotPreviewSessions, + status: initialSnapshot?.status ?? 'unknown', + }, + afterCursor: reconnectAfterCursor, + targetChatId: chatHistory.id, + }) + } catch (error) { + logger.warn('Snapshot stream reconnect failed; falling back to retry', { + chatId: chatHistory.id, + streamId: activeStreamId, + error: toError(error).message, + }) + } + } const succeeded = reconnectResult !== null @@ -1691,9 +2045,10 @@ export function useChat( streamId: activeStreamId, assistantId, gen, + targetChatId: chatHistory.id, }) if (succeeded && streamGenRef.current === gen && sendingRef.current) { - finalizeRef.current() + finalizeRef.current({ targetChatId: chatHistory.id }) return } if (succeeded && streamGenRef.current === gen) { @@ -1703,7 +2058,7 @@ export function useChat( } if (!succeeded && streamGenRef.current === gen) { try { - finalizeRef.current({ error: true }) + finalizeRef.current({ error: true, targetChatId: chatHistory.id }) } catch { setTransportIdle() abortControllerRef.current = null @@ -1728,9 +2083,21 @@ export function useChat( reader: ReadableStreamDefaultReader, assistantId: string, expectedGen?: number, - options?: { preserveExistingState?: boolean; suppressWorkflowToolStarts?: boolean } + options?: { + preserveExistingState?: boolean + suppressWorkflowToolStarts?: boolean + targetChatId?: string + shouldContinue?: () => boolean + } ) => { const decoder = new TextDecoder() + const isReaderStale = () => + (expectedGen !== undefined && streamGenRef.current !== expectedGen) || + options?.shouldContinue?.() === false + if (isReaderStale()) { + void reader.cancel().catch(() => {}) + return { sawStreamError: false, sawComplete: false } + } streamReaderRef.current = reader let buffer = '' @@ -1886,7 +2253,7 @@ export function useChat( })}` } - const isStale = () => expectedGen !== undefined && streamGenRef.current !== expectedGen + const isStale = isReaderStale let sawStreamError = false let sawCompleteEvent = false let scheduledTextFlushFrame: number | null = null @@ -1899,7 +2266,7 @@ export function useChat( [assistantId, streamRequestId], runningText ) - const activeChatId = chatIdRef.current + const activeChatId = options?.targetChatId ?? chatIdRef.current if (!activeChatId) { const snapshot: Partial = { content: runningText, @@ -2001,10 +2368,14 @@ export function useChat( if (parsed.stream?.streamId) { streamIdRef.current = parsed.stream.streamId } - if (parsed.stream?.cursor) { - lastCursorRef.current = parsed.stream.cursor - } else if (typeof parsed.seq === 'number') { - lastCursorRef.current = String(parsed.seq) + const eventCursor = + parsed.stream?.cursor ?? + (typeof parsed.seq === 'number' ? String(parsed.seq) : undefined) + if (isAlreadyProcessedStreamCursor(eventCursor, lastCursorRef.current)) { + continue + } + if (eventCursor) { + lastCursorRef.current = eventCursor } logger.debug('SSE event received', parsed) @@ -2729,14 +3100,19 @@ export function useChat( processSSEStreamRef.current = processSSEStream const getActiveStreamIdForChat = useCallback( - async (chatId: string): Promise => { + async (chatId: string, signal?: AbortSignal): Promise => { const cached = queryClient.getQueryData(taskKeys.detail(chatId)) if (cached?.activeStreamId) { return cached.activeStreamId } try { - const history = await fetchChatHistory(chatId) + const fetchSignal = combineAbortSignals( + signal, + createTimeoutSignal(CHAT_HISTORY_RECOVERY_TIMEOUT_MS) + ) + const history = await fetchChatHistory(chatId, fetchSignal) + if (signal?.aborted || fetchSignal?.aborted) return null queryClient.setQueryData(taskKeys.detail(chatId), history) return history.activeStreamId ?? null } catch (error) { @@ -2756,11 +3132,15 @@ export function useChat( afterCursor: string, signal?: AbortSignal ): Promise => { + const fetchSignal = combineAbortSignals( + signal, + createTimeoutSignal(STREAM_BATCH_FETCH_TIMEOUT_MS) + ) // boundary-raw-fetch: stream-resume batch endpoint requires dynamic per-request traceparent header propagation that the contract layer does not model, and the response is consumed alongside live SSE tail fetches const response = await fetch( `/api/mothership/chat/stream?streamId=${encodeURIComponent(streamId)}&after=${encodeURIComponent(afterCursor)}&batch=true`, { - signal, + signal: fetchSignal, ...(streamTraceparentRef.current ? { headers: { traceparent: streamTraceparentRef.current } } : {}), @@ -2769,11 +3149,69 @@ export function useChat( if (!response.ok) { throw new Error(`Stream resume batch failed: ${response.status}`) } - const batch = parseStreamBatchResponse(await response.json()) + return parseStreamBatchResponse(await response.json()) + }, + [] + ) + + const resolveChatIdForStream = useCallback( + async ( + streamId: string, + options?: { preferExistingChatId?: boolean; signal?: AbortSignal } + ): Promise => { + if (options?.preferExistingChatId !== false) { + const existingChatId = chatIdRef.current ?? selectedChatIdRef.current + if (existingChatId) return existingChatId + } + + const deadline = Date.now() + STREAM_CHAT_ID_RESOLVE_TIMEOUT_MS + let retryDelayMs = 250 + let lastError: unknown + + while (Date.now() < deadline) { + if (options?.signal?.aborted) throw createAbortError(options.signal) + const remainingMs = Math.max(1, deadline - Date.now()) + try { + const batch = await fetchStreamBatch( + streamId, + '0', + combineAbortSignals( + options?.signal, + createTimeoutSignal(Math.min(remainingMs, STREAM_BATCH_FETCH_TIMEOUT_MS)) + ) + ) + const chatId = resolveChatIdFromStreamBatch(batch) + if (chatId) return chatId + } catch (error) { + lastError = error + if (error instanceof Error && error.name === 'AbortError' && Date.now() >= deadline) { + break + } + } + + await sleepWithAbort( + Math.min(retryDelayMs, Math.max(1, deadline - Date.now())), + options?.signal + ) + retryDelayMs = Math.min(retryDelayMs * 2, 2000) + } + + if (lastError) { + logger.warn('Failed to resolve chat id for stream before timeout', { + streamId, + error: toError(lastError).message, + }) + } + return undefined + }, + [fetchStreamBatch] + ) + + const seedStreamBatchPreviewSessions = useCallback( + (batch: StreamBatchResponse) => { if (Array.isArray(batch.previewSessions) && batch.previewSessions.length > 0) { seedPreviewSessions(batch.previewSessions) } - return batch }, [seedPreviewSessions] ) @@ -2785,15 +3223,26 @@ export function useChat( expectedGen: number initialBatch?: StreamBatchResponse | null afterCursor?: string + targetChatId?: string + shouldContinue?: () => boolean }): Promise<{ error: boolean; aborted: boolean }> => { - const { streamId, assistantId, expectedGen, afterCursor = '0' } = opts + const { + streamId, + assistantId, + expectedGen, + afterCursor = '0', + targetChatId, + shouldContinue, + } = opts let latestCursor = afterCursor let seedEvents = opts.initialBatch?.events ?? [] let streamStatus = opts.initialBatch?.status ?? 'unknown' let suppressSeedWorkflowStarts = seedEvents.length > 0 const isStaleReconnect = () => - streamGenRef.current !== expectedGen || abortControllerRef.current?.signal.aborted === true + streamGenRef.current !== expectedGen || + abortControllerRef.current?.signal.aborted === true || + shouldContinue?.() === false if (isStaleReconnect()) { return { error: false, aborted: true } @@ -2812,8 +3261,13 @@ export function useChat( { preserveExistingState: true, suppressWorkflowToolStarts: suppressSeedWorkflowStarts, + ...(targetChatId ? { targetChatId } : {}), + ...(shouldContinue ? { shouldContinue } : {}), } ) + if (isStaleReconnect()) { + return { error: false, aborted: true } + } latestCursor = String(seedEvents[seedEvents.length - 1]?.eventId ?? latestCursor) lastCursorRef.current = latestCursor seedEvents = [] @@ -2862,7 +3316,11 @@ export function useChat( sseRes.body.getReader(), assistantId, expectedGen, - { preserveExistingState: true } + { + preserveExistingState: true, + ...(targetChatId ? { targetChatId } : {}), + ...(shouldContinue ? { shouldContinue } : {}), + } ) if (liveResult.sawStreamError) { @@ -2887,12 +3345,15 @@ export function useChat( }) const batch = await fetchStreamBatch(streamId, latestCursor, activeAbort.signal) + if (isStaleReconnect()) { + return { error: false, aborted: true } + } + seedStreamBatchPreviewSessions(batch) seedEvents = batch.events streamStatus = batch.status if (batch.events.length > 0) { latestCursor = String(batch.events[batch.events.length - 1].eventId) - lastCursorRef.current = latestCursor } if (batch.events.length === 0 && !isTerminalStreamStatus(batch.status)) { @@ -2918,7 +3379,13 @@ export function useChat( } } }, - [fetchStreamBatch, setTransportIdle, setTransportReconnecting, setTransportStreaming] + [ + fetchStreamBatch, + seedStreamBatchPreviewSessions, + setTransportIdle, + setTransportReconnecting, + setTransportStreaming, + ] ) attachToExistingStreamRef.current = attachToExistingStream @@ -2929,11 +3396,14 @@ export function useChat( gen: number afterCursor: string signal?: AbortSignal + targetChatId?: string + shouldContinue?: () => boolean }): Promise => { - const { streamId, assistantId, gen, afterCursor, signal } = opts + const { streamId, assistantId, gen, afterCursor, signal, targetChatId, shouldContinue } = opts const batch = await fetchStreamBatch(streamId, afterCursor, signal) - if (streamGenRef.current !== gen) return + if (streamGenRef.current !== gen || shouldContinue?.() === false) return + seedStreamBatchPreviewSessions(batch) if (isTerminalStreamStatus(batch.status)) { if (batch.events.length > 0) { @@ -2941,10 +3411,18 @@ export function useChat( buildReplayStream(batch.events).getReader(), assistantId, gen, - { preserveExistingState: true } + { + preserveExistingState: true, + ...(targetChatId ? { targetChatId } : {}), + ...(shouldContinue ? { shouldContinue } : {}), + } ) } - finalizeRef.current(batch.status === 'error' ? { error: true } : undefined) + if (streamGenRef.current !== gen || shouldContinue?.() === false) return + finalizeRef.current({ + ...(batch.status === 'error' ? { error: true } : {}), + ...(targetChatId ? { targetChatId } : {}), + }) return } @@ -2953,27 +3431,49 @@ export function useChat( assistantId, expectedGen: gen, initialBatch: batch, + ...(targetChatId ? { targetChatId } : {}), + ...(shouldContinue ? { shouldContinue } : {}), afterCursor: batch.events.length > 0 ? String(batch.events[batch.events.length - 1].eventId) : afterCursor, }) - if (streamGenRef.current === gen && !reconnectResult.aborted) { - finalizeRef.current(reconnectResult.error ? { error: true } : undefined) - } else if (streamGenRef.current === gen && reconnectResult.aborted && !sendingRef.current) { + if ( + streamGenRef.current === gen && + !reconnectResult.aborted && + shouldContinue?.() !== false + ) { + finalizeRef.current({ + ...(reconnectResult.error ? { error: true } : {}), + ...(targetChatId ? { targetChatId } : {}), + }) + } else if ( + streamGenRef.current === gen && + reconnectResult.aborted && + !sendingRef.current && + shouldContinue?.() !== false + ) { setTransportIdle() } }, - [fetchStreamBatch, attachToExistingStream, setTransportIdle] + [fetchStreamBatch, seedStreamBatchPreviewSessions, attachToExistingStream, setTransportIdle] ) const retryReconnect = useCallback( - async (opts: { streamId: string; assistantId: string; gen: number }): Promise => { - const { streamId, assistantId, gen } = opts + async (opts: { + streamId: string + assistantId: string + gen: number + targetChatId?: string + shouldContinue?: () => boolean + }): Promise => { + const { streamId, assistantId, gen, targetChatId, shouldContinue } = opts const isStaleReconnect = () => - streamGenRef.current !== gen || abortControllerRef.current?.signal.aborted === true + streamGenRef.current !== gen || + abortControllerRef.current?.signal.aborted === true || + shouldContinue?.() === false for (let attempt = 0; attempt <= MAX_RECONNECT_ATTEMPTS; attempt++) { if (isStaleReconnect()) return true @@ -2993,16 +3493,14 @@ export function useChat( if (isStaleReconnect()) return true setTransportReconnecting() - await sleep(delayMs) - if (streamGenRef.current !== gen) { - if (!sendingRef.current) { - setTransportIdle() - } else { - setIsReconnecting(false) + try { + await sleepWithAbort(delayMs, abortControllerRef.current?.signal) + } catch (err) { + if (!(err instanceof Error) || err.name !== 'AbortError') { + throw err } - return true } - if (abortControllerRef.current?.signal.aborted) { + if (isStaleReconnect()) { if (!sendingRef.current) { setTransportIdle() } else { @@ -3019,6 +3517,8 @@ export function useChat( gen, afterCursor: lastCursorRef.current || '0', signal: abortControllerRef.current?.signal, + ...(targetChatId ? { targetChatId } : {}), + ...(shouldContinue ? { shouldContinue } : {}), }) if (streamGenRef.current !== gen) { if (!sendingRef.current) { @@ -3081,68 +3581,236 @@ export function useChat( ) retryReconnectRef.current = retryReconnect - const persistPartialResponse = useCallback( - async (overrides?: { - chatId?: string - streamId?: string - content?: string - blocks?: ContentBlock[] - // `stopGeneration` must snapshot these BEFORE clearActiveTurn() - // nulls the refs, or the fetch sees undefined. - requestId?: string - traceparent?: string - }) => { - const chatId = overrides?.chatId ?? chatIdRef.current - const streamId = overrides?.streamId ?? streamIdRef.current - if (!chatId || !streamId) return - - const content = overrides?.content ?? streamingContentRef.current - const requestId = overrides?.requestId ?? streamRequestIdRef.current - const traceparent = overrides?.traceparent ?? streamTraceparentRef.current + const recoverActiveStreamFromRedis = useCallback( + async (reason: 'pageshow' | 'visible' | 'online'): Promise => { + const startingChatId = chatIdRef.current + const startingSelectedChatId = selectedChatIdRef.current + const chatId = startingChatId ?? startingSelectedChatId + if (!chatId) return + + const subjectKey = buildRecoverySubjectKey(startingChatId, startingSelectedChatId) + const existingRecovery = activeStreamReturnRecoveryRef.current + if (existingRecovery?.subjectKey === subjectKey) { + return existingRecovery.promise + } + if (existingRecovery) { + existingRecovery.controller.abort('replaced_by_new_recovery_subject') + activeStreamReturnRecoveryRef.current = null + } - const sourceBlocks = overrides?.blocks ?? streamingBlocksRef.current - const storedBlocks = sourceBlocks.map((block) => { - const timing = { - ...(typeof block.timestamp === 'number' ? { timestamp: block.timestamp } : {}), - ...(typeof block.endedAt === 'number' ? { endedAt: block.endedAt } : {}), + const recoveryController = new AbortController() + const recovery = (async () => { + const observedGeneration = streamGenRef.current + const isSameRecoverySubject = () => + chatIdRef.current === startingChatId && + selectedChatIdRef.current === startingSelectedChatId && + !recoveryController.signal.aborted + + const cached = queryClient.getQueryData(taskKeys.detail(chatId)) + let streamId = + streamIdRef.current ?? activeTurnRef.current?.userMessageId ?? cached?.activeStreamId + if (!streamId) { + streamId = + (await getActiveStreamIdForChat(chatId, recoveryController.signal)) ?? undefined } - if (block.type === 'tool_call' && block.toolCall) { - const isCancelled = - block.toolCall.status === 'executing' || block.toolCall.status === 'cancelled' - const displayTitle = isCancelled ? 'Stopped by user' : block.toolCall.displayTitle - const display = displayTitle ? { title: displayTitle } : undefined - return { - type: block.type, - content: block.content, - toolCall: { - id: block.toolCall.id, - name: block.toolCall.name, - state: isCancelled ? MothershipStreamV1ToolOutcome.cancelled : block.toolCall.status, - params: block.toolCall.params, - result: block.toolCall.result, - ...(display ? { display } : {}), - calledBy: block.toolCall.calledBy, - }, - ...(block.parentToolCallId ? { parentToolCallId: block.parentToolCallId } : {}), - ...timing, - } + if ( + !isSameRecoverySubject() || + streamGenRef.current !== observedGeneration || + pendingStopPromiseRef.current !== null || + !streamId || + locallyTerminalStreamIdRef.current === streamId + ) { + return } - return { - type: block.type, - content: block.content, - ...(block.subagent ? { lane: 'subagent' } : {}), - ...(block.parentToolCallId ? { parentToolCallId: block.parentToolCallId } : {}), - ...timing, + + const recoveryGen = observedGeneration + 1 + streamGenRef.current = recoveryGen + setTransportReconnecting() + streamIdRef.current = streamId + + const replacedController = abortControllerRef.current + if (replacedController && !replacedController.signal.aborted) { + replacedController.abort('superseded_recovery') } - }) - if (storedBlocks.length > 0) { - storedBlocks.push({ type: 'stopped', content: undefined }) - } + const replacedReader = streamReaderRef.current + streamReaderRef.current = null + void replacedReader?.cancel().catch((error) => { + logger.warn('Failed to cancel superseded stream reader during recovery', { + chatId, + streamId, + error: toError(error).message, + }) + }) + + abortControllerRef.current = recoveryController + + logger.info('Recovering active stream after browser return', { + reason, + chatId, + streamId, + fromGeneration: observedGeneration, + toGeneration: recoveryGen, + }) + + if ( + streamGenRef.current !== recoveryGen || + pendingStopPromiseRef.current !== null || + !isSameRecoverySubject() + ) { + return + } + if (locallyTerminalStreamIdRef.current === streamId) return + + const assistantId = getLiveAssistantMessageId(streamId) + const afterCursor = lastCursorRef.current || '0' + + try { + await resumeOrFinalize({ + streamId, + assistantId, + gen: recoveryGen, + afterCursor, + signal: recoveryController.signal, + targetChatId: chatId, + shouldContinue: isSameRecoverySubject, + }) + } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + return + } + logger.warn('Active stream recovery failed', { + reason, + chatId, + streamId, + error: toError(error).message, + }) + + const succeeded = await retryReconnectRef.current({ + streamId, + assistantId, + gen: recoveryGen, + targetChatId: chatId, + shouldContinue: isSameRecoverySubject, + }) + if (!succeeded && streamGenRef.current === recoveryGen && isSameRecoverySubject()) { + finalizeRef.current({ error: true, targetChatId: chatId }) + } + } + })() + + activeStreamReturnRecoveryRef.current = { + subjectKey, + controller: recoveryController, + promise: recovery, + } + try { + await recovery + } finally { + if (activeStreamReturnRecoveryRef.current?.promise === recovery) { + activeStreamReturnRecoveryRef.current = null + } + } + }, + [getActiveStreamIdForChat, queryClient, resumeOrFinalize, setTransportReconnecting] + ) + + useEffect(() => { + if (typeof window === 'undefined' || typeof document === 'undefined') return + + const recoverIfChatSelected = (reason: 'pageshow' | 'visible' | 'online') => { + if (!chatIdRef.current && !selectedChatIdRef.current) return + void recoverActiveStreamFromRedis(reason) + } + + const handleVisibilityChange = () => { + if (document.visibilityState === 'visible') { + recoverIfChatSelected('visible') + } + } + + const handlePageShow = () => { + recoverIfChatSelected('pageshow') + } + + const handleOnline = () => { + recoverIfChatSelected('online') + } + + document.addEventListener('visibilitychange', handleVisibilityChange) + window.addEventListener('pageshow', handlePageShow) + window.addEventListener('online', handleOnline) + + return () => { + document.removeEventListener('visibilitychange', handleVisibilityChange) + window.removeEventListener('pageshow', handlePageShow) + window.removeEventListener('online', handleOnline) + } + }, [recoverActiveStreamFromRedis]) + + const persistPartialResponse = useCallback( + async (overrides?: { + chatId?: string + streamId?: string + content?: string + blocks?: ContentBlock[] + // `stopGeneration` must snapshot these BEFORE clearActiveTurn() + // nulls the refs, or the fetch sees undefined. + requestId?: string + traceparent?: string + }) => { + const chatId = overrides?.chatId ?? chatIdRef.current + const streamId = overrides?.streamId ?? streamIdRef.current + if (!chatId || !streamId) return + + const content = overrides?.content ?? streamingContentRef.current + const requestId = overrides?.requestId ?? streamRequestIdRef.current + const traceparent = overrides?.traceparent ?? streamTraceparentRef.current + + const sourceBlocks = overrides?.blocks ?? streamingBlocksRef.current + const storedBlocks = sourceBlocks.map((block) => { + const timing = { + ...(typeof block.timestamp === 'number' ? { timestamp: block.timestamp } : {}), + ...(typeof block.endedAt === 'number' ? { endedAt: block.endedAt } : {}), + } + if (block.type === 'tool_call' && block.toolCall) { + const isCancelled = + block.toolCall.status === 'executing' || block.toolCall.status === 'cancelled' + const displayTitle = isCancelled ? 'Stopped by user' : block.toolCall.displayTitle + const display = displayTitle ? { title: displayTitle } : undefined + return { + type: block.type, + content: block.content, + toolCall: { + id: block.toolCall.id, + name: block.toolCall.name, + state: isCancelled ? MothershipStreamV1ToolOutcome.cancelled : block.toolCall.status, + params: block.toolCall.params, + result: block.toolCall.result, + ...(display ? { display } : {}), + calledBy: block.toolCall.calledBy, + }, + ...(block.parentToolCallId ? { parentToolCallId: block.parentToolCallId } : {}), + ...timing, + } + } + return { + type: block.type, + content: block.content, + ...(block.subagent ? { lane: 'subagent' } : {}), + ...(block.parentToolCallId ? { parentToolCallId: block.parentToolCallId } : {}), + ...timing, + } + }) + + if (storedBlocks.length > 0) { + storedBlocks.push({ type: 'stopped', content: undefined }) + } try { const res = await fetch(stopPathRef.current, { method: 'POST', + signal: createTimeoutSignal(STOP_REQUEST_TIMEOUT_MS), headers: { 'Content-Type': 'application/json', ...(traceparent ? { traceparent } : {}), @@ -3176,8 +3844,8 @@ export function useChat( ) const invalidateChatQueries = useCallback( - (options?: { includeDetail?: boolean }) => { - const activeChatId = chatIdRef.current + (options?: { includeDetail?: boolean; targetChatId?: string }) => { + const activeChatId = options?.targetChatId ?? chatIdRef.current if (options?.includeDetail !== false && activeChatId) { queryClient.invalidateQueries({ queryKey: taskKeys.detail(activeChatId), @@ -3217,8 +3885,45 @@ export function useChat( [] ) + const createQueuedMessage = useCallback( + ( + message: string, + fileAttachments?: FileAttachmentForApi[], + contexts?: ChatContext[] + ): QueuedChatMessage => { + const id = generateId() + const handoffChatId = selectedChatIdRef.current ?? chatIdRef.current + const cachedActiveStreamId = handoffChatId + ? queryClient.getQueryData(taskKeys.detail(handoffChatId))?.activeStreamId + : undefined + const supersededStreamId = + streamIdRef.current || + activeTurnRef.current?.userMessageId || + locallyTerminalStreamIdRef.current || + cachedActiveStreamId || + null + + return { + id, + content: message, + fileAttachments, + contexts, + ...(supersededStreamId || handoffChatId + ? { + queuedSendHandoff: { + id, + ...(handoffChatId ? { chatId: handoffChatId } : {}), + supersededStreamId, + }, + } + : {}), + } + }, + [queryClient] + ) + const finalize = useCallback( - (options?: { error?: boolean }) => { + (options?: { error?: boolean; targetChatId?: string }) => { const isError = !!options?.error const hasQueuedFollowUp = !isError && messageQueueRef.current.length > 0 reconcileTerminalPreviewSessions() @@ -3227,7 +3932,10 @@ export function useChat( clearActiveTurn() setTransportIdle() abortControllerRef.current = null - invalidateChatQueries({ includeDetail: !hasQueuedFollowUp }) + invalidateChatQueries({ + includeDetail: !hasQueuedFollowUp, + ...(options?.targetChatId ? { targetChatId: options.targetChatId } : {}), + }) notifyTurnEnded({ error: isError }) }, [ @@ -3251,21 +3959,21 @@ export function useChat( ) => { if (!message.trim() || !workspaceId) return false const pendingStop = pendingStopOverride ?? pendingStopPromiseRef.current + const pendingStopStreamId = pendingStop + ? queuedSendHandoff?.supersededStreamId || + locallyTerminalStreamIdRef.current || + streamIdRef.current || + activeTurnRef.current?.userMessageId + : undefined - const gen = ++streamGenRef.current let consumedByTranscript = false setError(null) setTransportStreaming() - locallyTerminalStreamIdRef.current = undefined const userMessageId = queuedSendHandoff?.userMessageId ?? generateId() const assistantId = getLiveAssistantMessageId(userMessageId) - streamIdRef.current = userMessageId - lastCursorRef.current = '0' - resetStreamingBuffers() - const storedAttachments: PersistedFileAttachment[] | undefined = fileAttachments && fileAttachments.length > 0 ? fileAttachments.map((f) => ({ @@ -3277,11 +3985,14 @@ export function useChat( })) : undefined - const requestChatId = selectedChatIdRef.current ?? chatIdRef.current - if (queuedSendHandoff) { + let requestChatId = + queuedSendHandoff?.chatId ?? selectedChatIdRef.current ?? chatIdRef.current + const writeQueuedSendHandoff = (chatId?: string) => { + if (!queuedSendHandoff) return + if (!chatId && !queuedSendHandoff.supersededStreamId) return writeQueuedSendHandoffState({ id: queuedSendHandoff.id, - chatId: queuedSendHandoff.chatId, + ...(chatId ? { chatId } : {}), workspaceId, supersededStreamId: queuedSendHandoff.supersededStreamId, userMessageId, @@ -3291,6 +4002,9 @@ export function useChat( requestedAt: Date.now(), }) } + if (queuedSendHandoff) { + writeQueuedSendHandoff(queuedSendHandoff.chatId) + } const messageContexts = contexts?.map((c) => ({ kind: c.kind, label: c.label, @@ -3331,12 +4045,6 @@ export function useChat( content: '', contentBlocks: [], } - activeTurnRef.current = { - userMessageId, - assistantMessageId: assistantId, - optimisticUserMessage, - optimisticAssistantMessage, - } if (requestChatId) { await queryClient.cancelQueries({ queryKey: taskKeys.detail(requestChatId) }) @@ -3395,27 +4103,76 @@ export function useChat( onOptimisticSendApplied?.() consumedByTranscript = true - const abortController = new AbortController() - abortControllerRef.current = abortController - + let gen: number | undefined + let streamTargetChatId: string | undefined try { if (pendingStop) { try { await pendingStop + if (!requestChatId) { + requestChatId = + queuedSendHandoff?.chatId ?? + (queuedSendHandoff ? undefined : selectedChatIdRef.current) ?? + chatIdRef.current + if (!requestChatId && pendingStopStreamId) { + const resolvedChatId = await resolveChatIdForStream(pendingStopStreamId, { + preferExistingChatId: false, + }) + if (resolvedChatId) { + if (!selectedChatIdRef.current || selectedChatIdRef.current === resolvedChatId) { + adoptResolvedChatId(resolvedChatId, { replaceHomeHistory: true }) + } + requestChatId = resolvedChatId + } + } + if (requestChatId) { + writeQueuedSendHandoff(requestChatId) + } + } + if ((queuedSendHandoff || pendingStopStreamId) && !requestChatId) { + throw new Error('Cannot send queued message until the active chat is known.') + } + if ( + queuedSendHandoff && + requestChatId && + selectedChatIdRef.current && + selectedChatIdRef.current !== requestChatId + ) { + throw new Error('Queued message was restored because the selected task changed.') + } if (requestChatId) { await queryClient.cancelQueries({ queryKey: taskKeys.detail(requestChatId) }) } applyOptimisticSend() } catch (err) { + if (queuedSendHandoff) { + clearQueuedSendHandoffClaim(queuedSendHandoff.id) + } rollbackOptimisticSend() - abortControllerRef.current = null - clearActiveTurn() - setTransportIdle() + if (!streamReaderRef.current && !abortControllerRef.current) { + clearActiveTurn() + setTransportIdle() + } setError(err instanceof Error ? err.message : 'Failed to stop the previous response') return false } } + streamTargetChatId = requestChatId + gen = ++streamGenRef.current + locallyTerminalStreamIdRef.current = undefined + streamIdRef.current = userMessageId + lastCursorRef.current = '0' + resetStreamingBuffers() + activeTurnRef.current = { + userMessageId, + assistantMessageId: assistantId, + optimisticUserMessage, + optimisticAssistantMessage, + } + const abortController = new AbortController() + abortControllerRef.current = abortController + const currentActiveId = activeResourceIdRef.current const currentResources = resourcesRef.current const resourceAttachments = @@ -3461,15 +4218,32 @@ export function useChat( typeof errorData.activeStreamId === 'string' ? errorData.activeStreamId : userMessageId + const supersededStreamId = queuedSendHandoff?.supersededStreamId ?? pendingStopStreamId + if (supersededStreamId && conflictStreamId === supersededStreamId) { + rollbackOptimisticSend() + if (streamGenRef.current === gen) { + streamGenRef.current++ + abortController.abort('queued_handoff:superseded_conflict') + abortControllerRef.current = null + clearActiveTurn() + setTransportIdle() + } + setError('Previous response is still shutting down; queued message was restored.') + return false + } streamIdRef.current = conflictStreamId const succeeded = await retryReconnect({ streamId: conflictStreamId, assistantId, gen, + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), }) if (succeeded) return consumedByTranscript if (streamGenRef.current === gen) { - finalize({ error: true }) + finalize({ + error: true, + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), + }) } return consumedByTranscript } @@ -3482,17 +4256,24 @@ export function useChat( if (!response.body) throw new Error('No response body') - const streamResult = await processSSEStream(response.body.getReader(), assistantId, gen) + const streamResult = await processSSEStream(response.body.getReader(), assistantId, gen, { + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), + }) if (streamGenRef.current === gen) { if (streamResult.sawStreamError) { - finalize({ error: true }) + finalize({ + error: true, + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), + }) return consumedByTranscript } // A live SSE `complete` event is already terminal. Finalize immediately so follow-up // sends do not get spuriously queued behind an already-finished response. if (streamResult.sawComplete) { - finalize() + finalize({ + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), + }) return consumedByTranscript } @@ -3502,34 +4283,44 @@ export function useChat( gen, afterCursor: lastCursorRef.current || '0', signal: abortController.signal, + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), }) if (streamGenRef.current === gen && sendingRef.current) { - finalize() + finalize({ + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), + }) } } } catch (err) { if (err instanceof Error && err.name === 'AbortError') return consumedByTranscript if (isStreamSchemaValidationError(err)) { setError(err.message) - if (streamGenRef.current === gen) { - finalize({ error: true }) + if (gen !== undefined && streamGenRef.current === gen) { + finalize({ + error: true, + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), + }) } return consumedByTranscript } const activeStreamId = streamIdRef.current - if (activeStreamId && streamGenRef.current === gen) { + if (activeStreamId && gen !== undefined && streamGenRef.current === gen) { const succeeded = await retryReconnect({ streamId: activeStreamId, assistantId, gen, + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), }) if (succeeded) return consumedByTranscript } setError(err instanceof Error ? err.message : 'Failed to send message') - if (streamGenRef.current === gen) { - finalize({ error: true }) + if (gen !== undefined && streamGenRef.current === gen) { + finalize({ + error: true, + ...(streamTargetChatId ? { targetChatId: streamTargetChatId } : {}), + }) } return consumedByTranscript } @@ -3545,6 +4336,8 @@ export function useChat( retryReconnect, clearActiveTurn, resetStreamingBuffers, + resolveChatIdForStream, + adoptResolvedChatId, setTransportIdle, setTransportStreaming, ] @@ -3554,19 +4347,25 @@ export function useChat( if (!message.trim() || !workspaceId) return if (sendingRef.current) { - const queued: QueuedMessage = { - id: generateId(), - content: message, - fileAttachments, - contexts, - } - setMessageQueue((prev) => [...prev, queued]) + setMessageQueue((prev) => [ + ...prev, + createQueuedMessage(message, fileAttachments, contexts), + ]) + return + } + + if (pendingStopPromiseRef.current) { + setMessageQueue((prev) => [ + ...prev, + createQueuedMessage(message, fileAttachments, contexts), + ]) + void enqueueQueueDispatchRef.current({ type: 'send_head' }) return } await startSendMessage(message, fileAttachments, contexts) }, - [workspaceId, startSendMessage] + [workspaceId, startSendMessage, createQueuedMessage] ) useEffect(() => { if (typeof window === 'undefined') return @@ -3582,13 +4381,136 @@ export function useChat( window.removeEventListener('beforeunload', clearClaim) } }, []) + useEffect(() => { + if (!workspaceId || sendingRef.current || pendingStopPromiseRef.current) return + + let cancelled = false + const handoff = readQueuedSendHandoffState() + if (!handoff || handoff.workspaceId !== workspaceId) return + if (recoveringQueuedSendHandoffRef.current?.id === handoff.id) return + const claimRetryDelayMs = queuedSendHandoffClaimRetryDelay(handoff.id) + if (claimRetryDelayMs !== null) { + const retryTimer = window.setTimeout(() => { + setQueuedHandoffRecoveryEpoch((epoch) => epoch + 1) + }, claimRetryDelayMs) + return () => window.clearTimeout(retryTimer) + } + + if (handoff.chatId) { + if (selectedChatIdRef.current && selectedChatIdRef.current !== handoff.chatId) return + adoptResolvedChatId(handoff.chatId, { replaceHomeHistory: true }) + return + } + + if (!handoff.supersededStreamId) return + + const claimOwnerId = writeQueuedSendHandoffClaim(handoff.id) + recoveringQueuedSendHandoffRef.current = { id: handoff.id, ownerId: claimOwnerId } + const effectAbortController = new AbortController() + let shouldRetry = false + void (async () => { + const chatId = await resolveChatIdForStream(handoff.supersededStreamId as string, { + preferExistingChatId: false, + signal: effectAbortController.signal, + }) + if (!chatId) { + shouldRetry = true + return + } + if (cancelled) return + const currentHandoff = readQueuedSendHandoffState() + if ( + !currentHandoff || + currentHandoff.id !== handoff.id || + currentHandoff.workspaceId !== workspaceId || + currentHandoff.userMessageId !== handoff.userMessageId || + currentHandoff.supersededStreamId !== handoff.supersededStreamId || + currentHandoff.chatId || + !hasQueuedSendHandoffClaimOwner(handoff.id, claimOwnerId) + ) { + return + } + writeQueuedSendHandoffState({ + ...currentHandoff, + chatId, + requestedAt: Date.now(), + }) + setQueuedHandoffRecoveryEpoch((epoch) => epoch + 1) + if (!selectedChatIdRef.current || selectedChatIdRef.current === chatId) { + adoptResolvedChatId(chatId, { replaceHomeHistory: true, invalidateList: true }) + } + })() + .catch((error) => { + if (error instanceof Error && error.name === 'AbortError') return + logger.warn('Failed to resolve queued send handoff chat id', { + handoffId: handoff.id, + streamId: handoff.supersededStreamId, + error: toError(error).message, + }) + }) + .finally(async () => { + if ( + shouldRetry && + !cancelled && + recoveringQueuedSendHandoffRef.current?.id === handoff.id && + recoveringQueuedSendHandoffRef.current.ownerId === claimOwnerId + ) { + const currentHandoff = readQueuedSendHandoffState() + if (currentHandoff?.id === handoff.id && !currentHandoff.chatId) { + const resolveAttempts = (currentHandoff.resolveAttempts ?? 0) + 1 + writeQueuedSendHandoffState({ ...currentHandoff, resolveAttempts }) + try { + await sleepWithAbort( + queuedSendHandoffResolveRetryDelay(resolveAttempts), + effectAbortController.signal + ) + } catch (error) { + if (error instanceof Error && error.name === 'AbortError') return + logger.warn('Failed to back off queued send handoff recovery', { + handoffId: handoff.id, + error: toError(error).message, + }) + return + } + if ( + !cancelled && + recoveringQueuedSendHandoffRef.current?.id === handoff.id && + recoveringQueuedSendHandoffRef.current.ownerId === claimOwnerId + ) { + recoveringQueuedSendHandoffRef.current = null + clearQueuedSendHandoffClaim(handoff.id, claimOwnerId) + setQueuedHandoffRecoveryEpoch((epoch) => epoch + 1) + } + return + } + } + if ( + recoveringQueuedSendHandoffRef.current?.id === handoff.id && + recoveringQueuedSendHandoffRef.current.ownerId === claimOwnerId + ) { + recoveringQueuedSendHandoffRef.current = null + } + clearQueuedSendHandoffClaim(handoff.id, claimOwnerId) + }) + return () => { + cancelled = true + effectAbortController.abort('cleanup:queued_handoff_recovery') + if ( + recoveringQueuedSendHandoffRef.current?.id === handoff.id && + recoveringQueuedSendHandoffRef.current.ownerId === claimOwnerId + ) { + recoveringQueuedSendHandoffRef.current = null + } + clearQueuedSendHandoffClaim(handoff.id, claimOwnerId) + } + }, [workspaceId, queuedHandoffRecoveryEpoch, adoptResolvedChatId, resolveChatIdForStream]) useEffect(() => { if (!workspaceId || !chatHistory || sendingRef.current || pendingStopPromiseRef.current) return const handoff = readQueuedSendHandoffState() if (!handoff) return if (handoff.workspaceId !== workspaceId || handoff.chatId !== chatHistory.id) return - if (recoveringQueuedSendHandoffIdRef.current === handoff.id) return + if (recoveringQueuedSendHandoffRef.current?.id === handoff.id) return if (readQueuedSendHandoffClaim() === handoff.id) return if ( @@ -3610,8 +4532,8 @@ export function useChat( return } - recoveringQueuedSendHandoffIdRef.current = handoff.id - writeQueuedSendHandoffClaim(handoff.id) + const claimOwnerId = writeQueuedSendHandoffClaim(handoff.id) + recoveringQueuedSendHandoffRef.current = { id: handoff.id, ownerId: claimOwnerId } void startSendMessage( handoff.message, handoff.fileAttachments, @@ -3625,12 +4547,15 @@ export function useChat( userMessageId: handoff.userMessageId, } ).finally(() => { - if (recoveringQueuedSendHandoffIdRef.current === handoff.id) { - recoveringQueuedSendHandoffIdRef.current = null + if ( + recoveringQueuedSendHandoffRef.current?.id === handoff.id && + recoveringQueuedSendHandoffRef.current.ownerId === claimOwnerId + ) { + recoveringQueuedSendHandoffRef.current = null } - clearQueuedSendHandoffClaim(handoff.id) + clearQueuedSendHandoffClaim(handoff.id, claimOwnerId) }) - }, [workspaceId, chatHistory, startSendMessage]) + }, [workspaceId, chatHistory, queuedHandoffRecoveryEpoch, startSendMessage]) const cancelActiveWorkflowExecutions = useCallback(() => { const execState = useExecutionStore.getState() const consoleStore = useTerminalConsoleStore.getState() @@ -3677,172 +4602,358 @@ export function useChat( } }, [executionStream]) - const stopGeneration = useCallback(async () => { - if (pendingStopPromiseRef.current) { - return pendingStopPromiseRef.current - } + const stopGeneration = useCallback( + async (options?: StopGenerationOptions) => { + const mode = options?.mode ?? 'normal' + if (pendingStopPromiseRef.current) { + if (mode === 'queued-handoff' && pendingStopModeRef.current !== 'queued-handoff') { + throw new Error('Previous response is already stopping; queued message was restored.') + } + return pendingStopPromiseRef.current + } - const wasSending = sendingRef.current - const activeChatId = chatIdRef.current - const sid = - streamIdRef.current || - activeTurnRef.current?.userMessageId || - queryClient.getQueryData(taskKeys.detail(chatIdRef.current)) - ?.activeStreamId || - undefined - // Snapshot the active assistant message id BEFORE clearActiveTurn() - // nulls the ref. Used below to restrict markMessageStopped to the - // in-flight turn only — historical messages from the chat history - // also lack `endedAt` on their legacy blocks (pre-timing-fields), - // and without this gate we'd corrupt them with cancelled markers. - const activeAssistantMessageId = - activeTurnRef.current?.assistantMessageId ?? - (sid ? getLiveAssistantMessageId(sid) : undefined) - const stopContentSnapshot = streamingContentRef.current - const stopNow = Date.now() - const stopBlocksSnapshot = streamingBlocksRef.current.map((block) => ({ - ...block, - ...(block.options ? { options: [...block.options] } : {}), - ...(block.toolCall ? { toolCall: { ...block.toolCall } } : {}), - ...(block.endedAt === undefined ? { endedAt: stopNow } : {}), - })) - // Snapshot BEFORE clearActiveTurn() nulls the refs. Both - // persistPartialResponse and the abort/stop fetches run inside - // stopBarrier below, after several awaits — the refs are long - // gone by the time the fetches serialize their headers. - const stopRequestIdSnapshot = streamRequestIdRef.current - const stopTraceparentSnapshot = streamTraceparentRef.current - - locallyTerminalStreamIdRef.current = sid - streamGenRef.current++ - clearActiveTurn() - streamReaderRef.current?.cancel().catch(() => {}) - streamReaderRef.current = null - abortControllerRef.current?.abort('user_stop:client_stopGeneration') - abortControllerRef.current = null - setTransportIdle() + let resolveStopOperation!: () => void + let rejectStopOperation!: (error: unknown) => void + const stopOperation = new Promise((resolve, reject) => { + resolveStopOperation = resolve + rejectStopOperation = reject + }) + stopOperation.catch(() => {}) + pendingStopPromiseRef.current = stopOperation + pendingStopModeRef.current = mode + + const wasSending = sendingRef.current + let activeChatId = chatIdRef.current ?? selectedChatIdRef.current + const sid = + streamIdRef.current || + activeTurnRef.current?.userMessageId || + (activeChatId + ? queryClient.getQueryData(taskKeys.detail(activeChatId))?.activeStreamId + : undefined) || + undefined + + const activeAssistantMessageId = + activeTurnRef.current?.assistantMessageId ?? + (sid ? getLiveAssistantMessageId(sid) : undefined) + const initialStopRequestIdSnapshot = streamRequestIdRef.current + const initialStopTraceparentSnapshot = streamTraceparentRef.current + + try { + if (mode === 'queued-handoff' && !activeChatId && !sid) { + throw new Error('Cannot send queued message until the active chat is known.') + } + } catch (err) { + if (pendingStopPromiseRef.current === stopOperation) { + pendingStopPromiseRef.current = null + pendingStopModeRef.current = null + } + setError(err instanceof Error ? err.message : 'Failed to stop the previous response') + rejectStopOperation(err) + throw err + } - if (activeChatId) { - await queryClient.cancelQueries({ queryKey: taskKeys.detail(activeChatId) }) - upsertTaskChatHistory(activeChatId, (current) => ({ - ...current, - messages: current.messages.map((message) => - activeAssistantMessageId && message.id === activeAssistantMessageId - ? markMessageStopped(message) - : message - ), + const stopContentSnapshot = streamingContentRef.current + const stopNow = Date.now() + const stopBlocksSnapshot = streamingBlocksRef.current.map((block) => ({ + ...block, + ...(block.options ? { options: [...block.options] } : {}), + ...(block.toolCall ? { toolCall: { ...block.toolCall } } : {}), + ...(block.endedAt === undefined ? { endedAt: stopNow } : {}), })) - } else { - setPendingMessages((prev) => - prev.map((msg) => { - const hasExecutingTool = msg.contentBlocks?.some( - (block) => block.toolCall?.status === 'executing' - ) - const hasOpenBlock = msg.contentBlocks?.some((block) => block.endedAt === undefined) - if (!hasExecutingTool && !hasOpenBlock) { - return msg - } - const updatedBlocks = (msg.contentBlocks ?? []).map((block) => { - const stamped = block.endedAt === undefined ? { ...block, endedAt: stopNow } : block - if (stamped.toolCall?.status !== 'executing') { - return stamped - } - return { - ...stamped, - toolCall: { - ...stamped.toolCall, - status: 'cancelled' as const, - displayTitle: 'Stopped by user', - }, - } - }) - updatedBlocks.push({ type: 'stopped' as const }) - return { ...msg, contentBlocks: updatedBlocks } - }) - ) - } + const stopRequestIdSnapshot = streamRequestIdRef.current ?? initialStopRequestIdSnapshot + const stopTraceparentSnapshot = streamTraceparentRef.current ?? initialStopTraceparentSnapshot - // Cancel active run-tool executions before waiting for the server-side stream - // shutdown barrier; otherwise the abort settle can sit behind tool execution teardown. - cancelActiveWorkflowExecutions() + locallyTerminalStreamIdRef.current = sid + streamGenRef.current++ + clearActiveTurn() + streamReaderRef.current?.cancel().catch(() => {}) + streamReaderRef.current = null + abortControllerRef.current?.abort('user_stop:client_stopGeneration') + abortControllerRef.current = null + setTransportIdle() - const stopBarrier = (async () => { try { - if (wasSending && !chatIdRef.current) { - const start = Date.now() - while (!chatIdRef.current && Date.now() - start < 3000) { - await new Promise((r) => setTimeout(r, 50)) - } + if (activeChatId) { + await queryClient.cancelQueries({ queryKey: taskKeys.detail(activeChatId) }) + upsertTaskChatHistory(activeChatId, (current) => ({ + ...current, + messages: current.messages.map((message) => + activeAssistantMessageId && message.id === activeAssistantMessageId + ? markMessageStopped(message) + : message + ), + })) + } else { + setPendingMessages((prev) => + prev.map((msg) => { + const hasExecutingTool = msg.contentBlocks?.some( + (block) => block.toolCall?.status === 'executing' + ) + const hasOpenBlock = msg.contentBlocks?.some((block) => block.endedAt === undefined) + if (!hasExecutingTool && !hasOpenBlock) { + return msg + } + const updatedBlocks = (msg.contentBlocks ?? []).map((block) => { + const stamped = block.endedAt === undefined ? { ...block, endedAt: stopNow } : block + if (stamped.toolCall?.status !== 'executing') { + return stamped + } + return { + ...stamped, + toolCall: { + ...stamped.toolCall, + status: 'cancelled' as const, + displayTitle: 'Stopped by user', + }, + } + }) + updatedBlocks.push({ type: 'stopped' as const }) + return { ...msg, contentBlocks: updatedBlocks } + }) + ) + } + } catch (err) { + if (sid && locallyTerminalStreamIdRef.current === sid) { + locallyTerminalStreamIdRef.current = undefined + } + if (pendingStopPromiseRef.current === stopOperation) { + pendingStopPromiseRef.current = null + pendingStopModeRef.current = null } + setError(err instanceof Error ? err.message : 'Failed to stop the previous response') + rejectStopOperation(err) + throw err + } - const resolvedChatId = chatIdRef.current - const abortPromise = sid - ? (async () => { - // boundary-raw-fetch: stream-abort endpoint requires propagating the snapshotted traceparent header from the in-flight stream and has no contract authored yet - const res = await fetch('/api/mothership/chat/abort', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - ...(stopTraceparentSnapshot ? { traceparent: stopTraceparentSnapshot } : {}), - }, - body: JSON.stringify({ - streamId: sid, - ...(resolvedChatId ? { chatId: resolvedChatId } : {}), - }), + // Cancel active run-tool executions before waiting for the server-side stream + // shutdown barrier; otherwise the abort settle can sit behind tool execution teardown. + cancelActiveWorkflowExecutions() + + let abortSucceeded = false + const stopBarrier = (async () => { + let stopSucceeded = false + try { + let resolvedChatId = activeChatId ?? chatIdRef.current + let abortSettled = false + const postAbortRequest = async (chatId?: string): Promise => { + if (!sid) return true + // boundary-raw-fetch: stream-abort endpoint requires propagating the snapshotted traceparent header from the in-flight stream and has no contract authored yet + const res = await fetch('/api/mothership/chat/abort', { + method: 'POST', + signal: createTimeoutSignal(STOP_REQUEST_TIMEOUT_MS), + headers: { + 'Content-Type': 'application/json', + ...(stopTraceparentSnapshot ? { traceparent: stopTraceparentSnapshot } : {}), + }, + body: JSON.stringify({ + streamId: sid, + ...(chatId ? { chatId } : {}), + }), + }) + const payload: unknown = await res.json().catch(() => null) + if (isRecord(payload) && payload.aborted === true) { + abortSucceeded = true + } + if (!res.ok) { + if (isRecord(payload) && payload.settled === false) { + return false + } + throw new Error( + isRecord(payload) && typeof payload.error === 'string' + ? payload.error + : 'Failed to abort previous response' + ) + } + abortSucceeded = true + return isRecord(payload) && payload.settled === true + } + const abortPromise = sid + ? postAbortRequest(resolvedChatId).then((settled) => { + abortSettled = settled }) - if (!res.ok) { - const payload = await res.json().catch(() => null) + : Promise.resolve() + + let stopFailure: unknown + let abortFailure: unknown + try { + if (mode === 'queued-handoff' && !resolvedChatId && sid) { + resolvedChatId = await resolveChatIdForStream(sid, { + preferExistingChatId: false, + }) + if (!resolvedChatId) { + throw new Error('Cannot send queued message until the active chat is known.') + } + if ( + pendingStopPromiseRef.current !== stopOperation || + locallyTerminalStreamIdRef.current !== sid + ) { throw new Error( - typeof payload?.error === 'string' - ? payload.error - : 'Failed to abort previous response' + 'Previous response stop was superseded; queued message was restored.' ) } - })() - : Promise.resolve() - - if (wasSending && resolvedChatId) { - await persistPartialResponse({ - chatId: resolvedChatId, - streamId: sid, - content: stopContentSnapshot, - blocks: stopBlocksSnapshot, - requestId: stopRequestIdSnapshot, - traceparent: stopTraceparentSnapshot, + activeChatId = resolvedChatId + if (!selectedChatIdRef.current || selectedChatIdRef.current === resolvedChatId) { + adoptResolvedChatId(resolvedChatId, { replaceHomeHistory: true }) + } + } + + if (wasSending && resolvedChatId) { + await persistPartialResponse({ + chatId: resolvedChatId, + streamId: sid, + content: stopContentSnapshot, + blocks: stopBlocksSnapshot, + requestId: stopRequestIdSnapshot, + traceparent: stopTraceparentSnapshot, + }) + } + } catch (err) { + stopFailure = err + } + + try { + await abortPromise + } catch (err) { + abortFailure = err + } + if (sid && resolvedChatId && !abortSettled) { + try { + const retrySettled = await postAbortRequest(resolvedChatId) + abortSettled = retrySettled + abortFailure = retrySettled + ? undefined + : new Error('Previous response is still shutting down.') + } catch (err) { + abortFailure = err + } + } + + if (stopFailure || abortFailure) throw stopFailure ?? abortFailure + if (wasSending && resolvedChatId) { + activeChatId = resolvedChatId + } + stopSucceeded = true + } finally { + invalidateChatQueries({ + includeDetail: mode !== 'queued-handoff' || !stopSucceeded, }) + resetEphemeralPreviewState({ removeStreamingResource: true }) } + })() - await abortPromise + try { + await stopBarrier + notifyTurnEnded({ + error: false, + skipQueueDispatch: mode === 'queued-handoff', + }) + resolveStopOperation() + } catch (err) { + if (sid && !abortSucceeded && locallyTerminalStreamIdRef.current === sid) { + locallyTerminalStreamIdRef.current = undefined + } + if (activeChatId) { + invalidateChatQueries() + } + setError(err instanceof Error ? err.message : 'Failed to stop the previous response') + rejectStopOperation(err) + throw err } finally { - invalidateChatQueries() - resetEphemeralPreviewState({ removeStreamingResource: true }) + if (pendingStopPromiseRef.current === stopOperation) { + pendingStopPromiseRef.current = null + pendingStopModeRef.current = null + } } - })() + }, + [ + cancelActiveWorkflowExecutions, + invalidateChatQueries, + notifyTurnEnded, + persistPartialResponse, + queryClient, + resolveChatIdForStream, + resetEphemeralPreviewState, + upsertTaskChatHistory, + adoptResolvedChatId, + clearActiveTurn, + setTransportIdle, + workspaceId, + ] + ) - pendingStopPromiseRef.current = stopBarrier - try { - await stopBarrier - // Dispatch queued follow-ups after Stop resolves. - notifyTurnEnded({ error: false }) - } catch (err) { - setError(err instanceof Error ? err.message : 'Failed to stop the previous response') - throw err - } finally { - if (pendingStopPromiseRef.current === stopBarrier) { - pendingStopPromiseRef.current = null + const dispatchQueuedMessage = useCallback( + async ( + msg: QueuedChatMessage, + options: { + epoch: number + pendingStop?: Promise | null + queuedSendHandoff?: QueuedSendHandoffSeed } - } - }, [ - cancelActiveWorkflowExecutions, - invalidateChatQueries, - notifyTurnEnded, - persistPartialResponse, - queryClient, - resetEphemeralPreviewState, - upsertTaskChatHistory, - clearActiveTurn, - setTransportIdle, - ]) + ) => { + if (queuedMessageDispatchIdsRef.current.has(msg.id)) { + return + } + queuedMessageDispatchIdsRef.current.add(msg.id) + + let originalIndex = messageQueueRef.current.findIndex((queued) => queued.id === msg.id) + if (originalIndex === -1) { + queuedMessageDispatchIdsRef.current.delete(msg.id) + return + } + + let removedFromQueue = false + const removeQueuedMessage = () => { + if (removedFromQueue || options.epoch !== queueDispatchEpochRef.current) { + return + } + removedFromQueue = true + setMessageQueue((prev) => prev.filter((queued) => queued.id !== msg.id)) + } + + const restoreQueuedMessage = (handoff?: QueuedSendHandoffSeed) => { + if (!handoff) { + clearQueuedSendHandoffState(msg.id) + } + clearQueuedSendHandoffClaim(msg.id) + if (!removedFromQueue || options.epoch !== queueDispatchEpochRef.current) { + return + } + setMessageQueue((prev) => { + if (prev.some((queued) => queued.id === msg.id)) return prev + const next = [...prev] + next.splice(Math.min(originalIndex, next.length), 0, msg) + return next + }) + } + + const activeQueuedSendHandoff = options.queuedSendHandoff ?? msg.queuedSendHandoff + try { + const currentIndex = messageQueueRef.current.findIndex((queued) => queued.id === msg.id) + if (currentIndex === -1) { + return + } + originalIndex = currentIndex + + const consumed = await startSendMessage( + msg.content, + msg.fileAttachments, + msg.contexts, + options.pendingStop, + removeQueuedMessage, + activeQueuedSendHandoff + ) + + if (!consumed) { + restoreQueuedMessage(activeQueuedSendHandoff) + } + } catch { + restoreQueuedMessage(activeQueuedSendHandoff) + } finally { + queuedMessageDispatchIdsRef.current.delete(msg.id) + } + }, + [startSendMessage] + ) const runQueueDispatchLoop = useCallback(async () => { if (queueDispatchTaskRef.current) { @@ -3861,47 +4972,7 @@ export function useChat( const msg = messageQueueRef.current[0] if (!msg) continue - let originalIndex = 0 - let removedFromQueue = false - const removeQueuedMessage = () => { - if (removedFromQueue || action.epoch !== queueDispatchEpochRef.current) { - return - } - removedFromQueue = true - setMessageQueue((prev) => prev.filter((queued) => queued.id !== msg.id)) - } - - try { - const currentIndex = messageQueueRef.current.findIndex((queued) => queued.id === msg.id) - if (currentIndex !== -1) { - originalIndex = currentIndex - } - - const consumed = await startSendMessage( - msg.content, - msg.fileAttachments, - msg.contexts, - undefined, - removeQueuedMessage - ) - if (!consumed && removedFromQueue && action.epoch === queueDispatchEpochRef.current) { - setMessageQueue((prev) => { - if (prev.some((queued) => queued.id === msg.id)) return prev - const next = [...prev] - next.splice(Math.min(originalIndex, next.length), 0, msg) - return next - }) - } - } catch { - if (removedFromQueue && action.epoch === queueDispatchEpochRef.current) { - setMessageQueue((prev) => { - if (prev.some((queued) => queued.id === msg.id)) return prev - const next = [...prev] - next.splice(Math.min(originalIndex, next.length), 0, msg) - return next - }) - } - } + await dispatchQueuedMessage(msg, { epoch: action.epoch }) } })() @@ -3915,7 +4986,7 @@ export function useChat( void queueDispatchLoopRef.current() } }) - }, [startSendMessage]) + }, [dispatchQueuedMessage]) queueDispatchLoopRef.current = runQueueDispatchLoop const enqueueQueueDispatch = useCallback((action: QueueDispatchActionInput) => { @@ -3933,84 +5004,49 @@ export function useChat( const sendQueuedMessageImmediately = useCallback( async (id: string) => { - const epoch = queueDispatchEpochRef.current - const initialIndex = messageQueueRef.current.findIndex((m) => m.id === id) - if (initialIndex === -1) return - const msg = messageQueueRef.current[initialIndex] - - if (queuedMessageDispatchIdsRef.current.has(msg.id)) { - return - } - queuedMessageDispatchIdsRef.current.add(msg.id) + const msg = messageQueueRef.current.find((queued) => queued.id === id) + if (!msg) return + if (queuedMessageDispatchIdsRef.current.has(msg.id)) return // Explicit queue sends should supersede any older auto-drain work scheduled by finalize(). queueDispatchActionsRef.current = queueDispatchActionsRef.current.filter( (queuedAction) => queuedAction.type !== 'send_head' ) - let originalIndex = initialIndex - let removedFromQueue = false - const removeQueuedMessage = () => { - if (removedFromQueue || epoch !== queueDispatchEpochRef.current) { - return - } - removedFromQueue = true - setMessageQueue((prev) => prev.filter((queued) => queued.id !== msg.id)) - } - const restoreQueuedMessage = () => { - if (!removedFromQueue || epoch !== queueDispatchEpochRef.current) { - return - } - setMessageQueue((prev) => { - if (prev.some((queued) => queued.id === msg.id)) return prev - const next = [...prev] - next.splice(Math.min(originalIndex, next.length), 0, msg) - return next - }) - } - - try { - const currentIndex = messageQueueRef.current.findIndex((queued) => queued.id === msg.id) - if (currentIndex === -1) { - return - } - - originalIndex = currentIndex - - const queuedSendHandoff = - sendingRef.current && workspaceId - ? { + const queuedSendHandoff = + msg.queuedSendHandoff ?? + ((sendingRef.current || pendingStopPromiseRef.current) && workspaceId + ? (() => { + const handoffChatId = selectedChatIdRef.current ?? chatIdRef.current + const cachedActiveStreamId = handoffChatId + ? queryClient.getQueryData(taskKeys.detail(handoffChatId)) + ?.activeStreamId + : undefined + return { id: msg.id, - chatId: selectedChatIdRef.current ?? chatIdRef.current ?? '', + ...(handoffChatId ? { chatId: handoffChatId } : {}), supersededStreamId: streamIdRef.current || activeTurnRef.current?.userMessageId || - queryClient.getQueryData( - taskKeys.detail(selectedChatIdRef.current ?? chatIdRef.current) - )?.activeStreamId || + cachedActiveStreamId || null, } - : undefined - const pendingStop = sendingRef.current ? stopGeneration() : pendingStopPromiseRef.current - const consumed = await startSendMessage( - msg.content, - msg.fileAttachments, - msg.contexts, - pendingStop, - removeQueuedMessage, - queuedSendHandoff?.chatId ? queuedSendHandoff : undefined - ) + })() + : undefined) - if (!consumed) { - restoreQueuedMessage() - } - } catch { - restoreQueuedMessage() - } finally { - queuedMessageDispatchIdsRef.current.delete(msg.id) - } + const pendingStop = sendingRef.current + ? stopGeneration({ + mode: 'queued-handoff', + }) + : pendingStopPromiseRef.current + + await dispatchQueuedMessage(msg, { + epoch: queueDispatchEpochRef.current, + pendingStop, + queuedSendHandoff, + }) }, - [startSendMessage, stopGeneration] + [dispatchQueuedMessage, queryClient, stopGeneration, workspaceId] ) const sendNow = useCallback( @@ -4031,14 +5067,21 @@ export function useChat( useEffect(() => { return () => { + cancelActiveStreamRecovery() clearQueueDispatchState() - streamReaderRef.current = null - abortControllerRef.current = null streamGenRef.current++ + cancelActiveStreamReader() + abortControllerRef.current?.abort('unmount:client_cleanup') + abortControllerRef.current = null clearActiveTurn() sendingRef.current = false } - }, [clearQueueDispatchState, clearActiveTurn]) + }, [ + cancelActiveStreamRecovery, + cancelActiveStreamReader, + clearQueueDispatchState, + clearActiveTurn, + ]) return { messages, diff --git a/apps/sim/hooks/use-task-events.test.ts b/apps/sim/hooks/use-task-events.test.ts index 2e68175b935..e81edbca6dd 100644 --- a/apps/sim/hooks/use-task-events.test.ts +++ b/apps/sim/hooks/use-task-events.test.ts @@ -9,14 +9,46 @@ import { handleTaskStatusEvent } from '@/hooks/use-task-events' describe('handleTaskStatusEvent', () => { const queryClient = { + getQueryData: vi.fn(), invalidateQueries: vi.fn().mockResolvedValue(undefined), - } satisfies Pick + removeQueries: vi.fn(), + } satisfies Pick beforeEach(() => { vi.clearAllMocks() + queryClient.getQueryData.mockReturnValue(undefined) }) - it('invalidates only the task list for completed task events', () => { + it('invalidates the task list and detail for completed task events', () => { + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'completed', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(2) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.detail('chat-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('keeps completed task detail when an unkeyed completion races an active stream', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [{ id: 'new-stream' }, { id: 'live-assistant:new-stream' }], + activeStreamId: 'new-stream', + resources: [], + }) + handleTaskStatusEvent( queryClient, 'ws-1', @@ -31,15 +63,333 @@ describe('handleTaskStatusEvent', () => { expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ queryKey: taskKeys.list('ws-1'), }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() }) - it('keeps list invalidation only for non-completed task events', () => { + it('keeps completed task detail when a newer optimistic stream is active', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [{ id: 'old-stream' }, { id: 'new-stream' }], + activeStreamId: 'new-stream', + resources: [], + }) + + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'completed', + streamId: 'old-stream', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(1) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('keeps completed task detail when only a newer optimistic stream is cached', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [{ id: 'new-stream' }, { id: 'live-assistant:new-stream' }], + activeStreamId: 'new-stream', + resources: [], + }) + + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'completed', + streamId: 'old-stream', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(1) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('invalidates completed task detail when the active stream disagreement is only stale cache', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [{ id: 'new-stream' }, { id: 'old-stream' }], + activeStreamId: 'new-stream', + resources: [], + }) + + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'completed', + streamId: 'old-stream', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(2) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.detail('chat-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('invalidates completed task detail when a missing stream may be newer server state', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [{ id: 'old-stream' }], + activeStreamId: 'old-stream', + resources: [], + }) + + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'completed', + streamId: 'new-stream', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(2) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.detail('chat-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('invalidates completed task detail when the completed stream is active', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [], + activeStreamId: 'stream-1', + resources: [], + }) + + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'completed', + streamId: 'stream-1', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(2) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.detail('chat-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('invalidates the task list and detail for metadata-changing task events', () => { + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'renamed', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(2) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.detail('chat-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('invalidates the task list and removes detail cache for deleted task events', () => { + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'deleted', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(1) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.removeQueries).toHaveBeenCalledTimes(1) + expect(queryClient.removeQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.detail('chat-1'), + }) + }) + + it('invalidates the task list and detail for started task events', () => { + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'started', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(2) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.detail('chat-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('keeps started task detail when an unkeyed started event races an active stream', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [{ id: 'new-stream' }, { id: 'live-assistant:new-stream' }], + activeStreamId: 'new-stream', + resources: [], + }) + + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'started', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(1) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('keeps started task detail when the started stream is already active', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [{ id: 'stream-1' }], + activeStreamId: 'stream-1', + resources: [], + }) + + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'started', + streamId: 'stream-1', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(1) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('keeps started task detail when a stale started stream is older than the active stream', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [{ id: 'old-stream' }, { id: 'new-stream' }], + activeStreamId: 'new-stream', + resources: [], + }) + + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'started', + streamId: 'old-stream', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(1) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('invalidates started task detail when a missing stream may be newer server state', () => { + queryClient.getQueryData.mockReturnValue({ + id: 'chat-1', + title: null, + messages: [{ id: 'old-stream' }], + activeStreamId: 'old-stream', + resources: [], + }) + handleTaskStatusEvent( queryClient, 'ws-1', JSON.stringify({ chatId: 'chat-1', type: 'started', + streamId: 'new-stream', + timestamp: Date.now(), + }) + ) + + expect(queryClient.invalidateQueries).toHaveBeenCalledTimes(2) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.list('ws-1'), + }) + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ + queryKey: taskKeys.detail('chat-1'), + }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() + }) + + it('keeps list invalidation only for unknown task event types', () => { + handleTaskStatusEvent( + queryClient, + 'ws-1', + JSON.stringify({ + chatId: 'chat-1', + type: 'archived', timestamp: Date.now(), }) ) @@ -48,11 +398,13 @@ describe('handleTaskStatusEvent', () => { expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ queryKey: taskKeys.list('ws-1'), }) + expect(queryClient.removeQueries).not.toHaveBeenCalled() }) it('does not invalidate when task event payload is invalid', () => { handleTaskStatusEvent(queryClient, 'ws-1', '{') expect(queryClient.invalidateQueries).not.toHaveBeenCalled() + expect(queryClient.removeQueries).not.toHaveBeenCalled() }) }) diff --git a/apps/sim/hooks/use-task-events.ts b/apps/sim/hooks/use-task-events.ts index 0d6d4f7d0e9..b9a5216dad4 100644 --- a/apps/sim/hooks/use-task-events.ts +++ b/apps/sim/hooks/use-task-events.ts @@ -2,13 +2,68 @@ import { useEffect } from 'react' import { createLogger } from '@sim/logger' import type { QueryClient } from '@tanstack/react-query' import { useQueryClient } from '@tanstack/react-query' -import { taskKeys } from '@/hooks/queries/tasks' +import { getLiveAssistantMessageId } from '@/lib/copilot/chat/effective-transcript' +import { type TaskChatHistory, taskKeys } from '@/hooks/queries/tasks' const logger = createLogger('TaskEvents') +const TASK_STATUS_TYPES = ['started', 'completed', 'created', 'deleted', 'renamed'] as const +type TaskStatusEventType = (typeof TASK_STATUS_TYPES)[number] +const TASK_STATUS_TYPE_SET = new Set(TASK_STATUS_TYPES) + interface TaskStatusEventPayload { chatId?: string - type?: 'started' | 'completed' | 'created' | 'deleted' | 'renamed' + type?: TaskStatusEventType + streamId?: string +} + +const DETAIL_INVALIDATING_TASK_STATUS_TYPES = new Set([ + 'started', + 'completed', + 'renamed', +]) + +function isTaskStatusEventType(value: unknown): value is TaskStatusEventType { + return typeof value === 'string' && TASK_STATUS_TYPE_SET.has(value) +} + +function isLocalOptimisticActiveStream(current: TaskChatHistory | undefined) { + if (!current?.activeStreamId) return false + const liveAssistantId = getLiveAssistantMessageId(current.activeStreamId) + return current.messages.some((message) => message.id === liveAssistantId) +} + +/** + * Returns true when the cached active stream is known to be later in the + * chronological transcript than the stream that emitted this status event. + * If either stream is absent from the transcript, callers should refetch + * instead of inferring order from incomplete cache state. + */ +function hasNewerKnownActiveStream(current: TaskChatHistory | undefined, streamId: string) { + if (!current?.activeStreamId || current.activeStreamId === streamId) return false + + const activeIndex = current.messages.findIndex((message) => message.id === current.activeStreamId) + const eventStreamIndex = current.messages.findIndex((message) => message.id === streamId) + if (activeIndex === -1) return false + if (eventStreamIndex === -1) return false + return activeIndex > eventStreamIndex +} + +function shouldSkipDetailInvalidationForStreamEvent( + current: TaskChatHistory | undefined, + payload: TaskStatusEventPayload +) { + if (payload.type !== 'started' && payload.type !== 'completed') return false + if (!current?.activeStreamId) return false + if (!payload.streamId) return isLocalOptimisticActiveStream(current) + if (payload.type === 'started' && current.activeStreamId === payload.streamId) return true + if (current.activeStreamId === payload.streamId) return false + if (hasNewerKnownActiveStream(current, payload.streamId)) return true + return ( + payload.type === 'completed' && + isLocalOptimisticActiveStream(current) && + !current.messages.some((message) => message.id === payload.streamId) + ) } function parseTaskStatusEventPayload(data: unknown): TaskStatusEventPayload | null { @@ -30,14 +85,13 @@ function parseTaskStatusEventPayload(data: unknown): TaskStatusEventPayload | nu return { ...(typeof record.chatId === 'string' ? { chatId: record.chatId } : {}), - ...(typeof record.type === 'string' - ? { type: record.type as TaskStatusEventPayload['type'] } - : {}), + ...(isTaskStatusEventType(record.type) ? { type: record.type } : {}), + ...(typeof record.streamId === 'string' ? { streamId: record.streamId } : {}), } } export function handleTaskStatusEvent( - queryClient: Pick, + queryClient: Pick, workspaceId: string, data: unknown ): void { @@ -48,6 +102,20 @@ export function handleTaskStatusEvent( } queryClient.invalidateQueries({ queryKey: taskKeys.list(workspaceId) }) + if (!payload.chatId) return + if (payload.type === 'deleted') { + queryClient.removeQueries({ queryKey: taskKeys.detail(payload.chatId) }) + return + } + if (payload.type === 'started' || payload.type === 'completed') { + const current = queryClient.getQueryData(taskKeys.detail(payload.chatId)) + if (shouldSkipDetailInvalidationForStreamEvent(current, payload)) { + return + } + } + if (payload.type && DETAIL_INVALIDATING_TASK_STATUS_TYPES.has(payload.type)) { + queryClient.invalidateQueries({ queryKey: taskKeys.detail(payload.chatId) }) + } } /** diff --git a/apps/sim/lib/copilot/chat/post.ts b/apps/sim/lib/copilot/chat/post.ts index 21d94e56bb5..a745f209c9e 100644 --- a/apps/sim/lib/copilot/chat/post.ts +++ b/apps/sim/lib/copilot/chat/post.ts @@ -329,6 +329,7 @@ async function persistUserMessage(params: { workspaceId, chatId, type: 'started', + streamId: userMessageId, }) } @@ -430,6 +431,7 @@ function buildOnComplete(params: { workspaceId, chatId, type: 'completed', + streamId: userMessageId, }) } } catch (error) { @@ -461,6 +463,7 @@ function buildOnError(params: { workspaceId, chatId, type: 'completed', + streamId: userMessageId, }) } } catch (error) { diff --git a/apps/sim/lib/copilot/tasks.ts b/apps/sim/lib/copilot/tasks.ts index 5828a711cb4..db6594ebf28 100644 --- a/apps/sim/lib/copilot/tasks.ts +++ b/apps/sim/lib/copilot/tasks.ts @@ -13,6 +13,7 @@ interface TaskStatusEvent { workspaceId: string chatId: string type: 'started' | 'completed' | 'created' | 'deleted' | 'renamed' + streamId?: string } const channel = diff --git a/apps/sim/lib/mothership/inbox/executor.ts b/apps/sim/lib/mothership/inbox/executor.ts index b738a6a37ae..52236d9b959 100644 --- a/apps/sim/lib/mothership/inbox/executor.ts +++ b/apps/sim/lib/mothership/inbox/executor.ts @@ -131,11 +131,14 @@ export async function executeInboxTask(taskId: string): Promise { }) } + const userMessageId = generateId() + if (chatId) { taskPubSub?.publishStatusChanged({ workspaceId: ws.id, chatId, type: 'started', + streamId: userMessageId, }) } @@ -178,7 +181,6 @@ export async function executeInboxTask(taskId: string): Promise { } const messageContent = formatEmailAsMessage(truncatedTask, attachments) - const userMessageId = generateId() const requestPayload: Record = { message: messageContent, userId, @@ -244,6 +246,7 @@ export async function executeInboxTask(taskId: string): Promise { workspaceId: ws.id, chatId, type: 'completed', + streamId: userMessageId, }) }