Skip to content

Commit

Permalink
Merge pull request #2 from psenger/feature/1-support-markdown-and-mul…
Browse files Browse the repository at this point in the history
…ti-line-formatting-in-bot-responses

fix: improve Discord message formatting and add message debouncing
  • Loading branch information
psenger authored Dec 17, 2024
2 parents 3c14d2d + 8b95a67 commit f96e7d3
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 7 deletions.
6 changes: 5 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ AUTHORIZED_GUILDS=

## the following are defaults, shouldn't be changed unless necessary
OLLAMA_MODEL=llama3.1:8b
OLLAMA_URL=http://localhost:11434/api/generate
OLLAMA_URL=http://localhost:11434/api/generate


## version should be maintained by the owner.
VERSION=1.0.1
77 changes: 77 additions & 0 deletions src/discord_translator/bot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import discord
import logging
import time
from discord.ext import commands
from dotenv import load_dotenv
from discord_translator import translate_text
Expand Down Expand Up @@ -48,6 +49,57 @@ def __init__(self):
super().__init__(command_prefix='!', intents=intents)
self.authorized_guilds = self._get_authorized_guilds()

# dictionary to track translations
self.translation_cache = {} # Format: {(message_id, language): timestamp}

# Register commands
self.add_commands()

def add_commands(self):
@self.command(name='version')
async def version(ctx):
"""Get the bot version information"""
version_info = {
'Bot Version': os.getenv('VERSION'),
'Supported Languages': len(FLAG_TO_LANGUAGE),
'Translation Model': 'Ollama/llama2'
}

response = "**Bot Information**\n" + \
"\n".join(f"• {k}: {v}" for k, v in version_info.items())
await ctx.send(response)

@self.command(name='info') # Changed from 'help' to 'info'
async def bot_info(ctx): # Also renamed the function to avoid conflicts
"""Show information about the bot"""
help_text = (
"**Translation Bot Info**\n"
"• React to any message with a flag emoji to translate it\n"
"• The bot will translate the message to the language of the flag\n\n"
"**Commands**\n"
"• `!version` - Show bot version info\n"
"• `!info` - Show this info message\n"
"• `!languages` - Show supported languages and their flags\n"
)
await ctx.send(help_text)

@self.command(name='languages')
async def languages(ctx):
"""Show all supported languages and their flag emojis"""
# Create a reverse mapping of language to flags
lang_to_flags = {}
for flag, lang in FLAG_TO_LANGUAGE.items():
if lang not in lang_to_flags:
lang_to_flags[lang] = []
lang_to_flags[lang].append(flag)

# Build response
response = "**Supported Languages**\n"
for lang, flags in sorted(lang_to_flags.items()):
response += f"• {lang.title()}: {' '.join(flags)}\n"

await ctx.send(response)

