This is a rough note to help me think about next steps of a project I'm working on. I'm hoping to come back and update/follow-up this post once the project is more figured out!

I recently took a course on program analysis that had an open-ended final project. My friend and I built a static analysis to infer PyTorch tensor shapes in the hopes of creating a useful tool for developers to catch shape mismatches early. What we submitted was a good proof of concept that works for a functional subset of Python/PyTorch, but can't yet analyze most PyTorch programs. I've been thinking about how to extend our analysis into a big boy analysis that can handle real Python$^\text{TM}$.

Looking around the dev tool landscape, I don't see many analyses over languages like Python that made it out of academia. This made me a little worried that dynamic languages are just too hard.

Dynamic languages by definition have less information known at compile-time, which make it harder to write good analyses over them. As I'm looking to build our proof of concept analysis into a useful tool, I want to think more about why languages like Python are so hard to analyze, and what simplifying assumptions can be made that make it tractable, and then apply these thoughts to our analysis.

Our analysis

At a high level, our analysis allows user to annotate tensor parameters with shapes, and then performs a dataflow analysis to infer tensor shapes at each location in a function. Given this function,

def concat(x: T["m a"], y: T["m b"], z: T["m c"]):
    w = torch.concat((x, y, z), dim=1)

our analysis infers that w is a tensor with shape $[m, a+b+c]$.

Going forward, I'll call T[m a] a tensor with shape $m \times a$, where $m$ and $a$ are "dimensional variables".


Why are dynamic languages tricky to analyze?

I'm going to focus on several important questions that analysts like to be able to answer about programs. For each, I'll meditate on how dynamism gets in the way, and what tradeoffs we can make to make answering these questions tractable again.

Here are the questions / problems:

  • Typing: What type is this variable? What attributes and methods does it have?
  • Determining Intraprocedural Control-Flow Graphs: What statements can execute after this one?
  • Heap analysis: What variables can reference this object instance, and what values does it hold?

Typing

  • "What type is this object? What attributes and methods does it have?"

Probably the most common type of analyses are ones that attempt to determine the type of each variable, ex. whether it's a primitive type or a more complex struct/class with a set of fields and methods (with their associated types).

Typing is important for downstream analyses that need to ex. match callsites to specific method definitions (method resolution).

In a statically-typed language, each variable has exactly one type, and it's known at compile time, so we get this essentially for free.

How do types work in Python? Each variable refers to an object, which is an instance of a class.

The same variable can refer to objects of different classes over different execution paths:

def foo(a: bool):
	x = 5
	if a:
		x = "no longer an int"
	print(type(x)) # might be <class 'int'> or <class 'str'>

And unlike ex. structs in Rust, classes aren't a static container of attributes and methods that is fixed at compile-time:

class Foo: pass

def bar(a: bool):
	b = Foo()
	if a:
		b.woo = lambda x: x + 2
	
	b.woo(12) # this might dispatch to the woo added above, or to 
			  # a .woo added before, or __getattr__ might be invoked, 
			  # potentially raising AttributeError

In Python, the shape of object instances is really created and mutated at runtime. Method and attribute lookup happens at runtime and can raise exceptions.

How do type checkers for dynamic languages work?

Let's focus on mypy, a popular static type checker for Python.

mypy assumes variables have static types after initialization, unless they're explicitly typed otherwise. This means that it will error on our program from before:

def foo(a: bool):
	x = 5
	if a:
		x = "no longer an int" # error: Incompatible types in assignment
	print(type(x)) 

even though this is valid Python. mypy is making the assumption that this type of dynamism is typically an error unless explicitly annotated:

def foo(a: bool):
	x: int | str = 5 # we can add annotations to recover the dynamism
	# unchanged...

# Success: no issues found

def foo(a: bool):
	x: Any = 5 # there's also a catchall 'Any' type
	# unchanged...

# Success : no issues found

mypy also disallows runtime additions to classes, assuming they're a static collection of fields and methods (hmm, that sounds familiar):

class Foo: pass

def bar(a: bool):
	b = Foo()
	if a:
		b.woo = lambda x: x + 2 # error: "Foo" has no attribute "woo"
	
	b.woo(12)

And we can again recover the dynamism with explicit typing1 :

from typing import Callable

class Foo:
	woo: Callable[[int], int]

def bar(a: bool):
	# unchanged...

# Sucess: no issues found

Essentially mypy accepts a subset of the total space of valid Python programs as a design decision. They deal with the annoying (from an analysis perspective) bits of a dynamic language by leaving them out of the semantics unless explicitly annotated. This obviously makes it much easier! And with the easy escape hatch given by annotations as well as gradual typing, these tools have chosen a decent tradeoff.

