main.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package main
  2. import (
  3. "crypto/sha256"
  4. "encoding/hex"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "log"
  9. "net"
  10. "net/http"
  11. "os"
  12. "strings"
  13. "time"
  14. "github.com/miekg/dns"
  15. )
  16. // Environment variables for the DNS server and TSIG key
  17. var (
  18. dnsServer = os.Getenv("DNS_SERVER") // e.g., "127.0.0.1:53"
  19. dnsTsigKey = os.Getenv("DNS_TSIG_KEY")
  20. dnsTsigName = os.Getenv("DNS_TSIG_NAME") // e.g., "update-key"
  21. dnsKeySalt = os.Getenv("DNS_KEY_SALT") // salt for TXT names
  22. )
  23. var RcodeNameError = errors.New("domain does not exist")
  24. // UpdateRequest represents the structure of the incoming HTTP request
  25. type UpdateRequest struct {
  26. FQDN string `json:"fqdn"`
  27. Key string `json:"key"`
  28. IP string `json:"ip,omitempty"`
  29. }
  30. func main() {
  31. if dnsServer == "" || dnsTsigKey == "" || dnsTsigName == "" {
  32. log.Fatal("Missing required environment variables: DNS_SERVER, DNS_TSIG_KEY, DNS_TSIG_NAME")
  33. }
  34. http.HandleFunc("/update", handleUpdate)
  35. http.HandleFunc("/v3/update", handleUpdateDyndns)
  36. http.HandleFunc("/nic/update", handleUpdateDyndns)
  37. log.Println("Server started on :8085")
  38. log.Fatal(http.ListenAndServe(":8085", nil))
  39. }
  40. func handleUpdate(w http.ResponseWriter, r *http.Request) {
  41. if r.Method != http.MethodPost && r.Method != http.MethodGet {
  42. http.Error(w, "Only POST and GET methods are allowed", http.StatusMethodNotAllowed)
  43. return
  44. }
  45. var req UpdateRequest
  46. if r.Method == http.MethodPost {
  47. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  48. http.Error(w, "Invalid request body", http.StatusBadRequest)
  49. return
  50. }
  51. } else if r.Method == http.MethodGet {
  52. req.FQDN = r.URL.Query().Get("fqdn")
  53. req.Key = r.URL.Query().Get("key")
  54. req.IP = r.URL.Query().Get("ip")
  55. }
  56. err := update(r, req)
  57. if err != nil {
  58. http.Error(w, "Failed DNS Update: "+err.Error() , http.StatusInternalServerError)
  59. return
  60. }
  61. w.WriteHeader(http.StatusOK)
  62. fmt.Fprintln(w, "DNS update successful")
  63. }
  64. func handleUpdateDyndns(w http.ResponseWriter, r *http.Request) {
  65. if r.Method != http.MethodGet {
  66. http.Error(w, "Only GET methods are allowed", http.StatusMethodNotAllowed)
  67. return
  68. }
  69. var req UpdateRequest
  70. username, password, ok := r.BasicAuth()
  71. if !ok {
  72. http.Error(w, "User and Password required", http.StatusUnauthorized)
  73. return
  74. }
  75. req.FQDN = r.URL.Query().Get("hostname")
  76. req.IP = r.URL.Query().Get("myip")
  77. req.Key = username+password
  78. err := update(r, req)
  79. if err != nil {
  80. http.Error(w, "dnserr\n"+err.Error(), http.StatusInternalServerError)
  81. return
  82. }
  83. w.WriteHeader(http.StatusOK)
  84. fmt.Fprintln(w, "good")
  85. }
  86. func update(r *http.Request, req UpdateRequest) error {
  87. if req.FQDN == "" || req.Key == "" {
  88. return errors.New("FQDN and Key are required")
  89. }
  90. // Default to the requester's IP if IP is not provided
  91. if req.IP == "" {
  92. req.IP = getRequesterIP(r)
  93. log.Printf("using requester ip %s since no ip was provided\n", req.IP)
  94. }
  95. if net.ParseIP(req.IP) == nil {
  96. return errors.New("Invalid IP address")
  97. }
  98. components := strings.Split(req.FQDN, ".")
  99. zone := strings.Join( components[len(components)-2:], ".")
  100. log.Printf("zone: %s, host: %s", zone, req.FQDN)
  101. // Perform the DNS update asynchronously
  102. err := updateDNS(zone, req.FQDN, req.Key, req.IP)
  103. if err != nil {
  104. log.Printf("Error updating DNS for %s: %v\n", req.FQDN, err)
  105. return err
  106. }
  107. return nil
  108. }
  109. func getRequesterIP(r *http.Request) string {
  110. xForwardedFor := r.Header.Get("X-Forwarded-For")
  111. if xForwardedFor != "" {
  112. return strings.Split(xForwardedFor, ",")[0]
  113. }
  114. ip, _, _ := net.SplitHostPort(r.RemoteAddr)
  115. return ip
  116. }
  117. func updateDNS(zone, fqdn, key, ip string) error {
  118. // Generate the hashed and salted TXT name
  119. txtName := generateTXTName(fqdn)
  120. // Validate the key using the DNS TXT record
  121. savedKeyHash, err := getTXTRecord(txtName)
  122. log.Printf("updating with txt record %s\n", txtName)
  123. if err != nil {
  124. if !errors.Is(err, RcodeNameError) {
  125. return fmt.Errorf("failed to query TXT record: %w", err)
  126. }
  127. // Create the TXT record if it doesn't exist
  128. hash := hashKey(key)
  129. if err := createTXTRecord(zone, txtName, hash); err != nil {
  130. return fmt.Errorf("failed to create TXT record: %w", err)
  131. }
  132. savedKeyHash = hash
  133. }
  134. if savedKeyHash != hashKey(key) {
  135. return errors.New("authentication failed: invalid key")
  136. }
  137. // Perform the DNS A record update
  138. return updateARecord(zone, fqdn, ip)
  139. }
  140. func getTXTRecord(txtName string) (string, error) {
  141. msg := new(dns.Msg)
  142. msg.SetQuestion(txtName+".", dns.TypeTXT)
  143. client := &dns.Client{}
  144. log.Printf("sending txt request for %s to %s", txtName, dnsServer)
  145. resp, _, err := client.Exchange(msg, dnsServer)
  146. if err != nil {
  147. return "", err
  148. }
  149. if len(resp.Answer) == 0 {
  150. return "", RcodeNameError
  151. }
  152. txt := resp.Answer[0].(*dns.TXT)
  153. return strings.Join(txt.Txt, ""), nil
  154. }
  155. func createTXTRecord(zone, txtName, value string) error {
  156. msg := new(dns.Msg)
  157. log.Printf("creating txt record for %s with value %s\n", txtName, value)
  158. msg.SetUpdate(zone+".")
  159. msg.Insert([]dns.RR{
  160. &dns.TXT{
  161. Hdr: dns.RR_Header{
  162. Name: txtName+".",
  163. Rrtype: dns.TypeTXT,
  164. Class: dns.ClassINET,
  165. Ttl: 3600*24, // keys don't really change much
  166. },
  167. Txt: []string{value},
  168. },
  169. })
  170. return sendMsg(msg)
  171. }
  172. func updateARecord(zone, fqdn, ip string) error {
  173. msg := new(dns.Msg)
  174. msg.SetUpdate(zone+".")
  175. msg.RemoveName([]dns.RR{
  176. &dns.A{
  177. Hdr: dns.RR_Header{
  178. Name: fqdn+".",
  179. Rrtype: dns.TypeA,
  180. Class: dns.ClassANY,
  181. Ttl: 0,
  182. },
  183. },
  184. })
  185. msg.Insert([]dns.RR{
  186. &dns.A{
  187. Hdr: dns.RR_Header{
  188. Name: fqdn+".",
  189. Rrtype: dns.TypeA,
  190. Class: dns.ClassINET,
  191. Ttl: 60,
  192. },
  193. A: net.ParseIP(ip),
  194. },
  195. })
  196. return sendMsg(msg)
  197. }
  198. func sendMsg(msg *dns.Msg) error {
  199. client := &dns.Client{}
  200. signame := dnsTsigName+"."
  201. client.TsigSecret = map[string]string{signame: dnsTsigKey}
  202. msg.SetTsig(signame, dns.HmacSHA512, 300, time.Now().Unix())
  203. res, _, err := client.Exchange(msg, dnsServer)
  204. if (err != nil) {
  205. return err
  206. }
  207. if (res.MsgHdr.Rcode != dns.RcodeSuccess) {
  208. return fmt.Errorf("Failure from DNS server: %s", dns.RcodeToString[res.MsgHdr.Rcode])
  209. }
  210. return nil
  211. }
  212. func hashKey(key string) string {
  213. h := sha256.Sum256([]byte(key))
  214. return hex.EncodeToString(h[:])
  215. }
  216. func generateTXTName(fqdn string) string {
  217. saltedFQDN := fmt.Sprintf("%s%s", fqdn, dnsKeySalt)
  218. hostname := strings.SplitN(fqdn, ".", 2)[0]
  219. domain := strings.Join(strings.SplitN(fqdn, ".", 2)[1:], ".")
  220. hash := sha256.Sum256([]byte(saltedFQDN))
  221. return fmt.Sprintf("%s_%s.%s", hostname, hex.EncodeToString(hash[:8]), domain) // hostname_hash.domain
  222. }