GoMLX: ML in Go without Python



In the previous post I talked about running ML inference in Go through a Python sidecar process. In this post, let's see how we can accomplish the same tasks without using Python at all.

How ML models are implemented

Let's start with a brief overview of how ML models are implemented under the hood [1]. The model is typically written in Python, using one of the ML frameworks like TensorFlow, JAX or PyTorch. The framework takes care of at least 2 high-level concerns for developers:

  • Expressive way to describe the model architecture, including auto-differentiation for training.
  • Efficient implementation of computational primitives on common HW: CPUs, GPUs and TPUs.

In-between these two concerns there exists a standardized model definition format (or several) that helps multiple tools interoperate. While it's by no means the only solution [2], let's look at the OpenXLA stack as a way to run models on diverse hardware:

OpenXLA architectural diagram, with a gopher
  • The top layer are the frameworks that provide high-level primitives to define ML models, and translate them to a common interchange format called StableHLO (where "HLO" stands for High-Level Operations). I've added the gopher on the very right - it will soon become clear why.
  • The bottom layer is the HW that executes these models efficiently.
  • In the middle is the OpenXLA system, which includes two major components: the XLA compiler translating HLO to HW machine code, and PJRT - the runtime component responsible for managing HW devices, moving data (tensors) between the host CPU and these devices, executing tasks, sharding and so on.

There's a huge amount of complexity hidden by the bottom layers of this diagram. Efficient compilation and code generation for diverse HW - including using fixed blocks and libraries (like cuDNN), runtime management etc. All of this is really something one shouldn't try to re-implement unless there's a really, really good reason to do so. And the best part? There's no Python there - this is C and C++; Python only exists on the upper layer - in the high-level ML frameworks.

GoMLX

GoMLX is a relatively new Go package for ML that deserves some attention. GoMLX slots in as one of the frameworks, exactly where the Gopher is in the diagram above [3]. This is absolutely the right approach to the problem. There's no point in re-implementing the low-level primitives - whatever works for TF and JAX will work for Go as well! Google, NVIDIA, Intel and several other companies invest huge resources into these systems, and it's a good idea to benefit from these efforts.

In this post I will showcase re-implementations of some of the samples from the previous post, but with no Python in sight. But first, a few words about what GoMLX does.

GoMLX should be familiar if you've used one of the popular Python ML frameworks. You build a computational graph representing your model - the usual operations are supported and sufficient to implement anything from linear regression to cutting-edge transformers. Since GoMLX wraps XLA, it has access to all the same building blocks TF and JAX use (and it adds its own higher-level primitives, similarly to the Python frameworks).

GoMLX supports automatic differentiation to create the backward propagation operations required to update weights in training. It also provides many helpers for training and keeping track of progress, as well as Jupyter notebook support.

An image model for the CIFAR-10 dataset with GoMLX

In the previous post we built a CNN (convolutional neural network) model using TF+Keras in Python, and ran its inference in a sidecar process we could control from Go.

Here, let's build a similar model in Go, without using Python at all; we'll be training it on the same CIFAR-10 dataset we've used before.

CIFAR-10 dataset sample

The full code for this sample is here; it is heavily based on GoMLX's own example, with some modifications for simplicity and clarity. Here's the code defining the model graph:

func C10ConvModel(mlxctx *mlxcontext.Context, spec any, inputs []*graph.Node) []*graph.Node {
  batchedImages := inputs[0]
  g := batchedImages.Graph()
  dtype := batchedImages.DType()
  batchSize := batchedImages.Shape().Dimensions[0]
  logits := batchedImages

  layerIdx := 0
  nextCtx := func(name string) *mlxcontext.Context {
    newCtx := mlxctx.Inf("%03d_%s", layerIdx, name)
    layerIdx++
    return newCtx
  }

  // Convolution / activation layers
  logits = layers.Convolution(nextCtx("conv"), logits).Filters(32).KernelSize(3).PadSame().Done()
  logits.AssertDims(batchSize, 32, 32, 32)
  logits = activations.Relu(logits)
  logits = layers.Convolution(nextCtx("conv"), logits).Filters(32).KernelSize(3).PadSame().Done()
  logits = activations.Relu(logits)
  logits = graph.MaxPool(logits).Window(2).Done()
  logits = layers.DropoutNormalize(nextCtx("dropout"), logits, graph.Scalar(g, dtype, 0.3), true)
  logits.AssertDims(batchSize, 16, 16, 32)

  logits = layers.Convolution(nextCtx("conv"), logits).Filters(64).KernelSize(3).PadSame().Done()
  logits.AssertDims(batchSize, 16, 16, 64)
  logits = activations.Relu(logits)
  logits = layers.Convolution(nextCtx("conv"), logits).Filters(64).KernelSize(3).PadSame().Done()
  logits.AssertDims(batchSize, 16, 16, 64)
  logits = activations.Relu(logits)
  logits = graph.MaxPool(logits).Window(2).Done()
  logits = layers.DropoutNormalize(nextCtx("dropout"), logits, graph.Scalar(g, dtype, 0.5), true)
  logits.AssertDims(batchSize, 8, 8, 64)

  logits = layers.Convolution(nextCtx("conv"), logits).Filters(128).KernelSize(3).PadSame().Done()
  logits.AssertDims(batchSize, 8, 8, 128)
  logits = activations.Relu(logits)
  logits = layers.Convolution(nextCtx("conv"), logits).Filters(128).KernelSize(3).PadSame().Done()
  logits.AssertDims(batchSize, 8, 8, 128)
  logits = activations.Relu(logits)
  logits = graph.MaxPool(logits).Window(2).Done()
  logits = layers.DropoutNormalize(nextCtx("dropout"), logits, graph.Scalar(g, dtype, 0.5), true)
  logits.AssertDims(batchSize, 4, 4, 128)

  // Flatten logits, and apply dense layer
  logits = graph.Reshape(logits, batchSize, -1)
  logits = layers.Dense(nextCtx("dense"), logits, true, 128)
  logits = activations.Relu(logits)
  logits = layers.DropoutNormalize(nextCtx("dropout"), logits, graph.Scalar(g, dtype, 0.5), true)
  numClasses := 10
  logits = layers.Dense(nextCtx("dense"), logits, true, numClasses)
  return []*graph.Node{logits}
}

