2019-10-15 14:39:51 +01:00
// Copyright 2019 The Gitea Authors. All rights reserved.
2022-11-27 13:20:29 -05:00
// SPDX-License-Identifier: MIT
2019-10-15 14:39:51 +01:00
// This code is heavily inspired by the archived gofacebook/gracenet/net.go handler
2021-08-24 11:47:09 -05:00
//go:build !windows
2019-10-15 14:39:51 +01:00
package graceful
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"code.gitea.io/gitea/modules/log"
2019-11-24 02:11:24 +00:00
"code.gitea.io/gitea/modules/setting"
2020-08-11 21:05:34 +01:00
"code.gitea.io/gitea/modules/util"
2019-10-15 14:39:51 +01:00
)
const (
listenFDs = "LISTEN_FDS"
startFD = 3
2022-08-14 05:31:33 +08:00
unlinkFDs = "GITEA_UNLINK_FDS"
2019-10-15 14:39:51 +01:00
)
// In order to keep the working directory the same as when we started we record
// it at startup.
var originalWD , _ = os . Getwd ( )
var (
once = sync . Once { }
mutex = sync . Mutex { }
2022-08-14 05:31:33 +08:00
providedListenersToUnlink = [ ] bool { }
activeListenersToUnlink = [ ] bool { }
providedListeners = [ ] net . Listener { }
activeListeners = [ ] net . Listener { }
2019-10-15 14:39:51 +01:00
)
func getProvidedFDs ( ) ( savedErr error ) {
// Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error
once . Do ( func ( ) {
mutex . Lock ( )
defer mutex . Unlock ( )
numFDs := os . Getenv ( listenFDs )
if numFDs == "" {
return
}
n , err := strconv . Atoi ( numFDs )
if err != nil {
2022-10-24 21:29:17 +02:00
savedErr = fmt . Errorf ( "%s is not a number: %s. Err: %w" , listenFDs , numFDs , err )
2019-10-15 14:39:51 +01:00
return
}
2022-08-14 05:31:33 +08:00
fdsToUnlinkStr := strings . Split ( os . Getenv ( unlinkFDs ) , "," )
providedListenersToUnlink = make ( [ ] bool , n )
for _ , fdStr := range fdsToUnlinkStr {
i , err := strconv . Atoi ( fdStr )
if err != nil || i < 0 || i >= n {
continue
}
providedListenersToUnlink [ i ] = true
}
2019-10-15 14:39:51 +01:00
for i := startFD ; i < n + startFD ; i ++ {
file := os . NewFile ( uintptr ( i ) , fmt . Sprintf ( "listener_FD%d" , i ) )
l , err := net . FileListener ( file )
if err == nil {
// Close the inherited file if it's a listener
if err = file . Close ( ) ; err != nil {
savedErr = fmt . Errorf ( "error closing provided socket fd %d: %s" , i , err )
return
}
providedListeners = append ( providedListeners , l )
continue
}
// If needed we can handle packetconns here.
2022-10-24 21:29:17 +02:00
savedErr = fmt . Errorf ( "Error getting provided socket fd %d: %w" , i , err )
2019-10-15 14:39:51 +01:00
return
}
} )
return savedErr
}
// CloseProvidedListeners closes all unused provided listeners.
func CloseProvidedListeners ( ) error {
mutex . Lock ( )
defer mutex . Unlock ( )
var returnableError error
for _ , l := range providedListeners {
err := l . Close ( )
if err != nil {
log . Error ( "Error in closing unused provided listener: %v" , err )
if returnableError != nil {
2022-10-24 21:29:17 +02:00
returnableError = fmt . Errorf ( "%v & %w" , returnableError , err )
2019-10-15 14:39:51 +01:00
} else {
returnableError = err
}
}
}
providedListeners = [ ] net . Listener { }
return returnableError
}
// GetListener obtains a listener for the local network address. The network must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It
// returns an provided net.Listener for the matching network and address, or
// creates a new one using net.Listen.
func GetListener ( network , address string ) ( net . Listener , error ) {
// Add a deferral to say that we've tried to grab a listener
2019-12-15 09:51:28 +00:00
defer GetManager ( ) . InformCleanup ( )
2019-10-15 14:39:51 +01:00
switch network {
case "tcp" , "tcp4" , "tcp6" :
tcpAddr , err := net . ResolveTCPAddr ( network , address )
if err != nil {
return nil , err
}
return GetListenerTCP ( network , tcpAddr )
case "unix" , "unixpacket" :
unixAddr , err := net . ResolveUnixAddr ( network , address )
if err != nil {
return nil , err
}
return GetListenerUnix ( network , unixAddr )
default :
return nil , net . UnknownNetworkError ( network )
}
}
// GetListenerTCP announces on the local network address. The network must be:
// "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the
// matching network and address, or creates a new one using net.ListenTCP.
func GetListenerTCP ( network string , address * net . TCPAddr ) ( * net . TCPListener , error ) {
if err := getProvidedFDs ( ) ; err != nil {
return nil , err
}
mutex . Lock ( )
defer mutex . Unlock ( )
// look for a provided listener
for i , l := range providedListeners {
if isSameAddr ( l . Addr ( ) , address ) {
providedListeners = append ( providedListeners [ : i ] , providedListeners [ i + 1 : ] ... )
2022-08-14 05:31:33 +08:00
needsUnlink := providedListenersToUnlink [ i ]
providedListenersToUnlink = append ( providedListenersToUnlink [ : i ] , providedListenersToUnlink [ i + 1 : ] ... )
2019-10-15 14:39:51 +01:00
activeListeners = append ( activeListeners , l )
2022-08-14 05:31:33 +08:00
activeListenersToUnlink = append ( activeListenersToUnlink , needsUnlink )
2019-10-15 14:39:51 +01:00
return l . ( * net . TCPListener ) , nil
}
}
// no provided listener for this address -> make a fresh listener
l , err := net . ListenTCP ( network , address )
if err != nil {
return nil , err
}
activeListeners = append ( activeListeners , l )
2022-08-14 05:31:33 +08:00
activeListenersToUnlink = append ( activeListenersToUnlink , false )
2019-10-15 14:39:51 +01:00
return l , nil
}
// GetListenerUnix announces on the local network address. The network must be:
// "unix" or "unixpacket". It returns a provided net.Listener for the
// matching network and address, or creates a new one using net.ListenUnix.
func GetListenerUnix ( network string , address * net . UnixAddr ) ( * net . UnixListener , error ) {
if err := getProvidedFDs ( ) ; err != nil {
return nil , err
}
mutex . Lock ( )
defer mutex . Unlock ( )
// look for a provided listener
for i , l := range providedListeners {
if isSameAddr ( l . Addr ( ) , address ) {
providedListeners = append ( providedListeners [ : i ] , providedListeners [ i + 1 : ] ... )
2022-08-14 05:31:33 +08:00
needsUnlink := providedListenersToUnlink [ i ]
providedListenersToUnlink = append ( providedListenersToUnlink [ : i ] , providedListenersToUnlink [ i + 1 : ] ... )
activeListenersToUnlink = append ( activeListenersToUnlink , needsUnlink )
2019-10-15 14:39:51 +01:00
activeListeners = append ( activeListeners , l )
2019-11-24 02:11:24 +00:00
unixListener := l . ( * net . UnixListener )
2022-08-14 05:31:33 +08:00
if needsUnlink {
unixListener . SetUnlinkOnClose ( true )
}
2019-11-24 02:11:24 +00:00
return unixListener , nil
2019-10-15 14:39:51 +01:00
}
}
// make a fresh listener
2020-08-11 21:05:34 +01:00
if err := util . Remove ( address . Name ) ; err != nil && ! os . IsNotExist ( err ) {
2022-10-24 21:29:17 +02:00
return nil , fmt . Errorf ( "Failed to remove unix socket %s: %w" , address . Name , err )
2019-11-24 02:11:24 +00:00
}
2019-10-15 14:39:51 +01:00
l , err := net . ListenUnix ( network , address )
if err != nil {
return nil , err
}
2019-11-24 02:11:24 +00:00
fileMode := os . FileMode ( setting . UnixSocketPermission )
if err = os . Chmod ( address . Name , fileMode ) ; err != nil {
2022-10-24 21:29:17 +02:00
return nil , fmt . Errorf ( "Failed to set permission of unix socket to %s: %w" , fileMode . String ( ) , err )
2019-11-24 02:11:24 +00:00
}
2019-10-15 14:39:51 +01:00
activeListeners = append ( activeListeners , l )
2022-08-14 05:31:33 +08:00
activeListenersToUnlink = append ( activeListenersToUnlink , true )
2019-10-15 14:39:51 +01:00
return l , nil
}
func isSameAddr ( a1 , a2 net . Addr ) bool {
// If the addresses are not on the same network fail.
if a1 . Network ( ) != a2 . Network ( ) {
return false
}
// If the two addresses have the same string representation they're equal
a1s := a1 . String ( )
a2s := a2 . String ( )
if a1s == a2s {
return true
}
// This allows for ipv6 vs ipv4 local addresses to compare as equal. This
// scenario is common when listening on localhost.
const ipv6prefix = "[::]"
a1s = strings . TrimPrefix ( a1s , ipv6prefix )
a2s = strings . TrimPrefix ( a2s , ipv6prefix )
const ipv4prefix = "0.0.0.0"
a1s = strings . TrimPrefix ( a1s , ipv4prefix )
a2s = strings . TrimPrefix ( a2s , ipv4prefix )
return a1s == a2s
}
func getActiveListeners ( ) [ ] net . Listener {
mutex . Lock ( )
defer mutex . Unlock ( )
listeners := make ( [ ] net . Listener , len ( activeListeners ) )
copy ( listeners , activeListeners )
return listeners
}
2022-08-14 05:31:33 +08:00
func getActiveListenersToUnlink ( ) [ ] bool {
mutex . Lock ( )
defer mutex . Unlock ( )
listenersToUnlink := make ( [ ] bool , len ( activeListenersToUnlink ) )
copy ( listenersToUnlink , activeListenersToUnlink )
return listenersToUnlink
}