Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
LIBDNS_DA_TEST_ZONE=domain.com.
LIBDNS_DA_TEST_ZONE=domain.com
LIBDNS_DA_NON_ROOT_TEST_ZONE=test.domain.com
LIBDNS_DA_TEST_SERVER_URL=https://da.domain.com:2222
LIBDNS_DA_TEST_INSECURE_SERVER_URL=https://1.1.1.1:2222
LIBDNS_DA_TEST_USER=admin
Expand Down
115 changes: 110 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,9 @@ func (p *Provider) getZoneRecords(ctx context.Context, zone string) ([]libdns.Re
if err != nil {
switch err {
case ErrUnsupported:
rr := libDnsRecord.RR()
p.getLogger().Warn("Unsupported record conversion",
zap.String("type", rr.Type),
zap.String("name", rr.Name))
zap.String("type", respData.Records[i].Type),
zap.String("name", respData.Records[i].Name))
continue
default:
return nil, err
Expand Down Expand Up @@ -140,7 +139,6 @@ func (p *Provider) appendZoneRecord(ctx context.Context, zone string, record lib
return nil, err
}

rr.Data = fmt.Sprintf("name=%v&value=%v", rr.Name, rr.Data)
return &rr, nil
}

Expand Down Expand Up @@ -200,7 +198,6 @@ func (p *Provider) setZoneRecord(ctx context.Context, zone string, record libdns
return nil, err
}

rr.Data = fmt.Sprintf("name=%v&value=%v", rr.Name, rr.Data)
return &rr, nil
}

Expand Down Expand Up @@ -298,3 +295,111 @@ func (p *Provider) executeRequest(ctx context.Context, method, url string) error

return nil
}

func (p *Provider) getDomainList(ctx context.Context) ([]string, error) {
reqURL, err := url.Parse(p.ServerURL)
if err != nil {
p.getLogger().Error("Failed to parse server URL", zap.Error(err))
return nil, err
}

reqURL.Path = "/CMD_API_SHOW_DOMAINS"

queryString := make(url.Values)
queryString.Set("json", "yes")

reqURL.RawQuery = queryString.Encode()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
p.getLogger().Error("Failed to build new request", zap.Error(err))
return nil, err
}

req.SetBasicAuth(p.User, p.LoginKey)

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: p.InsecureRequests,
},
}}

resp, err := client.Do(req)
if err != nil {
p.getLogger().Error("Failed to execute request", zap.Error(err))
return nil, err
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
p.getLogger().Error("Failed to close response body", zap.Error(err))
}
}(resp.Body)

if resp.StatusCode != http.StatusOK {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
p.getLogger().Error("Failed to read response body", zap.Error(err))
return nil, err
}

bodyString := string(bodyBytes)

p.getLogger().Error("API returned a non-200 status code",
zap.Int("status_code", resp.StatusCode),
zap.String("body", bodyString))

return nil, fmt.Errorf("api request failed with status code %d", resp.StatusCode)
}

var respData daDomainList
err = json.NewDecoder(resp.Body).Decode(&respData)
if err != nil {
p.getLogger().Error("Failed to decode JSON response", zap.Error(err))
return nil, err
}

return respData, nil
}

func (p *Provider) findManageableZone(ctx context.Context, requestedZone string) (string, error) {
p.getLogger().Debug("findManageableZone called", zap.String("zone", requestedZone))

// Remove trailing dot if present
requestedZone = strings.TrimSuffix(requestedZone, ".")

// Get list of domains we can manage
domains, err := p.getDomainList(ctx)
if err != nil {
return "", fmt.Errorf("failed to get domain list: %v", err)
}

p.getLogger().Debug("Available domains", zap.Strings("domains", domains))

// Try the requested zone first (exact match)
for _, domain := range domains {
if strings.EqualFold(requestedZone, domain) {
p.getLogger().Debug("Found exact match", zap.String("domain", domain))
return domain, nil
}
}

// If no exact match, traverse backwards through the FQDN to find parent zones
parts := strings.Split(requestedZone, ".")
for i := 1; i < len(parts); i++ {
candidateZone := strings.Join(parts[i:], ".")
p.getLogger().Debug("Checking candidate zone", zap.String("candidate_zone", candidateZone))

for _, domain := range domains {
if strings.EqualFold(candidateZone, domain) {
p.getLogger().Debug("Found manageable parent zone",
zap.String("parent_zone", domain),
zap.String("requested_zone", requestedZone))
return domain, nil
}
}
}

return "", fmt.Errorf("no manageable zone found for %s in available domains: %v", requestedZone, domains)
}
2 changes: 2 additions & 0 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,5 @@ type daResponse struct {
Success string `json:"success,omitempty"`
Result string `json:"result,omitempty"`
}

