diff --git a/main.go b/main.go index 7979d46..f28486b 100644 --- a/main.go +++ b/main.go @@ -40,7 +40,8 @@ func Start(logger logrus.FieldLogger) error { defer close(sigc) log.Infoln("Start watching periodically for changes!") - go w.Run(time.NewTicker(viper.GetDuration("update-time")*time.Second), chClose, chErr) + t := time.NewTicker(viper.GetDuration("update-time")*time.Second) + go w.Run(t, chClose, chErr) for { select { diff --git a/pkg/providers/ovh/main.go b/pkg/providers/ovh/main.go index 1833977..a38d64c 100644 --- a/pkg/providers/ovh/main.go +++ b/pkg/providers/ovh/main.go @@ -18,8 +18,8 @@ type ovh struct { // NewOVH returns a new instance of the OVH provider func NewOVH(logger logrus.FieldLogger) (providers.Provider, error) { var ovhConfig utils.ProviderConfig - if c, ok := viper.GetStringMap("provider")["ovh"].(utils.ProviderConfig); ok { - ovhConfig = c + if c, ok := viper.GetStringMap("providers")["ovh"]; ok { + ovhConfig = c.(map[string]interface{}) } else { return nil, utils.ErrNilOvhConfig } @@ -52,6 +52,10 @@ func (ovh *ovh) UpdateIP(subdomain, ip string) error { req.SetBasicAuth(ovh.ovhConfig["username"].(string), ovh.ovhConfig["password"].(string)) // * perform GET request + logger.WithFields(logrus.Fields{ + "subdomain": subdomain, + "new-ip": ip, + }).Debugln("calling OVH DynHost to update subdomain IP") c := new(http.Client) resp, err := c.Do(req) if err != nil { @@ -63,5 +67,6 @@ func (ovh *ovh) UpdateIP(subdomain, ip string) error { return utils.ErrWrongStatusCode } + return nil } diff --git a/pkg/subdomain/main.go b/pkg/subdomain/main.go index 5ae92be..152be3a 100644 --- a/pkg/subdomain/main.go +++ b/pkg/subdomain/main.go @@ -2,24 +2,32 @@ package subdomain import ( "net" + "net/http" h "net/http" + "net/http/httptrace" + "strings" + "time" "github.com/datahearth/ddnsclient/pkg/utils" "github.com/sirupsen/logrus" ) -// HTTP is the base interface to interact with websites -type Subdomain interface { - CheckIPAddr(srvIP string) (bool, error) - GetSubdomainIP() string - retrieveSubdomainIP() error -} - -type subdomain struct { - logger logrus.FieldLogger - subdomainAddr string - ip string -} +type ( + PendingSubdomains map[time.Time]Subdomain + subdomain struct { + logger logrus.FieldLogger + subdomainAddr string + ip string + } + Subdomain interface { + CheckIPAddr(srvIP string) (bool, error) + GetSubdomainIP() string + retrieveSubdomainIP() error + GetSubdomainAddr() string + SubIsPending(sbs PendingSubdomains) bool + FindSubdomain(sbs PendingSubdomains) Subdomain + } +) // NewSubdomain instanciate a new http implementation func NewSubdomain(logger logrus.FieldLogger, subdomainAddr string) (Subdomain, error) { @@ -37,33 +45,61 @@ func NewSubdomain(logger logrus.FieldLogger, subdomainAddr string) (Subdomain, e // RetrieveSubdomainIP will retrieve the subdomain IP with a HEAD request func (sd *subdomain) retrieveSubdomainIP() error { + var remoteAddr string logger := sd.logger.WithField("component", "retrieve-subdomain-ip") - resp, err := h.Head(sd.subdomainAddr) + // * create HEAD request + req, err := http.NewRequest("HEAD", "https://"+sd.subdomainAddr, nil) if err != nil { - logger.WithError(err).WithField("subdomain", sd.subdomainAddr).Errorln(utils.ErrHeadRemoteIP.Error()) - return utils.ErrHeadRemoteIP + return err } - if resp.StatusCode != 200 { - logger.WithField("status-code", resp.StatusCode).Errorln(utils.ErrWrongStatusCode.Error()) + // * create a trace to get server remote address + trace := &httptrace.ClientTrace{ + GotConn: func(gci httptrace.GotConnInfo) { + remoteAddr = gci.Conn.RemoteAddr().String() + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + // * create a client to perform the request + client := new(h.Client) + client.Timeout = 5 * time.Second + + // * perform the request + resp, err := client.Do(req) + if err != nil { + // todo: ignoring errors is bad. Implement a solution to scrape 100% of the time remote addr + logger.WithError(err).WithFields(logrus.Fields{ + "subdomain": sd.subdomainAddr, + }).Errorln(utils.ErrHeadRemoteIP.Error()) + sd.ip = "" + return nil + } + if resp.StatusCode != 200 && remoteAddr == "" { + logger.WithFields(logrus.Fields{ + "status-code": resp.StatusCode, + "subdomain": sd.subdomainAddr, + }).Errorln(utils.ErrWrongStatusCode.Error()) return utils.ErrWrongStatusCode } - host, _, err := net.SplitHostPort(resp.Request.RemoteAddr) - if err != nil { - logger.WithError(err).WithField("remote-address", resp.Request.RemoteAddr).Errorln() - return utils.ErrSplitAddr + // * check if remote address contains a port + if strings.Contains(remoteAddr, ":") { + remoteAddr, _, err = net.SplitHostPort(remoteAddr) + if err != nil { + logger.WithError(err).WithField("remote-address", remoteAddr).Errorln(utils.ErrSplitAddr.Error()) + return utils.ErrSplitAddr + } } - sd.ip = host + sd.ip = remoteAddr return nil } -// CheckIPAddr will compare the srvIP passed in parameter and the subIP retrieved from the head request +// CheckIPAddr will compare the server IP and the subdomain IP func (sd *subdomain) CheckIPAddr(srvIP string) (bool, error) { if err := sd.retrieveSubdomainIP(); err != nil { - sd.logger.WithError(err).WithField("component", "check-ip-address").Errorln("failed to retrieve subdomain ip address") return false, err } @@ -74,6 +110,12 @@ func (sd *subdomain) CheckIPAddr(srvIP string) (bool, error) { return true, nil } +// GetSubdomainIP returns the subdomain IP func (sd *subdomain) GetSubdomainIP() string { return sd.ip } + +// GetSubdomainAddr returns the subdomain address +func (sd *subdomain) GetSubdomainAddr() string { + return sd.subdomainAddr +} diff --git a/pkg/subdomain/pending.go b/pkg/subdomain/pending.go new file mode 100644 index 0000000..0ec230d --- /dev/null +++ b/pkg/subdomain/pending.go @@ -0,0 +1,52 @@ +package subdomain + +import "time" + +// SubIsPending check if the current subdomain is waiting the DNS propagation. +func (sb *subdomain) SubIsPending(sbs PendingSubdomains) bool { + for _, sub := range sbs { + if sb == sub { + return true + } + } + + return false +} + +// CheckPendingSubdomains check if any pending subdomains are waiting to be restored. +// If so, it/they will be returned as a slice. +// If not, it returns nil. +func CheckPendingSubdomains(sbs PendingSubdomains, now time.Time) PendingSubdomains { + delSbs := make(PendingSubdomains) + for t, sb := range sbs { + if t.Add(5 * time.Minute).Before(now) { + delSbs[t] = sb + } + } + + if len(delSbs) < 1 { + return nil + } + + return delSbs +} + +// FindSubdomain returns a subdomain found in the pending map of subdomain. +// If not found, it returns nil. +func (sb *subdomain) FindSubdomain(sbs PendingSubdomains) Subdomain { + for _, sub := range sbs { + if sub == sb { + return sb + } + } + + return nil +} + +func DeletePendingSubdomains(delSbs PendingSubdomains, pending PendingSubdomains) PendingSubdomains { + for t := range delSbs { + delete(pending, t) + } + + return pending +} diff --git a/pkg/utils/types.go b/pkg/utils/types.go index 1561a8f..8415396 100644 --- a/pkg/utils/types.go +++ b/pkg/utils/types.go @@ -1,6 +1,8 @@ package utils -import "errors" +import ( + "errors" +) // * Errors var ( diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 35038a7..9350c3c 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -56,8 +56,8 @@ func SetupLogger(logger *logrus.Logger) { func AggregateSubdomains(subdomains []string, domain string) []string { agdSub := make([]string, len(subdomains)) - for _, sd := range subdomains { - agdSub = append(agdSub, sd+"."+domain) + for i, sd := range subdomains { + agdSub[i] = sd + "." + domain } return agdSub diff --git a/pkg/watcher/watcher.go b/pkg/watcher/watcher.go index 624164a..96a8438 100644 --- a/pkg/watcher/watcher.go +++ b/pkg/watcher/watcher.go @@ -1,6 +1,7 @@ package watcher import ( + "fmt" "time" "github.com/datahearth/ddnsclient/pkg/providers" @@ -15,11 +16,13 @@ type Watcher interface { } type watcher struct { - logger logrus.FieldLogger - provider providers.Provider - subdomains []subdomain.Subdomain - domain string - webIP string + logger logrus.FieldLogger + provider providers.Provider + subdomains []subdomain.Subdomain + domain string + webIP string + firstRun bool + pendingSubdomains subdomain.PendingSubdomains } func NewWatcher(logger logrus.FieldLogger, provider providers.Provider, webIP string) (Watcher, error) { @@ -35,28 +38,45 @@ func NewWatcher(logger logrus.FieldLogger, provider providers.Provider, webIP st logger = logger.WithField("pkg", "watcher") domain := viper.GetStringMap("watcher")["domain"].(string) - sbs := utils.AggregateSubdomains(viper.GetStringMap("watcher")["subdomains"].([]string), domain) + var sbs []string + if sb, ok := viper.GetStringMap("watcher")["subdomains"].([]interface{}); ok { + for _, v := range sb { + sbs = append(sbs, fmt.Sprint(v)) + } + } + + sbs = utils.AggregateSubdomains(sbs, domain) subdomains := make([]subdomain.Subdomain, len(sbs)) - for _, sd := range sbs { + for i, sd := range sbs { sub, err := subdomain.NewSubdomain(logger, sd) if err != nil { return nil, err } - subdomains = append(subdomains, sub) + subdomains[i] = sub } return &watcher{ - logger: logger, - provider: provider, - domain: domain, - subdomains: subdomains, - webIP: webIP, + logger: logger, + provider: provider, + domain: domain, + subdomains: subdomains, + webIP: webIP, + firstRun: true, + pendingSubdomains: make(map[time.Time]subdomain.Subdomain), }, nil } func (w *watcher) Run(t *time.Ticker, chClose chan struct{}, chErr chan error) { - logger := w.logger.WithField("component", "run") + logger := w.logger.WithField("component", "Run") + + go w.checkPendingSubdomains(chClose) + if w.firstRun { + if err := w.runDDNSCheck(); err != nil { + chErr <- err + } + w.firstRun = false + } for { select { @@ -65,40 +85,75 @@ func (w *watcher) Run(t *time.Ticker, chClose chan struct{}, chErr chan error) { logger.Infoln("Close watcher channel triggered. Ticker stoped") return case <-t.C: - logger.Infoln("Starting DDNS check") - srvIP, err := utils.RetrieveServerIP(w.webIP) - if err != nil { + if err := w.runDDNSCheck(); err != nil { chErr <- err - continue - } - - logger.WithField("server-ip", srvIP).Debugln("Server IP retrieved. Checking subdomains...") - for _, sd := range w.subdomains { - ok, err := sd.CheckIPAddr(srvIP) - if err != nil { - logger.WithError(err).WithField("server-ip", srvIP).Errorln("failed to check ip addresses") - chErr <- err - continue - } - if !ok { - subIP := sd.GetSubdomainIP() - logger.WithFields(logrus.Fields{ - "server-ip": srvIP, - "subdomain-ip": subIP, - }).Infoln("IP addresses doesn't match. Updating subdomain's ip...") - if err := w.provider.UpdateIP(subIP, srvIP); err != nil { - logger.WithError(err).WithFields(logrus.Fields{ - "server-ip": srvIP, - "subdomain-ip": subIP, - }).Errorln("failed to update subdomain's ip") - chErr <- err - continue - } - logger.WithFields(logrus.Fields{ - "server-ip": srvIP, - "subdomain-ip": subIP, - }).Infoln("Subdomain updated successfully!") - } + } + } + } +} + +func (w *watcher) runDDNSCheck() error { + logger := w.logger.WithField("component", "runDDNSCheck") + logger.Infoln("Starting DDNS check...") + srvIP, err := utils.RetrieveServerIP(w.webIP) + if err != nil { + return err + } + + for _, sb := range w.subdomains { + if sb.SubIsPending(w.pendingSubdomains) { + continue + } + + logger.Debugf("Checking subdomain %s...\n", sb.GetSubdomainAddr()) + ok, err := sb.CheckIPAddr(srvIP) + if err != nil { + return err + } + subAddr := sb.GetSubdomainAddr() + if !ok { + logger.WithFields(logrus.Fields{ + "server-ip": srvIP, + "subdomain-address": subAddr, + }).Infoln("IP addresses doesn't match. Updating subdomain's ip...") + if err := w.provider.UpdateIP(subAddr, srvIP); err != nil { + logger.WithError(err).WithFields(logrus.Fields{ + "server-ip": srvIP, + "subdomain-address": subAddr, + }).Errorln("failed to update subdomain's ip") + return err + } + logger.WithFields(logrus.Fields{ + "server-ip": srvIP, + "subdomain-address": subAddr, + }).Infoln("Subdomain's ip updated! Removing from checks for 5 mins") + + w.pendingSubdomains[time.Now()] = sb + + continue + } + logger.Debugf("%s is up to date. \n", subAddr) + } + + logger.Infoln("DDNS check finished") + return nil +} + +func (w *watcher) checkPendingSubdomains(chClose chan struct{}) { + logger := w.logger.WithField("component", "checkPendingSubdomains") + t := time.NewTicker(time.Second * 10) + + logger.Debugln("Start checking for pending subdomains...") + for { + select { + case <-chClose: + logger.Debugln("Close pending subdomains") + return + case <-t.C: + logger.Debugln("Checking pending subdomains...") + if delSbs := subdomain.CheckPendingSubdomains(w.pendingSubdomains, time.Now()); delSbs != nil { + w.pendingSubdomains = subdomain.DeletePendingSubdomains(delSbs, w.pendingSubdomains) + logger.Debugln("Pendings subdomains found. Cleaned.") } } }