/*
Copyright
*/
package main

import (
	"crypto/tls"
	"encoding/json"
	"errors"
	"net/http"
	"net/url"
	"os"
	"os/signal"
	"reflect"
	"regexp"
	"sort"
	"sync"
	"syscall"
	"time"

	log "github.com/Sirupsen/logrus"
	"github.com/codegangsta/negroni"
	"github.com/containous/oxy/cbreaker"
	"github.com/containous/oxy/forward"
	"github.com/containous/oxy/roundrobin"
	"github.com/containous/traefik/middlewares"
	"github.com/containous/traefik/provider"
	"github.com/containous/traefik/types"
	"github.com/gorilla/mux"
	"github.com/mailgun/manners"
)

var oxyLogger = &OxyLogger{}

// Server is the reverse-proxy/load-balancer engine
type Server struct {
	serverEntryPoints          map[string]serverEntryPoint
	configurationChan          chan types.ConfigMessage
	configurationValidatedChan chan types.ConfigMessage
	signals                    chan os.Signal
	stopChan                   chan bool
	providers                  []provider.Provider
	serverLock                 sync.Mutex
	currentConfigurations      configs
	globalConfiguration        GlobalConfiguration
	loggerMiddleware           *middlewares.Logger
}

type serverEntryPoint struct {
	httpServer *manners.GracefulServer
	httpRouter *middlewares.HandlerSwitcher
}

// NewServer returns an initialized Server.
func NewServer(globalConfiguration GlobalConfiguration) *Server {
	server := new(Server)

	server.serverEntryPoints = make(map[string]serverEntryPoint)
	server.configurationChan = make(chan types.ConfigMessage, 10)
	server.configurationValidatedChan = make(chan types.ConfigMessage, 10)
	server.signals = make(chan os.Signal, 1)
	server.stopChan = make(chan bool)
	server.providers = []provider.Provider{}
	signal.Notify(server.signals, syscall.SIGINT, syscall.SIGTERM)
	server.currentConfigurations = make(configs)
	server.globalConfiguration = globalConfiguration
	server.loggerMiddleware = middlewares.NewLogger(globalConfiguration.AccessLogsFile)

	return server
}

// Start starts the server and blocks until server is shutted down.
func (server *Server) Start() {
	go server.listenProviders()
	go server.listenConfigurations()
	server.configureProviders()
	server.startProviders()
	go server.listenSignals()
	<-server.stopChan
}

// Stop stops the server
func (server *Server) Stop() {
	for _, serverEntryPoint := range server.serverEntryPoints {
		serverEntryPoint.httpServer.BlockingClose()
	}
	server.stopChan <- true
}

// Close destroys the server
func (server *Server) Close() {
	close(server.configurationChan)
	close(server.configurationValidatedChan)
	close(server.signals)
	close(server.stopChan)
	server.loggerMiddleware.Close()
}

func (server *Server) listenProviders() {
	lastReceivedConfiguration := time.Unix(0, 0)
	lastConfigs := make(map[string]*types.ConfigMessage)
	for {
		configMsg := <-server.configurationChan
		jsonConf, _ := json.Marshal(configMsg.Configuration)
		log.Debugf("Configuration receveived from provider %s: %s", configMsg.ProviderName, string(jsonConf))
		lastConfigs[configMsg.ProviderName] = &configMsg
		if time.Now().After(lastReceivedConfiguration.Add(time.Duration(server.globalConfiguration.ProvidersThrottleDuration))) {
			log.Debugf("Last %s config received more than %s, OK", configMsg.ProviderName, server.globalConfiguration.ProvidersThrottleDuration)
			// last config received more than n s ago
			server.configurationValidatedChan <- configMsg
		} else {
			log.Debugf("Last %s config received less than %s, waiting...", configMsg.ProviderName, server.globalConfiguration.ProvidersThrottleDuration)
			go func() {
				<-time.After(server.globalConfiguration.ProvidersThrottleDuration)
				if time.Now().After(lastReceivedConfiguration.Add(time.Duration(server.globalConfiguration.ProvidersThrottleDuration))) {
					log.Debugf("Waited for %s config, OK", configMsg.ProviderName)
					server.configurationValidatedChan <- *lastConfigs[configMsg.ProviderName]
				}
			}()
		}
		lastReceivedConfiguration = time.Now()
	}
}