type daDomainList []string
148 changes: 140 additions & 8 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,15 @@ func (p *Provider) caller() string {

// GetRecords lists all the records in the zone.
func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record, error) {
zone = strings.TrimSuffix(zone, ".")
p.getLogger().Debug("GetRecords called",
zap.String("zone", zone))

records, err := p.getZoneRecords(ctx, zone)
managedZone, err := p.findManageableZone(ctx, zone)
if err != nil {
return nil, err
}

records, err := p.getZoneRecords(ctx, managedZone)
if err != nil {
return nil, err
}
Expand All @@ -76,11 +82,36 @@ func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record

// AppendRecords adds records to the zone. It returns the records that were added.
func (p *Provider) AppendRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
zone = strings.TrimSuffix(zone, ".")
p.getLogger().Debug("AppendRecords called",
zap.String("zone", zone),
zap.Int("record_count", len(records)))

managedZone, err := p.findManageableZone(ctx, zone)
if err != nil {
return nil, err
}

if zone != managedZone {
p.getLogger().Debug("Using managed zone",
zap.String("managed_zone", managedZone),
zap.String("requested_zone", zone))
}

var created []libdns.Record
for _, rec := range records {
result, err := p.appendZoneRecord(ctx, zone, rec)
// Adjust record name if managedZone differs from requested zone
adjustedRecord := rec
if managedZone != strings.TrimSuffix(zone, ".") {
adjustedRecord = p.adjustRecordForZone(rec, zone, managedZone)
}

adjustedRR := adjustedRecord.RR()
p.getLogger().Debug("Creating record",
zap.String("name", adjustedRR.Name),
zap.String("type", adjustedRR.Type),
zap.String("value", adjustedRR.Data))

result, err := p.appendZoneRecord(ctx, managedZone, adjustedRecord)
if err != nil {
return nil, err
}
Expand All @@ -93,13 +124,38 @@ func (p *Provider) AppendRecords(ctx context.Context, zone string, records []lib
// SetRecords sets the records in the zone, either by updating existing records or creating new ones.
// It returns the updated records.
func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
zone = strings.TrimSuffix(zone, ".")
p.getLogger().Debug("SetRecords called",
zap.String("zone", zone),
zap.Int("record_count", len(records)))

managedZone, err := p.findManageableZone(ctx, zone)
if err != nil {
return nil, err
}

if zone != managedZone {
p.getLogger().Debug("Using managed zone",
zap.String("managed_zone", managedZone),
zap.String("requested_zone", zone))
}

var updated []libdns.Record
var errors []error

for _, rec := range records {
result, err := p.setZoneRecord(ctx, zone, rec)
// Adjust record name if managedZone differs from requested zone
adjustedRecord := rec
if managedZone != strings.TrimSuffix(zone, ".") {
adjustedRecord = p.adjustRecordForZone(rec, zone, managedZone)
}

adjustedRR := adjustedRecord.RR()
p.getLogger().Debug("Creating record",
zap.String("name", adjustedRR.Name),
zap.String("type", adjustedRR.Type),
zap.String("value", adjustedRR.Data))

result, err := p.setZoneRecord(ctx, managedZone, adjustedRecord)
if err != nil {
errors = append(errors, err)
continue
Expand All @@ -121,11 +177,36 @@ func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns

// DeleteRecords deletes the records from the zone. It returns the records that were deleted.
func (p *Provider) DeleteRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
zone = strings.TrimSuffix(zone, ".")
p.getLogger().Debug("DeleteRecords called",
zap.String("zone", zone),
zap.Int("record_count", len(records)))

managedZone, err := p.findManageableZone(ctx, zone)
if err != nil {
return nil, err
}

if zone != managedZone {
p.getLogger().Debug("Using managed zone",
zap.String("managed_zone", managedZone),
zap.String("requested_zone", zone))
}

var deleted []libdns.Record
for _, rec := range records {
result, err := p.deleteZoneRecord(ctx, zone, rec)
// Adjust record name if managedZone differs from requested zone
adjustedRecord := rec
if managedZone != strings.TrimSuffix(zone, ".") {
adjustedRecord = p.adjustRecordForZone(rec, zone, managedZone)
}

adjustedRR := adjustedRecord.RR()
p.getLogger().Debug("Deleting record",
zap.String("name", adjustedRR.Name),
zap.String("type", adjustedRR.Type),
zap.String("value", adjustedRR.Data))

result, err := p.deleteZoneRecord(ctx, managedZone, adjustedRecord)
if err != nil {
return nil, err
}
Expand All @@ -135,6 +216,57 @@ func (p *Provider) DeleteRecords(ctx context.Context, zone string, records []lib
return deleted, nil
}

// adjustRecordForZone adjusts the record name when the managed zone differs from the requested zone
func (p *Provider) adjustRecordForZone(record libdns.Record, requestedZone, managedZone string) libdns.Record {
requestedZone = strings.TrimSuffix(requestedZone, ".")
managedZone = strings.TrimSuffix(managedZone, ".")

// Calculate the subdomain portion that was stripped during zone detection
// Example: requestedZone="test.domain.com", managedZone="domain.com" -> subdomain="test"
if !strings.HasSuffix(requestedZone, managedZone) {
return record // Safety check - shouldn't happen with proper zone detection
}

var subdomain string
if requestedZone == managedZone {
subdomain = ""
} else {
subdomain = strings.TrimSuffix(requestedZone, "."+managedZone)
}

if subdomain == "" {
return record
}

rr := record.RR()

// Check if the record name has already been adjusted by seeing if it already ends with the subdomain
if strings.HasSuffix(rr.Name, "."+subdomain) {
p.getLogger().Debug("Record name already adjusted, skipping",
zap.String("name", rr.Name),
zap.String("subdomain", subdomain))
return record
}

// Adjust the record name to include the subdomain
// Example: "_acme-challenge.libdns" -> "_acme-challenge.libdns.test"
adjustedName := rr.Name + "." + subdomain

p.getLogger().Debug("Adjusting record name",
zap.String("original_name", rr.Name),
zap.String("adjusted_name", adjustedName),
zap.String("subdomain", subdomain))

adjustedRR := &libdns.RR{
Type: rr.Type,
Name: adjustedName,
Data: rr.Data,
TTL: rr.TTL,
}

return adjustedRR
}

// Interface guards
var (
_ libdns.RecordGetter = (*Provider)(nil)
Expand Down
Loading