203 lines
4.2 KiB
Go
203 lines
4.2 KiB
Go
package router
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.red-panda.pet/pandaware/house/backend/db"
|
|
"github.com/charmbracelet/log"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type IRouter interface {
|
|
Register(method, pattern string, route Route)
|
|
GET(pattern string, route Route)
|
|
POST(pattern string, route Route)
|
|
HEAD(pattern string, route Route)
|
|
PUT(pattern string, route Route)
|
|
PATCH(pattern string, route Route)
|
|
DELETE(pattern string, route Route)
|
|
OPTIONS(pattern string, route Route)
|
|
}
|
|
|
|
type Router struct {
|
|
log *log.Logger
|
|
mux *http.ServeMux
|
|
db *sql.DB
|
|
prefix string
|
|
}
|
|
|
|
func NewRouter(logger *log.Logger, db *sql.DB) *Router {
|
|
r := new(Router)
|
|
r.mux = http.NewServeMux()
|
|
r.log = logger
|
|
r.db = db
|
|
|
|
return r
|
|
}
|
|
|
|
func (router *Router) SetPrefix(prefix string) {
|
|
router.prefix = prefix
|
|
}
|
|
|
|
func (router *Router) NewContext(w http.ResponseWriter, r *http.Request) (*Context, error) {
|
|
router.log.Helper()
|
|
|
|
ctx := new(Context)
|
|
id, err := uuid.NewV7()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conn, err := router.db.Conn(r.Context())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
method := r.Method
|
|
if method == "" {
|
|
method = "GET"
|
|
}
|
|
path := r.URL.EscapedPath()
|
|
|
|
ctx.id = id.String()
|
|
ctx.start = time.Now()
|
|
ctx.resp = w
|
|
ctx.Request = r
|
|
ctx.log = router.log.WithPrefix(method + " " + path)
|
|
ctx.DB = conn
|
|
ctx.Query = db.New(ctx.DB)
|
|
ctx.context = r.Context()
|
|
|
|
return ctx, nil
|
|
}
|
|
|
|
func (router *Router) handleError(ctx *Context, err error) {
|
|
if err == nil {
|
|
return
|
|
}
|
|
|
|
if reqErr, ok := err.(*RequestError); ok {
|
|
err = ctx.JSON(reqErr.StatusCode, reqErr)
|
|
ctx.log.Warn("error during request", "err", reqErr.Inner)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
panic(err)
|
|
}
|
|
|
|
// All implements IRouter.
|
|
func (router *Router) Register(method, path string, route Route) {
|
|
router.log.Helper()
|
|
|
|
routeImpl := any(route)
|
|
|
|
authorize, implAuthorized := routeImpl.(AuthorizedRoute)
|
|
validateBody, implBodyValidation := routeImpl.(ValidatedBodyRoute)
|
|
validateParams, implParamValidation := routeImpl.(ValidatedParamRoute)
|
|
validateSearch, implSearchValidation := routeImpl.(ValidatedSearchRoute)
|
|
|
|
pattern := fmt.Sprintf("%s %s%s", method, router.prefix, path)
|
|
|
|
router.mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, err := router.NewContext(w, r)
|
|
//TODO: don't panic :3
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if implAuthorized {
|
|
err = authorize.Authorize(ctx)
|
|
}
|
|
|
|
if err != nil {
|
|
router.handleError(ctx, err)
|
|
return
|
|
}
|
|
|
|
if implSearchValidation {
|
|
err = validateSearch.ValidateSearch(ctx)
|
|
}
|
|
|
|
if err != nil {
|
|
router.handleError(ctx, err)
|
|
return
|
|
}
|
|
|
|
if implParamValidation {
|
|
err = validateParams.ValidateParams(ctx)
|
|
}
|
|
|
|
if err != nil {
|
|
router.handleError(ctx, err)
|
|
return
|
|
}
|
|
|
|
if implBodyValidation {
|
|
err = validateBody.ValidateBody(ctx)
|
|
}
|
|
|
|
if err != nil {
|
|
router.handleError(ctx, err)
|
|
return
|
|
}
|
|
|
|
err = route.Handle(ctx)
|
|
router.handleError(ctx, err)
|
|
})
|
|
}
|
|
|
|
// GET implements IRouter.
|
|
func (router *Router) GET(pattern string, route Route) {
|
|
router.Register("GET ", pattern, route)
|
|
}
|
|
|
|
// DELETE implements IRouter.
|
|
func (router *Router) DELETE(pattern string, route Route) {
|
|
router.Register("DELETE", pattern, route)
|
|
}
|
|
|
|
// HEAD implements IRouter.
|
|
func (router *Router) HEAD(pattern string, route Route) {
|
|
router.Register("HEAD", pattern, route)
|
|
}
|
|
|
|
// OPTIONS implements IRouter.
|
|
func (router *Router) OPTIONS(pattern string, route Route) {
|
|
router.Register("OPTIONS", pattern, route)
|
|
}
|
|
|
|
// POST implements IRouter.
|
|
func (router *Router) POST(pattern string, route Route) {
|
|
router.Register("POST", pattern, route)
|
|
}
|
|
|
|
// PUT implements IRouter.
|
|
func (router *Router) PUT(pattern string, route Route) {
|
|
router.Register("PUT", pattern, route)
|
|
}
|
|
|
|
// PATCH implements IRouter.
|
|
func (router *Router) PATCH(pattern string, route Route) {
|
|
router.Register("PATCH", pattern, route)
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler.
|
|
func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
router.log.Helper()
|
|
router.log.Warn("recovered from panic", "err", rec)
|
|
}
|
|
}()
|
|
|
|
router.mux.ServeHTTP(w, r)
|
|
}
|
|
|
|
var _ http.Handler = new(Router)
|
|
var _ IRouter = new(Router)
|