func (server *Server) listenConfigurations() {
	for {
		configMsg := <-server.configurationValidatedChan
		if configMsg.Configuration == nil {
			log.Info("Skipping empty Configuration")
		} else if reflect.DeepEqual(server.currentConfigurations[configMsg.ProviderName], configMsg.Configuration) {
			log.Info("Skipping same configuration")
		} else {
			// Copy configurations to new map so we don't change current if LoadConfig fails
			newConfigurations := make(configs)
			for k, v := range server.currentConfigurations {
				newConfigurations[k] = v
			}
			newConfigurations[configMsg.ProviderName] = configMsg.Configuration

			newServerEntryPoints, err := server.loadConfig(newConfigurations, server.globalConfiguration)
			if err == nil {
				server.serverLock.Lock()
				for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints {
					currentServerEntryPoint := server.serverEntryPoints[newServerEntryPointName]
					if currentServerEntryPoint.httpServer == nil {
						newsrv, err := server.prepareServer(newServerEntryPoint.httpRouter, server.globalConfiguration.EntryPoints[newServerEntryPointName], nil, server.loggerMiddleware, metrics)
						if err != nil {
							log.Fatal("Error preparing server: ", err)
						}
						go server.startServer(newsrv, server.globalConfiguration)
						currentServerEntryPoint.httpServer = newsrv
						currentServerEntryPoint.httpRouter = newServerEntryPoint.httpRouter
						server.serverEntryPoints[newServerEntryPointName] = currentServerEntryPoint
						log.Infof("Created new Handler: %p", newServerEntryPoint.httpRouter.GetHandler())
					} else {
						handlerSwitcher := currentServerEntryPoint.httpRouter
						handlerSwitcher.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler())
						log.Infof("Created new Handler: %p", newServerEntryPoint.httpRouter.GetHandler())
					}
				}
				server.currentConfigurations = newConfigurations
				server.serverLock.Unlock()
			} else {
				log.Error("Error loading new configuration, aborted ", err)
			}
		}
	}
}

func (server *Server) configureProviders() {
	// configure providers
	if server.globalConfiguration.Docker != nil {
		server.providers = append(server.providers, server.globalConfiguration.Docker)
	}
	if server.globalConfiguration.Marathon != nil {
		server.providers = append(server.providers, server.globalConfiguration.Marathon)
	}
	if server.globalConfiguration.File != nil {
		server.providers = append(server.providers, server.globalConfiguration.File)
	}
	if server.globalConfiguration.Web != nil {
		server.globalConfiguration.Web.server = server
		server.providers = append(server.providers, server.globalConfiguration.Web)
	}
	if server.globalConfiguration.Consul != nil {
		server.providers = append(server.providers, server.globalConfiguration.Consul)
	}
	if server.globalConfiguration.ConsulCatalog != nil {
		server.providers = append(server.providers, server.globalConfiguration.ConsulCatalog)
	}
	if server.globalConfiguration.Etcd != nil {
		server.providers = append(server.providers, server.globalConfiguration.Etcd)
	}
	if server.globalConfiguration.Zookeeper != nil {
		server.providers = append(server.providers, server.globalConfiguration.Zookeeper)
	}
	if server.globalConfiguration.Boltdb != nil {
		server.providers = append(server.providers, server.globalConfiguration.Boltdb)
	}
}

func (server *Server) startProviders() {
	// start providers
	for _, provider := range server.providers {
		jsonConf, _ := json.Marshal(provider)
		log.Infof("Starting provider %v %s", reflect.TypeOf(provider), jsonConf)
		currentProvider := provider
		go func() {
			err := currentProvider.Provide(server.configurationChan)
			if err != nil {
				log.Errorf("Error starting provider %s", err)
			}
		}()
	}
}

func (server *Server) listenSignals() {
	sig := <-server.signals
	log.Infof("I have to go... %+v", sig)
	log.Info("Stopping server")
	server.Stop()
}

// creates a TLS config that allows terminating HTTPS for multiple domains using SNI
func (server *Server) createTLSConfig(tlsOption *TLS) (*tls.Config, error) {
	if tlsOption == nil {
		return nil, nil
	}
	if len(tlsOption.Certificates) == 0 {
		return nil, nil
	}

	config := &tls.Config{}
	if config.NextProtos == nil {
		config.NextProtos = []string{"http/1.1"}
	}

	var err error
	config.Certificates = make([]tls.Certificate, len(tlsOption.Certificates))
	for i, v := range tlsOption.Certificates {
		config.Certificates[i], err = tls.LoadX509KeyPair(v.CertFile, v.KeyFile)
		if err != nil {
			return nil, err
		}
	}
	// BuildNameToCertificate parses the CommonName and SubjectAlternateName fields
	// in each certificate and populates the config.NameToCertificate map.
	config.BuildNameToCertificate()
	return config, nil
}

func (server *Server) startServer(srv *manners.GracefulServer, globalConfiguration GlobalConfiguration) {
	log.Info("Starting server on ", srv.Addr)
	if srv.TLSConfig != nil {
		err := srv.ListenAndServeTLSWithConfig(srv.TLSConfig)
		if err != nil {
			log.Fatal("Error creating server: ", err)
		}
	} else {
		err := srv.ListenAndServe()
		if err != nil {
			log.Fatal("Error creating server: ", err)
		}
	}
	log.Info("Server stopped")
}

