Refactor files

This commit is contained in:
Joakim Hellsén 2024-02-05 00:09:21 +01:00
commit f2f1a08687
6 changed files with 564 additions and 532 deletions

320
main.go
View file

@ -1,16 +1,10 @@
package main package main
import ( import (
"bufio"
"errors"
"fmt" "fmt"
"html/template" "html/template"
"log" "log"
"net"
"net/http" "net/http"
"net/url"
"strings"
"time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
@ -48,12 +42,15 @@ func main() {
// Create a new router // Create a new router
r := chi.NewRouter() r := chi.NewRouter()
// Middleware
r.Use(middleware.RealIP) r.Use(middleware.RealIP)
r.Use(middleware.Logger) r.Use(middleware.Logger)
r.Use(middleware.Recoverer) r.Use(middleware.Recoverer)
r.Use(middleware.Compress(5)) r.Use(middleware.Compress(5))
r.Use(middleware.Heartbeat("/ping")) r.Use(middleware.Heartbeat("/ping"))
// Routes
r.Get("/", IndexHandler) r.Get("/", IndexHandler)
r.Get("/api", ApiHandler) r.Get("/api", ApiHandler)
r.Get("/donate", DonateHandler) r.Get("/donate", DonateHandler)
@ -62,8 +59,10 @@ func main() {
r.Get("/terms", TermsHandler) r.Get("/terms", TermsHandler)
r.Post("/add", AddFeedHandler) r.Post("/add", AddFeedHandler)
// Static files
r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
// 404 and 405 handlers
r.NotFound(NotFoundHandler) r.NotFound(NotFoundHandler)
r.MethodNotAllowed(MethodNotAllowedHandler) r.MethodNotAllowed(MethodNotAllowedHandler)
@ -71,86 +70,6 @@ func main() {
http.ListenAndServe("127.0.0.1:8000", r) http.ListenAndServe("127.0.0.1:8000", r)
} }
func scrapeBadURLs() {
// TODO: We should only scrape the bad URLs if the file has been updated
// TODO: Use brotli compression https://gitlab.com/malware-filter/urlhaus-filter#compressed-version
filterListURLs := []string{
"https://malware-filter.gitlab.io/malware-filter/phishing-filter-dnscrypt-blocked-names.txt",
"https://malware-filter.gitlab.io/malware-filter/urlhaus-filter-dnscrypt-blocked-names-online.txt",
}
// Scrape the bad URLs
badURLs := []BadURLs{}
for _, url := range filterListURLs {
// Check if we have scraped the bad URLs in the last 24 hours
var meta BadURLsMeta
db.Where("url = ?", url).First(&meta)
if time.Since(meta.LastScraped).Hours() < 24 {
log.Printf("%s was last scraped %.1f hours ago\n", url, time.Since(meta.LastScraped).Hours())
continue
}
// Create the meta if it doesn't exist
if meta.ID == 0 {
meta = BadURLsMeta{URL: url}
db.Create(&meta)
}
// Update the last scraped time
db.Model(&meta).Update("last_scraped", time.Now())
// Get the filter list
resp, err := http.Get(url)
if err != nil {
log.Println("Failed to get filter list:", err)
continue
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
log.Println("Comment:", line)
continue
}
// Skip the URL if it already exists in the database
var count int64
db.Model(&BadURLs{}).Where("url = ?", line).Count(&count)
if count > 0 {
log.Println("URL already exists:", line)
continue
}
// Add the bad URL to the list
badURLs = append(badURLs, BadURLs{URL: line, Active: true})
}
if err := scanner.Err(); err != nil {
log.Println("Failed to scan filter list:", err)
}
}
if len(badURLs) == 0 {
log.Println("No new URLs found in", len(filterListURLs), "filter lists")
return
}
// Log how many bad URLs we found
log.Println("Found", len(badURLs), "bad URLs")
// Mark all the bad URLs as inactive if we have any in the database
var count int64
db.Model(&BadURLs{}).Count(&count)
if count > 0 {
db.Model(&BadURLs{}).Update("active", false)
}
// Save the bad URLs to the database
db.Create(&badURLs)
}
func renderPage(w http.ResponseWriter, title, description, keywords, author, url, templateName string) { func renderPage(w http.ResponseWriter, title, description, keywords, author, url, templateName string) {
data := TemplateData{ data := TemplateData{
Title: title, Title: title,
@ -169,232 +88,3 @@ func renderPage(w http.ResponseWriter, title, description, keywords, author, url
} }
t.ExecuteTemplate(w, "base", data) t.ExecuteTemplate(w, "base", data)
} }
func NotFoundHandler(w http.ResponseWriter, r *http.Request) {
data := TemplateData{
Request: r,
}
data.GetDatabaseSizeAndFeedCount()
t, err := template.ParseFiles("templates/base.tmpl", "templates/404.tmpl")
if err != nil {
http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNotFound)
t.ExecuteTemplate(w, "base", data)
}
func MethodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
data := TemplateData{
Request: r,
}
data.GetDatabaseSizeAndFeedCount()
t, err := template.ParseFiles("templates/base.tmpl", "templates/405.tmpl")
if err != nil {
http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusMethodNotAllowed)
t.ExecuteTemplate(w, "base", data)
}
func IndexHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "FeedVault", "FeedVault - A feed archive", "RSS, Atom, Feed, Archive", "TheLovinator", "http://localhost:8000/", "index")
}
func ApiHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "API", "API Page", "api, page", "TheLovinator", "http://localhost:8000/api", "api")
}
func AboutHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "About", "About Page", "about, page", "TheLovinator", "http://localhost:8000/about", "about")
}
func DonateHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "Donate", "Donate Page", "donate, page", "TheLovinator", "http://localhost:8000/donate", "donate")
}
func FeedsHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "Feeds", "Feeds Page", "feeds, page", "TheLovinator", "http://localhost:8000/feeds", "feeds")
}
func PrivacyHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "Privacy", "Privacy Page", "privacy, page", "TheLovinator", "http://localhost:8000/privacy", "privacy")
}
func TermsHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "Terms", "Terms and Conditions Page", "terms, page", "TheLovinator", "http://localhost:8000/terms", "terms")
}
// Run some simple validation on the URL
func validateURL(feed_url string) error {
// Check if URL starts with http or https
if !strings.HasPrefix(feed_url, "http://") && !strings.HasPrefix(feed_url, "https://") {
return errors.New("URL must start with http:// or https://")
}
// Parse a url into a URL structure
u, err := url.Parse(feed_url)
if err != nil {
return errors.New("failed to parse URL")
}
// Get the domain from the URL
domain := u.Hostname()
domain = strings.TrimSpace(domain)
if domain == "" {
return errors.New("URL does not contain a domain")
}
// Don't allow IP address URLs
ip := net.ParseIP(domain)
if ip != nil {
return errors.New("IP address URLs are not allowed")
}
// Don't allow local URLs (e.g. router URLs)
// Taken from https://github.com/uBlockOrigin/uAssets/blob/master/filters/lan-block.txt
// https://github.com/gwarser/filter-lists
localURLs := []string{
"[::]",
"[::1]",
"airbox.home",
"airport",
"arcor.easybox",
"aterm.me",
"bthomehub.home",
"bthub.home",
"congstar.box",
"connect.box",
"console.gl-inet.com",
"easy.box",
"etxr",
"fire.walla",
"fritz.box",
"fritz.nas",
"fritz.repeater",
"giga.cube",
"hi.link",
"hitronhub.home",
"home.arpa",
"homerouter.cpe",
"host.docker.internal",
"huaweimobilewifi.com",
"localbattle.net",
"localhost",
"mobile.hotspot",
"myfritz.box",
"ntt.setup",
"pi.hole",
"plex.direct",
"repeater.asus.com",
"router.asus.com",
"routerlogin.com",
"routerlogin.net",
"samsung.router",
"speedport.ip",
"steamloopback.host",
"tplinkap.net",
"tplinkeap.net",
"tplinkmodem.net",
"tplinkplclogin.net",
"tplinkrepeater.net",
"tplinkwifi.net",
"web.setup.home",
"web.setup",
}
for _, localURL := range localURLs {
if strings.Contains(domain, localURL) {
return errors.New("local URLs are not allowed")
}
}
// Check if the domain is in BadURLs
var count int64
db.Model(&BadURLs{}).Where("url = ?", domain).Count(&count)
if count > 0 {
return errors.New("URL is in the bad URLs list")
}
// Don't allow URLs that end with .local
if strings.HasSuffix(domain, ".local") {
return errors.New("URLs ending with .local are not allowed")
}
// Don't allow URLs that end with .onion
if strings.HasSuffix(domain, ".onion") {
return errors.New("URLs ending with .onion are not allowed")
}
// Don't allow URLs that end with .home.arpa
if strings.HasSuffix(domain, ".home.arpa") {
return errors.New("URLs ending with .home.arpa are not allowed")
}
// Don't allow URLs that end with .internal
// Docker uses host.docker.internal
if strings.HasSuffix(domain, ".internal") {
return errors.New("URLs ending with .internal are not allowed")
}
// Don't allow URLs that end with .localdomain
if strings.HasSuffix(domain, ".localdomain") {
return errors.New("URLs ending with .localdomain are not allowed")
}
// Check if the domain is resolvable
_, err = net.LookupIP(domain)
if err != nil {
return errors.New("failed to resolve domain")
}
// Check if the URL is reachable
_, err = http.Get(feed_url)
if err != nil {
return errors.New("failed to reach URL")
}
return nil
}
func AddFeedHandler(w http.ResponseWriter, r *http.Request) {
var parseErrors []ParseResult
// Parse the form and get the URLs
r.ParseForm()
urls := r.Form.Get("urls")
if urls == "" {
http.Error(w, "No URLs provided", http.StatusBadRequest)
return
}
for _, feed_url := range strings.Split(urls, "\n") {
// TODO: Try to upgrade to https if http is provided
// Validate the URL
err := validateURL(feed_url)
if err != nil {
parseErrors = append(parseErrors, ParseResult{FeedURL: feed_url, Msg: err.Error(), IsError: true})
continue
}
// "Add" the feed to the database
log.Println("Adding feed:", feed_url)
}
// Render the index page with the parse errors
data := TemplateData{
Title: "FeedVault",
Description: "FeedVault - A feed archive",
Keywords: "RSS, Atom, Feed, Archive",
ParseErrors: parseErrors,
}
data.GetDatabaseSizeAndFeedCount()
t, err := template.ParseFiles("templates/base.tmpl", "templates/index.tmpl")
if err != nil {
http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError)
return
}
t.ExecuteTemplate(w, "base", data)
}

