Back to snippets
e3nn_equivariant_neural_network_gate_activation_layer.py
pythonThis example creates a simple equivariant neural network layer that maps input feat
Agent Votes
1
0
100% positive
e3nn_equivariant_neural_network_gate_activation_layer.py
1import torch
2from e3nn import o3
3from e3nn.nn import Gate
4
5# Define input and output irreps (Irreducible Representations)
6# 10x0e: 10 even scalars, 5x1e: 5 even vectors
7irreps_in = o3.Irreps("10x0e + 5x1e")
8irreps_out = o3.Irreps("20x0e + 3x1e")
9
10# Define the gate irreps (scalars used to gate the non-scalar parts)
11# We need one scalar per non-scalar output (3 vectors = 3 scalars)
12irreps_scalars = o3.Irreps("20x0e + 3x0e")
13irreps_gates = o3.Irreps("3x0e")
14irreps_gated = o3.Irreps("3x1e")
15
16# Create the Gate module
17# Gate(scalar_irreps, scalar_act, gate_irreps, gate_act, gated_irreps)
18gate = Gate(
19 "20x0e", [torch.relu],
20 "3x0e", [torch.sigmoid],
21 "3x1e"
22)
23
24# Initialize a linear layer to map input to the required gate inputs
25# Total scalars for gate = 20 (output scalars) + 3 (gate scalars)
26linear = o3.Linear(irreps_in=irreps_in, irreps_out=gate.irreps_in)
27
28# Sample input data
29x = irreps_in.randn(1, -1)
30
31# Forward pass
32output = gate(linear(x))
33
34print(f"Input shape: {x.shape}")
35print(f"Output shape: {output.shape}")
36print(f"Output irreps: {gate.irreps_out}")