Skip to content

Commit d2e19cd

Browse files
Enable support for a plethora of more DNS types. (#24)
1 parent 17a809e commit d2e19cd

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

internal/dns/dns.go

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -240,29 +240,12 @@ func HandleRemoteDNSQuery(ctx context.Context, request mcp.CallToolRequest, conf
240240
}
241241

242242
// ConvertToQType converts a string record type to the corresponding DNS query type.
243+
// This function supports all DNS record types available in the miekg/dns package.
243244
func ConvertToQType(recordType string) (uint16, error) {
244-
switch recordType {
245-
case "A":
246-
return dns.TypeA, nil
247-
case "AAAA":
248-
return dns.TypeAAAA, nil
249-
case "CNAME":
250-
return dns.TypeCNAME, nil
251-
case "MX":
252-
return dns.TypeMX, nil
253-
case "NS":
254-
return dns.TypeNS, nil
255-
case "PTR":
256-
return dns.TypePTR, nil
257-
case "SOA":
258-
return dns.TypeSOA, nil
259-
case "SRV":
260-
return dns.TypeSRV, nil
261-
case "TXT":
262-
return dns.TypeTXT, nil
263-
default:
264-
return 0, fmt.Errorf("unsupported record type: %s", recordType)
245+
if qtype, exists := dns.StringToType[recordType]; exists {
246+
return qtype, nil
265247
}
248+
return 0, fmt.Errorf("unsupported record type %q", recordType)
266249
}
267250

268251
// createDNSResponse creates a JSON-serializable map from a DNS message.
@@ -306,6 +289,10 @@ func createDNSResponse(response *dns.Msg) map[string]any {
306289
if rec, ok := a.(*dns.AAAA); ok {
307290
data = rec.AAAA.String()
308291
}
292+
case dns.TypeCAA:
293+
if rec, ok := a.(*dns.CAA); ok {
294+
data = fmt.Sprintf("%d %s %q", rec.Flag, rec.Tag, rec.Value)
295+
}
309296
case dns.TypeCNAME:
310297
if rec, ok := a.(*dns.CNAME); ok {
311298
data = rec.Target

internal/server/server.go

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package server
22

33
import (
44
"context"
5+
"slices"
56
"time"
67

78
"github.com/mark3labs/mcp-go/mcp"
89
"github.com/mark3labs/mcp-go/server"
9-
"github.com/patrickdappollonio/mcp-domaintools/internal/dns"
10+
"github.com/miekg/dns"
11+
internaldns "github.com/patrickdappollonio/mcp-domaintools/internal/dns"
1012
"github.com/patrickdappollonio/mcp-domaintools/internal/http_ping"
1113
"github.com/patrickdappollonio/mcp-domaintools/internal/ping"
1214
"github.com/patrickdappollonio/mcp-domaintools/internal/resolver"
@@ -16,7 +18,7 @@ import (
1618

1719
// DomainToolsConfig contains configuration for the domain tools.
1820
type DomainToolsConfig struct {
19-
QueryConfig *dns.QueryConfig
21+
QueryConfig *internaldns.QueryConfig
2022
WhoisConfig *whois.Config
2123
ResolverConfig *resolver.Config
2224
PingConfig *ping.Config
@@ -25,6 +27,16 @@ type DomainToolsConfig struct {
2527
Version string
2628
}
2729

30+
// getDNSRecordTypes returns a sorted slice of all DNS record type names.
31+
func getDNSRecordTypes() []string {
32+
var recordTypes []string
33+
for recordType := range dns.StringToType {
34+
recordTypes = append(recordTypes, recordType)
35+
}
36+
slices.Sort(recordTypes)
37+
return recordTypes
38+
}
39+
2840
// SetupTools creates and configures the domain query tools.
2941
func SetupTools(config *DomainToolsConfig) (*server.MCPServer, error) {
3042
// Create a new MCP server
@@ -34,6 +46,9 @@ func SetupTools(config *DomainToolsConfig) (*server.MCPServer, error) {
3446
server.WithRecovery(),
3547
)
3648

49+
// Get all available DNS record types for the enum
50+
dnsRecordTypes := getDNSRecordTypes()
51+
3752
// Initialize resolver config if not provided
3853
if config.ResolverConfig == nil {
3954
config.ResolverConfig = &resolver.Config{
@@ -74,8 +89,8 @@ func SetupTools(config *DomainToolsConfig) (*server.MCPServer, error) {
7489
),
7590
mcp.WithString("record_type",
7691
mcp.Required(),
77-
mcp.Description("The type of DNS record to query; defaults to A"),
78-
mcp.Enum("A", "AAAA", "CNAME", "MX", "NS", "PTR", "SOA", "SRV", "TXT"),
92+
mcp.Description("The type of DNS record to query (supports all standard DNS record types); defaults to A"),
93+
mcp.Enum(dnsRecordTypes...),
7994
mcp.DefaultString("A"),
8095
),
8196
)
@@ -89,8 +104,8 @@ func SetupTools(config *DomainToolsConfig) (*server.MCPServer, error) {
89104
),
90105
mcp.WithString("record_type",
91106
mcp.Required(),
92-
mcp.Description("The type of DNS record to query; defaults to A"),
93-
mcp.Enum("A", "AAAA", "CNAME", "MX", "NS", "PTR", "SOA", "SRV", "TXT"),
107+
mcp.Description("The type of DNS record to query (supports all standard DNS record types); defaults to A"),
108+
mcp.Enum(dnsRecordTypes...),
94109
mcp.DefaultString("A"),
95110
),
96111
)
@@ -173,11 +188,11 @@ func SetupTools(config *DomainToolsConfig) (*server.MCPServer, error) {
173188

174189
// Create handler wrappers
175190
localDNSHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
176-
return dns.HandleLocalDNSQuery(ctx, request, config.QueryConfig)
191+
return internaldns.HandleLocalDNSQuery(ctx, request, config.QueryConfig)
177192
}
178193

179194
remoteDNSHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
180-
return dns.HandleRemoteDNSQuery(ctx, request, config.QueryConfig)
195+
return internaldns.HandleRemoteDNSQuery(ctx, request, config.QueryConfig)
181196
}
182197

183198
whoisHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {

0 commit comments

Comments
 (0)