2017-10-10 15:50:03 +03:00
package whitelist
import (
2018-04-23 17:20:05 +03:00
"errors"
2017-10-10 15:50:03 +03:00
"fmt"
"net"
2018-03-23 19:40:04 +03:00
"net/http"
2018-04-23 17:20:05 +03:00
"strings"
2017-10-10 15:50:03 +03:00
)
2018-03-23 19:40:04 +03:00
const (
// XForwardedFor Header name
XForwardedFor = "X-Forwarded-For"
)
2017-10-10 15:50:03 +03:00
// IP allows to check that addresses are in a white list
type IP struct {
2018-03-23 19:40:04 +03:00
whiteListsIPs [ ] * net . IP
whiteListsNet [ ] * net . IPNet
insecure bool
useXForwardedFor bool
2017-10-10 15:50:03 +03:00
}
2018-03-23 19:40:04 +03:00
// NewIP builds a new IP given a list of CIDR-Strings to white list
func NewIP ( whiteList [ ] string , insecure bool , useXForwardedFor bool ) ( * IP , error ) {
if len ( whiteList ) == 0 && ! insecure {
2018-03-08 17:08:03 +03:00
return nil , errors . New ( "no white list provided" )
2017-10-10 15:50:03 +03:00
}
2018-03-23 19:40:04 +03:00
ip := IP {
insecure : insecure ,
useXForwardedFor : useXForwardedFor ,
}
2017-10-10 15:50:03 +03:00
2017-10-16 13:46:03 +03:00
if ! insecure {
2018-03-23 19:40:04 +03:00
for _ , ipMask := range whiteList {
if ipAddr := net . ParseIP ( ipMask ) ; ipAddr != nil {
2017-10-16 13:46:03 +03:00
ip . whiteListsIPs = append ( ip . whiteListsIPs , & ipAddr )
} else {
2018-03-23 19:40:04 +03:00
_ , ipAddr , err := net . ParseCIDR ( ipMask )
2017-10-16 13:46:03 +03:00
if err != nil {
2018-03-23 19:40:04 +03:00
return nil , fmt . Errorf ( "parsing CIDR white list %s: %v" , ipAddr , err )
2017-10-16 13:46:03 +03:00
}
2018-03-23 19:40:04 +03:00
ip . whiteListsNet = append ( ip . whiteListsNet , ipAddr )
2017-10-10 15:50:03 +03:00
}
}
}
return & ip , nil
}
2018-03-23 19:40:04 +03:00
// IsAuthorized checks if provided request is authorized by the white list
2018-04-23 17:20:05 +03:00
func ( ip * IP ) IsAuthorized ( req * http . Request ) error {
2017-10-16 13:46:03 +03:00
if ip . insecure {
2018-04-23 17:20:05 +03:00
return nil
2017-10-16 13:46:03 +03:00
}
2018-04-23 17:20:05 +03:00
var invalidMatches [ ] string
2018-03-23 19:40:04 +03:00
if ip . useXForwardedFor {
xFFs := req . Header [ XForwardedFor ]
2018-04-23 17:20:05 +03:00
if len ( xFFs ) > 0 {
2018-03-23 19:40:04 +03:00
for _ , xFF := range xFFs {
2018-05-30 10:26:03 +03:00
xffs := strings . Split ( xFF , "," )
for _ , xff := range xffs {
ok , err := ip . contains ( parseHost ( xff ) )
if err != nil {
return err
}
2018-03-23 19:40:04 +03:00
2018-05-30 10:26:03 +03:00
if ok {
return nil
}
2018-04-23 17:20:05 +03:00
2018-05-30 10:26:03 +03:00
invalidMatches = append ( invalidMatches , xff )
}
2018-03-23 19:40:04 +03:00
}
}
}
host , _ , err := net . SplitHostPort ( req . RemoteAddr )
if err != nil {
2018-04-23 17:20:05 +03:00
return err
2018-03-23 19:40:04 +03:00
}
2018-04-23 17:20:05 +03:00
ok , err := ip . contains ( host )
if err != nil {
return err
}
if ! ok {
invalidMatches = append ( invalidMatches , req . RemoteAddr )
return fmt . Errorf ( "%q matched none of the white list" , strings . Join ( invalidMatches , ", " ) )
}
return nil
2018-03-23 19:40:04 +03:00
}
// contains checks if provided address is in the white list
2018-04-23 17:20:05 +03:00
func ( ip * IP ) contains ( addr string ) ( bool , error ) {
2018-03-23 19:40:04 +03:00
ipAddr , err := parseIP ( addr )
2017-10-10 15:50:03 +03:00
if err != nil {
2018-04-23 17:20:05 +03:00
return false , fmt . Errorf ( "unable to parse address: %s: %s" , addr , err )
2017-10-10 15:50:03 +03:00
}
2018-04-23 17:20:05 +03:00
return ip . ContainsIP ( ipAddr ) , nil
2017-10-10 15:50:03 +03:00
}
// ContainsIP checks if provided address is in the white list
2018-04-23 17:20:05 +03:00
func ( ip * IP ) ContainsIP ( addr net . IP ) bool {
2017-10-16 13:46:03 +03:00
if ip . insecure {
2018-04-23 17:20:05 +03:00
return true
2017-10-16 13:46:03 +03:00
}
2017-10-10 15:50:03 +03:00
for _ , whiteListIP := range ip . whiteListsIPs {
if whiteListIP . Equal ( addr ) {
2018-04-23 17:20:05 +03:00
return true
2017-10-10 15:50:03 +03:00
}
}
for _ , whiteListNet := range ip . whiteListsNet {
if whiteListNet . Contains ( addr ) {
2018-04-23 17:20:05 +03:00
return true
2017-10-10 15:50:03 +03:00
}
}
2018-04-23 17:20:05 +03:00
return false
2017-10-10 15:50:03 +03:00
}
2018-03-23 19:40:04 +03:00
func parseIP ( addr string ) ( net . IP , error ) {
2017-10-10 15:50:03 +03:00
userIP := net . ParseIP ( addr )
if userIP == nil {
return nil , fmt . Errorf ( "can't parse IP from address %s" , addr )
}
return userIP , nil
}
2018-03-23 19:40:04 +03:00
func parseHost ( addr string ) string {
host , _ , err := net . SplitHostPort ( addr )
if err != nil {
return addr
}
return host
}