Skip to the content.

data_prefetcher提高数据加载速度

Contact me

本系列博客主页及相关见此处


来自apex@GitHub,可以一定程度上帮助加速数据读取,最好的方式当然是砸钱上设置。

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)
            
    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        input = self.next_input
        target = self.next_target
        self.preload()
        return input, target

用法:把dataloader用这个类包装一下

prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next()

还有一个包prefetch_generator,可以让生成器工作在另一个线程,与模型运行等独立并行:

pip install prefetch_generator
for batch in BackgroundGenerator(my_minibatch_iterator):
    doit()

或者

@background()
def iterate_minibatches(some_param):
    while True:
        X = read_heavy_file()
        X = do_helluva_math(X)
        y = wget_from_pornhub()
        do_pretty_much_anything()
        yield X_batch, y_batch

一些例子可以看官方example