diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index cab8f3491c2b..7eb6a8754f19 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -355,17 +355,21 @@ static inline void sk_psock_restore_proto(struct sock *sk,
 					  struct sk_psock *psock)
 {
 	sk->sk_prot->unhash = psock->saved_unhash;
-	sk->sk_write_space = psock->saved_write_space;
 
 	if (psock->sk_proto) {
 		struct inet_connection_sock *icsk = inet_csk(sk);
 		bool has_ulp = !!icsk->icsk_ulp_data;
 
-		if (has_ulp)
-			tcp_update_ulp(sk, psock->sk_proto);
-		else
+		if (has_ulp) {
+			tcp_update_ulp(sk, psock->sk_proto,
+				       psock->saved_write_space);
+		} else {
 			sk->sk_prot = psock->sk_proto;
+			sk->sk_write_space = psock->saved_write_space;
+		}
 		psock->sk_proto = NULL;
+	} else {
+		sk->sk_write_space = psock->saved_write_space;
 	}
 }
 
diff --git a/include/net/tcp.h b/include/net/tcp.h
index b2367cfe0bda..830c89db1245 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -2132,7 +2132,8 @@ struct tcp_ulp_ops {
 	/* initialize ulp */
 	int (*init)(struct sock *sk);
 	/* update ulp */
-	void (*update)(struct sock *sk, struct proto *p);
+	void (*update)(struct sock *sk, struct proto *p,
+		       void (*write_space)(struct sock *sk));
 	/* cleanup ulp */
 	void (*release)(struct sock *sk);
 	/* diagnostic */
@@ -2147,7 +2148,8 @@ void tcp_unregister_ulp(struct tcp_ulp_ops *type);
 int tcp_set_ulp(struct sock *sk, const char *name);
 void tcp_get_available_ulp(char *buf, size_t len);
 void tcp_cleanup_ulp(struct sock *sk);
-void tcp_update_ulp(struct sock *sk, struct proto *p);
+void tcp_update_ulp(struct sock *sk, struct proto *p,
+		    void (*write_space)(struct sock *sk));
 
 #define MODULE_ALIAS_TCP_ULP(name)				\
 	__MODULE_INFO(alias, alias_userspace, name);		\
diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c
index 4849edb62d52..9168645b760e 100644
--- a/net/ipv4/tcp_ulp.c
+++ b/net/ipv4/tcp_ulp.c
@@ -96,17 +96,19 @@ void tcp_get_available_ulp(char *buf, size_t maxlen)
 	rcu_read_unlock();
 }
 
-void tcp_update_ulp(struct sock *sk, struct proto *proto)
+void tcp_update_ulp(struct sock *sk, struct proto *proto,
+		    void (*write_space)(struct sock *sk))
 {
 	struct inet_connection_sock *icsk = inet_csk(sk);
 
 	if (!icsk->icsk_ulp_ops) {
+		sk->sk_write_space = write_space;
 		sk->sk_prot = proto;
 		return;
 	}
 
 	if (icsk->icsk_ulp_ops->update)
-		icsk->icsk_ulp_ops->update(sk, proto);
+		icsk->icsk_ulp_ops->update(sk, proto, write_space);
 }
 
 void tcp_cleanup_ulp(struct sock *sk)
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 82d0beed8f07..7aba4ee77aba 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -798,15 +798,19 @@ out:
 	return rc;
 }
 
-static void tls_update(struct sock *sk, struct proto *p)
+static void tls_update(struct sock *sk, struct proto *p,
+		       void (*write_space)(struct sock *sk))
 {
 	struct tls_context *ctx;
 
 	ctx = tls_get_ctx(sk);
-	if (likely(ctx))
+	if (likely(ctx)) {
+		ctx->sk_write_space = write_space;
 		ctx->sk_proto = p;
-	else
+	} else {
 		sk->sk_prot = p;
+		sk->sk_write_space = write_space;
+	}
 }
 
 static int tls_get_info(const struct sock *sk, struct sk_buff *skb)