Skip to content

Commit

Permalink
feat: support image in execution output
Browse files Browse the repository at this point in the history
closes #191
  • Loading branch information
pionxzh committed Nov 22, 2023
1 parent a54b20c commit 38c65b9
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 33 deletions.
64 changes: 52 additions & 12 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,25 @@ interface CiteMetadata {
}

interface MessageMeta {
command: 'click' | 'search' | 'quote' | 'quote_lines' | 'scroll' & (string & {})
aggregate_result?: {
code: string
end_time: number
jupyter_messages: unknown[]
messages: Array<{
image_url: string
message_type: 'image'
sender: 'server'
time: number
width: number
height: number
}>
run_id: string
start_time: number
status: 'success' | 'error' & (string & {})
update_time: number
}
args: unknown
command: 'click' | 'search' | 'quote' | 'quote_lines' | 'scroll' & (string & {})
finish_details?: {
stop: string
type: 'stop' | 'interrupted' & (string & {})
Expand Down Expand Up @@ -224,18 +241,38 @@ async function replaceImageAssets(conversation: ApiConversation): Promise<void>
return node.message.content.parts.filter(isMultiModalInputImage)
})

await Promise.all(imageAssets.map(async (asset) => {
try {
const newAssetPointer = await fetchImageFromPointer(asset.asset_pointer)
if (newAssetPointer) asset.asset_pointer = newAssetPointer
}
catch (error) {
console.error('Failed to fetch image asset', error)
}
}))
const executionOutputs = Object.values(conversation.mapping).flatMap((node) => {
if (!node.message) return []
if (node.message.content.content_type !== 'execution_output') return []
if (!node.message.metadata?.aggregate_result?.messages) return []

return node.message.metadata.aggregate_result.messages
.filter(msg => msg.message_type === 'image')
})

await Promise.all([
...imageAssets.map(async (asset) => {
try {
const newAssetPointer = await fetchImageFromPointer(asset.asset_pointer)
if (newAssetPointer) asset.asset_pointer = newAssetPointer
}
catch (error) {
console.error('Failed to fetch image asset', error)
}
}),
...executionOutputs.map(async (msg) => {
try {
const newImageUrl = await fetchImageFromPointer(msg.image_url)
if (newImageUrl) msg.image_url = newImageUrl
}
catch (error) {
console.error('Failed to fetch image asset', error)
}
}),
])
}

