data_prefetcher提高数据加载速度
Contact me
- Blog -> https://cugtyt.github.io/blog/index
- Email -> cugtyt@qq.com
- GitHub -> Cugtyt@GitHub
本系列博客主页及相关见此处
来自apex@GitHub,可以一定程度上帮助加速数据读取,最好的方式当然是砸钱上设置。
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
self.preload()
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
self.preload()
return input, target
用法:把dataloader用这个类包装一下
prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next()
还有一个包prefetch_generator,可以让生成器工作在另一个线程,与模型运行等独立并行:
pip install prefetch_generator
for batch in BackgroundGenerator(my_minibatch_iterator):
doit()
或者
@background()
def iterate_minibatches(some_param):
while True:
X = read_heavy_file()
X = do_helluva_math(X)
y = wget_from_pornhub()
do_pretty_much_anything()
yield X_batch, y_batch
一些例子可以看官方example。