Skip to content

Commit

Permalink
add http handler
Browse files Browse the repository at this point in the history
  • Loading branch information
taoso committed Apr 9, 2024
1 parent a0c5d4c commit b133eb3
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 71 deletions.
75 changes: 4 additions & 71 deletions cmd/zns/main.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
package main

import (
"bytes"
"crypto/tls"
"encoding/base64"
"flag"
"io"
"log"
"net"
"net/http"
"net/netip"
"os"
"strings"

"github.com/miekg/dns"
"github.com/taoso/zns"
"golang.org/x/crypto/acme/autocert"
)

var tlsCert string
var tlsKey string
var tlsHosts string
var listen string
var upstream string

func main() {
flag.StringVar(&tlsCert, "tls-cert", "", "tls cert file path")
flag.StringVar(&tlsKey, "tls-key", "", "tls key file path")
flag.StringVar(&tlsHosts, "tls-hosts", "", "tls host name")
flag.StringVar(&listen, "listen", ":443", "listen addr")
flag.StringVar(&upstream, "upstream", "https://doh.pub/dns-query", "DoH upstream URL")

flag.Parse()

Expand Down Expand Up @@ -54,71 +51,7 @@ func main() {
}

mux := http.NewServeMux()
mux.HandleFunc("/dns/{name}", func(w http.ResponseWriter, r *http.Request) {
var body []byte
var err error
if r.Method == http.MethodGet {
q := r.URL.Query().Get("dns")
body, err = base64.RawURLEncoding.DecodeString(q)
} else {
body, err = io.ReadAll(r.Body)
r.Body.Close()
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

var m dns.Msg
if err := m.Unpack(body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

if m.IsEdns0() == nil {
ip, err := netip.ParseAddrPort(r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
addr := ip.Addr()
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
ecs := &dns.EDNS0_SUBNET{Code: dns.EDNS0SUBNET}
var bits int
if addr.Is4() {
bits = 24
ecs.Family = 1
} else {
bits = 48
ecs.Family = 2
}
ecs.SourceNetmask = uint8(bits)
p := netip.PrefixFrom(addr, bits)
ecs.Address = net.IP(p.Masked().Addr().AsSlice())
opt.Option = append(opt.Option, ecs)
m.Extra = append(m.Extra, opt)
}

if body, err = m.Pack(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

resp, err := http.Post("https://doh.pub/dns-query", "application/dns-message", bytes.NewReader(body))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()

body, err = io.ReadAll(resp.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

w.Write(body)
})
mux.Handle("/dns/{name}", zns.Handler{Upstream: upstream})

if err = http.Serve(lnTLS, mux); err != nil {
log.Fatal(err)
Expand Down
82 changes: 82 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package zns

import (
"bytes"
"encoding/base64"
"io"
"net"
"net/http"
"net/netip"

"github.com/miekg/dns"
)

type Handler struct {
Upstream string
}

func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var body []byte
var err error
if r.Method == http.MethodGet {
q := r.URL.Query().Get("dns")
body, err = base64.RawURLEncoding.DecodeString(q)
} else {
body, err = io.ReadAll(r.Body)
r.Body.Close()
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

var m dns.Msg
if err := m.Unpack(body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

if m.IsEdns0() == nil {
ip, err := netip.ParseAddrPort(r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
addr := ip.Addr()
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
ecs := &dns.EDNS0_SUBNET{Code: dns.EDNS0SUBNET}
var bits int
if addr.Is4() {
bits = 24
ecs.Family = 1
} else {
bits = 48
ecs.Family = 2
}
ecs.SourceNetmask = uint8(bits)
p := netip.PrefixFrom(addr, bits)
ecs.Address = net.IP(p.Masked().Addr().AsSlice())
opt.Option = append(opt.Option, ecs)
m.Extra = append(m.Extra, opt)
}

if body, err = m.Pack(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

resp, err := http.Post(h.Upstream, "application/dns-message", bytes.NewReader(body))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()

body, err = io.ReadAll(resp.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

w.Write(body)
}

0 comments on commit b133eb3

Please sign in to comment.