View file

@ -1,8 +1,6 @@
package main package main
import ( import (
"net/http"
"net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -200,213 +198,3 @@ func TestLocalURLs(t *testing.T) {
assert.Equal(t, "local URLs are not allowed", err.Error()) assert.Equal(t, "local URLs are not allowed", err.Error())
} }
} }
func TestIndexHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(IndexHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "Input the URLs of the feeds you wish to archive below. You can add as many as needed, and access them through the website or API. Alternatively, include links to .opml files, and the feeds within will be archived."
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestApiHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/api", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(ApiHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "<p>Here be dragons.</p>"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestTermsHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/terms", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(TermsHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "Terms of Service"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestPrivacyHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/privacy", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(PrivacyHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "Privacy Policy"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestNotFoundHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/notfound", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(NotFoundHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusNotFound {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusNotFound)
}
// Check the response contains the expected string.
shouldContain := "<h2>404 - Page not found</h2>"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestMethodNotAllowedHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/api", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(MethodNotAllowedHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusMethodNotAllowed {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusMethodNotAllowed)
}
// Check the response contains the expected string.
shouldContain := "<h2>405 - Method Not Allowed</h2>"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestDonateHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/donate", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(DonateHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "tl;dr: <a href=\"https://github.com/sponsors/TheLovinator1\">GitHub Sponsors</a>"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}

