Back to Article
Conditional Image Generation
Download Notebook

Conditional Image Generation

Open in Collab

Associative Memory for Face Generation

Building a simple Dense Associative Memory by inserting images into a memory matrix. Final model is converted into a TFJS binary that can be served in javascript.

In [3]:
!pip install einops equinox jaxtyping gdown

Fetch and parse data

In [5]:
import gdown
from pathlib import Path
force_redownload = True

ppl_id= '1oGCR84KnIwSzTfN85ikWlA4kGpYna75D'
ppl_yaml = 'workshop_people.yaml'
if force_redownload or not Path(ppl_yaml).exists():
    gdown.download(id=ppl_id, output=ppl_yaml, quiet=False)

imgs_id = '15FTnXRLBTL1oBr3GK5CbB8jfSiCZJxba'
imgs_dir = "workshop_headshots"
if force_redownload or not Path(imgs_dir).exists():
    gdown.download_folder(id=imgs_id, output=imgs_dir, quiet=False)
In [6]:
import yaml
import os
import numpy as np
from PIL import Image

with open(ppl_yaml, "r") as fp:
    people = yaml.safe_load(fp)

headshots = [Image.open(os.path.join(imgs_dir, person["headshot"])).convert('RGB') for person in people]
assert all(h.size == headshots[0].size for h in headshots)

imgs = np.stack([np.asarray(h) for h in headshots])
img_shape = imgs.shape[1:]
img_nelements = D = imgs[0].size
Nsamples = len(imgs)
h,w,c = img_shape

Create the Associative Memory

First, let’s import what we need to build our Associative Memory (we will use JAX – truly the best for managing complex gradient computation), and write some helper functions for converting images to vectors (to store in our Associative Memory) and back to images.

In [9]:
import jax
import jax.random as jr
import jax.numpy as jnp
import jax.tree_util as jtu
import equinox as eqx
import optax
from einops import rearrange
import numpy as np
from sklearn.decomposition import PCA
from PIL import Image
from jaxtyping import Float, Array

import os
from typing import *
import functools as ft
from pathlib import Path
import yaml
import json
import numpy as np

def to_imgs(x):
    """Convert weight vectors back to images"""
    x = rearrange(x, "... (h w c) -> ... h w c", h=h, w=w, c=c)
    if len(x.shape) < 4: return [Image.fromarray(np.uint8(x*255))]
    return [Image.fromarray(np.uint8(x_*255)) for x_ in x]

def to_vectors(x):
    """Convert images to weight vectors"""
    return rearrange(x, "... h w c -> ... (h w c)") / 255.

Next, let’s define our energy function using a simple Equinox module. If you paid attention to the math, you’ll notice that we took a few shortcuts in this implementation. Namely:

  1. We integrated out our hidden memories s.t. the whole energy function is only a function of \(x\) and \(y\) as in \(E_\theta (x, y)\).
  2. We assume that the labels are conditioning labels: that is, we don’t need to evolve them over time

Our Associative Memory stores \(N\) (image-label) pairs in the \(\{\theta^\text{image}, \theta^\text{label}\}\) parameters, where each image is of dimension \(D\) and each label is of dimension \(M\).

In [11]:
class ConditionedAssociativeMemory(eqx.Module):
    W: jax.Array
    labelW: jax.Array
    beta: float

    def __init__(self,
                 Winit:Float[Array, "N D"], # Initialize the image weights
                 labelW_init: Float[Array, "N M"], # Initialize the label weights
                 beta_init=1. # Inverse temp of the hidden units
                 ):
        self.W = jnp.array(Winit)
        self.labelW =  jnp.array(labelW_init)
        self.beta = beta_init

    def __call__(self,
                 x: Float[Array, "D"], # Dynamic image
                 labels: Float[Array, "M"], # Conditioning label
                 label_strength:float=1., # Additionally weight the label, compensates for M << D
                 ):
        """Compute the energy of the memories given a particular label"""
        assert len(x.shape) < 2, "We do not expect a batch dimension"

        sim = -jnp.sum(jnp.power(self.W - x, 2), -1)
        simlabel = label_strength * self.labelW @ labels
        energy = -jax.nn.logsumexp(self.beta * (sim + label_strength * simlabel))

        aux={}
        return energy, aux

