Custom mini batch Loader로 원하는 데이터 mini batch에 강제하기

현재 축관 제조공정 중 결함을 감지하는 Object detection Task에서

데이터 셋 3만장 중 결함이 있는 데이터는 3%에 불과합니다.

하지만 학습환경상 최대 설정 batch size가 8이었고, mini batch 안에 결함이 들어갈 이미지가 들어갈 확률은 극도로 작아

결함데이터에 대해 학습하지 못할 가능성이 높았습니다.

따라서 custom mini batch loader를 만들어 이 문제를 해결하려 했습니다.

1) Pytorch 구현 방식

파이토치에서 기본적으로 데이터로더를 쓸 때

불러와지는 것은

  1. Dataset
  2. RandomSampler
  3. BatchSampler
  4. DataLoader

이 4가지 입니다.

pytorch에서 비복원 추출시 mini batch에 random하게 Sampling넣어 주는 방법은 아래와 같습니다.

 

  1. torch.randperm() method를 사용해 idx값을 random순열로 생성하는 방식으로 생성되고
  2. 표본의 index값을 가지고 있는 Sampler가 BatchSampler에 Iterater로 들어갑니다.

저는 최대한 이 매커니즘을 활용해 구현하려고 합니다.

# torch/utils/data/sampler.py > RandomSampler
for _ in range(self.num_samples // n):
    yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

 

2) 구현하고 싶은 방식

  1. mini-batch에 원하는 데이터를 강제할 수 있을 것
  2. 시간 복잡도를 최대한 낮출 것
  3. 결함 데이터는 복원 추출로, 정상 데이터는 비복원 추출로 할 것

3) CustomRandomSampler

class CustomRandomSampler(Sampler[int]):
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, 
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self._num_samples = num_samples
        self.generator = generator
        cat1_imgs = set(data_source.coco.getImgIds(catIds = data_source.coco.getCatIds()[0]))
        cat2_imgs = set(data_source.coco.getImgIds(catIds = data_source.coco.getCatIds()[1]))
        # 결함이 있는 데이터 중 1, 2 category id에 대해 추출
        self.defect_imgs = list(cat2_imgs | cat1_imgs)

        # 결함이 없는 데이터
        self.other_imgs = list(set(data_source.coco.getImgIds()) - set(cat2_imgs | cat1_imgs))

    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n_n = len(self.other_imgs)
        n_d = len(self.defect_imgs)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator

        # 결함이 있는 data일 경우 복원 추출
        if self.dtype == 'defect':
            for _ in range(self.num_samples // 32):
                yield from [self.defect_imgs[i] for i in torch.randint(high=n_d, size=(32,), dtype=torch.int64,         generator=generator).tolist()]
            yield from [self.defect_imgs[i] for i in torch.randint(high=n_d, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()]

        # 결함이 없는 data일 경우 비복원 추출
        else:
            for _ in range(n_n // n_n):
                yield from [self.other_imgs[i] for i in torch.randperm(n_n, generator=generator).tolist()]
            yield from [self.other_imgs[i] for i in torch.randperm(n_n, generator=generator).tolist()[:n_n % n_n]]

    def __len__(self) -> int:
        if self.dtype == 'defect':
            return len(self.defect_imgs)
        if self.dtype == 'normal':
            return len(self.other_imgs)
        else:
            return len(self.data_source)


    def __call__(self, dtype) -> str:
        self.dtype = dtype 
        return self

RandomSampler는 Pytorch알고리즘과 비슷하게 ImageDataset의 idx return하는 iterator형태로 만들었습니다.

하지만 iterator를 호출할 때 결함인지, 정상인지에 따라 다르게 iterator를 리턴하도록 했습니다.

 

 

4) CustomBatchSampler

class CustomBatchSampler(Sampler[List[int]]):
    '''
    fixed_size : mini_batch 중 몇개의 결함 data를 강제로 집어 넣어줄 것인지 
    '''
    def __init__(self, sampler_1: Union[Sampler[int], Iterable[int]],
                     sampler_2: Union[Sampler[int], Iterable[int]],
                   batch_size: int, fixed_size : int , drop_last = True) -> None:
        self.sampler_1 = sampler_1
        self.sampler_2 = sampler_2
        self.batch_size = batch_size
        self.fixed_size = fixed_size
        self.drop_last = drop_last
    def __iter__(self) -> Iterator[List[int]]:
        batch = [0] * self.batch_size
        idx_in_batch = 0
        fix_size = 0 
        normal_sampler = iter(self.sampler_1('normal'))
        defect_sampler_iter = iter(self.sampler_2('defect'))
        if self.drop_last:
            while True:
                try:
                    idx = next(normal_sampler)

                    # 지정된 횟수까지 복원 추출로 데이터를 강제로 넣어줌
                    while fix_size < self.fixed_size:
                        batch[idx_in_batch] = next(defect_sampler_iter)
                        fix_size += 1
                        idx_in_batch += 1
                    batch[idx_in_batch] = idx
                    idx_in_batch += 1

                    # mini batch에 데이터가 다 들어갔을 경우 
                    if idx_in_batch == self.batch_size:
                        random.shuffle(batch)
                        yield batch
                        idx_in_batch = 0
                        fix_size = 0 
                        batch = [0] * self.batch_size
                except StopIteration:
                    break
    def __len__(self) -> int:
        num_steps = math.ceil(len(self.sampler_2.other_imgs) /  (self.batch_size - self.fixed_size))
        return num_steps

기존의 BatchSampler의 경우 하나의 Sampler를 받아 iteration을 돌며 mini batch를 return했지만 

저희는 DefectRandomSampler와 NormalRandomSampler를 두개의 인자로 받아야 했기 때문에 이에 따라 

두개의 Sampler를 Input으로 받도록 설정했습니다.

 

5) CustomDataLoader

class CustomDataLoader(data.DataLoader):
    def __init__(self, *args, **kwargs):
        super(CustomDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = collate_fn


def collate_fn(batch):
    image_list = []
    bboxes_list = []
    cls_ids_list = []

    for image, bboxes, class_ids in batch:
        image_list.append(torch.tensor(image, dtype=torch.float32))

        # image에 결함 데이터가 없을경
        if bboxes.size == 0:
            bboxes_list.append(torch.empty(0, 4, dtype=torch.float32))
            cls_ids_list.append(torch.empty(0, 1, dtype=torch.float32))
        else:
            bboxes_list.append(torch.tensor(bboxes, dtype=torch.float32))
            cls_ids_list.append(torch.tensor(class_ids, dtype=torch.float32))

    return torch.stack(image_list, dim=0), bboxes_list, cls_ids_list



if __name__=="__main__":
  dataset = CustomDataset(data_dir=data_dir )
  random_sampler_1 = CustomRandomSampler(dataset)
  random_sampler_2 = CustomRandomSampler(dataset)
  batchsampler = CustomBatchSampler(random_sampler_1 , random_sampler_2 , batch_size=8 , fixed_size=3)

classification task와 달리 Object Detection은 이미지에 따라 label의 개수가 달라 배열로 return을 할 수 없기때문에

collate_fn을 지정해줘야해서 CustomDataLoader를 만들었습니다.

 


이번 포스팅에서는 mini-batch내 data를 강제하기 위한 custom code를 구현했습니다.

'Programming > Pytorch' 카테고리의 다른 글

Pytorch Tips(정리중)  (0) 2023.04.03