Source code for dirty_cat.datasets.fetching

# -*- coding: utf-8 -*-

"""
Fetching functions to retrieve example datasets, using
Scikit-Learn's ``fetch_openml()`` function.
"""


# Author:
# Lilian Boulard <lilian@boulard.fr>
# https://github.com/LilianBoulard

# Future notes:
# - Watch out for ``fetch_openml()`` API modifications:
# as of january 2021, the function is marked as experimental.


import gzip
import json
import sklearn
import warnings
import pandas as pd

from pathlib import Path
from dataclasses import dataclass
from typing import Union, Dict, Any, List

from dirty_cat.utils import Version
from dirty_cat.datasets.utils import get_data_dir

# Directory where the ``.gz`` files containing the
# details on downloaded datasets are stored.
# Note: the tree structure is created by ``fetch_openml()``.
# As of october 2020, this function is annotated as
# ``Experimental`` so the structure might change in future releases.
# This path will be concatenated to the dirty_cat data directory,
# available via the function ``get_data_dir()``.
DETAILS_DIRECTORY = "openml/openml.org/api/v1/json/data/"

# Same as above ; for the datasets features location.
FEATURES_DIRECTORY = "openml/openml.org/api/v1/json/data/features/"

# Same as above ; for the datasets data location.
DATA_DIRECTORY = "openml/openml.org/data/v1/download/"

# The IDs of the datasets, from OpenML.
# For each dataset, its URL is constructed as follows:
openml_url = "https://www.openml.org/d/{ID}"
ROAD_SAFETY_ID = 42803
OPEN_PAYMENTS_ID = 42738
MIDWEST_SURVEY_ID = 42805
MEDICAL_CHARGE_ID = 42720
EMPLOYEE_SALARIES_ID = 42125
TRAFFIC_VIOLATIONS_ID = 42132
DRUG_DIRECTORY_ID = 43044


@dataclass(unsafe_hash=True)
class Details:
    name: str
    file_id: str
    description: str


@dataclass(unsafe_hash=True)
class Features:
    names: List[str]


@dataclass(unsafe_hash=True)
class DatasetAll:
    name: str
    description: str
    X: pd.DataFrame
    y: pd.Series
    source: str
    path: Path


@dataclass(unsafe_hash=True)
class DatasetInfoOnly:
    name: str
    description: str
    source: str
    target: str
    path: Path
    read_csv_kwargs: Dict[str, Any]


def fetch_openml_dataset(dataset_id: int,
                         data_directory: Path = get_data_dir(),
                         ) -> Dict[str, Any]:
    """
    Gets a dataset from OpenML (https://www.openml.org),
    or from the disk if already downloaded.

    Parameters
    ----------
    dataset_id: int
        The ID of the dataset to fetch.
    data_directory: Path
        Optional. A directory to save the data to.
        By default, the dirty_cat data directory.

    Returns
    -------
    Dict[str, Any]
        A dictionary containing:
          - ``description``: str
              The description of the dataset,
              as gathered from OpenML.
          - ``source``: str
              The dataset's URL from OpenML.
          - ``path``: pathlib.Path
              The local path leading to the dataset,
              saved as a CSV file.

    """
    # Make path absolute
    data_directory = data_directory.resolve()

    # Construct the path to the gzip file containing the details on a dataset.
    details_gz_path = data_directory / DETAILS_DIRECTORY / f'{dataset_id}.gz'
    features_gz_path = data_directory / FEATURES_DIRECTORY / f'{dataset_id}.gz'

    if not details_gz_path.is_file() or not features_gz_path.is_file():
        # If the details file or the features file don't exist,
        # download the dataset.
        warnings.warn(
            f"Could not find the dataset {dataset_id} locally. "
            "Downloading it from OpenML; this might take a while... "
            "If it is interrupted, some files might be invalid/incomplete: "
            "if on the following run, the fetching raises errors, you can try "
            f"fixing this issue by deleting the directory {data_directory!s}.",
            UserWarning,
            stacklevel=2,
        )
        _download_and_write_openml_dataset(dataset_id=dataset_id,
                                           data_directory=data_directory)
    details = _get_details(details_gz_path)

    # The file ID is required because the data file is named after this ID,
    # and not after the dataset's.
    file_id = details.file_id
    csv_path = data_directory / f'{details.name}.csv'

    data_gz_path = data_directory / DATA_DIRECTORY / f'{file_id}.gz'

    if not data_gz_path.is_file():
        # This is a double-check.
        # If the data file does not exist, download the dataset.
        _download_and_write_openml_dataset(dataset_id=dataset_id,
                                           data_directory=data_directory)

    if not csv_path.is_file():
        # If the CSV file does not exist, use the dataset
        # downloaded by ``fetch_openml()`` to construct it.
        features = _get_features(features_gz_path)
        _export_gz_data_to_csv(data_gz_path, csv_path, features)

    url = openml_url.format(ID=dataset_id)

    return {
        "description": details.description,
        "source": url,
        "path": csv_path.resolve()
    }


