Skip to content

Commit

Permalink
ACL/Transport: Add structured subject
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed Feb 22, 2022
1 parent 593b427 commit c39ca5d
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 0 deletions.
201 changes: 201 additions & 0 deletions src/access/Subjects.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Copyright (c) 2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <access/AuthMode.h>
#include <lib/core/DataModelTypes.h>
#include <lib/core/InPlace.h>
#include <lib/core/NodeId.h>
#include <lib/support/Variant.h>

namespace chip {
namespace Access {

struct PaseSubject
{
PaseSubject(uint16_t aPasscodeId) : passcodeId(aPasscodeId) {}
uint16_t passcodeId;
bool operator==(const PaseSubject & that) const { return this->passcodeId == that.passcodeId; }
};

struct NodeSubject
{
NodeSubject(NodeId aNodeId) : nodeId(aNodeId) {}
NodeId nodeId;
bool operator==(const NodeSubject & that) const { return this->nodeId == that.nodeId; }
};

struct GroupSubject
{
GroupSubject(GroupId aGroupId) : groupId(aGroupId) {}
GroupId groupId;
bool operator==(const GroupSubject & that) const { return this->groupId == that.groupId; }
};

class OperationalNodeId;

/** A scoped subject is a unique identifier with-in a fabric scope. */
class ScopedSubject
{
public:
ScopedSubject() {}

template <typename T, class... Args>
constexpr explicit ScopedSubject(InPlaceTemplateType<T>, Args &&... args) :
mSubject(InPlaceTemplate<T>, std::forward<Args>(args)...)
{}

template <typename T, typename... Args>
static ScopedSubject Create(Args &&... args)
{
return ScopedSubject(InPlaceTemplate<T>, std::forward<Args>(args)...);
}

Access::AuthMode GetAuthMode() const
{
if (mSubject.Is<PaseSubject>())
{
return Access::AuthMode::kPase;
}
else if (mSubject.Is<NodeSubject>())
{
return Access::AuthMode::kCase;
}
else if (mSubject.Is<GroupSubject>())
{
return Access::AuthMode::kGroup;
}
else
{
return Access::AuthMode::kNone;
}
}

#if CHIP_DETAIL_LOGGING
const char * GetAuthModeString() const
{
switch (GetAuthMode())
{
case Access::AuthMode::kPase:
return "Pase";
case Access::AuthMode::kCase:
return "Case";
case Access::AuthMode::kGroup:
return "Group";
case Access::AuthMode::kNone:
default:
return "None";
}
}
#endif // CHIP_DETAIL_LOGGING

/** Return subject value described in spec */
uint64_t GetValue() const
{
if (mSubject.Is<PaseSubject>())
{
return static_cast<uint64_t>(mSubject.Get<PaseSubject>().passcodeId) << 48;
}
else if (mSubject.Is<NodeSubject>())
{
return mSubject.Get<NodeSubject>().nodeId;
}
else if (mSubject.Is<GroupSubject>())
{
return static_cast<uint64_t>(mSubject.Get<GroupSubject>().groupId) << 48;
}
else
{
return -1ull;
}
}

bool operator==(const ScopedSubject & that) const { return this->mSubject == that.mSubject; }

private:
Variant<PaseSubject, NodeSubject, GroupSubject> mSubject;
};

/**
* A subject is a global unique identifier. It suite 2 purpose:
*
* 1. Identify an entity. operator== can be used to check if 2 subjects are identical.
* 2. Associate to ACL entries to grant privileges
*/
class Subject
{
public:
Subject() : mFabricIndex(kUndefinedFabricIndex) {}

template <typename T, class... Args>
constexpr explicit Subject(FabricIndex fabricIndex, InPlaceTemplateType<T>, Args &&... args) :
mFabricIndex(fabricIndex), mScopedSubject(InPlaceTemplate<T>, std::forward<Args>(args)...)
{}

template <typename T, typename... Args>
static Subject Create(FabricIndex fabricIndex, Args &&... args)
{
return Subject(fabricIndex, InPlaceTemplate<T>, std::forward<Args>(args)...);
}

FabricIndex GetFabricIndex() const { return mFabricIndex; }
const ScopedSubject & GetScopedSubject() const { return mScopedSubject; }

bool operator==(const Subject & that) const
{
return this->mFabricIndex == that.mFabricIndex && this->mScopedSubject == that.mScopedSubject;
}

private:
FabricIndex mFabricIndex;
ScopedSubject mScopedSubject;

friend bool operator==(const Subject &, const OperationalNodeId &);
};

/**
* OperationalNodeId identifies an individual Node on a Fabric. It is a special type of Subject targeting to NodeSubject. It is
* interchangeable with the generic Subject type but uses less memory.
*/
class OperationalNodeId
{
public:
OperationalNodeId(FabricIndex fabricIndex, NodeId nodeId) : mFabricIndex(fabricIndex), mNodeId(nodeId) {}
Subject ToSubject() { return Subject::Create<NodeSubject>(mFabricIndex, mNodeId); }

FabricIndex GetFabricIndex() const { return mFabricIndex; }
NodeId GetNodeId() const { return mNodeId; }

private:
FabricIndex mFabricIndex;
NodeId mNodeId;

bool operator==(const OperationalNodeId & that) const
{
return this->mFabricIndex == that.mFabricIndex && this->mNodeId == that.mNodeId;
}

friend bool operator==(const Subject &, const OperationalNodeId &);
};

inline bool operator==(const Subject & subject, const OperationalNodeId & node)
{
return subject.mFabricIndex == node.mFabricIndex && subject.mScopedSubject == ScopedSubject::Create<NodeSubject>(node.mNodeId);
}

} // namespace Access
} // namespace chip
7 changes: 7 additions & 0 deletions src/lib/support/logging/CHIPLogging.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,5 +378,12 @@ bool IsCategoryEnabled(uint8_t category);
*/
#define ChipLogFormatMessageType "0x%x"

