Source code for ketl.extractor.Extractor

import urllib.parse as up
import asyncio

from abc import abstractmethod
from datetime import datetime, timedelta
from dataclasses import dataclass
from ftplib import FTP
from functools import partial
from pathlib import Path
from typing import List, Union, Optional, Iterator
from urllib.parse import quote
from multiprocessing.pool import Pool

from furl import furl
from smart_open import open as smart_open
from tqdm import tqdm
from sqlalchemy.orm import defer, Query
from more_itertools import chunked

from ketl.db.models import API, CachedFile, ExpectedFile, Source
from ketl.db.settings import get_session
from ketl.utils.file_utils import file_hash


[docs]class BaseExtractor:
[docs] @abstractmethod def extract(self) -> List[Path]: raise NotImplementedError('extract not implemented in the base class')
[docs]class DefaultExtractor(BaseExtractor): """ The default extractor can fetch files from an FTP server or any location that is openable via smart_open. It is up to the user to provide any credentials that are required to access the desired resources. """ BLOCK_SIZE = 16384 def __init__(self, api_config: Union[API, int, str], skip_existing_files: bool = False, overwrite_on_extract=True, show_progress: bool = False, concurrency: str = 'sync', on_disk_check='full', expected_file_generation='incremental'): if type(api_config) is int: self.api = get_session().query(API).filter(API.id == api_config).one() elif type(api_config) is str: self.api = get_session().query(API).filter(API.name == api_config).one() elif isinstance(api_config, API): self.api = api_config self.headers = {} self.auth = None self.auth_token = None self.skip_existing_files = skip_existing_files self.show_progress = show_progress self.concurrency = concurrency self.on_disk_check = on_disk_check self.expected_file_generation = expected_file_generation if self.api.creds: details = self.api.creds.creds_details cookie = details.get('cookie', None) if cookie: self.headers['Cookie'] = cookie['name'] + '=' + cookie['value'] self.auth = details.get('auth', None) self.auth_token = details.get('auth_token', None) if self.auth_token: self.headers[self.auth_token['header']] = self.auth_token['token']
[docs] def extract(self) -> List[Path]: """ Run the extractor. Attempts to minimize the amount of repeated work by checking which cached files actually exist, whether on disk or in the database, and batching downloads. Optionally distributes the work across processes if the `concurrency` parameter is set to `multiprocess`. :return: a list of paths corresponding to all the :class:`ExpectedFile` s that the extractor's API is responsible for. """ session = get_session() # depending on whether we are skipping files known to be on disk # we produce an iterable that is either a list of queries that will # give us the files that are missing, or a chunked version of a query if self.skip_existing_files: kwargs = {'missing': True, 'use_hash': self.on_disk_check == 'hash'} data_iterator: Query = self.api.cached_files_on_disk(**kwargs) else: data_iterator: Query = self.api.cached_files data_iterator = data_iterator.options(defer(CachedFile.meta)) collected_results = [] for batch in tqdm(chunked(data_iterator, 10000)): # type: List[CachedFile] if self.concurrency == 'sync': results = list( filter(None, [ self.get_file( cached_file.id, cached_file.full_url, cached_file.full_path, cached_file.refresh_interval, cached_file.url_params, show_progress=self.show_progress ) for cached_file in batch]) ) collected_results.extend(results) elif self.concurrency == 'async': # pargma: no cover raise NotImplementedError('Async downloads not yet implemented.') # pragma: no cover elif self.concurrency == 'multiprocess': get_file_args = [( cached_file.id, cached_file.full_url, cached_file.full_path, cached_file.refresh_interval, cached_file.url_params, self.show_progress ) for cached_file in batch] if get_file_args: with Pool() as pool: futures = pool.starmap_async(self.get_file, get_file_args) results = futures.get() if results: results = list(filter(None, results)) collected_results.extend(results) pool.join() session.bulk_update_mappings(CachedFile, collected_results) session.commit() new_expected_files: List[dict] = [] updated_expected_files: List[dict] = [] q: Query = session.query( ExpectedFile.path, ExpectedFile.cached_file_id, ExpectedFile.id ).join( CachedFile, ExpectedFile.cached_file_id == CachedFile.id ).join( Source, CachedFile.source_id == Source.id ).filter( Source.api_config_id == self.api.id ) current_files = {(ef[0], ef[1]): ef[2] for ef in q.yield_per(10000)} if self.expected_file_generation == 'full': cached_file_iterator = self.api.cached_files.options(defer(CachedFile.meta)) elif self.expected_file_generation == 'incremental': cached_file_iterator = data_iterator else: raise ValueError('Unspecified expected file generation strategy') bar = tqdm(total=cached_file_iterator.count()) for source_file in cached_file_iterator: ef = source_file.preprocess() if ef: key = (ef['path'], ef['cached_file_id']) if key not in current_files: new_expected_files.append(ef) else: updated_expected_files.append({'id': current_files[key], **ef}) bar.update(1) session.bulk_insert_mappings(ExpectedFile, new_expected_files) session.bulk_update_mappings(ExpectedFile, updated_expected_files) session.commit() return [Path(ef.path) for ef in self.api.expected_files]
@classmethod def _fetch_ftp_file(cls, source_url: str, target_file: Path, refresh_interval: timedelta, show_progress=False, force_download=False) -> bool: """ Fetch a file from the FTP. :param source_url: the source url to fetch. :param target_file: the location to which the file is to be downloaded. :param refresh_interval: maximum age of the file if it exists. :param show_progress: whether to show a tqdm progress bar. :param force_download: force downloads regardless of file existence. :return: whether the file was downloaded successfully. """ parsed_url = up.urlparse(source_url) ftp = FTP(parsed_url.hostname) ftp.login() total_size = ftp.size(parsed_url.path) updated = False if cls._requires_update(target_file, total_size, refresh_interval) or force_download: bar = tqdm(total=total_size, unit='B', unit_scale=True) if show_progress else None target_file.parent.mkdir(exist_ok=True, parents=True) with open(target_file.as_posix(), 'wb') as f: ftp.retrbinary(f'RETR {parsed_url.path}', partial(cls._ftp_writer, f, bar=bar), blocksize=cls.BLOCK_SIZE) if bar: bar.close() # pragma: no cover updated = True return updated @classmethod def _fetch_generic_file(cls, source_url: str, target_file: Path, refresh_interval: timedelta, url_params=None, headers=None, auth=None, show_progress=False, force_download=False) -> bool: """ Fetch a file from any scheme which is smart_open-able (e.g. https://, s3://, etc). :param source_url: the source url to fetch. :param target_file: the location to which the file is to be downloaded. :param refresh_interval: maximum age of the file if it exists. :param url_params: optional query parameters. :param headers: optional headers. :param auth: optional authorization parameters. :param show_progress: whether to show a tqdm progress bar. :param force_download: force downloads regardless of file existence. :return: whether the file was downloaded successfully """ transport_params = {} url = furl(source_url) if headers: transport_params['headers'] = headers if auth: transport_params.update(auth) if url_params: url.add(url_params) # tragic hack that is necessitated by s3's failure to properly conform to http spec # c.f. https://forums.aws.amazon.com/thread.jspa?threadID=55746 url_to_fetch = cls._handle_s3_urls(url) updated = False with smart_open(url_to_fetch, 'rb', ignore_ext=True, transport_params=transport_params) as r: total_size = getattr(r, 'content_length', -1) if cls._requires_update(target_file, total_size, refresh_interval) or force_download: target_file.parent.mkdir(exist_ok=True, parents=True) bar = tqdm(total=total_size, unit='B', unit_scale=True) if show_progress else None with open(target_file.as_posix(), 'wb') as f: # this is actually identical to shutil.copyfileobj(r.raw, f) # but with tqdm injected to show progress cls._generic_writer(r, f, block_size=cls.BLOCK_SIZE, bar=bar) updated = True return updated @staticmethod def _handle_s3_urls(url: furl.url): """ Munge S3 URLs so that filenames that include e.g. hashes and ampersands are downloadable. :param url: the URL :return: a URL adjusted for the presence of hashes and ampersands. """ url_to_fetch = url.url if url.scheme in {'s3', 's3a'} or url.host == 's3.amazonaws.com': url_to_fetch = f'{url.scheme}://{url.host}{quote(str(url.path))}' if str(url.fragment) != '': url_to_fetch += quote(f'#{url.fragment}', safe='%') if str(url.query) != '': url_to_fetch += f'&{url.query}' return url_to_fetch @staticmethod def _ftp_writer(dest, block, bar=None): """ Write a data block to a data destination. :param dest: an open file descriptor. :param block: data. :param bar: optional bar for displaying progress. :return: None """ if bar: bar.update(len(block)) dest.write(block) @staticmethod def _requires_update(target_file: Path, total_size: int, time_delta: timedelta = None) -> bool: """ Check whether the file should be re-downloaded. :param target_file: the location to which the file is to be downloaded. :param total_size: the size of the file that we're trying to download. :param time_delta: the maximum age of the file. :return: whether or not we should download the file """ if target_file.exists(): stat = target_file.stat() existing_size = stat.st_size # if either the sizes match up or the size can't be obtained but the file is recent return not ((existing_size == total_size) or (total_size == -1 and (datetime.now() - datetime.fromtimestamp(stat.st_mtime)) < time_delta)) else: return True @staticmethod def _generic_writer(source, target, block_size=16384, bar=None): """ Write data from a source stream to a target stream. :param source: a stream. :param target: a stream. :param block_size: size of block to write. :param bar: optional bar for showing progress. :return: None """ while 1: buf = source.read(block_size) if bar: bar.update(len(buf)) if not buf: break target.write(buf) if bar: bar.close()
[docs] def get_file(self, cached_file_id: int, source_url: str, target_file: Path, refresh_interval: timedelta, url_params=None, show_progress=False, force_download=False) -> Optional[dict]: """ Download a file either using the FTP downloader or the generic downloader. :param cached_file_id: the id of the cached file. :param source_url: the url from which to get the file. :param target_file: the path to which the file shoudl be downloaded. :param refresh_interval: the maximum age of the file. :param url_params: optional query parameters. :param show_progress: whether to show a tqdm progress bar. :param force_download: whether to force download regardless of file presence. :return: A dict that contains the updated data for the cached file. """ try: parsed_url = up.urlparse(source_url) if parsed_url.scheme == 'ftp': result = self._fetch_ftp_file(source_url, target_file, refresh_interval, show_progress=show_progress, force_download=force_download) else: result = self._fetch_generic_file(source_url, target_file, refresh_interval, url_params, headers=self.headers, auth=self.auth, show_progress=show_progress, force_download=force_download) if result: return { 'id': cached_file_id, 'hash': file_hash(target_file).hexdigest(), 'last_download': datetime.now(), 'size': target_file.stat().st_size } else: return None except Exception as ex: print(f'Could not download {source_url}: {ex}') return None
@staticmethod def _update_file_cache(source_file: CachedFile, target_file: Path): """ Deprecated. :param source_file: :param target_file: :return: """ session = get_session() source_file.hash = file_hash(target_file).hexdigest() source_file.last_download = datetime.now() source_file.size = target_file.stat().st_size session.add(source_file) session.commit()