Check if the domain is on a blocklist

This commit is contained in:
Joakim Hellsén 2024-02-04 21:23:21 +01:00
commit 11035584af
4 changed files with 151 additions and 11 deletions

View file

@ -7,6 +7,7 @@
"aterm", "aterm",
"blocklist", "blocklist",
"blocklists", "blocklists",
"brotli",
"bthomehub", "bthomehub",
"bthub", "bthub",
"chartboost", "chartboost",

114
main.go
View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"bufio"
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
@ -9,6 +10,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
@ -16,24 +18,32 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
func main() { var db *gorm.DB
log.Println("Starting FeedVault...")
db, err := gorm.Open(sqlite.Open("feedvault.db"), &gorm.Config{}) // Initialize the database
func init() {
var err error
db, err = gorm.Open(sqlite.Open("feedvault.db"), &gorm.Config{})
if err != nil { if err != nil {
panic("Failed to connect to database") panic("Failed to connect to database")
} }
if db == nil {
sqlDB, err := db.DB() panic("db nil")
if err != nil {
panic("Failed to get database connection")
} }
defer sqlDB.Close() log.Println("Connected to database")
// Migrate the schema // Migrate the schema
err = db.AutoMigrate(&Feed{}, &Item{}, &Person{}, &Image{}, &Enclosure{}, &DublinCoreExtension{}, &ITunesFeedExtension{}, &ITunesItemExtension{}, &ITunesCategory{}, &ITunesOwner{}, &Extension{}) err = db.AutoMigrate(&BadURLsMeta{}, &BadURLs{}, &Feed{}, &Item{}, &Person{}, &Image{}, &Enclosure{}, &DublinCoreExtension{}, &ITunesFeedExtension{}, &ITunesItemExtension{}, &ITunesCategory{}, &ITunesOwner{}, &Extension{})
if err != nil { if err != nil {
panic("Failed to migrate the database") panic("Failed to migrate the database")
} }
}
func main() {
log.Println("Starting FeedVault...")
// Scrape the bad URLs in the background
scrapeBadURLs()
// Create a new router // Create a new router
r := chi.NewRouter() r := chi.NewRouter()
@ -59,6 +69,85 @@ func main() {
http.ListenAndServe("127.0.0.1:8000", r) http.ListenAndServe("127.0.0.1:8000", r)
} }
func scrapeBadURLs() {
// TODO: We should only scrape the bad URLs if the file has been updated
filterListURLs := []string{
"https://malware-filter.gitlab.io/malware-filter/phishing-filter-dnscrypt-blocked-names.txt",
"https://malware-filter.gitlab.io/malware-filter/urlhaus-filter-dnscrypt-blocked-names-online.txt",
}
// Scrape the bad URLs
badURLs := []BadURLs{}
for _, url := range filterListURLs {
// Check if we have scraped the bad URLs in the last 24 hours
var meta BadURLsMeta
db.Where("url = ?", url).First(&meta)
if time.Since(meta.LastScraped).Hours() < 24 {
log.Printf("%s was last scraped %.1f hours ago\n", url, time.Since(meta.LastScraped).Hours())
continue
}
// Create the meta if it doesn't exist
if meta.ID == 0 {
meta = BadURLsMeta{URL: url}
db.Create(&meta)
}
// Update the last scraped time
db.Model(&meta).Update("last_scraped", time.Now())
// Get the filter list
resp, err := http.Get(url)
if err != nil {
log.Println("Failed to get filter list:", err)
continue
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
log.Println("Comment:", line)
continue
}
// Skip the URL if it already exists in the database
var count int64
db.Model(&BadURLs{}).Where("url = ?", line).Count(&count)
if count > 0 {
log.Println("URL already exists:", line)
continue
}
// Add the bad URL to the list
badURLs = append(badURLs, BadURLs{URL: line, Active: true})
}
if err := scanner.Err(); err != nil {
log.Println("Failed to scan filter list:", err)
}
}
if len(badURLs) == 0 {
log.Println("No new URLs found in", len(filterListURLs), "filter lists")
return
}
// Log how many bad URLs we found
log.Println("Found", len(badURLs), "bad URLs")
// Mark all the bad URLs as inactive if we have any in the database
var count int64
db.Model(&BadURLs{}).Count(&count)
if count > 0 {
db.Model(&BadURLs{}).Update("active", false)
}
// Save the bad URLs to the database
db.Create(&badURLs)
}
func renderPage(w http.ResponseWriter, title, description, keywords, author, url, templateName string) { func renderPage(w http.ResponseWriter, title, description, keywords, author, url, templateName string) {
data := TemplateData{ data := TemplateData{
Title: title, Title: title,
@ -208,6 +297,13 @@ func validateURL(feed_url string) error {
} }
} }
// Check if the domain is in BadURLs
var count int64
db.Model(&BadURLs{}).Where("url = ?", domain).Count(&count)
if count > 0 {
return errors.New("URL is in the bad URLs list")
}
// Don't allow URLs that end with .local // Don't allow URLs that end with .local
if strings.HasSuffix(domain, ".local") { if strings.HasSuffix(domain, ".local") {
return errors.New("URLs ending with .local are not allowed") return errors.New("URLs ending with .local are not allowed")

View file

@ -224,8 +224,39 @@ func TestIndexHandler(t *testing.T) {
// Check the response contains the expected string. // Check the response contains the expected string.
shouldContain := "Input the URLs of the feeds you wish to archive below. You can add as many as needed, and access them through the website or API. Alternatively, include links to .opml files, and the feeds within will be archived." shouldContain := "Input the URLs of the feeds you wish to archive below. You can add as many as needed, and access them through the website or API. Alternatively, include links to .opml files, and the feeds within will be archived."
if rr.Body.String() != shouldContain { body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v", t.Errorf("handler returned unexpected body: got %v want %v",
rr.Body.String(), shouldContain) body, shouldContain)
}
}
func TestApiHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/api", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(ApiHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "<p>Here be dragons.</p>"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
} }
} }

View file

@ -178,3 +178,15 @@ type ParseResult struct {
func (d *TemplateData) GetDatabaseSizeAndFeedCount() { func (d *TemplateData) GetDatabaseSizeAndFeedCount() {
d.DatabaseSize = GetDBSize() d.DatabaseSize = GetDBSize()
} }
type BadURLs struct {
gorm.Model
URL string `json:"url"`
Active bool `json:"active"`
}
type BadURLsMeta struct {
gorm.Model
URL string `json:"url"`
LastScraped time.Time `json:"lastScraped"`
}