DmitriyLewen f1f4c3e8e6
refactor(nvd): use API instead of JSON feeds (#258)
Signed-off-by: knqyf263 <knqyf263@gmail.com>
Co-authored-by: knqyf263 <knqyf263@gmail.com>
2023-12-18 20:34:54 +04:00

233 lines
6.7 KiB
Go

package nvd
import (
"encoding/json"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"time"
"golang.org/x/xerrors"
"github.com/aquasecurity/vuln-list-update/utils"
)
const (
retry = 50
url20 = "https://services.nvd.nist.gov/rest/json/cves/2.0/"
apiDir = "api"
nvdTimeFormat = "2006-01-02T15:04:05"
maxResultsPerPage = 2000
retryAfter = 30 * time.Second
apiKeyEnvName = "NVD_API_KEY"
)
type Option func(*Updater)
func WithLastModEndDate(lastModEndDate time.Time) Option {
return func(u *Updater) {
u.lastModEndDate = lastModEndDate
}
}
func WithMaxResultsPerPage(maxResultsPerPage int) Option {
return func(u *Updater) {
u.maxResultsPerPage = maxResultsPerPage
}
}
func WithBaseURL(url string) Option {
return func(u *Updater) {
u.baseURL = url
}
}
func WithRetry(retry int) Option {
return func(u *Updater) {
u.retry = retry
}
}
func WithRetryAfter(retryAfter time.Duration) Option {
return func(u *Updater) {
u.retryAfter = retryAfter
}
}
type Updater struct {
baseURL string
apiKey string
maxResultsPerPage int
retry int
retryAfter time.Duration
lastModEndDate time.Time // time.Now() by default
}
func NewUpdater(opts ...Option) *Updater {
u := &Updater{
baseURL: url20,
apiKey: os.Getenv(apiKeyEnvName),
maxResultsPerPage: maxResultsPerPage,
retry: retry,
retryAfter: retryAfter,
lastModEndDate: time.Now().UTC(),
}
for _, opt := range opts {
opt(u)
}
return u
}
func (u Updater) Update() error {
intervals, err := TimeIntervals(u.lastModEndDate)
if err != nil {
return xerrors.Errorf("unable to build time intervals: %w", err)
}
for _, interval := range intervals {
slog.Info("Fetching NVD entries...", slog.String("start", interval.LastModStartDate),
slog.String("end", interval.LastModEndDate))
totalResults := 1 // Set a dummy value to start the loop
for startIndex := 0; startIndex < totalResults; startIndex += u.maxResultsPerPage {
if totalResults, err = u.saveEntry(interval, startIndex); err != nil {
return xerrors.Errorf("unable to save entry CVEs for %q: %w", interval, err)
}
slog.Info("Fetched NVD entries", slog.Int("total", totalResults), slog.Int("start_index", startIndex))
}
}
// Update last_updated.json at the end.
if err = utils.SetLastUpdatedDate(apiDir, u.lastModEndDate); err != nil {
return xerrors.Errorf("unable to update last_updated.json file: %w", err)
}
return nil
}
func (u Updater) saveEntry(interval TimeInterval, startIndex int) (int, error) {
entryURL, err := urlWithParams(u.baseURL, startIndex, u.maxResultsPerPage, interval)
if err != nil {
return 0, xerrors.Errorf("unable to get url with query parameters: %w", err)
}
entry, err := u.fetchEntry(entryURL)
if err != nil {
return 0, xerrors.Errorf("unable to get entry for %q: %w", entryURL, err)
}
for _, vuln := range entry.Vulnerabilities {
if err := utils.SaveCVEPerYear(filepath.Join(utils.VulnListDir(), apiDir), vuln.Cve.ID, vuln.Cve); err != nil {
return 0, xerrors.Errorf("unable to write %s: %w", vuln.Cve.ID, err)
}
}
return entry.TotalResults, nil
}
func (u Updater) fetchEntry(url string) (Entry, error) {
var entry Entry
r, err := u.fetchURL(url)
if err != nil {
return Entry{}, xerrors.Errorf("unable to fetch: %w", err)
} else if r == nil {
return Entry{}, xerrors.Errorf("unable to get entry from %q", url)
}
defer r.Close()
if err = json.NewDecoder(r).Decode(&entry); err != nil {
return Entry{}, xerrors.Errorf("unable to decode response for %q: %w", url, err)
}
return entry, nil
}
func (u Updater) fetchURL(url string) (io.ReadCloser, error) {
var c http.Client
for i := 0; i <= u.retry; i++ {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, xerrors.Errorf("unable to build request for %q: %w", url, err)
}
if u.apiKey != "" {
req.Header.Set("apiKey", u.apiKey)
}
resp, err := c.Do(req)
if err != nil {
slog.Error("Response error. Try to get the entry again.", slog.String("error", err.Error()))
continue
}
switch resp.StatusCode {
case http.StatusForbidden:
slog.Error("NVD rate limit. Wait to gain access.")
// NVD limits:
// Without API key: 5 requests / 30 seconds window
// With API key: 50 requests / 30 seconds window
time.Sleep(u.retryAfter)
continue
case http.StatusServiceUnavailable, http.StatusRequestTimeout, http.StatusBadGateway, http.StatusGatewayTimeout:
slog.Error("NVD API is unstable. Try to fetch URL again.", slog.String("status_code", resp.Status))
// NVD API works unstable
time.Sleep(time.Duration(i) * time.Second)
continue
case http.StatusOK:
return resp.Body, nil
default:
return nil, xerrors.Errorf("unexpected status code: %s", resp.Status)
}
}
return nil, xerrors.Errorf("unable to fetch url. Retry limit exceeded.")
}
// TimeIntervals returns time intervals for NVD API
// NVD API doesn't allow to get more than 120 days per request.
// So we need to split the time range into intervals.
func TimeIntervals(endTime time.Time) ([]TimeInterval, error) {
lastUpdatedDate, err := utils.GetLastUpdatedDate(apiDir)
if err != nil {
return nil, xerrors.Errorf("unable to get lastUpdatedDate: %w", err)
}
var intervals []TimeInterval
for endTime.Sub(lastUpdatedDate).Hours()/24 > 120 {
newLastUpdatedDate := lastUpdatedDate.Add(120 * 24 * time.Hour)
intervals = append(intervals, TimeInterval{
LastModStartDate: lastUpdatedDate.Format(nvdTimeFormat),
LastModEndDate: newLastUpdatedDate.Format(nvdTimeFormat),
})
lastUpdatedDate = newLastUpdatedDate
}
// fill latest interval
intervals = append(intervals, TimeInterval{
LastModStartDate: lastUpdatedDate.Format(nvdTimeFormat),
LastModEndDate: endTime.Format(nvdTimeFormat),
})
return intervals, nil
}
func urlWithParams(baseUrl string, startIndex, resultsPerPage int, interval TimeInterval) (string, error) {
u, err := url.Parse(baseUrl)
if err != nil {
return "", xerrors.Errorf("unable to parse %q base url: %w", baseUrl, err)
}
q := u.Query()
q.Set("lastModStartDate", interval.LastModStartDate)
q.Set("lastModEndDate", interval.LastModEndDate)
q.Set("startIndex", strconv.Itoa(startIndex))
q.Set("resultsPerPage", strconv.Itoa(resultsPerPage))
// NVD API doesn't work with escaped `:`
// So we only need to escape `+` for `Z`:
// https://nvd.nist.gov/developers/vulnerabilities:
// `Please note, if a positive Z value is used (such as +01:00 for Central European Time) then the "+" should be encoded in the request as "%2B".`
decoded, err := url.QueryUnescape(q.Encode())
if err != nil {
return "", xerrors.Errorf("unable to decode query params: %w", err)
}
u.RawQuery = decoded
return u.String(), nil
}