package ext import ( "compress/gzip" "compress/lzw" "errors" "io" "log/slog" "net/http" "strconv" "strings" "git.gabrielgio.me/cerrado/pkg/u" "github.com/andybalholm/brotli" "github.com/klauspost/compress/zstd" ) var ( invalidParamErr = errors.New("Invalid weighted param") ) type CompressionResponseWriter struct { innerWriter http.ResponseWriter compressWriter io.Writer } func Compress(next func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if accept, ok := r.Header["Accept-Encoding"]; ok { if compress, algo := GetCompressionWriter(u.FirstOrZero(accept), w); algo != "" { defer compress.Close() w.Header().Add("Content-Encoding", algo) w = &CompressionResponseWriter{ innerWriter: w, compressWriter: compress, } } } next(w, r) } } func GetCompressionWriter(header string, inner io.Writer) (io.WriteCloser, string) { c := GetCompression(header) switch c { case "br": return GetBrotliWriter(inner), c case "gzip": return GetGZIPWriter(inner), c case "compress": return GetLZWWriter(inner), c case "zstd": return GetZSTDWriter(inner), c default: return nil, "" } } func (c *CompressionResponseWriter) Header() http.Header { return c.innerWriter.Header() } func (c *CompressionResponseWriter) Write(b []byte) (int, error) { return c.compressWriter.Write(b) } func (c *CompressionResponseWriter) WriteHeader(statusCode int) { c.innerWriter.WriteHeader(statusCode) } func GetCompression(header string) string { c := "*" q := 0.0 if header == "" { return c } // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding for _, e := range strings.Split(header, ",") { ps := strings.Split(e, ";") if len(ps) == 2 { w, err := getWeighedValue(ps[1]) if err != nil { slog.Error( "Error parsing weighting from Accept-Encoding", "error", err, ) continue } // gettting weighting value if w > q { q = w c = strings.Trim(ps[0], " ") } } else { if 1 > q { q = 1 c = strings.Trim(ps[0], " ") } } } return c } func GetGZIPWriter(w io.Writer) io.WriteCloser { // error can be ignored here since it will only err when compression level // is not valid r, _ := gzip.NewWriterLevel(w, gzip.BestCompression) return r } func GetBrotliWriter(w io.Writer) io.WriteCloser { return brotli.NewWriterLevel(w, brotli.BestCompression) } func GetZSTDWriter(w io.Writer) io.WriteCloser { // error can be ignored here since it will only opts are given r, _ := zstd.NewWriter(w) return r } func GetLZWWriter(w io.Writer) io.WriteCloser { return lzw.NewWriter(w, lzw.LSB, 8) } func getWeighedValue(part string) (float64, error) { ps := strings.SplitN(part, "=", 2) if len(ps) != 2 { return 0, invalidParamErr } if name := strings.TrimSpace(ps[0]); name == "q" { w, err := strconv.ParseFloat(ps[1], 64) if err != nil { return 0, err } return w, nil } return 0, invalidParamErr }