2023-06-07 16:58:04 +03:00
// SPDX-License-Identifier: GPL-2.0-only
/*
* Copyright ( c ) 2021 , 2022 Oracle . All rights reserved .
*
* The AUTH_TLS credential is used only to probe a remote peer
* for RPC - over - TLS support .
*/
# include <linux/types.h>
# include <linux/module.h>
# include <linux/sunrpc/clnt.h>
static const char * starttls_token = " STARTTLS " ;
static const size_t starttls_len = 8 ;
static struct rpc_auth tls_auth ;
static struct rpc_cred tls_cred ;
static void tls_encode_probe ( struct rpc_rqst * rqstp , struct xdr_stream * xdr ,
const void * obj )
{
}
static int tls_decode_probe ( struct rpc_rqst * rqstp , struct xdr_stream * xdr ,
void * obj )
{
return 0 ;
}
static const struct rpc_procinfo rpcproc_tls_probe = {
. p_encode = tls_encode_probe ,
. p_decode = tls_decode_probe ,
} ;
static void rpc_tls_probe_call_prepare ( struct rpc_task * task , void * data )
{
task - > tk_flags & = ~ RPC_TASK_NO_RETRANS_TIMEOUT ;
rpc_call_start ( task ) ;
}
static void rpc_tls_probe_call_done ( struct rpc_task * task , void * data )
{
}
static const struct rpc_call_ops rpc_tls_probe_ops = {
. rpc_call_prepare = rpc_tls_probe_call_prepare ,
. rpc_call_done = rpc_tls_probe_call_done ,
} ;
static int tls_probe ( struct rpc_clnt * clnt )
{
struct rpc_message msg = {
. rpc_proc = & rpcproc_tls_probe ,
} ;
struct rpc_task_setup task_setup_data = {
. rpc_client = clnt ,
. rpc_message = & msg ,
. rpc_op_cred = & tls_cred ,
. callback_ops = & rpc_tls_probe_ops ,
. flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN ,
} ;
struct rpc_task * task ;
int status ;
task = rpc_run_task ( & task_setup_data ) ;
if ( IS_ERR ( task ) )
return PTR_ERR ( task ) ;
status = task - > tk_status ;
rpc_put_task ( task ) ;
return status ;
}
static struct rpc_auth * tls_create ( const struct rpc_auth_create_args * args ,
struct rpc_clnt * clnt )
{
refcount_inc ( & tls_auth . au_count ) ;
return & tls_auth ;
}
static void tls_destroy ( struct rpc_auth * auth )
{
}
static struct rpc_cred * tls_lookup_cred ( struct rpc_auth * auth ,
struct auth_cred * acred , int flags )
{
return get_rpccred ( & tls_cred ) ;
}
static void tls_destroy_cred ( struct rpc_cred * cred )
{
}
static int tls_match ( struct auth_cred * acred , struct rpc_cred * cred , int taskflags )
{
return 1 ;
}
static int tls_marshal ( struct rpc_task * task , struct xdr_stream * xdr )
{
__be32 * p ;
p = xdr_reserve_space ( xdr , 4 * XDR_UNIT ) ;
if ( ! p )
return - EMSGSIZE ;
/* Credential */
* p + + = rpc_auth_tls ;
* p + + = xdr_zero ;
/* Verifier */
* p + + = rpc_auth_null ;
* p = xdr_zero ;
return 0 ;
}
static int tls_refresh ( struct rpc_task * task )
{
set_bit ( RPCAUTH_CRED_UPTODATE , & task - > tk_rqstp - > rq_cred - > cr_flags ) ;
return 0 ;
}
static int tls_validate ( struct rpc_task * task , struct xdr_stream * xdr )
{
__be32 * p ;
void * str ;
p = xdr_inline_decode ( xdr , XDR_UNIT ) ;
if ( ! p )
return - EIO ;
if ( * p ! = rpc_auth_null )
return - EIO ;
if ( xdr_stream_decode_opaque_inline ( xdr , & str , starttls_len ) ! = starttls_len )
2023-09-06 23:05:26 +03:00
return - EPROTONOSUPPORT ;
2023-06-07 16:58:04 +03:00
if ( memcmp ( str , starttls_token , starttls_len ) )
2023-09-06 23:05:26 +03:00
return - EPROTONOSUPPORT ;
2023-06-07 16:58:04 +03:00
return 0 ;
}
const struct rpc_authops authtls_ops = {
. owner = THIS_MODULE ,
. au_flavor = RPC_AUTH_TLS ,
. au_name = " NULL " ,
. create = tls_create ,
. destroy = tls_destroy ,
. lookup_cred = tls_lookup_cred ,
. ping = tls_probe ,
} ;
static struct rpc_auth tls_auth = {
. au_cslack = NUL_CALLSLACK ,
. au_rslack = NUL_REPLYSLACK ,
. au_verfsize = NUL_REPLYSLACK ,
. au_ralign = NUL_REPLYSLACK ,
. au_ops = & authtls_ops ,
. au_flavor = RPC_AUTH_TLS ,
. au_count = REFCOUNT_INIT ( 1 ) ,
} ;
static const struct rpc_credops tls_credops = {
. cr_name = " AUTH_TLS " ,
. crdestroy = tls_destroy_cred ,
. crmatch = tls_match ,
. crmarshal = tls_marshal ,
. crwrap_req = rpcauth_wrap_req_encode ,
. crrefresh = tls_refresh ,
. crvalidate = tls_validate ,
. crunwrap_resp = rpcauth_unwrap_resp_decode ,
} ;
static struct rpc_cred tls_cred = {
. cr_lru = LIST_HEAD_INIT ( tls_cred . cr_lru ) ,
. cr_auth = & tls_auth ,
. cr_ops = & tls_credops ,
. cr_count = REFCOUNT_INIT ( 2 ) ,
. cr_flags = 1UL < < RPCAUTH_CRED_UPTODATE ,
} ;