def _download_and_write_openml_dataset(dataset_id: int,
                                       data_directory: Path) -> None:
    """
    Downloads a dataset from OpenML,
    taking care of creating the directories.

    Parameters
    ----------
    dataset_id: int
        The ID of the dataset to download.
    data_directory: Path
        The directory in which the data will be saved.

    Raises
    ------
    ValueError
        If the ID is incorrect (does not exist on OpenML)
    urllib.error.URLError
        If there is no Internet connection.

    """
    from sklearn.datasets import fetch_openml

    fetch_kwargs = {}
    if Version(sklearn.__version__) >= Version('0.22'):
        fetch_kwargs.update({'as_frame': True})

    # The ``fetch_openml()`` function returns a Scikit-Learn ``Bunch`` object,
    # which behaves just like a ``namedtuple``.
    # However, we do not want to save this data into memory:
    # we will read it from the disk later.
    #
    # Raises ``ValueError`` if the ID is incorrect (does not exist on OpenML)
    # and ``urllib.error.URLError`` if there is no Internet connection.
    fetch_openml(
        data_id=dataset_id,
        data_home=str(data_directory),
        **fetch_kwargs
    )


def _read_json_from_gz(compressed_dir_path: Path) -> dict:
    """
    Opens a gzip file, reads its content (JSON expected),
    and returns a dictionary.

    Parameters
    ----------
    compressed_dir_path: Path
        Path to the ``.gz`` file to read.

    Returns
    -------
    dict
        The information contained in the file,
        converted from plain-text JSON.

    """
    if not compressed_dir_path.is_file():
        raise FileNotFoundError(f"Couldn't find file {compressed_dir_path!s}")

    # Read content
    with gzip.open(compressed_dir_path, mode='rt') as gz:
        content = gz.read()

    details_json = json.JSONDecoder().decode(content)
    return details_json


def _get_details(compressed_dir_path: Path) -> Details:
    """
    Gets useful details from the details file.

    Parameters
    ----------
    compressed_dir_path: Path
        The path to the ``.gz`` file containing the details.

    Returns
    -------
    Details
        A ``Details`` object.

    """
    details = _read_json_from_gz(compressed_dir_path)["data_set_description"]
    # We filter out the irrelevant information.
    # If you want to modify this list (to add or remove items)
    # you must also modify the ``Details`` object definition.
    return Details(
        name=details["name"],
        file_id=details["file_id"],
        description=details["description"],
    )


def _get_features(compressed_dir_path: Path) -> Features:
    """
    Gets features that can be inserted in the CSV file.
    The most important feature being the column names.

    Parameters
    ----------
    compressed_dir_path: Path
        Path to the gzip file containing the features.

    Returns
    -------
    Features
        A ``Features`` object.

    """
    raw_features = _read_json_from_gz(compressed_dir_path)["data_features"]
    # We filter out the irrelevant information.
    # If you want to modify this list (to add or remove items)
    # you must also modify the ``Features`` object definition.
    return Features(
        names=[column["name"] for column in raw_features["feature"]]
    )


