Source code for parallel_write.writer

"""Parallel writer."""

from concurrent.futures import ThreadPoolExecutor, as_completed


[docs]class Writer: """Writes to given file handles `files` in parallel.""" def __init__(self, files, max_workers=None, ignore_diff=["fileno", "seekable"]): """Parallel writer. files: list of the input file handles to write to. max_workers: the maximum number of workers in the threadpool. Defaults to the length of `files`. ignore_diff: the list of methods where different results are accepted. The default is `[fileno, seekable]`. Fileno will always differ, seekable may differ when using various protocol implementations. """ if max_workers is None: max_workers = len(files) self._files = files self._executor = ThreadPoolExecutor(max_workers) self._iterators = [] self._ignore_diff = ignore_diff def __getattr__(self, attr): """Proxy the methods/properties to the underlying file objects.""" if callable(getattr(self._files[0], attr)): # if the attribute is a callable, return a function which evaluates # the given method on all files in parallel and expects that # they return the same value # for eg. read(), write(), seek() etc. def _do(*args, **kwargs): futures = {self._executor.submit(getattr(f, attr), *args, **kwargs): f for f in self._files} res = [future.result() for future in as_completed(futures)] # only check methods which must return the same value if attr not in self._ignore_diff: assert res.count(res[0]) == len(res) return res[0] return _do else: # if the attribute is not a callable (property), collect the values, # check for equality and return it # for eg. mode, name, closed etc. res = [getattr(f, attr) for f in self._files] assert res.count(res[0]) == len(res) return res[0] def __enter__(self): """Context manager enter.""" return self def __exit__(self, exc_type, exc_value, exc_traceback): """Context manager exit.""" self.close() def __iter__(self): """Initialize the iterators.""" self._iterators = [iter(f) for f in self._files] return self def __next__(self): """Read from the iterators.""" res = [next(i) for i in self._iterators] assert res.count(res[0]) == len(res) return res[0]