As you might expect, the Go code is longer and more explicit (nodes are threaded explicitly between builder calls, instead of being magically accumulated). It's not hard to envision a Keras-like high level library on top of this.

Here's a snippet from the classifier (inference):

func main() {
  flagCheckpoint := flag.String("checkpoint", "", "Directory to load checkpoint from")
  flag.Parse()

  mlxctx := mlxcontext.New()
  backend := backends.New()

  _, err := checkpoints.Load(mlxctx).Dir(*flagCheckpoint).Done()
  if err != nil {
    panic(err)
  }
  mlxctx = mlxctx.Reuse() // helps sanity check the loaded context
  exec := mlxcontext.NewExec(backend, mlxctx.In("model"), func(mlxctx *mlxcontext.Context, image *graph.Node) *graph.Node {
    // Convert our image to a tensor with batch dimension of size 1, and pass
    // it to the C10ConvModel graph.
    image = graph.ExpandAxes(image, 0) // Create a batch dimension of size 1.
    logits := cnnmodel.C10ConvModel(mlxctx, nil, []*graph.Node{image})[0]
    // Take the class with highest logit value, then remove the batch dimension.
    choice := graph.ArgMax(logits, -1, dtypes.Int32)
    return graph.Reshape(choice)
  })

  // classify takes a 32x32 image and returns a Cifar-10 classification according
  // to the models. Use C10Labels to convert the returned class to a string
  // name. The returned class is from 0 to 9.
  classify := func(img image.Image) int32 {
    input := images.ToTensor(dtypes.Float32).Single(img)
    outputs := exec.Call(input)
    classID := tensors.ToScalar[int32](outputs[0])
    return classID
  }

  // ...

Now classify is a function that takes an image.Image and runs it through the network, returning the index of the most likely label out of the list of CIFAR-10 labels.

The README file in the sample explains how to run it locally on a GPU; the model trains and runs successfully, with similar results to the TF+Keras model we trained in Python earlier.

Gemma2 with GoMLX

For a (much) more involved example, GoMLX has a full implementation of Gemma2 inference. The model implementation itself is in the transformers package. It should look fairly familiar if you've seen a transformer implementation in another language.

The official example in that repository shows how to run it with weights downloaded from HuggingFace; since I've already downloaded the Gemma2 weights from Kaggle for the previous post, here's a simple adaptation:

var (
  flagDataDir   = flag.String("data", "", "dir with converted weights")
  flagVocabFile = flag.String("vocab", "", "tokenizer vocabulary file")
)

func main() {
  flag.Parse()
  ctx := context.New()

  // Load model weights from the checkpoint downloaded from Kaggle.
  err := kaggle.ReadConvertedWeights(ctx, *flagDataDir)
  if err != nil {
    log.Fatal(err)
  }

  // Load tokenizer vocabulary.
  vocab, err := sentencepiece.NewFromPath(*flagVocabFile)
  if err != nil {
    log.Fatal(err)
  }

  // Create a Gemma sampler and start sampling tokens.
  sampler, err := samplers.New(backends.New(), ctx, vocab, 256)
  if err != nil {
    log.Fatalf("%+v", err)
  }

  start := time.Now()
  output, err := sampler.Sample([]string{
    "Are bees and wasps similar?",
  })
  if err != nil {
    log.Fatalf("%+v", err)
  }
  fmt.Printf("\tElapsed time: %s\n", time.Since(start))
  fmt.Printf("Generated text:\n%s\n", strings.Join(output, "\n\n"))
}

The complete code together with installation and setup instructions is here.

gomlx/gemma demonstrates that GoMLX has sufficiently advanced capabilities to run a real production-grade open LLM, without Python in the loop.

Summary

The previous post discussed some options for incorporating ML inference into a Go project via a minimal Python sidecar process. Here, we take it a step further and implement ML inference in Go without using Python. We do so by leveraging GoMLX, which itself relies on XLA and PJRT to do the heavy lifting.

If we strip down a framework like TensorFlow to its layers, GoMLX reuses the bottom layers (which is where most of the magic lies), and replaces the model builder library with a Go variant.

Since GoMLX is still a relatively new project, it may be a little risky for production uses at this point. That said, I find this direction very promising and will be following the project's development with interest.

Code

The full code for the samples in this post is on GitHub.


[1]This assumes you know the basics of neural network graphs, their training, etc. If not, check out this post and some of my other posts in the Machine Learning category.
[2]It's likely the most common production solution, and pretty much the only way to access Google's TPUs.
[3]It does so by including Go bindings for both XLA and PJRT; these are wrapped in higher-level APIs for users.

Recent posts

2024.11.11: ML in Go with a Python sidecar
2024.11.02: Ranging over functions in Go 1.23
2024.10.29: Bloch sphere
2024.10.17: Calculating the norm of a complex number
2024.10.10: Implementing Raft: Part 4 - Key/Value Database
2024.10.07: Linearizability in distributed systems
2024.09.30: Summary of reading: July - September 2024
2024.09.14: Notes on running Go in the browser with WebAssembly
2024.09.07: Notes on the Euler formula
2024.08.23: SentencePiece BPE Tokenizer in Go

See Archives for a full list.