diff options
Diffstat (limited to 'pkg/ext')
-rw-r--r-- | pkg/ext/auth.go | 72 | ||||
-rw-r--r-- | pkg/ext/auth_test.go | 40 | ||||
-rw-r--r-- | pkg/ext/gorm_logger.go | 58 | ||||
-rw-r--r-- | pkg/ext/middleware.go | 89 | ||||
-rw-r--r-- | pkg/ext/responses.go | 50 | ||||
-rw-r--r-- | pkg/ext/router.go | 51 |
6 files changed, 360 insertions, 0 deletions
diff --git a/pkg/ext/auth.go b/pkg/ext/auth.go new file mode 100644 index 0000000..d9fbfba --- /dev/null +++ b/pkg/ext/auth.go @@ -0,0 +1,72 @@ +package ext + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/gob" + "fmt" + "io" +) + +type Token struct { + UserID uint + Username string +} + +var nonce []byte + +func init() { + nonce = make([]byte, 12) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + fmt.Println("Erro while generating nonce " + err.Error()) + panic(1) + } +} + +func ReadToken(data []byte, key []byte) (*Token, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + panic(err.Error()) + } + + plaintext, err := aesgcm.Open(nil, nonce, data, nil) + if err != nil { + return nil, err + } + + r := bytes.NewReader(plaintext) + var token Token + dec := gob.NewDecoder(r) + if err = dec.Decode(&token); err != nil { + return nil, err + } + return &token, nil +} + +func WriteToken(token *Token, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + var buffer bytes.Buffer + enc := gob.NewEncoder(&buffer) + if err := enc.Encode(token); err != nil { + return nil, err + } + + ciphertext := aesgcm.Seal(nil, nonce, buffer.Bytes(), nil) + return ciphertext, nil +} diff --git a/pkg/ext/auth_test.go b/pkg/ext/auth_test.go new file mode 100644 index 0000000..dc72a0c --- /dev/null +++ b/pkg/ext/auth_test.go @@ -0,0 +1,40 @@ +//go:build unit + +package ext + +import ( + "testing" + + "git.sr.ht/~gabrielgio/img/pkg/testkit" +) + +func TestReadWriteToken(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + key []byte + token *Token + }{ + { + name: "Normal write", + key: []byte("AES256Key-32Characters1234567890"), + token: &Token{ + UserID: 3, + Username: "username", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data, err := WriteToken(tc.token, tc.key) + testkit.TestFatalError(t, "WriteToken", err) + + token, err := ReadToken(data, tc.key) + testkit.TestFatalError(t, "ReadToken", err) + + testkit.TestValue(t, "ReadWriteToken", token, tc.token) + }) + } +} diff --git a/pkg/ext/gorm_logger.go b/pkg/ext/gorm_logger.go new file mode 100644 index 0000000..bfb26d2 --- /dev/null +++ b/pkg/ext/gorm_logger.go @@ -0,0 +1,58 @@ +package ext + +import ( + "context" + "fmt" + "time" + + "github.com/sirupsen/logrus" + "gorm.io/gorm/logger" + "gorm.io/gorm/utils" +) + +type Log struct { + logrus *logrus.Entry +} + +func getFullMsg(msg string, data ...interface{}) string { + return fmt.Sprintf(msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) +} + +func (self *Log) LogMode(log logger.LogLevel) logger.Interface { + return self +} + +func (self *Log) Info(ctx context.Context, msg string, data ...interface{}) { + fullMsg := getFullMsg(msg, data) + self.logrus. + WithContext(ctx). + Info(fullMsg) +} + +func (self *Log) Warn(ctx context.Context, msg string, data ...interface{}) { + fullMsg := getFullMsg(msg, data) + self.logrus. + WithContext(ctx). + Warn(fullMsg) +} +func (self *Log) Error(ctx context.Context, msg string, data ...interface{}) { + fullMsg := getFullMsg(msg, data) + self.logrus. + WithContext(ctx). + Error(fullMsg) +} + +func (self *Log) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + elapsed := time.Since(begin) + sql, _ := fc() + self.logrus. + WithContext(ctx). + WithField("time", elapsed). + Printf(sql) +} + +func Wraplog(log *logrus.Entry) *Log { + return &Log{ + logrus: log, + } +} diff --git a/pkg/ext/middleware.go b/pkg/ext/middleware.go new file mode 100644 index 0000000..771c0ac --- /dev/null +++ b/pkg/ext/middleware.go @@ -0,0 +1,89 @@ +package ext + +import ( + "encoding/base64" + "time" + + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" +) + +func HTML(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + ctx.Response.Header.SetContentType("text/html") + next(ctx) + } +} + +type LogMiddleware struct { + entry *logrus.Entry +} + +func NewLogMiddleare(log *logrus.Entry) *LogMiddleware { + return &LogMiddleware{ + entry: log, + } +} + +func (l *LogMiddleware) HTTP(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + start := time.Now() + next(ctx) + elapsed := time.Since(start) + l.entry. + WithField("time", elapsed). + WithField("code", ctx.Response.StatusCode()). + WithField("path", string(ctx.Path())). + WithField("bytes", len(ctx.Response.Body())). + Info(string(ctx.Request.Header.Method())) + } +} + +type AuthMiddleware struct { + key []byte + entry *logrus.Entry +} + +func NewAuthMiddleware(key []byte, log *logrus.Entry) *AuthMiddleware { + return &AuthMiddleware{ + key: key, + entry: log.WithField("context", "auth"), + } +} + +func (a *AuthMiddleware) LoggedIn(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + if path == "/login" { + next(ctx) + return + } + + redirectLogin := "/login?redirect=" + path + authBase64 := ctx.Request.Header.Cookie("auth") + if authBase64 == nil { + a.entry.Info("No auth provided") + ctx.Redirect(redirectLogin, 307) + return + } + + auth, err := base64.StdEncoding.DecodeString(string(authBase64)) + if err != nil { + a.entry.Error(err) + return + } + + token, err := ReadToken(auth, a.key) + if err != nil { + a.entry.Error(err) + ctx.Redirect(redirectLogin, 307) + return + } + ctx.SetUserValue("token", token) + a.entry. + WithField("userID", token.UserID). + WithField("username", token.Username). + Info("user recognized") + next(ctx) + } +} diff --git a/pkg/ext/responses.go b/pkg/ext/responses.go new file mode 100644 index 0000000..7354395 --- /dev/null +++ b/pkg/ext/responses.go @@ -0,0 +1,50 @@ +package ext + +import ( + "bytes" + "fmt" + + "github.com/valyala/fasthttp" + + "git.sr.ht/~gabrielgio/img" +) + +var ( + ContentTypeJSON = []byte("application/json") + ContentTypeHTML = []byte("text/html") + ContentTypeMARKDOWN = []byte("text/markdown") + ContentTypeJPEG = []byte("image/jpeg") +) + +func NotFoundHTML(ctx *fasthttp.RequestCtx) { + ctx.Response.Header.SetContentType("text/html") + //nolint:errcheck + img.Render(ctx, "error.html", &img.HTMLView[string]{ + Data: "NotFound", + }) +} + +func NotFound(ctx *fasthttp.RequestCtx) { + ctx.Response.SetStatusCode(404) + ct := ctx.Response.Header.ContentType() + if bytes.Equal(ct, ContentTypeHTML) { + NotFoundHTML(ctx) + } +} + +func InternalServerError(ctx *fasthttp.RequestCtx, err error) { + ctx.Response.Header.SetContentType("text/html") + message := fmt.Sprintf("Internal Server Error:\n%+v", err) + //nolint:errcheck + respErr := img.Render(ctx, "error.html", &img.HTMLView[string]{ + Data: message, + }) + + if respErr != nil { + fmt.Println(respErr.Error()) + } +} + +func NoContent(ctx *fasthttp.RequestCtx) { + ctx.Response.SetStatusCode(204) +} diff --git a/pkg/ext/router.go b/pkg/ext/router.go new file mode 100644 index 0000000..74f0a95 --- /dev/null +++ b/pkg/ext/router.go @@ -0,0 +1,51 @@ +package ext + +import ( + "github.com/fasthttp/router" + "github.com/valyala/fasthttp" +) + +type ( + Router struct { + middlewares []Middleware + fastRouter *router.Router + } + Middleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler + ErrorRequestHandler func(ctx *fasthttp.RequestCtx) error +) + +func NewRouter(nestedRouter *router.Router) *Router { + return &Router{ + fastRouter: nestedRouter, + } +} + +func (self *Router) AddMiddleware(middleware Middleware) { + self.middlewares = append(self.middlewares, middleware) +} + +func wrapError(next ErrorRequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + if err := next(ctx); err != nil { + ctx.Response.SetStatusCode(500) + InternalServerError(ctx, err) + } + } +} + +func (self *Router) run(next ErrorRequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + req := wrapError(next) + for _, r := range self.middlewares { + req = r(req) + } + req(ctx) + } +} + +func (self *Router) GET(path string, handler ErrorRequestHandler) { + self.fastRouter.GET(path, self.run(handler)) +} +func (self *Router) POST(path string, handler ErrorRequestHandler) { + self.fastRouter.POST(path, self.run(handler)) +} |