Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restyle ACL/Transport: Add structured subject #15479

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions src/access/Subjects.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Copyright (c) 2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <limits>

#include <access/AuthMode.h>
#include <lib/core/DataModelTypes.h>
#include <lib/core/InPlace.h>
#include <lib/core/NodeId.h>
#include <lib/support/Variant.h>

namespace chip {
namespace Access {

struct PaseSubject
{
static constexpr uint16_t kPasscodeId = 0;
bool operator==(const PaseSubject & that) const { return true; }
};

struct NodeSubject
{
NodeSubject(NodeId aNodeId) : nodeId(aNodeId) {}
NodeId nodeId;
bool operator==(const NodeSubject & that) const { return nodeId == that.nodeId; }
};

struct GroupSubject
{
GroupSubject(GroupId aGroupId) : groupId(aGroupId) {}
GroupId groupId;
bool operator==(const GroupSubject & that) const { return groupId == that.groupId; }
};

class OperationalNodeId;

/** A scoped subject is a unique identifier with-in a fabric scope. */
class ScopedSubject
{
public:
ScopedSubject() {}

template <typename T, class... Args>
constexpr explicit ScopedSubject(InPlaceTemplateType<T>, Args &&... args) :
mSubject(InPlaceTemplate<T>, std::forward<Args>(args)...)
{}

template <typename T, typename... Args>
static ScopedSubject Create(Args &&... args)
{
return ScopedSubject(InPlaceTemplate<T>, std::forward<Args>(args)...);
}

Access::AuthMode GetAuthMode() const
{
if (mSubject.Is<PaseSubject>())
{
return Access::AuthMode::kPase;
}
else if (mSubject.Is<NodeSubject>())
{
return Access::AuthMode::kCase;
}
else if (mSubject.Is<GroupSubject>())
{
return Access::AuthMode::kGroup;
}
else
{
return Access::AuthMode::kNone;
}
}

#if CHIP_DETAIL_LOGGING
const char * GetAuthModeString() const
{
switch (GetAuthMode())
{
case Access::AuthMode::kPase:
return "Pase";
case Access::AuthMode::kCase:
return "Case";
case Access::AuthMode::kGroup:
return "Group";
case Access::AuthMode::kNone:
default:
return "None";
}
}
#endif // CHIP_DETAIL_LOGGING

/** Return subject value described in spec */
uint64_t GetValue() const
{
if (mSubject.Is<PaseSubject>())
{
return static_cast<uint64_t>(mSubject.Get<PaseSubject>().kPasscodeId) << 48;
}
else if (mSubject.Is<NodeSubject>())
{
return mSubject.Get<NodeSubject>().nodeId;
}
else if (mSubject.Is<GroupSubject>())
{
return static_cast<uint64_t>(mSubject.Get<GroupSubject>().groupId) << 48;
}
else
{
return std::numeric_limits<uint64_t>::max();
}
}

bool operator==(const ScopedSubject & that) const { return mSubject == that.mSubject; }

private:
Variant<PaseSubject, NodeSubject, GroupSubject> mSubject;
};

/**
* A subject is a global unique identifier. It serves 2 purposes:
*
* 1. Identify an entity. operator== can be used to check if 2 subjects are identical.
* 2. Associate to ACL entries to grant privileges
*/
class Subject
{
public:
Subject() : mFabricIndex(kUndefinedFabricIndex) {}

template <typename T, class... Args>
constexpr explicit Subject(FabricIndex fabricIndex, InPlaceTemplateType<T>, Args &&... args) :
mFabricIndex(fabricIndex), mScopedSubject(InPlaceTemplate<T>, std::forward<Args>(args)...)
{}

template <typename T, typename... Args>
static Subject Create(FabricIndex fabricIndex, Args &&... args)
{
return Subject(fabricIndex, InPlaceTemplate<T>, std::forward<Args>(args)...);
}

static Subject CreatePaseSubject()
{
// Return a unique subject for all PASE sessions.
return Subject(kUndefinedFabricIndex, InPlaceTemplate<PaseSubject>);
}

FabricIndex GetFabricIndex() const { return mFabricIndex; }
const ScopedSubject & GetScopedSubject() const { return mScopedSubject; }

bool operator==(const Subject & that) const
{
return mFabricIndex == that.mFabricIndex && mScopedSubject == that.mScopedSubject;
}

private:
FabricIndex mFabricIndex;
ScopedSubject mScopedSubject;

friend bool operator==(const Subject &, const OperationalNodeId &);
};

/**
* OperationalNodeId identifies an individual Node on a Fabric. It is a special type of Subject targeting to NodeSubject. It is
* interchangeable with the generic Subject type but uses less memory.
*/
class OperationalNodeId
{
public:
OperationalNodeId(FabricIndex fabricIndex, NodeId nodeId) : mFabricIndex(fabricIndex), mNodeId(nodeId) {}
Subject ToSubject() { return Subject::Create<NodeSubject>(mFabricIndex, mNodeId); }

FabricIndex GetFabricIndex() const { return mFabricIndex; }
NodeId GetNodeId() const { return mNodeId; }

private:
FabricIndex mFabricIndex;
NodeId mNodeId;

bool operator==(const OperationalNodeId & that) const { return mFabricIndex == that.mFabricIndex && mNodeId == that.mNodeId; }

friend bool operator==(const Subject &, const OperationalNodeId &);
};

inline bool operator==(const Subject & subject, const OperationalNodeId & node)
{
return subject.mFabricIndex == node.mFabricIndex && subject.mScopedSubject == ScopedSubject::Create<NodeSubject>(node.mNodeId);
}

} // namespace Access
} // namespace chip
7 changes: 7 additions & 0 deletions src/lib/support/logging/CHIPLogging.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,5 +378,12 @@ bool IsCategoryEnabled(uint8_t category);
*/
#define ChipLogFormatMessageType "0x%x"

