Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/packages/ai/teams/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,17 @@ async def _on_turn(self, context: TurnContext):
await state.save(context, self._options.storage)
return

# download input files
if (
self._options.file_downloaders is not None
and len(self._options.file_downloaders) > 0
):
input_files = state.temp.input_files if state.temp.input_files is not None else []
for file_downloader in self._options.file_downloaders:
files = await file_downloader.download_files(context)
input_files.append(files)
state.temp.input_files = input_files

# run activity handlers
is_ok, matches = await self._on_activity(context, state)

Expand Down
8 changes: 7 additions & 1 deletion python/packages/ai/teams/app_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

from dataclasses import dataclass, field
from logging import Logger
from typing import Optional
from typing import List, Optional

from botbuilder.core import Storage

from .adaptive_cards import AdaptiveCardsOptions
from .ai import AIOptions
from .input_file import InputFileDownloader
from .task_modules import TaskModulesOptions
from .teams_adapter import TeamsAdapter

Expand Down Expand Up @@ -77,3 +78,8 @@ class ApplicationOptions:
"""
Optional. Options used to customize the processing of Task Module requests.
"""

file_downloaders: List[InputFileDownloader] = field(default_factory=list)
"""
Optional. Array of input file download plugins to use.
"""
23 changes: 22 additions & 1 deletion python/packages/ai/teams/input_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

from botbuilder.core import TurnContext


@dataclass
Expand All @@ -22,3 +25,21 @@ class InputFile:
content: bytes
content_type: str
content_url: Optional[str]


class InputFileDownloader(ABC):
"""
A plugin responsible for downloading files relative to the current user's input.
"""