def _export_gz_data_to_csv(compressed_dir_path: Path,
                           destination_file: Path, features: Features) -> None:
    """
    Reads a gzip file containing ARFF data,
    and writes it to a target CSV.

    Parameters
    ----------
    compressed_dir_path: Path
        Path to the ``.gz`` file containing the ARFF data.
    destination_file: Path
        A CSV file to write to.
    features: Features
        A ``Features`` object containing the first CSV line (the column names).

    """
    atdata_found = False
    with destination_file.open(mode="w", encoding='utf8') as csv:
        with gzip.open(compressed_dir_path, mode="rt", encoding='utf8') as gz:
            csv.write(_features_to_csv_format(features))
            csv.write("\n")
            # We will look at each line of the file until we find
            # "@data": only after this tag is the actual CSV data.
            for line in gz.readlines():
                if not atdata_found:
                    if line.lower().startswith("@data"):
                        atdata_found = True
                else:
                    csv.write(line)


def _features_to_csv_format(features: Features) -> str:
    return ",".join(features.names)


def fetch_dataset_as_dataclass(dataset_name: str,
                               dataset_id: int,
                               target: str,
                               read_csv_kwargs: dict,
                               load_dataframe: bool,
                               ) -> Union[DatasetAll, DatasetInfoOnly]:
    """
    Takes a dataset identifier, a target column name,
    and some additional keyword arguments for `pd.read_csv`.

    If you don't need the dataset to be loaded in memory,
    pass `load_dataframe=False`.

    Returns
    -------
    DatasetAll
        If `load_dataframe=True`

    DatasetInfoOnly
        If `load_dataframe=False`

    """
    info = fetch_openml_dataset(dataset_id)
    if load_dataframe:
        df = pd.read_csv(info['path'], **read_csv_kwargs)
        y = df[target]
        X = df.drop(target, axis='columns')
        dataset = DatasetAll(
            name=dataset_name,
            description=info['description'],
            X=X,
            y=y,
            source=info['source'],
            path=info['path'],
        )
    else:
        dataset = DatasetInfoOnly(
            name=dataset_name,
            description=info['description'],
            source=info['source'],
            target=target,
            path=info['path'],
            read_csv_kwargs=read_csv_kwargs,
        )

    return dataset


# Datasets fetchers section
# Public API


