WIP: Regression for gas prices

This commit is contained in:
Patrick Müller 2021-12-25 10:48:20 +01:00
parent 71a065e52c
commit 5b8af888f4
5 changed files with 159 additions and 15 deletions

16
main.py
View File

@ -1,16 +1,4 @@
# This is a sample Python script. import sql_connection_handler as sql
# Press Shift+F10 to execute it or replace it with your code.
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
def print_hi(name):
# Use a breakpoint in the code line below to debug your script.
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
# Press the green button in the gutter to run the script.
if __name__ == '__main__': if __name__ == '__main__':
print_hi('PyCharm') conn = sql.get_db_connection()
# See PyCharm help at https://www.jetbrains.com/help/pycharm/

132
regression_gas_prices.py Normal file
View File

@ -0,0 +1,132 @@
import pymysql
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import sql_connection_handler as sql
np.set_printoptions(precision=3, suppress=True)
def get_data_from_sql() -> pd.DataFrame:
conn = sql.get_db_connection()
cur = conn.cursor()
query = 'SELECT timestamp, price FROM prices WHERE fuel_type = "E5" AND station = 1'
if not cur.execute(query):
raise pymysql.Error("Error loading data from SQL")
res = cur.fetchall()
raw_data = pd.DataFrame(res)
return raw_data
def prepare_data(dataset: pd.DataFrame):
dataset = dataset.dropna()
# Split into training and test data
train_dataset = dataset.sample(frac=0.8, random_state=0)
test_dataset = dataset.drop(train_dataset.index)
# Split into features and labels
train_features = train_dataset.copy()
test_features = test_dataset.copy()
train_labels = train_features.pop(1)
test_labels = test_features.pop(1)
return train_features, test_features, train_labels, test_labels
def normalize_data(train_features, test_features, train_labels, test_labels):
normalizer = tf.keras.layers.Normalization(axis=-1)
normalizer.adapt(np.asarray(train_features).astype('float32'))
price = np.asarray(train_features[0]).astype('float32')
price_normalizer = layers.Normalization(input_shape=[1, ], axis=None)
price_normalizer.adapt(price)
return price_normalizer
def generate_model(price_normalizer):
price_model = tf.keras.Sequential([
price_normalizer,
layers.Dense(units=1)
])
price_model.compile(
optimizer=tf.optimizers.Adam(learning_rate=0.1),
loss='mean_absolute_error'
)
return price_model
def train_model(price_model, train_features, train_labels):
history = price_model.fit(
np.asarray(train_features[0]).astype('float32'),
train_labels,
epochs=100,
verbose=0,
validation_split=0.2
)
# Show loss plot
plot_loss(history)
return price_model
def collect_results(price_model, test_features, test_labels):
test_results = {}
test_results['price_model'] = price_model.evaluate(
np.asarray(train_features[0]).astype('float32'),
test_labels,
verbose=0
)
print(test_results)
def predict_prices(price_model, train_features, train_labels):
x = tf.linspace(0.0, 250, 251)
y = price_model.predict(x)
plot_prices(x, y, train_features, train_labels)
def plot_loss(history):
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.ylim([0, 10])
plt.xlabel('Epoch')
plt.ylabel('Error [MPG]')
plt.legend()
plt.grid(True)
plt.show()
def plot_prices(x, y, train_features, train_labels):
plt.scatter(train_features[0], train_labels, label='Data')
plt.plot(x, y, color='k', label='Predictions')
plt.xlabel('Datetime')
plt.ylabel('Price')
plt.legend()
plt.show()
if __name__ == '__main__':
dataset = get_data_from_sql().copy()
train_features, test_features, train_labels, test_labels = prepare_data(dataset)
price_normalizer = normalize_data(train_features, test_features, train_labels, test_labels)
price_model = generate_model(price_normalizer)
price_model = train_model(price_model, train_features, train_labels)
collect_results(price_model, test_features, test_labels)
predict_prices(price_model)

View File

@ -4,3 +4,5 @@ seaborn
matplotlib matplotlib
numpy numpy
pandas pandas
python-dotenv
pymysql

22
sql_connection_handler.py Normal file
View File

@ -0,0 +1,22 @@
import logging
import pymysql
import os
from dotenv import load_dotenv
def get_db_connection() -> pymysql.Connection:
conn = None
try:
load_dotenv(".env")
conn = pymysql.connect(
user=os.environ.get('DB_USER'),
password=os.environ.get('DB_PASSWORD'),
host=os.environ.get('DB_HOST'),
port=int(os.environ.get('DB_PORT')),
database=os.environ.get('DB_NAME')
)
except Exception as e:
logging.error("SQL Connection Error:%s", e)
return conn