View file

@ -17,18 +17,23 @@ func GetDBSize() string {
// Get the file size in bytes // Get the file size in bytes
fileSize := fileInfo.Size() fileSize := fileInfo.Size()
// Convert to human readable size and append the unit (KB, MB, GB) // Convert to human readable size and append the unit (KiB, MiB, GiB, TiB)
var size float64 var size float64
if fileSize < 1024*1024 { if fileSize < 1024*1024 {
size = float64(fileSize) / 1024 size = float64(fileSize) / 1024
return fmt.Sprintf("%.2f KB", size) return fmt.Sprintf("%.2f KiB", size)
} }
if fileSize < 1024*1024*1024 { if fileSize < 1024*1024*1024 {
size = float64(fileSize) / (1024 * 1024) size = float64(fileSize) / (1024 * 1024)
return fmt.Sprintf("%.2f MB", size) return fmt.Sprintf("%.2f MiB", size)
} }
if fileSize < 1024*1024*1024*1024 {
size = float64(fileSize) / (1024 * 1024 * 1024) size = float64(fileSize) / (1024 * 1024 * 1024)
return fmt.Sprintf("%.2f GB", size) return fmt.Sprintf("%.2f GiB", size)
}
size = float64(fileSize) / (1024 * 1024 * 1024 * 1024)
return fmt.Sprintf("%.2f TiB", size)
} }

