diff options
| -rw-r--r-- | main.go | 24 | ||||
| -rw-r--r-- | pkg/service/auth.go | 117 | ||||
| -rw-r--r-- | pkg/service/auth_test.go | 119 | 
3 files changed, 244 insertions, 16 deletions
@@ -2,8 +2,6 @@ package main  import (  	"context" -	"crypto/rand" -	"encoding/base64"  	"flag"  	"fmt"  	"log/slog" @@ -12,7 +10,6 @@ import (  	"time"  	"github.com/alecthomas/chroma/v2/styles" -	"golang.org/x/crypto/bcrypt"  	"git.gabrielgio.me/cerrado/pkg/config"  	"git.gabrielgio.me/cerrado/pkg/handler" @@ -21,9 +18,6 @@ import (  )  func main() { -	ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) -	defer stop() -  	if len(os.Args) == 4 && os.Args[1] == "hash" {  		err := hash(os.Args[2], os.Args[3])  		if err != nil { @@ -42,36 +36,34 @@ func main() {  		return  	} -	if err := run(ctx); err != nil { +	if err := run(); err != nil {  		slog.Error("Server", "error", err)  		os.Exit(1)  	}  }  func hash(username string, password string) error { -	passphrase := fmt.Sprintf("%s:%s", username, password) -	bytes, err := bcrypt.GenerateFromPassword([]byte(passphrase), 14) +	hash, err := service.GenerateHash(username, password)  	if err != nil {  		return err  	} -	fmt.Println(string(bytes)) +	fmt.Println(hash)  	return nil  }  func key() error { -	key := make([]byte, 64) - -	_, err := rand.Read(key) +	en, err := service.GenerateAesKey()  	if err != nil {  		return err  	} - -	en := base64.StdEncoding.EncodeToString(key)  	fmt.Println(en)  	return nil  } -func run(ctx context.Context) error { +func run() error { +	ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) +	defer stop() +  	configPath := flag.String("config", "/etc/cerrado.scfg", "File path for the configuration file")  	flag.Parse() diff --git a/pkg/service/auth.go b/pkg/service/auth.go new file mode 100644 index 0000000..1fbf4b6 --- /dev/null +++ b/pkg/service/auth.go @@ -0,0 +1,117 @@ +package service + +import ( +	"bytes" +	"crypto/aes" +	"crypto/cipher" +	"crypto/rand" +	"encoding/base64" +	"fmt" +	"io" + +	"golang.org/x/crypto/bcrypt" +) + +type ( +	AuthService struct { +		authRepository authRepository +	} + +	authRepository interface { +		GetPassphrase() []byte +		GetBase64AesKey() []byte +	} +) + +var tokenSeed = []byte("cerrado") + +func (a *AuthService) CheckAuth(username, password string) bool { +	passphrase := a.authRepository.GetPassphrase() +	pass := []byte(fmt.Sprintf("%s:%s", username, password)) + +	err := bcrypt.CompareHashAndPassword(passphrase, pass) + +	return err == nil +} + +func (a *AuthService) IssueToken() ([]byte, error) { +	// TODO: do this block only once +	base := a.authRepository.GetBase64AesKey() + +	dbuf, err := base64.StdEncoding.DecodeString(string(base)) +	if err != nil { +		return nil, err +	} + +	block, err := aes.NewCipher(dbuf) +	if err != nil { +		return nil, err +	} + +	gcm, err := cipher.NewGCM(block) +	if err != nil { +		return nil, err +	} + +	nonce := make([]byte, gcm.NonceSize()) +	if _, err := io.ReadFull(rand.Reader, nonce); err != nil { +		return nil, err +	} + +	ciphertext := gcm.Seal(nonce, nonce, tokenSeed, nil) + +	return ciphertext, nil +} + +func (a *AuthService) ValidateToken(token []byte) (bool, error) { +	base := a.authRepository.GetBase64AesKey() + +	dbuf, err := base64.StdEncoding.DecodeString(string(base)) +	if err != nil { +		return false, err +	} + +	block, err := aes.NewCipher(dbuf) +	if err != nil { +		return false, err +	} + +	gcm, err := cipher.NewGCM(block) +	if err != nil { +		return false, err +	} + +	nonceSize := gcm.NonceSize() +	if len(token) < nonceSize { +		return false, fmt.Errorf("ciphertext too short") +	} + +	nonce, ciphertext := token[:nonceSize], token[nonceSize:] +	plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) +	if err != nil { +		return false, err +	} + +	return bytes.Equal(tokenSeed, plaintext), nil +} + +func GenerateHash(username, password string) (string, error) { +	passphrase := fmt.Sprintf("%s:%s", username, password) +	bytes, err := bcrypt.GenerateFromPassword([]byte(passphrase), 14) +	if err != nil { +		return "", err +	} + +	return string(bytes), nil +} + +func GenerateAesKey() (string, error) { +	key := make([]byte, 32) + +	_, err := rand.Read(key) +	if err != nil { +		return "", err +	} + +	return base64.StdEncoding.EncodeToString(key), nil +} diff --git a/pkg/service/auth_test.go b/pkg/service/auth_test.go new file mode 100644 index 0000000..06bf76f --- /dev/null +++ b/pkg/service/auth_test.go @@ -0,0 +1,119 @@ +// go:build unit + +package service + +import ( +	"testing" +) + +func TestCheck(t *testing.T) { +	testCases := []struct { +		name       string +		passphrase []byte +		username   string +		password   string +		wantError  bool +	}{ +		{ +			name:       "generated", +			passphrase: nil, +			username:   "gabrielgio", +			password:   "adminadmin", +			wantError:  false, +		}, +		{ +			name:       "static", +			passphrase: []byte("$2a$14$W2yT0E6Zm8nTecqipHUQGOLC6PvNjIQqpQTW/MZmD5oqDfaBJnBV6"), +			username:   "gabrielgio", +			password:   "adminadmin", +			wantError:  false, +		}, +		{ +			name:       "error", +			passphrase: []byte("This is not a valid hash"), +			username:   "gabrielgio", +			password:   "adminadmin", +			wantError:  true, +		}, +	} + +	for _, tc := range testCases { +		t.Run(tc.name, func(t *testing.T) { +			mock := &mockAuthRepository{ +				username:   tc.username, +				password:   tc.password, +				passphrase: tc.passphrase, +			} + +			service := AuthService{authRepository: mock} + +			if service.CheckAuth(tc.username, tc.password) == tc.wantError { +				t.Errorf("Invalid result, wanted %t got %t", tc.wantError, !tc.wantError) +			} +		}) +	} +} + +func TestValidate(t *testing.T) { +	testCases := []struct { +		name   string +		aesKey []byte +	}{ +		{ +			name:   "generated", +			aesKey: nil, +		}, +		{ +			name:   "static", +			aesKey: []byte("RTGkmunKmi5agh7jaqENunG2zI/godnkqhHaHyX/AVg="), +		}, +	} + +	for _, tc := range testCases { +		t.Run(tc.name, func(t *testing.T) { +			mock := &mockAuthRepository{ +				aesKey: tc.aesKey, +			} + +			service := AuthService{authRepository: mock} + +			token, err := service.IssueToken() +			if err != nil { +				t.Fatalf("Error issuing token: %s", err.Error()) +			} + +			v, err := service.ValidateToken(token) +			if err != nil { +				t.Fatalf("Error validating token: %s", err.Error()) +			} + +			if !v { +				t.Error("Invalid token generated") +			} +		}) +	} +} + +type mockAuthRepository struct { +	username   string +	password   string +	passphrase []byte + +	aesKey []byte +} + +func (m *mockAuthRepository) GetPassphrase() []byte { +	if m.passphrase == nil { +		hash, _ := GenerateHash(m.username, m.password) +		m.passphrase = []byte(hash) +	} +	return m.passphrase +} + +func (m *mockAuthRepository) GetBase64AesKey() []byte { +	if m.aesKey == nil { +		key, _ := GenerateAesKey() +		m.aesKey = []byte(key) +	} +	return m.aesKey +}  | 
