Описание проекта
Компания «Чётенькое такси» собрала исторические данные о заказах такси в аэропортах. Чтобы привлекать больше водителей в период пиковой нагрузки, нужно спрогнозировать количество заказов такси на следующий час. Постройте модель для такого предсказания.
Значение метрики RMSE на тестовой выборке должно быть не больше 48.
Вам нужно:
- Загрузить данные и выполнить их ресемплирование по одному часу.
- Проанализировать данные.
- Обучить разные модели с различными гиперпараметрами. Сделать тестовую выборку размером 10% от исходных данных.
- Проверить данные на тестовой выборке и сделать выводы.
Данные лежат в файле taxi.csv
. Количество заказов находится в столбце num_orders
(от англ. number of orders, «число заказов»).
Подготовка¶
Подготовка тетради¶
# Базовые библиотеки
import pandas as pd # Датафреймы
import numpy as np # Математика для массивов
from math import factorial # Факториалы
from scipy import stats as st # Статистика
import os # Библиотека для оптимизации чтения данных из файла
import time # Расчет времени выполнения функций
# Pipeline (пайплайн)
from sklearn.pipeline import(
Pipeline, # Pipeline с ручным вводом названий шагов.
make_pipeline # Pipeline с автоматическим названием шагов.
)
# Функция для поддержки экспериментальной функции HavingGridSearchSV
from sklearn.experimental import enable_halving_search_cv
# Ускоренная автоматизация поиска лучших моделей и их параметров
from sklearn.model_selection import HalvingGridSearchCV
# Ускоренная автоматизация рандомного поиска лучших моделей и их параметров
from sklearn.model_selection import HalvingRandomSearchCV
# Автоматизация раздельного декодирования признаков
from sklearn.compose import(
make_column_selector,
make_column_transformer,
ColumnTransformer
)
# Обработка данных для машинного обучения
# Стандартизация данных
import re
#! pip install sklearn.preprocessing
from sklearn.preprocessing import(
OneHotEncoder, # Создание отдельных столбцов для каждого категориального значения, drop='first' (удаление первого столбца против dummy-ловушки), sparse=False (?)
OrdinalEncoder, # Кодирование порядковых категориальных признаков
#TargetEncoder, # Кодирование категорий на основе таргетов (ошибка, модуль не найден)
LabelEncoder,
StandardScaler,
MinMaxScaler
)
# Кодирование категорий на основе таргетов
#!pip install -pU category_encoders
#from category_encoders.target_encoder import TargetEncoder
# Другие функции предобработки данных
from sklearn.impute import KNNImputer # Заполнение пропусков в данных методом k-блжиайших соседей.
from sklearn.utils import shuffle # Перемешивание данных для уравновешивания их в разных выборках
from statsmodels.stats.outliers_influence import variance_inflation_factor # Коэффициент инфляции дисперсии (5 и более - признак коррелирует со всеми остальными, его можно удалить и выразить через другие признаки)
from sklearn.model_selection import(
GridSearchCV, # Поиск гиперпараметров по сетке (GridSearch)
train_test_split, # Разделение выборок с целевыми и нецелевыми признаками на обучающую и тестовую
TimeSeriesSplit, # Разделение обучающей выборки при кросс-валидации (!) временных рядов
validation_curve,
StratifiedKFold, # Кроссвалидация с указанием количества фолдов (частей, на которые будет разбита обучающая выборка, одна из которых будет участвовать в валидации)
KFold, # Кроссвалидация
cross_val_score # Оценка качества модели на кроссвалидации
)
# Различные модели машинного обучения (в данном проекте требуется регрессия)
# (есть разбор на https://russianblogs.com/article/83691573909/)
# Линейная модель
from sklearn.linear_model import(
#LogisticRegression, # Линейная классификация
LinearRegression, # Линейная регрессия
Ridge , # Линейная регрессия. "Хребтовая" регрессия (метод наименьших квадратов)
BayesianRidge , # Линейная регрессия. Байесовская "хребтовая" регрессия (максимизации предельного логарифмического правдоподобия)
SGDRegressor # Линейная регрессия. SGD - Стохастический градиентный спуск (минимизирует регуляризованные эмпирические потери за счет стохастического градиентного спуска)
)
# Решающее дерево
from sklearn.tree import(
#DecisionTreeClassifier, # Решающее дерево. Классификация
DecisionTreeRegressor # Решающее дерево. Регрессия
)
# Случайный лес
from sklearn.ensemble import(
#RandomForestClassifier, # Случайный лес. Классификация
RandomForestRegressor # Случайный лес. Регрессия
)
# Машина опорных векторов
from sklearn.svm import(
SVR # # Линейная модель. Регрессия с использованием опорных векторов
)
# Нейронная сеть
from sklearn.neural_network import(
MLPRegressor # Нейронная сеть. Регрессия
)
# CatBoost (made in Yandex)
from catboost import(
CatBoostRegressor # CatBoost (Яндекс). Регрессия
)
# LightGBM
from lightgbm import(
LGBMRegressor # LightGBM. Регрессия
)
# Метрики (Показатели качества моделей)
from sklearn.metrics import(
# Метрики для моделей регрессии
mean_absolute_error, # MAE, Средняя абсолютная ошибка (не чувствительная к выбросам)
mean_absolute_percentage_error, # MAPE, Средняя абсолютная ошибка в % (универсальная в %)
mean_squared_error, # MSE, Средняя квадратичная ошибка (дисперсия, чувствительная к выбросам), RMSE (сигма) = mean_squared_error(test_y, preds, squared=False)
r2_score, # R^2, Коэффициент детерминации (универсальная в %, чувствительная к выбросам, может быть отрицательной и возвращать NaN)
# Другое
make_scorer, # Функция для использования собственных функций в параметре scoring функции HalvingGridSearchCV
ConfusionMatrixDisplay
)
# Анализ сезонности и трендов
from statsmodels.tsa.seasonal import seasonal_decompose
# Визуализация графиков
import seaborn as sns
import matplotlib
%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib import rcParams, rcParamsDefault
from pandas.plotting import scatter_matrix
# Для поиска совпадений
# в названиях населённых пунктов
from fuzzywuzzy import fuzz
from fuzzywuzzy import process
/opt/conda/lib/python3.9/site-packages/fuzzywuzzy/fuzz.py:11: UserWarning: Using slow pure-python SequenceMatcher. Install python-Levenshtein to remove this warning warnings.warn('Using slow pure-python SequenceMatcher. Install python-Levenshtein to remove this warning')
# Отображение всех столбцов таблицы
pd.set_option('display.max_columns', None)
# Обязательно для нормального отображения графиков plt
rcParams['figure.figsize'] = 10, 6
%config InlineBackend.figure_format = 'svg'
# Дополнительно и не обязательно для декорирования графиков
factor = .8
default_dpi = rcParamsDefault['figure.dpi']
rcParams['figure.dpi'] = default_dpi * factor
# Глобальная переменная
# для функций со случайными значениями
STATE = 42
Загрузка данных и их предварительный анализ¶
# Загрузка данных
def read_csv_file(path1, path2):
if os.path.exists(path1):
data = pd.read_csv(
path1,
index_col=[0],
parse_dates=[0]
)
elif os.path.exists(path12):
data = pd.read_csv(
path2,
index_col=[0],
parse_dates=[0])
else:
print('Файл не найден')
return data
data = read_csv_file(
'/datasets/taxi.csv',
'datasets/taxi.csv'
)
# Анализ датафрейма
print(data.info())
data.head()
<class 'pandas.core.frame.DataFrame'> DatetimeIndex: 26496 entries, 2018-03-01 00:00:00 to 2018-08-31 23:50:00 Data columns (total 1 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 num_orders 26496 non-null int64 dtypes: int64(1) memory usage: 414.0 KB None
num_orders | |
---|---|
datetime | |
2018-03-01 00:00:00 | 9 |
2018-03-01 00:10:00 | 14 |
2018-03-01 00:20:00 | 28 |
2018-03-01 00:30:00 | 20 |
2018-03-01 00:40:00 | 32 |
# Анализ значений датафрейма
data.hist()
data.describe()
num_orders | |
---|---|
count | 26496.000000 |
mean | 14.070463 |
std | 9.211330 |
min | 0.000000 |
25% | 8.000000 |
50% | 13.000000 |
75% | 19.000000 |
max | 119.000000 |
Выводы из предварительного анализа
- Датафрейм изначально содержал всего два признака:
datetime
(дата со временем) иnum_orders
(число заказов), а также 26496 объектов без пропусков. - Признак
datetime
использован в качестве индексов. - Признак
num_orders
имеет небольшие значения и может быть приведён к типуint32
. - Согласно условию задачи требуется ресемплировать датафрейм до 1 часа.
Предварительная подготовка данных¶
# Ресемплирование данных до 1 часа
data = data.resample('1H').sum()
# Оптимизация типа данных
# признака "num_orders"
data['num_orders'] = data['num_orders'].astype('int32')
# Анализ датафрейма после
# предварительной подготовки
data.info()
data.head()
<class 'pandas.core.frame.DataFrame'> DatetimeIndex: 4416 entries, 2018-03-01 00:00:00 to 2018-08-31 23:00:00 Freq: H Data columns (total 1 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 num_orders 4416 non-null int32 dtypes: int32(1) memory usage: 51.8 KB
num_orders | |
---|---|
datetime | |
2018-03-01 00:00:00 | 124 |
2018-03-01 01:00:00 | 85 |
2018-03-01 02:00:00 | 71 |
2018-03-01 03:00:00 | 66 |
2018-03-01 04:00:00 | 43 |
# Анализ значений датафрейма после
# предварительной подготовки
data.rename(
columns={'num_orders':'Количество заказов'}
).hist()
plt.title('Распределение количества заказов в час')
data.describe()
num_orders | |
---|---|
count | 4416.000000 |
mean | 84.422781 |
std | 45.023853 |
min | 0.000000 |
25% | 54.000000 |
50% | 78.000000 |
75% | 107.000000 |
max | 462.000000 |
data.rename(
columns={'num_orders':'Количество заказов'}
).plot(
title='Количество заказов, распределенное по времени'
);
Результаты предварительной обработки данных
- При загрузке датафрейма индексы заменены на значения признака
datetime
. - Данные ресемплированы до 1 часа.
- Оптимизировать данных признака
num_orders
. - В результате предварительной обработки датафрейм имеет 1 признак и 4416 объектов, а индексы в виде даты и времени с 00:00 1 марта по 23:00 31 августа 2018 года. Каждый объект соответствует одному часу, следующему в хронологческом порядке.
Результаты подготовки¶
Изначально датафрейм содержал два признака: datetime
(дата со временем) и num_orders
(число заказов), а также 26496 объектов без пропусков. Признак datetime
был использован в качестве индексов при загрузке датафрейма. Признак num_orders
оптимизирован путем приведения к типу int32
. Согласно условию задачи датафрейм ресемплирован до 1 часа.
В результате предварительной обработки датафрейм имеет 1 признак num_orders
и 4416 объектов, а индексы в виде даты и времени.
Анализ¶
В рамках данного проекта будт проанализированы:
- Распределение количества заказов в разных временных интервалах.
- Сдвиги и скользящие средние.
- Декомпозиция количества заказов на сезонную и трендовую составляющую с шагом в 1 день
Распределение количества заказов в разных временных интервалах¶
# Распределение заказов по часам в течение суток
data.groupby(data.index.hour).mean().rename(
columns={'num_orders':'Количество заказов'}
).plot(
title='Распределение заказов по часам в течение суток',
xlabel='Часы суток',
xlim=(0, 23),
grid=True
);
# Распределение заказов по дням в течение недели
data.groupby(data.index.weekday).mean().rename(
columns={'num_orders':'Количество заказов'},
index={0:'Понедельник', 1:'Вторник', 2:'Среда', 3:'Четверг', 4:'Пятница', 5:'Суббота', 6:'Воскресенье'}
).plot(
title='Распределение заказов по дням в течение недели',
xlabel='Дни недели',
xlim=(0, 6),
grid=True
);
# Распределение заказов по дням в течение месяца
data.groupby(data.index.day).mean().rename(
columns={'num_orders':'Количество заказов'}
).plot(
title='Распределение заказов по дням в течение месяца',
xlabel='Числа месяца',
xlim=(1, 31),
grid=True
);
# Распределение заказов по месяцам за все время набюдений
data.groupby(data.index.month).mean().rename(
columns={'num_orders':'Количество заказов'},
index={3:'Март', 4:'Апрель', 5:'Май', 6:'Июнь', 7:'Июль', 8:'Август'}
).plot(
title='Распределение заказов по месяцам за все время набюдений',
xlabel='Месяца',
xlim=(0, 5),
grid=True
);
Выводы из распределения количества заказов в разных временных интервалах
- В среднем в течение суток минимум заказов бывает в 6 часов утра, а максимум в полночь. Между максимумом и минимумом набюдаются выраженные эстремумы. Так, максимальное количество заказов наблюдается в 10 утра и 16 часов дня, а минимальное в полдень и 18 часов вечера.
- В среднем в течение недели наблюдается выраженное снижение активности во втроник и воскресенье, между которыми стабильный рост до максимумомов в понедельник и пятницу.
- В среднем в течение месяца наблюдается рост от начала к концу месяца. Минимум заказов 5 числа, максимум 31 числа.
- Наблюдается рост количества заказов от месяца к месяцу.
Анализ сдвигов и скользящих средних¶
# Создание аналитического датафрейма
data_analisys = data.copy()
data_analisys.head()
num_orders | |
---|---|
datetime | |
2018-03-01 00:00:00 | 124 |
2018-03-01 01:00:00 | 85 |
2018-03-01 02:00:00 | 71 |
2018-03-01 03:00:00 | 66 |
2018-03-01 04:00:00 | 43 |
# Функция для создания новых признаков
# путем сдвига итерационно
def shift(data, target, start, finish, step):
for i in range(start, finish, step):
string = 'shift_' + str(i)
data[string] = data[target].shift(i)
return data
# Функция для создания новых признаков
# со скользящими средними путем сдвига на 1 час итерационно
def rolling_mean_shift(data, target, start, finish, step):
for i in range(start, finish, step):
string = 'rolling_mean_' + str(i) + '_shift_1'
data[string] = data[target].shift().rolling(i).mean()
return data
# Функция для создания новых ризнаков
# со скользящими средними итерационно
def rolling_mean(data, target, start, finish, step):
for i in range(start, finish, step):
string = 'rolling_mean_' + str(i)
data[string] = data[target].rolling(i).mean()
return data
# Сдвиг с разным количеством часов
data_analisys = shift(data_analisys, 'num_orders', 1, 25, 1)
data_analisys = shift(data_analisys, 'num_orders', 48, 169, 24)
# Скользящее среднее со сдвигом в 1 час
# и окном 168 часов, т.к. наблюдается
# "недельная сезонность". С таким окном
# тренд будет более гладким
data_analisys = rolling_mean_shift(data_analisys, 'num_orders', 168, 168, 1)
# Удаление объектов с пустыми значениями
data_analisys = data_analisys.dropna()
# Вывод на печать RMSE сдвигов и скользящих средних
for i in data_analisys.columns:
if i != 'num_orders':
print('RMSE', i, '=', mean_squared_error(
data_analisys[i],
data_analisys['num_orders'],
squared=False
))
RMSE shift_1 = 39.30292142790866 RMSE shift_2 = 46.945798072729566 RMSE shift_3 = 52.952578009864375 RMSE shift_4 = 55.63490331656302 RMSE shift_5 = 56.233426099809954 RMSE shift_6 = 55.86740254089737 RMSE shift_7 = 53.87424107330057 RMSE shift_8 = 53.090801832062915 RMSE shift_9 = 54.62524483131728 RMSE shift_10 = 54.888993010186624 RMSE shift_11 = 52.785766072035706 RMSE shift_12 = 52.10447367572656 RMSE shift_13 = 52.879965251946615 RMSE shift_14 = 54.96986847972229 RMSE shift_15 = 55.78799702492222 RMSE shift_16 = 54.35478632659841 RMSE shift_17 = 55.052089697282746 RMSE shift_18 = 56.698811089014214 RMSE shift_19 = 57.2117354587293 RMSE shift_20 = 56.45780510312234 RMSE shift_21 = 54.76864110934485 RMSE shift_22 = 49.250063332241055 RMSE shift_23 = 42.425399798029 RMSE shift_24 = 35.399379297015024 RMSE shift_48 = 33.229463873243596 RMSE shift_72 = 34.112244793051836 RMSE shift_96 = 34.15574098427023 RMSE shift_120 = 33.298664881080796 RMSE shift_144 = 36.11251771851234 RMSE shift_168 = 27.573911972531693
# Удаление не востребованной далее переменной
del data_analisys
Выводы из анализа сдвигов и скользящих средних
- Минимальную ошибку RMSE дает скользящее среднее с окном в 2 часа.
- Минимальные ошибки RMSE среди скользящих и сдвигов, применимые для предсказания будущих периодов, наблюдаются при сдвиге на 168 часов (1 неделя). С большим отрывом, но тоже хороший результат показали сдвиги на 48, 72, 96 и 120 часов (2, 3, 4 и 5 суток).
Из полученных результатов можно сделать вывод о том, что почасовые значения аналогичного периода предыдущей недели дают наиболее точное предсказание текущих почасовых предсказываемых значений. Т.е. для предсказания количества заказов с 11:00 до 12:00 10.08.2023 требуется использовать количество заказов с 11:00 до 12:00 03.08.2023.
Декомпозиция количества заказов на сезонную и трендовую составляющую с шагом в 1 день¶
# Декомпозиция на сезонную
# и трендовую составляющую
# с шагом в 1 день
decomposed = seasonal_decompose(data['num_orders'].resample('1H').sum())
#plt.figure(figsize=(10, 14))
#plt.subplot(311)
plt.show()
# Чтобы график корректно отобразился, указываем его
# оси ax, равными plt.gca() (англ. get current axis,
# получить текущие оси)
decomposed.trend.plot(ax=plt.gca())
plt.title('Тренд')
#plt.subplot(312)
plt.show()
decomposed.seasonal.plot(ax=plt.gca())
plt.title('Сезонность')
#plt.subplot(313)
plt.show()
decomposed.resid.plot(ax=plt.gca())
plt.title('Остатки')
plt.tight_layout()