
Inject DB connections in Golang gRPC API
March 21, 2022
One of the first issues that I had to solve when I started to use gRPC was how to inject a DB connection pool to the function handling the request. The DB connection injection is needed because creating a new SQL connection every time there is a new gRPC request (and tearing it down at the end) is a massive waste of resources. Also, this approach could limit the scalability of the API since the database probably has a limited number of connections it will accept.
There are different possible ways of doing this, and some people would deem this solution as “dirty” since you will leverage Go’s context
to pass the SQL connection pool to the function.
Despite this, I (and many others) do not see any potential practical issue with this practice.
If you see practical issues, let me know!
The first element that you will need is a connection pool in your main
function.
In my case, I’ll be using GORM since that is what I usually use, but any other interface would be similar:
dbSession, err := gorm.Open(mysql.Open("[DBC_STRING]"))
if err != nil {
panic(err)
}
You will need to substitute [DBC_STRING]
with the correct value for your environment or a function that will return it.
You can now progress on the definition of a Go Context key. This step is not strictly needed since Go’s Context support string as keys, but that is not suggested due to potential conflicts.
type contextKey string
const (
DBSession contextKey = "dbSession"
)
If you create an application that will be composed of many packages, you should move the contextKey
definition and the constants definitions to a different package to prevent circular dependencies.
You can now proceed at defining the two interceptors: one for unary communications and the other one for stream communications.
func DBUnaryServerInterceptor(session *gorm.DB) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return handler(context.WithValue(ctx, DBSession, session), req)
}
}
func DBStreamServerInterceptor(session *gorm.DB) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapped := grpc_middleware.WrapServerStream(stream)
wrapped.WrappedContext = context.WithValue(stream.Context(), DBSession, session)
return handler(srv, wrapped)
}
}
Those interceptors are injecting in the context
the previously declared key with the gorm connection pool as value.
This same syntax could also be used for sql.DB
connections if you prefer not to use gorm, by replacing *gorm.DB
with *sql.DB
.
You can now move to stitch those injection functions in your gRPC server. To do so, you will need to define a gRPC server or, if you are already doing it, extend the definition with the following:
gs := grpc.NewServer(
grpc.ChainStreamInterceptor(
DBStreamServerInterceptor(dbSession),
),
grpc.ChainUnaryInterceptor(
DBUnaryServerInterceptor(dbSession),
),
)
As the name suggest, grpc.ChainStreamInterceptor
and grpc.ChainUnaryInterceptor
are meant to chain multiple interceptors, so if you are already using them, you will just need to add the interceptors to the lists.
You have ensured that the DB connection pool is always injected in the context
.
To now use the DB connection pool, you will need to extract it from context
, like in the following way:
dbSession := ctx.Value(DBSession).(*gorm.DB)
if dbSession == nil {
return nil, status.Error(codes.Internal, "no database connection found")
}
I always check that dbSession
is not nil
to ensure that everything worked properly.
The risk here is that the dbSession
was not properly injected, and you have a nil pointer.
If you do not check for this case, you could end up with an invalid memory address or nil pointer dereference
error, which would crush your gRPC server.
If you are using sql.DB
instead of gorm, replace *gorm.DB
with *sql.DB
.
Adding all code together, you’ll end up with something like:
package main
import (
"context"
"crypto/tls"
"log"
"net/http"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"google.golang.org/grpc"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
type contextKey string
const (
DBSession contextKey = "dbSession"
)
func main() {
dbSession, err := gorm.Open(mysql.Open("[DBC_STRING]"))
if err != nil {
panic(err)
}
gs := grpc.NewServer(
grpc.ChainStreamInterceptor(
DBStreamServerInterceptor(dbSession),
),
grpc.ChainUnaryInterceptor(
DBUnaryServerInterceptor(dbSession),
),
)
certPem := []byte(`-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`)
keyPem := []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
log.Fatal(err)
}
server := &http.Server{
Addr: ":8433",
TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}},
Handler: gs,
}
if err := server.ListenAndServeTLS("", ""); err != nil {
panic(err)
}
}
func DBUnaryServerInterceptor(session *gorm.DB) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return handler(context.WithValue(ctx, DBSession, session), req)
}
}
func DBStreamServerInterceptor(session *gorm.DB) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapped := grpc_middleware.WrapServerStream(stream)
wrapped.WrappedContext = context.WithValue(stream.Context(), DBSession, session)
return handler(srv, wrapped)
}
}
Be aware that this code will not run as is since you will need to replace [DBC_STRING]
with the correct DB string to make it work.
Also, the used TLS certificate is publicly available on go’s documentation, so it should not be considered safe.
Therefore, before running this code in any environment, at least replace the certificate with one whose private key is not publicly available.