Source code for cl_gym.backbones.cnn

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class CNN1D(nn.Module): """ A small 1-D CNN used for time series tasks. """ def __init__(self, dropout_prob=0.0): super().__init__() self.dropout_prob = dropout_prob self.conv1 = nn.Conv1d(18, 64, 5) self.conv2 = nn.Conv1d(64, 64, 5) self.dropout = nn.Dropout(p=dropout_prob) self.pool2 = nn.MaxPool1d(4) self.relu = nn.ReLU(inplace=True) self.fc3 = nn.Linear(64*62, 200) self.fc4 = nn.Linear(200, 12)
[docs] def forward(self, x, task_id=None): out = x batch_size = out.shape[0] out = torch.transpose(out, 1, 2) out = self.conv1(out) out = self.relu(out) out = self.conv2(out) out = self.relu(out) if self.dropout_prob: out = self.dropout(out) out = self.pool2(out) out = out.view(batch_size, -1) out = self.fc3(out) out = self.relu(out) out = self.fc4(out) return out