Using the energy function…

In [13]:
W = to_vectors(imgs)
labels = jnp.arange(Nsamples)
labelW = jax.nn.one_hot(jnp.arange(Nsamples), num_classes=Nsamples)

label_strength=20_000 # How much more strongly to weight labels than pixels in the similarity functions
label_beta=10.
device="cpu"

# Initialize energy function
am = ConditionedAssociativeMemory(W, labelW, beta_init=10.)
Training PCA

For the demo, we will project the stored images using PCA to build intuition: for the dynamics converging to different energy minima.

In [15]:
# Train PCA model on the stored images
pca_model = PCA(2).fit(W)
W2 = pca_model.transform(W)
x0 = jnp.array(W[0])

To use this energy function in the frontend, we need a single function that returns the new image given the current image and a label. We also expose other things we want to show in the demo, e.g., the true energy value and the PCA projected state

In [17]:
def energy_and_projection_step(
        am:ConditionedAssociativeMemory, # The Associative Memory
        pcomponents, # The learned principle components of the images
        pmean, # sklearn's PCA algorithm centers the images before projecting -- this is the mean value expected
        x:Float[Array, "D"], # The dynamic image state
        label:Float[Array, "M"], # The dynamic label state
        alpha:float # The step size down the energy to take
        ):
    n_classes = am.W.shape[0]
    label = jax.nn.one_hot(label, n_classes)
    (energy, aux), grad = jax.value_and_grad(am, argnums=0, has_aux=True)(x, label, label_strength=label_strength)
    x2 = x - alpha * grad

    X = x2 - pmean
    proj_x2 = X[None] @ jnp.array(pcomponents).T
    return x2, energy, proj_x2

Quick test of the energy function:

In [19]:
x2, e, proj_x2 = energy_and_projection_step(am, pca_model.components_, pca_model.mean_, x0, 5, 0.05)
stepfn = ft.partial(energy_and_projection_step, am, pca_model.components_, pca_model.mean_)

Porting to ONNX

Our trained associative memory now needs to be ported to the frontend to run. For this, we will convert our model to tensorflowjs via JAX’s built-in conversion tools for tensorflow

In [21]:
# Don't upgrade Colab's TF version
!pip install tensorflowjs --no-deps
Requirement already satisfied: tensorflowjs in /usr/local/lib/python3.10/dist-packages (4.20.0)
In [22]:
import tensorflow as tf
import tensorflowjs as tfjs
import tempfile
from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion
from jax.experimental import jax2tf
from typing import *
from pathlib import Path

DType = Any
PolyShape = jax2tf.PolyShape
Array = Any
_TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY

class _ReusableSavedModelWrapper(tf.train.Checkpoint):
    """Wraps a function and its parameters for saving to a SavedModel.
    Implements the interface described at
    https://www.tensorflow.org/hub/reusable_saved_models.
    """

    def __init__(self, tf_graph, param_vars):
        """Args:
          tf_graph: a tf.function taking one argument (the inputs), which can be
             be tuples/lists/dictionaries of np.ndarray or tensors. The function
             may have references to the tf.Variables in `param_vars`.
          param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
             to be saved as the variables of the SavedModel.
        """
        super().__init__()
        # Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models
        self.variables = tf.nest.flatten(param_vars)
        self.trainable_variables = [v for v in self.variables if v.trainable]
        # If you intend to prescribe regularization terms for users of the model,
        # add them as @tf.functions with no inputs to this list. Else drop this.
        self.regularization_losses = []
        self.__call__ = tf_graph