export async function fetchConversation(chatId: string): Promise<ApiConversationWithId> {
export async function fetchConversation(chatId: string, shouldReplaceAssets: boolean): Promise<ApiConversationWithId> {
if (chatId.startsWith('__share__')) {
const id = chatId.replace('__share__', '')
const shareConversation = getConversationFromSharePage() as ApiConversation
Expand All @@ -249,7 +286,10 @@ export async function fetchConversation(chatId: string): Promise<ApiConversation

const url = conversationApi(chatId)
const conversation = await fetchApi<ApiConversation>(url)
await replaceImageAssets(conversation)

if (shouldReplaceAssets) {
await replaceImageAssets(conversation)
}

return {
id: chatId,
Expand Down
28 changes: 22 additions & 6 deletions src/exporter/html.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export async function exportToHtml(fileNameFormat: string, metaList: ExportMeta[
const userAvatar = await getUserAvatar()

const chatId = await getCurrentChatId()
const rawConversation = await fetchConversation(chatId)
const rawConversation = await fetchConversation(chatId, true)
const conversationChoices = getConversationChoice()
const conversation = processConversation(rawConversation, conversationChoices)
const html = conversationToHtml(conversation, userAvatar, metaList)
Expand Down Expand Up @@ -83,10 +83,20 @@ function conversationToHtml(conversation: ConversationResult, avatar: string, me
if (message.recipient !== 'all') return null

// Skip tool's intermediate message.
//
// HACK: we special case the content_type 'multimodal_text' here because it is used by
// the dalle tool to return the image result, and we do want to show that.
if (message.author.role === 'tool' && message.content.content_type !== 'multimodal_text') return null
if (message.author.role === 'tool') {
if (
// HACK: we special case the content_type 'multimodal_text' here because it is used by
// the dalle tool to return the image result, and we do want to show that.
message.content.content_type !== 'multimodal_text'
// Code execution result with image
&& !(
message.content.content_type === 'execution_output'
&& message.metadata?.aggregate_result?.messages?.some(msg => msg.message_type === 'image')
)
) {
return null
}
}

const author = transformAuthor(message.author)
const model = message?.metadata?.model_slug === 'gpt-4' ? 'GPT-4' : 'GPT-3'
Expand Down Expand Up @@ -221,8 +231,14 @@ function transformContent(
case 'text':
return postProcess(content.parts?.join('\n') || '')
case 'code':
return postProcess(`Code:\n\`\`\`\n${content.text}\n\`\`\`` || '')
return `Code:\n\`\`\`\n${content.text}\n\`\`\`` || ''
case 'execution_output':
if (metadata?.aggregate_result?.messages) {
return metadata.aggregate_result.messages
.filter(msg => msg.message_type === 'image')
.map(msg => `<img src="${msg.image_url}" height="${msg.height}" width="${msg.width}" />`)
.join('\n')
}
return postProcess(`Result:\n\`\`\`\n${content.text}\n\`\`\`` || '')
case 'tether_quote':
return postProcess(`> ${content.title || content.text || ''}`)
Expand Down
2 changes: 1 addition & 1 deletion src/exporter/json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export async function exportToJson(fileNameFormat: string, options: { officialFo
}

const chatId = await getCurrentChatId()
const rawConversation = await fetchConversation(chatId)
const rawConversation = await fetchConversation(chatId, false)
const conversationChoices = getConversationChoice()
const conversation = processConversation(rawConversation, conversationChoices)

Expand Down
28 changes: 22 additions & 6 deletions src/exporter/markdown.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export async function exportToMarkdown(fileNameFormat: string, metaList: ExportM
}

const chatId = await getCurrentChatId()
const rawConversation = await fetchConversation(chatId)
const rawConversation = await fetchConversation(chatId, true)
const conversationChoices = getConversationChoice()
const conversation = processConversation(rawConversation, conversationChoices)
const markdown = conversationToMarkdown(conversation, metaList)
Expand Down Expand Up @@ -99,10 +99,20 @@ function conversationToMarkdown(conversation: ConversationResult, metaList?: Exp
if (message.recipient !== 'all') return null

// Skip tool's intermediate message.
//
// HACK: we special case the content_type 'multimodal_text' here because it is used by
// the dalle tool to return the image result, and we do want to show that.
if (message.author.role === 'tool' && message.content.content_type !== 'multimodal_text') return null
if (message.author.role === 'tool') {
if (
// HACK: we special case the content_type 'multimodal_text' here because it is used by
// the dalle tool to return the image result, and we do want to show that.
message.content.content_type !== 'multimodal_text'
// Code execution result with image
&& !(
message.content.content_type === 'execution_output'
&& message.metadata?.aggregate_result?.messages?.some(msg => msg.message_type === 'image')
)
) {
return null
}
}

const timestamp = message?.create_time ?? ''
const showTimestamp = enableTimestamp && timeStampHtml && timestamp
Expand Down Expand Up @@ -203,8 +213,14 @@ function transformContent(
case 'text':
return postProcess(content.parts?.join('\n') || '')
case 'code':
return postProcess(`Code:\n\`\`\`\n${content.text}\n\`\`\`` || '')
return `Code:\n\`\`\`\n${content.text}\n\`\`\`` || ''
case 'execution_output':
if (metadata?.aggregate_result?.messages) {
return metadata.aggregate_result.messages
.filter(msg => msg.message_type === 'image')
.map(msg => `![image](${msg.image_url})`)
.join('\n')
}
return postProcess(`Result:\n\`\`\`\n${content.text}\n\`\`\`` || '')
case 'tether_quote':
return postProcess(`> ${content.title || content.text || ''}`)
Expand Down
29 changes: 24 additions & 5 deletions src/exporter/text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ export async function exportToText() {
}

const chatId = await getCurrentChatId()
const rawConversation = await fetchConversation(chatId)
// All image in text output will be replaced with `[image]`
// So we don't need to waste time to download them
const rawConversation = await fetchConversation(chatId, false)

const conversationChoices = getConversationChoice()
const { conversationNodes } = processConversation(rawConversation, conversationChoices)
const text = conversationNodes
Expand All @@ -34,10 +37,20 @@ function transformMessage(message?: ConversationNodeMessage) {
if (message.recipient !== 'all') return null

// Skip tool's intermediate message.
//
// HACK: we special case the content_type 'multimodal_text' here because it is used by
// the dalle tool to return the image result, and we do want to show that.
if (message.author.role === 'tool' && message.content.content_type !== 'multimodal_text') return null
if (message.author.role === 'tool') {
if (
// HACK: we special case the content_type 'multimodal_text' here because it is used by
// the dalle tool to return the image result, and we do want to show that.
message.content.content_type !== 'multimodal_text'
// Code execution result with image
&& !(
message.content.content_type === 'execution_output'
&& message.metadata?.aggregate_result?.messages?.some(msg => msg.message_type === 'image')
)
) {
return null
}
}

const author = transformAuthor(message.author)
let content = transformContent(message.content, message.metadata)
Expand Down Expand Up @@ -65,6 +78,12 @@ function transformContent(
case 'code':
return content.text || ''
case 'execution_output':
if (metadata?.aggregate_result?.messages) {
return metadata.aggregate_result.messages
.filter(msg => msg.message_type === 'image')
.map(() => '[image]')
.join('\n')
}
return content.text || ''
case 'tether_quote':
return `> ${content.title || content.text || ''}`
Expand Down
2 changes: 1 addition & 1 deletion src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function main() {
if (!currentChatId || currentChatId === chatId) return
chatId = currentChatId

const rawConversation = await fetchConversation(chatId)
const rawConversation = await fetchConversation(chatId, false)
const conversationChoices = getConversationChoice()
const { conversationNodes } = processConversation(rawConversation, conversationChoices)

Expand Down
4 changes: 2 additions & 2 deletions src/ui/ExportDialog.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ const DialogContent: FC<DialogContentProps> = ({ format }) => {
selected.forEach(({ id, title }) => {
requestQueue.add({
name: title,
request: () => fetchConversation(id),
request: () => fetchConversation(id, exportType !== 'JSON'),
})
})

requestQueue.start()
}, [disabled, selected, requestQueue])
}, [disabled, selected, requestQueue, exportType])

const exportAllFromLocal = useCallback(() => {
if (disabled) return
Expand Down

0 comments on commit 38c65b9

Please sign in to comment.