house/backend/router/router.go
2025-06-14 23:47:44 -04:00

203 lines
4.2 KiB
Go

package router
import (
"database/sql"
"fmt"
"net/http"
"time"
"git.red-panda.pet/pandaware/house/backend/db"
"github.com/charmbracelet/log"
"github.com/google/uuid"
)
type IRouter interface {
Register(method, pattern string, route Route)
GET(pattern string, route Route)
POST(pattern string, route Route)
HEAD(pattern string, route Route)
PUT(pattern string, route Route)
PATCH(pattern string, route Route)
DELETE(pattern string, route Route)
OPTIONS(pattern string, route Route)
}
type Router struct {
log *log.Logger
mux *http.ServeMux
db *sql.DB
prefix string
}
func NewRouter(logger *log.Logger, db *sql.DB) *Router {
r := new(Router)
r.mux = http.NewServeMux()
r.log = logger
r.db = db
return r
}
func (router *Router) SetPrefix(prefix string) {
router.prefix = prefix
}
func (router *Router) NewContext(w http.ResponseWriter, r *http.Request) (*Context, error) {
router.log.Helper()
ctx := new(Context)
id, err := uuid.NewV7()
if err != nil {
return nil, err
}
conn, err := router.db.Conn(r.Context())
if err != nil {
return nil, err
}
method := r.Method
if method == "" {
method = "GET"
}
path := r.URL.EscapedPath()
ctx.id = id.String()
ctx.start = time.Now()
ctx.resp = w
ctx.Request = r
ctx.log = router.log.WithPrefix(method + " " + path)
ctx.DB = conn
ctx.Query = db.New(ctx.DB)
ctx.context = r.Context()
return ctx, nil
}
func (router *Router) handleError(ctx *Context, err error) {
if err == nil {
return
}
if reqErr, ok := err.(*RequestError); ok {
err = ctx.JSON(reqErr.StatusCode, reqErr)
ctx.log.Warn("error during request", "err", reqErr.Inner)
if err != nil {
panic(err)
}
return
}
panic(err)
}
// All implements IRouter.
func (router *Router) Register(method, path string, route Route) {
router.log.Helper()
routeImpl := any(route)
authorize, implAuthorized := routeImpl.(AuthorizedRoute)
validateBody, implBodyValidation := routeImpl.(ValidatedBodyRoute)
validateParams, implParamValidation := routeImpl.(ValidatedParamRoute)
validateSearch, implSearchValidation := routeImpl.(ValidatedSearchRoute)
pattern := fmt.Sprintf("%s %s%s", method, router.prefix, path)
router.mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
ctx, err := router.NewContext(w, r)
//TODO: don't panic :3
if err != nil {
panic(err)
}
if implAuthorized {
err = authorize.Authorize(ctx)
}
if err != nil {
router.handleError(ctx, err)
return
}
if implSearchValidation {
err = validateSearch.ValidateSearch(ctx)
}
if err != nil {
router.handleError(ctx, err)
return
}
if implParamValidation {
err = validateParams.ValidateParams(ctx)
}
if err != nil {
router.handleError(ctx, err)
return
}
if implBodyValidation {
err = validateBody.ValidateBody(ctx)
}
if err != nil {
router.handleError(ctx, err)
return
}
err = route.Handle(ctx)
router.handleError(ctx, err)
})
}
// GET implements IRouter.
func (router *Router) GET(pattern string, route Route) {
router.Register("GET ", pattern, route)
}
// DELETE implements IRouter.
func (router *Router) DELETE(pattern string, route Route) {
router.Register("DELETE", pattern, route)
}
// HEAD implements IRouter.
func (router *Router) HEAD(pattern string, route Route) {
router.Register("HEAD", pattern, route)
}
// OPTIONS implements IRouter.
func (router *Router) OPTIONS(pattern string, route Route) {
router.Register("OPTIONS", pattern, route)
}
// POST implements IRouter.
func (router *Router) POST(pattern string, route Route) {
router.Register("POST", pattern, route)
}
// PUT implements IRouter.
func (router *Router) PUT(pattern string, route Route) {
router.Register("PUT", pattern, route)
}
// PATCH implements IRouter.
func (router *Router) PATCH(pattern string, route Route) {
router.Register("PATCH", pattern, route)
}
// ServeHTTP implements http.Handler.
func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
router.log.Helper()
router.log.Warn("recovered from panic", "err", rec)
}
}()
router.mux.ServeHTTP(w, r)
}
var _ http.Handler = new(Router)
var _ IRouter = new(Router)