package middleware import ( "database/sql" "log" "net/http" "time" ) type AuthMiddleware struct { Err *log.Logger Db *sql.DB } func (auth *AuthMiddleware) CheckAndInvalidate(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc( func (w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("paterissa_session_token") if err != nil { next.ServeHTTP(w, r) return } stmt, err := auth.Db.Prepare("SELECT * FROM cookies WHERE content = $1;") if err != nil { cookie = &http.Cookie{ Name: "paterissa_session_token", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, } http.SetCookie(w, cookie) http.Redirect(w, r, "/", http.StatusFound) return } defer stmt.Close() var id int var content string var userId int var expiration time.Time row := stmt.QueryRow(cookie.Value) err = row.Scan(&id, &content, &userId, &expiration) if err != nil { cookie = &http.Cookie{ Name: "paterissa_session_token", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, } http.SetCookie(w, cookie) http.Redirect(w, r, "/", http.StatusFound) return } if time.Now().After(expiration) { cookie = &http.Cookie{ Name: "paterissa_session_token", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, } http.SetCookie(w, cookie) http.Redirect(w, r, "/", http.StatusFound) return } next.ServeHTTP(w, r) return }) } func (auth *AuthMiddleware) Resolve(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc( func (w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("paterissa_session_token") if err != nil { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("Unauthorized")) return } stmt, err := auth.Db.Prepare("SELECT * FROM cookies WHERE content = $1;") if err != nil { cookie = &http.Cookie{ Name: "paterissa_session_token", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, } http.SetCookie(w, cookie) w.Write([]byte("Unauthorized")) auth.Err.Printf("Could not retrieve cookie from DB: %v\n", err) http.Redirect(w, r, "/", http.StatusUnauthorized) return } defer stmt.Close() var id int var content string var userId int var expiration time.Time row := stmt.QueryRow(cookie.Value) err = row.Scan(&id, &content, &userId, &expiration) if err != nil { cookie = &http.Cookie{ Name: "paterissa_session_token", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, } http.SetCookie(w, cookie) w.Write([]byte("Unauthorized")) auth.Err.Printf("Could not retrieve cookie from DB: %v\n", err) http.Redirect(w, r, "/", http.StatusUnauthorized) return } if time.Now().After(expiration) { cookie = &http.Cookie{ Name: "paterissa_session_token", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, } http.SetCookie(w, cookie) w.Write([]byte("Expired")) http.Redirect(w, r, "/", http.StatusUnauthorized) return } next.ServeHTTP(w, r) return }) }