From 5b8af888f4edcc1afe75dd4c7b6db1f7e13f46ee Mon Sep 17 00:00:00 2001 From: Patrick Mueller Date: Sat, 25 Dec 2021 10:48:20 +0100 Subject: [PATCH] WIP: Regression for gas prices --- main.py | 16 +-- regression.py => regression_example.py | 0 regression_gas_prices.py | 132 +++++++++++++++++++++++++ requirements.txt | 4 +- sql_connection_handler.py | 22 +++++ 5 files changed, 159 insertions(+), 15 deletions(-) rename regression.py => regression_example.py (100%) create mode 100644 regression_gas_prices.py create mode 100644 sql_connection_handler.py diff --git a/main.py b/main.py index 5596b44..e501c67 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,4 @@ -# This is a sample Python script. +import sql_connection_handler as sql -# Press Shift+F10 to execute it or replace it with your code. -# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. - - -def print_hi(name): - # Use a breakpoint in the code line below to debug your script. - print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. - - -# Press the green button in the gutter to run the script. if __name__ == '__main__': - print_hi('PyCharm') - -# See PyCharm help at https://www.jetbrains.com/help/pycharm/ + conn = sql.get_db_connection() \ No newline at end of file diff --git a/regression.py b/regression_example.py similarity index 100% rename from regression.py rename to regression_example.py diff --git a/regression_gas_prices.py b/regression_gas_prices.py new file mode 100644 index 0000000..ee8f845 --- /dev/null +++ b/regression_gas_prices.py @@ -0,0 +1,132 @@ +import pymysql +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers + +import sql_connection_handler as sql + +np.set_printoptions(precision=3, suppress=True) + + +def get_data_from_sql() -> pd.DataFrame: + conn = sql.get_db_connection() + cur = conn.cursor() + + query = 'SELECT timestamp, price FROM prices WHERE fuel_type = "E5" AND station = 1' + if not cur.execute(query): + raise pymysql.Error("Error loading data from SQL") + + res = cur.fetchall() + + raw_data = pd.DataFrame(res) + + return raw_data + + +def prepare_data(dataset: pd.DataFrame): + dataset = dataset.dropna() + + # Split into training and test data + train_dataset = dataset.sample(frac=0.8, random_state=0) + test_dataset = dataset.drop(train_dataset.index) + + # Split into features and labels + train_features = train_dataset.copy() + test_features = test_dataset.copy() + + train_labels = train_features.pop(1) + test_labels = test_features.pop(1) + + return train_features, test_features, train_labels, test_labels + + +def normalize_data(train_features, test_features, train_labels, test_labels): + normalizer = tf.keras.layers.Normalization(axis=-1) + normalizer.adapt(np.asarray(train_features).astype('float32')) + + price = np.asarray(train_features[0]).astype('float32') + + price_normalizer = layers.Normalization(input_shape=[1, ], axis=None) + price_normalizer.adapt(price) + + return price_normalizer + + +def generate_model(price_normalizer): + price_model = tf.keras.Sequential([ + price_normalizer, + layers.Dense(units=1) + ]) + + price_model.compile( + optimizer=tf.optimizers.Adam(learning_rate=0.1), + loss='mean_absolute_error' + ) + + return price_model + + +def train_model(price_model, train_features, train_labels): + history = price_model.fit( + np.asarray(train_features[0]).astype('float32'), + train_labels, + epochs=100, + verbose=0, + validation_split=0.2 + ) + + # Show loss plot + plot_loss(history) + + return price_model + + +def collect_results(price_model, test_features, test_labels): + test_results = {} + + test_results['price_model'] = price_model.evaluate( + np.asarray(train_features[0]).astype('float32'), + test_labels, + verbose=0 + ) + + print(test_results) + + +def predict_prices(price_model, train_features, train_labels): + x = tf.linspace(0.0, 250, 251) + y = price_model.predict(x) + plot_prices(x, y, train_features, train_labels) + + +def plot_loss(history): + plt.plot(history.history['loss'], label='loss') + plt.plot(history.history['val_loss'], label='val_loss') + plt.ylim([0, 10]) + plt.xlabel('Epoch') + plt.ylabel('Error [MPG]') + plt.legend() + plt.grid(True) + plt.show() + +def plot_prices(x, y, train_features, train_labels): + plt.scatter(train_features[0], train_labels, label='Data') + plt.plot(x, y, color='k', label='Predictions') + plt.xlabel('Datetime') + plt.ylabel('Price') + plt.legend() + plt.show() + + +if __name__ == '__main__': + dataset = get_data_from_sql().copy() + train_features, test_features, train_labels, test_labels = prepare_data(dataset) + price_normalizer = normalize_data(train_features, test_features, train_labels, test_labels) + price_model = generate_model(price_normalizer) + price_model = train_model(price_model, train_features, train_labels) + collect_results(price_model, test_features, test_labels) + predict_prices(price_model) diff --git a/requirements.txt b/requirements.txt index 307faf5..5789774 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,6 @@ tensorflow seaborn matplotlib numpy -pandas \ No newline at end of file +pandas +python-dotenv +pymysql \ No newline at end of file diff --git a/sql_connection_handler.py b/sql_connection_handler.py new file mode 100644 index 0000000..c4d1f05 --- /dev/null +++ b/sql_connection_handler.py @@ -0,0 +1,22 @@ +import logging +import pymysql +import os +from dotenv import load_dotenv + + +def get_db_connection() -> pymysql.Connection: + conn = None + try: + load_dotenv(".env") + + conn = pymysql.connect( + user=os.environ.get('DB_USER'), + password=os.environ.get('DB_PASSWORD'), + host=os.environ.get('DB_HOST'), + port=int(os.environ.get('DB_PORT')), + database=os.environ.get('DB_NAME') + ) + except Exception as e: + logging.error("SQL Connection Error:%s", e) + + return conn