Custom mini batch Loader로 원하는 데이터 mini batch에 강제하기
현재 축관 제조공정 중 결함을 감지하는 Object detection Task에서
데이터 셋 3만장 중 결함이 있는 데이터는 3%에 불과합니다.
하지만 학습환경상 최대 설정 batch size가 8이었고, mini batch 안에 결함이 들어갈 이미지가 들어갈 확률은 극도로 작아
결함데이터에 대해 학습하지 못할 가능성이 높았습니다.
따라서 custom mini batch loader를 만들어 이 문제를 해결하려 했습니다.
1) Pytorch 구현 방식
파이토치에서 기본적으로 데이터로더를 쓸 때
불러와지는 것은
- Dataset
- RandomSampler
- BatchSampler
- DataLoader
이 4가지 입니다.
pytorch에서 비복원 추출시 mini batch에 random하게 Sampling넣어 주는 방법은 아래와 같습니다.
- torch.randperm() method를 사용해 idx값을 random순열로 생성하는 방식으로 생성되고
- 표본의 index값을 가지고 있는 Sampler가 BatchSampler에 Iterater로 들어갑니다.
저는 최대한 이 매커니즘을 활용해 구현하려고 합니다.
# torch/utils/data/sampler.py > RandomSampler
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
2) 구현하고 싶은 방식
- mini-batch에 원하는 데이터를 강제할 수 있을 것
- 시간 복잡도를 최대한 낮출 것
- 결함 데이터는 복원 추출로, 정상 데이터는 비복원 추출로 할 것
3) CustomRandomSampler
class CustomRandomSampler(Sampler[int]):
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self._num_samples = num_samples
self.generator = generator
cat1_imgs = set(data_source.coco.getImgIds(catIds = data_source.coco.getCatIds()[0]))
cat2_imgs = set(data_source.coco.getImgIds(catIds = data_source.coco.getCatIds()[1]))
# 결함이 있는 데이터 중 1, 2 category id에 대해 추출
self.defect_imgs = list(cat2_imgs | cat1_imgs)
# 결함이 없는 데이터
self.other_imgs = list(set(data_source.coco.getImgIds()) - set(cat2_imgs | cat1_imgs))
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n_n = len(self.other_imgs)
n_d = len(self.defect_imgs)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
# 결함이 있는 data일 경우 복원 추출
if self.dtype == 'defect':
for _ in range(self.num_samples // 32):
yield from [self.defect_imgs[i] for i in torch.randint(high=n_d, size=(32,), dtype=torch.int64, generator=generator).tolist()]
yield from [self.defect_imgs[i] for i in torch.randint(high=n_d, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()]
# 결함이 없는 data일 경우 비복원 추출
else:
for _ in range(n_n // n_n):
yield from [self.other_imgs[i] for i in torch.randperm(n_n, generator=generator).tolist()]
yield from [self.other_imgs[i] for i in torch.randperm(n_n, generator=generator).tolist()[:n_n % n_n]]
def __len__(self) -> int:
if self.dtype == 'defect':
return len(self.defect_imgs)
if self.dtype == 'normal':
return len(self.other_imgs)
else:
return len(self.data_source)
def __call__(self, dtype) -> str:
self.dtype = dtype
return self
RandomSampler는 Pytorch알고리즘과 비슷하게 ImageDataset의 idx return하는 iterator형태로 만들었습니다.
하지만 iterator를 호출할 때 결함인지, 정상인지에 따라 다르게 iterator를 리턴하도록 했습니다.
4) CustomBatchSampler
class CustomBatchSampler(Sampler[List[int]]):
'''
fixed_size : mini_batch 중 몇개의 결함 data를 강제로 집어 넣어줄 것인지
'''
def __init__(self, sampler_1: Union[Sampler[int], Iterable[int]],
sampler_2: Union[Sampler[int], Iterable[int]],
batch_size: int, fixed_size : int , drop_last = True) -> None:
self.sampler_1 = sampler_1
self.sampler_2 = sampler_2
self.batch_size = batch_size
self.fixed_size = fixed_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[List[int]]:
batch = [0] * self.batch_size
idx_in_batch = 0
fix_size = 0
normal_sampler = iter(self.sampler_1('normal'))
defect_sampler_iter = iter(self.sampler_2('defect'))
if self.drop_last:
while True:
try:
idx = next(normal_sampler)
# 지정된 횟수까지 복원 추출로 데이터를 강제로 넣어줌
while fix_size < self.fixed_size:
batch[idx_in_batch] = next(defect_sampler_iter)
fix_size += 1
idx_in_batch += 1
batch[idx_in_batch] = idx
idx_in_batch += 1
# mini batch에 데이터가 다 들어갔을 경우
if idx_in_batch == self.batch_size:
random.shuffle(batch)
yield batch
idx_in_batch = 0
fix_size = 0
batch = [0] * self.batch_size
except StopIteration:
break
def __len__(self) -> int:
num_steps = math.ceil(len(self.sampler_2.other_imgs) / (self.batch_size - self.fixed_size))
return num_steps
기존의 BatchSampler의 경우 하나의 Sampler를 받아 iteration을 돌며 mini batch를 return했지만
저희는 DefectRandomSampler와 NormalRandomSampler를 두개의 인자로 받아야 했기 때문에 이에 따라
두개의 Sampler를 Input으로 받도록 설정했습니다.
5) CustomDataLoader
class CustomDataLoader(data.DataLoader):
def __init__(self, *args, **kwargs):
super(CustomDataLoader, self).__init__(*args, **kwargs)
self.collate_fn = collate_fn
def collate_fn(batch):
image_list = []
bboxes_list = []
cls_ids_list = []
for image, bboxes, class_ids in batch:
image_list.append(torch.tensor(image, dtype=torch.float32))
# image에 결함 데이터가 없을경
if bboxes.size == 0:
bboxes_list.append(torch.empty(0, 4, dtype=torch.float32))
cls_ids_list.append(torch.empty(0, 1, dtype=torch.float32))
else:
bboxes_list.append(torch.tensor(bboxes, dtype=torch.float32))
cls_ids_list.append(torch.tensor(class_ids, dtype=torch.float32))
return torch.stack(image_list, dim=0), bboxes_list, cls_ids_list
if __name__=="__main__":
dataset = CustomDataset(data_dir=data_dir )
random_sampler_1 = CustomRandomSampler(dataset)
random_sampler_2 = CustomRandomSampler(dataset)
batchsampler = CustomBatchSampler(random_sampler_1 , random_sampler_2 , batch_size=8 , fixed_size=3)
classification task와 달리 Object Detection은 이미지에 따라 label의 개수가 달라 배열로 return을 할 수 없기때문에
collate_fn을 지정해줘야해서 CustomDataLoader를 만들었습니다.
이번 포스팅에서는 mini-batch내 data를 강제하기 위한 custom code를 구현했습니다.
'Programming > Pytorch' 카테고리의 다른 글
Pytorch Tips(정리중) (0) | 2023.04.03 |
---|