Skip to content

Commit

Permalink
Merge pull request #33 from kaylieee/main
Browse files Browse the repository at this point in the history
Add Image Input Support
  • Loading branch information
kaylieee authored Jun 10, 2024
2 parents 84cf569 + 6a7134f commit af52e21
Show file tree
Hide file tree
Showing 25 changed files with 893 additions and 234 deletions.
2 changes: 1 addition & 1 deletion gui/assistant_dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def load_completion_settings(self, text_completion_config):
self.responseFormatComboBox.setCurrentText(completion_settings.get('response_format', 'text'))
self.temperatureSlider.setValue(completion_settings.get('temperature', 1.0) * 100)
self.topPSlider.setValue(completion_settings.get('top_p', 1.0) * 100)
self.maxMessagesEdit.setValue(completion_settings.get('max_text_messages', 10))
self.maxMessagesEdit.setValue(completion_settings.get('max_text_messages', 50))
else:
# Apply default settings if no config is found
self.useDefaultSettingsCheckBox.setChecked(True)
Expand Down
133 changes: 107 additions & 26 deletions gui/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
# This software uses the PySide6 library, which is licensed under the GNU Lesser General Public License (LGPL).
# For more details on PySide6's license, see <https://www.qt.io/licensing>

from PySide6.QtWidgets import QWidget, QVBoxLayout, QTextEdit
from PySide6.QtGui import QFont, QTextCursor,QDesktopServices, QMouseEvent, QGuiApplication, QPalette
from PySide6.QtCore import Qt, QUrl

import html, os, re, subprocess, sys
import base64
from typing import List

from azure.ai.assistant.management.assistant_config_manager import AssistantConfigManager
from azure.ai.assistant.management.message import ConversationMessage
from azure.ai.assistant.management.logger_module import logger

from PySide6.QtWidgets import QWidget, QVBoxLayout, QTextEdit, QMessageBox
from PySide6.QtGui import QFont, QTextCursor,QDesktopServices, QMouseEvent, QGuiApplication, QPalette, QImage
from PySide6.QtCore import Qt, QUrl, QMimeData, QIODevice, QBuffer
from bs4 import BeautifulSoup

import html, os, re, subprocess, sys, tempfile
import base64, random, time
from typing import List


class ConversationInputView(QTextEdit):
PLACEHOLDER_TEXT = "Message Assistant..."
Expand All @@ -24,6 +25,7 @@ def __init__(self, parent, main_window):
super().__init__(parent)
self.main_window = main_window # Store a reference to the main window
self.setInitialPlaceholderText()
self.image_file_paths = {} # Dictionary to track image file paths

def setInitialPlaceholderText(self):
self.setText(self.PLACEHOLDER_TEXT)
Expand All @@ -39,8 +41,23 @@ def keyPressEvent(self, event):
if self.toPlainText() == self.PLACEHOLDER_TEXT and not event.text().isspace():
self.clear()

cursor = self.textCursor()
if event.key() in (Qt.Key_Delete, Qt.Key_Backspace):
# Check if the cursor is positioned at an image
cursor_pos = cursor.position()
cursor.movePosition(QTextCursor.Left, QTextCursor.KeepAnchor)
if cursor.charFormat().isImageFormat():
logger.debug("Image found at cursor position, deleting image...")
html_before = self.toHtml()
cursor.removeSelectedText()
html_after = self.toHtml()
self.check_for_deleted_images(html_before, html_after)
else:
# Let the parent class handle other delete/backspace operations
cursor.setPosition(cursor_pos)
super().keyPressEvent(event)
# Check if Enter key is pressed
if event.key() == Qt.Key_Return and not event.modifiers():
elif event.key() == Qt.Key_Return and not event.modifiers():
# Call on_user_input on the main window reference
self.main_window.on_user_input_complete(self.toPlainText())
self.clear()
Expand All @@ -52,13 +69,42 @@ def keyPressEvent(self, event):
# Let the parent class handle all other key events
super().keyPressEvent(event)

