From 16f8efc382591433d8db44a531ceaa78ea71d0eb Mon Sep 17 00:00:00 2001 From: Motalleb Fallahnehzad Date: Sat, 5 Aug 2023 14:59:05 +0330 Subject: [PATCH] [Feat] --- config.yaml | 6 ++- lib/rule/rule.go | 13 ++--- lib/utils/utils.go | 119 ++++++++++++++++++++++++++++++++------------- 3 files changed, 96 insertions(+), 42 deletions(-) diff --git a/config.yaml b/config.yaml index b7710be..9574869 100644 --- a/config.yaml +++ b/config.yaml @@ -44,8 +44,10 @@ rules: matcherParams: - (.*\.)?accounts\.ea\.com.* - (.*\.)?signin\.ea.com\.* - raw: - A: "{{ .address }} 60 IN A 50.7.87.85" + resolver: "shecan" + resolverParams: ea.com. +# raw: +# A: "{{ .address }} 60 IN A 50.7.87.85" - name: Ea proxy matcher: regex diff --git a/lib/rule/rule.go b/lib/rule/rule.go index 1534c1e..1ca75b2 100644 --- a/lib/rule/rule.go +++ b/lib/rule/rule.go @@ -10,12 +10,13 @@ import ( // Rule set of rules to find resolver of each request type Rule struct { - Name *string `yaml:"name"` - Matcher string `yaml:"matcher"` - MatcherParams []string `yaml:"matcherParams"` - Resolver *string `yaml:"resolver"` - Raw *map[string]string `yaml:"raw"` - IsBlocked bool `yaml:"isBlocked,alias:blocked"` + Name *string `yaml:"name"` + Matcher string `yaml:"matcher"` + MatcherParams []string `yaml:"matcherParams"` + Resolver *string `yaml:"resolver"` + ResolverParams *string `yaml:"resolverParams"` + Raw *map[string]string `yaml:"raw"` + IsBlocked bool `yaml:"isBlocked,alias:blocked"` } func (r *Rule) String() string { diff --git a/lib/utils/utils.go b/lib/utils/utils.go index 3498db6..91f3254 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -4,11 +4,12 @@ package utils import ( "bytes" + "github.com/FMotalleb/dns-reverse-proxy-docker/lib/rule" "html/template" "net" + "strings" "github.com/FMotalleb/dns-reverse-proxy-docker/lib/config" - "github.com/FMotalleb/dns-reverse-proxy-docker/lib/provider" "github.com/miekg/dns" "github.com/rs/zerolog/log" ) @@ -17,51 +18,93 @@ import ( // The handler returns a response, which is then written to the ResponseWriter. func HandleRequest(c config.Config, w dns.ResponseWriter, req *dns.Msg) { log.Debug().Msgf("received request to find `%v`", req.Question) - if len(req.Question) == 0 || !allowed(c.Global.AllowTransfer, w, req) { - responseErrorToRequest(w, req) - return + defer recoverDNSResponse(w, req) + if !allowed(c.Global.AllowTransfer, w, req) { + log.Panic().Msgf("received request is not allowed") + } + if len(req.Question) == 0 { + log.Panic().Msgf("received request has no question") } requestHostname := req.Question[0].Name log.Debug().Msgf("received request to find `%s`", requestHostname) r := c.FindRuleFor(requestHostname) - dnsProvider := provider.Provider{} + dnsProvider := *c.GetDefaultProvider() switch { + case r.IsBlocked: + log.Debug().Msgf("blocking `%s`", requestHostname) + log.Panic().Msgf("blocked request") + return case r == nil: - dnsProvider = *c.GetDefaultProvider() + log.Debug().Msgf("no handler found for `%s`, will use default handler", requestHostname) case r.Resolver != nil: dnsProvider = *c.FindProvider(*r.Resolver) + log.Debug().Msgf("handler found for `%s`, will use %v, request: %v", requestHostname, dnsProvider, UnNil(r.ResolverParams, requestHostname)) case r.Raw != nil: - mapper := make(map[string]string, 0) - mapper["address"] = requestHostname - raw := r.GetRaw(dns.TypeToString[req.Question[0].Qtype]) - if raw == nil { - log.Error().Msgf("%s not supported in the config", dns.TypeToString[req.Question[0].Qtype]) - responseErrorToRequest(w, req) - return - } - msg, err := dns.NewRR(formatString(*raw, mapper)) - if err != nil { - log.Debug().Msgf("cannot parse raw response: %v", err) - } - if msg != nil { - result := make([]dns.RR, 0) - result = append(result, msg) - req.Answer = result - log.Info().Msgf("cannot parse raw response: %v", req) - _ = w.WriteMsg(req) + if handleRawResponse(requestHostname, r, req, w) { return } } - log.Debug().Msgf("no handler found for `%s`, will use default handler", requestHostname) + transport := "udp" if _, ok := w.RemoteAddr().(*net.TCPAddr); ok { transport = "tcp" } + if r != nil && r.ResolverParams != nil { + changeRequestAddress(req, *r.ResolverParams) + } resp := dnsProvider.Handle(transport, req) + if r != nil && r.ResolverParams != nil { + changeResponseAddress(resp, requestHostname) + } _ = w.WriteMsg(resp) } -func responseErrorToRequest(w dns.ResponseWriter, r *dns.Msg) { +func handleRawResponse(requestHostname string, r *rule.Rule, req *dns.Msg, w dns.ResponseWriter) bool { + mapper := make(map[string]string, 0) + mapper["address"] = requestHostname + raw := r.GetRaw(dns.TypeToString[req.Question[0].Qtype]) + if raw == nil { + log.Error().Msgf("%s not supported in the config, continue using default handler", dns.TypeToString[req.Question[0].Qtype]) + return false + } + msg, err := dns.NewRR(formatString(*raw, mapper)) + if err != nil { + log.Debug().Msgf("cannot parse raw response: %v", err) + return false + } + if msg != nil { + result := make([]dns.RR, 0) + result = append(result, msg) + req.Answer = result + log.Info().Msgf("cannot parse raw response: %v", req) + _ = w.WriteMsg(req) + return true + } + return false +} +func recoverDNSResponse(w dns.ResponseWriter, req *dns.Msg) { + err := recover() + if err != nil { + log.Error().Msgf("Recovering from: %v", err) + reject(w, req) + } +} +func changeRequestAddress(req *dns.Msg, newAddress string) *dns.Msg { + req.Question[0].Name = newAddress + return req +} +func changeResponseAddress(req *dns.Msg, newAddress string) *dns.Msg { + req.Question[0].Name = newAddress + ans := req.Answer[0] + ansStr := strings.Replace(ans.String(), ans.Header().Name, newAddress, 1) + result, err := dns.NewRR(ansStr) + req.Answer[0] = result + if err != nil { + log.Panic().Msgf("faced an error when tried to change answer \nfrom:%v\nto:%v\nerror:%v", req.Answer[0], ansStr, err) + } + return req +} +func reject(w dns.ResponseWriter, r *dns.Msg) { msg := makeErrorMessage(r) _ = w.WriteMsg(msg) } @@ -98,22 +141,30 @@ func isTransfer(req *dns.Msg) bool { func formatString(text string, hashmap map[string]string) string { tmpl, err := template.New("Mapper").Parse(text) if err != nil { - panic(err) + log.Panic().Msgf("failed to use parse template: %s\nerror: %v", text, err) } writer := bytes.NewBuffer(nil) err = tmpl.Execute(writer, hashmap) if err != nil { - panic(err) + log.Panic().Msgf("failed to use template for %s\nerror: %v\nhashmap:%v", text, err, hashmap) } return writer.String() } // FindFirst non-null value in items -func FindFirst[T any](items ...*T) (t *T) { - for _, item := range items { - if item != nil { - return item - } +// func FindFirst[T any](items ...*T) (t *T) { +// for _, item := range items { +// if item != nil { +// return item +// } +// } +// return nil +// } + +// UnNil will check if value is not nil it will return value but if it was nil returns defaultValue +func UnNil[T any](value *T, defaultValue T) T { + if value != nil { + return *value } - return nil + return defaultValue }