trainer = Trainer(
model=model,
optimizer=optimizer,
iterator=iterator,
train_dataset=train_reader,
validation_dataset=val_reader,
num_epochs=10,
cuda_device=0,
batch_size = 2
)
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
model_parameters = Params.from_file("path/to/model.jsonnet")
model_parameters["model_type"] = "bert_base"
最后,可以将模型转移到更大的机器上进行训练/微调,以减少OOM问题。