aboutsummaryrefslogtreecommitdiff
path: root/pkg/ext
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/ext')
-rw-r--r--pkg/ext/auth.go72
-rw-r--r--pkg/ext/auth_test.go40
-rw-r--r--pkg/ext/gorm_logger.go58
-rw-r--r--pkg/ext/middleware.go89
-rw-r--r--pkg/ext/responses.go50
-rw-r--r--pkg/ext/router.go51
6 files changed, 360 insertions, 0 deletions
diff --git a/pkg/ext/auth.go b/pkg/ext/auth.go
new file mode 100644
index 0000000..d9fbfba
--- /dev/null
+++ b/pkg/ext/auth.go
@@ -0,0 +1,72 @@
+package ext
+
+import (
+ "bytes"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "encoding/gob"
+ "fmt"
+ "io"
+)
+
+type Token struct {
+ UserID uint
+ Username string
+}
+
+var nonce []byte
+
+func init() {
+ nonce = make([]byte, 12)
+ if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
+ fmt.Println("Erro while generating nonce " + err.Error())
+ panic(1)
+ }
+}
+
+func ReadToken(data []byte, key []byte) (*Token, error) {
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, err
+ }
+
+ aesgcm, err := cipher.NewGCM(block)
+ if err != nil {
+ panic(err.Error())
+ }
+
+ plaintext, err := aesgcm.Open(nil, nonce, data, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ r := bytes.NewReader(plaintext)
+ var token Token
+ dec := gob.NewDecoder(r)
+ if err = dec.Decode(&token); err != nil {
+ return nil, err
+ }
+ return &token, nil
+}
+
+func WriteToken(token *Token, key []byte) ([]byte, error) {
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, err
+ }
+
+ aesgcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return nil, err
+ }
+
+ var buffer bytes.Buffer
+ enc := gob.NewEncoder(&buffer)
+ if err := enc.Encode(token); err != nil {
+ return nil, err
+ }
+
+ ciphertext := aesgcm.Seal(nil, nonce, buffer.Bytes(), nil)
+ return ciphertext, nil
+}
diff --git a/pkg/ext/auth_test.go b/pkg/ext/auth_test.go
new file mode 100644
index 0000000..dc72a0c
--- /dev/null
+++ b/pkg/ext/auth_test.go
@@ -0,0 +1,40 @@
+//go:build unit
+
+package ext
+
+import (
+ "testing"
+
+ "git.sr.ht/~gabrielgio/img/pkg/testkit"
+)
+
+func TestReadWriteToken(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ key []byte
+ token *Token
+ }{
+ {
+ name: "Normal write",
+ key: []byte("AES256Key-32Characters1234567890"),
+ token: &Token{
+ UserID: 3,
+ Username: "username",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ data, err := WriteToken(tc.token, tc.key)
+ testkit.TestFatalError(t, "WriteToken", err)
+
+ token, err := ReadToken(data, tc.key)
+ testkit.TestFatalError(t, "ReadToken", err)
+
+ testkit.TestValue(t, "ReadWriteToken", token, tc.token)
+ })
+ }
+}
diff --git a/pkg/ext/gorm_logger.go b/pkg/ext/gorm_logger.go
new file mode 100644
index 0000000..bfb26d2
--- /dev/null
+++ b/pkg/ext/gorm_logger.go
@@ -0,0 +1,58 @@
+package ext
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "gorm.io/gorm/logger"
+ "gorm.io/gorm/utils"
+)
+
+type Log struct {
+ logrus *logrus.Entry
+}
+
+func getFullMsg(msg string, data ...interface{}) string {
+ return fmt.Sprintf(msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
+}
+
+func (self *Log) LogMode(log logger.LogLevel) logger.Interface {
+ return self
+}
+
+func (self *Log) Info(ctx context.Context, msg string, data ...interface{}) {
+ fullMsg := getFullMsg(msg, data)
+ self.logrus.
+ WithContext(ctx).
+ Info(fullMsg)
+}
+
+func (self *Log) Warn(ctx context.Context, msg string, data ...interface{}) {
+ fullMsg := getFullMsg(msg, data)
+ self.logrus.
+ WithContext(ctx).
+ Warn(fullMsg)
+}
+func (self *Log) Error(ctx context.Context, msg string, data ...interface{}) {
+ fullMsg := getFullMsg(msg, data)
+ self.logrus.
+ WithContext(ctx).
+ Error(fullMsg)
+}
+
+func (self *Log) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
+ elapsed := time.Since(begin)
+ sql, _ := fc()
+ self.logrus.
+ WithContext(ctx).
+ WithField("time", elapsed).
+ Printf(sql)
+}
+
+func Wraplog(log *logrus.Entry) *Log {
+ return &Log{
+ logrus: log,
+ }
+}
diff --git a/pkg/ext/middleware.go b/pkg/ext/middleware.go
new file mode 100644
index 0000000..771c0ac
--- /dev/null
+++ b/pkg/ext/middleware.go
@@ -0,0 +1,89 @@
+package ext
+
+import (
+ "encoding/base64"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/valyala/fasthttp"
+)
+
+func HTML(next fasthttp.RequestHandler) fasthttp.RequestHandler {
+ return func(ctx *fasthttp.RequestCtx) {
+ ctx.Response.Header.SetContentType("text/html")
+ next(ctx)
+ }
+}
+
+type LogMiddleware struct {
+ entry *logrus.Entry
+}
+
+func NewLogMiddleare(log *logrus.Entry) *LogMiddleware {
+ return &LogMiddleware{
+ entry: log,
+ }
+}
+
+func (l *LogMiddleware) HTTP(next fasthttp.RequestHandler) fasthttp.RequestHandler {
+ return func(ctx *fasthttp.RequestCtx) {
+ start := time.Now()
+ next(ctx)
+ elapsed := time.Since(start)
+ l.entry.
+ WithField("time", elapsed).
+ WithField("code", ctx.Response.StatusCode()).
+ WithField("path", string(ctx.Path())).
+ WithField("bytes", len(ctx.Response.Body())).
+ Info(string(ctx.Request.Header.Method()))
+ }
+}
+
+type AuthMiddleware struct {
+ key []byte
+ entry *logrus.Entry
+}
+
+func NewAuthMiddleware(key []byte, log *logrus.Entry) *AuthMiddleware {
+ return &AuthMiddleware{
+ key: key,
+ entry: log.WithField("context", "auth"),
+ }
+}
+
+func (a *AuthMiddleware) LoggedIn(next fasthttp.RequestHandler) fasthttp.RequestHandler {
+ return func(ctx *fasthttp.RequestCtx) {
+ path := string(ctx.Path())
+ if path == "/login" {
+ next(ctx)
+ return
+ }
+
+ redirectLogin := "/login?redirect=" + path
+ authBase64 := ctx.Request.Header.Cookie("auth")
+ if authBase64 == nil {
+ a.entry.Info("No auth provided")
+ ctx.Redirect(redirectLogin, 307)
+ return
+ }
+
+ auth, err := base64.StdEncoding.DecodeString(string(authBase64))
+ if err != nil {
+ a.entry.Error(err)
+ return
+ }
+
+ token, err := ReadToken(auth, a.key)
+ if err != nil {
+ a.entry.Error(err)
+ ctx.Redirect(redirectLogin, 307)
+ return
+ }
+ ctx.SetUserValue("token", token)
+ a.entry.
+ WithField("userID", token.UserID).
+ WithField("username", token.Username).
+ Info("user recognized")
+ next(ctx)
+ }
+}
diff --git a/pkg/ext/responses.go b/pkg/ext/responses.go
new file mode 100644
index 0000000..7354395
--- /dev/null
+++ b/pkg/ext/responses.go
@@ -0,0 +1,50 @@
+package ext
+
+import (
+ "bytes"
+ "fmt"
+
+ "github.com/valyala/fasthttp"
+
+ "git.sr.ht/~gabrielgio/img"
+)
+
+var (
+ ContentTypeJSON = []byte("application/json")
+ ContentTypeHTML = []byte("text/html")
+ ContentTypeMARKDOWN = []byte("text/markdown")
+ ContentTypeJPEG = []byte("image/jpeg")
+)
+
+func NotFoundHTML(ctx *fasthttp.RequestCtx) {
+ ctx.Response.Header.SetContentType("text/html")
+ //nolint:errcheck
+ img.Render(ctx, "error.html", &img.HTMLView[string]{
+ Data: "NotFound",
+ })
+}
+
+func NotFound(ctx *fasthttp.RequestCtx) {
+ ctx.Response.SetStatusCode(404)
+ ct := ctx.Response.Header.ContentType()
+ if bytes.Equal(ct, ContentTypeHTML) {
+ NotFoundHTML(ctx)
+ }
+}
+
+func InternalServerError(ctx *fasthttp.RequestCtx, err error) {
+ ctx.Response.Header.SetContentType("text/html")
+ message := fmt.Sprintf("Internal Server Error:\n%+v", err)
+ //nolint:errcheck
+ respErr := img.Render(ctx, "error.html", &img.HTMLView[string]{
+ Data: message,
+ })
+
+ if respErr != nil {
+ fmt.Println(respErr.Error())
+ }
+}
+
+func NoContent(ctx *fasthttp.RequestCtx) {
+ ctx.Response.SetStatusCode(204)
+}
diff --git a/pkg/ext/router.go b/pkg/ext/router.go
new file mode 100644
index 0000000..74f0a95
--- /dev/null
+++ b/pkg/ext/router.go
@@ -0,0 +1,51 @@
+package ext
+
+import (
+ "github.com/fasthttp/router"
+ "github.com/valyala/fasthttp"
+)
+
+type (
+ Router struct {
+ middlewares []Middleware
+ fastRouter *router.Router
+ }
+ Middleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler
+ ErrorRequestHandler func(ctx *fasthttp.RequestCtx) error
+)
+
+func NewRouter(nestedRouter *router.Router) *Router {
+ return &Router{
+ fastRouter: nestedRouter,
+ }
+}
+
+func (self *Router) AddMiddleware(middleware Middleware) {
+ self.middlewares = append(self.middlewares, middleware)
+}
+
+func wrapError(next ErrorRequestHandler) fasthttp.RequestHandler {
+ return func(ctx *fasthttp.RequestCtx) {
+ if err := next(ctx); err != nil {
+ ctx.Response.SetStatusCode(500)
+ InternalServerError(ctx, err)
+ }
+ }
+}
+
+func (self *Router) run(next ErrorRequestHandler) fasthttp.RequestHandler {
+ return func(ctx *fasthttp.RequestCtx) {
+ req := wrapError(next)
+ for _, r := range self.middlewares {
+ req = r(req)
+ }
+ req(ctx)
+ }
+}
+
+func (self *Router) GET(path string, handler ErrorRequestHandler) {
+ self.fastRouter.GET(path, self.run(handler))
+}
+func (self *Router) POST(path string, handler ErrorRequestHandler) {
+ self.fastRouter.POST(path, self.run(handler))
+}