/**
* Logging helpers for Subject
*/
#define ChipLogFormatSubject "< %u, %s, " ChipLogFormatX64 ">"
#define ChipLogValueSubject(subject) \
subject.GetFabricIndex(), subject.GetScopedSubject().GetAuthModeString(), ChipLogValueX64(subject.GetScopedSubject().GetValue())

} // namespace Logging
} // namespace chip
5 changes: 5 additions & 0 deletions src/transport/GroupSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class GroupSession : public Session
const char * GetSessionTypeString() const override { return "secure"; };
#endif

Access::Subject GetSubject() const override
{
return Access::Subject::Create<Access::GroupSubject>(GetFabricIndex(), mGroupId);
}

Access::SubjectDescriptor GetSubjectDescriptor() const override
{
Access::SubjectDescriptor subjectDescriptor;
Expand Down
17 changes: 17 additions & 0 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@
namespace chip {
namespace Transport {

Access::Subject SecureSession::GetSubject() const
{
if (IsOperationalNodeId(mPeerNodeId))
{
return Access::Subject::Create<Access::NodeSubject>(GetFabricIndex(), mPeerNodeId);
}
else if (IsPAKEKeyId(mPeerNodeId))
{
return Access::Subject::CreatePaseSubject();
}
else
{
VerifyOrDie(false);
return Access::Subject();
}
}

Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const
{
Access::SubjectDescriptor subjectDescriptor;
Expand Down
1 change: 1 addition & 0 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class SecureSession : public Session
const char * GetSessionTypeString() const override { return "secure"; };
#endif

Access::Subject GetSubject() const override;
Access::SubjectDescriptor GetSubjectDescriptor() const override;

bool RequireMRP() const override { return GetPeerAddress().GetTransportType() == Transport::Type::kUdp; }
Expand Down
2 changes: 2 additions & 0 deletions src/transport/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <access/Subjects.h>
#include <credentials/FabricTable.h>
#include <lib/core/CHIPConfig.h>
#include <messaging/ReliableMessageProtocolConfig.h>
Expand Down Expand Up @@ -63,6 +64,7 @@ class Session
virtual void Retain() {}
virtual void Release() {}

virtual Access::Subject GetSubject() const = 0;
virtual Access::SubjectDescriptor GetSubjectDescriptor() const = 0;
virtual bool RequireMRP() const = 0;
virtual const ReliableMessageProtocolConfig & GetMRPConfig() const = 0;
Expand Down
1 change: 1 addition & 0 deletions src/transport/UnauthenticatedSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class UnauthenticatedSession : public Session, public ReferenceCounted<Unauthent
void Retain() override { ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>::Retain(); }
void Release() override { ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>::Release(); }

Access::Subject GetSubject() const override { return Access::Subject(); }
Access::SubjectDescriptor GetSubjectDescriptor() const override
{
return Access::SubjectDescriptor(); // return an empty ISD for unauthenticated session.
Expand Down