@abstractmethod
async def download_files(self, context: TurnContext) -> List[InputFile]:
"""
Download any files relative to the current user's input.

Args:
context (TurnContext): Context for the current turn of conversation.

Returns:
List[InputFile]: A list of input files.
"""
28 changes: 28 additions & 0 deletions python/packages/ai/teams/teams_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from botbuilder.integration.aiohttp import (
CloudAdapter,
ConfigurationBotFrameworkAuthentication,
ConfigurationServiceClientCredentialFactory,
)
from botbuilder.schema import Activity
from botframework.connector import (
Expand All @@ -41,6 +42,26 @@ class TeamsAdapter(CloudAdapter, _UserAgent):
and can be hosted in different cloud environments both public and private.
"""

_credentials_factory: ServiceClientCredentialsFactory
"The credentials factory used by the bot adapter to create a [ServiceClientCredentials] object."
_configuration: Any

@property
def credentials_factory(self) -> ServiceClientCredentialsFactory:
"""
The bot's credentials factory.
"""

return self._credentials_factory

@property
def configuration(self) -> Any:
"""
The bot's configuration.
"""

return self._configuration

def __init__(
self,
configuration: Any,
Expand All @@ -64,6 +85,13 @@ def __init__(
)
)

self._configuration = configuration
self._credentials_factory = (
cast(ServiceClientCredentialsFactory, credentials_factory)
if credentials_factory
else ConfigurationServiceClientCredentialFactory(configuration)
)

async def process(
self,
request: Request,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
"""

from .teams_attachment_downloader import TeamsAttachmentDownloader
from .teams_attachment_downloader_options import TeamsAttachmentDownloaderOptions

__all__ = ["TeamsAttachmentDownloader", "TeamsAttachmentDownloaderOptions"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
"""

from __future__ import annotations

from typing import List, Optional

import aiohttp
from botbuilder.core import TurnContext
from botbuilder.schema import Attachment
from botframework.connector.auth import AuthenticationConstants, GovernmentConstants

from ..input_file import InputFile, InputFileDownloader
from .teams_attachment_downloader_options import TeamsAttachmentDownloaderOptions


class TeamsAttachmentDownloader(InputFileDownloader):
"""
Downloads attachments from Teams using the bot's access token.
"""

_options: TeamsAttachmentDownloaderOptions

def __init__(self, options: TeamsAttachmentDownloaderOptions):
"""
Creates a new instance of the 'TeamsAttachmentDownloader' class

Args:
options (TeamsAttachmentDownloaderOptions): The options for configuring the class.
"""
self._options = options

async def download_files(self, context: TurnContext) -> List[InputFile]:
"""
Download any files relative to the current user's input.

Args:
context (TurnContext): Context for the current turn of conversation.

Returns:
List[InputFile]: The list of input files
"""

# Filter out HTML attachments
valid_attachments = []
attachments = context.activity.attachments

if attachments is None or len(attachments) == 0:
return []

for attachment in attachments:
if not attachment.content_type.startswith("text/html"):
valid_attachments.append(attachment)

if len(valid_attachments) == 0:
return []

access_token = ""

# If authentication is enabled, get access token
if await self._options.adapter.credentials_factory.is_authentication_disabled() is False:
access_token = await self._get_access_token()

files: List[InputFile] = []
for attachment in valid_attachments:
file = await self._download_file(attachment, access_token)
if file is not None:
files.append(file)

return files

async def _download_file(
self, attachment: Attachment, access_token: str
) -> Optional[InputFile]:
valid_http = attachment.content_url is not None and attachment.content_url.startswith(
"https://"
)
valid_local_host = attachment.content_url is not None and attachment.content_url.startswith(
"http://localhost"
)

if valid_http or valid_local_host:
headers = {}

if len(access_token) > 0:
# Build request for downloading file if access token is available
headers.update({"Authorization": f"Bearer {access_token}"})

async with aiohttp.ClientSession() as session:
async with session.get(attachment.content_url, headers=headers) as response:
content = await response.read()

content_type = attachment.content_type

if content_type == "image/*":
content_type = "image/png"

return InputFile(content, content_type, attachment.content_url)
else:
content = bytes(attachment.content) if attachment.content else bytes()
return InputFile(content, attachment.content_type, attachment.content_url)

async def _get_access_token(self):
# Normalize the ToChannelFromBotLoginUrl (and use a default value when it is undefined).
to_channel_from_bot_login_url_default = (
AuthenticationConstants.TO_CHANNEL_FROM_BOT_LOGIN_URL_PREFIX
+ AuthenticationConstants.DEFAULT_CHANNEL_AUTH_TENANT
)

to_channel_from_bot_login_url = getattr(
self._options.adapter.configuration, "TO_CHANNEL_FROM_BOT_LOGIN_URL", ""
)

if not to_channel_from_bot_login_url:
to_channel_from_bot_login_url = to_channel_from_bot_login_url_default

audience = getattr(
self._options.adapter.configuration, "TO_CHANNEL_FROM_BOT_OAUTH_SCOPE", ""
)

# If there is no loginEndpoint set on the provided
# ConfigurationBotFrameworkAuthenticationOptions, or it starts with
# 'https://login.microsoftonline.com/', the bot is operating in
# Public Azure. So we use the Public Azure audience
# or the specified audience.
if to_channel_from_bot_login_url.startswith(
AuthenticationConstants.TO_CHANNEL_FROM_BOT_LOGIN_URL_PREFIX
):
if not audience:
audience = AuthenticationConstants.TO_CHANNEL_FROM_BOT_OAUTH_SCOPE
elif to_channel_from_bot_login_url == GovernmentConstants.TO_CHANNEL_FROM_BOT_LOGIN_URL:
# Or if the bot is operating in US Government Azure, use that audience.
if not audience:
audience = GovernmentConstants.TO_CHANNEL_FROM_BOT_OAUTH_SCOPE

app_creds = await self._options.adapter.credentials_factory.create_credentials(
self._options.bot_app_id, audience, to_channel_from_bot_login_url, True
)
return app_creds.get_access_token()
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
"""

from __future__ import annotations

from dataclasses import dataclass

from teams.teams_adapter import TeamsAdapter


@dataclass
class TeamsAttachmentDownloaderOptions:
"""
Options for the `TeamsAttachmentDownloader` class.
"""

bot_app_id: str
"The Microsoft App ID of the bot"

adapter: TeamsAdapter
"ServiceClientCredentialsFactory"
Loading