diff options
Diffstat (limited to 'pkg/ext')
-rw-r--r-- | pkg/ext/auth.go | 85 | ||||
-rw-r--r-- | pkg/ext/compression.go | 31 | ||||
-rw-r--r-- | pkg/ext/log.go | 4 | ||||
-rw-r--r-- | pkg/ext/request.go | 14 | ||||
-rw-r--r-- | pkg/ext/router.go | 42 |
5 files changed, 156 insertions, 20 deletions
diff --git a/pkg/ext/auth.go b/pkg/ext/auth.go new file mode 100644 index 0000000..ef126ec --- /dev/null +++ b/pkg/ext/auth.go @@ -0,0 +1,85 @@ +package ext + +import ( + "context" + "encoding/base64" + "errors" + "log/slog" + "net/http" + + serverconfig "git.gabrielgio.me/cerrado/pkg/config" +) + +type authService interface { + ValidateToken(token []byte) (bool, error) +} + +func DisableAuthentication(next HandlerFunc) HandlerFunc { + return func(w http.ResponseWriter, r *Request) { + ctx := r.Context() + ctx = context.WithValue(ctx, "disableAuthentication", true) + r.Request = r.WithContext(ctx) + next(w, r) + } +} + +func VerifyRespository( + config *serverconfig.ConfigurationRepository, +) func(next HandlerFunc) HandlerFunc { + return func(next HandlerFunc) HandlerFunc { + return func(w http.ResponseWriter, r *Request) { + name := r.PathValue("name") + if name != "" { + repo := config.GetByName(name) + if repo != nil && !repo.Public && !IsLoggedIn(r.Context()) { + NotFound(w, r) + return + } + } + + next(w, r) + } + } +} + +func Authenticate(auth authService) func(next HandlerFunc) HandlerFunc { + return func(next HandlerFunc) HandlerFunc { + return func(w http.ResponseWriter, r *Request) { + cookie, err := r.Cookie("auth") + if err != nil { + if !errors.Is(err, http.ErrNoCookie) { + slog.Error("Error loading cookie", "error", err) + } + + next(w, r) + return + } + + value, err := base64.StdEncoding.DecodeString(cookie.Value) + if err != nil { + slog.Error("Error decoding", "error", err) + next(w, r) + return + } + + valid, err := auth.ValidateToken(value) + if err != nil { + slog.Error("Error validating token", "error", err, "cookie", cookie.Value) + next(w, r) + return + } + + ctx := r.Context() + ctx = context.WithValue(ctx, "logged", valid) + r.Request = r.WithContext(ctx) + + slog.Info("Validated token", "valid?", valid) + next(w, r) + } + } +} + +func IsLoggedIn(ctx context.Context) bool { + t, ok := ctx.Value("logged").(bool) + return ok && t +} diff --git a/pkg/ext/compression.go b/pkg/ext/compression.go index 6c7a219..d3a3df1 100644 --- a/pkg/ext/compression.go +++ b/pkg/ext/compression.go @@ -15,18 +15,37 @@ import ( "github.com/klauspost/compress/zstd" ) -var ( - errInvalidParam = errors.New("Invalid weighted param") -) +var errInvalidParam = errors.New("Invalid weighted param") type CompressionResponseWriter struct { innerWriter http.ResponseWriter compressWriter io.Writer } -func Compress(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func Compress(next HandlerFunc) HandlerFunc { + return func(w http.ResponseWriter, r *Request) { + // TODO: hand this better + if strings.HasSuffix(r.URL.Path, ".tar.gz") { + next(w, r) + return + } + + if accept, ok := r.Header["Accept-Encoding"]; ok { + if compress, algo := GetCompressionWriter(u.FirstOrZero(accept), w); algo != "" { + defer compress.Close() + w.Header().Add("Content-Encoding", algo) + w = &CompressionResponseWriter{ + innerWriter: w, + compressWriter: compress, + } + } + } + next(w, r) + } +} +func Decompress(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { // TODO: hand this better if strings.HasSuffix(r.URL.Path, ".tar.gz") { next(w, r) @@ -61,12 +80,12 @@ func GetCompressionWriter(header string, inner io.Writer) (io.WriteCloser, strin default: return nil, "" } - } func (c *CompressionResponseWriter) Header() http.Header { return c.innerWriter.Header() } + func (c *CompressionResponseWriter) Write(b []byte) (int, error) { return c.compressWriter.Write(b) } diff --git a/pkg/ext/log.go b/pkg/ext/log.go index 8e68134..e0ad89f 100644 --- a/pkg/ext/log.go +++ b/pkg/ext/log.go @@ -39,8 +39,8 @@ func wrap(w http.ResponseWriter) *statusWraper { } } -func Log(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func Log(next HandlerFunc) HandlerFunc { + return func(w http.ResponseWriter, r *Request) { t := time.Now() s := wrap(w) next(s, r) diff --git a/pkg/ext/request.go b/pkg/ext/request.go new file mode 100644 index 0000000..d1593b2 --- /dev/null +++ b/pkg/ext/request.go @@ -0,0 +1,14 @@ +package ext + +import ( + "io" + "net/http" +) + +type Request struct { + *http.Request +} + +func (r *Request) ReadBody() io.ReadCloser { + return r.Body +} diff --git a/pkg/ext/router.go b/pkg/ext/router.go index 96da1c9..bbbffa1 100644 --- a/pkg/ext/router.go +++ b/pkg/ext/router.go @@ -3,10 +3,12 @@ package ext import ( "errors" "fmt" + "log/slog" "net/http" "git.gabrielgio.me/cerrado/pkg/service" "git.gabrielgio.me/cerrado/templates" + "github.com/go-git/go-git/v5/plumbing" ) type ( @@ -14,8 +16,9 @@ type ( middlewares []Middleware router *http.ServeMux } - Middleware func(next http.HandlerFunc) http.HandlerFunc - ErrorRequestHandler func(w http.ResponseWriter, r *http.Request) error + HandlerFunc func(http.ResponseWriter, *Request) + Middleware func(next HandlerFunc) HandlerFunc + ErrorRequestHandler func(w http.ResponseWriter, r *Request) error ) func NewRouter() *Router { @@ -23,6 +26,7 @@ func NewRouter() *Router { router: http.NewServeMux(), } } + func (r *Router) Handler() http.Handler { return r.router } @@ -31,13 +35,15 @@ func (r *Router) AddMiddleware(middleware Middleware) { r.middlewares = append(r.middlewares, middleware) } -func wrapError(next ErrorRequestHandler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func wrapError(next ErrorRequestHandler) HandlerFunc { + return func(w http.ResponseWriter, r *Request) { if err := next(w, r); err != nil { - if errors.Is(err, service.ErrRepositoryNotFound) { - NotFound(w) + if errors.Is(err, service.ErrRepositoryNotFound) || + errors.Is(err, plumbing.ErrReferenceNotFound) { + NotFound(w, r) } else { - InternalServerError(w, err) + slog.Error("Internal Server Error", "error", err) + InternalServerError(w, r, err) } } } @@ -49,7 +55,7 @@ func (r *Router) run(next ErrorRequestHandler) http.HandlerFunc { for _, r := range r.middlewares { req = r(req) } - req(w, re) + req(w, &Request{Request: re}) } } @@ -57,16 +63,28 @@ func (r *Router) HandleFunc(path string, handler ErrorRequestHandler) { r.router.HandleFunc(path, r.run(handler)) } -func NotFound(w http.ResponseWriter) { +func NotFound(w http.ResponseWriter, r *Request) { w.WriteHeader(http.StatusNotFound) templates.WritePageTemplate(w, &templates.ErrorPage{ Message: "Not Found", - }) + }, r.Context()) +} + +func BadRequest(w http.ResponseWriter, r *Request, msg string) { + w.WriteHeader(http.StatusBadRequest) + templates.WritePageTemplate(w, &templates.ErrorPage{ + Message: msg, + }, r.Context()) +} + +func Redirect(w http.ResponseWriter, location string) { + w.Header().Add("location", location) + w.WriteHeader(http.StatusTemporaryRedirect) } -func InternalServerError(w http.ResponseWriter, err error) { +func InternalServerError(w http.ResponseWriter, r *Request, err error) { w.WriteHeader(http.StatusInternalServerError) templates.WritePageTemplate(w, &templates.ErrorPage{ Message: fmt.Sprintf("Internal Server Error:\n%s", err.Error()), - }) + }, r.Context()) } |