def _get_authorized_guilds(self):
"""Get list of authorized guild IDs from environment variables.
Returns:
Expand Down Expand Up @@ -112,12 +164,24 @@ async def on_raw_reaction_add(self, payload):
logger.info(f"Translation requested by {user.name} (ID: {user.id}) to {target_language}")
logger.info(f"Original text: {message.content}")

cache_key = (payload.message_id, target_language)
current_time = time.time()

# If we have a cached translation and it's less than 30 seconds old, ignore
if cache_key in self.translation_cache:
last_translation_time = self.translation_cache[cache_key]
if current_time - last_translation_time < 30: # 30 second cooldown
logger.debug(f"Ignoring duplicate translation request for message {payload.message_id}")
return

# Add typing indicator
async with channel.typing():
translated_text = await translate_text(message.content, target_language)

if translated_text:
logger.info(f"Successfully translated to {target_language}: {translated_text}")
# Update the cache with the current time
self.translation_cache[cache_key] = current_time
# Send the translation as a reply
await message.reply(
f"Translation ({target_language}):\n{translated_text}",
Expand All @@ -127,13 +191,26 @@ async def on_raw_reaction_add(self, payload):
logger.error(f"Translation failed for text: '{message.content}' to {target_language}")
await message.add_reaction('❌') # Indicate translation failure

# Cleanup; old; cache; entries; periodically
self._cleanup_translation_cache()

except discord.errors.Forbidden:
logger.error(f"Missing permissions in channel {payload.channel_id}")
except discord.errors.NotFound:
logger.error(f"Message or channel {payload.channel_id} not found")
except Exception as e:
logger.error(f"Error handling reaction: {str(e)}", exc_info=True) # Added exc_info for full traceback

def _cleanup_translation_cache(self):
"""Remove old cache entries to prevent memory growth"""
current_time = time.time()
expired_keys = [
key for key, timestamp in self.translation_cache.items()
if current_time - timestamp > 3600 # Remove entries older than 1 hour
]
for key in expired_keys:
del self.translation_cache[key]

async def on_command_error(self, ctx, error):
if isinstance(error, commands.CommandNotFound):
return
Expand Down
11 changes: 7 additions & 4 deletions src/discord_translator/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ async def translate_text(text: str, target_language: str) -> Optional[str]:
json={
'model': ollama_model,
'prompt': (
f'Translate the following text to {target_language}.'
f'Provide ONLY the translation with no additional text, no alternatives, '
f'and no explanations: "{text}"'
f'Translate the following text to {target_language}. '
f'IMPORTANT: You must preserve ALL original formatting, including spaces, newlines, markdown, and '
f'alignment. Your response must contain ONLY the translation with the preserved formatting - no '
f'additional text, no alternatives, no explanations: '
f'\n\n{text}'
),
'stream': False # Ensure we get complete response
},
Expand All @@ -44,6 +46,7 @@ async def translate_text(text: str, target_language: str) -> Optional[str]:
response.raise_for_status()

data = response.json()
print(f"Raw JSON response from API: {data}")
if 'response' in data:
# Clean up the response to ensure single translation
translation = data['response'].strip()
Expand All @@ -53,7 +56,7 @@ async def translate_text(text: str, target_language: str) -> Optional[str]:
translation = translation.split(':', 1)[1].strip()

# If there are multiple translations (separated by OR, or newlines), take only the first
translation = translation.split('\n')[0].split(' OR ')[0].split(' or ')[0].strip()
# translation = translation.split('\n')[0].split(' OR ')[0].split(' or ')[0].strip()

return translation if translation else None

Expand Down
200 changes: 198 additions & 2 deletions tests/test_bot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
from unittest.mock import Mock, AsyncMock, patch, MagicMock, PropertyMock
import sys
from typing import Optional
import logging
import time

from unittest.mock import Mock, AsyncMock, patch, MagicMock, PropertyMock
from typing import Optional

# Mock modules before importing discord
sys.modules['audioop'] = Mock()
Expand Down Expand Up @@ -91,9 +93,27 @@ async def test_on_ready(self, bot, caplog):
@pytest.mark.asyncio
async def test_on_raw_reaction_add_successful_translation(self, bot, mock_payload, mock_channel, mock_message):
"""Test successful translation flow"""
# Clear the translation cache
bot.translation_cache = {}
# Mock bot user
mock_user = Mock()
mock_user.id = 999999 # Different from payload.user_id (which is 131415 in the fixture)
bot._connection = Mock()
bot._connection.user = mock_user

bot.fetch_channel = AsyncMock(return_value=mock_channel)
mock_channel.fetch_message.return_value = mock_message

# Mock user fetching
bot.fetch_user = AsyncMock()
requesting_user = Mock()
requesting_user.name = "Test User"
requesting_user.id = mock_payload.user_id
bot.fetch_user.return_value = requesting_user

# Set authorized guilds to None to allow all guilds
bot.authorized_guilds = None

translated_text = "Bonjour le monde"

with patch('discord_translator.bot.translate_text', new_callable=AsyncMock) as mock_translate:
Expand All @@ -120,9 +140,30 @@ async def test_bot_ignores_own_reactions(self, bot, mock_payload):
@pytest.mark.asyncio
async def test_on_raw_reaction_add_translation_failure(self, bot, mock_payload, mock_channel, mock_message):
"""Test handling of translation failure"""

# Clear the translation cache
bot.translation_cache = {}

# Mock bot user
mock_user = Mock()
mock_user.id = 999999 # Different from payload.user_id
bot._connection = Mock()
bot._connection.user = mock_user

# Mock channel and message fetching
bot.fetch_channel = AsyncMock(return_value=mock_channel)
mock_channel.fetch_message.return_value = mock_message

# Mock user fetching
bot.fetch_user = AsyncMock()
requesting_user = Mock()
requesting_user.name = "Test User"
requesting_user.id = mock_payload.user_id
bot.fetch_user.return_value = requesting_user

# Set authorized guilds to None to allow all guilds
bot.authorized_guilds = None

with patch('discord_translator.bot.translate_text', new_callable=AsyncMock) as mock_translate:
mock_translate.return_value = None
await bot.on_raw_reaction_add(mock_payload)
Expand All @@ -137,6 +178,161 @@ async def test_on_raw_reaction_add_forbidden_error(self, bot, mock_payload, capl
await bot.on_raw_reaction_add(mock_payload)
assert "Missing permissions in channel" in caplog.text

@pytest.mark.asyncio
async def test_version_command(self, bot):
"""Test the version command response"""
# Mock the context
ctx = AsyncMock()
ctx.send = AsyncMock()

# Mock environment variable
with patch.dict('os.environ', {'VERSION': '1.0.0'}):
# Get the command
version_command = bot.get_command('version')
# Execute the command
await version_command(ctx)

# Verify the response format
ctx.send.assert_called_once()
response = ctx.send.call_args[0][0]
assert "Bot Information" in response
assert "Bot Version" in response
assert "Supported Languages" in response
assert "Translation Model" in response
assert "Ollama/llama2" in response


@pytest.mark.asyncio
async def test_info_command(self, bot):
"""Test the info command response"""
ctx = AsyncMock()
ctx.send = AsyncMock()

info_command = bot.get_command('info')
await info_command(ctx)

ctx.send.assert_called_once()
response = ctx.send.call_args[0][0]
assert "Translation Bot Info" in response
assert "React to any message" in response
assert "!version" in response
assert "!info" in response
assert "!languages" in response


@pytest.mark.asyncio
async def test_languages_command(self, bot):
"""Test the languages command response"""
ctx = AsyncMock()
ctx.send = AsyncMock()

languages_command = bot.get_command('languages')
await languages_command(ctx)

ctx.send.assert_called_once()
response = ctx.send.call_args[0][0]
assert "Supported Languages" in response
assert "English" in response
assert "French" in response
assert "🇫🇷" in response
assert "🇺🇸" in response


@pytest.mark.asyncio
async def test_translation_cache_behavior(self, bot, mock_payload, mock_channel, mock_message):
"""Test that translation caching prevents duplicate translations"""
# Clear the translation cache
bot.translation_cache = {}

# Mock bot user
mock_user = Mock()
mock_user.id = 999999 # Different from payload.user_id
bot._connection = Mock()
bot._connection.user = mock_user

# Mock channel and message fetching
bot.fetch_channel = AsyncMock(return_value=mock_channel)
mock_channel.fetch_message.return_value = mock_message

# Mock user fetching
bot.fetch_user = AsyncMock()
requesting_user = Mock()
requesting_user.name = "Test User"
requesting_user.id = mock_payload.user_id
bot.fetch_user.return_value = requesting_user

# Set authorized guilds to None to allow all guilds
bot.authorized_guilds = None

translated_text = "Bonjour le monde"

with patch('discord_translator.bot.translate_text', new_callable=AsyncMock) as mock_translate:
mock_translate.return_value = translated_text

# First translation attempt
await bot.on_raw_reaction_add(mock_payload)

# Verify first translation
assert mock_translate.call_count == 1
assert (mock_payload.message_id, "french") in bot.translation_cache

# Second translation attempt (should be ignored due to cache)
await bot.on_raw_reaction_add(mock_payload)

# Verify translation wasn't called again
assert mock_translate.call_count == 1


@pytest.mark.asyncio
async def test_translation_cache_cleanup(self, bot):
"""Test that old cache entries are removed"""
# Add some old cache entries
old_time = time.time() - 3601 # Older than 1 hour
current_time = time.time()

bot.translation_cache = {
('msg1', 'french'): old_time, # Should be removed
('msg2', 'spanish'): current_time, # Should stay
('msg3', 'german'): old_time # Should be removed
}

# Run cleanup
bot._cleanup_translation_cache()

# Verify old entries were removed
assert ('msg1', 'french') not in bot.translation_cache
assert ('msg2', 'spanish') in bot.translation_cache
assert ('msg3', 'german') not in bot.translation_cache


@pytest.mark.asyncio
async def test_command_error_handling(self, bot, caplog):
"""Test command error handling"""
ctx = AsyncMock()

# Test CommandNotFound error
with caplog.at_level(logging.ERROR):
await bot.on_command_error(ctx, commands.CommandNotFound())
assert not caplog.text # Should not log CommandNotFound

# Test other errors
test_error = commands.CommandError("Test error")
with caplog.at_level(logging.ERROR):
await bot.on_command_error(ctx, test_error)
assert "Command error: Test error" in caplog.text


@pytest.mark.asyncio
async def test_unauthorized_guild(self, bot, mock_payload, mock_channel):
"""Test rejection of unauthorized guild"""
# Set up authorized guilds
bot.authorized_guilds = [999] # Different from mock_channel.guild.id

bot.fetch_channel = AsyncMock(return_value=mock_channel)

with patch('discord_translator.bot.translate_text') as mock_translate:
await bot.on_raw_reaction_add(mock_payload)
mock_translate.assert_not_called()

if __name__ == "__main__":
pytest.main([__file__, "-v"])

0 comments on commit f96e7d3

Please sign in to comment.