2021-11-17 15:34:35 +03:00
// Copyright 2021 The Gitea Authors. All rights reserved.
2022-11-27 21:20:29 +03:00
// SPDX-License-Identifier: MIT
2021-11-17 15:34:35 +03:00
package codeformat
import (
"bytes"
"errors"
"io"
"os"
"sort"
"strings"
)
var importPackageGroupOrders = map [ string ] int {
"" : 1 , // internal
"code.gitea.io/gitea/" : 2 ,
}
var errInvalidCommentBetweenImports = errors . New ( "comments between imported packages are invalid, please move comments to the end of the package line" )
2022-01-20 20:46:10 +03:00
var (
importBlockBegin = [ ] byte ( "\nimport (\n" )
importBlockEnd = [ ] byte ( "\n)" )
)
2021-11-17 15:34:35 +03:00
type importLineParsed struct {
group string
pkg string
content string
}
func parseImportLine ( line string ) ( * importLineParsed , error ) {
il := & importLineParsed { content : line }
p1 := strings . IndexRune ( line , '"' )
if p1 == - 1 {
return nil , errors . New ( "invalid import line: " + line )
}
p1 ++
p := strings . IndexRune ( line [ p1 : ] , '"' )
if p == - 1 {
return nil , errors . New ( "invalid import line: " + line )
}
p2 := p1 + p
il . pkg = line [ p1 : p2 ]
pDot := strings . IndexRune ( il . pkg , '.' )
pSlash := strings . IndexRune ( il . pkg , '/' )
if pDot != - 1 && pDot < pSlash {
il . group = "domain-package"
}
for groupName := range importPackageGroupOrders {
if groupName == "" {
continue // skip internal
}
if strings . HasPrefix ( il . pkg , groupName ) {
il . group = groupName
}
}
return il , nil
}
2022-01-20 20:46:10 +03:00
type (
importLineGroup [ ] * importLineParsed
importLineGroupMap map [ string ] importLineGroup
)
2021-11-17 15:34:35 +03:00
func formatGoImports ( contentBytes [ ] byte ) ( [ ] byte , error ) {
p1 := bytes . Index ( contentBytes , importBlockBegin )
if p1 == - 1 {
return nil , nil
}
p1 += len ( importBlockBegin )
p := bytes . Index ( contentBytes [ p1 : ] , importBlockEnd )
if p == - 1 {
return nil , nil
}
p2 := p1 + p
importGroups := importLineGroupMap { }
r := bytes . NewBuffer ( contentBytes [ p1 : p2 ] )
eof := false
for ! eof {
line , err := r . ReadString ( '\n' )
eof = err == io . EOF
if err != nil && ! eof {
return nil , err
}
line = strings . TrimSpace ( line )
if line != "" {
if strings . HasPrefix ( line , "//" ) || strings . HasPrefix ( line , "/*" ) {
return nil , errInvalidCommentBetweenImports
}
importLine , err := parseImportLine ( line )
if err != nil {
return nil , err
}
importGroups [ importLine . group ] = append ( importGroups [ importLine . group ] , importLine )
}
}
var groupNames [ ] string
for groupName , importLines := range importGroups {
groupNames = append ( groupNames , groupName )
sort . Slice ( importLines , func ( i , j int ) bool {
return strings . Compare ( importLines [ i ] . pkg , importLines [ j ] . pkg ) < 0
} )
}
sort . Slice ( groupNames , func ( i , j int ) bool {
n1 := groupNames [ i ]
n2 := groupNames [ j ]
o1 := importPackageGroupOrders [ n1 ]
o2 := importPackageGroupOrders [ n2 ]
if o1 != 0 && o2 != 0 {
return o1 < o2
}
if o1 == 0 && o2 == 0 {
return strings . Compare ( n1 , n2 ) < 0
}
return o1 != 0
} )
formattedBlock := bytes . Buffer { }
for _ , groupName := range groupNames {
hasNormalImports := false
hasDummyImports := false
// non-dummy import comes first
for _ , importLine := range importGroups [ groupName ] {
if strings . HasPrefix ( importLine . content , "_" ) {
hasDummyImports = true
} else {
formattedBlock . WriteString ( "\t" + importLine . content + "\n" )
hasNormalImports = true
}
}
// dummy (_ "pkg") comes later
if hasDummyImports {
if hasNormalImports {
formattedBlock . WriteString ( "\n" )
}
for _ , importLine := range importGroups [ groupName ] {
if strings . HasPrefix ( importLine . content , "_" ) {
formattedBlock . WriteString ( "\t" + importLine . content + "\n" )
}
}
}
formattedBlock . WriteString ( "\n" )
}
formattedBlockBytes := bytes . TrimRight ( formattedBlock . Bytes ( ) , "\n" )
var formattedBytes [ ] byte
formattedBytes = append ( formattedBytes , contentBytes [ : p1 ] ... )
formattedBytes = append ( formattedBytes , formattedBlockBytes ... )
formattedBytes = append ( formattedBytes , contentBytes [ p2 : ] ... )
return formattedBytes , nil
}
2022-01-20 20:46:10 +03:00
// FormatGoImports format the imports by our rules (see unit tests)
2022-10-16 11:11:17 +03:00
func FormatGoImports ( file string , doWriteFile bool ) error {
2021-11-17 15:34:35 +03:00
f , err := os . Open ( file )
if err != nil {
return err
}
var contentBytes [ ] byte
{
defer f . Close ( )
contentBytes , err = io . ReadAll ( f )
if err != nil {
return err
}
}
formattedBytes , err := formatGoImports ( contentBytes )
if err != nil {
return err
}
if formattedBytes == nil {
return nil
}
if bytes . Equal ( contentBytes , formattedBytes ) {
return nil
}
2022-02-06 18:44:30 +03:00
if doWriteFile {
f , err = os . OpenFile ( file , os . O_TRUNC | os . O_WRONLY , 0 o644 )
if err != nil {
return err
}
defer f . Close ( )
_ , err = f . Write ( formattedBytes )
2021-11-17 15:34:35 +03:00
return err
}
2022-02-06 18:44:30 +03:00
2021-11-17 15:34:35 +03:00
return err
}