223
validate.go Normal file
View file

@ -0,0 +1,223 @@
package main
import (
"bufio"
"errors"
"log"
"net"
"net/http"
"net/url"
"strings"
"time"
)
func scrapeBadURLs() {
// TODO: We should only scrape the bad URLs if the file has been updated
// TODO: Use brotli compression https://gitlab.com/malware-filter/urlhaus-filter#compressed-version
filterListURLs := []string{
"https://malware-filter.gitlab.io/malware-filter/phishing-filter-dnscrypt-blocked-names.txt",
"https://malware-filter.gitlab.io/malware-filter/urlhaus-filter-dnscrypt-blocked-names-online.txt",
}
// Scrape the bad URLs
badURLs := []BadURLs{}
for _, url := range filterListURLs {
// Check if we have scraped the bad URLs in the last 24 hours
var meta BadURLsMeta
db.Where("url = ?", url).First(&meta)
if time.Since(meta.LastScraped).Hours() < 24 {
log.Printf("%s was last scraped %.1f hours ago\n", url, time.Since(meta.LastScraped).Hours())
continue
}
// Create the meta if it doesn't exist
if meta.ID == 0 {
meta = BadURLsMeta{URL: url}
db.Create(&meta)
}
// Update the last scraped time
db.Model(&meta).Update("last_scraped", time.Now())
// Get the filter list
resp, err := http.Get(url)
if err != nil {
log.Println("Failed to get filter list:", err)
continue
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
log.Println("Comment:", line)
continue
}
// Skip the URL if it already exists in the database
var count int64
db.Model(&BadURLs{}).Where("url = ?", line).Count(&count)
if count > 0 {
log.Println("URL already exists:", line)
continue
}
// Add the bad URL to the list
badURLs = append(badURLs, BadURLs{URL: line, Active: true})
}
if err := scanner.Err(); err != nil {
log.Println("Failed to scan filter list:", err)
}
}
if len(badURLs) == 0 {
log.Println("No new URLs found in", len(filterListURLs), "filter lists")
return
}
// Log how many bad URLs we found
log.Println("Found", len(badURLs), "bad URLs")
// Mark all the bad URLs as inactive if we have any in the database
var count int64
db.Model(&BadURLs{}).Count(&count)
if count > 0 {
db.Model(&BadURLs{}).Update("active", false)
}
// Save the bad URLs to the database
db.Create(&badURLs)
}
// Run some simple validation on the URL
func validateURL(feed_url string) error {
// Check if URL starts with http or https
if !strings.HasPrefix(feed_url, "http://") && !strings.HasPrefix(feed_url, "https://") {
return errors.New("URL must start with http:// or https://")
}
// Parse a url into a URL structure
u, err := url.Parse(feed_url)
if err != nil {
return errors.New("failed to parse URL")
}
// Get the domain from the URL
domain := u.Hostname()
domain = strings.TrimSpace(domain)
if domain == "" {
return errors.New("URL does not contain a domain")
}
// Don't allow IP address URLs
ip := net.ParseIP(domain)
if ip != nil {
return errors.New("IP address URLs are not allowed")
}
// Don't allow local URLs (e.g. router URLs)
// Taken from https://github.com/uBlockOrigin/uAssets/blob/master/filters/lan-block.txt
// https://github.com/gwarser/filter-lists
localURLs := []string{
"[::]",
"[::1]",
"airbox.home",
"airport",
"arcor.easybox",
"aterm.me",
"bthomehub.home",
"bthub.home",
"congstar.box",
"connect.box",
"console.gl-inet.com",
"easy.box",
"etxr",
"fire.walla",
"fritz.box",
"fritz.nas",
"fritz.repeater",
"giga.cube",
"hi.link",
"hitronhub.home",
"home.arpa",
"homerouter.cpe",
"host.docker.internal",
"huaweimobilewifi.com",
"localbattle.net",
"localhost",
"mobile.hotspot",
"myfritz.box",
"ntt.setup",
"pi.hole",
"plex.direct",
"repeater.asus.com",
"router.asus.com",
"routerlogin.com",
"routerlogin.net",
"samsung.router",
"speedport.ip",
"steamloopback.host",
"tplinkap.net",
"tplinkeap.net",
"tplinkmodem.net",
"tplinkplclogin.net",
"tplinkrepeater.net",
"tplinkwifi.net",
"web.setup.home",
"web.setup",
}
for _, localURL := range localURLs {
if strings.Contains(domain, localURL) {
return errors.New("local URLs are not allowed")
}
}
// Check if the domain is in BadURLs
var count int64
db.Model(&BadURLs{}).Where("url = ?", domain).Count(&count)
if count > 0 {
return errors.New("URL is in the bad URLs list")
}
// Don't allow URLs that end with .local
if strings.HasSuffix(domain, ".local") {
return errors.New("URLs ending with .local are not allowed")
}
// Don't allow URLs that end with .onion
if strings.HasSuffix(domain, ".onion") {
return errors.New("URLs ending with .onion are not allowed")
}
// Don't allow URLs that end with .home.arpa
if strings.HasSuffix(domain, ".home.arpa") {
return errors.New("URLs ending with .home.arpa are not allowed")
}
// Don't allow URLs that end with .internal
// Docker uses host.docker.internal
if strings.HasSuffix(domain, ".internal") {
return errors.New("URLs ending with .internal are not allowed")
}
// Don't allow URLs that end with .localdomain
if strings.HasSuffix(domain, ".localdomain") {
return errors.New("URLs ending with .localdomain are not allowed")
}
// Check if the domain is resolvable
_, err = net.LookupIP(domain)
if err != nil {
return errors.New("failed to resolve domain")
}
// Check if the URL is reachable
_, err = http.Get(feed_url)
if err != nil {
return errors.New("failed to reach URL")
}
return nil
}

