--- id: wiki-2026-0508-gnn title: Graph Neural Networks (GNN) category: 10_Wiki/Topics status: verified canonical_id: self aliases: [GNN, graph neural network, GCN, GAT, message passing, PyG, DGL] duplicate_of: none source_trust_level: A confidence_score: 0.97 verification_status: applied tags: [machine-learning, gnn, graph-neural-network, gcn, gat, message-passing, pyg] raw_sources: [] last_reinforced: 2026-05-10 github_commit: pending tech_stack: language: Python framework: PyTorch Geometric / DGL --- # Graph Neural Networks (GNN) ## 매 한 줄 > **"매 graph 의 의 의 message passing"**. 매 node + edge + global feature. 매 GCN (Kipf 2017), GAT, GraphSAGE, GIN, message-passing framework. 매 응용: 매 social, 매 drug, 매 molecule (AlphaFold), 매 traffic, 매 LLM 의 graph reasoning. ## 매 핵심 ### 매 task - **Node classification**: 매 단일 node label. - **Link prediction**: 매 edge 의 의 likelihood. - **Graph classification**: 매 entire graph. - **Graph regression**. - **Generation**: 매 graph generative. ### 매 layer family - **GCN** (Kipf 2017): 매 spectral / message passing. - **GAT**: 매 attention. - **GraphSAGE**: 매 sampled neighborhood. - **GIN** (Xu 2019): 매 most expressive. - **Transformer-based**: GraphTransformer, Graphormer. - **Message Passing NN** (general). ### 매 modern - **Geometric DL** (Bronstein). - **Equivariant GNN** (E(3), SE(3)). - **AlphaFold-3** (geometric deep learning). - **GNN + LLM** (graph reasoning). ### 매 응용 1. **Social network**: 매 fraud, recommendation. 2. **Molecule**: 매 drug, materials. 3. **Knowledge graph**: 매 reasoning. 4. **Traffic**: 매 ETA prediction. 5. **Recommender**. 6. **Combinatorial opt** (TSP, scheduling). ## 💻 패턴 ### GCN (PyG) ```python import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv class GCN(torch.nn.Module): def __init__(self, in_feat, hidden, n_classes): super().__init__() self.conv1 = GCNConv(in_feat, hidden) self.conv2 = GCNConv(hidden, n_classes) def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.5, training=self.training) return self.conv2(x, edge_index) ``` ### GAT (attention) ```python from torch_geometric.nn import GATConv class GAT(torch.nn.Module): def __init__(self, in_feat, hidden, n_heads=8): super().__init__() self.conv1 = GATConv(in_feat, hidden, heads=n_heads, dropout=0.6) self.conv2 = GATConv(hidden * n_heads, n_classes, heads=1, concat=False) def forward(self, x, edge_index): x = F.elu(self.conv1(x, edge_index)) return self.conv2(x, edge_index) ``` ### GraphSAGE (sampling) ```python from torch_geometric.nn import SAGEConv class GraphSAGE(torch.nn.Module): def __init__(self, in_feat, hidden, out_feat): super().__init__() self.conv1 = SAGEConv(in_feat, hidden, aggr='mean') self.conv2 = SAGEConv(hidden, out_feat, aggr='mean') ``` ### Custom MessagePassing ```python from torch_geometric.nn import MessagePassing class CustomConv(MessagePassing): def __init__(self, in_feat, out_feat): super().__init__(aggr='mean') self.lin = torch.nn.Linear(in_feat, out_feat) def forward(self, x, edge_index): x = self.lin(x) return self.propagate(edge_index, x=x) def message(self, x_j): return x_j # 매 from neighbor def update(self, aggr_out): return aggr_out ``` ### Graph classification (read-out) ```python from torch_geometric.nn import global_mean_pool class GraphClassifier(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(in_feat, 64) self.conv2 = GCNConv(64, 64) self.classifier = torch.nn.Linear(64, n_classes) def forward(self, x, edge_index, batch): x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) # 매 graph-level return self.classifier(x) ``` ### Link prediction ```python import torch.nn as nn class LinkPredictor(nn.Module): def __init__(self): super().__init__() self.encoder = GCN(...) self.decoder = lambda src, dst: (src * dst).sum(-1) # 매 dot product def forward(self, x, edge_index, edge_label_index): z = self.encoder(x, edge_index) src = z[edge_label_index[0]] dst = z[edge_label_index[1]] return self.decoder(src, dst) ``` ### Sampling for large graphs (NeighborLoader) ```python from torch_geometric.loader import NeighborLoader loader = NeighborLoader(data, num_neighbors=[15, 10], batch_size=128, input_nodes=data.train_mask) for batch in loader: out = model(batch.x, batch.edge_index) loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size]) ``` ### Heterogeneous (HeteroData) ```python from torch_geometric.data import HeteroData data = HeteroData() data['user'].x = user_feats data['movie'].x = movie_feats data['user', 'rates', 'movie'].edge_index = rate_edges from torch_geometric.nn import to_hetero model = to_hetero(model, data.metadata()) ``` ### Equivariant GNN (E(n)-EGNN) ```python class EGNN(MessagePassing): def __init__(self, dim): super().__init__(aggr='mean') self.edge_mlp = nn.Sequential(nn.Linear(2*dim+1, dim), nn.SiLU(), nn.Linear(dim, dim)) self.coord_mlp = nn.Linear(dim, 1) def forward(self, x, pos, edge_index): return self.propagate(edge_index, x=x, pos=pos) def message(self, x_i, x_j, pos_i, pos_j): rel_pos = pos_i - pos_j dist = (rel_pos ** 2).sum(-1, keepdim=True) edge_feat = self.edge_mlp(torch.cat([x_i, x_j, dist], -1)) coord_msg = rel_pos * self.coord_mlp(edge_feat) return edge_feat, coord_msg ``` ### Drug discovery (molecule) ```python from torch_geometric.datasets import MoleculeNet dataset = MoleculeNet(root='data', name='ESOL') # 매 atom-level features + bond edges → solubility ``` ### Knowledge graph (TransE) ```python class TransE(nn.Module): def __init__(self, n_entities, n_relations, dim): super().__init__() self.entity_emb = nn.Embedding(n_entities, dim) self.relation_emb = nn.Embedding(n_relations, dim) def score(self, h, r, t): return -(self.entity_emb(h) + self.relation_emb(r) - self.entity_emb(t)).norm(dim=-1) ``` ### Graph Transformer (Graphormer) ```python class GraphTransformer(nn.Module): def __init__(self, dim, n_heads=8): super().__init__() self.attn = nn.MultiheadAttention(dim, n_heads) self.spatial_bias = nn.Embedding(MAX_DIST, n_heads) def forward(self, x, spatial_dist): # 매 attention with spatial bias bias = self.spatial_bias(spatial_dist) attn_out, _ = self.attn(x, x, x, attn_bias=bias) return attn_out ``` ### GNN explainer ```python from torch_geometric.explain import Explainer, GNNExplainer explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=200), explanation_type='model', node_mask_type='attributes', edge_mask_type='object', ) explanation = explainer(data.x, data.edge_index, target=label) ``` ## 매 결정 기준 | 상황 | Architecture | |---|---| | Default | GCN | | Heterogeneous | HeteroData + GAT | | Large graph | GraphSAGE + sampling | | Most expressive | GIN | | Spatial / molecule | EGNN / SchNet | | Graph-level | + global pooling | | Knowledge graph | TransE / RotatE | | Long-range | GraphTransformer / Graphormer | **기본값**: 매 PyG + 매 GCN/GAT baseline + 매 sampling for large + 매 EGNN for geometry + 매 explainer. ## 🔗 Graph - 부모: [[Deep-Learning]] · [[Graph_Theory|Graph-Theory]] - 변형: [[GCN]] · [[GAT]] · [[GIN]] - 응용: [[Recommender-Systems]] · [[Knowledge-Graphs]] - Adjacent: [[AlphaFold]] ## 🤖 LLM 활용 **언제**: 매 graph data. 매 social. 매 molecule. 매 KG. **언제 X**: 매 sequence / image (use Transformer / CNN). ## ❌ 안티패턴 - **Over-smoothing** (deep GNN): 매 nodes converge. - **No batching for large**: 매 OOM. - **Ignore edge features**: 매 info lose. - **Default attention 의 always**: 매 simple sometimes better. - **No scaling for many classes**: 매 long-tail. ## 🧪 검증 / 중복 - Verified (Kipf GCN 2017, Xu GIN 2019, PyG/DGL docs, AlphaFold). - 신뢰도 A. ## 🕓 Changelog | 날짜 | 변경 | |---|---| | 2026-04-26 | GNN auto | | 2026-05-08 | Phase 1 | | 2026-05-10 | Manual cleanup — GCN/GAT/SAGE + 매 PyG / hetero / EGNN / link / explainer code |