// SPDX-License-Identifier: GPL-2.0-only /* * Copyright (c) 2021, 2022 Oracle. All rights reserved. * * The AUTH_TLS credential is used only to probe a remote peer * for RPC-over-TLS support. */ #include <linux/types.h> #include <linux/module.h> #include <linux/sunrpc/clnt.h> static const char *starttls_token = "STARTTLS"; static const size_t starttls_len = 8; static struct rpc_auth tls_auth; static struct rpc_cred tls_cred; static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, const void *obj) { } static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, void *obj) { return 0; } static const struct rpc_procinfo rpcproc_tls_probe = { .p_encode = tls_encode_probe, .p_decode = tls_decode_probe, }; static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data) { task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT; rpc_call_start(task); } static void rpc_tls_probe_call_done(struct rpc_task *task, void *data) { } static const struct rpc_call_ops rpc_tls_probe_ops = { .rpc_call_prepare = rpc_tls_probe_call_prepare, .rpc_call_done = rpc_tls_probe_call_done, }; static int tls_probe(struct rpc_clnt *clnt) { struct rpc_message msg = { .rpc_proc = &rpcproc_tls_probe, }; struct rpc_task_setup task_setup_data = { .rpc_client = clnt, .rpc_message = &msg, .rpc_op_cred = &tls_cred, .callback_ops = &rpc_tls_probe_ops, .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN, }; struct rpc_task *task; int status; task = rpc_run_task(&task_setup_data); if (IS_ERR(task)) return PTR_ERR(task); status = task->tk_status; rpc_put_task(task); return status; } static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { refcount_inc(&tls_auth.au_count); return &tls_auth; } static void tls_destroy(struct rpc_auth *auth) { } static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) { return get_rpccred(&tls_cred); } static void tls_destroy_cred(struct rpc_cred *cred) { } static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) { return 1; } static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr) { __be32 *p; p = xdr_reserve_space(xdr, 4 * XDR_UNIT); if (!p) return -EMSGSIZE; /* Credential */ *p++ = rpc_auth_tls; *p++ = xdr_zero; /* Verifier */ *p++ = rpc_auth_null; *p = xdr_zero; return 0; } static int tls_refresh(struct rpc_task *task) { set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); return 0; } static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr) { __be32 *p; void *str; p = xdr_inline_decode(xdr, XDR_UNIT); if (!p) return -EIO; if (*p != rpc_auth_null) return -EIO; if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len) return -EIO; if (memcmp(str, starttls_token, starttls_len)) return -EIO; return 0; } const struct rpc_authops authtls_ops = { .owner = THIS_MODULE, .au_flavor = RPC_AUTH_TLS, .au_name = "NULL", .create = tls_create, .destroy = tls_destroy, .lookup_cred = tls_lookup_cred, .ping = tls_probe, }; static struct rpc_auth tls_auth = { .au_cslack = NUL_CALLSLACK, .au_rslack = NUL_REPLYSLACK, .au_verfsize = NUL_REPLYSLACK, .au_ralign = NUL_REPLYSLACK, .au_ops = &authtls_ops, .au_flavor = RPC_AUTH_TLS, .au_count = REFCOUNT_INIT(1), }; static const struct rpc_credops tls_credops = { .cr_name = "AUTH_TLS", .crdestroy = tls_destroy_cred, .crmatch = tls_match, .crmarshal = tls_marshal, .crwrap_req = rpcauth_wrap_req_encode, .crrefresh = tls_refresh, .crvalidate = tls_validate, .crunwrap_resp = rpcauth_unwrap_resp_decode, }; static struct rpc_cred tls_cred = { .cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru), .cr_auth = &tls_auth, .cr_ops = &tls_credops, .cr_count = REFCOUNT_INIT(2), .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE, };