107
views.go Normal file
View file

@ -0,0 +1,107 @@
package main
import (
"fmt"
"html/template"
"log"
"net/http"
"strings"
)
func NotFoundHandler(w http.ResponseWriter, r *http.Request) {
data := TemplateData{
Request: r,
}
data.GetDatabaseSizeAndFeedCount()
t, err := template.ParseFiles("templates/base.tmpl", "templates/404.tmpl")
if err != nil {
http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNotFound)
t.ExecuteTemplate(w, "base", data)
}
func MethodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
data := TemplateData{
Request: r,
}
data.GetDatabaseSizeAndFeedCount()
t, err := template.ParseFiles("templates/base.tmpl", "templates/405.tmpl")
if err != nil {
http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusMethodNotAllowed)
t.ExecuteTemplate(w, "base", data)
}
func IndexHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "FeedVault", "FeedVault - A feed archive", "RSS, Atom, Feed, Archive", "TheLovinator", "http://localhost:8000/", "index")
}
func ApiHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "API", "API Page", "api, page", "TheLovinator", "http://localhost:8000/api", "api")
}
func AboutHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "About", "About Page", "about, page", "TheLovinator", "http://localhost:8000/about", "about")
}
func DonateHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "Donate", "Donate Page", "donate, page", "TheLovinator", "http://localhost:8000/donate", "donate")
}
func FeedsHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "Feeds", "Feeds Page", "feeds, page", "TheLovinator", "http://localhost:8000/feeds", "feeds")
}
func PrivacyHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "Privacy", "Privacy Page", "privacy, page", "TheLovinator", "http://localhost:8000/privacy", "privacy")
}
func TermsHandler(w http.ResponseWriter, _ *http.Request) {
renderPage(w, "Terms", "Terms and Conditions Page", "terms, page", "TheLovinator", "http://localhost:8000/terms", "terms")
}
func AddFeedHandler(w http.ResponseWriter, r *http.Request) {
var parseErrors []ParseResult
// Parse the form and get the URLs
r.ParseForm()
urls := r.Form.Get("urls")
if urls == "" {
http.Error(w, "No URLs provided", http.StatusBadRequest)
return
}
for _, feed_url := range strings.Split(urls, "\n") {
// TODO: Try to upgrade to https if http is provided
// Validate the URL
err := validateURL(feed_url)
if err != nil {
parseErrors = append(parseErrors, ParseResult{FeedURL: feed_url, Msg: err.Error(), IsError: true})
continue
}
// "Add" the feed to the database
log.Println("Adding feed:", feed_url)
}
// Render the index page with the parse errors
data := TemplateData{
Title: "FeedVault",
Description: "FeedVault - A feed archive",
Keywords: "RSS, Atom, Feed, Archive",
ParseErrors: parseErrors,
}
data.GetDatabaseSizeAndFeedCount()
t, err := template.ParseFiles("templates/base.tmpl", "templates/index.tmpl")
if err != nil {
http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError)
return
}
t.ExecuteTemplate(w, "base", data)
}

