||
- package main
- import (
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "log"
- "net"
- "net/http"
- "os"
- "strings"
- "time"
- "github.com/miekg/dns"
- )
- // Environment variables for the DNS server and TSIG key
- var (
- dnsServer = os.Getenv("DNS_SERVER") // e.g., "127.0.0.1:53"
- dnsTsigKey = os.Getenv("DNS_TSIG_KEY")
- dnsTsigName = os.Getenv("DNS_TSIG_NAME") // e.g., "update-key"
- dnsKeySalt = os.Getenv("DNS_KEY_SALT") // salt for TXT names
- )
- var RcodeNameError = errors.New("domain does not exist")
- // UpdateRequest represents the structure of the incoming HTTP request
- type UpdateRequest struct {
- FQDN string `json:"fqdn"`
- Key string `json:"key"`
- IP string `json:"ip,omitempty"`
- }
- func main() {
- if dnsServer == "" || dnsTsigKey == "" || dnsTsigName == "" {
- log.Fatal("Missing required environment variables: DNS_SERVER, DNS_TSIG_KEY, DNS_TSIG_NAME")
- }
- http.HandleFunc("/update", handleUpdate)
- http.HandleFunc("/v3/update", handleUpdateDyndns)
- http.HandleFunc("/nic/update", handleUpdateDyndns)
- log.Println("Server started on :8085")
- log.Fatal(http.ListenAndServe(":8085", nil))
- }
- func handleUpdate(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost && r.Method != http.MethodGet {
- http.Error(w, "Only POST and GET methods are allowed", http.StatusMethodNotAllowed)
- return
- }
- var req UpdateRequest
- if r.Method == http.MethodPost {
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- http.Error(w, "Invalid request body", http.StatusBadRequest)
- return
- }
- } else if r.Method == http.MethodGet {
- req.FQDN = r.URL.Query().Get("fqdn")
- req.Key = r.URL.Query().Get("key")
- req.IP = r.URL.Query().Get("ip")
- }
- err := update(r, req)
- if err != nil {
- http.Error(w, "Failed DNS Update: "+err.Error() , http.StatusInternalServerError)
- return
- }
- w.WriteHeader(http.StatusOK)
- fmt.Fprintln(w, "DNS update successful")
- }
- func handleUpdateDyndns(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodGet {
- http.Error(w, "Only GET methods are allowed", http.StatusMethodNotAllowed)
- return
- }
- var req UpdateRequest
- username, password, ok := r.BasicAuth()
- if !ok {
- http.Error(w, "User and Password required", http.StatusUnauthorized)
- return
- }
- req.FQDN = r.URL.Query().Get("hostname")
- req.IP = r.URL.Query().Get("myip")
- req.Key = username+password
- err := update(r, req)
- if err != nil {
- http.Error(w, "dnserr\n"+err.Error(), http.StatusInternalServerError)
- return
- }
- w.WriteHeader(http.StatusOK)
- fmt.Fprintln(w, "good")
- }
- func update(r *http.Request, req UpdateRequest) error {
- if req.FQDN == "" || req.Key == "" {
- return errors.New("FQDN and Key are required")
- }
- // Default to the requester's IP if IP is not provided
- if req.IP == "" {
- req.IP = getRequesterIP(r)
- log.Printf("using requester ip %s since no ip was provided\n", req.IP)
- }
- if net.ParseIP(req.IP) == nil {
- return errors.New("Invalid IP address")
- }
- components := strings.Split(req.FQDN, ".")
- zone := strings.Join( components[len(components)-2:], ".")
- log.Printf("zone: %s, host: %s", zone, req.FQDN)
- // Perform the DNS update asynchronously
- err := updateDNS(zone, req.FQDN, req.Key, req.IP)
- if err != nil {
- log.Printf("Error updating DNS for %s: %v\n", req.FQDN, err)
- return err
- }
- return nil
- }
- func getRequesterIP(r *http.Request) string {
- xForwardedFor := r.Header.Get("X-Forwarded-For")
- if xForwardedFor != "" {
- return strings.Split(xForwardedFor, ",")[0]
- }
- ip, _, _ := net.SplitHostPort(r.RemoteAddr)
- return ip
- }
- func updateDNS(zone, fqdn, key, ip string) error {
- // Generate the hashed and salted TXT name
- txtName := generateTXTName(fqdn)
- // Validate the key using the DNS TXT record
- savedKeyHash, err := getTXTRecord(txtName)
- log.Printf("updating with txt record %s\n", txtName)
- if err != nil {
- if !errors.Is(err, RcodeNameError) {
- return fmt.Errorf("failed to query TXT record: %w", err)
- }
- // Create the TXT record if it doesn't exist
- hash := hashKey(key)
- if err := createTXTRecord(zone, txtName, hash); err != nil {
- return fmt.Errorf("failed to create TXT record: %w", err)
- }
- savedKeyHash = hash
- }
- if savedKeyHash != hashKey(key) {
- return errors.New("authentication failed: invalid key")
- }
- // Perform the DNS A record update
- return updateARecord(zone, fqdn, ip)
- }
- func getTXTRecord(txtName string) (string, error) {
- msg := new(dns.Msg)
- msg.SetQuestion(txtName+".", dns.TypeTXT)
- client := &dns.Client{}
- log.Printf("sending txt request for %s to %s", txtName, dnsServer)
- resp, _, err := client.Exchange(msg, dnsServer)
- if err != nil {
- return "", err
- }
- if len(resp.Answer) == 0 {
- return "", RcodeNameError
- }
- txt := resp.Answer[0].(*dns.TXT)
- return strings.Join(txt.Txt, ""), nil
- }
- func createTXTRecord(zone, txtName, value string) error {
- msg := new(dns.Msg)
- log.Printf("creating txt record for %s with value %s\n", txtName, value)
- msg.SetUpdate(zone+".")
- msg.Insert([]dns.RR{
- &dns.TXT{
- Hdr: dns.RR_Header{
- Name: txtName+".",
- Rrtype: dns.TypeTXT,
- Class: dns.ClassINET,
- Ttl: 3600*24, // keys don't really change much
- },
- Txt: []string{value},
- },
- })
- return sendMsg(msg)
- }
- func updateARecord(zone, fqdn, ip string) error {
- msg := new(dns.Msg)
- msg.SetUpdate(zone+".")
- msg.RemoveName([]dns.RR{
- &dns.A{
- Hdr: dns.RR_Header{
- Name: fqdn+".",
- Rrtype: dns.TypeA,
- Class: dns.ClassANY,
- Ttl: 0,
- },
- },
- })
- msg.Insert([]dns.RR{
- &dns.A{
- Hdr: dns.RR_Header{
- Name: fqdn+".",
- Rrtype: dns.TypeA,
- Class: dns.ClassINET,
- Ttl: 60,
- },
- A: net.ParseIP(ip),
- },
- })
- return sendMsg(msg)
- }
- func sendMsg(msg *dns.Msg) error {
- client := &dns.Client{}
- signame := dnsTsigName+"."
- client.TsigSecret = map[string]string{signame: dnsTsigKey}
- msg.SetTsig(signame, dns.HmacSHA512, 300, time.Now().Unix())
- res, _, err := client.Exchange(msg, dnsServer)
- if (err != nil) {
- return err
- }
- if (res.MsgHdr.Rcode != dns.RcodeSuccess) {
- return fmt.Errorf("Failure from DNS server: %s", dns.RcodeToString[res.MsgHdr.Rcode])
- }
- return nil
- }
- func hashKey(key string) string {
- h := sha256.Sum256([]byte(key))
- return hex.EncodeToString(h[:])
- }
- func generateTXTName(fqdn string) string {
- saltedFQDN := fmt.Sprintf("%s%s", fqdn, dnsKeySalt)
- hostname := strings.SplitN(fqdn, ".", 2)[0]
- domain := strings.Join(strings.SplitN(fqdn, ".", 2)[1:], ".")
- hash := sha256.Sum256([]byte(saltedFQDN))
- return fmt.Sprintf("%s_%s.%s", hostname, hex.EncodeToString(hash[:8]), domain) // hostname_hash.domain
- }
|