728x90
def train_model
train_set = Dataset(train_dataset, ...)
train_generator = torch.utils.data.DataLoader(train_set,
batch_size=1,
shuffle=True,
num_workers=4)
# val_set =
# val_generator =
# Train
optimizer = optim.SGD(params_dict, lr, momentum)
for epoch in range(self.epoch+1, epochs+1):
record log
# Training
loss, loss_rpn_class, loss_rpn_bbox, loss_mrcnn_class, los_mrcnn_bbox, loss_mrcnn_mask =
self.train_epoch(train_generator, optimizer, self.config.STEPS_PER_EPOCH)
# Validation
val_loss, val_loss_rpn_class, val_loss_rpn_bbox, val_loss_mrcnn_class, val_loss_mrcnn_bbox, val_loss_mrcnn_mask =
self.valid_epoch(val_generator, self.config.VALIDATION_STEPS)
# Statistics
self.loss_history
self.val_history
visualize_plot_loss
# Save model
torch.save()
self.epoch = epochs
def train_epoch
for inputs in datagenerator:
images = inputs[0]
image_metas = inputs[1]
rpn_match = inputs[2]
rpn_bbox = inputs[3]
gt_class_ids = inputs[4]
gt_boxes = inputs[5]
gt_masks = inputs[6]
# image_metas as numpy array
image_metas = image_metas.numpy()
# Wrap in variables
images = Variable(images)
rpn_match = Variable(rpn_match)
rpn_bbox = Variable(rpn_bbox)
gt_class_ids = Variable(gt_class_ids)
gt_boxes = Variable(gt_boxes)
gt_masks = Variable(gt_masks)
# To GPU
if self.config.GPU_COUNT:
images = images.cuda()
rpn_match = rpn_match.cuda()
rpn_bbox = rpn_bbox.cuda()
gt_class_ids = gt_class_ids.cuda()
gt_boxes = gt_boxes.cuda()
gt_masks = gt_masks.cuda()
# Run object detection
self.predict([images, image_metas, gt_class_ids, gt_boxes, gt_masks], mode='training')
# Loss 계산
comute_losses(...)
loss = rpn_class_loss +
rpn_bbox_loss +
mrcnn_class_loss +
mrcnn_bbox_loss +
mrcnn_mask_loss
# Backpropagation
loss.backward()
torch.nn.utils.clip_grad_norm(self.parameters(), 5.0)
if (batch_count % self.config.BATCH_SIZE) == 0:
optimizer.step()
optimizer.zero_grad()
batch_count = 0
# Progress
printProgressBar(...)
# Statistics
return loss_sum,
loss_rpn_class_sum,
loss_rpn_bbox_sum,
loss_mrcnn_class_sum,
loss_mrcnn_bbox_sum,
loss_mrcnn_mask_sum
def val_epoch() 위와 비슷
728x90
'ML, DL > pytorch' 카테고리의 다른 글
[Pytorch] 1. 파이토치를 써야하는 이유 & 텐서란 (588) | 2020.02.24 |
---|---|
GAN (0) | 2019.07.16 |
[Pytorch] 데이터 불러오기 및 처리 (261) | 2019.07.11 |
댓글