main.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. package main
  2. import (
  3. "crypto/sha256"
  4. "encoding/hex"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "log"
  9. "net"
  10. "net/http"
  11. "os"
  12. "strings"
  13. "time"
  14. "github.com/miekg/dns"
  15. )
  16. // Environment variables for the DNS server and TSIG key
  17. var (
  18. dnsServer = os.Getenv("DNS_SERVER") // e.g., "127.0.0.1:53"
  19. dnsTsigKey = os.Getenv("DNS_TSIG_KEY")
  20. dnsTsigName = os.Getenv("DNS_TSIG_NAME") // e.g., "update-key"
  21. dnsKeySalt = os.Getenv("DNS_KEY_SALT") // salt for TXT names
  22. )
  23. var RcodeNameError = errors.New("domain does not exist")
  24. // UpdateRequest represents the structure of the incoming HTTP request
  25. type UpdateRequest struct {
  26. FQDN string `json:"fqdn"`
  27. Key string `json:"key"`
  28. IP string `json:"ip,omitempty"`
  29. }
  30. func main() {
  31. if dnsServer == "" || dnsTsigKey == "" || dnsTsigName == "" {
  32. log.Fatal("Missing required environment variables: DNS_SERVER, DNS_TSIG_KEY, DNS_TSIG_NAME")
  33. }
  34. http.HandleFunc("/update", handleUpdate)
  35. log.Println("Server started on :8085")
  36. log.Fatal(http.ListenAndServe(":8085", nil))
  37. }
  38. func handleUpdate(w http.ResponseWriter, r *http.Request) {
  39. if r.Method != http.MethodPost && r.Method != http.MethodGet {
  40. http.Error(w, "Only POST and GET methods are allowed", http.StatusMethodNotAllowed)
  41. return
  42. }
  43. var req UpdateRequest
  44. if r.Method == http.MethodPost {
  45. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  46. http.Error(w, "Invalid request body", http.StatusBadRequest)
  47. return
  48. }
  49. } else if r.Method == http.MethodGet {
  50. req.FQDN = r.URL.Query().Get("fqdn")
  51. req.Key = r.URL.Query().Get("key")
  52. req.IP = r.URL.Query().Get("ip")
  53. }
  54. if req.FQDN == "" || req.Key == "" {
  55. http.Error(w, "FQDN and Key are required", http.StatusBadRequest)
  56. return
  57. }
  58. // Default to the requester's IP if IP is not provided
  59. if req.IP == "" {
  60. req.IP = getRequesterIP(r)
  61. log.Printf("using requester ip %s since no ip was provided\n", req.IP)
  62. }
  63. if net.ParseIP(req.IP) == nil {
  64. http.Error(w, "Invalid IP address", http.StatusBadRequest)
  65. return
  66. }
  67. components := strings.Split(req.FQDN, ".")
  68. zone := strings.Join( components[len(components)-2:], ".")
  69. log.Printf("zone: %s, host: %s", zone, req.FQDN)
  70. // Perform the DNS update asynchronously
  71. err := updateDNS(zone, req.FQDN, req.Key, req.IP)
  72. if err != nil {
  73. log.Printf("Error updating DNS for %s: %v\n", req.FQDN, err)
  74. w.WriteHeader(http.StatusInternalServerError)
  75. fmt.Fprintf(w, "Failed to update DNS: %v\n", err)
  76. return
  77. }
  78. w.WriteHeader(http.StatusOK)
  79. fmt.Fprintln(w, "DNS update successful")
  80. }
  81. func getRequesterIP(r *http.Request) string {
  82. xForwardedFor := r.Header.Get("X-Forwarded-For")
  83. if xForwardedFor != "" {
  84. return strings.Split(xForwardedFor, ",")[0]
  85. }
  86. ip, _, _ := net.SplitHostPort(r.RemoteAddr)
  87. return ip
  88. }
  89. func updateDNS(zone, fqdn, key, ip string) error {
  90. // Generate the hashed and salted TXT name
  91. txtName := generateTXTName(fqdn)
  92. // Validate the key using the DNS TXT record
  93. savedKeyHash, err := getTXTRecord(txtName)
  94. log.Printf("updating with txt record %s\n", txtName)
  95. if err != nil {
  96. if !errors.Is(err, RcodeNameError) {
  97. return fmt.Errorf("failed to query TXT record: %w", err)
  98. }
  99. // Create the TXT record if it doesn't exist
  100. hash := hashKey(key)
  101. if err := createTXTRecord(zone, txtName, hash); err != nil {
  102. return fmt.Errorf("failed to create TXT record: %w", err)
  103. }
  104. savedKeyHash = hash
  105. }
  106. if savedKeyHash != hashKey(key) {
  107. return errors.New("authentication failed: invalid key")
  108. }
  109. // Perform the DNS A record update
  110. return updateARecord(zone, fqdn, ip)
  111. }
  112. func getTXTRecord(txtName string) (string, error) {
  113. msg := new(dns.Msg)
  114. msg.SetQuestion(txtName+".", dns.TypeTXT)
  115. client := &dns.Client{}
  116. log.Printf("sending txt request for %s to %s", txtName, dnsServer)
  117. resp, _, err := client.Exchange(msg, dnsServer)
  118. if err != nil {
  119. return "", err
  120. }
  121. if len(resp.Answer) == 0 {
  122. return "", RcodeNameError
  123. }
  124. txt := resp.Answer[0].(*dns.TXT)
  125. return strings.Join(txt.Txt, ""), nil
  126. }
  127. func createTXTRecord(zone, txtName, value string) error {
  128. msg := new(dns.Msg)
  129. log.Printf("creating txt record for %s with value %s\n", txtName, value)
  130. msg.SetUpdate(zone+".")
  131. msg.Insert([]dns.RR{
  132. &dns.TXT{
  133. Hdr: dns.RR_Header{
  134. Name: txtName+".",
  135. Rrtype: dns.TypeTXT,
  136. Class: dns.ClassINET,
  137. Ttl: 3600*24, // keys don't really change much
  138. },
  139. Txt: []string{value},
  140. },
  141. })
  142. return sendMsg(msg)
  143. }
  144. func updateARecord(zone, fqdn, ip string) error {
  145. msg := new(dns.Msg)
  146. msg.SetUpdate(zone+".")
  147. msg.RemoveName([]dns.RR{
  148. &dns.A{
  149. Hdr: dns.RR_Header{
  150. Name: fqdn+".",
  151. Rrtype: dns.TypeA,
  152. Class: dns.ClassANY,
  153. Ttl: 0,
  154. },
  155. },
  156. })
  157. msg.Insert([]dns.RR{
  158. &dns.A{
  159. Hdr: dns.RR_Header{
  160. Name: fqdn+".",
  161. Rrtype: dns.TypeA,
  162. Class: dns.ClassINET,
  163. Ttl: 60,
  164. },
  165. A: net.ParseIP(ip),
  166. },
  167. })
  168. return sendMsg(msg)
  169. }
  170. func sendMsg(msg *dns.Msg) error {
  171. client := &dns.Client{}
  172. signame := dnsTsigName+"."
  173. client.TsigSecret = map[string]string{signame: dnsTsigKey}
  174. msg.SetTsig(signame, dns.HmacSHA512, 300, time.Now().Unix())
  175. res, _, err := client.Exchange(msg, dnsServer)
  176. if (err != nil) {
  177. return err
  178. }
  179. if (res.MsgHdr.Rcode != dns.RcodeSuccess) {
  180. return fmt.Errorf("Failure from DNS server: %s", dns.RcodeToString[res.MsgHdr.Rcode])
  181. }
  182. return nil
  183. }
  184. func hashKey(key string) string {
  185. h := sha256.Sum256([]byte(key))
  186. return hex.EncodeToString(h[:])
  187. }
  188. func generateTXTName(fqdn string) string {
  189. saltedFQDN := fmt.Sprintf("%s%s", fqdn, dnsKeySalt)
  190. hostname := strings.SplitN(fqdn, ".", 2)[0]
  191. domain := strings.Join(strings.SplitN(fqdn, ".", 2)[1:], ".")
  192. hash := sha256.Sum256([]byte(saltedFQDN))
  193. return fmt.Sprintf("%s_%s.%s", hostname, hex.EncodeToString(hash[:8]), domain) // hostname_hash.domain
  194. }