Skip to content

Commit

Permalink
Only allow webserver to request from the worker log server
Browse files Browse the repository at this point in the history
Logs _shouldn't_ contain any sensitive info, but they often do by
mistake. As an extra level of protection we shouldn't allow anything
other than the webserver to access the logs.

(We can't change the bind IP form 0.0.0.0 as for it to be useful it
needs to be accessed from different hosts -- i.e. the webserver will
almost always be on a different node)
  • Loading branch information
ashb committed Jul 1, 2021
1 parent 0c95db5 commit d772f38
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 34 deletions.
9 changes: 8 additions & 1 deletion airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import TYPE_CHECKING, Optional

import httpx
from itsdangerous import TimedJSONWebSignatureSerializer

from airflow.configuration import AirflowConfigException, conf
from airflow.utils.helpers import parse_template_string
Expand Down Expand Up @@ -172,7 +173,13 @@ def _read(self, ti, try_number, metadata=None):
except (AirflowConfigException, ValueError):
pass

response = httpx.get(url, timeout=timeout)
signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
algorithm_name='HS512',
expires_in=conf.getint('webserver', 'log_request_clock_grace', fallback=30),
)

response = httpx.get(url, timeout=timeout, headers={'Authorization': signer.dumps({})})
response.encoding = "utf-8"

# Check if the resource was properly fetched
Expand Down
58 changes: 47 additions & 11 deletions airflow/utils/serve_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,61 @@

"""Serve logs process"""
import os
import time

import flask
from flask import Flask, abort, request, send_from_directory
from itsdangerous import TimedJSONWebSignatureSerializer
from setproctitle import setproctitle

from airflow.configuration import conf


def serve_logs():
"""Serves logs generated by Worker"""
print("Starting flask")
flask_app = flask.Flask(__name__)
setproctitle("airflow serve-logs")
def flask_app():
flask_app = Flask(__name__)
max_request_age = conf.getint('webserver', 'log_request_clock_grace', fallback=30)
log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))

signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
algorithm_name='HS512',
expires_in=max_request_age,
)

# Prevent direct access to the logs port
@flask_app.before_request
def validate_pre_signed_url():
try:
auth = request.headers['Authorization']

# We don't actually care about the payload, just that the signature
# was valid and the `exp` claim is correct
_, headers = signer.loads(auth, return_header=True)

issued_at = int(headers['iat'])
expires_at = int(headers['exp'])
except Exception as e:
print(e)
abort(403)
# Validate the `iat` and `exp` are within `max_request_age` of now.
now = int(time.time())
if abs(now - issued_at) > max_request_age:
abort(403)
if abs(now - expires_at) > max_request_age:
abort(403)
if issued_at > expires_at or expires_at - issued_at > max_request_age:
abort(403)

@flask_app.route('/log/<path:filename>')
def serve_logs_view(filename):
log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))
return flask.send_from_directory(
log_directory, filename, mimetype="application/json", as_attachment=False
)
return send_from_directory(log_directory, filename, mimetype="application/json", as_attachment=False)

return flask_app


def serve_logs():
"""Serves logs generated by Worker"""
setproctitle("airflow serve-logs")
app = flask_app()

worker_log_server_port = conf.getint('celery', 'WORKER_LOG_SERVER_PORT')
flask_app.run(host='0.0.0.0', port=worker_log_server_port)
app.run(host='0.0.0.0', port=worker_log_server_port)
99 changes: 77 additions & 22 deletions tests/utils/test_serve_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,88 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import unittest
from multiprocessing import Process
from os.path import basename
from tempfile import NamedTemporaryFile
from time import sleep
from typing import TYPE_CHECKING

import pytest
import requests
from itsdangerous import TimedJSONWebSignatureSerializer

from airflow.configuration import conf
from airflow.utils.serve_logs import serve_logs
from airflow.utils.serve_logs import flask_app
from tests.test_utils.config import conf_vars

if TYPE_CHECKING:
from flask.testing import FlaskClient

LOG_DATA = "Airflow log data" * 20


@pytest.mark.quarantined
class TestServeLogs(unittest.TestCase):
def test_should_serve_file(self):
log_dir = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))
log_port = conf.get('celery', 'WORKER_LOG_SERVER_PORT')
with NamedTemporaryFile(dir=log_dir) as f:
f.write(LOG_DATA.encode())
f.flush()
sub_proc = Process(target=serve_logs)
sub_proc.start()
sleep(1)
log_url = f"http://localhost:{log_port}/log/{basename(f.name)}"
assert LOG_DATA == requests.get(log_url).content.decode()
sub_proc.terminate()
@pytest.fixture
def client(tmpdir):
with conf_vars({('logging', 'base_log_folder'): str(tmpdir)}):
app = flask_app()

yield app.test_client()


@pytest.fixture
def sample_log(tmpdir):
f = tmpdir / 'sample.log'
f.write(LOG_DATA.encode())

return f


@pytest.mark.usefixtures('sample_log')
class TestServeLogs:
def test_forbidden_no_auth(self, client: "FlaskClient"):
assert 403 == client.get('/log/sample.log').status_code

def test_should_serve_file(self, client: "FlaskClient"):
signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
algorithm_name='HS512',
expires_in=30,
)
assert (
LOG_DATA
== client.get(
'/log/sample.log',
headers={
'Authorization': signer.dumps({}),
},
).data.decode()
)

def test_forbidden_too_long_validity(self, client: "FlaskClient"):
signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
algorithm_name='HS512',
expires_in=3600,
)
assert (
403
== client.get(
'/log/sample.log',
headers={
'Authorization': signer.dumps({}),
},
).status_code
)

def test_forbidden_expired(self, client: "FlaskClient"):
signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
algorithm_name='HS512',
expires_in=30,
)
# Fake the time we think we are
signer.now = lambda: 0
assert (
403
== client.get(
'/log/sample.log',
headers={
'Authorization': signer.dumps({}),
},
).status_code
)

0 comments on commit d772f38

Please sign in to comment.