aboutsummaryrefslogtreecommitdiff
path: root/pkg/ext
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/ext')
-rw-r--r--pkg/ext/compression.go142
-rw-r--r--pkg/ext/compression_test.go42
-rw-r--r--pkg/ext/mime.go24
3 files changed, 208 insertions, 0 deletions
diff --git a/pkg/ext/compression.go b/pkg/ext/compression.go
new file mode 100644
index 0000000..92144b8
--- /dev/null
+++ b/pkg/ext/compression.go
@@ -0,0 +1,142 @@
+package ext
+
+import (
+ "compress/gzip"
+ "compress/lzw"
+ "errors"
+ "io"
+ "log/slog"
+ "net/http"
+ "strconv"
+ "strings"
+
+ "git.gabrielgio.me/cerrado/pkg/u"
+ "github.com/andybalholm/brotli"
+ "github.com/klauspost/compress/zstd"
+)
+
+var (
+ invalidParamErr = errors.New("Invalid weighted param")
+)
+
+type CompressionResponseWriter struct {
+ innerWriter http.ResponseWriter
+ compressWriter io.Writer
+}
+
+func Compress(next func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if accept, ok := r.Header["Accept-Encoding"]; ok {
+ if compress, algo := GetCompressionWriter(u.FirstOrZero(accept), w); algo != "" {
+ defer compress.Close()
+ w.Header().Add("Content-Encoding", algo)
+ w = &CompressionResponseWriter{
+ innerWriter: w,
+ compressWriter: compress,
+ }
+ }
+ }
+ next(w, r)
+ }
+}
+
+func GetCompressionWriter(header string, inner io.Writer) (io.WriteCloser, string) {
+ c := GetCompression(header)
+ switch c {
+ case "br":
+ return GetBrotliWriter(inner), c
+ case "gzip":
+ return GetGZIPWriter(inner), c
+ case "compress":
+ return GetLZWWriter(inner), c
+ case "zstd":
+ return GetZSTDWriter(inner), c
+ default:
+ return nil, ""
+ }
+
+}
+
+func (c *CompressionResponseWriter) Header() http.Header {
+ return c.innerWriter.Header()
+}
+func (c *CompressionResponseWriter) Write(b []byte) (int, error) {
+ return c.compressWriter.Write(b)
+}
+
+func (c *CompressionResponseWriter) WriteHeader(statusCode int) {
+ c.innerWriter.WriteHeader(statusCode)
+}
+
+func GetCompression(header string) string {
+ c := "*"
+ q := 0.0
+
+ if header == "" {
+ return c
+ }
+
+ // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
+ for _, e := range strings.Split(header, ",") {
+ ps := strings.Split(e, ";")
+ if len(ps) == 2 {
+ w, err := getWeighedValue(ps[1])
+ if err != nil {
+ slog.Error(
+ "Error parsing weighting from Accept-Encoding",
+ "error", err,
+ )
+ continue
+ }
+ // gettting weighting value
+ if w > q {
+ q = w
+ c = strings.Trim(ps[0], " ")
+ }
+ } else {
+ if 1 > q {
+ q = 1
+ c = strings.Trim(ps[0], " ")
+ }
+ }
+ }
+
+ return c
+}
+
+func GetGZIPWriter(w io.Writer) io.WriteCloser {
+ // error can be ignored here since it will only err when compression level
+ // is not valid
+ r, _ := gzip.NewWriterLevel(w, gzip.BestCompression)
+ return r
+}
+
+func GetBrotliWriter(w io.Writer) io.WriteCloser {
+ return brotli.NewWriterLevel(w, brotli.BestCompression)
+}
+
+func GetZSTDWriter(w io.Writer) io.WriteCloser {
+ // error can be ignored here since it will only opts are given
+ r, _ := zstd.NewWriter(w)
+ return r
+}
+
+func GetLZWWriter(w io.Writer) io.WriteCloser {
+ return lzw.NewWriter(w, lzw.LSB, 8)
+}
+
+func getWeighedValue(part string) (float64, error) {
+ ps := strings.SplitN(part, "=", 2)
+ if len(ps) != 2 {
+ return 0, invalidParamErr
+ }
+ if name := strings.TrimSpace(ps[0]); name == "q" {
+ w, err := strconv.ParseFloat(ps[1], 64)
+ if err != nil {
+ return 0, err
+ }
+ return w, nil
+ }
+
+ return 0, invalidParamErr
+}
diff --git a/pkg/ext/compression_test.go b/pkg/ext/compression_test.go
new file mode 100644
index 0000000..6424378
--- /dev/null
+++ b/pkg/ext/compression_test.go
@@ -0,0 +1,42 @@
+// go:build unit
+package ext
+
+import "testing"
+
+func TestGetCompression(t *testing.T) {
+ testCases := []struct {
+ name string
+ header string
+ compression string
+ }{
+ {
+ name: "Empty",
+ header: "",
+ compression: "*",
+ },
+ {
+ name: "Weighted",
+ header: "gzip;q=1.0, *;q=0.5",
+ compression: "gzip",
+ },
+ {
+ name: "Mixed",
+ header: "deflate, gzip;q=1.0, *;q=0.5",
+ compression: "deflate",
+ },
+ {
+ name: "Not weighted",
+ header: "zstd, deflate, gzip",
+ compression: "zstd",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ got := GetCompression(tc.header)
+ if got != tc.compression {
+ t.Errorf("Wrong compression returned: got %s want %s", got, tc.compression)
+ }
+ })
+ }
+}
diff --git a/pkg/ext/mime.go b/pkg/ext/mime.go
new file mode 100644
index 0000000..6da66e3
--- /dev/null
+++ b/pkg/ext/mime.go
@@ -0,0 +1,24 @@
+package ext
+
+import "net/http"
+
+type ContentType = string
+
+const (
+ TextHTML ContentType = "text/html"
+)
+
+func Html(next func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ next(w, r)
+ }
+}
+
+func SetHTML(w http.ResponseWriter) {
+ SetMIME(w, TextHTML)
+
+}
+
+func SetMIME(w http.ResponseWriter, mime ContentType) {
+ w.Header().Add("Content-Type", mime)
+}