From c888c933a930ee2ba4e7bb0bf6678aaf45a9778a Mon Sep 17 00:00:00 2001
From: Thomas Desveaux <thomas.desveaux@dont-nod.com>
Date: Tue, 4 Jun 2024 08:45:56 +0200
Subject: [PATCH] Fix NuGet Package API for $filter with Id equality  (#31188)

Fixes issue when running `choco info pkgname` where `pkgname` is also a
substring of another package Id.

Relates to #31168

---

This might fix the issue linked, but I'd like to test it with more choco
commands before closing the issue in case I find other problems if
that's ok.

---------

Co-authored-by: KN4CK3R <admin@oldschoolhack.me>
---
 routers/api/packages/nuget/nuget.go          |  48 +++++----
 tests/integration/api_packages_nuget_test.go | 102 ++++++++++++++++---
 2 files changed, 115 insertions(+), 35 deletions(-)

diff --git a/routers/api/packages/nuget/nuget.go b/routers/api/packages/nuget/nuget.go
index 26b0ae226e..3633d0d007 100644
--- a/routers/api/packages/nuget/nuget.go
+++ b/routers/api/packages/nuget/nuget.go
@@ -96,20 +96,34 @@ func FeedCapabilityResource(ctx *context.Context) {
 	xmlResponse(ctx, http.StatusOK, Metadata)
 }
 
-var searchTermExtract = regexp.MustCompile(`'([^']+)'`)
+var (
+	searchTermExtract = regexp.MustCompile(`'([^']+)'`)
+	searchTermExact   = regexp.MustCompile(`\s+eq\s+'`)
+)
 
-func getSearchTerm(ctx *context.Context) string {
+func getSearchTerm(ctx *context.Context) packages_model.SearchValue {
 	searchTerm := strings.Trim(ctx.FormTrim("searchTerm"), "'")
-	if searchTerm == "" {
-		// $filter contains a query like:
-		// (((Id ne null) and substringof('microsoft',tolower(Id)))
-		// We don't support these queries, just extract the search term.
-		match := searchTermExtract.FindStringSubmatch(ctx.FormTrim("$filter"))
-		if len(match) == 2 {
-			searchTerm = strings.TrimSpace(match[1])
+	if searchTerm != "" {
+		return packages_model.SearchValue{
+			Value:      searchTerm,
+			ExactMatch: false,
 		}
 	}
-	return searchTerm
+
+	// $filter contains a query like:
+	// (((Id ne null) and substringof('microsoft',tolower(Id)))
+	// https://www.odata.org/documentation/odata-version-2-0/uri-conventions/ section 4.5
+	// We don't support these queries, just extract the search term.
+	filter := ctx.FormTrim("$filter")
+	match := searchTermExtract.FindStringSubmatch(filter)
+	if len(match) == 2 {
+		return packages_model.SearchValue{
+			Value:      strings.TrimSpace(match[1]),
+			ExactMatch: searchTermExact.MatchString(filter),
+		}
+	}
+
+	return packages_model.SearchValue{}
 }
 
 // https://github.com/NuGet/NuGet.Client/blob/dev/src/NuGet.Core/NuGet.Protocol/LegacyFeed/V2FeedQueryBuilder.cs
@@ -118,11 +132,9 @@ func SearchServiceV2(ctx *context.Context) {
 	paginator := db.NewAbsoluteListOptions(skip, take)
 
 	pvs, total, err := packages_model.SearchLatestVersions(ctx, &packages_model.PackageSearchOptions{
-		OwnerID: ctx.Package.Owner.ID,
-		Type:    packages_model.TypeNuGet,
-		Name: packages_model.SearchValue{
-			Value: getSearchTerm(ctx),
-		},
+		OwnerID:    ctx.Package.Owner.ID,
+		Type:       packages_model.TypeNuGet,
+		Name:       getSearchTerm(ctx),
 		IsInternal: optional.Some(false),
 		Paginator:  paginator,
 	})
@@ -169,10 +181,8 @@ func SearchServiceV2(ctx *context.Context) {
 // http://docs.oasis-open.org/odata/odata/v4.0/errata03/os/complete/part2-url-conventions/odata-v4.0-errata03-os-part2-url-conventions-complete.html#_Toc453752351
 func SearchServiceV2Count(ctx *context.Context) {
 	count, err := nuget_model.CountPackages(ctx, &packages_model.PackageSearchOptions{
-		OwnerID: ctx.Package.Owner.ID,
-		Name: packages_model.SearchValue{
-			Value: getSearchTerm(ctx),
-		},
+		OwnerID:    ctx.Package.Owner.ID,
+		Name:       getSearchTerm(ctx),
 		IsInternal: optional.Some(false),
 	})
 	if err != nil {
diff --git a/tests/integration/api_packages_nuget_test.go b/tests/integration/api_packages_nuget_test.go
index 83947ff967..630b4de3f9 100644
--- a/tests/integration/api_packages_nuget_test.go
+++ b/tests/integration/api_packages_nuget_test.go
@@ -429,22 +429,33 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`)
 
 	t.Run("SearchService", func(t *testing.T) {
 		cases := []struct {
-			Query           string
-			Skip            int
-			Take            int
-			ExpectedTotal   int64
-			ExpectedResults int
+			Query              string
+			Skip               int
+			Take               int
+			ExpectedTotal      int64
+			ExpectedResults    int
+			ExpectedExactMatch bool
 		}{
-			{"", 0, 0, 1, 1},
-			{"", 0, 10, 1, 1},
-			{"gitea", 0, 10, 0, 0},
-			{"test", 0, 10, 1, 1},
-			{"test", 1, 10, 1, 0},
+			{"", 0, 0, 4, 4, false},
+			{"", 0, 10, 4, 4, false},
+			{"gitea", 0, 10, 0, 0, false},
+			{"test", 0, 10, 1, 1, false},
+			{"test", 1, 10, 1, 0, false},
+			{"almost.similar", 0, 0, 3, 3, true},
 		}
 
-		req := NewRequestWithBody(t, "PUT", url, createPackage(packageName, "1.0.99")).
-			AddBasicAuth(user.Name)
-		MakeRequest(t, req, http.StatusCreated)
+		fakePackages := []string{
+			packageName,
+			"almost.similar.dependency",
+			"almost.similar",
+			"almost.similar.dependant",
+		}
+
+		for _, fakePackageName := range fakePackages {
+			req := NewRequestWithBody(t, "PUT", url, createPackage(fakePackageName, "1.0.99")).
+				AddBasicAuth(user.Name)
+			MakeRequest(t, req, http.StatusCreated)
+		}
 
 		t.Run("v2", func(t *testing.T) {
 			t.Run("Search()", func(t *testing.T) {
@@ -491,6 +502,63 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`)
 				}
 			})
 
+			t.Run("Packages()", func(t *testing.T) {
+				defer tests.PrintCurrentTest(t)()
+
+				t.Run("substringof", func(t *testing.T) {
+					defer tests.PrintCurrentTest(t)()
+
+					for i, c := range cases {
+						req := NewRequest(t, "GET", fmt.Sprintf("%s/Packages()?$filter=substringof('%s',tolower(Id))&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
+							AddBasicAuth(user.Name)
+						resp := MakeRequest(t, req, http.StatusOK)
+
+						var result FeedResponse
+						decodeXML(t, resp, &result)
+
+						assert.Equal(t, c.ExpectedTotal, result.Count, "case %d: unexpected total hits", i)
+						assert.Len(t, result.Entries, c.ExpectedResults, "case %d: unexpected result count", i)
+
+						req = NewRequest(t, "GET", fmt.Sprintf("%s/Packages()/$count?$filter=substringof('%s',tolower(Id))&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
+							AddBasicAuth(user.Name)
+						resp = MakeRequest(t, req, http.StatusOK)
+
+						assert.Equal(t, strconv.FormatInt(c.ExpectedTotal, 10), resp.Body.String(), "case %d: unexpected total hits", i)
+					}
+				})
+
+				t.Run("IdEq", func(t *testing.T) {
+					defer tests.PrintCurrentTest(t)()
+
+					for i, c := range cases {
+						if c.Query == "" {
+							// Ignore the `tolower(Id) eq ''` as it's unlikely to happen
+							continue
+						}
+						req := NewRequest(t, "GET", fmt.Sprintf("%s/Packages()?$filter=(tolower(Id) eq '%s')&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
+							AddBasicAuth(user.Name)
+						resp := MakeRequest(t, req, http.StatusOK)
+
+						var result FeedResponse
+						decodeXML(t, resp, &result)
+
+						expectedCount := 0
+						if c.ExpectedExactMatch {
+							expectedCount = 1
+						}
+
+						assert.Equal(t, int64(expectedCount), result.Count, "case %d: unexpected total hits", i)
+						assert.Len(t, result.Entries, expectedCount, "case %d: unexpected result count", i)
+
+						req = NewRequest(t, "GET", fmt.Sprintf("%s/Packages()/$count?$filter=(tolower(Id) eq '%s')&$skip=%d&$top=%d", url, c.Query, c.Skip, c.Take)).
+							AddBasicAuth(user.Name)
+						resp = MakeRequest(t, req, http.StatusOK)
+
+						assert.Equal(t, strconv.FormatInt(int64(expectedCount), 10), resp.Body.String(), "case %d: unexpected total hits", i)
+					}
+				})
+			})
+
 			t.Run("Next", func(t *testing.T) {
 				req := NewRequest(t, "GET", fmt.Sprintf("%s/Search()?searchTerm='test'&$skip=0&$top=1", url)).
 					AddBasicAuth(user.Name)
@@ -548,9 +616,11 @@ AAAjQmxvYgAAAGm7ENm9SGxMtAFVvPUsPJTF6PbtAAAAAFcVogEJAAAAAQAAAA==`)
 			})
 		})
 
-		req = NewRequest(t, "DELETE", fmt.Sprintf("%s/%s/%s", url, packageName, "1.0.99")).
-			AddBasicAuth(user.Name)
-		MakeRequest(t, req, http.StatusNoContent)
+		for _, fakePackageName := range fakePackages {
+			req := NewRequest(t, "DELETE", fmt.Sprintf("%s/%s/%s", url, fakePackageName, "1.0.99")).
+				AddBasicAuth(user.Name)
+			MakeRequest(t, req, http.StatusNoContent)
+		}
 	})
 
 	t.Run("RegistrationService", func(t *testing.T) {