From c39ca5d5150a23031dab3a39e77e0875c249f96e Mon Sep 17 00:00:00 2001 From: Zang MingJie Date: Wed, 23 Feb 2022 00:14:22 +0800 Subject: [PATCH] ACL/Transport: Add structured subject --- src/access/Subjects.h | 201 ++++++++++++++++++++ src/lib/support/logging/CHIPLogging.h | 7 + src/transport/GroupSession.h | 5 + src/transport/SecureSession.cpp | 17 ++ src/transport/SecureSession.h | 1 + src/transport/Session.h | 2 + src/transport/UnauthenticatedSessionTable.h | 1 + 7 files changed, 234 insertions(+) create mode 100644 src/access/Subjects.h diff --git a/src/access/Subjects.h b/src/access/Subjects.h new file mode 100644 index 00000000000000..1ad7adedbe9d41 --- /dev/null +++ b/src/access/Subjects.h @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace chip { +namespace Access { + +struct PaseSubject +{ + PaseSubject(uint16_t aPasscodeId) : passcodeId(aPasscodeId) {} + uint16_t passcodeId; + bool operator==(const PaseSubject & that) const { return this->passcodeId == that.passcodeId; } +}; + +struct NodeSubject +{ + NodeSubject(NodeId aNodeId) : nodeId(aNodeId) {} + NodeId nodeId; + bool operator==(const NodeSubject & that) const { return this->nodeId == that.nodeId; } +}; + +struct GroupSubject +{ + GroupSubject(GroupId aGroupId) : groupId(aGroupId) {} + GroupId groupId; + bool operator==(const GroupSubject & that) const { return this->groupId == that.groupId; } +}; + +class OperationalNodeId; + +/** A scoped subject is a unique identifier with-in a fabric scope. */ +class ScopedSubject +{ +public: + ScopedSubject() {} + + template + constexpr explicit ScopedSubject(InPlaceTemplateType, Args &&... args) : + mSubject(InPlaceTemplate, std::forward(args)...) + {} + + template + static ScopedSubject Create(Args &&... args) + { + return ScopedSubject(InPlaceTemplate, std::forward(args)...); + } + + Access::AuthMode GetAuthMode() const + { + if (mSubject.Is()) + { + return Access::AuthMode::kPase; + } + else if (mSubject.Is()) + { + return Access::AuthMode::kCase; + } + else if (mSubject.Is()) + { + return Access::AuthMode::kGroup; + } + else + { + return Access::AuthMode::kNone; + } + } + +#if CHIP_DETAIL_LOGGING + const char * GetAuthModeString() const + { + switch (GetAuthMode()) + { + case Access::AuthMode::kPase: + return "Pase"; + case Access::AuthMode::kCase: + return "Case"; + case Access::AuthMode::kGroup: + return "Group"; + case Access::AuthMode::kNone: + default: + return "None"; + } + } +#endif // CHIP_DETAIL_LOGGING + + /** Return subject value described in spec */ + uint64_t GetValue() const + { + if (mSubject.Is()) + { + return static_cast(mSubject.Get().passcodeId) << 48; + } + else if (mSubject.Is()) + { + return mSubject.Get().nodeId; + } + else if (mSubject.Is()) + { + return static_cast(mSubject.Get().groupId) << 48; + } + else + { + return -1ull; + } + } + + bool operator==(const ScopedSubject & that) const { return this->mSubject == that.mSubject; } + +private: + Variant mSubject; +}; + +/** + * A subject is a global unique identifier. It suite 2 purpose: + * + * 1. Identify an entity. operator== can be used to check if 2 subjects are identical. + * 2. Associate to ACL entries to grant privileges + */ +class Subject +{ +public: + Subject() : mFabricIndex(kUndefinedFabricIndex) {} + + template + constexpr explicit Subject(FabricIndex fabricIndex, InPlaceTemplateType, Args &&... args) : + mFabricIndex(fabricIndex), mScopedSubject(InPlaceTemplate, std::forward(args)...) + {} + + template + static Subject Create(FabricIndex fabricIndex, Args &&... args) + { + return Subject(fabricIndex, InPlaceTemplate, std::forward(args)...); + } + + FabricIndex GetFabricIndex() const { return mFabricIndex; } + const ScopedSubject & GetScopedSubject() const { return mScopedSubject; } + + bool operator==(const Subject & that) const + { + return this->mFabricIndex == that.mFabricIndex && this->mScopedSubject == that.mScopedSubject; + } + +private: + FabricIndex mFabricIndex; + ScopedSubject mScopedSubject; + + friend bool operator==(const Subject &, const OperationalNodeId &); +}; + +/** + * OperationalNodeId identifies an individual Node on a Fabric. It is a special type of Subject targeting to NodeSubject. It is + * interchangeable with the generic Subject type but uses less memory. + */ +class OperationalNodeId +{ +public: + OperationalNodeId(FabricIndex fabricIndex, NodeId nodeId) : mFabricIndex(fabricIndex), mNodeId(nodeId) {} + Subject ToSubject() { return Subject::Create(mFabricIndex, mNodeId); } + + FabricIndex GetFabricIndex() const { return mFabricIndex; } + NodeId GetNodeId() const { return mNodeId; } + +private: + FabricIndex mFabricIndex; + NodeId mNodeId; + + bool operator==(const OperationalNodeId & that) const + { + return this->mFabricIndex == that.mFabricIndex && this->mNodeId == that.mNodeId; + } + + friend bool operator==(const Subject &, const OperationalNodeId &); +}; + +inline bool operator==(const Subject & subject, const OperationalNodeId & node) +{ + return subject.mFabricIndex == node.mFabricIndex && subject.mScopedSubject == ScopedSubject::Create(node.mNodeId); +} + +} // namespace Access +} // namespace chip diff --git a/src/lib/support/logging/CHIPLogging.h b/src/lib/support/logging/CHIPLogging.h index b511cb9d834167..fcf2e1e12f8295 100644 --- a/src/lib/support/logging/CHIPLogging.h +++ b/src/lib/support/logging/CHIPLogging.h @@ -378,5 +378,12 @@ bool IsCategoryEnabled(uint8_t category); */ #define ChipLogFormatMessageType "0x%x" +/** + * Logging helpers for Subject + */ +#define ChipLogFormatSubject "< %u, %s, " ChipLogFormatX64 ">" +#define ChipLogValueSubject(subject) \ + subject.GetFabricIndex(), subject.GetScopedSubject().GetAuthModeString(), ChipLogValueX64(subject.GetScopedSubject().GetValue()) + } // namespace Logging } // namespace chip diff --git a/src/transport/GroupSession.h b/src/transport/GroupSession.h index b90d966a528f70..e90165a90983c9 100644 --- a/src/transport/GroupSession.h +++ b/src/transport/GroupSession.h @@ -38,6 +38,11 @@ class GroupSession : public Session const char * GetSessionTypeString() const override { return "secure"; }; #endif + Access::Subject GetSubject() const override + { + return Access::Subject::Create(GetFabricIndex(), mGroupId); + } + Access::SubjectDescriptor GetSubjectDescriptor() const override { Access::SubjectDescriptor subjectDescriptor; diff --git a/src/transport/SecureSession.cpp b/src/transport/SecureSession.cpp index e2688869bc509a..f963423bed0582 100644 --- a/src/transport/SecureSession.cpp +++ b/src/transport/SecureSession.cpp @@ -20,6 +20,23 @@ namespace chip { namespace Transport { +Access::Subject SecureSession::GetSubject() const +{ + if (IsOperationalNodeId(mPeerNodeId)) + { + return Access::Subject::Create(GetFabricIndex(), mPeerNodeId); + } + else if (IsPAKEKeyId(mPeerNodeId)) + { + return Access::Subject::Create(GetFabricIndex(), static_cast(mPeerNodeId >> 48)); + } + else + { + VerifyOrDie(false); + return Access::Subject(); + } +} + Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const { Access::SubjectDescriptor subjectDescriptor; diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index 04094961ae6c6e..621e367624b170 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -83,6 +83,7 @@ class SecureSession : public Session const char * GetSessionTypeString() const override { return "secure"; }; #endif + Access::Subject GetSubject() const override; Access::SubjectDescriptor GetSubjectDescriptor() const override; bool RequireMRP() const override { return GetPeerAddress().GetTransportType() == Transport::Type::kUdp; } diff --git a/src/transport/Session.h b/src/transport/Session.h index 99be7b2315a65b..8f4a714407f514 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -63,6 +64,7 @@ class Session virtual void Retain() {} virtual void Release() {} + virtual Access::Subject GetSubject() const = 0; virtual Access::SubjectDescriptor GetSubjectDescriptor() const = 0; virtual bool RequireMRP() const = 0; virtual const ReliableMessageProtocolConfig & GetMRPConfig() const = 0; diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 213afb1661235a..86fc00099bd988 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -74,6 +74,7 @@ class UnauthenticatedSession : public Session, public ReferenceCounted::Retain(); } void Release() override { ReferenceCounted::Release(); } + Access::Subject GetSubject() const override { return Access::Subject(); } Access::SubjectDescriptor GetSubjectDescriptor() const override { return Access::SubjectDescriptor(); // return an empty ISD for unauthenticated session.