Skip to content

Commit

Permalink
Merge pull request #12 from FMotalleb/feat-provider-params
Browse files Browse the repository at this point in the history
[Feat]<ResolverParams>
  • Loading branch information
FMotalleb authored Aug 5, 2023
2 parents ef370cf + 16f8efc commit 2917695
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 42 deletions.
6 changes: 4 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ rules:
matcherParams:
- (.*\.)?accounts\.ea\.com.*
- (.*\.)?signin\.ea.com\.*
raw:
A: "{{ .address }} 60 IN A 50.7.87.85"
resolver: "shecan"
resolverParams: ea.com.
# raw:
# A: "{{ .address }} 60 IN A 50.7.87.85"

- name: Ea proxy
matcher: regex
Expand Down
13 changes: 7 additions & 6 deletions lib/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import (

// Rule set of rules to find resolver of each request
type Rule struct {
Name *string `yaml:"name"`
Matcher string `yaml:"matcher"`
MatcherParams []string `yaml:"matcherParams"`
Resolver *string `yaml:"resolver"`
Raw *map[string]string `yaml:"raw"`
IsBlocked bool `yaml:"isBlocked,alias:blocked"`
Name *string `yaml:"name"`
Matcher string `yaml:"matcher"`
MatcherParams []string `yaml:"matcherParams"`
Resolver *string `yaml:"resolver"`
ResolverParams *string `yaml:"resolverParams"`
Raw *map[string]string `yaml:"raw"`
IsBlocked bool `yaml:"isBlocked,alias:blocked"`
}

func (r *Rule) String() string {
Expand Down
119 changes: 85 additions & 34 deletions lib/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ package utils

import (
"bytes"
"github.com/FMotalleb/dns-reverse-proxy-docker/lib/rule"
"html/template"
"net"
"strings"

"github.com/FMotalleb/dns-reverse-proxy-docker/lib/config"
"github.com/FMotalleb/dns-reverse-proxy-docker/lib/provider"
"github.com/miekg/dns"
"github.com/rs/zerolog/log"
)
Expand All @@ -17,51 +18,93 @@ import (
// The handler returns a response, which is then written to the ResponseWriter.
func HandleRequest(c config.Config, w dns.ResponseWriter, req *dns.Msg) {
log.Debug().Msgf("received request to find `%v`", req.Question)
if len(req.Question) == 0 || !allowed(c.Global.AllowTransfer, w, req) {
responseErrorToRequest(w, req)
return
defer recoverDNSResponse(w, req)
if !allowed(c.Global.AllowTransfer, w, req) {
log.Panic().Msgf("received request is not allowed")
}
if len(req.Question) == 0 {
log.Panic().Msgf("received request has no question")
}
requestHostname := req.Question[0].Name
log.Debug().Msgf("received request to find `%s`", requestHostname)
r := c.FindRuleFor(requestHostname)
dnsProvider := provider.Provider{}
dnsProvider := *c.GetDefaultProvider()
switch {
case r.IsBlocked:
log.Debug().Msgf("blocking `%s`", requestHostname)
log.Panic().Msgf("blocked request")
return
case r == nil:
dnsProvider = *c.GetDefaultProvider()
log.Debug().Msgf("no handler found for `%s`, will use default handler", requestHostname)
case r.Resolver != nil:
dnsProvider = *c.FindProvider(*r.Resolver)
log.Debug().Msgf("handler found for `%s`, will use %v, request: %v", requestHostname, dnsProvider, UnNil(r.ResolverParams, requestHostname))
case r.Raw != nil:
mapper := make(map[string]string, 0)
mapper["address"] = requestHostname
raw := r.GetRaw(dns.TypeToString[req.Question[0].Qtype])
if raw == nil {
log.Error().Msgf("%s not supported in the config", dns.TypeToString[req.Question[0].Qtype])
responseErrorToRequest(w, req)
return
}
msg, err := dns.NewRR(formatString(*raw, mapper))
if err != nil {
log.Debug().Msgf("cannot parse raw response: %v", err)
}
if msg != nil {
result := make([]dns.RR, 0)
result = append(result, msg)
req.Answer = result
log.Info().Msgf("cannot parse raw response: %v", req)
_ = w.WriteMsg(req)
if handleRawResponse(requestHostname, r, req, w) {
return
}
}
log.Debug().Msgf("no handler found for `%s`, will use default handler", requestHostname)

transport := "udp"
if _, ok := w.RemoteAddr().(*net.TCPAddr); ok {
transport = "tcp"
}
if r != nil && r.ResolverParams != nil {
changeRequestAddress(req, *r.ResolverParams)
}
resp := dnsProvider.Handle(transport, req)
if r != nil && r.ResolverParams != nil {
changeResponseAddress(resp, requestHostname)
}
_ = w.WriteMsg(resp)
}

func responseErrorToRequest(w dns.ResponseWriter, r *dns.Msg) {
func handleRawResponse(requestHostname string, r *rule.Rule, req *dns.Msg, w dns.ResponseWriter) bool {
mapper := make(map[string]string, 0)
mapper["address"] = requestHostname
raw := r.GetRaw(dns.TypeToString[req.Question[0].Qtype])
if raw == nil {
log.Error().Msgf("%s not supported in the config, continue using default handler", dns.TypeToString[req.Question[0].Qtype])
return false
}
msg, err := dns.NewRR(formatString(*raw, mapper))
if err != nil {
log.Debug().Msgf("cannot parse raw response: %v", err)
return false
}
if msg != nil {
result := make([]dns.RR, 0)
result = append(result, msg)
req.Answer = result
log.Info().Msgf("cannot parse raw response: %v", req)
_ = w.WriteMsg(req)
return true
}
return false
}
func recoverDNSResponse(w dns.ResponseWriter, req *dns.Msg) {
err := recover()
if err != nil {
log.Error().Msgf("Recovering from: %v", err)
reject(w, req)
}
}
func changeRequestAddress(req *dns.Msg, newAddress string) *dns.Msg {
req.Question[0].Name = newAddress
return req
}
func changeResponseAddress(req *dns.Msg, newAddress string) *dns.Msg {
req.Question[0].Name = newAddress
ans := req.Answer[0]
ansStr := strings.Replace(ans.String(), ans.Header().Name, newAddress, 1)
result, err := dns.NewRR(ansStr)
req.Answer[0] = result
if err != nil {
log.Panic().Msgf("faced an error when tried to change answer \nfrom:%v\nto:%v\nerror:%v", req.Answer[0], ansStr, err)
}
return req
}
func reject(w dns.ResponseWriter, r *dns.Msg) {
msg := makeErrorMessage(r)
_ = w.WriteMsg(msg)
}
Expand Down Expand Up @@ -98,22 +141,30 @@ func isTransfer(req *dns.Msg) bool {
func formatString(text string, hashmap map[string]string) string {
tmpl, err := template.New("Mapper").Parse(text)
if err != nil {
panic(err)
log.Panic().Msgf("failed to use parse template: %s\nerror: %v", text, err)
}
writer := bytes.NewBuffer(nil)
err = tmpl.Execute(writer, hashmap)
if err != nil {
panic(err)
log.Panic().Msgf("failed to use template for %s\nerror: %v\nhashmap:%v", text, err, hashmap)
}
return writer.String()
}

// FindFirst non-null value in items
func FindFirst[T any](items ...*T) (t *T) {
for _, item := range items {
if item != nil {
return item
}
// func FindFirst[T any](items ...*T) (t *T) {
// for _, item := range items {
// if item != nil {
// return item
// }
// }
// return nil
// }

// UnNil will check if value is not nil it will return value but if it was nil returns defaultValue
func UnNil[T any](value *T, defaultValue T) T {
if value != nil {
return *value
}
return nil
return defaultValue
}

0 comments on commit 2917695

Please sign in to comment.