Go HTTP handler with context

One thing that I find is missing from Go's http library is some sort of context passed through the request handler chain. Obviously there is a context.Context attached to the Request given, but using an any map to access data is cumbersome and not very 'clean'.

This article shows how you can make a new type of handler with a context, and how this unlocks all sorts of other benefits.

Context

First, lets define our context. This struct will wrap the standard http.ResponseWriter and *http.Request duo and supply additional utility methods.

The magic happens when you implement http.ResponseWriter for Context. This allows it to propagate through the handler chain and keep state along the way.

You can also implement context.Context if you like, but simply having a getter method for the requests context works fine.

// server/context.go
package server

import (
    ...
)

type Context struct {
    R         *http.Request
    // Context implemenents ResponseWriter so we
    // should always pass this instead of w.
    w         http.ResponseWriter
    status    int
    err       error
    bytes     int
    createdAt time.Time
}

func newContext(w http.ResponseWriter, r *http.Request) *Context {
    return &Context{r, w, 200, nil, 0, time.Now()}
}

func ReuseOrNewContext(w http.ResponseWriter, r *http.Request) *Context {
    if ctx, ok := w.(*Context); ok {
        // r may have changed with shallow copy in r.WithContext().
        // The request is still the same, but we need to update the reference.
    	ctx.R = r
    	return ctx
    }
    return newContext(w, r)
}

// Implement http.ResponseWriter
func (c *Context) WriteHeader(code int) {
    c.status = code
    c.w.WriteHeader(code)
}

// Implement http.ResponseWriter
func (c *Context) Write(b []byte) (int, error) {
    // If WriteHeader was never called, default is 200
    if c.status == 0 {
    	c.status = http.StatusOK
    }
    c.bytes += len(b)
    return c.w.Write(b)
}

// Implement http.ResponseWriter
func (c *Context) Header() http.Header {
    return c.w.Header()
}

Handler

Our new Handler type is quite simple:

// server/handler.go
package server

import "net/http"

type Handler func(ctx *Context) error

// Implement http.Handler
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    ctx := ReuseOrNewContext(w, r) // Context is now passed down through http.Handlers
    _ = h(ctx)
}

// Since our Handler implements http.Handler, any http middleware will work.
type Middleware func(h http.Handler) http.Handler

Note that the returned error is ignored by ServeHTTP. This is fine as the Error() method already writes to the response as we will see later. The reason it is returned is to avoid an extra line per early return (instead of writing error and then returning you simply return error), and also to extract errors in middleware.

Additionally, the Handler is always responsible for writing the response. The returned error is only for convencience and should always be returned through ctx.Error(). Forgetting this will result in silent errors/failures unless you have logging middleware to track it.

Mux

Our mux will wrap a chi.Mux (or any other mux including the standard http one). The only difference here is that we apply middleware, and that the Handle method accepts our Handler type instead of a http.HandlerFunc.

package server

import (
    "net/http"

    "github.com/go-chi/chi/v5"
    chiMiddleware "github.com/go-chi/chi/v5/middleware"
    "github.com/go-chi/cors"
)

type Mux struct {
    mux         *chi.Mux
    middlewares []Middleware
}

func NewMux(middlewares ...Middleware) *Mux {
    mux := chi.NewMux()

    mux.Use(chiMiddleware.StripSlashes)
    mux.Use(cors.Handler(cors.Options{
    	AllowedOrigins:   []string{"https://*", "http://*"},
    	AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
    	AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type", "X-Admin-Key"},
    	ExposedHeaders:   []string{"Content-Length", "Content-Type"},
    	AllowCredentials: true,
    }))

    return &Mux{
    	mux:         mux,
    	middlewares: middlewares,
    }
}

func (m *Mux) applyMiddlewares(h http.Handler, middlewares ...Middleware) http.Handler {
    for _, m := range m.middlewares {
    	h = m(h)
    }
    for _, m := range middlewares {
    	h = m(h)
    }
    return h
}

// Handle registers a new Handler for the given method/pattern.
func (m *Mux) Handle(method string, pattern string, h Handler, middlewares ...Middleware) {
    m.mux.MethodFunc(method, pattern, func(w http.ResponseWriter, r *http.Request) {
    	httpH := m.applyMiddlewares(h, middlewares...)
    	httpH.ServeHTTP(w, r)
    })
}

// Mount mounts a http.Handler to the given pattern.
func (m *Mux) Mount(pattern string, h http.Handler, middleware ...Middleware) {
    h = m.applyMiddlewares(h, middleware...)
    m.mux.Mount(pattern, h)
}

// Implement http.Handler
func (m Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    m.mux.ServeHTTP(ReuseOrNewContext(w, r), r)
}

Middleware

One major benefit of this new Handler type is the context aware middleware. Context can keep track of all sorts of state and be read and written in middlewares, like setting auth credentials if the request passed through an authentication middleware.

