• Rally
  • Let's Build a Codec!

Let's Build a Codec!

By Mark Waks, AKA Justin du Coeur | November 8, 2019

Introduction

So there I am -- trying to add Redis (one of the relatively straightforward and useful in-memory cache systems) to my Scala / Play application.  I've settled on rediscala as the client layer: it is non-blocking, efficient, and nicely straightforward.

There's only one issue: the protocol they expose for caching data is a bit basic -- just a simple typeclass definition for serializing and deserializing, and not much more.  If you look at their example application, it's not exactly regular, using ad hoc delimiters for data structures, and serializing numbers as Strings, which is a little suspicious.  (What happens if you have several of these in a row, and the digits run into each other?  It only works if you add delimiters, which is kind of a waste if you are just storing a one-byte number.)

I want things to be a bit more regular and efficient for my application, and I want to lay down clear guidelines for folks to use going forward.  Which means it's time to write a codec! (Coder/Decoder -- something that serializes and deserializes to a format. This term is more often used at the binary level, but I find it a convenient shorthand for the general concept.)

This article is aimed at intermediate Scala engineers, who know the language but are still learning the idioms, and is going to work through the development of a simple codec from scratch.  Along the way, we're going to talk about:

  • How to design and build typeclasses
  • How to create typeclass instances for primitive types, more interesting classes, and collections
  • Using implicit classes to make syntax prettier
  • A little bit of codec basics
  • And other such fun stuff, with nice practical examples.  

This doesn't claim to be the one true way to manage serialization -- you should always examine your own needs, and use something well-suited to them.  But I hope it shows how to break down a problem with typeclasses, and that this stuff isn't terribly hard.

General Approach

For this codec, I decided to go delimiter-free, mostly so that numeric values can pack reasonably efficiently.  For the same reason, it is binary-oriented instead of String-oriented: rediscala puts everything into ByteStrings, and the underlying Redis system is being used in binary mode, so there is no reason to stringify things.

That, in turn, means that we want to know how long each element is.  There are generally two approaches to a serialization format: either your elements are fixed-length, or you need delimiters to tell you when you have gotten to the end of an element.

