在PyTorch中,可以使用torch.nn.ModuleList来实现类似于torch.nn.Sequential容器的并行模拟。以下是一个示例代码:
import torch
import torch.nn as nn
class ParallelModule(nn.Module):
def __init__(self, modules):
super(ParallelModule, self).__init__()
self.modules = nn.ModuleList(modules)
def forward(self, x):
outputs = [module(x) for module in self.modules]
return torch.cat(outputs, dim=1)
# 创建两个子模块
module1 = nn.Linear(10, 5)
module2 = nn.Linear(10, 3)
# 创建并行模块
parallel_module = ParallelModule([module1, module2])
# 输入数据
x = torch.randn(2, 10)
# 前向传播
output = parallel_module(x)
print(output)
在上面的示例中,我们创建了两个线性模块module1
和module2
,然后使用ParallelModule
将它们包装在一起。在forward
方法中,我们并行地应用每个子模块到输入x
上,并将它们的输出连接在一起。最后,我们可以通过调用parallel_module(x)
来进行前向传播,并得到最终的输出。
请注意,要使用并行模块进行训练,您需要确保每个子模块的输入和输出维度是匹配的。在上面的示例中,两个子模块的输入和输出维度都是相同的。