219
views_test.go Normal file
View file

@ -0,0 +1,219 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIndexHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(IndexHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "Input the URLs of the feeds you wish to archive below. You can add as many as needed, and access them through the website or API. Alternatively, include links to .opml files, and the feeds within will be archived."
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestApiHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/api", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(ApiHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "<p>Here be dragons.</p>"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestTermsHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/terms", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(TermsHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "Terms of Service"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestPrivacyHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/privacy", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(PrivacyHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "Privacy Policy"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestNotFoundHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/notfound", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(NotFoundHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusNotFound {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusNotFound)
}
// Check the response contains the expected string.
shouldContain := "<h2>404 - Page not found</h2>"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestMethodNotAllowedHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/api", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(MethodNotAllowedHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusMethodNotAllowed {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusMethodNotAllowed)
}
// Check the response contains the expected string.
shouldContain := "<h2>405 - Method Not Allowed</h2>"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}
func TestDonateHandler(t *testing.T) {
// Create a request to pass to our handler.
req, err := http.NewRequest("GET", "/donate", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(DonateHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response contains the expected string.
shouldContain := "tl;dr: <a href=\"https://github.com/sponsors/TheLovinator1\">GitHub Sponsors</a>"
body := rr.Body.String()
if !assert.Contains(t, body, shouldContain) {
t.Errorf("handler returned unexpected body: got %v want %v",
body, shouldContain)
}
}