Skip to content

Commit

Permalink
Cost ticket for query
Browse files Browse the repository at this point in the history
  • Loading branch information
taoso committed Apr 10, 2024
1 parent a9d9d02 commit ff8d7cb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
9 changes: 8 additions & 1 deletion cmd/zns/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ var tlsKey string
var tlsHosts string
var listen string
var upstream string
var dbPath string

func main() {
flag.StringVar(&tlsCert, "tls-cert", "", "tls cert file path")
flag.StringVar(&tlsKey, "tls-key", "", "tls key file path")
flag.StringVar(&tlsHosts, "tls-hosts", "", "tls host name")
flag.StringVar(&listen, "listen", ":443", "listen addr")
flag.StringVar(&upstream, "upstream", "https://doh.pub/dns-query", "DoH upstream URL")
flag.StringVar(&dbPath, "db", "", "sqlite database file path")

flag.Parse()

Expand All @@ -50,8 +52,13 @@ func main() {
panic(err)
}

repo := zns.NewTicketRepo(dbPath)
repo.New("foo", 2048, "pay-1")

h := zns.Handler{Upstream: upstream, Repo: repo}

mux := http.NewServeMux()
mux.Handle("/dns/{name}", zns.Handler{Upstream: upstream})
mux.Handle("/dns/{token}", h)

if err = http.Serve(lnTLS, mux); err != nil {
log.Fatal(err)
Expand Down
39 changes: 30 additions & 9 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,32 @@ import (

type Handler struct {
Upstream string
Repo TicketRepo
}

func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var body []byte
var err error
token := r.PathValue("token")
if token == "" {
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}

ts, err := h.Repo.List(token, 1)
if err != nil {
http.Error(w, "invalid token", http.StatusInternalServerError)
return
}
if len(ts) == 0 || ts[0].Bytes <= 0 {
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}

var question []byte
if r.Method == http.MethodGet {
q := r.URL.Query().Get("dns")
body, err = base64.RawURLEncoding.DecodeString(q)
question, err = base64.RawURLEncoding.DecodeString(q)
} else {
body, err = io.ReadAll(r.Body)
question, err = io.ReadAll(r.Body)
r.Body.Close()
}
if err != nil {
Expand All @@ -31,7 +47,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

var m dns.Msg
if err := m.Unpack(body); err != nil {
if err := m.Unpack(question); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
Expand Down Expand Up @@ -70,23 +86,28 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.Extra = append(m.Extra, opt)
}

if body, err = m.Pack(); err != nil {
if question, err = m.Pack(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

resp, err := http.Post(h.Upstream, "application/dns-message", bytes.NewReader(body))
resp, err := http.Post(h.Upstream, "application/dns-message", bytes.NewReader(question))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()

body, err = io.ReadAll(resp.Body)
answer, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

w.Write(body)
if err = h.Repo.Cost(token, len(question)+len(answer)); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}

w.Write(answer)
}

0 comments on commit ff8d7cb

Please sign in to comment.