diff --git a/net/sctp/socket.c b/net/sctp/socket.c index 68691d2b3167..bf089e59c792 100644 --- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -1606,6 +1606,61 @@ static int sctp_error(struct sock *sk, int flags, int err) static int sctp_msghdr_parse(const struct msghdr *msg, struct sctp_cmsgs *cmsgs); +static int sctp_sendmsg_parse(struct sock *sk, struct sctp_cmsgs *cmsgs, + struct sctp_sndrcvinfo *srinfo, + const struct msghdr *msg, size_t msg_len) +{ + __u16 sflags; + int err; + + if (sctp_sstate(sk, LISTENING) && sctp_style(sk, TCP)) + return -EPIPE; + + if (msg_len > sk->sk_sndbuf) + return -EMSGSIZE; + + memset(cmsgs, 0, sizeof(*cmsgs)); + err = sctp_msghdr_parse(msg, cmsgs); + if (err) { + pr_debug("%s: msghdr parse err:%x\n", __func__, err); + return err; + } + + memset(srinfo, 0, sizeof(*srinfo)); + if (cmsgs->srinfo) { + srinfo->sinfo_stream = cmsgs->srinfo->sinfo_stream; + srinfo->sinfo_flags = cmsgs->srinfo->sinfo_flags; + srinfo->sinfo_ppid = cmsgs->srinfo->sinfo_ppid; + srinfo->sinfo_context = cmsgs->srinfo->sinfo_context; + srinfo->sinfo_assoc_id = cmsgs->srinfo->sinfo_assoc_id; + srinfo->sinfo_timetolive = cmsgs->srinfo->sinfo_timetolive; + } + + if (cmsgs->sinfo) { + srinfo->sinfo_stream = cmsgs->sinfo->snd_sid; + srinfo->sinfo_flags = cmsgs->sinfo->snd_flags; + srinfo->sinfo_ppid = cmsgs->sinfo->snd_ppid; + srinfo->sinfo_context = cmsgs->sinfo->snd_context; + srinfo->sinfo_assoc_id = cmsgs->sinfo->snd_assoc_id; + } + + sflags = srinfo->sinfo_flags; + if (!sflags && msg_len) + return 0; + + if (sctp_style(sk, TCP) && (sflags & (SCTP_EOF | SCTP_ABORT))) + return -EINVAL; + + if (((sflags & SCTP_EOF) && msg_len > 0) || + (!(sflags & (SCTP_EOF | SCTP_ABORT)) && msg_len == 0)) + return -EINVAL; + + if ((sflags & SCTP_ADDR_OVER) && !msg->msg_name) + return -EINVAL; + + return 0; +} + static int sctp_sendmsg_new_asoc(struct sock *sk, __u16 sflags, struct sctp_cmsgs *cmsgs, union sctp_addr *daddr, @@ -1839,39 +1894,23 @@ static union sctp_addr *sctp_sendmsg_get_daddr(struct sock *sk, static int sctp_sendmsg(struct sock *sk, struct msghdr *msg, size_t msg_len) { - struct sctp_sock *sp; - struct sctp_endpoint *ep; + struct sctp_endpoint *ep = sctp_sk(sk)->ep; struct sctp_association *new_asoc = NULL, *asoc = NULL; struct sctp_transport *transport, *chunk_tp; - struct sctp_sndrcvinfo default_sinfo; - struct sctp_sndrcvinfo *sinfo; - struct sctp_initmsg *sinit; + struct sctp_sndrcvinfo _sinfo, *sinfo; sctp_assoc_t associd = 0; struct sctp_cmsgs cmsgs = { NULL }; - bool fill_sinfo_ttl = false; __u16 sinfo_flags = 0; union sctp_addr *daddr; int err; - err = 0; - sp = sctp_sk(sk); - ep = sp->ep; - - pr_debug("%s: sk:%p, msg:%p, msg_len:%zu ep:%p\n", __func__, sk, - msg, msg_len, ep); - - /* We cannot send a message over a TCP-style listening socket. */ - if (sctp_style(sk, TCP) && sctp_sstate(sk, LISTENING)) { - err = -EPIPE; + /* Parse and get snd_info */ + err = sctp_sendmsg_parse(sk, &cmsgs, &_sinfo, msg, msg_len); + if (err) goto out_nounlock; - } - /* Parse out the SCTP CMSGs. */ - err = sctp_msghdr_parse(msg, &cmsgs); - if (err) { - pr_debug("%s: msghdr parse err:%x\n", __func__, err); - goto out_nounlock; - } + sinfo = &_sinfo; + sinfo_flags = sinfo->sinfo_flags; /* Get daddr from msg */ daddr = sctp_sendmsg_get_daddr(sk, msg, &cmsgs); @@ -1880,58 +1919,6 @@ static int sctp_sendmsg(struct sock *sk, struct msghdr *msg, size_t msg_len) goto out_nounlock; } - sinit = cmsgs.init; - if (cmsgs.sinfo != NULL) { - memset(&default_sinfo, 0, sizeof(default_sinfo)); - default_sinfo.sinfo_stream = cmsgs.sinfo->snd_sid; - default_sinfo.sinfo_flags = cmsgs.sinfo->snd_flags; - default_sinfo.sinfo_ppid = cmsgs.sinfo->snd_ppid; - default_sinfo.sinfo_context = cmsgs.sinfo->snd_context; - default_sinfo.sinfo_assoc_id = cmsgs.sinfo->snd_assoc_id; - - sinfo = &default_sinfo; - fill_sinfo_ttl = true; - } else { - sinfo = cmsgs.srinfo; - } - /* Did the user specify SNDINFO/SNDRCVINFO? */ - if (sinfo) { - sinfo_flags = sinfo->sinfo_flags; - associd = sinfo->sinfo_assoc_id; - } - - pr_debug("%s: msg_len:%zu, sinfo_flags:0x%x\n", __func__, - msg_len, sinfo_flags); - - /* SCTP_EOF or SCTP_ABORT cannot be set on a TCP-style socket. */ - if (sctp_style(sk, TCP) && (sinfo_flags & (SCTP_EOF | SCTP_ABORT))) { - err = -EINVAL; - goto out_nounlock; - } - - /* If SCTP_EOF is set, no data can be sent. Disallow sending zero - * length messages when SCTP_EOF|SCTP_ABORT is not set. - * If SCTP_ABORT is set, the message length could be non zero with - * the msg_iov set to the user abort reason. - */ - if (((sinfo_flags & SCTP_EOF) && (msg_len > 0)) || - (!(sinfo_flags & (SCTP_EOF|SCTP_ABORT)) && (msg_len == 0))) { - err = -EINVAL; - goto out_nounlock; - } - - /* If SCTP_ADDR_OVER is set, there must be an address - * specified in msg_name. - */ - if ((sinfo_flags & SCTP_ADDR_OVER) && (!msg->msg_name)) { - err = -EINVAL; - goto out_nounlock; - } - - transport = NULL; - - pr_debug("%s: about to look up association\n", __func__); - lock_sock(sk); /* If a msg_name has been specified, assume this is to be used. */ @@ -1964,36 +1951,15 @@ static int sctp_sendmsg(struct sock *sk, struct msghdr *msg, size_t msg_len) new_asoc = asoc; } - /* ASSERT: we have a valid association at this point. */ - pr_debug("%s: we have a valid association\n", __func__); + if (!cmsgs.srinfo && !cmsgs.sinfo) { + sinfo->sinfo_stream = asoc->default_stream; + sinfo->sinfo_ppid = asoc->default_ppid; + sinfo->sinfo_context = asoc->default_context; + sinfo->sinfo_assoc_id = sctp_assoc2id(asoc); + } - if (!sinfo) { - /* If the user didn't specify SNDINFO/SNDRCVINFO, make up - * one with some defaults. - */ - memset(&default_sinfo, 0, sizeof(default_sinfo)); - default_sinfo.sinfo_stream = asoc->default_stream; - default_sinfo.sinfo_flags = asoc->default_flags; - default_sinfo.sinfo_ppid = asoc->default_ppid; - default_sinfo.sinfo_context = asoc->default_context; - default_sinfo.sinfo_timetolive = asoc->default_timetolive; - default_sinfo.sinfo_assoc_id = sctp_assoc2id(asoc); - - sinfo = &default_sinfo; - } else if (fill_sinfo_ttl) { - /* In case SNDINFO was specified, we still need to fill - * it with a default ttl from the assoc here. - */ + if (!cmsgs.srinfo) sinfo->sinfo_timetolive = asoc->default_timetolive; - } - - /* API 7.1.7, the sndbuf size per association bounds the - * maximum size of data that can be sent in a single send call. - */ - if (msg_len > sk->sk_sndbuf) { - err = -EMSGSIZE; - goto out_free; - } /* If an address is passed with the sendto/sendmsg call, it is used * to override the primary destination address in the TCP model, or