From ae10e121875982d6956d6bff453544cc59a75616 Mon Sep 17 00:00:00 2001 From: Gabriel Arakaki Giovanini Date: Tue, 12 Sep 2023 18:37:30 +0200 Subject: feat: Add admin control Now only admins can access settings. --- pkg/database/sql/user.go | 18 +++++++----------- pkg/ext/middleware.go | 47 ++++++++++++++++++++++++++++++++++------------- pkg/ext/responses.go | 4 ++-- pkg/service/auth.go | 5 +---- pkg/view/album.go | 10 +++++----- pkg/view/auth.go | 13 ++++++++----- pkg/view/filesystem.go | 10 +++++----- pkg/view/media.go | 14 +++++++------- pkg/view/settings.go | 33 +++++++++++++++++++++++---------- pkg/view/view.go | 17 ++++++++++++++++- 10 files changed, 108 insertions(+), 63 deletions(-) (limited to 'pkg') diff --git a/pkg/database/sql/user.go b/pkg/database/sql/user.go index 2ec8622..0c503c2 100644 --- a/pkg/database/sql/user.go +++ b/pkg/database/sql/user.go @@ -158,20 +158,16 @@ func (self *UserRepository) Create(ctx context.Context, createUser *repository.C } func (self *UserRepository) Update(ctx context.Context, id uint, update *repository.UpdateUser) error { - user := &User{ - Model: gorm.Model{ - ID: id, - }, - Username: update.Username, - Name: update.Name, - IsAdmin: update.IsAdmin, - Path: update.Path, - } - result := self.db. WithContext(ctx). + Model(&User{}). Omit("password"). - Updates(user) + Where("id = ?", id). + Update("username", update.Username). + Update("name", update.Name). + Update("is_admin", update.IsAdmin). + Update("path", update.Path) + if result.Error != nil { return wrapError(result.Error) } diff --git a/pkg/ext/middleware.go b/pkg/ext/middleware.go index 061cf7c..6a94c4f 100644 --- a/pkg/ext/middleware.go +++ b/pkg/ext/middleware.go @@ -20,9 +20,17 @@ func HTML(next http.HandlerFunc) http.HandlerFunc { } } -type LogMiddleware struct { - entry *logrus.Entry -} +type ( + User string + + LogMiddleware struct { + entry *logrus.Entry + } +) + +const ( + UserKey User = "user" +) func NewLogMiddleare(log *logrus.Entry) *LogMiddleware { return &LogMiddleware{ @@ -43,14 +51,20 @@ func (l *LogMiddleware) HTTP(next http.HandlerFunc) http.HandlerFunc { } type AuthMiddleware struct { - key []byte - entry *logrus.Entry + key []byte + entry *logrus.Entry + userRepository repository.UserRepository } -func NewAuthMiddleware(key []byte, log *logrus.Entry) *AuthMiddleware { +func NewAuthMiddleware( + key []byte, + log *logrus.Entry, + userRepository repository.UserRepository, +) *AuthMiddleware { return &AuthMiddleware{ - key: key, - entry: log.WithField("context", "auth"), + key: key, + entry: log.WithField("context", "auth"), + userRepository: userRepository, } } @@ -82,7 +96,14 @@ func (a *AuthMiddleware) LoggedIn(next http.HandlerFunc) http.HandlerFunc { http.Redirect(w, r, redirectLogin, http.StatusTemporaryRedirect) return } - r = r.WithContext(context.WithValue(r.Context(), service.TokenKey, token)) + + user, err := a.userRepository.Get(r.Context(), token.UserID) + if err != nil { + a.entry.Error(err) + return + } + + r = r.WithContext(context.WithValue(r.Context(), UserKey, user)) a.entry. WithField("userID", token.UserID). WithField("username", token.Username). @@ -91,9 +112,9 @@ func (a *AuthMiddleware) LoggedIn(next http.HandlerFunc) http.HandlerFunc { } } -func GetTokenFromCtx(r *http.Request) *service.Token { - tokenValue := r.Context().Value(service.TokenKey) - if token, ok := tokenValue.(*service.Token); ok { +func GetUserFromCtx(r *http.Request) *repository.User { + tokenValue := r.Context().Value(UserKey) + if token, ok := tokenValue.(*repository.User); ok { return token } return nil @@ -113,7 +134,7 @@ func (i *InitialSetupMiddleware) Check(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // if user has been set to context it is logged in already - token := GetTokenFromCtx(r) + token := GetUserFromCtx(r) if token != nil { next(w, r) return diff --git a/pkg/ext/responses.go b/pkg/ext/responses.go index 34e5f27..d8941e8 100644 --- a/pkg/ext/responses.go +++ b/pkg/ext/responses.go @@ -10,12 +10,12 @@ import ( func NotFound(w http.ResponseWriter) { templates.WritePageTemplate(w, &templates.ErrorPage{ Err: "Not Found", - }) + }, false) } func InternalServerError(w http.ResponseWriter, err error) { w.WriteHeader(http.StatusInternalServerError) templates.WritePageTemplate(w, &templates.ErrorPage{ Err: fmt.Sprintf("Internal Server Error:\n%s", err.Error()), - }) + }, false) } diff --git a/pkg/service/auth.go b/pkg/service/auth.go index 2fc06e3..3811965 100644 --- a/pkg/service/auth.go +++ b/pkg/service/auth.go @@ -147,15 +147,12 @@ func (u *AuthController) Upsert( } type ( - AuthKey string - Token struct { + Token struct { UserID uint Username string } ) -const TokenKey AuthKey = "token" - func ReadToken(data []byte, key []byte) (*Token, error) { block, err := aes.NewCipher(key) if err != nil { diff --git a/pkg/view/album.go b/pkg/view/album.go index b19e381..e0ee405 100644 --- a/pkg/view/album.go +++ b/pkg/view/album.go @@ -30,10 +30,10 @@ func NewAlbumView( func (self *AlbumView) Index(w http.ResponseWriter, r *http.Request) error { p := getPagination(r) - token := ext.GetTokenFromCtx(r) + user := ext.GetUserFromCtx(r) // TODO: optmize call, GetPathFromUserID may no be necessary - userPath, err := self.userRepository.GetPathFromUserID(r.Context(), token.UserID) + userPath, err := self.userRepository.GetPathFromUserID(r.Context(), user.ID) if err != nil { return err } @@ -91,12 +91,12 @@ func (self *AlbumView) Index(w http.ResponseWriter, r *http.Request) error { Settings: settings, } - templates.WritePageTemplate(w, page) + templates.WritePageTemplate(w, page, user.IsAdmin) return nil } func (self *AlbumView) SetMyselfIn(r *ext.Router) { - r.GET("/album/", self.Index) - r.POST("/album/", self.Index) + r.GET("/album", self.Index) + r.POST("/album", self.Index) } diff --git a/pkg/view/auth.go b/pkg/view/auth.go index 8d87035..318d0a3 100644 --- a/pkg/view/auth.go +++ b/pkg/view/auth.go @@ -20,8 +20,8 @@ func NewAuthView(userController *service.AuthController) *AuthView { } } -func (v *AuthView) LoginView(w http.ResponseWriter, _ *http.Request) error { - templates.WritePageTemplate(w, &templates.LoginPage{}) +func (v *AuthView) LoginView(w http.ResponseWriter, r *http.Request) error { + templates.WritePageTemplate(w, &templates.LoginPage{}, false) return nil } @@ -46,12 +46,15 @@ func (v *AuthView) Login(w http.ResponseWriter, r *http.Request) error { ) auth, err := v.userController.Login(r.Context(), username, password) + if err != nil { + return err + } if errors.Is(err, service.InvalidLogin) { templates.WritePageTemplate(w, &templates.LoginPage{ Username: r.FormValue("username"), Err: err.Error(), - }) + }, false) return nil } @@ -82,8 +85,8 @@ func Index(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) } -func (v *AuthView) InitialRegisterView(w http.ResponseWriter, _ *http.Request) error { - templates.WritePageTemplate(w, &templates.RegisterPage{}) +func (v *AuthView) InitialRegisterView(w http.ResponseWriter, r *http.Request) error { + templates.WritePageTemplate(w, &templates.RegisterPage{}, false) return nil } diff --git a/pkg/view/filesystem.go b/pkg/view/filesystem.go index 24f0ce6..9071ec0 100644 --- a/pkg/view/filesystem.go +++ b/pkg/view/filesystem.go @@ -34,10 +34,10 @@ func NewFileSystemView( func (self *FileSystemView) Index(w http.ResponseWriter, r *http.Request) error { var ( pathValue = r.FormValue("path") - token = ext.GetTokenFromCtx(r) + user = ext.GetUserFromCtx(r) ) - page, err := self.fsService.GetPage(r.Context(), token.UserID, pathValue) + page, err := self.fsService.GetPage(r.Context(), user.ID, pathValue) if err != nil { return err } @@ -51,7 +51,7 @@ func (self *FileSystemView) Index(w http.ResponseWriter, r *http.Request) error Page: page, ShowMode: settings.ShowMode, ShowOwner: settings.ShowOwner, - }) + }, user.IsAdmin) return nil } @@ -59,6 +59,6 @@ func (self *FileSystemView) Index(w http.ResponseWriter, r *http.Request) error func (self *FileSystemView) SetMyselfIn(r *ext.Router) { r.GET("/", self.Index) r.POST("/", self.Index) - r.GET("/fs/", self.Index) - r.POST("/fs/", self.Index) + r.GET("/fs", self.Index) + r.POST("/fs", self.Index) } diff --git a/pkg/view/media.go b/pkg/view/media.go index 3041998..8a10fe0 100644 --- a/pkg/view/media.go +++ b/pkg/view/media.go @@ -71,9 +71,9 @@ func NewMediaView( func (self *MediaView) Index(w http.ResponseWriter, r *http.Request) error { p := getPagination(r) - token := ext.GetTokenFromCtx(r) + user := ext.GetUserFromCtx(r) - userPath, err := self.userRepository.GetPathFromUserID(r.Context(), token.UserID) + userPath, err := self.userRepository.GetPathFromUserID(r.Context(), user.ID) if err != nil { return err } @@ -98,7 +98,7 @@ func (self *MediaView) Index(w http.ResponseWriter, r *http.Request) error { Settings: settings, } - templates.WritePageTemplate(w, page) + templates.WritePageTemplate(w, page, user.IsAdmin) return nil } @@ -132,9 +132,9 @@ func (self *MediaView) GetThumbnail(w http.ResponseWriter, r *http.Request) erro } func (self *MediaView) SetMyselfIn(r *ext.Router) { - r.GET("/media/", self.Index) - r.POST("/media/", self.Index) + r.GET("/media", self.Index) + r.POST("/media", self.Index) - r.GET("/media/image/", self.GetImage) - r.GET("/media/thumbnail/", self.GetThumbnail) + r.GET("/media/image", self.GetImage) + r.GET("/media/thumbnail", self.GetThumbnail) } diff --git a/pkg/view/settings.go b/pkg/view/settings.go index bf2dca6..cdd7baa 100644 --- a/pkg/view/settings.go +++ b/pkg/view/settings.go @@ -39,23 +39,28 @@ func (self *SettingsView) Index(w http.ResponseWriter, r *http.Request) error { return err } + user := ext.GetUserFromCtx(r) + templates.WritePageTemplate(w, &templates.SettingsPage{ Settings: s, Users: users, - }) + }, user.IsAdmin) return nil } func (self *SettingsView) User(w http.ResponseWriter, r *http.Request) error { - id := r.FormValue("userId") + var ( + id = r.URL.Query().Get("userId") + user = ext.GetUserFromCtx(r) + ) idValue, err := ParseUint(id) if err != nil { return err } if idValue == nil { - templates.WritePageTemplate(w, &templates.UserPage{}) + templates.WritePageTemplate(w, &templates.UserPage{}, user.IsAdmin) } else { user, err := self.userController.Get(r.Context(), *idValue) if err != nil { @@ -67,7 +72,7 @@ func (self *SettingsView) User(w http.ResponseWriter, r *http.Request) error { Username: user.Username, Path: user.Path, IsAdmin: user.IsAdmin, - }) + }, user.IsAdmin) } return nil @@ -87,7 +92,15 @@ func (self *SettingsView) UpsertUser(w http.ResponseWriter, r *http.Request) err return err } - err = self.userController.Upsert(r.Context(), idValue, username, "", password, isAdmin, path) + err = self.userController.Upsert( + r.Context(), + idValue, + username, + "", + password, + isAdmin, + path, + ) if err != nil { return err } @@ -137,12 +150,12 @@ func (self *SettingsView) Save(w http.ResponseWriter, r *http.Request) error { } func (self *SettingsView) SetMyselfIn(r *ext.Router) { - r.GET("/settings/", self.Index) - r.POST("/settings/", self.Save) + r.GET("/settings", Protect(self.Index)) + r.POST("/settings", Protect(self.Save)) - r.GET("/users/", self.User) - r.GET("/users/delete", self.Delete) - r.POST("/users/", self.UpsertUser) + r.GET("/users", Protect(self.User)) + r.GET("/users/delete", Protect(self.Delete)) + r.POST("/users", Protect(self.UpsertUser)) } func ParseUint(id string) (*uint, error) { diff --git a/pkg/view/view.go b/pkg/view/view.go index 663738b..f8dfa16 100644 --- a/pkg/view/view.go +++ b/pkg/view/view.go @@ -1,7 +1,22 @@ package view -import "git.sr.ht/~gabrielgio/img/pkg/ext" +import ( + "net/http" + + "git.sr.ht/~gabrielgio/img/pkg/ext" +) type View interface { SetMyselfIn(r *ext.Router) } + +func Protect(next ext.ErrorRequestHandler) ext.ErrorRequestHandler { + return func(w http.ResponseWriter, r *http.Request) error { + user := ext.GetUserFromCtx(r) + if !user.IsAdmin { + http.NotFound(w, r) + return nil + } + return next(w, r) + } +} -- cgit v1.2.3