diff options
| -rw-r--r-- | cmd/server/main.go | 4 | ||||
| -rw-r--r-- | pkg/database/sql/user.go | 18 | ||||
| -rw-r--r-- | pkg/ext/middleware.go | 47 | ||||
| -rw-r--r-- | pkg/ext/responses.go | 4 | ||||
| -rw-r--r-- | pkg/service/auth.go | 5 | ||||
| -rw-r--r-- | pkg/view/album.go | 10 | ||||
| -rw-r--r-- | pkg/view/auth.go | 13 | ||||
| -rw-r--r-- | pkg/view/filesystem.go | 10 | ||||
| -rw-r--r-- | pkg/view/media.go | 14 | ||||
| -rw-r--r-- | pkg/view/settings.go | 33 | ||||
| -rw-r--r-- | pkg/view/view.go | 17 | ||||
| -rw-r--r-- | templates/album.qtpl | 4 | ||||
| -rw-r--r-- | templates/base.qtpl | 12 | ||||
| -rw-r--r-- | templates/media.qtpl | 2 | ||||
| -rw-r--r-- | templates/mosaic.qtpl | 6 | ||||
| -rw-r--r-- | templates/settings.qtpl | 2 | ||||
| -rw-r--r-- | templates/user.qtpl | 4 | 
17 files changed, 126 insertions, 79 deletions
| diff --git a/cmd/server/main.go b/cmd/server/main.go index 58256fa..41b2b4a 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -72,7 +72,7 @@ func main() {  		panic("failed to decode key database: " + err.Error())  	} -	r := mux.NewRouter() +	r := mux.NewRouter().StrictSlash(false)  	r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.FS(static.Static))))  	// repository @@ -85,7 +85,7 @@ func main() {  	// middleware  	var ( -		authMiddleware    = ext.NewAuthMiddleware(baseKey, logger.WithField("context", "auth")) +		authMiddleware    = ext.NewAuthMiddleware(baseKey, logger.WithField("context", "auth"), userRepository)  		logMiddleware     = ext.NewLogMiddleare(logger.WithField("context", "http"))  		initialMiddleware = ext.NewInitialSetupMiddleware(userRepository)  	) diff --git a/pkg/database/sql/user.go b/pkg/database/sql/user.go index 2ec8622..0c503c2 100644 --- a/pkg/database/sql/user.go +++ b/pkg/database/sql/user.go @@ -158,20 +158,16 @@ func (self *UserRepository) Create(ctx context.Context, createUser *repository.C  }  func (self *UserRepository) Update(ctx context.Context, id uint, update *repository.UpdateUser) error { -	user := &User{ -		Model: gorm.Model{ -			ID: id, -		}, -		Username: update.Username, -		Name:     update.Name, -		IsAdmin:  update.IsAdmin, -		Path:     update.Path, -	} -  	result := self.db.  		WithContext(ctx). +		Model(&User{}).  		Omit("password"). -		Updates(user) +		Where("id = ?", id). +		Update("username", update.Username). +		Update("name", update.Name). +		Update("is_admin", update.IsAdmin). +		Update("path", update.Path) +  	if result.Error != nil {  		return wrapError(result.Error)  	} diff --git a/pkg/ext/middleware.go b/pkg/ext/middleware.go index 061cf7c..6a94c4f 100644 --- a/pkg/ext/middleware.go +++ b/pkg/ext/middleware.go @@ -20,9 +20,17 @@ func HTML(next http.HandlerFunc) http.HandlerFunc {  	}  } -type LogMiddleware struct { -	entry *logrus.Entry -} +type ( +	User string + +	LogMiddleware struct { +		entry *logrus.Entry +	} +) + +const ( +	UserKey User = "user" +)  func NewLogMiddleare(log *logrus.Entry) *LogMiddleware {  	return &LogMiddleware{ @@ -43,14 +51,20 @@ func (l *LogMiddleware) HTTP(next http.HandlerFunc) http.HandlerFunc {  }  type AuthMiddleware struct { -	key   []byte -	entry *logrus.Entry +	key            []byte +	entry          *logrus.Entry +	userRepository repository.UserRepository  } -func NewAuthMiddleware(key []byte, log *logrus.Entry) *AuthMiddleware { +func NewAuthMiddleware( +	key []byte, +	log *logrus.Entry, +	userRepository repository.UserRepository, +) *AuthMiddleware {  	return &AuthMiddleware{ -		key:   key, -		entry: log.WithField("context", "auth"), +		key:            key, +		entry:          log.WithField("context", "auth"), +		userRepository: userRepository,  	}  } @@ -82,7 +96,14 @@ func (a *AuthMiddleware) LoggedIn(next http.HandlerFunc) http.HandlerFunc {  			http.Redirect(w, r, redirectLogin, http.StatusTemporaryRedirect)  			return  		} -		r = r.WithContext(context.WithValue(r.Context(), service.TokenKey, token)) + +		user, err := a.userRepository.Get(r.Context(), token.UserID) +		if err != nil { +			a.entry.Error(err) +			return +		} + +		r = r.WithContext(context.WithValue(r.Context(), UserKey, user))  		a.entry.  			WithField("userID", token.UserID).  			WithField("username", token.Username). @@ -91,9 +112,9 @@ func (a *AuthMiddleware) LoggedIn(next http.HandlerFunc) http.HandlerFunc {  	}  } -func GetTokenFromCtx(r *http.Request) *service.Token { -	tokenValue := r.Context().Value(service.TokenKey) -	if token, ok := tokenValue.(*service.Token); ok { +func GetUserFromCtx(r *http.Request) *repository.User { +	tokenValue := r.Context().Value(UserKey) +	if token, ok := tokenValue.(*repository.User); ok {  		return token  	}  	return nil @@ -113,7 +134,7 @@ func (i *InitialSetupMiddleware) Check(next http.HandlerFunc) http.HandlerFunc {  	return func(w http.ResponseWriter, r *http.Request) {  		// if user has been set to context it is logged in already -		token := GetTokenFromCtx(r) +		token := GetUserFromCtx(r)  		if token != nil {  			next(w, r)  			return diff --git a/pkg/ext/responses.go b/pkg/ext/responses.go index 34e5f27..d8941e8 100644 --- a/pkg/ext/responses.go +++ b/pkg/ext/responses.go @@ -10,12 +10,12 @@ import (  func NotFound(w http.ResponseWriter) {  	templates.WritePageTemplate(w, &templates.ErrorPage{  		Err: "Not Found", -	}) +	}, false)  }  func InternalServerError(w http.ResponseWriter, err error) {  	w.WriteHeader(http.StatusInternalServerError)  	templates.WritePageTemplate(w, &templates.ErrorPage{  		Err: fmt.Sprintf("Internal Server Error:\n%s", err.Error()), -	}) +	}, false)  } diff --git a/pkg/service/auth.go b/pkg/service/auth.go index 2fc06e3..3811965 100644 --- a/pkg/service/auth.go +++ b/pkg/service/auth.go @@ -147,15 +147,12 @@ func (u *AuthController) Upsert(  }  type ( -	AuthKey string -	Token   struct { +	Token struct {  		UserID   uint  		Username string  	}  ) -const TokenKey AuthKey = "token" -  func ReadToken(data []byte, key []byte) (*Token, error) {  	block, err := aes.NewCipher(key)  	if err != nil { diff --git a/pkg/view/album.go b/pkg/view/album.go index b19e381..e0ee405 100644 --- a/pkg/view/album.go +++ b/pkg/view/album.go @@ -30,10 +30,10 @@ func NewAlbumView(  func (self *AlbumView) Index(w http.ResponseWriter, r *http.Request) error {  	p := getPagination(r) -	token := ext.GetTokenFromCtx(r) +	user := ext.GetUserFromCtx(r)  	// TODO: optmize call, GetPathFromUserID may no be necessary -	userPath, err := self.userRepository.GetPathFromUserID(r.Context(), token.UserID) +	userPath, err := self.userRepository.GetPathFromUserID(r.Context(), user.ID)  	if err != nil {  		return err  	} @@ -91,12 +91,12 @@ func (self *AlbumView) Index(w http.ResponseWriter, r *http.Request) error {  		Settings: settings,  	} -	templates.WritePageTemplate(w, page) +	templates.WritePageTemplate(w, page, user.IsAdmin)  	return nil  }  func (self *AlbumView) SetMyselfIn(r *ext.Router) { -	r.GET("/album/", self.Index) -	r.POST("/album/", self.Index) +	r.GET("/album", self.Index) +	r.POST("/album", self.Index)  } diff --git a/pkg/view/auth.go b/pkg/view/auth.go index 8d87035..318d0a3 100644 --- a/pkg/view/auth.go +++ b/pkg/view/auth.go @@ -20,8 +20,8 @@ func NewAuthView(userController *service.AuthController) *AuthView {  	}  } -func (v *AuthView) LoginView(w http.ResponseWriter, _ *http.Request) error { -	templates.WritePageTemplate(w, &templates.LoginPage{}) +func (v *AuthView) LoginView(w http.ResponseWriter, r *http.Request) error { +	templates.WritePageTemplate(w, &templates.LoginPage{}, false)  	return nil  } @@ -46,12 +46,15 @@ func (v *AuthView) Login(w http.ResponseWriter, r *http.Request) error {  	)  	auth, err := v.userController.Login(r.Context(), username, password) +	if err != nil { +		return err +	}  	if errors.Is(err, service.InvalidLogin) {  		templates.WritePageTemplate(w, &templates.LoginPage{  			Username: r.FormValue("username"),  			Err:      err.Error(), -		}) +		}, false)  		return nil  	} @@ -82,8 +85,8 @@ func Index(w http.ResponseWriter, r *http.Request) {  	http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)  } -func (v *AuthView) InitialRegisterView(w http.ResponseWriter, _ *http.Request) error { -	templates.WritePageTemplate(w, &templates.RegisterPage{}) +func (v *AuthView) InitialRegisterView(w http.ResponseWriter, r *http.Request) error { +	templates.WritePageTemplate(w, &templates.RegisterPage{}, false)  	return nil  } diff --git a/pkg/view/filesystem.go b/pkg/view/filesystem.go index 24f0ce6..9071ec0 100644 --- a/pkg/view/filesystem.go +++ b/pkg/view/filesystem.go @@ -34,10 +34,10 @@ func NewFileSystemView(  func (self *FileSystemView) Index(w http.ResponseWriter, r *http.Request) error {  	var (  		pathValue = r.FormValue("path") -		token     = ext.GetTokenFromCtx(r) +		user      = ext.GetUserFromCtx(r)  	) -	page, err := self.fsService.GetPage(r.Context(), token.UserID, pathValue) +	page, err := self.fsService.GetPage(r.Context(), user.ID, pathValue)  	if err != nil {  		return err  	} @@ -51,7 +51,7 @@ func (self *FileSystemView) Index(w http.ResponseWriter, r *http.Request) error  		Page:      page,  		ShowMode:  settings.ShowMode,  		ShowOwner: settings.ShowOwner, -	}) +	}, user.IsAdmin)  	return nil  } @@ -59,6 +59,6 @@ func (self *FileSystemView) Index(w http.ResponseWriter, r *http.Request) error  func (self *FileSystemView) SetMyselfIn(r *ext.Router) {  	r.GET("/", self.Index)  	r.POST("/", self.Index) -	r.GET("/fs/", self.Index) -	r.POST("/fs/", self.Index) +	r.GET("/fs", self.Index) +	r.POST("/fs", self.Index)  } diff --git a/pkg/view/media.go b/pkg/view/media.go index 3041998..8a10fe0 100644 --- a/pkg/view/media.go +++ b/pkg/view/media.go @@ -71,9 +71,9 @@ func NewMediaView(  func (self *MediaView) Index(w http.ResponseWriter, r *http.Request) error {  	p := getPagination(r) -	token := ext.GetTokenFromCtx(r) +	user := ext.GetUserFromCtx(r) -	userPath, err := self.userRepository.GetPathFromUserID(r.Context(), token.UserID) +	userPath, err := self.userRepository.GetPathFromUserID(r.Context(), user.ID)  	if err != nil {  		return err  	} @@ -98,7 +98,7 @@ func (self *MediaView) Index(w http.ResponseWriter, r *http.Request) error {  		Settings: settings,  	} -	templates.WritePageTemplate(w, page) +	templates.WritePageTemplate(w, page, user.IsAdmin)  	return nil  } @@ -132,9 +132,9 @@ func (self *MediaView) GetThumbnail(w http.ResponseWriter, r *http.Request) erro  }  func (self *MediaView) SetMyselfIn(r *ext.Router) { -	r.GET("/media/", self.Index) -	r.POST("/media/", self.Index) +	r.GET("/media", self.Index) +	r.POST("/media", self.Index) -	r.GET("/media/image/", self.GetImage) -	r.GET("/media/thumbnail/", self.GetThumbnail) +	r.GET("/media/image", self.GetImage) +	r.GET("/media/thumbnail", self.GetThumbnail)  } diff --git a/pkg/view/settings.go b/pkg/view/settings.go index bf2dca6..cdd7baa 100644 --- a/pkg/view/settings.go +++ b/pkg/view/settings.go @@ -39,23 +39,28 @@ func (self *SettingsView) Index(w http.ResponseWriter, r *http.Request) error {  		return err  	} +	user := ext.GetUserFromCtx(r) +  	templates.WritePageTemplate(w, &templates.SettingsPage{  		Settings: s,  		Users:    users, -	}) +	}, user.IsAdmin)  	return nil  }  func (self *SettingsView) User(w http.ResponseWriter, r *http.Request) error { -	id := r.FormValue("userId") +	var ( +		id   = r.URL.Query().Get("userId") +		user = ext.GetUserFromCtx(r) +	)  	idValue, err := ParseUint(id)  	if err != nil {  		return err  	}  	if idValue == nil { -		templates.WritePageTemplate(w, &templates.UserPage{}) +		templates.WritePageTemplate(w, &templates.UserPage{}, user.IsAdmin)  	} else {  		user, err := self.userController.Get(r.Context(), *idValue)  		if err != nil { @@ -67,7 +72,7 @@ func (self *SettingsView) User(w http.ResponseWriter, r *http.Request) error {  			Username: user.Username,  			Path:     user.Path,  			IsAdmin:  user.IsAdmin, -		}) +		}, user.IsAdmin)  	}  	return nil @@ -87,7 +92,15 @@ func (self *SettingsView) UpsertUser(w http.ResponseWriter, r *http.Request) err  		return err  	} -	err = self.userController.Upsert(r.Context(), idValue, username, "", password, isAdmin, path) +	err = self.userController.Upsert( +		r.Context(), +		idValue, +		username, +		"", +		password, +		isAdmin, +		path, +	)  	if err != nil {  		return err  	} @@ -137,12 +150,12 @@ func (self *SettingsView) Save(w http.ResponseWriter, r *http.Request) error {  }  func (self *SettingsView) SetMyselfIn(r *ext.Router) { -	r.GET("/settings/", self.Index) -	r.POST("/settings/", self.Save) +	r.GET("/settings", Protect(self.Index)) +	r.POST("/settings", Protect(self.Save)) -	r.GET("/users/", self.User) -	r.GET("/users/delete", self.Delete) -	r.POST("/users/", self.UpsertUser) +	r.GET("/users", Protect(self.User)) +	r.GET("/users/delete", Protect(self.Delete)) +	r.POST("/users", Protect(self.UpsertUser))  }  func ParseUint(id string) (*uint, error) { diff --git a/pkg/view/view.go b/pkg/view/view.go index 663738b..f8dfa16 100644 --- a/pkg/view/view.go +++ b/pkg/view/view.go @@ -1,7 +1,22 @@  package view -import "git.sr.ht/~gabrielgio/img/pkg/ext" +import ( +	"net/http" + +	"git.sr.ht/~gabrielgio/img/pkg/ext" +)  type View interface {  	SetMyselfIn(r *ext.Router)  } + +func Protect(next ext.ErrorRequestHandler) ext.ErrorRequestHandler { +	return func(w http.ResponseWriter, r *http.Request) error { +		user := ext.GetUserFromCtx(r) +		if !user.IsAdmin { +			http.NotFound(w, r) +			return nil +		} +		return next(w, r) +	} +} diff --git a/templates/album.qtpl b/templates/album.qtpl index 835db57..58fc499 100644 --- a/templates/album.qtpl +++ b/templates/album.qtpl @@ -23,14 +23,14 @@ func (m *AlbumPage) PreloadAttr() string {  <h1 class="title text-size-1">{%s p.Name %}</h1>  <div class="tags are-large">  {% for _, a := range p.Albums %} -  <a href="/album/?albumId={%s FromUInttoString(&a.ID) %}" class="tag text-size-2">{%s a.Name %}</a> +  <a href="/album?albumId={%s FromUInttoString(&a.ID) %}" class="tag text-size-2">{%s a.Name %}</a>  {% endfor %}  </div>  <div class="columns">  {%= Mosaic(p.Medias, p.PreloadAttr()) %}  </div>  <div> -    <a href="/album/?albumId={%s FromUInttoString(p.Next.AlbumID) %}&page={%d p.Next.Page %}" class="button is-pulled-right">next</a> +    <a href="/album?albumId={%s FromUInttoString(p.Next.AlbumID) %}&page={%d p.Next.Page %}" class="button is-pulled-right">next</a>  </div>  {% endfunc %} diff --git a/templates/base.qtpl b/templates/base.qtpl index a80803a..30b084e 100644 --- a/templates/base.qtpl +++ b/templates/base.qtpl @@ -21,7 +21,7 @@ Page {  Page prints a page implementing Page interface. -{% func PageTemplate(p Page) %} +{% func PageTemplate(p Page, isAdmin bool) %}  <html lang="en">      <head>          <meta charset="utf-8"> @@ -33,18 +33,20 @@ Page prints a page implementing Page interface.      <body>          <nav class="navbar">              <div class="navbar-start"> -                <a href="/fs/" class="navbar-item text-size-1"> +                <a href="/fs" class="navbar-item text-size-1">                      file                  </a> -                <a href="/media/" class="navbar-item text-size-1"> +                <a href="/media" class="navbar-item text-size-1">                      media                  </a> -                <a href="/album/" class="navbar-item text-size-1"> +                <a href="/album" class="navbar-item text-size-1">                      album                  </a> -                <a href="/settings/" class="navbar-item text-size-1"> +                {% if isAdmin  %} +                <a href="/settings" class="navbar-item text-size-1">                      settings                  </a> +                {% endif %}              </div>          </nav>          <div class="container is-fullhd"> diff --git a/templates/media.qtpl b/templates/media.qtpl index 4251deb..737d03d 100644 --- a/templates/media.qtpl +++ b/templates/media.qtpl @@ -22,7 +22,7 @@ func (m *MediaPage) PreloadAttr() string {  {%= Mosaic(p.Medias, p.PreloadAttr()) %}  </div>  <div> -    <a href="/media/?page={%d p.Next.Page %}" class="button is-pulled-right">next</a> +    <a href="/media?page={%d p.Next.Page %}" class="button is-pulled-right">next</a>  </div>  {% endfunc %} diff --git a/templates/mosaic.qtpl b/templates/mosaic.qtpl index 3e6ccf8..18dbcba 100644 --- a/templates/mosaic.qtpl +++ b/templates/mosaic.qtpl @@ -8,12 +8,12 @@      {% for _, media := range c %}      <div class="card-image">         {% if media.IsVideo() %} -       <video class="image is-fit" controls muted="true" poster="/media/thumbnail/?path_hash={%s media.PathHash %}" preload="{%s preloadAttr %}"> -           <source src="/media/image/?path_hash={%s media.PathHash %}" type="{%s media.MIMEType %}"> +       <video class="image is-fit" controls muted="true" poster="/media/thumbnail?path_hash={%s media.PathHash %}" preload="{%s preloadAttr %}"> +           <source src="/media/image?path_hash={%s media.PathHash %}" type="{%s media.MIMEType %}">         </video>         {% else %}          <figure class="image is-fit"> -            <img src="/media/thumbnail/?path_hash={%s media.PathHash %}"> +            <img src="/media/thumbnail?path_hash={%s media.PathHash %}">          </figure>          {% endif %}      </div> diff --git a/templates/settings.qtpl b/templates/settings.qtpl index 4439c77..b720a88 100644 --- a/templates/settings.qtpl +++ b/templates/settings.qtpl @@ -58,7 +58,7 @@ type SettingsPage struct {      </div>      {% endfor %}      <div class="field"> -        <a href="/users/" class="button">create</a> +        <a href="/users" class="button">create</a>      </div>  </div>  {% endfunc %} diff --git a/templates/user.qtpl b/templates/user.qtpl index 6ec783d..6fc3ce6 100644 --- a/templates/user.qtpl +++ b/templates/user.qtpl @@ -13,7 +13,7 @@ type UserPage struct {  {% func (p *UserPage) Content() %}  <h1>Initial Setup</h1> -<form action="/users/" method="post"> +<form action="/users" method="post">      {% if p.ID != nil %}       <input type="hidden" name="userId" value="{%s FromUInttoString(p.ID) %}" />      {% endif %} @@ -41,7 +41,7 @@ type UserPage struct {      <div class="field">          <label class="label">Is Admin?</label>          <div class="control"> -            <input type="checkbox" name="isAdmin" type="password" {% if p.IsAdmin %}checked{% endif %}> +            <input type="checkbox" name="isAdmin" {% if p.IsAdmin %}checked{% endif %}>          </div>      </div>      <div class="field"> | 