[docs]def fetch_employee_salaries(load_dataframe: bool = True, drop_linked: bool = True, drop_irrelevant: bool = True, ) -> Union[DatasetAll, DatasetInfoOnly]: """Fetches the employee_salaries dataset (regression), available at https://openml.org/d/42125 Description of the dataset: > Annual salary information including gross pay and overtime pay for all active, permanent employees of Montgomery County, MD paid in calendar year 2016. This information will be published annually each year. Parameters ---------- drop_linked: bool (default True) Drops columns "2016_gross_pay_received" and "2016_overtime_pay", which are closely linked to "current_annual_salary", the target. drop_irrelevant: bool (default True) Drops column "full_name", which is usually irrelevant to the statistical analysis. Returns ------- DatasetAll If `load_dataframe=True` DatasetInfoOnly If `load_dataframe=False` """ dataset = fetch_dataset_as_dataclass( dataset_name='Employee salaries', dataset_id=EMPLOYEE_SALARIES_ID, target='current_annual_salary', read_csv_kwargs={ 'quotechar': "'", 'escapechar': '\\', 'na_values': ['?'], }, load_dataframe=load_dataframe, ) if load_dataframe: if drop_linked: dataset.X.drop(["2016_gross_pay_received", "2016_overtime_pay"], axis=1, inplace=True) if drop_irrelevant: dataset.X.drop(["full_name"], axis=1, inplace=True) return dataset
[docs]def fetch_road_safety(load_dataframe: bool = True, ) -> Union[DatasetAll, DatasetInfoOnly]: """Fetches the road safety dataset (classification), available at https://openml.org/d/42803 Description of the dataset: > Data reported to the police about the circumstances of personal injury road accidents in Great Britain from 1979, and the maker and model information of vehicles involved in the respective accident. This version includes data up to 2015. Returns ------- DatasetAll If `load_dataframe=True` DatasetInfoOnly If `load_dataframe=False` """ return fetch_dataset_as_dataclass( dataset_name='Road safety', dataset_id=ROAD_SAFETY_ID, target='Sex_of_Driver', read_csv_kwargs={ 'na_values': ['?'], }, load_dataframe=load_dataframe, )
[docs]def fetch_medical_charge(load_dataframe: bool = True ) -> Union[DatasetAll, DatasetInfoOnly]: """Fetches the medical charge dataset (regression), available at https://openml.org/d/42720 Description of the dataset: > The Inpatient Utilization and Payment Public Use File (Inpatient PUF) provides information on inpatient discharges for Medicare fee-for-service beneficiaries. The Inpatient PUF includes information on utilization, payment (total payment and Medicare payment), and hospital-specific charges for the more than 3,000 U.S. hospitals that receive Medicare Inpatient Prospective Payment System (IPPS) payments. The PUF is organized by hospital and Medicare Severity Diagnosis Related Group (MS-DRG) and covers Fiscal Year (FY) 2011 through FY 2016. Returns ------- DatasetAll If `load_dataframe=True` DatasetInfoOnly If `load_dataframe=False` """ return fetch_dataset_as_dataclass( dataset_name='Medical charge', dataset_id=MEDICAL_CHARGE_ID, target='Average_Total_Payments', read_csv_kwargs={ 'quotechar': "'", 'escapechar': '\\', }, load_dataframe=load_dataframe, )
[docs]def fetch_midwest_survey(load_dataframe: bool = True ) -> Union[DatasetAll, DatasetInfoOnly]: """Fetches the midwest survey dataset (classification), available at https://openml.org/d/42805 Description of the dataset: > Survey to know if people self-identify as Midwesterners. Returns ------- DatasetAll If `load_dataframe=True` DatasetInfoOnly If `load_dataframe=False` """ return fetch_dataset_as_dataclass( dataset_name='Midwest survey', dataset_id=MIDWEST_SURVEY_ID, target='Census_Region', read_csv_kwargs={ 'quotechar': "'", 'escapechar': '\\', }, load_dataframe=load_dataframe, )
[docs]def fetch_open_payments(load_dataframe: bool = True ) -> Union[DatasetAll, DatasetInfoOnly]: """Fetches the open payments dataset (classification), available at https://openml.org/d/42738 Description of the dataset: > Payments given by healthcare manufacturing companies to medical doctors or hospitals. Returns ------- DatasetAll If `load_dataframe=True` DatasetInfoOnly If `load_dataframe=False` """ return fetch_dataset_as_dataclass( dataset_name='Open payments', dataset_id=OPEN_PAYMENTS_ID, target='status', read_csv_kwargs={ 'quotechar': "'", 'escapechar': '\\', 'na_values': ['?'], }, load_dataframe=load_dataframe, )
[docs]def fetch_traffic_violations(load_dataframe: bool = True ) -> Union[DatasetAll, DatasetInfoOnly]: """Fetches the traffic violations dataset (classification), available at https://openml.org/d/42132 Description of the dataset: > This dataset contains traffic violation information from all electronic traffic violations issued in the Montgomery County, MD. Any information that can be used to uniquely identify the vehicle, the vehicle owner or the officer issuing the violation will not be published. Returns ------- DatasetAll If `load_dataframe=True` DatasetInfoOnly If `load_dataframe=False` """ return fetch_dataset_as_dataclass( dataset_name='Traffic violations', dataset_id=TRAFFIC_VIOLATIONS_ID, target='violation_type', read_csv_kwargs={ 'quotechar': "'", 'escapechar': '\\', 'na_values': ['?'], }, load_dataframe=load_dataframe, )
def fetch_drug_directory(load_dataframe: bool = True ) -> Union[DatasetAll, DatasetInfoOnly]: """Fetches the drug directory dataset (classification), available at https://openml.org/d/43044 Description of the dataset: > Product listing data submitted to the U.S. FDA for all unfinished, unapproved drugs. Returns ------- DatasetAll If `load_dataframe=True` DatasetInfoOnly If `load_dataframe=False` """ return fetch_dataset_as_dataclass( dataset_name='Drug directory', dataset_id=DRUG_DIRECTORY_ID, target='PRODUCTTYPENAME', read_csv_kwargs={ 'quotechar': "'", 'escapechar': '\\', }, load_dataframe=load_dataframe, )