package router import ( "database/sql" "fmt" "net/http" "time" "git.red-panda.pet/pandaware/house/backend/db" "github.com/charmbracelet/log" "github.com/google/uuid" ) type IRouter interface { Register(method, pattern string, route Route) GET(pattern string, route Route) POST(pattern string, route Route) HEAD(pattern string, route Route) PUT(pattern string, route Route) PATCH(pattern string, route Route) DELETE(pattern string, route Route) OPTIONS(pattern string, route Route) } type Router struct { log *log.Logger mux *http.ServeMux db *sql.DB prefix string } func NewRouter(logger *log.Logger, db *sql.DB) *Router { r := new(Router) r.mux = http.NewServeMux() r.log = logger r.db = db return r } func (router *Router) SetPrefix(prefix string) { router.prefix = prefix } func (router *Router) NewContext(w http.ResponseWriter, r *http.Request) (*Context, error) { router.log.Helper() ctx := new(Context) id, err := uuid.NewV7() if err != nil { return nil, err } conn, err := router.db.Conn(r.Context()) if err != nil { return nil, err } method := r.Method if method == "" { method = "GET" } path := r.URL.EscapedPath() ctx.id = id.String() ctx.start = time.Now() ctx.resp = w ctx.Request = r ctx.log = router.log.WithPrefix(method + " " + path) ctx.DB = conn ctx.Query = db.New(ctx.DB) ctx.context = r.Context() return ctx, nil } func (router *Router) handleError(ctx *Context, err error) { if err == nil { return } if reqErr, ok := err.(*RequestError); ok { err = ctx.JSON(reqErr.StatusCode, reqErr) ctx.log.Warn("error during request", "err", reqErr.Inner) if err != nil { panic(err) } return } panic(err) } // All implements IRouter. func (router *Router) Register(method, path string, route Route) { router.log.Helper() routeImpl := any(route) authorize, implAuthorized := routeImpl.(AuthorizedRoute) validateBody, implBodyValidation := routeImpl.(ValidatedBodyRoute) validateParams, implParamValidation := routeImpl.(ValidatedParamRoute) validateSearch, implSearchValidation := routeImpl.(ValidatedSearchRoute) pattern := fmt.Sprintf("%s %s%s", method, router.prefix, path) router.mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { ctx, err := router.NewContext(w, r) //TODO: don't panic :3 if err != nil { panic(err) } if implAuthorized { err = authorize.Authorize(ctx) } if err != nil { router.handleError(ctx, err) return } if implSearchValidation { err = validateSearch.ValidateSearch(ctx) } if err != nil { router.handleError(ctx, err) return } if implParamValidation { err = validateParams.ValidateParams(ctx) } if err != nil { router.handleError(ctx, err) return } if implBodyValidation { err = validateBody.ValidateBody(ctx) } if err != nil { router.handleError(ctx, err) return } err = route.Handle(ctx) router.handleError(ctx, err) }) } // GET implements IRouter. func (router *Router) GET(pattern string, route Route) { router.Register("GET ", pattern, route) } // DELETE implements IRouter. func (router *Router) DELETE(pattern string, route Route) { router.Register("DELETE", pattern, route) } // HEAD implements IRouter. func (router *Router) HEAD(pattern string, route Route) { router.Register("HEAD", pattern, route) } // OPTIONS implements IRouter. func (router *Router) OPTIONS(pattern string, route Route) { router.Register("OPTIONS", pattern, route) } // POST implements IRouter. func (router *Router) POST(pattern string, route Route) { router.Register("POST", pattern, route) } // PUT implements IRouter. func (router *Router) PUT(pattern string, route Route) { router.Register("PUT", pattern, route) } // PATCH implements IRouter. func (router *Router) PATCH(pattern string, route Route) { router.Register("PATCH", pattern, route) } // ServeHTTP implements http.Handler. func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer func() { if rec := recover(); rec != nil { router.log.Helper() router.log.Warn("recovered from panic", "err", rec) } }() router.mux.ServeHTTP(w, r) } var _ http.Handler = new(Router) var _ IRouter = new(Router)