diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..4e3224b
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,77 @@
+name: ci
+on:
+ push:
+ branches:
+ - '*'
+permissions:
+ contents: write
+jobs:
+ deploy:
+ if: github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Configure Git Credentials
+ run: |
+ git config user.name github-actions[bot]
+ git config user.email 41898282+github-actions[bot]@users.noreply.github.com
+ - uses: actions/setup-python@v5
+ with:
+ python-version: 3.x
+ - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
+ - uses: actions/cache@v4
+ with:
+ key: mkdocs-material-${{ env.cache_id }}
+ path: .cache
+ restore-keys: |
+ mkdocs-material-
+ - run: pip install mkdocs "mkdocstrings[python]" mkdocs-material
+ - run: mkdocs gh-deploy --force
+
+ mypy:
+ runs-on: ubuntu-latest
+ name: Mypy
+ steps:
+ - uses: actions/checkout@v1
+ - name: Set up Python 3.x
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.x
+ - name: Install Dependencies
+ run: |
+ pip install mypy
+ - name: mypy
+ run: |
+ mypy src/
+
+ pylint:
+ runs-on: ubuntu-latest
+ name: pylint
+ steps:
+ - uses: actions/checkout@v1
+ - name: Set up Python 3.x
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.x
+ - name: Install Dependencies
+ run: |
+ pip install pylint
+ - name: pylint
+ run: |
+ pylint src/
+
+ pytest:
+ runs-on: ubuntu-latest
+ name: pytest
+ steps:
+ - uses: actions/checkout@v1
+ - name: Set up Python 3.x
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.x
+ - name: Install Dependencies
+ run: |
+ pip install pytest
+ - name: pytest
+ run: |
+ pytest tests/
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..5166d7f
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,18 @@
+FROM ubuntu:22.04
+
+# set a directory for the app
+WORKDIR /usr/workspace/
+
+# copy all the files to the container
+COPY . .
+
+# install dependencies
+RUN apt-get update
+RUN apt-get install -y python3
+RUN apt install -y python3-pip
+RUN pip install --no-cache-dir -r requirements.txt
+
+# EXPOSE 8000
+
+# CMD cd django-stock-tracker && python3 manage.py runserver 0.0.0.0:8000 Can run this in the terminal
+# but need to specify docker run -it -p 8888:8000 stock
\ No newline at end of file
diff --git a/django-stock-tracker/home/views.py b/django-stock-tracker/home/views.py
index cedb49b..513f707 100644
--- a/django-stock-tracker/home/views.py
+++ b/django-stock-tracker/home/views.py
@@ -1,7 +1,8 @@
from django.shortcuts import render
from django.http import HttpResponse, JsonResponse
import sys, json
-sys.path.append("/home/wleong/Personal_project/StockTracker/")
+sys.path.append("/usr/workspace/")
+sys.path.append("/home/wleong/Personal_project/StockTracker")
from src.model import Model
from src.data_processing import DataProcessing
from src.live_price_display import LivePriceDisplay
diff --git a/docs/explanation.md b/docs/explanation.md
new file mode 100644
index 0000000..1a74ec0
--- /dev/null
+++ b/docs/explanation.md
@@ -0,0 +1,15 @@
+This part of the project documentation focuses on an
+**understanding-oriented** approach. You'll get a
+chance to read about the background of the project,
+as well as reasoning about how it was implemented.
+
+> **Note:** Expand this section by considering the
+> following points:
+
+- Give context and background on your library
+- Explain why you created it
+- Provide multiple examples and approaches of how
+ to work with it
+- Help the reader make connections
+- Avoid writing instructions or technical descriptions
+ here
\ No newline at end of file
diff --git a/docs/how-to-guides.md b/docs/how-to-guides.md
new file mode 100644
index 0000000..18c2ad3
--- /dev/null
+++ b/docs/how-to-guides.md
@@ -0,0 +1,6 @@
+This part of the project documentation focuses on a
+**problem-oriented** approach. You'll tackle common
+tasks that you might have, with the help of the code
+provided in this project.
+
+To be continued.
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 0000000..f7dd8ac
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,30 @@
+This site contains the project documentation for the
+'Calculators' project that is used mainly for individual
+tax calculation purposes. Its aim is to provide users
+with a fairly accurate amount of tax paid, given the
+financial year and calculator.
+
+## Table Of Contents
+
+The documentation follows the best practice for
+project documentation as described by Daniele Procida
+in the [Diátaxis documentation framework](https://diataxis.fr/)
+and consists of four separate parts:
+
+1. [Tutorials](tutorials.md)
+2. [How-To Guides](how-to-guides.md)
+3. [Reference](reference.md)
+4. [Explanation](explanation.md)
+
+Quickly find what you're looking for depending on
+your use case by looking at the different pages.
+
+## Project Overview
+
+::: src
+
+## Acknowledgements
+I want to thank my house plants for providing me with
+a negligible amount of oxygen each day. Also, I want
+to thank the sun for providing more than half of their
+nourishment free of charge.
diff --git a/docs/reference.md b/docs/reference.md
new file mode 100644
index 0000000..162a8ec
--- /dev/null
+++ b/docs/reference.md
@@ -0,0 +1,8 @@
+This part of the project documentation focuses on
+an **information-oriented** approach. Use it as a
+reference for the technical implementation of the
+`Calculators` project code.
+
+::: src.live_price_display
+::: src.news_display
+::: src.model
\ No newline at end of file
diff --git a/docs/tutorials.md b/docs/tutorials.md
new file mode 100644
index 0000000..384f1d6
--- /dev/null
+++ b/docs/tutorials.md
@@ -0,0 +1,16 @@
+This part of the project documentation focuses on a
+**learning-oriented** approach. You'll learn how to
+get started with the code in this project.
+
+> **Note:** Expand this section by considering the
+> following points:
+
+- Help newcomers with getting started
+- Teach readers about your library by making them
+ write code
+- Inspire confidence through examples that work for
+ everyone, repeatably
+- Give readers an immediate sense of achievement
+- Show concrete examples, no abstractions
+- Provide the minimum necessary explanation
+- Avoid any distractions
\ No newline at end of file
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 0000000..cc785b7
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,14 @@
+site_name: Stock Tracker
+
+theme:
+ name: "material"
+
+plugins:
+ - mkdocstrings
+
+nav:
+ - Stock Tracker Docs: index.md
+ - tutorials.md
+ - How-To Guides: how-to-guides.md
+ - reference.md
+ - explanation.md
diff --git a/models/forecasting-algorithms/ARIMA_forecast.py b/models/forecasting-algorithms/ARIMA_forecast.py
deleted file mode 100644
index 77618ee..0000000
--- a/models/forecasting-algorithms/ARIMA_forecast.py
+++ /dev/null
@@ -1,188 +0,0 @@
-import json, argparse
-from statsmodels.tsa.arima.model import ARIMA
-from collections import deque
-import pmdarima as pm
-import numpy as np
-import pandas as pd
-
-class ARIMAForecast:
-
- def __init__(self) -> None:
- with open("data.txt", "r") as file:
- raw_data = file.read()
- raw_data = raw_data.replace("'", "\"")
- data = json.loads(raw_data)
- self.processed_data = [{"date": data_point["date"], "close": float(data_point["close"])} for data_point in data]
- self.df = pd.DataFrame.from_dict(self.processed_data)
-
-
- def find_nearest_date(self, date_offset, start_date, direction):
- if direction == "backwards":
- req_date = pd.to_datetime(start_date) - pd.DateOffset(days = date_offset)
- elif direction == "forwards":
- req_date = pd.to_datetime(start_date) + pd.DateOffset(days = date_offset)
-
- queue = deque([req_date])
- visited_dates = set()
- last_available_date = self.processed_data[0]["date"]
- while queue:
- req_date = queue.popleft().strftime('%Y-%m-%d')
- if req_date in visited_dates:
- continue
- idx = self.df[self.df.date == req_date].index.values
- if idx.size > 0:
- if (direction == "forwards" and req_date >= start_date) or (direction == "backwards" and req_date < start_date):
- break
- visited_dates.add(req_date)
- queue.append(pd.to_datetime(req_date) - pd.DateOffset(days = 1))
- if req_date < last_available_date:
- queue.append(pd.to_datetime(req_date) + pd.DateOffset(days = 1))
-
- return start_date, req_date, idx[0]
-
- def slice_data(self, start_date, **kwargs):
- final_idx = kwargs.get("final_idx", None)
- start_idx = self.df[self.df.date == start_date].index.values[0]
- if final_idx:
- if final_idx > start_idx:
- return self.df.iloc[start_idx:final_idx]
- else:
- return self.df.iloc[start_idx:final_idx:-1]
- else: return self.df.iloc[:start_idx]
-
- def window_slice_optimisation(self, start_date):
- best_results = {"AIC": float("inf"), "combination":{"p": 0, "d": 0, "q": 0}}
- date_offset = 180
- curr_start_date = start_date
- _, curr_end_date, curr_end_idx = self.find_nearest_date(3*365, curr_start_date, "forwards")
- _, _, curr_goal_idx = self.find_nearest_date(date_offset, curr_end_date, "forwards")
- curr_end_sliced_data = self.slice_data(curr_start_date, final_idx = curr_end_idx)
- curr_goal_sliced_data = self.slice_data(curr_end_date, final_idx = curr_goal_idx)
-
- for p in range(0,4):
- for d in range(0, 3):
- for q in range(0, 4):
- arima_model_manual = ARIMA(curr_end_sliced_data.close, order=(p, d, q), enforce_invertibility=False, enforce_stationarity=False)
- model_manual = arima_model_manual.fit(method_kwargs={"warn_convergence": False})
- aic_value_manual = model_manual.aic
-
- if aic_value_manual < best_results["AIC"]:
- best_results["AIC"] = float(aic_value_manual)
- best_results["combination"]["p"] = p
- best_results["combination"]["d"] = d
- best_results["combination"]["q"] = q
- p_manual, d_manual, q_manual = list(best_results["combination"].values())
- arima_model_manual = ARIMA(curr_end_sliced_data.close, order=(p_manual, d_manual, q_manual), enforce_invertibility=False, enforce_stationarity=False)
- model_manual = arima_model_manual.fit(method_kwargs={"warn_convergence": False})
- try:
- forecast_length = len(curr_goal_sliced_data)
- forecasted_values_manual = pd.Series(model_manual.forecast(forecast_length), index=self.df.close[curr_end_idx:curr_goal_idx:-1].index)
- actual_values = self.df.close[curr_end_idx:curr_goal_idx:-1]
- Mean_Absolute_Percentage_Error_manual = np.mean(np.abs(forecasted_values_manual - actual_values)/np.abs(actual_values)) * 100
-
- model_auto = pm.auto_arima(curr_end_sliced_data.close, seasonal=True, m=12)
- (p_auto, d_auto, q_auto) = model_auto.get_params()["order"]
- arima_model_auto = ARIMA(curr_end_sliced_data.close, order=(p_auto, d_auto, q_auto), enforce_invertibility=False, enforce_stationarity=False)
- model_auto = arima_model_auto.fit(method_kwargs={"warn_convergence": False})
- forecast_length = len(curr_goal_sliced_data)
- forecasted_values_auto = pd.Series(model_auto.forecast(forecast_length), index=self.df.close[curr_end_idx:curr_goal_idx:-1].index)
- actual_values = self.df.close[curr_end_idx:curr_goal_idx:-1]
- Mean_Absolute_Percentage_Error_auto = np.mean(np.abs(forecasted_values_auto - actual_values)/np.abs(actual_values)) * 100
-
- return Mean_Absolute_Percentage_Error_manual, Mean_Absolute_Percentage_Error_auto
- except ValueError as e:
- return None, None
-
- def train_test_optimisation(self, backwards_duration):
- best_results_trained_manual = {"AIC": float("inf"), "combination":{"p": 0, "d": 0, "q": 0}}
- _, first_data_date, _ = self.find_nearest_date(backwards_duration, self.processed_data[0]["date"], "backwards")
- sliced_data = self.slice_data(first_data_date)
- train_value_index = len(sliced_data) * 8 // 10
- for p in range(0,4):
- for d in range(0, 3):
- for q in range(0, 4):
- arima_model_manual = ARIMA(sliced_data.close[:train_value_index], order=(p, d, q), enforce_invertibility=False, enforce_stationarity=False)
- model_manual = arima_model_manual.fit(method_kwargs={"warn_convergence": False})
- aic_value = model_manual.aic
- if aic_value < best_results_trained_manual["AIC"]:
- best_results_trained_manual["AIC"] = aic_value
- best_results_trained_manual["combination"]["p"] = p
- best_results_trained_manual["combination"]["d"] = d
- best_results_trained_manual["combination"]["q"] = q
-
- p_manual, d_manual, q_manual = list(best_results_trained_manual["combination"].values())
- arima_model_manual = ARIMA(sliced_data.close[:train_value_index], order=(p_manual, d_manual, q_manual), enforce_invertibility=False, enforce_stationarity=False)
- model_manual = arima_model_manual.fit(method_kwargs={"warn_convergence": False})
- forecasted_values_manual = pd.Series(model_manual.forecast(len(sliced_data) - train_value_index),
- index=sliced_data.close[train_value_index:].index)
- actual_values = sliced_data.close[train_value_index:]
-
- Mean_Absolute_Percentage_Error_manual = np.mean(np.abs(forecasted_values_manual - actual_values)/np.abs(actual_values)) * 100
-
- model_auto = pm.auto_arima(sliced_data.close, seasonal=True, m=12)
- (p_auto, d_auto, q_auto) = model_auto.get_params()["order"]
- arima_model_auto = ARIMA(sliced_data.close[:train_value_index], order=(p_auto, d_auto, q_auto), enforce_invertibility=False, enforce_stationarity=False)
- model_auto = arima_model_auto.fit(method_kwargs={"warn_convergence": False})
- forecasted_values_auto = pd.Series(model_auto.forecast(len(sliced_data) - train_value_index),
- index=sliced_data.close[train_value_index:].index)
- actual_values = self.df.close[train_value_index:]
- Mean_Absolute_Percentage_Error_auto = np.mean(np.abs(forecasted_values_auto - actual_values)/np.abs(actual_values)) * 100
-
- return Mean_Absolute_Percentage_Error_manual, Mean_Absolute_Percentage_Error_auto
-
- def generate_mape(self, start_date, slice_window, prediction_length, backwards_duration):
- dates = []
- dates.append(start_date)
- slice_window = eval(slice_window)
- prediction_length = eval(prediction_length)
- backwards_duration = eval(backwards_duration)
- manual_series, auto_series, mape_manual, mape_auto = None, None, None, None
- slice_final_date = self.find_nearest_date(slice_window, start_date, "forwards")[1]
- slice_window_manual_mape_list, slice_window_auto_mape_list = [], []
- manual_result, auto_result = self.window_slice_optimisation(dates[-1])
- slice_window_manual_mape_list.append(manual_result)
- slice_window_auto_mape_list.append(auto_result)
- while slice_final_date < self.processed_data[0]["date"]:
- start_date = self.find_nearest_date(prediction_length, dates[-1], "forwards")[1]
- dates.append(start_date)
- manual_result, auto_result = self.window_slice_optimisation(dates[-1])
- if manual_result and auto_result:
- slice_window_manual_mape_list.append(manual_result)
- slice_window_auto_mape_list.append(auto_result)
- slice_final_date = self.find_nearest_date(slice_window, start_date, "forwards")[1]
-
- if slice_window_manual_mape_list[0] != None and slice_window_auto_mape_list[0] != None:
- manual_series = pd.Series(slice_window_manual_mape_list)
- auto_series = pd.Series(slice_window_auto_mape_list)
- mape_manual = np.mean(manual_series)
- mape_auto = np.mean(auto_series)
- else:
- print("Not enough data provided, please provide more data, or reduce the slice window or prediction length")
-
- try:
- train_test_manual_mape, train_test_auto_mape = self.train_test_optimisation(backwards_duration)
- except ValueError as e:
- print("Data too short to split")
- train_test_manual_mape, train_test_auto_mape = None, None
- except IndexError as e:
- print("Need more data points")
- train_test_manual_mape, train_test_auto_mape = None, None
-
-
- return(f"""Results:\n
- sliced window manual mape: {mape_manual},\n
- sliced window auto mape: {mape_auto},\n
- train test manual mape: {train_test_manual_mape},\n
- train test auto mape: {train_test_auto_mape}""")
-
-if __name__ == "__main__":
- af = ARIMAForecast()
-
- parser = argparse.ArgumentParser(description='Finding Mean Absolute Percentage Error using two different ARIMA methods')
- parser.add_argument('start_date', help='Provide date to start the slice, ensure date has data')
- parser.add_argument('slice_window', help='The window size of the slice used for analysis, in days')
- parser.add_argument('prediction_length', help='The number of data points to be predicted')
- parser.add_argument('backwards_duration', help='How far back would the first data be, in days')
- args = parser.parse_args()
-
- af.generate_mape(args.start_date, args.slice_window, args.prediction_length, args.backwards_duration)
diff --git a/models/forecasting-algorithms/monte_carlo_forecast.py b/models/forecasting-algorithms/monte_carlo_forecast.py
deleted file mode 100644
index e28970c..0000000
--- a/models/forecasting-algorithms/monte_carlo_forecast.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import json, argparse
-import numpy as np
-import pandas as pd
-import matplotlib.pyplot as plt
-from scipy.stats import norm
-
-
-class MonteCarloForecast:
-
- def __init__(self) -> None:
- with open("data.txt", "r") as file:
- raw_data = file.read()
- raw_data = raw_data.replace("'", "\"")
- data = json.loads(raw_data)
- self.processed_data = [{"date": data_point["date"], "close": float(data_point["close"])} for data_point in data]
- self.df = pd.DataFrame.from_dict(self.processed_data)
-
- def generate_mape(self, days_to_test, days_to_predict, number_of_simulations, return_mode):
- self.df.date = pd.to_datetime(mc.df.date)
- daily_return = np.log(1 + self.df.close.pct_change())
- average_daily_return = daily_return.mean()
- variance = daily_return.var()
- drift = average_daily_return - (variance/2)
- standard_deviation = daily_return.std()
- days_to_test = eval(days_to_test)
- days_to_predict = eval(days_to_predict)
- number_of_simulations = eval(number_of_simulations)
- predictions = np.zeros(days_to_test + days_to_predict)
- predictions[0] = self.df.close[days_to_test + days_to_predict]
- pred_collection = np.ndarray(shape=(number_of_simulations, days_to_test + days_to_predict))
- curr_mean_absolute_error = 0
- differences = np.array([])
- for sim_idx in range(0,number_of_simulations):
- for prediction_idx in range(1, days_to_test + days_to_predict):
- random_value = standard_deviation * norm.ppf(np.random.rand())
- predictions[prediction_idx] = predictions[prediction_idx - 1] * np.exp(drift + random_value)
- pred_collection[sim_idx] = predictions
- actual_values = self.df.close[:days_to_test]
- predicted_values = predictions[:days_to_test]
- curr_mean_absolute_error += np.mean(np.abs(predicted_values - actual_values) / np.abs(actual_values))
- if return_mode != "MAPE only":
- difference_array = np.subtract(predicted_values, actual_values)
- difference_value = np.sum(np.abs(difference_array))
- differences = np.append(differences, difference_value)
-
- if return_mode != "MAPE only":
- best_fit = np.argmin(differences)
- future_prices = pred_collection[best_fit][days_to_predict * -1:]
-
-
- Mean_Absolute_Percentage_Error = curr_mean_absolute_error / number_of_simulations * 100
- if return_mode == "forecast only":
- return future_prices
- elif return_mode == "both":
- return Mean_Absolute_Percentage_Error, future_prices
- elif return_mode == "MAPE only":
- return Mean_Absolute_Percentage_Error
-
-
-if __name__ == "__main__":
- mc = MonteCarloForecast()
-
- parser = argparse.ArgumentParser(description='Finding Mean Absolute Percentage Error using Monte Carlo Simulation')
- parser.add_argument('days_to_test', help='Provide the number of days to test')
- parser.add_argument('days_to_predict', help='Provide the number of days to predict')
- parser.add_argument('number_of_simulations', help='Provide the number of simulations to run')
- parser.add_argument('return_mode', help='Output to be returned, choose one of the modes: "forecast only", "both", or "MAPE only"')
- args = parser.parse_args()
-
- mc.generate_mape(args.days_to_test, args.days_to_predict, args.number_of_simulations, args.return_mode)
diff --git a/models/forecasting-algorithms/moving_average.py b/models/forecasting-algorithms/moving_average.py
deleted file mode 100644
index e5b0428..0000000
--- a/models/forecasting-algorithms/moving_average.py
+++ /dev/null
@@ -1,106 +0,0 @@
-from parameters import mongodb_connection
-from pymongo import MongoClient
-import pandas as pd
-import numpy as np
-import argparse
-
-
-class MovingAverage:
-
- def __init__(self, company) -> None:
- client = MongoClient(mongodb_connection)
- database = client.StockTracker
- collection = database.Companies
- projection = {"_id": 1, "price": 1}
- cursor = collection.find({"_id": company}, projection)
- for doc in cursor:
- all_points = doc["price"]
- self.dataset = [float(closing_price["close"]) for closing_price in all_points]
- self.window_size = [window for window in range(10, 1000)]
- self.smoothing_factor = [smoothing_factor / 10 for smoothing_factor in range(1, 10)]
- self.sma_results = {}
- self.sma_predictions = []
- self.ema_results = {}
- self.ema_predictions = []
- self.best_results = {"algo": None, "MAPE": float("inf"), "window": None, "smoothing_factor": None}
- self.mape = float("inf")
-
- def simple_moving_average(self, window):
- dataset_length = len(self.dataset)
- start, end = 0, window
- curr_sum = sum(self.dataset[:end])
- actual_dataset, forecasted_dataset = [], []
- actual_data = self.dataset[end]
- actual_dataset.append(actual_data)
- forecasted_data = curr_sum / window
- forecasted_dataset.append(forecasted_data)
- for end in range(window + 1, dataset_length):
- curr_sum = curr_sum + self.dataset[end - 1] - self.dataset[start]
- start += 1
- actual_data = self.dataset[end]
- actual_dataset.append(actual_data)
- forecasted_data = curr_sum / window
- forecasted_dataset.append(forecasted_data)
- actual_dataset = pd.Series(actual_dataset)
- forecasted_dataset = pd.Series(forecasted_dataset)
- curr_mape = np.mean(np.abs(forecasted_dataset - actual_dataset)/np.abs(actual_dataset)) * 100
- self.sma_results[window] = {
- "MAPE": curr_mape
- }
- if curr_mape < self.best_results["MAPE"]:
- self.best_results["algo"] = "sma"
- self.best_results["MAPE"] = curr_mape
- self.best_results["window"] = window
- self.best_results["smoothing_factor"] = None
- return (curr_sum + self.dataset[end] - self.dataset[start]) / window
-
- def exponential_moving_average(self, smoothing_factor):
- dataset_length = len(self.dataset)
- total_percentage_error = 0
- first_data = self.dataset[0]
- second_data = self.dataset[1]
- actual_dataset, forecasted_dataset = [], []
- actual_dataset.append(second_data)
- forecasted_dataset.append(first_data)
- curr_error = second_data - first_data
- total_percentage_error += (abs(curr_error) / second_data) * 100
- for end in range(2, dataset_length):
- forecasted_value = smoothing_factor * second_data + (1 - smoothing_factor) * first_data
- actual_data = self.dataset[end]
- actual_dataset.append(actual_data)
- forecasted_dataset.append(forecasted_value)
- curr_error = forecasted_value - actual_data
- total_percentage_error += (abs(curr_error) / actual_data) * 100
- first_data = forecasted_value
- second_data = actual_data
- actual_dataset = pd.Series(actual_dataset)
- forecasted_dataset = pd.Series(forecasted_dataset)
- curr_mape = np.mean(np.abs(forecasted_dataset - actual_dataset)/np.abs(actual_dataset)) * 100
- self.ema_results[smoothing_factor] = {
- "MAPE": curr_mape
- }
- if curr_mape < self.best_results["MAPE"]:
- self.best_results["algo"] = "ema"
- self.best_results["MAPE"] = curr_mape
- self.best_results["window"] = None
- self.best_results["smoothing_factor"] = smoothing_factor
- return smoothing_factor * second_data + (1 - smoothing_factor) * first_data
-
- def run_forecast(self):
- for window in self.window_size:
- forecasted_value = self.simple_moving_average(window)
- self.sma_predictions.append(forecasted_value)
-
- for smoothing_factor in self.smoothing_factor:
- forecasted_value = self.exponential_moving_average(smoothing_factor)
- self.ema_predictions.append(forecasted_value)
-
- return self.sma_results, self.sma_predictions, self.ema_results, self.ema_predictions
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description='Finding Mean Absolute Percentage Error using two different moving averages')
- parser.add_argument('company_name', help='Provide company name to analyse')
- args = parser.parse_args()
- ma = MovingAverage(args.company_name)
- ma.run_forecast()
diff --git a/models/forecasting-algorithms/proph_forecast.py b/models/forecasting-algorithms/proph_forecast.py
deleted file mode 100644
index e2b09a3..0000000
--- a/models/forecasting-algorithms/proph_forecast.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import pandas as pd
-from prophet import Prophet
-from parameters import mongodb_connection
-from pymongo import MongoClient
-import numpy as np
-import argparse
-
-class ProphForecast:
-
- def __init__(self) -> None:
- client = MongoClient(mongodb_connection)
- database = client.StockTracker
- collection = database.Companies
- projection = {"_id": 1, "price": 1}
- cursor = collection.find({"_id": "AAPL"}, projection)
- for doc in cursor:
- self.all_points = doc["price"]
-
- def generate_mape(self, days_to_test, days_to_predict):
- days_to_test = eval(days_to_test)
- days_to_predict = eval(days_to_predict)
- df = pd.DataFrame.from_dict(self.all_points[:days_to_test][::-1])
- new_headers = {"date": "ds",
- "close": "y"}
- df.rename(columns=new_headers,
- inplace=True)
- m = Prophet()
- m.fit(df)
- future = m.make_future_dataframe(periods=days_to_predict)
- forecast = m.predict(future)
- actual_prices = pd.Series([float(price) for price in df["y"].values.tolist()])
- forecasted_prices = pd.Series([price[0] for price in forecast[["yhat"]].values.tolist()[:-1]])
- Mean_Absolute_Percentage_Error = np.mean(np.abs(forecasted_prices - actual_prices)/np.abs(actual_prices)) * 100
- return Mean_Absolute_Percentage_Error
-
-
-if __name__ == "__main__":
- pf = ProphForecast()
-
- parser = argparse.ArgumentParser(description='Finding Mean Absolute Percentage Error using Prophet Forecast')
- parser.add_argument('days_to_test', help='Provide the number of days to test')
- parser.add_argument('days_to_predict', help='Provide the number of days to predict')
- args = parser.parse_args()
-
- pf.generate_mape(args.days_to_test, args.days_to_predict)
diff --git a/requirements.txt b/requirements.txt
index a2b8b8a..d9b273b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,36 +1,47 @@
-bokeh==3.1.1
+appdirs==1.4.4
+asgiref==3.7.2
+beautifulsoup4==4.12.3
certifi==2023.5.7
charset-normalizer==3.1.0
-contourpy==1.0.7
-cycler==0.11.0
+coverage==7.4.3
+Django==4.1.12
+django-admin-volt==1.0.10
exceptiongroup==1.1.1
-fonttools==4.42.1
+frozendict==2.4.0
+gunicorn==21.2.0
+html5lib==1.1
idna==3.4
iniconfig==2.0.0
-Jinja2==3.1.2
-kiwisolver==1.4.5
-MarkupSafe==2.1.2
-matplotlib==3.7.2
-numpy==1.24.3
+install==1.3.5
+jmespath==1.0.1
+lxml==4.9.3
+multitasking==0.0.11
+mypy==1.8.0
+mypy-extensions==1.0.0
+numpy==1.26.4
packaging==23.1
-pandas==2.0.1
-Pillow==9.5.0
+pandas==2.2.1
+pandas-stubs==2.2.0.240218
+peewee==3.17.1
pluggy==1.0.0
-pyparsing==3.0.9
-PyQt5==5.15.9
-PyQt5-Qt5==5.15.2
-PyQt5-sip==12.12.1
-pyqtgraph==0.13.3
-PyQtWebEngine==5.15.6
-PyQtWebEngine-Qt5==5.15.2
pytest==7.3.1
+pytest-cov==4.1.0
python-dateutil==2.8.2
+python-dotenv==1.0.0
pytz==2023.3
-PyYAML==6.0
requests==2.31.0
+ruff==0.3.0
six==1.16.0
+soupsieve==2.5
+sqlparse==0.4.4
tomli==2.0.1
-tornado==6.3.2
+types-openpyxl==3.1.0.20240301
+types-pytz==2024.1.0.20240203
+types-requests==2.31.0.20240218
+typing_extensions==4.8.0
tzdata==2023.3
urllib3==2.0.2
-xyzservices==2023.5.0
+utils==1.0.1
+webencodings==0.5.1
+whitenoise==6.5.0
+yfinance==0.2.37
diff --git a/models/forecasting-algorithms/__init__.py b/src/__init__.py
similarity index 100%
rename from models/forecasting-algorithms/__init__.py
rename to src/__init__.py
diff --git a/src/company_selection.py b/src/company_selection.py
deleted file mode 100644
index 433c880..0000000
--- a/src/company_selection.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from PyQt5.QtWidgets import QWidget, QComboBox
-from graph_plotting import GraphPlotting
-from live_price_display import LivePriceDisplay
-from news_display import NewsDisplay
-from models.model import Model
-
-class CompanySelection(QWidget):
- """Displays and handles company selection menu"""
-
- def __init__(self, parent=None):
- super().__init__()
- self.parent = parent
- self.combo_box = QComboBox()
- self.combo_box.setMaximumSize(700, 50)
- self.placeholder_text: str = "Select a company"
- self.combo_box.addItem(self.placeholder_text)
-
- models = Model()
- # Populate the combo box with available company names from csv files
- for company in models.company_list:
- self.combo_box.addItem(company)
-
- # Create instances of widgets for graph plotting, live price display, and news display
- self.graph_plotting_widget: GraphPlotting = GraphPlotting()
- self.live_price_display_widget: LivePriceDisplay = LivePriceDisplay()
- self.news_display_widget: NewsDisplay = NewsDisplay()
- self.combo_box.currentTextChanged.connect(lambda: self.company_selected())
-
- # Connect the combo box's currentTextChanged signal to update the selected company in the widgets
- def company_selected(self) -> None:
- """
- If the selected company is not equal to the placeholder text, it will obtain information of the selected company
- , if it is then it will not do anything.
-
- :return: None
- """
- if self.combo_box.currentText() != self.placeholder_text:
- self.graph_plotting_widget.plot_selected_graph(self.combo_box.currentText())
- self.live_price_display_widget.display_final_price(self.combo_box.currentText())
- self.news_display_widget.display_company_news(self.combo_box.currentText())
diff --git a/src/graph_plotting.py b/src/graph_plotting.py
deleted file mode 100644
index c7744b6..0000000
--- a/src/graph_plotting.py
+++ /dev/null
@@ -1,53 +0,0 @@
-from PyQt5.QtWidgets import QWidget, QVBoxLayout
-import pyqtgraph as pg
-import os, sys
-sys.path.append(os.getcwd())
-from processing.data_processing import DataProcessing
-
-import numpy as np
-import pandas as pd
-
-PLOT_HEIGHT = 700
-PLOT_WIDTH = 1100
-
-
-class GraphPlotting(QWidget):
- """Plots graph from available data"""
-
- def __init__(self, parent=None):
- super().__init__()
- self.web_view: str = None
- self.parent = parent
- self.vbox_left = QVBoxLayout()
- self.placeholder_graph()
- self.data_processor = DataProcessing()
-
- def placeholder_graph(self) -> None:
- """
- Creates a placeholder graph using Bokeh.
-
- :return: None
- """
- self.plot = pg.PlotWidget()
- self.x = [1, 2, 3, 4, 5]
- self.y = [5, 4, 3, 2, 1]
-
- self.dataline = self.plot.plot(self.x, self.y)
- self.vbox_left.addWidget(self.plot)
-
- return self.plot
-
- def plot_selected_graph(self, company_name: str) -> None:
- """
- Plots selected company.
-
- :param company_name: (str) The name of the selected company
- :return: None
- """
- data: dict = self.data_processor.companies_data[company_name].to_dict()
- self.x = data["date"]
- self.y = data["close"]
- # print(self.x)
- # print(self.y)
- # print(len(self.x)==len(self.y))
- # self.dataline.setData(self.x, self.y)
\ No newline at end of file
diff --git a/src/live_price_display.py b/src/live_price_display.py
index bddab02..e113efb 100644
--- a/src/live_price_display.py
+++ b/src/live_price_display.py
@@ -1,67 +1,77 @@
-from PyQt5.QtWidgets import QWidget, QLabel
-from PyQt5.QtGui import QFont
-from PyQt5 import QtCore
-import requests
-import yfinance as yf
+"""This module returns the most recent price of the selected company."""
+from typing import Union, Any
+import requests
+import yfinance as yf # type: ignore[import-untyped] # pylint: disable=E0401
+import pandas as pd
-from src.data_processing import DataProcessing
-from src.parameters import ALPHA_VANTAGE_API_KEY, mongodb_connection
+from src.parameters import ALPHA_VANTAGE_API_KEY
ALPHA_VANTAGE_ENDPOINT = "https://www.alphavantage.co/query"
class LivePriceDisplay:
- """Shows the live prices"""
-
- def __init__(self, parent=None):
- super().__init__()
- self.price = None
- self.parent = parent
+ """
+ Returns the most recent price of the selected company.
+ """
-
- def display_final_price_av(self, company_name: str) -> None:
+ @staticmethod
+ def display_final_price_av(company_name: str) -> Union[str, dict, Any]:
"""
- Attempts to display the final price of the selected company.
+ Returns a the price using Alpha Vantage.
- :param company_name: (str) The name of the name to display the final price for.
- :return:
- """
+ Args:
+ company_name: The ticker symbol of the company
+ Returns:
+ The most recent price in string
+ """
try:
# Gets last available price by default
price_params: dict = {
"apikey": ALPHA_VANTAGE_API_KEY,
"function": "TIME_SERIES_DAILY",
- "symbol": company_name
+ "symbol": company_name,
}
- price_response: requests.models.Response = requests.get(ALPHA_VANTAGE_ENDPOINT, params=price_params)
+ price_response: requests.models.Response = requests.get(
+ ALPHA_VANTAGE_ENDPOINT, params=price_params
+ )
if price_response.ok:
response_data: dict = price_response.json()
if "Time Series (Daily)" in response_data:
price_list: dict = response_data["Time Series (Daily)"]
most_recent_day: str = next(iter(price_list))
return price_list[most_recent_day]["4. close"]
+ return response_data
+ return price_response
- except (requests.RequestException, KeyError, IndexError):
- raise
+ except (
+ requests.exceptions.MissingSchema,
+ requests.RequestException,
+ KeyError,
+ IndexError,
+ ):
+ return "Error fetching price"
-
- def display_final_price_yf(self, company_name: str) -> None:
+ @staticmethod
+ def display_final_price_yf(company_name: str) -> Union[float, str]:
"""
- Attempts to display the final price of the selected company.
+ Returns a the price using Yahoo Finance.
- :param company_name: (str) The name of the name to display the final price for.
- :return:
- """
+ Args:
+ company_name: The ticker symbol of the company
+ Returns:
+ The most recent price in string
+ """
try:
- df = yf.download(company_name)
- price = df.iloc[-1]["Close"]
+ df: pd.DataFrame = yf.download(company_name) # pylint: disable=C0103
+ price: float = df.iloc[-1]["Close"]
return round(price, 5)
except IndexError:
return "Error fetching price"
+
# from pymongo import MongoClient
# client = MongoClient(mongodb_connection)
# database = client.StockTracker
@@ -80,6 +90,6 @@ def display_final_price_yf(self, company_name: str) -> None:
# "outputsize": "full"
# }
# a = requests.get(ALPHA_VANTAGE_ENDPOINT, params=price_params).json()
-# company = {"_id": symbol, "price":[{"date": b, "close": a["Time Series (Daily)"][b]["4. close"]} for b in a["Time Series (Daily)"]]}
+# company = {"_id": symbol, "price":[{"date": b, "close": a["Time Series (Daily)"][b]["4. close"]} for b in a["Time Series (Daily)"]]} # pylint: disable=C0301
# result = collection.insert_one(company)
# print(f"Inserted document ID: {result.inserted_id}")
diff --git a/src/main.py b/src/main.py
deleted file mode 100644
index 80fc450..0000000
--- a/src/main.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from PyQt5.QtWidgets import QMainWindow, QWidget, QVBoxLayout, QApplication, QHBoxLayout
-import sys
-
-from company_selection import CompanySelection
-
-
-class MainWindow(QMainWindow):
- """Main GUI window"""
-
- def __init__(self, parent=None):
- super().__init__()
- self.parent = parent
- self.setWindowTitle("Stock Viewer")
- self.setGeometry(100, 100, 800, 600)
- self.showMaximized()
-
- central_widget = QWidget()
- self.setCentralWidget(central_widget)
-
- vbox_right = QVBoxLayout()
-
- cs: CompanySelection = CompanySelection()
-
- vbox_right.addWidget(cs.combo_box)
- vbox_right.addWidget(cs.live_price_display_widget.share_price_label)
- vbox_right.addWidget(cs.news_display_widget.company_news_section)
-
- window_layout = QHBoxLayout()
-
- window_layout.addLayout(cs.graph_plotting_widget.vbox_left)
- window_layout.addLayout(vbox_right)
-
- central_widget.setLayout(window_layout)
-
-
-if __name__ == '__main__':
- app = QApplication(sys.argv)
- window = MainWindow()
- window.show()
- app.exec_()
diff --git a/src/model.py b/src/model.py
index 27a6660..0b696a7 100644
--- a/src/model.py
+++ b/src/model.py
@@ -1,16 +1,17 @@
-# import csv
-import os, sys
+"""This module reads csv data files and processes them into the required format"""
+
+from typing import Union
+import os
import pandas as pd
-sys.path.append(os.getcwd())
class Model:
"""Processes data and returns data in required format"""
- def __init__(self, parent=None):
- super().__init__()
- self.company_list = None
- self.parent = parent
- self.path: str = "/home/wleong/Personal_project/StockTracker/individual_stocks_5yr/" # This is relative from where you run the script, not where this script is
+
+ def __init__(self) -> None:
+ self.path: str = (
+ "../individual_stocks_5yr/"
+ )
def generate_company_list(self) -> list:
"""
@@ -19,8 +20,8 @@ def generate_company_list(self) -> list:
:return: (list) A list of companies.
"""
company_list: list = []
- expected_headers = ["date", "close"]
- for (dirpath, dirnames, filenames) in os.walk(self.path):
+ expected_headers: list = ["date", "close"]
+ for _, _, filenames in os.walk(self.path):
for file in filenames:
if file.endswith(".csv"):
if self.check_headers_and_data(file, expected_headers):
@@ -36,15 +37,18 @@ def check_headers_and_data(self, file, expected_headers) -> bool:
:param expected_headers: (list) The list of headers required
:return: (bool) The results of the file
"""
- has_expected_headers = False
- has_data = False
+ has_expected_headers: bool = False
+ has_data: bool = False
try:
- parse_dates = ["date"]
- df = pd.read_csv(self.path + file, skip_blank_lines=True,
- dtype={"date": "string", "close": "float64"},
- parse_dates=parse_dates)
- headers = set(df.columns.to_list())
- expected_headers_copy = expected_headers[:]
+ parse_dates: list = ["date"]
+ df: pd.DataFrame = pd.read_csv( # pylint: disable=C0103
+ self.path + file,
+ skip_blank_lines=True,
+ dtype={"date": "string", "close": "float64"},
+ parse_dates=parse_dates,
+ )
+ headers: set = set(df.columns.to_list())
+ expected_headers_copy: list = expected_headers[:]
# Two conditions the while loop should break:
# 1. No more headers in expected_headers_copy (all are met)
# 2. At least one header is not met
@@ -60,62 +64,43 @@ def check_headers_and_data(self, file, expected_headers) -> bool:
except pd.errors.EmptyDataError:
return False
try:
- df.iloc[[0]]
+ df.iloc[[0]] # pylint: disable=E1101,W0104
has_data = True
except (ValueError, IndexError, NameError):
return False
return has_expected_headers and has_data
- # no data (done)
- # data on top and headers at the bottom (treated same way as wrong headers)
- # only headers (done)
- # only data (done)
- # when is it time to use numpy or pandas to filter:
- # empty rows of data (solved using pandas dropna)
- # Nan data(solved, same as whole row empty, whole row will be removed)
- # one empty data (solved, same as whole row empty, whole row will be removed)
- # blank rows before headers and data, might need to remove before check headers and data? can be within this function (solved using skip_blank_lines from pandas)
-
- # string data in float or float data in string (not tested here, should be done in pandas read_csv, which I have mocked)
- # repeated headers in csv (should be checked in process_data) (done)
- # repeated headers in expected headers (should be checked in process_data)(done)
- # what happens if expected headers is empty(does it return whole dataframe?)(done, nothing)
- def process_data(self, expected_headers) -> pd.DataFrame:
+ def process_data(self, expected_headers: list) -> Union[pd.DataFrame, str]:
"""
Slices the data as required.
:return: (DataFrame) A DataFrame containing required information of all companies.
"""
- companies_list = self.generate_company_list()
+ companies_list: list = self.generate_company_list()
companies_data: dict = {}
try:
for company in companies_list:
csv_file: str = f"{self.path}{company}_data.csv"
- parse_dates = ["date"]
- df: pd.DataFrame = pd.read_csv(
- csv_file, header=0, usecols=expected_headers, skip_blank_lines=True,
+ parse_dates: list = ["date"]
+ df: pd.DataFrame = pd.read_csv( # pylint: disable=C0103
+ csv_file,
+ header=0,
+ usecols=expected_headers,
+ skip_blank_lines=True,
dtype={"date": "string", "close": "float64"},
- parse_dates=parse_dates)
+ parse_dates=parse_dates,
+ )
df.dropna(how="all", subset="date", inplace=True)
- df.interpolate(method='linear', inplace=True)
+ df.interpolate(method="linear", inplace=True) # pylint: disable=E1101
df["date"] = pd.to_datetime(df["date"])
df["date"] = df["date"].dt.strftime("%Y-%m-%d")
df["close"] = pd.to_numeric(df["close"])
modified_data: dict = df.to_dict("list")
companies_data[company] = modified_data
- all_companies_data = pd.DataFrame(companies_data)
+ all_companies_data: pd.DataFrame = pd.DataFrame(companies_data)
return all_companies_data
except (ValueError, TypeError, KeyError):
- return "Please ensure each header is unique, data is correct, or expected_headers and process_data are configured correctly"
-# a = Model()
-# a.path = "tests/sample_data/"
-# filename = "CompanyH_data.csv"
-# expected_headers = ["date", "closing"]
-# b = a.check_headers_and_data(filename, expected_headers)
-# print(b)
-
-# hard to make multiple csv files and push
-# use pandas to read csv and save it as df?
-# then don't need to make csv, can make pandas df for tests
-# maybe only make csv for generate_company_list?
-# df for check_headers_and_data?
\ No newline at end of file
+ return (
+ "Please ensure each header is unique, data is correct, "
+ "or expected_headers and process_data are configured correctly"
+ )
diff --git a/src/news_display.py b/src/news_display.py
index c9cc933..8f3c6fe 100644
--- a/src/news_display.py
+++ b/src/news_display.py
@@ -1,3 +1,5 @@
+"""This module displays the most recent news of the selected company if available"""
+
import requests
from src.parameters import NEWS_API_KEY
@@ -6,37 +8,52 @@
class NewsDisplay:
- """Obtains and handles recent news"""
-
- def __init__(self, parent=None):
- super().__init__()
- self.parent = parent
+ """
+ Returns the most recent news of the selected company, if any.
+ """
- def _collect_news(self, company_name: str) -> list:
+ @staticmethod
+ def _collect_news(company_name: str) -> list:
"""
- Collect recent news articles related to a specific company and format them.
+ Collect recent news articles related to the selected company and format them.
+ Args:
+ company_name: The ticker symbol of the company
- :param company_name: (str) The name of the company to collect news for.
- :return: (list) A list of formatted news headlines with respective URLs.
+ Returns:
+ five_article: The most recent five articles
"""
- news_params: dict = {
- "apiKey": NEWS_API_KEY,
- "qInTitle": company_name
- }
+ news_params: dict = {"apiKey": NEWS_API_KEY, "qInTitle": company_name}
- news_response: requests.models.Response = requests.get(NEWS_ENDPOINT, params=news_params)
+ news_response: requests.models.Response = requests.get(
+ NEWS_ENDPOINT, params=news_params
+ )
articles: list = news_response.json()["articles"]
five_articles: list = articles[:5]
return five_articles
-
- def format_news_pyqt(self, company_name):
- # Generate formatted headlines with clickable URLs
- news = self._collect_news(company_name)
+
+ def format_news_pyqt(self, company_name: str) -> list:
+ """
+ Formats the collected news to suit different PyQt5 UI.
+ Args:
+ company_name: The ticker symbol of the company
+
+ Returns:
+ five_article: The most recent five articles
+ """
+ news: list = self._collect_news(company_name)
return [
f"{article['title']}: ''{article['url']}''"
for article in news
]
-
- def format_news_django(self, company_name):
- news = self._collect_news(company_name)
- return [{"title":article["title"], "url":article["url"]} for article in news]
+
+ def format_news_django(self, company_name: str) -> list:
+ """
+ Formats the collected news to suit different django UI.
+ Args:
+ company_name: The ticker symbol of the company
+
+ Returns:
+ five_article: The most recent five articles
+ """
+ news: list = self._collect_news(company_name)
+ return [{"title": article["title"], "url": article["url"]} for article in news]
diff --git a/tests/sample_data/CompanyD_data.txt b/tests/__init__.py
similarity index 100%
rename from tests/sample_data/CompanyD_data.txt
rename to tests/__init__.py
diff --git a/tests/sample_data/CompanyA_data.csv b/tests/sample_data/CompanyA_data.csv
deleted file mode 100644
index 95b3797..0000000
--- a/tests/sample_data/CompanyA_data.csv
+++ /dev/null
@@ -1,2 +0,0 @@
-date,open,volume,close
-12/08/2021,12,2000,13
diff --git a/tests/sample_data/CompanyB_data.csv b/tests/sample_data/CompanyB_data.csv
deleted file mode 100644
index 505f3ff..0000000
--- a/tests/sample_data/CompanyB_data.csv
+++ /dev/null
@@ -1,2 +0,0 @@
-date,,,close
-12/08/2021,,,13
diff --git a/tests/sample_data/CompanyC_data.csv b/tests/sample_data/CompanyC_data.csv
deleted file mode 100644
index a24f91f..0000000
--- a/tests/sample_data/CompanyC_data.csv
+++ /dev/null
@@ -1,2 +0,0 @@
-date,close
-12/08/2021,13
diff --git a/tests/sample_data/CompanyD_data.csv b/tests/sample_data/CompanyD_data.csv
deleted file mode 100644
index 19daa0b..0000000
--- a/tests/sample_data/CompanyD_data.csv
+++ /dev/null
@@ -1,2 +0,0 @@
-date,open,high,low ,close
-12/08/2021,12,14,11,13
diff --git a/tests/sample_data/CompanyE_data.xml b/tests/sample_data/CompanyE_data.xml
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/sample_data/CompanyF_data.xlsx b/tests/sample_data/CompanyF_data.xlsx
deleted file mode 100644
index ab28378..0000000
Binary files a/tests/sample_data/CompanyF_data.xlsx and /dev/null differ
diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py
deleted file mode 100644
index 89576d0..0000000
--- a/tests/test_data_processing.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import unittest
-import sys
-
-pkg_dir = "../src"
-sys.path.append(pkg_dir)
-
-from src import data_processing, model
-
-
-class MyDataProcessingTestCase(unittest.TestCase):
-
- def test_process_data(self):
- """
- Test the data processing functionality.
-
- This test case checks if the data processing works correctly by comparing
- the number of data points generated for each company.
-
- :return: None
- """
-
- data_processor = data_processing.DataProcessing()
- model_test = model.Model()
- model_test.path = "sample_data/"
- company_list = model_test.generate_company_list()
-
- data_processor.companies_list = company_list
- data_processor.path = "sample_data/"
- test_data = data_processor.process_data()
-
- companies_passed = []
- for company in company_list:
- if len(test_data[company]) == 2:
- companies_passed.append(company)
- self.assertEqual(company_list, companies_passed)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/test_live_price_display.py b/tests/test_live_price_display.py
index a222baa..9ed3c46 100644
--- a/tests/test_live_price_display.py
+++ b/tests/test_live_price_display.py
@@ -1,4 +1,4 @@
-from unittest.mock import patch
+from unittest.mock import patch, Mock
from src.live_price_display import LivePriceDisplay
import pytest
@@ -13,12 +13,30 @@ def test_mock_display_final_price_av(mock_get, test_live_price_display):
assert isinstance(results, float)
assert results == 123.45
+@patch("requests.get")
+def test_mock_display_final_price_av_price_response_not_ok(mock_get, test_live_price_display):
+ mock_get.return_value = ""
+ mock_response = Mock()
+ mock_response.ok = False
+ mock_response.response = mock_get.return_value
+ results = mock_response.response
+ assert mock_response.ok == False
+ assert isinstance(results, str)
+ assert results == ""
+
@patch("src.live_price_display.LivePriceDisplay.display_final_price_av")
def test_display_final_price_av_exception(mock_get, test_live_price_display):
mock_get.return_value = "Error fetching price"
result = test_live_price_display.display_final_price_av("Unknown_Company")
assert result == "Error fetching price"
+@patch("src.live_price_display.LivePriceDisplay.display_final_price_av")
+def test_display_final_price_av_invalid_api_call(mock_get, test_live_price_display):
+ mock_get.return_value = {"Error Message":
+ "Invalid API call. Please retry or visit the documentation (https://www.alphavantage.co/documentation/) for TIME_SERIES_DAILY."}
+ result = test_live_price_display.display_final_price_av("Unknown_Company")
+ assert "Error Message" in result
+
@patch("src.live_price_display.LivePriceDisplay.display_final_price_yf")
def test_mock_display_final_price_yf(mock_get, test_live_price_display):
mock_get.return_value = 291.97