diff --git a/main.go b/main.go index c8d1fb8..8e29185 100644 --- a/main.go +++ b/main.go @@ -1,16 +1,10 @@ package main import ( - "bufio" - "errors" "fmt" "html/template" "log" - "net" "net/http" - "net/url" - "strings" - "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -48,12 +42,15 @@ func main() { // Create a new router r := chi.NewRouter() + + // Middleware r.Use(middleware.RealIP) r.Use(middleware.Logger) r.Use(middleware.Recoverer) r.Use(middleware.Compress(5)) r.Use(middleware.Heartbeat("/ping")) + // Routes r.Get("/", IndexHandler) r.Get("/api", ApiHandler) r.Get("/donate", DonateHandler) @@ -62,8 +59,10 @@ func main() { r.Get("/terms", TermsHandler) r.Post("/add", AddFeedHandler) + // Static files r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) + // 404 and 405 handlers r.NotFound(NotFoundHandler) r.MethodNotAllowed(MethodNotAllowedHandler) @@ -71,86 +70,6 @@ func main() { http.ListenAndServe("127.0.0.1:8000", r) } -func scrapeBadURLs() { - // TODO: We should only scrape the bad URLs if the file has been updated - // TODO: Use brotli compression https://gitlab.com/malware-filter/urlhaus-filter#compressed-version - filterListURLs := []string{ - "https://malware-filter.gitlab.io/malware-filter/phishing-filter-dnscrypt-blocked-names.txt", - "https://malware-filter.gitlab.io/malware-filter/urlhaus-filter-dnscrypt-blocked-names-online.txt", - } - - // Scrape the bad URLs - badURLs := []BadURLs{} - for _, url := range filterListURLs { - // Check if we have scraped the bad URLs in the last 24 hours - var meta BadURLsMeta - db.Where("url = ?", url).First(&meta) - if time.Since(meta.LastScraped).Hours() < 24 { - log.Printf("%s was last scraped %.1f hours ago\n", url, time.Since(meta.LastScraped).Hours()) - continue - } - - // Create the meta if it doesn't exist - if meta.ID == 0 { - meta = BadURLsMeta{URL: url} - db.Create(&meta) - } - - // Update the last scraped time - db.Model(&meta).Update("last_scraped", time.Now()) - - // Get the filter list - resp, err := http.Get(url) - if err != nil { - log.Println("Failed to get filter list:", err) - continue - } - defer resp.Body.Close() - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "#") { - log.Println("Comment:", line) - continue - } - - // Skip the URL if it already exists in the database - var count int64 - db.Model(&BadURLs{}).Where("url = ?", line).Count(&count) - if count > 0 { - log.Println("URL already exists:", line) - continue - } - - // Add the bad URL to the list - badURLs = append(badURLs, BadURLs{URL: line, Active: true}) - } - - if err := scanner.Err(); err != nil { - log.Println("Failed to scan filter list:", err) - } - } - - if len(badURLs) == 0 { - log.Println("No new URLs found in", len(filterListURLs), "filter lists") - return - } - - // Log how many bad URLs we found - log.Println("Found", len(badURLs), "bad URLs") - - // Mark all the bad URLs as inactive if we have any in the database - var count int64 - db.Model(&BadURLs{}).Count(&count) - if count > 0 { - db.Model(&BadURLs{}).Update("active", false) - } - - // Save the bad URLs to the database - db.Create(&badURLs) -} - func renderPage(w http.ResponseWriter, title, description, keywords, author, url, templateName string) { data := TemplateData{ Title: title, @@ -169,232 +88,3 @@ func renderPage(w http.ResponseWriter, title, description, keywords, author, url } t.ExecuteTemplate(w, "base", data) } - -func NotFoundHandler(w http.ResponseWriter, r *http.Request) { - data := TemplateData{ - Request: r, - } - data.GetDatabaseSizeAndFeedCount() - t, err := template.ParseFiles("templates/base.tmpl", "templates/404.tmpl") - if err != nil { - http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusNotFound) - t.ExecuteTemplate(w, "base", data) -} - -func MethodNotAllowedHandler(w http.ResponseWriter, r *http.Request) { - data := TemplateData{ - Request: r, - } - data.GetDatabaseSizeAndFeedCount() - t, err := template.ParseFiles("templates/base.tmpl", "templates/405.tmpl") - if err != nil { - http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusMethodNotAllowed) - t.ExecuteTemplate(w, "base", data) -} - -func IndexHandler(w http.ResponseWriter, _ *http.Request) { - renderPage(w, "FeedVault", "FeedVault - A feed archive", "RSS, Atom, Feed, Archive", "TheLovinator", "http://localhost:8000/", "index") -} - -func ApiHandler(w http.ResponseWriter, _ *http.Request) { - renderPage(w, "API", "API Page", "api, page", "TheLovinator", "http://localhost:8000/api", "api") -} -func AboutHandler(w http.ResponseWriter, _ *http.Request) { - renderPage(w, "About", "About Page", "about, page", "TheLovinator", "http://localhost:8000/about", "about") -} - -func DonateHandler(w http.ResponseWriter, _ *http.Request) { - renderPage(w, "Donate", "Donate Page", "donate, page", "TheLovinator", "http://localhost:8000/donate", "donate") -} - -func FeedsHandler(w http.ResponseWriter, _ *http.Request) { - renderPage(w, "Feeds", "Feeds Page", "feeds, page", "TheLovinator", "http://localhost:8000/feeds", "feeds") -} - -func PrivacyHandler(w http.ResponseWriter, _ *http.Request) { - renderPage(w, "Privacy", "Privacy Page", "privacy, page", "TheLovinator", "http://localhost:8000/privacy", "privacy") -} - -func TermsHandler(w http.ResponseWriter, _ *http.Request) { - renderPage(w, "Terms", "Terms and Conditions Page", "terms, page", "TheLovinator", "http://localhost:8000/terms", "terms") -} - -// Run some simple validation on the URL -func validateURL(feed_url string) error { - // Check if URL starts with http or https - if !strings.HasPrefix(feed_url, "http://") && !strings.HasPrefix(feed_url, "https://") { - return errors.New("URL must start with http:// or https://") - } - - // Parse a url into a URL structure - u, err := url.Parse(feed_url) - if err != nil { - return errors.New("failed to parse URL") - } - - // Get the domain from the URL - domain := u.Hostname() - domain = strings.TrimSpace(domain) - if domain == "" { - return errors.New("URL does not contain a domain") - } - - // Don't allow IP address URLs - ip := net.ParseIP(domain) - if ip != nil { - return errors.New("IP address URLs are not allowed") - } - - // Don't allow local URLs (e.g. router URLs) - // Taken from https://github.com/uBlockOrigin/uAssets/blob/master/filters/lan-block.txt - // https://github.com/gwarser/filter-lists - localURLs := []string{ - "[::]", - "[::1]", - "airbox.home", - "airport", - "arcor.easybox", - "aterm.me", - "bthomehub.home", - "bthub.home", - "congstar.box", - "connect.box", - "console.gl-inet.com", - "easy.box", - "etxr", - "fire.walla", - "fritz.box", - "fritz.nas", - "fritz.repeater", - "giga.cube", - "hi.link", - "hitronhub.home", - "home.arpa", - "homerouter.cpe", - "host.docker.internal", - "huaweimobilewifi.com", - "localbattle.net", - "localhost", - "mobile.hotspot", - "myfritz.box", - "ntt.setup", - "pi.hole", - "plex.direct", - "repeater.asus.com", - "router.asus.com", - "routerlogin.com", - "routerlogin.net", - "samsung.router", - "speedport.ip", - "steamloopback.host", - "tplinkap.net", - "tplinkeap.net", - "tplinkmodem.net", - "tplinkplclogin.net", - "tplinkrepeater.net", - "tplinkwifi.net", - "web.setup.home", - "web.setup", - } - for _, localURL := range localURLs { - if strings.Contains(domain, localURL) { - return errors.New("local URLs are not allowed") - } - } - - // Check if the domain is in BadURLs - var count int64 - db.Model(&BadURLs{}).Where("url = ?", domain).Count(&count) - if count > 0 { - return errors.New("URL is in the bad URLs list") - } - - // Don't allow URLs that end with .local - if strings.HasSuffix(domain, ".local") { - return errors.New("URLs ending with .local are not allowed") - } - - // Don't allow URLs that end with .onion - if strings.HasSuffix(domain, ".onion") { - return errors.New("URLs ending with .onion are not allowed") - } - - // Don't allow URLs that end with .home.arpa - if strings.HasSuffix(domain, ".home.arpa") { - return errors.New("URLs ending with .home.arpa are not allowed") - } - - // Don't allow URLs that end with .internal - // Docker uses host.docker.internal - if strings.HasSuffix(domain, ".internal") { - return errors.New("URLs ending with .internal are not allowed") - } - - // Don't allow URLs that end with .localdomain - if strings.HasSuffix(domain, ".localdomain") { - return errors.New("URLs ending with .localdomain are not allowed") - } - - // Check if the domain is resolvable - _, err = net.LookupIP(domain) - if err != nil { - return errors.New("failed to resolve domain") - } - - // Check if the URL is reachable - _, err = http.Get(feed_url) - if err != nil { - return errors.New("failed to reach URL") - } - - return nil -} - -func AddFeedHandler(w http.ResponseWriter, r *http.Request) { - var parseErrors []ParseResult - - // Parse the form and get the URLs - r.ParseForm() - urls := r.Form.Get("urls") - if urls == "" { - http.Error(w, "No URLs provided", http.StatusBadRequest) - return - } - - for _, feed_url := range strings.Split(urls, "\n") { - // TODO: Try to upgrade to https if http is provided - - // Validate the URL - err := validateURL(feed_url) - if err != nil { - parseErrors = append(parseErrors, ParseResult{FeedURL: feed_url, Msg: err.Error(), IsError: true}) - continue - } - - // "Add" the feed to the database - log.Println("Adding feed:", feed_url) - } - - // Render the index page with the parse errors - data := TemplateData{ - Title: "FeedVault", - Description: "FeedVault - A feed archive", - Keywords: "RSS, Atom, Feed, Archive", - ParseErrors: parseErrors, - } - data.GetDatabaseSizeAndFeedCount() - - t, err := template.ParseFiles("templates/base.tmpl", "templates/index.tmpl") - if err != nil { - http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError) - return - } - t.ExecuteTemplate(w, "base", data) - -} diff --git a/main_test.go b/main_test.go index 3699bfa..62e421a 100644 --- a/main_test.go +++ b/main_test.go @@ -1,8 +1,6 @@ package main import ( - "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -200,213 +198,3 @@ func TestLocalURLs(t *testing.T) { assert.Equal(t, "local URLs are not allowed", err.Error()) } } - -func TestIndexHandler(t *testing.T) { - // Create a request to pass to our handler. - req, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) - } - - // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. - rr := httptest.NewRecorder() - handler := http.HandlerFunc(IndexHandler) - - // Our handlers satisfy http.Handler, so we can call their ServeHTTP method - // directly and pass in our Request and ResponseRecorder. - handler.ServeHTTP(rr, req) - - // Check the status code is what we expect. - if status := rr.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - - // Check the response contains the expected string. - shouldContain := "Input the URLs of the feeds you wish to archive below. You can add as many as needed, and access them through the website or API. Alternatively, include links to .opml files, and the feeds within will be archived." - body := rr.Body.String() - if !assert.Contains(t, body, shouldContain) { - t.Errorf("handler returned unexpected body: got %v want %v", - body, shouldContain) - } -} - -func TestApiHandler(t *testing.T) { - // Create a request to pass to our handler. - req, err := http.NewRequest("GET", "/api", nil) - if err != nil { - t.Fatal(err) - } - - // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. - rr := httptest.NewRecorder() - handler := http.HandlerFunc(ApiHandler) - - // Our handlers satisfy http.Handler, so we can call their ServeHTTP method - // directly and pass in our Request and ResponseRecorder. - handler.ServeHTTP(rr, req) - - // Check the status code is what we expect. - if status := rr.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - - // Check the response contains the expected string. - shouldContain := "

Here be dragons.

" - body := rr.Body.String() - if !assert.Contains(t, body, shouldContain) { - t.Errorf("handler returned unexpected body: got %v want %v", - body, shouldContain) - } -} - -func TestTermsHandler(t *testing.T) { - // Create a request to pass to our handler. - req, err := http.NewRequest("GET", "/terms", nil) - if err != nil { - t.Fatal(err) - } - - // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. - rr := httptest.NewRecorder() - handler := http.HandlerFunc(TermsHandler) - - // Our handlers satisfy http.Handler, so we can call their ServeHTTP method - // directly and pass in our Request and ResponseRecorder. - handler.ServeHTTP(rr, req) - - // Check the status code is what we expect. - if status := rr.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - - // Check the response contains the expected string. - shouldContain := "Terms of Service" - body := rr.Body.String() - if !assert.Contains(t, body, shouldContain) { - t.Errorf("handler returned unexpected body: got %v want %v", - body, shouldContain) - } -} - -func TestPrivacyHandler(t *testing.T) { - // Create a request to pass to our handler. - req, err := http.NewRequest("GET", "/privacy", nil) - if err != nil { - t.Fatal(err) - } - - // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. - rr := httptest.NewRecorder() - handler := http.HandlerFunc(PrivacyHandler) - - // Our handlers satisfy http.Handler, so we can call their ServeHTTP method - // directly and pass in our Request and ResponseRecorder. - handler.ServeHTTP(rr, req) - - // Check the status code is what we expect. - if status := rr.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - - // Check the response contains the expected string. - shouldContain := "Privacy Policy" - body := rr.Body.String() - if !assert.Contains(t, body, shouldContain) { - t.Errorf("handler returned unexpected body: got %v want %v", - body, shouldContain) - } -} - -func TestNotFoundHandler(t *testing.T) { - // Create a request to pass to our handler. - req, err := http.NewRequest("GET", "/notfound", nil) - if err != nil { - t.Fatal(err) - } - - // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. - rr := httptest.NewRecorder() - handler := http.HandlerFunc(NotFoundHandler) - - // Our handlers satisfy http.Handler, so we can call their ServeHTTP method - // directly and pass in our Request and ResponseRecorder. - handler.ServeHTTP(rr, req) - - // Check the status code is what we expect. - if status := rr.Code; status != http.StatusNotFound { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusNotFound) - } - - // Check the response contains the expected string. - shouldContain := "

404 - Page not found

" - body := rr.Body.String() - if !assert.Contains(t, body, shouldContain) { - t.Errorf("handler returned unexpected body: got %v want %v", - body, shouldContain) - } -} - -func TestMethodNotAllowedHandler(t *testing.T) { - // Create a request to pass to our handler. - req, err := http.NewRequest("GET", "/api", nil) - if err != nil { - t.Fatal(err) - } - - // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. - rr := httptest.NewRecorder() - handler := http.HandlerFunc(MethodNotAllowedHandler) - - // Our handlers satisfy http.Handler, so we can call their ServeHTTP method - // directly and pass in our Request and ResponseRecorder. - handler.ServeHTTP(rr, req) - - // Check the status code is what we expect. - if status := rr.Code; status != http.StatusMethodNotAllowed { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusMethodNotAllowed) - } - - // Check the response contains the expected string. - shouldContain := "

405 - Method Not Allowed

" - body := rr.Body.String() - if !assert.Contains(t, body, shouldContain) { - t.Errorf("handler returned unexpected body: got %v want %v", - body, shouldContain) - } -} - -func TestDonateHandler(t *testing.T) { - // Create a request to pass to our handler. - req, err := http.NewRequest("GET", "/donate", nil) - if err != nil { - t.Fatal(err) - } - - // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. - rr := httptest.NewRecorder() - handler := http.HandlerFunc(DonateHandler) - - // Our handlers satisfy http.Handler, so we can call their ServeHTTP method - // directly and pass in our Request and ResponseRecorder. - handler.ServeHTTP(rr, req) - - // Check the status code is what we expect. - if status := rr.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - - // Check the response contains the expected string. - shouldContain := "tl;dr: GitHub Sponsors" - body := rr.Body.String() - if !assert.Contains(t, body, shouldContain) { - t.Errorf("handler returned unexpected body: got %v want %v", - body, shouldContain) - } -} diff --git a/db.go b/stats.go similarity index 56% rename from db.go rename to stats.go index 8a11aa2..33f18f5 100644 --- a/db.go +++ b/stats.go @@ -17,18 +17,23 @@ func GetDBSize() string { // Get the file size in bytes fileSize := fileInfo.Size() - // Convert to human readable size and append the unit (KB, MB, GB) + // Convert to human readable size and append the unit (KiB, MiB, GiB, TiB) var size float64 if fileSize < 1024*1024 { size = float64(fileSize) / 1024 - return fmt.Sprintf("%.2f KB", size) + return fmt.Sprintf("%.2f KiB", size) } if fileSize < 1024*1024*1024 { size = float64(fileSize) / (1024 * 1024) - return fmt.Sprintf("%.2f MB", size) + return fmt.Sprintf("%.2f MiB", size) } - size = float64(fileSize) / (1024 * 1024 * 1024) - return fmt.Sprintf("%.2f GB", size) + if fileSize < 1024*1024*1024*1024 { + size = float64(fileSize) / (1024 * 1024 * 1024) + return fmt.Sprintf("%.2f GiB", size) + } + + size = float64(fileSize) / (1024 * 1024 * 1024 * 1024) + return fmt.Sprintf("%.2f TiB", size) } diff --git a/validate.go b/validate.go new file mode 100644 index 0000000..663c6cb --- /dev/null +++ b/validate.go @@ -0,0 +1,223 @@ +package main + +import ( + "bufio" + "errors" + "log" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +func scrapeBadURLs() { + // TODO: We should only scrape the bad URLs if the file has been updated + // TODO: Use brotli compression https://gitlab.com/malware-filter/urlhaus-filter#compressed-version + filterListURLs := []string{ + "https://malware-filter.gitlab.io/malware-filter/phishing-filter-dnscrypt-blocked-names.txt", + "https://malware-filter.gitlab.io/malware-filter/urlhaus-filter-dnscrypt-blocked-names-online.txt", + } + + // Scrape the bad URLs + badURLs := []BadURLs{} + for _, url := range filterListURLs { + // Check if we have scraped the bad URLs in the last 24 hours + var meta BadURLsMeta + db.Where("url = ?", url).First(&meta) + if time.Since(meta.LastScraped).Hours() < 24 { + log.Printf("%s was last scraped %.1f hours ago\n", url, time.Since(meta.LastScraped).Hours()) + continue + } + + // Create the meta if it doesn't exist + if meta.ID == 0 { + meta = BadURLsMeta{URL: url} + db.Create(&meta) + } + + // Update the last scraped time + db.Model(&meta).Update("last_scraped", time.Now()) + + // Get the filter list + resp, err := http.Get(url) + if err != nil { + log.Println("Failed to get filter list:", err) + continue + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "#") { + log.Println("Comment:", line) + continue + } + + // Skip the URL if it already exists in the database + var count int64 + db.Model(&BadURLs{}).Where("url = ?", line).Count(&count) + if count > 0 { + log.Println("URL already exists:", line) + continue + } + + // Add the bad URL to the list + badURLs = append(badURLs, BadURLs{URL: line, Active: true}) + } + + if err := scanner.Err(); err != nil { + log.Println("Failed to scan filter list:", err) + } + } + + if len(badURLs) == 0 { + log.Println("No new URLs found in", len(filterListURLs), "filter lists") + return + } + + // Log how many bad URLs we found + log.Println("Found", len(badURLs), "bad URLs") + + // Mark all the bad URLs as inactive if we have any in the database + var count int64 + db.Model(&BadURLs{}).Count(&count) + if count > 0 { + db.Model(&BadURLs{}).Update("active", false) + } + + // Save the bad URLs to the database + db.Create(&badURLs) +} + +// Run some simple validation on the URL +func validateURL(feed_url string) error { + // Check if URL starts with http or https + if !strings.HasPrefix(feed_url, "http://") && !strings.HasPrefix(feed_url, "https://") { + return errors.New("URL must start with http:// or https://") + } + + // Parse a url into a URL structure + u, err := url.Parse(feed_url) + if err != nil { + return errors.New("failed to parse URL") + } + + // Get the domain from the URL + domain := u.Hostname() + domain = strings.TrimSpace(domain) + if domain == "" { + return errors.New("URL does not contain a domain") + } + + // Don't allow IP address URLs + ip := net.ParseIP(domain) + if ip != nil { + return errors.New("IP address URLs are not allowed") + } + + // Don't allow local URLs (e.g. router URLs) + // Taken from https://github.com/uBlockOrigin/uAssets/blob/master/filters/lan-block.txt + // https://github.com/gwarser/filter-lists + localURLs := []string{ + "[::]", + "[::1]", + "airbox.home", + "airport", + "arcor.easybox", + "aterm.me", + "bthomehub.home", + "bthub.home", + "congstar.box", + "connect.box", + "console.gl-inet.com", + "easy.box", + "etxr", + "fire.walla", + "fritz.box", + "fritz.nas", + "fritz.repeater", + "giga.cube", + "hi.link", + "hitronhub.home", + "home.arpa", + "homerouter.cpe", + "host.docker.internal", + "huaweimobilewifi.com", + "localbattle.net", + "localhost", + "mobile.hotspot", + "myfritz.box", + "ntt.setup", + "pi.hole", + "plex.direct", + "repeater.asus.com", + "router.asus.com", + "routerlogin.com", + "routerlogin.net", + "samsung.router", + "speedport.ip", + "steamloopback.host", + "tplinkap.net", + "tplinkeap.net", + "tplinkmodem.net", + "tplinkplclogin.net", + "tplinkrepeater.net", + "tplinkwifi.net", + "web.setup.home", + "web.setup", + } + for _, localURL := range localURLs { + if strings.Contains(domain, localURL) { + return errors.New("local URLs are not allowed") + } + } + + // Check if the domain is in BadURLs + var count int64 + db.Model(&BadURLs{}).Where("url = ?", domain).Count(&count) + if count > 0 { + return errors.New("URL is in the bad URLs list") + } + + // Don't allow URLs that end with .local + if strings.HasSuffix(domain, ".local") { + return errors.New("URLs ending with .local are not allowed") + } + + // Don't allow URLs that end with .onion + if strings.HasSuffix(domain, ".onion") { + return errors.New("URLs ending with .onion are not allowed") + } + + // Don't allow URLs that end with .home.arpa + if strings.HasSuffix(domain, ".home.arpa") { + return errors.New("URLs ending with .home.arpa are not allowed") + } + + // Don't allow URLs that end with .internal + // Docker uses host.docker.internal + if strings.HasSuffix(domain, ".internal") { + return errors.New("URLs ending with .internal are not allowed") + } + + // Don't allow URLs that end with .localdomain + if strings.HasSuffix(domain, ".localdomain") { + return errors.New("URLs ending with .localdomain are not allowed") + } + + // Check if the domain is resolvable + _, err = net.LookupIP(domain) + if err != nil { + return errors.New("failed to resolve domain") + } + + // Check if the URL is reachable + _, err = http.Get(feed_url) + if err != nil { + return errors.New("failed to reach URL") + } + + return nil +} diff --git a/views.go b/views.go new file mode 100644 index 0000000..4653f1b --- /dev/null +++ b/views.go @@ -0,0 +1,107 @@ +package main + +import ( + "fmt" + "html/template" + "log" + "net/http" + "strings" +) + +func NotFoundHandler(w http.ResponseWriter, r *http.Request) { + data := TemplateData{ + Request: r, + } + data.GetDatabaseSizeAndFeedCount() + t, err := template.ParseFiles("templates/base.tmpl", "templates/404.tmpl") + if err != nil { + http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNotFound) + t.ExecuteTemplate(w, "base", data) +} + +func MethodNotAllowedHandler(w http.ResponseWriter, r *http.Request) { + data := TemplateData{ + Request: r, + } + data.GetDatabaseSizeAndFeedCount() + t, err := template.ParseFiles("templates/base.tmpl", "templates/405.tmpl") + if err != nil { + http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusMethodNotAllowed) + t.ExecuteTemplate(w, "base", data) +} + +func IndexHandler(w http.ResponseWriter, _ *http.Request) { + renderPage(w, "FeedVault", "FeedVault - A feed archive", "RSS, Atom, Feed, Archive", "TheLovinator", "http://localhost:8000/", "index") +} + +func ApiHandler(w http.ResponseWriter, _ *http.Request) { + renderPage(w, "API", "API Page", "api, page", "TheLovinator", "http://localhost:8000/api", "api") +} +func AboutHandler(w http.ResponseWriter, _ *http.Request) { + renderPage(w, "About", "About Page", "about, page", "TheLovinator", "http://localhost:8000/about", "about") +} + +func DonateHandler(w http.ResponseWriter, _ *http.Request) { + renderPage(w, "Donate", "Donate Page", "donate, page", "TheLovinator", "http://localhost:8000/donate", "donate") +} + +func FeedsHandler(w http.ResponseWriter, _ *http.Request) { + renderPage(w, "Feeds", "Feeds Page", "feeds, page", "TheLovinator", "http://localhost:8000/feeds", "feeds") +} + +func PrivacyHandler(w http.ResponseWriter, _ *http.Request) { + renderPage(w, "Privacy", "Privacy Page", "privacy, page", "TheLovinator", "http://localhost:8000/privacy", "privacy") +} + +func TermsHandler(w http.ResponseWriter, _ *http.Request) { + renderPage(w, "Terms", "Terms and Conditions Page", "terms, page", "TheLovinator", "http://localhost:8000/terms", "terms") +} + +func AddFeedHandler(w http.ResponseWriter, r *http.Request) { + var parseErrors []ParseResult + + // Parse the form and get the URLs + r.ParseForm() + urls := r.Form.Get("urls") + if urls == "" { + http.Error(w, "No URLs provided", http.StatusBadRequest) + return + } + + for _, feed_url := range strings.Split(urls, "\n") { + // TODO: Try to upgrade to https if http is provided + + // Validate the URL + err := validateURL(feed_url) + if err != nil { + parseErrors = append(parseErrors, ParseResult{FeedURL: feed_url, Msg: err.Error(), IsError: true}) + continue + } + + // "Add" the feed to the database + log.Println("Adding feed:", feed_url) + } + + // Render the index page with the parse errors + data := TemplateData{ + Title: "FeedVault", + Description: "FeedVault - A feed archive", + Keywords: "RSS, Atom, Feed, Archive", + ParseErrors: parseErrors, + } + data.GetDatabaseSizeAndFeedCount() + + t, err := template.ParseFiles("templates/base.tmpl", "templates/index.tmpl") + if err != nil { + http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError) + return + } + t.ExecuteTemplate(w, "base", data) + +} diff --git a/views_test.go b/views_test.go new file mode 100644 index 0000000..18ff52f --- /dev/null +++ b/views_test.go @@ -0,0 +1,219 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIndexHandler(t *testing.T) { + // Create a request to pass to our handler. + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + handler := http.HandlerFunc(IndexHandler) + + // Our handlers satisfy http.Handler, so we can call their ServeHTTP method + // directly and pass in our Request and ResponseRecorder. + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + // Check the response contains the expected string. + shouldContain := "Input the URLs of the feeds you wish to archive below. You can add as many as needed, and access them through the website or API. Alternatively, include links to .opml files, and the feeds within will be archived." + body := rr.Body.String() + if !assert.Contains(t, body, shouldContain) { + t.Errorf("handler returned unexpected body: got %v want %v", + body, shouldContain) + } +} + +func TestApiHandler(t *testing.T) { + // Create a request to pass to our handler. + req, err := http.NewRequest("GET", "/api", nil) + if err != nil { + t.Fatal(err) + } + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + handler := http.HandlerFunc(ApiHandler) + + // Our handlers satisfy http.Handler, so we can call their ServeHTTP method + // directly and pass in our Request and ResponseRecorder. + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + // Check the response contains the expected string. + shouldContain := "

Here be dragons.

" + body := rr.Body.String() + if !assert.Contains(t, body, shouldContain) { + t.Errorf("handler returned unexpected body: got %v want %v", + body, shouldContain) + } +} + +func TestTermsHandler(t *testing.T) { + // Create a request to pass to our handler. + req, err := http.NewRequest("GET", "/terms", nil) + if err != nil { + t.Fatal(err) + } + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + handler := http.HandlerFunc(TermsHandler) + + // Our handlers satisfy http.Handler, so we can call their ServeHTTP method + // directly and pass in our Request and ResponseRecorder. + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + // Check the response contains the expected string. + shouldContain := "Terms of Service" + body := rr.Body.String() + if !assert.Contains(t, body, shouldContain) { + t.Errorf("handler returned unexpected body: got %v want %v", + body, shouldContain) + } +} + +func TestPrivacyHandler(t *testing.T) { + // Create a request to pass to our handler. + req, err := http.NewRequest("GET", "/privacy", nil) + if err != nil { + t.Fatal(err) + } + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + handler := http.HandlerFunc(PrivacyHandler) + + // Our handlers satisfy http.Handler, so we can call their ServeHTTP method + // directly and pass in our Request and ResponseRecorder. + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + // Check the response contains the expected string. + shouldContain := "Privacy Policy" + body := rr.Body.String() + if !assert.Contains(t, body, shouldContain) { + t.Errorf("handler returned unexpected body: got %v want %v", + body, shouldContain) + } +} + +func TestNotFoundHandler(t *testing.T) { + // Create a request to pass to our handler. + req, err := http.NewRequest("GET", "/notfound", nil) + if err != nil { + t.Fatal(err) + } + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + handler := http.HandlerFunc(NotFoundHandler) + + // Our handlers satisfy http.Handler, so we can call their ServeHTTP method + // directly and pass in our Request and ResponseRecorder. + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != http.StatusNotFound { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusNotFound) + } + + // Check the response contains the expected string. + shouldContain := "

404 - Page not found

" + body := rr.Body.String() + if !assert.Contains(t, body, shouldContain) { + t.Errorf("handler returned unexpected body: got %v want %v", + body, shouldContain) + } +} + +func TestMethodNotAllowedHandler(t *testing.T) { + // Create a request to pass to our handler. + req, err := http.NewRequest("GET", "/api", nil) + if err != nil { + t.Fatal(err) + } + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + handler := http.HandlerFunc(MethodNotAllowedHandler) + + // Our handlers satisfy http.Handler, so we can call their ServeHTTP method + // directly and pass in our Request and ResponseRecorder. + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != http.StatusMethodNotAllowed { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusMethodNotAllowed) + } + + // Check the response contains the expected string. + shouldContain := "

405 - Method Not Allowed

" + body := rr.Body.String() + if !assert.Contains(t, body, shouldContain) { + t.Errorf("handler returned unexpected body: got %v want %v", + body, shouldContain) + } +} + +func TestDonateHandler(t *testing.T) { + // Create a request to pass to our handler. + req, err := http.NewRequest("GET", "/donate", nil) + if err != nil { + t.Fatal(err) + } + + // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. + rr := httptest.NewRecorder() + handler := http.HandlerFunc(DonateHandler) + + // Our handlers satisfy http.Handler, so we can call their ServeHTTP method + // directly and pass in our Request and ResponseRecorder. + handler.ServeHTTP(rr, req) + + // Check the status code is what we expect. + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + // Check the response contains the expected string. + shouldContain := "tl;dr: GitHub Sponsors" + body := rr.Body.String() + if !assert.Contains(t, body, shouldContain) { + t.Errorf("handler returned unexpected body: got %v want %v", + body, shouldContain) + } +}