본문 바로가기
ML, DL/pytorch

[Instance Segmentation] Train code

by Wordbe 2019. 7. 25.
728x90
def train_model

    train_set = Dataset(train_dataset, ...)
    train_generator = torch.utils.data.DataLoader(train_set, 
                                                batch_size=1, 
                                                shuffle=True, 
                                                num_workers=4)
#    val_set =
#    val_generator =

    # Train

    optimizer = optim.SGD(params_dict, lr, momentum)

    for epoch in range(self.epoch+1, epochs+1):
        record log

        # Training
        loss, loss_rpn_class, loss_rpn_bbox, loss_mrcnn_class, los_mrcnn_bbox, loss_mrcnn_mask = 
            self.train_epoch(train_generator, optimizer, self.config.STEPS_PER_EPOCH)

        # Validation
        val_loss, val_loss_rpn_class, val_loss_rpn_bbox, val_loss_mrcnn_class, val_loss_mrcnn_bbox, val_loss_mrcnn_mask =
            self.valid_epoch(val_generator, self.config.VALIDATION_STEPS)

        # Statistics
        self.loss_history
        self.val_history
        visualize_plot_loss

        # Save model
        torch.save()

       self.epoch = epochs


def train_epoch

    for inputs in datagenerator:
        images = inputs[0]
        image_metas = inputs[1]
        rpn_match = inputs[2]
        rpn_bbox = inputs[3]
        gt_class_ids = inputs[4]
        gt_boxes = inputs[5]
        gt_masks = inputs[6]

        # image_metas as numpy array
        image_metas = image_metas.numpy()

        # Wrap in variables
        images = Variable(images)
        rpn_match = Variable(rpn_match)
        rpn_bbox = Variable(rpn_bbox)
        gt_class_ids = Variable(gt_class_ids)
        gt_boxes = Variable(gt_boxes)
        gt_masks = Variable(gt_masks)

        # To GPU
        if self.config.GPU_COUNT:
               images = images.cuda()
              rpn_match = rpn_match.cuda()
            rpn_bbox = rpn_bbox.cuda()
            gt_class_ids = gt_class_ids.cuda()
            gt_boxes = gt_boxes.cuda()
            gt_masks = gt_masks.cuda()

        # Run object detection
        self.predict([images, image_metas, gt_class_ids, gt_boxes, gt_masks], mode='training')

        # Loss 계산
        comute_losses(...)
        loss = rpn_class_loss + 
               rpn_bbox_loss + 
               mrcnn_class_loss + 
               mrcnn_bbox_loss + 
               mrcnn_mask_loss

        # Backpropagation
        loss.backward()
        torch.nn.utils.clip_grad_norm(self.parameters(), 5.0)
        if (batch_count % self.config.BATCH_SIZE) == 0:
            optimizer.step()
            optimizer.zero_grad()
            batch_count = 0

        # Progress
        printProgressBar(...)

        # Statistics

        return loss_sum, 
               loss_rpn_class_sum, 
               loss_rpn_bbox_sum, 
               loss_mrcnn_class_sum, 
               loss_mrcnn_bbox_sum, 
               loss_mrcnn_mask_sum

def val_epoch() 위와 비슷

728x90

'ML, DL > pytorch' 카테고리의 다른 글

[Pytorch] 1. 파이토치를 써야하는 이유 & 텐서란  (588) 2020.02.24
GAN  (0) 2019.07.16
[Pytorch] 데이터 불러오기 및 처리  (261) 2019.07.11

댓글