Skip to content

Commit

Permalink
allow multiple auth domains per company (#860)
Browse files Browse the repository at this point in the history
STTNHUB-330
  • Loading branch information
petrjasek authored Mar 27, 2024
1 parent 5cebc46 commit f2dd07d
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 27 deletions.
11 changes: 6 additions & 5 deletions assets/companies/components/EditCompanyDetails.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import TextInput from 'components/TextInput';
import SelectInput from 'components/SelectInput';
import DateInput from 'components/DateInput';
import CheckboxInput from 'components/CheckboxInput';
import TextListInput from 'components/TextListInput';

interface IProps {
company: ICompany;
Expand Down Expand Up @@ -60,12 +61,12 @@ export function EditCompanyDetails({
)}

{ssoEnabled && (
<TextInput
name='auth_domain'
label={gettext('SSO domain')}
value={company.auth_domain || ''}
<TextListInput
name='auth_domains'
label={gettext('SSO domains')}
value={company.auth_domains || []}
onChange={onChange}
error={errors ? errors.auth_domain : null}
error={errors ? errors.auth_domains : null}
/>
)}

Expand Down
2 changes: 1 addition & 1 deletion assets/interfaces/company.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ export interface ICompany {
seats: number;
section: 'wire' | 'agenda';
}>;
auth_domain?: string;
auth_domains?: Array<string>;
auth_provider?: IAuthProvider['_id']; // if not defined, system assumes a value of 'newshub'
}
28 changes: 28 additions & 0 deletions data_updates/00015_20240325-150938_companies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding: utf-8; -*-
# This file is part of Superdesk.
# For the full copyright and license information, please see the
# AUTHORS and LICENSE files distributed with this source code, or
# at https://www.sourcefabric.org/superdesk/license
#
# Author : petr
# Creation: 2024-03-25 15:09

from superdesk.commands.data_updates import DataUpdate as _DataUpdate


class DataUpdate(_DataUpdate):
resource = "companies"

def forwards(self, mongodb_collection, mongodb_database):
for company in mongodb_collection.find({"auth_domain": {"$exists": True}}):
if company["auth_domain"]:
print("Updating company", company["_id"])
mongodb_collection.update_one(
{"_id": company["_id"]}, {"$set": {"auth_domains": [company["auth_domain"]]}}
)

def backwards(self, mongodb_collection, mongodb_database):
for company in mongodb_collection.find({"auth_domains.0": {"$exists": True}}):
mongodb_collection.update_one(
{"_id": company["_id"]}, {"$set": {"auth_domain": [company["auth_domains"][0]]}}
)
4 changes: 2 additions & 2 deletions newsroom/auth/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ def get_userdata(nameid: str, saml_data: Dict[str, List[str]]) -> UserData:
# first we try to find company based on email domain
domain = nameid.split("@")[-1]
if domain:
company = superdesk.get_resource_service("companies").find_one(req=None, auth_domain=domain)
company = superdesk.get_resource_service("companies").find_one(req=None, auth_domains=domain)
if company is not None:
userdata["company"] = company["_id"]

# then based on preconfigured saml client
if session.get(SESSION_SAML_CLIENT) and not userdata.get("company"):
company = superdesk.get_resource_service("companies").find_one(
req=None, auth_domain=session[SESSION_SAML_CLIENT]
req=None, auth_domains=session[SESSION_SAML_CLIENT]
)
if company is not None:
userdata["company"] = company["_id"]
Expand Down
12 changes: 8 additions & 4 deletions newsroom/companies/companies.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ class CompaniesResource(newsroom.Resource):
},
},
},
"auth_domain": {
"auth_domain": { # Deprecated
"type": "string",
"nullable": True,
"readonly": True,
},
"auth_domains": {
"type": "list",
},
"auth_provider": {"type": "string"},
"company_size": {"type": "string"},
Expand All @@ -84,12 +88,12 @@ class CompaniesResource(newsroom.Resource):
[("name", 1)],
{"unique": True, "collation": {"locale": "en", "strength": 2}},
),
"auth_domain_1": (
[("auth_domain", 1)],
"auth_domains_1": (
[("auth_domains", 1)],
{
"unique": True,
"collation": {"locale": "en", "strength": 2},
"partialFilterExpression": {"auth_domain": {"$gt": ""}}, # filters out None and ""
"partialFilterExpression": {"auth_domains.0": {"$exists": True}}, # only check non empty
},
),
}
Expand Down
16 changes: 13 additions & 3 deletions newsroom/companies/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def create():
try:
ids = get_resource_service("companies").post([new_company])
except werkzeug.exceptions.Conflict:
return jsonify({"name": gettext("Company already exists")}), 400
return conflict_error(new_company)

return jsonify({"success": True, "_id": ids[0]}), 201