/**
* Logging helpers for Subject
*/
#define ChipLogFormatSubject "< %u, %s, " ChipLogFormatX64 ">"
#define ChipLogValueSubject(subject) \
subject.GetFabricIndex(), subject.GetScopedSubject().GetAuthModeString(), ChipLogValueX64(subject.GetScopedSubject().GetValue())

} // namespace Logging
} // namespace chip
5 changes: 5 additions & 0 deletions src/transport/GroupSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class GroupSession : public Session
const char * GetSessionTypeString() const override { return "secure"; };
#endif

Access::Subject GetSubject() const override
{
return Access::Subject::Create<Access::GroupSubject>(GetFabricIndex(), mGroupId);
}

Access::SubjectDescriptor GetSubjectDescriptor() const override
{
Access::SubjectDescriptor subjectDescriptor;
Expand Down
17 changes: 17 additions & 0 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@
namespace chip {
namespace Transport {

Access::Subject SecureSession::GetSubject() const
{
if (IsOperationalNodeId(mPeerNodeId))
{
return Access::Subject::Create<Access::NodeSubject>(GetFabricIndex(), mPeerNodeId);
}
else if (IsPAKEKeyId(mPeerNodeId))
{
return Access::Subject::Create<Access::PaseSubject>(GetFabricIndex(), static_cast<uint16_t>(mPeerNodeId >> 48));
}
else
{
VerifyOrDie(false);
return Access::Subject();
}
}

Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const
{
Access::SubjectDescriptor subjectDescriptor;
Expand Down
1 change: 1 addition & 0 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class SecureSession : public Session
const char * GetSessionTypeString() const override { return "secure"; };
#endif

Access::Subject GetSubject() const override;
Access::SubjectDescriptor GetSubjectDescriptor() const override;

bool RequireMRP() const override { return GetPeerAddress().GetTransportType() == Transport::Type::kUdp; }
Expand Down
2 changes: 2 additions & 0 deletions src/transport/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <access/Subjects.h>
#include <credentials/FabricTable.h>
#include <lib/core/CHIPConfig.h>
#include <messaging/ReliableMessageProtocolConfig.h>
Expand Down Expand Up @@ -63,6 +64,7 @@ class Session
virtual void Retain() {}
virtual void Release() {}

virtual Access::Subject GetSubject() const = 0;
virtual Access::SubjectDescriptor GetSubjectDescriptor() const = 0;
virtual bool RequireMRP() const = 0;
virtual const ReliableMessageProtocolConfig & GetMRPConfig() const = 0;
Expand Down
1 change: 1 addition & 0 deletions src/transport/UnauthenticatedSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class UnauthenticatedSession : public Session, public ReferenceCounted<Unauthent
void Retain() override { ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>::Retain(); }
void Release() override { ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter, 0>::Release(); }

Access::Subject GetSubject() const override { return Access::Subject(); }
Access::SubjectDescriptor GetSubjectDescriptor() const override
{
return Access::SubjectDescriptor(); // return an empty ISD for unauthenticated session.
Expand Down

0 comments on commit c39ca5d

Please sign in to comment.