aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/server/main.go4
-rw-r--r--pkg/database/sql/user.go18
-rw-r--r--pkg/ext/middleware.go47
-rw-r--r--pkg/ext/responses.go4
-rw-r--r--pkg/service/auth.go5
-rw-r--r--pkg/view/album.go10
-rw-r--r--pkg/view/auth.go13
-rw-r--r--pkg/view/filesystem.go10
-rw-r--r--pkg/view/media.go14
-rw-r--r--pkg/view/settings.go33
-rw-r--r--pkg/view/view.go17
-rw-r--r--templates/album.qtpl4
-rw-r--r--templates/base.qtpl12
-rw-r--r--templates/media.qtpl2
-rw-r--r--templates/mosaic.qtpl6
-rw-r--r--templates/settings.qtpl2
-rw-r--r--templates/user.qtpl4
17 files changed, 126 insertions, 79 deletions
diff --git a/cmd/server/main.go b/cmd/server/main.go
index 58256fa..41b2b4a 100644
--- a/cmd/server/main.go
+++ b/cmd/server/main.go
@@ -72,7 +72,7 @@ func main() {
panic("failed to decode key database: " + err.Error())
}
- r := mux.NewRouter()
+ r := mux.NewRouter().StrictSlash(false)
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.FS(static.Static))))
// repository
@@ -85,7 +85,7 @@ func main() {
// middleware
var (
- authMiddleware = ext.NewAuthMiddleware(baseKey, logger.WithField("context", "auth"))
+ authMiddleware = ext.NewAuthMiddleware(baseKey, logger.WithField("context", "auth"), userRepository)
logMiddleware = ext.NewLogMiddleare(logger.WithField("context", "http"))
initialMiddleware = ext.NewInitialSetupMiddleware(userRepository)
)
diff --git a/pkg/database/sql/user.go b/pkg/database/sql/user.go
index 2ec8622..0c503c2 100644
--- a/pkg/database/sql/user.go
+++ b/pkg/database/sql/user.go
@@ -158,20 +158,16 @@ func (self *UserRepository) Create(ctx context.Context, createUser *repository.C
}
func (self *UserRepository) Update(ctx context.Context, id uint, update *repository.UpdateUser) error {
- user := &User{
- Model: gorm.Model{
- ID: id,
- },
- Username: update.Username,
- Name: update.Name,
- IsAdmin: update.IsAdmin,
- Path: update.Path,
- }
-
result := self.db.
WithContext(ctx).
+ Model(&User{}).
Omit("password").
- Updates(user)
+ Where("id = ?", id).
+ Update("username", update.Username).
+ Update("name", update.Name).
+ Update("is_admin", update.IsAdmin).
+ Update("path", update.Path)
+
if result.Error != nil {
return wrapError(result.Error)
}
diff --git a/pkg/ext/middleware.go b/pkg/ext/middleware.go
index 061cf7c..6a94c4f 100644
--- a/pkg/ext/middleware.go
+++ b/pkg/ext/middleware.go
@@ -20,9 +20,17 @@ func HTML(next http.HandlerFunc) http.HandlerFunc {
}
}
-type LogMiddleware struct {
- entry *logrus.Entry
-}
+type (
+ User string
+
+ LogMiddleware struct {
+ entry *logrus.Entry
+ }
+)
+
+const (
+ UserKey User = "user"
+)
func NewLogMiddleare(log *logrus.Entry) *LogMiddleware {
return &LogMiddleware{
@@ -43,14 +51,20 @@ func (l *LogMiddleware) HTTP(next http.HandlerFunc) http.HandlerFunc {
}
type AuthMiddleware struct {
- key []byte
- entry *logrus.Entry
+ key []byte
+ entry *logrus.Entry
+ userRepository repository.UserRepository
}
-func NewAuthMiddleware(key []byte, log *logrus.Entry) *AuthMiddleware {
+func NewAuthMiddleware(
+ key []byte,
+ log *logrus.Entry,
+ userRepository repository.UserRepository,
+) *AuthMiddleware {
return &AuthMiddleware{
- key: key,
- entry: log.WithField("context", "auth"),
+ key: key,
+ entry: log.WithField("context", "auth"),
+ userRepository: userRepository,
}
}
@@ -82,7 +96,14 @@ func (a *AuthMiddleware) LoggedIn(next http.HandlerFunc) http.HandlerFunc {
http.Redirect(w, r, redirectLogin, http.StatusTemporaryRedirect)
return
}
- r = r.WithContext(context.WithValue(r.Context(), service.TokenKey, token))
+
+ user, err := a.userRepository.Get(r.Context(), token.UserID)
+ if err != nil {
+ a.entry.Error(err)
+ return
+ }
+
+ r = r.WithContext(context.WithValue(r.Context(), UserKey, user))
a.entry.
WithField("userID", token.UserID).
WithField("username", token.Username).
@@ -91,9 +112,9 @@ func (a *AuthMiddleware) LoggedIn(next http.HandlerFunc) http.HandlerFunc {
}
}
-func GetTokenFromCtx(r *http.Request) *service.Token {
- tokenValue := r.Context().Value(service.TokenKey)
- if token, ok := tokenValue.(*service.Token); ok {
+func GetUserFromCtx(r *http.Request) *repository.User {
+ tokenValue := r.Context().Value(UserKey)
+ if token, ok := tokenValue.(*repository.User); ok {
return token
}
return nil
@@ -113,7 +134,7 @@ func (i *InitialSetupMiddleware) Check(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// if user has been set to context it is logged in already
- token := GetTokenFromCtx(r)
+ token := GetUserFromCtx(r)
if token != nil {
next(w, r)
return
diff --git a/pkg/ext/responses.go b/pkg/ext/responses.go
index 34e5f27..d8941e8 100644
--- a/pkg/ext/responses.go
+++ b/pkg/ext/responses.go
@@ -10,12 +10,12 @@ import (
func NotFound(w http.ResponseWriter) {
templates.WritePageTemplate(w, &templates.ErrorPage{
Err: "Not Found",
- })
+ }, false)
}
func InternalServerError(w http.ResponseWriter, err error) {
w.WriteHeader(http.StatusInternalServerError)
templates.WritePageTemplate(w, &templates.ErrorPage{
Err: fmt.Sprintf("Internal Server Error:\n%s", err.Error()),
- })
+ }, false)
}
diff --git a/pkg/service/auth.go b/pkg/service/auth.go
index 2fc06e3..3811965 100644
--- a/pkg/service/auth.go
+++ b/pkg/service/auth.go
@@ -147,15 +147,12 @@ func (u *AuthController) Upsert(
}
type (
- AuthKey string
- Token struct {
+ Token struct {
UserID uint
Username string
}
)
-const TokenKey AuthKey = "token"
-
func ReadToken(data []byte, key []byte) (*Token, error) {
block, err := aes.NewCipher(key)
if err != nil {
diff --git a/pkg/view/album.go b/pkg/view/album.go
index b19e381..e0ee405 100644
--- a/pkg/view/album.go
+++ b/pkg/view/album.go
@@ -30,10 +30,10 @@ func NewAlbumView(
func (self *AlbumView) Index(w http.ResponseWriter, r *http.Request) error {
p := getPagination(r)
- token := ext.GetTokenFromCtx(r)
+ user := ext.GetUserFromCtx(r)
// TODO: optmize call, GetPathFromUserID may no be necessary
- userPath, err := self.userRepository.GetPathFromUserID(r.Context(), token.UserID)
+ userPath, err := self.userRepository.GetPathFromUserID(r.Context(), user.ID)
if err != nil {
return err
}
@@ -91,12 +91,12 @@ func (self *AlbumView) Index(w http.ResponseWriter, r *http.Request) error {
Settings: settings,
}
- templates.WritePageTemplate(w, page)
+ templates.WritePageTemplate(w, page, user.IsAdmin)
return nil
}
func (self *AlbumView) SetMyselfIn(r *ext.Router) {
- r.GET("/album/", self.Index)
- r.POST("/album/", self.Index)
+ r.GET("/album", self.Index)
+ r.POST("/album", self.Index)
}
diff --git a/pkg/view/auth.go b/pkg/view/auth.go
index 8d87035..318d0a3 100644
--- a/pkg/view/auth.go
+++ b/pkg/view/auth.go
@@ -20,8 +20,8 @@ func NewAuthView(userController *service.AuthController) *AuthView {
}
}
-func (v *AuthView) LoginView(w http.ResponseWriter, _ *http.Request) error {
- templates.WritePageTemplate(w, &templates.LoginPage{})
+func (v *AuthView) LoginView(w http.ResponseWriter, r *http.Request) error {
+ templates.WritePageTemplate(w, &templates.LoginPage{}, false)
return nil
}
@@ -46,12 +46,15 @@ func (v *AuthView) Login(w http.ResponseWriter, r *http.Request) error {
)
auth, err := v.userController.Login(r.Context(), username, password)
+ if err != nil {
+ return err
+ }
if errors.Is(err, service.InvalidLogin) {
templates.WritePageTemplate(w, &templates.LoginPage{
Username: r.FormValue("username"),
Err: err.Error(),
- })
+ }, false)
return nil
}
@@ -82,8 +85,8 @@ func Index(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
}
-func (v *AuthView) InitialRegisterView(w http.ResponseWriter, _ *http.Request) error {
- templates.WritePageTemplate(w, &templates.RegisterPage{})
+func (v *AuthView) InitialRegisterView(w http.ResponseWriter, r *http.Request) error {
+ templates.WritePageTemplate(w, &templates.RegisterPage{}, false)
return nil
}
diff --git a/pkg/view/filesystem.go b/pkg/view/filesystem.go
index 24f0ce6..9071ec0 100644
--- a/pkg/view/filesystem.go
+++ b/pkg/view/filesystem.go
@@ -34,10 +34,10 @@ func NewFileSystemView(
func (self *FileSystemView) Index(w http.ResponseWriter, r *http.Request) error {
var (
pathValue = r.FormValue("path")
- token = ext.GetTokenFromCtx(r)
+ user = ext.GetUserFromCtx(r)
)
- page, err := self.fsService.GetPage(r.Context(), token.UserID, pathValue)
+ page, err := self.fsService.GetPage(r.Context(), user.ID, pathValue)
if err != nil {
return err
}
@@ -51,7 +51,7 @@ func (self *FileSystemView) Index(w http.ResponseWriter, r *http.Request) error
Page: page,
ShowMode: settings.ShowMode,
ShowOwner: settings.ShowOwner,
- })
+ }, user.IsAdmin)
return nil
}
@@ -59,6 +59,6 @@ func (self *FileSystemView) Index(w http.ResponseWriter, r *http.Request) error
func (self *FileSystemView) SetMyselfIn(r *ext.Router) {
r.GET("/", self.Index)
r.POST("/", self.Index)
- r.GET("/fs/", self.Index)
- r.POST("/fs/", self.Index)
+ r.GET("/fs", self.Index)
+ r.POST("/fs", self.Index)
}
diff --git a/pkg/view/media.go b/pkg/view/media.go
index 3041998..8a10fe0 100644
--- a/pkg/view/media.go
+++ b/pkg/view/media.go
@@ -71,9 +71,9 @@ func NewMediaView(
func (self *MediaView) Index(w http.ResponseWriter, r *http.Request) error {
p := getPagination(r)
- token := ext.GetTokenFromCtx(r)
+ user := ext.GetUserFromCtx(r)
- userPath, err := self.userRepository.GetPathFromUserID(r.Context(), token.UserID)
+ userPath, err := self.userRepository.GetPathFromUserID(r.Context(), user.ID)
if err != nil {
return err
}
@@ -98,7 +98,7 @@ func (self *MediaView) Index(w http.ResponseWriter, r *http.Request) error {
Settings: settings,
}
- templates.WritePageTemplate(w, page)
+ templates.WritePageTemplate(w, page, user.IsAdmin)
return nil
}
@@ -132,9 +132,9 @@ func (self *MediaView) GetThumbnail(w http.ResponseWriter, r *http.Request) erro
}
func (self *MediaView) SetMyselfIn(r *ext.Router) {
- r.GET("/media/", self.Index)
- r.POST("/media/", self.Index)
+ r.GET("/media", self.Index)
+ r.POST("/media", self.Index)
- r.GET("/media/image/", self.GetImage)
- r.GET("/media/thumbnail/", self.GetThumbnail)
+ r.GET("/media/image", self.GetImage)
+ r.GET("/media/thumbnail", self.GetThumbnail)
}
diff --git a/pkg/view/settings.go b/pkg/view/settings.go
index bf2dca6..cdd7baa 100644
--- a/pkg/view/settings.go
+++ b/pkg/view/settings.go
@@ -39,23 +39,28 @@ func (self *SettingsView) Index(w http.ResponseWriter, r *http.Request) error {
return err
}
+ user := ext.GetUserFromCtx(r)
+
templates.WritePageTemplate(w, &templates.SettingsPage{
Settings: s,
Users: users,
- })
+ }, user.IsAdmin)
return nil
}
func (self *SettingsView) User(w http.ResponseWriter, r *http.Request) error {
- id := r.FormValue("userId")
+ var (
+ id = r.URL.Query().Get("userId")
+ user = ext.GetUserFromCtx(r)
+ )
idValue, err := ParseUint(id)
if err != nil {
return err
}
if idValue == nil {
- templates.WritePageTemplate(w, &templates.UserPage{})
+ templates.WritePageTemplate(w, &templates.UserPage{}, user.IsAdmin)
} else {
user, err := self.userController.Get(r.Context(), *idValue)
if err != nil {
@@ -67,7 +72,7 @@ func (self *SettingsView) User(w http.ResponseWriter, r *http.Request) error {
Username: user.Username,
Path: user.Path,
IsAdmin: user.IsAdmin,
- })
+ }, user.IsAdmin)
}
return nil
@@ -87,7 +92,15 @@ func (self *SettingsView) UpsertUser(w http.ResponseWriter, r *http.Request) err
return err
}
- err = self.userController.Upsert(r.Context(), idValue, username, "", password, isAdmin, path)
+ err = self.userController.Upsert(
+ r.Context(),
+ idValue,
+ username,
+ "",
+ password,
+ isAdmin,
+ path,
+ )
if err != nil {
return err
}
@@ -137,12 +150,12 @@ func (self *SettingsView) Save(w http.ResponseWriter, r *http.Request) error {
}
func (self *SettingsView) SetMyselfIn(r *ext.Router) {
- r.GET("/settings/", self.Index)
- r.POST("/settings/", self.Save)
+ r.GET("/settings", Protect(self.Index))
+ r.POST("/settings", Protect(self.Save))
- r.GET("/users/", self.User)
- r.GET("/users/delete", self.Delete)
- r.POST("/users/", self.UpsertUser)
+ r.GET("/users", Protect(self.User))
+ r.GET("/users/delete", Protect(self.Delete))
+ r.POST("/users", Protect(self.UpsertUser))
}
func ParseUint(id string) (*uint, error) {
diff --git a/pkg/view/view.go b/pkg/view/view.go
index 663738b..f8dfa16 100644
--- a/pkg/view/view.go
+++ b/pkg/view/view.go
@@ -1,7 +1,22 @@
package view
-import "git.sr.ht/~gabrielgio/img/pkg/ext"
+import (
+ "net/http"
+
+ "git.sr.ht/~gabrielgio/img/pkg/ext"
+)
type View interface {
SetMyselfIn(r *ext.Router)
}
+
+func Protect(next ext.ErrorRequestHandler) ext.ErrorRequestHandler {
+ return func(w http.ResponseWriter, r *http.Request) error {
+ user := ext.GetUserFromCtx(r)
+ if !user.IsAdmin {
+ http.NotFound(w, r)
+ return nil
+ }
+ return next(w, r)
+ }
+}
diff --git a/templates/album.qtpl b/templates/album.qtpl
index 835db57..58fc499 100644
--- a/templates/album.qtpl
+++ b/templates/album.qtpl
@@ -23,14 +23,14 @@ func (m *AlbumPage) PreloadAttr() string {
<h1 class="title text-size-1">{%s p.Name %}</h1>
<div class="tags are-large">
{% for _, a := range p.Albums %}
- <a href="/album/?albumId={%s FromUInttoString(&a.ID) %}" class="tag text-size-2">{%s a.Name %}</a>
+ <a href="/album?albumId={%s FromUInttoString(&a.ID) %}" class="tag text-size-2">{%s a.Name %}</a>
{% endfor %}
</div>
<div class="columns">
{%= Mosaic(p.Medias, p.PreloadAttr()) %}
</div>
<div>
- <a href="/album/?albumId={%s FromUInttoString(p.Next.AlbumID) %}&page={%d p.Next.Page %}" class="button is-pulled-right">next</a>
+ <a href="/album?albumId={%s FromUInttoString(p.Next.AlbumID) %}&page={%d p.Next.Page %}" class="button is-pulled-right">next</a>
</div>
{% endfunc %}
diff --git a/templates/base.qtpl b/templates/base.qtpl
index a80803a..30b084e 100644
--- a/templates/base.qtpl
+++ b/templates/base.qtpl
@@ -21,7 +21,7 @@ Page {
Page prints a page implementing Page interface.
-{% func PageTemplate(p Page) %}
+{% func PageTemplate(p Page, isAdmin bool) %}
<html lang="en">
<head>
<meta charset="utf-8">
@@ -33,18 +33,20 @@ Page prints a page implementing Page interface.
<body>
<nav class="navbar">
<div class="navbar-start">
- <a href="/fs/" class="navbar-item text-size-1">
+ <a href="/fs" class="navbar-item text-size-1">
file
</a>
- <a href="/media/" class="navbar-item text-size-1">
+ <a href="/media" class="navbar-item text-size-1">
media
</a>
- <a href="/album/" class="navbar-item text-size-1">
+ <a href="/album" class="navbar-item text-size-1">
album
</a>
- <a href="/settings/" class="navbar-item text-size-1">
+ {% if isAdmin %}
+ <a href="/settings" class="navbar-item text-size-1">
settings
</a>
+ {% endif %}
</div>
</nav>
<div class="container is-fullhd">
diff --git a/templates/media.qtpl b/templates/media.qtpl
index 4251deb..737d03d 100644
--- a/templates/media.qtpl
+++ b/templates/media.qtpl
@@ -22,7 +22,7 @@ func (m *MediaPage) PreloadAttr() string {
{%= Mosaic(p.Medias, p.PreloadAttr()) %}
</div>
<div>
- <a href="/media/?page={%d p.Next.Page %}" class="button is-pulled-right">next</a>
+ <a href="/media?page={%d p.Next.Page %}" class="button is-pulled-right">next</a>
</div>
{% endfunc %}
diff --git a/templates/mosaic.qtpl b/templates/mosaic.qtpl
index 3e6ccf8..18dbcba 100644
--- a/templates/mosaic.qtpl
+++ b/templates/mosaic.qtpl
@@ -8,12 +8,12 @@
{% for _, media := range c %}
<div class="card-image">
{% if media.IsVideo() %}
- <video class="image is-fit" controls muted="true" poster="/media/thumbnail/?path_hash={%s media.PathHash %}" preload="{%s preloadAttr %}">
- <source src="/media/image/?path_hash={%s media.PathHash %}" type="{%s media.MIMEType %}">
+ <video class="image is-fit" controls muted="true" poster="/media/thumbnail?path_hash={%s media.PathHash %}" preload="{%s preloadAttr %}">
+ <source src="/media/image?path_hash={%s media.PathHash %}" type="{%s media.MIMEType %}">
</video>
{% else %}
<figure class="image is-fit">
- <img src="/media/thumbnail/?path_hash={%s media.PathHash %}">
+ <img src="/media/thumbnail?path_hash={%s media.PathHash %}">
</figure>
{% endif %}
</div>
diff --git a/templates/settings.qtpl b/templates/settings.qtpl
index 4439c77..b720a88 100644
--- a/templates/settings.qtpl
+++ b/templates/settings.qtpl
@@ -58,7 +58,7 @@ type SettingsPage struct {
</div>
{% endfor %}
<div class="field">
- <a href="/users/" class="button">create</a>
+ <a href="/users" class="button">create</a>
</div>
</div>
{% endfunc %}
diff --git a/templates/user.qtpl b/templates/user.qtpl
index 6ec783d..6fc3ce6 100644
--- a/templates/user.qtpl
+++ b/templates/user.qtpl
@@ -13,7 +13,7 @@ type UserPage struct {
{% func (p *UserPage) Content() %}
<h1>Initial Setup</h1>
-<form action="/users/" method="post">
+<form action="/users" method="post">
{% if p.ID != nil %}
<input type="hidden" name="userId" value="{%s FromUInttoString(p.ID) %}" />
{% endif %}
@@ -41,7 +41,7 @@ type UserPage struct {
<div class="field">
<label class="label">Is Admin?</label>
<div class="control">
- <input type="checkbox" name="isAdmin" type="password" {% if p.IsAdmin %}checked{% endif %}>
+ <input type="checkbox" name="isAdmin" {% if p.IsAdmin %}checked{% endif %}>
</div>
</div>
<div class="field">