Here is a simple logger middleware which simply prints any errors:

// middleware/logger.go
package middleware

import (
    "errors"
    "net/http"
    "time"

    "app/server"
)

func Logger(h http.Handler) http.Handler {
    return server.Handler(func(ctx *server.Context) error {
        start := time.Now()
        err := ctx.Next(h) // see implementation below

        if err != nil {
            log.Printf(
                "request failed. err=%s, time=%s, status=%d",
                err.Error(),
                time.Since(start).String(),
                ctx.Status(),
            )
        }

        return err
    })
}
// Next calls the given http.Handler with the current context
// and returns any error set to the context.
func (c *Context) Next(h http.Handler) error {
    h.ServeHTTP(c, c.R)
    return c.GetError()
}

Middleware should never create and pass its own http.ResponseWriter. This would create a new Context and throw away anything written to it in later handlers.

Example

Simple example using the Mux, Handler, and Middleware together:

package main

import (
    "net/http"

    "app/server"
    "app/middleware"
)

func main() {
    // Logger applies to all requests handled by this mux
    mux := server.NewMux(middleware.Logger)

    mux.Handle("GET", "/", func (ctx *server.Context) error {
        return ctx.String("Hello world!")
    })

    http.ListenAndServe(":8080", mux)
}

Testing

Another great benefit of this pattern is testing. Since the Context is written to along the way, it acts as an even better httptest.ResponseRecorder, and you can even wrap it for double the testing power.

r := httptest.NewRequest(test_method, test_path, test_body)
ctx := handler.ReuseOrNewContext(httptest.NewRecorder(), r)

test_mux.ServeHTTP(ctx, r)

assert(ctx.GetError() == nil)
assert(ctx.Status() == 200)
assert(ctx.Bytes() == n)
...

Context helpers

Here's a list of helper methods used on the Context. These can be tailored to your project, which is what makes this pattern so powerful.

// ReadJSON unmarshals the request body and reads into dest.
func (c *Context) ReadJSON(dest any) error {
    return json.NewDecoder(c.R.Body).Decode(dest)
}

// JSON responds to the request with the json encoding of data.
func (c *Context) JSON(data any) error {
    // Write status header if it's been set to non-default
    if c.status != 0 && c.status != 200 {
    	c.WriteHeader(c.status)
    }
    return json.NewEncoder(c).Encode(data)
}

func (c *Context) PathValue(key string) string {
    return c.R.PathValue(key)
}

// PathValueInt requires that the value is an integer.
func (c *Context) PathValueInt(key string) (int, error) {
    s := c.R.PathValue(key)
    if n, err := strconv.Atoi(s); err == nil {
    	return n, err
    }
    return 0, c.Error(fmt.Errorf("invalid integer literal: %s", s), http.StatusBadRequest)
}

// To be used instead of 'return nil' in handlers for clarity.
func (c *Context) Ok() error {
    return nil
}

func (c *Context) Context() context.Context {
    return c.R.Context()
}

func (c *Context) String(s string) error {
    _, err := c.Write([]byte(s))
    return err
}

// Error writes an error message to the response writer and sets the status.
// Returns same error for convenience.
func (c *Context) Error(err error, status int) error {
    c.status = status
    c.err = err
    http.Error(c, err.Error(), status)
    return err
}

func (c *Context) GetError() error {
    return c.err
}

func (c *Context) ServeFile(filepath string) error {
    http.ServeFile(c, c.R, filepath)
    return nil
}

func (c *Context) SetCookie(cookie *http.Cookie) {
    http.SetCookie(c, cookie)
}

func (c *Context) Redirect(url string) error {
    http.Redirect(c, c.R, url, http.StatusSeeOther)
    return nil
}

func (c *Context) QueryParam(key string) string {
    return c.R.URL.Query().Get(key)
}

func (c *Context) Status() int {
    // If using httptest.ResponseRecorder, get status from it
    if recorder, ok := c.w.(*httptest.ResponseRecorder); ok {
    	return recorder.Code
    }
    return c.status
}

func (c *Context) HeaderValue(key string) string {
    return c.R.Header.Get(key)
}

// Lifetime returns the duration since creation.
func (c *Context) Lifetime() time.Duration {
    return time.Now().Sub(c.createdAt)
}

func (c *Context) Bytes() int {
    return c.bytes
}

Limitations

There are many benefits of this new Handler/Context pattern, but there are also a few limitations to be aware of:

  1. As mentioned, the Handler must always return an error through ctx.Error().
  2. If any middleware replaces Context with some other http.ResponseWriter, the state breaks at that point, and anything written before it will be lost.
  3. Boilerplate. While this really isn't a lot of code, there is no way to extract it into its own package. The usefulnes of Context depends on the fact that it is implemented differently for each project. And as such, all of the code must be copy-pasted between projects.