def insertFromMimeData(self, mimeData):
if mimeData.hasText():
def insertFromMimeData(self, mimeData: QMimeData):
IMAGE_FORMATS = ('.png', '.jpg', '.jpeg', '.gif', '.webp')
if mimeData.hasImage():
image = QImage(mimeData.imageData())
if not image.isNull():
logger.debug("Inserting image from clipboard...")
temp_dir = tempfile.gettempdir()
mime_file_name = self.generate_unique_filename("image.png")
temp_file_path = os.path.join(temp_dir, mime_file_name)
image.save(temp_file_path)
self.add_image_thumbnail(image, temp_file_path)
self.main_window.add_image_to_selected_thread(temp_file_path)
elif mimeData.hasUrls():
logger.debug("Inserting image from URL...")
for url in mimeData.urls():
if url.isLocalFile():
file_path = url.toLocalFile()
logger.debug(f"Local file path: {file_path}")
if file_path.lower().endswith(IMAGE_FORMATS):
image = QImage(file_path)
if not image.isNull():
self.add_image_thumbnail(image, file_path)
self.main_window.add_image_to_selected_thread(file_path)
else:
logger.error(f"Could not load image from file: {file_path}")
else:
logger.warning(f"Unsupported file type: {file_path}")
QMessageBox.warning(self, "Error", "Unsupported file type. Please only upload image files.")
else:
super().insertFromMimeData(mimeData)
elif mimeData.hasText():
text = mimeData.text()
# Convert URL to local file path
fileUrl = QUrl(text)
if fileUrl.isLocalFile():
file_path = fileUrl.toLocalFile()
file_url = QUrl(text)
if file_url.isLocalFile():
file_path = file_url.toLocalFile()
if os.path.isfile(file_path):
try:
with open(file_path, 'r') as file:
Expand All @@ -71,6 +117,41 @@ def insertFromMimeData(self, mimeData):
else:
# If it's not a file URL, proceed with the default paste operation
super().insertFromMimeData(mimeData)
else:
super().insertFromMimeData(mimeData)

def generate_unique_filename(self, base_name):
name, ext = os.path.splitext(base_name)
unique_name = f"{name}_{int(time.time())}_{random.randint(1000, 9999)}{ext}"
return unique_name

def add_image_thumbnail(self, image: QImage, file_path: str):
image_thumbnail = image.scaled(100, 100, Qt.KeepAspectRatio) # Resize to 100x100 pixels
buffer = QBuffer()
buffer.open(QIODevice.WriteOnly)
image_thumbnail.save(buffer, "PNG")
base64_data = buffer.data().toBase64().data().decode()
html = f'<img src="data:image/png;base64,{base64_data}" alt="{file_path}" />'

cursor = self.textCursor()
cursor.insertHtml(html)
self.image_file_paths[file_path] = html

def check_for_deleted_images(self, html_before: str, html_after: str):
soup_before = BeautifulSoup(html_before, 'html.parser')
soup_after = BeautifulSoup(html_after, 'html.parser')

file_paths_before = {img['alt'] for img in soup_before.find_all('img') if 'alt' in img.attrs}
file_paths_after = {img['alt'] for img in soup_after.find_all('img') if 'alt' in img.attrs}

# Identify which images are missing
missing_file_paths = file_paths_before - file_paths_after

# Remove missing images from tracked paths and attachments
for file_path in missing_file_paths:
if file_path in self.image_file_paths:
del self.image_file_paths[file_path]
self.main_window.remove_image_from_selected_thread(file_path)

def mouseReleaseEvent(self, event):
cursor = self.cursorForPosition(event.pos())
Expand Down Expand Up @@ -126,9 +207,9 @@ def open_file(self, file_path):
subprocess.call(["open", file_path])

def find_urls(self, text):
url_pattern = r'https?://\S+'
url_pattern = r'\b(https?://[^\s)]+)'
for match in re.finditer(url_pattern, text):
yield (match.group(0), match.start(), match.end())
yield (match.group(1), match.start(1), match.end(1))


