!pip install einops equinox jaxtyping gdownConditional Image Generation
Associative Memory for Face Generation
Building a simple Dense Associative Memory by inserting images into a memory matrix. Final model is converted into a TFJS binary that can be served in javascript.
In [3]:
Fetch and parse data
In [5]:
import gdown
from pathlib import Path
force_redownload = True
ppl_id= '1oGCR84KnIwSzTfN85ikWlA4kGpYna75D'
ppl_yaml = 'workshop_people.yaml'
if force_redownload or not Path(ppl_yaml).exists():
gdown.download(id=ppl_id, output=ppl_yaml, quiet=False)
imgs_id = '15FTnXRLBTL1oBr3GK5CbB8jfSiCZJxba'
imgs_dir = "workshop_headshots"
if force_redownload or not Path(imgs_dir).exists():
gdown.download_folder(id=imgs_id, output=imgs_dir, quiet=False)In [6]:
import yaml
import os
import numpy as np
from PIL import Image
with open(ppl_yaml, "r") as fp:
people = yaml.safe_load(fp)
headshots = [Image.open(os.path.join(imgs_dir, person["headshot"])).convert('RGB') for person in people]
assert all(h.size == headshots[0].size for h in headshots)
imgs = np.stack([np.asarray(h) for h in headshots])
img_shape = imgs.shape[1:]
img_nelements = D = imgs[0].size
Nsamples = len(imgs)
h,w,c = img_shapeCreate the Associative Memory
First, let’s import what we need to build our Associative Memory (we will use JAX – truly the best for managing complex gradient computation), and write some helper functions for converting images to vectors (to store in our Associative Memory) and back to images.
In [9]:
import jax
import jax.random as jr
import jax.numpy as jnp
import jax.tree_util as jtu
import equinox as eqx
import optax
from einops import rearrange
import numpy as np
from sklearn.decomposition import PCA
from PIL import Image
from jaxtyping import Float, Array
import os
from typing import *
import functools as ft
from pathlib import Path
import yaml
import json
import numpy as np
def to_imgs(x):
"""Convert weight vectors back to images"""
x = rearrange(x, "... (h w c) -> ... h w c", h=h, w=w, c=c)
if len(x.shape) < 4: return [Image.fromarray(np.uint8(x*255))]
return [Image.fromarray(np.uint8(x_*255)) for x_ in x]
def to_vectors(x):
"""Convert images to weight vectors"""
return rearrange(x, "... h w c -> ... (h w c)") / 255.Next, let’s define our energy function using a simple Equinox module. If you paid attention to the math, you’ll notice that we took a few shortcuts in this implementation. Namely:
- We integrated out our hidden memories s.t. the whole energy function is only a function of \(x\) and \(y\) as in \(E_\theta (x, y)\).
- We assume that the labels are conditioning labels: that is, we don’t need to evolve them over time
Our Associative Memory stores \(N\) (image-label) pairs in the \(\{\theta^\text{image}, \theta^\text{label}\}\) parameters, where each image is of dimension \(D\) and each label is of dimension \(M\).
In [11]:
class ConditionedAssociativeMemory(eqx.Module):
W: jax.Array
labelW: jax.Array
beta: float
def __init__(self,
Winit:Float[Array, "N D"], # Initialize the image weights
labelW_init: Float[Array, "N M"], # Initialize the label weights
beta_init=1. # Inverse temp of the hidden units
):
self.W = jnp.array(Winit)
self.labelW = jnp.array(labelW_init)
self.beta = beta_init
def __call__(self,
x: Float[Array, "D"], # Dynamic image
labels: Float[Array, "M"], # Conditioning label
label_strength:float=1., # Additionally weight the label, compensates for M << D
):
"""Compute the energy of the memories given a particular label"""
assert len(x.shape) < 2, "We do not expect a batch dimension"
sim = -jnp.sum(jnp.power(self.W - x, 2), -1)
simlabel = label_strength * self.labelW @ labels
energy = -jax.nn.logsumexp(self.beta * (sim + label_strength * simlabel))
aux={}
return energy, auxUsing the energy function…
In [13]:
W = to_vectors(imgs)
labels = jnp.arange(Nsamples)
labelW = jax.nn.one_hot(jnp.arange(Nsamples), num_classes=Nsamples)
label_strength=20_000 # How much more strongly to weight labels than pixels in the similarity functions
label_beta=10.
device="cpu"
# Initialize energy function
am = ConditionedAssociativeMemory(W, labelW, beta_init=10.)Training PCA
For the demo, we will project the stored images using PCA to build intuition: for the dynamics converging to different energy minima.
In [15]:
# Train PCA model on the stored images
pca_model = PCA(2).fit(W)
W2 = pca_model.transform(W)
x0 = jnp.array(W[0])To use this energy function in the frontend, we need a single function that returns the new image given the current image and a label. We also expose other things we want to show in the demo, e.g., the true energy value and the PCA projected state
In [17]:
def energy_and_projection_step(
am:ConditionedAssociativeMemory, # The Associative Memory
pcomponents, # The learned principle components of the images
pmean, # sklearn's PCA algorithm centers the images before projecting -- this is the mean value expected
x:Float[Array, "D"], # The dynamic image state
label:Float[Array, "M"], # The dynamic label state
alpha:float # The step size down the energy to take
):
n_classes = am.W.shape[0]
label = jax.nn.one_hot(label, n_classes)
(energy, aux), grad = jax.value_and_grad(am, argnums=0, has_aux=True)(x, label, label_strength=label_strength)
x2 = x - alpha * grad
X = x2 - pmean
proj_x2 = X[None] @ jnp.array(pcomponents).T
return x2, energy, proj_x2Quick test of the energy function:
In [19]:
x2, e, proj_x2 = energy_and_projection_step(am, pca_model.components_, pca_model.mean_, x0, 5, 0.05)
stepfn = ft.partial(energy_and_projection_step, am, pca_model.components_, pca_model.mean_)Porting to ONNX
Our trained associative memory now needs to be ported to the frontend to run. For this, we will convert our model to tensorflowjs via JAX’s built-in conversion tools for tensorflow
In [21]:
# Don't upgrade Colab's TF version
!pip install tensorflowjs --no-depsRequirement already satisfied: tensorflowjs in /usr/local/lib/python3.10/dist-packages (4.20.0)
In [22]:
import tensorflow as tf
import tensorflowjs as tfjs
import tempfile
from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion
from jax.experimental import jax2tf
from typing import *
from pathlib import Path
DType = Any
PolyShape = jax2tf.PolyShape
Array = Any
_TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
class _ReusableSavedModelWrapper(tf.train.Checkpoint):
"""Wraps a function and its parameters for saving to a SavedModel.
Implements the interface described at
https://www.tensorflow.org/hub/reusable_saved_models.
"""
def __init__(self, tf_graph, param_vars):
"""Args:
tf_graph: a tf.function taking one argument (the inputs), which can be
be tuples/lists/dictionaries of np.ndarray or tensors. The function
may have references to the tf.Variables in `param_vars`.
param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
to be saved as the variables of the SavedModel.
"""
super().__init__()
# Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models
self.variables = tf.nest.flatten(param_vars)
self.trainable_variables = [v for v in self.variables if v.trainable]
# If you intend to prescribe regularization terms for users of the model,
# add them as @tf.functions with no inputs to this list. Else drop this.
self.regularization_losses = []
self.__call__ = tf_graph
def convert_jax(
apply_fn: Callable[..., Any],
*,
input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]],
model_dir: str,
polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None):
"""Converts a JAX function `apply_fn` to a TensorflowJS model.
Works with `functools.partial` style models if we don't need to access the variables in the frontend.
Example usage for an arbitrary function:
```
import functools as ft
import tensorflow as tf
def energy_and_projection_step(am, pcomponents, x:jnp.array, label:int, alpha:float):
...
fn = ft.partial(predict_fn, trained_model, jnp.array(pcomponents))
convert_jax(
apply_fn=fn,
input_signatures=[tf.TensorSpec((D,)), tf.TensorSpec(tuple(), dtype=tf.int32), tf.TensorSpec(tuple(), dtype=tf.float32)],
model_dir=tfjs_model_dir) # Saves model to binary files that can be loaded by tfjs
```
Note that when using dynamic shapes, an additional argument
`polymorphic_shapes` should be provided specifying values for the dynamic
("polymorphic") dimensions). See here for more details:
https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion
This is an adaption of the original implementation in jax2tf here:
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py
Arguments:
apply_fn: A JAX function that has one or more arguments, of which the first
argument are the model parameters. This function typically is the forward
pass of the network (e.g., `Module.apply()` in Flax).
input_signatures: the input signatures for the second and remaining
arguments to `apply_fn` (the input). A signature must be a
`tensorflow.TensorSpec` instance, or a (nested) tuple/list/dictionary
thereof with a structure matching the second argument of `apply_fn`.
model_dir: Directory where the TensorflowJS model will be written to.
polymorphic_shapes: If given then it will be used as the
`polymorphic_shapes` argument for the second parameter of `apply_fn`. In
this case, a single `input_signatures` is supported, and should have
`None` in the polymorphic (dynamic) dimensions.
"""
tf_fn = jax2tf.convert(
apply_fn,
# Gradients must be included as 'PreventGradient' is not supported.
with_gradient=True,
polymorphic_shapes=polymorphic_shapes,
# Do not use TFXLA Ops because these aren't supported by TFjs, but use
# workarounds instead. More information:
# https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops
enable_xla=False)
# Create tf.Variables for the parameters. If you want more useful variable
# names, you can use `tree.map_structure_with_path` from the `dm-tree`
# package.
# For this demo we presume that the variables are inaccessible:
param_vars = []
# param_vars = tf.nest.map_structure(
# lambda param: tf.Variable(param, trainable=False), params)
# Do not use TF's jit compilation on the function.
tf_graph = tf.function(
lambda *xs: tf_fn(*xs), autograph=False, jit_compile=False)
# This signature is needed for TensorFlow Serving use.
signatures = {
_TF_SERVING_KEY: tf_graph.get_concrete_function(*input_signatures)
}
wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
saved_model_options = tf.saved_model.SaveOptions(
experimental_custom_gradients=True)
with tempfile.TemporaryDirectory() as saved_model_dir:
tf.saved_model.save(
wrapper,
saved_model_dir,
signatures=signatures,
options=saved_model_options)
saved_model_conversion.convert_tf_saved_model(saved_model_dir, model_dir, skip_op_check=True)Finally, let’s save our energy_and_projection_step function in an ONNX binary that can be loaded from javascript.
In [24]:
model_output_dir = Path("hopfield_model")
print(f"Converting to TFJS model in {model_output_dir}")
convert_jax(
apply_fn=stepfn,
input_signatures=[tf.TensorSpec((D,)), tf.TensorSpec(tuple(), dtype=tf.int32), tf.TensorSpec(tuple(), dtype=tf.float32)],
model_dir=model_output_dir)
# Save model config
print(f"Caching projection of original headshots")
listW2 = W2.tolist()
for i, person in enumerate(people):
person['proj2d'] = listW2[i]
config = {}
config['people'] = people
config['model_dir'] = str(model_output_dir)
config['img_size'] = list(img_shape)
config['Nsamples'] = Nsamples
config['nelements'] = img_nelements
print(f"Saving configuration")
with open(model_output_dir / "config.json", "w") as fp:
json.dump(config, fp)
print(f"DONE")Converting to TFJS model in hopfield_model
WARNING:absl:Importing a function (__inference_internal_grad_fn_2630) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_internal_grad_fn_2898) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
Caching projection of original headshots
Saving configuration
DONE