aboutsummaryrefslogtreecommitdiff
path: root/pkg/ext/compression.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/ext/compression.go')
-rw-r--r--pkg/ext/compression.go142
1 files changed, 142 insertions, 0 deletions
diff --git a/pkg/ext/compression.go b/pkg/ext/compression.go
new file mode 100644
index 0000000..92144b8
--- /dev/null
+++ b/pkg/ext/compression.go
@@ -0,0 +1,142 @@
+package ext
+
+import (
+ "compress/gzip"
+ "compress/lzw"
+ "errors"
+ "io"
+ "log/slog"
+ "net/http"
+ "strconv"
+ "strings"
+
+ "git.gabrielgio.me/cerrado/pkg/u"
+ "github.com/andybalholm/brotli"
+ "github.com/klauspost/compress/zstd"
+)
+
+var (
+ invalidParamErr = errors.New("Invalid weighted param")
+)
+
+type CompressionResponseWriter struct {
+ innerWriter http.ResponseWriter
+ compressWriter io.Writer
+}
+
+func Compress(next func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if accept, ok := r.Header["Accept-Encoding"]; ok {
+ if compress, algo := GetCompressionWriter(u.FirstOrZero(accept), w); algo != "" {
+ defer compress.Close()
+ w.Header().Add("Content-Encoding", algo)
+ w = &CompressionResponseWriter{
+ innerWriter: w,
+ compressWriter: compress,
+ }
+ }
+ }
+ next(w, r)
+ }
+}
+
+func GetCompressionWriter(header string, inner io.Writer) (io.WriteCloser, string) {
+ c := GetCompression(header)
+ switch c {
+ case "br":
+ return GetBrotliWriter(inner), c
+ case "gzip":
+ return GetGZIPWriter(inner), c
+ case "compress":
+ return GetLZWWriter(inner), c
+ case "zstd":
+ return GetZSTDWriter(inner), c
+ default:
+ return nil, ""
+ }
+
+}
+
+func (c *CompressionResponseWriter) Header() http.Header {
+ return c.innerWriter.Header()
+}
+func (c *CompressionResponseWriter) Write(b []byte) (int, error) {
+ return c.compressWriter.Write(b)
+}
+
+func (c *CompressionResponseWriter) WriteHeader(statusCode int) {
+ c.innerWriter.WriteHeader(statusCode)
+}
+
+func GetCompression(header string) string {
+ c := "*"
+ q := 0.0
+
+ if header == "" {
+ return c
+ }
+
+ // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
+ for _, e := range strings.Split(header, ",") {
+ ps := strings.Split(e, ";")
+ if len(ps) == 2 {
+ w, err := getWeighedValue(ps[1])
+ if err != nil {
+ slog.Error(
+ "Error parsing weighting from Accept-Encoding",
+ "error", err,
+ )
+ continue
+ }
+ // gettting weighting value
+ if w > q {
+ q = w
+ c = strings.Trim(ps[0], " ")
+ }
+ } else {
+ if 1 > q {
+ q = 1
+ c = strings.Trim(ps[0], " ")
+ }
+ }
+ }
+
+ return c
+}
+
+func GetGZIPWriter(w io.Writer) io.WriteCloser {
+ // error can be ignored here since it will only err when compression level
+ // is not valid
+ r, _ := gzip.NewWriterLevel(w, gzip.BestCompression)
+ return r
+}
+
+func GetBrotliWriter(w io.Writer) io.WriteCloser {
+ return brotli.NewWriterLevel(w, brotli.BestCompression)
+}
+
+func GetZSTDWriter(w io.Writer) io.WriteCloser {
+ // error can be ignored here since it will only opts are given
+ r, _ := zstd.NewWriter(w)
+ return r
+}
+
+func GetLZWWriter(w io.Writer) io.WriteCloser {
+ return lzw.NewWriter(w, lzw.LSB, 8)
+}
+
+func getWeighedValue(part string) (float64, error) {
+ ps := strings.SplitN(part, "=", 2)
+ if len(ps) != 2 {
+ return 0, invalidParamErr
+ }
+ if name := strings.TrimSpace(ps[0]); name == "q" {
+ w, err := strconv.ParseFloat(ps[1], 64)
+ if err != nil {
+ return 0, err
+ }
+ return w, nil
+ }
+
+ return 0, invalidParamErr
+}