3fa42edb73
* fix import path for v7 find ./ -name "*.go" | xargs sed -i -e 's|"github.com/oauth2-proxy/oauth2-proxy|"github.com/oauth2-proxy/oauth2-proxy/v7|' * fix module path * go mod tidy * fix installation docs * update CHANGELOG * Update CHANGELOG.md Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk> Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
137 lines
3.5 KiB
Go
137 lines
3.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
|
)
|
|
|
|
// Server represents an HTTP server
|
|
type Server struct {
|
|
Handler http.Handler
|
|
Opts *options.Options
|
|
stop chan struct{} // channel for waiting shutdown
|
|
}
|
|
|
|
// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
|
|
func (s *Server) ListenAndServe() {
|
|
if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
|
|
s.ServeHTTPS()
|
|
} else {
|
|
s.ServeHTTP()
|
|
}
|
|
}
|
|
|
|
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
|
|
func (s *Server) ServeHTTP() {
|
|
HTTPAddress := s.Opts.HTTPAddress
|
|
var scheme string
|
|
|
|
i := strings.Index(HTTPAddress, "://")
|
|
if i > -1 {
|
|
scheme = HTTPAddress[0:i]
|
|
}
|
|
|
|
var networkType string
|
|
switch scheme {
|
|
case "", "http":
|
|
networkType = "tcp"
|
|
default:
|
|
networkType = scheme
|
|
}
|
|
|
|
slice := strings.SplitN(HTTPAddress, "//", 2)
|
|
listenAddr := slice[len(slice)-1]
|
|
|
|
listener, err := net.Listen(networkType, listenAddr)
|
|
if err != nil {
|
|
logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err)
|
|
}
|
|
logger.Printf("HTTP: listening on %s", listenAddr)
|
|
s.serve(listener)
|
|
logger.Printf("HTTP: closing %s", listener.Addr())
|
|
}
|
|
|
|
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
|
|
func (s *Server) ServeHTTPS() {
|
|
addr := s.Opts.HTTPSAddress
|
|
config := &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
MaxVersion: tls.VersionTLS12,
|
|
}
|
|
if config.NextProtos == nil {
|
|
config.NextProtos = []string{"http/1.1"}
|
|
}
|
|
|
|
var err error
|
|
config.Certificates = make([]tls.Certificate, 1)
|
|
config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile)
|
|
if err != nil {
|
|
logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err)
|
|
}
|
|
|
|
ln, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err)
|
|
}
|
|
logger.Printf("HTTPS: listening on %s", ln.Addr())
|
|
|
|
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
|
|
s.serve(tlsListener)
|
|
logger.Printf("HTTPS: closing %s", tlsListener.Addr())
|
|
}
|
|
|
|
func (s *Server) serve(listener net.Listener) {
|
|
srv := &http.Server{Handler: s.Handler}
|
|
|
|
// See https://golang.org/pkg/net/http/#Server.Shutdown
|
|
idleConnsClosed := make(chan struct{})
|
|
go func() {
|
|
<-s.stop // wait notification for stopping server
|
|
|
|
// We received an interrupt signal, shut down.
|
|
if err := srv.Shutdown(context.Background()); err != nil {
|
|
// Error from closing listeners, or context timeout:
|
|
logger.Printf("HTTP server Shutdown: %v", err)
|
|
}
|
|
close(idleConnsClosed)
|
|
}()
|
|
|
|
err := srv.Serve(listener)
|
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
logger.Errorf("ERROR: http.Serve() - %s", err)
|
|
}
|
|
<-idleConnsClosed
|
|
}
|
|
|
|
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
|
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
|
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
|
// go away.
|
|
type tcpKeepAliveListener struct {
|
|
*net.TCPListener
|
|
}
|
|
|
|
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
|
tc, err := ln.AcceptTCP()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = tc.SetKeepAlive(true)
|
|
if err != nil {
|
|
logger.Printf("Error setting Keep-Alive: %v", err)
|
|
}
|
|
err = tc.SetKeepAlivePeriod(3 * time.Minute)
|
|
if err != nil {
|
|
logger.Printf("Error setting Keep-Alive period: %v", err)
|
|
}
|
|
return tc, nil
|
|
}
|