Going forward, I'm going to assume we have a mypy-style type-checked program, which makes the language we're operating on downstream almost as well behaved as a statically typed language (because we're assuming it is statically typed).

Control-Flow Analysis

  • "After this statement, what statements could execute next?"

It's often useful to transform functions into an intermediate representation that is represented as a directed graph over blocks of statements that are executed atomically. This is often called a Control-Flow Graph (CFG). Consider the function below:


def foo(a: int):
	if a > 5:
		print("a > 5")
	else:
		print("a <= 5)

This function would transform into a representation like this:

block0:
	if a > 5 jmp block1 else jmp block2
block1:
	print("a > 5")
	jmp block3
block2:
	print("a <= 5")
	jmp block3
block3:
	return None

Each block is a series of statements, ending with a terminator, which can be either a conditional/unconditional jump, or a return.

Having a representation like this is nice because it provides a standardized way to reason about control flow, as opposed to handling if statements differently than while loops differently than for loops, etc. We can see that answering the question "What statements can execute after statement X?" is now easy. If the statement of interest is not the last statement in a block, only the following statement in that block can execute next. If the statement is a terminator at the end of a block, any of the first statements of the blocks that may be jumped to could be executed next.

So is converting a function to a CFG harder in Python than ex. Rust?

Ignoring exception-based control flow, constructing a CFG isn't that much trickier for Python programs than Rust programs. However, when doing analysis over the CFG, we might care about determining in what cases a jump condition is truthy vs. falsy.

For the purposes of the tensor shape inference analysis I'm working on, I'm curious if ignoring conditions, and assuming any path through the CFG could be taken is a viable approach.

Because our analysis is a dataflow analysis, it updates a mapping from variables to the possible shapes they could be at each statement, and joins possible shapes when multiple control flow paths flow together. Here is an example for our analysis:

def foo(a: T["b d"], flag: bool):
	if flag:
		a = torch.flatten(a)
		# a has shape [bd]
	else:
		a = a * 2
		# a has shape [b, d] (unchanged)
	# here a can have shape [bd] or [b, d], represented as {[bd], [b, d]}

We're explicitly ignoring the flag condition and assuming that either branch could execute. Then, when inferring future shapes, we run our inference model on each possible shape a variable could have, accepting the ones where dimension constraints are met. If none meet dimension constraints, we throw an error. You could also imagine throwing an error if any of the possible shapes don't satisfy a dimension constraint, but we believed this would lead to a lot of false positives.

This approach has a potential issue: the possible shapes a variable could get very large, and we might miss true errors. On the other hand, it probably won't admit many false positive errors.

While the true test will be applying our analysis to many real PyTorch programs, my hunch (hope?) is that this won't be too big of a problem in practice, as (1) most branching operations over tensor operations don't result in different tensor shapes, only different tensor values, and (2) our strategy of rejecting inferred shapes that don't meet constraints will keep the set of possible shapes relatively small.

Class & Heap Analysis

  • "What object instance does this variable reference?" and "What concrete value might this object take?"

A common pattern in PyTorch code is creating a class that represents a neural network, where the params determine the exact size:

class TwoLayerMLP(nn.Module):
    def __init__(self, in_dim: D["in"], hidden_dim: D["hidden"], out_dim: D["out"]):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x: T["b in"]) -> T["b out"]:
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def main():
	mlp = TwoLayerMLP(50, 100, 20)
	x = torch.zeros((100, 50))
	y = mlp(x)

We'd like to be able to model specific instances of user-defined classes like TwoLayerMLP, as well as objects like dictionaries and lists. To do this, we'll need a way to understand how the shapes within a class interact, as well model the heap as a store of objects, with variables referencing them.

In Python, everything is a class; I'm going to focus on user-defined classes as they're the most interesting.

We can introduce a new class-level analysis that will process the methods of a class. Just as type-checking a class yields a static set of fields and methods with known types, running our analysis over a class will associate relevant fields with shapes and dimensional variables, composed of the dimensional variables from the __init__ param annotations (in, hidden, out for TwoLayerMLP).

An important similar assumption we're making is this: all shapes can be inferred from the body of __init__ or explicit class field annotations. This will allow us to significantly simplify our analysis and keep it context-insensitive.

Context-insensitive means we can analyze the body of each method/function only once, and will not have to run an analysis for each set of arguments we encounter at different callsites.

To elucidate this approach, and show why we can get away with analyzing each method only once, let's simulate the class-level analysis on TwoLayerMLP.

  1. From the typecheck phase, we know that TwoLayerMLP has the following fields and associated types (omitting methods):
    • fc1: nn.Linear
    • fc2: nn.Linear
  2. We run something very similar to our function-level analysis over the __init__ method, producing shape and dimensional variable inferences for all relevant variables.
  3. When encountering an nn.Linear constructor, if it's not in our cache of class analyses, we run a class-level analysis on it before continuing. The result of that analysis is that nn.Linear has two fields with associated shapes derived from the parameters: weight: T[out_features, in_features] and bias: T[out_features].
  4. Returning to our analysis over TwoLayerMLP, we construct a copy of the nn.Linear class with TwoLayerMLP's dimensional variables substituted for in_features, out_features
    • for fc1, this yields a nn.Linear object with weight: T[in, hidden], bias: T[hidden]
    • for fc2, this yields a nn.Linear object with weight: T[hidden, out], bias: T[out]
  5. After __init__ is analyzed, each relevant field will have a shape or dimensional variable associated with it.
  6. Then each method can be checked for internal consistency, pulling shapes from the __init__ analysis for attribute accesses
  7. In main(), when mlp is constructed, the concrete dimensional variables 50, 100, 20 are recursively substituted for in, hidden, out in each field and method signature of TwoLayerMLP. 2
  8. On the mlp(x) call on line 15, we do our typical signature check with the new forward signature of (x: T[b 50]) -> T[b 20]

This approach has the benefits of being nicely compositional (fast!) and relying mostly on the function-level analysis that is already built. What it's missing is the ability to deal with shape mutations caused by methods other than __init__. I'm hoping that this isn't too big of a deal in practice, and that something similar to this is a good start.

A note on handling dicts/lists

The analysis above is designed with the idea that classes get initialized with some dimensional variables and then have a set of methods that can be parameterized by those dimensional variables. Yet there are also built-in and user-defined collections that don't act like this and will need to be treated differently.

1

our woo annotation tells mypy to trust us that .woo on a Foo instance will resolve correctly, yet unless we monkey-patch Foo at runtime to have a woo callable, we'll still get a runtime Exception when we call bar(a=False)

2

equivalently you can store a reference to the original analysis with a mapping from class to argument dimensional variables and apply the mapping at callsites.