func (server *Server) prepareServer(router http.Handler, entryPoint *EntryPoint, oldServer *manners.GracefulServer, middlewares ...negroni.Handler) (*manners.GracefulServer, error) {
	log.Info("Preparing server")
	// middlewares
	var negroni = negroni.New()
	for _, middleware := range middlewares {
		negroni.Use(middleware)
	}
	negroni.UseHandler(router)
	tlsConfig, err := server.createTLSConfig(entryPoint.TLS)
	if err != nil {
		log.Fatalf("Error creating TLS config %s", err)
		return nil, err
	}

	if oldServer == nil {
		return manners.NewWithServer(
			&http.Server{
				Addr:      entryPoint.Address,
				Handler:   negroni,
				TLSConfig: tlsConfig,
			}), nil
	}
	gracefulServer, err := oldServer.HijackListener(&http.Server{
		Addr:      entryPoint.Address,
		Handler:   negroni,
		TLSConfig: tlsConfig,
	}, tlsConfig)
	if err != nil {
		log.Fatalf("Error hijacking server %s", err)
		return nil, err
	}
	return gracefulServer, nil
}

func (server *Server) buildEntryPoints(globalConfiguration GlobalConfiguration) map[string]serverEntryPoint {
	serverEntryPoints := make(map[string]serverEntryPoint)
	for entryPointName := range globalConfiguration.EntryPoints {
		router := server.buildDefaultHTTPRouter()
		serverEntryPoints[entryPointName] = serverEntryPoint{
			httpRouter: middlewares.NewHandlerSwitcher(router),
		}
	}
	return serverEntryPoints
}

// LoadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic
// provider configurations.
func (server *Server) loadConfig(configurations configs, globalConfiguration GlobalConfiguration) (map[string]serverEntryPoint, error) {
	serverEntryPoints := server.buildEntryPoints(globalConfiguration)
	redirectHandlers := make(map[string]http.Handler)

	backends := map[string]http.Handler{}
	for _, configuration := range configurations {
		frontendNames := sortedFrontendNamesForConfig(configuration)
		for _, frontendName := range frontendNames {
			frontend := configuration.Frontends[frontendName]

			log.Debugf("Creating frontend %s", frontendName)
			fwd, _ := forward.New(forward.Logger(oxyLogger), forward.PassHostHeader(frontend.PassHostHeader))
			// default endpoints if not defined in frontends
			if len(frontend.EntryPoints) == 0 {
				frontend.EntryPoints = globalConfiguration.DefaultEntryPoints
			}
			for _, entryPointName := range frontend.EntryPoints {
				log.Debugf("Wiring frontend %s to entryPoint %s", frontendName, entryPointName)
				if _, ok := serverEntryPoints[entryPointName]; !ok {
					return nil, errors.New("Undefined entrypoint: " + entryPointName)
				}
				newRoute := serverEntryPoints[entryPointName].httpRouter.GetHandler().NewRoute().Name(frontendName)
				for routeName, route := range frontend.Routes {
					log.Debugf("Creating route %s %s:%s", routeName, route.Rule, route.Value)
					route, err := getRoute(newRoute, route.Rule, route.Value)
					if err != nil {
						return nil, err
					}
					newRoute = route
				}
				entryPoint := globalConfiguration.EntryPoints[entryPointName]
				if entryPoint.Redirect != nil {
					if redirectHandlers[entryPointName] != nil {
						newRoute.Handler(redirectHandlers[entryPointName])
					} else if handler, err := server.loadEntryPointConfig(entryPointName, entryPoint); err != nil {
						return nil, err
					} else {
						newRoute.Handler(handler)
						redirectHandlers[entryPointName] = handler
					}
				} else {
					if backends[frontend.Backend] == nil {
						log.Debugf("Creating backend %s", frontend.Backend)
						var lb http.Handler
						rr, _ := roundrobin.New(fwd)
						if configuration.Backends[frontend.Backend] == nil {
							return nil, errors.New("Undefined backend: " + frontend.Backend)
						}
						lbMethod, err := types.NewLoadBalancerMethod(configuration.Backends[frontend.Backend].LoadBalancer)
						if err != nil {
							configuration.Backends[frontend.Backend].LoadBalancer = &types.LoadBalancer{Method: "wrr"}
						}
						switch lbMethod {
						case types.Drr:
							log.Debugf("Creating load-balancer drr")
							rebalancer, _ := roundrobin.NewRebalancer(rr, roundrobin.RebalancerLogger(oxyLogger))
							lb = rebalancer
							for serverName, server := range configuration.Backends[frontend.Backend].Servers {
								url, err := url.Parse(server.URL)
								if err != nil {
									return nil, err
								}
								log.Debugf("Creating server %s at %s with weight %d", serverName, url.String(), server.Weight)
								rebalancer.UpsertServer(url, roundrobin.Weight(server.Weight))
							}
						case types.Wrr:
							log.Debugf("Creating load-balancer wrr")
							lb = rr
							for serverName, server := range configuration.Backends[frontend.Backend].Servers {
								url, err := url.Parse(server.URL)
								if err != nil {
									return nil, err
								}
								log.Debugf("Creating server %s at %s with weight %d", serverName, url.String(), server.Weight)
								rr.UpsertServer(url, roundrobin.Weight(server.Weight))
							}
						}
						var negroni = negroni.New()
						if configuration.Backends[frontend.Backend].CircuitBreaker != nil {
							log.Debugf("Creating circuit breaker %s", configuration.Backends[frontend.Backend].CircuitBreaker.Expression)
							negroni.Use(middlewares.NewCircuitBreaker(lb, configuration.Backends[frontend.Backend].CircuitBreaker.Expression, cbreaker.Logger(oxyLogger)))
						} else {
							negroni.UseHandler(lb)
						}
						backends[frontend.Backend] = negroni
					} else {
						log.Debugf("Reusing backend %s", frontend.Backend)
					}
					server.wireFrontendBackend(frontend.Routes, newRoute, backends[frontend.Backend])
				}
				err := newRoute.GetError()
				if err != nil {
					log.Errorf("Error building route: %s", err)
				}
			}
		}
	}
	return serverEntryPoints, nil
}