def convert_jax(
    apply_fn: Callable[..., Any],
    *,
    input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]],
    model_dir: str,
    polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None):
    """Converts a JAX function `apply_fn` to a TensorflowJS model.
    Works with `functools.partial` style models if we don't need to access the variables in the frontend.

    Example usage for an arbitrary function:

    ```
    import functools as ft
    import tensorflow as tf

    def energy_and_projection_step(am, pcomponents, x:jnp.array, label:int, alpha:float):
        ...

    fn = ft.partial(predict_fn, trained_model, jnp.array(pcomponents))

    convert_jax(
        apply_fn=fn,
        input_signatures=[tf.TensorSpec((D,)), tf.TensorSpec(tuple(), dtype=tf.int32), tf.TensorSpec(tuple(), dtype=tf.float32)],
        model_dir=tfjs_model_dir) # Saves model to binary files that can be loaded by tfjs
    ```

    Note that when using dynamic shapes, an additional argument
    `polymorphic_shapes` should be provided specifying values for the dynamic
    ("polymorphic") dimensions). See here for more details:
    https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion

    This is an adaption of the original implementation in jax2tf here:
    https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py

    Arguments:
    apply_fn: A JAX function that has one or more arguments, of which the first
      argument are the model parameters. This function typically is the forward
      pass of the network (e.g., `Module.apply()` in Flax).
    input_signatures: the input signatures for the second and remaining
      arguments to `apply_fn` (the input). A signature must be a
      `tensorflow.TensorSpec` instance, or a (nested) tuple/list/dictionary
      thereof with a structure matching the second argument of `apply_fn`.
    model_dir: Directory where the TensorflowJS model will be written to.
    polymorphic_shapes: If given then it will be used as the
      `polymorphic_shapes` argument for the second parameter of `apply_fn`. In
      this case, a single `input_signatures` is supported, and should have
      `None` in the polymorphic (dynamic) dimensions.
    """

    tf_fn = jax2tf.convert(
        apply_fn,
        # Gradients must be included as 'PreventGradient' is not supported.
        with_gradient=True,
        polymorphic_shapes=polymorphic_shapes,
        # Do not use TFXLA Ops because these aren't supported by TFjs, but use
        # workarounds instead. More information:
        # https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops
        enable_xla=False)

    # Create tf.Variables for the parameters. If you want more useful variable
    # names, you can use `tree.map_structure_with_path` from the `dm-tree`
    # package.
    # For this demo we presume that the variables are inaccessible:
    param_vars = []
    # param_vars = tf.nest.map_structure(
    #     lambda param: tf.Variable(param, trainable=False), params)
    # Do not use TF's jit compilation on the function.
    tf_graph = tf.function(
        lambda *xs: tf_fn(*xs), autograph=False, jit_compile=False)

    # This signature is needed for TensorFlow Serving use.
    signatures = {
        _TF_SERVING_KEY: tf_graph.get_concrete_function(*input_signatures)
    }

    wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
    saved_model_options = tf.saved_model.SaveOptions(
        experimental_custom_gradients=True)

    with tempfile.TemporaryDirectory() as saved_model_dir:
        tf.saved_model.save(
            wrapper,
            saved_model_dir,
            signatures=signatures,
            options=saved_model_options)
        saved_model_conversion.convert_tf_saved_model(saved_model_dir, model_dir, skip_op_check=True)

Finally, let’s save our energy_and_projection_step function in an ONNX binary that can be loaded from javascript.

In [24]:
model_output_dir = Path("hopfield_model")
print(f"Converting to TFJS model in {model_output_dir}")

convert_jax(
    apply_fn=stepfn,
    input_signatures=[tf.TensorSpec((D,)), tf.TensorSpec(tuple(), dtype=tf.int32), tf.TensorSpec(tuple(), dtype=tf.float32)],
    model_dir=model_output_dir)

# Save model config
print(f"Caching projection of original headshots")
listW2 = W2.tolist()
for i, person in enumerate(people):
    person['proj2d'] = listW2[i]

config = {}
config['people'] = people
config['model_dir'] = str(model_output_dir)
config['img_size'] = list(img_shape)
config['Nsamples'] = Nsamples
config['nelements'] = img_nelements

print(f"Saving configuration")

with open(model_output_dir / "config.json", "w") as fp:
    json.dump(config, fp)

print(f"DONE")
Converting to TFJS model in hopfield_model
WARNING:absl:Importing a function (__inference_internal_grad_fn_2630) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_internal_grad_fn_2898) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
Caching projection of original headshots
Saving configuration
DONE