diff --git a/.vscode/settings.json b/.vscode/settings.json index 0d84518..875b190 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,6 +7,7 @@ "aterm", "blocklist", "blocklists", + "brotli", "bthomehub", "bthub", "chartboost", diff --git a/main.go b/main.go index c55a87e..9a3315f 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "errors" "fmt" "html/template" @@ -9,6 +10,7 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -16,24 +18,32 @@ import ( "gorm.io/gorm" ) -func main() { - log.Println("Starting FeedVault...") - db, err := gorm.Open(sqlite.Open("feedvault.db"), &gorm.Config{}) +var db *gorm.DB + +// Initialize the database +func init() { + var err error + db, err = gorm.Open(sqlite.Open("feedvault.db"), &gorm.Config{}) if err != nil { panic("Failed to connect to database") } - - sqlDB, err := db.DB() - if err != nil { - panic("Failed to get database connection") + if db == nil { + panic("db nil") } - defer sqlDB.Close() + log.Println("Connected to database") // Migrate the schema - err = db.AutoMigrate(&Feed{}, &Item{}, &Person{}, &Image{}, &Enclosure{}, &DublinCoreExtension{}, &ITunesFeedExtension{}, &ITunesItemExtension{}, &ITunesCategory{}, &ITunesOwner{}, &Extension{}) + err = db.AutoMigrate(&BadURLsMeta{}, &BadURLs{}, &Feed{}, &Item{}, &Person{}, &Image{}, &Enclosure{}, &DublinCoreExtension{}, &ITunesFeedExtension{}, &ITunesItemExtension{}, &ITunesCategory{}, &ITunesOwner{}, &Extension{}) if err != nil { panic("Failed to migrate the database") } +} + +func main() { + log.Println("Starting FeedVault...") + + // Scrape the bad URLs in the background + scrapeBadURLs() // Create a new router r := chi.NewRouter() @@ -59,6 +69,85 @@ func main() { http.ListenAndServe("127.0.0.1:8000", r) } +func scrapeBadURLs() { + // TODO: We should only scrape the bad URLs if the file has been updated + filterListURLs := []string{ + "https://malware-filter.gitlab.io/malware-filter/phishing-filter-dnscrypt-blocked-names.txt", + "https://malware-filter.gitlab.io/malware-filter/urlhaus-filter-dnscrypt-blocked-names-online.txt", + } + + // Scrape the bad URLs + badURLs := []BadURLs{} + for _, url := range filterListURLs { + // Check if we have scraped the bad URLs in the last 24 hours + var meta BadURLsMeta + db.Where("url = ?", url).First(&meta) + if time.Since(meta.LastScraped).Hours() < 24 { + log.Printf("%s was last scraped %.1f hours ago\n", url, time.Since(meta.LastScraped).Hours()) + continue + } + + // Create the meta if it doesn't exist + if meta.ID == 0 { + meta = BadURLsMeta{URL: url} + db.Create(&meta) + } + + // Update the last scraped time + db.Model(&meta).Update("last_scraped", time.Now()) + + // Get the filter list + resp, err := http.Get(url) + if err != nil { + log.Println("Failed to get filter list:", err) + continue + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "#") { + log.Println("Comment:", line) + continue + } + + // Skip the URL if it already exists in the database + var count int64 + db.Model(&BadURLs{}).Where("url = ?", line).Count(&count) + if count > 0 { + log.Println("URL already exists:", line) + continue + } + + // Add the bad URL to the list + badURLs = append(badURLs, BadURLs{URL: line, Active: true}) + } + + if err := scanner.Err(); err != nil { + log.Println("Failed to scan filter list:", err) + } + } + + if len(badURLs) == 0 { + log.Println("No new URLs found in", len(filterListURLs), "filter lists") + return + } + + // Log how many bad URLs we found + log.Println("Found", len(badURLs), "bad URLs") + + // Mark all the bad URLs as inactive if we have any in the database + var count int64 + db.Model(&BadURLs{}).Count(&count) + if count > 0 { + db.Model(&BadURLs{}).Update("active", false) + } + + // Save the bad URLs to the database + db.Create(&badURLs) +} + func renderPage(w http.ResponseWriter, title, description, keywords, author, url, templateName string) { data := TemplateData{ Title: title, @@ -208,6 +297,13 @@ func validateURL(feed_url string) error { } } + // Check if the domain is in BadURLs + var count int64 + db.Model(&BadURLs{}).Where("url = ?", domain).Count(&count) + if count > 0 { + return errors.New("URL is in the bad URLs list") + } + // Don't allow URLs that end with .local if strings.HasSuffix(domain, ".local") { return errors.New("URLs ending with .local are not allowed") diff --git a/main_test.go b/main_test.go index ba94003..4e51aa2 100644 --- a/main_test.go +++ b/main_test.go @@ -224,8 +224,39 @@ func TestIndexHandler(t *testing.T) { // Check the response contains the expected string. shouldContain := "Input the URLs of the feeds you wish to archive below. You can add as many as needed, and access them through the website or API. Alternatively, include links to .opml files, and the feeds within will be archived." - if rr.Body.String() != shouldContain { + body := rr.Body.String() + if !assert.Contains(t, body, shouldContain) { t.Errorf("handler returned unexpected body: got %v want %v", - rr.Body.String(), shouldContain) + body, shouldContain) + } +} + +func TestApiHandler(t *testing.T) { + // Create a request to pass to our handler. + req, err := http.NewRequest("GET", "/api", nil) + if err != nil { + t.Fatal(err) + } + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + handler := http.HandlerFunc(ApiHandler) + + // Our handlers satisfy http.Handler, so we can call their ServeHTTP method + // directly and pass in our Request and ResponseRecorder. + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + // Check the response contains the expected string. + shouldContain := "

Here be dragons.

" + body := rr.Body.String() + if !assert.Contains(t, body, shouldContain) { + t.Errorf("handler returned unexpected body: got %v want %v", + body, shouldContain) } } diff --git a/models.go b/models.go index dc1db16..e8dc6dc 100644 --- a/models.go +++ b/models.go @@ -178,3 +178,15 @@ type ParseResult struct { func (d *TemplateData) GetDatabaseSizeAndFeedCount() { d.DatabaseSize = GetDBSize() } + +type BadURLs struct { + gorm.Model + URL string `json:"url"` + Active bool `json:"active"` +} + +type BadURLsMeta struct { + gorm.Model + URL string `json:"url"` + LastScraped time.Time `json:"lastScraped"` +}