Skip to content

Commit

Permalink
unified request flow && external binding to LDAP
Browse files Browse the repository at this point in the history
  • Loading branch information
regeda committed Oct 8, 2019
1 parent 9f0d712 commit 9f7ed51
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 318 deletions.
53 changes: 17 additions & 36 deletions add.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
package ldap

import (
"errors"
"log"

"gopkg.in/asn1-ber.v1"
ber "gopkg.in/asn1-ber.v1"
)

// Attribute represents an LDAP attribute
Expand Down Expand Up @@ -45,20 +44,26 @@ type AddRequest struct {
Controls []Control
}

func (a AddRequest) encode() *ber.Packet {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.DN, "DN"))
func (req *AddRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
for _, attribute := range a.Attributes {
for _, attribute := range req.Attributes {
attributes.AppendChild(attribute.encode())
}
request.AppendChild(attributes)
return request
pkt.AppendChild(attributes)

envelope.AppendChild(pkt)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}

return nil
}

// Attribute adds an attribute with the given type and values
func (a *AddRequest) Attribute(attrType string, attrVals []string) {
a.Attributes = append(a.Attributes, Attribute{Type: attrType, Vals: attrVals})
func (req *AddRequest) Attribute(attrType string, attrVals []string) {
req.Attributes = append(req.Attributes, Attribute{Type: attrType, Vals: attrVals})
}

// NewAddRequest returns an AddRequest for the given DN, with no attributes
Expand All @@ -72,39 +77,17 @@ func NewAddRequest(dn string, controls []Control) *AddRequest {

// Add performs the given AddRequest
func (l *Conn) Add(addRequest *AddRequest) error {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
packet.AppendChild(addRequest.encode())
if len(addRequest.Controls) > 0 {
packet.AppendChild(encodeControls(addRequest.Controls))
}

l.Debug.PrintPacket(packet)

msgCtx, err := l.sendMessage(packet)
msgCtx, err := l.doRequest(addRequest)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)

l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
if !ok {
return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
packet, err := l.readPacket(msgCtx)
if err != nil {
return err
}

if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
return err
}
ber.PrintPacket(packet)
}

if packet.Children[1].Tag == ApplicationAddResponse {
err := GetLDAPError(packet)
if err != nil {
Expand All @@ -113,7 +96,5 @@ func (l *Conn) Add(addRequest *AddRequest) error {
} else {
log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
}

l.Debug.Printf("%d: returning", msgCtx.id)
return nil
}
83 changes: 50 additions & 33 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"errors"
"fmt"

"gopkg.in/asn1-ber.v1"
ber "gopkg.in/asn1-ber.v1"
)

// SimpleBindRequest represents a username/password bind operation
Expand Down Expand Up @@ -35,13 +35,18 @@ func NewSimpleBindRequest(username string, password string, controls []Control)
}
}

func (bindRequest *SimpleBindRequest) encode() *ber.Packet {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, bindRequest.Username, "User Name"))
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, bindRequest.Password, "Password"))
func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name"))
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password"))

return request
envelope.AppendChild(pkt)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}

return nil
}

// SimpleBind performs the simple bind operation defined in the given request
Expand All @@ -50,41 +55,17 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
}

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
encodedBindRequest := simpleBindRequest.encode()
packet.AppendChild(encodedBindRequest)
if len(simpleBindRequest.Controls) > 0 {
packet.AppendChild(encodeControls(simpleBindRequest.Controls))
}

if l.Debug {
ber.PrintPacket(packet)
}

msgCtx, err := l.sendMessage(packet)
msgCtx, err := l.doRequest(simpleBindRequest)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)

packetResponse, ok := <-msgCtx.responses
if !ok {
return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
}

if l.Debug {
if err = addLDAPDescriptions(packet); err != nil {
return nil, err
}
ber.PrintPacket(packet)
}

result := &SimpleBindResult{
Controls: make([]Control, 0),
}
Expand Down Expand Up @@ -133,3 +114,39 @@ func (l *Conn) UnauthenticatedBind(username string) error {
_, err := l.SimpleBind(req)
return err
}

var externalBindRequest = requestFunc(func(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))

saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech"))
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred"))

pkt.AppendChild(saslAuth)

envelope.AppendChild(pkt)

return nil
})

