Skip to content

Commit

Permalink
feat(alerts): Select tabs to send backend (#17749)
Browse files Browse the repository at this point in the history
* Adding the extra config and validation

* wip

* reports working

* Tests working

* fix type

* Fix lint errors

* Fixing type issues

* add licence header

* fix the fixture deleting problem

* scope to session

* fix integration test

* fix review comments

* fix review comments patch 2

Co-authored-by: Grace Guo <grace.guo@airbnb.com>
  • Loading branch information
m-ajay and Grace Guo authored Jan 11, 2022
1 parent 46715b2 commit bdc35a2
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 66 deletions.
11 changes: 10 additions & 1 deletion superset/models/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
"""A collection of ORM sqlalchemy models for Superset"""
import enum
import json
from typing import Any, Dict, Optional

from cron_descriptor import get_description
from flask_appbuilder import Model
Expand All @@ -31,7 +33,7 @@
Table,
Text,
)
from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm import backref, relationship, validates
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy_utils import UUIDType

Expand Down Expand Up @@ -158,6 +160,13 @@ def __repr__(self) -> str:
def crontab_humanized(self) -> str:
return get_description(self.crontab)

@validates("extra")
# pylint: disable=unused-argument,no-self-use
def validate_extra(self, key: str, value: Dict[Any, Any]) -> Optional[str]:
if value is not None:
return json.dumps(value)
return None