(The distinction doesn't have to be hard-and-fast, but you should generally focus in one direction or the other and make sure that it makes sense for a mix of data.)

So in this code we're going to let fixed-size types (Int, Long, Short, etc) just be fixed-size, and anything that is variable in size will be preceded by its length.  That means that we can just deserialize from head to tail in a single pass, nice and quickly, without needing delimiters.

FixedLenFormatter (first try)

Step one is to define a first-draft API.  We're going to do this with a couple of typeclasses.

For those who are new to the idea of typeclasses, but who have some experience with object-oriented programming, think of a typeclass as a sort of interface.

Like a Java interface or a normal inherited Scala trait, it defines some abstract behavior that you want to provide for a bunch of classes.  But unlike the OO approach, you don't make those classes inherit from the interface -- instead, you provide a typeclass instance that says how to implement this interface for this class.

A full discussion of why typeclasses are great is outside the scope of this article, but this will show you a couple of simple typeclasses, and how to implement them for several types.

Each typeclass should define one set of related functions.  Rediscala gives us a starting point, two typeclasses that we are required to provide for each type.  Those are (simplified):

trait ByteStringSerializer[K] { self =>
  def serialize(data: K): ByteString
}
trait ByteStringDeserializer[T] { self =>
  def deserialize(bs: ByteString): T
}
trait ByteStringFormatter[T] 
  extends ByteStringSerializer[T] with ByteStringDeserializer[T]

That is, we need to provide a serialize function that takes a value and produces a ByteString, and a deserialize function that takes a ByteString and produces a value. Okay, that's straightforward enough.

We want to think in terms of serialization with well-defined lengths, so let's try this:

trait FixedLenFormatter[T] extends ByteStringFormatter[T] {
  // Given a ByteString, pull off a [[T]] from the front, and return that and
  // the number of bytes we pulled.
  def deserializeWithLen(bs: ByteString): (T, Int)
  // Push the given [[T]] onto the front of the (mutable) ByteStringBuilder.
  def serializeInto(builder: ByteStringBuilder, data: T): ByteStringBuilder
  // Adapt our typeclass to what Rediscala expects:
  override def deserialize(bs: ByteString): T = deserializeWithLen(bs)._1
  override def serialize(data: T): ByteString = {
    val builder = new ByteStringBuilder()
    serializeInto(builder, data)
    builder.result()
  }
}

That makes sense at first blush: we are providing Rediscala's ByteStringFormatter for each type we want, by redirecting from that to the functions we want to provide. deserializeWithLen() needs to say how many bytes this took (so we can skip to the next element), and serializeInto() uses a nicely-efficient ByteStringBuilder to build up the value we will be serializing.

(ByteStringBuilder is a mutable data structure, and should be used with care -- anything mutable is potentially dangerous in a multi-threaded program, and can make code harder to reason about. But it's appropriate for use in strictly single-threaded operations of limited scope like this: note how its lifespan is contained within Rediscala's call to serialize(), and it doesn't leak outside this limited API.)

Our First Few Instances

Now let's try using that on a couple of data types, to see how well it works in practice.  We're going to use Short as a building-block, so let's start with that:

object FixedLenFormatter {
  implicit object ShortFormatter extends FixedLenFormatter[Short] {
    override def deserializeWithLen(bs: ByteString): (Short, Int) = {
      val s: Short = bs.take(2).asByteBuffer.getShort(0)
      (s, 2)
    }
    override def serializeInto(
      builder: ByteStringBuilder, 
      Data: Short
    ): ByteStringBuilder =
      builder.putShort(data)
  }
}

Not too bad. Now String makes use of that. Note that we're limiting Strings to 32k -- that should be plenty enough for our application's needs, and saves a couple of bytes per String, since the length can be a Short:

implicit object StringFormatter extends FixedLenFormatter[String] {
  override def deserializeWithLen(bs: ByteString): (String, Int) = {
    val (len, lenLen) = implicitly[FixedLenFormatter[Short]].deserializeWithLen(bs)
    val str = bs.drop(lenLen).take(len).decodeString(StandardCharsets.UTF_8)
    (str, lenLen)
  }
  override def serializeInto(
    builder: ByteStringBuilder,
    str: String
  ): ByteStringBuilder = {
    val bytes: Array[Byte] = str.getBytes(StandardCharsets.UTF_8)
    val len = bytes.size
    // 32767 is the maximum bytes that can be sized by a signed Short. Ideally,
    // we would use a tagged String type that enforces the short length all over,
    // but we're not going to get into that here.
    if (len > 32767) throw new Exception("Overlong String!")
    implicitly[FixedLenFormatter[Short]].serializeInto(builder, len)
    builder.putBytes(bytes)
  }
}

Hmm. It's okay, but not great. Having to call drop() at the call site every time we call a deserialize function is going to get old, really fast, and that implicitly call (which summons the FixedLenFormatter for Short) is pretty boilerplatey. But it's a start.

When in Doubt, Refactor!

Let's tweak our API, to do that drop() step inside deserializeWithLen(), instead of at the call site, and tweak the signature accordingly:

trait FixedLenFormatter[T] extends ByteStringFormatter[T] {
  def deserializeWithLen(bs: ByteString): (T, ByteString)
}

Now, deserializeWithLen() returns the start of the next bit of ByteString -- it's basically consuming some bytes from the front, and returning the next part. So our Short instance changes like this:

override def deserializeWithLen(bs: ByteString): (Short, ByteString) = {
  val s: Short = bs.take(2).asByteBuffer.getShort(0)
  (s, bs.drop(2))
}

We've just tweaked this to do the drop() step as part of deserializing, so the caller doesn't have to do it. And the String instance becomes a tad simpler and cleaner:

implicit object StringFormatter extends FixedLenFormatter[String] {
  override def deserializeWithLen(bs: ByteString): (String, ByteString) = {
    val (len, nextPos) = implicitly[FixedLenFormatter[Short]].deserializeWithLen(bs)
    val str = nextPos.take(len).decodeString(StandardCharsets.UTF_8)
    (str, lenLen)
  }
}

That feels like we're moving in the right direction.

Improving Ergonomics with Implicit Classes

Having to call implicitly[FixedLenFormatter[Short]] over and over again kind of sucks -- that's a lot of characters of boilerplate.  Can we do better?

The problem, really, is that we don't want to be thinking in terms of the formatter -- we just want to be pushing objects onto the end of the ByteStringBuilder when we are serializing, and pulling them off the front of the ByteString when we are deserializing.  We want new methods on those types, that just do the right thing.

Fortunately, Scala has an answer for this -- implicit classes.  The clunky term "implicit class" basically means "a block of extension methods for some other class".  (It will be replaced in Scala 3 by real extension methods, which should be a bit easier to use.) So the functionality we want for ByteString can be expressed like this:

implicit class RichByteString(bs: ByteString) {
  /**
    * Pull a T from the front of this [[ByteString]].
    *
    * @tparam T the type that we are deserializing from this [[ByteString]]
    * @return the pulled value, and the [[ByteString]] advanced to the next location
    */
  def pull[T : FixedLenFormatter]: (T, ByteString) =
    implicitly[FixedLenFormatter[T]].deserializeWithLen(bs)
}

So this is effectively adding a new ByteString.pull[T]() method for any type T that has a FixedLenFormatter defined. We are finding the FixedLenFormatter for T (that's what implicitly does, remember), and then calling it on our ByteString. Exactly the same functionality, but expressed in a nicer, quasi-OO way, hiding the implicitly under the hood.

Doing the same thing for ByteStringBuilder, we get:

implicit class RichByteStringBuilder(builder: ByteStringBuilder) {
  /**
    * Push a [[T]] onto the end of this [[ByteStringBuilder]].
    *
    * @param t the value being pushed
    * @tparam T the type of that value, which must have an implicit [[FixedLenFormatter]]
    * @return the original [[ByteStringBuilder]] (which is mutable, so this is just for possible chaining)
    */
  def push[T : FixedLenFormatter](t: T): ByteStringBuilder =
    implicitly[FixedLenFormatter[T]].serializeInto(builder, t)
}

Given that, our StringFormatter now becomes cleaner:

implicit object StringFormatter extends FixedLenFormatter[String] {
  override def deserializeWithLen(bs: ByteString): (String, Int) = {
    val (len, nextPos) = bs.pull[Short]
    val str = nextPos.take(len).decodeString(StandardCharsets.UTF_8)
    (str, nextPos.drop(len))
  }
}

It's a little thing, but we're going to be doing a lot of these push and pull operations, so making it a little cleaner each time adds up quickly, and makes the code easier to read.

Higher-level Types

Okay, we've got a couple of building blocks.  How do we serialize and deserialize more interesting classes?

Let's say we have a simple case class, like this:

case class User(
  displayName: String,
  handle: String,
  age: Short
)

The serialization for that would look something like this:

object User {
  implicit val cacheFormatter = new FixedLenFormatter[User] {
    override def deserializeWithLen(bs: ByteString): (User, ByteString) = {
      val (((displayName, handle), age), next) =
        bs.pull[String]
          .pull[String]
          .pull[Short]
      (User(displayName, handle, age), next)
    }
    override def serializeInto(
      builder: ByteStringBuilder,
      data: User
    ): ByteStringBuilder =
      builder
        .push(data.displayName)
        .push(data.handle)
        .push(data.age)
  }
}

That's not bad -- just a series of pushes to serialize, and pulls to deserialize. The nested tuples when we pull aren't quite ideal -- if we were getting more serious about this, we might try using Shapeless' HLists instead, which are a lot like arbitrarily-nested Tuple2s. (And will probably become standard in Scala 3.) But this is good enough for simple use.

Typeclass Instances for Collections

Okay -- what if we want to serialize Lists?  Do we need lots of little typeclass instances for each different kind of List?  Fortunately, no, we can build a single instance that covers all sorts of Lists at once:

implicit def ListFormatter[T : FixedLenFormatter] = new FixedLenFormatter[List[T]] {
  override def deserializeWithLen(bs: ByteString): (List[T], ByteString) = {
    // At the beginning of each serialized List, we say how many elements it has:
    val (n, next) = bs.pull[Short]
    // And now, we read that many elements in. We are building up a List of the
    // results, and each time we return the "next" pointer to the remaining
    // serialized data:
    val (ts, endPointer) = (0 until n).foldLeft((List.empty[T], next)) {
      case ((ts, curByteString), i) =>
        val (t, nextByteString) = curByteString.pull[T]
        (t :: ts, nextByteString)
    }
    // Since the List gets prepended each time as we build it up, we need to 
    // reverse it to get everything into the right order at the end:
    (ts.reverse, endPointer)
  }
  override def serializeInto(
    builder: ByteStringBuilder,
    ts: List[T]
  ): ByteStringBuilder = {
    // Start with the length of the List:
    builder.push(ts.length.toShort)
    // Then push each element onto the builder:
    for {
      t <- ts
    } builder.push(t)
    builder
  }
}

Note that this is an implicit def, not an implicit val -- that's because it takes a type parameter, so it can't be a val. But in exchange for that, this now works for any type T that has a FixedLenFormatter, and any sort of List -- lots of types, all in one fell swoop.

Put the Toys Together

Okay, one more instance.  Say we have a Room, that contains a bunch of Users:

case class Room(members: List[User])

The FixedLenFormatter for that is nice and easy:

object Room {
  implicit val cacheFormatter = new FixedLenFormatter[Room] {
    override def deserializeWithLen(bs: ByteString): (Room, ByteString) = {
      val (members, next) = bs.pull[List[User]]
      (Room(members), next)
    }
    override def serializeInto(
      builder: ByteStringBuilder,
      data: Room
    ): ByteStringBuilder =
      builder
        .push(data.members)
  }
}

That's all we need! We know how to serialize User, and we know how to serialize List[T], so serializing a List[User] just works!

Conclusion

This is still just a toy -- the real system has FixedLenFormatters for more primitives such as Boolean, Option (which uses Boolean under the hood), Enumeratum enumerations, and so on, and lots of case classes.  But it's all built from the same patterns shown above, so extending this is left as an exercise for the reader.  Seriously: building this out is a good way to get more comfortable with typeclasses.

A really serious system, intended as a library for general use, would probably go further -- in particular, you can eliminate most of the boilerplate around typeclass instances for case classes by using Shapeless' "generic" mechanisms, which gives you instances more or less for free for well-behaved types.  Much of this should be built into Scala 3 -- look into automatic typeclass derivation in Dotty if you're interested in the details.

There's one very important caveat to keep in mind: we're not doing anything about schema evolution yet.  This is important to keep in mind when building a codec -- what happens when your types change? If you might need to be reading old data with a newer program, this becomes a really important question.  It's outside the scope of this article, but I recommend looking into serious serialization formats like Protobuf if you are interested: they spend a good deal of effort to make this work properly.

I hope this has been useful for you.  Typeclasses are a key tool for serious Scala programming, and while they take a bit of practice, they shouldn't be hard once you've worked a bit in them.  I encourage you to play with this code, and to come up with some similar problems on your own. Pretty much any time when you say "I want to be able to do X with lots of different types", you are probably looking for a typeclass: try it out!

And remember: you shouldn't burn your brain too much trying to make it all perfect from the beginning -- that's why I outlined the way this system was really built, wrong turns and all.  The great thing about Scala is that refactoring and tweaking is pretty easy to do without breaking stuff, so make a stab, see how it works, and evolve it until it's what you really want.

Mark Waks, AKA Justin du Coeur