CORS with Go and Negroni

There are some pieces that you need to put in every microservice you write. Those are for instance logging, error handling, authentication.

Over the last year, I found myself writing over and over CORS headers. This requirement brought me to think that I should have used a Negroni middleware since we are already using Negroni for other middlewares. I started looking online for an already written one, and I found a bunch, but I was not happy with what I found, so I decided to write my own.

I wanted to create a piece of code that was small (tens of lines of code) but still allowed me to have a default behavior (everything is denied unless the origin is *.example.com) and allowed to override it, where needed.

The first thing we will need is a CORSMiddleware struct:

type CORSMiddleware struct {
        OriginRule string
}

We are going to declare an OriginRule variable of type string so that we can put the default one or the custom one.

Then let’s create a builder accepting a custom origin. I know that this is not required since you could instantiate the CORSMiddleware struct from outside the package, but I wanted to provide a simpler API to it:

func NewOrigin(origin string) *CORSMiddleware {
        return &CORSMiddleware{
                OriginRule: origin,
        }
}

This function will return a pointer to a CORSMiddleware with the right OriginRule set. In the same way, I’ve created a builder that returns the default origin one:

func New() *CORSMiddleware {
        return &CORSMiddleware{
                OriginRule: ".*\\.example\\.com",
        }
}

Due to how the CORS system works, it is essential to identify if the corresponding GET of the current OPTIONS request would be allowed or not. To achieve this, we need to perform a regular expression and see if it matches. Thought, regular expressions are computing intensive, and if the set rule is a wildcard (*), it’s pointless to perform the check. For those reasons, I wrote the following function to implement this logic:

func (m *CORSMiddleware) allowedOrigin(origin string) bool {
        if m.OriginRule == "*" {
                return true
        }
        if matched, _ := regexp.MatchString(m.OriginRule, origin); matched {
                return true
        }
        return false
}

Since our goal is to implement a Negroni middleware, it needs to implement the Negroni Handler interface:

type Handler interface {
  ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}

So we need to create a ServerHTTP function that matches the signature and, when is called by an OPTIONS call, it executes the allowedOrigin function and returns the headers consequently.

func (m *CORSMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
        if r.Method == "OPTIONS" {
                if m.allowedOrigin(r.Header.Get("Origin")) {
                        w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
                        w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, PUT, DELETE")
                        w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, Authorization, ResponseType")
                }
                return
        }
        w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
        w.Header().Set("Access-Control-Allow-Origin", "*")
        next(w, r)
}

So putting all together, the resulting code is:

package cors

import (
        "net/http"
        "regexp"
)

type CORSMiddleware struct {
        OriginRule string
}

func NewOrigin(origin string) *CORSMiddleware {
        return &CORSMiddleware{
                OriginRule: origin,
        }
}

func New() *CORSMiddleware {
        return &CORSMiddleware{
                OriginRule: ".*\\.example\\.com",
        }
}

func (m *CORSMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
        if r.Method == "OPTIONS" {
                if m.allowedOrigin(r.Header.Get("Origin")) {
                        w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
                        w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, PUT, DELETE")
                        w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, Authorization, ResponseType")
                }
                return
        }
        w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
        w.Header().Set("Access-Control-Allow-Origin", "*")
        next(w, r)
}

func (m *CORSMiddleware) allowedOrigin(origin string) bool {
        if m.OriginRule == "*" {
                return true
        }
        if matched, _ := regexp.MatchString(m.OriginRule, origin); matched {
                return true
        }
        return false
}