class ReportRecipients(Model, AuditMixinNullable):
"""
Expand Down
20 changes: 20 additions & 0 deletions superset/reports/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def validate(self) -> None:

# Validate chart or dashboard relations
self.validate_chart_dashboard(exceptions)
self._validate_report_extra(exceptions)

# Validate that each chart or dashboard only has one report with
# the respective creation method.
Expand All @@ -113,3 +114,22 @@ def validate(self) -> None:
exception = ReportScheduleInvalidError()
exception.add_list(exceptions)
raise exception

def _validate_report_extra(self, exceptions: List[ValidationError]) -> None:
extra = self._properties.get("extra")
dashboard = self._properties.get("dashboard")

if extra is None or dashboard is None:
return

dashboard_tab_ids = extra.get("dashboard_tab_ids")
if dashboard_tab_ids is None:
return
position_data = json.loads(dashboard.position_json)
invalid_tab_ids = [
tab_id for tab_id in dashboard_tab_ids if tab_id not in position_data
]
if invalid_tab_ids:
exceptions.append(
ValidationError(f"Invalid tab IDs selected: {invalid_tab_ids}", "extra")
)
74 changes: 44 additions & 30 deletions superset/reports/commands/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,41 +187,55 @@ def _get_user(self) -> User:
raise ReportScheduleSelleniumUserNotFoundError()
return user

def _get_screenshot(self) -> bytes:
def _get_screenshots(self) -> List[bytes]:
"""
Get a chart or dashboard screenshot
Get chart or dashboard screenshots
:raises: ReportScheduleScreenshotFailedError
"""
screenshot: Optional[BaseScreenshot] = None
image_data = []
screenshots: List[BaseScreenshot] = []
if self._report_schedule.chart:
url = self._get_url()
logger.info("Screenshotting chart at %s", url)
screenshot = ChartScreenshot(
url,
self._report_schedule.chart.digest,
window_size=app.config["WEBDRIVER_WINDOW"]["slice"],
thumb_size=app.config["WEBDRIVER_WINDOW"]["slice"],
)
screenshots = [
ChartScreenshot(
url,
self._report_schedule.chart.digest,
window_size=app.config["WEBDRIVER_WINDOW"]["slice"],
thumb_size=app.config["WEBDRIVER_WINDOW"]["slice"],
)
]
else:
url = self._get_url()
logger.info("Screenshotting dashboard at %s", url)
screenshot = DashboardScreenshot(
url,
self._report_schedule.dashboard.digest,
window_size=app.config["WEBDRIVER_WINDOW"]["dashboard"],
thumb_size=app.config["WEBDRIVER_WINDOW"]["dashboard"],
tabs: Optional[List[str]] = json.loads(self._report_schedule.extra).get(
"dashboard_tab_ids", None
)
dashboard_base_url = self._get_url()
if tabs is None:
urls = [dashboard_base_url]
else:
urls = [f"{dashboard_base_url}#{tab_id}" for tab_id in tabs]
screenshots = [
DashboardScreenshot(
url,
self._report_schedule.dashboard.digest,
window_size=app.config["WEBDRIVER_WINDOW"]["dashboard"],
thumb_size=app.config["WEBDRIVER_WINDOW"]["dashboard"],
)
for url in urls
]
user = self._get_user()
try:
image_data = screenshot.get_screenshot(user=user)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while taking a screenshot.")
raise ReportScheduleScreenshotTimeout() from ex
except Exception as ex:
raise ReportScheduleScreenshotFailedError(
f"Failed taking a screenshot {str(ex)}"
) from ex
for screenshot in screenshots:
try:
image = screenshot.get_screenshot(user=user)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while taking a screenshot.")
raise ReportScheduleScreenshotTimeout() from ex
except Exception as ex:
raise ReportScheduleScreenshotFailedError(
f"Failed taking a screenshot {str(ex)}"
) from ex
if image is not None:
image_data.append(image)
if not image_data:
raise ReportScheduleScreenshotFailedError()
return image_data
Expand Down Expand Up @@ -285,7 +299,7 @@ def _update_query_context(self) -> None:
context.
"""
try:
self._get_screenshot()
self._get_screenshots()
except (
ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout,
Expand All @@ -305,14 +319,14 @@ def _get_notification_content(self) -> NotificationContent:
csv_data = None
embedded_data = None
error_text = None
screenshot_data = None
screenshot_data = []
url = self._get_url(user_friendly=True)
if (
feature_flag_manager.is_feature_enabled("ALERTS_ATTACH_REPORTS")
or self._report_schedule.type == ReportScheduleType.REPORT
):
if self._report_schedule.report_format == ReportDataFormat.VISUALIZATION:
screenshot_data = self._get_screenshot()
screenshot_data = self._get_screenshots()
if not screenshot_data:
error_text = "Unexpected missing screenshot"
elif (
Expand Down Expand Up @@ -346,7 +360,7 @@ def _get_notification_content(self) -> NotificationContent:
return NotificationContent(
name=name,
url=url,
screenshot=screenshot_data,
screenshots=screenshot_data,
description=self._report_schedule.description,
csv=csv_data,
embedded_data=embedded_data,
Expand Down
2 changes: 1 addition & 1 deletion superset/reports/notifications/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class NotificationContent:
name: str
csv: Optional[bytes] = None # bytes for csv file
screenshot: Optional[bytes] = None # bytes for the screenshot
screenshots: Optional[List[bytes]] = None # bytes for a list of screenshots
text: Optional[str] = None
description: Optional[str] = ""
url: Optional[str] = None # url to chart/dashboard for this screenshot
Expand Down
32 changes: 22 additions & 10 deletions superset/reports/notifications/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,15 @@ def _get_content(self) -> EmailContent:
return EmailContent(body=self._error_template(self._content.text))
# Get the domain from the 'From' address ..
# and make a message id without the < > in the end
image = None
csv_data = None
domain = self._get_smtp_domain()
msgid = make_msgid(domain)[1:-1]
images = {}

if self._content.screenshots:
images = {
make_msgid(domain)[1:-1]: screenshot
for screenshot in self._content.screenshots
}

# Strip any malicious HTML from the description
description = bleach.clean(self._content.description or "")
Expand All @@ -89,11 +94,16 @@ def _get_content(self) -> EmailContent:
html_table = ""

call_to_action = __("Explore in Superset")
img_tag = (
f'<img width="1000px" src="cid:{msgid}">'
if self._content.screenshot
else ""
)
img_tags = []
for msgid in images.keys():
img_tags.append(
f"""<div class="image">
<img width="1000px" src="cid:{msgid}">
</div>
<
"""
)
img_tag = "".join(img_tags)
body = textwrap.dedent(
f"""
<html>
Expand All @@ -105,6 +115,9 @@ def _get_content(self) -> EmailContent:
color: rgb(42, 63, 95);
padding: 4px 8px;
}}
.image{{
margin-bottom: 18px;
}}
</style>
</head>
<body>
Expand All @@ -116,11 +129,10 @@ def _get_content(self) -> EmailContent:
</html>
"""
)
if self._content.screenshot:
image = {msgid: self._content.screenshot}

if self._content.csv:
csv_data = {__("%(name)s.csv", name=self._content.name): self._content.csv}
return EmailContent(body=body, images=image, data=csv_data)
return EmailContent(body=body, images=images, data=csv_data)

def _get_subject(self) -> str:
return __(
Expand Down
31 changes: 16 additions & 15 deletions superset/reports/notifications/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import json
import logging
from io import IOBase
from typing import Optional, Union
from typing import Sequence, Union

import backoff
from flask_babel import gettext as __
Expand Down Expand Up @@ -133,16 +133,16 @@ def _get_body(self) -> str:

return self._message_template(table)

def _get_inline_file(self) -> Optional[Union[str, IOBase, bytes]]:
def _get_inline_files(self) -> Sequence[Union[str, IOBase, bytes]]:
if self._content.csv:
return self._content.csv
if self._content.screenshot:
return self._content.screenshot
return None
return [self._content.csv]
if self._content.screenshots:
return self._content.screenshots
return []

@backoff.on_exception(backoff.expo, SlackApiError, factor=10, base=2, max_tries=5)
def send(self) -> None:
file = self._get_inline_file()
files = self._get_inline_files()
title = self._content.name
channel = self._get_channel()
body = self._get_body()
Expand All @@ -153,14 +153,15 @@ def send(self) -> None:
token = token()
client = WebClient(token=token, proxy=app.config["SLACK_PROXY"])
# files_upload returns SlackResponse as we run it in sync mode.
if file:
client.files_upload(
channels=channel,
file=file,
initial_comment=body,
title=title,
filetype=file_type,
)
if files:
for file in files:
client.files_upload(
channels=channel,
file=file,
initial_comment=body,
title=title,
filetype=file_type,
)
else:
client.chat_postMessage(channel=channel, text=body)
logger.info("Report sent to slack")
Expand Down
2 changes: 2 additions & 0 deletions superset/reports/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class ReportSchedulePostSchema(Schema):
description=creation_method_description,
)
dashboard = fields.Integer(required=False, allow_none=True)
selected_tabs = fields.List(fields.Integer(), required=False, allow_none=True)
database = fields.Integer(required=False)
owners = fields.List(fields.Integer(description=owners_description))
validator_type = fields.String(
Expand Down Expand Up @@ -202,6 +203,7 @@ class ReportSchedulePostSchema(Schema):
default=ReportDataFormat.VISUALIZATION,
validate=validate.OneOf(choices=tuple(key.value for key in ReportDataFormat)),
)
extra = fields.Dict(default=None,)
force_screenshot = fields.Boolean(default=False)

@validates_schema
Expand Down
1 change: 0 additions & 1 deletion tests/integration_tests/dashboard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def create_dashboard(
slug: str, title: str, position: str, slices: List[Slice]
) -> Dashboard:
dash = db.session.query(Dashboard).filter_by(slug=slug).one_or_none()

if not dash:
dash = Dashboard()
dash.dashboard_title = title
Expand Down
Loading

0 comments on commit bdc35a2

Please sign in to comment.