Expand Down Expand Up @@ -116,7 +116,7 @@ def get_company_updates(data, original=None):
"company_type": data.get("company_type") or original.get("company_type"),
"monitoring_administrator": data.get("monitoring_administrator") or original.get("monitoring_administrator"),
"allowed_ip_list": data.get("allowed_ip_list") or original.get("allowed_ip_list"),
"auth_domain": data.get("auth_domain"),
"auth_domains": data.get("auth_domains"),
"auth_provider": data.get("auth_provider") or original.get("auth_provider") or "newshub",
}

Expand Down Expand Up @@ -152,12 +152,22 @@ def edit(_id):

updates = get_company_updates(company, original)
set_version_creator(updates)
get_resource_service("companies").patch(ObjectId(_id), updates=updates)
try:
get_resource_service("companies").patch(ObjectId(_id), updates=updates)
except werkzeug.exceptions.Conflict:
return conflict_error(updates)
app.cache.delete(_id)
return jsonify({"success": True}), 200
return jsonify(original), 200


def conflict_error(updates):
if updates.get("auth_domains"):
return jsonify({"auth_domains": gettext("Value is already used")}), 400
else:
return jsonify({"name": gettext("Company already exists")}), 400


@blueprint.route("/companies/<_id>", methods=["DELETE"])
@admin_only
def delete(_id):
Expand Down
2 changes: 1 addition & 1 deletion newsroom/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class Company(CompanyRequired, total=False):

# Authentication
auth_provider: str
auth_domain: str
auth_domains: List[str]
is_enabled: bool
is_approved: bool
expiry_date: datetime
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_auth_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def init(app):
"name": "SAML based auth",
"is_enabled": True,
"auth_provider": "saml",
"auth_domain": "samplecomp",
"auth_domains": ["samplecomp"],
},
],
)
Expand Down
11 changes: 11 additions & 0 deletions tests/core/test_companies.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,14 @@ def test_company_ip_whitelist_validation(client):
test_login_succeeds_for_admin(client)
resp = client.post("companies/new", data=json.dumps(new_company), content_type="application/json")
assert resp.status_code == 400


def test_company_auth_domains(client):
new_company = {"name": "Test", "auth_domains": ["example.com"]}
resp = client.post("companies/new", data=json.dumps(new_company), content_type="application/json")
assert resp.status_code == 201

new_company = {"name": "Test 2", "auth_domains": ["example.com"]}
resp = client.post("companies/new", data=json.dumps(new_company), content_type="application/json")
assert resp.status_code == 400
assert resp.json.get("auth_domains") == "Value is already used"
22 changes: 12 additions & 10 deletions tests/core/test_saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def test_user_data_with_matching_company(app):
company = {
"name": "test",
"auth_domain": "example.com",
"auth_domains": ["example.com"],
}
app.data.insert("companies", [company])

Expand All @@ -27,7 +27,7 @@ def test_user_data_with_matching_company(app):
def test_user_data_with_matching_preconfigured_client(app, client):
company = {
"name": "test",
"auth_domain": "samplecomp",
"auth_domains": ["samplecomp"],
}

app.data.insert("companies", [company])
Expand All @@ -54,13 +54,15 @@ def test_user_data_with_matching_preconfigured_client(app, client):
assert user_data.get("company") == company["_id"]


def test_auth_domain_unique_for_company(app):
app.data.insert("companies", [{"name": "test", "auth_domain": "example.com"}])
def test_company_auth_domains(app):
app.data.insert("companies", [{"name": "test", "auth_domains": ["example.com"]}])
assert app.data.find_one("companies", req=None, auth_domains="example.com") is not None
with pytest.raises(werkzeug.exceptions.Conflict):
app.data.insert("companies", [{"name": "test2", "auth_domain": "example.com"}])
app.data.insert("companies", [{"name": "test2", "auth_domains": ["example.com"]}])
with pytest.raises(werkzeug.exceptions.Conflict):
app.data.insert("companies", [{"name": "TEST2", "auth_domain": "EXAMPLE.COM"}])
app.data.insert("companies", [{"name": "test3", "auth_domain": None}])
app.data.insert("companies", [{"name": "test4", "auth_domain": None}])
app.data.insert("companies", [{"name": "test5", "auth_domain": ""}])
app.data.insert("companies", [{"name": "test6", "auth_domain": ""}])
app.data.insert("companies", [{"name": "TEST2", "auth_domains": ["EXAMPLE.COM"]}])
app.data.insert("companies", [{"name": "test3", "auth_domains": []}])
app.data.insert("companies", [{"name": "test4", "auth_domains": ["foo.com", "bar.com"]}])
assert app.data.find_one("companies", req=None, auth_domains="bar.com") is not None
with pytest.raises(werkzeug.exceptions.Conflict):
app.data.insert("companies", [{"name": "test6", "auth_domains": ["unique.com", "example.com"]}])

0 comments on commit f2dd07d

Please sign in to comment.