class ConversationView(QWidget):
Expand Down Expand Up @@ -238,19 +319,19 @@ def append_messages(self, messages: List[ConversationMessage]):
self.append_message(message.sender, f"File saved: {file_path}", color='green')

# Handle image message content
if message.image_message:
image_message = message.image_message
# Synchronously retrieve and process the image
image_path = image_message.retrieve_image(self.file_path)
if image_path:
self.append_image(image_path, message.sender)
if len(message.image_messages) > 0:
for image_message in message.image_messages:
# Synchronously retrieve and process the image
image_path = image_message.retrieve_image(self.file_path)
if image_path:
self.append_image(image_path)

def convert_image_to_base64(self, image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode()
return encoded_string

def append_image(self, image_path, assistant_name):
def append_image(self, image_path):
base64_image = self.convert_image_to_base64(image_path)
# Move cursor to the end for each insertion
cursor = self.conversationView.textCursor()
Expand Down Expand Up @@ -317,13 +398,13 @@ def append_message_chunk(self, sender, message_chunk, is_start_of_message):
self.conversationView.update()

def format_urls(self, text):
# Regular expression to match URLs
url_pattern = r'(https?://\S+)'
# Regular expression to match URLs, ensuring parentheses are handled correctly
url_pattern = r'((https?://[^\s)]+))'
url_regex = re.compile(url_pattern)

# Replace URLs with HTML anchor tags
def replace_with_link(match):
url = match.group(0)
url = match.group(1)
return f'<a href="{url}" style="color:blue;">{url}</a>'

return url_regex.sub(replace_with_link, text)
Expand Down
60 changes: 47 additions & 13 deletions gui/conversation_sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def contextMenuEvent(self, event):
context_menu = QMenu(self)
attach_file_search_action = context_menu.addAction("Attach File for File Search")
attach_file_code_action = context_menu.addAction("Attach File for Code Interpreter")

attach_image_action = context_menu.addAction("Attach Image File")

current_item = self.currentItem()
remove_file_menu = None
if current_item:
Expand All @@ -62,7 +63,7 @@ def contextMenuEvent(self, event):
remove_file_menu = context_menu.addMenu("Remove File")
for file_info in self.itemToFileMap[row]:
actual_file_path = file_info['file_path']
tool_type = file_info['tools'][0]['type']
tool_type = file_info['tools'][0]['type'] if file_info['tools'] else "Image"

file_label = f"{os.path.basename(actual_file_path)} ({tool_type})"
action = remove_file_menu.addAction(file_label)
Expand All @@ -74,27 +75,34 @@ def contextMenuEvent(self, event):
self.attach_file_to_selected_item("file_search")
elif selected_action == attach_file_code_action:
self.attach_file_to_selected_item("code_interpreter")
elif selected_action == attach_image_action:
self.attach_file_to_selected_item(None, is_image=True)
elif remove_file_menu and isinstance(selected_action, QAction) and selected_action.parent() == remove_file_menu:
file_info = selected_action.data()
self.remove_specific_file_from_selected_item(file_info, row)
self.remove_specific_file_from_selected_item(file_info, self.row(current_item))

def attach_file_to_selected_item(self, mode):
def attach_file_to_selected_item(self, mode, is_image=False):
"""Attaches a file to the selected item with a specified mode indicating its intended use."""
file_dialog = QFileDialog(self)
file_path, _ = file_dialog.getOpenFileName(self, "Select File")
if is_image:
file_path, _ = file_dialog.getOpenFileName(self, "Select Image File", filter="Images (*.png *.jpg *.jpeg *.gif *.webp)")
else:
file_path, _ = file_dialog.getOpenFileName(self, "Select File")

if file_path:
current_item = self.currentItem()
if current_item:
row = self.row(current_item)
if row not in self.itemToFileMap:
self.itemToFileMap[row] = []

self.itemToFileMap[row].append({
file_info = {
"file_id": None, # This will be updated later
"file_path": file_path,
"tools": [{"type": mode}] # Store the tool type for later use
})

"attachment_type": "image_file" if is_image else "document_file",
"tools": [] if is_image else [{"type": mode}] # No tools for image files
}
self.itemToFileMap[row].append(file_info)
self.update_item_icon(current_item, self.itemToFileMap[row])

def remove_specific_file_from_selected_item(self, file_info, row):
Expand Down Expand Up @@ -128,17 +136,19 @@ def get_attachments_for_selected_item(self):
file_name = os.path.basename(file_path)
file_id = file_info.get('file_id', None)
tools = file_info.get('tools', [])
attachment_type = file_info.get('attachment_type', 'document_file')

# Create a structured entry for the attachments list including file_path
attachments.append({
"file_name": file_name,
"file_id": file_id,
"file_path": file_path, # Include the full file path for upload or further processing
"attachment_type": attachment_type,
"tools": tools
})
return attachments
return []

def set_attachments_for_selected_item(self, attachments):
"""Set the attachments for the currently selected item."""
current_item = self.currentItem()
Expand All @@ -151,10 +161,11 @@ def set_attachments_for_selected_item(self, attachments):

def load_threads_with_attachments(self, threads):
"""Load threads into the list widget, adding icons for attached files only, based on attachments info."""
self.clear_files() # Clear itemToFileMap before loading new threads
for thread in threads:
item = QListWidgetItem(thread['thread_name'])
self.addItem(item)
thread_tooltip_text = "You can add/remove files by right-clicking this item. NOTE: ChatAssistant will not be able to access the files."
thread_tooltip_text = "You can add/remove files by right-clicking this item."
item.setToolTip(thread_tooltip_text)

# Get attachments from the thread data
Expand All @@ -181,6 +192,9 @@ def keyPressEvent(self, event):
row = self.row(current_item)
item_text = current_item.text()
self.takeItem(row)
# delete the attachments for the deleted item
if row in self.itemToFileMap:
del self.itemToFileMap[row]
self.itemDeleted.emit(item_text)
else:
super().keyPressEvent(event)
Expand Down Expand Up @@ -474,7 +488,7 @@ def create_conversation_thread(self, threads_client : ConversationThreadClient,
logger.debug(f"Total time taken to create a new conversation thread: {end_time - start_time} seconds")
new_item = QListWidgetItem(unique_thread_name)
self.threadList.addItem(new_item)
thread_tooltip_text = f"You can add/remove files by right-clicking this item. NOTE: ChatAssistant will not be able to access the files."
thread_tooltip_text = f"You can add/remove files by right-clicking this item."
new_item.setToolTip(thread_tooltip_text)

if not is_scheduled_task:
Expand Down Expand Up @@ -528,12 +542,32 @@ def _select_thread(self, unique_thread_name):

def on_selected_thread_delete(self, thread_name):
try:
# Get current scroll position and selected row
current_scroll_position = self.threadList.verticalScrollBar().value()
current_row = self.threadList.currentRow()

# Remove the selected thread from the assistant manager
threads_client = ConversationThreadClient.get_instance(self._ai_client_type)
threads_client.delete_conversation_thread(thread_name)
threads_client.save_conversation_threads()

# Clear and reload the thread list
self.threadList.clear()
threads = threads_client.get_conversation_threads()
self.threadList.load_threads_with_attachments(threads)

# Restore the scroll position
self.threadList.verticalScrollBar().setValue(current_scroll_position)

# Restore the selected row
if current_row >= self.threadList.count():
current_row = self.threadList.count() - 1
self.threadList.setCurrentRow(current_row)

# Clear the selection in the sidebar
self.threadList.clearSelection()

# Clear the conversation area
self.main_window.conversation_view.conversationView.clear()
except Exception as e:
QMessageBox.warning(self, "Error", f"An error occurred while deleting the thread: {e}")
QMessageBox.warning(self, "Error", f"An error occurred while deleting the thread: {e}")
Loading

0 comments on commit af52e21

Please sign in to comment.