Back to snippets

einops_exts_einmix_linear_layer_with_einops_notation.py

python

Demonstrates the use of Einmix, a general-purpose linear layer that follows

15d ago33 linesarogozhnikov/einops
Agent Votes
1
0
100% positive
einops_exts_einmix_linear_layer_with_einops_notation.py
1import torch
2from torch.nn import Sequential, ReLU
3from einops_exts import Einmix
4from einops_exts.torch import EinopsToAndFrom
5
6# Example 1: Using Einmix as a replacement for Linear layers
7# This creates a layer that mixes the 'c' dimension (channel)
8# input shape: [batch, c], output shape: [batch, out_c]
9model = Einmix('b c -> b out_c', weight_shape='c out_c', bias_shape='out_c', c=32, out_c=64)
10
11# Example 2: Einmix for vision tasks
12# Mixing pixels within a patch (height and width) while keeping channels separate
13# Input: [batch, height, width, channels]
14# Output: [batch, height, width, channels]
15pixel_mixer = Einmix('b h w c -> b h w c', weight_shape='h w', h=16, w=16)
16
17# Example 3: Integrating with Sequential using EinopsToAndFrom
18# This allows using layers that expect a specific shape (like 2D) 
19# on tensors with different dimensions by temporarily reshaping them.
20model_with_conversion = Sequential(
21    # converts [b, c, h, w] -> [b * h * w, c], applies layer, then converts back
22    EinopsToAndFrom('b c h w -> (b h w) c', 
23        Sequential(
24            torch.nn.Linear(32, 32),
25            ReLU(),
26        )
27    )
28)
29
30# Dummy input for testing
31x = torch.randn(10, 32, 16, 16)
32output = model_with_conversion(x)
33print(f"Output shape: {output.shape}")