func (server *Server) wireFrontendBackend(routes map[string]types.Route, newRoute *mux.Route, handler http.Handler) {
	// strip prefix
	var strip bool
	for _, route := range routes {
		switch route.Rule {
		case "PathStrip":
			newRoute.Handler(&middlewares.StripPrefix{
				Prefix:  route.Value,
				Handler: handler,
			})
			strip = true
			break
		case "PathPrefixStrip":
			newRoute.Handler(&middlewares.StripPrefix{
				Prefix:  route.Value,
				Handler: handler,
			})
			strip = true
			break
		}
	}
	if !strip {
		newRoute.Handler(handler)
	}
}

func (server *Server) loadEntryPointConfig(entryPointName string, entryPoint *EntryPoint) (http.Handler, error) {
	regex := entryPoint.Redirect.Regex
	replacement := entryPoint.Redirect.Replacement
	if len(entryPoint.Redirect.EntryPoint) > 0 {
		regex = "^(?:https?:\\/\\/)?([\\da-z\\.-]+)(?::\\d+)?(.*)$"
		if server.globalConfiguration.EntryPoints[entryPoint.Redirect.EntryPoint] == nil {
			return nil, errors.New("Unknown entrypoint " + entryPoint.Redirect.EntryPoint)
		}
		protocol := "http"
		if server.globalConfiguration.EntryPoints[entryPoint.Redirect.EntryPoint].TLS != nil {
			protocol = "https"
		}
		r, _ := regexp.Compile("(:\\d+)")
		match := r.FindStringSubmatch(server.globalConfiguration.EntryPoints[entryPoint.Redirect.EntryPoint].Address)
		if len(match) == 0 {
			return nil, errors.New("Bad Address format: " + server.globalConfiguration.EntryPoints[entryPoint.Redirect.EntryPoint].Address)
		}
		replacement = protocol + "://$1" + match[0] + "$2"
	}
	rewrite, err := middlewares.NewRewrite(regex, replacement, true)
	if err != nil {
		return nil, err
	}
	log.Debugf("Creating entryPoint redirect %s -> %s : %s -> %s", entryPointName, entryPoint.Redirect.EntryPoint, regex, replacement)
	negroni := negroni.New()
	negroni.Use(rewrite)
	return negroni, nil
}

func (server *Server) buildDefaultHTTPRouter() *mux.Router {
	router := mux.NewRouter()
	router.NotFoundHandler = http.HandlerFunc(notFoundHandler)
	router.StrictSlash(true)
	return router
}

func getRoute(any interface{}, rule string, value ...interface{}) (*mux.Route, error) {
	switch rule {
	case "PathStrip":
		rule = "Path"
	case "PathPrefixStrip":
		rule = "PathPrefix"
	}
	inputs := make([]reflect.Value, len(value))
	for i := range value {
		inputs[i] = reflect.ValueOf(value[i])
	}
	method := reflect.ValueOf(any).MethodByName(rule)
	if method.IsValid() {
		return method.Call(inputs)[0].Interface().(*mux.Route), nil
	}
	return nil, errors.New("Method not found: " + rule)
}

func sortedFrontendNamesForConfig(configuration *types.Configuration) []string {
	keys := []string{}

	for key := range configuration.Frontends {
		keys = append(keys, key)
	}

	sort.Strings(keys)

	return keys
}