|
|
@@ -0,0 +1,224 @@
|
|
|
+package main
|
|
|
+
|
|
|
+import (
|
|
|
+ "crypto/sha256"
|
|
|
+ "encoding/hex"
|
|
|
+ "encoding/json"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "log"
|
|
|
+ "net"
|
|
|
+ "net/http"
|
|
|
+ "os"
|
|
|
+ "strings"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/miekg/dns"
|
|
|
+)
|
|
|
+
|
|
|
+// Environment variables for the DNS server and TSIG key
|
|
|
+var (
|
|
|
+ dnsServer = os.Getenv("DNS_SERVER") // e.g., "127.0.0.1:53"
|
|
|
+ dnsTsigKey = os.Getenv("DNS_TSIG_KEY")
|
|
|
+ dnsTsigName = os.Getenv("DNS_TSIG_NAME") // e.g., "update-key"
|
|
|
+ dnsKeySalt = os.Getenv("DNS_KEY_SALT") // salt for TXT names
|
|
|
+)
|
|
|
+
|
|
|
+var RcodeNameError = errors.New("domain does not exist")
|
|
|
+
|
|
|
+// UpdateRequest represents the structure of the incoming HTTP request
|
|
|
+type UpdateRequest struct {
|
|
|
+ FQDN string `json:"fqdn"`
|
|
|
+ Key string `json:"key"`
|
|
|
+ IP string `json:"ip,omitempty"`
|
|
|
+}
|
|
|
+
|
|
|
+func main() {
|
|
|
+ if dnsServer == "" || dnsTsigKey == "" || dnsTsigName == "" {
|
|
|
+ log.Fatal("Missing required environment variables: DNS_SERVER, DNS_TSIG_KEY, DNS_TSIG_NAME")
|
|
|
+ }
|
|
|
+
|
|
|
+ http.HandleFunc("/update", handleUpdate)
|
|
|
+ log.Println("Server started on :8085")
|
|
|
+ log.Fatal(http.ListenAndServe(":8085", nil))
|
|
|
+}
|
|
|
+
|
|
|
+func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
|
|
+ if r.Method != http.MethodPost && r.Method != http.MethodGet {
|
|
|
+ http.Error(w, "Only POST and GET methods are allowed", http.StatusMethodNotAllowed)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ var req UpdateRequest
|
|
|
+ if r.Method == http.MethodPost {
|
|
|
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
|
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ } else if r.Method == http.MethodGet {
|
|
|
+ req.FQDN = r.URL.Query().Get("fqdn")
|
|
|
+ req.Key = r.URL.Query().Get("key")
|
|
|
+ req.IP = r.URL.Query().Get("ip")
|
|
|
+ }
|
|
|
+
|
|
|
+ if req.FQDN == "" || req.Key == "" {
|
|
|
+ http.Error(w, "FQDN and Key are required", http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Default to the requester's IP if IP is not provided
|
|
|
+ if req.IP == "" {
|
|
|
+ req.IP = getRequesterIP(r)
|
|
|
+ log.Printf("using requester ip %s since no ip was provided\n", req.IP)
|
|
|
+ }
|
|
|
+
|
|
|
+ if net.ParseIP(req.IP) == nil {
|
|
|
+ http.Error(w, "Invalid IP address", http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ components := strings.Split(req.FQDN, ".")
|
|
|
+ zone := strings.Join( components[len(components)-2:], ".")
|
|
|
+ log.Printf("zone: %s, host: %s", zone, req.FQDN)
|
|
|
+
|
|
|
+ // Perform the DNS update asynchronously
|
|
|
+ err := updateDNS(zone, req.FQDN, req.Key, req.IP)
|
|
|
+ if err != nil {
|
|
|
+ log.Printf("Error updating DNS for %s: %v\n", req.FQDN, err)
|
|
|
+ w.WriteHeader(http.StatusInternalServerError)
|
|
|
+ fmt.Fprintf(w, "Failed to update DNS: %v\n", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ w.WriteHeader(http.StatusOK)
|
|
|
+ fmt.Fprintln(w, "DNS update successful")
|
|
|
+}
|
|
|
+
|
|
|
+func getRequesterIP(r *http.Request) string {
|
|
|
+ xForwardedFor := r.Header.Get("X-Forwarded-For")
|
|
|
+ if xForwardedFor != "" {
|
|
|
+ return strings.Split(xForwardedFor, ",")[0]
|
|
|
+ }
|
|
|
+ ip, _, _ := net.SplitHostPort(r.RemoteAddr)
|
|
|
+ return ip
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func updateDNS(zone, fqdn, key, ip string) error {
|
|
|
+ // Generate the hashed and salted TXT name
|
|
|
+ txtName := generateTXTName(fqdn)
|
|
|
+
|
|
|
+ // Validate the key using the DNS TXT record
|
|
|
+ savedKeyHash, err := getTXTRecord(txtName)
|
|
|
+ log.Printf("updating with txt record %s\n", txtName)
|
|
|
+ if err != nil {
|
|
|
+ if !errors.Is(err, RcodeNameError) {
|
|
|
+ return fmt.Errorf("failed to query TXT record: %w", err)
|
|
|
+ }
|
|
|
+ // Create the TXT record if it doesn't exist
|
|
|
+ hash := hashKey(key)
|
|
|
+ if err := createTXTRecord(zone, txtName, hash); err != nil {
|
|
|
+ return fmt.Errorf("failed to create TXT record: %w", err)
|
|
|
+ }
|
|
|
+ savedKeyHash = hash
|
|
|
+ }
|
|
|
+
|
|
|
+ if savedKeyHash != hashKey(key) {
|
|
|
+ return errors.New("authentication failed: invalid key")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Perform the DNS A record update
|
|
|
+ return updateARecord(zone, fqdn, ip)
|
|
|
+}
|
|
|
+
|
|
|
+func getTXTRecord(txtName string) (string, error) {
|
|
|
+ msg := new(dns.Msg)
|
|
|
+ msg.SetQuestion(txtName+".", dns.TypeTXT)
|
|
|
+ client := &dns.Client{}
|
|
|
+ log.Printf("sending txt request for %s to %s", txtName, dnsServer)
|
|
|
+ resp, _, err := client.Exchange(msg, dnsServer)
|
|
|
+ if err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ if len(resp.Answer) == 0 {
|
|
|
+ return "", RcodeNameError
|
|
|
+ }
|
|
|
+ txt := resp.Answer[0].(*dns.TXT)
|
|
|
+ return strings.Join(txt.Txt, ""), nil
|
|
|
+}
|
|
|
+
|
|
|
+func createTXTRecord(zone, txtName, value string) error {
|
|
|
+ msg := new(dns.Msg)
|
|
|
+ log.Printf("creating txt record for %s with value %s\n", txtName, value)
|
|
|
+ msg.SetUpdate(zone+".")
|
|
|
+ msg.Insert([]dns.RR{
|
|
|
+ &dns.TXT{
|
|
|
+ Hdr: dns.RR_Header{
|
|
|
+ Name: txtName+".",
|
|
|
+ Rrtype: dns.TypeTXT,
|
|
|
+ Class: dns.ClassINET,
|
|
|
+ Ttl: 3600*24, // keys don't really change much
|
|
|
+ },
|
|
|
+ Txt: []string{value},
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ return sendMsg(msg)
|
|
|
+}
|
|
|
+
|
|
|
+func updateARecord(zone, fqdn, ip string) error {
|
|
|
+ msg := new(dns.Msg)
|
|
|
+ msg.SetUpdate(zone+".")
|
|
|
+ msg.RemoveName([]dns.RR{
|
|
|
+ &dns.A{
|
|
|
+ Hdr: dns.RR_Header{
|
|
|
+ Name: fqdn+".",
|
|
|
+ Rrtype: dns.TypeA,
|
|
|
+ Class: dns.ClassANY,
|
|
|
+ Ttl: 0,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ })
|
|
|
+ msg.Insert([]dns.RR{
|
|
|
+ &dns.A{
|
|
|
+ Hdr: dns.RR_Header{
|
|
|
+ Name: fqdn+".",
|
|
|
+ Rrtype: dns.TypeA,
|
|
|
+ Class: dns.ClassINET,
|
|
|
+ Ttl: 60,
|
|
|
+ },
|
|
|
+ A: net.ParseIP(ip),
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ return sendMsg(msg)
|
|
|
+}
|
|
|
+
|
|
|
+func sendMsg(msg *dns.Msg) error {
|
|
|
+ client := &dns.Client{}
|
|
|
+ signame := dnsTsigName+"."
|
|
|
+ client.TsigSecret = map[string]string{signame: dnsTsigKey}
|
|
|
+ msg.SetTsig(signame, dns.HmacSHA512, 300, time.Now().Unix())
|
|
|
+
|
|
|
+ res, _, err := client.Exchange(msg, dnsServer)
|
|
|
+ if (err != nil) {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if (res.MsgHdr.Rcode != dns.RcodeSuccess) {
|
|
|
+ return fmt.Errorf("Failure from DNS server: %s", dns.RcodeToString[res.MsgHdr.Rcode])
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func hashKey(key string) string {
|
|
|
+ h := sha256.Sum256([]byte(key))
|
|
|
+ return hex.EncodeToString(h[:])
|
|
|
+}
|
|
|
+
|
|
|
+func generateTXTName(fqdn string) string {
|
|
|
+ saltedFQDN := fmt.Sprintf("%s%s", fqdn, dnsKeySalt)
|
|
|
+ hostname := strings.SplitN(fqdn, ".", 2)[0]
|
|
|
+ domain := strings.Join(strings.SplitN(fqdn, ".", 2)[1:], ".")
|
|
|
+ hash := sha256.Sum256([]byte(saltedFQDN))
|
|
|
+ return fmt.Sprintf("%s_%s.%s", hostname, hex.EncodeToString(hash[:8]), domain) // hostname_hash.domain
|
|
|
+}
|
|
|
+
|