WIP: Regression for gas prices
This commit is contained in:
parent
71a065e52c
commit
5b8af888f4
16
main.py
16
main.py
|
@ -1,16 +1,4 @@
|
||||||
# This is a sample Python script.
|
import sql_connection_handler as sql
|
||||||
|
|
||||||
# Press Shift+F10 to execute it or replace it with your code.
|
|
||||||
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
|
|
||||||
|
|
||||||
|
|
||||||
def print_hi(name):
|
|
||||||
# Use a breakpoint in the code line below to debug your script.
|
|
||||||
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
|
|
||||||
|
|
||||||
|
|
||||||
# Press the green button in the gutter to run the script.
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print_hi('PyCharm')
|
conn = sql.get_db_connection()
|
||||||
|
|
||||||
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|
|
132
regression_gas_prices.py
Normal file
132
regression_gas_prices.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
import pymysql
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow import keras
|
||||||
|
from tensorflow.keras import layers
|
||||||
|
|
||||||
|
import sql_connection_handler as sql
|
||||||
|
|
||||||
|
np.set_printoptions(precision=3, suppress=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_from_sql() -> pd.DataFrame:
|
||||||
|
conn = sql.get_db_connection()
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
query = 'SELECT timestamp, price FROM prices WHERE fuel_type = "E5" AND station = 1'
|
||||||
|
if not cur.execute(query):
|
||||||
|
raise pymysql.Error("Error loading data from SQL")
|
||||||
|
|
||||||
|
res = cur.fetchall()
|
||||||
|
|
||||||
|
raw_data = pd.DataFrame(res)
|
||||||
|
|
||||||
|
return raw_data
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(dataset: pd.DataFrame):
|
||||||
|
dataset = dataset.dropna()
|
||||||
|
|
||||||
|
# Split into training and test data
|
||||||
|
train_dataset = dataset.sample(frac=0.8, random_state=0)
|
||||||
|
test_dataset = dataset.drop(train_dataset.index)
|
||||||
|
|
||||||
|
# Split into features and labels
|
||||||
|
train_features = train_dataset.copy()
|
||||||
|
test_features = test_dataset.copy()
|
||||||
|
|
||||||
|
train_labels = train_features.pop(1)
|
||||||
|
test_labels = test_features.pop(1)
|
||||||
|
|
||||||
|
return train_features, test_features, train_labels, test_labels
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_data(train_features, test_features, train_labels, test_labels):
|
||||||
|
normalizer = tf.keras.layers.Normalization(axis=-1)
|
||||||
|
normalizer.adapt(np.asarray(train_features).astype('float32'))
|
||||||
|
|
||||||
|
price = np.asarray(train_features[0]).astype('float32')
|
||||||
|
|
||||||
|
price_normalizer = layers.Normalization(input_shape=[1, ], axis=None)
|
||||||
|
price_normalizer.adapt(price)
|
||||||
|
|
||||||
|
return price_normalizer
|
||||||
|
|
||||||
|
|
||||||
|
def generate_model(price_normalizer):
|
||||||
|
price_model = tf.keras.Sequential([
|
||||||
|
price_normalizer,
|
||||||
|
layers.Dense(units=1)
|
||||||
|
])
|
||||||
|
|
||||||
|
price_model.compile(
|
||||||
|
optimizer=tf.optimizers.Adam(learning_rate=0.1),
|
||||||
|
loss='mean_absolute_error'
|
||||||
|
)
|
||||||
|
|
||||||
|
return price_model
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(price_model, train_features, train_labels):
|
||||||
|
history = price_model.fit(
|
||||||
|
np.asarray(train_features[0]).astype('float32'),
|
||||||
|
train_labels,
|
||||||
|
epochs=100,
|
||||||
|
verbose=0,
|
||||||
|
validation_split=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show loss plot
|
||||||
|
plot_loss(history)
|
||||||
|
|
||||||
|
return price_model
|
||||||
|
|
||||||
|
|
||||||
|
def collect_results(price_model, test_features, test_labels):
|
||||||
|
test_results = {}
|
||||||
|
|
||||||
|
test_results['price_model'] = price_model.evaluate(
|
||||||
|
np.asarray(train_features[0]).astype('float32'),
|
||||||
|
test_labels,
|
||||||
|
verbose=0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(test_results)
|
||||||
|
|
||||||
|
|
||||||
|
def predict_prices(price_model, train_features, train_labels):
|
||||||
|
x = tf.linspace(0.0, 250, 251)
|
||||||
|
y = price_model.predict(x)
|
||||||
|
plot_prices(x, y, train_features, train_labels)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_loss(history):
|
||||||
|
plt.plot(history.history['loss'], label='loss')
|
||||||
|
plt.plot(history.history['val_loss'], label='val_loss')
|
||||||
|
plt.ylim([0, 10])
|
||||||
|
plt.xlabel('Epoch')
|
||||||
|
plt.ylabel('Error [MPG]')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def plot_prices(x, y, train_features, train_labels):
|
||||||
|
plt.scatter(train_features[0], train_labels, label='Data')
|
||||||
|
plt.plot(x, y, color='k', label='Predictions')
|
||||||
|
plt.xlabel('Datetime')
|
||||||
|
plt.ylabel('Price')
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dataset = get_data_from_sql().copy()
|
||||||
|
train_features, test_features, train_labels, test_labels = prepare_data(dataset)
|
||||||
|
price_normalizer = normalize_data(train_features, test_features, train_labels, test_labels)
|
||||||
|
price_model = generate_model(price_normalizer)
|
||||||
|
price_model = train_model(price_model, train_features, train_labels)
|
||||||
|
collect_results(price_model, test_features, test_labels)
|
||||||
|
predict_prices(price_model)
|
|
@ -4,3 +4,5 @@ seaborn
|
||||||
matplotlib
|
matplotlib
|
||||||
numpy
|
numpy
|
||||||
pandas
|
pandas
|
||||||
|
python-dotenv
|
||||||
|
pymysql
|
22
sql_connection_handler.py
Normal file
22
sql_connection_handler.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import logging
|
||||||
|
import pymysql
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_connection() -> pymysql.Connection:
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
load_dotenv(".env")
|
||||||
|
|
||||||
|
conn = pymysql.connect(
|
||||||
|
user=os.environ.get('DB_USER'),
|
||||||
|
password=os.environ.get('DB_PASSWORD'),
|
||||||
|
host=os.environ.get('DB_HOST'),
|
||||||
|
port=int(os.environ.get('DB_PORT')),
|
||||||
|
database=os.environ.get('DB_NAME')
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("SQL Connection Error:%s", e)
|
||||||
|
|
||||||
|
return conn
|
Loading…
Reference in New Issue
Block a user