// ExternalBind performs SASL/EXTERNAL authentication.
//
// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind.
//
// See https://tools.ietf.org/html/rfc4422#appendix-A
func (l *Conn) ExternalBind() error {
msgCtx, err := l.doRequest(externalBindRequest)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)

packet, err := l.readPacket(msgCtx)
if err != nil {
return err
}

return GetLDAPError(packet)
}
17 changes: 9 additions & 8 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,22 @@ import (
// Client knows how to interact with an LDAP server
type Client interface {
Start()
StartTLS(config *tls.Config) error
StartTLS(*tls.Config) error
Close()
SetTimeout(time.Duration)

Bind(username, password string) error
SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error)
UnauthenticatedBind(username string) error
SimpleBind(*SimpleBindRequest) (*SimpleBindResult, error)

Add(addRequest *AddRequest) error
Del(delRequest *DelRequest) error
Modify(modifyRequest *ModifyRequest) error
ModifyDN(modifyDNRequest *ModifyDNRequest) error
Add(*AddRequest) error
Del(*DelRequest) error
Modify(*ModifyRequest) error
ModifyDN(*ModifyDNRequest) error

Compare(dn, attribute, value string) (bool, error)
PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error)
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)

Search(searchRequest *SearchRequest) (*SearchResult, error)
Search(*SearchRequest) (*SearchResult, error)
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
}
55 changes: 26 additions & 29 deletions compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,53 +20,50 @@
package ldap

import (
"errors"
"fmt"

"gopkg.in/asn1-ber.v1"
ber "gopkg.in/asn1-ber.v1"
)

// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise
// false with any error that occurs if any.
func (l *Conn) Compare(dn, attribute, value string) (bool, error) {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
// CompareRequest represents an LDAP CompareRequest operation.
type CompareRequest struct {
DN string
Attribute string
Value string
}

request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request")
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, dn, "DN"))
func (req *CompareRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))

ava := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "AttributeValueAssertion")
ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "AttributeDesc"))
ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "AssertionValue"))
request.AppendChild(ava)
packet.AppendChild(request)
ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Attribute, "AttributeDesc"))
ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Value, "AssertionValue"))

l.Debug.PrintPacket(packet)
pkt.AppendChild(ava)

msgCtx, err := l.sendMessage(packet)
envelope.AppendChild(pkt)

return nil
}

// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise
// false with any error that occurs if any.
func (l *Conn) Compare(dn, attribute, value string) (bool, error) {
msgCtx, err := l.doRequest(&CompareRequest{
DN: dn,
Attribute: attribute,
Value: value})
if err != nil {
return false, err
}
defer l.finishMessage(msgCtx)

l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
if !ok {
return false, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
packet, err := l.readPacket(msgCtx)
if err != nil {
return false, err
}

if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
return false, err
}
ber.PrintPacket(packet)
}

if packet.Children[1].Tag == ApplicationCompareResponse {
err := GetLDAPError(packet)

Expand Down
14 changes: 10 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"sync/atomic"
"time"

"gopkg.in/asn1-ber.v1"
ber "gopkg.in/asn1-ber.v1"
)

const (
Expand Down Expand Up @@ -140,7 +140,6 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
// or ldap:// specified as protocol. On success a new Conn for the connection
// is returned.
func DialURL(addr string) (*Conn, error) {

lurl, err := url.Parse(addr)
if err != nil {
return nil, NewError(ErrorNetwork, err)
Expand All @@ -154,6 +153,11 @@ func DialURL(addr string) (*Conn, error) {
}

switch lurl.Scheme {
case "ldapi":
if lurl.Path == "" || lurl.Path == "/" {
lurl.Path = "/var/run/slapd/ldapi"
}
return Dial("unix", lurl.Path)
case "ldap":
if port == "" {
port = DefaultLdapPort
Expand Down Expand Up @@ -490,11 +494,13 @@ func (l *Conn) reader() {
// A read error is expected here if we are closing the connection...
if !l.IsClosing() {
l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
l.Debug.Printf("reader error: %s", err.Error())
l.Debug.Printf("reader error: %s", err)
}
return
}
addLDAPDescriptions(packet)
if err := addLDAPDescriptions(packet); err != nil {
l.Debug.Printf("descriptions error: %s", err)
}
if len(packet.Children) == 0 {
l.Debug.Printf("Received bad ldap packet")
continue
Expand Down
Loading

0 comments on commit 9f7ed51

Please sign in to comment.