package middleware

import (
	"net/http"
	"strings"
)

// RouteHeaders is a neat little header-based router that allows you to direct
// the flow of a request through a middleware stack based on a request header.
//
// For example, lets say you'd like to setup multiple routers depending on the
// request Host header, you could then do something as so:
//
// r := chi.NewRouter()
// rSubdomain := chi.NewRouter()
//
// r.Use(middleware.RouteHeaders().
//   Route("Host", "example.com", middleware.New(r)).
//   Route("Host", "*.example.com", middleware.New(rSubdomain)).
//   Handler)
//
// r.Get("/", h)
// rSubdomain.Get("/", h2)
//
//
// Another example, imagine you want to setup multiple CORS handlers, where for
// your origin servers you allow authorized requests, but for third-party public
// requests, authorization is disabled.
//
// r := chi.NewRouter()
//
// r.Use(middleware.RouteHeaders().
//   Route("Origin", "https://app.skyweaver.net", cors.Handler(cors.Options{
// 	   AllowedOrigins:   []string{"https://api.skyweaver.net"},
// 	   AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
// 	   AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type"},
// 	   AllowCredentials: true, // <----------<<< allow credentials
//   })).
//   Route("Origin", "*", cors.Handler(cors.Options{
// 	   AllowedOrigins:   []string{"*"},
// 	   AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
// 	   AllowedHeaders:   []string{"Accept", "Content-Type"},
// 	   AllowCredentials: false, // <----------<<< do not allow credentials
//   })).
//   Handler)
//
func RouteHeaders() HeaderRouter {
	return HeaderRouter{}
}

type HeaderRouter map[string][]HeaderRoute

func (hr HeaderRouter) Route(header, match string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
	header = strings.ToLower(header)
	k := hr[header]
	if k == nil {
		hr[header] = []HeaderRoute{}
	}
	hr[header] = append(hr[header], HeaderRoute{MatchOne: NewPattern(match), Middleware: middlewareHandler})
	return hr
}

func (hr HeaderRouter) RouteAny(header string, match []string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
	header = strings.ToLower(header)
	k := hr[header]
	if k == nil {
		hr[header] = []HeaderRoute{}
	}
	patterns := []Pattern{}
	for _, m := range match {
		patterns = append(patterns, NewPattern(m))
	}
	hr[header] = append(hr[header], HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler})
	return hr
}

func (hr HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) HeaderRouter {
	hr["*"] = []HeaderRoute{{Middleware: handler}}
	return hr
}

func (hr HeaderRouter) Handler(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if len(hr) == 0 {
			// skip if no routes set
			next.ServeHTTP(w, r)
		}

		// find first matching header route, and continue
		for header, matchers := range hr {
			headerValue := r.Header.Get(header)
			if headerValue == "" {
				continue
			}
			headerValue = strings.ToLower(headerValue)
			for _, matcher := range matchers {
				if matcher.IsMatch(headerValue) {
					matcher.Middleware(next).ServeHTTP(w, r)
					return
				}
			}
		}

		// if no match, check for "*" default route
		matcher, ok := hr["*"]
		if !ok || matcher[0].Middleware == nil {
			next.ServeHTTP(w, r)
			return
		}
		matcher[0].Middleware(next).ServeHTTP(w, r)
	})
}

type HeaderRoute struct {
	MatchAny   []Pattern
	MatchOne   Pattern
	Middleware func(next http.Handler) http.Handler
}

func (r HeaderRoute) IsMatch(value string) bool {
	if len(r.MatchAny) > 0 {
		for _, m := range r.MatchAny {
			if m.Match(value) {
				return true
			}
		}
	} else if r.MatchOne.Match(value) {
		return true
	}
	return false
}

type Pattern struct {
	prefix   string
	suffix   string
	wildcard bool
}

func NewPattern(value string) Pattern {
	p := Pattern{}
	if i := strings.IndexByte(value, '*'); i >= 0 {
		p.wildcard = true
		p.prefix = value[0:i]
		p.suffix = value[i+1:]
	} else {
		p.prefix = value
	}
	return p
}

func (p Pattern) Match(v string) bool {
	if !p.wildcard {
		if p.prefix == v {
			return true
		} else {
			return false
		}
	}
	return len(v) >= len(p.prefix+p.suffix) && strings.HasPrefix(v, p.prefix) && strings.HasSuffix(v, p.suffix)
}