2020-07-26 04:50:18 +01:00
package header
import (
"encoding/base64"
"fmt"
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
)
type Injector interface {
Inject ( http . Header , * sessionsapi . SessionState )
}
type injector struct {
valueInjectors [ ] valueInjector
}
func ( i injector ) Inject ( header http . Header , session * sessionsapi . SessionState ) {
for _ , injector := range i . valueInjectors {
injector . inject ( header , session )
}
}
func NewInjector ( headers [ ] options . Header ) ( Injector , error ) {
injectors := [ ] valueInjector { }
for _ , header := range headers {
for _ , value := range header . Values {
injector , err := newValueinjector ( header . Name , value )
if err != nil {
return nil , fmt . Errorf ( "error building injector for header %q: %v" , header . Name , err )
}
injectors = append ( injectors , injector )
}
}
return & injector { valueInjectors : injectors } , nil
}
type valueInjector interface {
inject ( http . Header , * sessionsapi . SessionState )
}
func newValueinjector ( name string , value options . HeaderValue ) ( valueInjector , error ) {
switch {
case value . SecretSource != nil && value . ClaimSource == nil :
return newSecretInjector ( name , value . SecretSource )
case value . SecretSource == nil && value . ClaimSource != nil :
return newClaimInjector ( name , value . ClaimSource )
default :
return nil , fmt . Errorf ( "header %q value has multiple entries: only one entry per value is allowed" , name )
}
}
type injectorFunc struct {
injectFunc func ( http . Header , * sessionsapi . SessionState )
}
func ( i * injectorFunc ) inject ( header http . Header , session * sessionsapi . SessionState ) {
i . injectFunc ( header , session )
}
func newInjectorFunc ( injectFunc func ( header http . Header , session * sessionsapi . SessionState ) ) valueInjector {
return & injectorFunc { injectFunc : injectFunc }
}
func newSecretInjector ( name string , source * options . SecretSource ) ( valueInjector , error ) {
value , err := util . GetSecretValue ( source )
if err != nil {
return nil , fmt . Errorf ( "error getting secret value: %v" , err )
}
return newInjectorFunc ( func ( header http . Header , session * sessionsapi . SessionState ) {
header . Add ( name , string ( value ) )
} ) , nil
}
func newClaimInjector ( name string , source * options . ClaimSource ) ( valueInjector , error ) {
switch {
case source . BasicAuthPassword != nil :
password , err := util . GetSecretValue ( source . BasicAuthPassword )
if err != nil {
return nil , fmt . Errorf ( "error loading basicAuthPassword: %v" , err )
}
return newInjectorFunc ( func ( header http . Header , session * sessionsapi . SessionState ) {
2020-10-03 18:57:25 +01:00
claimValues := session . GetClaim ( source . Claim )
for _ , claim := range claimValues {
if claim == "" {
continue
}
auth := claim + ":" + string ( password )
header . Add ( name , "Basic " + base64 . StdEncoding . EncodeToString ( [ ] byte ( auth ) ) )
2020-07-26 04:50:18 +01:00
}
} ) , nil
case source . Prefix != "" :
return newInjectorFunc ( func ( header http . Header , session * sessionsapi . SessionState ) {
2020-10-03 18:57:25 +01:00
claimValues := session . GetClaim ( source . Claim )
for _ , claim := range claimValues {
if claim == "" {
continue
}
header . Add ( name , source . Prefix + claim )
2020-07-26 04:50:18 +01:00
}
} ) , nil
default :
return newInjectorFunc ( func ( header http . Header , session * sessionsapi . SessionState ) {
2020-10-03 18:57:25 +01:00
claimValues := session . GetClaim ( source . Claim )
for _ , claim := range claimValues {
if claim == "" {
continue
}
header . Add ( name , claim )
2020-07-26 04:50:18 +01:00
}
} ) , nil
}
}